├── .gitignore ├── LICENSE ├── README.md ├── dataset ├── CSL │ ├── processed │ │ ├── data.pt │ │ ├── pre_filter.pt │ │ └── pre_transform.pt │ └── raw │ │ ├── X_eye_list_Kary_Deterministic_Graphs.pkl │ │ ├── X_unity_list_Kary_Deterministic_Graphs.pkl │ │ ├── __MACOSX │ │ ├── ._X_eye_list_Kary_Deterministic_Graphs.pkl │ │ ├── ._X_unity_list_Kary_Deterministic_Graphs.pkl │ │ ├── ._graphs_Kary_Deterministic_Graphs.pkl │ │ └── ._y_Kary_Deterministic_Graphs.pt │ │ ├── graphs_Kary_Deterministic_Graphs.pkl │ │ └── y_Kary_Deterministic_Graphs.pt ├── EXP │ ├── processed │ │ ├── data.pt │ │ ├── pre_filter.pt │ │ └── pre_transform.pt │ └── raw │ │ ├── GRAPHSAT.pkl │ │ └── newGRAPHSAT.pkl ├── sr25 │ ├── processed │ │ ├── data.pt │ │ ├── pre_filter.pt │ │ └── pre_transform.pt │ └── raw │ │ └── sr251256.g6 └── subgraphcount │ ├── processed │ ├── data.pt │ ├── pre_filter.pt │ └── pre_transform.pt │ └── raw │ └── randomgraph.mat ├── docs ├── Advance_MultipleTensor.md ├── BasicDataStructure.md ├── BasicOperators.md ├── HoData.md ├── Makefile ├── SpeedIssue.md ├── make.bat ├── mini_example.md └── source │ ├── index.rst │ ├── modules │ ├── backend.rst │ ├── hodata.rst │ └── honn.rst │ └── notes │ ├── datastructure.rst │ ├── hodata.rst │ ├── installation.rst │ ├── miniexample.rst │ ├── multtensor.rst │ └── operator.rst ├── example ├── lr_scheduler.py ├── minimal.py ├── reproduce.sh ├── work.sh └── zinc.py ├── pygho ├── __init__.py ├── backend │ ├── MaTensor.py │ ├── Mamamm.py │ ├── SpTensor.py │ ├── Spmamm.py │ ├── Spmm.py │ ├── Spspmm.py │ ├── __init__.py │ └── utils.py ├── hodata │ ├── MaData.py │ ├── MaTupleSampler.py │ ├── ParallelPreprocess.py │ ├── SpData.py │ ├── SpTupleSampler.py │ ├── Wrapper.py │ └── __init__.py └── honn │ ├── Conv.py │ ├── MaOperator.py │ ├── SpOperator.py │ ├── TensorOp.py │ ├── __init__.py │ └── utils.py ├── requirements.txt ├── setup.py └── tests ├── test_backend_masked.py └── test_backend_sparse.py /.gitignore: -------------------------------------------------------------------------------- 1 | out/* 2 | *.out 3 | *.test 4 | plot/* 5 | __pycache__ 6 | */__pycache__ 7 | */*/__pycache__ 8 | dataset/* 9 | *.db 10 | past 11 | mod/* 12 | *.ipynb 13 | !dataset/EXP 14 | !dataset/sr25 15 | !dataset/subgraphcount 16 | !dataset/CSL 17 | test/* 18 | opt/* 19 | pygho.egg-info 20 | *toml 21 | SWL 22 | subgraphcount 23 | ESAN 24 | I2GNN 25 | GNNAsKernel 26 | NestedGNN 27 | ProvablyPowerfulGraphNetworks_torch 28 | experiments 29 | SUN 30 | *.py 31 | .vscode 32 | */.vscode 33 | docs/build 34 | example/*.sh 35 | *.sh -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2023 Graph PKU Team 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PygHO 2 | 3 | A library for high-order GNN based on torch_geometric. 4 | 5 | ## Installation 6 | First clone our repo 7 | ``` 8 | git clone https://github.com/GraphPKU/PygHO.git 9 | ``` 10 | Then install it locally 11 | ``` 12 | cd PygHO 13 | pip install -e ./ 14 | ``` 15 | `-e` enables modifying the library code dynamically and is optional. 16 | 17 | ## Usage 18 | 19 | Please refer to our [online document](https://graphpku.github.io/PyGHO_doc/) for more details. 20 | -------------------------------------------------------------------------------- /dataset/CSL/processed/data.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/PygHO/4cd5a20adf84c76842e001b54570d053747c9470/dataset/CSL/processed/data.pt -------------------------------------------------------------------------------- /dataset/CSL/processed/pre_filter.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/PygHO/4cd5a20adf84c76842e001b54570d053747c9470/dataset/CSL/processed/pre_filter.pt -------------------------------------------------------------------------------- /dataset/CSL/processed/pre_transform.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/PygHO/4cd5a20adf84c76842e001b54570d053747c9470/dataset/CSL/processed/pre_transform.pt -------------------------------------------------------------------------------- /dataset/CSL/raw/X_eye_list_Kary_Deterministic_Graphs.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/PygHO/4cd5a20adf84c76842e001b54570d053747c9470/dataset/CSL/raw/X_eye_list_Kary_Deterministic_Graphs.pkl -------------------------------------------------------------------------------- /dataset/CSL/raw/X_unity_list_Kary_Deterministic_Graphs.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/PygHO/4cd5a20adf84c76842e001b54570d053747c9470/dataset/CSL/raw/X_unity_list_Kary_Deterministic_Graphs.pkl -------------------------------------------------------------------------------- /dataset/CSL/raw/__MACOSX/._X_eye_list_Kary_Deterministic_Graphs.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/PygHO/4cd5a20adf84c76842e001b54570d053747c9470/dataset/CSL/raw/__MACOSX/._X_eye_list_Kary_Deterministic_Graphs.pkl -------------------------------------------------------------------------------- /dataset/CSL/raw/__MACOSX/._X_unity_list_Kary_Deterministic_Graphs.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/PygHO/4cd5a20adf84c76842e001b54570d053747c9470/dataset/CSL/raw/__MACOSX/._X_unity_list_Kary_Deterministic_Graphs.pkl -------------------------------------------------------------------------------- /dataset/CSL/raw/__MACOSX/._graphs_Kary_Deterministic_Graphs.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/PygHO/4cd5a20adf84c76842e001b54570d053747c9470/dataset/CSL/raw/__MACOSX/._graphs_Kary_Deterministic_Graphs.pkl -------------------------------------------------------------------------------- /dataset/CSL/raw/__MACOSX/._y_Kary_Deterministic_Graphs.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/PygHO/4cd5a20adf84c76842e001b54570d053747c9470/dataset/CSL/raw/__MACOSX/._y_Kary_Deterministic_Graphs.pt -------------------------------------------------------------------------------- /dataset/CSL/raw/graphs_Kary_Deterministic_Graphs.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/PygHO/4cd5a20adf84c76842e001b54570d053747c9470/dataset/CSL/raw/graphs_Kary_Deterministic_Graphs.pkl -------------------------------------------------------------------------------- /dataset/CSL/raw/y_Kary_Deterministic_Graphs.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/PygHO/4cd5a20adf84c76842e001b54570d053747c9470/dataset/CSL/raw/y_Kary_Deterministic_Graphs.pt -------------------------------------------------------------------------------- /dataset/EXP/processed/data.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/PygHO/4cd5a20adf84c76842e001b54570d053747c9470/dataset/EXP/processed/data.pt -------------------------------------------------------------------------------- /dataset/EXP/processed/pre_filter.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/PygHO/4cd5a20adf84c76842e001b54570d053747c9470/dataset/EXP/processed/pre_filter.pt -------------------------------------------------------------------------------- /dataset/EXP/processed/pre_transform.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/PygHO/4cd5a20adf84c76842e001b54570d053747c9470/dataset/EXP/processed/pre_transform.pt -------------------------------------------------------------------------------- /dataset/EXP/raw/GRAPHSAT.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/PygHO/4cd5a20adf84c76842e001b54570d053747c9470/dataset/EXP/raw/GRAPHSAT.pkl -------------------------------------------------------------------------------- /dataset/EXP/raw/newGRAPHSAT.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/PygHO/4cd5a20adf84c76842e001b54570d053747c9470/dataset/EXP/raw/newGRAPHSAT.pkl -------------------------------------------------------------------------------- /dataset/sr25/processed/data.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/PygHO/4cd5a20adf84c76842e001b54570d053747c9470/dataset/sr25/processed/data.pt -------------------------------------------------------------------------------- /dataset/sr25/processed/pre_filter.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/PygHO/4cd5a20adf84c76842e001b54570d053747c9470/dataset/sr25/processed/pre_filter.pt -------------------------------------------------------------------------------- /dataset/sr25/processed/pre_transform.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/PygHO/4cd5a20adf84c76842e001b54570d053747c9470/dataset/sr25/processed/pre_transform.pt -------------------------------------------------------------------------------- /dataset/sr25/raw/sr251256.g6: -------------------------------------------------------------------------------- 1 | X}rM\QTeLEuUlQY[I\IgtWfTxCrhEOZo{FfwEew`LXMbWp}JDtM 2 | X}rM\QWh[jUYkiYYJMBDfSrXsLRgdOMusE{{JQhgxiMDSxVEplU 3 | X}rM\Qhd[fU`kdUSjDjHjWqxtOJdOgMxwEzYFRPgtdMDctfIpsu 4 | X}rUTEdmTpQxbkSxHkZceZBLtObTQHHmsbM{EFxwrroAtQtBtBr 5 | X}rUTEdmTpQybiSwhkjgeYbLrHBX`HI\wavYEFxwrroAt`tBsdr 6 | X}rUTIbmLqQybiTWh[jcXZCtqpBYPHE]wcuyEFxwrroAtQtBtBr 7 | X}rU\adeSetTjKWNJEYNR]PLjPBgUGVTkK^YKbipMcxbk`{DlXF 8 | X}r^SQbcdJQjesS[jLJQhPxTcZZcZ?S|krEYBlQolTZDhWuNKKN 9 | X}r^SQbcdQqjdwS[jLJJHQtTcZZcZ?L\krEYBlQostZDhWuNKKN 10 | X}vEKeLlTTUXjKXXK[ZKo]EXZAqxI``\khVUD]VOVptAqtXBbrL 11 | X}vEKiJklYUXjKXWk[jKo]EXYdQyD`amkgmuD]VGVpuAqtUBbrR 12 | X}vEKiJlTTUUjQWxKkZKo]EXZBQxH`a]kguuC}NGZquAptYBdrJ 13 | X}ve[IJd\FUImBSYmFJRIWfLLRYmJ?Tl[ZKYHhsokuXdhS{NKKN 14 | X~rM[Edd\RRBkpOxlMIqZXDxTWZdQGX][fI[EsholidDRdVNGk] 15 | X~rM[Ihi[eQVlQPTkTiqrWwxXgZgiGX^KfWqEdu?nG\dRP}NGk] 16 | -------------------------------------------------------------------------------- /dataset/subgraphcount/processed/data.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/PygHO/4cd5a20adf84c76842e001b54570d053747c9470/dataset/subgraphcount/processed/data.pt -------------------------------------------------------------------------------- /dataset/subgraphcount/processed/pre_filter.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/PygHO/4cd5a20adf84c76842e001b54570d053747c9470/dataset/subgraphcount/processed/pre_filter.pt -------------------------------------------------------------------------------- /dataset/subgraphcount/processed/pre_transform.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/PygHO/4cd5a20adf84c76842e001b54570d053747c9470/dataset/subgraphcount/processed/pre_transform.pt -------------------------------------------------------------------------------- /dataset/subgraphcount/raw/randomgraph.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/PygHO/4cd5a20adf84c76842e001b54570d053747c9470/dataset/subgraphcount/raw/randomgraph.mat -------------------------------------------------------------------------------- /docs/Advance_MultipleTensor.md: -------------------------------------------------------------------------------- 1 | # Multiple Tensor 2 | 3 | In our dataset preprocessing routine, the default computation involves two high-order tensors: the adjacency matrix `A` and the tuple feature `X`. However, in certain scenarios, there may be a need for additional high-order tensors. For instance, when using a Nested Graph Neural Network (GNN) with a 2-hop GNN as the base GNN, Message Passing Neural Network (MPNN) operations are performed on each subgraph with an augmented adjacency matrix. In this case, two high-order tensors are required: the tuple feature and the augmented adjacency matrix. 4 | 5 | During data preprocessing, we can use multiple samplers, each responsible for sampling one tensor. For sparse data, the code might look like this: 6 | 7 | ```python 8 | trn_dataset = ParallelPreprocessDataset( 9 | "dataset/ZINC_trn", trn_dataset, 10 | pre_transform=Sppretransform( 11 | tuplesamplers=[ 12 | partial(KhopSampler, hop=3), 13 | partial(KhopSampler, hop=2) 14 | ], 15 | annotate=["tuplefeat", "2hopadj"], 16 | keys=keys 17 | ), 18 | num_workers=8 19 | ) 20 | ``` 21 | 22 | In this code, two tuple features are precomputed simultaneously and assigned different annotations: "tuplefeat" and "2hopadj" to distinguish between them. 23 | 24 | For dense data, the process is quite similar: 25 | 26 | ```python 27 | trn_dataset = ParallelPreprocessDataset( 28 | "dataset/ZINC_trn", 29 | trn_dataset, 30 | pre_transform=Mapretransform( 31 | [ 32 | partial(spdsampler, hop=3), 33 | partial(spdsampler, hop=2) 34 | ], 35 | annotate=["tuplefeat", "2hopadj"] 36 | ), 37 | num_worker=0 38 | ) 39 | ``` 40 | 41 | After passing the data through a dataloader, the batch will contain `Xtuplefeat` and `X2hopadj` as the high-order tensors that are needed. For dense models, this concludes the process. However, for sparse models, if you want to retrieve the correct keys, you will need to modify the operator symbols for sparse message passing layers. 42 | 43 | Ordinarily, the `NGNNConv` is defined as: 44 | 45 | ```python 46 | NGNNConv(hiddim, hiddim, mlp=mlp) 47 | ``` 48 | 49 | This is equivalent to: 50 | 51 | ```python 52 | NGNNConv(hiddim, hiddim, mlp=mlp, optuplefeat="X", opadj="A") 53 | ``` 54 | 55 | To ensure that you retrieve the correct keys, you should use: 56 | 57 | ```python 58 | NGNNConv(hiddim, hiddim, mlp=mlp, optuplefeat="Xtuplefeat", opadj="X2hopadj") 59 | ``` 60 | 61 | Similar modifications should be made for other layers as needed. -------------------------------------------------------------------------------- /docs/BasicDataStructure.md: -------------------------------------------------------------------------------- 1 | # Refined Basic Data Structure 2 | 3 | In this section, we'll provide a refined explanation of the basic data structures, MaskedTensor and SparseTensor, used in HOGNNs to address their unique requirements. 4 | 5 | ## MaskedTensor 6 | 7 | HOGNNs demand specialized data structures to handle high-order tensors efficiently. One such structure is the **MaskedTensor**, consisting of two components: `data` and `mask`. 8 | 9 | - `data` has a shape of $(\text{masked shape}, \text{dense shape})$, residing in $\mathbb{R}^{n\times n\times d}$, where $n$ represents the number of nodes, and $d$ is the dimensionality of the data. 10 | - `mask` has a shape of $(\text{masked shape})$, containing Boolean values, typically $\{0,1\}^{n\times n}$. The element $(i,j)$ in `mask` is set to $1$ if the tuple $(i,j)$ exists in the tensor. 11 | 12 | Unused elements in `data` do not affect the output of the operators in this library. For example, when performing operations like summation, MaskedTensor treats the non-existing elements as $0$, effectively ignoring them. 13 | 14 | Here's an example of creating a MaskedTensor: 15 | 16 | ```python 17 | from pygho import MaskedTensor 18 | import torch 19 | 20 | n, d = 3, 3 21 | data = torch.tensor([[4, 1, 4], [4, 4, 2], [3, 4, 4]]) 22 | mask = torch.tensor([[0, 1, 0], [0, 0, 1], [1, 0, 0]], dtype=torch.bool) 23 | A = MaskedTensor(data, mask) 24 | ``` 25 | 26 | The non-existing elements in `data` can be assigned arbitrary values. The created masked tensor is as follows 27 | 28 | $$ 29 | \begin{bmatrix} 30 | -&1&-\\ 31 | -&-&2\\ 32 | 3&-&- 33 | \end{bmatrix} 34 | $$ 35 | 36 | ## SparseTensor 37 | 38 | On the other hand, the **SparseTensor** stores only existing elements, making it more efficient when a small ratio of valid elements is present. A SparseTensor, with shape `(sparse_shape, dense_shape)`, consists of two tensors: `indices` and `values`. 39 | 40 | - `indices` is an Integer Tensor with shape `(sparse_dim, nnz)`, where `sparse_dim` represents the number of dimensions in the sparse shape, and `nnz` stands for the count of existing elements. 41 | - `values` has a shape of `(nnz, dense_shape)`. 42 | 43 | The columns of `indices` and rows of `values` correspond to the non-zero elements, simplifying retrieval and manipulation of the required information. 44 | 45 | For instance, in the context of NGNN's representation $H\in \mathbb{R}^{n\times n\times d}$, where the total number of nodes in subgraphs is $m$, you can represent $H$ using `indices` $a\in \mathbb{N}^{2\times m}$ and `values` $v\in \mathbb{R}^{m\times d}$. Specifically, for $i=1,2,\ldots,n$, $H_{a_{1,i},a_{2,i}}=v_i$. 46 | 47 | Creating a SparseTensor is illustrated in the following example: 48 | 49 | ```python 50 | from pygho import SparseTensor 51 | import torch 52 | 53 | n, d = 3, 3 54 | indices = torch.tensor([[0, 1, 2], [1, 2, 0]], dtype=torch.long) 55 | values = torch.tensor([1, 2, 3]) 56 | A = SparseTensor(indices, values, shape=(3, 3)) 57 | ``` 58 | representing the following matrix 59 | $$ 60 | \begin{bmatrix} 61 | -&1&-\\ 62 | -&-&2\\ 63 | 3&-&- 64 | \end{bmatrix} 65 | $$ 66 | 67 | Please note that in the SparseTensor format, each non-zero element is represented by `sparse_dim` int64 indices and the element itself. If the tensor is not sparse enough, SparseTensor may occupy more memory than a dense tensor. -------------------------------------------------------------------------------- /docs/BasicOperators.md: -------------------------------------------------------------------------------- 1 | # Operators 2 | In this section, we'll provide a detailed introduction of some basic operators on high order tensors. 3 | ## Code Architecture 4 | The code for these high-order graph neural network (HOGNN) operations is organized into three layers: 5 | 6 | **Layer 1: Backend:** This layer, found in the `pygho.backend` module, contains basic data structures and operations focused on tensor manipulations. It lacks graph-specific learning concepts and includes the following functionalities: 7 | 8 | - **Matrix multiplication:** This method supports general matrix multiplication capabilities, including operations on two SparseTensors, one sparse and one MaskedTensor, and two MaskedTensors. It also handles batched matrix multiplication and offers operations that replace the sum in traditional matrix multiplication with max and mean operations. 9 | - **Two matrix addition:** Operations for adding two sparse or two dense matrices. 10 | - **Reduce operations:** These operations include sum, mean, max, and min, which reduce dimensions in tensors. 11 | - **Expand operation:** This operation adds new dimensions to tensors. 12 | - **Tuplewise apply(func):** It applies a given function to the underlying data tensor. 13 | - **Diagonal apply(func):** This operation applies a function to diagonal elements of tensors. 14 | 15 | **Layer 2: Graph operations:** Building upon Layer 1, the `pygho.honn.SpOperator` and `pygho.honn.MaOperator` modules provide graph operations specifically tailored for Sparse and Masked Tensor structures. Additionally, the `pygho.honn.TensorOp` layer wraps these operators, abstracting away the differences between Sparse and Masked Tensor data structures. These operations encompass: 16 | 17 | - **General message passing between tuples:** Facilitating message passing between tuples of nodes. 18 | - **Pooling:** This operation reduces high-order tensors to lower-order ones by summing, taking the maximum, or computing the mean across specific dimensions. 19 | - **Diagonal:** It reduces high-order tensors to lower-order ones by extracting diagonal elements. 20 | - **Unpooling:** This operation extends low-order tensors to high-order ones. 21 | 22 | **Layer 3: Models:** Building on Layer 2, this layer provides a collection of representative high-order GNN layers, including NGNN, GNNAK, DSSGNN, SUN, SSWL, PPGN, and I2GNN. Layer 3 offers numerous ready-to-use methods, and with Layer 2, users can design additional models using general graph operations. Layer 1 allows for the development of novel operations, expanding the library's flexibility and utility. Now let's explore these layers in more detail. 23 | 24 | ### Layer 1: Backend 25 | 26 | #### Spspmm 27 | One of the most complex operators in this layer is sparse-sparse matrix multiplication (Spspmm). Given two sparse matrices, C and D, and assuming their output is B: 28 | 29 | $$ 30 | B_{ij} = \sum_k C_{ik} D_{kj} 31 | $$ 32 | 33 | The Spspmm operator utilizes a coo format to represent elements. Assuming element C_{ik} corresponds to C.values[c_{ik}] and D_{kj} corresponds to D.values[d_{kj}], we can create a tensor `bcd` of shape (3, m), where m is the number of pairs (i, j, k) where both C_{ik} and D_{kj} exist. The multiplication can be performed as follows: 34 | 35 | ```python 36 | B.values = zeros(...) 37 | for i in range(m): 38 | B.values[bcd[0, i]] += C.values[bcd[1, i]] * D.values[bcd[2, i]] 39 | ``` 40 | 41 | This summation process can be efficiently implemented in parallel on GPU using `torch.Tensor.scatter_reduce_`. The `bcd` tensor can be precomputed with `pygho.backend.Spspmm.spspmm_ind` and shared among matrices with the same indices. 42 | 43 | Hadamard product between two sparse matrices can also be implemented, where $C = A \odot B$: C can use the same indice as $A$. 44 | 45 | ```python 46 | C.values[a_{ij}] = A.values[a_{ij}] * B.values[b_{ij}] 47 | ``` 48 | 49 | The tensor `b2a` can be defined, where `b2a[b_ij] = a_ij` if A has the element (i, j); otherwise, it is set to -1. Then, the Hadamard product can be computed as follows: 50 | 51 | ```python 52 | C.values = zeros(...) 53 | for i in range(A.nnz): 54 | if b2a[i] >= 0: 55 | C.values[i] = A.values[i] * B.values[b2a[i]] 56 | ``` 57 | 58 | The operation can also be efficiently implemented in parallel on a GPU. 59 | 60 | To compute $A\odot (CD)$, you can define a tensor `acd` of shape (3, m') where `acd[0] = b2a[bcd[0]]`, `acd[1] = bcd[1]`, and `acd[2] = bcd[2]`, and remove columns i where `acd[0, i] = -1`. The computation can be done as follows: 61 | 62 | ```python 63 | ret.values = zeros(...) 64 | for i in range(acd.shape[1]): 65 | ret.values[acd[0, i]] += A.values[acd[0, i]] * B.values[acd[1, i]] * C.values[acd[2, i]] 66 | ``` 67 | 68 | Like the previous operations, this can also be implemented efficiently in parallel on a GPU. Additionally, by setting `A.values[acd[0, i]]` to 1, A can act as a mask, ensuring that only elements existing in A are computed. 69 | 70 | The overall wrapper for these functions is `pygho.honn.Spspmm.spspmm`, which can perform sparse-sparse matrix multiplication with precomputed indices. `pygho.honn.Spspmm.spspmpnn` provides a more complex operator that goes beyond matrix multiplication, allowing you to implement various graph operations. It can in fact implement the following framework. 71 | 72 | $$ 73 | ret_{ij} = \phi(\{(A_{ij}, B_{ik}, C_{kj})|B_{ik},C_{kj} \text{ elements exist}\}) 74 | $$ 75 | where `phi` is a general multiset function, which is a functional parameter of `spspmpnn`. With it, we can implement GAT on each subgraph as follows. 76 | 77 | ```python 78 | self.attentionnn1 = nn.Linear(hiddim, hiddim, hiddim) 79 | self.attentionnn2 = nn.Linear(hiddim, hiddim, hiddim) 80 | self.attentionnn3 = nn.Linear(hiddim, hiddim, hiddim) 81 | self.subggnn = NGNNConv(hiddim, hiddim, args.aggr, "SS", transfermlpparam(mlp), message_func=lambda a,b,c,tarid: scatter_softmax(self.attentionnn1(a) * b * self.attention2(c), tarid) * self.attentionnn3(c)) 82 | 83 | ``` 84 | 85 | #### TuplewiseApply 86 | 87 | Both Sparse and Masked Tensors have the `tuplewiseapply` function. The most common usage is: 88 | 89 | ```python 90 | mlp = ... 91 | X.tuplewiseapply(mlp) 92 | ``` 93 | 94 | However, in practice, this function directly applies the values or data tensor to `mlp`. As linear layers, non-linearities, and layer normalization all operate on the last dimension, this operation is essentially equivalent to tuplewise apply. For batchnorm, we provide a version that not affected by this problem in `pygho.honn.utils`. 95 | 96 | #### DiagonalApply 97 | 98 | Both Sparse and Masked Tensors have the `diagonalapply` function. Unlike `tuplewiseapply`, this function passes both data/values and a mask indicating whether the corresponding elements are on the diagonal to the input function. A common use case is: 99 | 100 | ```python 101 | mlp1 = ... 102 | mlp2 = ... 103 | lambda x, diagonalmask: torch.where(diagonalmask, mlp1(x), mlp2(x)) 104 | X.diagonalapply(mlp) 105 | ``` 106 | 107 | Here, `mlp1` is applied to diagonal elements, and `mlp2` is applied to non-diagonal elements. You can also use `torch_geometric.nn.HeteroLinear` for a faster implementation. 108 | 109 | ### Layer 2: Operators 110 | 111 | `pygho.honn.SpOperator` and `pygho.honn.MaOperator` wrap the backend for SparseTensor and MaskedTensor separately. Their APIs are unified in `pygho.honn.TensorOp`. The basic operators include `OpNodeMessagePassing` (node-level message passing), `OpMessagePassing` (tuple-level message passing, wrapping matrix multiplication), `OpPooling` (reduce high-order tensors to lower-order ones by sum, mean, max), `OpDiag` (reduce high-order tensors to lower-order ones by extracting diagonal elements), and `OpUnpooling` (extend lower-order tensors to higher-order ones). Special cases are also defined. 112 | 113 | #### Sparse OpMessagePassing 114 | 115 | As described in Layer 1, the `OpMessagePassing` operator wraps the properties of Spspmm and is defined with parameters like `op0`, `op1`, `dim1`, `op2`, `dim2`, and `aggr`. It retrieves precomputed data from a data dictionary during the forward process using keys like `f"{op0}___{op1}___{dim1}___{op2}___{dim2}"`. Here's the forward method signature: 116 | 117 | ```python 118 | def forward(self, 119 | A: SparseTensor, 120 | B: SparseTensor, 121 | datadict: Dict, 122 | tarX: Optional[SparseTensor] = None) -> SparseTensor: 123 | ``` 124 | 125 | In this signature, `tarX` corresponds to `op0`, providing the target indices, while `A` and `B` correspond to `op1` and `op2`. The `datadict` can be obtained from the data loader using `for batch in dataloader: batch.to_dict()`. 126 | 127 | ## Example 128 | 129 | To illustrate how these operators work, let's use NGNN as an example. Although our operators can handle batched data, for simplicity, we'll focus on the single-graph case. Let H represent the representation matrix in $\mathbb{R}^{n\times n\times d}$, and A denote the adjacency matrix in $\mathbb{R}^{n\times n}$. The Graph Isomorphism Network (GIN) operation on all subgraphs can be defined as: 130 | 131 | $$ 132 | h_{ij} \leftarrow \sum_{k\in N_i(j)} \text{MLP}(h_{ik}) 133 | $$ 134 | 135 | This operation can be represented using two steps: 136 | 137 | 1. Apply the MLP function to each tuple's representation: 138 | 139 | ```python 140 | X' = X.tuplewiseapply(MLP) 141 | ``` 142 | 143 | 2. Perform matrix multiplication to sum over neighbors: 144 | 145 | ```python 146 | X = X' * A^T 147 | ``` 148 | 149 | In the matrix multiplication step, batching is applied to the last dimension of X. This conversion may seem straightforward, but there are several key points to consider: 150 | 151 | - Optimization for induced subgraph input: The original equation involves a sum over neighbors in the subgraph, but the matrix multiplication version includes neighbors from the entire graph. However, our implementation optimizes for induced subgraph cases, where neighbors outside the subgraph are automatically handled by setting their values to zero. 152 | 153 | - Optimization for sparse output: The operation X' * A^T may produce non-zero elements for pairs (i, j) that do not exist in the subgraph. For sparse input tensors X and A, we optimize the multiplication to avoid computing such non-existent elements. 154 | 155 | Pooling processes can also be considered as a reduction of $X$. For instance: 156 | 157 | $$ 158 | h_i=\sum_{j\in V_i}\text{MLP}_2(h_{ij}) 159 | $$ 160 | 161 | can be implemented as follows: 162 | 163 | ``` 164 | # definition 165 | self.pool = OpPoolingSubg2D(...) 166 | ... 167 | # forward 168 | Xn = self.pool(X.tuplewiseapply(MLP_1)) 169 | ``` 170 | 171 | This example demonstrate how our library's operators can be used to efficiently implement various HOGNNs. -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/SpeedIssue.md: -------------------------------------------------------------------------------- 1 | ## Speed issue 2 | 3 | You can use python -O to disable all `assert` when you are sure there is no bug. 4 | 5 | Changing the `transform` of dataset to `pre_transform` can accelerate significantly. 6 | 7 | Precompute spspmm's indice may provide some acceleration. (See the sparse data section) -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/mini_example.md: -------------------------------------------------------------------------------- 1 | # Minimal Example 2 | Let's delve into the fundamental concepts of PyGHO using a minimal example. The complete code can be found in [example/minimal.py](https://github.com/GraphPKU/PygHO/tree/main/example/minimal.py). You can execute the code with the following command: 3 | 4 | ```shell 5 | python minimal.py 6 | ``` 7 | 8 | This example demonstrates the implementation of a basic HOGNN model **Nested Graph Neural Network (NGNN)** in the paper [Nested Graph Neural Network (NGNN)](https://arxiv.org/abs/2110.13197). NGNN works by first sampling a k-hop subgraph for each node $i$ and then applying Graph Neural Networks (GNN) on all these subgraphs simultaneously. It generates a 2-D representation $H\in \mathbb{R}^{n\times n\times d}$, where $H_{ij}$ represents the representation of node $j$ in the subgraph rooted at node $i$. The message passing within all subgraphs can be expressed as: 9 | 10 | $$ 11 | h_{ij}^{t+1} \leftarrow \sum_{k\in N_i(j)} \text{MLP}(h^t_{ik}), 12 | $$ 13 | 14 | where $N_i(j)$ represents the set of neighbors of node $j$ in the subgraph rooted at $i$. After several layers of message passing, tuple representations $H$ are pooled to generate the node representations: 15 | 16 | $$ 17 | h_i = P(\{h_{ij} | j\in V_i\}). 18 | $$ 19 | 20 | This example serves as a fundamental illustration of our work. 21 | 22 | ## Dataset Preprocessing 23 | 24 | As HOGNNs share tasks with ordinary GNNs, they can utilize datasets provided by PyG. However, NGNN still needs to sample subgraphs, equivalent to providing initial features for tuple representation $h_{ij}$. You can achieve this transformation with the following code: 25 | 26 | ```python 27 | # Load an ordinary PyG dataset 28 | from torch_geometric.datasets import ZINC 29 | trn_dataset = ZINC("dataset/ZINC", subset=True, split="train") 30 | 31 | # Transform it into a High-order graph dataset 32 | from pygho.hodata import Sppretransform, ParallelPreprocessDataset 33 | trn_dataset = ParallelPreprocessDataset( 34 | "dataset/ZINC_trn", trn_dataset, 35 | pre_transform=Sppretransform(tuplesamplers=[partial(KhopSampler, hop=3)], annotate=[""], keys=keys), num_workers=8) 36 | ``` 37 | 38 | The `ParallelPreprocessDataset` class takes a standard PyG dataset as input and performs transformations on each graph in parallel, utilizing 8 processes in this example. The `tuplesamplers` parameter represents functions that take a graph as input and produce a sparse tensor. In this example, we use `partial(KhopSampler, hop=3)`, a sampler designed for NGNN, to sample a 3-hop ego-network rooted at each node. The shortest path distance to the root node serves as the tuple features. The produced SparseTensor is then saved and can be effectively used to initialize tuple representations. The `keys` variable is a list of strings indicating the required precomputation, which can be automatically generated after defining a model: 39 | 40 | ```python 41 | from pygho.honn.SpOperator import parse_precomputekey 42 | keys = parse_precomputekey(model) 43 | ``` 44 | 45 | ## Mini-batch and DataLoader 46 | 47 | Enabling batch training in HOGNNs requires handling graphs of varying sizes, which can be challenging. This library concatenates the SparseTensors of each graph along the diagonal of a larger tensor. For instance, in a batch of $B$ graphs with adjacency matrices $A_i\in \mathbb{R}^{n_i\times n_i}$, node features $x\in \mathbb{R}^{n_i\times d}$, and tuple features $X\in \mathbb{R}^{n_i\times n_i\times d'}$ for $i=1,2,\ldots,B$, the features for the entire batch are represented as $A\in \mathbb{R}^{n\times n}$, $x\in \mathbb{R}^{n\times d}$, and $X\in \mathbb{R}^{n\times n\times d'}$, where $n=\sum_{i=1}^B n_i$. The concatenation is as follows: 48 | 49 | $$ 50 | A=\begin{bmatrix} 51 | A_1&0&0&\cdots &0\\ 52 | 0&A_2&0&\cdots &0\\ 53 | 0&0&A_3&\cdots &0\\ 54 | \vdots&\vdots&\vdots&\vdots&\vdots\\ 55 | 0&0&0&\cdots&A_B 56 | \end{bmatrix} 57 | ,x=\begin{bmatrix} 58 | x_1\\ 59 | x_2\\ 60 | x_3\\ 61 | \vdots\\ 62 | x_B 63 | \end{bmatrix} 64 | ,X=\begin{bmatrix} 65 | X_1&0&0&\cdots &0\\ 66 | 0&X_2&0&\cdots &0\\ 67 | 0&0&X_3&\cdots &0\\ 68 | \vdots&\vdots&\vdots&\vdots&\vdots\\ 69 | 0&0&0&\cdots&X_B 70 | \end{bmatrix} 71 | $$ 72 | 73 | We provide our DataLoader as part of PygHO. It has compatible parameters with PyTorch's DataLoader and combines sparse tensors for different graphs: 74 | 75 | ```python 76 | from pygho.subgdata import SpDataloader 77 | trn_dataloader = SpDataloader(trn_dataset, batch_size=128, shuffle=True, drop_last=True) 78 | ``` 79 | 80 | Using this DataLoader is similar to an ordinary PyG DataLoader: 81 | 82 | ```python 83 | for batch in dataloader: 84 | batch = batch.to(device, non_blocking=True) 85 | ``` 86 | 87 | However, in addition to PyG batch attributes (like `edge_index`, `x`, `batch`), this batch also contains a SparseTensor adjacency matrix `A` and initial tuple feature SparseTensor `X`. 88 | 89 | ### Learning Methods on Graphs 90 | 91 | To execute message passing on each subgraph simultaneously, you can utilize the NGNNConv in our library: 92 | 93 | ```python 94 | # Definition 95 | self.subggnns = nn.ModuleList([ 96 | NGNNConv(hiddim, hiddim, "sum", "SS", mlp) 97 | for _ in range(num_layer) 98 | ]) 99 | 100 | ... 101 | # Forward pass 102 | for conv in self.subggnns: 103 | tX = conv.forward(A, X, datadict) 104 | X = X.add(tX, True) 105 | ``` 106 | Here, `A` and `X` are SparseTensors representing the adjacency matrix and tuple representation, respectively. `X.add` implements a residual connection. 107 | 108 | We also provide other convolution layers, including [GNNAK](https://arxiv.org/abs/2110.03753 109 | ), [DSSGNN](https://arxiv.org/abs/2110.02910), [SSWL](https://arxiv.org/abs/2302.07090 110 | ), [PPGN](https://arxiv.org/abs/1905.11136), [SUN](https://arxiv.org/abs/2206.11140), [I2GNN](https://arxiv.org/abs/2210.13978 111 | ), in [pygho.honn.Conv](??). -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. PyGHO documentation master file, created by 2 | sphinx-quickstart on Fri Sep 15 13:55:23 2023. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | :github_url: https://github.com/GraphPKU/PygHO 7 | 8 | PyTorch Geometric High Order Documentation 9 | ========================================== 10 | 11 | PygHO is a library for high-order GNN. Ordinary GNNs, like GCN, GIN, GraphSage, all pass messages between nodes and produce node representations. The node representation forms a dense matrix of shape $(n, d)$, where $n$ is the number of nodes and $d$ is the hidden dimension. Existing libraries like PyG can easily implement them. 12 | 13 | In constrast, higher-order GNNs (HOGNNs) use node tuples as the message passing unit and produce representations for the tuples. The tuple representation can be of shape $(n, n, d)$, $(n, n, n, d)$, and even more dimensions. Furthermore, to reduce complexity, the representation can be sparse. PyGHO is the first unified library for HOGNNs. 14 | 15 | .. code-block:: latex 16 | 17 | >@inproceedings{PyGHO, 18 | author = {Xiyuan Wang and Muhan Zhang}, 19 | title = {{PyGHO, a Library for High Order Graph Neural Networks}}, 20 | year = {2023}, 21 | } 22 | 23 | .. toctree:: 24 | :glob: 25 | :maxdepth: 2 26 | :caption: Notes 27 | 28 | notes/installation 29 | notes/miniexample 30 | notes/datastructure 31 | notes/hodata 32 | notes/operator 33 | 34 | .. toctree:: 35 | :glob: 36 | :maxdepth: 2 37 | :caption: Advanced Tutorial 38 | 39 | notes/multtensor 40 | 41 | .. toctree:: 42 | :glob: 43 | :maxdepth: 2 44 | :caption: Package Reference 45 | 46 | modules/backend 47 | modules/hodata 48 | modules/honn 49 | -------------------------------------------------------------------------------- /docs/source/modules/backend.rst: -------------------------------------------------------------------------------- 1 | pygho.backend package 2 | ===================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | pygho.backend.MaTensor module 8 | ----------------------------- 9 | 10 | .. automodule:: pygho.backend.MaTensor 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | pygho.backend.Mamamm module 16 | --------------------------- 17 | 18 | .. automodule:: pygho.backend.Mamamm 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | pygho.backend.SpTensor module 24 | ----------------------------- 25 | 26 | .. automodule:: pygho.backend.SpTensor 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | pygho.backend.Spmamm module 32 | --------------------------- 33 | 34 | .. automodule:: pygho.backend.Spmamm 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | pygho.backend.Spmm module 40 | ------------------------- 41 | 42 | .. automodule:: pygho.backend.Spmm 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | pygho.backend.Spspmm module 48 | --------------------------- 49 | 50 | .. automodule:: pygho.backend.Spspmm 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | pygho.backend.utils module 56 | -------------------------- 57 | 58 | .. automodule:: pygho.backend.utils 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | Module contents 64 | --------------- 65 | 66 | .. automodule:: pygho.backend 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | -------------------------------------------------------------------------------- /docs/source/modules/hodata.rst: -------------------------------------------------------------------------------- 1 | pygho.hodata package 2 | ==================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | pygho.hodata.MaData module 8 | -------------------------- 9 | 10 | .. automodule:: pygho.hodata.MaData 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | pygho.hodata.MaTupleSampler module 16 | ---------------------------------- 17 | 18 | .. automodule:: pygho.hodata.MaTupleSampler 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | pygho.hodata.ParallelPreprocess module 24 | -------------------------------------- 25 | 26 | .. automodule:: pygho.hodata.ParallelPreprocess 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | pygho.hodata.SpData module 32 | -------------------------- 33 | 34 | .. automodule:: pygho.hodata.SpData 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | pygho.hodata.SpTupleSampler module 40 | ---------------------------------- 41 | 42 | .. automodule:: pygho.hodata.SpTupleSampler 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | pygho.hodata.Wrapper module 48 | --------------------------- 49 | 50 | .. automodule:: pygho.hodata.Wrapper 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | Module contents 56 | --------------- 57 | 58 | .. automodule:: pygho.hodata 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | -------------------------------------------------------------------------------- /docs/source/modules/honn.rst: -------------------------------------------------------------------------------- 1 | pygho.honn package 2 | ================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | pygho.honn.Conv module 8 | ---------------------- 9 | 10 | .. automodule:: pygho.honn.Conv 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | pygho.honn.MaOperator module 16 | ---------------------------- 17 | 18 | .. automodule:: pygho.honn.MaOperator 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | pygho.honn.SpOperator module 24 | ---------------------------- 25 | 26 | .. automodule:: pygho.honn.SpOperator 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | pygho.honn.TensorOp module 32 | -------------------------- 33 | 34 | .. automodule:: pygho.honn.TensorOp 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | pygho.honn.utils module 40 | ----------------------- 41 | 42 | .. automodule:: pygho.honn.utils 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | Module contents 48 | --------------- 49 | 50 | .. automodule:: pygho.honn 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | -------------------------------------------------------------------------------- /docs/source/notes/datastructure.rst: -------------------------------------------------------------------------------- 1 | Basic Data Structure 2 | ============================ 3 | 4 | In this section, we'll provide a refined explanation of the basic data 5 | structures, MaskedTensor and SparseTensor, used in HOGNNs to address 6 | their unique requirements. 7 | 8 | MaskedTensor 9 | ------------ 10 | 11 | HOGNNs demand specialized data structures to handle high-order tensors 12 | efficiently. One such structure is the **MaskedTensor**, consisting of 13 | two components: ``data`` and ``mask``. 14 | 15 | - ``data`` has a shape of 16 | :math:`(\text{masked shape}, \text{dense shape})`, residing in 17 | :math:`\mathbb{R}^{n\times n\times d}`, where :math:`n` represents 18 | the number of nodes, and :math:`d` is the dimensionality of the data. 19 | - ``mask`` has a shape of :math:`(\text{masked shape})`, containing 20 | Boolean values, typically :math:`\{0,1\}^{n\times n}`. The element 21 | :math:`(i,j)` in ``mask`` is set to :math:`1` if the tuple 22 | :math:`(i,j)` exists in the tensor. 23 | 24 | Unused elements in ``data`` do not affect the output of the operators in 25 | this library. For example, when performing operations like summation, 26 | MaskedTensor treats the non-existing elements as :math:`0`, effectively 27 | ignoring them. 28 | 29 | Here's an example of creating a MaskedTensor: 30 | 31 | .. code:: python 32 | 33 | from pygho import MaskedTensor 34 | import torch 35 | 36 | n, d = 3, 3 37 | data = torch.tensor([[4, 1, 4], [4, 4, 2], [3, 4, 4]]) 38 | mask = torch.tensor([[0, 1, 0], [0, 0, 1], [1, 0, 0]], dtype=torch.bool) 39 | A = MaskedTensor(data, mask) 40 | 41 | The non-existing elements in ``data`` can be assigned arbitrary values. 42 | The created masked tensor is as follows 43 | 44 | .. math:: 45 | 46 | 47 | \begin{bmatrix} 48 | -&1&-\\ 49 | -&-&2\\ 50 | 3&-&- 51 | \end{bmatrix} 52 | 53 | SparseTensor 54 | ------------ 55 | 56 | On the other hand, the **SparseTensor** stores only existing elements, 57 | making it more efficient when a small ratio of valid elements is 58 | present. A SparseTensor, with shape ``(sparse_shape, dense_shape)``, 59 | consists of two tensors: ``indices`` and ``values``. 60 | 61 | - ``indices`` is an Integer Tensor with shape ``(sparse_dim, nnz)``, 62 | where ``sparse_dim`` represents the number of dimensions in the 63 | sparse shape, and ``nnz`` stands for the count of existing elements. 64 | - ``values`` has a shape of ``(nnz, dense_shape)``. 65 | 66 | The columns of ``indices`` and rows of ``values`` correspond to the 67 | non-zero elements, simplifying retrieval and manipulation of the 68 | required information. 69 | 70 | For instance, in the context of NGNN's representation 71 | :math:`H\in \mathbb{R}^{n\times n\times d}`, where the total number of 72 | nodes in subgraphs is :math:`m`, you can represent :math:`H` using 73 | ``indices`` :math:`a\in \mathbb{N}^{2\times m}` and ``values`` 74 | :math:`v\in \mathbb{R}^{m\times d}`. Specifically, for 75 | :math:`i=1,2,\ldots,n`, :math:`H_{a_{1,i},a_{2,i}}=v_i`. 76 | 77 | Creating a SparseTensor is illustrated in the following example: 78 | 79 | .. code:: python 80 | 81 | from pygho import SparseTensor 82 | import torch 83 | 84 | n, d = 3, 3 85 | indices = torch.tensor([[0, 1, 2], [1, 2, 0]], dtype=torch.long) 86 | values = torch.tensor([1, 2, 3]) 87 | A = SparseTensor(indices, values, shape=(3, 3)) 88 | 89 | representing the following matrix 90 | 91 | .. math:: 92 | 93 | 94 | \begin{bmatrix} 95 | -&1&-\\ 96 | -&-&2\\ 97 | 3&-&- 98 | \end{bmatrix} 99 | 100 | Please note that in the SparseTensor format, each non-zero element is 101 | represented by ``sparse_dim`` int64 indices and the element itself. If 102 | the tensor is not sparse enough, SparseTensor may occupy more memory 103 | than a dense tensor. 104 | -------------------------------------------------------------------------------- /docs/source/notes/hodata.rst: -------------------------------------------------------------------------------- 1 | .. _hodata-label: 2 | 3 | Efficient High Order Data Processing 4 | ==================================== 5 | 6 | In this section, we'll delve into the efficient high-order data 7 | processing capabilities provided by PyGHO, particularly focusing on the 8 | handling of high-order tensors, tuple feature precomputation, and data 9 | loading strategies for both sparse and masked tensor data structures. 10 | 11 | Adding High Order Features to PyG Dataset 12 | ----------------------------------------- 13 | 14 | HOGNNs and MPNNs share common tasks, allowing us to leverage PyTorch 15 | Geometric's (PyG) data processing routines. However, to cater to the 16 | unique requirements of HOGNNs, PyGHO significantly extends these 17 | routines while maintaining compatibility with PyG. This extension 18 | ensures convenient high-order feature precomputation and preservation. 19 | 20 | Efficient High-Order Feature Precomputation 21 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 22 | 23 | High-order feature precomputation can be efficiently conducted in 24 | parallel using the PyGHO library. Consider the following example: 25 | 26 | .. code:: python 27 | 28 | # Ordinary PyG dataset 29 | from torch_geometric.datasets import ZINC 30 | trn_dataset = ZINC("dataset/ZINC", subset=True, split="train") 31 | # High-order graph dataset 32 | from pygho.hodata import Sppretransform, ParallelPreprocessDataset 33 | trn_dataset = ParallelPreprocessDataset( 34 | "dataset/ZINC_trn", trn_dataset, 35 | pre_transform=Sppretransform(tuplesamplers=partial(KhopSampler, hop=3), keys=keys), num_workers=8) 36 | 37 | The ``ParallelPreprocessDataset`` class takes an ordinary PyG dataset as 38 | input and performs transformations on each graph in parallel (utilizing 39 | 8 processes in this example). The ``tuplesamplers`` parameter represents 40 | functions that take a graph as input and produce a sparse tensor. You 41 | can apply multiple samplers simultaneously, and the resulting output can 42 | be assigned specific names using the ``annotate`` parameter. In this 43 | example, we utilize ``partial(KhopSampler, hop=3)``, a sampler designed 44 | for NGNN, to sample a 3-hop ego-network rooted at each node. The 45 | shortest path distance to the root node serves as the tuple features. 46 | The produced SparseTensor is then saved and can be effectively used to 47 | initialize tuple representations. 48 | 49 | Since the dataset preprocessing routine is closely related to data 50 | structures, we have designed two separate routines for sparse and dense 51 | tensors. These routines only differ in the ``pre_transform`` function. 52 | For dense tensors, we can simply use 53 | ``Mapretransform(None, tuplesamplers)``. In this case, the 54 | ``tuplesamplers`` is a function producing dense tuple features. In :py:mod:`pygho.hodata.MaTupleSampler` We provide ``spdsampler`` and 55 | ``rdsampler`` to compute shortest path distance and resisitance distance 56 | between nodes. One example is 57 | 58 | .. code:: python 59 | 60 | trn_dataset = ParallelPreprocessDataset("dataset/ZINC_trn", 61 | trn_dataset, 62 | pre_transform=Mapretransform( 63 | partial(spdsampler, 64 | hop=4)), 65 | num_worker=0) 66 | 67 | Defining Custom Tuple Samplers 68 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 69 | 70 | In addition to the provided tuple samplers, you can define your own 71 | tuple sampler. For sparse data, a sampler is a function or callable 72 | object that takes a ``torch_geometric.data.Data`` object as input and 73 | produces a sparse tensor as output. Here's an example of a custom sparse 74 | tuple sampler that assigns ``0`` as a feature for each tuple ``(i, i)`` 75 | in the graph: 76 | 77 | .. code:: python 78 | 79 | def SparseToySampler(data: PygData) -> SparseTensor: 80 | """ 81 | Sample k-hop subgraph on a given PyG graph. 82 | 83 | Args: 84 | 85 | - data (PygData): The input PyG data. 86 | 87 | Returns: 88 | 89 | - SparseTensor for the precomputed tuple features. 90 | """ 91 | n = data.num_nodes 92 | tupleid = torch.stack((torch.arange(n), torch.arange(n))) 93 | tuplefeat = torch.zeros((n,)) 94 | ret = SparseTensor(tupleid, tuplefeat, shape=(n, n)) 95 | return ret 96 | 97 | For dense data, a sampler is a function or callable object that takes a 98 | ``torch_geometric.data.Data`` object as input and produces a tensor 99 | along with the masked shape of the features. Here's a custom dense tuple 100 | sampler that assigns ``0`` as a feature for each tuple ``(i, i)`` in the 101 | graph: 102 | 103 | .. code:: python 104 | 105 | def DenseToySampler(data: PygData) -> Tuple[Tensor, List[int]]: 106 | """ 107 | Sample k-hop subgraph on a given PyG graph. 108 | 109 | Args: 110 | 111 | - data (PygData): The input PyG data. 112 | 113 | Returns: 114 | 115 | - Tensor: The precomputed tuple features. 116 | - List[int]: The masked shape of the features. 117 | """ 118 | n = data.num_nodes 119 | val = torch.eye(n) 120 | return val, [n, n] 121 | 122 | Please note that for dense data, the function returns a tuple consisting 123 | of the value and the masked shape, as opposed to returning a 124 | MaskedTensor. This is because the mask can typically be inferred from 125 | the feature itself, making it unnecessary to explicitly include it in 126 | the returned data. In such cases, the mask can be determined as val == 127 | 1 . 128 | 129 | Using Multiple Tuple Samplers 130 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 131 | 132 | You can use multiple tuple samplers simultaneously. For instance: 133 | 134 | .. code:: python 135 | 136 | trn_dataset = ParallelPreprocessDataset( 137 | "dataset/ZINC_trn", trn_dataset, 138 | pre_transform=Sppretransform(tuplesamplers=[partial(KhopSampler, hop=1),partial(KhopSampler, hop=2)], annotate=["1hop", "2hop"], keys=keys), num_workers=8) 139 | 140 | This code precomputes two tuple features simultaneously and assigns them 141 | different annotations, "1hop" and "2hop," to distinguish between them. 142 | 143 | For dense, it works similarly 144 | 145 | .. code:: python 146 | 147 | trn_dataset = ParallelPreprocessDataset( 148 | "dataset/ZINC_trn", 149 | trn_dataset, 150 | pre_transform=Mapretransform( 151 | [partial(spdsampler,hop=1),partial(spdsampler,hop=2)], 152 | annotate=["1hop","2hop"]), 153 | num_worker=0) 154 | 155 | Sparse-Sparse Matrix Multiplication Precomputation 156 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 157 | 158 | Efficient Sparse-Sparse Matrix Multiplication in our library can be 159 | achieved through precomputation. The ``keys`` parameter in 160 | ``Sppretransform`` is a list of strings, where each string indicates a 161 | specific precomputation. For example, consider the key: 162 | 163 | :: 164 | 165 | "X___A___1___X___0" 166 | 167 | Here, the precomputation involves sparse matrix multiplication 168 | :math:`AX`, but only computes the output elements that exist in 169 | :math:`X`. These precomputation results can be shared among matrices 170 | with the same indices. The key elements signify the following: 171 | 172 | - The first ``X`` refers to the target sparse matrix indices. 173 | - ``A`` and ``X`` represent the matrices involved in the multiplication, the adjacency matrix ``A``, and the tuple feature ``X``. 174 | - ``1`` denotes that dimension ``1`` of ``A`` will be reduced. 175 | - ``0`` signifies that dimension ``0`` of ``X`` will be reduced. 176 | 177 | You don't need to manually feed the precomputation results to the model. 178 | Converting the batch to a dictionary and using it as the ``datadict`` 179 | parameter is sufficient: 180 | 181 | .. code:: python 182 | 183 | for batch in dataloader: 184 | datadict = batch.to_dict() 185 | 186 | Dense data does not require precomputation currently. 187 | 188 | If you use annotate in transformation, for example, 189 | 190 | :: 191 | 192 | Sppretransform(tuplesamplers=partial(KhopSampler, hop=1),annotate=["1hop"], keys=keys) 193 | 194 | Then the key can be 195 | 196 | :: 197 | 198 | "X1hop___A___1___X1hop___0" 199 | 200 | More details are shown in :ref:`multi-tensor-tutorial-label` 201 | Mini-batch and DataLoader 202 | 203 | Enabling batch training in HOGNNs demands handling graphs of varying 204 | sizes, which presents a challenge. We employ different strategies for 205 | Sparse and Masked Tensor data structures. 206 | 207 | Sparse Tensor Data 208 | ~~~~~~~~~~~~~~~~~~ 209 | 210 | For Sparse Tensor data, we adopt a relatively straightforward solution. 211 | We concatenate the tensors of each graph along the diagonal of a larger 212 | tensor. For example, in a batch of :math:`B` graphs with adjacency 213 | matrices :math:`A_i\in \mathbb{R}^{n_i\times n_i}`, node features 214 | :math:`x\in \mathbb{R}^{n_i\times d}`, and tuple features 215 | :math:`X\in \mathbb{R}^{n_i\times n_i\times d'}` for 216 | :math:`i=1,2,\ldots,B`, the features for the entire batch are 217 | represented as :math:`A\in \mathbb{R}^{n\times n}`, 218 | :math:`x\in \mathbb{R}^{n\times d}`, and 219 | :math:`X\in \mathbb{R}^{n\times n\times d'}`, where 220 | :math:`n=\sum_{i=1}^B n_i`. This arrangement allows tensors in batched 221 | 222 | data to have the same number of dimensions as those of a single graph, 223 | facilitating the sharing of common operators. 224 | 225 | We provide PygHO's own DataLoader to simplify this process: 226 | 227 | .. code:: python 228 | 229 | from pygho.subgdata import SpDataloader 230 | trn_dataloader = SpDataloader(trn_dataset, batch_size=32, shuffle=True, drop_last=True) 231 | 232 | Masked Tensor Data 233 | ~~~~~~~~~~~~~~~~~~ 234 | 235 | As concatenation along the diagonal leads to a lot of non-existing 236 | elements, handling Masked Tensor data involves a different strategy for 237 | saving space. In this case, tensors are padded to the same shape and 238 | stacked along a new axis. For instance, in a batch of :math:`B` graphs 239 | with adjacency matrices :math:`A_i\in \mathbb{R}^{n_i\times n_i}`, node 240 | features :math:`x\in \mathbb{R}^{n_i\times d}`, and tuple features 241 | :math:`X\in \mathbb{R}^{n_i\times n_i\times d'}` for 242 | :math:`i=1,2,\ldots,B`, the features for the entire batch are 243 | represented as 244 | :math:`A\in \mathbb{R}^{B\times \tilde{n}\times \tilde{n}}`, 245 | :math:`x\in \mathbb{R}^{B\times \tilde{n}\times d}`, and 246 | :math:`X\in \mathbb{R}^{B\times \tilde{n}\times \tilde{n}\times d'}`, 247 | where :math:`\tilde{n}=\max\{n_i|i=1,2,\ldots,B\}`. 248 | 249 | .. math:: 250 | 251 | 252 | A=\begin{bmatrix} 253 | \begin{pmatrix} 254 | A_1&0_{n_1,\tilde n-n_1}\\ 255 | 0_{\tilde n-n_1, n_1}&0_{n_1,n_1}\\ 256 | \end{pmatrix}\\ 257 | \begin{pmatrix} 258 | A_2&0_{n_2,\tilde n-n_2}\\ 259 | 0_{\tilde n-n_2, n_2}&0_{n_2,n_2}\\ 260 | \end{pmatrix}\\ 261 | \vdots\\ 262 | \begin{pmatrix} 263 | A_B&0_{n_B,\tilde n-n_B}\\ 264 | 0_{\tilde n-n_B, n_B}&0_{n_B,n_B}\\ 265 | \end{pmatrix}\\ 266 | \end{bmatrix} 267 | ,x=\begin{bmatrix} 268 | \begin{pmatrix} 269 | x_1\\ 270 | 0_{\tilde n-n_1, d}\\ 271 | \end{pmatrix}\\ 272 | \begin{pmatrix} 273 | x_2\\ 274 | 0_{\tilde n-n_2, d}\\ 275 | \end{pmatrix}\\ 276 | \vdots\\ 277 | \begin{pmatrix} 278 | x_B\\ 279 | 0_{\tilde n-n_B, d}\\ 280 | \end{pmatrix}\\ 281 | \end{bmatrix} 282 | ,X=\begin{bmatrix} 283 | \begin{pmatrix} 284 | X_1&0_{n_1,\tilde n-n_1}\\ 285 | 0_{\tilde n-n_1, n_1}&0_{n_1,n_1}\\ 286 | \end{pmatrix}\\ 287 | \begin{pmatrix} 288 | X_2&0_{n_2,\tilde n-n_2}\\ 289 | 0_{\tilde n-n_2, n_2}&0_{n_2,n_2}\\ 290 | \end{pmatrix}\\ 291 | \vdots\\ 292 | \begin{pmatrix} 293 | X_B&0_{n_B,\tilde n-n_B}\\ 294 | 0_{\tilde n-n_B, n_B}&0_{n_B,n_B}\\ 295 | \end{pmatrix}\\ 296 | \end{bmatrix} 297 | 298 | The 0 for padding will be masked in the result MaskedTensor. 299 | 300 | We also provide a DataLoader for this purpose: 301 | 302 | .. code:: python 303 | 304 | from pygho.subgdata import MaDataloader 305 | trn_dataloader = MaDataloader(trn_dataset, batch_size=256, device=device, shuffle=True, drop_last=True) 306 | 307 | This padding and stacking strategy ensures consistent shapes across 308 | tensors, allowing for efficient processing of dense data. 309 | 310 | -------------------------------------------------------------------------------- /docs/source/notes/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | The installation of PyGHO requires pytorch>=2.0 and pytorch_geometric>=2.3. Please refer to their official site for the installation. 5 | 6 | Once the required versions of PyTorch and PyTorch Geometric are installed, simply run: 7 | 8 | .. code-block:: bash 9 | 10 | $ git clone https://github.com/GraphPKU/PygHO.git 11 | $ cd PygHO 12 | $ pip install -e . 13 | 14 | To update PyGHO, simply run 15 | 16 | .. code-block:: bash 17 | 18 | $ git pull 19 | 20 | -------------------------------------------------------------------------------- /docs/source/notes/miniexample.rst: -------------------------------------------------------------------------------- 1 | .. _miniexample-label: 2 | 3 | Minimal Example 4 | =============== 5 | 6 | Let's delve into the fundamental concepts of PyGHO using a minimal 7 | example. The complete code can be found in 8 | `example/minimal.py `__. 9 | You can execute the code with the following command: 10 | 11 | .. code:: shell 12 | 13 | python minimal.py 14 | 15 | This example demonstrates the implementation of a basic HOGNN model 16 | **Nested Graph Neural Network (NGNN)** in the paper `Nested Graph Neural 17 | Network (NGNN) `__. NGNN works by 18 | first sampling a k-hop subgraph for each node :math:`i` and then 19 | applying Graph Neural Networks (GNN) on all these subgraphs 20 | simultaneously. It generates a 2-D representation 21 | :math:`H\in \mathbb{R}^{n\times n\times d}`, where :math:`H_{ij}` 22 | represents the representation of node :math:`j` in the subgraph rooted 23 | at node :math:`i`. The message passing within all subgraphs can be 24 | expressed as: 25 | 26 | .. math:: 27 | 28 | 29 | h_{ij}^{t+1} \leftarrow \sum_{k\in N_i(j)} \text{MLP}(h^t_{ik}), 30 | 31 | where :math:`N_i(j)` represents the set of neighbors of node :math:`j` 32 | in the subgraph rooted at :math:`i`. After several layers of message 33 | passing, tuple representations :math:`H` are pooled to generate the node 34 | representations: 35 | 36 | .. math:: 37 | 38 | 39 | h_i = P(\{h_{ij} | j\in V_i\}). 40 | 41 | This example serves as a fundamental illustration of our work. 42 | 43 | Dataset Preprocessing 44 | --------------------- 45 | 46 | As HOGNNs share tasks with ordinary GNNs, they can utilize datasets 47 | provided by PyG. However, NGNN still needs to sample subgraphs, 48 | equivalent to providing initial features for tuple representation 49 | :math:`h_{ij}`. You can achieve this transformation with the following 50 | code: 51 | 52 | .. code:: python 53 | 54 | # Load an ordinary PyG dataset 55 | from torch_geometric.datasets import ZINC 56 | trn_dataset = ZINC("dataset/ZINC", subset=True, split="train") 57 | 58 | # Transform it into a High-order graph dataset 59 | from pygho.hodata import Sppretransform, ParallelPreprocessDataset 60 | trn_dataset = ParallelPreprocessDataset( 61 | "dataset/ZINC_trn", trn_dataset, 62 | pre_transform=Sppretransform(tuplesamplers=[partial(KhopSampler, hop=3)], annotate=[""], keys=keys), num_workers=8) 63 | 64 | The ``ParallelPreprocessDataset`` class takes a standard PyG dataset as 65 | input and performs transformations on each graph in parallel, utilizing 66 | 8 processes in this example. The ``tuplesamplers`` parameter represents 67 | functions that take a graph as input and produce a sparse tensor. In 68 | this example, we use ``partial(KhopSampler, hop=3)``, a sampler designed 69 | for NGNN, to sample a 3-hop ego-network rooted at each node. The 70 | shortest path distance to the root node serves as the tuple features. 71 | The produced SparseTensor is then saved and can be effectively used to 72 | initialize tuple representations. The ``keys`` variable is a list of 73 | strings indicating the required precomputation, which can be 74 | automatically generated after defining a model: 75 | 76 | .. code:: python 77 | 78 | from pygho.honn.SpOperator import parse_precomputekey 79 | keys = parse_precomputekey(model) 80 | 81 | Mini-batch and DataLoader 82 | ------------------------- 83 | 84 | Enabling batch training in HOGNNs requires handling graphs of varying 85 | sizes, which can be challenging. This library concatenates the 86 | SparseTensors of each graph along the diagonal of a larger tensor. For 87 | instance, in a batch of :math:`B` graphs with adjacency matrices 88 | :math:`A_i\in \mathbb{R}^{n_i\times n_i}`, node features 89 | :math:`x\in \mathbb{R}^{n_i\times d}`, and tuple features 90 | :math:`X\in \mathbb{R}^{n_i\times n_i\times d'}` for 91 | :math:`i=1,2,\ldots,B`, the features for the entire batch are 92 | represented as :math:`A\in \mathbb{R}^{n\times n}`, 93 | :math:`x\in \mathbb{R}^{n\times d}`, and 94 | :math:`X\in \mathbb{R}^{n\times n\times d'}`, where 95 | :math:`n=\sum_{i=1}^B n_i`. The concatenation is as follows: 96 | 97 | .. math:: 98 | 99 | 100 | A=\begin{bmatrix} 101 | A_1&0&0&\cdots &0\\ 102 | 0&A_2&0&\cdots &0\\ 103 | 0&0&A_3&\cdots &0\\ 104 | \vdots&\vdots&\vdots&\vdots&\vdots\\ 105 | 0&0&0&\cdots&A_B 106 | \end{bmatrix} 107 | ,x=\begin{bmatrix} 108 | x_1\\ 109 | x_2\\ 110 | x_3\\ 111 | \vdots\\ 112 | x_B 113 | \end{bmatrix} 114 | ,X=\begin{bmatrix} 115 | X_1&0&0&\cdots &0\\ 116 | 0&X_2&0&\cdots &0\\ 117 | 0&0&X_3&\cdots &0\\ 118 | \vdots&\vdots&\vdots&\vdots&\vdots\\ 119 | 0&0&0&\cdots&X_B 120 | \end{bmatrix} 121 | 122 | We provide our DataLoader as part of PygHO. It has compatible parameters 123 | with PyTorch's DataLoader and combines sparse tensors for different 124 | graphs: 125 | 126 | .. code:: python 127 | 128 | from pygho.subgdata import SpDataloader 129 | trn_dataloader = SpDataloader(trn_dataset, batch_size=128, shuffle=True, drop_last=True) 130 | 131 | Using this DataLoader is similar to an ordinary PyG DataLoader: 132 | 133 | .. code:: python 134 | 135 | for batch in dataloader: 136 | batch = batch.to(device, non_blocking=True) 137 | 138 | However, in addition to PyG batch attributes (like ``edge_index``, 139 | ``x``, ``batch``), this batch also contains a SparseTensor adjacency 140 | matrix ``A`` and initial tuple feature SparseTensor ``X``. 141 | 142 | Learning Methods on Graphs 143 | ~~~~~~~~~~~~~~~~~~~~~~~~~~ 144 | 145 | To execute message passing on each subgraph simultaneously, you can 146 | utilize the NGNNConv in our library: 147 | 148 | .. code:: python 149 | 150 | # Definition 151 | self.subggnns = nn.ModuleList([ 152 | NGNNConv(hiddim, hiddim, "sum", "SS", mlp) 153 | for _ in range(num_layer) 154 | ]) 155 | 156 | ... 157 | # Forward pass 158 | for conv in self.subggnns: 159 | tX = conv.forward(A, X, datadict) 160 | X = X.add(tX, True) 161 | 162 | Here, ``A`` and ``X`` are SparseTensors representing the adjacency 163 | matrix and tuple representation, respectively. ``X.add`` implements a 164 | residual connection. 165 | 166 | We also provide other convolution layers, including 167 | `GNNAK `__, 168 | `DSSGNN `__, 169 | `SSWL `__, 170 | `PPGN `__, 171 | `SUN `__, 172 | `I2GNN `__, in 173 | :py:mod:`pygho.honn.Conv`. 174 | -------------------------------------------------------------------------------- /docs/source/notes/multtensor.rst: -------------------------------------------------------------------------------- 1 | .. _multi-tensor-tutorial-label: 2 | 3 | Multiple Tensor 4 | =============== 5 | 6 | In our dataset preprocessing routine, the default computation involves 7 | two high-order tensors: the adjacency matrix ``A`` and the tuple feature 8 | ``X``. However, in certain scenarios, there may be a need for additional 9 | high-order tensors. For instance, when using a Nested Graph Neural 10 | Network (GNN) with a 2-hop GNN as the base GNN, Message Passing Neural 11 | Network (MPNN) operations are performed on each subgraph with an 12 | augmented adjacency matrix. In this case, two high-order tensors are 13 | required: the tuple feature and the augmented adjacency matrix. 14 | 15 | During data preprocessing, we can use multiple samplers, each 16 | responsible for sampling one tensor. For sparse data, the code might 17 | look like this: 18 | 19 | .. code:: python 20 | 21 | trn_dataset = ParallelPreprocessDataset( 22 | "dataset/ZINC_trn", trn_dataset, 23 | pre_transform=Sppretransform( 24 | tuplesamplers=[ 25 | partial(KhopSampler, hop=3), 26 | partial(KhopSampler, hop=2) 27 | ], 28 | annotate=["tuplefeat", "2hopadj"], 29 | keys=keys 30 | ), 31 | num_workers=8 32 | ) 33 | 34 | In this code, two tuple features are precomputed simultaneously and 35 | assigned different annotations: "tuplefeat" and "2hopadj" to distinguish 36 | between them. 37 | 38 | For dense data, the process is quite similar: 39 | 40 | .. code:: python 41 | 42 | trn_dataset = ParallelPreprocessDataset( 43 | "dataset/ZINC_trn", 44 | trn_dataset, 45 | pre_transform=Mapretransform( 46 | [ 47 | partial(spdsampler, hop=3), 48 | partial(spdsampler, hop=2) 49 | ], 50 | annotate=["tuplefeat", "2hopadj"] 51 | ), 52 | num_worker=0 53 | ) 54 | 55 | After passing the data through a dataloader, the batch will contain 56 | ``Xtuplefeat`` and ``X2hopadj`` as the high-order tensors that are 57 | needed. For dense models, this concludes the process. However, for 58 | sparse models, if you want to retrieve the correct keys, you will need 59 | to modify the operator symbols for sparse message passing layers. 60 | 61 | Ordinarily, the ``NGNNConv`` is defined as: 62 | 63 | .. code:: python 64 | 65 | NGNNConv(hiddim, hiddim, mlp=mlp) 66 | 67 | This is equivalent to: 68 | 69 | .. code:: python 70 | 71 | NGNNConv(hiddim, hiddim, mlp=mlp, optuplefeat="X", opadj="A") 72 | 73 | To ensure that you retrieve the correct keys, you should use: 74 | 75 | .. code:: python 76 | 77 | NGNNConv(hiddim, hiddim, mlp=mlp, optuplefeat="Xtuplefeat", opadj="X2hopadj") 78 | 79 | Similar modifications should be made for other layers as needed. 80 | -------------------------------------------------------------------------------- /docs/source/notes/operator.rst: -------------------------------------------------------------------------------- 1 | .. _operator-label: 2 | 3 | Operators 4 | ========= 5 | 6 | In this section, we'll provide a detailed introduction of some basic 7 | operators on high order tensors. 8 | 9 | Code Architecture 10 | ----------------- 11 | 12 | The code for these 13 | high-order graph neural network (HOGNN) operations is organized into 14 | three layers: 15 | 16 | **Layer 1: Backend:** This layer, found in the :py:mod:`pygho.backend` module, 17 | contains basic data structures and operations focused on tensor 18 | manipulations. It lacks graph-specific learning concepts and includes 19 | the following functionalities: 20 | 21 | - **Matrix multiplication:** This method supports general matrix 22 | multiplication capabilities, including operations on two 23 | SparseTensors, one sparse and one MaskedTensor, and two 24 | MaskedTensors. It also handles batched matrix multiplication and 25 | offers operations that replace the sum in traditional matrix 26 | multiplication with max and mean operations. 27 | - **Two matrix addition:** Operations for adding two sparse or two 28 | dense matrices. 29 | - **Reduce operations:** These operations include sum, mean, max, and 30 | min, which reduce dimensions in tensors. 31 | - **Expand operation:** This operation adds new dimensions to tensors. 32 | - **Tuplewise apply(func):** It applies a given function to the 33 | underlying data tensor. 34 | - **Diagonal apply(func):** This operation applies a function to 35 | diagonal elements of tensors. 36 | 37 | **Layer 2: Graph operations:** Building upon Layer 1, the 38 | :py:mod:`pygho.honn.SpOperator` and :py:mod:`pygho.honn.MaOperator` modules provide 39 | graph operations specifically tailored for Sparse and Masked Tensor 40 | structures. Additionally, the :py:mod:`pygho.honn.TensorOp` layer wraps these 41 | operators, abstracting away the differences between Sparse and Masked 42 | Tensor data structures. These operations encompass: 43 | 44 | - **General message passing between tuples:** Facilitating message 45 | passing between tuples of nodes. 46 | - **Pooling:** This operation reduces high-order tensors to lower-order 47 | ones by summing, taking the maximum, or computing the mean across 48 | specific dimensions. 49 | - **Diagonal:** It reduces high-order tensors to lower-order ones by 50 | extracting diagonal elements. 51 | - **Unpooling:** This operation extends low-order tensors to high-order 52 | ones. 53 | 54 | **Layer 3: Models:** Building on Layer 2, this layer provides a 55 | collection of representative high-order GNN layers, including NGNN, 56 | GNNAK, DSSGNN, SUN, SSWL, PPGN, and I2GNN. Layer 3 offers numerous 57 | ready-to-use methods, and with Layer 2, users can design additional 58 | models using general graph operations. Layer 1 allows for the 59 | development of novel operations, expanding the library's flexibility and 60 | utility. Now let's explore these layers in more detail. 61 | 62 | Layer 1: Backend 63 | ~~~~~~~~~~~~~~~~ 64 | 65 | Spspmm 66 | ^^^^^^ 67 | 68 | One of the most complex operators in this layer is sparse-sparse matrix 69 | multiplication (Spspmm). Given two sparse matrices, C and D, and 70 | assuming their output is B: 71 | 72 | .. math:: 73 | 74 | 75 | B_{ij} = \sum_k C_{ik} D_{kj} 76 | 77 | The Spspmm operator utilizes a coo format to represent elements. 78 | Assuming element C\_ik corresponds to C.values[c\_ik] and D\_kj 79 | corresponds to D.values[d\_kj], we can create a tensor ``bcd`` of 80 | shape (3, m), where m is the number of pairs (i, j, k) where both 81 | C\_ik and D\_kj exist. The multiplication can be performed as 82 | follows: 83 | 84 | .. code:: python 85 | 86 | B.values = zeros(...) 87 | for i in range(m): 88 | B.values[bcd[0, i]] += C.values[bcd[1, i]] * D.values[bcd[2, i]] 89 | 90 | This summation process can be efficiently implemented in parallel on GPU 91 | using ``torch.Tensor.scatter_reduce_``. The ``bcd`` tensor can be 92 | precomputed with :py:func:`pygho.backend.Spspmm.spspmm_ind` and shared among 93 | matrices with the same indices. 94 | 95 | Hadamard product between two sparse matrices can also be implemented, 96 | where :math:`C = A \odot B`: C can use the same indice as :math:`A`. 97 | 98 | .. code:: python 99 | 100 | C.values[a_ij] = A.values[a_ij] * B.values[b_ij] 101 | 102 | The tensor ``b2a`` can be defined, where ``b2a[b_ij] = a_ij`` if A has 103 | the element (i, j); otherwise, it is set to -1. Then, the Hadamard 104 | product can be computed as follows: 105 | 106 | .. code:: python 107 | 108 | C.values = zeros(...) 109 | for i in range(A.nnz): 110 | if b2a[i] >= 0: 111 | C.values[i] = A.values[i] * B.values[b2a[i]] 112 | 113 | The operation can also be efficiently implemented in parallel on a GPU. 114 | 115 | To compute :math:`A\odot (CD)`, you can define a tensor ``acd`` of shape 116 | (3, m') where ``acd[0] = b2a[bcd[0]]``, ``acd[1] = bcd[1]``, and 117 | ``acd[2] = bcd[2]``, and remove columns i where ``acd[0, i] = -1``. The 118 | computation can be done as follows: 119 | 120 | .. code:: python 121 | 122 | ret.values = zeros(...) 123 | for i in range(acd.shape[1]): 124 | ret.values[acd[0, i]] += A.values[acd[0, i]] * C.values[acd[1, i]] * D.values[acd[2, i]] 125 | 126 | Like the previous operations, this can also be implemented efficiently 127 | in parallel on a GPU. Additionally, by setting ``A.values[acd[0, i]]`` 128 | to 1, A can act as a mask, ensuring that only elements existing in A are 129 | computed. 130 | 131 | The overall wrapper for these functions is :py:func:`pygho.honn.Spspmm.spspmm`, 132 | which can perform sparse-sparse matrix multiplication with precomputed 133 | indices. :py:func:`pygho.honn.Spspmm.spspmpnn` provides a more complex operator 134 | that goes beyond matrix multiplication, allowing you to implement 135 | various graph operations. It can in fact implement the following 136 | framework. 137 | 138 | .. math:: 139 | 140 | 141 | ret_{ij} = \phi(\{(A_{ij}, B_{ik}, C_{kj})|B_{ik},C_{kj} \text{ elements exist}\}) 142 | 143 | 144 | where ``phi`` is a general multiset function, which is a functional 145 | parameter of ``spspmpnn``. For example, with it, we can implement GAT on each 146 | subgraph as follows. 147 | 148 | .. code:: python 149 | 150 | self.attentionnn1 = nn.Linear(hiddim, hiddim) 151 | self.attentionnn2 = nn.Linear(hiddim, hiddim) 152 | self.attentionnn3 = nn.Linear(hiddim, hiddim) 153 | self.subggnn = NGNNConv(hiddim, hiddim, args.aggr, 154 | "SS", transfermlpparam(mlp), 155 | message_func=lambda a,b,c,tarid: 156 | scatter_softmax( 157 | self.attentionnn1(a) * b * self.attention2(c), 158 | tarid) 159 | * self.attentionnn3(c)) 160 | 161 | 162 | TuplewiseApply 163 | ^^^^^^^^^^^^^^ 164 | 165 | Both Sparse and Masked Tensors have the ``tuplewiseapply`` function. The 166 | most common usage is: 167 | 168 | .. code:: python 169 | 170 | mlp = ... 171 | X.tuplewiseapply(mlp) 172 | 173 | However, in practice, this function directly applies the values or data 174 | tensor to ``mlp``. As linear layers, non-linearities, and layer 175 | normalization all operate on the last dimension, this operation is 176 | essentially equivalent to tuplewise apply. For batchnorm, we provide a 177 | version that not affected by this problem in :py:mod:`pygho.honn.utils`. 178 | 179 | DiagonalApply 180 | ^^^^^^^^^^^^^ 181 | 182 | Both Sparse and Masked Tensors have the ``diagonalapply`` function. 183 | Unlike ``tuplewiseapply``, this function passes both data/values and a 184 | mask indicating whether the corresponding elements are on the diagonal 185 | to the input function. A common use case is: 186 | 187 | .. code:: python 188 | 189 | mlp1 = ... 190 | mlp2 = ... 191 | lambda x, diagonalmask: torch.where(diagonalmask, mlp1(x), mlp2(x)) 192 | X.diagonalapply(mlp) 193 | 194 | Here, ``mlp1`` is applied to diagonal elements, and ``mlp2`` is applied 195 | to non-diagonal elements. You can also use 196 | ``torch_geometric.nn.HeteroLinear`` for a faster implementation. 197 | 198 | Layer 2: Operators 199 | ~~~~~~~~~~~~~~~~~~ 200 | 201 | :py:mod:`pygho.honn.SpOperator` and :py:mod:`pygho.honn.MaOperator` wrap the backend 202 | for SparseTensor and MaskedTensor separately. Their APIs are unified in 203 | :py:mod:`pygho.honn.TensorOp`. The basic operators include 204 | ``OpNodeMessagePassing`` (node-level message passing), 205 | ``OpMessagePassing`` (tuple-level message passing, wrapping matrix 206 | multiplication), ``OpPooling`` (reduce high-order tensors to lower-order 207 | ones by sum, mean, max), ``OpDiag`` (reduce high-order tensors to 208 | lower-order ones by extracting diagonal elements), and ``OpUnpooling`` 209 | (extend lower-order tensors to higher-order ones). Special cases are 210 | also defined. 211 | 212 | Sparse OpMessagePassing 213 | ^^^^^^^^^^^^^^^^^^^^^^^ 214 | 215 | As described in Layer 1, the ``OpMessagePassing`` operator wraps the 216 | properties of Spspmm and is defined with parameters like ``op0``, 217 | ``op1``, ``dim1``, ``op2``, ``dim2``, and ``aggr``. It retrieves 218 | precomputed data from a data dictionary during the forward process using 219 | keys like ``f"{op0}___{op1}___{dim1}___{op2}___{dim2}"``. Here's the 220 | forward method signature: 221 | 222 | .. code:: python 223 | 224 | def forward(self, 225 | A: SparseTensor, 226 | B: SparseTensor, 227 | datadict: Dict, 228 | tarX: Optional[SparseTensor] = None) -> SparseTensor: 229 | 230 | In this signature, ``tarX`` corresponds to ``op0``, providing the target 231 | indices, while ``A`` and ``B`` correspond to ``op1`` and ``op2``. The 232 | ``datadict`` can be obtained from the data loader using 233 | ``for batch in dataloader: batch.to_dict()``. 234 | 235 | Example 236 | ------- 237 | 238 | To illustrate how these operators work, let's use NGNN as an example. 239 | Although our operators can handle batched data, for simplicity, we'll 240 | focus on the single-graph case. Let H represent the representation 241 | matrix in :math:`\mathbb{R}^{n\times n\times d}`, and A denote the 242 | adjacency matrix in :math:`\mathbb{R}^{n\times n}`. The Graph 243 | Isomorphism Network (GIN) operation on all subgraphs can be defined as: 244 | 245 | .. math:: 246 | 247 | 248 | h_{ij} \leftarrow \sum_{k\in N_i(j)} \text{MLP}(h_{ik}) 249 | 250 | This operation can be represented using two steps: 251 | 252 | 1. Apply the MLP function to each tuple's representation: 253 | 254 | .. code:: python 255 | 256 | X' = X.tuplewiseapply(MLP) 257 | 258 | 2. Perform matrix multiplication to sum over neighbors: 259 | 260 | .. code:: python 261 | 262 | X = X' * A^T 263 | 264 | In the matrix multiplication step, batching is applied to the last 265 | dimension of X. This conversion may seem straightforward, but there are 266 | several key points to consider: 267 | 268 | - Optimization for induced subgraph input: The original equation 269 | involves a sum over neighbors in the subgraph, but the matrix 270 | multiplication version includes neighbors from the entire graph. 271 | However, our implementation optimizes for induced subgraph cases, 272 | where neighbors outside the subgraph are automatically handled by 273 | setting their values to zero. 274 | 275 | - Optimization for sparse output: The operation X' \* A^T may produce 276 | non-zero elements for pairs (i, j) that do not exist in the subgraph. 277 | For sparse input tensors X and A, we optimize the multiplication to 278 | avoid computing such non-existent elements. 279 | 280 | Pooling processes can also be considered as a reduction of :math:`X`. 281 | For instance: 282 | 283 | .. math:: 284 | 285 | 286 | h_i=\sum_{j\in V_i}\text{MLP}_2(h_{ij}) 287 | 288 | can be implemented as follows: 289 | 290 | :: 291 | 292 | # definition 293 | self.pool = OpPoolingSubg2D(...) 294 | ... 295 | # forward 296 | Xn = self.pool(X.tuplewiseapply(MLP_1)) 297 | 298 | This example demonstrate how our library's operators can be used to 299 | efficiently implement various HOGNNs. 300 | -------------------------------------------------------------------------------- /example/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | from torch.optim.lr_scheduler import LRScheduler 4 | 5 | class CosineAnnealingWarmRestarts(LRScheduler): 6 | 7 | def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, K = 0.0, K2 = 0.0, verbose=False): 8 | if T_mult < 1: 9 | raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult)) 10 | self.T_0 = T_0 11 | self.T_i = T_0 12 | self.T_mult = T_mult 13 | self.num_cos = 0 14 | self.K = K 15 | self.K2 = K2 16 | self.eta_min = eta_min 17 | self.T_cur = last_epoch 18 | super().__init__(optimizer, last_epoch, verbose) 19 | 20 | def get_lr(self): 21 | if not self._get_lr_called_within_step: 22 | warnings.warn("To get the last learning rate computed by the scheduler, " 23 | "please use `get_last_lr()`.", UserWarning) 24 | if self.T_0 < 1: 25 | return [base_lr for base_lr in self.base_lrs] 26 | else: 27 | return [(1/(1+self.K*self.num_cos+self.K2*self.num_cos**2))*(self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2) 28 | for base_lr in self.base_lrs] 29 | 30 | def step(self, epoch=None): 31 | if self.T_0 < 1: 32 | return 33 | if epoch is None and self.last_epoch < 0: 34 | epoch = 0 35 | 36 | if epoch is None: 37 | epoch = self.last_epoch + 1 38 | self.T_cur = self.T_cur + 1 39 | if self.T_cur >= self.T_i: 40 | self.T_cur = self.T_cur - self.T_i 41 | self.T_i = self.T_i * self.T_mult 42 | self.num_cos += 1 43 | else: 44 | if epoch < 0: 45 | raise ValueError("Expected non-negative epoch, but got {}".format(epoch)) 46 | if epoch >= self.T_0: 47 | if self.T_mult == 1: 48 | self.T_cur = epoch % self.T_0 49 | else: 50 | n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult)) 51 | self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1) 52 | self.T_i = self.T_0 * self.T_mult ** (n) 53 | else: 54 | self.T_i = self.T_0 55 | self.T_cur = epoch 56 | self.last_epoch = math.floor(epoch) 57 | 58 | class _enable_get_lr_call: 59 | 60 | def __init__(self, o): 61 | self.o = o 62 | 63 | def __enter__(self): 64 | self.o._get_lr_called_within_step = True 65 | return self 66 | 67 | def __exit__(self, type, value, traceback): 68 | self.o._get_lr_called_within_step = False 69 | return self 70 | 71 | with _enable_get_lr_call(self): 72 | for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())): 73 | param_group, lr = data 74 | param_group['lr'] = lr 75 | self.print_lr(self.verbose, i, lr, epoch) 76 | 77 | self._last_lr = [group['lr'] for group in self.optimizer.param_groups] -------------------------------------------------------------------------------- /example/minimal.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch import Tensor 3 | from functools import partial 4 | import time 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from torch_geometric.datasets import ZINC 11 | from pygho import SparseTensor 12 | from pygho.hodata import SpDataloader, Sppretransform, ParallelPreprocessDataset 13 | from pygho.hodata.SpTupleSampler import KhopSampler 14 | from pygho.honn.SpOperator import parse_precomputekey 15 | from pygho.backend.utils import torch_scatter_reduce 16 | 17 | from pygho.honn.Conv import NGNNConv 18 | from pygho.honn.TensorOp import OpPoolingSubg2D 19 | from pygho.honn.utils import MLP 20 | 21 | 22 | class InputEncoderSp(nn.Module): 23 | 24 | def __init__(self, hiddim: int) -> None: 25 | super().__init__() 26 | self.x_encoder = nn.Embedding(32, hiddim) 27 | self.ea_encoder = nn.Embedding(16, hiddim) 28 | self.tuplefeat_encoder = nn.Embedding(16, hiddim) 29 | 30 | def forward(self, datadict: dict) -> dict: 31 | datadict["x"] = self.x_encoder(datadict["x"].flatten()) 32 | datadict["A"] = datadict["A"].tuplewiseapply(self.ea_encoder) 33 | datadict["X"] = datadict["X"].tuplewiseapply(self.tuplefeat_encoder) 34 | return datadict 35 | 36 | 37 | class SpModel(nn.Module): 38 | 39 | def __init__(self, num_tasks=1, num_layer=6, hiddim=128, mlp: dict = {}): 40 | ''' 41 | num_tasks (int): number of output dimensions 42 | npool: node level pooling 43 | lpool: subgraph pooling 44 | aggr: aggregation scheme in MPNN on each subgraph 45 | ln_out: use layernorm in output, 46 | a normalization method for classification problem 47 | ''' 48 | 49 | super().__init__() 50 | 51 | self.lin_tupleinit0 = nn.Linear(hiddim, hiddim) 52 | self.lin_tupleinit1 = nn.Linear(hiddim, hiddim) 53 | 54 | self.npool = "sum" 55 | self.lpool = OpPoolingSubg2D("S", "mean") 56 | self.poolmlp = MLP(hiddim, hiddim, 1, tailact=True, **mlp) 57 | self.data_encoder = InputEncoderSp(hiddim) 58 | 59 | self.pred_lin = MLP(hiddim, num_tasks, 2, tailact=False, **mlp) 60 | 61 | mlp.update({"numlayer": 1, "tailact": True}) 62 | self.subggnns = nn.ModuleList([ 63 | NGNNConv(hiddim, hiddim, "sum", "SS", mlp) 64 | for _ in range(num_layer) 65 | ]) 66 | 67 | def tupleinit(self, X: SparseTensor, x: Tensor): 68 | subgx0 = X.unpooling_fromdense1dim(0, self.lin_tupleinit0(x)) 69 | subgx1 = X.unpooling_fromdense1dim(1, self.lin_tupleinit1(x)) 70 | return X.tuplewiseapply(lambda val: subgx0.values * subgx1.values * val) 71 | 72 | def forward(self, datadict: dict): 73 | datadict = self.data_encoder(datadict) 74 | A = datadict["A"] 75 | X = datadict["X"] 76 | x = datadict["x"] 77 | X = self.tupleinit(X, x) 78 | for conv in self.subggnns: 79 | tX = conv.forward(A, X, datadict) 80 | X = X.add(tX, True) 81 | x = self.lpool(X) 82 | x = self.poolmlp(x) 83 | h_graph = torch_scatter_reduce(0, x, datadict["batch"], 84 | datadict["num_graphs"], self.npool) 85 | return self.pred_lin(h_graph) 86 | 87 | 88 | # 2 build models 89 | 90 | mlpdict = { 91 | "norm": "bn", 92 | "act": "silu", 93 | "dp": 0.0 94 | } # hyperparameter for multi-layer perceptrons in model. dropout ratio=0, use batchnorm, use SiLU activition function 95 | 96 | model = SpModel(mlp=mlpdict) 97 | 98 | device = torch.device("cuda") 99 | # 3 data set preprocessing 100 | 101 | # load pyg data 102 | trn_dataset = ZINC("dataset/ZINC", subset=True, split="train") 103 | val_dataset = ZINC("dataset/ZINC", subset=True, split="val") 104 | tst_dataset = ZINC("dataset/ZINC", subset=True, split="test") 105 | 106 | # initialize tuple feature 107 | keys = parse_precomputekey(model) 108 | trn_dataset = ParallelPreprocessDataset( 109 | "dataset/ZINC_trn", trn_dataset, 110 | Sppretransform(partial(KhopSampler, hop=3), [""], keys), 0) 111 | val_dataset = ParallelPreprocessDataset( 112 | "dataset/ZINC_val", val_dataset, 113 | Sppretransform(partial(KhopSampler, hop=3), [""], keys), 0) 114 | tst_dataset = ParallelPreprocessDataset( 115 | "dataset/ZINC_tst", tst_dataset, 116 | Sppretransform(partial(KhopSampler, hop=3), [""], keys), 0) 117 | 118 | # create sparse dataloader 119 | batch_size=128 120 | trn_dataloader = SpDataloader(trn_dataset, 121 | batch_size=batch_size, 122 | shuffle=True, 123 | drop_last=True, 124 | device=device) 125 | val_dataloader = SpDataloader(val_dataset, 126 | batch_size=batch_size, 127 | device=device) 128 | tst_dataloader = SpDataloader(tst_dataset, 129 | batch_size=batch_size, 130 | device=device) 131 | 132 | optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) 133 | 134 | model = model.to(device) 135 | 136 | 137 | # 4 training process 138 | def train(dataloader): 139 | model.train() 140 | losss = [] 141 | for batch in dataloader: 142 | batch = batch.to(device, non_blocking=True) 143 | optimizer.zero_grad() 144 | datadict = batch.to_dict() 145 | datadict["num_graphs"] = batch.num_graphs 146 | pred = model(datadict) 147 | loss = F.l1_loss(datadict["y"].unsqueeze(-1), pred, reduction="mean") 148 | loss.backward() 149 | optimizer.step() 150 | losss.append(loss) 151 | losss = np.average(list(map(lambda x: x.item(), losss))) 152 | return losss 153 | 154 | 155 | @torch.no_grad() 156 | def eval(dataloader): 157 | model.eval() 158 | loss = 0 159 | size = 0 160 | for batch in dataloader: 161 | batch = batch.to(device, non_blocking=True) 162 | datadict = batch.to_dict() 163 | datadict["num_graphs"] = batch.num_graphs 164 | pred = model(datadict) 165 | loss += F.l1_loss(datadict["y"].unsqueeze(-1), pred, reduction="sum") 166 | size += pred.shape[0] 167 | return (loss / size).item() 168 | 169 | 170 | out = [] 171 | 172 | best_val = float("inf") 173 | tst_score = float("inf") 174 | for epoch in range(1, 100 + 1): 175 | t1 = time.time() 176 | losss = train(trn_dataloader) 177 | t2 = time.time() 178 | val_score = eval(val_dataloader) 179 | if val_score < best_val: 180 | best_val = val_score 181 | tst_score = eval(tst_dataloader) 182 | t3 = time.time() 183 | print( 184 | f"epoch {epoch} trn time {t2-t1:.2f} val time {t3-t2:.2f} memory {torch.cuda.max_memory_allocated()/1024**3:.2f} GB l1loss {losss:.4f} val MAE {val_score:.4f} tst MAE {tst_score:.4f}" 185 | ) 186 | if np.isnan(losss) or np.isnan(val_score): 187 | break 188 | out.append(tst_score) 189 | 190 | print(f"All {np.average(tst_score)} {np.std(tst_score)}") -------------------------------------------------------------------------------- /example/reproduce.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 nohup python -O example/zinc.py --epochs 10 --repeat 1 --sparse --aggr sum --conv NGNN --npool sum --lpool mean --cpool mean --mlplayer 2 --norm bn --lr 1e-2 --wd 4.9e-5 --cosT 26 --dp 0.0 --outlayer 4 --normparam 1.94e-1 --minlr 8.4e-5 --K 4.9e-3 --K2 4.33e-6 > NGNN.compiled.out & 2 | CUDA_VISIBLE_DEVICES=2 nohup python -O example/zinc.py --sparse --epochs 10 --repeat 1 --aggr sum --conv I2GNN --npool sum --lpool mean --cpool mean --mlplayer 2 --norm bn --lr 3.4e-3 --wd 3.7e-2 --cosT 26 --dp 0.0 --outlayer 4 --normparam 0.31 --minlr 2.03e-5 --K 0.011 --K2 0.0073 > I2GNN.compiled.out & 3 | CUDA_VISIBLE_DEVICES=3 nohup python -O example/zinc.py --epochs 10 --repeat 1 --sparse --aggr sum --conv PPGN --npool sum --lpool mean --cpool mean --mlplayer 2 --norm bn --lr 4.5e-3 --wd 6.5e-6 --cosT 32 --dp 0.0 --outlayer 4 --normparam 1.85e-1 --minlr 7.0e-5 --K 1.04e-4 --K2 8.24e-5 > PPGN.compiled.out & 4 | CUDA_VISIBLE_DEVICES=4 nohup python -O example/zinc.py --epochs 10 --repeat 1 --aggr sum --conv SSWL --npool sum --lpool mean --cpool mean --mlplayer 2 --norm bn --lr 9e-3 --wd 6.5e-7 --cosT 40 --dp 0.0 --outlayer 4 --normparam 0.22 --minlr 8.4e-5 --K 1.4e-2 --K2 1.0e-7 > SSWL.compiled.out & 5 | CUDA_VISIBLE_DEVICES=5 nohup python -O example/zinc.py --epochs 10 --repeat 1 --sparse --aggr sum --conv DSSGNN --npool sum --lpool sum --cpool mean --mlplayer 2 --norm bn --lr 0.0086 --wd 0.012 --cosT 26 --dp 0.0 --outlayer 4 --normparam 0.31 --minlr 8.9e-6 --K 1.3e-3 --K2 2.8e-4 > DSSGNN.compiled.out & 6 | CUDA_VISIBLE_DEVICES=6 nohup python -O example/zinc.py --epochs 10 --repeat 1 --sparse --aggr sum --conv GNNAK --npool sum --lpool sum --cpool mean --mlplayer 2 --norm bn --lr 0.0086 --wd 0.012 --cosT 26 --dp 0.0 --outlayer 4 --normparam 0.31 --minlr 8.9e-6 --K 1.3e-3 --K2 2.8e-4 > GNNAK.compiled.out & 7 | CUDA_VISIBLE_DEVICES=7 nohup python -O example/zinc.py --epochs 10 --repeat 1 --sparse --aggr sum --conv SUN --npool sum --lpool sum --cpool mean --mlplayer 2 --norm bn --lr 0.0086 --wd 0.0064 --cosT 26 --dp 0.0 --outlayer 4 --normparam 0.57 --minlr 2.4e-5 --K 5.7e-7 --K2 2.8e-4 > SUN.compiled.out & 8 | CUDA_VISIBLE_DEVICES=1 nohup python -O example/NGAT.py --epochs 1000 --repeat 10 --sparse --aggr sum --conv NGNN --npool sum --lpool sum --cpool mean --mlplayer 2 --norm bn --lr 1e-2 --wd 4.9e-5 --cosT 26 --dp 0.0 --outlayer 4 --normparam 1.94e-1 --minlr 8.4e-5 --K 4.9e-3 --K2 4.33e-6 > NGAT.compiled.out & -------------------------------------------------------------------------------- /example/work.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 nohup python -O example/zinc.py --epochs 10 --repeat 10 --sparse --aggr sum --conv NGNN --npool sum --lpool mean --cpool mean --mlplayer 2 --norm bn --lr 1e-2 --wd 4.9e-5 --cosT 26 --dp 0.0 --outlayer 4 --normparam 1.94e-1 --minlr 8.4e-5 --K 4.9e-3 --K2 4.33e-6 > SpNGNN.time.out & 2 | CUDA_VISIBLE_DEVICES=1 nohup python -O example/zinc.py --epochs 10 --repeat 10 --sparse --aggr sum --conv GNNAK --npool sum --lpool mean --cpool mean --mlplayer 2 --norm bn --lr 1e-2 --wd 4.9e-5 --cosT 26 --dp 0.0 --outlayer 4 --normparam 1.94e-1 --minlr 8.4e-5 --K 4.9e-3 --K2 4.33e-6 > SpGNNAK.time.out & 3 | CUDA_VISIBLE_DEVICES=2 nohup python -O example/zinc.py --epochs 10 --repeat 10 --sparse --aggr sum --conv DSSGNN --npool sum --lpool mean --cpool mean --mlplayer 2 --norm bn --lr 1e-2 --wd 4.9e-5 --cosT 26 --dp 0.0 --outlayer 4 --normparam 1.94e-1 --minlr 8.4e-5 --K 4.9e-3 --K2 4.33e-6 > SpDSSGNN.time.out & 4 | CUDA_VISIBLE_DEVICES=3 nohup python -O example/zinc.py --epochs 10 --repeat 10 --sparse --aggr sum --conv SSWL --npool sum --lpool mean --cpool mean --mlplayer 2 --norm bn --lr 1e-2 --wd 4.9e-5 --cosT 26 --dp 0.0 --outlayer 4 --normparam 1.94e-1 --minlr 8.4e-5 --K 4.9e-3 --K2 4.33e-6 > SpSSWL.time.out & 5 | CUDA_VISIBLE_DEVICES=4 nohup python -O example/zinc.py --epochs 10 --repeat 10 --sparse --aggr sum --conv PPGN --npool sum --lpool mean --cpool mean --mlplayer 2 --norm bn --lr 1e-2 --wd 4.9e-5 --cosT 26 --dp 0.0 --outlayer 4 --normparam 1.94e-1 --minlr 8.4e-5 --K 4.9e-3 --K2 4.33e-6 > SpPPGN.time.out & 6 | CUDA_VISIBLE_DEVICES=5 nohup python -O example/zinc.py --epochs 10 --repeat 10 --sparse --aggr sum --conv SUN --npool sum --lpool mean --cpool mean --mlplayer 2 --norm bn --lr 1e-2 --wd 4.9e-5 --cosT 26 --dp 0.0 --outlayer 4 --normparam 1.94e-1 --minlr 8.4e-5 --K 4.9e-3 --K2 4.33e-6 > SpSUN.time.out & 7 | CUDA_VISIBLE_DEVICES=6 nohup python -O example/zinc.py --epochs 10 --repeat 10 --aggr sum --conv SSWL --npool sum --lpool mean --cpool mean --mlplayer 2 --norm bn --lr 1e-2 --wd 4.9e-5 --cosT 26 --dp 0.0 --outlayer 4 --normparam 1.94e-1 --minlr 8.4e-5 --K 4.9e-3 --K2 4.33e-6 > MaSSWL.time.out & 8 | CUDA_VISIBLE_DEVICES=7 nohup python -O example/zinc.py --epochs 10 --repeat 10 --aggr sum --conv PPGN --npool sum --lpool mean --cpool mean --mlplayer 2 --norm bn --lr 1e-2 --wd 4.9e-5 --cosT 26 --dp 0.0 --outlayer 4 --normparam 1.94e-1 --minlr 8.4e-5 --K 4.9e-3 --K2 4.33e-6 > MaPPGN.time.out & 9 | 10 | CUDA_VISIBLE_DEVICES=6 nohup python -O example/zinc.py --epochs 10 --repeat 1 --aggr sum --conv I2GNN --npool sum --lpool mean --cpool mean --mlplayer 2 --norm bn --lr 3.4e-3 --wd 3.7e-2 --cosT 26 --dp 0.0 --outlayer 4 --normparam 0.31 --minlr 2.03e-5 --K 0.011 --K2 0.0073 > I2GNN.time.out & 11 | 12 | 13 | {'lr': 0.0034065612285146232, 'wd': 0.03722265158992254, 'aggr': 'sum', 'npool': 'sum', 'lpool': 'sum', 'minlr': 2.0341235269027242e-05, 'normparam': 0.3130753368607271, 'cosT': 26, 'K': 0.011016896208476656, 'K2': 0.007270837470201833} 14 | 15 | CUDA_VISIBLE_DEVICES=6 nohup python -O example/zinc.py --epochs 10 --repeat 10 --aggr sum --conv I2GNN --npool sum --lpool mean --cpool mean --mlplayer 2 --norm bn --lr 3.4e-3 --wd 3.7e-2 --cosT 26 --dp 0.0 --outlayer 4 --normparam 0.31 --minlr 2.03e-5 --K 0.011 --K2 0.0073 > I2GNN.time.out & 16 | 17 | {'lr': 0.004543542861001459, 'wd': 6.477760912476973e-06, 'aggr': 'mean', 'npool': 'sum', 'lpool': 'sum', 'minlr': 7.030999869053724e-05, 'normparam': 0.18535052628942864, 'cosT': 32, 'K': 0.00010376840392440702, 'K2': 8.242613862862107e-05} 18 | CUDA_VISIBLE_DEVICES=4 nohup python -O example/zinc.py --epochs 10 --repeat 10 --sparse --aggr sum --conv PPGN --npool sum --lpool mean --cpool mean --mlplayer 2 --norm bn --lr 4.5e-3 --wd 6.5e-6 --cosT 32 --dp 0.0 --outlayer 4 --normparam 1.85e-1 --minlr 7.0e-5 --K 1.04e-4 --K2 8.24e-5 > SpPPGN.2.time.out & 19 | 20 | 21 | {'lr': 0.008917818793847022, 'wd': 6.478937487304309e-07, 'aggr': 'sum', 'npool': 'sum', 'lpool': 'mean', 'minlr': 1.5994362893084637e-06, 'normparam': 0.21991685589063176, 'cosT': 40, 'K': 0.014100560322057884, 'K2': 1.0347911649315998e-07} 22 | CUDA_VISIBLE_DEVICES=6 nohup python -O example/zinc.py --epochs 600 --repeat 10 --aggr sum --conv SSWL --npool sum --lpool mean --cpool mean --mlplayer 2 --norm bn --lr 9e-3 --wd 6.5e-7 --cosT 40 --dp 0.0 --outlayer 4 --normparam 0.22 --minlr 8.4e-5 --K 1.4e-2 --K2 1.0e-7 > MaSSWL.2.time.out & 23 | 24 | {'lr': 0.0086, 'wd': 0.012, 'aggr': 'sum', 'npool': 'sum', 'lpool': 'sum', 'minlr': 8.9e-06, 'normparam': 0.31, 'cosT': 42, 'K': 0.0013, 'K2': 0.00028 25 | CUDA_VISIBLE_DEVICES=0 nohup python -O example/zinc.py --epochs 10 --repeat 10 --sparse --aggr sum --conv DSSGNN --npool sum --lpool sum --cpool mean --mlplayer 2 --norm bn --lr 0.0086 --wd 0.012 --cosT 26 --dp 0.0 --outlayer 4 --normparam 0.31 --minlr 8.9e-6 --K 1.3e-3 --K2 2.8e-4 > SpDSSGNN.2.time.out & 26 | 27 | CUDA_VISIBLE_DEVICES=2 nohup python -O example/zinc.py --epochs 10 --repeat 10 --sparse --aggr sum --conv GNNAK --npool sum --lpool sum --cpool mean --mlplayer 2 --norm bn --lr 0.0086 --wd 0.012 --cosT 26 --dp 0.0 --outlayer 4 --normparam 0.31 --minlr 8.9e-6 --K 1.3e-3 --K2 2.8e-4 > SpGNNAK.2.time.out & 28 | 29 | {'lr': 0.0086, 'wd': 0.0064, 'aggr': 'sum', 'npool': 'sum', 'lpool': 'sum', 'minlr': 2.36e-05, 'normparam': 0.57, 'cosT': 35, 'K': 5.7e-07, 'K2': 0.00028} 30 | CUDA_VISIBLE_DEVICES=1 nohup python -O example/zinc.py --epochs 10 --repeat 10 --sparse --aggr sum --conv SUN --npool sum --lpool sum --cpool mean --mlplayer 2 --norm bn --lr 0.0086 --wd 0.0064 --cosT 26 --dp 0.0 --outlayer 4 --normparam 0.57 --minlr 2.4e-5 --K 5.7e-7 --K2 2.8e-4 > SpSUN.2.time.out & 31 | 32 | {'lr': 0.0037306061580411167, 'wd': 0.0001365758890619353, 'aggr': 'sum', 'npool': 'sum', 'lpool': 'mean', 'minlr': 5.902538747285543e-07, 'normparam': 0.21560077784065987, 'cosT': 43, 'K': 0.036279045710338645, 'K2': 0.006338929618591591} 33 | CUDA_VISIBLE_DEVICES=3 nohup python -O example/zinc.py --epochs 10 --repeat 10 --sparse --aggr sum --conv GNNAK --npool sum --lpool mean --cpool mean --mlplayer 2 --norm bn --lr 3.7e-3 --wd 1.4e-4 --cosT 43 --dp 0.0 --outlayer 4 --normparam 2.2e-1 --minlr 5.9e-7 --K 3.6e-2 --K2 6.3e-3 > SpGNNAK.2.time.out & 34 | 35 | 36 | 37 | CUDA_VISIBLE_DEVICES=6 nohup python example/zinc.py --sparse --aggr sum --conv SUN --npool mean --lpool mean --cpool max > SpSun.debug.time.out & 38 | CUDA_VISIBLE_DEVICES=5 nohup python example/zinc.py --sparse --aggr sum --conv SSWL --npool mean --lpool mean --cpool max > SpSSWL.debug.time.out & 39 | CUDA_VISIBLE_DEVICES=6 nohup python example/zinc.py --sparse --aggr sum --conv NGNN --npool mean --lpool mean --cpool max > SpNGNN.debug.time.out & 40 | CUDA_VISIBLE_DEVICES=7 nohup python example/zinc.py --sparse --aggr sum --conv GNNAK --npool mean --lpool mean --cpool max > SpGNNAK.debug.time.out & 41 | CUDA_VISIBLE_DEVICES=3 nohup python example/zinc.py --sparse --aggr sum --conv DSSGNN --npool mean --lpool mean --cpool max > SpDSSGNN.debug.time.out & 42 | 43 | 44 | CUDA_VISIBLE_DEVICES=0 nohup python example/zinc.py --aggr sum --conv SSWL --npool mean --lpool mean --cpool max > MaSSWL.debug.time.out & 45 | CUDA_VISIBLE_DEVICES=1 nohup python example/zinc.py --aggr sum --conv NGNN --npool mean --lpool mean --cpool max > MaNGNN.debug.time.out & 46 | CUDA_VISIBLE_DEVICES=2 nohup python example/zinc.py --aggr sum --conv GNNAK --npool mean --lpool mean --cpool max > MaGNNAK.debug.time.out & 47 | CUDA_VISIBLE_DEVICES=4 nohup python example/zinc.py --aggr sum --conv SUN --npool mean --lpool mean --cpool max > MaSun.debug.time.out & 48 | CUDA_VISIBLE_DEVICES=3 nohup python example/zinc.py --aggr sum --conv DSSGNN --npool mean --lpool mean --cpool max > MaDSSGNN.debug.time.out & 49 | 50 | CUDA_VISIBLE_DEVICES=2 nohup python -O example/zinc.py --aggr sum --conv PPGN --npool mean --lpool mean --cpool max > MaPPGN.debug.time.out & 51 | CUDA_VISIBLE_DEVICES=5 nohup python -O example/zinc.py --sparse --aggr sum --conv PPGN --npool mean --lpool mean --cpool max > SpPPGN.debug.time.out & 52 | -------------------------------------------------------------------------------- /pygho/__init__.py: -------------------------------------------------------------------------------- 1 | from .backend.SpTensor import SparseTensor 2 | from .backend.MaTensor import MaskedTensor -------------------------------------------------------------------------------- /pygho/backend/MaTensor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor, BoolTensor, LongTensor 3 | from typing import Optional, Callable, Iterable 4 | from typing import Union 5 | # merge torch.nested or torch.masked API in the long run. 6 | 7 | 8 | def filterinf(X: Tensor, filled_value: float = 0): 9 | """ 10 | Replaces positive and negative infinity values in a tensor with a specified value. 11 | 12 | Args: 13 | 14 | - X (Tensor): The input tensor. 15 | - filled_value (float, optional): The value to replace positive and negative 16 | infinity values with (default: 0). 17 | 18 | Returns: 19 | 20 | - Tensor: A tensor with positive and negative infinity values replaced by the 21 | specified `filled_value`. 22 | 23 | Example: 24 | 25 | :: 26 | 27 | input_tensor = torch.tensor([1.0, 2.0, torch.inf, -torch.inf, 3.0]) 28 | result = filterinf(input_tensor, filled_value=999.0) 29 | 30 | """ 31 | return X.masked_fill(torch.isinf(X), filled_value) 32 | 33 | 34 | class MaskedTensor: 35 | """ 36 | Represents a masked tensor with optional padding values. 37 | This class allows you to work with tensors that have a mask indicating valid and 38 | invalid values. You can perform various operations on the masked tensor, such as 39 | filling masked values, computing sums, means, maximums, minimums, and more. 40 | 41 | Parameters: 42 | 43 | - data (Tensor): The underlying data tensor of shape (\*maskedshape, \*denseshape) 44 | - mask (BoolTensor): The mask tensor of shape (\*maskedshape) 45 | where `True` represents valid values, and False` represents invalid values. 46 | - padvalue (float, optional): The value to use for padding. Defaults to 0. 47 | - is_filled (bool, optional): Indicates whether the invalid values have already 48 | been filled to the padvalue. Defaults to False. 49 | 50 | Attributes: 51 | 52 | - data (Tensor): The underlying data tensor. 53 | - mask (BoolTensor): The mask tensor. 54 | - fullnegmask (BoolTensor): The mask tensor after broadcasting to match the data's 55 | dimensions and take logical_not. 56 | - padvalue (float): The padding value. 57 | - shape (torch.Size): The shape of the data tensor. 58 | - masked_dim (int): The number of dimensions in maskedshape. 59 | - dense_dim (int): The number of dimensions in denseshape. 60 | - maskedshape (torch.Size): The shape of the tensor up to the masked dimensions. 61 | - denseshape (torch.Size): The shape of the tensor after the masked dimensions. 62 | 63 | Methods: 64 | 65 | - fill_masked_(self, val: float = 0) -> None: In-place fill of masked values. 66 | - fill_masked(self, val: float = 0) -> Tensor: Return a tensor with masked values 67 | filled with the specified value. 68 | - to(self, device: torch.DeviceObjType, non_blocking: bool = True): Move the 69 | tensor to the specified device. 70 | - sum(self, dims: Union[Iterable[int], int], keepdim: bool = False): Compute the 71 | sum of masked values along specified dimensions. 72 | - mean(self, dims: Union[Iterable[int], int], keepdim: bool = False): Compute 73 | the mean of masked values along specified dimensions. 74 | - max(self, dims: Union[Iterable[int], int], keepdim: bool = False): Compute the 75 | maximum of masked values along specified dimensions. 76 | - min(self, dims: Union[Iterable[int], int], keepdim: bool = False): Compute the 77 | minimum of masked values along specified dimensions. 78 | - diag(self, dims: Iterable[int]): Extract diagonals from the tensor. 79 | The dimensions in dims will be take diagonal and put at dims[0] 80 | - unpooling(self, dims: Union[int, Iterable[int]], tarX): Perform unpooling 81 | operation along specified dimensions. 82 | - tuplewiseapply(self, func: Callable[[Tensor], Tensor]): Apply a function to 83 | each element of the masked tensor. 84 | - diagonalapply(self, func: Callable[[Tensor, LongTensor], Tensor]): Apply a 85 | function to diagonal elements of the masked tensor. 86 | - add(self, tarX, samesparse: bool): Add two masked tensors together. 87 | - catvalue(self, tarX, samesparse: bool): Concatenate values of two masked 88 | tensors. 89 | """ 90 | def __init__(self, 91 | data: Tensor, 92 | mask: BoolTensor, 93 | padvalue: float = 0.0, 94 | is_filled: bool = False): 95 | # mask: True for valid value, False for invalid value 96 | assert data.ndim >= mask.ndim, "data's #dim should be larger than mask " 97 | assert data.shape[:mask. 98 | ndim] == mask.shape, "data and mask's first dimensions should match" 99 | self.__data = data 100 | self.__mask = mask 101 | self.__masked_dim = mask.ndim 102 | if self.dense_dim > 0: 103 | mask = mask.unsqueeze(-1) 104 | if self.dense_dim > 1: 105 | mask = mask.unflatten(-1, (self.dense_dim)*(1,)) 106 | self.__fullnegmask = torch.logical_not(mask) 107 | if not is_filled: 108 | self.__padvalue = padvalue 109 | self.fill_masked_(padvalue) 110 | else: 111 | self.__padvalue = padvalue 112 | 113 | def fill_masked_(self, val: float = 0.0) -> None: 114 | """ 115 | inplace fill the masked values 116 | """ 117 | if self.padvalue == val: 118 | return 119 | self.__padvalue = val 120 | self.__data = self.data.masked_fill(self.fullnegmask, val) 121 | 122 | def fill_masked(self, val: float = 0.0) -> Tensor: 123 | """ 124 | return a tensor with masked values filled with val. 125 | """ 126 | if self.__padvalue == val: 127 | return self.data 128 | return self.data.masked_fill(self.fullnegmask, val) 129 | 130 | def to(self, device: torch.DeviceObjType, non_blocking: bool = True): 131 | """ 132 | move data to some device 133 | """ 134 | self.__data = self.__data.to(device, non_blocking=non_blocking) 135 | self.__mask = self.__mask.to(device, non_blocking=non_blocking) 136 | self.__fullnegmask = self.__fullnegmask.to(device, non_blocking=non_blocking) 137 | return self 138 | 139 | @property 140 | def padvalue(self) -> float: 141 | return self.__padvalue 142 | 143 | @property 144 | def data(self) -> Tensor: 145 | return self.__data 146 | 147 | @property 148 | def mask(self) -> BoolTensor: 149 | return self.__mask 150 | 151 | @property 152 | def fullnegmask(self) -> BoolTensor: 153 | return self.__fullnegmask 154 | 155 | @property 156 | def shape(self) -> torch.Size: 157 | return self.__data.shape 158 | 159 | @property 160 | def masked_dim(self): 161 | return self.__masked_dim 162 | 163 | @property 164 | def dense_dim(self): 165 | return len(self.denseshape) 166 | 167 | @property 168 | def maskedshape(self): 169 | return self.shape[:self.masked_dim] 170 | 171 | @property 172 | def denseshape(self): 173 | return self.shape[self.masked_dim:] 174 | 175 | def sum(self, dims: Union[Iterable[int], int], keepdim: bool = False): 176 | return MaskedTensor(torch.sum(self.fill_masked(0.), 177 | dim=dims, 178 | keepdim=keepdim), 179 | torch.amax(self.mask, dims, keepdim=keepdim), 180 | padvalue=0, 181 | is_filled=True) 182 | 183 | def mean(self, dims: Union[Iterable[int], int], keepdim: bool = False): 184 | count = torch.clamp_min_( 185 | torch.sum(torch.logical_not(self.fullnegmask), dim=dims, keepdim=keepdim), 1) 186 | valsum = self.sum(dims, keepdim) 187 | return MaskedTensor(valsum.data / count, 188 | valsum.mask, 189 | padvalue=valsum.padvalue, 190 | is_filled=True) 191 | 192 | def max(self, dims: Union[Iterable[int], int], keepdim: bool = False): 193 | tmp = self.fill_masked(-torch.inf) 194 | return MaskedTensor(filterinf( 195 | torch.amax(tmp, dim=dims, keepdim=keepdim), 0), 196 | torch.amax(self.mask, dims, keepdim=keepdim), 197 | padvalue=0, 198 | is_filled=True) 199 | 200 | def min(self, dims: Union[Iterable[int], int], keepdim: bool = False): 201 | tmp = self.fill_masked(torch.inf) 202 | return MaskedTensor(filterinf( 203 | torch.amax(tmp, dim=dims, keepdim=keepdim), 0), 204 | torch.amax(self.mask, dims, keepdim=keepdim), 205 | padvalue=0, 206 | is_filled=True) 207 | 208 | def diag(self, dims: Iterable[int]): 209 | """ 210 | put the reduced output to dim[0] 211 | """ 212 | assert len(dims) >= 2, "must diag several dims" 213 | dims = sorted(list(dims)) 214 | tdata = self.data 215 | tmask = self.mask 216 | tdata = torch.diagonal(tdata, 0, dims[0], dims[1]) 217 | tmask = torch.diagonal(tmask, 0, dims[0], dims[1]) 218 | for i in range(2, len(dims)): 219 | tdata = torch.diagonal(tdata, 0, dims[i], -1) 220 | tmask = torch.diagonal(tmask, 0, dims[i], -1) 221 | tdata = torch.movedim(tdata, -1, dims[0]) 222 | tmask = torch.movedim(tmask, -1, dims[0]) 223 | return MaskedTensor(tdata, tmask, self.padvalue, True) 224 | 225 | def unpooling(self, dims: Union[int, Iterable[int]], tarX): 226 | if isinstance(dims, int): 227 | dims = [dims] 228 | dims = sorted(list(dims)) 229 | tdata = self.data 230 | for _ in dims: 231 | tdata = tdata.unsqueeze(_) 232 | tdata = tdata.expand(*(-1 if i not in dims else tarX.shape[i] 233 | for i in range(tdata.ndim))) 234 | return MaskedTensor(tdata, tarX.mask, self.padvalue, False) 235 | 236 | def tuplewiseapply(self, func: Callable[[Tensor], Tensor]): 237 | # it may cause nan in gradient and makes amp unable to update 238 | ndata = func(self.fill_masked(0.)) 239 | return MaskedTensor(ndata, self.mask) 240 | 241 | def diagonalapply(self, func: Callable[[Tensor, LongTensor], Tensor]): 242 | assert self.masked_dim == 3, "only implemented for 2D" 243 | diagonaltype = torch.eye(self.shape[1], 244 | self.shape[2], 245 | dtype=torch.long, 246 | device=self.data.device) 247 | diagonaltype = diagonaltype.unsqueeze(0).expand_as(self.mask) 248 | ndata = func(self.data, diagonaltype) 249 | return MaskedTensor(ndata, self.mask) 250 | 251 | def add(self, tarX, samesparse: bool): 252 | assert isinstance(tarX, MaskedTensor) 253 | tarX: MaskedTensor = tarX 254 | if samesparse: 255 | return MaskedTensor(tarX.data + self.data, 256 | self.mask, 257 | self.padvalue, 258 | is_filled=self.padvalue == tarX.padvalue) 259 | else: 260 | return MaskedTensor( 261 | tarX.fill_masked(0.) + self.fill_masked(0.), 262 | torch.logical_or(self.mask, tarX.mask), 0, True) 263 | 264 | def catvalue(self, tarX: Iterable, samesparse: bool): 265 | assert samesparse == True, "must have the same sparcity to concat value" 266 | return self.tuplewiseapply(lambda _: torch.concat([self.data] + [_.data for _ in tarX], dim=-1)) -------------------------------------------------------------------------------- /pygho/backend/Mamamm.py: -------------------------------------------------------------------------------- 1 | from .MaTensor import MaskedTensor 2 | import torch 3 | from torch import BoolTensor, Tensor 4 | from typing import Optional, Tuple 5 | 6 | 7 | def mamamm(A: MaskedTensor, 8 | dim1: int, 9 | B: MaskedTensor, 10 | dim2: int, 11 | mask: BoolTensor, 12 | broadcast_firstdim: bool = True) -> MaskedTensor: 13 | """ 14 | Batched masked matrix multiplication of two MaskedTensors. 15 | 16 | This function performs batched matrix multiplication between two MaskedTensors `A` and `B`, where the masked dimensions `dim1` and `dim2` are contracted. The result is a new MaskedTensor with the specified mask. 17 | 18 | Args: 19 | 20 | - A (MaskedTensor): The first MaskedTensor with shape (B,\* maskedshape1,\*denseshapeshape). 21 | - dim1 (int): The masked dimension to contract in the first tensor `A`. 22 | - B (MaskedTensor): The second MaskedTensor with shape (B,\* maskedshape2,\*denseshapeshape). 23 | - dim2 (int): The masked dimension to contract in the second tensor `B`. 24 | - mask (BoolTensor): The mask to apply to the resulting MaskedTensor. 25 | - broadcast_firstdim (bool, optional): If True, broadcast the first dimension (batch dimension) of `A` and `B` to ensure compatibility. Default is True. 26 | 27 | Returns: 28 | 29 | - MaskedTensor: A new MaskedTensor with shape (B,\* maskedshape1\dim1,\* maskedshape2\dim2,\*denseshapeshape) and the specified mask. 30 | 31 | Notes: 32 | 33 | - This function performs batched matrix multiplication between two MaskedTensors, contracting the specified masked dimensions. 34 | """ 35 | tA, tB = A.fill_masked(0.), B.fill_masked(0.) 36 | catdim1, catdim2 = A.masked_dim, B.masked_dim 37 | if broadcast_firstdim: 38 | assert dim1 > 0, "0 dim of A is batch, need to be broadcasted" 39 | assert dim2 > 0, "0 dim of B is batch, need to be broadcasted" 40 | tA = torch.movedim(tA, 0, -1) 41 | tB = torch.movedim(tB, 0, -1) 42 | dim1 -= 1 43 | dim2 -= 1 44 | catdim1 -= 1 45 | catdim2 -= 1 46 | if catdim1 == 1: 47 | tA.unsqueeze(catdim1) 48 | catdim1 += 1 49 | if catdim2 == 1: 50 | tB.unsqueeze(catdim2) 51 | catdim2 += 1 52 | assert catdim1 >= 2, "bug" 53 | assert catdim2 >= 2, "bug" 54 | tA, tB = tA.movedim(dim1, -1), tB.movedim(dim2, -1) 55 | catshape1, catshape2 = tA.shape[:catdim1-1], tB.shape[:catdim2-1] 56 | tA, tB = tA.flatten(0, catdim1 - 2), tB.flatten(0, catdim2 - 2) 57 | tA, tB = tA.movedim(0, -2), tB.movedim(0, -1) 58 | prod = torch.matmul(tA, tB) 59 | 60 | prod = prod.flatten(-2, -1).movedim(-1, 0) 61 | prod = prod.unflatten(0, catshape1 + catshape2) 62 | if broadcast_firstdim: 63 | prod = prod.movedim(-1, 0) 64 | return MaskedTensor(prod, mask) 65 | -------------------------------------------------------------------------------- /pygho/backend/Spmamm.py: -------------------------------------------------------------------------------- 1 | from .MaTensor import MaskedTensor, filterinf 2 | import torch 3 | from torch import BoolTensor, Tensor 4 | from typing import Optional 5 | from .SpTensor import SparseTensor 6 | from .utils import torch_scatter_reduce 7 | 8 | filled_value_dict = {"sum": 0, "max": -torch.inf, "min": torch.inf} 9 | filter_inf_ops = ["max", "min"] 10 | 11 | 12 | def spmamm(A: SparseTensor, 13 | dim1: int, 14 | B: MaskedTensor, 15 | dim2: int, 16 | mask: Optional[BoolTensor] = None, 17 | aggr: str = "sum") -> MaskedTensor: 18 | """ 19 | SparseTensor-MaskedTensor multiplication. 20 | 21 | This function performs multiplication between a SparseTensor `A` and a MaskedTensor `B`. The specified dimensions `dim1` and `dim2` are contracted during the multiplication, and the result is returned as a MaskedTensor. 22 | 23 | Args: 24 | 25 | - A (SparseTensor): The SparseTensor with shape (B, n, m, \*shape). 26 | - dim1 (int): The dimension to contract in the SparseTensor `A`. 27 | - B (MaskedTensor): The MaskedTensor with shape (B, m, \*shape). 28 | - dim2 (int): The dimension to contract in the MaskedTensor `B`. 29 | - mask (BoolTensor, optional): The mask to apply to the resulting MaskedTensor. Default is None. 30 | - aggr (str, optional): The aggregation method for reduction during multiplication (e.g., "sum", "max"). Default is "sum". 31 | 32 | Returns: 33 | 34 | - MaskedTensor: A new MaskedTensor with shape (B, n,\*denseshapeshape) and the specified mask. 35 | 36 | Notes: 37 | - This function performs multiplication between a SparseTensor and a MaskedTensor, contracting the specified dimensions. 38 | - The `aggr` parameter controls the reduction operation during multiplication. 39 | - The result is returned as a MaskedTensor. 40 | 41 | """ 42 | assert A.sparse_dim == 3, f"A should have 3 sparse dims, but input has {A.sparse_dim}" 43 | assert aggr != "mean", "not implemented" 44 | if dim1 == 1: 45 | b, n = A.shape[0], A.shape[2] 46 | bij = A.indices[0], A.indices[1] 47 | tar_ind = n * A.indices[0] + A.indices[2] 48 | elif dim1 == 2: 49 | b, n = A.shape[0], A.shape[1] 50 | bij = A.indices[0], A.indices[2] 51 | tar_ind = n * A.indices[0] + A.indices[1] 52 | else: 53 | raise NotImplementedError 54 | Aval = A.values 55 | tB = torch.movedim(B.data, dim2, 1) 56 | tBmask = torch.movedim(B.mask, dim2, 1) 57 | if Aval is not None: 58 | mult = Aval.unsqueeze(1) * tB[bij[0], bij[1]] 59 | else: 60 | mult = tB[bij[0], bij[1]] 61 | validmask = tBmask[bij[0], bij[1]] 62 | mult.masked_fill(torch.logical_not(validmask), filled_value_dict[aggr]) 63 | val = torch_scatter_reduce(0, mult, tar_ind, b*n, aggr) 64 | ret = val.unflatten(0, (b, n)) 65 | ret = torch.movedim(ret, 1, dim2) 66 | if aggr in filter_inf_ops: 67 | ret = filterinf(ret) 68 | return MaskedTensor(ret, mask if mask is not None else B.mask) -------------------------------------------------------------------------------- /pygho/backend/Spmm.py: -------------------------------------------------------------------------------- 1 | from .SpTensor import SparseTensor 2 | from torch import Tensor 3 | import torch 4 | from .utils import torch_scatter_reduce 5 | 6 | def spmm(A: SparseTensor, dim1: int, X: Tensor, aggr: str = "sum") -> Tensor: 7 | """ 8 | SparseTensor, Tensor matrix multiplication. 9 | 10 | This function performs a matrix multiplication between a SparseTensor `A` and a dense tensor `X` along the specified dimension `dim1`. The result is a dense tensor. The `aggr` parameter specifies the reduction operation used for merging the resulting values. 11 | 12 | Args: 13 | 14 | - A (SparseTensor): The SparseTensor used for multiplication. 15 | - dim1 (int): The dimension along which `A` is reduced. 16 | - X (Tensor): The dense tensor to be multiplied with `A`. It dim 0 will be reduced. 17 | - aggr (str, optional): The reduction operation to use for merging edge features ("sum", "min", "max", "mean"). Defaults to "sum". 18 | 19 | Returns: 20 | 21 | - Tensor: A dense tensor containing the result of the matrix multiplication. 22 | 23 | Notes: 24 | 25 | - `A` should be a 2-dimensional SparseTensor. 26 | - The dense shapes of `A` and `X` other than `dim1` must be broadcastable. 27 | 28 | """ 29 | assert A.sparse_dim == 2, "can only use 2-dim sparse tensor" 30 | val = A.values 31 | if dim1 == 0: 32 | srcind = A.indices[0] 33 | tarind = A.indices[1] 34 | tarshape = A.shape[1] 35 | else: 36 | srcind = A.indices[1] 37 | tarind = A.indices[0] 38 | tarshape = A.shape[0] 39 | if val is None: 40 | mult = X[srcind] 41 | else: 42 | mult = val * X[srcind] 43 | ret = torch_scatter_reduce(0, mult, tarind, tarshape, aggr) 44 | return ret 45 | -------------------------------------------------------------------------------- /pygho/backend/Spspmm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import LongTensor, Tensor 3 | from typing import Optional, Callable, Tuple 4 | from .SpTensor import SparseTensor, indicehash, decodehash 5 | import warnings 6 | from .utils import torch_scatter_reduce 7 | 8 | 9 | def ptr2batch(ptr: LongTensor, dim_size: int=None) -> LongTensor: 10 | """ 11 | Converts a pointer tensor to a batch tensor. 12 | 13 | This function takes a pointer tensor `ptr` and a `dim_size` and converts it to a 14 | batch tensor where each element in the batch tensor corresponds to a range of 15 | indices in the original tensor. 16 | 17 | Args: 18 | 19 | - ptr (LongTensor): The pointer tensor, where `ptr[0] = 0` and `torch.all(diff(ptr) >= 0)` is true. 20 | - dim_size (int): The size of the target dimension. 21 | 22 | Returns: 23 | 24 | - LongTensor: A batch tensor of shape `(dim_size,)` where `batch[ptr[i]:ptr[i+1]] = i`. 25 | """ 26 | assert ptr.ndim == 1, "ptr should be 1-d" 27 | assert ptr[0] == 0 and torch.all( 28 | torch.diff(ptr) >= 0), "should put in a ptr tensor" 29 | assert ptr[-1] == dim_size, "dim_size should match ptr" 30 | ret = torch.repeat_interleave(torch.diff(ptr), output_size=dim_size) 31 | return ret 32 | 33 | 34 | def deg2batch(deg: LongTensor, dim_size: int=None) -> LongTensor: 35 | """ 36 | Converts a degree tensor to a batch tensor. 37 | 38 | This function takes a degree tensor `deg` and a `dim_size` and converts it to a 39 | batch tensor where each element in the batch tensor corresponds to a range of 40 | indices in the original tensor. 41 | 42 | Args: 43 | 44 | - deg (LongTensor): The degree tensor, where `deg[i]` represents the number of element i in returned tensor. 45 | - dim_size (int): The size of the target dimension. 46 | 47 | Returns: 48 | 49 | - LongTensor: A batch tensor of shape `(dim_size,)`. 50 | """ 51 | assert deg.ndim == 1, "ptr should be 1-d" 52 | assert torch.all(deg >= 0), "should put in a degree tensor" 53 | ret = torch.repeat_interleave(deg, output_size=dim_size) 54 | return ret 55 | 56 | 57 | def spspmm_ind(ind1: LongTensor, 58 | dim1: int, 59 | ind2: LongTensor, 60 | dim2: int, 61 | is_k2_sorted: bool = False) -> Tuple[LongTensor, LongTensor]: 62 | """ 63 | Sparse-sparse matrix multiplication for indices. 64 | 65 | This function performs a sparse-sparse matrix multiplication for indices. 66 | Given two sets of indices `ind1` and `ind2`, this function eliminates `dim1` in `ind1` and `dim2` in `ind2`, and concatenates the remaining dimensions. 67 | 68 | The result represents the product of the input indices. 69 | 70 | Args: 71 | 72 | - ind1 (LongTensor): The indices of the first sparse tensor of shape `(sparsedim1, M1)`. 73 | - dim1 (int): The dimension to eliminate in `ind1`. 74 | - ind2 (LongTensor): The indices of the second sparse tensor of shape `(sparsedim2, M2)`. 75 | - dim2 (int): The dimension to eliminate in `ind2`. 76 | - is_k2_sorted (bool, optional): Whether `ind2` is sorted along `dim2`. Defaults to `False`. 77 | 78 | Returns: 79 | 80 | - tarind: LongTensor: The resulting indices after performing the sparse-sparse matrix multiplication. 81 | - bcd: LongTensor: In tensor perspective (\*i_1, k, \*i_2), (\*j_1, k, \*j_2) -> (\*i_1, \*i_2, \*j_1, \*j_2). 82 | The return indice is of shape (3, nnz), (b, c, d), c represent index of \*i, d represent index of \*j, b represent index of output.For i=1,2,...,nnz, val1[c[i]] * val2[d[i]] will be add to output val's b[i]-th element. 83 | 84 | Example: 85 | 86 | :: 87 | 88 | ind1 = torch.tensor([[0, 1, 1, 2], 89 | [2, 1, 0, 2]], dtype=torch.long) 90 | dim1 = 0 91 | ind2 = torch.tensor([[2, 1, 0, 1], 92 | [1, 0, 2, 2]], dtype=torch.long) 93 | dim2 = 1 94 | result = spspmm_ind(ind1, dim1, ind2, dim2) 95 | 96 | """ 97 | assert 0 <= dim1 < ind1.shape[ 98 | 0], f"ind1's reduced dim {dim1} is out of range" 99 | assert 0 <= dim2 < ind2.shape[ 100 | 0], f"ind2's reduced dim {dim2} is out of range" 101 | if dim2 != 0 and not (is_k2_sorted): 102 | perm = torch.argsort(ind2[dim2]) 103 | tarind, bcd = spspmm_ind(ind1, dim1, ind2[:, perm], dim2, True) 104 | bcd[2] = perm[bcd[2]] 105 | return tarind, bcd 106 | else: 107 | nnz1, nnz2, sparsedim1, sparsedim2 = ind1.shape[1], ind2.shape[ 108 | 1], ind1.shape[0], ind2.shape[0] 109 | k1, k2 = ind1[dim1], ind2[dim2] 110 | 111 | assert torch.all(torch.diff(k2) >= 0), "ind2[0] should be sorted" 112 | 113 | # for each k in k1, it can match a interval of k2 as k2 is sorted 114 | upperbound = torch.searchsorted(k2, k1, right=True) 115 | lowerbound = torch.searchsorted(k2, k1, right=False) 116 | matched_num = torch.clamp_min_(upperbound - lowerbound, 0) 117 | 118 | # ptr[i] provide the offset to place pair of ind1[:, i] and the matched ind2 119 | retptr = torch.zeros((nnz1 + 1), 120 | dtype=matched_num.dtype, 121 | device=matched_num.device) 122 | torch.cumsum(matched_num, dim=0, out=retptr[1:]) 123 | retsize = retptr[-1] 124 | 125 | # fill the output with ptr 126 | ret = torch.zeros((3, retsize), device=ind1.device, dtype=ind1.dtype) 127 | ret[1] = deg2batch(matched_num, retsize) 128 | torch.arange(retsize, out=ret[2], device=ret.device, dtype=ret.dtype) 129 | ret[2] += (lowerbound - retptr[:-1])[ret[1]] 130 | 131 | # compute the ind pair index 132 | combinedind = indicehash( 133 | torch.concat( 134 | ((torch.concat((ind1[:dim1], ind1[dim1 + 1:])))[:, ret[1]], 135 | torch.concat((ind2[:dim2], ind2[dim2 + 1:]))[:, ret[2]]))) 136 | combinedind, taridx = torch.unique(combinedind, 137 | sorted=True, 138 | return_inverse=True) 139 | tarind = decodehash(combinedind, sparsedim1 + sparsedim2 - 2) 140 | ret[0] = taridx 141 | 142 | sorted_idx = torch.argsort(ret[0]) # sort is optional 143 | return tarind, ret[:, sorted_idx] 144 | 145 | 146 | def spsphadamard_ind(tar_ind: LongTensor, ind: LongTensor) -> LongTensor: 147 | """ 148 | Auxiliary function for SparseTensor-SparseTensor Hadamard product. 149 | 150 | This function is an auxiliary function used in the Hadamard product of two sparse tensors. Given the indices `tar_ind` of sparse tensor A and the indices `ind` of sparse tensor B, this function returns an index array `b2a` of shape `(ind.shape[1],)` such that `ind[:, i]` matches `tar_ind[:, b2a[i]]` for each `i`. If `b2a[i]` is less than 0, it means `ind[:, i]` is not matched. 151 | 152 | Args: 153 | 154 | - tar_ind (LongTensor): The indices of sparse tensor A. 155 | - ind (LongTensor): The indices of sparse tensor B. 156 | 157 | Returns: 158 | 159 | - LongTensor: An index array `b2a` representing the matching indices between `tar_ind` and `ind`. 160 | b2a of shape ind.shape[1]. ind[:, i] matches tar_ind[:, b2a[i]]. if b2a[i]<0, ind[:, i] is not matched 161 | 162 | Example: 163 | 164 | :: 165 | 166 | tar_ind = torch.tensor([[0, 1, 1, 2], 167 | [2, 1, 0, 2]], dtype=torch.long) 168 | ind = torch.tensor([[2, 1, 0, 1], 169 | [1, 0, 2, 2]], dtype=torch.long) 170 | b2a = spsphadamard_ind(tar_ind, ind) 171 | 172 | """ 173 | assert tar_ind.shape[0] == ind.shape[0] 174 | combine_tar_ind = indicehash(tar_ind) 175 | assert torch.all(torch.diff(combine_tar_ind) > 176 | 0), "tar_ind should be sorted and coalesce" 177 | combine_ind = indicehash(ind) 178 | 179 | b2a = torch.clamp_min_( 180 | torch.searchsorted(combine_tar_ind, combine_ind, right=True) - 1, 0) 181 | notmatchmask = (combine_ind != combine_tar_ind[b2a]) 182 | b2a[notmatchmask] = -1 183 | return b2a 184 | 185 | 186 | def filterind(tar_ind: LongTensor, ind: LongTensor, 187 | bcd: LongTensor) -> LongTensor: 188 | """ 189 | A combination of Hadamard and Sparse Matrix Multiplication. 190 | 191 | Given the indices `tar_ind` of sparse tensor A, the indices `ind` of sparse tensor BC, and the index array `bcd`, this function returns an index array `acd`, where `(A ⊙ (BC)).val[a] = A.val[a] * scatter(B.val[c] * C.val[d], a)`. 192 | 193 | Args: 194 | 195 | - tar_ind (LongTensor): The indices of sparse tensor A. 196 | - ind (LongTensor): The indices of sparse tensor BC. 197 | - bcd (LongTensor): An index array representing `(BC).val`. 198 | 199 | Returns: 200 | 201 | - LongTensor: An index array `acd` representing the filtered indices. 202 | 203 | Example: 204 | 205 | :: 206 | 207 | tar_ind = torch.tensor([[0, 1, 1, 2], 208 | [2, 1, 0, 2]], dtype=torch.long) 209 | ind = torch.tensor([[2, 1, 0, 1], 210 | [1, 0, 2, 2]], dtype=torch.long) 211 | bcd = torch.tensor([[3, 2, 1, 0], 212 | [6, 5, 4, 3], 213 | [9, 8, 7, 6]], dtype=torch.long) 214 | acd = filterind(tar_ind, ind, bcd) 215 | 216 | 217 | """ 218 | b2a = spsphadamard_ind(tar_ind, ind) 219 | a = b2a[bcd[0]] 220 | retmask = a >= 0 221 | acd = torch.stack((a[retmask], bcd[1][retmask], bcd[2][retmask])) 222 | return acd 223 | 224 | 225 | def spsphadamard(A: SparseTensor, 226 | B: SparseTensor, 227 | b2a: Optional[LongTensor] = None) -> SparseTensor: 228 | """ 229 | Element-wise Hadamard product between two SparseTensors. 230 | 231 | This function performs the element-wise Hadamard product between two SparseTensors, `A` and `B`. The `b2a` parameter is an optional auxiliary index produced by the `spsphadamard_ind` function. 232 | 233 | Args: 234 | 235 | - A (SparseTensor): The first SparseTensor. 236 | - B (SparseTensor): The second SparseTensor. 237 | - b2a (LongTensor, optional): An optional index array produced by `spsphadamard_ind`. If not provided, it will be computed. 238 | 239 | Returns: 240 | 241 | - SparseTensor: A SparseTensor containing the result of the Hadamard product. 242 | 243 | 244 | Notes: 245 | 246 | - Both `A` and `B` must be coalesced SparseTensors. 247 | - The dense shapes of `A` and `B` must be broadcastable. 248 | """ 249 | assert A.is_coalesced(), "A should be coalesced" 250 | assert B.is_coalesced(), "B should be coalesced" 251 | assert A.sparseshape == B.sparseshape, "A, B should be of the same sparse shape" 252 | ind1, val1 = A.indices, A.values 253 | ind2, val2 = B.indices, B.values 254 | if b2a is None: 255 | b2a = spsphadamard_ind(ind1, ind2) 256 | mask = (b2a >= 0) 257 | if val1 is None: 258 | retval = val2[mask] 259 | elif val2 is None: 260 | retval = val1[b2a[mask]] 261 | else: 262 | retval = val1[b2a[mask]] * val2[mask] 263 | retind = ind2[:, mask] 264 | return SparseTensor(retind, 265 | retval, 266 | shape=A.sparseshape + retval.shape[1:], 267 | is_coalesced=True) 268 | 269 | 270 | def spspmm(A: SparseTensor, 271 | dim1: int, 272 | B: SparseTensor, 273 | dim2: int, 274 | aggr: str = "sum", 275 | bcd: Optional[LongTensor] = None, 276 | tar_ind: Optional[LongTensor] = None, 277 | acd: Optional[LongTensor] = None) -> SparseTensor: 278 | """ 279 | SparseTensor SparseTensor matrix multiplication at a specified sparse dimension. 280 | 281 | This function performs matrix multiplication between two SparseTensors, `A` and `B`, at the specified sparse dimensions `dim1` and `dim2`. The result is a SparseTensor containing the result of the multiplication. The `aggr` parameter specifies the reduction operation used for merging the resulting values. 282 | 283 | Args: 284 | 285 | - A (SparseTensor): The first SparseTensor. 286 | - dim1 (int): The dimension along which `A` is multiplied. 287 | - B (SparseTensor): The second SparseTensor. 288 | - dim2 (int): The dimension along which `B` is multiplied. 289 | - aggr (str, optional): The reduction operation to use for merging edge features ("sum", "min", "max", "mean"). Defaults to "sum". 290 | - bcd (LongTensor, optional): An optional auxiliary index array produced by spspmm_ind. 291 | - tar_ind (LongTensor, optional): An optional target index array for the output. If not provided, it will be computed. 292 | - acd (LongTensor, optional): An optional auxiliary index array produced by filterind. 293 | 294 | Returns: 295 | 296 | - SparseTensor: A SparseTensor containing the result of the matrix multiplication. 297 | 298 | Notes: 299 | 300 | - Both `A` and `B` must be coalesced SparseTensors. 301 | - The dense shapes of `A` and `B` must be broadcastable. 302 | - This function allows for optional indices `bcd` and `tar_ind` for improved performance and control. 303 | 304 | """ 305 | assert A.is_coalesced(), "A should be coalesced" 306 | assert B.is_coalesced(), "B should be coalesced" 307 | if acd is not None: 308 | assert tar_ind is not None 309 | if A.values is None: 310 | mult = B.values[acd[2]] 311 | elif B.values is None: 312 | mult = A.values[acd[1]] 313 | else: 314 | mult = A.values[acd[1]] * B.values[acd[2]] 315 | retval = torch_scatter_reduce(0, mult, acd[0], tar_ind.shape[1], aggr) 316 | return SparseTensor(tar_ind, 317 | retval, 318 | shape=A.sparseshape[:dim1] + 319 | A.sparseshape[dim1 + 1:] + B.sparseshape[:dim2] + 320 | B.sparseshape[dim2 + 1:] + retval.shape[1:], 321 | is_coalesced=True) 322 | else: 323 | warnings.warn("acd is not found") 324 | if bcd is None: 325 | ind, bcd = spspmm_ind(A.indices, dim1, B.indices, dim2) 326 | if tar_ind is not None: 327 | acd = filterind(tar_ind, ind, bcd) 328 | return spspmm(A, dim1, B, dim2, aggr, acd=acd, tar_ind=tar_ind) 329 | else: 330 | warnings.warn("tar_ind is not found") 331 | return spspmm(A, dim1, B, dim2, aggr, acd=bcd, tar_ind=ind) 332 | 333 | 334 | def spspmpnn(A: SparseTensor, 335 | dim1: int, 336 | B: SparseTensor, 337 | dim2: int, 338 | C: SparseTensor, 339 | acd: LongTensor, 340 | message_func: Callable[[Tensor, Tensor, Tensor, LongTensor], 341 | Tensor], 342 | aggr: str = "sum") -> SparseTensor: 343 | """ 344 | SparseTensor SparseTensor matrix multiplication at a specified sparse dimension using a message function. 345 | 346 | This function extend matrix multiplication between two SparseTensors, `A` and `B`, at the specified sparse dimensions `dim1` and `dim2`, while using a message function `message_func` to compute the messages sent from `A` to `B` and `C`. The result is a SparseTensor containing the result of the multiplication. The `aggr` parameter specifies the reduction operation used for merging the resulting values. 347 | 348 | Args: 349 | 350 | - A (SparseTensor): The first SparseTensor. 351 | - dim1 (int): The dimension along which `A` is multiplied. 352 | - B (SparseTensor): The second SparseTensor. 353 | - dim2 (int): The dimension along which `B` is multiplied. 354 | - C (SparseTensor): The third SparseTensor, providing the target indice 355 | - acd (LongTensor): The auxiliary index array produced by a previous operation. 356 | - message_func (Callable): A callable function that computes the messages between `A`, `B`, and `C`. 357 | - aggr (str, optional): The reduction operation to use for merging edge features ("sum", "min", "max", "mul", "any"). Defaults to "sum". 358 | 359 | Returns: 360 | 361 | - SparseTensor: A SparseTensor containing the result of the matrix multiplication. 362 | 363 | Notes: 364 | 365 | - Both `A` and `B` must be coalesced SparseTensors. 366 | - The dense shapes of `A`, `B`, and `C` must be broadcastable. 367 | - The `message_func` should take four arguments: `A_values`, `B_values`, `C_values`, and `acd`, and return messages based on custom logic. 368 | 369 | """ 370 | mult = message_func(None if A.values is None else A.values[acd[1]], 371 | None if B.values is None else B.values[acd[2]], 372 | None if C.values is None else C.values[acd[0]], acd[0]) 373 | tar_ind = C.indices 374 | retval = torch_scatter_reduce(0, mult, acd[0], tar_ind.shape[1], aggr) 375 | return SparseTensor(tar_ind, 376 | retval, 377 | shape=A.sparseshape[:dim1] + A.sparseshape[dim1 + 1:] + 378 | B.sparseshape[:dim2] + B.sparseshape[dim2 + 1:] + 379 | retval.shape[1:], 380 | is_coalesced=True) -------------------------------------------------------------------------------- /pygho/backend/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/PygHO/4cd5a20adf84c76842e001b54570d053747c9470/pygho/backend/__init__.py -------------------------------------------------------------------------------- /pygho/backend/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor, LongTensor 3 | from typing import Tuple 4 | 5 | 6 | def torch_scatter_reduce(dim: int, src: Tensor, ind: LongTensor, dim_size: int, 7 | aggr: str) -> Tensor: 8 | """ 9 | Applies a reduction operation to scatter elements from `src` to `dim_size` 10 | locations based on the indices in `ind`. 11 | 12 | This function is a wrapper for `torch.Tensor.scatter_reduce_` and is designed 13 | to scatter elements from `src` to `dim_size` locations based on the specified 14 | dimension `dim` and the indices in `ind`. The reduction operation is specified 15 | by the `aggr` parameter, which can be 'sum', 'mean', 'min', 'max'. 16 | 17 | Args: 18 | 19 | - dim (int): The dimension along which to scatter elements (only dim=0 is currently supported). 20 | - src (Tensor): The source tensor of shape (nnz, denseshape). 21 | - ind (LongTensor): The indices tensor of shape (nnz). 22 | - dim_size (int): The size of the target dimension for scattering. 23 | - aggr (str): The reduction operation to apply ('sum', 'mean', 'min', 'max', 'mul', 'any'). 24 | 25 | Returns: 26 | 27 | - Tensor: A tensor of shape (dim_size, denseshape) resulting from the scatter operation. 28 | 29 | Raises: 30 | 31 | - AssertionError: If `dim` is not 0, or if `ind` is not 1-dimensional. 32 | 33 | Example: 34 | 35 | :: 36 | src = torch.tensor([[1, 2], [4, 5], [7, 8], [9, 10]], dtype=torch.float) 37 | ind = torch.tensor([2, 2, 0, 1], dtype=torch.long) 38 | dim_size = 3 39 | aggr = 'sum' 40 | result = torch_scatter_reduce(0, src, ind, dim_size, aggr) 41 | 42 | 43 | """ 44 | assert dim == 0, "other dim not implemented" 45 | assert ind.ndim == 1, "indice must be 1-d" 46 | if aggr in ["min", "max"]: 47 | aggr = "a" + aggr 48 | onedim = src.ndim - 1 49 | dim_size = dim_size 50 | ret = torch.zeros_like(src[[0]].expand((dim_size, ) + (-1, ) * onedim)) 51 | ret.scatter_reduce_(dim, 52 | ind.reshape((-1, ) + (1, ) * onedim).expand_as(src), 53 | src, 54 | aggr, 55 | include_self=False) 56 | return ret -------------------------------------------------------------------------------- /pygho/hodata/MaData.py: -------------------------------------------------------------------------------- 1 | ''' 2 | utilities for dense high order data 3 | ''' 4 | from torch_geometric.data import Data as PygData, Batch as PygBatch 5 | import torch 6 | from torch import Tensor, LongTensor, BoolTensor 7 | from typing import Any, Callable, Optional, Tuple, List, Union, Iterable 8 | from ..backend.SpTensor import SparseTensor 9 | from ..backend.MaTensor import MaskedTensor 10 | from torch_geometric.utils import coalesce 11 | import torch 12 | 13 | 14 | class MaHoData(PygData): 15 | ''' 16 | a data class for dense high order graph data. 17 | ''' 18 | 19 | def __inc__(self, key: str, value: Any, *args, **kwargs): 20 | if key == 'edge_index': 21 | return 0 22 | return super().__inc__(key, value, *args, **kwargs) 23 | 24 | 25 | def to_dense_adj(edge_index: LongTensor, 26 | edge_batch: LongTensor, 27 | edge_attr: Optional[Tensor] = None, 28 | max_num_nodes: Optional[int] = None, 29 | batch_size: Optional[int] = None, 30 | filled_value: float = 0) -> MaskedTensor: 31 | ''' 32 | Convert sparse adjacency to dense matrix. 33 | 34 | Args: 35 | 36 | - edge_index (LongTensor): Coalesced edge indices of shape (2, nnz). 37 | - edge_batch (LongTensor): Batch assignments of shape (nnz). 38 | - edge_attr (Optional[Tensor]): Edge attributes of shape (nnz, \*). 39 | - max_num_nodes (Optional[int]): Maximum number of nodes in the graph. 40 | - batch_size (Optional[int]): Batch size. 41 | - filled_value (float): Value to fill in the dense matrix. 42 | 43 | Returns: 44 | 45 | - MaskedTensor: A masked dense tensor. 46 | 47 | ''' 48 | idx0 = edge_batch 49 | idx1 = edge_index[0] 50 | idx2 = edge_index[1] 51 | 52 | if max_num_nodes is None: 53 | max_num_nodes = edge_index.max().item() + 1 54 | 55 | if edge_attr is None: 56 | edge_attr = torch.ones(idx0.shape[0], device=edge_index.device) 57 | 58 | if batch_size is None: 59 | batch_size = torch.max(edge_batch).item() + 1 60 | 61 | size = [batch_size, max_num_nodes, max_num_nodes] + list( 62 | edge_attr.shape)[1:] 63 | ret = torch.empty(size, dtype=edge_attr.dtype, device=edge_attr.device) 64 | ret.fill_(filled_value) 65 | ret[idx0, idx1, idx2] = edge_attr 66 | mask = torch.zeros([batch_size, max_num_nodes, max_num_nodes], 67 | device=ret.device, 68 | dtype=torch.bool) 69 | mask[idx0, idx1, idx2] = True 70 | return MaskedTensor(ret, mask, filled_value, True) 71 | 72 | 73 | def to_sparse_adj(edge_index: LongTensor, 74 | edge_batch: LongTensor, 75 | edge_attr: Optional[Tensor] = None, 76 | max_num_nodes: Optional[int] = None, 77 | batch_size: Optional[int] = None) -> SparseTensor: 78 | ''' 79 | Convert sparse edge_index and edge_attr to a SparseTensor. 80 | 81 | Args: 82 | 83 | - edge_index (LongTensor): Coalesced edge indices of shape (2, nnz). 84 | - edge_batch (LongTensor): Batch assignments of shape (nnz). 85 | - edge_attr (Optional[Tensor]): Edge attributes of shape (nnz, \*). 86 | - max_num_nodes (Optional[int]): Maximum number of nodes in the graph. 87 | - batch_size (Optional[int]): Batch size. 88 | 89 | Returns: 90 | 91 | - SparseTensor: A sparse tensor representation. 92 | 93 | ''' 94 | if max_num_nodes is None: 95 | max_num_nodes = edge_index.max().item() + 1 96 | 97 | if batch_size is None: 98 | batch_size = torch.max(edge_batch).item() + 1 99 | 100 | size = [batch_size, max_num_nodes, max_num_nodes] 101 | size += list(edge_attr.size())[1:] 102 | return SparseTensor(torch.concatenate( 103 | (edge_batch.unsqueeze(0), edge_index), dim=0), 104 | edge_attr, 105 | shape=size, 106 | is_coalesced=False) 107 | 108 | 109 | def to_dense_x(nodeX: Tensor, 110 | Xptr: LongTensor, 111 | max_num_nodes: Optional[int] = None, 112 | batch_size: Optional[int] = None, 113 | filled_value: float = 0) -> MaskedTensor: 114 | ''' 115 | Convert node features of different subgraphs to a dense matrix. 116 | 117 | Args: 118 | 119 | - nodeX (Tensor): Node features. of shape (sum of number of nodes in a batch,\*denseshapeshape). 120 | - Xptr (LongTensor): Pointer to subgraphs. nodeX[Xptr[i]:Xptr[i+1]] represents the node feature for subgraph i 121 | - max_num_nodes (Optional[int]): Maximum number of nodes in a subgraph. 122 | - batch_size (Optional[int]): Batch size. 123 | - filled_value (float): Value to fill in the dense matrix. 124 | 125 | Returns: 126 | 127 | - MaskedTensor: A masked dense tensor. of shape (b, n,\*denseshapeshape). 128 | 129 | To align graphs of different sizes, padding is applied. 130 | 131 | ''' 132 | if batch_size is None: 133 | batch_size = Xptr.shape[0] - 1 134 | 135 | if max_num_nodes is None: 136 | max_num_nodes = torch.diff(Xptr).max().item() 137 | 138 | idx = torch.arange(max_num_nodes, device=nodeX.device).unsqueeze(0) 139 | idx = idx + Xptr[:-1].reshape(-1, 1) 140 | idx.clamp_max_(Xptr[-1] - 1) 141 | 142 | ret = nodeX[idx] 143 | mask = torch.ones((batch_size, max_num_nodes + 1), 144 | dtype=torch.bool, 145 | device=nodeX.device) 146 | mask[torch.arange(batch_size, device=nodeX.device), 147 | torch.diff(Xptr)] = False 148 | mask = mask.cummin(dim=-1)[0][:, :-1] 149 | return MaskedTensor(ret, mask, filled_value, False) 150 | 151 | 152 | def to_dense_tuplefeat( 153 | tuplefeat: Tensor, 154 | tupleshape: LongTensor, 155 | tuplefeatptr: LongTensor, 156 | max_tupleshape: Optional[LongTensor] = None, 157 | batch_size: Optional[int] = None, 158 | feat2mask: Callable[[Tensor], BoolTensor] = None) -> MaskedTensor: 159 | ''' 160 | Convert tuple features of different subgraphs to a dense matrix. 161 | 162 | Args: 163 | 164 | - tuplefeat (Tensor): Tuple features. (total number of tuples in batch,\*denseshapeshape) 165 | - tupleshape (LongTensor): Shape of tuple features. 166 | - tuplefeatptr (LongTensor): Pointer to tuple features. tuplefeat[tuplefeatptr[i]:tuplefeatptr[i+1]] represents the tuple feature for subgraph i 167 | - max_tupleshape (Optional[LongTensor]): Maximum shape of tuple features. 168 | - batch_size (Optional[int]): Batch size. 169 | - feat2mask (Callable[[Tensor], BoolTensor]): Function to generate masks for tuple features. 170 | 171 | Returns: 172 | 173 | - MaskedTensor: A masked dense tensor. of shape (b, n1, n2,..,\*denseshapeshape), whose ret[i] is of subgraph i. (n1, n2,...) is the maximum sizes of the tuplefeat of subgraphs. 174 | 175 | To align tuple features of different sizes, padding is applied. 176 | 177 | ''' 178 | if batch_size is None: 179 | batch_size = tupleshape.shape[0] 180 | 181 | if max_tupleshape is None: 182 | max_tupleshape = torch.amax(tupleshape, dim=0) 183 | 184 | ndim = max_tupleshape.shape[0] 185 | fullidx = tuplefeatptr[:-1].reshape([-1] + [1] * ndim) 186 | cumshape = torch.ones_like(tupleshape[:, [0]]) 187 | # print(cumshape.shape) 188 | for i in range(ndim): 189 | tidx = (torch.arange(max_tupleshape[-i - 1], device=tuplefeat.device) * 190 | cumshape).reshape([batch_size] + [1] * (ndim - i - 1) + [-1] + 191 | [1] * i) 192 | # print(fullidx.shape, tidx.shape, max_tupleshape, ndim) 193 | fullidx = fullidx + tidx 194 | cumshape = cumshape * tupleshape[:, [-i - 1]] 195 | fullidx.clamp_max_(tuplefeat.shape[0] - 1) 196 | ret = tuplefeat[fullidx] 197 | if feat2mask is not None: 198 | mask = feat2mask(ret) 199 | else: 200 | mask = torch.ones([batch_size] + max_tupleshape.tolist(), 201 | device=ret.device, 202 | dtype=torch.bool) 203 | 204 | for i in range(ndim): 205 | tmask = torch.ones([batch_size] + [max_tupleshape[i] + 1] + [1] * 206 | (ndim - 1), 207 | dtype=torch.bool, 208 | device=ret.device) 209 | tmask[torch.arange(batch_size, device=ret.device), 210 | tupleshape[:, i]] = False 211 | tmask = torch.cummin(tmask, dim=1)[0] 212 | tmask = tmask[:, :-1] 213 | tmask = torch.movedim(tmask, 1, i + 1) 214 | mask.logical_and_(tmask) 215 | return MaskedTensor(ret, mask, 0, False) 216 | 217 | 218 | def batch2dense(batch: PygBatch, 219 | batch_size: int = None, 220 | max_num_nodes: int = None, 221 | denseadj: bool = False, 222 | keys: List[str] = [""]) -> PygBatch: 223 | ''' 224 | A main wrapper for converting and padding data in a batch object to dense forms. 225 | 226 | Args: 227 | 228 | - batch (PygBatch): The input batch object. 229 | - batch_size (int): Batch size. 230 | - max_num_nodes (int): Maximum number of nodes in the graph. 231 | - denseadj (bool): Whether to convert adjacency to dense or sparse. 232 | - keys (List[str]): List of keys for additional attributes. 233 | 234 | Returns: 235 | 236 | - PygBatch: The processed batch object. 237 | 238 | ''' 239 | 240 | batch.x = to_dense_x(batch.x, batch.ptr, max_num_nodes, batch_size) 241 | batch_size, max_num_nodes = batch.x.shape[0], batch.x.shape[1] 242 | if denseadj: 243 | batch.A = to_dense_adj(batch.edge_index, batch.edge_index_batch, 244 | batch.edge_attr, max_num_nodes, batch_size) 245 | else: 246 | batch.A = to_sparse_adj(batch.edge_index, batch.edge_index_batch, 247 | batch.edge_attr, max_num_nodes, batch_size) 248 | for key in keys: 249 | tuplefeat = getattr(batch, f"tuplefeat{key}") 250 | tupleshape = getattr(batch, f"tupleshape{key}") 251 | tuplefeat_ptr = getattr(batch, f"tuplefeat{key}_ptr") 252 | X = to_dense_tuplefeat(tuplefeat, tupleshape, tuplefeat_ptr, None, 253 | batch_size, None) 254 | setattr(batch, f"X{key}", X) 255 | return batch 256 | 257 | 258 | def ma_datapreprocess(data: PygData, 259 | tuplesamplers: List[Callable[[PygData], 260 | Tuple[Tensor, List[int]]]], 261 | annotate: List[str] = [""]) -> MaHoData: 262 | ''' 263 | A wrapper for preprocessing dense data. 264 | 265 | Args: 266 | 267 | - data (PygData): Input data object. 268 | - tuplesamplers (Union[Callable[[PygData], Tuple[Tensor, List[int]]], List[Callable[[PygData], Tuple[Tensor, List[int]]]]]): Tuple samplers for extracting data. 269 | - annotate (List[str]): List of annotation strings. 270 | 271 | Returns: 272 | 273 | - MaHoData: Preprocessed data object. 274 | 275 | ''' 276 | assert len(tuplesamplers) == len( 277 | annotate), "each tuplesampler need a different annotate" 278 | data.edge_index, data.edge_attr = coalesce(data.edge_index, 279 | data.edge_attr, 280 | num_nodes=data.num_nodes) 281 | 282 | datadict = data.to_dict() 283 | datadict.update({ 284 | "num_nodes": data.num_nodes, 285 | "num_edges": data.edge_index.shape[1], 286 | "x": data.x, 287 | "edge_index": data.edge_index, 288 | "edge_attr": data.edge_attr 289 | }) 290 | for i, tuplesampler in enumerate(tuplesamplers): 291 | tuplefeat, tupleshape = tuplesampler(data) 292 | datadict.update({ 293 | f"tuplefeat{annotate[i]}": 294 | tuplefeat, 295 | f"tupleshape{annotate[i]}": 296 | torch.LongTensor(tupleshape).reshape(1, -1), 297 | }) 298 | 299 | return MaHoData(**datadict) 300 | -------------------------------------------------------------------------------- /pygho/hodata/MaTupleSampler.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.data import Data as PygData 2 | import torch 3 | from torch import Tensor 4 | from typing import Tuple, List 5 | from torch_geometric.utils import to_scipy_sparse_matrix 6 | import scipy.sparse as ssp 7 | import numpy as np 8 | import scipy.linalg as spl 9 | 10 | 11 | def spdsampler(data: PygData, hop: int = 2) -> Tuple[Tensor, List[int]]: 12 | """ 13 | sample k-hop subgraph on a given PyG graph. 14 | 15 | Args: 16 | 17 | - data (PygData): The input PyG dataset. 18 | - hop (int, optional): The number of hops for subgraph sampling. Defaults to 2. 19 | 20 | Returns: 21 | 22 | - Tensor: the precomputed tuple features. 23 | - List[int]: the masked shape of the features. 24 | """ 25 | adj = to_scipy_sparse_matrix(data.edge_index, num_nodes=data.num_nodes) 26 | dist_matrix = ssp.csgraph.shortest_path(adj, 27 | directed=False, 28 | unweighted=True, 29 | return_predecessors=False) 30 | ret = torch.LongTensor(dist_matrix).flatten() 31 | ret.clamp_max_(hop+1) 32 | return ret.reshape(-1), [data.num_nodes, data.num_nodes] 33 | 34 | 35 | def rdsampler(data: PygData) -> Tuple[Tensor, List[int]]: 36 | """ 37 | compute resistance distance between nodes. 38 | 39 | Args: 40 | 41 | - data (PygData): The input PyG dataset. 42 | - hop (int, optional): The number of hops for subgraph sampling. Defaults to 2. 43 | 44 | Returns: 45 | 46 | - Tensor: the precomputed tuple features. 47 | - List[int]: the masked shape of the features. 48 | """ 49 | adj = to_scipy_sparse_matrix(data.edge_index, num_nodes=data.num_nodes) 50 | laplacian = ssp.csgraph.laplacian(adj).toarray() 51 | laplacian += 0.01 * np.eye(*laplacian.shape) 52 | assert spl.issymmetric(laplacian), "please use symmetric graph" 53 | L_inv = np.linalg.pinv(laplacian, hermitian=True) 54 | dL = np.diagonal(L_inv) 55 | return torch.FloatTensor( 56 | (dL.reshape(-1, 1) + dL.reshape(1, -1) - L_inv - L_inv.T)).reshape( 57 | -1, 1), [data.num_nodes, data.num_nodes] 58 | -------------------------------------------------------------------------------- /pygho/hodata/ParallelPreprocess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import InMemoryDataset, Data as PygData 3 | from typing import Callable, Optional, Iterable 4 | from multiprocessing import Pool 5 | from pqdm.processes import pqdm 6 | from tqdm import tqdm 7 | from .Wrapper import _repr 8 | import os.path as osp 9 | 10 | 11 | class ParallelPreprocessDataset(InMemoryDataset): 12 | ''' 13 | Parallelly transform a PyG dataset. 14 | 15 | This dataset class allows parallel preprocessing of a list of PyGData or PyGDataset instances. 16 | 17 | Args: 18 | 19 | - root (str): The directory to save processed data. 20 | - data_list (Iterable[PygData]): A list of PygData or PygDataset instances. 21 | - pre_transform (Callable[[PygData], PygData]): A function that maps PygData to PygData. It is executed only once for all data and is typically a tuple sampler. 22 | - num_worker (int): The number of processes for parallel preprocessing. It can be set to the number of available CPU cores. 23 | - processedname (Optional[str]): The name to save the processed data. If None, the name will be a hash of the pre_transform function. 24 | - transform (Optional[Callable[[PygData], PygData]]): A function to dynamically transform data during data loading. 25 | ''' 26 | 27 | def __init__(self, 28 | root: str, 29 | data_list: Iterable[PygData], 30 | pre_transform: Callable[[PygData], PygData], 31 | num_worker: int, 32 | processedname: Optional[str] = None, 33 | transform: Optional[Callable[[PygData], PygData]] = None): 34 | self.tmp_data_list = list(data_list) 35 | self.num_worker = num_worker 36 | self.processedname = processedname 37 | super().__init__(root, 38 | pre_transform=pre_transform, 39 | transform=transform) 40 | self.data, self.slices = torch.load(self.processed_paths[0]) 41 | 42 | @property 43 | def processed_file_names(self): 44 | return 'data.pt' 45 | 46 | @property 47 | def processed_dir(self) -> str: 48 | if self.processedname is None: 49 | return osp.join( 50 | self.root, 51 | f'processed__{_repr(self.pre_transform)}__{_repr(self.pre_filter)}' 52 | ) 53 | else: 54 | return osp.join(self.root, f'processed__{self.processedname}') 55 | 56 | def process(self): 57 | if self.num_worker > 0: 58 | data_list = pqdm(self.tmp_data_list, 59 | self.pre_transform, 60 | n_jobs=self.num_worker) 61 | else: 62 | data_list = [ 63 | self.pre_transform(_) for _ in tqdm(self.tmp_data_list) 64 | ] 65 | torch.save(self.collate(data_list), self.processed_paths[0]) 66 | -------------------------------------------------------------------------------- /pygho/hodata/SpData.py: -------------------------------------------------------------------------------- 1 | ''' 2 | utilities for sparse high order data 3 | ''' 4 | from torch_geometric.data import Data as PygData, Batch as PygBatch 5 | import torch 6 | from typing import Any, List, Callable, Union, Tuple, Iterable 7 | from torch import Tensor 8 | from ..backend.Spspmm import spspmm_ind, filterind 9 | from ..backend.SpTensor import SparseTensor 10 | from ..honn.SpOperator import KEYSEP 11 | from torch_geometric.utils import coalesce 12 | 13 | 14 | def parseop(op: str): 15 | ''' 16 | Get the increment for a tensor when combining graphs. 17 | 18 | Args: 19 | 20 | - op (str): The operator string. 21 | 22 | Returns: 23 | 24 | - str or NotImplementedError: The increment information or NotImplementedError if the operator is not implemented. 25 | ''' 26 | if op[0] == "X": 27 | return f"num_tuples{op[1:]}" 28 | elif op == "A": 29 | return "num_edges" 30 | else: 31 | return NotImplementedError, f"operator name {op} not implemented now" 32 | 33 | 34 | def parsekey(key: str) -> Tuple[str, str, int, str, int]: 35 | ''' 36 | Parse the operators in precomputation keys. 37 | 38 | Args: 39 | 40 | - key (str): The precomputation key. 41 | 42 | Returns: 43 | 44 | - Tuple[str, str, int, str, int]: A tuple containing parsed operators and dimensions. 45 | ''' 46 | assert len(key.split(KEYSEP)) == 5, "key format not match" 47 | op0, op1, dim1, op2, dim2 = key.split(KEYSEP) 48 | dim1 = int(dim1) 49 | dim2 = int(dim2) 50 | parseop(op0) 51 | parseop(op1) 52 | parseop(op2) 53 | return op0, op1, dim1, op2, dim2 54 | 55 | 56 | class SpHoData(PygData): 57 | ''' 58 | A data class for sparse high order graph data. 59 | ''' 60 | def __inc__(self, key: str, value: Any, *args, **kwargs): 61 | if key.startswith('tupleid'): 62 | return getattr(self, 63 | "tupleshape" + key.removeprefix("tupleid")).reshape( 64 | -1, 1) 65 | if key.endswith(f"{KEYSEP}acd"): 66 | key = key.removesuffix(f"{KEYSEP}acd") 67 | op0, op1, _, op2, _ = parsekey(key) 68 | return torch.tensor( 69 | [[getattr(self, parseop(op0))], [getattr(self, parseop(op1))], 70 | [getattr(self, parseop(op2))]], 71 | dtype=torch.long) 72 | return super().__inc__(key, value, *args, **kwargs) 73 | 74 | def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any: 75 | if key.startswith('tupleid') or key.endswith(f"{KEYSEP}acd"): 76 | return 1 77 | return super().__cat_dim__(key, value, *args, **kwargs) 78 | 79 | 80 | def batch2sparse(batch: PygBatch, keys: List[str] = [""]) -> PygBatch: 81 | ''' 82 | A main wrapper for converting data in a batch object to SparseTensor. 83 | 84 | Args: 85 | 86 | - batch (PygBatch): The batch object containing graph data. 87 | - keys (List[str]): The list of keys to convert to SparseTensor. 88 | 89 | Returns: 90 | 91 | - PygBatch: The batch object with converted data. 92 | ''' 93 | batch.A = SparseTensor( 94 | batch.edge_index, 95 | batch.edge_attr, 96 | [batch.num_nodes, batch.num_nodes] if batch.edge_attr is None else 97 | [batch.num_nodes, batch.num_nodes] + list(batch.edge_attr.shape[1:]), 98 | is_coalesced=True) 99 | for key in keys: 100 | # print("key=", key) 101 | totaltupleshape = getattr(batch, 102 | f"tupleshape{key}").sum(dim=0).tolist() 103 | tupleid = getattr(batch, f"tupleid{key}") 104 | tuplefeat = getattr(batch, f"tuplefeat{key}") 105 | X = SparseTensor( 106 | tupleid, 107 | tuplefeat, 108 | shape=totaltupleshape if tuplefeat is None else totaltupleshape + 109 | list(tuplefeat.shape[1:]), 110 | is_coalesced=True) 111 | setattr(batch, f"X{key}", X) 112 | return batch 113 | 114 | 115 | def sp_datapreprocess(data: PygData, 116 | tuplesamplers: List[Callable[[PygData], SparseTensor]], 117 | annotate: List[str] = [""], 118 | keys: List[str] = [""]) -> SpHoData: 119 | ''' 120 | A wrapper for preprocessing dense data for sparse high order graphs. 121 | 122 | Args: 123 | 124 | - data (PygData): The input dense data in PyG Data format. 125 | - tuplesamplers (Union[Callable, List[Callable]]): A single or list of tuple sampling functions. 126 | - annotate (List[str]): A list of annotation strings for tuple sampling. 127 | - keys (List[str]): A list of precomputation keys. 128 | 129 | Returns: 130 | 131 | - SpHoData: The preprocessed sparse high order data in SpHoData format. 132 | ''' 133 | data.edge_index, data.edge_attr = coalesce(data.edge_index, 134 | data.edge_attr, 135 | num_nodes=data.num_nodes) 136 | 137 | assert len(tuplesamplers) == len( 138 | annotate 139 | ), "number of tuple sampler should match the number of annotate" 140 | 141 | datadict = data.to_dict() 142 | datadict.update({ 143 | "num_nodes": data.num_nodes, 144 | "num_edges": data.edge_index.shape[1], 145 | "x": data.x, 146 | "edge_index": data.edge_index, 147 | "edge_attr": data.edge_attr, 148 | }) 149 | for i, tuplesampler in enumerate(tuplesamplers): 150 | feat = tuplesampler(data) 151 | tupleid, tuplefeat, tupleshape = feat.indices, feat.values, feat.sparseshape 152 | num_tuples = tupleid.shape[1] 153 | datadict.update({ 154 | f"tupleid{annotate[i]}": 155 | tupleid, 156 | f"tuplefeat{annotate[i]}": 157 | tuplefeat, 158 | f"tupleshape{annotate[i]}": 159 | torch.LongTensor(tupleshape).reshape(1, -1), 160 | f"num_tuples{annotate[i]}": 161 | num_tuples 162 | }) 163 | for key in keys: 164 | op0, op1, dim1, op2, dim2 = parsekey(key) 165 | datadict[key + f"{KEYSEP}acd"] = filterind( 166 | datadict[f"tupleid{op0[1:]}"] 167 | if op0[0] == "X" else datadict["edge_index"], 168 | *spspmm_ind( 169 | datadict[f"tupleid{op1[1:]}"] if op1[0] == "X" else 170 | datadict["edge_index"], dim1, datadict[f"tupleid{op2[1:]}"] 171 | if op2[0] == "X" else datadict["edge_index"], dim2)) 172 | return SpHoData(**datadict) -------------------------------------------------------------------------------- /pygho/hodata/SpTupleSampler.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.data import Data as PygData, Batch as PygBatch 2 | from torch_geometric.utils import to_scipy_sparse_matrix, k_hop_subgraph 3 | import torch 4 | from typing import List, Optional, Tuple, Union 5 | from torch import Tensor, LongTensor 6 | from torch_geometric.utils.num_nodes import maybe_num_nodes 7 | from typing import Tuple 8 | import scipy.sparse as ssp 9 | from ..backend.SpTensor import coalesce, SparseTensor 10 | 11 | 12 | def k_hop_subgraph( 13 | node_idx: Union[int, List[int], LongTensor], 14 | num_hops: int, 15 | edge_index: LongTensor, 16 | relabel_nodes: bool = False, 17 | num_nodes: Optional[int] = None, 18 | flow: str = 'source_to_target', 19 | directed: bool = False, 20 | ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: 21 | """ 22 | Compute the k-hop subgraph around a set of nodes in an edge list. 23 | 24 | Args: 25 | 26 | - node_idx (Union[int, List[int], LongTensor]): The root node(s) for the subgraph. 27 | - num_hops (int): The number of hops for the subgraph. 28 | - edge_index (LongTensor): The edge indices of the graph. 29 | - relabel_nodes (bool, optional): Whether to relabel node indices. Defaults to False. 30 | - num_nodes (Optional[int], optional): The total number of nodes. Defaults to None. 31 | - flow (str, optional): The direction of traversal ('source_to_target' or 'target_to_source'). Defaults to 'source_to_target'. 32 | - directed (bool, optional): Whether the graph is directed. Defaults to False. 33 | 34 | Returns: 35 | 36 | Tuple[Tensor, Tensor, Tensor, Tensor]: A tuple containing: 37 | - subset (Tensor): The node indices in the subgraph. 38 | - edge_index (Tensor): The edge indices of the subgraph. 39 | - inv (Tensor): The inverse mapping of node indices in the original graph to the subgraph. 40 | - edge_mask (Tensor): A mask indicating which edges are part of the subgraph. 41 | - dist (Tensor): A distance of each node to the root node. 42 | """ 43 | num_nodes = maybe_num_nodes(edge_index, num_nodes) 44 | 45 | assert flow in ['source_to_target', 'target_to_source'] 46 | if flow == 'target_to_source': 47 | row, col = edge_index 48 | else: 49 | col, row = edge_index 50 | 51 | node_mask = row.new_empty(num_nodes, dtype=torch.bool) 52 | edge_mask = row.new_empty(row.size(0), dtype=torch.bool) 53 | 54 | if isinstance(node_idx, (int, list, tuple)): 55 | node_idx = torch.tensor([node_idx], device=row.device).flatten() 56 | else: 57 | node_idx = node_idx.to(row.device) 58 | 59 | subsets = [node_idx] 60 | dist = torch.empty_like(node_mask, dtype=torch.long).fill_(num_nodes + 1) 61 | for _ in range(num_hops): 62 | node_mask.fill_(False) 63 | node_mask[subsets[-1]] = True 64 | torch.index_select(node_mask, 0, row, out=edge_mask) 65 | subsets.append(col[edge_mask]) 66 | 67 | for _ in range(num_hops, -1, -1): 68 | dist[subsets[_]] = _ 69 | 70 | subset, inv = torch.cat(subsets).unique(return_inverse=True) 71 | inv = inv[:node_idx.numel()] 72 | 73 | dist = dist[subset] 74 | 75 | node_mask.fill_(False) 76 | node_mask[subset] = True 77 | 78 | if not directed: 79 | edge_mask = node_mask[row] & node_mask[col] 80 | 81 | edge_index = edge_index[:, edge_mask] 82 | 83 | if relabel_nodes: 84 | node_idx = row.new_full((num_nodes, ), -1) 85 | node_idx[subset] = torch.arange(subset.size(0), device=row.device) 86 | edge_index = node_idx[edge_index] 87 | 88 | return subset, edge_index, inv, edge_mask, dist 89 | 90 | 91 | def KhopSampler( 92 | data: PygData, 93 | hop: int = 2) -> SparseTensor: 94 | """ 95 | sample k-hop subgraph on a given PyG graph. 96 | 97 | Args: 98 | 99 | - data (PygData): The input PyG dataset. 100 | - hop (int, optional): The number of hops for subgraph sampling. Defaults to 2. 101 | 102 | Returns: 103 | 104 | SparseTensor for the precomputed tuple features. 105 | """ 106 | 107 | subgraphs = [] 108 | 109 | for i in range(data.num_nodes): 110 | subset, _, _, _, dist = k_hop_subgraph(i, 111 | hop, 112 | data.edge_index, 113 | relabel_nodes=True, 114 | num_nodes=data.num_nodes) 115 | assert subset.shape[0] > 1, "empty subgraph!" 116 | nodeidx1 = subset.clone() 117 | nodeidx1.fill_(i) 118 | subgraphs.append( 119 | PygData( 120 | x=dist, 121 | subg_nodeidx=torch.stack((nodeidx1, subset), dim=-1), 122 | num_nodes=subset.shape[0], 123 | )) 124 | subgbatch = PygBatch.from_data_list(subgraphs) 125 | tupleid, tuplefeat = subgbatch.subg_nodeidx.t(), subgbatch.x 126 | return SparseTensor(tupleid, tuplefeat, shape=2*[data.num_nodes]+list(tuplefeat.shape[1:]), is_coalesced=False, reduce="min") 127 | 128 | 129 | def I2Sampler( 130 | data: PygData, 131 | hop: int = 3) -> SparseTensor: 132 | """ 133 | Perform subgraph sampling on a given graph for I2GNN. 134 | 135 | Args: 136 | 137 | - data (PygData): The input PyG dataset. 138 | - hop (int, optional): The number of hops for subgraph sampling. Defaults to 3. 139 | 140 | Returns: 141 | 142 | SparseTensor for the precomputed tuple features. 143 | """ 144 | subgraphs = [] 145 | spadj = to_scipy_sparse_matrix(data.edge_index, num_nodes=data.num_nodes) 146 | dist_matrix = torch.from_numpy( 147 | ssp.csgraph.shortest_path(spadj, 148 | directed=False, 149 | unweighted=True, 150 | return_predecessors=False)).to(torch.long) 151 | ei = data.edge_index 152 | for i in range(ei.shape[1]): 153 | nodepair = ei[:, i] 154 | subset, _, _, _, _ = k_hop_subgraph(nodepair, 155 | hop, 156 | data.edge_index, 157 | relabel_nodes=True, 158 | num_nodes=data.num_nodes) 159 | assert subset.shape[0] > 1, "empty subgraph!" 160 | nodeidx1 = subset.clone() 161 | nodeidx1.fill_(nodepair[0]) 162 | nodeidx2 = subset.clone() 163 | nodeidx2.fill_(nodepair[1]) 164 | subgraphs.append( 165 | PygData( 166 | x=torch.stack((dist_matrix[nodepair[0].item()][subset], 167 | dist_matrix[nodepair[1].item()][subset]), 168 | dim=-1), 169 | subg_nodeidx=torch.stack((nodeidx1, nodeidx2, subset), dim=-1), 170 | num_nodes=subset.shape[0], 171 | )) 172 | subgbatch = PygBatch.from_data_list(subgraphs) 173 | tupleid, tuplefeat = subgbatch.subg_nodeidx.t(), subgbatch.x 174 | return SparseTensor(tupleid, tuplefeat, shape=3*[data.num_nodes]+list(tuplefeat.shape[1:]), is_coalesced=False, reduce="min") -------------------------------------------------------------------------------- /pygho/hodata/Wrapper.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.dataloader import _BaseDataLoaderIter 2 | from torch_geometric.data import Data as PygData, Dataset 3 | from torch_geometric.data.data import BaseData 4 | from torch_geometric.data.datapipes import DatasetAdapter 5 | from torch_geometric.loader import DataLoader as PygDataLoader 6 | import re 7 | from typing import Any, Callable, List, Iterable, Sequence, Tuple, Union, Optional 8 | from torch import Tensor 9 | from functools import partial 10 | from .SpData import sp_datapreprocess, batch2sparse 11 | from .MaData import ma_datapreprocess, batch2dense 12 | from torch_geometric.transforms import Compose 13 | from ..backend.SpTensor import SparseTensor 14 | from ..backend.MaTensor import MaskedTensor 15 | 16 | 17 | def _repr(obj: Any) -> str: 18 | if obj is None: 19 | return 'None' 20 | ret = re.sub('at 0x[0-9a-fA-F]+', "", str(obj)) 21 | ret = ret.replace("\n", " ") 22 | ret = ret.replace("functools.partial", " ") 23 | ret = ret.replace("function", " ") 24 | ret = ret.replace("<", " ") 25 | ret = ret.replace(">", " ") 26 | ret = ret.replace(" ", "") 27 | return ret 28 | 29 | 30 | def Sppretransform(tuplesamplers: List[Callable[[PygData], SparseTensor]] 31 | | Callable[[PygData], SparseTensor], 32 | annotate: List[str] = [""], 33 | keys: List[str] = [""]): 34 | """ 35 | Create a data pre-transformation function for sparse data. 36 | 37 | Args: 38 | 39 | - tuplesamplers (Union[Callable[[PygData], Tuple[Tensor, Tensor, Union[List[int], int]]], List[Callable[[PygData], Tuple[Tensor, Tensor, Union[List[int], int]]]]]): A tuple sampler or a list of tuple samplers. 40 | - annotate (List[str], optional): A list of annotations. Defaults to [""]. 41 | - keys (List[str], optional): A list of keys. Defaults to [""]. 42 | 43 | Returns: 44 | 45 | - Callable: A data pre-transformation function. 46 | """ 47 | if not isinstance(tuplesamplers, Iterable): 48 | tuplesamplers = [tuplesamplers] 49 | hopre_transform = partial(sp_datapreprocess, 50 | tuplesamplers=tuplesamplers, 51 | annotate=annotate, 52 | keys=keys) 53 | return hopre_transform 54 | 55 | 56 | def Mapretransform(tuplesamplers: List[Callable[[PygData], MaskedTensor]] 57 | | Callable[[PygData], MaskedTensor], 58 | annotate: List[str] = [""]): 59 | """ 60 | Create a data pre-transformation function for dense data. 61 | 62 | Args: 63 | 64 | - tuplesamplers (Union[Callable[[PygData], Tuple[Tensor, List[int]]], List[Callable[[PygData], Tuple[Tensor, List[int]]]]]): A tuple sampler or a list of tuple samplers. 65 | - annotate (List[str], optional): A list of annotations. Defaults to [""]. 66 | 67 | Returns: 68 | 69 | - Callable: A data pre-transformation function. 70 | """ 71 | if not isinstance(tuplesamplers, Iterable): 72 | tuplesamplers = [tuplesamplers] 73 | hopre_transform = partial(ma_datapreprocess, 74 | tuplesamplers=tuplesamplers, 75 | annotate=annotate) 76 | return hopre_transform 77 | 78 | 79 | class IterWrapper: 80 | """ 81 | A wrapper for the iterator of a data loader. 82 | """ 83 | 84 | def __init__(self, iterator: Iterable, batch_transform: Callable, 85 | device) -> None: 86 | self.iterator = iterator 87 | self.device = device 88 | self.batch_transform = batch_transform 89 | 90 | def __next__(self): 91 | batch = next(self.iterator) 92 | if self.device is not None: 93 | ''' 94 | sparse batch is usually smaller than dense batch and the to device takes less time 95 | ''' 96 | batch = batch.to(self.device, non_blocking=True) 97 | batch = self.batch_transform(batch) 98 | return batch 99 | 100 | 101 | class SpDataloader(PygDataLoader): 102 | """ 103 | A data loader for sparse data that converts the inner data format to SparseTensor. 104 | 105 | Args: 106 | 107 | - dataset (Dataset | Sequence[BaseData] | DatasetAdapter): The input dataset or data sequence. 108 | - device (optional): The device to place the data on. Defaults to None. 109 | - \*\*kwargs: Additional keyword arguments for DataLoader. Same as Pyg Dataloader. 110 | """ 111 | 112 | def __init__(self, 113 | dataset: Dataset | Sequence[BaseData] | DatasetAdapter, 114 | batch_size: int = 1, 115 | shuffle: bool = False, 116 | follow_batch: List[str] | None = None, 117 | exclude_keys: List[str] | None = None, 118 | device=None, 119 | **kwargs): 120 | super().__init__(dataset, batch_size, shuffle, follow_batch, 121 | exclude_keys, **kwargs) 122 | self.device = device 123 | keys = [ 124 | k.removeprefix("tupleid") for k in dataset[0].to_dict().keys() 125 | if k.startswith("tupleid") 126 | ] 127 | self.keys = keys 128 | 129 | def __iter__(self) -> _BaseDataLoaderIter: 130 | ret = super().__iter__() 131 | return IterWrapper(ret, partial(batch2sparse, keys=self.keys), 132 | self.device) 133 | 134 | 135 | class MaDataloader(PygDataLoader): 136 | """ 137 | A data loader for sparse data that converts the inner data format to MaskedTensor. 138 | 139 | Args: 140 | 141 | - dataset (Dataset | Sequence[BaseData] | DatasetAdapter): The input dataset or data sequence. 142 | - device (optional): The device to place the data on. Defaults to None. 143 | - denseadj (bool, optional): Whether to use dense adjacency. Defaults to True. 144 | - other kwargs: Additional keyword arguments for DataLoader. Same as Pyg dataloader 145 | 146 | """ 147 | 148 | def __init__(self, 149 | dataset: Dataset | Sequence[BaseData] | DatasetAdapter, 150 | batch_size: int = 1, 151 | shuffle: bool = False, 152 | follow_batch: List[str] | None = None, 153 | exclude_keys: List[str] | None = None, 154 | device=None, 155 | denseadj: bool = True, 156 | **kwargs): 157 | if follow_batch is None: 158 | follow_batch = [] 159 | keys = [ 160 | k.removeprefix("tuplefeat") for k in dataset[0].to_dict().keys() 161 | if k.startswith("tuplefeat") 162 | ] 163 | self.keys = keys 164 | for i in ["edge_index"] + [f"tuplefeat{_}" for _ in keys]: 165 | if i not in follow_batch: 166 | follow_batch.append(i) 167 | super().__init__(dataset, batch_size, shuffle, follow_batch, 168 | exclude_keys, **kwargs) 169 | self.device = device 170 | self.denseadj = denseadj 171 | 172 | def __iter__(self) -> _BaseDataLoaderIter: 173 | ret = super().__iter__() 174 | return IterWrapper( 175 | ret, partial(batch2dense, keys=self.keys, denseadj=self.denseadj), 176 | self.device) 177 | -------------------------------------------------------------------------------- /pygho/hodata/__init__.py: -------------------------------------------------------------------------------- 1 | from .Wrapper import SpDataloader, Sppretransform, MaDataloader, Mapretransform 2 | from .ParallelPreprocess import ParallelPreprocessDataset -------------------------------------------------------------------------------- /pygho/honn/Conv.py: -------------------------------------------------------------------------------- 1 | """ 2 | Representative GNN layers built upon message passing operations. 3 | For all module, A means adjacency matrix, X means tuple representation 4 | mode SS means sparse adjacency and sparse X, SD means sparse adjacency and dense X, DD means dense adjacency and dense X. 5 | datadict contains precomputation results. 6 | """ 7 | 8 | from torch import Tensor 9 | from ..backend.SpTensor import SparseTensor 10 | from ..backend.MaTensor import MaskedTensor 11 | from typing import Union, Tuple, List, Iterable, Literal, Dict, Optional, Callable 12 | from torch.nn import Module 13 | from .utils import MLP 14 | from . import TensorOp 15 | from torch_geometric.nn import HeteroLinear 16 | import torch.nn as nn 17 | 18 | 19 | # NGNNConv: Nested Graph Neural Network Convolution Layer 20 | class NGNNConv(Module): 21 | """ 22 | Implementation of the NGNNConv layer based on the paper "Nested Graph Neural Networks" by Muhan Zhang and Pan Li, NeurIPS 2021. 23 | This layer performs message passing on 2D subgraph representations. 24 | 25 | Args: 26 | 27 | - indim (int): Input feature dimension. 28 | - outdim (int): Output feature dimension. 29 | - aggr (str): Aggregation method for message passing (e.g., "sum"). 30 | - mode (str): Mode for specifying tensor types (e.g., "SS" for sparse adjacency and sparse X). 31 | - mlp (dict): Parameters for the MLP layer. 32 | 33 | Methods: 34 | 35 | - forward(A: Union[SparseTensor, MaskedTensor], X: Union[SparseTensor, MaskedTensor], datadict: dict) -> Union[SparseTensor, MaskedTensor]: 36 | Forward pass of the NGNNConv layer. 37 | """ 38 | 39 | def __init__(self, 40 | indim: int, 41 | outdim: int, 42 | aggr: str = "sum", 43 | mode: Literal["SD", "DD", "SS"] = "SS", 44 | mlp: dict = {}, 45 | optuplefeat: str = "X", 46 | opadj: str = "A", 47 | message_func: Optional[Callable] = None): 48 | super().__init__() 49 | self.aggr = TensorOp.OpMessagePassingOnSubg2D(mode, aggr, optuplefeat, 50 | opadj, message_func) 51 | self.lin = MLP(indim, outdim, **mlp) 52 | 53 | def forward(self, A: Union[SparseTensor, MaskedTensor], 54 | X: Union[SparseTensor, MaskedTensor], 55 | datadict: dict) -> Union[SparseTensor, MaskedTensor]: 56 | tX = X.tuplewiseapply(self.lin) 57 | ret = self.aggr.forward(A, tX, datadict, tX) 58 | return ret 59 | 60 | 61 | # SSWLConv: Subgraph Weisfeiler-Lehman Convolution Layer 62 | class SSWLConv(Module): 63 | ''' 64 | Implementation of the SSWLConv layer based on the paper "A complete expressiveness hierarchy for subgraph GNNs via subgraph Weisfeiler-Lehman tests" by Bohang Zhang et al., ICML 2023. 65 | This layer performs message passing on 2D subgraph representations and cross-subgraph pooling. 66 | 67 | Args: 68 | 69 | - indim (int): Input feature dimension. 70 | - outdim (int): Output feature dimension. 71 | - aggr (str): Aggregation method for message passing (e.g., "sum"). 72 | - mode (str): Mode for specifying tensor types (e.g., "SS" for sparse adjacency and sparse X). 73 | - mlp (dict): Parameters for the MLP layer. 74 | 75 | Methods: 76 | 77 | - forward(A: Union[SparseTensor, MaskedTensor], X: Union[SparseTensor, MaskedTensor], datadict: dict) -> Union[SparseTensor, MaskedTensor]: 78 | Forward pass of the SSWLConv layer. 79 | 80 | ''' 81 | 82 | def __init__(self, 83 | indim: int, 84 | outdim: int, 85 | aggr: str = "sum", 86 | mode: Literal["SD", "DD", "SS"] = "SS", 87 | mlp: dict = {}, 88 | optuplefeat: str = "X", 89 | opadj: str = "A"): 90 | super().__init__() 91 | self.aggr1 = TensorOp.OpMessagePassingOnSubg2D(mode, aggr, optuplefeat, 92 | opadj) 93 | self.aggr2 = TensorOp.OpMessagePassingCrossSubg2D( 94 | mode, aggr, optuplefeat, opadj) 95 | self.lin = MLP(3 * indim, outdim, **mlp) 96 | 97 | def forward(self, A: Union[SparseTensor, MaskedTensor], 98 | X: Union[SparseTensor, MaskedTensor], 99 | datadict: dict) -> Union[SparseTensor, MaskedTensor]: 100 | tX = X 101 | X1 = self.aggr1.forward(A, tX, datadict, tX) 102 | X2 = self.aggr2.forward(A, tX, datadict, tX) 103 | return X.catvalue([X1, X2], True).tuplewiseapply(self.lin) 104 | 105 | 106 | # I2Conv: I2-GNN Convolution Layer 107 | class I2Conv(Module): 108 | """ 109 | Implementation of the I2Conv layer based on the paper "Boosting the cycle counting power of graph neural networks with I2-GNNs" by Yinan Huang et al., ICLR 2023. 110 | This layer performs message passing on 3D subgraph representations. 111 | 112 | Args: 113 | 114 | - indim (int): Input feature dimension. 115 | - outdim (int): Output feature dimension. 116 | - aggr (str): Aggregation method for message passing (e.g., "sum"). 117 | - mode (str): Mode for specifying tensor types (e.g., "SS" for sparse adjacency and sparse X). 118 | - mlp (dict): Parameters for the MLP layer. 119 | 120 | Methods: 121 | 122 | - forward(A: Union[SparseTensor, MaskedTensor], X: Union[SparseTensor, MaskedTensor], datadict: dict) -> Union[SparseTensor, MaskedTensor]: 123 | Forward pass of the I2Conv layer. 124 | 125 | Notes: 126 | - This layer is based on the I2-GNN paper and performs message passing on 3D subgraph representations. 127 | """ 128 | 129 | def __init__(self, 130 | indim: int, 131 | outdim: int, 132 | aggr: str = "sum", 133 | mode: Literal["SD", "DD", "SS"] = "SS", 134 | mlp: dict = {}, 135 | optuplefeat: str = "X", 136 | opadj: str = "A"): 137 | super().__init__() 138 | self.aggr = TensorOp.OpMessagePassingOnSubg3D(mode, aggr, optuplefeat, 139 | opadj) 140 | self.lin = MLP(indim, outdim, **mlp) 141 | 142 | def forward(self, A: Union[SparseTensor, MaskedTensor], 143 | X: Union[SparseTensor, MaskedTensor], 144 | datadict: dict) -> Union[SparseTensor, MaskedTensor]: 145 | tX = X.tuplewiseapply(self.lin) 146 | ret = self.aggr.forward(A, tX, datadict, tX) 147 | return ret 148 | 149 | 150 | # DSSGNNConv: Equivariant Subgraph Aggregation Networks Convolution Layer 151 | class DSSGNNConv(Module): 152 | """ 153 | Implementation of the DSSGNNConv layer based on the paper "Equivariant subgraph aggregation networks" by Beatrice Bevilacqua et al., ICLR 2022. 154 | This layer performs message passing on 2D subgraph representations with subgraph pooling. 155 | 156 | Args: 157 | 158 | - indim (int): Input feature dimension. 159 | - outdim (int): Output feature dimension. 160 | - aggr_subg (str): Aggregation method for message passing within subgraphs (e.g., "sum"). 161 | - aggr_global (str): Aggregation method for message passing in the global context (e.g., "sum"). 162 | - pool (str): Pooling method (e.g., "mean"). 163 | - mode (str): Mode for specifying tensor types (e.g., "SS" for sparse adjacency and sparse X). 164 | - mlp (dict): Parameters for the MLP layer. 165 | 166 | Methods: 167 | 168 | - forward(A: Union[SparseTensor, MaskedTensor], X: Union[SparseTensor, MaskedTensor], datadict: dict) -> Union[SparseTensor, MaskedTensor]: 169 | Forward pass of the DSSGNNConv layer. 170 | """ 171 | 172 | def __init__(self, 173 | indim: int, 174 | outdim: int, 175 | aggr_subg: str = "sum", 176 | aggr_global: str = "sum", 177 | pool: str = "mean", 178 | mode: Literal["SD", "DD", "SS"] = "SS", 179 | mlp: dict = {}, 180 | optuplefeat: str = "X", 181 | opadj: str = "A"): 182 | super().__init__() 183 | self.aggr_subg = TensorOp.OpMessagePassingOnSubg2D( 184 | mode, aggr_subg, optuplefeat, opadj) 185 | self.pool2global = TensorOp.OpPoolingCrossSubg2D(mode[1], pool) 186 | self.aggr_global = TensorOp.OpNodeMessagePassing(mode, aggr_global) 187 | self.unpooling2subg = TensorOp.OpUnpoolingRootNodes2D(mode[1]) 188 | self.lin = MLP(2 * indim, outdim, **mlp) 189 | 190 | def forward(self, A: Union[SparseTensor, MaskedTensor], 191 | X: Union[SparseTensor, MaskedTensor], 192 | datadict: dict) -> Union[SparseTensor, MaskedTensor]: 193 | X1 = self.unpooling2subg.forward( 194 | self.aggr_global.forward(A, self.pool2global.forward(X)), X) 195 | X2 = self.aggr_subg.forward(A, X, datadict, X) 196 | return X2.catvalue(X1, True).tuplewiseapply(self.lin) 197 | 198 | 199 | # PPGNConv: Provably Powerful Graph Networks Convolution Layer 200 | class PPGNConv(Module): 201 | """ 202 | Implementation of the PPGNConv layer based on the paper "Provably powerful graph networks" by Haggai Maron et al., NeurIPS 2019. 203 | This layer performs message passing with power-sum pooling on 2D subgraph representations. 204 | 205 | Args: 206 | 207 | - indim (int): Input feature dimension. 208 | - outdim (int): Output feature dimension. 209 | - aggr (str): Aggregation method for message passing (e.g., "sum"). 210 | - mode (str): Mode for specifying tensor types (e.g., "SS" for sparse adjacency and sparse X). 211 | - mlp (dict): Parameters for the MLP layers. 212 | 213 | Methods: 214 | 215 | - forward(A: Union[SparseTensor, MaskedTensor], X: Union[SparseTensor, MaskedTensor], datadict: dict) -> Union[SparseTensor, MaskedTensor]: 216 | Forward pass of the PPGNConv layer. 217 | 218 | """ 219 | 220 | def __init__(self, 221 | indim: int, 222 | outdim: int, 223 | aggr: str = "sum", 224 | mode: Literal["DD", "SS"] = "SS", 225 | mlp: dict = {}, 226 | optuplefeat: str = "X"): 227 | super().__init__() 228 | self.op = TensorOp.Op2FWL(mode, aggr, optuplefeat) 229 | self.lin1 = MLP(indim, outdim, **mlp) 230 | self.lin2 = MLP(indim, outdim, **mlp) 231 | 232 | def forward(self, A: Union[SparseTensor, MaskedTensor], 233 | X: Union[SparseTensor, MaskedTensor], 234 | datadict: dict) -> Union[SparseTensor, MaskedTensor]: 235 | return self.op.forward(X.tuplewiseapply(self.lin1), 236 | X.tuplewiseapply(self.lin2), datadict, X) 237 | 238 | 239 | # GNNAKConv: Graph Neural Networks As Kernel Convolution layer 240 | class GNNAKConv(Module): 241 | """ 242 | Implementation of the GNNAKConv layer based on the paper "From stars to subgraphs: Uplifting any GNN with local structure awareness" by Lingxiao Zhao et al., ICLR 2022. 243 | This layer performs message passing on 2D subgraph representations with subgraph pooling and cross-subgraph pooling. 244 | 245 | Args: 246 | 247 | - indim (int): Input feature dimension. 248 | - outdim (int): Output feature dimension. 249 | - aggr (str): Aggregation method for message passing (e.g., "sum"). 250 | - pool (str): Pooling method (e.g., "mean"). 251 | - mode (str): Mode for specifying tensor types (e.g., "SS" for sparse adjacency and sparse X). 252 | - mlp0 (dict): Parameters for the first MLP layer. 253 | - mlp1 (dict): Parameters for the second MLP layer. 254 | - ctx (bool): Whether to include context information. 255 | 256 | Methods: 257 | 258 | - forward(A: Union[SparseTensor, MaskedTensor], X: Union[SparseTensor, MaskedTensor], datadict: dict) -> Union[SparseTensor, MaskedTensor]: 259 | Forward pass of the GNNAKConv layer. 260 | 261 | """ 262 | 263 | def __init__(self, 264 | indim: int, 265 | outdim: int, 266 | aggr: str = "sum", 267 | pool: str = "mean", 268 | mode: Literal["SD", "DD", "SS"] = "SS", 269 | mlp0: dict = {}, 270 | mlp1: dict = {}, 271 | ctx: bool = True, 272 | optuplefeat: str = "X", 273 | opadj: str = "A"): 274 | super().__init__() 275 | self.lin0 = MLP(indim, indim, **mlp0) 276 | self.aggr = TensorOp.OpMessagePassingOnSubg2D(mode, aggr, optuplefeat, 277 | opadj) 278 | self.diag = TensorOp.OpDiag2D(mode[1]) 279 | self.pool2subg = TensorOp.OpPoolingSubg2D(mode[1], pool) 280 | self.unpool4subg = TensorOp.OpUnpoolingSubgNodes2D(mode[1]) 281 | self.ctx = ctx 282 | if ctx: 283 | self.pool2node = TensorOp.OpPoolingCrossSubg2D(mode[1], pool) 284 | self.unpool4rootnode = TensorOp.OpUnpoolingRootNodes2D(mode[1]) 285 | self.lin = MLP(3 * indim if ctx else 2 * indim, outdim, **mlp1) 286 | 287 | def forward(self, A: Union[SparseTensor, MaskedTensor], 288 | X: Union[SparseTensor, MaskedTensor], 289 | datadict: dict) -> Union[SparseTensor, MaskedTensor]: 290 | X = self.aggr.forward(A, X.tuplewiseapply(self.lin0), datadict, X) 291 | X1 = self.unpool4subg.forward(self.diag.forward(X), X) 292 | X2 = self.unpool4subg.forward(self.pool2subg.forward(X), X) 293 | if self.ctx: 294 | X3 = self.unpool4rootnode.forward(self.pool2node.forward(X), X) 295 | return X2.catvalue([X1, X3], True).tuplewiseapply(self.lin) 296 | else: 297 | return X2.catvalue(X1, True).tuplewiseapply(self.lin) 298 | 299 | 300 | # SUNConv: Subgraph Union Network Convolution Layer 301 | class SUNConv(Module): 302 | """ 303 | Implementation of the SUNConv layer based on the paper "Understanding and extending subgraph GNNs by rethinking their symmetries" by Fabrizio Frasca et al., NeurIPS 2022. 304 | This layer performs message passing on 2D subgraph representations with subgraph and cross-subgraph pooling. 305 | 306 | Args: 307 | 308 | - indim (int): Input feature dimension. 309 | - outdim (int): Output feature dimension. 310 | - aggr (str): Aggregation method for message passing (e.g., "sum"). 311 | - pool (str): Pooling method (e.g., "mean"). 312 | - mode (str): Mode for specifying tensor types (e.g., "SS" for sparse adjacency and sparse X). 313 | - mlp0 (dict): Parameters for the first MLP layer. 314 | - mlp1 (dict): Parameters for the second MLP layer. 315 | 316 | Methods: 317 | 318 | - forward(A: Union[SparseTensor, MaskedTensor], X: Union[SparseTensor, MaskedTensor], datadict: dict) -> Union[SparseTensor, MaskedTensor]: 319 | Forward pass of the SUNConv layer. 320 | 321 | Notes: 322 | 323 | - This layer is based on Symmetry Understanding Networks (SUN) and performs message passing on 2D subgraph representations with subgraph and cross-subgraph pooling. 324 | """ 325 | 326 | def __init__(self, 327 | indim: int, 328 | outdim: int, 329 | aggr: str = "sum", 330 | pool: str = "mean", 331 | mode: Literal["SD", "DD", "SS"] = "SS", 332 | mlp0: dict = {}, 333 | mlp1: dict = {}, 334 | optuplefeat: str = "X", 335 | opadj: str = "A"): 336 | super().__init__() 337 | self.lin0 = MLP(indim, indim, **mlp0) 338 | self.aggr = TensorOp.OpMessagePassingOnSubg2D(mode, aggr, optuplefeat, 339 | opadj) 340 | self.diag = TensorOp.OpDiag2D(mode[1]) 341 | self.pool2subg = TensorOp.OpPoolingSubg2D(mode[1], pool) 342 | self.unpool4subg = TensorOp.OpUnpoolingSubgNodes2D(mode[1]) 343 | self.pool2node = TensorOp.OpPoolingCrossSubg2D(mode[1], pool) 344 | self.unpool4rootnode = TensorOp.OpUnpoolingRootNodes2D(mode[1]) 345 | self.lin1_0 = HeteroLinear(7 * indim, indim, 2, False) 346 | self.lin1_1 = MLP(indim, outdim, **mlp1) 347 | 348 | def forward(self, A: Union[SparseTensor, MaskedTensor], 349 | X: Union[SparseTensor, MaskedTensor], 350 | datadict: dict) -> Union[SparseTensor, MaskedTensor]: 351 | X4 = self.aggr.forward(A, X.tuplewiseapply(self.lin0), datadict, X) 352 | Xdiag = self.diag.forward(X) 353 | X1 = X 354 | X2 = self.unpool4subg.forward(Xdiag, X) 355 | X3 = self.unpool4rootnode.forward(Xdiag, X) 356 | X5 = self.unpool4rootnode.forward(self.pool2node(X), X) 357 | X6 = self.unpool4subg.forward(self.pool2subg(X), X) 358 | X7 = self.unpool4rootnode.forward(self.pool2node(X4), X) 359 | X = X1.catvalue([X2, X3, X4, X5, X6, X7], True) 360 | X = X.diagonalapply(lambda val, ind: self.lin1_0( 361 | val.flatten(0, -2), ind.flatten()).unflatten(0, val.shape[0:-1])) 362 | X = X.tuplewiseapply(self.lin1_1) 363 | return X -------------------------------------------------------------------------------- /pygho/honn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/PygHO/4cd5a20adf84c76842e001b54570d053747c9470/pygho/honn/__init__.py -------------------------------------------------------------------------------- /pygho/honn/utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | A general MLP class 3 | ''' 4 | import torch.nn as nn 5 | from torch import Tensor 6 | from typing import Callable 7 | from torch import Tensor 8 | 9 | # Norms for subgraph GNN 10 | 11 | 12 | class NormMomentumScheduler: 13 | 14 | def __init__(self, 15 | mfunc: Callable, 16 | initmomentum: float, 17 | normtype=nn.BatchNorm1d) -> None: 18 | super().__init__() 19 | self.normtype = normtype 20 | self.mfunc = mfunc 21 | self.epoch = 0 22 | self.initmomentum = initmomentum 23 | 24 | def step(self, model: nn.Module): 25 | ratio = self.mfunc(self.epoch) 26 | if 1 - 1e-6 < ratio < 1 + 1e-6: 27 | return self.initmomentum 28 | curm = self.initmomentum * ratio 29 | self.epoch += 1 30 | for mod in model.modules(): 31 | if type(mod) is self.normtype: 32 | mod.momentum = curm 33 | return curm 34 | 35 | 36 | class NoneNorm(nn.Module): 37 | 38 | def __init__(self, dim=0, normparam=0) -> None: 39 | super().__init__() 40 | self.num_features = dim 41 | 42 | def forward(self, x): 43 | return x 44 | 45 | 46 | class BatchNorm(nn.Module): 47 | 48 | def __init__(self, dim, normparam=0.1) -> None: 49 | super().__init__() 50 | self.num_features = dim 51 | self.norm = nn.BatchNorm1d(dim, momentum=normparam) 52 | 53 | def forward(self, x: Tensor): 54 | if x.dim() == 2: 55 | return self.norm(x) 56 | elif x.dim() >= 3: 57 | shape = x.shape 58 | x = self.norm(x.flatten(0, -2)).reshape(shape) 59 | return x 60 | else: 61 | raise NotImplementedError 62 | 63 | 64 | class LayerNorm(nn.Module): 65 | 66 | def __init__(self, dim, normparam=0.1) -> None: 67 | super().__init__() 68 | self.num_features = dim 69 | self.norm = nn.LayerNorm(dim) 70 | 71 | def forward(self, x: Tensor): 72 | return self.norm(x) 73 | 74 | # Define a dictionary for normalization layers 75 | normdict = {"bn": BatchNorm, "ln": LayerNorm, "none": NoneNorm} 76 | 77 | # a dictionary for activation functions 78 | act_dict = { 79 | "relu": nn.ReLU(inplace=True), 80 | "ELU": nn.ELU(inplace=True), 81 | "silu": nn.SiLU(inplace=True) 82 | } 83 | 84 | 85 | class MLP(nn.Module): 86 | """ 87 | Multi-Layer Perceptron (MLP) module with customizable layers and activation functions. 88 | 89 | Args: 90 | 91 | - hiddim (int): Number of hidden units in each layer. 92 | - outdim (int): Number of output units. 93 | - numlayer (int): Number of hidden layers in the MLP. 94 | - tailact (bool): Whether to apply the activation function after the final layer. 95 | - dp (float): Dropout probability, if greater than 0, dropout layers are added. 96 | - norm (str): Normalization method to apply between layers (e.g., "bn" for BatchNorm). 97 | - act (str): Activation function to apply between layers (e.g., "relu"). 98 | - tailbias (bool): Whether to include a bias term in the final linear layer. 99 | - normparam (float): Parameter for normalization (e.g., momentum for BatchNorm). 100 | 101 | Methods: 102 | 103 | - forward(x: Tensor) -> Tensor: 104 | Forward pass of the MLP. 105 | 106 | Notes: 107 | 108 | - This class defines a multi-layer perceptron with customizable layers, activation functions, normalization, and dropout. 109 | """ 110 | def __init__(self, 111 | hiddim: int, 112 | outdim: int, 113 | numlayer: int, 114 | tailact: bool, 115 | dp: float = 0, 116 | norm: str = "bn", 117 | act: str = "relu", 118 | tailbias=True, 119 | normparam: float = 0.1) -> None: 120 | super().__init__() 121 | assert numlayer >= 0 122 | if numlayer == 0: 123 | assert hiddim == outdim 124 | self.lins = NoneNorm() 125 | else: 126 | lin0 = nn.Sequential(nn.Linear(hiddim, outdim, bias=tailbias)) 127 | if tailact: 128 | lin0.append(normdict[norm](outdim, normparam)) 129 | if dp > 0: 130 | lin0.append(nn.Dropout(dp, inplace=True)) 131 | lin0.append(act_dict[act]) 132 | for _ in range(numlayer - 1): 133 | lin0.insert(0, act_dict[act]) 134 | if dp > 0: 135 | lin0.insert(0, nn.Dropout(dp, inplace=True)) 136 | lin0.insert(0, normdict[norm](hiddim, normparam)) 137 | lin0.insert(0, nn.Linear(hiddim, hiddim)) 138 | self.lins = lin0 139 | 140 | def forward(self, x: Tensor): 141 | # Forward pass through the MLP 142 | return self.lins(x) 143 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | networkx==2.8.4 2 | numpy==1.24.3 3 | ogb==1.3.6 4 | scikit_learn==1.2.2 5 | scipy==1.10.1 6 | torch==2.0.1 7 | torch_geometric==2.3.0 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | version = "0.0.1" 4 | 5 | setup( 6 | name="pygho", 7 | version=version, 8 | description="PygHO is a library for high-order GNNs", 9 | download_url="https://github.com/GraphPKU/PygHO", 10 | author="GraphPKU", 11 | python_requires=">=3.10", 12 | packages=find_packages(include=["pygho", "pygho.*"]), 13 | install_requires=[ 14 | "torch>=2.0", 15 | "torch_geometric>=2.3", 16 | "tqdm", 17 | "pqdm" 18 | ], 19 | ) -------------------------------------------------------------------------------- /tests/test_backend_masked.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from pygho import MaskedTensor, SparseTensor 3 | import pygho.backend.MaTensor as MaTensor 4 | import pygho.backend.Mamamm as Mamamm 5 | import pygho.backend.Spmamm as Spmamm 6 | import torch 7 | 8 | EPS = 1e-5 9 | 10 | 11 | def maxdiff(a: torch.Tensor, b: torch.Tensor) -> float: 12 | return torch.amax((a - b).abs()).item() 13 | 14 | 15 | def tensorequal(a: torch.Tensor, b: torch.Tensor) -> bool: 16 | return torch.all(a == b).item() 17 | 18 | 19 | def floattensorequal(a: torch.Tensor, b: torch.Tensor) -> bool: 20 | return maxdiff(a, b) < EPS 21 | 22 | 23 | class MaTensorTest(unittest.TestCase): 24 | 25 | def setUp(self) -> None: 26 | B = 2 27 | N = 3 28 | L = 5 29 | M = 7 30 | data = torch.randn((B, N, M, L)) 31 | mask = torch.ones((B, N, M), dtype=torch.bool) 32 | mask[0, 2:] = False 33 | mask[0, :, 2:] = False 34 | mask[1, 1:] = False 35 | mask[1, :, 1:] = False 36 | vd = torch.masked.masked_tensor(data, 37 | mask.unsqueeze(-1).expand_as(data)) 38 | mt = MaskedTensor(data, mask, padvalue=torch.inf) 39 | self.data = data 40 | self.mask = mask 41 | self.vd = vd 42 | self.mt = mt 43 | return super().setUp() 44 | 45 | def test_filterinf(self): 46 | A = torch.tensor([-torch.inf, 0, torch.inf, 1, 2, -torch.inf, 3]) 47 | self.assertTrue( 48 | tensorequal(MaTensor.filterinf(A), 49 | torch.tensor([0, 0, 0, 1, 2, 0, 3])), 50 | "filter inf error") 51 | 52 | def test_fill(self): 53 | self.assertTrue( 54 | floattensorequal( 55 | self.mt.fill_masked(1024), 56 | torch.masked_fill( 57 | self.data, 58 | torch.logical_not(self.mask).unsqueeze(-1).expand_as( 59 | self.data), 1024)), "mask fill error") 60 | 61 | def test_pool(self): 62 | self.assertTrue( 63 | floattensorequal(self.mt.max(dim=1), 64 | self.vd.amax(dim=1).to_dense()), "max error") 65 | self.assertTrue( 66 | floattensorequal(self.mt.min(dim=1), 67 | self.vd.amin(dim=1).to_dense()), "min error") 68 | self.assertTrue( 69 | floattensorequal(self.mt.mean(dim=1), 70 | self.vd.mean(dim=1).to_dense()), "mean error") 71 | self.assertTrue( 72 | floattensorequal(self.mt.sum(dim=1), 73 | self.vd.sum(dim=1).to_dense()), "sum error") 74 | 75 | self.assertTrue( 76 | floattensorequal(self.mt.max(), 77 | self.vd.amax().to_dense()), "max error") 78 | self.assertTrue( 79 | floattensorequal(self.mt.min(), 80 | self.vd.amin().to_dense()), "min error") 81 | self.assertTrue( 82 | floattensorequal(self.mt.mean(), 83 | self.vd.mean().to_dense()), "mean error") 84 | self.assertTrue( 85 | floattensorequal(self.mt.sum(), 86 | self.vd.sum().to_dense()), "sum error") 87 | 88 | 89 | class SpmammTest(unittest.TestCase): 90 | 91 | def setUp(self) -> None: 92 | b, n, m, l, d = 5, 3, 7, 13, 11 93 | A = torch.rand((b, n, m, d)) 94 | Amask = torch.rand_like(A[:, :, :, 0]) > 0.9 95 | MA = MaskedTensor(A, Amask) 96 | ind = Amask.to_sparse_coo().indices() 97 | SA = SparseTensor(ind, A[ind[0], ind[1], ind[2]], shape=MA.shape) 98 | B = torch.rand((b, m, l, d)) 99 | Bmask = torch.rand_like(B[:, :, :, 0]) > 0.9 100 | MB = MaskedTensor(B, Bmask) 101 | ind = Bmask.to_sparse_coo().indices() 102 | SB = SparseTensor(ind, B[ind[0], ind[1], ind[2]], shape=MB.shape) 103 | mask = torch.ones((b, n, l), dtype=torch.bool) 104 | self.SA = SA 105 | self.MB = MB 106 | self.mask = mask 107 | self.MA = MA 108 | self.SB = SB 109 | return super().setUp() 110 | 111 | def test_spmamm(self): 112 | self.assertTrue( 113 | floattensorequal( 114 | Spmamm.spmamm(self.SA, self.MB, self.mask).data, 115 | torch.einsum("bnmd,bmld->bnld", self.MA.data, self.MB.data)), 116 | "spmamm error") 117 | 118 | def test_maspmm(self): 119 | self.assertTrue( 120 | floattensorequal( 121 | Spmamm.maspmm(self.MA, self.SB, self.mask).data, 122 | torch.einsum("bnmd,bmld->bnld", self.MA.data, self.MB.data)), 123 | "maspmm error") 124 | -------------------------------------------------------------------------------- /tests/test_backend_sparse.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from pygho import SparseTensor 3 | import pygho.backend.SpTensor as SpTensor 4 | import pygho.backend.Spspmm as Spspmm 5 | import pygho.backend.Spmm as Spmm 6 | import torch 7 | import numpy as np 8 | 9 | 10 | def maxdiff(a: torch.Tensor, b: torch.Tensor) -> float: 11 | return torch.max((a - b).abs()).item() 12 | 13 | 14 | def tensorequal(a: torch.Tensor, b: torch.Tensor) -> bool: 15 | return torch.all(a == b).item() 16 | 17 | 18 | def lexsort(keys, dim: int = -1): 19 | ''' 20 | lexsort ascendingly 21 | ''' 22 | tmpkey = torch.flip(keys, dims=(0, )) 23 | ind = np.lexsort(tmpkey.detach().cpu().numpy(), axis=dim) 24 | return keys[:, torch.from_numpy(ind).to(keys.device)] 25 | 26 | 27 | EPS = 5e-5 28 | 29 | 30 | class SpTensorTest(unittest.TestCase): 31 | 32 | def setUp(self) -> None: 33 | return super().setUp() 34 | 35 | def test_hash_tight(self): 36 | sd, sshape, nnz, d = 5, (2, 3, 7, 11, 13), 17, 7 37 | indices = torch.stack( 38 | tuple(torch.randint(sshape[i], (nnz, )) for i in range(sd))) 39 | tsshape = torch.LongTensor(sshape) 40 | thash = SpTensor.indicehash_tight(indices, tsshape) 41 | hhash = ((( 42 | (indices[0]) * sshape[1] + indices[1]) * sshape[2] + indices[2]) * 43 | sshape[3] + indices[3]) * sshape[4] + indices[4] 44 | self.assertTrue(tensorequal(thash, hhash), "hash_tight wrong") 45 | dthash = SpTensor.decodehash_tight(thash, tsshape) 46 | self.assertTrue(tensorequal(dthash, indices), 47 | "hash_tight decode wrong") 48 | 49 | def test_hash(self): 50 | sd, sshape, nnz, d = 5, (2, 3, 7, 11, 13), 17, 7 51 | indices = torch.stack( 52 | tuple(torch.randint(sshape[i], (nnz, )) for i in range(sd))) 53 | indices = lexsort(indices, dim=-1) 54 | thash = SpTensor.indicehash(indices) 55 | self.assertTrue(torch.all(torch.diff(thash) >= 0), 56 | "hash not keep order") 57 | dthash = SpTensor.decodehash(thash, sparse_dim=len(sshape)) 58 | self.assertTrue(tensorequal(dthash, indices), 59 | "hash_tight decode wrong") 60 | 61 | def test_create(self): 62 | n, m, l, nnz, d = 2, 3, 5, 23, 7 63 | indices = torch.stack( 64 | (torch.randint(0, n, (nnz, )), torch.randint(0, m, (nnz, )), 65 | torch.randint(0, l, (nnz, )))) 66 | values = torch.randn((nnz, d)) 67 | 68 | A1 = torch.sparse_coo_tensor(indices, values, size=(n, m, l, d)) 69 | A2 = SpTensor.SparseTensor(indices, values, (n, m, l, d), False) 70 | A2f = SpTensor.SparseTensor.from_torch_sparse_coo(A1) 71 | A1c = A1.coalesce() 72 | 73 | self.assertTrue(tensorequal(A2.indices, A1c.indices()), 74 | "create indice not match") 75 | self.assertLessEqual(maxdiff(A2.values, A1c.values()), EPS, 76 | "create value not match") 77 | 78 | self.assertTrue(tensorequal(A2.indices, A2f.indices), 79 | "from coo indice not match") 80 | self.assertLessEqual(maxdiff(A2.values, A2f.values), EPS, 81 | "from coo value not match") 82 | 83 | A1f = A2.to_torch_sparse_coo() 84 | self.assertLessEqual(maxdiff(A1f.to_dense(), A1c.to_dense()), EPS, 85 | "to coo not match") 86 | 87 | 88 | class SpspmmTest(unittest.TestCase): 89 | 90 | def setUp(self) -> None: 91 | 92 | return super().setUp() 93 | 94 | def test_ptr2batch(self): 95 | ptr = torch.tensor([0, 4, 4, 7, 8, 11, 11, 11, 16], dtype=torch.long) 96 | batch = torch.tensor([0, 0, 0, 0, 2, 2, 2, 3, 4, 4, 4, 7, 7, 7, 7, 7], 97 | dtype=torch.long) 98 | self.assertTrue(tensorequal(Spspmm.ptr2batch(ptr, dim_size=16), batch), 99 | "ptr2batch error") 100 | 101 | def test_2dmm(self): 102 | from torch_scatter import scatter_add 103 | n, m, l = 300, 200, 400 104 | device = torch.device("cuda") 105 | A = torch.rand((n, m), device=device) 106 | A[torch.rand_like(A) > 0.9] = 0 107 | A = A.to_sparse_coo() 108 | B = torch.rand((m, l), device=device) 109 | B[torch.rand_like(B) > 0.9] = 0 110 | B = B.to_sparse_coo() 111 | ind1 = A.indices() 112 | val1 = A.values() 113 | ind2 = B.indices() 114 | val2 = B.values() 115 | 116 | C = A @ B 117 | C = C.coalesce() 118 | 119 | ind, bcd = Spspmm.spspmm_ind(ind1, 1, ind2, 0) 120 | mult = val1[bcd[1]] * val2[bcd[2]] 121 | outval = scatter_add(mult, bcd[0], dim_size=ind.shape[1]) 122 | out = torch.sparse_coo_tensor(ind, outval) 123 | out = out.coalesce() 124 | self.assertTrue(tensorequal(C.indices(), out.indices()), 125 | "spspmm indice not match") 126 | self.assertLessEqual(maxdiff(C.values(), out.values()), EPS, 127 | "spspmm value not match") 128 | 129 | tar_ind = torch.stack( 130 | (torch.randint_like(ind1[0], n), torch.randint_like(ind1[0], l))) 131 | 132 | tar_ind = SpTensor.decodehash( 133 | torch.unique(SpTensor.indicehash(tar_ind), sorted=True), 2) 134 | acd = Spspmm.filterind(tar_ind, ind, bcd) 135 | mult = val1[acd[1]] * val2[acd[2]] 136 | outval = scatter_add(mult, acd[0], dim_size=tar_ind.shape[1]) 137 | maskedout = torch.sparse_coo_tensor(tar_ind, outval) 138 | maskedout = maskedout.coalesce() 139 | # debug spspmm with target filter 140 | self.assertLessEqual( 141 | maxdiff(maskedout.to_dense()[tar_ind[0], tar_ind[1]], 142 | C.to_dense()[tar_ind[0], tar_ind[1]]), EPS, 143 | "spspmm with target ind value not match") 144 | 145 | def test_2dhadamard(self): 146 | n, m = 300, 200 147 | Ap = torch.rand((n, m)) 148 | Ap[torch.rand_like(Ap) > 0.9] = 0 149 | Ap = Ap.to_sparse_coo() 150 | Bp = torch.rand((n, m)) 151 | Bp[torch.rand_like(Bp) > 0.9] = 0 152 | Bp = Bp.to_sparse_coo() 153 | spsphadamardout = Spspmm.spsphadamard( 154 | SparseTensor.from_torch_sparse_coo(Ap), 155 | SparseTensor.from_torch_sparse_coo(Bp)) 156 | self.assertLessEqual( 157 | torch.max((torch.multiply(Ap, Bp) - 158 | spsphadamardout.to_torch_sparse_coo() 159 | ).coalesce().values().abs()).item(), EPS, 160 | "hadamard error") 161 | 162 | def test_3dmm(self): 163 | from torch_scatter import scatter_add 164 | n, m, l, k = 13, 5, 7, 11 165 | A = torch.rand((n, k, m)) 166 | A[torch.rand_like(A) > 0.5] = 0 167 | A = A.to_sparse_coo() 168 | B = torch.rand((l, k, n)) 169 | B[torch.rand_like(B) > 0.5] = 0 170 | B = B.to_sparse_coo() 171 | ind1 = A.indices() 172 | val1 = A.values() 173 | ind2 = B.indices() 174 | val2 = B.values() 175 | 176 | C = torch.einsum("nkm,lkd->nmld", A.to_dense(), B.to_dense()) 177 | Cs = C.to_sparse_coo().coalesce() 178 | 179 | ind, bcd = Spspmm.spspmm_ind(ind1, 1, ind2, 1) 180 | mult = val1[bcd[1]] * val2[bcd[2]] 181 | outval = scatter_add(mult, bcd[0], dim_size=ind.shape[1]) 182 | out = torch.sparse_coo_tensor(ind, outval) 183 | out = out.coalesce() 184 | 185 | self.assertTrue(tensorequal(Cs.indices(), out.indices()), 186 | "spspmm indice not match") 187 | self.assertLessEqual(maxdiff(Cs.values(), out.values()), 1e-5, 188 | "spspmm value not match") 189 | 190 | tar_ind = torch.stack( 191 | (torch.randint_like(ind1[0], n), torch.randint_like(ind1[0], m), 192 | torch.randint_like(ind1[0], l), torch.randint_like(ind1[0], n))) 193 | 194 | tar_ind = SpTensor.decodehash( 195 | torch.unique(SpTensor.indicehash(tar_ind), sorted=True), 4) 196 | acd = Spspmm.filterind(tar_ind, ind, bcd) 197 | mult = val1[acd[1]] * val2[acd[2]] 198 | outval = scatter_add(mult, acd[0], dim_size=tar_ind.shape[1]) 199 | maskedout = torch.sparse_coo_tensor(tar_ind, outval) 200 | maskedout = maskedout.coalesce() 201 | # debug spspmm with target filter 202 | self.assertLessEqual( 203 | maxdiff( 204 | maskedout.to_dense()[tar_ind[0], tar_ind[1], tar_ind[2], 205 | tar_ind[3]], 206 | C[tar_ind[0], tar_ind[1], tar_ind[2], tar_ind[3]]), 1e-5, 207 | "spspmm with target indice value not match") 208 | 209 | def test_3dhadamard(self): 210 | n, m, l = 3, 5, 11 211 | Ap = torch.rand((n, m, l)) 212 | Ap[torch.rand_like(Ap) > 0.9] = 0 213 | Ap = Ap.to_sparse_coo() 214 | Bp = torch.rand((n, m, l)) 215 | Bp[torch.rand_like(Bp) > 0.9] = 0 216 | Bp = Bp.to_sparse_coo() 217 | spsphadamardout = Spspmm.spsphadamard( 218 | SparseTensor.from_torch_sparse_coo(Ap), 219 | SparseTensor.from_torch_sparse_coo(Bp)) 220 | self.assertLessEqual( 221 | torch.max((torch.multiply(Ap, Bp) - 222 | spsphadamardout.to_torch_sparse_coo() 223 | ).coalesce().values().abs()).item(), 1e-5, 224 | "3d hadamard value error") 225 | 226 | 227 | class SpmmTest(unittest.TestCase): 228 | 229 | def setUp(self) -> None: 230 | return super().setUp() 231 | 232 | def test_spmm(self): 233 | n, m, l = 300, 200, 400 234 | device = torch.device("cuda") 235 | A = torch.rand((n, m), device=device) 236 | A[torch.rand_like(A) > 0.9] = 0 237 | A = A.to_sparse_coo() 238 | X = torch.randn((m, l), device=device) 239 | Y1 = Spmm.spmm( 240 | SparseTensor(A.indices(), 241 | A.values().unsqueeze(-1), A.shape + (1, )), X) 242 | Y2 = A @ X 243 | self.assertLessEqual(maxdiff(Y1, Y2), EPS, "spmm error") 244 | 245 | def test_mspmm(self): 246 | n, m, l = 300, 200, 400 247 | X = torch.randn((n, m)) 248 | A = torch.rand((m, l)) 249 | A[torch.rand_like(A) > 0.9] = 0 250 | A = A.to_sparse_coo() 251 | Y1 = Spmm.mspmm( 252 | X, 253 | SparseTensor(A.indices(), 254 | A.values().unsqueeze(-1), A.shape + (1, ))) 255 | Y2 = X @ A.to_dense() 256 | self.assertLessEqual(maxdiff(Y1, Y2), EPS, "mspmm error") 257 | --------------------------------------------------------------------------------