├── .gitignore ├── .idea ├── .gitignore ├── SMP.iml ├── deployment.xml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── other.xml └── vcs.xml ├── LICENSE ├── README.md ├── config_cycles.yaml ├── config_multi_task.yaml ├── config_zinc.yaml ├── cycles_main.py ├── data ├── .DS_Store ├── datasets_kcycle_nsamples=10000 │ └── .DS_Store └── multitask_dataset.pkl ├── datasets_generation ├── __pycache__ │ ├── build_cycles.cpython-37.pyc │ ├── graph_algorithms.cpython-37.pyc │ ├── graph_generation.cpython-37.pyc │ └── multitask_dataset.cpython-37.pyc ├── build_cycles.py ├── graph_algorithms.py ├── graph_generation.py └── multitask_dataset.py ├── models ├── .DS_Store ├── gin.py ├── model_cycles.py ├── model_multi_task.py ├── model_zinc.py ├── ppgn.py ├── ring_gnn.py ├── smp_layers.py └── utils │ ├── layers.py │ ├── misc.py │ └── transforms.py ├── multi_task_main.py ├── multi_task_utils ├── train.py └── util.py ├── requirements.txt ├── saved_models ├── PPGN_4 │ └── epoch0.pkl └── ZINC │ └── Zinc_SMP.pkl └── zinc_main.py /.gitignore: -------------------------------------------------------------------------------- 1 | models/__pycache__/ 2 | .idea/ 3 | .DS_Store 4 | data/ 5 | multi_task_utils/__pycache__/ 6 | __pycache__/ 7 | multi_task_utils/__pycache__/ 8 | wandb/ 9 | data/.DS_Store 10 | tests/ 11 | saved_models/ -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Datasource local storage ignored files 5 | /dataSources/ 6 | /dataSources.local.xml 7 | # Editor-based HTTP Client requests 8 | /httpRequests/ 9 | -------------------------------------------------------------------------------- /.idea/SMP.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 24 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 17 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/other.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Gabriele Corso, Luca Cavalleri 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Building powerful and equivariant graph neural networks with structural message-passing 2 | 3 | This paper contains code for the paper *Building powerful and equivariant graph neural networks with structural message-passing* (Neurips 2020) by 4 | [Clément Vignac](https://cvignac.github.io/), [Andreas Loukas](https://andreasloukas.blog/) and [Pascal Frossard](https://www.epfl.ch/labs/lts4/people/people-current/frossard/). 5 | [Link to the paper](https://papers.nips.cc/paper/2020/file/a32d7eeaae19821fd9ce317f3ce952a7-Paper.pdf) 6 | 7 | Abstract: 8 | 9 | Message-passing has proved to be an effective way to design graph neural networks, 10 | as it is able to leverage both permutation equivariance and an inductive bias towards 11 | learning local structures in order to achieve good generalization. However, current 12 | message-passing architectures have a limited representation power and fail to learn 13 | basic topological properties of graphs. We address this problem and propose a 14 | powerful and equivariant message-passing framework based on two ideas: first, 15 | we propagate a one-hot encoding of the nodes, in addition to the features, in order 16 | to learn a local context matrix around each node. This matrix contains rich local 17 | information about both features and topology and can eventually be pooled to build 18 | node representations. Second, we propose methods for the parametrization of the 19 | message and update functions that ensure permutation equivariance. Having a 20 | representation that is independent of the specific choice of the one-hot encoding 21 | permits inductive reasoning and leads to better generalization properties. Experi- 22 | mentally, our model can predict various graph topological properties on synthetic 23 | data more accurately than previous methods and achieves state-of-the-art results on 24 | molecular graph regression on the ZINC dataset. 25 | 26 | ## Code overview 27 | 28 | 29 | This folder contains the source code used for Structural Message passing for three tasks: 30 | - Cycle detection 31 | - The multi-task regression of graph properties presented in [https://arxiv.org/abs/2004.05718](https://arxiv.org/abs/2004.05718) 32 | - Constrained solubility regression on ZINC 33 | 34 | Source code for the second task is adapted from [https://github.com/lukecavabarrett/pna](https://github.com/lukecavabarrett/pna). 35 | 36 | 37 | ## Dependencies 38 | [https://pytorch-geometric.readthedocs.io/en/latest/](Pytorch geometric) v1.6.1 was used. Please follow the instructions on the 39 | website, as simple installations via pip do not work. In particular, the version of pytorch used must match the one of torch-geometric. 40 | 41 | Then install the other dependencies: 42 | ``` 43 | pip install -r requirements.txt 44 | ``` 45 | 46 | ## Dataset generation 47 | 48 | ### Cycle detection 49 | First, download the data from https://drive.switch.ch/index.php/s/hv65hmY48GrRAoN 50 | and unzip it in data/datasets_kcycle_nsamples=10000. Then, run 51 | 52 | ``` 53 | python3 datasets_generation/build_cycles.py 54 | ``` 55 | 56 | ### Multi-task regression 57 | Simply run 58 | ``` 59 | python -m datasets_generation.multitask_dataset 60 | ``` 61 | 62 | ### ZINC 63 | We use the pytorch-geometric downloader, there should be nothing to to by hand. 64 | ## Folder structure 65 | 66 | - Each task is launched by running the corresponding *main* file (cycles_main, zinc_main, multi_task_main). 67 | - The model parameters can be changed in the associated config.yaml file, while training parameters are modified 68 | with command line arguments. 69 | - The model used for each task is located in the model folder (model_cycles, 70 | model_multi_task, model_zinc). 71 | - They all use some of the SMP layers parametrized in the smp_layers file. 72 | - All SMP layers use the same set of base functions in models/utils/layers.py. These functions map tensors of one order 73 | to tensors of another order using a predefined set of equivariant transformations. 74 | 75 | ## Train 76 | 77 | ### Cycle detection 78 | 79 | In order to train SMP, specify the cycle length, the size of the graphs that is used, and potentially the proportion of the training data 80 | that is kept. For example, 81 | ``` 82 | python3 cycle_main.py --k 4 --n 12 --proportion 1.0 --gpu 0 83 | ``` 84 | will train the 4-cycle on graph with on average 12 nodes on 1.0 * 100 = 100% of the training data. 85 | 86 | In order to run another model, modify models.config.yaml. To run a MPNN that has the 87 | same architecture as SMP, set use_x=True in this file. 88 | 89 | For MPNN and GIN, transforms can be specified in order to add a one-hot encoding of the node degrees, 90 | or one-hot identifiers. The available options can be seen by using 91 | ``` 92 | python3 cycles_main.py --help 93 | ``` 94 | 95 | ### Multi-task regression 96 | 97 | Specify the configuration in the file `config_multi_task.yaml`, and the the available options by using 98 | ``` 99 | python3 multi_task_main.py --help 100 | ``` 101 | To use default parameters, simply run: 102 | ``` 103 | python3 multi_task_main.py --gpu 0 104 | ``` 105 | 106 | ### ZINC 107 | 108 | The ZINC dataset is downloaded through pytorch geometric, but the destination folder should be specified at 109 | the beginning of `zinc_main.py`. Model parameters can be changed in `config_zinc.yaml`. 110 | 111 | To use default parameters, simply run: 112 | ``` 113 | python3 zinc_main.py --gpu 0 114 | ``` 115 | 116 | ## Use SMP on new data 117 | 118 | This code is currently not available as a library, so you will need to copy-paste files to adapt it to your 119 | own data. 120 | While most of the code can be reused, you may need to adapt the model to your own problem. We advise you to look at the 121 | different model files (model_cycles, model_multi_task, model_zinc) to see how they are built. They all follow the same 122 | design: 123 | - A local context is first created using the functions in models.utils.misc. If you have node features that 124 | you wish to use in SMP, use `map_x_to_u` to include them in the local contexts. 125 | - One of the three SMP layers (SMP, FastSMP, SimplifiedFastSMP) is used at each layer to update the local context. 126 | Then either some node-level features or some graph-level features are extracted. For this purpose, you can use 127 | the `NodeExtractor` and `GraphExtractor` classes in `models.utils.layers.py`. 128 | - The extracted features are processed by a standard neural network. You can use a multi-layer perceptron here, or 129 | a more complex structure such as a Gated Recurrent Network that will take as input the features extracted at 130 | each layer. 131 | 132 | To sum up, you need to copy the following files to your own folder: 133 | - models.smp_layers.py 134 | - models.utils.layers.py and models.utils.misc.py 135 | 136 | and to adapt the following files to your own problem: 137 | - the main file (e.g. zinc_main.py) 138 | - the config file (e.g. config_zinc.yaml) 139 | - the model file (e.g. models/model_zinc.py) 140 | 141 | We advise you to use the "weights and biases" library as well, as we found it very convenient to store results. 142 | 143 | ## License 144 | MIT 145 | 146 | ## Cite this paper 147 | 148 | @inproceedings{NEURIPS2020_a32d7eea, 149 | author = {Vignac, Cl\'{e}ment and Loukas, Andreas and Frossard, Pascal}, 150 | booktitle = {Advances in Neural Information Processing Systems}, 151 | editor = {H. Larochelle and M. Ranzato and R. Hadsell and M. F. Balcan and H. Lin}, 152 | pages = {14143--14155}, 153 | publisher = {Curran Associates, Inc.}, 154 | title = {Building powerful and equivariant graph neural networks with structural message-passing}, 155 | url = {https://proceedings.neurips.cc/paper/2020/file/a32d7eeaae19821fd9ce317f3ce952a7-Paper.pdf}, 156 | volume = {33}, 157 | year = {2020} 158 | } 159 | 160 | 161 | -------------------------------------------------------------------------------- /config_cycles.yaml: -------------------------------------------------------------------------------- 1 | # Model properties 2 | model_name: GIN # PPGN, SMP, RING_GNN or GIN 3 | num_towers: 1 4 | hidden: 32 5 | hidden_final: 128 6 | dropout_prob: 0.5 7 | num_classes: 2 8 | use_x: False # Use_x is used for ablation studies 9 | num_layers: -1 # If None, set n_layers = k 10 | 11 | # Options specific to SMP 12 | layer_type: FastSMP 13 | simplified: False 14 | 15 | # Options specific to GIN 16 | one_hot: False # Use a one-hot encoding of the degree as node features 17 | identifiers: False # Use a one hot encoding of the nodes as node features 18 | random: False # Use random identifiers as node features 19 | relational_pooling: 0 # if == p > 0, sum over p random permutations of the nodes 20 | -------------------------------------------------------------------------------- /config_multi_task.yaml: -------------------------------------------------------------------------------- 1 | # Model properties 2 | model_name: SMP 3 | num_layers: 8 4 | hidden_u: 64 5 | num_towers: 8 6 | out_u: 32 7 | hidden_gru: 16 8 | layer_type: SMP # SMP or FastSMP -------------------------------------------------------------------------------- /config_zinc.yaml: -------------------------------------------------------------------------------- 1 | # Model properties 2 | hidden: 32 # internal representation 3 | num_towers: 8 # used within each SMP layer 4 | hidden_final: 128 # Extracted feature 5 | num_layers: 12 6 | use_x: False # used for ablation study 7 | use_batch_norm: True 8 | map_x_to_u: True # map the initial node features to the local context 9 | simplified: False # less layers in the feature extractor 10 | residual: False # residual connections when transorming local contexts 11 | use_edge_features: True 12 | shared_extractor: True # share the feature extractor across layers -------------------------------------------------------------------------------- /cycles_main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import os 5 | import torch 6 | import torch.nn.functional as F 7 | from torch_geometric.data import DataLoader 8 | from torch_geometric.transforms import OneHotDegree 9 | import argparse 10 | import numpy as np 11 | import time 12 | import yaml 13 | from models.model_cycles import SMP 14 | from models.gin import GIN 15 | from datasets_generation.build_cycles import FourCyclesDataset 16 | from models.utils.transforms import EyeTransform, RandomId, DenseAdjMatrix 17 | from models import ppgn 18 | from models.ring_gnn import RingGNN 19 | from easydict import EasyDict as edict 20 | 21 | 22 | # Change the following to point to the the folder where the datasets are stored 23 | if os.path.isdir('/datasets2/'): 24 | rootdir = '/datasets2/CYCLE_DETECTION/' 25 | else: 26 | rootdir = './data/datasets_kcycle_nsamples=10000/' 27 | yaml_file = './config_cycles.yaml' 28 | # yaml_file = './benchmark/kernel/config4cycles.yaml' 29 | torch.manual_seed(0) 30 | np.random.seed(0) 31 | 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--epochs', type=int, default=300) 34 | parser.add_argument('--k', type=int, default=4, 35 | help="Length of the cycles to detect") 36 | parser.add_argument('--n', type=int, help='Average number of nodes in the graphs') 37 | parser.add_argument('--save-model', action='store_true', 38 | help='Save the model once training is done') 39 | parser.add_argument('--wandb', action='store_true', 40 | help="Use weights and biases library") 41 | parser.add_argument('--gpu', type=int, help='Id of gpu device. By default use cpu') 42 | parser.add_argument('--lr', type=float, default=0.001, help="Initial learning rate") 43 | parser.add_argument('--batch-size', type=int, default=16) 44 | parser.add_argument('--weight-decay', type=float, default=1e-4) 45 | parser.add_argument('--clip', type=float, default=10, help="Gradient clipping") 46 | parser.add_argument('--name', type=str, help="Name for weights and biases") 47 | parser.add_argument('--proportion', type=float, default=1.0, 48 | help='Proportion of the training data that is kept') 49 | parser.add_argument('--generalization', action='store_true', 50 | help='Evaluate out of distribution accuracy') 51 | args = parser.parse_args() 52 | 53 | # Log parameters 54 | test_every_epoch = 5 55 | print_every_epoch = 1 56 | log_interval = 20 57 | 58 | # Store maximum number of nodes for each pair (k, n) -- this value is used by provably powerful graph networks 59 | max_num_nodes = {4: {12: 12, 20: 20, 28: 28, 36: 36}, 60 | 6: {20: 25, 31: 38, 42: 52, 56: 65}, 61 | 8: {28: 38, 50: 56, 66: 76, 72: 90}} 62 | # Store the maximum degree for the one-hot encoding 63 | max_degree = {4: {12: 4, 20: 6, 28: 7, 36: 7}, 64 | 6: {20: 4, 31: 6, 42: 8, 56: 7}, 65 | 8: {28: 4, 50: 6, 66: 7, 72: 8}} 66 | # Store the values of n to use for generalization experiments 67 | n_gener = {4: {'train': 20, 'val': 28, 'test': 36}, 68 | 6: {'train': 31, 'val': 42, 'test': 56}, 69 | 8: {'train': 50, 'val': 66, 'test': 72}} 70 | 71 | # Handle the device 72 | use_cuda = args.gpu is not None and torch.cuda.is_available() 73 | if use_cuda: 74 | device = torch.device("cuda:" + str(args.gpu)) 75 | torch.cuda.set_device(args.gpu) 76 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) 77 | else: 78 | device = "cpu" 79 | args.device = device 80 | args.kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} 81 | print('Device used:', device) 82 | 83 | # Load the config file of the model 84 | with open(yaml_file) as f: 85 | config = yaml.load(f, Loader=yaml.FullLoader) 86 | config['map_x_to_u'] = False # Not used here 87 | config = edict(config) 88 | print(config) 89 | 90 | model_name = config['model_name'] 91 | 92 | config.pop('model_name') 93 | if model_name == 'SMP': 94 | model_name = config['layer_type'] 95 | 96 | if args.name is None: 97 | if model_name != 'GIN': 98 | args.name = model_name 99 | else: 100 | if config.relational_pooling > 0: 101 | args.name = 'RP' 102 | elif config.one_hot: 103 | args.name = 'OneHotDeg' 104 | elif config.identifiers: 105 | args.name = 'OneHotNod' 106 | elif config.random: 107 | args.name = 'Random' 108 | else: 109 | args.name = 'GIN' 110 | args.name = args.name + '_' + str(args.k) 111 | if args.n is not None: 112 | args.name = args.name + '_' + str(args.n) 113 | 114 | # Create a folder for the saved models 115 | if not os.path.isdir('./saved_models/' + args.name) and args.generalization: 116 | os.mkdir('./saved_models/' + args.name) 117 | 118 | 119 | if args.name: 120 | args.wandb = True 121 | if args.wandb: 122 | import wandb 123 | wandb.init(project="smp", config=config, name=args.name) 124 | wandb.config.update(args) 125 | 126 | if args.n is None: 127 | args.n = n_gener[args.k]['train'] 128 | 129 | if config.num_layers == -1: 130 | config.num_layers = args.k 131 | 132 | 133 | def train(epoch): 134 | """ Train for one epoch. """ 135 | model.train() 136 | lr_scheduler(args.lr, epoch, optimizer) 137 | loss_all = 0 138 | if not config.relational_pooling: 139 | for batch_idx, data in enumerate(train_loader): 140 | data = data.to(device) 141 | optimizer.zero_grad() 142 | output = model(data) 143 | loss = F.nll_loss(output, data.y) 144 | loss.backward() 145 | loss_all += loss.item() * data.num_graphs 146 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) 147 | optimizer.step() 148 | return loss_all / len(train_loader.dataset) 149 | else: 150 | # For relational pooling, we sample several permutations of each graph 151 | for batch_idx, data in enumerate(train_loader): 152 | for repetition in range(config.relational_pooling): 153 | for i in range(args.batch_size): 154 | n_nodes = int(torch.sum(data.batch == i).item()) 155 | p = torch.randperm(n_nodes) 156 | data.x[data.batch == i, :n_nodes] = data.x[data.batch == i, :n_nodes][p, :][:, p] 157 | data = data.to(device) 158 | optimizer.zero_grad() 159 | output = model(data) 160 | loss = F.nll_loss(output, data.y) 161 | loss.backward() 162 | loss_all += loss.item() * data.num_graphs 163 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) 164 | optimizer.step() 165 | return loss_all / len(train_loader.dataset) 166 | 167 | 168 | def test(loader): 169 | model.eval() 170 | correct = 0 171 | for data in loader: 172 | data = data.to(device) 173 | output = model(data) 174 | pred = output.max(dim=1)[1] 175 | correct += pred.eq(data.y).sum().item() 176 | return correct / len(loader.dataset) 177 | 178 | 179 | def lr_scheduler(lr, epoch, optimizer): 180 | for param_group in optimizer.param_groups: 181 | param_group['lr'] = lr * (0.995 ** (epoch / 5)) 182 | 183 | 184 | # Define the transform to use in the dataset 185 | transform=None 186 | if 'GIN' or 'RP' in model_name: 187 | if config.one_hot: 188 | # Cannot always be used in an inductive setting, 189 | # because the maximal degree might be bigger than during training 190 | degree = max_degree[args.k][args.n] 191 | transform = OneHotDegree(degree, cat=False) 192 | config.num_input_features = degree + 1 193 | elif config.identifiers: 194 | # Cannot be used in an inductive setting 195 | transform = EyeTransform(max_num_nodes[args.k][args.n]) 196 | config.num_input_features = max_num_nodes[args.k][args.n] 197 | elif config.random: 198 | # Can be used in an inductive setting 199 | transform = RandomId() 200 | transform_val = RandomId() 201 | transform_test = RandomId() 202 | config.num_input_features = 1 203 | 204 | if transform is None: 205 | transform_val = None 206 | transform_test = None 207 | config.num_input_features = 1 208 | 209 | if 'SMP' in model_name: 210 | config.use_batch_norm = args.k > 6 or args.n > 30 211 | model = SMP(config.num_input_features, config.num_classes, config.num_layers, config.hidden, config.layer_type, 212 | config.hidden_final, config.dropout_prob, config.use_batch_norm, config.use_x, config.map_x_to_u, 213 | config.num_towers, config.simplified).to(device) 214 | 215 | elif model_name == 'PPGN': 216 | transform = DenseAdjMatrix(max_num_nodes[args.k][args.n]) 217 | transform_val = DenseAdjMatrix(max_num_nodes[args.k][n_gener[args.k]['val']]) 218 | transform_test = DenseAdjMatrix(max_num_nodes[args.k][n_gener[args.k]['test']]) 219 | model = ppgn.Powerful(config.num_classes, config.num_layers, config.hidden, 220 | config.hidden_final, config.dropout_prob, config.simplified) 221 | elif model_name == 'GIN': 222 | config.use_batch_norm = args.k > 6 or args.n > 50 223 | model = GIN(config.num_input_features, config.num_classes, config.num_layers, 224 | config.hidden, config.hidden_final, config.dropout_prob, config.use_batch_norm) 225 | elif model_name == 'RING_GNN': 226 | transform = DenseAdjMatrix(max_num_nodes[args.k][args.n]) 227 | transform_val = DenseAdjMatrix(max_num_nodes[args.k][n_gener[args.k]['val']]) 228 | transform_test = DenseAdjMatrix(max_num_nodes[args.k][n_gener[args.k]['test']]) 229 | model = RingGNN(config.num_classes, config.num_layers, config.hidden, config.hidden_final, config.dropout_prob, 230 | config.simplified) 231 | 232 | model = model.to(device) 233 | 234 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.5, weight_decay=args.weight_decay) 235 | # Load the data 236 | print("Transform used:", transform) 237 | 238 | batch_size = args.batch_size 239 | if args.generalization: 240 | train_data = FourCyclesDataset(args.k, n_gener[args.k]['train'], rootdir, train=True, transform=transform) 241 | test_data = FourCyclesDataset(args.k, n_gener[args.k]['train'], rootdir, train=False, transform=transform) 242 | gener_data_val = FourCyclesDataset(args.k, n_gener[args.k]['val'], rootdir, train=False, transform=transform_val) 243 | train_loader = DataLoader(train_data, batch_size, shuffle=True) 244 | test_loader = DataLoader(test_data, batch_size, shuffle=False) 245 | gener_val_loader = DataLoader(gener_data_val, batch_size, shuffle=False) 246 | 247 | else: 248 | train_data = FourCyclesDataset(args.k, args.n, rootdir, proportion=args.proportion, train=True, transform=transform) 249 | test_data = FourCyclesDataset(args.k, args.n, rootdir, proportion=args.proportion, train=False, transform=transform) 250 | train_loader = DataLoader(train_data, batch_size, shuffle=True) 251 | test_loader = DataLoader(test_data, batch_size, shuffle=False) 252 | 253 | print("Starting to train") 254 | start = time.time() 255 | best_epoch = -1 256 | best_generalization_acc = 0 257 | for epoch in range(args.epochs): 258 | epoch_start = time.time() 259 | tr_loss = train(epoch) 260 | if epoch % print_every_epoch == 0: 261 | acc_train = test(train_loader) 262 | current_lr = optimizer.param_groups[0]["lr"] 263 | duration = time.time() - epoch_start 264 | print(f'Time:{duration:2.2f} | {epoch:5d} | Loss: {tr_loss:2.5f} | Train Acc: {acc_train:2.5f} | LR: {current_lr:.6f}') 265 | if epoch % test_every_epoch == 0: 266 | acc_test = test(test_loader) 267 | print(f'Test accuracy: {acc_test:2.5f}') 268 | if args.generalization: 269 | acc_generalization = test(gener_val_loader) 270 | print("Validation generalization accuracy", acc_generalization) 271 | if args.wandb: 272 | wandb.log({"Epoch": epoch, "Duration": duration, "Train loss": tr_loss, "train accuracy": acc_train, 273 | "Test acc": acc_test, 'Gene eval': acc_generalization}) 274 | if acc_generalization > best_generalization_acc: 275 | print(f"New best generalization error + accuracy > 90% at epoch {epoch}") 276 | # Remove existing models 277 | folder = f'./saved_models/{args.name}/' 278 | files_in_folder = [f for f in os.listdir(folder) if os.path.isfile(os.path.join(folder, f))] 279 | for file in files_in_folder: 280 | try: 281 | os.remove(folder + file) 282 | except: 283 | print("Could not remove file", file) 284 | # Save new model 285 | torch.save(model, f'./saved_models/{args.name}/epoch{epoch}.pkl') 286 | print(f"Model saved at epoch {epoch}.") 287 | best_epoch = epoch 288 | else: 289 | if args.wandb: 290 | wandb.log({"Epoch": epoch, "Duration": duration, "Train loss": tr_loss, "train accuracy": acc_train, 291 | "Test acc": acc_test}) 292 | else: 293 | if args.wandb: 294 | wandb.log({"Epoch": epoch, "Duration": duration, "Train loss": tr_loss, "train accuracy": acc_train}) 295 | 296 | cur_lr = optimizer.param_groups[0]["lr"] 297 | print(f'{epoch:2.5f} | Loss: {tr_loss:2.5f} | Train Acc: {acc_train:2.5f} | LR: {cur_lr:.6f} | Test Acc: {acc_test:2.5f}') 298 | print(f'Elapsed time: {(time.time() - start) / 60:.1f} minutes') 299 | print('done!') 300 | 301 | final_acc = test(test_loader) 302 | print(f"Final accuracy: {final_acc}") 303 | print("Done.") 304 | 305 | if args.generalization: 306 | new_n = n_gener[args.k]['test'] 307 | gener_data_test = FourCyclesDataset(args.k, new_n, rootdir, train=False, transform=transform_test) 308 | gener_test_loader = DataLoader(gener_data_test, batch_size, shuffle=False) 309 | model = torch.load(f"./saved_models/{args.name}/epoch{best_epoch}.pkl", map_location=device) 310 | model.eval() 311 | acc_test_generalization = test(gener_test_loader) 312 | print(f"Generalization accuracy on {args.k} cycles with {new_n} nodes", acc_test_generalization) 313 | if args.wandb: 314 | wandb.run.summary['test_generalization'] = acc_test_generalization 315 | -------------------------------------------------------------------------------- /data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvignac/SMP/95b55a880d0fc9149ddf32e8c2fdf5eac5b474b3/data/.DS_Store -------------------------------------------------------------------------------- /data/datasets_kcycle_nsamples=10000/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvignac/SMP/95b55a880d0fc9149ddf32e8c2fdf5eac5b474b3/data/datasets_kcycle_nsamples=10000/.DS_Store -------------------------------------------------------------------------------- /data/multitask_dataset.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvignac/SMP/95b55a880d0fc9149ddf32e8c2fdf5eac5b474b3/data/multitask_dataset.pkl -------------------------------------------------------------------------------- /datasets_generation/__pycache__/build_cycles.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvignac/SMP/95b55a880d0fc9149ddf32e8c2fdf5eac5b474b3/datasets_generation/__pycache__/build_cycles.cpython-37.pyc -------------------------------------------------------------------------------- /datasets_generation/__pycache__/graph_algorithms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvignac/SMP/95b55a880d0fc9149ddf32e8c2fdf5eac5b474b3/datasets_generation/__pycache__/graph_algorithms.cpython-37.pyc -------------------------------------------------------------------------------- /datasets_generation/__pycache__/graph_generation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvignac/SMP/95b55a880d0fc9149ddf32e8c2fdf5eac5b474b3/datasets_generation/__pycache__/graph_generation.cpython-37.pyc -------------------------------------------------------------------------------- /datasets_generation/__pycache__/multitask_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvignac/SMP/95b55a880d0fc9149ddf32e8c2fdf5eac5b474b3/datasets_generation/__pycache__/multitask_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /datasets_generation/build_cycles.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pickle 4 | from torch_geometric.data import InMemoryDataset, Data 5 | import numpy as np 6 | import networkx as nx 7 | import numpy.random as npr 8 | 9 | 10 | if os.path.isdir('/datasets2/'): 11 | rootdir = '/datasets2/CYCLE_DETECTION/' 12 | else: 13 | rootdir = './data/datasets_kcycle_nsamples=10000/' 14 | 15 | 16 | def build_dataset(): 17 | """ Given pickle files, split the dataset into one per value of n 18 | Run once before running the experiments. """ 19 | n_samples = 10000 20 | for k in [4, 6, 8]: 21 | with open(os.path.join(rootdir, 'datasets_kcycle_k={}_nsamples=10000.pickle'.format(k)), 'rb') as f: 22 | datasets_params, datasets = pickle.load(f) 23 | # Split by graph size 24 | for params, dataset in zip(datasets_params, datasets): 25 | n = params['n'] 26 | train, test = dataset[:n_samples], dataset[n_samples:] 27 | torch.save(train, rootdir + f'{k}cycles_n{n}_{n_samples}samples_train.pt') 28 | torch.save(test, rootdir + f'/{k}cycles_n{n}_{n_samples}samples_test.pt') 29 | # torch.save(test, '{}cycles_n{}_{}samples_test.pt'.format(k, n, n_samples)) 30 | 31 | 32 | class FourCyclesDataset(InMemoryDataset): 33 | def __init__(self, k, n, root, train, proportion=1.0, n_samples=10000, transform=None, pre_transform=None): 34 | self.train = train 35 | self.k, self.n, self.n_samples = k, n, n_samples 36 | self.root = root 37 | self.s = 'train' if train else 'test' 38 | self.proportion = proportion 39 | super().__init__(root, transform, pre_transform) 40 | self.data, self.slices = torch.load(self.processed_paths[0]) 41 | 42 | @property 43 | def raw_file_names(self): 44 | return ['{}cycles_n{}_{}samples_{}.pt'.format(self.k, self.n, self.n_samples, self.s)] 45 | 46 | @property 47 | def processed_file_names(self): 48 | if self.transform is None: 49 | st = 'no-transf' 50 | else: 51 | st = str(self.transform.__class__.__name__) 52 | return [f'processed_{self.k}cycles_n{self.n}_{self.n_samples}samples_{self.s}_{st}_{self.proportion}.pt'] 53 | 54 | def download(self): 55 | # Download to `self.raw_dir`. 56 | pass 57 | 58 | def process(self): 59 | # Read data into huge `Data` list. 60 | dataset = torch.load(os.path.join(self.root, f'{self.k}cycles_n{self.n}_{self.n_samples}samples_{self.s}.pt')) 61 | 62 | data_list = [] 63 | for sample in dataset: 64 | graph, y, label = sample 65 | edge_list = nx.to_edgelist(graph) 66 | edges = [np.array([edge[0], edge[1]]) for edge in edge_list] 67 | edges2 = [np.array([edge[1], edge[0]]) for edge in edge_list] 68 | 69 | edge_index = torch.tensor(np.array(edges + edges2).T, dtype=torch.long) 70 | 71 | x = torch.ones(graph.number_of_nodes(), 1, dtype=torch.float) 72 | y = torch.tensor([1], dtype=torch.long) if label == 'has-kcycle' else torch.tensor([0], dtype=torch.long) 73 | 74 | data_list.append(Data(x=x, edge_index=edge_index, edge_attr=None, y=y)) 75 | # Subsample the data 76 | if self.train: 77 | all_data = len(data_list) 78 | to_select = int(all_data * self.proportion) 79 | print(to_select, "samples were selected") 80 | data_list = data_list[:to_select] 81 | data, slices = self.collate(data_list) 82 | torch.save((data, slices), self.processed_paths[0]) 83 | 84 | 85 | if __name__ == '__main__': 86 | build_dataset() -------------------------------------------------------------------------------- /datasets_generation/graph_algorithms.py: -------------------------------------------------------------------------------- 1 | import math 2 | from queue import Queue 3 | 4 | import numpy as np 5 | 6 | 7 | def is_connected(A): 8 | """ 9 | :param A:np.array the adjacency matrix 10 | :return:bool whether the graph is connected or not 11 | """ 12 | for _ in range(int(1 + math.ceil(math.log2(A.shape[0])))): 13 | A = np.dot(A, A) 14 | return np.min(A) > 0 15 | 16 | 17 | def identity(A, F): 18 | """ 19 | :param A:np.array the adjacency matrix 20 | :param F:np.array the nodes features 21 | :return:F 22 | """ 23 | return F 24 | 25 | 26 | def first_neighbours(A): 27 | """ 28 | :param A:np.array the adjacency matrix 29 | :param F:np.array the nodes features 30 | :return: for each node, the number of nodes reachable in 1 hop 31 | """ 32 | return np.sum(A > 0, axis=0) 33 | 34 | 35 | def second_neighbours(A): 36 | """ 37 | :param A:np.array the adjacency matrix 38 | :param F:np.array the nodes features 39 | :return: for each node, the number of nodes reachable in no more than 2 hops 40 | """ 41 | A = A > 0.0 42 | A = A + np.dot(A, A) 43 | np.fill_diagonal(A, 0) 44 | return np.sum(A > 0, axis=0) 45 | 46 | 47 | def kth_neighbours(A, k): 48 | """ 49 | :param A:np.array the adjacency matrix 50 | :param F:np.array the nodes features 51 | :return: for each node, the number of nodes reachable in k hops 52 | """ 53 | A = A > 0.0 54 | R = np.zeros(A.shape) 55 | for _ in range(k): 56 | R = np.dot(R, A) + A 57 | np.fill_diagonal(R, 0) 58 | return np.sum(R > 0, axis=0) 59 | 60 | 61 | def map_reduce_neighbourhood(A, F, f_reduce, f_map=None, hops=1, consider_itself=False): 62 | """ 63 | :param A:np.array the adjacency matrix 64 | :param F:np.array the nodes features 65 | :return: for each node, map its neighbourhood with f_map, and reduce it with f_reduce 66 | """ 67 | if f_map is not None: 68 | F = f_map(F) 69 | A = np.array(A) 70 | 71 | A = A > 0 72 | R = np.zeros(A.shape) 73 | for _ in range(hops): 74 | R = np.dot(R, A) + A 75 | np.fill_diagonal(R, 1 if consider_itself else 0) 76 | R = R > 0 77 | 78 | return np.array([f_reduce(F[R[i]]) for i in range(A.shape[0])]) 79 | 80 | 81 | def max_neighbourhood(A, F): 82 | """ 83 | :param A:np.array the adjacency matrix 84 | :param F:np.array the nodes features 85 | :return: for each node, the maximum in its neighbourhood 86 | """ 87 | return map_reduce_neighbourhood(A, F, np.max, consider_itself=True) 88 | 89 | 90 | def min_neighbourhood(A, F): 91 | """ 92 | :param A:np.array the adjacency matrix 93 | :param F:np.array the nodes features 94 | :return: for each node, the minimum in its neighbourhood 95 | """ 96 | return map_reduce_neighbourhood(A, F, np.min, consider_itself=True) 97 | 98 | 99 | def std_neighbourhood(A, F): 100 | """ 101 | :param A:np.array the adjacency matrix 102 | :param F:np.array the nodes features 103 | :return: for each node, the standard deviation of its neighbourhood 104 | """ 105 | return map_reduce_neighbourhood(A, F, np.std, consider_itself=True) 106 | 107 | 108 | def mean_neighbourhood(A, F): 109 | """ 110 | :param A:np.array the adjacency matrix 111 | :param F:np.array the nodes features 112 | :return: for each node, the mean of its neighbourhood 113 | """ 114 | return map_reduce_neighbourhood(A, F, np.mean, consider_itself=True) 115 | 116 | 117 | def local_maxima(A, F): 118 | """ 119 | :param A:np.array the adjacency matrix 120 | :param F:np.array the nodes features 121 | :return: for each node, whether it is the maximum in its neighbourhood 122 | """ 123 | return F == map_reduce_neighbourhood(A, F, np.max, consider_itself=True) 124 | 125 | 126 | def graph_laplacian(A): 127 | """ 128 | :param A:np.array the adjacency matrix 129 | :return: the laplacian of the adjacency matrix 130 | """ 131 | L = (A > 0) * -1 132 | np.fill_diagonal(L, np.sum(A > 0, axis=0)) 133 | return L 134 | 135 | 136 | def graph_laplacian_features(A, F): 137 | """ 138 | :param A:np.array the adjacency matrix 139 | :param F:np.array the nodes features 140 | :return: the laplacian of the adjacency matrix multiplied by the features 141 | """ 142 | return np.matmul(graph_laplacian(A), F) 143 | 144 | 145 | def isomorphism(A1, A2, F1=None, F2=None): 146 | """ 147 | Takes two adjacency matrices (A1,A2) and (optionally) two lists of features. It uses Weisfeiler-Lehman algorithms, so false positives might arise 148 | :param A1: adj_matrix, N*N numpy matrix 149 | :param A2: adj_matrix, N*N numpy matrix 150 | :param F1: node_values, numpy array of size N 151 | :param F1: node_values, numpy array of size N 152 | :return: isomorphic: boolean which is false when the two graphs are not isomorphic, true when they probably are. 153 | """ 154 | N = A1.shape[0] 155 | if (F1 is None) ^ (F2 is None): 156 | raise ValueError("either both or none between F1,F2 must be defined.") 157 | if F1 is None: 158 | # Assign same initial value to each node 159 | F1 = np.ones(N, int) 160 | F2 = np.ones(N, int) 161 | else: 162 | if not np.array_equal(np.sort(F1), np.sort(F2)): 163 | return False 164 | if F1.dtype() != int: 165 | raise NotImplementedError('Still have to implement this') 166 | 167 | p = 1000000007 168 | 169 | def mapping(F): 170 | return (F * 234 + 133) % 1000000007 171 | 172 | def adjacency_hash(F): 173 | F = np.sort(F) 174 | b = 257 175 | 176 | h = 0 177 | for f in F: 178 | h = (b * h + f) % 1000000007 179 | return h 180 | 181 | for i in range(N): 182 | F1 = map_reduce_neighbourhood(A1, F1, adjacency_hash, f_map=mapping, consider_itself=True, hops=1) 183 | F2 = map_reduce_neighbourhood(A2, F2, adjacency_hash, f_map=mapping, consider_itself=True, hops=1) 184 | if not np.array_equal(np.sort(F1), np.sort(F2)): 185 | return False 186 | return True 187 | 188 | 189 | def count_edges(A): 190 | """ 191 | :param A:np.array the adjacency matrix 192 | :return: the number of edges in the graph 193 | """ 194 | return np.sum(A) / 2 195 | 196 | 197 | def is_eulerian_cyclable(A): 198 | """ 199 | :param A:np.array the adjacency matrix 200 | :return: whether the graph has an eulerian cycle 201 | """ 202 | return is_connected(A) and np.count_nonzero(first_neighbours(A) % 2 == 1) == 0 203 | 204 | 205 | def is_eulerian_percorrible(A): 206 | """ 207 | :param A:np.array the adjacency matrix 208 | :return: whether the graph has an eulerian path 209 | """ 210 | return is_connected(A) and np.count_nonzero(first_neighbours(A) % 2 == 1) in [0, 2] 211 | 212 | 213 | def map_reduce_graph(A, F, f_reduce): 214 | """ 215 | :param A:np.array the adjacency matrix 216 | :param F:np.array the nodes features 217 | :return: the features of the nodes reduced by f_reduce 218 | """ 219 | return f_reduce(F) 220 | 221 | 222 | def mean_graph(A, F): 223 | """ 224 | :param A:np.array the adjacency matrix 225 | :param F:np.array the nodes features 226 | :return: the mean of the features 227 | """ 228 | return map_reduce_graph(A, F, np.mean) 229 | 230 | 231 | def max_graph(A, F): 232 | """ 233 | :param A:np.array the adjacency matrix 234 | :param F:np.array the nodes features 235 | :return: the maximum of the features 236 | """ 237 | return map_reduce_graph(A, F, np.max) 238 | 239 | 240 | def min_graph(A, F): 241 | """ 242 | :param A:np.array the adjacency matrix 243 | :param F:np.array the nodes features 244 | :return: the minimum of the features 245 | """ 246 | return map_reduce_graph(A, F, np.min) 247 | 248 | 249 | def std_graph(A, F): 250 | """ 251 | :param A:np.array the adjacency matrix 252 | :param F:np.array the nodes features 253 | :return: the standard deviation of the features 254 | """ 255 | return map_reduce_graph(A, F, np.std) 256 | 257 | 258 | def has_hamiltonian_cycle(A): 259 | """ 260 | :param A:np.array the adjacency matrix 261 | :return:bool whether the graph has an hamiltonian cycle 262 | """ 263 | A += np.transpose(A) 264 | A = A > 0 265 | V = A.shape[0] 266 | 267 | def ham_cycle_loop(pos): 268 | if pos == V: 269 | if A[path[pos - 1]][path[0]]: 270 | return True 271 | else: 272 | return False 273 | for v in range(1, V): 274 | if A[path[pos - 1]][v] and not used[v]: 275 | path[pos] = v 276 | used[v] = True 277 | if ham_cycle_loop(pos + 1): 278 | return True 279 | path[pos] = -1 280 | used[v] = False 281 | return False 282 | 283 | used = [False] * V 284 | path = [-1] * V 285 | path[0] = 0 286 | 287 | return ham_cycle_loop(1) 288 | 289 | 290 | def all_pairs_shortest_paths(A, inf_sub=math.inf): 291 | """ 292 | :param A:np.array the adjacency matrix 293 | :param inf_sub: the placeholder value to use for pairs which are not connected 294 | :return:np.array all pairs shortest paths 295 | """ 296 | A = np.array(A) 297 | N = A.shape[0] 298 | for i in range(N): 299 | for j in range(N): 300 | if A[i][j] == 0: 301 | A[i][j] = math.inf 302 | if i == j: 303 | A[i][j] = 0 304 | 305 | for k in range(N): 306 | for i in range(N): 307 | for j in range(N): 308 | A[i][j] = min(A[i][j], A[i][k] + A[k][j]) 309 | 310 | A = np.where(A == math.inf, inf_sub, A) 311 | return A 312 | 313 | 314 | def diameter(A): 315 | """ 316 | :param A:np.array the adjacency matrix 317 | :return: the diameter of the gra[h 318 | """ 319 | sum = np.sum(A) 320 | apsp = all_pairs_shortest_paths(A) 321 | apsp = np.where(apsp < sum + 1, apsp, -1) 322 | return np.max(apsp) 323 | 324 | 325 | def eccentricity(A): 326 | """ 327 | :param A:np.array the adjacency matrix 328 | :return: the eccentricity of the gra[h 329 | """ 330 | sum = np.sum(A) 331 | apsp = all_pairs_shortest_paths(A) 332 | apsp = np.where(apsp < sum + 1, apsp, -1) 333 | return np.max(apsp, axis=0) 334 | 335 | 336 | def sssp_predecessor(A, F): 337 | """ 338 | :param A:np.array the adjacency matrix 339 | :param F:np.array the nodes features 340 | :return: for each node, the best next step to reach the designated source 341 | """ 342 | assert (np.sum(F) == 1) 343 | assert (np.max(F) == 1) 344 | s = np.argmax(F) 345 | N = A.shape[0] 346 | P = np.zeros(A.shape) 347 | V = np.zeros(N) 348 | bfs = Queue() 349 | bfs.put(s) 350 | V[s] = 1 351 | while not bfs.empty(): 352 | u = bfs.get() 353 | for v in range(N): 354 | if A[u][v] > 0 and V[v] == 0: 355 | V[v] = 1 356 | P[v][u] = 1 357 | bfs.put(v) 358 | return P 359 | 360 | 361 | def max_eigenvalue(A): 362 | """ 363 | :param A:np.array the adjacency matrix 364 | :return: the maximum eigenvalue of A 365 | since A is positive symmetric, all the eigenvalues are guaranteed to be real 366 | """ 367 | [W, _] = np.linalg.eig(A) 368 | return W[np.argmax(np.absolute(W))].real 369 | 370 | 371 | def max_eigenvalues(A, k): 372 | """ 373 | :param A:np.array the adjacency matrix 374 | :param k:int the number of eigenvalues to be selected 375 | :return: the k greatest (by absolute value) eigenvalues of A 376 | """ 377 | [W, _] = np.linalg.eig(A) 378 | values = W[sorted(range(len(W)), key=lambda x: -np.absolute(W[x]))[:k]] 379 | return values.real 380 | 381 | 382 | def max_absolute_eigenvalues(A, k): 383 | """ 384 | :param A:np.array the adjacency matrix 385 | :param k:int the number of eigenvalues to be selected 386 | :return: the absolute value of the k greatest (by absolute value) eigenvalues of A 387 | """ 388 | return np.absolute(max_eigenvalues(A, k)) 389 | 390 | 391 | def max_absolute_eigenvalues_laplacian(A, n): 392 | """ 393 | :param A:np.array the adjacency matrix 394 | :param k:int the number of eigenvalues to be selected 395 | :return: the absolute value of the k greatest (by absolute value) eigenvalues of the laplacian of A 396 | """ 397 | A = graph_laplacian(A) 398 | return np.absolute(max_eigenvalues(A, n)) 399 | 400 | 401 | def max_eigenvector(A): 402 | """ 403 | :param A:np.array the adjacency matrix 404 | :return: the maximum (by absolute value) eigenvector of A 405 | since A is positive symmetric, all the eigenvectors are guaranteed to be real 406 | """ 407 | [W, V] = np.linalg.eig(A) 408 | return V[:, np.argmax(np.absolute(W))].real 409 | 410 | 411 | def spectral_radius(A): 412 | """ 413 | :param A:np.array the adjacency matrix 414 | :return: the maximum (by absolute value) eigenvector of A 415 | since A is positive symmetric, all the eigenvectors are guaranteed to be real 416 | """ 417 | return np.abs(max_eigenvalue(A)) 418 | 419 | 420 | def page_rank(A, F=None, iter=64): 421 | """ 422 | :param A:np.array the adjacency matrix 423 | :param F:np.array with initial weights. If None, uniform initialization will happen. 424 | :param iter: log2 of length of power iteration 425 | :return: for each node, its pagerank 426 | """ 427 | 428 | # normalize A rows 429 | A = np.array(A) 430 | A /= A.sum(axis=1)[:, np.newaxis] 431 | 432 | # power iteration 433 | for _ in range(iter): 434 | A = np.matmul(A, A) 435 | 436 | # generate prior distribution 437 | if F is None: 438 | F = np.ones(A.shape[-1]) 439 | else: 440 | F = np.array(F) 441 | 442 | # normalize prior 443 | F /= np.sum(F) 444 | 445 | # compute limit distribution 446 | return np.matmul(F, A) 447 | 448 | 449 | def tsp_length(A, F=None): 450 | """ 451 | :param A:np.array the adjacency matrix 452 | :param F:np.array determining which nodes are to be visited. If None, all of them are. 453 | :return: the length of the Traveling Salesman Problem shortest solution 454 | """ 455 | 456 | A = all_pairs_shortest_paths(A) 457 | N = A.shape[0] 458 | if F is None: 459 | F = np.ones(N) 460 | targets = np.nonzero(F)[0] 461 | T = targets.shape[0] 462 | S = (1 << T) 463 | dp = np.zeros((S, T)) 464 | 465 | def popcount(x): 466 | b = 0 467 | while x > 0: 468 | x &= x - 1 469 | b += 1 470 | return b 471 | 472 | msks = np.argsort(np.vectorize(popcount)(np.arange(S))) 473 | for i in range(T + 1): 474 | for j in range(T): 475 | if (1 << j) & msks[i] == 0: 476 | dp[msks[i]][j] = math.inf 477 | 478 | for i in range(T + 1, S): 479 | msk = msks[i] 480 | for u in range(T): 481 | if (1 << u) & msk == 0: 482 | dp[msk][u] = math.inf 483 | continue 484 | cost = math.inf 485 | for v in range(T): 486 | if v == u or (1 << v) & msk == 0: 487 | continue 488 | cost = min(cost, dp[msk ^ (1 << u)][v] + A[targets[v]][targets[u]]) 489 | dp[msk][u] = cost 490 | return np.min(dp[S - 1]) 491 | 492 | 493 | def get_nodes_labels(A, F): 494 | """ 495 | Takes the adjacency matrix and the list of nodes features (and a list of algorithms) and returns 496 | a set of labels for each node 497 | :param A: adj_matrix, N*N numpy matrix 498 | :param F: node_values, numpy array of size N 499 | :return: labels: KxN numpy matrix where K is the number of labels for each node 500 | """ 501 | labels = [identity(A, F), map_reduce_neighbourhood(A, F, np.mean, consider_itself=True), 502 | map_reduce_neighbourhood(A, F, np.max, consider_itself=True), 503 | map_reduce_neighbourhood(A, F, np.std, consider_itself=True), first_neighbours(A), second_neighbours(A), 504 | eccentricity(A)] 505 | return np.swapaxes(np.stack(labels), 0, 1) 506 | 507 | 508 | def get_graph_labels(A, F): 509 | """ 510 | Takes the adjacency matrix and the list of nodes features (and a list of algorithms) and returns 511 | a set of labels for the whole graph 512 | :param A: adj_matrix, N*N numpy matrix 513 | :param F: node_values, numpy array of size N 514 | :return: labels: numpy array of size K where K is the number of labels for the graph 515 | """ 516 | labels = [diameter(A)] 517 | return np.asarray(labels) 518 | -------------------------------------------------------------------------------- /datasets_generation/graph_generation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import networkx as nx 4 | import math 5 | import matplotlib.pyplot as plt # only required to plot 6 | from enum import Enum 7 | 8 | """ 9 | Generates random graphs of different types of a given size. 10 | Some of the graph are created using the NetworkX library, for more info see 11 | https://networkx.github.io/documentation/networkx-1.10/reference/generators.html 12 | """ 13 | 14 | 15 | class GraphType(Enum): 16 | RANDOM = 0 17 | ERDOS_RENYI = 1 18 | BARABASI_ALBERT = 2 19 | GRID = 3 20 | CAVEMAN = 5 21 | TREE = 6 22 | LADDER = 7 23 | LINE = 8 24 | STAR = 9 25 | CATERPILLAR = 10 26 | LOBSTER = 11 27 | 28 | 29 | # probabilities of each type in case of random type 30 | MIXTURE = [(GraphType.ERDOS_RENYI, 0.2), (GraphType.BARABASI_ALBERT, 0.2), (GraphType.GRID, 0.05), 31 | (GraphType.CAVEMAN, 0.05), (GraphType.TREE, 0.15), (GraphType.LADDER, 0.05), 32 | (GraphType.LINE, 0.05), (GraphType.STAR, 0.05), (GraphType.CATERPILLAR, 0.1), (GraphType.LOBSTER, 0.1)] 33 | 34 | 35 | def erdos_renyi(N, degree, seed): 36 | """ Creates an Erdős-Rényi or binomial graph of size N with degree/N probability of edge creation """ 37 | return nx.fast_gnp_random_graph(N, degree / N, seed, directed=False) 38 | 39 | 40 | def barabasi_albert(N, degree, seed): 41 | """ Creates a random graph according to the Barabási–Albert preferential attachment model 42 | of size N and where nodes are atteched with degree edges """ 43 | return nx.barabasi_albert_graph(N, degree, seed) 44 | 45 | 46 | def grid(N): 47 | """ Creates a m x k 2d grid graph with N = m*k and m and k as close as possible """ 48 | m = 1 49 | for i in range(1, int(math.sqrt(N)) + 1): 50 | if N % i == 0: 51 | m = i 52 | return nx.grid_2d_graph(m, N // m) 53 | 54 | 55 | def caveman(N): 56 | """ Creates a caveman graph of m cliques of size k, with m and k as close as possible """ 57 | m = 1 58 | for i in range(1, int(math.sqrt(N)) + 1): 59 | if N % i == 0: 60 | m = i 61 | return nx.caveman_graph(m, N // m) 62 | 63 | 64 | def tree(N, seed): 65 | """ Creates a tree of size N with a power law degree distribution """ 66 | return nx.random_powerlaw_tree(N, seed=seed, tries=10000) 67 | 68 | 69 | def ladder(N): 70 | """ Creates a ladder graph of N nodes: two rows of N/2 nodes, with each pair connected by a single edge. 71 | In case N is odd another node is attached to the first one. """ 72 | G = nx.ladder_graph(N // 2) 73 | if N % 2 != 0: 74 | G.add_node(N - 1) 75 | G.add_edge(0, N - 1) 76 | return G 77 | 78 | 79 | def line(N): 80 | """ Creates a graph composed of N nodes in a line """ 81 | return nx.path_graph(N) 82 | 83 | 84 | def star(N): 85 | """ Creates a graph composed by one center node connected N-1 outer nodes """ 86 | return nx.star_graph(N - 1) 87 | 88 | 89 | def caterpillar(N, seed): 90 | """ Creates a random caterpillar graph with a backbone of size b (drawn from U[1, N)), and N − b 91 | pendent vertices uniformly connected to the backbone. """ 92 | np.random.seed(seed) 93 | B = np.random.randint(low=1, high=N) 94 | G = nx.empty_graph(N) 95 | for i in range(1, B): 96 | G.add_edge(i - 1, i) 97 | for i in range(B, N): 98 | G.add_edge(i, np.random.randint(B)) 99 | return G 100 | 101 | 102 | def lobster(N, seed): 103 | """ Creates a random Lobster graph with a backbone of size b (drawn from U[1, N)), and p (drawn 104 | from U[1, N − b ]) pendent vertices uniformly connected to the backbone, and additional 105 | N − b − p pendent vertices uniformly connected to the previous pendent vertices """ 106 | np.random.seed(seed) 107 | B = np.random.randint(low=1, high=N) 108 | F = np.random.randint(low=B + 1, high=N + 1) 109 | G = nx.empty_graph(N) 110 | for i in range(1, B): 111 | G.add_edge(i - 1, i) 112 | for i in range(B, F): 113 | G.add_edge(i, np.random.randint(B)) 114 | for i in range(F, N): 115 | G.add_edge(i, np.random.randint(low=B, high=F)) 116 | return G 117 | 118 | 119 | def randomize(A): 120 | """ Adds some randomness by toggling some edges without chancing the expected number of edges of the graph """ 121 | BASE_P = 0.9 122 | 123 | # e is the number of edges, r the number of missing edges 124 | N = A.shape[0] 125 | e = np.sum(A) / 2 126 | r = N * (N - 1) / 2 - e 127 | 128 | # ep chance of an existing edge to remain, rp chance of another edge to appear 129 | if e <= r: 130 | ep = BASE_P 131 | rp = (1 - BASE_P) * e / r 132 | else: 133 | ep = BASE_P + (1 - BASE_P) * (e - r) / e 134 | rp = 1 - BASE_P 135 | 136 | array = np.random.uniform(size=(N, N), low=0.0, high=0.5) 137 | array = array + array.transpose() 138 | remaining = np.multiply(np.where(array < ep, 1, 0), A) 139 | appearing = np.multiply(np.multiply(np.where(array < rp, 1, 0), 1 - A), 1 - np.eye(N)) 140 | ans = np.add(remaining, appearing) 141 | 142 | # assert (np.all(np.multiply(ans, np.eye(N)) == np.zeros((N, N)))) 143 | # assert (np.all(ans >= 0)) 144 | # assert (np.all(ans <= 1)) 145 | # assert (np.all(ans == ans.transpose())) 146 | return ans 147 | 148 | 149 | def generate_graph(N, type=GraphType.RANDOM, seed=None, degree=None): 150 | """ 151 | Generates random graphs of different types of a given size. Note: 152 | - graph are undirected and without weights on edges 153 | - node values are sampled independently from U[0,1] 154 | 155 | :param N: number of nodes 156 | :param type: type chosen between the categories specified in GraphType enum 157 | :param seed: random seed 158 | :param degree: average degree of a node, only used in some graph types 159 | :return: adj_matrix: N*N numpy matrix 160 | node_values: numpy array of size N 161 | """ 162 | random.seed(seed) 163 | np.random.seed(seed) 164 | 165 | # sample which random type to use 166 | if type == GraphType.RANDOM: 167 | type = np.random.choice([t for (t, _) in MIXTURE], 1, p=[pr for (_, pr) in MIXTURE])[0] 168 | 169 | # generate the graph structure depending on the type 170 | if type == GraphType.ERDOS_RENYI: 171 | if degree == None: degree = random.random() * N 172 | G = erdos_renyi(N, degree, seed) 173 | elif type == GraphType.BARABASI_ALBERT: 174 | if degree == None: degree = int(random.random() * (N - 1)) + 1 175 | G = barabasi_albert(N, degree, seed) 176 | elif type == GraphType.GRID: 177 | G = grid(N) 178 | elif type == GraphType.CAVEMAN: 179 | G = caveman(N) 180 | elif type == GraphType.TREE: 181 | G = tree(N, seed) 182 | elif type == GraphType.LADDER: 183 | G = ladder(N) 184 | elif type == GraphType.LINE: 185 | G = line(N) 186 | elif type == GraphType.STAR: 187 | G = star(N) 188 | elif type == GraphType.CATERPILLAR: 189 | G = caterpillar(N, seed) 190 | elif type == GraphType.LOBSTER: 191 | G = lobster(N, seed) 192 | else: 193 | print("Type not defined") 194 | return 195 | 196 | # generate adjacency matrix and nodes values 197 | nodes = list(G) 198 | random.shuffle(nodes) 199 | adj_matrix = nx.to_numpy_array(G, nodes) 200 | node_values = np.random.uniform(low=0, high=1, size=N) 201 | 202 | # randomization 203 | adj_matrix = randomize(adj_matrix) 204 | 205 | # draw the graph created 206 | # nx.draw(G, pos=nx.spring_layout(G)) 207 | # plt.draw() 208 | 209 | return adj_matrix, node_values, type 210 | 211 | 212 | if __name__ == '__main__': 213 | for i in range(100): 214 | adj_matrix, node_values = generate_graph(10, GraphType.RANDOM, seed=i) 215 | print(adj_matrix) 216 | -------------------------------------------------------------------------------- /datasets_generation/multitask_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | 5 | import numpy as np 6 | import torch 7 | from inspect import signature 8 | 9 | from datasets_generation import graph_algorithms 10 | from datasets_generation.graph_generation import GraphType, generate_graph 11 | 12 | 13 | class DatasetMultitask: 14 | 15 | def __init__(self, n_graphs, N, seed, graph_type, get_nodes_labels, get_graph_labels, print_every, sssp, filename): 16 | self.adj = {} 17 | self.features = {} 18 | self.nodes_labels = {} 19 | self.graph_labels = {} 20 | 21 | def progress_bar(iteration, total, prefix='', suffix='', decimals=1, length=100, fill='█', printEnd=""): 22 | percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total))) 23 | filledLength = int(length * iteration // total) 24 | bar = fill * filledLength + '-' * (length - filledLength) 25 | print('\r{} |{}| {}% {}'.format(prefix, bar, percent, suffix), end=printEnd) 26 | 27 | def to_categorical(x, N): 28 | v = np.zeros(N) 29 | v[x] = 1 30 | return v 31 | 32 | for dset in N.keys(): 33 | if dset not in n_graphs: 34 | n_graphs[dset] = n_graphs['default'] 35 | 36 | total_n_graphs = sum(n_graphs[dset]) 37 | 38 | set_adj = [[] for _ in n_graphs[dset]] 39 | set_features = [[] for _ in n_graphs[dset]] 40 | set_nodes_labels = [[] for _ in n_graphs[dset]] 41 | set_graph_labels = [[] for _ in n_graphs[dset]] 42 | generated = 0 43 | 44 | progress_bar(0, total_n_graphs, prefix='Generating {:20}\t\t'.format(dset), 45 | suffix='({} of {})'.format(0, total_n_graphs)) 46 | 47 | for batch, batch_size in enumerate(n_graphs[dset]): 48 | for i in range(batch_size): 49 | # generate a random graph of type graph_type and size N 50 | seed += 1 51 | adj, features, type = generate_graph(N[dset][batch], graph_type, seed=seed) 52 | 53 | while np.min(np.max(adj, 0)) == 0.0: 54 | # remove graph with singleton nodes 55 | seed += 1 56 | adj, features, _ = generate_graph(N[dset][batch], type, seed=seed) 57 | 58 | generated += 1 59 | if generated % print_every == 0: 60 | progress_bar(generated, total_n_graphs, prefix='Generating {:20}\t\t'.format(dset), 61 | suffix='({} of {})'.format(generated, total_n_graphs)) 62 | 63 | # make sure there are no self connection 64 | assert np.all( 65 | np.multiply(adj, np.eye(N[dset][batch])) == np.zeros((N[dset][batch], N[dset][batch]))) 66 | 67 | if sssp: 68 | # define the source node 69 | source_node = np.random.randint(0, N[dset][batch]) 70 | 71 | # compute the labels with graph_algorithms; if sssp add the sssp 72 | node_labels = get_nodes_labels(adj, features, 73 | graph_algorithms.all_pairs_shortest_paths(adj, 0)[source_node] 74 | if sssp else None) 75 | graph_labels = get_graph_labels(adj, features) 76 | if sssp: 77 | # add the 1-hot feature determining the starting node 78 | features = np.stack([to_categorical(source_node, N[dset][batch]), features], axis=1) 79 | 80 | set_adj[batch].append(adj) 81 | set_features[batch].append(features) 82 | set_nodes_labels[batch].append(node_labels) 83 | set_graph_labels[batch].append(graph_labels) 84 | 85 | self.adj[dset] = [torch.from_numpy(np.asarray(adjs)).float() for adjs in set_adj] 86 | self.features[dset] = [torch.from_numpy(np.asarray(fs)).float() for fs in set_features] 87 | self.nodes_labels[dset] = [torch.from_numpy(np.asarray(nls)).float() for nls in set_nodes_labels] 88 | self.graph_labels[dset] = [torch.from_numpy(np.asarray(gls)).float() for gls in set_graph_labels] 89 | progress_bar(total_n_graphs, total_n_graphs, prefix='Generating {:20}\t\t'.format(dset), 90 | suffix='({} of {})'.format(total_n_graphs, total_n_graphs), printEnd='\n') 91 | 92 | self.save_as_pickle(filename) 93 | 94 | def save_as_pickle(self, filename): 95 | """" Saves the data into a pickle file at filename """ 96 | directory = os.path.dirname(filename) 97 | if not os.path.exists(directory): 98 | os.makedirs(directory) 99 | 100 | with open(filename, 'wb') as f: 101 | pickle.dump((self.adj, self.features, self.nodes_labels, self.graph_labels), f) 102 | 103 | 104 | if __name__ == '__main__': 105 | parser = argparse.ArgumentParser() 106 | parser.add_argument('--out', type=str, default='./data/multitask_dataset.pkl', help='Data path.') 107 | parser.add_argument('--seed', type=int, default=1234, help='Random seed.') 108 | parser.add_argument('--graph_type', type=str, default='RANDOM', help='Type of graphs in train set') 109 | parser.add_argument('--nodes_labels', nargs='+', default=["eccentricity", "graph_laplacian_features", "sssp"]) 110 | parser.add_argument('--graph_labels', nargs='+', default=["is_connected", "diameter", "spectral_radius"]) 111 | parser.add_argument('--extrapolation', action='store_true', default=False, 112 | help='Generated various test sets of dimensions larger than train and validation.') 113 | parser.add_argument('--print_every', type=int, default=20, help='') 114 | args = parser.parse_args() 115 | 116 | if 'sssp' in args.nodes_labels: 117 | sssp = True 118 | args.nodes_labels.remove('sssp') 119 | else: 120 | sssp = False 121 | 122 | # gets the functions of graph_algorithms from the specified datasets 123 | nodes_labels_algs = list(map(lambda s: getattr(graph_algorithms, s), args.nodes_labels)) 124 | graph_labels_algs = list(map(lambda s: getattr(graph_algorithms, s), args.graph_labels)) 125 | 126 | 127 | def get_nodes_labels(A, F, initial=None): 128 | labels = [] if initial is None else [initial] 129 | for f in nodes_labels_algs: 130 | params = signature(f).parameters 131 | labels.append(f(A, F) if 'F' in params else f(A)) 132 | return np.swapaxes(np.stack(labels), 0, 1) 133 | 134 | 135 | def get_graph_labels(A, F): 136 | labels = [] 137 | for f in graph_labels_algs: 138 | params = signature(f).parameters 139 | labels.append(f(A, F) if 'F' in params else f(A)) 140 | return np.asarray(labels).flatten() 141 | 142 | 143 | data = DatasetMultitask(n_graphs={'train': [512] * 10, 'val': [128] * 5, 'default': [256] * 5}, 144 | N={**{'train': range(15, 25), 'val': range(15, 25)}, **( 145 | {'test-(20,25)': range(20, 25), 'test-(25,30)': range(25, 30), 146 | 'test-(30,35)': range(30, 35), 'test-(35,40)': range(35, 40), 147 | 'test-(40,45)': range(40, 45), 'test-(45,50)': range(45, 50), 148 | 'test-(60,65)': range(60, 65), 'test-(75,80)': range(75, 80), 149 | 'test-(95,100)': range(95, 100)} if args.extrapolation else 150 | {'test': range(15, 25)})}, 151 | seed=args.seed, graph_type=getattr(GraphType, args.graph_type), 152 | get_nodes_labels=get_nodes_labels, get_graph_labels=get_graph_labels, 153 | print_every=args.print_every, sssp=sssp, filename=args.out) 154 | 155 | data.save_as_pickle(args.out) 156 | -------------------------------------------------------------------------------- /models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvignac/SMP/95b55a880d0fc9149ddf32e8c2fdf5eac5b474b3/models/.DS_Store -------------------------------------------------------------------------------- /models/gin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import Sequential, Linear, ReLU, ModuleList 4 | import torch.nn.functional as F 5 | from torch_geometric.nn import GINConv 6 | from models.utils.layers import XtoGlobal 7 | 8 | 9 | class FeatureExtractor(nn.Module): 10 | def __init__(self, in_features: int, out_features: int): 11 | super().__init__() 12 | self.XtoG = XtoGlobal(in_features, out_features, bias=True) 13 | self.lin = Linear(out_features, out_features, bias=False) 14 | 15 | def forward(self, x, batch_info): 16 | """ x: (num_nodes, in_features) 17 | output: (batch_size, out_features). """ 18 | out = self.XtoG.forward(x, batch_info) 19 | out = out + self.lin.forward(F.relu(out)) 20 | return out 21 | 22 | 23 | class GINNetwork(nn.Module): 24 | def __init__(self, in_features, out_features): 25 | super().__init__() 26 | self.lin_1 = nn.Linear(in_features, in_features) 27 | self.lin_2 = nn.Linear(in_features, out_features) 28 | 29 | def forward(self, x): 30 | x = self.lin_2(x + torch.relu(self.lin_1(x))) 31 | return x 32 | 33 | 34 | class GIN(nn.Module): 35 | def __init__(self, num_input_features: int, num_classes: int, num_layers: int, 36 | hidden, hidden_final: int, dropout_prob: float, use_batch_norm: bool): 37 | super().__init__() 38 | self.use_batch_norm = use_batch_norm 39 | self.dropout_prob = dropout_prob 40 | self.no_prop = FeatureExtractor(num_input_features, hidden_final) 41 | self.initial_lin_x = nn.Linear(num_input_features, hidden) 42 | 43 | self.convs = nn.ModuleList([]) 44 | self.batch_norm_x = nn.ModuleList() 45 | self.feature_extractors = nn.ModuleList([]) 46 | for i in range(num_layers): 47 | self.convs.append(GINConv(GINNetwork(hidden, hidden))) 48 | self.feature_extractors.append(FeatureExtractor(hidden, hidden_final)) 49 | self.batch_norm_x.append(nn.BatchNorm1d(hidden)) 50 | 51 | self.after_conv = nn.Linear(hidden_final, hidden_final) 52 | self.final_lin = nn.Linear(hidden_final, num_classes) 53 | 54 | def forward(self, data): 55 | """ data.x: (num_nodes, num_features)""" 56 | x, edge_index, batch, batch_size = data.x, data.edge_index, data.batch, data.num_graphs 57 | 58 | # Compute some information about the batch 59 | # Count the number of nodes in each graph 60 | unique, n_per_graph = torch.unique(data.batch, return_counts=True) 61 | n_batch = torch.zeros_like(batch, dtype=torch.float) 62 | 63 | for value, n in zip(unique, n_per_graph): 64 | n_batch[batch == value] = n.float() 65 | 66 | # Aggregate into a dict 67 | batch_info = {'num_nodes': data.num_nodes, 68 | 'num_graphs': data.num_graphs, 69 | 'batch': data.batch} 70 | 71 | out = self.no_prop.forward(x, batch_info) 72 | x = self.initial_lin_x(x) 73 | for i, (conv, bn_x, extractor) in enumerate(zip(self.convs, self.batch_norm_x, self.feature_extractors)): 74 | if self.use_batch_norm and i > 0: 75 | x = bn_x(x) 76 | x = conv(x, edge_index) 77 | global_features = extractor.forward(x, batch_info) 78 | out += global_features 79 | 80 | out = F.relu(out) / len(self.convs) 81 | out = F.relu(self.after_conv(out)) + out 82 | out = F.dropout(out, p=self.dropout_prob, training=self.training) 83 | out = self.final_lin(out) 84 | return F.log_softmax(out, dim=-1) 85 | 86 | def __repr__(self): 87 | return self.__class__.__name__ -------------------------------------------------------------------------------- /models/model_cycles.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from models.smp_layers import SimplifiedFastSMPLayer, FastSMPLayer, SMPLayer 5 | from models.utils.layers import GraphExtractor, EdgeCounter, BatchNorm 6 | from models.utils.misc import create_batch_info, map_x_to_u 7 | 8 | 9 | class SMP(torch.nn.Module): 10 | def __init__(self, num_input_features: int, num_classes: int, num_layers: int, hidden: int, layer_type: str, 11 | hidden_final: int, dropout_prob: float, use_batch_norm: bool, use_x: bool, map_x_to_u: bool, 12 | num_towers: int, simplified: bool): 13 | """ num_input_features: number of node features 14 | layer_type: 'SMP', 'FastSMP' or 'SimplifiedFastSMP' 15 | hidden_final: size of the feature map after pooling 16 | use_x: for ablation study, run a MPNN instead of SMP 17 | map_x_to_u: map the node features to the local context 18 | num_towers: inside each SMP layers, use towers to reduce the number of parameters 19 | simplified: less layers in the feature extractor. 20 | """ 21 | super().__init__() 22 | self.map_x_to_u, self.use_x = map_x_to_u, use_x 23 | self.dropout_prob = dropout_prob 24 | self.use_batch_norm = use_batch_norm 25 | self.edge_counter = EdgeCounter() 26 | self.num_classes = num_classes 27 | 28 | self.no_prop = GraphExtractor(in_features=num_input_features, out_features=hidden_final, use_x=use_x) 29 | self.initial_lin = nn.Linear(num_input_features, hidden) 30 | 31 | layer_type_dict = {'SMP': SMPLayer, 'FastSMP': FastSMPLayer, 'SimplifiedFastSMP': SimplifiedFastSMPLayer} 32 | conv_layer = layer_type_dict[layer_type] 33 | 34 | self.convs = nn.ModuleList() 35 | self.batch_norm_list = nn.ModuleList() 36 | self.feature_extractors = torch.nn.ModuleList([]) 37 | for i in range(0, num_layers): 38 | self.convs.append(conv_layer(in_features=hidden, num_towers=num_towers, out_features=hidden, use_x=use_x)) 39 | self.batch_norm_list.append(BatchNorm(hidden, use_x)) 40 | self.feature_extractors.append(GraphExtractor(in_features=hidden, out_features=hidden_final, use_x=use_x, 41 | simplified=simplified)) 42 | 43 | # Last layers 44 | self.simplified = simplified 45 | self.after_conv = nn.Linear(hidden_final, hidden_final) 46 | self.final_lin = nn.Linear(hidden_final, num_classes) 47 | 48 | def forward(self, data): 49 | """ data.x: (num_nodes, num_features)""" 50 | x, edge_index = data.x, data.edge_index 51 | batch_info = create_batch_info(data, self.edge_counter) 52 | 53 | # Create the context matrix 54 | if self.use_x: 55 | assert x is not None 56 | u = x 57 | elif self.map_x_to_u: 58 | u = map_x_to_u(data, batch_info) 59 | else: 60 | u = data.x.new_zeros((data.num_nodes, batch_info['n_colors'])) 61 | u.scatter_(1, data.coloring, 1) 62 | u = u[..., None] 63 | 64 | # Forward pass 65 | out = self.no_prop(u, batch_info) 66 | u = self.initial_lin(u) 67 | for i, (conv, bn, extractor) in enumerate(zip(self.convs, self.batch_norm_list, self.feature_extractors)): 68 | if self.use_batch_norm and i > 0: 69 | u = bn(u) 70 | u = conv(u, edge_index, batch_info) 71 | global_features = extractor.forward(u, batch_info) 72 | out += global_features / len(self.convs) 73 | 74 | # Two layer MLP with dropout and residual connections: 75 | if not self.simplified: 76 | out = torch.relu(self.after_conv(out)) + out 77 | out = F.dropout(out, p=self.dropout_prob, training=self.training) 78 | out = self.final_lin(out) 79 | if self.num_classes > 1: 80 | # Classification 81 | return F.log_softmax(out, dim=-1) 82 | else: 83 | # Regression 84 | assert out.shape[1] == 1 85 | return out[:, 0] 86 | 87 | def reset_parameters(self): 88 | for layer in [self.no_prop, self.initial_lin, *self.convs, *self.batch_norm_list, *self.feature_extractors, 89 | self.after_conv, self.final_lin]: 90 | layer.reset_parameters() 91 | 92 | def __repr__(self): 93 | return self.__class__.__name__ 94 | -------------------------------------------------------------------------------- /models/model_multi_task.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.utils.layers import EdgeCounter, NodeExtractor, BatchNorm 4 | from models.smp_layers import FastSMPLayer, SMPLayer, SimplifiedFastSMPLayer 5 | from torch_geometric.nn import Set2Set 6 | from models.utils.misc import create_batch_info 7 | 8 | 9 | class SMP(torch.nn.Module): 10 | def __init__(self, num_input_features: int, nodes_out: int, graph_out: int, 11 | num_layers: int, num_towers: int, hidden_u: int, out_u: int, hidden_gru: int, 12 | layer_type: str): 13 | """ num_input_features: number of node features 14 | nodes_out: number of output features at each node's level (3 on the benchmark) 15 | graph_out: number of output features at the graph level (3 on the benchmark) 16 | num_towers: inside each SMP layers, use towers to reduce the number of parameters 17 | hidden_u: number of channels in the local contexts 18 | out_u: number of channels after extraction of node features 19 | hidden_gru: number of channels inside the gated recurrent unit 20 | layer_type: 'SMP', 'FastSMP' or 'SimplifiedFastSMP'. 21 | """ 22 | super().__init__() 23 | num_input_u = 1 + num_input_features 24 | 25 | self.edge_counter = EdgeCounter() 26 | self.initial_lin_u = nn.Linear(num_input_u, hidden_u) 27 | 28 | self.extractor = NodeExtractor(hidden_u, out_u) 29 | 30 | layer_type_dict = {'SMP': SMPLayer, 'FastSMP': FastSMPLayer, 'SimplifiedFastSMP': SimplifiedFastSMPLayer} 31 | conv_layer = layer_type_dict[layer_type] 32 | 33 | self.gru = nn.GRU(out_u, hidden_gru) 34 | self.convs = nn.ModuleList([]) 35 | self.batch_norm_u = nn.ModuleList([]) 36 | for i in range(0, num_layers): 37 | self.batch_norm_u.append(BatchNorm(hidden_u, use_x=False)) 38 | conv = conv_layer(in_features=hidden_u, out_features=hidden_u, num_towers=num_towers, use_x=False) 39 | self.convs.append(conv) 40 | 41 | # Process the extracted node features 42 | max_n = 19 43 | self.set2set = Set2Set(hidden_gru, max_n) 44 | 45 | self.final_node = nn.Sequential(nn.Linear(hidden_gru, hidden_gru), nn.LeakyReLU(), 46 | nn.Linear(hidden_gru, hidden_gru), nn.LeakyReLU(), 47 | nn.Linear(hidden_gru, nodes_out)) 48 | 49 | self.final_graph = nn.Sequential(nn.Linear(2 * hidden_gru, hidden_gru), nn.ReLU(), 50 | nn.BatchNorm1d(hidden_gru), 51 | nn.Linear(hidden_gru, hidden_gru), nn.LeakyReLU(), 52 | nn.BatchNorm1d(hidden_gru), 53 | nn.Linear(hidden_gru, graph_out)) 54 | 55 | def forward(self, data): 56 | """ data.x: (num_nodes, num_features)""" 57 | x, edge_index, batch, batch_size = data.x, data.edge_index, data.batch, data.num_graphs 58 | batch_info = create_batch_info(data, self.edge_counter) 59 | 60 | # Create the context matrix 61 | u = data.x.new_zeros((data.num_nodes, batch_info['n_colors'])) 62 | u.scatter_(1, data.coloring, 1) 63 | u = u[..., None] 64 | 65 | # Map x to u 66 | shortest_path_ids = x[:, 0] 67 | lap_feat = x[:, 1] 68 | u_shortest_path = torch.zeros_like(u) 69 | u_lap_feat = torch.zeros_like(u) 70 | non_zero = shortest_path_ids.nonzero(as_tuple=False)[:, 0] 71 | nonzero_batch = batch_info['batch'][non_zero] 72 | nonzero_color = batch_info['coloring'][non_zero][:, 0] 73 | for b, c in zip(nonzero_batch, nonzero_color): 74 | u_shortest_path[batch == b, c] = 1 75 | 76 | for i, feat in enumerate(lap_feat): 77 | u_lap_feat[i, batch_info['coloring'][i]] = feat 78 | 79 | u = torch.cat((u, u_shortest_path, u_lap_feat), dim=2) 80 | 81 | # Forward pass 82 | u = self.initial_lin_u(u) 83 | hidden_state = None 84 | for i, (conv, bn_u) in enumerate(zip(self.convs, self.batch_norm_u)): 85 | if i > 0: 86 | u = bn_u(u) 87 | u = conv(u, edge_index, batch_info) 88 | extracted = self.extractor(x, u, batch_info)[None, :, :] 89 | hidden_state = self.gru(extracted, hidden_state)[1] 90 | 91 | # Compute the final representation 92 | out = hidden_state[0, :, :] 93 | nodes_out = self.final_node(out) 94 | after_set2set = self.set2set(out, batch_info['batch']) 95 | graph_out = self.final_graph(after_set2set) 96 | 97 | return nodes_out, graph_out 98 | 99 | def __repr__(self): 100 | return self.__class__.__name__ 101 | -------------------------------------------------------------------------------- /models/model_zinc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from models.smp_layers import ZincSMPLayer 5 | from models.utils.layers import GraphExtractor, EdgeCounter, BatchNorm 6 | from models.utils.misc import create_batch_info, map_x_to_u 7 | 8 | 9 | class SMPZinc(torch.nn.Module): 10 | def __init__(self, num_input_features: int, num_edge_features: int, num_classes: int, num_layers: int, 11 | hidden: int, residual: bool, use_edge_features: bool, shared_extractor: bool, 12 | hidden_final: int, use_batch_norm: bool, use_x: bool, map_x_to_u: bool, 13 | num_towers: int, simplified: bool): 14 | """ num_input_features: number of node features 15 | num_edge_features: number of edge features 16 | num_classes: output dimension 17 | hidden: number of channels of the local contexts 18 | residual: use residual connexion after each SMP layer 19 | use_edge_features: if False, edge features are simply ignored 20 | shared extractor: share extractor among layers to reduce the number of parameters 21 | hidden_final: number of channels after extraction of graph features 22 | use_x: for ablation study, run a MPNN instead of SMP 23 | map_x_to_u: map the initial node features to the local context. If false, node features are ignored 24 | num_towers: inside each SMP layers, use towers to reduce the number of parameters 25 | simplified: if True, the feature extractor has less layers. 26 | """ 27 | super().__init__() 28 | self.map_x_to_u, self.use_x = map_x_to_u, use_x 29 | self.use_batch_norm = use_batch_norm 30 | self.edge_counter = EdgeCounter() 31 | self.num_classes = num_classes 32 | self.residual = residual 33 | self.shared_extractor = shared_extractor 34 | 35 | self.no_prop = GraphExtractor(in_features=num_input_features, out_features=hidden_final, use_x=use_x) 36 | self.initial_lin = nn.Linear(num_input_features, hidden) 37 | 38 | self.convs = nn.ModuleList() 39 | self.batch_norm_list = nn.ModuleList() 40 | for i in range(0, num_layers): 41 | self.convs.append(ZincSMPLayer(in_features=hidden, num_towers=num_towers, out_features=hidden, 42 | edge_features=num_edge_features, use_x=use_x, 43 | use_edge_features=use_edge_features)) 44 | self.batch_norm_list.append(BatchNorm(hidden, use_x) if i > 0 else None) 45 | 46 | # Feature extractors 47 | if shared_extractor: 48 | self.feature_extractor = GraphExtractor(in_features=hidden, out_features=hidden_final, use_x=use_x, 49 | simplified=simplified) 50 | else: 51 | self.feature_extractors = torch.nn.ModuleList([]) 52 | for i in range(0, num_layers): 53 | self.feature_extractors.append(GraphExtractor(in_features=hidden, out_features=hidden_final, 54 | use_x=use_x, simplified=simplified)) 55 | 56 | # Last layers 57 | self.after_conv = nn.Linear(hidden_final, hidden_final) 58 | self.final_lin = nn.Linear(hidden_final, num_classes) 59 | 60 | def forward(self, data): 61 | """ data.x: (num_nodes, num_node_features)""" 62 | x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr 63 | 64 | # Compute information about the batch 65 | batch_info = create_batch_info(data, self.edge_counter) 66 | 67 | # Create the context matrix 68 | if self.use_x: 69 | assert x is not None 70 | u = x 71 | elif self.map_x_to_u: 72 | u = map_x_to_u(data, batch_info) 73 | else: 74 | u = data.x.new_zeros((data.num_nodes, batch_info['n_colors'])) 75 | u.scatter_(1, data.coloring, 1) 76 | u = u[..., None] 77 | 78 | # Forward pass 79 | out = self.no_prop(u, batch_info) 80 | u = self.initial_lin(u) 81 | for i in range(len(self.convs)): 82 | conv = self.convs[i] 83 | bn = self.batch_norm_list[i] 84 | extractor = self.feature_extractor if self.shared_extractor else self.feature_extractors[i] 85 | if self.use_batch_norm and i > 0: 86 | u = bn(u) 87 | u = conv(u, edge_index, edge_attr, batch_info) + (u if self.residual else 0) 88 | global_features = extractor.forward(u, batch_info) 89 | out += global_features / len(self.convs) 90 | 91 | out = self.final_lin(torch.relu(self.after_conv(out)) + out) 92 | assert out.shape[1] == 1 93 | return out[:, 0] 94 | 95 | def __repr__(self): 96 | return self.__class__.__name__ 97 | -------------------------------------------------------------------------------- /models/ppgn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch import Tensor 5 | 6 | 7 | class InvariantMaxLayer(nn.Module): 8 | def forward(self, x: Tensor): 9 | """ x: (batch_size, n_nodes, n_nodes, channels)""" 10 | bs, n, channels = x.shape[0], x.shape[1], x.shape[3] 11 | diag = torch.diagonal(x, dim1=1, dim2=2).contiguous() # batch, Channels, n_nodes 12 | # max_diag = diag.max(dim=2)[0] # Batch, channels 13 | max_diag = diag.sum(dim=2) 14 | mask = ~ torch.eye(n=x.shape[1], dtype=torch.bool, device=x.device)[None, :, :, None].expand(x.shape) 15 | x_off_diag = x[mask].reshape(bs, n, n - 1, channels) 16 | # max_off_diag = x_off_diag.max(dim=1)[0].max(dim=1)[0] 17 | max_off_diag = x_off_diag.sum(dim=1).sum(dim=1) 18 | out = torch.cat((max_diag, max_off_diag), dim=1) 19 | return out 20 | 21 | 22 | class UnitMLP(nn.Module): 23 | def __init__(self, in_feat: int, out_feat: int, num_layers): 24 | super().__init__() 25 | self.layers = nn.ModuleList() 26 | self.layers.append(nn.Conv2d(in_feat, out_feat, (1, 1))) 27 | for i in range(1, num_layers): 28 | self.layers.append(nn.Conv2d(out_feat, out_feat, (1, 1))) 29 | 30 | def forward(self, x: Tensor): 31 | """ x: batch x N x N x channels""" 32 | # Convert for conv2d 33 | x = x.permute(0, 3, 1, 2).contiguous() # channels, N, N 34 | for layer in self.layers[:-1]: 35 | x = F.relu(layer.forward(x)) 36 | x = self.layers[-1].forward(x) 37 | x = x.permute(0, 2, 3, 1) # batch_size, N, N, channels 38 | return x 39 | 40 | 41 | class PowerfulLayer(nn.Module): 42 | def __init__(self, in_feat: int, out_feat: int, num_layers: int): 43 | super().__init__() 44 | a = in_feat 45 | b = out_feat 46 | self.m1 = UnitMLP(a, b, num_layers) 47 | self.m2 = UnitMLP(a, b, num_layers) 48 | self.m4 = nn.Linear(a + b, b, bias=True) 49 | 50 | def forward(self, x): 51 | """ x: batch x N x N x in_feat""" 52 | out1 = self.m1.forward(x).permute(0, 3, 1, 2) # batch, out_feat, N, N 53 | out2 = self.m2.forward(x).permute(0, 3, 1, 2) # batch, out_feat, N, N 54 | out3 = x 55 | mult = out1 @ out2 # batch, out_feat, N, N 56 | out = torch.cat((mult.permute(0, 2, 3, 1), out3), dim=3) # batch, N, N, out_feat 57 | suffix = self.m4.forward(out) 58 | return suffix 59 | 60 | 61 | class FeatureExtractor(nn.Module): 62 | def __init__(self, in_features: int, out_features: int): 63 | super().__init__() 64 | self.lin1 = nn.Linear(in_features, out_features, bias=True) 65 | self.lin2 = nn.Linear(in_features, out_features, bias=False) 66 | self.lin3 = torch.nn.Linear(out_features, out_features, bias=False) 67 | 68 | def forward(self, u): 69 | """ u: (batch_size, num_nodes, num_nodes, in_features) 70 | output: (batch_size, out_features). """ 71 | n = u.shape[1] 72 | diag = u.diagonal(dim1=1, dim2=2) # batch_size, channels, num_nodes 73 | trace = torch.sum(diag, dim=2) 74 | out1 = self.lin1.forward(trace / n) 75 | 76 | s = (torch.sum(u, dim=[1, 2]) - trace) / (n * (n-1)) 77 | out2 = self.lin2.forward(s) # bs, out_feat 78 | out = out1 + out2 79 | out = out + self.lin3.forward(F.relu(out)) 80 | return out 81 | 82 | 83 | class Powerful(nn.Module): 84 | def __init__(self, num_classes: int, num_layers: int, hidden: int, hidden_final: int, dropout_prob: float, 85 | simplified: bool): 86 | super().__init__() 87 | layers_per_conv = 1 88 | self.layer_after_conv = not simplified 89 | self.dropout_prob = dropout_prob 90 | self.no_prop = FeatureExtractor(1, hidden_final) 91 | initial_conv = PowerfulLayer(1, hidden, layers_per_conv) 92 | self.convs = nn.ModuleList([initial_conv]) 93 | self.bns = nn.ModuleList([]) 94 | for i in range(1, num_layers): 95 | self.convs.append(PowerfulLayer(hidden, hidden, layers_per_conv)) 96 | 97 | self.feature_extractors = torch.nn.ModuleList([]) 98 | for i in range(num_layers): 99 | self.bns.append(nn.BatchNorm2d(hidden)) 100 | self.feature_extractors.append(FeatureExtractor(hidden, hidden_final)) 101 | if self.layer_after_conv: 102 | self.after_conv = nn.Linear(hidden_final, hidden_final) 103 | self.final_lin = nn.Linear(hidden_final, num_classes) 104 | 105 | def forward(self, data): 106 | u = data.A[..., None] # batch, N, N, 1 107 | out = self.no_prop.forward(u) 108 | for conv, extractor, bn in zip(self.convs, self.feature_extractors, self.bns): 109 | u = conv(u) 110 | u = bn(u.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) 111 | out = out + extractor.forward(u) 112 | out = F.relu(out) / len(self.convs) 113 | if self.layer_after_conv: 114 | out = out + F.relu(self.after_conv(out)) 115 | out = F.dropout(out, p=self.dropout_prob, training=self.training) 116 | out = self.final_lin(out) 117 | return F.log_softmax(out, dim=-1) 118 | -------------------------------------------------------------------------------- /models/ring_gnn.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/leichen2018/Ring-GNN/blob/master/src/model.py 2 | 3 | import torch 4 | import torch as th 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import numpy as np 8 | 9 | 10 | class FeatureExtractor(nn.Module): 11 | def __init__(self, in_features: int, out_features: int): 12 | super().__init__() 13 | self.lin1 = nn.Linear(in_features, out_features, bias=True) 14 | self.lin2 = nn.Linear(in_features, out_features, bias=False) 15 | self.lin3 = torch.nn.Linear(out_features, out_features, bias=False) 16 | 17 | def forward(self, u): 18 | """ u: (batch_size, num_nodes, num_nodes, in_features) 19 | output: (batch_size, out_features). """ 20 | n = u.shape[1] 21 | diag = u.diagonal(dim1=1, dim2=2) # batch_size, channels, num_nodes 22 | trace = torch.sum(diag, dim=2) 23 | out1 = self.lin1.forward(trace / n) 24 | 25 | s = (torch.sum(u, dim=[1, 2]) - trace) / (n * (n-1)) 26 | out2 = self.lin2.forward(s) # bs, out_feat 27 | out = out1 + out2 28 | out = out + self.lin3.forward(F.relu(out)) 29 | return out 30 | 31 | 32 | class RingGNN(nn.Module): 33 | def __init__(self, num_classes: int, num_layers: int, hidden: int, hidden_final: int, dropout_prob: float, 34 | simplified: bool): 35 | super().__init__() 36 | self.layer_after_conv = not simplified 37 | self.dropout_prob = dropout_prob 38 | self.no_prop = FeatureExtractor(1, hidden_final) 39 | initial_conv = equi_2_to_2(1, hidden) 40 | self.convs = nn.ModuleList([initial_conv]) 41 | self.bns = nn.ModuleList([]) 42 | for i in range(1, num_layers): 43 | self.convs.append(equi_2_to_2(hidden, hidden)) 44 | 45 | self.feature_extractors = torch.nn.ModuleList([]) 46 | for i in range(num_layers): 47 | self.bns.append(nn.BatchNorm2d(hidden)) 48 | self.feature_extractors.append(FeatureExtractor(hidden, hidden_final)) 49 | if self.layer_after_conv: 50 | self.after_conv = nn.Linear(hidden_final, hidden_final) 51 | self.final_lin = nn.Linear(hidden_final, num_classes) 52 | 53 | def forward(self, data): 54 | u = data.A[..., None] # batch, N, N, 1 55 | out = self.no_prop.forward(u) 56 | for conv, extractor, bn in zip(self.convs, self.feature_extractors, self.bns): 57 | u = conv(u) 58 | u = bn(u.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) 59 | out = out + extractor.forward(u) 60 | out = F.relu(out) / len(self.convs) 61 | if self.layer_after_conv: 62 | out = out + F.relu(self.after_conv(out)) 63 | out = F.dropout(out, p=self.dropout_prob, training=self.training) 64 | out = self.final_lin(out) 65 | return F.log_softmax(out, dim=-1) 66 | 67 | 68 | 69 | class MLP(nn.Module): 70 | def __init__(self, feats): 71 | super(MLP, self).__init__() 72 | self.linears = nn.ModuleList([nn.Linear(m, n) for m, n in zip(feats[:-1], feats[1:])]) 73 | 74 | def forward(self, x): 75 | for layer in self.linears[:-1]: 76 | x = layer(x) 77 | x = F.relu(x) 78 | 79 | return self.linears[-1](x) 80 | 81 | 82 | class equi_2_to_2(nn.Module): 83 | def __init__(self, input_depth, output_depth, normalization='inf', normalization_val=1.0, radius=2, k2_init=0.1): 84 | super(equi_2_to_2, self).__init__() 85 | basis_dimension = 15 86 | self.radius = radius 87 | # coeffs_values = lambda i, j, k: th.randn([i, j, k]) * th.sqrt(2. / (i + j).float()) 88 | coeffs_values = lambda i, j, k: th.randn([i, j, k]) * np.sqrt(2. / float((i + j))) 89 | self.diag_bias_list = nn.ParameterList([]) 90 | 91 | for i in range(radius): 92 | for j in range(i + 1): 93 | self.diag_bias_list.append(nn.Parameter(th.zeros(1, output_depth, 1, 1))) 94 | 95 | self.all_bias = nn.Parameter(th.zeros(1, output_depth, 1, 1)) 96 | self.coeffs_list = nn.ParameterList([]) 97 | 98 | for i in range(radius): 99 | for j in range(i + 1): 100 | self.coeffs_list.append(nn.Parameter(coeffs_values(input_depth, output_depth, basis_dimension))) 101 | 102 | self.switch = nn.ParameterList([nn.Parameter(th.FloatTensor([1])), nn.Parameter(th.FloatTensor([k2_init]))]) 103 | self.output_depth = output_depth 104 | 105 | self.normalization = normalization 106 | self.normalization_val = normalization_val 107 | 108 | def forward(self, inputs): 109 | inputs = inputs.permute(0, 3, 1, 2) # Convert to N x D x m x m 110 | m = inputs.size()[3] 111 | ops_out = ops_2_to_2(inputs, m, normalization=self.normalization) 112 | ops_out = th.stack(ops_out, dim=2) 113 | output_list = [] 114 | 115 | for i in range(self.radius): 116 | for j in range(i + 1): 117 | output_i = th.einsum('dsb,ndbij->nsij', self.coeffs_list[i * (i + 1) // 2 + j], ops_out) 118 | mat_diag_bias = th.eye(inputs.size()[3]).to(inputs.device).unsqueeze(0).unsqueeze(0) * self.diag_bias_list[ 119 | i * (i + 1) // 2 + j] 120 | if j == 0: 121 | output = output_i + mat_diag_bias 122 | else: 123 | output = th.einsum('abcd,abde->abce', output_i, output) 124 | 125 | output_list.append(output) 126 | 127 | output = 0 128 | for i in range(self.radius): 129 | output += output_list[i] * self.switch[i] 130 | 131 | output = output + self.all_bias 132 | output = output.permute(0, 2, 3, 1) 133 | return output 134 | 135 | 136 | def diag_offdiag_maxpool(input): 137 | max_diag = th.max(th.diagonal(input, dim1=2, dim2=3), dim=2)[0] 138 | 139 | max_val = th.max(max_diag) 140 | 141 | min_val = th.max(input * (-1.)) 142 | val = th.abs(max_val + min_val) 143 | min_mat = th.diag_embed(th.diagonal(input[0][0]) * 0 + val).unsqueeze(0).unsqueeze(0) 144 | max_offdiag = th.max(th.max(input - min_mat, dim=2)[0], dim=2)[0] 145 | 146 | return th.cat([max_diag, max_offdiag], dim=1) 147 | 148 | 149 | def ops_2_to_2(inputs, dim, normalization='inf', normalization_val=1.0): # N x D x m x m 150 | # input: N x D x m x m 151 | diag_part = th.diagonal(inputs, dim1=2, dim2=3) # N x D x m 152 | sum_diag_part = th.sum(diag_part, dim=2, keepdim=True) # N x D x 1 153 | sum_of_rows = th.sum(inputs, dim=3) # N x D x m 154 | sum_of_cols = th.sum(inputs, dim=2) # N x D x m 155 | sum_all = th.sum(sum_of_rows, dim=2) # N x D 156 | 157 | # op1 - (1234) - extract diag 158 | op1 = th.diag_embed(diag_part) # N x D x m x m 159 | 160 | # op2 - (1234) + (12)(34) - place sum of diag on diag 161 | op2 = th.diag_embed(sum_diag_part.repeat(1, 1, dim)) 162 | 163 | # op3 - (1234) + (123)(4) - place sum of row i on diag ii 164 | op3 = th.diag_embed(sum_of_rows) 165 | 166 | # op4 - (1234) + (124)(3) - place sum of col i on diag ii 167 | op4 = th.diag_embed(sum_of_cols) 168 | 169 | # op5 - (1234) + (124)(3) + (123)(4) + (12)(34) + (12)(3)(4) - place sum of all entries on diag 170 | op5 = th.diag_embed(sum_all.unsqueeze(2).repeat(1, 1, dim)) 171 | 172 | # op6 - (14)(23) + (13)(24) + (24)(1)(3) + (124)(3) + (1234) - place sum of col i on row i 173 | op6 = sum_of_cols.unsqueeze(3).repeat(1, 1, 1, dim) 174 | 175 | # op7 - (14)(23) + (23)(1)(4) + (234)(1) + (123)(4) + (1234) - place sum of row i on row i 176 | op7 = sum_of_rows.unsqueeze(3).repeat(1, 1, 1, dim) 177 | 178 | # op8 - (14)(2)(3) + (134)(2) + (14)(23) + (124)(3) + (1234) - place sum of col i on col i 179 | op8 = sum_of_cols.unsqueeze(2).repeat(1, 1, dim, 1) 180 | 181 | # op9 - (13)(24) + (13)(2)(4) + (134)(2) + (123)(4) + (1234) - place sum of row i on col i 182 | op9 = sum_of_rows.unsqueeze(2).repeat(1, 1, dim, 1) 183 | 184 | # op10 - (1234) + (14)(23) - identity 185 | op10 = inputs 186 | 187 | # op11 - (1234) + (13)(24) - transpose 188 | op11 = th.transpose(inputs, -2, -1) 189 | 190 | # op12 - (1234) + (234)(1) - place ii element in row i 191 | op12 = diag_part.unsqueeze(3).repeat(1, 1, 1, dim) 192 | 193 | # op13 - (1234) + (134)(2) - place ii element in col i 194 | op13 = diag_part.unsqueeze(2).repeat(1, 1, dim, 1) 195 | 196 | # op14 - (34)(1)(2) + (234)(1) + (134)(2) + (1234) + (12)(34) - place sum of diag in all entries 197 | op14 = sum_diag_part.unsqueeze(3).repeat(1, 1, dim, dim) 198 | 199 | # op15 - sum of all ops - place sum of all entries in all entries 200 | op15 = sum_all.unsqueeze(2).unsqueeze(3).repeat(1, 1, dim, dim) 201 | 202 | # A_2 = th.einsum('abcd,abde->abce', inputs, inputs) 203 | # A_4 = th.einsum('abcd,abde->abce', A_2, A_2) 204 | # op16 = th.where(A_4>1, th.ones(A_4.size()), A_4) 205 | 206 | if normalization is not None: 207 | float_dim = float(dim) 208 | if normalization is 'inf': 209 | op2 = th.div(op2, float_dim) 210 | op3 = th.div(op3, float_dim) 211 | op4 = th.div(op4, float_dim) 212 | op5 = th.div(op5, float_dim ** 2) 213 | op6 = th.div(op6, float_dim) 214 | op7 = th.div(op7, float_dim) 215 | op8 = th.div(op8, float_dim) 216 | op9 = th.div(op9, float_dim) 217 | op14 = th.div(op14, float_dim) 218 | op15 = th.div(op15, float_dim ** 2) 219 | 220 | # return [op1, op2, op3, op4, op5, op6, op7, op8, op9, op10, op11, op12, op13, op14, op15, op16] 221 | ''' 222 | l = [op1, op2, op3, op4, op5, op6, op7, op8, op9, op10, op11, op12, op13, op14, op15] 223 | for i, ls in enumerate(l): 224 | print(i+1) 225 | print(th.sum(ls)) 226 | print("$%^&*(*&^%$#$%^&*(*&^%$%^&*(*&^%$%^&*(") 227 | ''' 228 | return [op1, op2, op3, op4, op5, op6, op7, op8, op9, op10, op11, op12, op13, op14, op15] 229 | -------------------------------------------------------------------------------- /models/smp_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import Tensor 4 | from torch_geometric.nn import MessagePassing 5 | from models.utils.layers import XtoX, UtoU, UtoU, EntrywiseU, EntryWiseX 6 | 7 | 8 | class SimplifiedFastSMPLayer(MessagePassing): 9 | def __init__(self, in_features: int, num_towers: int, out_features: int, use_x: bool): 10 | super().__init__(aggr='add', node_dim=-3) 11 | self.use_x = use_x 12 | self.message_nn = (XtoX if use_x else UtoU)(in_features, out_features, bias=True) 13 | if self.use_x: 14 | self.alpha = nn.Parameter(torch.zeros(1, out_features), requires_grad=True) 15 | else: 16 | self.alpha = nn.Parameter(torch.zeros(1, 1, out_features), requires_grad=True) 17 | 18 | def reset_parameters(self): 19 | self.message_nn.reset_parameters() 20 | self.alpha.requires_grad_(False) 21 | self.alpha[...] = 0 22 | self.alpha.requires_grad_(True) 23 | 24 | def forward(self, u, edge_index, batch_info): 25 | """ x corresponds either to node features or to the local context, depending on use_x.""" 26 | n = batch_info['num_nodes'] 27 | if self.use_x and u.dim() == 1: 28 | u = u.unsqueeze(-1) 29 | u = self.message_nn(u, batch_info) 30 | new_u = self.propagate(edge_index, size=(n, n), u=u) 31 | # Normalization 32 | if len(new_u.shape) == 2: 33 | # node features are used 34 | new_u /= batch_info['average_edges'][:, :, 0] 35 | else: 36 | # local contexts are used 37 | new_u /= batch_info['average_edges'] 38 | return new_u 39 | 40 | def message(self, u_j: Tensor): 41 | return u_j 42 | 43 | def update(self, aggr_u, u): 44 | return aggr_u + u + self.alpha * u * aggr_u 45 | 46 | 47 | class FastSMPLayer(MessagePassing): 48 | def __init__(self, in_features: int, num_towers: int, out_features: int, use_x: bool): 49 | super().__init__(aggr='add', node_dim=-2 if use_x else -3) 50 | self.use_x = use_x 51 | self.in_u, self.out_u = in_features, out_features 52 | if use_x: 53 | self.message_nn = XtoX(in_features, out_features, bias=True) 54 | self.linu_i = EntryWiseX(out_features, out_features, num_towers=out_features) 55 | self.linu_j = EntryWiseX(out_features, out_features, num_towers=out_features) 56 | else: 57 | self.message_nn = UtoU(in_features, out_features, n_groups=num_towers, residual=False) 58 | self.linu_i = EntrywiseU(out_features, out_features, num_towers=out_features) 59 | self.linu_j = EntrywiseU(out_features, out_features, num_towers=out_features) 60 | 61 | def forward(self, u, edge_index, batch_info): 62 | n = batch_info['num_nodes'] 63 | u = self.message_nn(u, batch_info) 64 | new_u = self.propagate(edge_index, size=(n, n), u=u) 65 | new_u /= batch_info['average_edges'] 66 | return new_u 67 | 68 | def message(self, u_j): 69 | return u_j 70 | 71 | def update(self, aggr_u, u): 72 | a_i = self.linu_i(u) 73 | a_j = self.linu_j(aggr_u) 74 | return aggr_u + u + a_i * a_j 75 | 76 | 77 | class SMPLayer(MessagePassing): 78 | def __init__(self, in_features: int, num_towers: int, out_features: int, use_x: bool): 79 | super().__init__(aggr='add', node_dim=-3) 80 | self.use_x = use_x 81 | self.in_u, self.out_u = in_features, out_features 82 | if use_x: 83 | self.message_nn = XtoX(in_features, out_features, bias=True) 84 | self.order2_i = EntryWiseX(out_features, out_features, num_towers) 85 | self.order2_j = EntryWiseX(out_features, out_features, num_towers) 86 | self.order2 = EntryWiseX(out_features, out_features, num_towers) 87 | else: 88 | self.message_nn = UtoU(in_features, out_features, n_groups=num_towers, residual=False) 89 | self.order2_i = EntrywiseU(out_features, out_features, num_towers) 90 | self.order2_j = EntrywiseU(out_features, out_features, num_towers) 91 | self.order2 = EntrywiseU(out_features, out_features, num_towers) 92 | self.update1 = nn.Linear(2 * out_features, out_features) 93 | self.update2 = nn.Linear(out_features, out_features) 94 | 95 | def forward(self, u, edge_index, batch_info): 96 | n = batch_info['num_nodes'] 97 | u = self.message_nn(u, batch_info) 98 | u1 = self.order2_i(u) 99 | u2 = self.order2_j(u) 100 | new_u = self.propagate(edge_index, size=(n, n), u=u, u1=u1, u2=u2) 101 | new_u /= batch_info['average_edges'] 102 | return new_u 103 | 104 | def message(self, u_j, u1_i, u2_j): 105 | order2 = self.order2(torch.relu(u1_i + u2_j)) 106 | return order2 107 | 108 | def update(self, aggr_u, u): 109 | up1 = self.update1(torch.cat((u, aggr_u), dim=-1)) 110 | up2 = up1 + self.update2(up1) 111 | return up2 112 | 113 | 114 | class ZincSMPLayer(MessagePassing): 115 | def __init__(self, in_features: int, num_towers: int, out_features: int, edge_features: int, use_x: bool, 116 | use_edge_features: bool): 117 | """ Use a MLP both for the update and message function + edge features. """ 118 | super().__init__(aggr='add', node_dim=-2 if use_x else -3) 119 | self.use_x, self.use_edge_features = use_x, use_edge_features 120 | self.in_u, self.out_u, self.edge_features = in_features, out_features, edge_features 121 | self.edge_nn = nn.Linear(edge_features, out_features) if use_edge_features else None 122 | 123 | self.message_nn = (EntryWiseX if use_x else UtoU)(in_features, out_features, 124 | n_groups=num_towers, residual=False) 125 | 126 | args_order2 = [out_features, out_features, num_towers] 127 | entry_wise = EntryWiseX if use_x else EntrywiseU 128 | self.order2_i = entry_wise(*args_order2) 129 | self.order2_j = entry_wise(*args_order2) 130 | self.order2 = entry_wise(*args_order2) 131 | 132 | self.update1 = nn.Linear(2 * out_features, out_features) 133 | self.update2 = nn.Linear(out_features, out_features) 134 | 135 | def forward(self, u, edge_index, edge_attr, batch_info): 136 | n = batch_info['num_nodes'] 137 | u = self.message_nn(u, batch_info) 138 | u1 = self.order2_i(u) 139 | u2 = self.order2_j(u) 140 | new_u = self.propagate(edge_index, size=(n, n), u=u, u1=u1, u2=u2, edge_attr=edge_attr) 141 | new_u /= batch_info['average_edges'][:, :, 0] if self.use_x else batch_info['average_edges'] 142 | return new_u 143 | 144 | def message(self, u_j, u1_i, u2_j, edge_attr): 145 | edge_feat = self.edge_nn(edge_attr) if self.use_edge_features else 0 146 | if not self.use_x: 147 | edge_feat = edge_feat.unsqueeze(1) 148 | order2 = self.order2(torch.relu(u1_i + u2_j + edge_feat)) 149 | u_j = u_j + order2 150 | return u_j 151 | 152 | def update(self, aggr_u, u): 153 | up1 = self.update1(torch.cat((u, aggr_u), dim=-1)) 154 | up2 = up1 + self.update2(up1) 155 | return up2 + u 156 | -------------------------------------------------------------------------------- /models/utils/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import Tensor as Tensor 4 | from torch.nn import Linear as Linear 5 | import torch.nn.init as init 6 | from torch.nn.init import _calculate_correct_fan, calculate_gain 7 | import torch.nn.functional as F 8 | from torch_geometric.nn import global_mean_pool, global_add_pool, global_max_pool, MessagePassing 9 | import math 10 | 11 | small_gain = 0.01 12 | 13 | 14 | def pooling(x: torch.Tensor, batch_info, method): 15 | if method == 'add': 16 | return global_add_pool(x, batch_info['batch'], batch_info['num_graphs']) 17 | elif method == 'mean': 18 | return global_mean_pool(x, batch_info['batch'], batch_info['num_graphs']) 19 | elif method == 'max': 20 | return global_max_pool(x, batch_info['batch'], batch_info['num_graphs']) 21 | else: 22 | raise ValueError("Pooling method not implemented") 23 | 24 | 25 | def kaiming_init_with_gain(x: Tensor, gain: float, a=0, mode='fan_in', nonlinearity='relu'): 26 | fan = _calculate_correct_fan(x, mode) 27 | non_linearity_gain = calculate_gain(nonlinearity, a) 28 | std = non_linearity_gain / math.sqrt(fan) 29 | bound = math.sqrt(3.0) * std * gain # Calculate uniform bounds from standard deviation 30 | with torch.no_grad(): 31 | return x.uniform_(-bound, bound) 32 | 33 | 34 | class BatchNorm(nn.Module): 35 | def __init__(self, channels: int, use_x: bool): 36 | super().__init__() 37 | self.bn = nn.BatchNorm1d(channels) 38 | self.use_x = use_x 39 | 40 | def reset_parameters(self): 41 | self.bn.reset_parameters() 42 | 43 | def forward(self, u): 44 | if self.use_x: 45 | return self.bn(u) 46 | else: 47 | return self.bn(u.transpose(1, 2)).transpose(1, 2) 48 | 49 | 50 | class EdgeCounter(MessagePassing): 51 | def __init__(self): 52 | super().__init__(aggr='add') 53 | 54 | def forward(self, x, edge_index, batch, batch_size): 55 | n_edges = self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x) 56 | return global_mean_pool(n_edges, batch, batch_size)[batch] 57 | 58 | 59 | class Linear(nn.Module): 60 | """ Linear layer with potentially smaller parameters at initialization. """ 61 | __constants__ = ['bias', 'in_features', 'out_features'] 62 | 63 | def __init__(self, in_features, out_features, bias=True, gain: float = 1.0): 64 | super().__init__() 65 | self.gain = gain 66 | self.lin = nn.Linear(in_features, out_features, bias) 67 | 68 | def reset_parameters(self): 69 | kaiming_init_with_gain(self.lin.weight, self.gain) 70 | if self.lin.bias is not None: 71 | nn.init.normal_(self.lin.bias, 0, self.gain / math.sqrt(self.lin.out_features)) 72 | 73 | def forward(self, x): 74 | return self.lin.forward(x) 75 | 76 | 77 | class XtoX(Linear): 78 | def forward(self, x, batch_info: dict = None): 79 | return self.lin.forward(x) 80 | 81 | 82 | class XtoGlobal(Linear): 83 | def forward(self, x: Tensor, batch_info: dict, method='mean'): 84 | """ x: (num_nodes, in_features). """ 85 | g = pooling(x, batch_info, method) # bs, N, in_feat or bs, in_feat 86 | return self.lin.forward(g) 87 | 88 | 89 | class EntrywiseU(nn.Module): 90 | def __init__(self, in_features: int, out_features: int, num_towers=None): 91 | super().__init__() 92 | if num_towers is None: 93 | num_towers = in_features 94 | self.lin1 = torch.nn.Conv1d(in_features, out_features, kernel_size=1, groups=num_towers, bias=False) 95 | 96 | def forward(self, u): 97 | """ u: N x colors x channels. """ 98 | u = u.transpose(1, 2) 99 | u = self.lin1(u) 100 | return u.transpose(1, 2) 101 | 102 | 103 | class EntryWiseX(nn.Module): 104 | def __init__(self, in_features: int, out_features: int, n_groups=None, residual=False): 105 | super().__init__() 106 | self.residual = residual 107 | if n_groups is None: 108 | n_groups = in_features 109 | self.lin1 = torch.nn.Conv1d(in_features, out_features, kernel_size=1, groups=n_groups, bias=False) 110 | 111 | def forward(self, x, batch_info=None): 112 | """ x: N x channels. """ 113 | new_x = self.lin1(x.unsqueeze(-1)).squeeze() 114 | return (new_x + x) if self.residual else new_x 115 | 116 | class UtoU(nn.Module): 117 | def __init__(self, in_features: int, out_features: int, residual=True, n_groups=None): 118 | super().__init__() 119 | if n_groups is None: 120 | n_groups = 1 121 | self.residual = residual 122 | self.lin1 = torch.nn.Conv1d(in_features, out_features, kernel_size=1, groups=n_groups, bias=True) 123 | self.lin2 = torch.nn.Conv1d(in_features, out_features, kernel_size=1, groups=n_groups, bias=False) 124 | self.lin3 = torch.nn.Conv1d(in_features, out_features, kernel_size=1, groups=n_groups, bias=False) 125 | 126 | def forward(self, u: Tensor, batch_info: dict = None): 127 | """ U: N x n_colors x channels""" 128 | old_u = u 129 | n = batch_info['num_nodes'] 130 | num_colors = u.shape[1] 131 | out_feat = self.lin1.out_channels 132 | 133 | mask = batch_info['mask'][..., None].expand(n, num_colors, out_feat) 134 | normalizer = batch_info['n_batch'] 135 | mean2 = torch.sum(u / normalizer, dim=1) # N, in_feat 136 | mean2 = mean2.unsqueeze(-1) # N, in_feat, 1 137 | # 1. Transform u element-wise 138 | u = u.permute(0, 2, 1) # In conv1d, channel dimension is second 139 | out = self.lin1(u).permute(0, 2, 1) 140 | 141 | # 2. Put in self of each line the sum over each line 142 | # The 0.1 factor is here to bias the network in favor of learning powers of the adjacency 143 | z2 = self.lin2(mean2) * 0.1 # N, out_feat, 1 144 | z2 = z2.transpose(1, 2) # N, 1, out_feat 145 | index_tensor = batch_info['coloring'][:, :, None].expand(out.shape[0], 1, out_feat) 146 | out.scatter_add_(1, index_tensor, z2) # n, n_colors, out_feat 147 | 148 | # 3. Put everywhere the sum over each line 149 | z3 = self.lin3(mean2) # N, out_feat, 1 150 | z3 = z3.transpose(1, 2) # N, 1, out_feat 151 | out3 = z3.expand(n, num_colors, out_feat) 152 | out += out3 * mask * 0.1 # Mask the extra colors 153 | if self.residual: 154 | return old_u + out 155 | return out 156 | 157 | 158 | class UtoGlobal(nn.Module): 159 | def __init__(self, in_features: int , out_features: int, bias: bool, gain: float): 160 | super().__init__() 161 | self.lin1 = Linear(in_features, out_features, bias, gain=gain) 162 | self.lin2 = Linear(in_features, out_features, bias, gain=gain) 163 | 164 | def reset_parameters(self): 165 | for layer in [self.lin1, self.lin2]: 166 | layer.reset_parameters() 167 | 168 | def forward(self, u, batch_info: dict, method='mean'): 169 | """ u: (num_nodes, colors, in_features) 170 | output: (batch_size, out_features). """ 171 | coloring = batch_info['coloring'] 172 | # Extract trace 173 | index_tensor = coloring[:, :, None].expand(u.shape[0], 1, u.shape[2]) 174 | extended_diag = u.gather(1, index_tensor)[:, 0, :] # n_nodes, in_feat 175 | mean_batch_trace = pooling(extended_diag, batch_info, 'mean') # n_graphs, in_feat 176 | out1 = self.lin1(mean_batch_trace) # bs, out_feat 177 | # Extract sum of elements - trace 178 | mean = torch.sum(u / batch_info['n_batch'], dim=1) # num_nodes, in_feat 179 | batch_sum = pooling(mean, batch_info, 'mean') # n_graphs, in_feat 180 | batch_sum = batch_sum - mean_batch_trace # make the basis orthogonal 181 | out2 = self.lin2(batch_sum) # bs, out_feat 182 | return out1 + out2 183 | 184 | 185 | class NodeExtractor(nn.Module): 186 | def __init__(self, in_features_u: int, out_features_u: int): 187 | super().__init__() 188 | # Extract from U with a Deep set 189 | self.lin1_u = nn.Linear(in_features_u, in_features_u) 190 | self.lin2_u = nn.Linear(in_features_u, in_features_u) 191 | self.combine1 = nn.Linear(3 * in_features_u, out_features_u) 192 | 193 | def forward(self, x: Tensor, u: Tensor, batch_info: dict): 194 | """ u: (num_nodes, num_nodes, in_features). 195 | output: (num_nodes, out_feat). 196 | this method can probably be made more efficient. 197 | """ 198 | # Extract u 199 | new_u = self.lin2_u(torch.relu(self.lin1_u(u))) 200 | # Aggregation 201 | # a. Extract the value in self 202 | index_tensor = batch_info['coloring'][:, :, None].expand(u.shape[0], 1, u.shape[-1]) 203 | x1 = torch.gather(new_u, 1, index_tensor) 204 | x1 = x1[:, 0, :] 205 | # b. Mean over the line 206 | x2 = torch.sum(new_u / batch_info['n_batch'], dim=1) # num_nodes x in_feat 207 | # c. Max over the line 208 | x3 = torch.max(new_u, dim=1)[0] # num_nodes x out_feat 209 | # Combine 210 | x_full = torch.cat((x1, x2, x3), dim=1) 211 | out = self.combine1(x_full) 212 | return out 213 | 214 | 215 | class GraphExtractor(nn.Module): 216 | def __init__(self, in_features: int, out_features: int, use_x: bool, simplified=False): 217 | super().__init__() 218 | self.use_x, self.simplified = use_x, simplified 219 | self.extractor = (XtoGlobal if self.use_x else UtoGlobal)(in_features, out_features, True, 1) 220 | self.lin = nn.Linear(out_features, out_features) 221 | 222 | def reset_parameters(self): 223 | for layer in [self.extractor, self.lin]: 224 | layer.reset_parameters() 225 | 226 | def forward(self, u: Tensor, batch_info: dict): 227 | out = self.extractor(u, batch_info) 228 | if self.simplified: 229 | return out 230 | out = out + self.lin(F.relu(out)) 231 | return out 232 | -------------------------------------------------------------------------------- /models/utils/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def create_batch_info(data, edge_counter): 5 | """ Compute some information about the batch that will be used by SMP.""" 6 | x, edge_index, batch, batch_size = data.x, data.edge_index, data.batch, data.num_graphs 7 | 8 | # Compute some information about the batch 9 | # Count the number of nodes in each graph 10 | unique, n_per_graph = torch.unique(data.batch, return_counts=True) 11 | n_batch = torch.zeros_like(batch, dtype=torch.float) 12 | 13 | for value, n in zip(unique, n_per_graph): 14 | n_batch[batch == value] = n.float() 15 | 16 | # Count the average number of edges per graph 17 | dummy = x.new_ones((data.num_nodes, 1)) 18 | average_edges = edge_counter(dummy, edge_index, batch, batch_size) 19 | 20 | # Create the coloring if it does not exist yet 21 | if not hasattr(data, 'coloring'): 22 | data.coloring = data.x.new_zeros(data.num_nodes, dtype=torch.long) 23 | for i in range(data.num_graphs): 24 | data.coloring[data.batch == i] = torch.arange(n_per_graph[i], device=data.x.device) 25 | data.coloring = data.coloring[:, None] 26 | n_colors = torch.max(data.coloring) + 1 # Indexing starts at 0 27 | 28 | mask = torch.zeros(data.num_nodes, n_colors, dtype=torch.bool, device=x.device) 29 | for value, n in zip(unique, n_per_graph): 30 | mask[batch == value, :n] = True 31 | 32 | # Aggregate into a dict 33 | batch_info = {'num_nodes': data.num_nodes, 34 | 'num_graphs': data.num_graphs, 35 | 'batch': data.batch, 36 | 'n_per_graph': n_per_graph, 37 | 'n_batch': n_batch[:, None, None].float(), 38 | 'average_edges': average_edges[:, :, None], 39 | 'coloring': data.coloring, 40 | 'n_colors': n_colors, 41 | 'mask': mask # Used because of batching - it tells which entries of u are not used by the graph 42 | } 43 | return batch_info 44 | 45 | 46 | def map_x_to_u(data, batch_info): 47 | """ map the node features to the right row of the initial local context.""" 48 | x = data.x 49 | u = x.new_zeros((data.num_nodes, batch_info['n_colors'])) 50 | u.scatter_(1, data.coloring, 1) 51 | u = u[..., None] 52 | 53 | u_x = u.new_zeros((u.shape[0], u.shape[1], x.shape[1])) 54 | 55 | n_features = x.shape[1] 56 | coloring = batch_info['coloring'] # N x 1 57 | expanded_colors = coloring[..., None].expand(-1, -1, n_features) 58 | 59 | u_x = u_x.scatter_(dim=1, index=expanded_colors, src=x[:, None, :]) 60 | 61 | u = torch.cat((u, u_x), dim=2) 62 | return u 63 | 64 | -------------------------------------------------------------------------------- /models/utils/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.utils import to_dense_adj 3 | import networkx as nx 4 | import torch_geometric 5 | from torch_geometric.data import Data 6 | from networkx.algorithms.shortest_paths.unweighted import all_pairs_shortest_path_length 7 | from networkx.algorithms.coloring import greedy_color 8 | import numpy as np 9 | from sklearn.preprocessing import OneHotEncoder 10 | 11 | 12 | class EyeTransform(object): 13 | def __init__(self, max_num_nodes): 14 | self.max_num_nodes = max_num_nodes 15 | 16 | def __call__(self, data): 17 | n = data.x.shape[0] 18 | data.x = torch.eye(n, self.max_num_nodes, dtype=torch.float) 19 | return data 20 | 21 | def __repr__(self): 22 | return str(self.__class__.__name__) 23 | 24 | 25 | class RandomId(object): 26 | r"""Adds the node degree as one hot encodings to the node features. 27 | 28 | Args: 29 | max_degree (int): Maximum degree. 30 | in_degree (bool, optional): If set to :obj:`True`, will compute the 31 | in-degree of nodes instead of the out-degree. 32 | (default: :obj:`False`) 33 | cat (bool, optional): Concat node degrees to node features instead 34 | of replacing them. (default: :obj:`True`) 35 | """ 36 | def __init__(self): 37 | pass 38 | def __call__(self, data): 39 | n = data.x.shape[0] 40 | data.x = torch.randint(0, 100, (n, 1), dtype=torch.float) / 100 41 | # data.x = torch.randn(n, self.embedding_size, dtype=torch.float) 42 | return data 43 | 44 | def __repr__(self): 45 | return str(self.__class__.__name__) 46 | 47 | 48 | class DenseAdjMatrix(object): 49 | def __init__(self, n: int): 50 | """ n: number of nodes in the graph (should be constant)""" 51 | self.n = n 52 | 53 | def __call__(self, data): 54 | batch = data.edge_index.new_zeros(self.n) 55 | data.A = to_dense_adj(data.edge_index, batch) 56 | return data 57 | 58 | def __repr__(self): 59 | return str(self.__class__.__name__) 60 | 61 | 62 | class KHopColoringTransform(object): 63 | def __init__(self, k: int): 64 | self.k = k 65 | 66 | def __call__(self, data): 67 | """ Compute a coloring such that no node sees twice the same color in its k-hop neighbourhood.""" 68 | k = self.k 69 | g = torch_geometric.utils.to_networkx(data, to_undirected=True, remove_self_loops=True) 70 | lengths = all_pairs_shortest_path_length(g, cutoff=2 * k) 71 | lengths = [l for l in lengths] 72 | # Graph where 2k hop neighbors are connected 73 | k_hop_graph = nx.Graph() 74 | for lengths_tuple in lengths: 75 | origin = lengths_tuple[0] 76 | edges = [(origin, dest) for dest in lengths_tuple[1].keys()] 77 | k_hop_graph.add_edges_from(edges) 78 | # Color the k-hop graph 79 | best_n_colors = np.infty 80 | best_color_dict = None 81 | # for strategy in ['largest_first', 'random_sequential', 'saturation_largest_first']: 82 | for strategy in ['largest_first']: 83 | color_dict = greedy_color(k_hop_graph, strategy) 84 | n_colors = np.max([color for color in color_dict.values()]) + 1 85 | if n_colors < best_n_colors: 86 | best_n_colors = n_colors 87 | best_color_dict = color_dict 88 | # Convert back to torch-geometric. The coloring is contained in data.x 89 | data.coloring = torch.zeros((data.num_nodes, 1), dtype=torch.long) 90 | for key, val in best_color_dict.items(): 91 | data.coloring[key] = val 92 | print('Number of nodes: {} - Number of colors: {}'.format(data.num_nodes, data.coloring.max() + 1)) 93 | return data 94 | 95 | def __repr__(self): 96 | return '{}({})'.format(self.__class__.__name__, self.k) 97 | 98 | 99 | class OneHotNodeEdgeFeatures(object): 100 | def __init__(self, node_types, edge_types): 101 | self.c = node_types 102 | self.d = edge_types 103 | 104 | def __call__(self, data): 105 | n = data.x.shape[0] 106 | node_encoded = torch.zeros((n, self.c), dtype=torch.float32) 107 | node_encoded.scatter_(1, data.x.long(), 1) 108 | data.x = node_encoded 109 | e = data.edge_attr.shape[0] 110 | edge_encoded = torch.zeros((e, self.d), dtype=torch.float32) 111 | edge_attr = (data.edge_attr - 1).long().unsqueeze(-1) 112 | edge_encoded.scatter_(1, edge_attr, 1) 113 | data.edge_attr = edge_encoded 114 | return data 115 | 116 | def __repr__(self): 117 | return str(self.__class__.__name__) -------------------------------------------------------------------------------- /multi_task_main.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | import yaml 4 | from multi_task_utils.train import execute_train, build_arg_parser 5 | 6 | # Training settings 7 | parser = build_arg_parser() 8 | parser.add_argument('--wandb', action='store_true') 9 | parser.add_argument('--batch-size', type=int, default=16) 10 | parser.add_argument('--clip', type=float, default=5) 11 | parser.add_argument('--name', type=str, help="name for weights and biases") 12 | parser.add_argument('--debug', action='store_true') 13 | parser.add_argument('--load-from-epoch', type=int, default=-1) 14 | args = parser.parse_args() 15 | 16 | yaml_file = 'config_multi_task.yaml' 17 | with open(yaml_file) as f: 18 | model_config = yaml.load(f, Loader=yaml.FullLoader) 19 | print(model_config) 20 | 21 | model_name = model_config['model_name'] 22 | model_config.pop('model_name') 23 | print("Model name:", model_name) 24 | 25 | if args.wandb or args.name: 26 | import wandb 27 | args.wandb = True 28 | if args.name is None: 29 | args.name = model_name + f'_{args.k}_{args.n}' 30 | wandb.init(project="pna_v2", config=model_config, name=args.name) 31 | wandb.config.update(args) 32 | 33 | execute_train(gnn_args=model_config, args=args) 34 | -------------------------------------------------------------------------------- /multi_task_utils/train.py: -------------------------------------------------------------------------------- 1 | # This file was adapted from https://github.com/lukecavabarrett/pna 2 | 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import argparse 7 | import os 8 | import sys 9 | import time 10 | from types import SimpleNamespace 11 | import wandb 12 | import numpy as np 13 | import torch 14 | import torch.optim as optim 15 | import numpy.random as npr 16 | from torch_geometric.data import DataLoader 17 | from models.model_multi_task import SMP 18 | from multi_task_utils.util import load_dataset, to_torch_geom, specific_loss_torch_geom 19 | 20 | log_loss_tasks = ["log_shortest_path", "log_eccentricity", "log_laplacian", 21 | "log_connected", "log_diameter", "log_radius"] 22 | 23 | 24 | def build_arg_parser(): 25 | """ 26 | :return: argparse.ArgumentParser() filled with the standard arguments for a training session. 27 | Might need to be enhanced for some models. 28 | """ 29 | parser = argparse.ArgumentParser() 30 | 31 | parser.add_argument('--data', type=str, default='./data/multitask_dataset.pkl', help='Data path.') 32 | parser.add_argument('--gpu', type=int, help='Id of the GPU') 33 | parser.add_argument('--no-cuda', action='store_true', default=False, help='Disables CUDA training.') 34 | parser.add_argument('--only_nodes', action='store_true', default=False, help='Evaluate only nodes labels.') 35 | parser.add_argument('--only_graph', action='store_true', default=False, help='Evaluate only graph labels.') 36 | parser.add_argument('--seed', type=int, default=42, help='Random seed.') 37 | parser.add_argument('--epochs', type=int, default=3000, help='Number of epochs to train.') 38 | parser.add_argument('--lr', type=float, default=0.001, help='Initial learning rate.') 39 | parser.add_argument('--weight_decay', type=float, default=1e-6, help='Weight decay (L2 loss on parameters).') 40 | parser.add_argument('--patience', type=int, default=1000, help='Patience') 41 | parser.add_argument('--loss', type=str, default='mse', help='Loss function to use.') 42 | parser.add_argument('--print_every', type=int, default=5, help='Print training results every') 43 | return parser 44 | 45 | 46 | def execute_train(gnn_args, args): 47 | """ 48 | :param gnn_args: the description of the model to be trained (expressed as arguments for GNN.__init__) 49 | :param args: the parameters of the training session 50 | """ 51 | if not os.path.isdir('./saved_models'): 52 | os.mkdir('./saved_models') 53 | if args.name is not None: 54 | save_dir = f'./saved_models/{args.name}' 55 | else: 56 | save_dir = f'./saved_models/' 57 | if args.name is not None and not os.path.isdir(save_dir): 58 | os.mkdir(save_dir) 59 | 60 | use_cuda = args.gpu is not None and torch.cuda.is_available() and not args.no_cuda 61 | if use_cuda: 62 | device = torch.device("cuda:" + str(args.gpu)) 63 | torch.cuda.set_device(args.gpu) 64 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) 65 | else: 66 | device = "cpu" 67 | args.device = device 68 | args.kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} 69 | print('Using device:', device) 70 | 71 | np.random.seed(args.seed) 72 | torch.manual_seed(args.seed) 73 | if use_cuda: 74 | torch.cuda.manual_seed(args.seed) 75 | 76 | # load data 77 | adj, features, node_labels, graph_labels = load_dataset(args.data, args.loss, args.only_nodes, args.only_graph, 78 | print_baseline=True) 79 | print("Processing torch geometric data") 80 | graphs = to_torch_geom(adj, features, node_labels, graph_labels, device, args.debug) 81 | train_loaders = [DataLoader(given_size, args.batch_size, shuffle=True) for given_size in graphs['train']] 82 | batch_sizes = {'train': args.batch_size, 'val': 128, 'test': 256} 83 | val_loaders = [DataLoader(given_size, 128) for given_size in graphs['val']] 84 | test_loaders = [DataLoader(given_size, 256) for given_size in graphs['test']] 85 | print("Data loaders created") 86 | # model and optimizer 87 | gnn_args = SimpleNamespace(**gnn_args) 88 | 89 | gnn_args.num_input_features = features['train'][0].shape[2] 90 | gnn_args.nodes_out = 3 91 | gnn_args.graph_out = 3 92 | model = SMP(**vars(gnn_args)).to(device) 93 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 94 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 95 | step_size=50, 96 | gamma=0.92) 97 | 98 | pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 99 | print("Total params", pytorch_total_params) 100 | 101 | if args.load_from_epoch != -1: 102 | checkpoint = torch.load(os.path.join(save_dir, f'{args.load_from_epoch}.pkl')) 103 | model.load_state_dict(checkpoint['model_state_dict']) 104 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 105 | epoch = checkpoint['epoch'] 106 | else: 107 | epoch = 0 108 | 109 | 110 | def train(epoch): 111 | """ Execute a single epoch of the training loop 112 | epoch (int): the number of the epoch being performed (0-indexed).""" 113 | t = time.time() 114 | 115 | # 1. Train 116 | nan_counts = 0 117 | model.train() 118 | total_train_loss_per_task = 0 119 | npr.shuffle(train_loaders) 120 | for i, loader in enumerate(train_loaders): 121 | for j, data in enumerate(loader): 122 | # Optimization 123 | optimizer.zero_grad() 124 | output = model(data.to(device)) 125 | train_loss_per_task = specific_loss_torch_geom(output, (data.pos, data.y), data.batch, args.batch_size) 126 | loss_train = torch.mean(train_loss_per_task) 127 | if torch.isnan(loss_train): 128 | print(f"Warning: loss was nan at epoch {epoch} and batch {i}{j}.") 129 | nan_counts += 1 130 | if nan_counts < 20: 131 | continue 132 | else: 133 | raise ValueError(f"Too many NaNs. Stopping training at epoch {epoch}. Best epoch: {best_epoch}") 134 | loss_train.backward() 135 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) 136 | optimizer.step() 137 | # Compute metrics 138 | total_train_loss_per_task += train_loss_per_task / len(loader) 139 | total_train_loss_per_task /= len(train_loaders) 140 | train_log_loss_per_task = torch.log10(total_train_loss_per_task).data.cpu().numpy() 141 | train_loss = torch.mean(total_train_loss_per_task).data.item() 142 | 143 | # validation epoch 144 | model.eval() 145 | val_loss_per_task = 0 146 | for loader in val_loaders: 147 | for i, data in enumerate(loader): 148 | if i > 0: 149 | print("Warning: not all the batch was loaded at once. It will lead to incorrect results.") 150 | output = model(data.to(device)) 151 | batch_loss_per_task = specific_loss_torch_geom(output, (data.pos, data.y), data.batch, batch_sizes['val']) 152 | val_loss_per_task += batch_loss_per_task.detach() / len(val_loaders) 153 | 154 | val_log_loss_per_task = torch.log10(val_loss_per_task).data.cpu().numpy() 155 | val_log_loss = torch.mean(val_loss_per_task).item() 156 | 157 | if epoch % args.print_every == 0: 158 | print('Epoch: {:04d}'.format(epoch + 1), 159 | 'loss.train: {:.4f}'.format(train_loss), 160 | 'log.loss.val: {:.4f}'.format(val_log_loss), 161 | 'time: {:.4f}s'.format(time.time() - t)) 162 | print(f'train loss per task (log10 scale): {train_log_loss_per_task}') 163 | print(f'val loss per task (log10 scale): {val_log_loss_per_task}') 164 | sys.stdout.flush() 165 | if args.wandb: 166 | wandb_dict = {"Epoch": epoch, "Duration": time.time() - t, "Train loss": train_loss, 167 | "Val log loss": val_log_loss} 168 | for loss, tr, val in zip(log_loss_tasks, train_log_loss_per_task, val_log_loss_per_task): 169 | wandb_dict[loss + 'tr'] = tr 170 | wandb_dict[loss + 'val'] = val 171 | wandb.log(wandb_dict) 172 | 173 | return val_log_loss 174 | 175 | def compute_test(): 176 | """ 177 | Evaluate the current model on all the sets of the dataset, printing results. 178 | This procedure is destructive on datasets. 179 | """ 180 | model.eval() 181 | sets = list(features.keys()) 182 | for dset, loaders in zip(sets, [train_loaders, val_loaders, test_loaders]): 183 | final_specific_loss = 0 184 | final_total_loss = 0 185 | for loader in loaders: 186 | loader_total_loss = 0 187 | loader_specific_loss = 0 188 | for data in loader: 189 | output = model(data.to(device)) 190 | specific_loss = specific_loss_torch_geom(output, (data.pos, data.y), 191 | data.batch, batch_sizes[dset]).detach() 192 | loader_specific_loss += specific_loss 193 | loader_total_loss += torch.mean(specific_loss) 194 | # Average the loss over each loader 195 | loader_specific_loss /= len(loader) 196 | loader_total_loss /= len(loader) 197 | # Average the loss over the different loaders 198 | final_specific_loss += loader_specific_loss / len(loaders) 199 | final_total_loss += loader_total_loss / len(loaders) 200 | del output, loader_specific_loss 201 | 202 | print("Test set results ", dset, ": loss= {:.4f}".format(final_total_loss)) 203 | print(dset, ": ", final_specific_loss) 204 | print("Results in log scale", np.log10(final_specific_loss.detach().cpu()), 205 | np.log10(final_total_loss.detach().cpu().numpy())) 206 | if args.wandb: 207 | wandb.run.summary["test results"] = np.log10(final_specific_loss.detach().cpu()) 208 | # free unnecessary data 209 | 210 | 211 | final_specific_numpy = np.log10(final_specific_loss.detach().cpu()) 212 | del final_total_loss, final_specific_loss 213 | torch.cuda.empty_cache() 214 | return final_specific_numpy 215 | 216 | sys.stdout.flush() 217 | # Train model 218 | t_total = time.time() 219 | loss_values = [] 220 | bad_counter = 0 221 | best = args.epochs + 1 222 | best_epoch = -1 223 | 224 | sys.stdout.flush() 225 | 226 | while epoch < args.epochs: 227 | epoch += 1 228 | 229 | loss_values.append(train(epoch)) 230 | scheduler.step() 231 | if epoch % 100 == 0: 232 | print("Results on the test set:") 233 | results_test = compute_test() 234 | print('Test set results', results_test) 235 | print(f"Saving checkpoint at epoch {epoch}") 236 | torch.save({ 237 | 'epoch': epoch, 238 | 'model_state_dict': model.state_dict(), 239 | 'optimizer_state_dict': optimizer.state_dict(), 240 | }, os.path.join(save_dir, f'{epoch}.pkl')) 241 | 242 | if loss_values[-1] < best: 243 | # save current model 244 | if loss_values[-1] < best: 245 | print(f"New best validation error at epoch {epoch}") 246 | else: 247 | print(f"Saving checkpoint at epoch {epoch}") 248 | torch.save({ 249 | 'epoch': epoch, 250 | 'model_state_dict': model.state_dict(), 251 | 'optimizer_state_dict': optimizer.state_dict(), 252 | }, os.path.join(save_dir, f'{epoch}.pkl')) 253 | # remove previous model 254 | if best_epoch >= 0: 255 | f_name = os.path.join(save_dir, f'{best_epoch}.pkl') 256 | if os.path.isfile(f_name): 257 | os.remove(f_name) 258 | # update training variables 259 | best = loss_values[-1] 260 | best_epoch = epoch 261 | bad_counter = 0 262 | else: 263 | bad_counter += 1 264 | 265 | if bad_counter == args.patience: 266 | print('Early stop at epoch {} (no improvement in last {} epochs)'.format(epoch + 1, bad_counter)) 267 | break 268 | 269 | print("Optimization Finished!") 270 | print("Total time elapsed: {:.4f}s".format(time.time() - t_total)) 271 | 272 | # Restore best model 273 | print('Loading {}th epoch'.format(best_epoch + 1)) 274 | checkpoint = torch.load(os.path.join(save_dir, f'{best_epoch}.pkl')) 275 | model.load_state_dict(checkpoint['model_state_dict']) 276 | 277 | # Testing 278 | print("Results on the test set:") 279 | results_test = compute_test() 280 | print('Test set results', results_test) 281 | -------------------------------------------------------------------------------- /multi_task_utils/util.py: -------------------------------------------------------------------------------- 1 | # This file was adapted from https://github.com/lukecavabarrett/pna 2 | 3 | from __future__ import division 4 | from __future__ import print_function 5 | from torch_geometric import data 6 | from torch_geometric.utils import dense_to_sparse 7 | import pickle 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | from torch_geometric.nn import global_add_pool 12 | 13 | 14 | def to_torch_geom(adj, features, node_labels, graph_labels, device, debug): 15 | graphs = {} 16 | for key in adj.keys(): # train, val, test 17 | graphs[key] = [] 18 | for i in range(len(adj[key])): # Graph of a given size 19 | batch_i = [] 20 | for j in range(adj[key][i].shape[0]): # Number of graphs 21 | graph_adj = adj[key][i][j] 22 | graph = data.Data(x=features[key][i][j], 23 | edge_index=dense_to_sparse(graph_adj)[0], 24 | y=graph_labels[key][i][j].unsqueeze(0), 25 | pos=node_labels[key][i][j]) 26 | if not debug: 27 | batch_i.append(graph) 28 | if debug: 29 | batch_i.append(graph) 30 | graphs[key].append(batch_i) 31 | return graphs 32 | 33 | 34 | def load_dataset(data_path, loss, only_nodes, only_graph, print_baseline=True): 35 | with open(data_path, 'rb') as f: 36 | (adj, features, node_labels, graph_labels) = pickle.load(f) 37 | 38 | # normalize labels 39 | max_node_labels = torch.cat([nls.max(0)[0].max(0)[0].unsqueeze(0) for nls in node_labels['train']]).max(0)[0] 40 | max_graph_labels = torch.cat([gls.max(0)[0].unsqueeze(0) for gls in graph_labels['train']]).max(0)[0] 41 | for dset in node_labels.keys(): 42 | node_labels[dset] = [nls / max_node_labels for nls in node_labels[dset]] 43 | graph_labels[dset] = [gls / max_graph_labels for gls in graph_labels[dset]] 44 | 45 | if print_baseline: 46 | # calculate baseline 47 | mean_node_labels = torch.cat([nls.mean(0).mean(0).unsqueeze(0) for nls in node_labels['train']]).mean(0) 48 | mean_graph_labels = torch.cat([gls.mean(0).unsqueeze(0) for gls in graph_labels['train']]).mean(0) 49 | 50 | for dset in node_labels.keys(): 51 | if dset not in ['train', 'val']: 52 | baseline_nodes = [mean_node_labels.repeat(list(nls.shape[0:-1]) + [1]) for nls in node_labels[dset]] 53 | baseline_graph = [mean_graph_labels.repeat([gls.shape[0], 1]) for gls in graph_labels[dset]] 54 | 55 | print("Baseline loss ", dset, 56 | np.log10(specific_loss_multiple_batches((baseline_nodes, baseline_graph), 57 | (node_labels[dset], graph_labels[dset]), 58 | loss=loss, only_nodes=only_nodes, only_graph=only_graph))) 59 | 60 | return adj, features, node_labels, graph_labels 61 | 62 | 63 | SUPPORTED_ACTIVATION_MAP = {'ReLU', 'Sigmoid', 'Tanh', 'ELU', 'SELU', 'GLU', 'LeakyReLU', 'Softplus', 'None'} 64 | 65 | 66 | def get_activation(activation): 67 | """ returns the activation function represented by the input string """ 68 | if activation and callable(activation): 69 | # activation is already a function 70 | return activation 71 | # search in SUPPORTED_ACTIVATION_MAP a torch.nn.modules.activation 72 | activation = [x for x in SUPPORTED_ACTIVATION_MAP if activation.lower() == x.lower()] 73 | assert len(activation) == 1 and isinstance(activation[0], str), 'Unhandled activation function' 74 | activation = activation[0] 75 | if activation.lower() == 'none': 76 | return None 77 | return vars(torch.nn.modules.activation)[activation]() 78 | 79 | 80 | def get_loss(loss, output, target): 81 | if loss == "mse": 82 | return F.mse_loss(output, target) 83 | elif loss == "cross_entropy": 84 | if len(output.shape) > 2: 85 | (B, N, _) = output.shape 86 | output = output.reshape((B * N, -1)) 87 | target = target.reshape((B * N, -1)) 88 | _, target = target.max(dim=1) 89 | return F.cross_entropy(output, target) 90 | else: 91 | print("Error: loss function not supported") 92 | 93 | 94 | def specific_loss_torch_geom(output, target, batch, batch_size): 95 | """ output: list of len 2 containing node and graph outputs 96 | returns the average losses of each task """ 97 | average_nodes = output[0].shape[0] / batch_size # Average nb nodes in each graph 98 | # Node loss 99 | node_out = output[0] # N x 3 100 | loss = (node_out - target[0]) ** 2 101 | error = global_add_pool(loss, batch, batch_size) / average_nodes # N graphs x 3 102 | nodes_loss = torch.mean(error, dim=0) # 3 103 | graph_loss = torch.mean((output[1] - target[1]) ** 2, dim=0) # 3 104 | specific_loss = torch.cat((nodes_loss, graph_loss)) 105 | return specific_loss 106 | 107 | 108 | def total_loss_torch_geom(output, target, batch, batch_size): 109 | """ returns the average of the average losses of each task """ 110 | specific_loss = specific_loss_torch_geom(output, target, batch, batch_size) 111 | weighted_average = torch.mean(specific_loss) 112 | return weighted_average 113 | 114 | 115 | def total_loss(output, target, loss='mse', only_nodes=False, only_graph=False): 116 | """ returns the average of the average losses of each task """ 117 | assert not (only_nodes and only_graph) 118 | 119 | if only_nodes: 120 | nodes_loss = get_loss(loss, output[0], target[0]) 121 | return nodes_loss 122 | elif only_graph: 123 | graph_loss = get_loss(loss, output[1], target[1]) 124 | return graph_loss 125 | 126 | nodes_loss = get_loss(loss, output[0], target[0]) 127 | graph_loss = get_loss(loss, output[1], target[1]) 128 | weighted_average = (nodes_loss * output[0].shape[-1] + graph_loss * output[1].shape[-1]) / ( 129 | output[0].shape[-1] + output[1].shape[-1]) 130 | return weighted_average 131 | 132 | 133 | def total_loss_multiple_batches(output, target, loss='mse', only_nodes=False, only_graph=False): 134 | """ returns the average of the average losses of each task over all batches, 135 | batches are weighted equally regardless of their cardinality or graph size """ 136 | return sum([total_loss_torch_geom((output[0][batch], output[1][batch]), (target[0][batch], target[1][batch]), 137 | loss, only_nodes, only_graph).data.item() 138 | for batch in range(len(output[0]))]) / len(output[0]) 139 | 140 | 141 | def specific_loss(output, target, loss='mse', only_nodes=False, only_graph=False): 142 | """ returns the average loss for each task """ 143 | assert not (only_nodes and only_graph) 144 | 145 | if only_nodes: 146 | nodes_losses = [get_loss(loss, output[0][:, :, k], target[0][:, :, k]).item() for k in 147 | range(output[0].shape[-1])] 148 | return nodes_losses 149 | elif only_graph: 150 | graph_loss = [get_loss(loss, output[1][:, k], target[1][:, k]).item() for k in range(output[1].shape[-1])] 151 | return graph_loss 152 | 153 | nodes_losses = [get_loss(loss, output[0][:, :, k], target[0][:, :, k]).item() for k in range(output[0].shape[-1])] 154 | graph_loss = [get_loss(loss, output[1][:, k], target[1][:, k]).item() for k in range(output[1].shape[-1])] 155 | return nodes_losses + graph_loss 156 | 157 | 158 | def specific_loss_multiple_batches(output, target, loss='mse', only_nodes=False, only_graph=False): 159 | """ returns the average loss over all batches for each task, 160 | batches are weighted equally regardless of their cardinality or graph size """ 161 | assert not (only_nodes and only_graph) 162 | 163 | n_batches = len(output[0]) 164 | classes = (output[0][0].shape[-1] if not only_graph else 0) + (output[1][0].shape[-1] if not only_nodes else 0) 165 | 166 | sum_losses = [0] * classes 167 | for batch in range(n_batches): 168 | spec_loss = specific_loss((output[0][batch], output[1][batch]), (target[0][batch], target[1][batch]), loss, 169 | only_nodes, only_graph) 170 | for par in range(classes): 171 | sum_losses[par] += spec_loss[par] 172 | 173 | return [sum_loss / n_batches for sum_loss in sum_losses] 174 | 175 | 176 | def save_checkpoint(path, model, optimizer, epoch): 177 | torch.save({ 178 | 'epoch': epoch, 179 | 'model_state_dict': model.state_dict(), 180 | 'optimizer_state_dict': optimizer.state_dict()}, path) 181 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | yaml 3 | argparse 4 | numpy 5 | pickle 6 | networkx 7 | easydict 8 | wandb -------------------------------------------------------------------------------- /saved_models/PPGN_4/epoch0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvignac/SMP/95b55a880d0fc9149ddf32e8c2fdf5eac5b474b3/saved_models/PPGN_4/epoch0.pkl -------------------------------------------------------------------------------- /saved_models/ZINC/Zinc_SMP.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvignac/SMP/95b55a880d0fc9149ddf32e8c2fdf5eac5b474b3/saved_models/ZINC/Zinc_SMP.pkl -------------------------------------------------------------------------------- /zinc_main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import os 5 | import torch 6 | import torch.nn as nn 7 | from torch_geometric.data import DataLoader 8 | from torch_geometric.datasets import ZINC 9 | import argparse 10 | import numpy as np 11 | import time 12 | import yaml 13 | from models.model_zinc import SMPZinc 14 | from models.utils.transforms import OneHotNodeEdgeFeatures 15 | 16 | # Change the following to point to the the folder where the datasets are stored 17 | if os.path.isdir('/datasets2/'): 18 | rootdir = '/datasets2/ZINC/' 19 | else: 20 | rootdir = './data/ZINC/' 21 | yaml_file = './config_zinc.yaml' 22 | 23 | torch.manual_seed(0) 24 | np.random.seed(0) 25 | 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--epochs', type=int, default=3000) 28 | parser.add_argument('--wandb', action='store_true', 29 | help="Use weights and biases library") 30 | parser.add_argument('--gpu', type=int, help='Id of gpu device. By default use cpu') 31 | parser.add_argument('--lr', type=float, default=0.001, help="Initial learning rate") 32 | parser.add_argument('--batch-size', type=int, default=128) 33 | parser.add_argument('--weight-decay', type=float, default=1e-6) 34 | parser.add_argument('--clip', type=float, default=10, help="Gradient clipping") 35 | parser.add_argument('--name', type=str, help="Name for weights and biases") 36 | parser.add_argument('--full', action='store_true') 37 | parser.add_argument('--lr-reduce-factor', type=float, default=0.5) 38 | parser.add_argument('--lr_schedule_patience', type=int, default=100) 39 | parser.add_argument('--save-model', action='store_true', help='Save the model after training') 40 | parser.add_argument('--load-model', action='store_true', help='Evaluate a pretrained model') 41 | parser.add_argument('--lr-limit', type=float, default=5e-6, help='Stop training once it is reached') 42 | args = parser.parse_args() 43 | 44 | args.subset = not args.full # Train either on the full dataset or the subset of 10k samples 45 | 46 | # Handle the device 47 | use_cuda = args.gpu is not None and torch.cuda.is_available() 48 | if use_cuda: 49 | device = torch.device("cuda:" + str(args.gpu)) 50 | torch.cuda.set_device(args.gpu) 51 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) 52 | else: 53 | device = "cpu" 54 | args.device = device 55 | args.kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} 56 | print('Device used:', device) 57 | 58 | # Load the config file of the model 59 | with open(yaml_file) as f: 60 | model_config = yaml.load(f, Loader=yaml.FullLoader) 61 | print(model_config) 62 | model_config['num_input_features'] = 28 if model_config['use_x'] else 29 63 | model_config['num_edge_features'] = 3 64 | model_config['num_classes'] = 1 65 | 66 | 67 | # Create a name for weights and biases 68 | model_name = 'Zinc_SMP' 69 | if args.name: 70 | args.wandb = True 71 | if args.wandb: 72 | import wandb 73 | if args.name is None: 74 | args.name = model_name + \ 75 | f"_{model_config['num_layers']}_{model_config['hidden']}_{model_config['hidden_final']}" 76 | wandb.init(project="smp-zinc-subset" if args.subset else "smp-zinc", config=model_config, name=args.name) 77 | wandb.config.update(args) 78 | 79 | 80 | # The paths can be changed here 81 | if args.save_model or args.load_model: 82 | if os.path.isdir('/SCRATCH2/'): 83 | savedir = '/SCRATCH2/vignac/SMP/saved_models/ZINC/' 84 | else: 85 | savedir = './saved_models/ZINC/' 86 | if not os.path.isdir(savedir): 87 | os.makedirs(savedir) 88 | 89 | 90 | def train(): 91 | """ Train for one epoch. """ 92 | model.train() 93 | loss_all = 0 94 | for batch_idx, data in enumerate(train_loader): 95 | data = data.to(device) 96 | optimizer.zero_grad() 97 | output = model(data) 98 | loss = loss_fct(output, data.y) 99 | loss.backward() 100 | loss_all += loss.item() 101 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) 102 | optimizer.step() 103 | return loss_all / len(train_loader.dataset) 104 | 105 | 106 | def test(loader): 107 | model.eval() 108 | total_mae = 0.0 109 | for data in loader: 110 | data = data.to(device) 111 | output = model(data) 112 | total_mae += loss_fct(output, data.y).item() 113 | average_mae = total_mae / len(loader.dataset) 114 | return average_mae 115 | 116 | 117 | start = time.time() 118 | 119 | model = SMPZinc(**model_config).to(device) 120 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 121 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', 122 | factor=args.lr_reduce_factor, 123 | patience=args.lr_schedule_patience, 124 | verbose=True) 125 | lr_limit = args.lr_limit 126 | 127 | if args.load_model: 128 | model = torch.load(savedir + model_name + '.pkl') 129 | 130 | pytorch_total_params = sum(p.numel() for p in model.parameters()) 131 | print("Total number of parameters", pytorch_total_params) 132 | 133 | loss_fct = nn.L1Loss(reduction='sum') 134 | 135 | # Load the data 136 | batch_size = args.batch_size 137 | transform = OneHotNodeEdgeFeatures(model_config['num_input_features'] - 1, model_config['num_edge_features']) 138 | 139 | train_data = ZINC(rootdir, subset=args.subset, split='train', pre_transform=transform) 140 | val_data = ZINC(rootdir, subset=args.subset, split='val', pre_transform=transform) 141 | test_data = ZINC(rootdir, subset=args.subset, split='test', pre_transform=transform) 142 | 143 | train_loader = DataLoader(train_data, batch_size, shuffle=True) 144 | val_loader = DataLoader(val_data, batch_size, shuffle=False) 145 | test_loader = DataLoader(test_data, batch_size, shuffle=False) 146 | 147 | print("Starting to train") 148 | for epoch in range(args.epochs): 149 | if args.load_model: 150 | break 151 | epoch_start = time.time() 152 | tr_loss = train() 153 | current_lr = optimizer.param_groups[0]["lr"] 154 | if current_lr < lr_limit: 155 | break 156 | duration = time.time() - epoch_start 157 | print(f'Time:{duration:2.2f} | {epoch:5d} | Train MAE: {tr_loss:2.5f} | LR: {current_lr:.6f}') 158 | mae_val = test(val_loader) 159 | scheduler.step(mae_val) 160 | print(f'MAE on the validation set: {mae_val:2.5f}') 161 | if args.wandb: 162 | wandb.log({"Epoch": epoch, "Duration": duration, "Train MAE": tr_loss, 163 | "Val MAE": mae_val}) 164 | 165 | if not args.load_model: 166 | cur_lr = optimizer.param_groups[0]["lr"] 167 | print(f'{epoch:2.5f} | Loss: {tr_loss:2.5f} | LR: {cur_lr:.6f} | Val MAE: {mae_val:2.5f}') 168 | print(f'Elapsed time: {(time.time() - start) / 60:.1f} minutes') 169 | print('done!') 170 | 171 | test_mae = test(test_loader) 172 | print(f"Final MAE on the test set: {test_mae}") 173 | print("Done.") 174 | 175 | if args.wandb: 176 | wandb.run.summary['Final test MAE'] = test_mae 177 | 178 | if args.save_model: 179 | torch.save(model, savedir + model_name + '.pkl') --------------------------------------------------------------------------------