├── .gitignore
├── LICENSE
├── README.md
├── img
├── classification.png
├── clustering.png
└── sharp.png
├── pytorch
├── AsymCheegerCutPool.py
├── GTVConv.py
├── classification.py
├── clustering.py
└── pytorch_environment.yml
├── tensorflow
├── AsymCheegerCutPool.py
├── GTVConv.py
├── classification.py
├── clustering.py
└── tf_environment.yml
├── tvgnn_poster.pdf
└── utils
└── metrics.py
/.gitignore:
--------------------------------------------------------------------------------
1 | /pytorch/data/*
2 |
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | pip-wheel-metadata/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | .python-version
88 |
89 | # pipenv
90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
93 | # install all needed dependencies.
94 | #Pipfile.lock
95 |
96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
97 | __pypackages__/
98 |
99 | # Celery stuff
100 | celerybeat-schedule
101 | celerybeat.pid
102 |
103 | # SageMath parsed files
104 | *.sage.py
105 |
106 | # Environments
107 | .env
108 | .venv
109 | env/
110 | venv/
111 | ENV/
112 | env.bak/
113 | venv.bak/
114 |
115 | # Spyder project settings
116 | .spyderproject
117 | .spyproject
118 |
119 | # Rope project settings
120 | .ropeproject
121 |
122 | # mkdocs documentation
123 | /site
124 |
125 | # mypy
126 | .mypy_cache/
127 | .dmypy.json
128 | dmypy.json
129 |
130 | # Pyre type checker
131 | .pyre/
132 |
133 | # MACOS stuff
134 | .DS_store
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Filippo Bianchi
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://icml.cc/virtual/2023/poster/24747)
2 | [](https://arxiv.org/abs/2211.06218)
3 | [](https://github.com/FilippoMB/Total-variation-graph-neural-networks/blob/main/tvgnn_poster.pdf)
4 | [](https://youtu.be/Dyb1YJOez8w)
5 |
6 | Tensorflow and Pytorch implementation of the Total Variation Graph Neural Network (TVGNN) as presented in the [original paper](https://arxiv.org/abs/2211.06218).
7 |
8 | The TVGNN model can be used to **cluster** the vertices of an annotated graph, by accounting both for the graph topology and the vertex features. Compared to other GNNs for clustering, TVGNN creates *sharp* cluster assignments that better approximate the optimal (in the minimum cut sense) partition.
9 |
10 |
11 |
12 | The TVGNN model can also be used to implement [graph pooling](https://gnn-pooling.notion.site/) in a deep GNN architecture for tasks such as graph classification.
13 |
14 | # Downstream tasks
15 | TVGNN can be used to perform vertex clustering and graph classification. Other tasks such as graph regression can also be done with the TVGNN model.
16 |
17 | ### Vertex clustering
18 | This is an unsupervised task, where the goal is to generate a partition of the vertices based on the similarity of their vertex features and the graph topology. The GNN model is trained only by minimizing the unsupervised loss $\mathcal{L}$.
19 |
20 |
21 |
22 | ### Graph classification
23 | This is a supervised with goal of predicting the class of each graph. The GNN rchitectures for graph classification alternates GTVConv layers with a graph pooling layer, which gradually distill the global label information from the vertex representations. The GNN is trained by minimizing the unsupervised loss $\mathcal{L}$ for each pooling layer and a supervised cross-entropy loss $\mathcal{L}_\text{cross-entr}$ between the true and predicted class label.
24 |
25 |
26 |
27 | # 💻 Implementation
28 |
29 |
30 |
31 | ### Tensorflow
32 | This implementation is based on the [Spektral](https://graphneural.network/) library and follows the [Select-Reduce-Connect](https://graphneural.network/layers/pooling/#srcpool) API.
33 | To execute the code, first install the conda environment from [tf_environment.yml](tensorflow/tf_environment.yml)
34 |
35 | conda env create -f tf_environment.yml
36 |
37 | The ``tensorflow/`` folder includes:
38 |
39 | - The implementation of the [GTVConv](/tensorflow/GTVConv.py) layer
40 | - The implementation of the [AsymCheegerCutPool](/tensorflow/AsymCheegerCutPool.py) layer
41 | - An example script to perform the [clustering](/tensorflow/clustering.py) task
42 | - An example script to perform the [classification](/tensorflow/classification.py) task
43 |
44 |
45 |
46 | ### Pytorch
47 | This implementation is based on the [Pytorch Geometric](https://pytorch-geometric.readthedocs.io/) library. To execute the code, install the conda environment from [pytorch_environment.yml](pytorch/pytorch_environment.yml)
48 |
49 | conda env create -f pytorch_environment.yml
50 |
51 | The ``pytorch/`` folder includes:
52 |
53 | - The implementation of the [GTVConv](/pytorch/GTVConv.py) layer
54 | - The implementation of the [AsymCheegerCutPool](/pytorch/AsymCheegerCutPool.py) layer
55 | - An example script to perform the [clustering](/pytorch/clustering.py) task
56 | - An example script to perform the [classification](/pytorch/classification.py) task
57 |
58 |
59 |
60 | ### Spektral
61 |
62 | TVGNN is available on Spektral:
63 |
64 | - [GTVConv](https://graphneural.network/layers/convolution/#gtvconv) layer,
65 | - [AsymCheegerCutPool](https://graphneural.network/layers/pooling/#asymcheegercutpool) layer,
66 | - [Example script](https://github.com/danielegrattarola/spektral/blob/master/examples/other/node_clustering_tvgnn.py) to perform node clustering with TVGNN.
67 |
68 | # 📚 Citation
69 | If you use TVGNN in your research, please consider citing our work as
70 |
71 | ````bibtex
72 | @inproceedings{hansen2023total,
73 | title={Total variation graph neural networks},
74 | author={Hansen, Jonas Berg and Bianchi, Filippo Maria},
75 | booktitle={International Conference on Machine Learning},
76 | pages={12445--12468},
77 | year={2023},
78 | organization={PMLR}
79 | }
80 | ````
81 |
82 | # ⚙️ Technical details
83 | TVGNN consists of the GTVConv layer and the AsymmetricCheegerCut layer.
84 |
85 | ### GTVConv
86 | The GTVConv layer is a *message-passing* layer that minimizes the $L_1$-norm of the difference between features of adjacent nodes. The $l$-th GTVConv layer updates the node features as
87 |
88 | $$\mathbf{X}^{(l+1)} = \sigma\left[ \left( \mathbf{I} - 2\delta \mathbf{L}_\Gamma^{(l)} \right) \mathbf{X}^{(l)}\mathbf{\Theta} \right] $$
89 |
90 | where $\sigma$ is a non-lineary, $\mathbf{\Theta}$ are the trainable weights of the layer, and $\delta$ is an hyperparameter. $\mathbf{L}^{(l)}_ \Gamma$ is a Laplacian defined as $\mathbf{L}^{(l)}_ \Gamma$ = $\mathbf{D}^{(l)}_ \Gamma - \mathbf{\Gamma}^{(l)}$, where $\mathbf{D}_\Gamma = \text{diag}(\mathbf{\Gamma} \boldsymbol{1})$ and
91 |
92 | $$ [\mathbf{\Gamma}]^{(l)}_ {ij} = \frac{a_ {ij}}{\texttt{max}\{ \lVert \boldsymbol{x}_i^{(l)} - \boldsymbol{x}_j^{(l)} \rVert_1, \epsilon \}}$$
93 |
94 | where $a_{ij}$ is the $ij$-th entry of the adjacency matrix, $\boldsymbol{x}_i^{(l)}$ is the feature of vertex $i$ at layer $l$ and $\epsilon$ is a small constant that avoids zero-division.
95 |
96 | ### AsymCheegerCut
97 | The AsymCheegerCut is a *graph pooling* layer that internally contains an $\texttt{MLP}$ parametrized by $\mathbf{\Theta}_\text{MLP}$ and that computes:
98 | - a cluster assignment matrix $\mathbf{S} = \texttt{Softmax}(\texttt{MLP}(\mathbf{X}; \mathbf{\Theta}_\text{MLP})) \in \mathbb{R}^{N\times K}$, which maps the $N$ vertices in $K$ clusters,
99 | - an unsupervised loss $\mathcal{L} = \alpha_1 \mathcal{L}_ \text{GTV} + \alpha_2 \mathcal{L}_ \text{AN}$, where $\alpha_1$ and $\alpha_2$ are two hyperparameters,
100 | - the adjacency matrix and the vertex features of a coarsened graph
101 |
102 | $$\mathbf{A}^\text{pool} = \mathbf{S}^T \tilde{\mathbf{A}} \mathbf{S} \in\mathbb{R}^{K\times K}; \\ \mathbf{X}^\text{pool}=\mathbf{S}^T\mathbf{X} \in\mathbb{R}^{K\times F}.
103 | $$
104 |
105 | The term $\mathcal{L}_ \text{GTV}$ in the loss minimizes the graph total variation of the cluster assignments $\mathbf{S}$ and is defined as
106 |
107 | $$\mathcal{L}_ \text{GTV} = \frac{\mathcal{L}_ \text{GTV}^*}{2E} \in [0, 1],$$
108 |
109 | where $\mathcal{L}_ \text{GTV}^*$ = $\displaystyle\sum_{k=1}^K\sum_{i=1}^N \sum_{j=i}^N a_{i,j} |s_{i,k} - s_{j,k}|$, $s_{i,k}$ is the assignment of vertex $i$ to cluster $k$ and $E$ is the number of edges.
110 |
111 | The term $\mathcal{L}_\text{AN}$ encourages the partition to be balanced and is defined as
112 |
113 | $$\mathcal{L}_ {\text{AN}} = \frac{\beta - \mathcal{L}^*_ \text{AN}}{\beta} \in [0, 1],$$
114 |
115 | where $\mathcal{L}_ \text{AN}^* = \displaystyle\sum^K_{k=1} ||\boldsymbol{s}_ {:,k}$ - $\text{quant}_ \rho (\boldsymbol{s}_ {:,k})||_ {1, \rho}$.
116 | When $\rho = K-1$, then $\beta = N\rho$.
117 | When $\rho$ takes different values, then $\beta = N\rho\min(1, K/(\rho+1))$.
118 | $\text{quant}_ \rho(\boldsymbol{s}_ k)$ denotes the $\rho$-quantile of $\boldsymbol{s}_ k$ and $||\cdot||_ {1,\rho}$ denotes an asymmetric $\ell_1$ norm, which for a vector $\boldsymbol{x} \in \mathbb{R}^{N\times 1}$ is $||\boldsymbol{x}||_ {1,\rho}$ = $\displaystyle\sum^N_{i=1} |x_{i}|_ \rho$, where $|x_i|_ \rho = \rho x_i$ if $x_i\geq 0$ and $|x_i|_ \rho = -x_i$ if $x_i < 0$.
119 |
--------------------------------------------------------------------------------
/img/classification.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FilippoMB/Total-variation-graph-neural-networks/c21b427460fda14000a820a541e9709a909b3005/img/classification.png
--------------------------------------------------------------------------------
/img/clustering.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FilippoMB/Total-variation-graph-neural-networks/c21b427460fda14000a820a541e9709a909b3005/img/clustering.png
--------------------------------------------------------------------------------
/img/sharp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FilippoMB/Total-variation-graph-neural-networks/c21b427460fda14000a820a541e9709a909b3005/img/sharp.png
--------------------------------------------------------------------------------
/pytorch/AsymCheegerCutPool.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Tuple, Union
2 | import math
3 | import torch
4 | from torch import Tensor
5 | from torch_geometric.nn.models.mlp import Linear
6 | from torch_geometric.nn.resolver import activation_resolver
7 |
8 |
9 | class AsymCheegerCutPool(torch.nn.Module):
10 | r"""
11 | The asymmetric cheeger cut pooling layer from the `"Total Variation Graph Neural Networks"
12 | `_ paper.
13 |
14 | Args:
15 | k (int):
16 | Number of clusters or output nodes
17 | mlp_channels (int, list of int):
18 | Number of hidden units for each hidden layer in the MLP used to
19 | compute cluster assignments. First integer must match the number
20 | of input channels.
21 | mlp_activation (any):
22 | Activation function between hidden layers of the MLP.
23 | Must be compatible with `torch_geometric.nn.resolver`.
24 | return_selection (bool):
25 | Whether to return selection matrix. Cannot not be False
26 | if `return_pooled_graph` is False. (default: :obj:`False`)
27 | return_pooled_graph (bool):
28 | Whether to return pooled node features and adjacency.
29 | Cannot be False if `return_selection` is False. (default: :obj:`True`)
30 | bias (bool):
31 | whether to add a bias term to the MLP layers. (default: :obj:`True`)
32 | totvar_coeff (float):
33 | Coefficient for graph total variation loss component. (default: :obj:`1.0`)
34 | balance_coeff (float):
35 | Coefficient for asymmetric norm loss component. (default: :obj:`1.0`)
36 | """
37 |
38 | def __init__(self,
39 | k: int,
40 | mlp_channels: Union[int, List[int]],
41 | mlp_activation="relu",
42 | return_selection: bool = False,
43 | return_pooled_graph: bool = True,
44 | bias: bool = True,
45 | totvar_coeff: float = 1.0,
46 | balance_coeff: float = 1.0,
47 | ):
48 | super().__init__()
49 |
50 | if not return_selection and not return_pooled_graph:
51 | raise ValueError("return_selection and return_pooled_graph can not both be False")
52 |
53 | if isinstance(mlp_channels, int):
54 | mlp_channels = [mlp_channels]
55 |
56 | act = activation_resolver(mlp_activation)
57 | in_channels = mlp_channels[0]
58 | self.mlp = torch.nn.Sequential()
59 | for channels in mlp_channels[1:]:
60 | self.mlp.append(Linear(in_channels, channels, bias=bias))
61 | in_channels = channels
62 | self.mlp.append(act)
63 |
64 |
65 | self.mlp.append(Linear(in_channels, k))
66 | self.k = k
67 | self.return_selection = return_selection
68 | self.return_pooled_graph = return_pooled_graph
69 | self.totvar_coeff = totvar_coeff
70 | self.balance_coeff = balance_coeff
71 |
72 | self.reset_parameters()
73 |
74 | def reset_parameters(self):
75 | for layer in self.mlp:
76 | if isinstance(layer, Linear):
77 | torch.nn.init.xavier_uniform(layer.weight)
78 | torch.nn.init.zeros_(layer.bias)
79 |
80 | def forward(
81 | self,
82 | x: Tensor,
83 | adj: Tensor,
84 | mask: Optional[Tensor] = None,
85 | ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
86 | r"""
87 | Args:
88 | x (Tensor):
89 | Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`
90 | with batch-size :math:`B`, (maximum) number of nodes :math:`N` for each graph,
91 | and feature dimension :math:`F`. Note that the cluster assignment matrix
92 | :math:`\mathbf{S} \in \mathbb{R}^{B \times N \times C}` is
93 | being created within this method.
94 | adj (Tensor):
95 | Adjacency tensor :math:`\mathbf{A} \in \mathbb{R}^{B \times N \times N}`.
96 | mask (BoolTensor, optional):
97 | Mask matrix :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}`
98 | indicating the valid nodes for each graph. (default: :obj:`None`)
99 |
100 | :rtype: (:class:`Tensor`, :class:`Tensor`, :class:`Tensor`,
101 | :class:`Tensor`, :class:`Tensor`, :class:`Tensor`)
102 | """
103 | x = x.unsqueeze(0) if x.dim() == 2 else x
104 | adj = adj.unsqueeze(0) if adj.dim() == 2 else adj
105 |
106 | s = self.mlp(x)
107 | s = torch.softmax(s, dim=-1)
108 |
109 | batch_size, n_nodes, _ = x.size()
110 |
111 | if mask is not None:
112 | mask = mask.view(batch_size, n_nodes, 1).to(x.dtype)
113 | x, s = x * mask, s * mask
114 |
115 | # Pooled features and adjacency
116 | if self.return_pooled_graph:
117 | x_pool = torch.matmul(s.transpose(1, 2), x)
118 | adj_pool = torch.matmul(torch.matmul(s.transpose(1, 2), adj), s)
119 |
120 | # Total variation loss
121 | tv_loss = self.totvar_coeff*torch.mean(self.totvar_loss(adj, s))
122 |
123 | # Balance loss
124 | bal_loss = self.balance_coeff*torch.mean(self.balance_loss(s))
125 |
126 | if self.return_selection and self.return_pooled_graph:
127 | return s, x_pool, adj_pool, tv_loss, bal_loss
128 | elif self.return_selection and not self.return_pooled_graph:
129 | return s, tv_loss, bal_loss
130 | else:
131 | return x_pool, adj_pool, tv_loss, bal_loss
132 |
133 | def totvar_loss(self, adj, s):
134 | l1_norm = torch.sum(torch.abs(s[..., None, :] - s[:, None, ...]), dim=-1)
135 |
136 | loss = torch.sum(adj * l1_norm, dim=(-1, -2))
137 |
138 | # Normalize loss
139 | n_edges = torch.count_nonzero(adj, dim=(-1, -2))
140 | loss *= 1 / (2 * n_edges)
141 |
142 | return loss
143 |
144 | def balance_loss(self, s):
145 | n_nodes = s.size()[-2]
146 |
147 | # k-quantile
148 | idx = int(math.floor(n_nodes / self.k))
149 | quant = torch.sort(s, dim=-2, descending=True)[0][:, idx, :] # shape [B, K]
150 |
151 | # Asymmetric l1-norm
152 | loss = s - torch.unsqueeze(quant, dim=1)
153 | loss = (loss >= 0) * (self.k - 1) * loss + (loss < 0) * loss * -1
154 | loss = torch.sum(loss, dim=(-1, -2)) # shape [B]
155 | loss = 1 / (n_nodes * (self.k - 1)) * (n_nodes * (self.k - 1) - loss)
156 |
157 | return loss
158 |
--------------------------------------------------------------------------------
/pytorch/GTVConv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import Linear, Parameter
3 | from torch_geometric.nn import MessagePassing
4 | from torch import Tensor
5 | from torch_sparse import SparseTensor
6 | from torch_geometric.typing import Adj, OptTensor
7 | from torch_geometric.nn.inits import zeros
8 | from torch_geometric import utils
9 | from torch_scatter import scatter_add
10 | from torch_geometric.nn.resolver import activation_resolver
11 |
12 | def gtv_adj_weights(edge_index, edge_weight, num_nodes=None, flow="source_to_target", coeff=1.):
13 |
14 | fill_value = 0.
15 |
16 | assert flow in ["source_to_target", "target_to_source"]
17 |
18 | edge_index, tmp_edge_weight = utils.add_remaining_self_loops(
19 | edge_index, edge_weight, fill_value, num_nodes)
20 | assert tmp_edge_weight is not None
21 | edge_weight = tmp_edge_weight
22 |
23 | # Compute degrees
24 | row, col = edge_index[0], edge_index[1]
25 | idx = col if flow == "source_to_target" else row
26 | deg = scatter_add(edge_weight, idx, dim=0, dim_size=num_nodes)
27 |
28 | # Compute laplacian: L = D - A = -A + D
29 | edge_weight = -edge_weight
30 | edge_weight[row == col] += deg
31 |
32 | # Compute adjusted laplacian: L_adjusted = I - delta*L = -delta*L + I
33 | edge_weight *= -coeff
34 | edge_weight[row == col] += 1
35 |
36 | return edge_index, edge_weight
37 |
38 |
39 | class GTVConv(MessagePassing):
40 | r"""
41 | The GTVConv layer from the `"Total Variation Graph Neural Networks"
42 | `_ paper
43 |
44 | Args:
45 | in_channels (int):
46 | Size of each input sample
47 | out_channels (int):
48 | Size of each output sample.
49 | bias (bool):
50 | If set to :obj:`False`, the layer will not learn
51 | an additive bias. (default: :obj:`True`)
52 | delta_coeff (float):
53 | Step size for gradient descent of GTV (default: :obj:`1.0`)
54 | eps (float):
55 | Small number used to numerically stabilize the computation of
56 | new adjacency weights. (default: :obj:`1e-3`)
57 | act (any):
58 | Activation function. Must be compatible with
59 | `torch_geometric.nn.resolver`.
60 | """
61 | def __init__(self, in_channels: int, out_channels: int, bias: bool = True,
62 | delta_coeff: float = 1., eps: float = 1e-3, act = "relu"):
63 | super().__init__(aggr='add', flow="target_to_source")
64 | self.weight = Parameter(torch.Tensor(in_channels, out_channels))
65 |
66 | self.delta_coeff = delta_coeff
67 | self.eps = eps
68 |
69 | self.act = activation_resolver(act)
70 |
71 | if bias:
72 | self.bias = Parameter(torch.Tensor(out_channels))
73 | else:
74 | self.register_parameter('bias', None)
75 |
76 | self.reset_parameters()
77 |
78 | def reset_parameters(self):
79 | torch.nn.init.kaiming_normal_(self.weight)
80 | zeros(self.bias)
81 |
82 | def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None, mask=None) -> Tensor:
83 |
84 | # Update node features
85 | x = x @ self.weight
86 |
87 | # Check if a dense adjacency is provided
88 | if isinstance(edge_index, Tensor) and edge_index.size(-1) == edge_index.size(-2):
89 | x = x.unsqueeze(0) if x.dim() == 2 else x
90 | edge_index = edge_index.unsqueeze(0) if edge_index.dim() == 2 else edge_index
91 | B, N, _ = edge_index.size()
92 |
93 | # Absolute differences between neighbouring nodes
94 | batch_idx, node_i, node_j = torch.nonzero(edge_index, as_tuple=True)
95 | abs_diff = torch.sum(torch.abs(x[batch_idx, node_i, :] - x[batch_idx, node_j, :]), dim=-1) # shape [B, E]
96 |
97 | # Gamma matrix
98 | mod_adj = torch.clone(edge_index)
99 | mod_adj[batch_idx, node_i, node_j] /= torch.clamp(abs_diff, min=self.eps)
100 |
101 | # Compute Laplacian L=D-A
102 | deg = torch.sum(mod_adj, dim=-1)
103 | mod_adj = -mod_adj
104 | mod_adj[:, range(N), range(N)] += deg
105 |
106 | # Compute modified laplacian: L_adjusted = I - delta*L
107 | mod_adj = -self.delta_coeff * mod_adj
108 | mod_adj[:, range(N), range(N)] += 1
109 |
110 | out = torch.matmul(mod_adj, x)
111 |
112 | if self.bias is not None:
113 | out = out + self.bias
114 |
115 | if mask is not None:
116 | out = out * mask.view(B, N, 1).to(x.dtype)
117 |
118 | else:
119 | if isinstance(edge_index, SparseTensor):
120 | row, col, edge_weight = edge_index.coo()
121 | edge_index = torch.stack((row, col), dim=0)
122 | else:
123 | row, col = edge_index
124 |
125 | # Absolute differences between neighbouring nodes
126 | abs_diff = torch.abs(x[row, :] - x[col, :]) # shape [E, in_channels]
127 | abs_diff = abs_diff.sum(dim=1) # shape [E, ]
128 |
129 | # Gamma matrix
130 | denom = torch.clamp(abs_diff, min=self.eps)
131 | if edge_weight is None:
132 | gamma_vals = 1 / denom # shape [E]
133 | else:
134 | gamma_vals = edge_weight / denom # shape [E]
135 |
136 | # Laplacian L=D-A
137 | lap_index, lap_weight = utils.get_laplacian(edge_index, gamma_vals)
138 |
139 | # Modified laplacian: I-delta*L
140 | lap_weight *= -self.delta_coeff
141 | mod_lap_index, mod_lap_weight = utils.add_self_loops(lap_index, lap_weight,
142 | fill_value=1., num_nodes=x.size(0))
143 |
144 | out = self.propagate(edge_index=mod_lap_index, x=x, edge_weight=mod_lap_weight, size=None)
145 |
146 | if self.bias is not None:
147 | out = out + self.bias
148 |
149 | return self.act(out)
150 |
151 | def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:
152 | return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j
153 |
--------------------------------------------------------------------------------
/pytorch/classification.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | from collections import OrderedDict
3 | import numpy as np
4 | import torch
5 | from torch import Tensor
6 | import torch_geometric.transforms as transforms
7 | from torch_geometric.datasets import TUDataset
8 | from torch_geometric.nn import Sequential, Linear
9 | from torch_geometric.loader import DataLoader
10 | from torch_geometric.utils import to_dense_batch, to_dense_adj
11 | from sklearn.model_selection import StratifiedKFold, train_test_split
12 | from GTVConv import GTVConv
13 | from AsymCheegerCutPool import AsymCheegerCutPool
14 |
15 |
16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17 |
18 | ################################
19 | # CONFIG
20 | ################################
21 | mp_layers = 1
22 | mp_channels = 32
23 | mp_activation = "relu"
24 | delta_coeff = 2.0
25 |
26 | mlp_hidden_layers = 1
27 | mlp_hidden_channels = 32
28 | mlp_activation = "relu"
29 | totvar_coeff = 0.5
30 | balance_coeff = 0.5
31 |
32 | epochs = 100
33 | batch_size = 16
34 | learning_rate = 5e-4
35 | l2_reg_val = 0
36 | patience = 10
37 |
38 | results = {"acc_scores": []}
39 |
40 | ################################
41 | # LOAD DATASET
42 | ################################
43 | dataset_id = "PROTEINS"
44 |
45 | path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'data', dataset_id)
46 | dataset = TUDataset(path, "PROTEINS", use_node_attr=True, cleaned=True)
47 |
48 | # Parameters
49 | N = max(graph.num_nodes for graph in dataset)
50 | n_out = dataset.num_classes # Dimension of target
51 |
52 | # Train/test split
53 | idxs = np.random.permutation(len(dataset))
54 | split_va, split_te = int(0.8 * len(dataset)), int(0.9 * len(dataset))
55 | idx_tr, idx_va, idx_te = np.split(idxs, [split_va, split_te])
56 | dataset_tr = dataset[torch.tensor(idx_tr).long()]
57 | dataset_va = dataset[torch.tensor(idx_va).long()]
58 | dataset_te = dataset[torch.tensor(idx_te).long()]
59 | loader_tr = DataLoader(dataset_tr, batch_size=batch_size, shuffle=True)
60 | loader_va = DataLoader(dataset_va, batch_size=batch_size, shuffle=False)
61 | loader_te = DataLoader(dataset_te, batch_size=batch_size, shuffle=False)
62 |
63 | ################################
64 | # MODEL
65 | ################################
66 | class ClassificationModel(torch.nn.Module):
67 | def __init__(self, n_out, mp1, pool1, mp2, pool2, mp3):
68 | super().__init__()
69 |
70 | self.mp1 = mp1
71 | self.pool1 = pool1
72 | self.mp2 = mp2
73 | self.pool2 = pool2
74 | self.mp3 = mp3
75 | self.output_layer = Linear(mp_channels, n_out)
76 |
77 |
78 | def forward(self, x: Tensor, edge_index: Tensor, edge_weight: Tensor, batch: Tensor):
79 |
80 | # 1st block
81 | x = self.mp1(x, edge_index, edge_weight)
82 | x, mask = to_dense_batch(x, batch)
83 | adj = to_dense_adj(edge_index, edge_attr=edge_weight, batch=batch)
84 | x, adj, tv1, bal1 = self.pool1(x, adj, mask=mask)
85 |
86 | # 2nd block
87 | x = self.mp2(x, edge_index=adj, edge_weight=None)
88 | x, adj, tv2, bal2 = self.pool2(x, adj)
89 |
90 | # 3rd block
91 | x = self.mp3(x, edge_index=adj, edge_weight=None)
92 | x = x.mean(dim=1) # global mean pooling
93 | x = self.output_layer(x)
94 |
95 | return x, tv1 + tv2, bal1 + bal2
96 |
97 |
98 | MP1 = [
99 | (GTVConv(dataset.num_features if i==0 else mp_channels,
100 | mp_channels,
101 | act=mp_activation,
102 | delta_coeff=delta_coeff),
103 | 'x, edge_index, edge_weight -> x')
104 | for i in range(mp_layers)]
105 |
106 | MP1 = Sequential('x, edge_index, edge_weight', MP1)
107 |
108 |
109 | Pool1 = AsymCheegerCutPool(int(N//2),
110 | mlp_channels=[mp_channels] +
111 | [mlp_hidden_channels for _ in range(mlp_hidden_layers)],
112 | mlp_activation=mlp_activation,
113 | totvar_coeff=totvar_coeff,
114 | balance_coeff=balance_coeff,
115 | return_selection=False,
116 | return_pooled_graph=True)
117 |
118 |
119 | MP2 = [
120 | (GTVConv(mp_channels,
121 | mp_channels,
122 | act=mp_activation,
123 | delta_coeff=delta_coeff),
124 | 'x, edge_index, edge_weight -> x')
125 | for _ in range(mp_layers)]
126 |
127 | MP2 = Sequential('x, edge_index, edge_weight', MP2)
128 |
129 |
130 | Pool2 = AsymCheegerCutPool(int(N//4),
131 | mlp_channels=[mp_channels] +
132 | [mlp_hidden_channels for _ in range(mlp_hidden_layers)],
133 | mlp_activation=mlp_activation,
134 | totvar_coeff=totvar_coeff,
135 | balance_coeff=balance_coeff,
136 | return_selection=False,
137 | return_pooled_graph=True)
138 |
139 |
140 | MP3 = [
141 | (GTVConv(mp_channels,
142 | mp_channels,
143 | act=mp_activation,
144 | delta_coeff=delta_coeff),
145 | 'x, edge_index, edge_weight -> x')
146 | for _ in range(mp_layers)]
147 |
148 | MP3 = Sequential('x, edge_index, edge_weight', MP3)
149 |
150 |
151 | # Setup model
152 | model = ClassificationModel(n_out,
153 | mp1=MP1,
154 | pool1=Pool1,
155 | mp2=MP2,
156 | pool2=Pool2,
157 | mp3=MP3).to(device)
158 |
159 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
160 | loss_fn = torch.nn.CrossEntropyLoss()
161 |
162 |
163 | ################################
164 | # TRAIN AND TEST
165 | ################################
166 |
167 | def train():
168 | model.train()
169 |
170 | for data in loader_tr:
171 | data.to(device)
172 | out, tv_loss, bal_loss = model(data.x, data.edge_index, data.edge_weight, data.batch)
173 | loss = tv_loss + bal_loss
174 | loss += loss_fn(out, data.y)
175 | loss.backward()
176 | optimizer.step()
177 | optimizer.zero_grad()
178 |
179 | @torch.no_grad()
180 | def test(loader):
181 | model.eval()
182 |
183 | correct = 0
184 | for data in loader:
185 | data.to(device)
186 | out, tv_loss, bal_loss = model(data.x, data.edge_index, data.edge_weight, data.batch)
187 | loss = tv_loss + bal_loss + loss_fn(out, data.y)
188 | pred = out.argmax(dim=1)
189 | correct += int((pred == data.y).sum())
190 | return loss, correct / len(loader.dataset)
191 |
192 |
193 | best_val_acc = 0
194 | patience_count = patience
195 | for epoch in range(1, epochs + 1):
196 | train()
197 | train_loss, train_acc = test(loader_tr)
198 | val_loss, val_acc = test(loader_va)
199 | test_loss, test_acc = test(loader_te)
200 |
201 | print(f"Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}, Test Acc: {test_acc: .4f}")
202 |
203 | if val_acc > best_val_acc:
204 | best_val_acc = val_acc
205 | test_loss_at_best_val = test_loss
206 | test_acc_at_best_val = test_acc
207 | patience_count = patience
208 | else:
209 | patience_count -= 1
210 | if patience_count == 0:
211 | break
212 |
213 | print("Test loss: {}. Test acc: {}".format())
214 |
215 |
--------------------------------------------------------------------------------
/pytorch/clustering.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.insert(1, '../utils')
3 |
4 | import os.path as osp
5 | import torch
6 | from torch import Tensor
7 | import torch_geometric.transforms as transforms
8 | from torch_geometric.datasets import Planetoid, CitationFull
9 | from torch_geometric import utils
10 | from torch_geometric.nn import Sequential
11 | from sklearn.metrics.cluster import normalized_mutual_info_score as NMI
12 | from metrics import cluster_acc
13 | from GTVConv import GTVConv
14 | from AsymCheegerCutPool import AsymCheegerCutPool
15 |
16 | torch.manual_seed(1)
17 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18 |
19 | ################################
20 | # CONFIG
21 | ################################
22 | dataset_id="Cora"
23 | mp_channels=512
24 | mp_layers=2
25 | mp_activation="elu"
26 | delta_coeff=0.311
27 | mlp_hidden_channels=256
28 | mlp_hidden_layers=1
29 | mlp_activation="relu"
30 | totvar_coeff=0.785
31 | balance_coeff=0.514
32 | learning_rate=1e-3
33 | epochs=500
34 |
35 | ################################
36 | # LOAD DATASET
37 | ################################
38 | path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'data', dataset_id)
39 | if dataset_id in ["Cora", "CiteSeer", "PubMed"]:
40 | dataset = Planetoid(path, dataset_id, transform=transforms.NormalizeFeatures())
41 | elif dataset_id == "DBLP":
42 | dataset = CitationFull(path, dataset_id, transform=transforms.NormalizeFeatures())
43 |
44 | data = dataset[0]
45 | data = data.to(device)
46 |
47 | ############################################################################
48 | # MODEL
49 | ############################################################################
50 |
51 | class Net(torch.nn.Module):
52 |
53 | def __init__(self):
54 | super().__init__()
55 |
56 | # Message passing layers
57 | mp = [
58 | (GTVConv(dataset.num_features if i==0 else mp_channels,
59 | mp_channels,
60 | act=mp_activation,
61 | delta_coeff=delta_coeff),
62 | 'x, edge_index, edge_weight -> x')
63 | for i in range(mp_layers)]
64 |
65 | self.mp = Sequential('x, edge_index, edge_weight', mp)
66 |
67 | # Pooling layer
68 | self.pool = AsymCheegerCutPool(
69 | dataset.num_classes,
70 | mlp_channels=[mp_channels] + [mlp_hidden_channels for _ in range(mlp_hidden_layers)],
71 | mlp_activation=mlp_activation,
72 | totvar_coeff=totvar_coeff,
73 | balance_coeff=balance_coeff,
74 | return_selection=True,
75 | return_pooled_graph=False)
76 |
77 | def forward(self, x: Tensor, edge_index: Tensor, edge_weight: Tensor):
78 |
79 | # Propagate node features
80 | x = self.mp(x, edge_index, edge_weight)
81 |
82 | # Compute cluster assignment and obtain auxiliary losses
83 | adj = utils.to_dense_adj(edge_index, edge_attr=edge_weight)
84 | s, tv_loss, bal_loss = self.pool(x, adj)
85 |
86 | return s.squeeze(0), tv_loss, bal_loss
87 |
88 | model = Net().to(device)
89 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, eps=1e-7)
90 |
91 | ############################################################################
92 | # TRAINING
93 | ############################################################################
94 | def train():
95 | model.train()
96 | optimizer.zero_grad()
97 | _, tv_loss, bal_loss = model(data.x, data.edge_index, data.edge_weight)
98 | loss = tv_loss + bal_loss
99 | loss.backward()
100 | optimizer.step()
101 | return loss.item()
102 |
103 | @torch.no_grad()
104 | def test():
105 | model.eval()
106 | clust, _, _ = model(data.x, data.edge_index, data.edge_weight)
107 | return NMI(data.y.cpu(), clust.max(1)[1].cpu()), cluster_acc(data.y.cpu().numpy(), clust.max(1)[1].cpu().numpy())[0]
108 |
109 | patience = 50
110 | best_loss = 1
111 | nmi_at_best_loss = 0
112 | acc_at_best_loss = 0
113 | for epoch in range(1, epochs+1):
114 | train_loss = train()
115 | nmi, acc = test()
116 | print(f"Epoch: {epoch:03d}, Loss: {train_loss:.4f}, NMI: {nmi:.3f}, ACC: {acc*100: .3f}")
117 | if train_loss < best_loss:
118 | best_loss = train_loss
119 | nmi_at_best_loss = nmi
120 | acc_at_best_loss = acc
121 | patience = 50
122 | else:
123 | patience -= 1
124 | if patience == 0:
125 | break
126 |
127 | print(f"NMI: {nmi_at_best_loss:.3f}, ACC: {acc_at_best_loss*100:.1f}")
--------------------------------------------------------------------------------
/pytorch/pytorch_environment.yml:
--------------------------------------------------------------------------------
1 | name: TVGNN-pytorch
2 | channels:
3 | - pyg
4 | - pytorch
5 | - nvidia
6 | - conda-forge
7 | - defaults
8 | dependencies:
9 | - _libgcc_mutex=0.1=main
10 | - _openmp_mutex=5.1=1_gnu
11 | - blas=1.0=mkl
12 | - brotlipy=0.7.0=py310h7f8727e_1002
13 | - bzip2=1.0.8=h7b6447c_0
14 | - ca-certificates=2022.9.24=ha878542_0
15 | - certifi=2022.9.24=pyhd8ed1ab_0
16 | - cffi=1.15.1=py310h5eee18b_2
17 | - charset-normalizer=2.0.4=pyhd3eb1b0_0
18 | - cryptography=38.0.1=py310h9ce1e76_0
19 | - cuda=11.7.1=0
20 | - cuda-cccl=11.7.91=0
21 | - cuda-command-line-tools=11.7.1=0
22 | - cuda-compiler=11.7.1=0
23 | - cuda-cudart=11.7.99=0
24 | - cuda-cudart-dev=11.7.99=0
25 | - cuda-cuobjdump=11.7.91=0
26 | - cuda-cupti=11.7.101=0
27 | - cuda-cuxxfilt=11.7.91=0
28 | - cuda-demo-suite=11.8.86=0
29 | - cuda-documentation=11.8.86=0
30 | - cuda-driver-dev=11.7.99=0
31 | - cuda-gdb=11.8.86=0
32 | - cuda-libraries=11.7.1=0
33 | - cuda-libraries-dev=11.7.1=0
34 | - cuda-memcheck=11.8.86=0
35 | - cuda-nsight=11.8.86=0
36 | - cuda-nsight-compute=11.8.0=0
37 | - cuda-nvcc=11.7.99=0
38 | - cuda-nvdisasm=11.8.86=0
39 | - cuda-nvml-dev=11.7.91=0
40 | - cuda-nvprof=11.8.87=0
41 | - cuda-nvprune=11.7.91=0
42 | - cuda-nvrtc=11.7.99=0
43 | - cuda-nvrtc-dev=11.7.99=0
44 | - cuda-nvtx=11.7.91=0
45 | - cuda-nvvp=11.8.87=0
46 | - cuda-runtime=11.7.1=0
47 | - cuda-sanitizer-api=11.8.86=0
48 | - cuda-toolkit=11.7.1=0
49 | - cuda-tools=11.7.1=0
50 | - cuda-visual-tools=11.7.1=0
51 | - ffmpeg=4.3=hf484d3e_0
52 | - fftw=3.3.9=h27cfd23_1
53 | - freetype=2.12.1=h4a9f257_0
54 | - gds-tools=1.4.0.31=0
55 | - giflib=5.2.1=h7b6447c_0
56 | - gmp=6.2.1=h295c915_3
57 | - gnutls=3.6.15=he1e5248_0
58 | - idna=3.4=py310h06a4308_0
59 | - intel-openmp=2021.4.0=h06a4308_3561
60 | - jinja2=3.1.2=py310h06a4308_0
61 | - joblib=1.1.1=py310h06a4308_0
62 | - jpeg=9e=h7f8727e_0
63 | - lame=3.100=h7b6447c_0
64 | - lcms2=2.12=h3be6417_0
65 | - ld_impl_linux-64=2.38=h1181459_1
66 | - lerc=3.0=h295c915_0
67 | - libcublas=11.11.3.6=0
68 | - libcublas-dev=11.11.3.6=0
69 | - libcufft=10.9.0.58=0
70 | - libcufft-dev=10.9.0.58=0
71 | - libcufile=1.4.0.31=0
72 | - libcufile-dev=1.4.0.31=0
73 | - libcurand=10.3.0.86=0
74 | - libcurand-dev=10.3.0.86=0
75 | - libcusolver=11.4.1.48=0
76 | - libcusolver-dev=11.4.1.48=0
77 | - libcusparse=11.7.5.86=0
78 | - libcusparse-dev=11.7.5.86=0
79 | - libdeflate=1.8=h7f8727e_5
80 | - libffi=3.4.2=h6a678d5_6
81 | - libgcc-ng=11.2.0=h1234567_1
82 | - libgfortran-ng=11.2.0=h00389a5_1
83 | - libgfortran5=11.2.0=h1234567_1
84 | - libgomp=11.2.0=h1234567_1
85 | - libiconv=1.16=h7f8727e_2
86 | - libidn2=2.3.2=h7f8727e_0
87 | - libnpp=11.8.0.86=0
88 | - libnpp-dev=11.8.0.86=0
89 | - libnvjpeg=11.9.0.86=0
90 | - libnvjpeg-dev=11.9.0.86=0
91 | - libpng=1.6.37=hbc83047_0
92 | - libstdcxx-ng=11.2.0=h1234567_1
93 | - libtasn1=4.16.0=h27cfd23_0
94 | - libtiff=4.4.0=hecacb30_2
95 | - libunistring=0.9.10=h27cfd23_0
96 | - libuuid=1.41.5=h5eee18b_0
97 | - libwebp=1.2.4=h11a3e52_0
98 | - libwebp-base=1.2.4=h5eee18b_0
99 | - lz4-c=1.9.3=h295c915_1
100 | - markupsafe=2.1.1=py310h7f8727e_0
101 | - mkl=2021.4.0=h06a4308_640
102 | - mkl-service=2.4.0=py310h7f8727e_0
103 | - mkl_fft=1.3.1=py310hd6ae3a3_0
104 | - mkl_random=1.2.2=py310h00e6091_0
105 | - munkres=1.1.4=pyh9f0ad1d_0
106 | - ncurses=6.3=h5eee18b_3
107 | - nettle=3.7.3=hbbd107a_1
108 | - nsight-compute=2022.3.0.22=0
109 | - numpy=1.23.4=py310hd5efca6_0
110 | - numpy-base=1.23.4=py310h8e6c178_0
111 | - openh264=2.1.1=h4ff587b_0
112 | - openssl=1.1.1s=h7f8727e_0
113 | - pillow=9.2.0=py310hace64e9_1
114 | - pip=22.2.2=py310h06a4308_0
115 | - pycparser=2.21=pyhd3eb1b0_0
116 | - pyg=2.1.0=py310_torch_1.13.0_cu117
117 | - pyopenssl=22.0.0=pyhd3eb1b0_0
118 | - pyparsing=3.0.9=py310h06a4308_0
119 | - pysocks=1.7.1=py310h06a4308_0
120 | - python=3.10.8=h7a1cb2a_1
121 | - pytorch=1.13.0=py3.10_cuda11.7_cudnn8.5.0_0
122 | - pytorch-cluster=1.6.0=py310_torch_1.13.0_cu117
123 | - pytorch-cuda=11.7=h67b0de4_0
124 | - pytorch-mutex=1.0=cuda
125 | - pytorch-scatter=2.1.0=py310_torch_1.13.0_cu117
126 | - pytorch-sparse=0.6.15=py310_torch_1.13.0_cu117
127 | - readline=8.2=h5eee18b_0
128 | - requests=2.28.1=py310h06a4308_0
129 | - scikit-learn=1.1.3=py310h6a678d5_0
130 | - scipy=1.9.3=py310hd5efca6_0
131 | - setuptools=65.5.0=py310h06a4308_0
132 | - six=1.16.0=pyhd3eb1b0_1
133 | - sqlite=3.40.0=h5082296_0
134 | - threadpoolctl=2.2.0=pyh0d69192_0
135 | - tk=8.6.12=h1ccaba5_0
136 | - torchaudio=0.13.0=py310_cu117
137 | - torchvision=0.14.0=py310_cu117
138 | - tqdm=4.64.1=py310h06a4308_0
139 | - typing_extensions=4.3.0=py310h06a4308_0
140 | - tzdata=2022f=h04d1e81_0
141 | - urllib3=1.26.12=py310h06a4308_0
142 | - wheel=0.37.1=pyhd3eb1b0_0
143 | - xz=5.2.6=h5eee18b_0
144 | - zlib=1.2.13=h5eee18b_0
145 | - zstd=1.5.2=ha4553b6_0
146 |
--------------------------------------------------------------------------------
/tensorflow/AsymCheegerCutPool.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.keras import Sequential
3 | from tensorflow.keras.layers import Dense
4 | import tensorflow.keras.backend as K
5 | from spektral.layers import ops
6 | from spektral.layers.pooling.src import SRCPool
7 |
8 | class AsymCheegerCutPool(SRCPool):
9 | r"""
10 | An Asymmetric Cheeger Cut Pooling layer from the paper
11 | > [Clustering with Total Variation Graph Neural Networks](https://arxiv.org/abs/2211.06218)
12 | > Jonas Berg Hansen and Filippo Maria Bianchi
13 |
14 | **Mode**: single, batch
15 |
16 | This layer learns a soft clustering of the input graph as follows:
17 | $$
18 | \begin{align}
19 | \S &= \textrm{MLP}(\X); \\
20 | \X' &= \S^\top \X \\
21 | \A' &= \S^\top \A \S; \\
22 | \end{align}
23 | $$
24 | where \(\textrm{MLP}\) is a multi-layer perceptron with softmax output.
25 |
26 | The layer includes two auxiliary loss terms/components:
27 | A graph total variation loss given by
28 | $$
29 | L_\text{GTV} = \frac{1}{2E} \sum_{k=1}^K \sum_{i=1}^N \sum_{j=i}^N a_{i,j} |s_{i,k} - s_{j,k}|,
30 | $$
31 | where $$E$$ is the number of edges/links, $$K$$ is the number of clusters or output nodes, and $$N$$ is the number of nodes.
32 |
33 | An asymmetrical norm term given by
34 | $$
35 | L_\text{AN} = \frac{N(K - 1) - \sum_{k=1}^K ||\s_{:,k} - \textrm{quant}_\rho (\s_{:,k})||_{1, \rho}}{N(K-1)},
36 | $$
37 |
38 | The layer can be used without a supervised loss to compute node clustering by
39 | minimizing the two auxiliary losses.
40 |
41 | **Input**
42 |
43 | - Node features of shape `(batch, n_nodes_in, n_node_features)`;
44 | - Adjacency matrix of shape `(batch, n_nodes_in, n_nodes_in)`;
45 |
46 | **Output**
47 |
48 | - Reduced node features of shape `(batch, n_nodes_out, n_node_features)`;
49 | - If `return_selection=True`, the selection matrix of shape
50 | `(batch, n_nodes_in, n_nodes_out)`.
51 |
52 | **Arguments**
53 | - `k`: number of output nodes;
54 | - `mlp_hidden`: list of integers, number of hidden units for each hidden layer in
55 | the MLP used to compute cluster assignments (if `None`, the MLP has only one output
56 | layer);
57 | - `mlp_activation`: activation for the MLP layers;
58 | - `return_selection`: boolean, whether to return the selection matrix;
59 | - `use_bias`: use bias in the MLP;
60 | - `totvar_coeff`: coefficient for graph total variation loss component;
61 | - `balance_coeff`: coefficient for asymmetric norm loss component;
62 | - `softmax_temparture`: temperature parameter for softmax activation at the end of the MLP;
63 | - `kernel_initializer`: initializer for the weights of the MLP;
64 | - `bias_regularizer`: regularization applied to the bias of the MLP;
65 | - `kernel_constraint`: constraint applied to the weights of the MLP;
66 | - `bias_constraint`: constraint applied to the bias of the MLP;
67 | """
68 |
69 | def __init__(self,
70 | k,
71 | mlp_hidden=None,
72 | mlp_activation="relu",
73 | return_selection=False,
74 | use_bias=True,
75 | totvar_coeff=1.0,
76 | balance_coeff=1.0,
77 | kernel_initializer="glorot_uniform",
78 | bias_initializer="zeros",
79 | kernel_regularizer=None,
80 | bias_regularizer=None,
81 | kernel_constraint=None,
82 | bias_constraint=None,
83 | **kwargs
84 | ):
85 | super().__init__(
86 | k=k,
87 | mlp_hidden=mlp_hidden,
88 | mlp_activation=mlp_activation,
89 | return_selection=return_selection,
90 | use_bias=use_bias,
91 | kernel_initializer=kernel_initializer,
92 | bias_initializer=bias_initializer,
93 | kernel_regularizer=kernel_regularizer,
94 | bias_regularizer=bias_regularizer,
95 | kernel_constraint=kernel_constraint,
96 | bias_constraint=bias_constraint,
97 | **kwargs
98 | )
99 |
100 | self.k = k
101 | self.mlp_hidden = mlp_hidden if mlp_hidden else []
102 | self.mlp_activation = mlp_activation
103 | self.totvar_coeff = totvar_coeff
104 | self.balance_coeff = balance_coeff
105 |
106 | def build(self, input_shape):
107 | layer_kwargs = dict(
108 | kernel_initializer=self.kernel_initializer,
109 | bias_initializer=self.bias_initializer,
110 | kernel_regularizer=self.kernel_regularizer,
111 | bias_regularizer=self.bias_regularizer,
112 | kernel_constraint=self.kernel_constraint,
113 | bias_constraint=self.bias_constraint,
114 | )
115 | self.mlp = Sequential(
116 | [
117 | Dense(channels, self.mlp_activation, **layer_kwargs)
118 | for channels in self.mlp_hidden
119 | ]
120 | + [Dense(self.k, "softmax", **layer_kwargs)]
121 | )
122 |
123 | super().build(input_shape)
124 |
125 | def call(self, inputs, mask=None):
126 | x, a, i = self.get_inputs(inputs)
127 | return self.pool(x, a, i, mask=mask)
128 |
129 | def select(self, x, a, i, mask=None):
130 | s = self.mlp(x)
131 | if mask is not None:
132 | s *= mask[0]
133 |
134 | # Total variation loss
135 | cut_loss = self.totvar_loss(a, s)
136 | if K.ndim(a) == 3:
137 | cut_loss = K.mean(cut_loss)
138 | self.add_loss(self.totvar_coeff * cut_loss)
139 |
140 | # Asymmetric l1-norm loss
141 | bal_loss = self.balance_loss(s)
142 | if K.ndim(a) == 3:
143 | bal_loss = K.mean(bal_loss)
144 | self.add_loss(self.balance_coeff * bal_loss)
145 |
146 | return s
147 |
148 | def reduce(self, x, s, **kwargs):
149 | return ops.modal_dot(s, x, transpose_a=True)
150 |
151 | def connect(self, a, s, **kwargs):
152 | a_pool = ops.matmul_at_b_a(s, a)
153 |
154 | return a_pool
155 |
156 | def reduce_index(self, i, s, **kwargs):
157 | i_mean = tf.math.segment_mean(i, i)
158 | i_pool = ops.repeat(i_mean, tf.ones_like(i_mean) * self.k)
159 |
160 | return i_pool
161 |
162 | def totvar_loss(self, a, s):
163 | if K.is_sparse(a):
164 | index_i = a.indices[:, 0]
165 | index_j = a.indices[:, 1]
166 |
167 | n_edges = float(len(a.values))
168 |
169 | loss = tf.math.reduce_sum(a.values[:, tf.newaxis] *
170 | tf.math.abs(tf.gather(s, index_i) -
171 | tf.gather(s, index_j)),
172 | axis=(-2, -1))
173 |
174 | else:
175 | n_edges = tf.cast(tf.math.count_nonzero(
176 | a, axis=(-2, -1)), dtype=s.dtype)
177 | n_nodes = tf.shape(a)[-1]
178 | if K.ndim(a) == 3:
179 | loss = tf.math.reduce_sum(a * tf.math.reduce_sum(tf.math.abs(s[:, tf.newaxis, ...] -
180 | tf.repeat(s[..., tf.newaxis, :],
181 | n_nodes, axis=-2)), axis=-1),
182 | axis=(-2, -1))
183 | else:
184 | loss = tf.math.reduce_sum(a * tf.math.reduce_sum(tf.math.abs(s -
185 | tf.repeat(s[..., tf.newaxis, :],
186 | n_nodes, axis=-2)), axis=-1),
187 | axis=(-2, -1))
188 |
189 | loss *= 1 / (2 * n_edges)
190 |
191 | return loss
192 |
193 | def balance_loss(self, s):
194 | n_nodes = tf.cast(tf.shape(s, out_type=tf.int32)[-2], s.dtype)
195 |
196 | # k-quantile
197 | idx = tf.cast(tf.math.floor(n_nodes / self.k) + 1, dtype=tf.int32)
198 | med = tf.math.top_k(tf.linalg.matrix_transpose(s),
199 | k=idx).values[..., -1]
200 | # Asymmetric l1-norm
201 | if K.ndim(s) == 2:
202 | loss = s - med
203 | else:
204 | loss = s - med[:, tf.newaxis, ...]
205 | loss = ((tf.cast(loss >= 0, loss.dtype) * (self.k - 1) * loss) +
206 | (tf.cast(loss < 0, loss.dtype) * loss * -1.))
207 | loss = tf.math.reduce_sum(loss, axis=(-2, -1))
208 | loss = 1 / (n_nodes * (self.k - 1)) * (n_nodes * (self.k - 1) - loss)
209 |
210 | return loss
211 |
212 | def get_config(self):
213 | config = {
214 | "k": self.k,
215 | "mlp_hidden": self.mlp_hidden,
216 | "mlp_activation": self.mlp_activation,
217 | "totvar_coeff": self.totvar_coeff,
218 | "balance_coeff": self.balance_coeff
219 | }
220 | base_config = super().get_config()
221 | return {**base_config, **config}
222 |
--------------------------------------------------------------------------------
/tensorflow/GTVConv.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.keras import backend as K
3 | from spektral.layers import ops
4 | from spektral.layers.convolutional.conv import Conv
5 |
6 |
7 | class GTVConv(Conv):
8 | r"""
9 | A graph total variation convolutional layer (GTVConv) from the paper
10 |
11 | > [Clustering with Total Variation Graph Neural Networks](https://arxiv.org/abs/2211.06218)
12 | > Jonas Berg Hansen and Filippo Maria Bianchi
13 |
14 | **Mode**: single, batch
15 |
16 | This layer computes
17 | $$
18 | \X' = \sigma\left[\left(\I - \delta{\hat{\Lb}_\mathbf{\Gamma}}\right) \X \W \right]
19 | $$
20 |
21 | **Input**
22 |
23 | - Node features of shape `(batch, n_nodes, n_node_features)`;
24 | - Adjacency matrix of shape `(batch, n_nodes, n_nodes)`;
25 |
26 | **Output**
27 |
28 | - Node features with the same shape as the input, but with the last
29 | dimension changed to `channels`.
30 |
31 | **Arguments**
32 |
33 | - `channels`: number of output channels;
34 | - `delta_coeff`: step size for gradient descent of GTV
35 | - `epsilon`: small number used to numerically stabilize the computation of new adjacency weights
36 | - `activation`: activation function;
37 | - `use_bias`: bool, add a bias vector to the output;
38 | - `kernel_initializer`: initializer for the weights;
39 | - `bias_initializer`: initializer for the bias vector;
40 | - `kernel_regularizer`: regularization applied to the weights;
41 | - `bias_regularizer`: regularization applied to the bias vector;
42 | - `activity_regularizer`: regularization applied to the output;
43 | - `kernel_constraint`: constraint applied to the weights;
44 | - `bias_constraint`: constraint applied to the bias vector.
45 |
46 | """
47 |
48 | def __init__(
49 | self,
50 | channels,
51 | delta_coeff=1.,
52 | epsilon=1e-3,
53 | activation=None,
54 | use_bias=True,
55 | kernel_initializer="he_normal",
56 | bias_initializer="zeros",
57 | kernel_regularizer=None,
58 | bias_regularizer=None,
59 | activity_regularizer=None,
60 | kernel_constraint=None,
61 | bias_constraint=None,
62 | **kwargs
63 | ):
64 | super().__init__(
65 | activation=activation,
66 | use_bias=use_bias,
67 | kernel_initializer=kernel_initializer,
68 | bias_initializer=bias_initializer,
69 | kernel_regularizer=kernel_regularizer,
70 | bias_regularizer=bias_regularizer,
71 | activity_regularizer=activity_regularizer,
72 | kernel_constraint=kernel_constraint,
73 | bias_constraint=bias_constraint,
74 | **kwargs
75 | )
76 | self.channels = channels
77 | self.delta_coeff = delta_coeff
78 | self.epsilon = epsilon
79 |
80 | def build(self, input_shape):
81 | assert len(input_shape) >= 2
82 | input_dim = input_shape[0][-1]
83 | self.kernel = self.add_weight(
84 | shape=(input_dim, self.channels),
85 | initializer=self.kernel_initializer,
86 | name="kernel",
87 | regularizer=self.kernel_regularizer,
88 | constraint=self.kernel_constraint,
89 | )
90 | if self.use_bias:
91 | self.bias = self.add_weight(
92 | shape=(self.channels,),
93 | initializer=self.bias_initializer,
94 | name="bias",
95 | regularizer=self.bias_regularizer,
96 | constraint=self.bias_constraint,
97 | )
98 | self.built = True
99 |
100 | def call(self, inputs, mask=None):
101 | x, a = inputs
102 |
103 | mode = ops.autodetect_mode(x, a)
104 |
105 | # Update node features
106 | x = K.dot(x, self.kernel)
107 |
108 | if mode == ops.modes.SINGLE:
109 | output = self._call_single(x, a)
110 |
111 | elif mode == ops.modes.BATCH:
112 | output = self._call_batch(x, a)
113 |
114 | if self.use_bias:
115 | output = K.bias_add(output, self.bias)
116 |
117 | if mask is not None:
118 | output *= mask[0]
119 |
120 | output = self.activation(output)
121 |
122 | return output
123 |
124 | def _call_single(self, x, a):
125 | if K.is_sparse(a):
126 | index_i = a.indices[:, 0]
127 | index_j = a.indices[:, 1]
128 |
129 | n_nodes = tf.shape(a, out_type=index_i.dtype)[0]
130 |
131 | # Compute absolute differences between neighbouring nodes
132 | abs_diff = tf.math.abs(tf.transpose(tf.gather(x, index_i)) -
133 | tf.transpose(tf.gather(x, index_j)))
134 | abs_diff = tf.math.reduce_sum(abs_diff, axis=0)
135 |
136 | # Compute new adjacency matrix
137 | gamma = tf.sparse.map_values(tf.multiply,
138 | a,
139 | 1 / tf.math.maximum(abs_diff, self.epsilon))
140 |
141 | # Compute degree matrix from gamma matrix
142 | d_gamma = tf.sparse.SparseTensor(tf.stack([tf.range(n_nodes)] * 2, axis=1),
143 | tf.sparse.reduce_sum(gamma, axis=-1),
144 | [n_nodes, n_nodes])
145 |
146 | # Compute laplcian: L = D_gamma - Gamma
147 | l = tf.sparse.add(d_gamma, tf.sparse.map_values(
148 | tf.multiply, gamma, -1.))
149 |
150 | # Compute adjsuted laplacian: L_adjusted = I - delta*L
151 | l = tf.sparse.add(tf.sparse.eye(n_nodes), tf.sparse.map_values(
152 | tf.multiply, l, -self.delta_coeff))
153 |
154 | # Aggregate features with adjusted laplacian
155 | output = ops.modal_dot(l, x)
156 |
157 | else:
158 | n_nodes = tf.shape(a)[-1]
159 |
160 | abs_diff = tf.math.abs(x[:, tf.newaxis, :] - x)
161 | abs_diff = tf.reduce_sum(abs_diff, axis=-1)
162 |
163 | gamma = a / tf.math.maximum(abs_diff, self.epsilon)
164 |
165 | degrees = tf.math.reduce_sum(gamma, axis=-1)
166 | l = -gamma
167 | l = tf.linalg.set_diag(l, degrees - tf.linalg.diag_part(gamma))
168 | l = tf.eye(n_nodes) - self.delta_coeff * l
169 |
170 | output = tf.matmul(l, x)
171 |
172 | return output
173 |
174 | def _call_batch(self, x, a):
175 | n_nodes = tf.shape(a)[-1]
176 |
177 | abs_diff = tf.reduce_sum(tf.math.abs(tf.expand_dims(x, 2) -
178 | tf.expand_dims(x, 1)), axis = -1)
179 |
180 | gamma = a / tf.math.maximum(abs_diff, self.epsilon)
181 |
182 | degrees = tf.math.reduce_sum(gamma, axis=-1)
183 | l = -gamma
184 | l = tf.linalg.set_diag(l, degrees - tf.linalg.diag_part(gamma))
185 | l = tf.eye(n_nodes) - self.delta_coeff * l
186 |
187 | output = tf.matmul(l, x)
188 |
189 | return output
190 |
--------------------------------------------------------------------------------
/tensorflow/classification.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | from tensorflow.keras import Model
4 | from tensorflow.keras.callbacks import EarlyStopping
5 | from tensorflow.keras.regularizers import L2
6 | from tensorflow.keras.layers import Dense
7 | from spektral.data.loaders import BatchLoader
8 | from spektral.datasets import TUDataset
9 | from spektral.layers import GraphMasking
10 | from spektral.layers.pooling import GlobalAvgPool
11 | from GTVConv import GTVConv
12 | from AsymCheegerCutPool import AsymCheegerCutPool
13 |
14 |
15 | ################################
16 | # CONFIG
17 | ################################
18 | mp_layers = 1
19 | mp_channels = 32
20 | mp_activation = "relu"
21 | delta_coeff = 2.0
22 |
23 | mlp_hidden_layers = 1
24 | mlp_hidden_channels = 32
25 | mlp_activation = "relu"
26 | totvar_coeff = 0.5
27 | balance_coeff = 0.5
28 |
29 | batch_size = 16
30 | l2_reg_val = 0
31 | learning_rate = 5e-4
32 | epochs = 100
33 | patience = 10
34 |
35 |
36 | ################################
37 | # LOAD DATASET
38 | ################################
39 | dataset = TUDataset("PROTEINS", clean=True)
40 |
41 | # Parameters
42 | N = max(g.n_nodes for g in dataset)
43 | n_out = dataset.n_labels # Dimension of the target
44 |
45 | # Train/test split
46 | idxs = np.random.permutation(len(dataset))
47 | split_va, split_te = int(0.8 * len(dataset)), int(0.9 * len(dataset))
48 | idx_tr, idx_va, idx_te = np.split(idxs, [split_va, split_te])
49 | dataset_tr = dataset[idx_tr]
50 | dataset_va = dataset[idx_va]
51 | dataset_te = dataset[idx_te]
52 | loader_tr = BatchLoader(dataset_tr, batch_size=batch_size, mask=True)
53 | loader_va = BatchLoader(dataset_va, batch_size=batch_size, shuffle=False, mask=True)
54 | loader_te = BatchLoader(dataset_te, batch_size=batch_size, shuffle=False, mask=True)
55 |
56 |
57 | ################################
58 | # MODEL
59 | ################################
60 | class ClassificationModel(Model):
61 | def __init__(self, n_out, mp1, pool1, mp2, pool2, mp3):
62 | super().__init__()
63 |
64 | self.mask = GraphMasking()
65 | self.mp1 = mp1
66 | self.pool1 = pool1
67 | self.mp2 = mp2
68 | self.pool2 = pool2
69 | self.mp3 = mp3
70 | self.global_pool = GlobalAvgPool()
71 | self.output_layer = Dense(n_out, activation = "softmax")
72 |
73 | def call(self, inputs):
74 |
75 | x, a = inputs
76 | x = self.mask(x)
77 | out = x
78 |
79 | # 1st block
80 | for _mp in self.mp1:
81 | out = _mp([out, a])
82 | out, a_pool = self.pool1([out, a])
83 |
84 | # 2nd block
85 | for _mp in self.mp2:
86 | out = _mp([out, a_pool])
87 | out, a_pool = self.pool2([out, a_pool])
88 |
89 | # 3rd block
90 | for _mp in self.mp3:
91 | out = _mp([out, a_pool])
92 | out = self.global_pool(out)
93 | out = self.output_layer(out)
94 |
95 | return out
96 |
97 |
98 | MP1 = [GTVConv(mp_channels,
99 | delta_coeff=delta_coeff,
100 | activation=mp_activation,
101 | kernel_regularizer=L2(l2_reg_val))
102 | for _ in range(mp_layers)]
103 |
104 | Pool1 = AsymCheegerCutPool(int(N//2),
105 | mlp_hidden=[mlp_hidden_channels
106 | for _ in range(mlp_hidden_layers)],
107 | mlp_activation=mlp_activation,
108 | totvar_coeff=totvar_coeff,
109 | balance_coeff=balance_coeff,
110 | kernel_regularizer=L2(l2_reg_val))
111 |
112 | MP2 = [GTVConv(mp_channels,
113 | delta_coeff=delta_coeff,
114 | activation=mp_activation,
115 | kernel_regularizer=L2(l2_reg_val))
116 | for _ in range(mp_layers)]
117 |
118 | Pool2 = AsymCheegerCutPool(int(N//4),
119 | mlp_hidden=[mlp_hidden_channels
120 | for _ in range(mlp_hidden_layers)],
121 | mlp_activation=mlp_activation,
122 | totvar_coeff=totvar_coeff,
123 | balance_coeff=balance_coeff,
124 | kernel_regularizer=L2(l2_reg_val))
125 |
126 | MP3 = [GTVConv(mp_channels,
127 | delta_coeff=delta_coeff,
128 | activation=mp_activation,
129 | kernel_regularizer=L2(l2_reg_val))
130 | for _ in range(mp_layers)]
131 |
132 |
133 | # Compile the model
134 | model = ClassificationModel(
135 | n_out,
136 | mp1=MP1,
137 | pool1=Pool1,
138 | mp2=MP2,
139 | pool2=Pool2,
140 | mp3=MP3)
141 |
142 | opt = tf.keras.optimizers.Adam(learning_rate=learning_rate)
143 | model.compile(optimizer=opt, loss="categorical_crossentropy", metrics=["acc"])
144 |
145 | ################################
146 | # TRAIN AND TEST
147 | ################################
148 |
149 | model.fit(
150 | loader_tr.load(),
151 | steps_per_epoch=loader_tr.steps_per_epoch,
152 | epochs=epochs,
153 | validation_data=loader_va,
154 | validation_steps=loader_va.steps_per_epoch,
155 | callbacks=[EarlyStopping(patience=patience, restore_best_weights=True)],
156 | verbose = 2)
157 |
158 | loss_te, acc_te = model.evaluate(loader_te.load(), steps=loader_te.steps_per_epoch)
159 | print("Test loss: {}. Test acc: {}".format(loss_te, acc_te))
160 |
--------------------------------------------------------------------------------
/tensorflow/clustering.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.insert(1, '../utils')
3 |
4 | import numpy as np
5 | from tqdm import tqdm
6 | from sklearn.metrics.cluster import normalized_mutual_info_score
7 | import tensorflow as tf
8 | from tensorflow.keras import Model
9 | from spektral.utils.sparse import sp_matrix_to_sp_tensor
10 | from spektral.datasets.citation import Citation
11 | from spektral.datasets import DBLP
12 | from GTVConv import GTVConv
13 | from AsymCheegerCutPool import AsymCheegerCutPool
14 | from metrics import cluster_acc
15 |
16 | tf.random.set_seed(1)
17 |
18 | ################################
19 | # CONFIG
20 | ################################
21 | dataset_id = "cora"
22 | mp_channels = 512
23 | mp_layers = 2
24 | mp_activation = "elu"
25 | delta_coeff = 0.311
26 | mlp_hidden_channels = 256
27 | mlp_hidden_layers = 1
28 | mlp_activation = "relu"
29 | totvar_coeff=0.785
30 | balance_coeff=0.514
31 | learning_rate = 1e-3
32 | epochs = 500
33 |
34 | ################################
35 | # LOAD DATASET
36 | ################################
37 | if dataset_id in ["cora", "citeseer", "pubmed"]:
38 | dataset = Citation(dataset_id, normalize_x=True)
39 | elif dataset_id == "dblp":
40 | dataset = DBLP(normalize_x=True)
41 | X = dataset.graphs[0].x
42 | A = dataset.graphs[0].a
43 | Y = dataset.graphs[0].y
44 | y = np.argmax(Y, axis=-1)
45 | n_clust = Y.shape[-1]
46 |
47 | ################################
48 | # MODEL
49 | ################################
50 | class ClusteringModel(Model):
51 | """
52 | Defines the general model structure
53 | """
54 |
55 | def __init__(self, aggr, pool):
56 | super().__init__()
57 |
58 | self.mp = aggr
59 | self.pool = pool
60 |
61 | def call(self, inputs):
62 | x, a = inputs
63 |
64 | out = x
65 | for _mp in self.mp:
66 | out = _mp([out, a])
67 |
68 | _, _, s_pool = self.pool([out, a])
69 |
70 | return s_pool
71 |
72 | # Define the message-passing layers
73 | MP_layers = [GTVConv(mp_channels,
74 | delta_coeff=delta_coeff,
75 | activation=mp_activation)
76 | for _ in range(mp_layers)]
77 |
78 | # Define the pooling layer
79 | pool_layer = AsymCheegerCutPool(n_clust,
80 | mlp_hidden=[mlp_hidden_channels for _ in range(mlp_hidden_layers)],
81 | mlp_activation=mlp_activation,
82 | totvar_coeff=totvar_coeff,
83 | balance_coeff=balance_coeff,
84 | return_selection=True)
85 |
86 | # Instantiate model and optimizer
87 | model = ClusteringModel(aggr=MP_layers, pool=pool_layer)
88 | opt = tf.keras.optimizers.Adam(learning_rate=learning_rate)
89 |
90 | ################################
91 | # TRAINING
92 | ################################
93 | @tf.function(input_signature=None)
94 | def train_step(model, inputs, labels):
95 | with tf.GradientTape() as tape:
96 | _ = model(inputs, training=True)
97 | loss = sum(model.losses)
98 | gradients = tape.gradient(loss, model.trainable_variables)
99 | opt.apply_gradients(zip(gradients, model.trainable_variables))
100 | return model.losses
101 |
102 | A = sp_matrix_to_sp_tensor(A)
103 | inputs = [X, A]
104 | loss_history = []
105 |
106 | # Training loop
107 | for _ in tqdm(range(epochs)):
108 | outs = train_step(model, inputs, Y)
109 | loss_history.append([outs[i].numpy()
110 | for i in range(len(outs))])
111 |
112 | ################################
113 | # INFERENCE
114 | ################################
115 | S_ = model(inputs, training=False)
116 | s = np.argmax(S_, axis=-1)
117 | nmi = normalized_mutual_info_score(y, s)
118 | acc, _, _ = cluster_acc(y, s)
119 | print("NMI: {:.3f}, ACC: {:.3f}".format(nmi, acc))
--------------------------------------------------------------------------------
/tensorflow/tf_environment.yml:
--------------------------------------------------------------------------------
1 | name: TVGNN-tf
2 | channels:
3 | - conda-forge
4 | - defaults
5 | dependencies:
6 | - _libgcc_mutex=0.1=conda_forge
7 | - _openmp_mutex=4.5=2_gnu
8 | - abseil-cpp=20210324.2=h9c3ff4c_0
9 | - absl-py=1.2.0=pyhd8ed1ab_0
10 | - aiohttp=3.8.3=py310h5764c6d_0
11 | - aiosignal=1.2.0=pyhd8ed1ab_0
12 | - astunparse=1.6.3=pyhd8ed1ab_0
13 | - async-timeout=4.0.2=pyhd8ed1ab_0
14 | - attrs=22.1.0=pyh71513ae_1
15 | - blinker=1.4=py_1
16 | - brotlipy=0.7.0=py310h5764c6d_1004
17 | - bzip2=1.0.8=h7f98852_4
18 | - c-ares=1.18.1=h7f98852_0
19 | - ca-certificates=2022.9.14=ha878542_0
20 | - cached-property=1.5.2=hd8ed1ab_1
21 | - cached_property=1.5.2=pyha770c72_1
22 | - cachetools=5.2.0=pyhd8ed1ab_0
23 | - certifi=2022.9.14=pyhd8ed1ab_0
24 | - cffi=1.15.1=py310h255011f_0
25 | - charset-normalizer=2.1.1=pyhd8ed1ab_0
26 | - click=8.1.3=py310hff52083_0
27 | - cryptography=37.0.4=py310h597c629_0
28 | - cudatoolkit=11.7.0=hd8887f6_10
29 | - cudnn=8.4.1.50=hed8a83a_0
30 | - frozenlist=1.3.1=py310h5764c6d_0
31 | - gast=0.5.3=pyhd8ed1ab_0
32 | - giflib=5.2.1=h36c2ea0_2
33 | - google-auth=2.11.1=pyh1a96a4e_0
34 | - google-auth-oauthlib=0.4.6=pyhd8ed1ab_0
35 | - google-pasta=0.2.0=pyh8c360ce_0
36 | - grpc-cpp=1.45.2=h9d3bbbb_5
37 | - grpcio=1.45.0=py310h44b9e0c_0
38 | - h5py=3.7.0=nompi_py310h416281c_101
39 | - hdf5=1.12.2=nompi_h2386368_100
40 | - icu=70.1=h27087fc_0
41 | - idna=3.4=pyhd8ed1ab_0
42 | - importlib-metadata=4.11.4=py310hff52083_0
43 | - jpeg=9e=h166bdaf_2
44 | - keras=2.8.0=pyhd8ed1ab_0
45 | - keras-preprocessing=1.1.2=pyhd8ed1ab_0
46 | - keyutils=1.6.1=h166bdaf_0
47 | - krb5=1.19.3=h3790be6_0
48 | - ld_impl_linux-64=2.36.1=hea4e1c9_2
49 | - libblas=3.9.0=16_linux64_openblas
50 | - libcblas=3.9.0=16_linux64_openblas
51 | - libcurl=7.83.1=h7bff187_0
52 | - libedit=3.1.20191231=he28a2e2_2
53 | - libev=4.33=h516909a_1
54 | - libffi=3.4.2=h7f98852_5
55 | - libgcc-ng=12.1.0=h8d9b700_16
56 | - libgfortran-ng=12.1.0=h69a702a_16
57 | - libgfortran5=12.1.0=hdcd56e2_16
58 | - libgomp=12.1.0=h8d9b700_16
59 | - liblapack=3.9.0=16_linux64_openblas
60 | - libnghttp2=1.47.0=hdcd2b5c_1
61 | - libnsl=2.0.0=h7f98852_0
62 | - libopenblas=0.3.21=pthreads_h78a6416_3
63 | - libpng=1.6.38=h753d276_0
64 | - libprotobuf=3.20.1=h6239696_4
65 | - libsqlite=3.39.3=h753d276_0
66 | - libssh2=1.10.0=haa6b8db_3
67 | - libstdcxx-ng=12.1.0=ha89aaad_16
68 | - libuuid=2.32.1=h7f98852_1000
69 | - libzlib=1.2.12=h166bdaf_3
70 | - markdown=3.4.1=pyhd8ed1ab_0
71 | - markupsafe=2.1.1=py310h5764c6d_1
72 | - multidict=6.0.2=py310h5764c6d_1
73 | - nccl=2.14.3.1=h0800d71_0
74 | - ncurses=6.3=h27087fc_1
75 | - numpy=1.23.3=py310h53a5b5f_0
76 | - oauthlib=3.2.1=pyhd8ed1ab_0
77 | - openssl=1.1.1q=h166bdaf_0
78 | - opt_einsum=3.3.0=pyhd8ed1ab_1
79 | - pip=22.2.2=pyhd8ed1ab_0
80 | - protobuf=3.20.1=py310hd8f1fbe_0
81 | - pyasn1=0.4.8=py_0
82 | - pyasn1-modules=0.2.7=py_0
83 | - pycparser=2.21=pyhd8ed1ab_0
84 | - pyjwt=2.5.0=pyhd8ed1ab_0
85 | - pyopenssl=22.0.0=pyhd8ed1ab_1
86 | - pysocks=1.7.1=pyha2e5f31_6
87 | - python=3.10.6=h582c2e5_0_cpython
88 | - python-flatbuffers=2.0=pyhd8ed1ab_0
89 | - python_abi=3.10=2_cp310
90 | - pyu2f=0.1.5=pyhd8ed1ab_0
91 | - re2=2022.06.01=h27087fc_0
92 | - readline=8.1.2=h0f457ee_0
93 | - requests=2.28.1=pyhd8ed1ab_1
94 | - requests-oauthlib=1.3.1=pyhd8ed1ab_0
95 | - rsa=4.9=pyhd8ed1ab_0
96 | - scipy=1.9.1=py310hdfbd76f_0
97 | - setuptools=65.3.0=pyhd8ed1ab_1
98 | - six=1.16.0=pyh6c4a22f_0
99 | - snappy=1.1.9=hbd366e4_1
100 | - sqlite=3.39.3=h4ff8645_0
101 | - tensorboard=2.8.0=pyhd8ed1ab_1
102 | - tensorboard-data-server=0.6.0=py310h597c629_2
103 | - tensorboard-plugin-wit=1.8.1=pyhd8ed1ab_0
104 | - tensorflow=2.8.1=cuda112py310he87a039_0
105 | - tensorflow-base=2.8.1=cuda112py310h666ff7d_0
106 | - tensorflow-estimator=2.8.1=cuda112py310h2fa73eb_0
107 | - tensorflow-gpu=2.8.1=cuda112py310h0bbbad9_0
108 | - termcolor=2.0.1=pyhd8ed1ab_1
109 | - tk=8.6.12=h27826a3_0
110 | - typing-extensions=4.3.0=hd8ed1ab_0
111 | - typing_extensions=4.3.0=pyha770c72_0
112 | - tzdata=2022c=h191b570_0
113 | - urllib3=1.26.11=pyhd8ed1ab_0
114 | - werkzeug=2.2.2=pyhd8ed1ab_0
115 | - wheel=0.37.1=pyhd8ed1ab_0
116 | - wrapt=1.14.1=py310h5764c6d_0
117 | - xz=5.2.6=h166bdaf_0
118 | - yarl=1.7.2=py310h5764c6d_2
119 | - zipp=3.8.1=pyhd8ed1ab_0
120 | - zlib=1.2.12=h166bdaf_3
121 | - pip:
122 | - docker-pycreds==0.4.0
123 | - gitdb==4.0.9
124 | - gitpython==3.1.27
125 | - joblib==1.2.0
126 | - lxml==4.9.1
127 | - networkx==2.8.6
128 | - pandas==1.5.0
129 | - pathtools==0.1.2
130 | - promise==2.3
131 | - psutil==5.9.2
132 | - python-dateutil==2.8.2
133 | - pytz==2022.2.1
134 | - pyyaml==6.0
135 | - scikit-learn==1.1.2
136 | - sentry-sdk==1.9.8
137 | - setproctitle==1.3.2
138 | - shortuuid==1.0.9
139 | - smmap==5.0.0
140 | - spektral==1.2.0
141 | - threadpoolctl==3.1.0
142 | - tqdm==4.64.1
143 |
--------------------------------------------------------------------------------
/tvgnn_poster.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FilippoMB/Total-variation-graph-neural-networks/c21b427460fda14000a820a541e9709a909b3005/tvgnn_poster.pdf
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from munkres import Munkres
3 | from sklearn import metrics
4 |
5 | # taken from https://github.com/Tiger101010/DAEGC/blob/main/DAEGC/evaluation.py
6 | # similar to https://github.com/karenlatong/AGC-master/blob/master/metrics.py
7 | def cluster_acc(y_true, y_pred):
8 | y_true = y_true - np.min(y_true)
9 |
10 | l1 = list(set(y_true))
11 | numclass1 = len(l1)
12 |
13 | l2 = list(set(y_pred))
14 | numclass2 = len(l2)
15 |
16 | ind = 0
17 | if numclass1 != numclass2:
18 | for i in l1:
19 | if i in l2:
20 | pass
21 | else:
22 | y_pred[ind] = i
23 | ind += 1
24 |
25 | l2 = list(set(y_pred))
26 | numclass2 = len(l2)
27 |
28 | if numclass1 != numclass2:
29 | print("error")
30 | return
31 |
32 | cost = np.zeros((numclass1, numclass2), dtype=int)
33 | for i, c1 in enumerate(l1):
34 | mps = [i1 for i1, e1 in enumerate(y_true) if e1 == c1]
35 | for j, c2 in enumerate(l2):
36 | mps_d = [i1 for i1 in mps if y_pred[i1] == c2]
37 | cost[i][j] = len(mps_d)
38 |
39 | # match two clustering results by Munkres algorithm
40 | m = Munkres()
41 | cost = cost.__neg__().tolist()
42 | indexes = m.compute(cost)
43 |
44 | # get the match results
45 | new_predict = np.zeros(len(y_pred))
46 | for i, c in enumerate(l1):
47 | # correponding label in l2:
48 | c2 = l2[indexes[i][1]]
49 |
50 | # ai is the index with label==c2 in the pred_label list
51 | ai = [ind for ind, elm in enumerate(y_pred) if elm == c2]
52 | new_predict[ai] = c
53 |
54 | acc = metrics.accuracy_score(y_true, new_predict)
55 | f1_macro = metrics.f1_score(y_true, new_predict, average="macro")
56 | f1_micro = metrics.f1_score(y_true, new_predict, average="micro")
57 | return acc, f1_macro, f1_micro
58 |
--------------------------------------------------------------------------------