├── .gitignore ├── LICENSE ├── README.md ├── img ├── classification.png ├── clustering.png └── sharp.png ├── pytorch ├── AsymCheegerCutPool.py ├── GTVConv.py ├── classification.py ├── clustering.py └── pytorch_environment.yml ├── tensorflow ├── AsymCheegerCutPool.py ├── GTVConv.py ├── classification.py ├── clustering.py └── tf_environment.yml ├── tvgnn_poster.pdf └── utils └── metrics.py /.gitignore: -------------------------------------------------------------------------------- 1 | /pytorch/data/* 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | 133 | # MACOS stuff 134 | .DS_store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Filippo Bianchi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![ICML](https://img.shields.io/badge/ICML-2023-blue)](https://icml.cc/virtual/2023/poster/24747) 2 | [![arXiv](https://img.shields.io/badge/arXiv-2211.06218-b31b1b.svg)](https://arxiv.org/abs/2211.06218) 3 | [![Poster](https://img.shields.io/badge/%E2%87%A9-Poster-%23228B22.svg)](https://github.com/FilippoMB/Total-variation-graph-neural-networks/blob/main/tvgnn_poster.pdf) 4 | [![Video](https://img.shields.io/badge/Presentation-%23FF0000.svg?logo=YouTube&logoColor=white)](https://youtu.be/Dyb1YJOez8w) 5 | 6 | Tensorflow and Pytorch implementation of the Total Variation Graph Neural Network (TVGNN) as presented in the [original paper](https://arxiv.org/abs/2211.06218). 7 | 8 | The TVGNN model can be used to **cluster** the vertices of an annotated graph, by accounting both for the graph topology and the vertex features. Compared to other GNNs for clustering, TVGNN creates *sharp* cluster assignments that better approximate the optimal (in the minimum cut sense) partition. 9 | 10 | smooth and sharp clustering assignments 11 | 12 | The TVGNN model can also be used to implement [graph pooling](https://gnn-pooling.notion.site/) in a deep GNN architecture for tasks such as graph classification. 13 | 14 | # Downstream tasks 15 | TVGNN can be used to perform vertex clustering and graph classification. Other tasks such as graph regression can also be done with the TVGNN model. 16 | 17 | ### Vertex clustering 18 | This is an unsupervised task, where the goal is to generate a partition of the vertices based on the similarity of their vertex features and the graph topology. The GNN model is trained only by minimizing the unsupervised loss $\mathcal{L}$. 19 | 20 | clustering architecture 21 | 22 | ### Graph classification 23 | This is a supervised with goal of predicting the class of each graph. The GNN rchitectures for graph classification alternates GTVConv layers with a graph pooling layer, which gradually distill the global label information from the vertex representations. The GNN is trained by minimizing the unsupervised loss $\mathcal{L}$ for each pooling layer and a supervised cross-entropy loss $\mathcal{L}_\text{cross-entr}$ between the true and predicted class label. 24 | 25 | classification architecture 26 | 27 | # 💻 Implementation 28 | 29 | Tensorflow icon 30 | 31 | ### Tensorflow 32 | This implementation is based on the [Spektral](https://graphneural.network/) library and follows the [Select-Reduce-Connect](https://graphneural.network/layers/pooling/#srcpool) API. 33 | To execute the code, first install the conda environment from [tf_environment.yml](tensorflow/tf_environment.yml) 34 | 35 | conda env create -f tf_environment.yml 36 | 37 | The ``tensorflow/`` folder includes: 38 | 39 | - The implementation of the [GTVConv](/tensorflow/GTVConv.py) layer 40 | - The implementation of the [AsymCheegerCutPool](/tensorflow/AsymCheegerCutPool.py) layer 41 | - An example script to perform the [clustering](/tensorflow/clustering.py) task 42 | - An example script to perform the [classification](/tensorflow/classification.py) task 43 | 44 | Pytorch icon 45 | 46 | ### Pytorch 47 | This implementation is based on the [Pytorch Geometric](https://pytorch-geometric.readthedocs.io/) library. To execute the code, install the conda environment from [pytorch_environment.yml](pytorch/pytorch_environment.yml) 48 | 49 | conda env create -f pytorch_environment.yml 50 | 51 | The ``pytorch/`` folder includes: 52 | 53 | - The implementation of the [GTVConv](/pytorch/GTVConv.py) layer 54 | - The implementation of the [AsymCheegerCutPool](/pytorch/AsymCheegerCutPool.py) layer 55 | - An example script to perform the [clustering](/pytorch/clustering.py) task 56 | - An example script to perform the [classification](/pytorch/classification.py) task 57 | 58 | Tensorflow icon 59 | 60 | ### Spektral 61 | 62 | TVGNN is available on Spektral: 63 | 64 | - [GTVConv](https://graphneural.network/layers/convolution/#gtvconv) layer, 65 | - [AsymCheegerCutPool](https://graphneural.network/layers/pooling/#asymcheegercutpool) layer, 66 | - [Example script](https://github.com/danielegrattarola/spektral/blob/master/examples/other/node_clustering_tvgnn.py) to perform node clustering with TVGNN. 67 | 68 | # 📚 Citation 69 | If you use TVGNN in your research, please consider citing our work as 70 | 71 | ````bibtex 72 | @inproceedings{hansen2023total, 73 | title={Total variation graph neural networks}, 74 | author={Hansen, Jonas Berg and Bianchi, Filippo Maria}, 75 | booktitle={International Conference on Machine Learning}, 76 | pages={12445--12468}, 77 | year={2023}, 78 | organization={PMLR} 79 | } 80 | ```` 81 | 82 | # ⚙️ Technical details 83 | TVGNN consists of the GTVConv layer and the AsymmetricCheegerCut layer. 84 | 85 | ### GTVConv 86 | The GTVConv layer is a *message-passing* layer that minimizes the $L_1$-norm of the difference between features of adjacent nodes. The $l$-th GTVConv layer updates the node features as 87 | 88 | $$\mathbf{X}^{(l+1)} = \sigma\left[ \left( \mathbf{I} - 2\delta \mathbf{L}_\Gamma^{(l)} \right) \mathbf{X}^{(l)}\mathbf{\Theta} \right] $$ 89 | 90 | where $\sigma$ is a non-lineary, $\mathbf{\Theta}$ are the trainable weights of the layer, and $\delta$ is an hyperparameter. $\mathbf{L}^{(l)}_ \Gamma$ is a Laplacian defined as $\mathbf{L}^{(l)}_ \Gamma$ = $\mathbf{D}^{(l)}_ \Gamma - \mathbf{\Gamma}^{(l)}$, where $\mathbf{D}_\Gamma = \text{diag}(\mathbf{\Gamma} \boldsymbol{1})$ and 91 | 92 | $$ [\mathbf{\Gamma}]^{(l)}_ {ij} = \frac{a_ {ij}}{\texttt{max}\{ \lVert \boldsymbol{x}_i^{(l)} - \boldsymbol{x}_j^{(l)} \rVert_1, \epsilon \}}$$ 93 | 94 | where $a_{ij}$ is the $ij$-th entry of the adjacency matrix, $\boldsymbol{x}_i^{(l)}$ is the feature of vertex $i$ at layer $l$ and $\epsilon$ is a small constant that avoids zero-division. 95 | 96 | ### AsymCheegerCut 97 | The AsymCheegerCut is a *graph pooling* layer that internally contains an $\texttt{MLP}$ parametrized by $\mathbf{\Theta}_\text{MLP}$ and that computes: 98 | - a cluster assignment matrix $\mathbf{S} = \texttt{Softmax}(\texttt{MLP}(\mathbf{X}; \mathbf{\Theta}_\text{MLP})) \in \mathbb{R}^{N\times K}$, which maps the $N$ vertices in $K$ clusters, 99 | - an unsupervised loss $\mathcal{L} = \alpha_1 \mathcal{L}_ \text{GTV} + \alpha_2 \mathcal{L}_ \text{AN}$, where $\alpha_1$ and $\alpha_2$ are two hyperparameters, 100 | - the adjacency matrix and the vertex features of a coarsened graph 101 | 102 | $$\mathbf{A}^\text{pool} = \mathbf{S}^T \tilde{\mathbf{A}} \mathbf{S} \in\mathbb{R}^{K\times K}; \\ \mathbf{X}^\text{pool}=\mathbf{S}^T\mathbf{X} \in\mathbb{R}^{K\times F}. 103 | $$ 104 | 105 | The term $\mathcal{L}_ \text{GTV}$ in the loss minimizes the graph total variation of the cluster assignments $\mathbf{S}$ and is defined as 106 | 107 | $$\mathcal{L}_ \text{GTV} = \frac{\mathcal{L}_ \text{GTV}^*}{2E} \in [0, 1],$$ 108 | 109 | where $\mathcal{L}_ \text{GTV}^*$ = $\displaystyle\sum_{k=1}^K\sum_{i=1}^N \sum_{j=i}^N a_{i,j} |s_{i,k} - s_{j,k}|$, $s_{i,k}$ is the assignment of vertex $i$ to cluster $k$ and $E$ is the number of edges. 110 | 111 | The term $\mathcal{L}_\text{AN}$ encourages the partition to be balanced and is defined as 112 | 113 | $$\mathcal{L}_ {\text{AN}} = \frac{\beta - \mathcal{L}^*_ \text{AN}}{\beta} \in [0, 1],$$ 114 | 115 | where $\mathcal{L}_ \text{AN}^* = \displaystyle\sum^K_{k=1} ||\boldsymbol{s}_ {:,k}$ - $\text{quant}_ \rho (\boldsymbol{s}_ {:,k})||_ {1, \rho}$. 116 | When $\rho = K-1$, then $\beta = N\rho$. 117 | When $\rho$ takes different values, then $\beta = N\rho\min(1, K/(\rho+1))$. 118 | $\text{quant}_ \rho(\boldsymbol{s}_ k)$ denotes the $\rho$-quantile of $\boldsymbol{s}_ k$ and $||\cdot||_ {1,\rho}$ denotes an asymmetric $\ell_1$ norm, which for a vector $\boldsymbol{x} \in \mathbb{R}^{N\times 1}$ is $||\boldsymbol{x}||_ {1,\rho}$ = $\displaystyle\sum^N_{i=1} |x_{i}|_ \rho$, where $|x_i|_ \rho = \rho x_i$ if $x_i\geq 0$ and $|x_i|_ \rho = -x_i$ if $x_i < 0$. 119 | -------------------------------------------------------------------------------- /img/classification.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FilippoMB/Total-variation-graph-neural-networks/c21b427460fda14000a820a541e9709a909b3005/img/classification.png -------------------------------------------------------------------------------- /img/clustering.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FilippoMB/Total-variation-graph-neural-networks/c21b427460fda14000a820a541e9709a909b3005/img/clustering.png -------------------------------------------------------------------------------- /img/sharp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FilippoMB/Total-variation-graph-neural-networks/c21b427460fda14000a820a541e9709a909b3005/img/sharp.png -------------------------------------------------------------------------------- /pytorch/AsymCheegerCutPool.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Union 2 | import math 3 | import torch 4 | from torch import Tensor 5 | from torch_geometric.nn.models.mlp import Linear 6 | from torch_geometric.nn.resolver import activation_resolver 7 | 8 | 9 | class AsymCheegerCutPool(torch.nn.Module): 10 | r""" 11 | The asymmetric cheeger cut pooling layer from the `"Total Variation Graph Neural Networks" 12 | `_ paper. 13 | 14 | Args: 15 | k (int): 16 | Number of clusters or output nodes 17 | mlp_channels (int, list of int): 18 | Number of hidden units for each hidden layer in the MLP used to 19 | compute cluster assignments. First integer must match the number 20 | of input channels. 21 | mlp_activation (any): 22 | Activation function between hidden layers of the MLP. 23 | Must be compatible with `torch_geometric.nn.resolver`. 24 | return_selection (bool): 25 | Whether to return selection matrix. Cannot not be False 26 | if `return_pooled_graph` is False. (default: :obj:`False`) 27 | return_pooled_graph (bool): 28 | Whether to return pooled node features and adjacency. 29 | Cannot be False if `return_selection` is False. (default: :obj:`True`) 30 | bias (bool): 31 | whether to add a bias term to the MLP layers. (default: :obj:`True`) 32 | totvar_coeff (float): 33 | Coefficient for graph total variation loss component. (default: :obj:`1.0`) 34 | balance_coeff (float): 35 | Coefficient for asymmetric norm loss component. (default: :obj:`1.0`) 36 | """ 37 | 38 | def __init__(self, 39 | k: int, 40 | mlp_channels: Union[int, List[int]], 41 | mlp_activation="relu", 42 | return_selection: bool = False, 43 | return_pooled_graph: bool = True, 44 | bias: bool = True, 45 | totvar_coeff: float = 1.0, 46 | balance_coeff: float = 1.0, 47 | ): 48 | super().__init__() 49 | 50 | if not return_selection and not return_pooled_graph: 51 | raise ValueError("return_selection and return_pooled_graph can not both be False") 52 | 53 | if isinstance(mlp_channels, int): 54 | mlp_channels = [mlp_channels] 55 | 56 | act = activation_resolver(mlp_activation) 57 | in_channels = mlp_channels[0] 58 | self.mlp = torch.nn.Sequential() 59 | for channels in mlp_channels[1:]: 60 | self.mlp.append(Linear(in_channels, channels, bias=bias)) 61 | in_channels = channels 62 | self.mlp.append(act) 63 | 64 | 65 | self.mlp.append(Linear(in_channels, k)) 66 | self.k = k 67 | self.return_selection = return_selection 68 | self.return_pooled_graph = return_pooled_graph 69 | self.totvar_coeff = totvar_coeff 70 | self.balance_coeff = balance_coeff 71 | 72 | self.reset_parameters() 73 | 74 | def reset_parameters(self): 75 | for layer in self.mlp: 76 | if isinstance(layer, Linear): 77 | torch.nn.init.xavier_uniform(layer.weight) 78 | torch.nn.init.zeros_(layer.bias) 79 | 80 | def forward( 81 | self, 82 | x: Tensor, 83 | adj: Tensor, 84 | mask: Optional[Tensor] = None, 85 | ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: 86 | r""" 87 | Args: 88 | x (Tensor): 89 | Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}` 90 | with batch-size :math:`B`, (maximum) number of nodes :math:`N` for each graph, 91 | and feature dimension :math:`F`. Note that the cluster assignment matrix 92 | :math:`\mathbf{S} \in \mathbb{R}^{B \times N \times C}` is 93 | being created within this method. 94 | adj (Tensor): 95 | Adjacency tensor :math:`\mathbf{A} \in \mathbb{R}^{B \times N \times N}`. 96 | mask (BoolTensor, optional): 97 | Mask matrix :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` 98 | indicating the valid nodes for each graph. (default: :obj:`None`) 99 | 100 | :rtype: (:class:`Tensor`, :class:`Tensor`, :class:`Tensor`, 101 | :class:`Tensor`, :class:`Tensor`, :class:`Tensor`) 102 | """ 103 | x = x.unsqueeze(0) if x.dim() == 2 else x 104 | adj = adj.unsqueeze(0) if adj.dim() == 2 else adj 105 | 106 | s = self.mlp(x) 107 | s = torch.softmax(s, dim=-1) 108 | 109 | batch_size, n_nodes, _ = x.size() 110 | 111 | if mask is not None: 112 | mask = mask.view(batch_size, n_nodes, 1).to(x.dtype) 113 | x, s = x * mask, s * mask 114 | 115 | # Pooled features and adjacency 116 | if self.return_pooled_graph: 117 | x_pool = torch.matmul(s.transpose(1, 2), x) 118 | adj_pool = torch.matmul(torch.matmul(s.transpose(1, 2), adj), s) 119 | 120 | # Total variation loss 121 | tv_loss = self.totvar_coeff*torch.mean(self.totvar_loss(adj, s)) 122 | 123 | # Balance loss 124 | bal_loss = self.balance_coeff*torch.mean(self.balance_loss(s)) 125 | 126 | if self.return_selection and self.return_pooled_graph: 127 | return s, x_pool, adj_pool, tv_loss, bal_loss 128 | elif self.return_selection and not self.return_pooled_graph: 129 | return s, tv_loss, bal_loss 130 | else: 131 | return x_pool, adj_pool, tv_loss, bal_loss 132 | 133 | def totvar_loss(self, adj, s): 134 | l1_norm = torch.sum(torch.abs(s[..., None, :] - s[:, None, ...]), dim=-1) 135 | 136 | loss = torch.sum(adj * l1_norm, dim=(-1, -2)) 137 | 138 | # Normalize loss 139 | n_edges = torch.count_nonzero(adj, dim=(-1, -2)) 140 | loss *= 1 / (2 * n_edges) 141 | 142 | return loss 143 | 144 | def balance_loss(self, s): 145 | n_nodes = s.size()[-2] 146 | 147 | # k-quantile 148 | idx = int(math.floor(n_nodes / self.k)) 149 | quant = torch.sort(s, dim=-2, descending=True)[0][:, idx, :] # shape [B, K] 150 | 151 | # Asymmetric l1-norm 152 | loss = s - torch.unsqueeze(quant, dim=1) 153 | loss = (loss >= 0) * (self.k - 1) * loss + (loss < 0) * loss * -1 154 | loss = torch.sum(loss, dim=(-1, -2)) # shape [B] 155 | loss = 1 / (n_nodes * (self.k - 1)) * (n_nodes * (self.k - 1) - loss) 156 | 157 | return loss 158 | -------------------------------------------------------------------------------- /pytorch/GTVConv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Linear, Parameter 3 | from torch_geometric.nn import MessagePassing 4 | from torch import Tensor 5 | from torch_sparse import SparseTensor 6 | from torch_geometric.typing import Adj, OptTensor 7 | from torch_geometric.nn.inits import zeros 8 | from torch_geometric import utils 9 | from torch_scatter import scatter_add 10 | from torch_geometric.nn.resolver import activation_resolver 11 | 12 | def gtv_adj_weights(edge_index, edge_weight, num_nodes=None, flow="source_to_target", coeff=1.): 13 | 14 | fill_value = 0. 15 | 16 | assert flow in ["source_to_target", "target_to_source"] 17 | 18 | edge_index, tmp_edge_weight = utils.add_remaining_self_loops( 19 | edge_index, edge_weight, fill_value, num_nodes) 20 | assert tmp_edge_weight is not None 21 | edge_weight = tmp_edge_weight 22 | 23 | # Compute degrees 24 | row, col = edge_index[0], edge_index[1] 25 | idx = col if flow == "source_to_target" else row 26 | deg = scatter_add(edge_weight, idx, dim=0, dim_size=num_nodes) 27 | 28 | # Compute laplacian: L = D - A = -A + D 29 | edge_weight = -edge_weight 30 | edge_weight[row == col] += deg 31 | 32 | # Compute adjusted laplacian: L_adjusted = I - delta*L = -delta*L + I 33 | edge_weight *= -coeff 34 | edge_weight[row == col] += 1 35 | 36 | return edge_index, edge_weight 37 | 38 | 39 | class GTVConv(MessagePassing): 40 | r""" 41 | The GTVConv layer from the `"Total Variation Graph Neural Networks" 42 | `_ paper 43 | 44 | Args: 45 | in_channels (int): 46 | Size of each input sample 47 | out_channels (int): 48 | Size of each output sample. 49 | bias (bool): 50 | If set to :obj:`False`, the layer will not learn 51 | an additive bias. (default: :obj:`True`) 52 | delta_coeff (float): 53 | Step size for gradient descent of GTV (default: :obj:`1.0`) 54 | eps (float): 55 | Small number used to numerically stabilize the computation of 56 | new adjacency weights. (default: :obj:`1e-3`) 57 | act (any): 58 | Activation function. Must be compatible with 59 | `torch_geometric.nn.resolver`. 60 | """ 61 | def __init__(self, in_channels: int, out_channels: int, bias: bool = True, 62 | delta_coeff: float = 1., eps: float = 1e-3, act = "relu"): 63 | super().__init__(aggr='add', flow="target_to_source") 64 | self.weight = Parameter(torch.Tensor(in_channels, out_channels)) 65 | 66 | self.delta_coeff = delta_coeff 67 | self.eps = eps 68 | 69 | self.act = activation_resolver(act) 70 | 71 | if bias: 72 | self.bias = Parameter(torch.Tensor(out_channels)) 73 | else: 74 | self.register_parameter('bias', None) 75 | 76 | self.reset_parameters() 77 | 78 | def reset_parameters(self): 79 | torch.nn.init.kaiming_normal_(self.weight) 80 | zeros(self.bias) 81 | 82 | def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None, mask=None) -> Tensor: 83 | 84 | # Update node features 85 | x = x @ self.weight 86 | 87 | # Check if a dense adjacency is provided 88 | if isinstance(edge_index, Tensor) and edge_index.size(-1) == edge_index.size(-2): 89 | x = x.unsqueeze(0) if x.dim() == 2 else x 90 | edge_index = edge_index.unsqueeze(0) if edge_index.dim() == 2 else edge_index 91 | B, N, _ = edge_index.size() 92 | 93 | # Absolute differences between neighbouring nodes 94 | batch_idx, node_i, node_j = torch.nonzero(edge_index, as_tuple=True) 95 | abs_diff = torch.sum(torch.abs(x[batch_idx, node_i, :] - x[batch_idx, node_j, :]), dim=-1) # shape [B, E] 96 | 97 | # Gamma matrix 98 | mod_adj = torch.clone(edge_index) 99 | mod_adj[batch_idx, node_i, node_j] /= torch.clamp(abs_diff, min=self.eps) 100 | 101 | # Compute Laplacian L=D-A 102 | deg = torch.sum(mod_adj, dim=-1) 103 | mod_adj = -mod_adj 104 | mod_adj[:, range(N), range(N)] += deg 105 | 106 | # Compute modified laplacian: L_adjusted = I - delta*L 107 | mod_adj = -self.delta_coeff * mod_adj 108 | mod_adj[:, range(N), range(N)] += 1 109 | 110 | out = torch.matmul(mod_adj, x) 111 | 112 | if self.bias is not None: 113 | out = out + self.bias 114 | 115 | if mask is not None: 116 | out = out * mask.view(B, N, 1).to(x.dtype) 117 | 118 | else: 119 | if isinstance(edge_index, SparseTensor): 120 | row, col, edge_weight = edge_index.coo() 121 | edge_index = torch.stack((row, col), dim=0) 122 | else: 123 | row, col = edge_index 124 | 125 | # Absolute differences between neighbouring nodes 126 | abs_diff = torch.abs(x[row, :] - x[col, :]) # shape [E, in_channels] 127 | abs_diff = abs_diff.sum(dim=1) # shape [E, ] 128 | 129 | # Gamma matrix 130 | denom = torch.clamp(abs_diff, min=self.eps) 131 | if edge_weight is None: 132 | gamma_vals = 1 / denom # shape [E] 133 | else: 134 | gamma_vals = edge_weight / denom # shape [E] 135 | 136 | # Laplacian L=D-A 137 | lap_index, lap_weight = utils.get_laplacian(edge_index, gamma_vals) 138 | 139 | # Modified laplacian: I-delta*L 140 | lap_weight *= -self.delta_coeff 141 | mod_lap_index, mod_lap_weight = utils.add_self_loops(lap_index, lap_weight, 142 | fill_value=1., num_nodes=x.size(0)) 143 | 144 | out = self.propagate(edge_index=mod_lap_index, x=x, edge_weight=mod_lap_weight, size=None) 145 | 146 | if self.bias is not None: 147 | out = out + self.bias 148 | 149 | return self.act(out) 150 | 151 | def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor: 152 | return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j 153 | -------------------------------------------------------------------------------- /pytorch/classification.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from collections import OrderedDict 3 | import numpy as np 4 | import torch 5 | from torch import Tensor 6 | import torch_geometric.transforms as transforms 7 | from torch_geometric.datasets import TUDataset 8 | from torch_geometric.nn import Sequential, Linear 9 | from torch_geometric.loader import DataLoader 10 | from torch_geometric.utils import to_dense_batch, to_dense_adj 11 | from sklearn.model_selection import StratifiedKFold, train_test_split 12 | from GTVConv import GTVConv 13 | from AsymCheegerCutPool import AsymCheegerCutPool 14 | 15 | 16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 17 | 18 | ################################ 19 | # CONFIG 20 | ################################ 21 | mp_layers = 1 22 | mp_channels = 32 23 | mp_activation = "relu" 24 | delta_coeff = 2.0 25 | 26 | mlp_hidden_layers = 1 27 | mlp_hidden_channels = 32 28 | mlp_activation = "relu" 29 | totvar_coeff = 0.5 30 | balance_coeff = 0.5 31 | 32 | epochs = 100 33 | batch_size = 16 34 | learning_rate = 5e-4 35 | l2_reg_val = 0 36 | patience = 10 37 | 38 | results = {"acc_scores": []} 39 | 40 | ################################ 41 | # LOAD DATASET 42 | ################################ 43 | dataset_id = "PROTEINS" 44 | 45 | path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'data', dataset_id) 46 | dataset = TUDataset(path, "PROTEINS", use_node_attr=True, cleaned=True) 47 | 48 | # Parameters 49 | N = max(graph.num_nodes for graph in dataset) 50 | n_out = dataset.num_classes # Dimension of target 51 | 52 | # Train/test split 53 | idxs = np.random.permutation(len(dataset)) 54 | split_va, split_te = int(0.8 * len(dataset)), int(0.9 * len(dataset)) 55 | idx_tr, idx_va, idx_te = np.split(idxs, [split_va, split_te]) 56 | dataset_tr = dataset[torch.tensor(idx_tr).long()] 57 | dataset_va = dataset[torch.tensor(idx_va).long()] 58 | dataset_te = dataset[torch.tensor(idx_te).long()] 59 | loader_tr = DataLoader(dataset_tr, batch_size=batch_size, shuffle=True) 60 | loader_va = DataLoader(dataset_va, batch_size=batch_size, shuffle=False) 61 | loader_te = DataLoader(dataset_te, batch_size=batch_size, shuffle=False) 62 | 63 | ################################ 64 | # MODEL 65 | ################################ 66 | class ClassificationModel(torch.nn.Module): 67 | def __init__(self, n_out, mp1, pool1, mp2, pool2, mp3): 68 | super().__init__() 69 | 70 | self.mp1 = mp1 71 | self.pool1 = pool1 72 | self.mp2 = mp2 73 | self.pool2 = pool2 74 | self.mp3 = mp3 75 | self.output_layer = Linear(mp_channels, n_out) 76 | 77 | 78 | def forward(self, x: Tensor, edge_index: Tensor, edge_weight: Tensor, batch: Tensor): 79 | 80 | # 1st block 81 | x = self.mp1(x, edge_index, edge_weight) 82 | x, mask = to_dense_batch(x, batch) 83 | adj = to_dense_adj(edge_index, edge_attr=edge_weight, batch=batch) 84 | x, adj, tv1, bal1 = self.pool1(x, adj, mask=mask) 85 | 86 | # 2nd block 87 | x = self.mp2(x, edge_index=adj, edge_weight=None) 88 | x, adj, tv2, bal2 = self.pool2(x, adj) 89 | 90 | # 3rd block 91 | x = self.mp3(x, edge_index=adj, edge_weight=None) 92 | x = x.mean(dim=1) # global mean pooling 93 | x = self.output_layer(x) 94 | 95 | return x, tv1 + tv2, bal1 + bal2 96 | 97 | 98 | MP1 = [ 99 | (GTVConv(dataset.num_features if i==0 else mp_channels, 100 | mp_channels, 101 | act=mp_activation, 102 | delta_coeff=delta_coeff), 103 | 'x, edge_index, edge_weight -> x') 104 | for i in range(mp_layers)] 105 | 106 | MP1 = Sequential('x, edge_index, edge_weight', MP1) 107 | 108 | 109 | Pool1 = AsymCheegerCutPool(int(N//2), 110 | mlp_channels=[mp_channels] + 111 | [mlp_hidden_channels for _ in range(mlp_hidden_layers)], 112 | mlp_activation=mlp_activation, 113 | totvar_coeff=totvar_coeff, 114 | balance_coeff=balance_coeff, 115 | return_selection=False, 116 | return_pooled_graph=True) 117 | 118 | 119 | MP2 = [ 120 | (GTVConv(mp_channels, 121 | mp_channels, 122 | act=mp_activation, 123 | delta_coeff=delta_coeff), 124 | 'x, edge_index, edge_weight -> x') 125 | for _ in range(mp_layers)] 126 | 127 | MP2 = Sequential('x, edge_index, edge_weight', MP2) 128 | 129 | 130 | Pool2 = AsymCheegerCutPool(int(N//4), 131 | mlp_channels=[mp_channels] + 132 | [mlp_hidden_channels for _ in range(mlp_hidden_layers)], 133 | mlp_activation=mlp_activation, 134 | totvar_coeff=totvar_coeff, 135 | balance_coeff=balance_coeff, 136 | return_selection=False, 137 | return_pooled_graph=True) 138 | 139 | 140 | MP3 = [ 141 | (GTVConv(mp_channels, 142 | mp_channels, 143 | act=mp_activation, 144 | delta_coeff=delta_coeff), 145 | 'x, edge_index, edge_weight -> x') 146 | for _ in range(mp_layers)] 147 | 148 | MP3 = Sequential('x, edge_index, edge_weight', MP3) 149 | 150 | 151 | # Setup model 152 | model = ClassificationModel(n_out, 153 | mp1=MP1, 154 | pool1=Pool1, 155 | mp2=MP2, 156 | pool2=Pool2, 157 | mp3=MP3).to(device) 158 | 159 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 160 | loss_fn = torch.nn.CrossEntropyLoss() 161 | 162 | 163 | ################################ 164 | # TRAIN AND TEST 165 | ################################ 166 | 167 | def train(): 168 | model.train() 169 | 170 | for data in loader_tr: 171 | data.to(device) 172 | out, tv_loss, bal_loss = model(data.x, data.edge_index, data.edge_weight, data.batch) 173 | loss = tv_loss + bal_loss 174 | loss += loss_fn(out, data.y) 175 | loss.backward() 176 | optimizer.step() 177 | optimizer.zero_grad() 178 | 179 | @torch.no_grad() 180 | def test(loader): 181 | model.eval() 182 | 183 | correct = 0 184 | for data in loader: 185 | data.to(device) 186 | out, tv_loss, bal_loss = model(data.x, data.edge_index, data.edge_weight, data.batch) 187 | loss = tv_loss + bal_loss + loss_fn(out, data.y) 188 | pred = out.argmax(dim=1) 189 | correct += int((pred == data.y).sum()) 190 | return loss, correct / len(loader.dataset) 191 | 192 | 193 | best_val_acc = 0 194 | patience_count = patience 195 | for epoch in range(1, epochs + 1): 196 | train() 197 | train_loss, train_acc = test(loader_tr) 198 | val_loss, val_acc = test(loader_va) 199 | test_loss, test_acc = test(loader_te) 200 | 201 | print(f"Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}, Test Acc: {test_acc: .4f}") 202 | 203 | if val_acc > best_val_acc: 204 | best_val_acc = val_acc 205 | test_loss_at_best_val = test_loss 206 | test_acc_at_best_val = test_acc 207 | patience_count = patience 208 | else: 209 | patience_count -= 1 210 | if patience_count == 0: 211 | break 212 | 213 | print("Test loss: {}. Test acc: {}".format()) 214 | 215 | -------------------------------------------------------------------------------- /pytorch/clustering.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(1, '../utils') 3 | 4 | import os.path as osp 5 | import torch 6 | from torch import Tensor 7 | import torch_geometric.transforms as transforms 8 | from torch_geometric.datasets import Planetoid, CitationFull 9 | from torch_geometric import utils 10 | from torch_geometric.nn import Sequential 11 | from sklearn.metrics.cluster import normalized_mutual_info_score as NMI 12 | from metrics import cluster_acc 13 | from GTVConv import GTVConv 14 | from AsymCheegerCutPool import AsymCheegerCutPool 15 | 16 | torch.manual_seed(1) 17 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 18 | 19 | ################################ 20 | # CONFIG 21 | ################################ 22 | dataset_id="Cora" 23 | mp_channels=512 24 | mp_layers=2 25 | mp_activation="elu" 26 | delta_coeff=0.311 27 | mlp_hidden_channels=256 28 | mlp_hidden_layers=1 29 | mlp_activation="relu" 30 | totvar_coeff=0.785 31 | balance_coeff=0.514 32 | learning_rate=1e-3 33 | epochs=500 34 | 35 | ################################ 36 | # LOAD DATASET 37 | ################################ 38 | path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'data', dataset_id) 39 | if dataset_id in ["Cora", "CiteSeer", "PubMed"]: 40 | dataset = Planetoid(path, dataset_id, transform=transforms.NormalizeFeatures()) 41 | elif dataset_id == "DBLP": 42 | dataset = CitationFull(path, dataset_id, transform=transforms.NormalizeFeatures()) 43 | 44 | data = dataset[0] 45 | data = data.to(device) 46 | 47 | ############################################################################ 48 | # MODEL 49 | ############################################################################ 50 | 51 | class Net(torch.nn.Module): 52 | 53 | def __init__(self): 54 | super().__init__() 55 | 56 | # Message passing layers 57 | mp = [ 58 | (GTVConv(dataset.num_features if i==0 else mp_channels, 59 | mp_channels, 60 | act=mp_activation, 61 | delta_coeff=delta_coeff), 62 | 'x, edge_index, edge_weight -> x') 63 | for i in range(mp_layers)] 64 | 65 | self.mp = Sequential('x, edge_index, edge_weight', mp) 66 | 67 | # Pooling layer 68 | self.pool = AsymCheegerCutPool( 69 | dataset.num_classes, 70 | mlp_channels=[mp_channels] + [mlp_hidden_channels for _ in range(mlp_hidden_layers)], 71 | mlp_activation=mlp_activation, 72 | totvar_coeff=totvar_coeff, 73 | balance_coeff=balance_coeff, 74 | return_selection=True, 75 | return_pooled_graph=False) 76 | 77 | def forward(self, x: Tensor, edge_index: Tensor, edge_weight: Tensor): 78 | 79 | # Propagate node features 80 | x = self.mp(x, edge_index, edge_weight) 81 | 82 | # Compute cluster assignment and obtain auxiliary losses 83 | adj = utils.to_dense_adj(edge_index, edge_attr=edge_weight) 84 | s, tv_loss, bal_loss = self.pool(x, adj) 85 | 86 | return s.squeeze(0), tv_loss, bal_loss 87 | 88 | model = Net().to(device) 89 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, eps=1e-7) 90 | 91 | ############################################################################ 92 | # TRAINING 93 | ############################################################################ 94 | def train(): 95 | model.train() 96 | optimizer.zero_grad() 97 | _, tv_loss, bal_loss = model(data.x, data.edge_index, data.edge_weight) 98 | loss = tv_loss + bal_loss 99 | loss.backward() 100 | optimizer.step() 101 | return loss.item() 102 | 103 | @torch.no_grad() 104 | def test(): 105 | model.eval() 106 | clust, _, _ = model(data.x, data.edge_index, data.edge_weight) 107 | return NMI(data.y.cpu(), clust.max(1)[1].cpu()), cluster_acc(data.y.cpu().numpy(), clust.max(1)[1].cpu().numpy())[0] 108 | 109 | patience = 50 110 | best_loss = 1 111 | nmi_at_best_loss = 0 112 | acc_at_best_loss = 0 113 | for epoch in range(1, epochs+1): 114 | train_loss = train() 115 | nmi, acc = test() 116 | print(f"Epoch: {epoch:03d}, Loss: {train_loss:.4f}, NMI: {nmi:.3f}, ACC: {acc*100: .3f}") 117 | if train_loss < best_loss: 118 | best_loss = train_loss 119 | nmi_at_best_loss = nmi 120 | acc_at_best_loss = acc 121 | patience = 50 122 | else: 123 | patience -= 1 124 | if patience == 0: 125 | break 126 | 127 | print(f"NMI: {nmi_at_best_loss:.3f}, ACC: {acc_at_best_loss*100:.1f}") -------------------------------------------------------------------------------- /pytorch/pytorch_environment.yml: -------------------------------------------------------------------------------- 1 | name: TVGNN-pytorch 2 | channels: 3 | - pyg 4 | - pytorch 5 | - nvidia 6 | - conda-forge 7 | - defaults 8 | dependencies: 9 | - _libgcc_mutex=0.1=main 10 | - _openmp_mutex=5.1=1_gnu 11 | - blas=1.0=mkl 12 | - brotlipy=0.7.0=py310h7f8727e_1002 13 | - bzip2=1.0.8=h7b6447c_0 14 | - ca-certificates=2022.9.24=ha878542_0 15 | - certifi=2022.9.24=pyhd8ed1ab_0 16 | - cffi=1.15.1=py310h5eee18b_2 17 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 18 | - cryptography=38.0.1=py310h9ce1e76_0 19 | - cuda=11.7.1=0 20 | - cuda-cccl=11.7.91=0 21 | - cuda-command-line-tools=11.7.1=0 22 | - cuda-compiler=11.7.1=0 23 | - cuda-cudart=11.7.99=0 24 | - cuda-cudart-dev=11.7.99=0 25 | - cuda-cuobjdump=11.7.91=0 26 | - cuda-cupti=11.7.101=0 27 | - cuda-cuxxfilt=11.7.91=0 28 | - cuda-demo-suite=11.8.86=0 29 | - cuda-documentation=11.8.86=0 30 | - cuda-driver-dev=11.7.99=0 31 | - cuda-gdb=11.8.86=0 32 | - cuda-libraries=11.7.1=0 33 | - cuda-libraries-dev=11.7.1=0 34 | - cuda-memcheck=11.8.86=0 35 | - cuda-nsight=11.8.86=0 36 | - cuda-nsight-compute=11.8.0=0 37 | - cuda-nvcc=11.7.99=0 38 | - cuda-nvdisasm=11.8.86=0 39 | - cuda-nvml-dev=11.7.91=0 40 | - cuda-nvprof=11.8.87=0 41 | - cuda-nvprune=11.7.91=0 42 | - cuda-nvrtc=11.7.99=0 43 | - cuda-nvrtc-dev=11.7.99=0 44 | - cuda-nvtx=11.7.91=0 45 | - cuda-nvvp=11.8.87=0 46 | - cuda-runtime=11.7.1=0 47 | - cuda-sanitizer-api=11.8.86=0 48 | - cuda-toolkit=11.7.1=0 49 | - cuda-tools=11.7.1=0 50 | - cuda-visual-tools=11.7.1=0 51 | - ffmpeg=4.3=hf484d3e_0 52 | - fftw=3.3.9=h27cfd23_1 53 | - freetype=2.12.1=h4a9f257_0 54 | - gds-tools=1.4.0.31=0 55 | - giflib=5.2.1=h7b6447c_0 56 | - gmp=6.2.1=h295c915_3 57 | - gnutls=3.6.15=he1e5248_0 58 | - idna=3.4=py310h06a4308_0 59 | - intel-openmp=2021.4.0=h06a4308_3561 60 | - jinja2=3.1.2=py310h06a4308_0 61 | - joblib=1.1.1=py310h06a4308_0 62 | - jpeg=9e=h7f8727e_0 63 | - lame=3.100=h7b6447c_0 64 | - lcms2=2.12=h3be6417_0 65 | - ld_impl_linux-64=2.38=h1181459_1 66 | - lerc=3.0=h295c915_0 67 | - libcublas=11.11.3.6=0 68 | - libcublas-dev=11.11.3.6=0 69 | - libcufft=10.9.0.58=0 70 | - libcufft-dev=10.9.0.58=0 71 | - libcufile=1.4.0.31=0 72 | - libcufile-dev=1.4.0.31=0 73 | - libcurand=10.3.0.86=0 74 | - libcurand-dev=10.3.0.86=0 75 | - libcusolver=11.4.1.48=0 76 | - libcusolver-dev=11.4.1.48=0 77 | - libcusparse=11.7.5.86=0 78 | - libcusparse-dev=11.7.5.86=0 79 | - libdeflate=1.8=h7f8727e_5 80 | - libffi=3.4.2=h6a678d5_6 81 | - libgcc-ng=11.2.0=h1234567_1 82 | - libgfortran-ng=11.2.0=h00389a5_1 83 | - libgfortran5=11.2.0=h1234567_1 84 | - libgomp=11.2.0=h1234567_1 85 | - libiconv=1.16=h7f8727e_2 86 | - libidn2=2.3.2=h7f8727e_0 87 | - libnpp=11.8.0.86=0 88 | - libnpp-dev=11.8.0.86=0 89 | - libnvjpeg=11.9.0.86=0 90 | - libnvjpeg-dev=11.9.0.86=0 91 | - libpng=1.6.37=hbc83047_0 92 | - libstdcxx-ng=11.2.0=h1234567_1 93 | - libtasn1=4.16.0=h27cfd23_0 94 | - libtiff=4.4.0=hecacb30_2 95 | - libunistring=0.9.10=h27cfd23_0 96 | - libuuid=1.41.5=h5eee18b_0 97 | - libwebp=1.2.4=h11a3e52_0 98 | - libwebp-base=1.2.4=h5eee18b_0 99 | - lz4-c=1.9.3=h295c915_1 100 | - markupsafe=2.1.1=py310h7f8727e_0 101 | - mkl=2021.4.0=h06a4308_640 102 | - mkl-service=2.4.0=py310h7f8727e_0 103 | - mkl_fft=1.3.1=py310hd6ae3a3_0 104 | - mkl_random=1.2.2=py310h00e6091_0 105 | - munkres=1.1.4=pyh9f0ad1d_0 106 | - ncurses=6.3=h5eee18b_3 107 | - nettle=3.7.3=hbbd107a_1 108 | - nsight-compute=2022.3.0.22=0 109 | - numpy=1.23.4=py310hd5efca6_0 110 | - numpy-base=1.23.4=py310h8e6c178_0 111 | - openh264=2.1.1=h4ff587b_0 112 | - openssl=1.1.1s=h7f8727e_0 113 | - pillow=9.2.0=py310hace64e9_1 114 | - pip=22.2.2=py310h06a4308_0 115 | - pycparser=2.21=pyhd3eb1b0_0 116 | - pyg=2.1.0=py310_torch_1.13.0_cu117 117 | - pyopenssl=22.0.0=pyhd3eb1b0_0 118 | - pyparsing=3.0.9=py310h06a4308_0 119 | - pysocks=1.7.1=py310h06a4308_0 120 | - python=3.10.8=h7a1cb2a_1 121 | - pytorch=1.13.0=py3.10_cuda11.7_cudnn8.5.0_0 122 | - pytorch-cluster=1.6.0=py310_torch_1.13.0_cu117 123 | - pytorch-cuda=11.7=h67b0de4_0 124 | - pytorch-mutex=1.0=cuda 125 | - pytorch-scatter=2.1.0=py310_torch_1.13.0_cu117 126 | - pytorch-sparse=0.6.15=py310_torch_1.13.0_cu117 127 | - readline=8.2=h5eee18b_0 128 | - requests=2.28.1=py310h06a4308_0 129 | - scikit-learn=1.1.3=py310h6a678d5_0 130 | - scipy=1.9.3=py310hd5efca6_0 131 | - setuptools=65.5.0=py310h06a4308_0 132 | - six=1.16.0=pyhd3eb1b0_1 133 | - sqlite=3.40.0=h5082296_0 134 | - threadpoolctl=2.2.0=pyh0d69192_0 135 | - tk=8.6.12=h1ccaba5_0 136 | - torchaudio=0.13.0=py310_cu117 137 | - torchvision=0.14.0=py310_cu117 138 | - tqdm=4.64.1=py310h06a4308_0 139 | - typing_extensions=4.3.0=py310h06a4308_0 140 | - tzdata=2022f=h04d1e81_0 141 | - urllib3=1.26.12=py310h06a4308_0 142 | - wheel=0.37.1=pyhd3eb1b0_0 143 | - xz=5.2.6=h5eee18b_0 144 | - zlib=1.2.13=h5eee18b_0 145 | - zstd=1.5.2=ha4553b6_0 146 | -------------------------------------------------------------------------------- /tensorflow/AsymCheegerCutPool.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import Sequential 3 | from tensorflow.keras.layers import Dense 4 | import tensorflow.keras.backend as K 5 | from spektral.layers import ops 6 | from spektral.layers.pooling.src import SRCPool 7 | 8 | class AsymCheegerCutPool(SRCPool): 9 | r""" 10 | An Asymmetric Cheeger Cut Pooling layer from the paper 11 | > [Clustering with Total Variation Graph Neural Networks](https://arxiv.org/abs/2211.06218) 12 | > Jonas Berg Hansen and Filippo Maria Bianchi 13 | 14 | **Mode**: single, batch 15 | 16 | This layer learns a soft clustering of the input graph as follows: 17 | $$ 18 | \begin{align} 19 | \S &= \textrm{MLP}(\X); \\ 20 | \X' &= \S^\top \X \\ 21 | \A' &= \S^\top \A \S; \\ 22 | \end{align} 23 | $$ 24 | where \(\textrm{MLP}\) is a multi-layer perceptron with softmax output. 25 | 26 | The layer includes two auxiliary loss terms/components: 27 | A graph total variation loss given by 28 | $$ 29 | L_\text{GTV} = \frac{1}{2E} \sum_{k=1}^K \sum_{i=1}^N \sum_{j=i}^N a_{i,j} |s_{i,k} - s_{j,k}|, 30 | $$ 31 | where $$E$$ is the number of edges/links, $$K$$ is the number of clusters or output nodes, and $$N$$ is the number of nodes. 32 | 33 | An asymmetrical norm term given by 34 | $$ 35 | L_\text{AN} = \frac{N(K - 1) - \sum_{k=1}^K ||\s_{:,k} - \textrm{quant}_\rho (\s_{:,k})||_{1, \rho}}{N(K-1)}, 36 | $$ 37 | 38 | The layer can be used without a supervised loss to compute node clustering by 39 | minimizing the two auxiliary losses. 40 | 41 | **Input** 42 | 43 | - Node features of shape `(batch, n_nodes_in, n_node_features)`; 44 | - Adjacency matrix of shape `(batch, n_nodes_in, n_nodes_in)`; 45 | 46 | **Output** 47 | 48 | - Reduced node features of shape `(batch, n_nodes_out, n_node_features)`; 49 | - If `return_selection=True`, the selection matrix of shape 50 | `(batch, n_nodes_in, n_nodes_out)`. 51 | 52 | **Arguments** 53 | - `k`: number of output nodes; 54 | - `mlp_hidden`: list of integers, number of hidden units for each hidden layer in 55 | the MLP used to compute cluster assignments (if `None`, the MLP has only one output 56 | layer); 57 | - `mlp_activation`: activation for the MLP layers; 58 | - `return_selection`: boolean, whether to return the selection matrix; 59 | - `use_bias`: use bias in the MLP; 60 | - `totvar_coeff`: coefficient for graph total variation loss component; 61 | - `balance_coeff`: coefficient for asymmetric norm loss component; 62 | - `softmax_temparture`: temperature parameter for softmax activation at the end of the MLP; 63 | - `kernel_initializer`: initializer for the weights of the MLP; 64 | - `bias_regularizer`: regularization applied to the bias of the MLP; 65 | - `kernel_constraint`: constraint applied to the weights of the MLP; 66 | - `bias_constraint`: constraint applied to the bias of the MLP; 67 | """ 68 | 69 | def __init__(self, 70 | k, 71 | mlp_hidden=None, 72 | mlp_activation="relu", 73 | return_selection=False, 74 | use_bias=True, 75 | totvar_coeff=1.0, 76 | balance_coeff=1.0, 77 | kernel_initializer="glorot_uniform", 78 | bias_initializer="zeros", 79 | kernel_regularizer=None, 80 | bias_regularizer=None, 81 | kernel_constraint=None, 82 | bias_constraint=None, 83 | **kwargs 84 | ): 85 | super().__init__( 86 | k=k, 87 | mlp_hidden=mlp_hidden, 88 | mlp_activation=mlp_activation, 89 | return_selection=return_selection, 90 | use_bias=use_bias, 91 | kernel_initializer=kernel_initializer, 92 | bias_initializer=bias_initializer, 93 | kernel_regularizer=kernel_regularizer, 94 | bias_regularizer=bias_regularizer, 95 | kernel_constraint=kernel_constraint, 96 | bias_constraint=bias_constraint, 97 | **kwargs 98 | ) 99 | 100 | self.k = k 101 | self.mlp_hidden = mlp_hidden if mlp_hidden else [] 102 | self.mlp_activation = mlp_activation 103 | self.totvar_coeff = totvar_coeff 104 | self.balance_coeff = balance_coeff 105 | 106 | def build(self, input_shape): 107 | layer_kwargs = dict( 108 | kernel_initializer=self.kernel_initializer, 109 | bias_initializer=self.bias_initializer, 110 | kernel_regularizer=self.kernel_regularizer, 111 | bias_regularizer=self.bias_regularizer, 112 | kernel_constraint=self.kernel_constraint, 113 | bias_constraint=self.bias_constraint, 114 | ) 115 | self.mlp = Sequential( 116 | [ 117 | Dense(channels, self.mlp_activation, **layer_kwargs) 118 | for channels in self.mlp_hidden 119 | ] 120 | + [Dense(self.k, "softmax", **layer_kwargs)] 121 | ) 122 | 123 | super().build(input_shape) 124 | 125 | def call(self, inputs, mask=None): 126 | x, a, i = self.get_inputs(inputs) 127 | return self.pool(x, a, i, mask=mask) 128 | 129 | def select(self, x, a, i, mask=None): 130 | s = self.mlp(x) 131 | if mask is not None: 132 | s *= mask[0] 133 | 134 | # Total variation loss 135 | cut_loss = self.totvar_loss(a, s) 136 | if K.ndim(a) == 3: 137 | cut_loss = K.mean(cut_loss) 138 | self.add_loss(self.totvar_coeff * cut_loss) 139 | 140 | # Asymmetric l1-norm loss 141 | bal_loss = self.balance_loss(s) 142 | if K.ndim(a) == 3: 143 | bal_loss = K.mean(bal_loss) 144 | self.add_loss(self.balance_coeff * bal_loss) 145 | 146 | return s 147 | 148 | def reduce(self, x, s, **kwargs): 149 | return ops.modal_dot(s, x, transpose_a=True) 150 | 151 | def connect(self, a, s, **kwargs): 152 | a_pool = ops.matmul_at_b_a(s, a) 153 | 154 | return a_pool 155 | 156 | def reduce_index(self, i, s, **kwargs): 157 | i_mean = tf.math.segment_mean(i, i) 158 | i_pool = ops.repeat(i_mean, tf.ones_like(i_mean) * self.k) 159 | 160 | return i_pool 161 | 162 | def totvar_loss(self, a, s): 163 | if K.is_sparse(a): 164 | index_i = a.indices[:, 0] 165 | index_j = a.indices[:, 1] 166 | 167 | n_edges = float(len(a.values)) 168 | 169 | loss = tf.math.reduce_sum(a.values[:, tf.newaxis] * 170 | tf.math.abs(tf.gather(s, index_i) - 171 | tf.gather(s, index_j)), 172 | axis=(-2, -1)) 173 | 174 | else: 175 | n_edges = tf.cast(tf.math.count_nonzero( 176 | a, axis=(-2, -1)), dtype=s.dtype) 177 | n_nodes = tf.shape(a)[-1] 178 | if K.ndim(a) == 3: 179 | loss = tf.math.reduce_sum(a * tf.math.reduce_sum(tf.math.abs(s[:, tf.newaxis, ...] - 180 | tf.repeat(s[..., tf.newaxis, :], 181 | n_nodes, axis=-2)), axis=-1), 182 | axis=(-2, -1)) 183 | else: 184 | loss = tf.math.reduce_sum(a * tf.math.reduce_sum(tf.math.abs(s - 185 | tf.repeat(s[..., tf.newaxis, :], 186 | n_nodes, axis=-2)), axis=-1), 187 | axis=(-2, -1)) 188 | 189 | loss *= 1 / (2 * n_edges) 190 | 191 | return loss 192 | 193 | def balance_loss(self, s): 194 | n_nodes = tf.cast(tf.shape(s, out_type=tf.int32)[-2], s.dtype) 195 | 196 | # k-quantile 197 | idx = tf.cast(tf.math.floor(n_nodes / self.k) + 1, dtype=tf.int32) 198 | med = tf.math.top_k(tf.linalg.matrix_transpose(s), 199 | k=idx).values[..., -1] 200 | # Asymmetric l1-norm 201 | if K.ndim(s) == 2: 202 | loss = s - med 203 | else: 204 | loss = s - med[:, tf.newaxis, ...] 205 | loss = ((tf.cast(loss >= 0, loss.dtype) * (self.k - 1) * loss) + 206 | (tf.cast(loss < 0, loss.dtype) * loss * -1.)) 207 | loss = tf.math.reduce_sum(loss, axis=(-2, -1)) 208 | loss = 1 / (n_nodes * (self.k - 1)) * (n_nodes * (self.k - 1) - loss) 209 | 210 | return loss 211 | 212 | def get_config(self): 213 | config = { 214 | "k": self.k, 215 | "mlp_hidden": self.mlp_hidden, 216 | "mlp_activation": self.mlp_activation, 217 | "totvar_coeff": self.totvar_coeff, 218 | "balance_coeff": self.balance_coeff 219 | } 220 | base_config = super().get_config() 221 | return {**base_config, **config} 222 | -------------------------------------------------------------------------------- /tensorflow/GTVConv.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import backend as K 3 | from spektral.layers import ops 4 | from spektral.layers.convolutional.conv import Conv 5 | 6 | 7 | class GTVConv(Conv): 8 | r""" 9 | A graph total variation convolutional layer (GTVConv) from the paper 10 | 11 | > [Clustering with Total Variation Graph Neural Networks](https://arxiv.org/abs/2211.06218) 12 | > Jonas Berg Hansen and Filippo Maria Bianchi 13 | 14 | **Mode**: single, batch 15 | 16 | This layer computes 17 | $$ 18 | \X' = \sigma\left[\left(\I - \delta{\hat{\Lb}_\mathbf{\Gamma}}\right) \X \W \right] 19 | $$ 20 | 21 | **Input** 22 | 23 | - Node features of shape `(batch, n_nodes, n_node_features)`; 24 | - Adjacency matrix of shape `(batch, n_nodes, n_nodes)`; 25 | 26 | **Output** 27 | 28 | - Node features with the same shape as the input, but with the last 29 | dimension changed to `channels`. 30 | 31 | **Arguments** 32 | 33 | - `channels`: number of output channels; 34 | - `delta_coeff`: step size for gradient descent of GTV 35 | - `epsilon`: small number used to numerically stabilize the computation of new adjacency weights 36 | - `activation`: activation function; 37 | - `use_bias`: bool, add a bias vector to the output; 38 | - `kernel_initializer`: initializer for the weights; 39 | - `bias_initializer`: initializer for the bias vector; 40 | - `kernel_regularizer`: regularization applied to the weights; 41 | - `bias_regularizer`: regularization applied to the bias vector; 42 | - `activity_regularizer`: regularization applied to the output; 43 | - `kernel_constraint`: constraint applied to the weights; 44 | - `bias_constraint`: constraint applied to the bias vector. 45 | 46 | """ 47 | 48 | def __init__( 49 | self, 50 | channels, 51 | delta_coeff=1., 52 | epsilon=1e-3, 53 | activation=None, 54 | use_bias=True, 55 | kernel_initializer="he_normal", 56 | bias_initializer="zeros", 57 | kernel_regularizer=None, 58 | bias_regularizer=None, 59 | activity_regularizer=None, 60 | kernel_constraint=None, 61 | bias_constraint=None, 62 | **kwargs 63 | ): 64 | super().__init__( 65 | activation=activation, 66 | use_bias=use_bias, 67 | kernel_initializer=kernel_initializer, 68 | bias_initializer=bias_initializer, 69 | kernel_regularizer=kernel_regularizer, 70 | bias_regularizer=bias_regularizer, 71 | activity_regularizer=activity_regularizer, 72 | kernel_constraint=kernel_constraint, 73 | bias_constraint=bias_constraint, 74 | **kwargs 75 | ) 76 | self.channels = channels 77 | self.delta_coeff = delta_coeff 78 | self.epsilon = epsilon 79 | 80 | def build(self, input_shape): 81 | assert len(input_shape) >= 2 82 | input_dim = input_shape[0][-1] 83 | self.kernel = self.add_weight( 84 | shape=(input_dim, self.channels), 85 | initializer=self.kernel_initializer, 86 | name="kernel", 87 | regularizer=self.kernel_regularizer, 88 | constraint=self.kernel_constraint, 89 | ) 90 | if self.use_bias: 91 | self.bias = self.add_weight( 92 | shape=(self.channels,), 93 | initializer=self.bias_initializer, 94 | name="bias", 95 | regularizer=self.bias_regularizer, 96 | constraint=self.bias_constraint, 97 | ) 98 | self.built = True 99 | 100 | def call(self, inputs, mask=None): 101 | x, a = inputs 102 | 103 | mode = ops.autodetect_mode(x, a) 104 | 105 | # Update node features 106 | x = K.dot(x, self.kernel) 107 | 108 | if mode == ops.modes.SINGLE: 109 | output = self._call_single(x, a) 110 | 111 | elif mode == ops.modes.BATCH: 112 | output = self._call_batch(x, a) 113 | 114 | if self.use_bias: 115 | output = K.bias_add(output, self.bias) 116 | 117 | if mask is not None: 118 | output *= mask[0] 119 | 120 | output = self.activation(output) 121 | 122 | return output 123 | 124 | def _call_single(self, x, a): 125 | if K.is_sparse(a): 126 | index_i = a.indices[:, 0] 127 | index_j = a.indices[:, 1] 128 | 129 | n_nodes = tf.shape(a, out_type=index_i.dtype)[0] 130 | 131 | # Compute absolute differences between neighbouring nodes 132 | abs_diff = tf.math.abs(tf.transpose(tf.gather(x, index_i)) - 133 | tf.transpose(tf.gather(x, index_j))) 134 | abs_diff = tf.math.reduce_sum(abs_diff, axis=0) 135 | 136 | # Compute new adjacency matrix 137 | gamma = tf.sparse.map_values(tf.multiply, 138 | a, 139 | 1 / tf.math.maximum(abs_diff, self.epsilon)) 140 | 141 | # Compute degree matrix from gamma matrix 142 | d_gamma = tf.sparse.SparseTensor(tf.stack([tf.range(n_nodes)] * 2, axis=1), 143 | tf.sparse.reduce_sum(gamma, axis=-1), 144 | [n_nodes, n_nodes]) 145 | 146 | # Compute laplcian: L = D_gamma - Gamma 147 | l = tf.sparse.add(d_gamma, tf.sparse.map_values( 148 | tf.multiply, gamma, -1.)) 149 | 150 | # Compute adjsuted laplacian: L_adjusted = I - delta*L 151 | l = tf.sparse.add(tf.sparse.eye(n_nodes), tf.sparse.map_values( 152 | tf.multiply, l, -self.delta_coeff)) 153 | 154 | # Aggregate features with adjusted laplacian 155 | output = ops.modal_dot(l, x) 156 | 157 | else: 158 | n_nodes = tf.shape(a)[-1] 159 | 160 | abs_diff = tf.math.abs(x[:, tf.newaxis, :] - x) 161 | abs_diff = tf.reduce_sum(abs_diff, axis=-1) 162 | 163 | gamma = a / tf.math.maximum(abs_diff, self.epsilon) 164 | 165 | degrees = tf.math.reduce_sum(gamma, axis=-1) 166 | l = -gamma 167 | l = tf.linalg.set_diag(l, degrees - tf.linalg.diag_part(gamma)) 168 | l = tf.eye(n_nodes) - self.delta_coeff * l 169 | 170 | output = tf.matmul(l, x) 171 | 172 | return output 173 | 174 | def _call_batch(self, x, a): 175 | n_nodes = tf.shape(a)[-1] 176 | 177 | abs_diff = tf.reduce_sum(tf.math.abs(tf.expand_dims(x, 2) - 178 | tf.expand_dims(x, 1)), axis = -1) 179 | 180 | gamma = a / tf.math.maximum(abs_diff, self.epsilon) 181 | 182 | degrees = tf.math.reduce_sum(gamma, axis=-1) 183 | l = -gamma 184 | l = tf.linalg.set_diag(l, degrees - tf.linalg.diag_part(gamma)) 185 | l = tf.eye(n_nodes) - self.delta_coeff * l 186 | 187 | output = tf.matmul(l, x) 188 | 189 | return output 190 | -------------------------------------------------------------------------------- /tensorflow/classification.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow.keras import Model 4 | from tensorflow.keras.callbacks import EarlyStopping 5 | from tensorflow.keras.regularizers import L2 6 | from tensorflow.keras.layers import Dense 7 | from spektral.data.loaders import BatchLoader 8 | from spektral.datasets import TUDataset 9 | from spektral.layers import GraphMasking 10 | from spektral.layers.pooling import GlobalAvgPool 11 | from GTVConv import GTVConv 12 | from AsymCheegerCutPool import AsymCheegerCutPool 13 | 14 | 15 | ################################ 16 | # CONFIG 17 | ################################ 18 | mp_layers = 1 19 | mp_channels = 32 20 | mp_activation = "relu" 21 | delta_coeff = 2.0 22 | 23 | mlp_hidden_layers = 1 24 | mlp_hidden_channels = 32 25 | mlp_activation = "relu" 26 | totvar_coeff = 0.5 27 | balance_coeff = 0.5 28 | 29 | batch_size = 16 30 | l2_reg_val = 0 31 | learning_rate = 5e-4 32 | epochs = 100 33 | patience = 10 34 | 35 | 36 | ################################ 37 | # LOAD DATASET 38 | ################################ 39 | dataset = TUDataset("PROTEINS", clean=True) 40 | 41 | # Parameters 42 | N = max(g.n_nodes for g in dataset) 43 | n_out = dataset.n_labels # Dimension of the target 44 | 45 | # Train/test split 46 | idxs = np.random.permutation(len(dataset)) 47 | split_va, split_te = int(0.8 * len(dataset)), int(0.9 * len(dataset)) 48 | idx_tr, idx_va, idx_te = np.split(idxs, [split_va, split_te]) 49 | dataset_tr = dataset[idx_tr] 50 | dataset_va = dataset[idx_va] 51 | dataset_te = dataset[idx_te] 52 | loader_tr = BatchLoader(dataset_tr, batch_size=batch_size, mask=True) 53 | loader_va = BatchLoader(dataset_va, batch_size=batch_size, shuffle=False, mask=True) 54 | loader_te = BatchLoader(dataset_te, batch_size=batch_size, shuffle=False, mask=True) 55 | 56 | 57 | ################################ 58 | # MODEL 59 | ################################ 60 | class ClassificationModel(Model): 61 | def __init__(self, n_out, mp1, pool1, mp2, pool2, mp3): 62 | super().__init__() 63 | 64 | self.mask = GraphMasking() 65 | self.mp1 = mp1 66 | self.pool1 = pool1 67 | self.mp2 = mp2 68 | self.pool2 = pool2 69 | self.mp3 = mp3 70 | self.global_pool = GlobalAvgPool() 71 | self.output_layer = Dense(n_out, activation = "softmax") 72 | 73 | def call(self, inputs): 74 | 75 | x, a = inputs 76 | x = self.mask(x) 77 | out = x 78 | 79 | # 1st block 80 | for _mp in self.mp1: 81 | out = _mp([out, a]) 82 | out, a_pool = self.pool1([out, a]) 83 | 84 | # 2nd block 85 | for _mp in self.mp2: 86 | out = _mp([out, a_pool]) 87 | out, a_pool = self.pool2([out, a_pool]) 88 | 89 | # 3rd block 90 | for _mp in self.mp3: 91 | out = _mp([out, a_pool]) 92 | out = self.global_pool(out) 93 | out = self.output_layer(out) 94 | 95 | return out 96 | 97 | 98 | MP1 = [GTVConv(mp_channels, 99 | delta_coeff=delta_coeff, 100 | activation=mp_activation, 101 | kernel_regularizer=L2(l2_reg_val)) 102 | for _ in range(mp_layers)] 103 | 104 | Pool1 = AsymCheegerCutPool(int(N//2), 105 | mlp_hidden=[mlp_hidden_channels 106 | for _ in range(mlp_hidden_layers)], 107 | mlp_activation=mlp_activation, 108 | totvar_coeff=totvar_coeff, 109 | balance_coeff=balance_coeff, 110 | kernel_regularizer=L2(l2_reg_val)) 111 | 112 | MP2 = [GTVConv(mp_channels, 113 | delta_coeff=delta_coeff, 114 | activation=mp_activation, 115 | kernel_regularizer=L2(l2_reg_val)) 116 | for _ in range(mp_layers)] 117 | 118 | Pool2 = AsymCheegerCutPool(int(N//4), 119 | mlp_hidden=[mlp_hidden_channels 120 | for _ in range(mlp_hidden_layers)], 121 | mlp_activation=mlp_activation, 122 | totvar_coeff=totvar_coeff, 123 | balance_coeff=balance_coeff, 124 | kernel_regularizer=L2(l2_reg_val)) 125 | 126 | MP3 = [GTVConv(mp_channels, 127 | delta_coeff=delta_coeff, 128 | activation=mp_activation, 129 | kernel_regularizer=L2(l2_reg_val)) 130 | for _ in range(mp_layers)] 131 | 132 | 133 | # Compile the model 134 | model = ClassificationModel( 135 | n_out, 136 | mp1=MP1, 137 | pool1=Pool1, 138 | mp2=MP2, 139 | pool2=Pool2, 140 | mp3=MP3) 141 | 142 | opt = tf.keras.optimizers.Adam(learning_rate=learning_rate) 143 | model.compile(optimizer=opt, loss="categorical_crossentropy", metrics=["acc"]) 144 | 145 | ################################ 146 | # TRAIN AND TEST 147 | ################################ 148 | 149 | model.fit( 150 | loader_tr.load(), 151 | steps_per_epoch=loader_tr.steps_per_epoch, 152 | epochs=epochs, 153 | validation_data=loader_va, 154 | validation_steps=loader_va.steps_per_epoch, 155 | callbacks=[EarlyStopping(patience=patience, restore_best_weights=True)], 156 | verbose = 2) 157 | 158 | loss_te, acc_te = model.evaluate(loader_te.load(), steps=loader_te.steps_per_epoch) 159 | print("Test loss: {}. Test acc: {}".format(loss_te, acc_te)) 160 | -------------------------------------------------------------------------------- /tensorflow/clustering.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(1, '../utils') 3 | 4 | import numpy as np 5 | from tqdm import tqdm 6 | from sklearn.metrics.cluster import normalized_mutual_info_score 7 | import tensorflow as tf 8 | from tensorflow.keras import Model 9 | from spektral.utils.sparse import sp_matrix_to_sp_tensor 10 | from spektral.datasets.citation import Citation 11 | from spektral.datasets import DBLP 12 | from GTVConv import GTVConv 13 | from AsymCheegerCutPool import AsymCheegerCutPool 14 | from metrics import cluster_acc 15 | 16 | tf.random.set_seed(1) 17 | 18 | ################################ 19 | # CONFIG 20 | ################################ 21 | dataset_id = "cora" 22 | mp_channels = 512 23 | mp_layers = 2 24 | mp_activation = "elu" 25 | delta_coeff = 0.311 26 | mlp_hidden_channels = 256 27 | mlp_hidden_layers = 1 28 | mlp_activation = "relu" 29 | totvar_coeff=0.785 30 | balance_coeff=0.514 31 | learning_rate = 1e-3 32 | epochs = 500 33 | 34 | ################################ 35 | # LOAD DATASET 36 | ################################ 37 | if dataset_id in ["cora", "citeseer", "pubmed"]: 38 | dataset = Citation(dataset_id, normalize_x=True) 39 | elif dataset_id == "dblp": 40 | dataset = DBLP(normalize_x=True) 41 | X = dataset.graphs[0].x 42 | A = dataset.graphs[0].a 43 | Y = dataset.graphs[0].y 44 | y = np.argmax(Y, axis=-1) 45 | n_clust = Y.shape[-1] 46 | 47 | ################################ 48 | # MODEL 49 | ################################ 50 | class ClusteringModel(Model): 51 | """ 52 | Defines the general model structure 53 | """ 54 | 55 | def __init__(self, aggr, pool): 56 | super().__init__() 57 | 58 | self.mp = aggr 59 | self.pool = pool 60 | 61 | def call(self, inputs): 62 | x, a = inputs 63 | 64 | out = x 65 | for _mp in self.mp: 66 | out = _mp([out, a]) 67 | 68 | _, _, s_pool = self.pool([out, a]) 69 | 70 | return s_pool 71 | 72 | # Define the message-passing layers 73 | MP_layers = [GTVConv(mp_channels, 74 | delta_coeff=delta_coeff, 75 | activation=mp_activation) 76 | for _ in range(mp_layers)] 77 | 78 | # Define the pooling layer 79 | pool_layer = AsymCheegerCutPool(n_clust, 80 | mlp_hidden=[mlp_hidden_channels for _ in range(mlp_hidden_layers)], 81 | mlp_activation=mlp_activation, 82 | totvar_coeff=totvar_coeff, 83 | balance_coeff=balance_coeff, 84 | return_selection=True) 85 | 86 | # Instantiate model and optimizer 87 | model = ClusteringModel(aggr=MP_layers, pool=pool_layer) 88 | opt = tf.keras.optimizers.Adam(learning_rate=learning_rate) 89 | 90 | ################################ 91 | # TRAINING 92 | ################################ 93 | @tf.function(input_signature=None) 94 | def train_step(model, inputs, labels): 95 | with tf.GradientTape() as tape: 96 | _ = model(inputs, training=True) 97 | loss = sum(model.losses) 98 | gradients = tape.gradient(loss, model.trainable_variables) 99 | opt.apply_gradients(zip(gradients, model.trainable_variables)) 100 | return model.losses 101 | 102 | A = sp_matrix_to_sp_tensor(A) 103 | inputs = [X, A] 104 | loss_history = [] 105 | 106 | # Training loop 107 | for _ in tqdm(range(epochs)): 108 | outs = train_step(model, inputs, Y) 109 | loss_history.append([outs[i].numpy() 110 | for i in range(len(outs))]) 111 | 112 | ################################ 113 | # INFERENCE 114 | ################################ 115 | S_ = model(inputs, training=False) 116 | s = np.argmax(S_, axis=-1) 117 | nmi = normalized_mutual_info_score(y, s) 118 | acc, _, _ = cluster_acc(y, s) 119 | print("NMI: {:.3f}, ACC: {:.3f}".format(nmi, acc)) -------------------------------------------------------------------------------- /tensorflow/tf_environment.yml: -------------------------------------------------------------------------------- 1 | name: TVGNN-tf 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=conda_forge 7 | - _openmp_mutex=4.5=2_gnu 8 | - abseil-cpp=20210324.2=h9c3ff4c_0 9 | - absl-py=1.2.0=pyhd8ed1ab_0 10 | - aiohttp=3.8.3=py310h5764c6d_0 11 | - aiosignal=1.2.0=pyhd8ed1ab_0 12 | - astunparse=1.6.3=pyhd8ed1ab_0 13 | - async-timeout=4.0.2=pyhd8ed1ab_0 14 | - attrs=22.1.0=pyh71513ae_1 15 | - blinker=1.4=py_1 16 | - brotlipy=0.7.0=py310h5764c6d_1004 17 | - bzip2=1.0.8=h7f98852_4 18 | - c-ares=1.18.1=h7f98852_0 19 | - ca-certificates=2022.9.14=ha878542_0 20 | - cached-property=1.5.2=hd8ed1ab_1 21 | - cached_property=1.5.2=pyha770c72_1 22 | - cachetools=5.2.0=pyhd8ed1ab_0 23 | - certifi=2022.9.14=pyhd8ed1ab_0 24 | - cffi=1.15.1=py310h255011f_0 25 | - charset-normalizer=2.1.1=pyhd8ed1ab_0 26 | - click=8.1.3=py310hff52083_0 27 | - cryptography=37.0.4=py310h597c629_0 28 | - cudatoolkit=11.7.0=hd8887f6_10 29 | - cudnn=8.4.1.50=hed8a83a_0 30 | - frozenlist=1.3.1=py310h5764c6d_0 31 | - gast=0.5.3=pyhd8ed1ab_0 32 | - giflib=5.2.1=h36c2ea0_2 33 | - google-auth=2.11.1=pyh1a96a4e_0 34 | - google-auth-oauthlib=0.4.6=pyhd8ed1ab_0 35 | - google-pasta=0.2.0=pyh8c360ce_0 36 | - grpc-cpp=1.45.2=h9d3bbbb_5 37 | - grpcio=1.45.0=py310h44b9e0c_0 38 | - h5py=3.7.0=nompi_py310h416281c_101 39 | - hdf5=1.12.2=nompi_h2386368_100 40 | - icu=70.1=h27087fc_0 41 | - idna=3.4=pyhd8ed1ab_0 42 | - importlib-metadata=4.11.4=py310hff52083_0 43 | - jpeg=9e=h166bdaf_2 44 | - keras=2.8.0=pyhd8ed1ab_0 45 | - keras-preprocessing=1.1.2=pyhd8ed1ab_0 46 | - keyutils=1.6.1=h166bdaf_0 47 | - krb5=1.19.3=h3790be6_0 48 | - ld_impl_linux-64=2.36.1=hea4e1c9_2 49 | - libblas=3.9.0=16_linux64_openblas 50 | - libcblas=3.9.0=16_linux64_openblas 51 | - libcurl=7.83.1=h7bff187_0 52 | - libedit=3.1.20191231=he28a2e2_2 53 | - libev=4.33=h516909a_1 54 | - libffi=3.4.2=h7f98852_5 55 | - libgcc-ng=12.1.0=h8d9b700_16 56 | - libgfortran-ng=12.1.0=h69a702a_16 57 | - libgfortran5=12.1.0=hdcd56e2_16 58 | - libgomp=12.1.0=h8d9b700_16 59 | - liblapack=3.9.0=16_linux64_openblas 60 | - libnghttp2=1.47.0=hdcd2b5c_1 61 | - libnsl=2.0.0=h7f98852_0 62 | - libopenblas=0.3.21=pthreads_h78a6416_3 63 | - libpng=1.6.38=h753d276_0 64 | - libprotobuf=3.20.1=h6239696_4 65 | - libsqlite=3.39.3=h753d276_0 66 | - libssh2=1.10.0=haa6b8db_3 67 | - libstdcxx-ng=12.1.0=ha89aaad_16 68 | - libuuid=2.32.1=h7f98852_1000 69 | - libzlib=1.2.12=h166bdaf_3 70 | - markdown=3.4.1=pyhd8ed1ab_0 71 | - markupsafe=2.1.1=py310h5764c6d_1 72 | - multidict=6.0.2=py310h5764c6d_1 73 | - nccl=2.14.3.1=h0800d71_0 74 | - ncurses=6.3=h27087fc_1 75 | - numpy=1.23.3=py310h53a5b5f_0 76 | - oauthlib=3.2.1=pyhd8ed1ab_0 77 | - openssl=1.1.1q=h166bdaf_0 78 | - opt_einsum=3.3.0=pyhd8ed1ab_1 79 | - pip=22.2.2=pyhd8ed1ab_0 80 | - protobuf=3.20.1=py310hd8f1fbe_0 81 | - pyasn1=0.4.8=py_0 82 | - pyasn1-modules=0.2.7=py_0 83 | - pycparser=2.21=pyhd8ed1ab_0 84 | - pyjwt=2.5.0=pyhd8ed1ab_0 85 | - pyopenssl=22.0.0=pyhd8ed1ab_1 86 | - pysocks=1.7.1=pyha2e5f31_6 87 | - python=3.10.6=h582c2e5_0_cpython 88 | - python-flatbuffers=2.0=pyhd8ed1ab_0 89 | - python_abi=3.10=2_cp310 90 | - pyu2f=0.1.5=pyhd8ed1ab_0 91 | - re2=2022.06.01=h27087fc_0 92 | - readline=8.1.2=h0f457ee_0 93 | - requests=2.28.1=pyhd8ed1ab_1 94 | - requests-oauthlib=1.3.1=pyhd8ed1ab_0 95 | - rsa=4.9=pyhd8ed1ab_0 96 | - scipy=1.9.1=py310hdfbd76f_0 97 | - setuptools=65.3.0=pyhd8ed1ab_1 98 | - six=1.16.0=pyh6c4a22f_0 99 | - snappy=1.1.9=hbd366e4_1 100 | - sqlite=3.39.3=h4ff8645_0 101 | - tensorboard=2.8.0=pyhd8ed1ab_1 102 | - tensorboard-data-server=0.6.0=py310h597c629_2 103 | - tensorboard-plugin-wit=1.8.1=pyhd8ed1ab_0 104 | - tensorflow=2.8.1=cuda112py310he87a039_0 105 | - tensorflow-base=2.8.1=cuda112py310h666ff7d_0 106 | - tensorflow-estimator=2.8.1=cuda112py310h2fa73eb_0 107 | - tensorflow-gpu=2.8.1=cuda112py310h0bbbad9_0 108 | - termcolor=2.0.1=pyhd8ed1ab_1 109 | - tk=8.6.12=h27826a3_0 110 | - typing-extensions=4.3.0=hd8ed1ab_0 111 | - typing_extensions=4.3.0=pyha770c72_0 112 | - tzdata=2022c=h191b570_0 113 | - urllib3=1.26.11=pyhd8ed1ab_0 114 | - werkzeug=2.2.2=pyhd8ed1ab_0 115 | - wheel=0.37.1=pyhd8ed1ab_0 116 | - wrapt=1.14.1=py310h5764c6d_0 117 | - xz=5.2.6=h166bdaf_0 118 | - yarl=1.7.2=py310h5764c6d_2 119 | - zipp=3.8.1=pyhd8ed1ab_0 120 | - zlib=1.2.12=h166bdaf_3 121 | - pip: 122 | - docker-pycreds==0.4.0 123 | - gitdb==4.0.9 124 | - gitpython==3.1.27 125 | - joblib==1.2.0 126 | - lxml==4.9.1 127 | - networkx==2.8.6 128 | - pandas==1.5.0 129 | - pathtools==0.1.2 130 | - promise==2.3 131 | - psutil==5.9.2 132 | - python-dateutil==2.8.2 133 | - pytz==2022.2.1 134 | - pyyaml==6.0 135 | - scikit-learn==1.1.2 136 | - sentry-sdk==1.9.8 137 | - setproctitle==1.3.2 138 | - shortuuid==1.0.9 139 | - smmap==5.0.0 140 | - spektral==1.2.0 141 | - threadpoolctl==3.1.0 142 | - tqdm==4.64.1 143 | -------------------------------------------------------------------------------- /tvgnn_poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FilippoMB/Total-variation-graph-neural-networks/c21b427460fda14000a820a541e9709a909b3005/tvgnn_poster.pdf -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from munkres import Munkres 3 | from sklearn import metrics 4 | 5 | # taken from https://github.com/Tiger101010/DAEGC/blob/main/DAEGC/evaluation.py 6 | # similar to https://github.com/karenlatong/AGC-master/blob/master/metrics.py 7 | def cluster_acc(y_true, y_pred): 8 | y_true = y_true - np.min(y_true) 9 | 10 | l1 = list(set(y_true)) 11 | numclass1 = len(l1) 12 | 13 | l2 = list(set(y_pred)) 14 | numclass2 = len(l2) 15 | 16 | ind = 0 17 | if numclass1 != numclass2: 18 | for i in l1: 19 | if i in l2: 20 | pass 21 | else: 22 | y_pred[ind] = i 23 | ind += 1 24 | 25 | l2 = list(set(y_pred)) 26 | numclass2 = len(l2) 27 | 28 | if numclass1 != numclass2: 29 | print("error") 30 | return 31 | 32 | cost = np.zeros((numclass1, numclass2), dtype=int) 33 | for i, c1 in enumerate(l1): 34 | mps = [i1 for i1, e1 in enumerate(y_true) if e1 == c1] 35 | for j, c2 in enumerate(l2): 36 | mps_d = [i1 for i1 in mps if y_pred[i1] == c2] 37 | cost[i][j] = len(mps_d) 38 | 39 | # match two clustering results by Munkres algorithm 40 | m = Munkres() 41 | cost = cost.__neg__().tolist() 42 | indexes = m.compute(cost) 43 | 44 | # get the match results 45 | new_predict = np.zeros(len(y_pred)) 46 | for i, c in enumerate(l1): 47 | # correponding label in l2: 48 | c2 = l2[indexes[i][1]] 49 | 50 | # ai is the index with label==c2 in the pred_label list 51 | ai = [ind for ind, elm in enumerate(y_pred) if elm == c2] 52 | new_predict[ai] = c 53 | 54 | acc = metrics.accuracy_score(y_true, new_predict) 55 | f1_macro = metrics.f1_score(y_true, new_predict, average="macro") 56 | f1_micro = metrics.f1_score(y_true, new_predict, average="micro") 57 | return acc, f1_macro, f1_micro 58 | --------------------------------------------------------------------------------