├── LICENSE ├── eval_metrics.py ├── .gitignore ├── just_balance.py ├── README.md ├── example_clustering.py ├── example_clustering_tf.py └── example_classification.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Filippo Bianchi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /eval_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from munkres import Munkres 3 | from sklearn import metrics 4 | from sklearn.metrics.cluster import normalized_mutual_info_score 5 | 6 | 7 | def cluster_acc(y_true, y_pred): 8 | y_true = y_true - np.min(y_true) 9 | 10 | l1 = list(set(y_true)) 11 | numclass1 = len(l1) 12 | 13 | l2 = list(set(y_pred)) 14 | numclass2 = len(l2) 15 | 16 | ind = 0 17 | if numclass1 != numclass2: 18 | for i in l1: 19 | if i in l2: 20 | pass 21 | else: 22 | y_pred[ind] = i 23 | ind += 1 24 | 25 | l2 = list(set(y_pred)) 26 | numclass2 = len(l2) 27 | 28 | if numclass1 != numclass2: 29 | print("error") 30 | return 31 | 32 | cost = np.zeros((numclass1, numclass2), dtype=int) 33 | for i, c1 in enumerate(l1): 34 | mps = [i1 for i1, e1 in enumerate(y_true) if e1 == c1] 35 | for j, c2 in enumerate(l2): 36 | mps_d = [i1 for i1 in mps if y_pred[i1] == c2] 37 | cost[i][j] = len(mps_d) 38 | 39 | # match two clustering results by Munkres algorithm 40 | m = Munkres() 41 | cost = cost.__neg__().tolist() 42 | indexes = m.compute(cost) 43 | 44 | # get the match results 45 | new_predict = np.zeros(len(y_pred)) 46 | for i, c in enumerate(l1): 47 | # correponding label in l2: 48 | c2 = l2[indexes[i][1]] 49 | 50 | # ai is the index with label==c2 in the pred_label list 51 | ai = [ind for ind, elm in enumerate(y_pred) if elm == c2] 52 | new_predict[ai] = c 53 | 54 | acc = metrics.accuracy_score(y_true, new_predict) 55 | return acc 56 | 57 | 58 | def eval_metrics(y_true, y_pred): 59 | acc = cluster_acc(y_true, y_pred) 60 | nmi = normalized_mutual_info_score(y_true, y_pred) 61 | 62 | return acc, nmi 63 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/* 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | -------------------------------------------------------------------------------- /just_balance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | EPS = 1e-15 4 | 5 | 6 | def just_balance_pool(x, adj, s, mask=None, normalize=True): 7 | r"""The Just Balance pooling operator from the `"Simplifying Clustering with 8 | Graph Neural Networks" `_ paper 9 | 10 | .. math:: 11 | \mathbf{X}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot 12 | \mathbf{X} 13 | 14 | \mathbf{A}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot 15 | \mathbf{A} \cdot \mathrm{softmax}(\mathbf{S}) 16 | 17 | based on dense learned assignments :math:`\mathbf{S} \in \mathbb{R}^{B 18 | \times N \times C}`. 19 | Returns the pooled node feature matrix, the coarsened and symmetrically 20 | normalized adjacency matrix and the following auxiliary objective: 21 | 22 | .. math:: 23 | \mathcal{L} = - {\mathrm{Tr}(\sqrt{\mathbf{S}^{\top} \mathbf{S}})} 24 | 25 | Args: 26 | x (Tensor): Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}` 27 | with batch-size :math:`B`, (maximum) number of nodes :math:`N` 28 | for each graph, and feature dimension :math:`F`. 29 | adj (Tensor): Symmetrically normalized adjacency tensor 30 | :math:`\mathbf{A} \in \mathbb{R}^{B \times N \times N}`. 31 | s (Tensor): Assignment tensor :math:`\mathbf{S} \in \mathbb{R}^{B \times N \times C}` 32 | with number of clusters :math:`C`. The softmax does not have to be 33 | applied beforehand, since it is executed within this method. 34 | mask (BoolTensor, optional): Mask matrix 35 | :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating 36 | the valid nodes for each graph. (default: :obj:`None`) 37 | 38 | :rtype: (:class:`Tensor`, :class:`Tensor`, :class:`Tensor`, 39 | :class:`Tensor`) 40 | """ 41 | 42 | x = x.unsqueeze(0) if x.dim() == 2 else x 43 | adj = adj.unsqueeze(0) if adj.dim() == 2 else adj 44 | s = s.unsqueeze(0) if s.dim() == 2 else s 45 | 46 | (batch_size, num_nodes, _), k = x.size(), s.size(-1) 47 | 48 | s = torch.softmax(s, dim=-1) 49 | 50 | if mask is not None: 51 | mask = mask.view(batch_size, num_nodes, 1).to(x.dtype) 52 | x, s = x * mask, s * mask 53 | 54 | out = torch.matmul(s.transpose(1, 2), x) 55 | out_adj = torch.matmul(torch.matmul(s.transpose(1, 2), adj), s) 56 | 57 | # Loss 58 | ss = torch.matmul(s.transpose(1, 2), s) 59 | ss_sqrt = torch.sqrt(ss + EPS) 60 | loss = torch.mean(-_rank3_trace(ss_sqrt)) 61 | if normalize: 62 | loss = loss / torch.sqrt(torch.tensor(num_nodes * k)) 63 | 64 | # Fix and normalize coarsened adjacency matrix. 65 | ind = torch.arange(k, device=out_adj.device) 66 | out_adj[:, ind, ind] = 0 67 | d = torch.einsum('ijk->ij', out_adj) 68 | d = torch.sqrt(torch.clamp(d, min=EPS))[:, None] 69 | out_adj = (out_adj / d) / d.transpose(1, 2) 70 | 71 | return out, out_adj, loss 72 | 73 | 74 | def _rank3_trace(x): 75 | return torch.einsum('ijj->i', x) 76 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![arXiv](https://img.shields.io/badge/arXiv-2207.08779-b31b1b.svg?)](https://arxiv.org/abs/2207.08779) 2 | [![poster](https://custom-icon-badges.demolab.com/badge/poster-pdf-orange.svg?logo=note&logoSource=feather&logoColor=white)](https://drive.google.com/file/d/1cXA0LTHcdTV8Q0-1cjabr7eayM7gKBbh/view?usp=share_link) 3 | 4 | Software implementation and code to reproduce the results of the *Just Balance GNN* (JBGNN) model for graph clustering as presented in the paper [Simplifying Clustering with Graph Neural Networks](https://arxiv.org/abs/2207.08779). 5 | 6 | The JBGNN architecture consists of: 7 | 8 | - a GCN layer operating on the connectivity matrix: $\mathbf{I} - \delta ( \mathbf{I} - \mathbf{D}^{-\frac{1}{2}} \mathbf{A} \mathbf{D}^{-\frac{1}{2}} )$; 9 | - a pooling layer that computes a cluster assignment matrix $\mathbf{S} \in \mathbb{R}^{N \times K}$ as 10 | 11 | $$ \mathbf{S} = \texttt{softmax} \left( \texttt{MLP} \left( \mathbf{\bar X}, \boldsymbol{\Theta}_\text{MLP} \right) \right) $$ 12 | 13 |        where $\mathbf{\bar X}$ are the node features returned by a stack of GCN layers. 14 | 15 | Each pooling layer is associated with an unsupervised loss that balances the size of the clusters and prevents degenerate partitions 16 | 17 | $$\mathcal{L} = - \text{Tr}\left( \sqrt{\mathbf{S}^T\mathbf{S} } \right).$$ 18 | 19 | ## Node clustering and graph classification 20 | Tensorflow icon 21 | 22 | ### Tensorflow 23 | A TF/Keras implementation of the [JustBalancePool](https://graphneural.network/layers/pooling/#justbalancepool) layer is on [Spektral](https://graphneural.network/getting-started/). 24 | 25 | Run [``example_clustering_tf.py``](https://github.com/FilippoMB/Simplifying-Clustering-with-Graph-Neural-Networks/blob/main/example_clustering_tf.py) to perform node clustering. 26 | 27 | Pytorch icon 28 | 29 | ### Pytorch 30 | [``just_balance.py``](https://github.com/FilippoMB/Simplifying-Clustering-with-Graph-Neural-Networks/blob/main/just_balance.py) provides a Pytorch implementation based on [Pytorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/index.html#). 31 | 32 | Run [``example_clustering.py``](https://github.com/FilippoMB/Simplifying-Clustering-with-Graph-Neural-Networks/blob/main/example_clustering.py) to perform node clustering in Pytorch. 33 | 34 | Run [``example_classification.py``](https://github.com/FilippoMB/Simplifying-Clustering-with-Graph-Neural-Networks/blob/main/example_classification.py) to perform graph classification in Pytorch. 35 | 36 | 37 | > [!IMPORTANT] 38 | > The results on the paper were obtained using the Tensorflow/Spektral implementation. 39 | 40 | 41 | 42 | ## Citation 43 | 44 | ```bibtex 45 | @misc{bianchi2022simplifying, 46 | doi = {10.48550/ARXIV.2207.08779}, 47 | author = {Bianchi, Filippo Maria}, 48 | title = {Simplifying Clustering with Graph Neural Networks}, 49 | publisher = {arXiv}, 50 | year = {2022}, 51 | } 52 | ``` -------------------------------------------------------------------------------- /example_clustering.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import torch 3 | 4 | import torch_geometric.transforms as T 5 | from torch_geometric.datasets import Planetoid 6 | from torch_geometric.nn import GCNConv, Sequential 7 | from torch_geometric.nn.models.mlp import MLP 8 | from torch_geometric import utils 9 | 10 | from sklearn.metrics import normalized_mutual_info_score as NMI 11 | 12 | from just_balance import just_balance_pool 13 | 14 | torch.manual_seed(1) # for (inconsistent) reproducibility 15 | torch.cuda.manual_seed(1) 16 | 17 | # Load dataset 18 | dataset = 'cora' 19 | path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'data', dataset) 20 | dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures()) 21 | data = dataset[0] 22 | 23 | # Compute connectivity matrix 24 | delta = 0.85 25 | edge_index, edge_weight = utils.get_laplacian(data.edge_index, data.edge_weight, normalization='sym') 26 | L = utils.to_dense_adj(edge_index, edge_attr=edge_weight) 27 | A = torch.eye(data.num_nodes) - delta*L 28 | data.edge_index, data.edge_weight = utils.dense_to_sparse(A) 29 | 30 | 31 | class Net(torch.nn.Module): 32 | def __init__(self, 33 | mp_units, 34 | mp_act, 35 | in_channels, 36 | n_clusters, 37 | mlp_units=[], 38 | mlp_act="Identity"): 39 | super().__init__() 40 | 41 | mp_act = getattr(torch.nn, mp_act)(inplace=True) 42 | mlp_act = getattr(torch.nn, mlp_act)(inplace=True) 43 | 44 | # Message passing layers 45 | mp = [ 46 | (GCNConv(in_channels, mp_units[0], normalize=False, cached=False), 'x, edge_index, edge_weight -> x'), 47 | mp_act 48 | ] 49 | for i in range(len(mp_units)-1): 50 | mp.append((GCNConv(mp_units[i], mp_units[i+1], normalize=False, cached=False), 'x, edge_index, edge_weight -> x')) 51 | mp.append(mp_act) 52 | self.mp = Sequential('x, edge_index, edge_weight', mp) 53 | out_chan = mp_units[-1] 54 | 55 | self.mlp = MLP([out_chan] + mlp_units + [n_clusters], act=mlp_act, norm=None) 56 | 57 | def forward(self, x, edge_index, edge_weight): 58 | 59 | # Propagate node feats 60 | x = self.mp(x, edge_index, edge_weight) 61 | 62 | # Cluster assignments (logits) 63 | s = self.mlp(x) 64 | 65 | # Compute loss 66 | adj = utils.to_dense_adj(edge_index, edge_attr=edge_weight) 67 | _, _, b_loss = just_balance_pool(x, adj, s) 68 | 69 | return torch.softmax(s, dim=-1), b_loss 70 | 71 | 72 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 73 | data = data.to(device) 74 | model = Net([64]*10, "ReLU", dataset.num_features, dataset.num_classes, [16], "ReLU").to(device) 75 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) 76 | 77 | def train(): 78 | model.train() 79 | optimizer.zero_grad() 80 | _, loss = model(data.x, data.edge_index, data.edge_weight) 81 | loss.backward() 82 | optimizer.step() 83 | return loss.item() 84 | 85 | @torch.no_grad() 86 | def test(): 87 | model.eval() 88 | clust, _ = model(data.x, data.edge_index, data.edge_weight) 89 | return NMI(clust.max(1)[1].cpu(), data.y.cpu()) 90 | 91 | for epoch in range(1, 1001): 92 | train_loss = train() 93 | nmi = test() 94 | print(f'Epoch: {epoch:03d}, Loss: {train_loss:.4f}, ' f'NMI: {nmi:.3f}') -------------------------------------------------------------------------------- /example_clustering_tf.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | from tqdm import tqdm 3 | import numpy as np 4 | import scipy.sparse as sp 5 | import matplotlib.pyplot as plt 6 | from sklearn.metrics.cluster import normalized_mutual_info_score 7 | 8 | import tensorflow as tf 9 | from tensorflow.keras.layers import Input 10 | from tensorflow.keras.models import Model 11 | 12 | from spektral.utils.sparse import sp_matrix_to_sp_tensor 13 | from spektral.datasets import Citation, DBLP 14 | from spektral.utils.convolution import normalized_laplacian 15 | from spektral.layers import GCNConv 16 | from spektral.layers.pooling import JustBalancePool 17 | 18 | from eval_metrics import eval_metrics 19 | 20 | # Hyperparameters 21 | dataset_name = 'cora' 22 | delta = 0.85 23 | mp_layers = 10 24 | mp_channels = 64 25 | mlp_hidden = [16] 26 | learning_rate = 1e-4 27 | epochs = 2000 28 | 29 | # Load dataset 30 | if dataset_name in ['cora', 'citeseer', 'pubmed']: 31 | dataset = Citation(dataset_name, normalize_x=True) 32 | 33 | elif dataset_name == 'dblp': 34 | dataset = DBLP(normalize_x=True) 35 | 36 | X = dataset.graphs[0].x 37 | A = dataset.graphs[0].a 38 | Y = dataset.graphs[0].y 39 | y = np.argmax(Y, axis=-1) 40 | n_clusters = Y.shape[-1] 41 | N, F = X.shape 42 | 43 | # Build connectivity matrix 44 | A_tilde = sp.eye(N) - delta*normalized_laplacian(A) 45 | A_tilde = sp_matrix_to_sp_tensor(A_tilde) 46 | 47 | # Build model 48 | x_in = Input(shape=(F,), name="X_in") 49 | a_in = Input(shape=(None,), name="A_in", sparse=True) 50 | 51 | x_bar = x_in 52 | for _ in range(mp_layers): 53 | x_bar = GCNConv(mp_channels, activation='relu')([x_bar, a_in]) 54 | 55 | _, _, s = JustBalancePool(n_clusters, 56 | mlp_hidden=mlp_hidden, 57 | mlp_activation='relu', 58 | return_selection=True)([x_bar, a_in]) 59 | model = Model([x_in, a_in], [s]) 60 | 61 | # Training 62 | opt = tf.keras.optimizers.Adam(learning_rate=learning_rate) 63 | 64 | @tf.function(input_signature=None) 65 | def train_step(model, inputs, labels): 66 | with tf.GradientTape() as tape: 67 | _ = model(inputs, training=True) 68 | loss = sum(model.losses) 69 | gradients = tape.gradient(loss, model.trainable_variables) 70 | opt.apply_gradients(zip(gradients, model.trainable_variables)) 71 | return model.losses 72 | 73 | loss_history = [] 74 | nmi_history = [] 75 | ep_time = [] 76 | for _ in tqdm(range(epochs)): 77 | time_s = time() 78 | outs = train_step(model, [X, A_tilde], Y) 79 | time_e = time() 80 | 81 | loss_history.append([outs[i].numpy() 82 | for i in range(len(outs))]) 83 | ep_time.append(time_e - time_s) 84 | 85 | S_ = model([X, A_tilde], training=False) 86 | s = np.argmax(S_, axis=-1) 87 | nmi = normalized_mutual_info_score(y, s) 88 | nmi_history.append(nmi) 89 | 90 | # Print results 91 | S_ = model([X, A_tilde], training=False) 92 | s = np.argmax(S_, axis=-1) 93 | acc, nmi = eval_metrics(y, s) 94 | ep_time.pop(0) 95 | print(f"ACC: {acc:.3f}, NMI: {nmi:.3f}, avg seconds/step: {np.average(ep_time):.3f}s") 96 | 97 | # Plots 98 | plt.figure(figsize=(10, 5)) 99 | 100 | plt.subplot(121) 101 | plt.plot(loss_history, label="Balance loss") 102 | plt.legend() 103 | plt.ylabel("Loss") 104 | plt.xlabel("Iteration") 105 | 106 | plt.subplot(122) 107 | plt.plot(nmi_history, label="NMI") 108 | plt.legend() 109 | plt.ylabel("NMI") 110 | plt.xlabel("Iteration") 111 | 112 | plt.tight_layout() 113 | plt.show() -------------------------------------------------------------------------------- /example_classification.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch_geometric.nn import GINConv, DenseGINConv 4 | from torch_geometric.nn.models.mlp import MLP 5 | from torch_geometric.utils import to_dense_batch 6 | from torch_geometric.loader import DataLoader 7 | from torch_geometric.datasets import TUDataset 8 | 9 | from torch_geometric.transforms import BaseTransform 10 | from torch_geometric.utils import (to_dense_batch, 11 | get_laplacian, 12 | to_dense_adj, 13 | dense_to_sparse) 14 | 15 | # Local imports 16 | from just_balance import just_balance_pool 17 | 18 | 19 | class NormalizeAdj(BaseTransform): 20 | """ 21 | Applies the following transformation: 22 | 23 | A --> I - delta * L 24 | """ 25 | def __init__(self, delta: float = 0.85) -> None: 26 | self.delta = delta 27 | super().__init__() 28 | 29 | def forward(self, data: torch.Any) -> torch.Any: 30 | edge_index, edge_weight = get_laplacian(data.edge_index, data.edge_weight, normalization='sym') 31 | L = to_dense_adj(edge_index, edge_attr=edge_weight) 32 | A_norm = torch.eye(data.num_nodes) - self.delta * L 33 | data.edge_index, data.edge_weight = dense_to_sparse(A_norm) 34 | return data 35 | 36 | 37 | ### Get the data 38 | dataset = TUDataset(root="../data/TUDataset", name='NCI1', pre_transform=NormalizeAdj()) 39 | train_loader = DataLoader(dataset[:0.9], batch_size=32, shuffle=True) 40 | test_loader = DataLoader(dataset[0.9:], batch_size=32) 41 | 42 | 43 | ### Model definition 44 | class Net(torch.nn.Module): 45 | def __init__(self, 46 | hidden_channels = 64, 47 | mlp_units=[16], 48 | mlp_act="ReLU" 49 | ): 50 | super().__init__() 51 | 52 | num_features = dataset.num_features 53 | num_classes = dataset.num_classes 54 | n_clusters = int(dataset._data.x.size(0) / len(dataset)) # average number of nodes per graph 55 | mlp_act = getattr(torch.nn, mlp_act)(inplace=True) 56 | 57 | # First MP layer 58 | self.conv1 = GINConv( 59 | torch.nn.Sequential( 60 | torch.nn.Linear(num_features, hidden_channels), 61 | torch.nn.ReLU(), 62 | torch.nn.Linear(hidden_channels, hidden_channels), 63 | ) 64 | ) 65 | 66 | self.mlp = MLP([hidden_channels] + mlp_units + [n_clusters], act=mlp_act, norm=None) 67 | 68 | # Second MP layer 69 | self.conv2 = DenseGINConv( 70 | torch.nn.Sequential( 71 | torch.nn.Linear(hidden_channels, hidden_channels), 72 | torch.nn.ReLU(), 73 | torch.nn.Linear(hidden_channels, hidden_channels), 74 | ) 75 | ) 76 | 77 | # Readout layer 78 | self.lin = torch.nn.Linear(hidden_channels, num_classes) 79 | 80 | 81 | def forward(self, x, edge_index, batch=None): 82 | 83 | # First MP layer 84 | x = self.conv1(x, edge_index) 85 | 86 | # Transform to dense batch 87 | x, mask = to_dense_batch(x, batch) 88 | adj = to_dense_adj(edge_index, batch) 89 | 90 | # Cluster assignments (logits) 91 | s = self.mlp(x) 92 | 93 | # Pooling 94 | x_pool, adj_pool, aux_loss = just_balance_pool(x, adj, s, mask, normalize=True) 95 | 96 | # Second MP layer 97 | x = self.conv2(x_pool, adj_pool) 98 | 99 | # Global pooling 100 | x = x.mean(dim=1) 101 | 102 | # Readout layer 103 | x = self.lin(x) 104 | 105 | return F.log_softmax(x, dim=-1), aux_loss 106 | 107 | 108 | ### Model setup 109 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 110 | model = Net().to(device) 111 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) 112 | 113 | 114 | def train(): 115 | model.train() 116 | loss_all = 0 117 | 118 | for data in train_loader: 119 | data = data.to(device) 120 | optimizer.zero_grad() 121 | output, aux_loss = model(data.x, data.edge_index, data.batch) 122 | loss = F.nll_loss(output, data.y.view(-1)) + aux_loss 123 | loss.backward() 124 | # torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0) 125 | loss_all += data.y.size(0) * float(loss) 126 | optimizer.step() 127 | return loss_all / len(dataset) 128 | 129 | 130 | @torch.no_grad() 131 | def test(loader): 132 | model.eval() 133 | correct = 0 134 | for data in loader: 135 | data = data.to(device) 136 | pred = model(data.x, data.edge_index, data.batch)[0].max(dim=1)[1] 137 | correct += int(pred.eq(data.y.view(-1)).sum()) 138 | return correct / len(loader.dataset) 139 | 140 | 141 | ### Training loop 142 | best_val_acc = test_acc = 0 143 | for epoch in range(1, 501): 144 | train_loss = train() 145 | val_acc = test(test_loader) 146 | print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.3f}, ' 147 | f'Val Acc: {val_acc:.3f}') --------------------------------------------------------------------------------