├── exp ├── __init__.py ├── run_table.sh └── run_tu.py ├── architecture ├── __init__.py └── cell_network.py ├── utils ├── __init__.py ├── RenzIO.py └── utils.py ├── layers ├── __init__.py ├── signal_lift.py └── cell_layers.py ├── results └── cfg2acc_proteins.txt ├── configs └── config.json ├── LICENSE ├── README.md └── .gitignore /exp/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /architecture/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | import utils -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- 1 | import os, sys; sys.path.append(os.path.dirname(os.path.realpath(__file__))) -------------------------------------------------------------------------------- /results/cfg2acc_proteins.txt: -------------------------------------------------------------------------------- 1 | {"dataset": "proteins", "max_epochs": 1000, "bs": 64, "negative_slope": 0.2, "features": [256, 256, 256], "cell_attention_heads": [2, 2, 2, 2], "dense": [64, 16, 2], "norm_strategy": "layer", "signal_heads": 32, "signal_lift_activation": "ELU(alpha=0.2)", "signal_lift_dropout": 0.0, "skip": "True", "signal_lift_readout": "cat", "lift_hidden_dim": 8, "cell_attention_activation": "ELU(alpha=0.2)", "cell_attention_dropout": 0.0, "cell_attention_readout": "cat", "cell_forward_activation": "ELU(alpha=0.2)", "cell_forward_dropout": 0.0, "dense_readout": "avg", "lr": 0.001, "wd": 0.001, "k_pool": 0.5, "max_ring_size": 6, "param_init": "normal", "val_acc": 0.8098958730697632, "train_acc": 0.7665539383888245, "fold": 1} 2 | -------------------------------------------------------------------------------- /configs/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset" : "proteins", 3 | "max_epochs": 1000, 4 | "bs" : 64, 5 | "negative_slope" : 0.2, 6 | "features": [256, 256,256], 7 | "cell_attention_heads": [2,2,2,2], 8 | "dense": [64,16], 9 | "norm_strategy": "layer", 10 | "signal_heads": 32, 11 | "signal_lift_activation": "elu", 12 | "signal_lift_dropout": 0.0, 13 | "skip": "True", 14 | "signal_lift_readout": "cat", 15 | "lift_hidden_dim" : 8, 16 | "cell_attention_activation": "elu", 17 | "cell_attention_dropout": 0.0, 18 | "cell_attention_readout": "cat", 19 | "cell_forward_activation": "elu", 20 | "cell_forward_dropout": 0.0, 21 | "dense_readout": "avg", 22 | "lr": 0.001, 23 | "wd": 0.001, 24 | "k_pool": 0.5, 25 | "max_ring_size": 6, 26 | "param_init": "normal" 27 | } 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 XXX 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Cell Attention Networks 2 | 3 | This repository contains the official code implementation for the paper 4 | [Cell Attention Networks](https://arxiv.org/abs/2209.08179). 5 | 6 | Cell Attention Networks propose a novel message-passing scheme for graph neural networks (GNNs) that lifts node feature vectors into a higher-dimensional space called Cellular Attention Networks. The information exchange between edges is weighted by learnable attention coefficients, which enhances the model's expressiveness and generalization. 7 | 8 | ![Cell Attention Network Diagram](./images/can_diagram.png) 9 | 10 | ## Table of Contents 11 | 12 | - [Requirements](#requirements) 13 | - [Installation](#installation) 14 | - [Running Experiments](#running-experiments) 15 | - [Examples](#examples) 16 | - [References](#references) 17 | - [Citation](#citation) 18 | 19 | ## Requirements 20 | 21 | - Python 3.7+ 22 | - PyTorch 1.9+ 23 | - torchvision 0.10+ 24 | - torch-geometric 2.0+ 25 | - numpy 1.20+ 26 | - tqdm 4.62+ 27 | 28 | ## Installation 29 | 30 | To install the required dependencies, run the following command: 31 | 32 | ```bash 33 | pip install -r requirements.txt 34 | ``` 35 | 36 | 37 | ## Running experiments on TUDatasets 38 | 39 | 40 | 41 | 42 | ```commandline 43 | python ./exp/run_tu.py 44 | ``` 45 | 46 | 47 | ## Running all results on TUDatasets 48 | 49 | ```commandline 50 | sh ./exp/run_table.sh 51 | ``` 52 | 53 | ## Citation 54 | If you find this work useful, please consider citing the paper: 55 | 56 | ``` 57 | @misc{giusti2022cell, 58 | title={Cell Attention Networks}, 59 | author={Lorenzo Giusti and Claudio Battiloro and Lucia Testa and Paolo Di Lorenzo and Stefania Sardellitti and Sergio Barbarossa}, 60 | year={2022}, 61 | eprint={2209.08179}, 62 | archivePrefix={arXiv}, 63 | primaryClass={cs.LG} 64 | } 65 | ``` -------------------------------------------------------------------------------- /exp/run_table.sh: -------------------------------------------------------------------------------- 1 | for d in mutag ptc_mm enzymes proteins nci109 nci1 2 | do 3 | 4 | 5 | python run_tu.py --seed=0 --pci_id=0 -c=configs/config.json -d=$d & 6 | python run_tu.py --seed=1 --pci_id=0 -c=configs/config.json -d=$d & 7 | python run_tu.py --seed=2 --pci_id=1 -c=configs/config.json -d=$d & 8 | python run_tu.py --seed=3 --pci_id=1 -c=configs/config.json -d=$d & 9 | python run_tu.py --seed=4 --pci_id=2 -c=configs/config.json -d=$d & 10 | python run_tu.py --seed=5 --pci_id=2 -c=configs/config.json -d=$d & 11 | 12 | 13 | for job in `jobs -p` 14 | do 15 | echo $job 16 | wait $job 17 | done 18 | 19 | 20 | python run_tu.py --seed=6 --pci_id=0 -c=configs/config.json -d=$d & 21 | python run_tu.py --seed=7 --pci_id=0 -c=configs/config.json -d=$d & 22 | python run_tu.py --seed=8 --pci_id=1 -c=configs/config.json -d=$d & 23 | python run_tu.py --seed=9 --pci_id=1 -c=configs/config.json -d=$d & 24 | python run_tu.py --seed=10 --pci_id=2 -c=configs/config.json -d=$d & 25 | python run_tu.py --seed=11 --pci_id=2 -c=configs/config.json -d=$d & 26 | 27 | 28 | for job in `jobs -p` 29 | do 30 | echo $job 31 | wait $job 32 | done 33 | 34 | 35 | python run_tu.py --seed=12 --pci_id=0 -c=configs/config.json -d=$d & 36 | python run_tu.py --seed=13 --pci_id=0 -c=configs/config.json -d=$d & 37 | python run_tu.py --seed=14 --pci_id=1 -c=configs/config.json -d=$d & 38 | python run_tu.py --seed=15 --pci_id=1 -c=configs/config.json -d=$d & 39 | python run_tu.py --seed=16 --pci_id=2 -c=configs/config.json -d=$d & 40 | python run_tu.py --seed=17 --pci_id=2 -c=configs/config.json -d=$d & 41 | 42 | 43 | for job in `jobs -p` 44 | do 45 | echo $job 46 | wait $job 47 | done 48 | 49 | 50 | python run_tu.py --seed=18 --pci_id=0 -c=configs/config.json -d=$d & 51 | python run_tu.py --seed=19 --pci_id=1 -c=configs/config.json -d=$d & 52 | python run_tu.py --seed=20 --pci_id=2 -c=configs/config.json -d=$d & 53 | 54 | 55 | for job in `jobs -p` 56 | do 57 | echo $job 58 | wait $job 59 | done 60 | done 61 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /utils/RenzIO.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sat Apr 9 11:18:07 2022 5 | 6 | @author: ince 7 | """ 8 | 9 | 10 | #import sys 11 | #sys.path.append("..") 12 | #sys.path.append(".") 13 | import graph_tool 14 | import os 15 | import pickle 16 | import pathlib 17 | from torch_geometric.data import InMemoryDataset 18 | from torch_geometric.datasets import TUDataset, QM9, ZINC, Planetoid 19 | 20 | 21 | import torch_geometric.transforms as T 22 | from .utils import compute_cell, ProgressParallel, compute_cell_complex_stat 23 | 24 | from joblib import delayed 25 | 26 | 27 | TU_DATASETS = ['MUTAG', 'NCI1', 'NCI109', 'PROTEINS', 'PTC_MM', 'ENZYMES', 'COLORS-3'] 28 | CITATION_NETWORKS = ['CORA', 'CITESEER', 'PUBMED'] 29 | 30 | 31 | class RenzIO(InMemoryDataset): 32 | def __init__(self, dataset_name, max_ring_size, compute_complex=False, split=None): 33 | super(RenzIO, self).__init__(dataset_name) 34 | self.max_ring_size = max_ring_size 35 | root = "./datasets" 36 | dataset_name = dataset_name.upper() 37 | if dataset_name in TU_DATASETS: 38 | dataset = TUDataset(root+"/TU/", dataset_name) 39 | connectivity_path = root + "/TU/" + dataset_name + "/cell_conn/" 40 | elif dataset_name == 'ZINC': 41 | dataset = ZINC(root+"/ZINC", subset=True, split=split) 42 | connectivity_path = root + "/ZINC/cell_conn/"+split+"/" 43 | elif dataset_name == 'qm9': 44 | dataset = QM9(root+"/QM9") 45 | connectivity_path = root + "/QM9/cell_conn/" 46 | elif dataset_name in CITATION_NETWORKS: 47 | dataset = Planetoid(root=root + "/CITATION/", name=dataset_name) 48 | connectivity_path = root + "/CITATION/" + dataset_name + "/cell_conn/" 49 | 50 | #make path to find connectivity information for the complex 51 | #print(dataset) 52 | #l = compute_cell_complex_stat(dataset, 6) 53 | #print({k:v/len(dataset) for k,v in dict(Counter(l)).items()}) 54 | pathlib.Path(connectivity_path).mkdir(parents=False, exist_ok=True) 55 | if compute_complex or "connectivity.pkl" not in os.listdir(connectivity_path): 56 | print("Starting Processing Dataset") 57 | 58 | parallel = ProgressParallel(n_jobs=1, use_tqdm=True, total=len(dataset)) 59 | connectivity = parallel(delayed(compute_cell)( 60 | G, self.max_ring_size) for G in dataset) 61 | 62 | print("Processing Completed") 63 | 64 | pickle.dump(connectivity, 65 | open(connectivity_path+"connectivity.pkl", "wb"), 66 | protocol=pickle.HIGHEST_PROTOCOL) 67 | else: 68 | connectivity = pickle.load(open(connectivity_path+"connectivity.pkl", "rb")) 69 | 70 | self.data = dataset #remove .data for TU 71 | self.connectivities = connectivity 72 | 73 | 74 | def __getitem__(self, idx): 75 | return self.data[idx], self.connectivities[idx] 76 | 77 | def __len__(self): 78 | return len(self.data) 79 | 80 | -------------------------------------------------------------------------------- /exp/run_tu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Import necessary libraries and modules 5 | import graph_tool 6 | import sys 7 | import os 8 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 9 | import torch 10 | import torch.nn as nn 11 | import pytorch_lightning as pl 12 | import numpy as np 13 | import json 14 | import argparse 15 | import logging 16 | import torch_geometric 17 | from sklearn.model_selection import StratifiedKFold 18 | from architecture.cell_network import CellNetwork 19 | from utils.RenzIO import RenzIO 20 | from utils.utils import collate_complexes 21 | from torch_geometric.loader import DataLoader 22 | from torch.utils.data import DataLoader 23 | from torch.utils.data.sampler import SubsetRandomSampler, SequentialSampler 24 | from pytorch_lightning.callbacks.early_stopping import EarlyStopping 25 | 26 | # Set up logging 27 | logging.getLogger("lightning").setLevel(logging.ERROR) 28 | 29 | # Set up environment variable 30 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 31 | 32 | # Initialize the argument parser 33 | parser = argparse.ArgumentParser() 34 | # ... the remaining code ... 35 | 36 | # Set up device for running the experiment 37 | device = torch.device("cuda:"+args.pci_id if torch.cuda.is_available() else torch.device("cpu")) 38 | 39 | # Load configuration parameters from JSON file 40 | config = json.load(open(args.config_file, "r")) 41 | 42 | # Set the random seed for reproducibility 43 | seed = 12091996 44 | pl.seed_everything(seed) 45 | 46 | # ... the remaining code ... 47 | 48 | # Load dataset using RenzIO utility 49 | dataset = RenzIO(dataset_name, max_ring_size) 50 | 51 | # Perform stratified k-fold cross-validation for the dataset 52 | skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=seed) 53 | labels = [G.y.item() for G in dataset.data] 54 | idx_list = [] 55 | for idx in skf.split(np.zeros(len(labels)), labels): 56 | idx_list.append(idx) 57 | train_idxs, test_idxs = idx_list[args.fold] 58 | 59 | # Set up data loaders for training and testing 60 | train_sampler = SubsetRandomSampler(train_idxs) 61 | test_sampler = SubsetRandomSampler(test_idxs) 62 | train_dataloader = DataLoader(dataset, sampler=train_sampler, collate_fn=collate_complexes, batch_size=bs, 63 | num_workers=0, pin_memory=True) 64 | test_dataloader = DataLoader(dataset, sampler=test_sampler, collate_fn=collate_complexes, batch_size=bs, 65 | num_workers=0, pin_memory=True) 66 | 67 | # Instantiate Cellular Attention Network (CellNetwork) model 68 | s = CellNetwork(in_features={'node': dataset.data.num_features, 69 | 'edge': dataset.data.num_edge_features}, 70 | n_class=dataset.data.num_classes, 71 | **config, device=device).to(device) 72 | 73 | # Set up early stopping callback 74 | early_stop_callback = EarlyStopping(monitor="valid_acc", min_delta=0.001, patience=200, 75 | verbose=True, mode="max") 76 | 77 | # Set up TensorBoard logger for recording experiment results 78 | logger = pl.loggers.TensorBoardLogger(name=config['dataset'], save_dir='results') 79 | 80 | # Initialize PyTorch Lightning trainer 81 | trainer = pl.Trainer(max_epochs=config['max_epochs'], logger=logger, callbacks=[early_stop_callback], 82 | devices=[int(args.pci_id)], accelerator="gpu", auto_select_gpus=False) 83 | 84 | # Train and validate the model 85 | trainer.fit(s, train_dataloader, test_dataloader) 86 | 87 | # Save the best validation and training accuracies in the configuration 88 | config['val_acc'] = max(s.valid_acc_epoch) 89 | config['train_acc'] = max(s.train_acc_epoch) 90 | config['fold'] = int(args.fold) 91 | config['signal_lift_activation'] = str(config['signal_lift_activation']) 92 | config['cell_attention_activation'] = str(config['cell_attention_activation']) 93 | config['cell_forward_activation'] = str(config['cell_forward_activation']) 94 | 95 | # Save the configuration and accuracies to a file 96 | with open("results/cfg2acc_" + dataset_name + ".txt", "a") as fout: 97 | fout.write(json.dumps(config)) 98 | fout.write("\n") 99 | -------------------------------------------------------------------------------- /layers/signal_lift.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from typing import TypeVar, Tuple, Callable 10 | 11 | 12 | 13 | NodeSignal = TypeVar('NodeSignal') 14 | EdgeSignal = TypeVar('EdgeSignal') 15 | Graph = TypeVar('Graph') 16 | 17 | 18 | class LiftLayer(nn.Module): 19 | """ 20 | A single lift layer for a Cell Attention Network (CAN). 21 | This layer is responsible for lifting node feature vectors into 22 | edge features using a learnable function 23 | 24 | Parameters 25 | ---------- 26 | F_in : int 27 | Number of input features for the single lift layer. 28 | signal_lift_activation : Callable 29 | Non-linear activation function for the signal lift. 30 | signal_lift_dropout : float 31 | Dropout rate applied after the lift. 32 | 33 | Examples 34 | -------- 35 | >>> lift_layer = LiftLayer(F_in=10, signal_lift_activation=torch.relu, signal_lift_dropout=0.5) 36 | """ 37 | def __init__(self, F_in: int, 38 | signal_lift_activation: Callable, 39 | signal_lift_dropout: float): 40 | super(LiftLayer, self).__init__() 41 | 42 | self.F_in = F_in 43 | self.att = nn.Parameter(torch.empty(size=(2 * F_in, 1))) 44 | self.signal_lift_activation = signal_lift_activation 45 | self.signal_lift_dropout = signal_lift_dropout 46 | self.reset_parameters() 47 | 48 | def __repr__(self): 49 | return "LiftLayer(" + \ 50 | "F_in=" + str(self.F_in) + \ 51 | ", Activation=" + str(self.signal_lift_activation) + \ 52 | ", Dropout=" + str(self.signal_lift_dropout) + ")" 53 | 54 | def to(self, device): 55 | super().to(device) 56 | return self 57 | 58 | def reset_parameters(self): 59 | """Reinitialize learnable parameters using Xavier uniform initialization.""" 60 | gain = nn.init.calculate_gain('relu') 61 | nn.init.xavier_uniform_(self.att.data, gain=gain) 62 | 63 | def forward(self, x: Tuple[NodeSignal, Graph]) -> EdgeSignal: 64 | """ 65 | Perform the forward pass for a single lift layer. 66 | 67 | Parameters 68 | ---------- 69 | x : Tuple[NodeSignal, Graph] 70 | Input tuple containing the node signal (node feature vectors) 71 | and the graph structure (Graph object). 72 | 73 | Returns 74 | ------- 75 | EdgeSignal 76 | The resulting edge signal after the lift layer operation. 77 | 78 | Notes 79 | ----- 80 | The forward pass can be described as follows: 81 | 1. Extract source and target nodes from the input graph's edge index. 82 | 2. Concatenate source and target node feature vectors. 83 | 3. Compute the output edge signal by applying the activation function to the 84 | matrix multiplication of the concatenated node features and the attention 85 | coefficients. 86 | """ 87 | # Unpack input tuple into node signal and graph structure 88 | node_signal, graph = x 89 | 90 | # Extract source and target nodes from the graph's edge index 91 | source, target = graph.edge_index 92 | 93 | # Concatenate source and target node feature vectors 94 | node_features_stacked = torch.cat((node_signal[source], node_signal[target]), dim=1) 95 | 96 | # Compute the output edge signal by applying the activation function 97 | edge_signal = self.signal_lift_activation(node_features_stacked.mm(self.att)) 98 | 99 | return edge_signal 100 | 101 | 102 | class MultiHeadLiftLayer(nn.Module): 103 | """ 104 | A multi-head lift layer for a Cellular Graph Attention Network (GAT). 105 | This layer is responsible for lifting node feature vectors by 106 | exchanging information between edges weighted by learnable attention 107 | coefficients, for multiple attention heads. 108 | 109 | Parameters 110 | ---------- 111 | F_in : int 112 | Number of input features for the lift layers. 113 | K : int 114 | Number of attention heads. 115 | signal_lift_activation : Callable 116 | Non-linear activation function for the signal lift. 117 | signal_lift_dropout : float 118 | Dropout rate applied after the lift. 119 | signal_lift_readout : str 120 | The readout strategy for combining the output of multiple attention heads. 121 | Choices: 'cat', 'sum', 'avg', 'max' 122 | 123 | Examples 124 | -------- 125 | >>> multi_head_lift_layer = MultiHeadLiftLayer(F_in=10, K=3, signal_lift_activation=torch.relu, signal_lift_dropout=0.5, signal_lift_readout='sum') 126 | """ 127 | def __init__(self, F_in: int, K: int, 128 | signal_lift_activation: Callable, 129 | signal_lift_dropout: float, 130 | signal_lift_readout: str, *args, **kwargs): 131 | super(MultiHeadLiftLayer, self).__init__() 132 | 133 | self.F_in = F_in 134 | self.K = K 135 | self.signal_lift_readout = signal_lift_readout 136 | self.signal_lift_dropout = signal_lift_dropout 137 | self.signal_lift_activation = signal_lift_activation 138 | self.lifts = [LiftLayer(F_in=F_in, 139 | signal_lift_activation=signal_lift_activation, 140 | signal_lift_dropout=signal_lift_dropout) for _ in range(K)] 141 | 142 | def __repr__(self): 143 | str_F_out = '2' if self.signal_lift_readout != 'cat' else str(2 * self.K) 144 | s = "MultiHeadLiftLayer(" + \ 145 | "F_in=" + str(self.F_in) + ", F_out=" + str_F_out + \ 146 | ", heads=" + str(self.K) + ", readout=" + self.signal_lift_readout + "):" 147 | 148 | for idx, lift in enumerate(self.lifts): 149 | s += "\n\t(" + str(idx) + "): " + str(lift) 150 | return s 151 | 152 | def to(self, device): 153 | self.lifts = nn.ModuleList([lift.to(device) for lift in self.lifts]) 154 | return self 155 | 156 | def forward(self, x): 157 | """ 158 | Perform the forward pass for a multi-head lift layer. 159 | 160 | Parameters 161 | ---------- 162 | x : Tuple[NodeSignal, Graph] 163 | Input tuple containing the node signal (node feature vectors) 164 | and the graph structure (Graph object). 165 | 166 | Returns 167 | ------- 168 | EdgeSignal 169 | The resulting edge signal after the lift layer operation for multiple 170 | attention heads. 171 | Graph 172 | The input graph structure. 173 | 174 | Notes 175 | ----- 176 | The forward pass can be described as follows: 177 | 1. Unpack the input tuple into node signal and graph structure. 178 | 2. Lift the node signal for each attention head. 179 | 3. Combine the output edge signals using the specified readout strategy. 180 | 4. Apply dropout to the combined edge signal. 181 | """ 182 | # Unpack input tuple into node signal and graph structure 183 | node_signal, graph = x 184 | 185 | # Lift the node signal for each attention head 186 | edge_signals = [lift((node_signal, graph)) for lift in self.lifts] 187 | 188 | # Combine the output edge signals using the specified readout strategy 189 | if self.signal_lift_readout == 'cat': 190 | combined_edge_signal = torch.cat(edge_signals, dim=1) 191 | elif self.signal_lift_readout == 'sum': 192 | combined_edge_signal = torch.stack(edge_signals, dim=2).sum(dim=2) 193 | elif self.signal_lift_readout == 'avg': 194 | combined_edge_signal = torch.stack(edge_signals, dim=2).mean(dim=2) 195 | elif self.signal_lift_readout == 'max': 196 | combined_edge_signal = torch.stack(edge_signals, dim=2).max(dim=2).values 197 | else: 198 | raise ValueError("Invalid signal_lift_readout value. Choose from ['cat', 'sum', 'avg', 'max']") 199 | 200 | # Apply dropout to the combined edge signal 201 | combined_edge_signal = F.dropout(combined_edge_signal, self.signal_lift_dropout, training=self.training) 202 | 203 | return combined_edge_signal, graph 204 | 205 | -------------------------------------------------------------------------------- /architecture/cell_network.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from typing import TypeVar, List, Callable 9 | import pytorch_lightning as pl 10 | from torch.optim.lr_scheduler import ReduceLROnPlateau 11 | from layers.cell_layers import MultiHeadCellAttentionLayer, TopologicalNorm, CAPooLayer 12 | from layers.signal_lift import MultiHeadLiftLayer 13 | from utils.utils import compute_cell, readout 14 | 15 | 16 | Graph = TypeVar('Graph') 17 | Laplacian = TypeVar('Laplacian') 18 | Signal = TypeVar('Signal') 19 | ConnectivityMask = TypeVar('ConnectivityMask') 20 | ReadoutIndexer = TypeVar('ReadoutIndexer') 21 | 22 | class CellNetwork(pl.LightningModule): 23 | """ 24 | A Cellular version of Graph Attention Networks that lifts the node feature vectors 25 | of GNNs. The information is exchanged between the edges weighted by learnable attention 26 | coefficients. 27 | 28 | Attributes: 29 | lr (float): Learning rate 30 | wd (float): Weight decay 31 | readout (str): Readout strategy for the dense layer 32 | lift (nn.Sequential): MultiHeadLiftLayer and TopologicalNorm layers for signal lifting 33 | dense (List[int]): A list of dense layers 34 | N_dense_layers (int): The number of dense layers 35 | cell_net (nn.Sequential): Sequential Cellular Attention layers 36 | mlp (nn.Sequential): Multi-layer perceptron 37 | max_acc (float): Maximum accuracy 38 | loss_fn (nn.BCEWithLogitsLoss): Binary Cross-Entropy with Logits Loss function 39 | train_acc (List[float]): Training accuracy 40 | valid_acc (List[float]): Validation accuracy 41 | valid_acc_epoch (List[float]): Validation accuracy per epoch 42 | train_acc_epoch (List[float]): Training accuracy per epoch 43 | """ 44 | def __init__(self, in_features: int, n_class: int, 45 | features: List[int], cell_attention_heads: List[int], 46 | dense: List[int], norm_strategy: str, signal_heads: int, 47 | signal_lift_activation: Callable, signal_lift_dropout: float, 48 | signal_lift_readout: str, 49 | cell_attention_activation: Callable, cell_attention_dropout: float, 50 | cell_attention_readout: str, 51 | cell_forward_activation: Callable, cell_forward_dropout: float, 52 | dense_readout: str, skip: bool, 53 | lr: float, wd: float, 54 | k_pool: float, device: str, param_init: str, **kwargs): 55 | """ 56 | Initializes the CellNetwork. 57 | 58 | Args: 59 | in_features (int): The input features 60 | n_class (int): The number of classes 61 | features (List[int]): List of features for each layer 62 | cell_attention_heads (List[int]): List of cell attention heads 63 | dense (List[int]): List of dense layers 64 | norm_strategy (str): Normalization strategy 65 | signal_heads (int): Number of signal heads 66 | signal_lift_activation (Callable): Activation function for signal lifting 67 | signal_lift_dropout (float): Dropout rate for signal lifting 68 | signal_lift_readout (str): Readout strategy for signal lifting 69 | cell_attention_activation (Callable): Activation function for cell attention 70 | cell_attention_dropout (float): Dropout rate for cell attention 71 | cell_attention_readout (str): Readout strategy for cell attention 72 | cell_forward_activation (Callable): Activation function for forward pass 73 | cell_forward_dropout (float): Dropout rate for forward pass 74 | dense_readout (str): Readout strategy for dense layer 75 | skip (bool): Whether to use skip connections 76 | lr (float): Learning rate 77 | wd (float): Weight decay 78 | k_pool (float): k value for the pooling layer 79 | device (str): Device to run the model on 80 | param_init (str): Parameter initialization strategy 81 | **kwargs: Additional keyword arguments 82 | """ 83 | super(CellNetwork, self).__init__() 84 | 85 | self.lr = lr 86 | self.wd = wd 87 | self.readout = dense_readout 88 | 89 | # Initialize lift layers 90 | lift = MultiHeadLiftLayer(F_in=in_features['node'], 91 | K=signal_heads, 92 | signal_lift_activation=signal_lift_activation, 93 | signal_lift_dropout=signal_lift_dropout, 94 | signal_lift_readout=signal_lift_readout).to(device) 95 | lift_out_feat = signal_heads 96 | norm = TopologicalNorm(feat_dim=lift_out_feat, 97 | strategy=norm_strategy) 98 | self.lift = nn.Sequential(lift, norm) 99 | 100 | # Initialize cell attention layers 101 | in_features = [lift_out_feat+ in_features['edge']] + \ 102 | [features[i]*cell_attention_heads[i] 103 | if cell_attention_readout == 'cat' 104 | else features[i] 105 | for i in range(len(features))] 106 | self.dense = dense 107 | self.dense.append(n_class) 108 | self.N_dense_layers = len(dense) 109 | ops = [] 110 | 111 | for l in range(len(in_features)-1): 112 | layer = MultiHeadCellAttentionLayer(cell_attention_heads=cell_attention_heads[l], 113 | F_in=in_features[l], 114 | F_out=features[l], 115 | skip=eval(skip), 116 | cell_attention_activation=cell_attention_activation, 117 | cell_attention_dropout=cell_attention_dropout, 118 | cell_forward_activation=cell_forward_activation, 119 | cell_forward_dropout=cell_forward_dropout, 120 | cell_attention_readout=cell_attention_readout, 121 | param_init=param_init).to(device) 122 | CAN_out_feat = features[l]*cell_attention_heads[l] \ 123 | if cell_attention_readout == 'cat' else features[l] 124 | can_norm = TopologicalNorm(feat_dim=CAN_out_feat, 125 | strategy=norm_strategy) 126 | 127 | ops.append(layer) 128 | ops.append(can_norm) 129 | 130 | if l == len(in_features)-2: 131 | pool = CAPooLayer(k_pool=k_pool, 132 | F_in=CAN_out_feat, 133 | cell_forward_activation=cell_forward_activation) 134 | ops.append(pool) 135 | 136 | # Initialize MLP 137 | mlp = [] 138 | cell_to_dense_feat = features[-1]*cell_attention_heads[-1] \ 139 | if cell_attention_readout == 'cat' else features[-1] 140 | simplicial_to_dense = nn.Linear( 141 | cell_to_dense_feat, dense[0]) 142 | mlp.extend([simplicial_to_dense]) 143 | 144 | self.dropout = cell_forward_dropout 145 | 146 | for l in range(1, self.N_dense_layers): 147 | mlp.extend([cell_forward_activation, 148 | nn.BatchNorm1d(dense[l-1]), 149 | nn.Dropout(cell_forward_dropout), 150 | nn.Linear(dense[l-1], dense[l])]) 151 | 152 | self.cell_net = nn.Sequential(*ops) 153 | self.mlp = nn.Sequential(*mlp) 154 | 155 | self.max_acc = 0.0 156 | self.loss_fn = nn.BCEWithLogitsLoss() 157 | self.train_acc = [] 158 | self.valid_acc = [] 159 | self.valid_acc_epoch = [] 160 | self.train_acc_epoch = [] 161 | print(self) 162 | 163 | def to(self, device): 164 | """ 165 | Moves the CellNetwork to the specified device. 166 | 167 | Args: 168 | device (str): The device to move the CellNetwork to 169 | 170 | Returns: 171 | self: The CellNetwork after being moved to the device 172 | """ 173 | super().to(device) 174 | self.cell_net = self.cell_net.to(device) 175 | self.lift = self.lift.to(device) 176 | self.mlp = self.mlp.to(device) 177 | return self 178 | 179 | def forward(self, G: Graph): 180 | """ 181 | Forward pass of the CellNetwork. 182 | 183 | Args: 184 | G (Graph): Input graph 185 | 186 | Returns: 187 | h_mlp (torch.Tensor): Output tensor after the forward pass 188 | """ 189 | X = G.x 190 | Xe, _ = self.lift((X, G)) 191 | 192 | if G.edge_attr is not None: 193 | Xe = torch.cat((Xe, G.edge_attr.float()), dim=1) 194 | 195 | H, _ = self.cell_net((Xe, G)) 196 | 197 | H_ro = torch.stack(G.ros, dim=2).sum(dim=2) 198 | Xe = F.dropout(H_ro, self.dropout, training=self.training) 199 | h_mlp = self.mlp(H_ro) 200 | 201 | return h_mlp 202 | 203 | def propagate(self, batch, batch_idx): 204 | """ 205 | Propagate the batch through the model. 206 | 207 | Args: 208 | batch (Batch): Input batch 209 | batch_idx (int): Batch index 210 | 211 | Returns: 212 | loss (torch.Tensor): Loss for the batch 213 | acc (torch.Tensor): Accuracy for the batch 214 | """ 215 | G = batch 216 | row, col = G.edge_index 217 | G.edge_batch = G.batch[row] 218 | y_hat = self(G) 219 | loss = self.loss_fn(y_hat, torch.nn.functional.one_hot(G.y, num_classes=2).float()) 220 | 221 | acc = ((y_hat.argmax(dim=1) == G.y)*1).float().mean() 222 | return loss, acc 223 | 224 | def training_step(self, batch, batch_idx): 225 | """ 226 | Training step for the model. 227 | 228 | Args: 229 | batch (Batch): Input batch 230 | batch_idx (int): Batch index 231 | 232 | Returns: 233 | loss (torch.Tensor): Loss for the training batch 234 | """ 235 | loss, acc = self.propagate(batch, batch_idx) 236 | 237 | self.train_acc.append(acc) 238 | 239 | self.log('train_loss', loss.item(), on_step=False, 240 | on_epoch=True, prog_bar=True) 241 | 242 | return loss 243 | 244 | def validation_step(self, batch, batch_idx): 245 | """ 246 | Validation step for the model. 247 | 248 | Args: 249 | batch (Batch): Input batch 250 | batch_idx (int): Batch index 251 | 252 | Returns: 253 | loss (torch.Tensor): Loss for the validation batch 254 | """ 255 | loss, acc = self.propagate(batch, batch_idx) 256 | self.valid_acc.append(acc) 257 | self.log('valid_loss', loss.item(), on_step=False, 258 | on_epoch=True, prog_bar=True) 259 | 260 | return loss 261 | 262 | def test_step(self, batch, batch_idx): 263 | """ 264 | Test step for the model. 265 | 266 | Args: 267 | batch (Batch): Input batch 268 | batch_idx (int): Batch index 269 | 270 | Returns: 271 | loss (torch.Tensor): Loss for the test batch 272 | """ 273 | return self.validation_step(batch, batch_idx) 274 | 275 | def training_epoch_end(self, outs): 276 | """ 277 | Operations at the end of a training epoch. 278 | 279 | Args: 280 | outs (List): Output list 281 | """ 282 | epoch_train_acc = torch.tensor(self.train_acc, dtype=torch.float).mean().item() 283 | self.train_acc_epoch.append(epoch_train_acc) 284 | self.log('train_acc', epoch_train_acc, on_step=False, 285 | on_epoch=True, prog_bar=True) 286 | self.train_acc = [] 287 | 288 | def validation_epoch_end(self, outs): 289 | """ 290 | Operations at the end of a validation epoch. 291 | 292 | Args: 293 | outs (List): Output list 294 | """ 295 | epoch_valid_acc = torch.tensor(self.valid_acc, dtype=torch.float).mean().item() 296 | self.valid_acc_epoch.append(epoch_valid_acc) 297 | self.log('valid_acc', epoch_valid_acc, on_step=False, 298 | on_epoch=True, prog_bar=True) 299 | self.valid_acc = [] 300 | 301 | def configure_optimizers(self): 302 | """ 303 | Configures the optimizer and learning rate scheduler. 304 | 305 | Returns: 306 | dict: Dictionary containing the optimizer, learning rate scheduler, and monitor 307 | """ 308 | optimizer = torch.optim.AdamW( 309 | self.parameters(), lr=self.lr, weight_decay=self.wd) 310 | scheduler = ReduceLROnPlateau(optimizer, 311 | mode='max', 312 | factor=0.5, 313 | patience=50, 314 | min_lr=7e-5, 315 | verbose=True) 316 | return {'optimizer': optimizer, 'lr_scheduler': scheduler, 'monitor': 'valid_acc'} 317 | 318 | -------------------------------------------------------------------------------- /layers/cell_layers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Fri Feb 4 20:00:00 2022 5 | 6 | @author: ince 7 | """ 8 | 9 | 10 | 11 | from typing import Callable, TypeVar, Tuple, Union, List, Optional 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | spmm = torch.sparse.mm 16 | 17 | EdgeSignal = TypeVar('EdgeSignal') 18 | Graph = TypeVar('Graph') 19 | 20 | from torch_geometric.nn.pool.topk_pool import topk, filter_adj 21 | from utils.utils import readout 22 | 23 | 24 | 25 | 26 | def sp_softmax(indices, values, N): 27 | """ 28 | Compute the sparse softmax of the given values. 29 | 30 | Parameters 31 | ---------- 32 | indices : torch.tensor 33 | The indices of the non-zero elements in the sparse tensor. 34 | values : torch.tensor 35 | The values of the non-zero elements in the sparse tensor. 36 | N : int 37 | The size of the output tensor. 38 | 39 | Returns 40 | ------- 41 | softmax_v : torch.tensor 42 | The softmax values computed for the sparse tensor. 43 | """ 44 | source, _ = indices 45 | v_max = values.max() 46 | exp_v = torch.exp(values - v_max) 47 | exp_sum = torch.zeros(N, 1, device='cuda') 48 | exp_sum.scatter_add_(0, source.unsqueeze(1), exp_v) 49 | exp_sum += 1e-10 50 | softmax_v = exp_v / exp_sum[source] 51 | return softmax_v 52 | 53 | 54 | def sp_matmul(indices, values, mat): 55 | """ 56 | Perform sparse matrix multiplication. 57 | Parameters 58 | ---------- 59 | indices : torch.tensor 60 | The indices of the non-zero elements in the sparse tensor. 61 | values : torch.tensor 62 | The values of the non-zero elements in the sparse tensor. 63 | mat : torch.tensor 64 | The dense matrix to be multiplied with the sparse tensor. 65 | 66 | Returns 67 | ------- 68 | out : torch.tensor 69 | The result of the sparse matrix multiplication. 70 | """ 71 | source, target = indices 72 | out = torch.zeros_like(mat) 73 | out.scatter_add_(0, source.expand(mat.size(1), -1).t(), values * mat[target]) 74 | return out 75 | 76 | 77 | class CellLayer(nn.Module): 78 | """ 79 | A Cell layer for a Cellular Attention Network (CAN). 80 | This layer is responsible for learning the cell update rules in the 81 | network using learnable weight matrices. 82 | 83 | Parameters 84 | ---------- 85 | F_in : int 86 | Number of input features. 87 | F_out : int 88 | Number of output features. 89 | cell_forward_activation : Callable 90 | Non-linear activation function for the cell forward operation. 91 | cell_forward_dropout : float 92 | Dropout rate applied after the cell forward operation. 93 | param_init : str 94 | The initialization method for the learnable weight matrices. 95 | Choices: 'uniform', 'normal' 96 | 97 | Examples 98 | -------- 99 | >>> cell_layer = CellLayer(F_in=10, F_out=20, cell_forward_activation=torch.relu, cell_forward_dropout=0.5, param_init='uniform') 100 | """ 101 | 102 | def __init__(self, F_in: int, F_out: int, 103 | cell_forward_activation: Callable, 104 | cell_forward_dropout: float, 105 | param_init: str): 106 | 107 | super(CellLayer, self).__init__() 108 | 109 | self.F_in = F_in 110 | self.F_out = F_out 111 | self.Wirr = nn.Parameter(torch.empty(size=(F_in, F_out))) 112 | self.Wsol = nn.Parameter(torch.empty(size=(F_in, F_out))) 113 | self.Wskip = nn.Parameter(torch.empty(size=(F_in, F_out))) 114 | self.W = nn.Parameter(torch.empty(size=(F_out, F_out))) 115 | self.Wout = nn.Parameter(torch.empty(size=(F_out, F_out))) 116 | 117 | self.param_init = param_init 118 | 119 | self.cell_forward_activation = cell_forward_activation 120 | self.cell_forward_dropout = cell_forward_dropout 121 | self.reset_parameters() 122 | 123 | def __repr__(self): 124 | s = "CellLayer(" + \ 125 | "F_in=" + str(self.F_in) + \ 126 | ", F_out=" + str(self.F_out) + \ 127 | ", Fwd_Activ=" + str(self.cell_forward_activation) + \ 128 | ", Fwd_Dropout=" + str(self.cell_forward_dropout) + ")" 129 | return s 130 | 131 | def reset_parameters(self): 132 | """ 133 | Reinitialize the learnable weight matrices of the cell layer. 134 | The initialization method is determined by the `param_init` parameter. 135 | """ 136 | gain = nn.init.calculate_gain('relu') 137 | 138 | if self.param_init == 'uniform': 139 | nn.init.xavier_uniform_(self.Wirr.data, gain=gain) 140 | nn.init.xavier_uniform_(self.Wsol.data, gain=gain) 141 | nn.init.xavier_uniform_(self.Wskip.data, gain=gain) 142 | nn.init.xavier_uniform_(self.W.data, gain=gain) 143 | nn.init.xavier_uniform_(self.Wout.data, gain=gain) 144 | else: 145 | nn.init.xavier_normal_(self.Wirr.data, gain=gain) 146 | nn.init.xavier_normal_(self.Wsol.data, gain=gain) 147 | nn.init.xavier_normal_(self.Wskip.data, gain=gain) 148 | nn.init.xavier_normal_(self.W.data, gain=gain) 149 | nn.init.xavier_normal_(self.Wout.data, gain=gain) 150 | 151 | 152 | 153 | class TopologicalNorm(torch.nn.Module): 154 | def __init__(self, feat_dim, strategy): 155 | """ 156 | 157 | Allow to perform data normalization by 158 | forward the also the complex information through the network 159 | 160 | the connectivity can also be ignored for the lift 161 | 162 | Parameters 163 | ---------- 164 | feat_dim : int 165 | dimension of the incomping signals' features. 166 | strategy : str 167 | Strategy to perform the normalization technique. 168 | 169 | Returns 170 | ------- 171 | norm(x), G : EdgeSignal, Graph. 172 | 173 | """ 174 | super(TopologicalNorm, self).__init__() 175 | assert (feat_dim > 0), "feature dimension of the signal must be > 0" 176 | assert strategy in ['layer', 'batch', 'identity', 'id'], "TopologicalNorm strategy must be one of: ['layer', 'batch', 'identity']" 177 | if strategy == 'layer': 178 | self.tn=nn.LayerNorm(feat_dim) 179 | elif strategy == 'batch': 180 | self.tn=nn.BatchNorm1d(feat_dim) 181 | elif strategy == 'identity': 182 | self.tn=nn.Identity() 183 | 184 | def forward(self, x: Union[Tuple[EdgeSignal, Graph], EdgeSignal]): 185 | x, G = x 186 | return self.tn(x), G 187 | 188 | 189 | class CellAttentionLayer(CellLayer): 190 | """ 191 | Attention-based cell layer of Cellular Attention Network. 192 | 193 | This layer inherits from the CellLayer class and adds an attention mechanism 194 | for the information exchange between the edges weighted by learnable attention 195 | coefficients. 196 | 197 | Parameters 198 | ---------- 199 | F_in : int 200 | Number of input features for the cell attention layer. 201 | F_out : int 202 | Number of output features for the cell attention layer. 203 | skip : bool 204 | Whether to add skip connections in the attention layer. 205 | cell_attention_activation : Callable 206 | Non-linear activation function for the cell attention mechanism. 207 | cell_forward_activation : Callable 208 | Non-linear activation function for the forward pass of the cell layer. 209 | cell_attention_dropout : float 210 | Dropout rate applied to the attention mechanism. 211 | cell_forward_dropout : float 212 | Dropout rate applied to the forward pass of the cell layer. 213 | param_init : str 214 | Parameter initialization method, either 'uniform' or 'normal'. 215 | """ 216 | 217 | def __init__(self, F_in: int, F_out: int, skip: bool, 218 | cell_attention_activation: Callable, 219 | cell_forward_activation: Callable, 220 | cell_attention_dropout: float, 221 | cell_forward_dropout: float, 222 | param_init: str): 223 | 224 | # Call the constructor of the parent class (CellLayer) 225 | super(CellAttentionLayer, self).__init__(F_in=F_in, 226 | F_out=F_out, 227 | cell_forward_activation=cell_forward_activation, 228 | cell_forward_dropout=cell_forward_dropout, 229 | param_init=param_init) 230 | 231 | # Define learnable parameters for the attention mechanism 232 | self.att_irr = nn.Parameter(torch.empty(size=(2 * self.F_out, 1))) 233 | self.att_sol = nn.Parameter(torch.empty(size=(2 * self.F_out, 1))) 234 | 235 | self.param_init = param_init 236 | self.skip = skip 237 | 238 | self.cell_attention_activation = cell_attention_activation 239 | self.dropout = cell_attention_dropout 240 | 241 | # Reset and initialize parameters 242 | self.reset_parameters() 243 | 244 | def __repr__(self): 245 | cell_repr = super().__repr__() 246 | s = "AttentionLayer(" + \ 247 | "Att Activ="+str(self.cell_attention_activation) +\ 248 | ", Att Dropout="+str(self.dropout) +\ 249 | ", Skip="+str(self.skip) +")\n\t\t" + cell_repr + "\n)" 250 | return s 251 | 252 | def to(self, device): 253 | super().to(device) 254 | return self 255 | 256 | def reset_parameters(self): 257 | """Reinitialize learnable parameters.""" 258 | super().reset_parameters() 259 | gain = nn.init.calculate_gain('relu') 260 | if self.param_init == 'uniform': 261 | nn.init.xavier_uniform_(self.att_irr.data, gain=gain) 262 | nn.init.xavier_uniform_(self.att_sol.data, gain=gain) 263 | else: 264 | nn.init.xavier_normal_(self.att_irr.data, gain=gain) 265 | nn.init.xavier_normal_(self.att_sol.data, gain=gain) 266 | 267 | 268 | def forward(self, x: Tuple[EdgeSignal, Graph]) -> EdgeSignal: 269 | x, G = x 270 | x = F.dropout(x, self.dropout, training=self.training) 271 | out = (1.001)*(x @ self.Wskip) if self.skip else torch.tensor(0.0) 272 | try: 273 | h = torch.matmul(x, self.Wirr) 274 | source, target = G.connectivities['do'] 275 | a = torch.cat([h[source], h[target]], dim=1) 276 | e = self.cell_attention_activation(torch.matmul(a, self.att_irr)) 277 | #e = F.dropout(e, self.dropout, training=self.training) 278 | 279 | attention = sp_softmax(G.connectivities['do'], e, h.size(0)) 280 | attention = F.dropout(attention, self.dropout, training=self.training) 281 | #h = F.dropout(h, self.dropout, training=self.training) 282 | h_prime = sp_matmul(G.connectivities['do'], attention, h) 283 | except: 284 | h_prime = torch.tensor(0.0) 285 | 286 | out = h_prime 287 | 288 | try: 289 | h = torch.matmul(x, self.Wsol) 290 | source, target = G.connectivities['up'] 291 | a = torch.cat([h[source], h[target]], dim=1) 292 | e = self.cell_attention_activation(torch.matmul(a, self.att_sol)) 293 | #e = F.dropout(e, self.dropout, training=self.training) 294 | 295 | attention = sp_softmax(G.connectivities['up'], e, h.size(0)) 296 | attention = F.dropout(attention, self.dropout, training=self.training) 297 | #h = F.dropout(h, self.dropout, training=self.training) 298 | h_prime = sp_matmul(G.connectivities['up'], attention, h) 299 | except: 300 | h_prime = torch.tensor(0.0) 301 | 302 | out += h_prime 303 | return self.cell_forward_activation(out), G 304 | 305 | 306 | 307 | class MultiHeadCellAttentionLayer(nn.Module): 308 | 309 | """ 310 | mhcal = MultiHeadCellAttentionLayer(F_in=3, F_out=3, 311 | sigma=torch.nn.ELU(), K=3, 312 | p_dropout=0.2) 313 | """ 314 | def __init__(self, F_in: int, F_out: int, skip: bool, 315 | cell_attention_heads: int, 316 | cell_attention_activation: Callable, 317 | cell_forward_activation: Callable, 318 | cell_attention_dropout: float, 319 | cell_forward_dropout: float, 320 | cell_attention_readout: str, 321 | param_init: str): 322 | """ 323 | cell_attention_activation: Callable, cell_attention_dropout: float, 324 | cell_attention_readout: str, 325 | cell_forward_activation: Callable, cell_forward_dropout: float, 326 | alpha_leaky_relu : float 327 | nevative slope of the leakyrelu. 328 | 329 | readout : str 330 | which function to use in order to perform readout operations 331 | 332 | K: Numer of attention heads 333 | """ 334 | super(MultiHeadCellAttentionLayer, self).__init__() 335 | assert F_in > 0, ValueError("Number of input feature must be > 0") 336 | assert F_out > 0, ValueError("Number of output feature must be > 0") 337 | assert cell_attention_readout in ['cat', 'avg', 'sum', 'max'], ValueError("readout must be one of ['cat', 'avg', 'sum', 'max']") 338 | assert param_init in ['uniform', 'normal'], ValueError("Param init must be one of ['uniform', 'normal']") 339 | self.F_out = F_out 340 | self.F_in = F_in 341 | self.skip=skip 342 | self.cell_attention_heads=cell_attention_heads 343 | self.cell_attention_readout = cell_attention_readout 344 | self.cell_forward_activation = cell_forward_activation 345 | self.cell_forward_dropout = cell_forward_dropout 346 | self.attentions = [CellAttentionLayer(F_in=F_in, 347 | F_out=F_out, 348 | cell_attention_dropout=cell_attention_dropout, 349 | cell_attention_activation=cell_attention_activation, 350 | cell_forward_activation=cell_forward_activation, 351 | cell_forward_dropout=cell_forward_dropout, 352 | param_init=param_init, 353 | skip=skip) 354 | for _ in range(cell_attention_heads)] 355 | 356 | 357 | 358 | def __repr__(self): 359 | _F_out = self.F_out*self.cell_attention_heads \ 360 | if self.cell_attention_readout == 'cat' else self.F_out 361 | s = "MultiHeadCellAttentionLayer(" + \ 362 | "F_in="+str(self.F_in)+ \ 363 | ", F_out="+str(_F_out)+\ 364 | ", Heads=" +str(self.cell_attention_heads) + \ 365 | ", Readout=" +self.cell_attention_readout+ \ 366 | "):" 367 | for idx, attention_layer in enumerate(self.attentions): 368 | s+= "\n\t(" +str(idx)+"): "+ str(attention_layer) 369 | return s 370 | 371 | def to(self, device): 372 | self.attentions = nn.ModuleList([attention.to(device) for attention in self.attentions]) 373 | return self 374 | 375 | def forward(self, x: Tuple[EdgeSignal, Graph], *args, **kwargs): 376 | G = x[1] 377 | H = [attention_layer(x, *args, **kwargs) for attention_layer in self.attentions] 378 | H = [hidden_signal[0] for hidden_signal in H] 379 | if self.cell_attention_readout == 'cat': 380 | H = torch.cat(H, dim=1) # Xe 381 | elif self.cell_attention_readout == 'sum': 382 | H = torch.stack(H, dim=2).sum(dim=2) 383 | elif self.cell_attention_readout == 'avg': 384 | H = torch.stack(H, dim=2).mean(dim=2) 385 | elif self.cell_attention_readout == 'max': 386 | H = torch.stack(H, dim=2).max(dim=2).values 387 | 388 | return H, G 389 | 390 | 391 | class CAPooLayer(nn.Module): 392 | """ 393 | CAPooLayer (Cellular Attention Pooling Layer) is responsible for pooling 394 | operations in the Cellular Graph Attention Network. 395 | 396 | This layer applies attention-based pooling to a given edge signal 397 | and updates the graph accordingly. 398 | 399 | Parameters 400 | ---------- 401 | k_pool : float 402 | Fraction of nodes to keep after the pooling operation. 403 | F_in : int 404 | Number of input features for the pooling layer. 405 | cell_forward_activation : Callable 406 | Non-linear activation function used in the forward pass. 407 | 408 | Returns 409 | ------- 410 | CAPooLayer. 411 | 412 | Examples 413 | ------- 414 | pool = CAPooLayer(k_pool=.75, 415 | F_in=3*att_heads, 416 | cell_forward_activation=nn.ReLU) 417 | """ 418 | 419 | def __init__(self, k_pool: float, F_in: int, cell_forward_activation: Callable): 420 | super(CAPooLayer, self).__init__() 421 | 422 | self.k_pool = k_pool 423 | self.cell_forward_activation = cell_forward_activation 424 | 425 | # Learnable attention parameter for the pooling operation 426 | self.att_pool = nn.Parameter(torch.empty(size=(F_in, 1))) 427 | 428 | # Initialize the attention parameter using Xavier initialization 429 | nn.init.xavier_normal_(self.att_pool.data, gain=1.41) 430 | 431 | 432 | def __repr__(self): 433 | s = "PoolLayer(" + \ 434 | "K Pool="+str(self.k_pool)+ ")" 435 | return s 436 | 437 | 438 | def forward(self, x: EdgeSignal) -> EdgeSignal: 439 | 440 | x, G = x 441 | shape = x.shape 442 | Zp = x @ self.att_pool 443 | idx = topk(Zp.view(-1), self.k_pool, G.edge_batch) 444 | x = x[idx] * self.cell_forward_activation(Zp)[idx].view(-1, 1) 445 | G.edge_batch = G.edge_batch[idx] 446 | G.ros.append(readout(x, G.edge_batch, 'sum')) 447 | G.connectivities['up'] = tuple(filter_adj(torch.stack(G.connectivities['up']), None, idx, shape[0])[0]) 448 | G.connectivities['do'] = tuple(filter_adj(torch.stack(G.connectivities['do']), None, idx, shape[0])[0]) 449 | 450 | 451 | return x, G 452 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Fri Feb 4 20:06:24 2022 5 | 6 | @author: ince 7 | """ 8 | 9 | ##stick graph_tool on top to avoid segfaults 10 | 11 | import graph_tool as gt 12 | import graph_tool.topology as top 13 | 14 | gt.openmp_set_num_threads(18)#os.cpu_count()-2) 15 | 16 | import torch 17 | import scipy.sparse as sp 18 | import scipy.sparse.linalg as spl 19 | import numpy as np 20 | 21 | import gudhi as gd 22 | import itertools 23 | import networkx as nx 24 | 25 | from typing import List, Dict 26 | from torch import Tensor 27 | from torch_scatter import scatter 28 | 29 | 30 | from tqdm.auto import tqdm 31 | from joblib import Parallel 32 | from copy import deepcopy 33 | from collections import defaultdict 34 | from torch_geometric.data import Batch 35 | 36 | 37 | def readout(value:torch.Tensor, labels:torch.LongTensor, op:str) -> (torch.Tensor, torch.LongTensor): 38 | """Group-wise average for (sparse) grouped tensors 39 | 40 | Args: 41 | value (torch.Tensor): values to average (# samples, latent dimension) 42 | labels (torch.LongTensor): labels for embedding parameters (# samples,) 43 | 44 | Returns: 45 | result (torch.Tensor): (# unique labels, latent dimension) 46 | new_labels (torch.LongTensor): (# unique labels,) 47 | 48 | Examples: 49 | >>> samples = torch.Tensor([ 50 | [0.15, 0.15, 0.15], #-> group / class 1 51 | [0.2, 0.2, 0.2], #-> group / class 3 52 | [0.4, 0.4, 0.4], #-> group / class 3 53 | [0.0, 0.0, 0.0] #-> group / class 0 54 | ]) 55 | >>> labels = torch.LongTensor([1, 5, 5, 0]) 56 | >>> result, new_labels = groupby_mean(samples, labels) 57 | 58 | >>> result 59 | tensor([[0.0000, 0.0000, 0.0000], 60 | [0.1500, 0.1500, 0.1500], 61 | [0.3000, 0.3000, 0.3000]]) 62 | 63 | >>> new_labels 64 | tensor([0, 1, 5]) 65 | """ 66 | uniques = labels.unique().tolist() 67 | labels = labels.tolist() 68 | 69 | key_val = {key: val for key, val in zip(uniques, range(len(uniques)))} 70 | val_key = {val: key for key, val in zip(uniques, range(len(uniques)))} 71 | 72 | labels = torch.LongTensor(list(map(key_val.get, labels))) 73 | 74 | labels = labels.view(labels.size(0), 1).expand(-1, value.size(1)).cuda() 75 | 76 | unique_labels, labels_count = labels.unique(dim=0, return_counts=True) 77 | result = torch.zeros_like(unique_labels, dtype=torch.float).scatter_add_(0, labels, value) 78 | if op == 'avg': 79 | result = result / labels_count.float().unsqueeze(1) 80 | 81 | return result 82 | 83 | 84 | class SparseDropout(torch.nn.Module): 85 | def __init__(self, p_droput=0.5): 86 | super(SparseDropout, self).__init__() 87 | # dprob is ratio of dropout 88 | # convert to keep probability 89 | self.kprob=1-p_droput 90 | 91 | def forward(self, x, training): 92 | mask=((torch.rand(x._values().size())+(self.kprob)).floor()).type(torch.bool) 93 | rc=x._indices()[:,mask] 94 | val=x._values()[mask]*(1.0/self.kprob) 95 | return torch.sparse_coo_tensor(rc, val, x.shape) 96 | 97 | def compute_projection_matrix(L, eps, kappa): 98 | P = (torch.eye(L.shape[0]) - eps*L) 99 | for _ in range(kappa): 100 | P = P @ P # approximate the limit 101 | return P 102 | 103 | def normalize(L, half_interval=False): 104 | assert(L.shape[0] == L.shape[1]) 105 | topeig = torch.linalg.eigvalsh(L.to_dense()).max().item() 106 | values = L.values() 107 | if half_interval: 108 | values *= 1.0/topeig 109 | else: 110 | values *= 2.0/topeig 111 | 112 | return torch.sparse_coo_tensor(L.indices(), values, size=L.shape).to_dense() 113 | 114 | 115 | def coo2tensor(A): 116 | assert(sp.isspmatrix_coo(A)) 117 | idxs = torch.LongTensor(np.vstack((A.row, A.col))) 118 | vals = torch.FloatTensor(A.data) 119 | return torch.sparse_coo_tensor(idxs, vals, size = A.shape, requires_grad = False) 120 | 121 | def normalize2(L,Lx, half_interval = False): 122 | assert(sp.isspmatrix(L)) 123 | M = L.shape[0] 124 | assert(M == L.shape[1]) 125 | topeig = spl.eigsh(L, k=1, which="LM", return_eigenvectors = False)[0] # we use the maximal eigenvalue of L to normalize 126 | #print("Topeig = %f" %(topeig)) 127 | 128 | ret = Lx.copy() 129 | if half_interval: 130 | ret *= 1.0/topeig 131 | else: 132 | ret *= 2.0/topeig 133 | ret.setdiag(ret.diagonal(0) - np.ones(M), 0) 134 | 135 | return ret 136 | 137 | 138 | 139 | 140 | def normalize3(L, half_interval = False): 141 | assert(sp.isspmatrix(L)) 142 | M = L.shape[0] 143 | assert(M == L.shape[1]) 144 | topeig = spl.eigsh(L, k=1, which="LM", return_eigenvectors = False)[0] 145 | #print("Topeig = %f" %(topeig)) 146 | 147 | ret = L.copy() 148 | if half_interval: 149 | ret *= 1.0/topeig 150 | else: 151 | ret *= 2.0/topeig 152 | ret.setdiag(ret.diagonal(0) - np.ones(M), 0) 153 | 154 | return ret 155 | 156 | 157 | 158 | def batch_mm(matrix, matrix_batch): 159 | """ 160 | :param matrix: Sparse or dense matrix, size (m, n). 161 | :param matrix_batch: Batched dense matrices, size (b, n, k). 162 | :return: The batched matrix-matrix product, size (m, n) x (b, n, k) = (b, m, k). 163 | """ 164 | batch_size = matrix_batch.shape[0] 165 | # Stack the vector batch into columns. (b, n, k) -> (n, b, k) -> (n, b*k) 166 | vectors = matrix_batch.transpose(0, 1).reshape(matrix.shape[1], -1) 167 | # 168 | # A matrix-matrix product is a batched matrix-vector product of the columns. 169 | # And then reverse the reshaping. 170 | #(m, n) x (n, b*k) = (m, b*k) -> (m, b, k) -> (b, m, k) 171 | return matrix.mm(vectors).reshape(matrix.shape[0], batch_size, -1).transpose(1, 0) 172 | 173 | 174 | 175 | ###### RING UTILS 176 | 177 | 178 | 179 | 180 | def pyg_to_simplex_tree(edge_index: Tensor, size: int): 181 | """Constructs a simplex tree from a PyG graph. 182 | 183 | Args: 184 | edge_index: The edge_index of the graph (a tensor of shape [2, num_edges]) 185 | size: The number of nodes in the graph. 186 | """ 187 | st = gd.SimplexTree() 188 | # Add vertices to the simplex. 189 | for v in range(size): 190 | st.insert([v]) 191 | 192 | # Add the edges to the simplex. 193 | edges = edge_index.numpy() 194 | for e in range(edges.shape[1]): 195 | edge = [edges[0][e], edges[1][e]] 196 | st.insert(edge) 197 | 198 | return st 199 | 200 | 201 | def get_simplex_boundaries(simplex): 202 | boundaries = itertools.combinations(simplex, len(simplex) - 1) 203 | return [tuple(boundary) for boundary in boundaries] 204 | 205 | 206 | def build_tables(simplex_tree, size): 207 | complex_dim = simplex_tree.dimension() 208 | # Each of these data structures has a separate entry per dimension. 209 | id_maps = [{} for _ in range(complex_dim+1)] # simplex -> id 210 | simplex_tables = [[] for _ in range(complex_dim+1)] # matrix of simplices 211 | boundaries_tables = [[] for _ in range(complex_dim+1)] 212 | 213 | simplex_tables[0] = [[v] for v in range(size)] 214 | id_maps[0] = {tuple([v]): v for v in range(size)} 215 | 216 | for simplex, _ in simplex_tree.get_simplices(): 217 | dim = len(simplex) - 1 218 | if dim == 0: 219 | continue 220 | 221 | # Assign this simplex the next unused ID 222 | next_id = len(simplex_tables[dim]) 223 | id_maps[dim][tuple(simplex)] = next_id 224 | simplex_tables[dim].append(simplex) 225 | 226 | return simplex_tables, id_maps 227 | 228 | 229 | def extract_boundaries_and_coboundaries_from_simplex_tree(simplex_tree, id_maps, complex_dim: int): 230 | """Build two maps simplex -> its coboundaries and simplex -> its boundaries""" 231 | # The extra dimension is added just for convenience to avoid treating it as a special case. 232 | boundaries = [{} for _ in range(complex_dim+2)] # simplex -> boundaries 233 | coboundaries = [{} for _ in range(complex_dim+2)] # simplex -> coboundaries 234 | boundaries_tables = [[] for _ in range(complex_dim+1)] 235 | 236 | for simplex, _ in simplex_tree.get_simplices(): 237 | # Extract the relevant boundary and coboundary maps 238 | simplex_dim = len(simplex) - 1 239 | level_coboundaries = coboundaries[simplex_dim] 240 | level_boundaries = boundaries[simplex_dim + 1] 241 | 242 | # Add the boundaries of the simplex to the boundaries table 243 | if simplex_dim > 0: 244 | boundaries_ids = [id_maps[simplex_dim-1][boundary] for boundary in get_simplex_boundaries(simplex)] 245 | boundaries_tables[simplex_dim].append(boundaries_ids) 246 | 247 | # This operation should be roughly be O(dim_complex), so that is very efficient for us. 248 | # For details see pages 6-7 https://hal.inria.fr/hal-00707901v1/document 249 | simplex_coboundaries = simplex_tree.get_cofaces(simplex, codimension=1) 250 | for coboundary, _ in simplex_coboundaries: 251 | assert len(coboundary) == len(simplex) + 1 252 | 253 | if tuple(simplex) not in level_coboundaries: 254 | level_coboundaries[tuple(simplex)] = list() 255 | level_coboundaries[tuple(simplex)].append(tuple(coboundary)) 256 | 257 | if tuple(coboundary) not in level_boundaries: 258 | level_boundaries[tuple(coboundary)] = list() 259 | level_boundaries[tuple(coboundary)].append(tuple(simplex)) 260 | 261 | return boundaries_tables, boundaries, coboundaries 262 | 263 | 264 | def build_adj(boundaries: List[Dict], coboundaries: List[Dict], id_maps: List[Dict], complex_dim: int, 265 | include_down_adj: bool): 266 | """Builds the upper and lower adjacency data structures of the complex 267 | 268 | Args: 269 | boundaries: A list of dictionaries of the form 270 | boundaries[dim][simplex] -> List[simplex] (the boundaries) 271 | coboundaries: A list of dictionaries of the form 272 | coboundaries[dim][simplex] -> List[simplex] (the coboundaries) 273 | id_maps: A dictionary from simplex -> simplex_id 274 | """ 275 | def initialise_structure(): 276 | return [[] for _ in range(complex_dim+1)] 277 | 278 | upper_indexes, lower_indexes = initialise_structure(), initialise_structure() 279 | all_shared_boundaries, all_shared_coboundaries = initialise_structure(), initialise_structure() 280 | 281 | # Go through all dimensions of the complex 282 | for dim in range(complex_dim+1): 283 | # Go through all the simplices at that dimension 284 | for simplex, id in id_maps[dim].items(): 285 | # Add the upper adjacent neighbours from the level below 286 | if dim > 0: 287 | for boundary1, boundary2 in itertools.combinations(boundaries[dim][simplex], 2): 288 | id1, id2 = id_maps[dim - 1][boundary1], id_maps[dim - 1][boundary2] 289 | upper_indexes[dim - 1].extend([[id1, id2], [id2, id1]]) 290 | all_shared_coboundaries[dim - 1].extend([id, id]) 291 | 292 | # Add the lower adjacent neighbours from the level above 293 | if include_down_adj and dim < complex_dim and simplex in coboundaries[dim]: 294 | for coboundary1, coboundary2 in itertools.combinations(coboundaries[dim][simplex], 2): 295 | id1, id2 = id_maps[dim + 1][coboundary1], id_maps[dim + 1][coboundary2] 296 | lower_indexes[dim + 1].extend([[id1, id2], [id2, id1]]) 297 | all_shared_boundaries[dim + 1].extend([id, id]) 298 | 299 | return all_shared_boundaries, all_shared_coboundaries, lower_indexes, upper_indexes 300 | 301 | 302 | def construct_features(vx: Tensor, cell_tables, init_method: str) -> List: 303 | """Combines the features of the component vertices to initialise the cell features""" 304 | features = [vx] 305 | for dim in range(1, len(cell_tables)): 306 | aux_1 = [] 307 | aux_0 = [] 308 | for c, cell in enumerate(cell_tables[dim]): 309 | aux_1 += [c for _ in range(len(cell))] 310 | aux_0 += cell 311 | node_cell_index = torch.LongTensor([aux_0, aux_1]) 312 | in_features = vx.index_select(0, node_cell_index[0]) 313 | features.append(scatter(in_features, node_cell_index[1], dim=0, 314 | dim_size=len(cell_tables[dim]), reduce=init_method)) 315 | 316 | return features 317 | 318 | 319 | def extract_labels(y, size): 320 | v_y, complex_y = None, None 321 | if y is None: 322 | return v_y, complex_y 323 | 324 | y_shape = list(y.size()) 325 | 326 | if y_shape[0] == 1: 327 | # This is a label for the whole graph (for graph classification). 328 | # We will use it for the complex. 329 | complex_y = y 330 | else: 331 | # This is a label for the vertices of the complex. 332 | assert y_shape[0] == size 333 | v_y = y 334 | 335 | return v_y, complex_y 336 | 337 | # ---- support for rings as cells Graph add_edge_list remove_parallel_edges 338 | 339 | def get_rings(edge_index, max_k=7): 340 | if isinstance(edge_index, torch.Tensor): 341 | edge_index = edge_index.numpy() 342 | 343 | edge_list = edge_index.T 344 | graph_gt = gt.Graph(directed=False) 345 | graph_gt.add_edge_list(edge_list) 346 | 347 | gt.stats.remove_self_loops(graph_gt) 348 | gt.stats.remove_parallel_edges(graph_gt) 349 | # We represent rings with their original node ordering 350 | # so that we can easily read out the boundaries 351 | # The use of the `sorted_rings` set allows to discard 352 | # different isomorphisms which are however associated 353 | # to the same original ring – this happens due to the intrinsic 354 | # symmetries of cycles 355 | rings = set() 356 | sorted_rings = set() 357 | for k in range(3, max_k+1): 358 | pattern = nx.cycle_graph(k) 359 | pattern_edge_list = list(pattern.edges) 360 | pattern_gt = gt.Graph(directed=False) 361 | pattern_gt.add_edge_list(pattern_edge_list) 362 | sub_isos = top.subgraph_isomorphism(pattern_gt, graph_gt, induced=True, subgraph=True, 363 | generator=True) 364 | sub_iso_sets = map(lambda isomorphism: tuple(isomorphism.a), sub_isos) 365 | for iso in sub_iso_sets: 366 | if tuple(sorted(iso)) not in sorted_rings: 367 | rings.add(iso) 368 | sorted_rings.add(tuple(sorted(iso))) 369 | rings = list(rings) 370 | return rings 371 | 372 | 373 | def build_tables_with_rings(edge_index, simplex_tree, size, max_k): 374 | 375 | # Build simplex tables and id_maps up to edges by conveniently 376 | # invoking the code for simplicial complexes 377 | cell_tables, id_maps = build_tables(simplex_tree, size) 378 | 379 | # Find rings in the graph 380 | rings = get_rings(edge_index, max_k=max_k) 381 | 382 | if len(rings) > 0: 383 | # Extend the tables with rings as 2-cells 384 | id_maps += [{}] 385 | cell_tables += [[]] 386 | assert len(cell_tables) == 3, cell_tables 387 | for cell in rings: 388 | next_id = len(cell_tables[2]) 389 | id_maps[2][cell] = next_id 390 | cell_tables[2].append(list(cell)) 391 | 392 | return cell_tables, id_maps 393 | 394 | 395 | def get_ring_boundaries(ring): 396 | boundaries = list() 397 | for n in range(len(ring)): 398 | a = n 399 | if n + 1 == len(ring): 400 | b = 0 401 | else: 402 | b = n + 1 403 | # We represent the boundaries in lexicographic order 404 | # so to be compatible with 0- and 1- dim cells 405 | # extracted as simplices with gudhi 406 | boundaries.append(tuple(sorted([ring[a], ring[b]]))) 407 | return sorted(boundaries) 408 | 409 | 410 | def extract_boundaries_and_coboundaries_with_rings(simplex_tree, id_maps): 411 | """Build two maps: cell -> its coboundaries and cell -> its boundaries""" 412 | 413 | # Find boundaries and coboundaries up to edges by conveniently 414 | # invoking the code for simplicial complexes 415 | assert simplex_tree.dimension() <= 1 416 | boundaries_tables, boundaries, coboundaries = extract_boundaries_and_coboundaries_from_simplex_tree( 417 | simplex_tree, id_maps, simplex_tree.dimension()) 418 | 419 | assert len(id_maps) <= 3 420 | if len(id_maps) == 3: 421 | # Extend tables with boundary and coboundary information of rings 422 | boundaries += [{}] 423 | coboundaries += [{}] 424 | boundaries_tables += [[]] 425 | for cell in id_maps[2]: 426 | cell_boundaries = get_ring_boundaries(cell) 427 | boundaries[2][cell] = list() 428 | boundaries_tables[2].append([]) 429 | for boundary in cell_boundaries: 430 | assert boundary in id_maps[1], boundary 431 | boundaries[2][cell].append(boundary) 432 | if boundary not in coboundaries[1]: 433 | coboundaries[1][boundary] = list() 434 | coboundaries[1][boundary].append(cell) 435 | boundaries_tables[2][-1].append(id_maps[1][boundary]) 436 | 437 | return boundaries_tables, boundaries, coboundaries 438 | 439 | 440 | def compute_incidences(edges, max_k=7): 441 | """ 442 | 443 | Get cellular incidence matrices from a graph. 444 | 445 | Parameters 446 | ---------- 447 | g : dgl.heterograph.DGLHeteroGraph or edgelist 448 | Graph or Batched Graph. 449 | max_k : int, optional 450 | max ring size. The default is 7. 451 | 452 | Returns 453 | ------- 454 | B1 & B2: torch.SparseTensor 455 | B1: node -> edge map 456 | B2: edge -> cell map 457 | 458 | 459 | 460 | #extract edge list 461 | try: 462 | edges = g.all_edges() 463 | edges = torch.stack(edges, dim=0) 464 | except: 465 | edges=torch.tensor(g) 466 | s = edges[:,0] 467 | t = edges[:,1] 468 | edges = torch.stack((s,t), dim=0) 469 | """ 470 | 471 | #from ([sources], [targets]) to ((s,t)) for s in sources and t in targets, dim=0) 472 | 473 | size = max(edges[0].max() , edges[1].max()) + 1 474 | 475 | #compute simplex treem, utility for getting boundaries 476 | simplex_tree = pyg_to_simplex_tree(edges, size) 477 | 478 | #here we get the mapping from the cell_i to it's id (label) 479 | #dim(cell_i) in [0,1,2] 480 | _, id_maps = build_tables_with_rings(edges, simplex_tree, size, max_k) 481 | 482 | #get the boundaries for computing B2 483 | _, boundaries, _ = extract_boundaries_and_coboundaries_with_rings(simplex_tree, id_maps) 484 | 485 | ## build the inverted id map: 486 | ## inv_id_maps[i] = [cell_j] i in [0,1,2], j in [0, ..., len(inv_id_maps[i])-1] 487 | inv_id_maps = [[cell for cell,id_map in id_map.items()] for id_map in id_maps ] 488 | 489 | nV= max(id_maps[0].values())+1 # number of vertices 490 | nE = max(id_maps[1].values())+1 # number of edges 491 | try: 492 | nF = max(id_maps[2].values())+1 # number of faces (cells) 493 | except: 494 | nF = 1 495 | B1 = np.zeros((nV, nE)) 496 | B2 = np.zeros((nE, nF)) 497 | 498 | 499 | # edge orientation is coherent with the ordering 500 | # of the vertices in the edge description 501 | for idx, e in enumerate(inv_id_maps[1]): 502 | B1[e[0],idx] = +1 503 | B1[e[1],idx] = -1 504 | 505 | # rings ordering is coherent with the orientation of its boundaries 506 | for idx, (face, bonds) in enumerate(boundaries[2].items()): 507 | #circular sliding windows of length 2 of the boundaries of the ring 508 | pairs = [(face[i], face[i+1]) for i in range(len(face)-1)] + [(face[-1], face[0])] 509 | for idx_bond, bond in enumerate(bonds): 510 | if bond in pairs: 511 | B2[id_maps[1][bond],idx] = +1 512 | elif bond[::-1] in pairs: 513 | B2[id_maps[1][bond],idx] = -1 514 | 515 | 516 | # easy to store tough to handle algebraic stuff 517 | B1 = torch.from_numpy(B1).to_sparse().float() 518 | B2 = torch.from_numpy(B2).to_sparse().float() 519 | return B1, B2 520 | 521 | 522 | def compute_cell_complex_stat(data, max_k): 523 | cell_dim_list = [] 524 | for G in data: 525 | _, B2 = compute_incidences(G.edge_index.cpu(), max_k) 526 | 527 | cell_dim_list.extend(B2.to_dense().abs().sum(dim=0).to(int).tolist()) 528 | 529 | return cell_dim_list 530 | 531 | def compute_cell(G, max_k): 532 | 533 | B1, B2 = compute_incidences(G.edge_index.cpu(), max_k) 534 | 535 | 536 | #removes self loops from the node adjacency 537 | adj = torch.sparse.mm(B1, B1.t()).to_dense() 538 | adj = (adj - adj.diagonal()*torch.eye(adj.shape[0])).to_sparse() #remove self loops 539 | #this will be used as indexer in the signal lift attention mechanism 540 | edge_indices = adj.indices() 541 | 542 | #arranging the edge indices so that in lift phase 543 | #is possible to index the connectivity directly and reshape the tensor 544 | #ysince triu indices are placed right before tril indices 545 | #permutation equivariance ensures symmetries 546 | """ 547 | num_nodes = G.num_nodes 548 | triu = torch.triu_indices(num_nodes,num_nodes).T 549 | idxs_u = [] 550 | idxs_l = [] 551 | for idx in range(adj._nnz()): 552 | if (edge_indices[:, idx] == triu).all(axis=1).any(): #check if edge is in the triu matrix 553 | idxs_u.append(idx) 554 | else: 555 | idxs_l.append(idx) 556 | 557 | edge_indices = torch.hstack((edge_indices[:, idxs_u], 558 | edge_indices[:, idxs_l])) 559 | """ 560 | #remove self loops from the lower adjacency neighorhood 561 | Ldo = torch.sparse.mm(B1.t(), B1).coalesce().to_dense()#.fill_diagonal_(1).to_sparse()#.to_dense() 562 | Ldo = (Ldo - Ldo.diagonal()*torch.eye(Ldo.shape[0])).to_sparse() #remove self loops 563 | 564 | #remove self loops from the upper adjacency neighorhood 565 | Lup = torch.sparse.mm(B2 , B2.t()).coalesce().to_dense()#.fill_diagonal_(1).to_sparse() 566 | Lup = (Lup - Lup.diagonal()*torch.eye(Lup.shape[0])).to_sparse() #remove self loops 567 | 568 | 569 | ### 570 | ### Convolutions require the entire connectivity information of the complex 571 | ### Attention deals only with the connectivity structure of the complex 572 | ### To deal with graph batching we incorporate additional information to the graphs 573 | ### The connectivity information will be collated and adjusted according to the reindexing mechanism of the collator 574 | 575 | lower_neigh_connection = Ldo.coalesce().indices() 576 | upper_neigh_connection = Lup.coalesce().indices() 577 | 578 | return {'do': lower_neigh_connection, 579 | 'up': upper_neigh_connection, 580 | 'adj': edge_indices} 581 | #'P': compute_projection_matrix(normalize((Lup+Ldo).coalesce()), 0.88, 7)} 582 | 583 | class ProgressParallel(Parallel): 584 | """A helper class for adding tqdm progressbar to the joblib library.""" 585 | def __init__(self, use_tqdm=True, total=None, *args, **kwargs): 586 | self._use_tqdm = use_tqdm 587 | self._total = total 588 | super().__init__(*args, **kwargs) 589 | 590 | def __call__(self, *args, **kwargs): 591 | with tqdm(disable=not self._use_tqdm, total=self._total) as self._pbar: 592 | return Parallel.__call__(self, *args, **kwargs) 593 | 594 | def print_progress(self): 595 | if self._total is None: 596 | self._pbar.total = self.n_dispatched_tasks 597 | self._pbar.n = self.n_completed_tasks 598 | self._pbar.refresh() 599 | 600 | def collate_complexes(samples): 601 | """ 602 | 603 | Each input graph becomes one disjoint component of the batched graph. The nodes 604 | and edges connecctivity matrices are relabeled to be disjoint segments: 605 | 606 | ================= ========= ================= === ========= 607 | graphs[0] graphs[1] ... graphs[k] 608 | ================= ========= ================= === ========= 609 | Original node ID 0 ~ N_0 0 ~ N_1 ... 0 ~ N_k 610 | New node ID 0 ~ N_0 N_0+1 ~ N_0+N_1+1 ... 1+\sum_{i=0}^{k-1} N_i ~ 611 | 1+\sum_{i=0}^k N_i 612 | ================= ========= ================= === ========= 613 | 614 | ------------------ EDGE REINDEXING --------------------------------------------- 615 | 616 | ================= ========= ================= === ========= 617 | graphs[0] graphs[1] ... graphs[k] 618 | ================= ========= ================= === ========= 619 | Original edge ID 0 ~ E_0 0 ~ E_1 ... 0 ~ E_k 620 | New edge ID 0 ~ E_0 E_0+1 ~ E_0+E_1+1 ... 1+\sum_{i=0}^{k-1} E_i ~ 621 | 1+\sum_{i=0}^k E_i 622 | ================= ========= ================= === =========dataset.num_classes 623 | 624 | Parameters 625 | ---------- 626 | samples : List of torch_geometric graphs with cell complex connectivity 627 | 628 | Returns 629 | ------- 630 | batched_graph: InMemoryDataset 631 | batch complexes and adjust connectivity. 632 | """ 633 | # collate for generating batched data 634 | # The samples is a list of pairs (graph, label) 635 | graphs, connectivities = map(list, zip(*samples)) 636 | batched_connectivities = defaultdict(list) 637 | projection_matrices = [] 638 | prev_num_nodes = prev_num_edges = 0 639 | for idx, graph in enumerate(graphs): 640 | batched_connectivities['adj'].append(connectivities[idx]['adj']+prev_num_nodes) 641 | batched_connectivities['do'].append(connectivities[idx]['do']+prev_num_edges) 642 | batched_connectivities['up'].append(connectivities[idx]['up']+prev_num_edges) 643 | prev_num_nodes += graph.num_nodes 644 | prev_num_edges += graph.num_edges 645 | 646 | 647 | batched_graph = Batch.from_data_list(graphs) 648 | batched_graph.connectivities = defaultdict(tuple) 649 | for k in batched_connectivities: 650 | batched_graph.connectivities[k] = tuple([torch.cat(x) for x in zip(*batched_connectivities[k])]) 651 | 652 | batched_graph.ros = [] 653 | batched_graph.default_connectivity = deepcopy(batched_graph.connectivities) 654 | 655 | #batched_graph.edge_attr = batched_graph.edge_attr.float().view(-1,1) 656 | #batched_graph.x = batched_graph.x.float() 657 | batched_graph.y = batched_graph.y.long()#.view(-1,1) 658 | return batched_graph 659 | --------------------------------------------------------------------------------