├── .gitignore
├── .idea
├── .gitignore
├── SMP.iml
├── deployment.xml
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── other.xml
└── vcs.xml
├── LICENSE
├── README.md
├── config_cycles.yaml
├── config_multi_task.yaml
├── config_zinc.yaml
├── cycles_main.py
├── data
├── .DS_Store
├── datasets_kcycle_nsamples=10000
│ └── .DS_Store
└── multitask_dataset.pkl
├── datasets_generation
├── __pycache__
│ ├── build_cycles.cpython-37.pyc
│ ├── graph_algorithms.cpython-37.pyc
│ ├── graph_generation.cpython-37.pyc
│ └── multitask_dataset.cpython-37.pyc
├── build_cycles.py
├── graph_algorithms.py
├── graph_generation.py
└── multitask_dataset.py
├── models
├── .DS_Store
├── gin.py
├── model_cycles.py
├── model_multi_task.py
├── model_zinc.py
├── ppgn.py
├── ring_gnn.py
├── smp_layers.py
└── utils
│ ├── layers.py
│ ├── misc.py
│ └── transforms.py
├── multi_task_main.py
├── multi_task_utils
├── train.py
└── util.py
├── requirements.txt
├── saved_models
├── PPGN_4
│ └── epoch0.pkl
└── ZINC
│ └── Zinc_SMP.pkl
└── zinc_main.py
/.gitignore:
--------------------------------------------------------------------------------
1 | models/__pycache__/
2 | .idea/
3 | .DS_Store
4 | data/
5 | multi_task_utils/__pycache__/
6 | __pycache__/
7 | multi_task_utils/__pycache__/
8 | wandb/
9 | data/.DS_Store
10 | tests/
11 | saved_models/
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Datasource local storage ignored files
5 | /dataSources/
6 | /dataSources.local.xml
7 | # Editor-based HTTP Client requests
8 | /httpRequests/
9 |
--------------------------------------------------------------------------------
/.idea/SMP.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
15 |
16 |
17 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/other.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Gabriele Corso, Luca Cavalleri
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Building powerful and equivariant graph neural networks with structural message-passing
2 |
3 | This paper contains code for the paper *Building powerful and equivariant graph neural networks with structural message-passing* (Neurips 2020) by
4 | [Clément Vignac](https://cvignac.github.io/), [Andreas Loukas](https://andreasloukas.blog/) and [Pascal Frossard](https://www.epfl.ch/labs/lts4/people/people-current/frossard/).
5 | [Link to the paper](https://papers.nips.cc/paper/2020/file/a32d7eeaae19821fd9ce317f3ce952a7-Paper.pdf)
6 |
7 | Abstract:
8 |
9 | Message-passing has proved to be an effective way to design graph neural networks,
10 | as it is able to leverage both permutation equivariance and an inductive bias towards
11 | learning local structures in order to achieve good generalization. However, current
12 | message-passing architectures have a limited representation power and fail to learn
13 | basic topological properties of graphs. We address this problem and propose a
14 | powerful and equivariant message-passing framework based on two ideas: first,
15 | we propagate a one-hot encoding of the nodes, in addition to the features, in order
16 | to learn a local context matrix around each node. This matrix contains rich local
17 | information about both features and topology and can eventually be pooled to build
18 | node representations. Second, we propose methods for the parametrization of the
19 | message and update functions that ensure permutation equivariance. Having a
20 | representation that is independent of the specific choice of the one-hot encoding
21 | permits inductive reasoning and leads to better generalization properties. Experi-
22 | mentally, our model can predict various graph topological properties on synthetic
23 | data more accurately than previous methods and achieves state-of-the-art results on
24 | molecular graph regression on the ZINC dataset.
25 |
26 | ## Code overview
27 |
28 |
29 | This folder contains the source code used for Structural Message passing for three tasks:
30 | - Cycle detection
31 | - The multi-task regression of graph properties presented in [https://arxiv.org/abs/2004.05718](https://arxiv.org/abs/2004.05718)
32 | - Constrained solubility regression on ZINC
33 |
34 | Source code for the second task is adapted from [https://github.com/lukecavabarrett/pna](https://github.com/lukecavabarrett/pna).
35 |
36 |
37 | ## Dependencies
38 | [https://pytorch-geometric.readthedocs.io/en/latest/](Pytorch geometric) v1.6.1 was used. Please follow the instructions on the
39 | website, as simple installations via pip do not work. In particular, the version of pytorch used must match the one of torch-geometric.
40 |
41 | Then install the other dependencies:
42 | ```
43 | pip install -r requirements.txt
44 | ```
45 |
46 | ## Dataset generation
47 |
48 | ### Cycle detection
49 | First, download the data from https://drive.switch.ch/index.php/s/hv65hmY48GrRAoN
50 | and unzip it in data/datasets_kcycle_nsamples=10000. Then, run
51 |
52 | ```
53 | python3 datasets_generation/build_cycles.py
54 | ```
55 |
56 | ### Multi-task regression
57 | Simply run
58 | ```
59 | python -m datasets_generation.multitask_dataset
60 | ```
61 |
62 | ### ZINC
63 | We use the pytorch-geometric downloader, there should be nothing to to by hand.
64 | ## Folder structure
65 |
66 | - Each task is launched by running the corresponding *main* file (cycles_main, zinc_main, multi_task_main).
67 | - The model parameters can be changed in the associated config.yaml file, while training parameters are modified
68 | with command line arguments.
69 | - The model used for each task is located in the model folder (model_cycles,
70 | model_multi_task, model_zinc).
71 | - They all use some of the SMP layers parametrized in the smp_layers file.
72 | - All SMP layers use the same set of base functions in models/utils/layers.py. These functions map tensors of one order
73 | to tensors of another order using a predefined set of equivariant transformations.
74 |
75 | ## Train
76 |
77 | ### Cycle detection
78 |
79 | In order to train SMP, specify the cycle length, the size of the graphs that is used, and potentially the proportion of the training data
80 | that is kept. For example,
81 | ```
82 | python3 cycle_main.py --k 4 --n 12 --proportion 1.0 --gpu 0
83 | ```
84 | will train the 4-cycle on graph with on average 12 nodes on 1.0 * 100 = 100% of the training data.
85 |
86 | In order to run another model, modify models.config.yaml. To run a MPNN that has the
87 | same architecture as SMP, set use_x=True in this file.
88 |
89 | For MPNN and GIN, transforms can be specified in order to add a one-hot encoding of the node degrees,
90 | or one-hot identifiers. The available options can be seen by using
91 | ```
92 | python3 cycles_main.py --help
93 | ```
94 |
95 | ### Multi-task regression
96 |
97 | Specify the configuration in the file `config_multi_task.yaml`, and the the available options by using
98 | ```
99 | python3 multi_task_main.py --help
100 | ```
101 | To use default parameters, simply run:
102 | ```
103 | python3 multi_task_main.py --gpu 0
104 | ```
105 |
106 | ### ZINC
107 |
108 | The ZINC dataset is downloaded through pytorch geometric, but the destination folder should be specified at
109 | the beginning of `zinc_main.py`. Model parameters can be changed in `config_zinc.yaml`.
110 |
111 | To use default parameters, simply run:
112 | ```
113 | python3 zinc_main.py --gpu 0
114 | ```
115 |
116 | ## Use SMP on new data
117 |
118 | This code is currently not available as a library, so you will need to copy-paste files to adapt it to your
119 | own data.
120 | While most of the code can be reused, you may need to adapt the model to your own problem. We advise you to look at the
121 | different model files (model_cycles, model_multi_task, model_zinc) to see how they are built. They all follow the same
122 | design:
123 | - A local context is first created using the functions in models.utils.misc. If you have node features that
124 | you wish to use in SMP, use `map_x_to_u` to include them in the local contexts.
125 | - One of the three SMP layers (SMP, FastSMP, SimplifiedFastSMP) is used at each layer to update the local context.
126 | Then either some node-level features or some graph-level features are extracted. For this purpose, you can use
127 | the `NodeExtractor` and `GraphExtractor` classes in `models.utils.layers.py`.
128 | - The extracted features are processed by a standard neural network. You can use a multi-layer perceptron here, or
129 | a more complex structure such as a Gated Recurrent Network that will take as input the features extracted at
130 | each layer.
131 |
132 | To sum up, you need to copy the following files to your own folder:
133 | - models.smp_layers.py
134 | - models.utils.layers.py and models.utils.misc.py
135 |
136 | and to adapt the following files to your own problem:
137 | - the main file (e.g. zinc_main.py)
138 | - the config file (e.g. config_zinc.yaml)
139 | - the model file (e.g. models/model_zinc.py)
140 |
141 | We advise you to use the "weights and biases" library as well, as we found it very convenient to store results.
142 |
143 | ## License
144 | MIT
145 |
146 | ## Cite this paper
147 |
148 | @inproceedings{NEURIPS2020_a32d7eea,
149 | author = {Vignac, Cl\'{e}ment and Loukas, Andreas and Frossard, Pascal},
150 | booktitle = {Advances in Neural Information Processing Systems},
151 | editor = {H. Larochelle and M. Ranzato and R. Hadsell and M. F. Balcan and H. Lin},
152 | pages = {14143--14155},
153 | publisher = {Curran Associates, Inc.},
154 | title = {Building powerful and equivariant graph neural networks with structural message-passing},
155 | url = {https://proceedings.neurips.cc/paper/2020/file/a32d7eeaae19821fd9ce317f3ce952a7-Paper.pdf},
156 | volume = {33},
157 | year = {2020}
158 | }
159 |
160 |
161 |
--------------------------------------------------------------------------------
/config_cycles.yaml:
--------------------------------------------------------------------------------
1 | # Model properties
2 | model_name: GIN # PPGN, SMP, RING_GNN or GIN
3 | num_towers: 1
4 | hidden: 32
5 | hidden_final: 128
6 | dropout_prob: 0.5
7 | num_classes: 2
8 | use_x: False # Use_x is used for ablation studies
9 | num_layers: -1 # If None, set n_layers = k
10 |
11 | # Options specific to SMP
12 | layer_type: FastSMP
13 | simplified: False
14 |
15 | # Options specific to GIN
16 | one_hot: False # Use a one-hot encoding of the degree as node features
17 | identifiers: False # Use a one hot encoding of the nodes as node features
18 | random: False # Use random identifiers as node features
19 | relational_pooling: 0 # if == p > 0, sum over p random permutations of the nodes
20 |
--------------------------------------------------------------------------------
/config_multi_task.yaml:
--------------------------------------------------------------------------------
1 | # Model properties
2 | model_name: SMP
3 | num_layers: 8
4 | hidden_u: 64
5 | num_towers: 8
6 | out_u: 32
7 | hidden_gru: 16
8 | layer_type: SMP # SMP or FastSMP
--------------------------------------------------------------------------------
/config_zinc.yaml:
--------------------------------------------------------------------------------
1 | # Model properties
2 | hidden: 32 # internal representation
3 | num_towers: 8 # used within each SMP layer
4 | hidden_final: 128 # Extracted feature
5 | num_layers: 12
6 | use_x: False # used for ablation study
7 | use_batch_norm: True
8 | map_x_to_u: True # map the initial node features to the local context
9 | simplified: False # less layers in the feature extractor
10 | residual: False # residual connections when transorming local contexts
11 | use_edge_features: True
12 | shared_extractor: True # share the feature extractor across layers
--------------------------------------------------------------------------------
/cycles_main.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 | import os
5 | import torch
6 | import torch.nn.functional as F
7 | from torch_geometric.data import DataLoader
8 | from torch_geometric.transforms import OneHotDegree
9 | import argparse
10 | import numpy as np
11 | import time
12 | import yaml
13 | from models.model_cycles import SMP
14 | from models.gin import GIN
15 | from datasets_generation.build_cycles import FourCyclesDataset
16 | from models.utils.transforms import EyeTransform, RandomId, DenseAdjMatrix
17 | from models import ppgn
18 | from models.ring_gnn import RingGNN
19 | from easydict import EasyDict as edict
20 |
21 |
22 | # Change the following to point to the the folder where the datasets are stored
23 | if os.path.isdir('/datasets2/'):
24 | rootdir = '/datasets2/CYCLE_DETECTION/'
25 | else:
26 | rootdir = './data/datasets_kcycle_nsamples=10000/'
27 | yaml_file = './config_cycles.yaml'
28 | # yaml_file = './benchmark/kernel/config4cycles.yaml'
29 | torch.manual_seed(0)
30 | np.random.seed(0)
31 |
32 | parser = argparse.ArgumentParser()
33 | parser.add_argument('--epochs', type=int, default=300)
34 | parser.add_argument('--k', type=int, default=4,
35 | help="Length of the cycles to detect")
36 | parser.add_argument('--n', type=int, help='Average number of nodes in the graphs')
37 | parser.add_argument('--save-model', action='store_true',
38 | help='Save the model once training is done')
39 | parser.add_argument('--wandb', action='store_true',
40 | help="Use weights and biases library")
41 | parser.add_argument('--gpu', type=int, help='Id of gpu device. By default use cpu')
42 | parser.add_argument('--lr', type=float, default=0.001, help="Initial learning rate")
43 | parser.add_argument('--batch-size', type=int, default=16)
44 | parser.add_argument('--weight-decay', type=float, default=1e-4)
45 | parser.add_argument('--clip', type=float, default=10, help="Gradient clipping")
46 | parser.add_argument('--name', type=str, help="Name for weights and biases")
47 | parser.add_argument('--proportion', type=float, default=1.0,
48 | help='Proportion of the training data that is kept')
49 | parser.add_argument('--generalization', action='store_true',
50 | help='Evaluate out of distribution accuracy')
51 | args = parser.parse_args()
52 |
53 | # Log parameters
54 | test_every_epoch = 5
55 | print_every_epoch = 1
56 | log_interval = 20
57 |
58 | # Store maximum number of nodes for each pair (k, n) -- this value is used by provably powerful graph networks
59 | max_num_nodes = {4: {12: 12, 20: 20, 28: 28, 36: 36},
60 | 6: {20: 25, 31: 38, 42: 52, 56: 65},
61 | 8: {28: 38, 50: 56, 66: 76, 72: 90}}
62 | # Store the maximum degree for the one-hot encoding
63 | max_degree = {4: {12: 4, 20: 6, 28: 7, 36: 7},
64 | 6: {20: 4, 31: 6, 42: 8, 56: 7},
65 | 8: {28: 4, 50: 6, 66: 7, 72: 8}}
66 | # Store the values of n to use for generalization experiments
67 | n_gener = {4: {'train': 20, 'val': 28, 'test': 36},
68 | 6: {'train': 31, 'val': 42, 'test': 56},
69 | 8: {'train': 50, 'val': 66, 'test': 72}}
70 |
71 | # Handle the device
72 | use_cuda = args.gpu is not None and torch.cuda.is_available()
73 | if use_cuda:
74 | device = torch.device("cuda:" + str(args.gpu))
75 | torch.cuda.set_device(args.gpu)
76 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
77 | else:
78 | device = "cpu"
79 | args.device = device
80 | args.kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
81 | print('Device used:', device)
82 |
83 | # Load the config file of the model
84 | with open(yaml_file) as f:
85 | config = yaml.load(f, Loader=yaml.FullLoader)
86 | config['map_x_to_u'] = False # Not used here
87 | config = edict(config)
88 | print(config)
89 |
90 | model_name = config['model_name']
91 |
92 | config.pop('model_name')
93 | if model_name == 'SMP':
94 | model_name = config['layer_type']
95 |
96 | if args.name is None:
97 | if model_name != 'GIN':
98 | args.name = model_name
99 | else:
100 | if config.relational_pooling > 0:
101 | args.name = 'RP'
102 | elif config.one_hot:
103 | args.name = 'OneHotDeg'
104 | elif config.identifiers:
105 | args.name = 'OneHotNod'
106 | elif config.random:
107 | args.name = 'Random'
108 | else:
109 | args.name = 'GIN'
110 | args.name = args.name + '_' + str(args.k)
111 | if args.n is not None:
112 | args.name = args.name + '_' + str(args.n)
113 |
114 | # Create a folder for the saved models
115 | if not os.path.isdir('./saved_models/' + args.name) and args.generalization:
116 | os.mkdir('./saved_models/' + args.name)
117 |
118 |
119 | if args.name:
120 | args.wandb = True
121 | if args.wandb:
122 | import wandb
123 | wandb.init(project="smp", config=config, name=args.name)
124 | wandb.config.update(args)
125 |
126 | if args.n is None:
127 | args.n = n_gener[args.k]['train']
128 |
129 | if config.num_layers == -1:
130 | config.num_layers = args.k
131 |
132 |
133 | def train(epoch):
134 | """ Train for one epoch. """
135 | model.train()
136 | lr_scheduler(args.lr, epoch, optimizer)
137 | loss_all = 0
138 | if not config.relational_pooling:
139 | for batch_idx, data in enumerate(train_loader):
140 | data = data.to(device)
141 | optimizer.zero_grad()
142 | output = model(data)
143 | loss = F.nll_loss(output, data.y)
144 | loss.backward()
145 | loss_all += loss.item() * data.num_graphs
146 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
147 | optimizer.step()
148 | return loss_all / len(train_loader.dataset)
149 | else:
150 | # For relational pooling, we sample several permutations of each graph
151 | for batch_idx, data in enumerate(train_loader):
152 | for repetition in range(config.relational_pooling):
153 | for i in range(args.batch_size):
154 | n_nodes = int(torch.sum(data.batch == i).item())
155 | p = torch.randperm(n_nodes)
156 | data.x[data.batch == i, :n_nodes] = data.x[data.batch == i, :n_nodes][p, :][:, p]
157 | data = data.to(device)
158 | optimizer.zero_grad()
159 | output = model(data)
160 | loss = F.nll_loss(output, data.y)
161 | loss.backward()
162 | loss_all += loss.item() * data.num_graphs
163 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
164 | optimizer.step()
165 | return loss_all / len(train_loader.dataset)
166 |
167 |
168 | def test(loader):
169 | model.eval()
170 | correct = 0
171 | for data in loader:
172 | data = data.to(device)
173 | output = model(data)
174 | pred = output.max(dim=1)[1]
175 | correct += pred.eq(data.y).sum().item()
176 | return correct / len(loader.dataset)
177 |
178 |
179 | def lr_scheduler(lr, epoch, optimizer):
180 | for param_group in optimizer.param_groups:
181 | param_group['lr'] = lr * (0.995 ** (epoch / 5))
182 |
183 |
184 | # Define the transform to use in the dataset
185 | transform=None
186 | if 'GIN' or 'RP' in model_name:
187 | if config.one_hot:
188 | # Cannot always be used in an inductive setting,
189 | # because the maximal degree might be bigger than during training
190 | degree = max_degree[args.k][args.n]
191 | transform = OneHotDegree(degree, cat=False)
192 | config.num_input_features = degree + 1
193 | elif config.identifiers:
194 | # Cannot be used in an inductive setting
195 | transform = EyeTransform(max_num_nodes[args.k][args.n])
196 | config.num_input_features = max_num_nodes[args.k][args.n]
197 | elif config.random:
198 | # Can be used in an inductive setting
199 | transform = RandomId()
200 | transform_val = RandomId()
201 | transform_test = RandomId()
202 | config.num_input_features = 1
203 |
204 | if transform is None:
205 | transform_val = None
206 | transform_test = None
207 | config.num_input_features = 1
208 |
209 | if 'SMP' in model_name:
210 | config.use_batch_norm = args.k > 6 or args.n > 30
211 | model = SMP(config.num_input_features, config.num_classes, config.num_layers, config.hidden, config.layer_type,
212 | config.hidden_final, config.dropout_prob, config.use_batch_norm, config.use_x, config.map_x_to_u,
213 | config.num_towers, config.simplified).to(device)
214 |
215 | elif model_name == 'PPGN':
216 | transform = DenseAdjMatrix(max_num_nodes[args.k][args.n])
217 | transform_val = DenseAdjMatrix(max_num_nodes[args.k][n_gener[args.k]['val']])
218 | transform_test = DenseAdjMatrix(max_num_nodes[args.k][n_gener[args.k]['test']])
219 | model = ppgn.Powerful(config.num_classes, config.num_layers, config.hidden,
220 | config.hidden_final, config.dropout_prob, config.simplified)
221 | elif model_name == 'GIN':
222 | config.use_batch_norm = args.k > 6 or args.n > 50
223 | model = GIN(config.num_input_features, config.num_classes, config.num_layers,
224 | config.hidden, config.hidden_final, config.dropout_prob, config.use_batch_norm)
225 | elif model_name == 'RING_GNN':
226 | transform = DenseAdjMatrix(max_num_nodes[args.k][args.n])
227 | transform_val = DenseAdjMatrix(max_num_nodes[args.k][n_gener[args.k]['val']])
228 | transform_test = DenseAdjMatrix(max_num_nodes[args.k][n_gener[args.k]['test']])
229 | model = RingGNN(config.num_classes, config.num_layers, config.hidden, config.hidden_final, config.dropout_prob,
230 | config.simplified)
231 |
232 | model = model.to(device)
233 |
234 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.5, weight_decay=args.weight_decay)
235 | # Load the data
236 | print("Transform used:", transform)
237 |
238 | batch_size = args.batch_size
239 | if args.generalization:
240 | train_data = FourCyclesDataset(args.k, n_gener[args.k]['train'], rootdir, train=True, transform=transform)
241 | test_data = FourCyclesDataset(args.k, n_gener[args.k]['train'], rootdir, train=False, transform=transform)
242 | gener_data_val = FourCyclesDataset(args.k, n_gener[args.k]['val'], rootdir, train=False, transform=transform_val)
243 | train_loader = DataLoader(train_data, batch_size, shuffle=True)
244 | test_loader = DataLoader(test_data, batch_size, shuffle=False)
245 | gener_val_loader = DataLoader(gener_data_val, batch_size, shuffle=False)
246 |
247 | else:
248 | train_data = FourCyclesDataset(args.k, args.n, rootdir, proportion=args.proportion, train=True, transform=transform)
249 | test_data = FourCyclesDataset(args.k, args.n, rootdir, proportion=args.proportion, train=False, transform=transform)
250 | train_loader = DataLoader(train_data, batch_size, shuffle=True)
251 | test_loader = DataLoader(test_data, batch_size, shuffle=False)
252 |
253 | print("Starting to train")
254 | start = time.time()
255 | best_epoch = -1
256 | best_generalization_acc = 0
257 | for epoch in range(args.epochs):
258 | epoch_start = time.time()
259 | tr_loss = train(epoch)
260 | if epoch % print_every_epoch == 0:
261 | acc_train = test(train_loader)
262 | current_lr = optimizer.param_groups[0]["lr"]
263 | duration = time.time() - epoch_start
264 | print(f'Time:{duration:2.2f} | {epoch:5d} | Loss: {tr_loss:2.5f} | Train Acc: {acc_train:2.5f} | LR: {current_lr:.6f}')
265 | if epoch % test_every_epoch == 0:
266 | acc_test = test(test_loader)
267 | print(f'Test accuracy: {acc_test:2.5f}')
268 | if args.generalization:
269 | acc_generalization = test(gener_val_loader)
270 | print("Validation generalization accuracy", acc_generalization)
271 | if args.wandb:
272 | wandb.log({"Epoch": epoch, "Duration": duration, "Train loss": tr_loss, "train accuracy": acc_train,
273 | "Test acc": acc_test, 'Gene eval': acc_generalization})
274 | if acc_generalization > best_generalization_acc:
275 | print(f"New best generalization error + accuracy > 90% at epoch {epoch}")
276 | # Remove existing models
277 | folder = f'./saved_models/{args.name}/'
278 | files_in_folder = [f for f in os.listdir(folder) if os.path.isfile(os.path.join(folder, f))]
279 | for file in files_in_folder:
280 | try:
281 | os.remove(folder + file)
282 | except:
283 | print("Could not remove file", file)
284 | # Save new model
285 | torch.save(model, f'./saved_models/{args.name}/epoch{epoch}.pkl')
286 | print(f"Model saved at epoch {epoch}.")
287 | best_epoch = epoch
288 | else:
289 | if args.wandb:
290 | wandb.log({"Epoch": epoch, "Duration": duration, "Train loss": tr_loss, "train accuracy": acc_train,
291 | "Test acc": acc_test})
292 | else:
293 | if args.wandb:
294 | wandb.log({"Epoch": epoch, "Duration": duration, "Train loss": tr_loss, "train accuracy": acc_train})
295 |
296 | cur_lr = optimizer.param_groups[0]["lr"]
297 | print(f'{epoch:2.5f} | Loss: {tr_loss:2.5f} | Train Acc: {acc_train:2.5f} | LR: {cur_lr:.6f} | Test Acc: {acc_test:2.5f}')
298 | print(f'Elapsed time: {(time.time() - start) / 60:.1f} minutes')
299 | print('done!')
300 |
301 | final_acc = test(test_loader)
302 | print(f"Final accuracy: {final_acc}")
303 | print("Done.")
304 |
305 | if args.generalization:
306 | new_n = n_gener[args.k]['test']
307 | gener_data_test = FourCyclesDataset(args.k, new_n, rootdir, train=False, transform=transform_test)
308 | gener_test_loader = DataLoader(gener_data_test, batch_size, shuffle=False)
309 | model = torch.load(f"./saved_models/{args.name}/epoch{best_epoch}.pkl", map_location=device)
310 | model.eval()
311 | acc_test_generalization = test(gener_test_loader)
312 | print(f"Generalization accuracy on {args.k} cycles with {new_n} nodes", acc_test_generalization)
313 | if args.wandb:
314 | wandb.run.summary['test_generalization'] = acc_test_generalization
315 |
--------------------------------------------------------------------------------
/data/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cvignac/SMP/95b55a880d0fc9149ddf32e8c2fdf5eac5b474b3/data/.DS_Store
--------------------------------------------------------------------------------
/data/datasets_kcycle_nsamples=10000/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cvignac/SMP/95b55a880d0fc9149ddf32e8c2fdf5eac5b474b3/data/datasets_kcycle_nsamples=10000/.DS_Store
--------------------------------------------------------------------------------
/data/multitask_dataset.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cvignac/SMP/95b55a880d0fc9149ddf32e8c2fdf5eac5b474b3/data/multitask_dataset.pkl
--------------------------------------------------------------------------------
/datasets_generation/__pycache__/build_cycles.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cvignac/SMP/95b55a880d0fc9149ddf32e8c2fdf5eac5b474b3/datasets_generation/__pycache__/build_cycles.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets_generation/__pycache__/graph_algorithms.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cvignac/SMP/95b55a880d0fc9149ddf32e8c2fdf5eac5b474b3/datasets_generation/__pycache__/graph_algorithms.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets_generation/__pycache__/graph_generation.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cvignac/SMP/95b55a880d0fc9149ddf32e8c2fdf5eac5b474b3/datasets_generation/__pycache__/graph_generation.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets_generation/__pycache__/multitask_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cvignac/SMP/95b55a880d0fc9149ddf32e8c2fdf5eac5b474b3/datasets_generation/__pycache__/multitask_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets_generation/build_cycles.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import pickle
4 | from torch_geometric.data import InMemoryDataset, Data
5 | import numpy as np
6 | import networkx as nx
7 | import numpy.random as npr
8 |
9 |
10 | if os.path.isdir('/datasets2/'):
11 | rootdir = '/datasets2/CYCLE_DETECTION/'
12 | else:
13 | rootdir = './data/datasets_kcycle_nsamples=10000/'
14 |
15 |
16 | def build_dataset():
17 | """ Given pickle files, split the dataset into one per value of n
18 | Run once before running the experiments. """
19 | n_samples = 10000
20 | for k in [4, 6, 8]:
21 | with open(os.path.join(rootdir, 'datasets_kcycle_k={}_nsamples=10000.pickle'.format(k)), 'rb') as f:
22 | datasets_params, datasets = pickle.load(f)
23 | # Split by graph size
24 | for params, dataset in zip(datasets_params, datasets):
25 | n = params['n']
26 | train, test = dataset[:n_samples], dataset[n_samples:]
27 | torch.save(train, rootdir + f'{k}cycles_n{n}_{n_samples}samples_train.pt')
28 | torch.save(test, rootdir + f'/{k}cycles_n{n}_{n_samples}samples_test.pt')
29 | # torch.save(test, '{}cycles_n{}_{}samples_test.pt'.format(k, n, n_samples))
30 |
31 |
32 | class FourCyclesDataset(InMemoryDataset):
33 | def __init__(self, k, n, root, train, proportion=1.0, n_samples=10000, transform=None, pre_transform=None):
34 | self.train = train
35 | self.k, self.n, self.n_samples = k, n, n_samples
36 | self.root = root
37 | self.s = 'train' if train else 'test'
38 | self.proportion = proportion
39 | super().__init__(root, transform, pre_transform)
40 | self.data, self.slices = torch.load(self.processed_paths[0])
41 |
42 | @property
43 | def raw_file_names(self):
44 | return ['{}cycles_n{}_{}samples_{}.pt'.format(self.k, self.n, self.n_samples, self.s)]
45 |
46 | @property
47 | def processed_file_names(self):
48 | if self.transform is None:
49 | st = 'no-transf'
50 | else:
51 | st = str(self.transform.__class__.__name__)
52 | return [f'processed_{self.k}cycles_n{self.n}_{self.n_samples}samples_{self.s}_{st}_{self.proportion}.pt']
53 |
54 | def download(self):
55 | # Download to `self.raw_dir`.
56 | pass
57 |
58 | def process(self):
59 | # Read data into huge `Data` list.
60 | dataset = torch.load(os.path.join(self.root, f'{self.k}cycles_n{self.n}_{self.n_samples}samples_{self.s}.pt'))
61 |
62 | data_list = []
63 | for sample in dataset:
64 | graph, y, label = sample
65 | edge_list = nx.to_edgelist(graph)
66 | edges = [np.array([edge[0], edge[1]]) for edge in edge_list]
67 | edges2 = [np.array([edge[1], edge[0]]) for edge in edge_list]
68 |
69 | edge_index = torch.tensor(np.array(edges + edges2).T, dtype=torch.long)
70 |
71 | x = torch.ones(graph.number_of_nodes(), 1, dtype=torch.float)
72 | y = torch.tensor([1], dtype=torch.long) if label == 'has-kcycle' else torch.tensor([0], dtype=torch.long)
73 |
74 | data_list.append(Data(x=x, edge_index=edge_index, edge_attr=None, y=y))
75 | # Subsample the data
76 | if self.train:
77 | all_data = len(data_list)
78 | to_select = int(all_data * self.proportion)
79 | print(to_select, "samples were selected")
80 | data_list = data_list[:to_select]
81 | data, slices = self.collate(data_list)
82 | torch.save((data, slices), self.processed_paths[0])
83 |
84 |
85 | if __name__ == '__main__':
86 | build_dataset()
--------------------------------------------------------------------------------
/datasets_generation/graph_algorithms.py:
--------------------------------------------------------------------------------
1 | import math
2 | from queue import Queue
3 |
4 | import numpy as np
5 |
6 |
7 | def is_connected(A):
8 | """
9 | :param A:np.array the adjacency matrix
10 | :return:bool whether the graph is connected or not
11 | """
12 | for _ in range(int(1 + math.ceil(math.log2(A.shape[0])))):
13 | A = np.dot(A, A)
14 | return np.min(A) > 0
15 |
16 |
17 | def identity(A, F):
18 | """
19 | :param A:np.array the adjacency matrix
20 | :param F:np.array the nodes features
21 | :return:F
22 | """
23 | return F
24 |
25 |
26 | def first_neighbours(A):
27 | """
28 | :param A:np.array the adjacency matrix
29 | :param F:np.array the nodes features
30 | :return: for each node, the number of nodes reachable in 1 hop
31 | """
32 | return np.sum(A > 0, axis=0)
33 |
34 |
35 | def second_neighbours(A):
36 | """
37 | :param A:np.array the adjacency matrix
38 | :param F:np.array the nodes features
39 | :return: for each node, the number of nodes reachable in no more than 2 hops
40 | """
41 | A = A > 0.0
42 | A = A + np.dot(A, A)
43 | np.fill_diagonal(A, 0)
44 | return np.sum(A > 0, axis=0)
45 |
46 |
47 | def kth_neighbours(A, k):
48 | """
49 | :param A:np.array the adjacency matrix
50 | :param F:np.array the nodes features
51 | :return: for each node, the number of nodes reachable in k hops
52 | """
53 | A = A > 0.0
54 | R = np.zeros(A.shape)
55 | for _ in range(k):
56 | R = np.dot(R, A) + A
57 | np.fill_diagonal(R, 0)
58 | return np.sum(R > 0, axis=0)
59 |
60 |
61 | def map_reduce_neighbourhood(A, F, f_reduce, f_map=None, hops=1, consider_itself=False):
62 | """
63 | :param A:np.array the adjacency matrix
64 | :param F:np.array the nodes features
65 | :return: for each node, map its neighbourhood with f_map, and reduce it with f_reduce
66 | """
67 | if f_map is not None:
68 | F = f_map(F)
69 | A = np.array(A)
70 |
71 | A = A > 0
72 | R = np.zeros(A.shape)
73 | for _ in range(hops):
74 | R = np.dot(R, A) + A
75 | np.fill_diagonal(R, 1 if consider_itself else 0)
76 | R = R > 0
77 |
78 | return np.array([f_reduce(F[R[i]]) for i in range(A.shape[0])])
79 |
80 |
81 | def max_neighbourhood(A, F):
82 | """
83 | :param A:np.array the adjacency matrix
84 | :param F:np.array the nodes features
85 | :return: for each node, the maximum in its neighbourhood
86 | """
87 | return map_reduce_neighbourhood(A, F, np.max, consider_itself=True)
88 |
89 |
90 | def min_neighbourhood(A, F):
91 | """
92 | :param A:np.array the adjacency matrix
93 | :param F:np.array the nodes features
94 | :return: for each node, the minimum in its neighbourhood
95 | """
96 | return map_reduce_neighbourhood(A, F, np.min, consider_itself=True)
97 |
98 |
99 | def std_neighbourhood(A, F):
100 | """
101 | :param A:np.array the adjacency matrix
102 | :param F:np.array the nodes features
103 | :return: for each node, the standard deviation of its neighbourhood
104 | """
105 | return map_reduce_neighbourhood(A, F, np.std, consider_itself=True)
106 |
107 |
108 | def mean_neighbourhood(A, F):
109 | """
110 | :param A:np.array the adjacency matrix
111 | :param F:np.array the nodes features
112 | :return: for each node, the mean of its neighbourhood
113 | """
114 | return map_reduce_neighbourhood(A, F, np.mean, consider_itself=True)
115 |
116 |
117 | def local_maxima(A, F):
118 | """
119 | :param A:np.array the adjacency matrix
120 | :param F:np.array the nodes features
121 | :return: for each node, whether it is the maximum in its neighbourhood
122 | """
123 | return F == map_reduce_neighbourhood(A, F, np.max, consider_itself=True)
124 |
125 |
126 | def graph_laplacian(A):
127 | """
128 | :param A:np.array the adjacency matrix
129 | :return: the laplacian of the adjacency matrix
130 | """
131 | L = (A > 0) * -1
132 | np.fill_diagonal(L, np.sum(A > 0, axis=0))
133 | return L
134 |
135 |
136 | def graph_laplacian_features(A, F):
137 | """
138 | :param A:np.array the adjacency matrix
139 | :param F:np.array the nodes features
140 | :return: the laplacian of the adjacency matrix multiplied by the features
141 | """
142 | return np.matmul(graph_laplacian(A), F)
143 |
144 |
145 | def isomorphism(A1, A2, F1=None, F2=None):
146 | """
147 | Takes two adjacency matrices (A1,A2) and (optionally) two lists of features. It uses Weisfeiler-Lehman algorithms, so false positives might arise
148 | :param A1: adj_matrix, N*N numpy matrix
149 | :param A2: adj_matrix, N*N numpy matrix
150 | :param F1: node_values, numpy array of size N
151 | :param F1: node_values, numpy array of size N
152 | :return: isomorphic: boolean which is false when the two graphs are not isomorphic, true when they probably are.
153 | """
154 | N = A1.shape[0]
155 | if (F1 is None) ^ (F2 is None):
156 | raise ValueError("either both or none between F1,F2 must be defined.")
157 | if F1 is None:
158 | # Assign same initial value to each node
159 | F1 = np.ones(N, int)
160 | F2 = np.ones(N, int)
161 | else:
162 | if not np.array_equal(np.sort(F1), np.sort(F2)):
163 | return False
164 | if F1.dtype() != int:
165 | raise NotImplementedError('Still have to implement this')
166 |
167 | p = 1000000007
168 |
169 | def mapping(F):
170 | return (F * 234 + 133) % 1000000007
171 |
172 | def adjacency_hash(F):
173 | F = np.sort(F)
174 | b = 257
175 |
176 | h = 0
177 | for f in F:
178 | h = (b * h + f) % 1000000007
179 | return h
180 |
181 | for i in range(N):
182 | F1 = map_reduce_neighbourhood(A1, F1, adjacency_hash, f_map=mapping, consider_itself=True, hops=1)
183 | F2 = map_reduce_neighbourhood(A2, F2, adjacency_hash, f_map=mapping, consider_itself=True, hops=1)
184 | if not np.array_equal(np.sort(F1), np.sort(F2)):
185 | return False
186 | return True
187 |
188 |
189 | def count_edges(A):
190 | """
191 | :param A:np.array the adjacency matrix
192 | :return: the number of edges in the graph
193 | """
194 | return np.sum(A) / 2
195 |
196 |
197 | def is_eulerian_cyclable(A):
198 | """
199 | :param A:np.array the adjacency matrix
200 | :return: whether the graph has an eulerian cycle
201 | """
202 | return is_connected(A) and np.count_nonzero(first_neighbours(A) % 2 == 1) == 0
203 |
204 |
205 | def is_eulerian_percorrible(A):
206 | """
207 | :param A:np.array the adjacency matrix
208 | :return: whether the graph has an eulerian path
209 | """
210 | return is_connected(A) and np.count_nonzero(first_neighbours(A) % 2 == 1) in [0, 2]
211 |
212 |
213 | def map_reduce_graph(A, F, f_reduce):
214 | """
215 | :param A:np.array the adjacency matrix
216 | :param F:np.array the nodes features
217 | :return: the features of the nodes reduced by f_reduce
218 | """
219 | return f_reduce(F)
220 |
221 |
222 | def mean_graph(A, F):
223 | """
224 | :param A:np.array the adjacency matrix
225 | :param F:np.array the nodes features
226 | :return: the mean of the features
227 | """
228 | return map_reduce_graph(A, F, np.mean)
229 |
230 |
231 | def max_graph(A, F):
232 | """
233 | :param A:np.array the adjacency matrix
234 | :param F:np.array the nodes features
235 | :return: the maximum of the features
236 | """
237 | return map_reduce_graph(A, F, np.max)
238 |
239 |
240 | def min_graph(A, F):
241 | """
242 | :param A:np.array the adjacency matrix
243 | :param F:np.array the nodes features
244 | :return: the minimum of the features
245 | """
246 | return map_reduce_graph(A, F, np.min)
247 |
248 |
249 | def std_graph(A, F):
250 | """
251 | :param A:np.array the adjacency matrix
252 | :param F:np.array the nodes features
253 | :return: the standard deviation of the features
254 | """
255 | return map_reduce_graph(A, F, np.std)
256 |
257 |
258 | def has_hamiltonian_cycle(A):
259 | """
260 | :param A:np.array the adjacency matrix
261 | :return:bool whether the graph has an hamiltonian cycle
262 | """
263 | A += np.transpose(A)
264 | A = A > 0
265 | V = A.shape[0]
266 |
267 | def ham_cycle_loop(pos):
268 | if pos == V:
269 | if A[path[pos - 1]][path[0]]:
270 | return True
271 | else:
272 | return False
273 | for v in range(1, V):
274 | if A[path[pos - 1]][v] and not used[v]:
275 | path[pos] = v
276 | used[v] = True
277 | if ham_cycle_loop(pos + 1):
278 | return True
279 | path[pos] = -1
280 | used[v] = False
281 | return False
282 |
283 | used = [False] * V
284 | path = [-1] * V
285 | path[0] = 0
286 |
287 | return ham_cycle_loop(1)
288 |
289 |
290 | def all_pairs_shortest_paths(A, inf_sub=math.inf):
291 | """
292 | :param A:np.array the adjacency matrix
293 | :param inf_sub: the placeholder value to use for pairs which are not connected
294 | :return:np.array all pairs shortest paths
295 | """
296 | A = np.array(A)
297 | N = A.shape[0]
298 | for i in range(N):
299 | for j in range(N):
300 | if A[i][j] == 0:
301 | A[i][j] = math.inf
302 | if i == j:
303 | A[i][j] = 0
304 |
305 | for k in range(N):
306 | for i in range(N):
307 | for j in range(N):
308 | A[i][j] = min(A[i][j], A[i][k] + A[k][j])
309 |
310 | A = np.where(A == math.inf, inf_sub, A)
311 | return A
312 |
313 |
314 | def diameter(A):
315 | """
316 | :param A:np.array the adjacency matrix
317 | :return: the diameter of the gra[h
318 | """
319 | sum = np.sum(A)
320 | apsp = all_pairs_shortest_paths(A)
321 | apsp = np.where(apsp < sum + 1, apsp, -1)
322 | return np.max(apsp)
323 |
324 |
325 | def eccentricity(A):
326 | """
327 | :param A:np.array the adjacency matrix
328 | :return: the eccentricity of the gra[h
329 | """
330 | sum = np.sum(A)
331 | apsp = all_pairs_shortest_paths(A)
332 | apsp = np.where(apsp < sum + 1, apsp, -1)
333 | return np.max(apsp, axis=0)
334 |
335 |
336 | def sssp_predecessor(A, F):
337 | """
338 | :param A:np.array the adjacency matrix
339 | :param F:np.array the nodes features
340 | :return: for each node, the best next step to reach the designated source
341 | """
342 | assert (np.sum(F) == 1)
343 | assert (np.max(F) == 1)
344 | s = np.argmax(F)
345 | N = A.shape[0]
346 | P = np.zeros(A.shape)
347 | V = np.zeros(N)
348 | bfs = Queue()
349 | bfs.put(s)
350 | V[s] = 1
351 | while not bfs.empty():
352 | u = bfs.get()
353 | for v in range(N):
354 | if A[u][v] > 0 and V[v] == 0:
355 | V[v] = 1
356 | P[v][u] = 1
357 | bfs.put(v)
358 | return P
359 |
360 |
361 | def max_eigenvalue(A):
362 | """
363 | :param A:np.array the adjacency matrix
364 | :return: the maximum eigenvalue of A
365 | since A is positive symmetric, all the eigenvalues are guaranteed to be real
366 | """
367 | [W, _] = np.linalg.eig(A)
368 | return W[np.argmax(np.absolute(W))].real
369 |
370 |
371 | def max_eigenvalues(A, k):
372 | """
373 | :param A:np.array the adjacency matrix
374 | :param k:int the number of eigenvalues to be selected
375 | :return: the k greatest (by absolute value) eigenvalues of A
376 | """
377 | [W, _] = np.linalg.eig(A)
378 | values = W[sorted(range(len(W)), key=lambda x: -np.absolute(W[x]))[:k]]
379 | return values.real
380 |
381 |
382 | def max_absolute_eigenvalues(A, k):
383 | """
384 | :param A:np.array the adjacency matrix
385 | :param k:int the number of eigenvalues to be selected
386 | :return: the absolute value of the k greatest (by absolute value) eigenvalues of A
387 | """
388 | return np.absolute(max_eigenvalues(A, k))
389 |
390 |
391 | def max_absolute_eigenvalues_laplacian(A, n):
392 | """
393 | :param A:np.array the adjacency matrix
394 | :param k:int the number of eigenvalues to be selected
395 | :return: the absolute value of the k greatest (by absolute value) eigenvalues of the laplacian of A
396 | """
397 | A = graph_laplacian(A)
398 | return np.absolute(max_eigenvalues(A, n))
399 |
400 |
401 | def max_eigenvector(A):
402 | """
403 | :param A:np.array the adjacency matrix
404 | :return: the maximum (by absolute value) eigenvector of A
405 | since A is positive symmetric, all the eigenvectors are guaranteed to be real
406 | """
407 | [W, V] = np.linalg.eig(A)
408 | return V[:, np.argmax(np.absolute(W))].real
409 |
410 |
411 | def spectral_radius(A):
412 | """
413 | :param A:np.array the adjacency matrix
414 | :return: the maximum (by absolute value) eigenvector of A
415 | since A is positive symmetric, all the eigenvectors are guaranteed to be real
416 | """
417 | return np.abs(max_eigenvalue(A))
418 |
419 |
420 | def page_rank(A, F=None, iter=64):
421 | """
422 | :param A:np.array the adjacency matrix
423 | :param F:np.array with initial weights. If None, uniform initialization will happen.
424 | :param iter: log2 of length of power iteration
425 | :return: for each node, its pagerank
426 | """
427 |
428 | # normalize A rows
429 | A = np.array(A)
430 | A /= A.sum(axis=1)[:, np.newaxis]
431 |
432 | # power iteration
433 | for _ in range(iter):
434 | A = np.matmul(A, A)
435 |
436 | # generate prior distribution
437 | if F is None:
438 | F = np.ones(A.shape[-1])
439 | else:
440 | F = np.array(F)
441 |
442 | # normalize prior
443 | F /= np.sum(F)
444 |
445 | # compute limit distribution
446 | return np.matmul(F, A)
447 |
448 |
449 | def tsp_length(A, F=None):
450 | """
451 | :param A:np.array the adjacency matrix
452 | :param F:np.array determining which nodes are to be visited. If None, all of them are.
453 | :return: the length of the Traveling Salesman Problem shortest solution
454 | """
455 |
456 | A = all_pairs_shortest_paths(A)
457 | N = A.shape[0]
458 | if F is None:
459 | F = np.ones(N)
460 | targets = np.nonzero(F)[0]
461 | T = targets.shape[0]
462 | S = (1 << T)
463 | dp = np.zeros((S, T))
464 |
465 | def popcount(x):
466 | b = 0
467 | while x > 0:
468 | x &= x - 1
469 | b += 1
470 | return b
471 |
472 | msks = np.argsort(np.vectorize(popcount)(np.arange(S)))
473 | for i in range(T + 1):
474 | for j in range(T):
475 | if (1 << j) & msks[i] == 0:
476 | dp[msks[i]][j] = math.inf
477 |
478 | for i in range(T + 1, S):
479 | msk = msks[i]
480 | for u in range(T):
481 | if (1 << u) & msk == 0:
482 | dp[msk][u] = math.inf
483 | continue
484 | cost = math.inf
485 | for v in range(T):
486 | if v == u or (1 << v) & msk == 0:
487 | continue
488 | cost = min(cost, dp[msk ^ (1 << u)][v] + A[targets[v]][targets[u]])
489 | dp[msk][u] = cost
490 | return np.min(dp[S - 1])
491 |
492 |
493 | def get_nodes_labels(A, F):
494 | """
495 | Takes the adjacency matrix and the list of nodes features (and a list of algorithms) and returns
496 | a set of labels for each node
497 | :param A: adj_matrix, N*N numpy matrix
498 | :param F: node_values, numpy array of size N
499 | :return: labels: KxN numpy matrix where K is the number of labels for each node
500 | """
501 | labels = [identity(A, F), map_reduce_neighbourhood(A, F, np.mean, consider_itself=True),
502 | map_reduce_neighbourhood(A, F, np.max, consider_itself=True),
503 | map_reduce_neighbourhood(A, F, np.std, consider_itself=True), first_neighbours(A), second_neighbours(A),
504 | eccentricity(A)]
505 | return np.swapaxes(np.stack(labels), 0, 1)
506 |
507 |
508 | def get_graph_labels(A, F):
509 | """
510 | Takes the adjacency matrix and the list of nodes features (and a list of algorithms) and returns
511 | a set of labels for the whole graph
512 | :param A: adj_matrix, N*N numpy matrix
513 | :param F: node_values, numpy array of size N
514 | :return: labels: numpy array of size K where K is the number of labels for the graph
515 | """
516 | labels = [diameter(A)]
517 | return np.asarray(labels)
518 |
--------------------------------------------------------------------------------
/datasets_generation/graph_generation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | import networkx as nx
4 | import math
5 | import matplotlib.pyplot as plt # only required to plot
6 | from enum import Enum
7 |
8 | """
9 | Generates random graphs of different types of a given size.
10 | Some of the graph are created using the NetworkX library, for more info see
11 | https://networkx.github.io/documentation/networkx-1.10/reference/generators.html
12 | """
13 |
14 |
15 | class GraphType(Enum):
16 | RANDOM = 0
17 | ERDOS_RENYI = 1
18 | BARABASI_ALBERT = 2
19 | GRID = 3
20 | CAVEMAN = 5
21 | TREE = 6
22 | LADDER = 7
23 | LINE = 8
24 | STAR = 9
25 | CATERPILLAR = 10
26 | LOBSTER = 11
27 |
28 |
29 | # probabilities of each type in case of random type
30 | MIXTURE = [(GraphType.ERDOS_RENYI, 0.2), (GraphType.BARABASI_ALBERT, 0.2), (GraphType.GRID, 0.05),
31 | (GraphType.CAVEMAN, 0.05), (GraphType.TREE, 0.15), (GraphType.LADDER, 0.05),
32 | (GraphType.LINE, 0.05), (GraphType.STAR, 0.05), (GraphType.CATERPILLAR, 0.1), (GraphType.LOBSTER, 0.1)]
33 |
34 |
35 | def erdos_renyi(N, degree, seed):
36 | """ Creates an Erdős-Rényi or binomial graph of size N with degree/N probability of edge creation """
37 | return nx.fast_gnp_random_graph(N, degree / N, seed, directed=False)
38 |
39 |
40 | def barabasi_albert(N, degree, seed):
41 | """ Creates a random graph according to the Barabási–Albert preferential attachment model
42 | of size N and where nodes are atteched with degree edges """
43 | return nx.barabasi_albert_graph(N, degree, seed)
44 |
45 |
46 | def grid(N):
47 | """ Creates a m x k 2d grid graph with N = m*k and m and k as close as possible """
48 | m = 1
49 | for i in range(1, int(math.sqrt(N)) + 1):
50 | if N % i == 0:
51 | m = i
52 | return nx.grid_2d_graph(m, N // m)
53 |
54 |
55 | def caveman(N):
56 | """ Creates a caveman graph of m cliques of size k, with m and k as close as possible """
57 | m = 1
58 | for i in range(1, int(math.sqrt(N)) + 1):
59 | if N % i == 0:
60 | m = i
61 | return nx.caveman_graph(m, N // m)
62 |
63 |
64 | def tree(N, seed):
65 | """ Creates a tree of size N with a power law degree distribution """
66 | return nx.random_powerlaw_tree(N, seed=seed, tries=10000)
67 |
68 |
69 | def ladder(N):
70 | """ Creates a ladder graph of N nodes: two rows of N/2 nodes, with each pair connected by a single edge.
71 | In case N is odd another node is attached to the first one. """
72 | G = nx.ladder_graph(N // 2)
73 | if N % 2 != 0:
74 | G.add_node(N - 1)
75 | G.add_edge(0, N - 1)
76 | return G
77 |
78 |
79 | def line(N):
80 | """ Creates a graph composed of N nodes in a line """
81 | return nx.path_graph(N)
82 |
83 |
84 | def star(N):
85 | """ Creates a graph composed by one center node connected N-1 outer nodes """
86 | return nx.star_graph(N - 1)
87 |
88 |
89 | def caterpillar(N, seed):
90 | """ Creates a random caterpillar graph with a backbone of size b (drawn from U[1, N)), and N − b
91 | pendent vertices uniformly connected to the backbone. """
92 | np.random.seed(seed)
93 | B = np.random.randint(low=1, high=N)
94 | G = nx.empty_graph(N)
95 | for i in range(1, B):
96 | G.add_edge(i - 1, i)
97 | for i in range(B, N):
98 | G.add_edge(i, np.random.randint(B))
99 | return G
100 |
101 |
102 | def lobster(N, seed):
103 | """ Creates a random Lobster graph with a backbone of size b (drawn from U[1, N)), and p (drawn
104 | from U[1, N − b ]) pendent vertices uniformly connected to the backbone, and additional
105 | N − b − p pendent vertices uniformly connected to the previous pendent vertices """
106 | np.random.seed(seed)
107 | B = np.random.randint(low=1, high=N)
108 | F = np.random.randint(low=B + 1, high=N + 1)
109 | G = nx.empty_graph(N)
110 | for i in range(1, B):
111 | G.add_edge(i - 1, i)
112 | for i in range(B, F):
113 | G.add_edge(i, np.random.randint(B))
114 | for i in range(F, N):
115 | G.add_edge(i, np.random.randint(low=B, high=F))
116 | return G
117 |
118 |
119 | def randomize(A):
120 | """ Adds some randomness by toggling some edges without chancing the expected number of edges of the graph """
121 | BASE_P = 0.9
122 |
123 | # e is the number of edges, r the number of missing edges
124 | N = A.shape[0]
125 | e = np.sum(A) / 2
126 | r = N * (N - 1) / 2 - e
127 |
128 | # ep chance of an existing edge to remain, rp chance of another edge to appear
129 | if e <= r:
130 | ep = BASE_P
131 | rp = (1 - BASE_P) * e / r
132 | else:
133 | ep = BASE_P + (1 - BASE_P) * (e - r) / e
134 | rp = 1 - BASE_P
135 |
136 | array = np.random.uniform(size=(N, N), low=0.0, high=0.5)
137 | array = array + array.transpose()
138 | remaining = np.multiply(np.where(array < ep, 1, 0), A)
139 | appearing = np.multiply(np.multiply(np.where(array < rp, 1, 0), 1 - A), 1 - np.eye(N))
140 | ans = np.add(remaining, appearing)
141 |
142 | # assert (np.all(np.multiply(ans, np.eye(N)) == np.zeros((N, N))))
143 | # assert (np.all(ans >= 0))
144 | # assert (np.all(ans <= 1))
145 | # assert (np.all(ans == ans.transpose()))
146 | return ans
147 |
148 |
149 | def generate_graph(N, type=GraphType.RANDOM, seed=None, degree=None):
150 | """
151 | Generates random graphs of different types of a given size. Note:
152 | - graph are undirected and without weights on edges
153 | - node values are sampled independently from U[0,1]
154 |
155 | :param N: number of nodes
156 | :param type: type chosen between the categories specified in GraphType enum
157 | :param seed: random seed
158 | :param degree: average degree of a node, only used in some graph types
159 | :return: adj_matrix: N*N numpy matrix
160 | node_values: numpy array of size N
161 | """
162 | random.seed(seed)
163 | np.random.seed(seed)
164 |
165 | # sample which random type to use
166 | if type == GraphType.RANDOM:
167 | type = np.random.choice([t for (t, _) in MIXTURE], 1, p=[pr for (_, pr) in MIXTURE])[0]
168 |
169 | # generate the graph structure depending on the type
170 | if type == GraphType.ERDOS_RENYI:
171 | if degree == None: degree = random.random() * N
172 | G = erdos_renyi(N, degree, seed)
173 | elif type == GraphType.BARABASI_ALBERT:
174 | if degree == None: degree = int(random.random() * (N - 1)) + 1
175 | G = barabasi_albert(N, degree, seed)
176 | elif type == GraphType.GRID:
177 | G = grid(N)
178 | elif type == GraphType.CAVEMAN:
179 | G = caveman(N)
180 | elif type == GraphType.TREE:
181 | G = tree(N, seed)
182 | elif type == GraphType.LADDER:
183 | G = ladder(N)
184 | elif type == GraphType.LINE:
185 | G = line(N)
186 | elif type == GraphType.STAR:
187 | G = star(N)
188 | elif type == GraphType.CATERPILLAR:
189 | G = caterpillar(N, seed)
190 | elif type == GraphType.LOBSTER:
191 | G = lobster(N, seed)
192 | else:
193 | print("Type not defined")
194 | return
195 |
196 | # generate adjacency matrix and nodes values
197 | nodes = list(G)
198 | random.shuffle(nodes)
199 | adj_matrix = nx.to_numpy_array(G, nodes)
200 | node_values = np.random.uniform(low=0, high=1, size=N)
201 |
202 | # randomization
203 | adj_matrix = randomize(adj_matrix)
204 |
205 | # draw the graph created
206 | # nx.draw(G, pos=nx.spring_layout(G))
207 | # plt.draw()
208 |
209 | return adj_matrix, node_values, type
210 |
211 |
212 | if __name__ == '__main__':
213 | for i in range(100):
214 | adj_matrix, node_values = generate_graph(10, GraphType.RANDOM, seed=i)
215 | print(adj_matrix)
216 |
--------------------------------------------------------------------------------
/datasets_generation/multitask_dataset.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import pickle
4 |
5 | import numpy as np
6 | import torch
7 | from inspect import signature
8 |
9 | from datasets_generation import graph_algorithms
10 | from datasets_generation.graph_generation import GraphType, generate_graph
11 |
12 |
13 | class DatasetMultitask:
14 |
15 | def __init__(self, n_graphs, N, seed, graph_type, get_nodes_labels, get_graph_labels, print_every, sssp, filename):
16 | self.adj = {}
17 | self.features = {}
18 | self.nodes_labels = {}
19 | self.graph_labels = {}
20 |
21 | def progress_bar(iteration, total, prefix='', suffix='', decimals=1, length=100, fill='█', printEnd=""):
22 | percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total)))
23 | filledLength = int(length * iteration // total)
24 | bar = fill * filledLength + '-' * (length - filledLength)
25 | print('\r{} |{}| {}% {}'.format(prefix, bar, percent, suffix), end=printEnd)
26 |
27 | def to_categorical(x, N):
28 | v = np.zeros(N)
29 | v[x] = 1
30 | return v
31 |
32 | for dset in N.keys():
33 | if dset not in n_graphs:
34 | n_graphs[dset] = n_graphs['default']
35 |
36 | total_n_graphs = sum(n_graphs[dset])
37 |
38 | set_adj = [[] for _ in n_graphs[dset]]
39 | set_features = [[] for _ in n_graphs[dset]]
40 | set_nodes_labels = [[] for _ in n_graphs[dset]]
41 | set_graph_labels = [[] for _ in n_graphs[dset]]
42 | generated = 0
43 |
44 | progress_bar(0, total_n_graphs, prefix='Generating {:20}\t\t'.format(dset),
45 | suffix='({} of {})'.format(0, total_n_graphs))
46 |
47 | for batch, batch_size in enumerate(n_graphs[dset]):
48 | for i in range(batch_size):
49 | # generate a random graph of type graph_type and size N
50 | seed += 1
51 | adj, features, type = generate_graph(N[dset][batch], graph_type, seed=seed)
52 |
53 | while np.min(np.max(adj, 0)) == 0.0:
54 | # remove graph with singleton nodes
55 | seed += 1
56 | adj, features, _ = generate_graph(N[dset][batch], type, seed=seed)
57 |
58 | generated += 1
59 | if generated % print_every == 0:
60 | progress_bar(generated, total_n_graphs, prefix='Generating {:20}\t\t'.format(dset),
61 | suffix='({} of {})'.format(generated, total_n_graphs))
62 |
63 | # make sure there are no self connection
64 | assert np.all(
65 | np.multiply(adj, np.eye(N[dset][batch])) == np.zeros((N[dset][batch], N[dset][batch])))
66 |
67 | if sssp:
68 | # define the source node
69 | source_node = np.random.randint(0, N[dset][batch])
70 |
71 | # compute the labels with graph_algorithms; if sssp add the sssp
72 | node_labels = get_nodes_labels(adj, features,
73 | graph_algorithms.all_pairs_shortest_paths(adj, 0)[source_node]
74 | if sssp else None)
75 | graph_labels = get_graph_labels(adj, features)
76 | if sssp:
77 | # add the 1-hot feature determining the starting node
78 | features = np.stack([to_categorical(source_node, N[dset][batch]), features], axis=1)
79 |
80 | set_adj[batch].append(adj)
81 | set_features[batch].append(features)
82 | set_nodes_labels[batch].append(node_labels)
83 | set_graph_labels[batch].append(graph_labels)
84 |
85 | self.adj[dset] = [torch.from_numpy(np.asarray(adjs)).float() for adjs in set_adj]
86 | self.features[dset] = [torch.from_numpy(np.asarray(fs)).float() for fs in set_features]
87 | self.nodes_labels[dset] = [torch.from_numpy(np.asarray(nls)).float() for nls in set_nodes_labels]
88 | self.graph_labels[dset] = [torch.from_numpy(np.asarray(gls)).float() for gls in set_graph_labels]
89 | progress_bar(total_n_graphs, total_n_graphs, prefix='Generating {:20}\t\t'.format(dset),
90 | suffix='({} of {})'.format(total_n_graphs, total_n_graphs), printEnd='\n')
91 |
92 | self.save_as_pickle(filename)
93 |
94 | def save_as_pickle(self, filename):
95 | """" Saves the data into a pickle file at filename """
96 | directory = os.path.dirname(filename)
97 | if not os.path.exists(directory):
98 | os.makedirs(directory)
99 |
100 | with open(filename, 'wb') as f:
101 | pickle.dump((self.adj, self.features, self.nodes_labels, self.graph_labels), f)
102 |
103 |
104 | if __name__ == '__main__':
105 | parser = argparse.ArgumentParser()
106 | parser.add_argument('--out', type=str, default='./data/multitask_dataset.pkl', help='Data path.')
107 | parser.add_argument('--seed', type=int, default=1234, help='Random seed.')
108 | parser.add_argument('--graph_type', type=str, default='RANDOM', help='Type of graphs in train set')
109 | parser.add_argument('--nodes_labels', nargs='+', default=["eccentricity", "graph_laplacian_features", "sssp"])
110 | parser.add_argument('--graph_labels', nargs='+', default=["is_connected", "diameter", "spectral_radius"])
111 | parser.add_argument('--extrapolation', action='store_true', default=False,
112 | help='Generated various test sets of dimensions larger than train and validation.')
113 | parser.add_argument('--print_every', type=int, default=20, help='')
114 | args = parser.parse_args()
115 |
116 | if 'sssp' in args.nodes_labels:
117 | sssp = True
118 | args.nodes_labels.remove('sssp')
119 | else:
120 | sssp = False
121 |
122 | # gets the functions of graph_algorithms from the specified datasets
123 | nodes_labels_algs = list(map(lambda s: getattr(graph_algorithms, s), args.nodes_labels))
124 | graph_labels_algs = list(map(lambda s: getattr(graph_algorithms, s), args.graph_labels))
125 |
126 |
127 | def get_nodes_labels(A, F, initial=None):
128 | labels = [] if initial is None else [initial]
129 | for f in nodes_labels_algs:
130 | params = signature(f).parameters
131 | labels.append(f(A, F) if 'F' in params else f(A))
132 | return np.swapaxes(np.stack(labels), 0, 1)
133 |
134 |
135 | def get_graph_labels(A, F):
136 | labels = []
137 | for f in graph_labels_algs:
138 | params = signature(f).parameters
139 | labels.append(f(A, F) if 'F' in params else f(A))
140 | return np.asarray(labels).flatten()
141 |
142 |
143 | data = DatasetMultitask(n_graphs={'train': [512] * 10, 'val': [128] * 5, 'default': [256] * 5},
144 | N={**{'train': range(15, 25), 'val': range(15, 25)}, **(
145 | {'test-(20,25)': range(20, 25), 'test-(25,30)': range(25, 30),
146 | 'test-(30,35)': range(30, 35), 'test-(35,40)': range(35, 40),
147 | 'test-(40,45)': range(40, 45), 'test-(45,50)': range(45, 50),
148 | 'test-(60,65)': range(60, 65), 'test-(75,80)': range(75, 80),
149 | 'test-(95,100)': range(95, 100)} if args.extrapolation else
150 | {'test': range(15, 25)})},
151 | seed=args.seed, graph_type=getattr(GraphType, args.graph_type),
152 | get_nodes_labels=get_nodes_labels, get_graph_labels=get_graph_labels,
153 | print_every=args.print_every, sssp=sssp, filename=args.out)
154 |
155 | data.save_as_pickle(args.out)
156 |
--------------------------------------------------------------------------------
/models/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cvignac/SMP/95b55a880d0fc9149ddf32e8c2fdf5eac5b474b3/models/.DS_Store
--------------------------------------------------------------------------------
/models/gin.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import Sequential, Linear, ReLU, ModuleList
4 | import torch.nn.functional as F
5 | from torch_geometric.nn import GINConv
6 | from models.utils.layers import XtoGlobal
7 |
8 |
9 | class FeatureExtractor(nn.Module):
10 | def __init__(self, in_features: int, out_features: int):
11 | super().__init__()
12 | self.XtoG = XtoGlobal(in_features, out_features, bias=True)
13 | self.lin = Linear(out_features, out_features, bias=False)
14 |
15 | def forward(self, x, batch_info):
16 | """ x: (num_nodes, in_features)
17 | output: (batch_size, out_features). """
18 | out = self.XtoG.forward(x, batch_info)
19 | out = out + self.lin.forward(F.relu(out))
20 | return out
21 |
22 |
23 | class GINNetwork(nn.Module):
24 | def __init__(self, in_features, out_features):
25 | super().__init__()
26 | self.lin_1 = nn.Linear(in_features, in_features)
27 | self.lin_2 = nn.Linear(in_features, out_features)
28 |
29 | def forward(self, x):
30 | x = self.lin_2(x + torch.relu(self.lin_1(x)))
31 | return x
32 |
33 |
34 | class GIN(nn.Module):
35 | def __init__(self, num_input_features: int, num_classes: int, num_layers: int,
36 | hidden, hidden_final: int, dropout_prob: float, use_batch_norm: bool):
37 | super().__init__()
38 | self.use_batch_norm = use_batch_norm
39 | self.dropout_prob = dropout_prob
40 | self.no_prop = FeatureExtractor(num_input_features, hidden_final)
41 | self.initial_lin_x = nn.Linear(num_input_features, hidden)
42 |
43 | self.convs = nn.ModuleList([])
44 | self.batch_norm_x = nn.ModuleList()
45 | self.feature_extractors = nn.ModuleList([])
46 | for i in range(num_layers):
47 | self.convs.append(GINConv(GINNetwork(hidden, hidden)))
48 | self.feature_extractors.append(FeatureExtractor(hidden, hidden_final))
49 | self.batch_norm_x.append(nn.BatchNorm1d(hidden))
50 |
51 | self.after_conv = nn.Linear(hidden_final, hidden_final)
52 | self.final_lin = nn.Linear(hidden_final, num_classes)
53 |
54 | def forward(self, data):
55 | """ data.x: (num_nodes, num_features)"""
56 | x, edge_index, batch, batch_size = data.x, data.edge_index, data.batch, data.num_graphs
57 |
58 | # Compute some information about the batch
59 | # Count the number of nodes in each graph
60 | unique, n_per_graph = torch.unique(data.batch, return_counts=True)
61 | n_batch = torch.zeros_like(batch, dtype=torch.float)
62 |
63 | for value, n in zip(unique, n_per_graph):
64 | n_batch[batch == value] = n.float()
65 |
66 | # Aggregate into a dict
67 | batch_info = {'num_nodes': data.num_nodes,
68 | 'num_graphs': data.num_graphs,
69 | 'batch': data.batch}
70 |
71 | out = self.no_prop.forward(x, batch_info)
72 | x = self.initial_lin_x(x)
73 | for i, (conv, bn_x, extractor) in enumerate(zip(self.convs, self.batch_norm_x, self.feature_extractors)):
74 | if self.use_batch_norm and i > 0:
75 | x = bn_x(x)
76 | x = conv(x, edge_index)
77 | global_features = extractor.forward(x, batch_info)
78 | out += global_features
79 |
80 | out = F.relu(out) / len(self.convs)
81 | out = F.relu(self.after_conv(out)) + out
82 | out = F.dropout(out, p=self.dropout_prob, training=self.training)
83 | out = self.final_lin(out)
84 | return F.log_softmax(out, dim=-1)
85 |
86 | def __repr__(self):
87 | return self.__class__.__name__
--------------------------------------------------------------------------------
/models/model_cycles.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 | from models.smp_layers import SimplifiedFastSMPLayer, FastSMPLayer, SMPLayer
5 | from models.utils.layers import GraphExtractor, EdgeCounter, BatchNorm
6 | from models.utils.misc import create_batch_info, map_x_to_u
7 |
8 |
9 | class SMP(torch.nn.Module):
10 | def __init__(self, num_input_features: int, num_classes: int, num_layers: int, hidden: int, layer_type: str,
11 | hidden_final: int, dropout_prob: float, use_batch_norm: bool, use_x: bool, map_x_to_u: bool,
12 | num_towers: int, simplified: bool):
13 | """ num_input_features: number of node features
14 | layer_type: 'SMP', 'FastSMP' or 'SimplifiedFastSMP'
15 | hidden_final: size of the feature map after pooling
16 | use_x: for ablation study, run a MPNN instead of SMP
17 | map_x_to_u: map the node features to the local context
18 | num_towers: inside each SMP layers, use towers to reduce the number of parameters
19 | simplified: less layers in the feature extractor.
20 | """
21 | super().__init__()
22 | self.map_x_to_u, self.use_x = map_x_to_u, use_x
23 | self.dropout_prob = dropout_prob
24 | self.use_batch_norm = use_batch_norm
25 | self.edge_counter = EdgeCounter()
26 | self.num_classes = num_classes
27 |
28 | self.no_prop = GraphExtractor(in_features=num_input_features, out_features=hidden_final, use_x=use_x)
29 | self.initial_lin = nn.Linear(num_input_features, hidden)
30 |
31 | layer_type_dict = {'SMP': SMPLayer, 'FastSMP': FastSMPLayer, 'SimplifiedFastSMP': SimplifiedFastSMPLayer}
32 | conv_layer = layer_type_dict[layer_type]
33 |
34 | self.convs = nn.ModuleList()
35 | self.batch_norm_list = nn.ModuleList()
36 | self.feature_extractors = torch.nn.ModuleList([])
37 | for i in range(0, num_layers):
38 | self.convs.append(conv_layer(in_features=hidden, num_towers=num_towers, out_features=hidden, use_x=use_x))
39 | self.batch_norm_list.append(BatchNorm(hidden, use_x))
40 | self.feature_extractors.append(GraphExtractor(in_features=hidden, out_features=hidden_final, use_x=use_x,
41 | simplified=simplified))
42 |
43 | # Last layers
44 | self.simplified = simplified
45 | self.after_conv = nn.Linear(hidden_final, hidden_final)
46 | self.final_lin = nn.Linear(hidden_final, num_classes)
47 |
48 | def forward(self, data):
49 | """ data.x: (num_nodes, num_features)"""
50 | x, edge_index = data.x, data.edge_index
51 | batch_info = create_batch_info(data, self.edge_counter)
52 |
53 | # Create the context matrix
54 | if self.use_x:
55 | assert x is not None
56 | u = x
57 | elif self.map_x_to_u:
58 | u = map_x_to_u(data, batch_info)
59 | else:
60 | u = data.x.new_zeros((data.num_nodes, batch_info['n_colors']))
61 | u.scatter_(1, data.coloring, 1)
62 | u = u[..., None]
63 |
64 | # Forward pass
65 | out = self.no_prop(u, batch_info)
66 | u = self.initial_lin(u)
67 | for i, (conv, bn, extractor) in enumerate(zip(self.convs, self.batch_norm_list, self.feature_extractors)):
68 | if self.use_batch_norm and i > 0:
69 | u = bn(u)
70 | u = conv(u, edge_index, batch_info)
71 | global_features = extractor.forward(u, batch_info)
72 | out += global_features / len(self.convs)
73 |
74 | # Two layer MLP with dropout and residual connections:
75 | if not self.simplified:
76 | out = torch.relu(self.after_conv(out)) + out
77 | out = F.dropout(out, p=self.dropout_prob, training=self.training)
78 | out = self.final_lin(out)
79 | if self.num_classes > 1:
80 | # Classification
81 | return F.log_softmax(out, dim=-1)
82 | else:
83 | # Regression
84 | assert out.shape[1] == 1
85 | return out[:, 0]
86 |
87 | def reset_parameters(self):
88 | for layer in [self.no_prop, self.initial_lin, *self.convs, *self.batch_norm_list, *self.feature_extractors,
89 | self.after_conv, self.final_lin]:
90 | layer.reset_parameters()
91 |
92 | def __repr__(self):
93 | return self.__class__.__name__
94 |
--------------------------------------------------------------------------------
/models/model_multi_task.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from models.utils.layers import EdgeCounter, NodeExtractor, BatchNorm
4 | from models.smp_layers import FastSMPLayer, SMPLayer, SimplifiedFastSMPLayer
5 | from torch_geometric.nn import Set2Set
6 | from models.utils.misc import create_batch_info
7 |
8 |
9 | class SMP(torch.nn.Module):
10 | def __init__(self, num_input_features: int, nodes_out: int, graph_out: int,
11 | num_layers: int, num_towers: int, hidden_u: int, out_u: int, hidden_gru: int,
12 | layer_type: str):
13 | """ num_input_features: number of node features
14 | nodes_out: number of output features at each node's level (3 on the benchmark)
15 | graph_out: number of output features at the graph level (3 on the benchmark)
16 | num_towers: inside each SMP layers, use towers to reduce the number of parameters
17 | hidden_u: number of channels in the local contexts
18 | out_u: number of channels after extraction of node features
19 | hidden_gru: number of channels inside the gated recurrent unit
20 | layer_type: 'SMP', 'FastSMP' or 'SimplifiedFastSMP'.
21 | """
22 | super().__init__()
23 | num_input_u = 1 + num_input_features
24 |
25 | self.edge_counter = EdgeCounter()
26 | self.initial_lin_u = nn.Linear(num_input_u, hidden_u)
27 |
28 | self.extractor = NodeExtractor(hidden_u, out_u)
29 |
30 | layer_type_dict = {'SMP': SMPLayer, 'FastSMP': FastSMPLayer, 'SimplifiedFastSMP': SimplifiedFastSMPLayer}
31 | conv_layer = layer_type_dict[layer_type]
32 |
33 | self.gru = nn.GRU(out_u, hidden_gru)
34 | self.convs = nn.ModuleList([])
35 | self.batch_norm_u = nn.ModuleList([])
36 | for i in range(0, num_layers):
37 | self.batch_norm_u.append(BatchNorm(hidden_u, use_x=False))
38 | conv = conv_layer(in_features=hidden_u, out_features=hidden_u, num_towers=num_towers, use_x=False)
39 | self.convs.append(conv)
40 |
41 | # Process the extracted node features
42 | max_n = 19
43 | self.set2set = Set2Set(hidden_gru, max_n)
44 |
45 | self.final_node = nn.Sequential(nn.Linear(hidden_gru, hidden_gru), nn.LeakyReLU(),
46 | nn.Linear(hidden_gru, hidden_gru), nn.LeakyReLU(),
47 | nn.Linear(hidden_gru, nodes_out))
48 |
49 | self.final_graph = nn.Sequential(nn.Linear(2 * hidden_gru, hidden_gru), nn.ReLU(),
50 | nn.BatchNorm1d(hidden_gru),
51 | nn.Linear(hidden_gru, hidden_gru), nn.LeakyReLU(),
52 | nn.BatchNorm1d(hidden_gru),
53 | nn.Linear(hidden_gru, graph_out))
54 |
55 | def forward(self, data):
56 | """ data.x: (num_nodes, num_features)"""
57 | x, edge_index, batch, batch_size = data.x, data.edge_index, data.batch, data.num_graphs
58 | batch_info = create_batch_info(data, self.edge_counter)
59 |
60 | # Create the context matrix
61 | u = data.x.new_zeros((data.num_nodes, batch_info['n_colors']))
62 | u.scatter_(1, data.coloring, 1)
63 | u = u[..., None]
64 |
65 | # Map x to u
66 | shortest_path_ids = x[:, 0]
67 | lap_feat = x[:, 1]
68 | u_shortest_path = torch.zeros_like(u)
69 | u_lap_feat = torch.zeros_like(u)
70 | non_zero = shortest_path_ids.nonzero(as_tuple=False)[:, 0]
71 | nonzero_batch = batch_info['batch'][non_zero]
72 | nonzero_color = batch_info['coloring'][non_zero][:, 0]
73 | for b, c in zip(nonzero_batch, nonzero_color):
74 | u_shortest_path[batch == b, c] = 1
75 |
76 | for i, feat in enumerate(lap_feat):
77 | u_lap_feat[i, batch_info['coloring'][i]] = feat
78 |
79 | u = torch.cat((u, u_shortest_path, u_lap_feat), dim=2)
80 |
81 | # Forward pass
82 | u = self.initial_lin_u(u)
83 | hidden_state = None
84 | for i, (conv, bn_u) in enumerate(zip(self.convs, self.batch_norm_u)):
85 | if i > 0:
86 | u = bn_u(u)
87 | u = conv(u, edge_index, batch_info)
88 | extracted = self.extractor(x, u, batch_info)[None, :, :]
89 | hidden_state = self.gru(extracted, hidden_state)[1]
90 |
91 | # Compute the final representation
92 | out = hidden_state[0, :, :]
93 | nodes_out = self.final_node(out)
94 | after_set2set = self.set2set(out, batch_info['batch'])
95 | graph_out = self.final_graph(after_set2set)
96 |
97 | return nodes_out, graph_out
98 |
99 | def __repr__(self):
100 | return self.__class__.__name__
101 |
--------------------------------------------------------------------------------
/models/model_zinc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 | from models.smp_layers import ZincSMPLayer
5 | from models.utils.layers import GraphExtractor, EdgeCounter, BatchNorm
6 | from models.utils.misc import create_batch_info, map_x_to_u
7 |
8 |
9 | class SMPZinc(torch.nn.Module):
10 | def __init__(self, num_input_features: int, num_edge_features: int, num_classes: int, num_layers: int,
11 | hidden: int, residual: bool, use_edge_features: bool, shared_extractor: bool,
12 | hidden_final: int, use_batch_norm: bool, use_x: bool, map_x_to_u: bool,
13 | num_towers: int, simplified: bool):
14 | """ num_input_features: number of node features
15 | num_edge_features: number of edge features
16 | num_classes: output dimension
17 | hidden: number of channels of the local contexts
18 | residual: use residual connexion after each SMP layer
19 | use_edge_features: if False, edge features are simply ignored
20 | shared extractor: share extractor among layers to reduce the number of parameters
21 | hidden_final: number of channels after extraction of graph features
22 | use_x: for ablation study, run a MPNN instead of SMP
23 | map_x_to_u: map the initial node features to the local context. If false, node features are ignored
24 | num_towers: inside each SMP layers, use towers to reduce the number of parameters
25 | simplified: if True, the feature extractor has less layers.
26 | """
27 | super().__init__()
28 | self.map_x_to_u, self.use_x = map_x_to_u, use_x
29 | self.use_batch_norm = use_batch_norm
30 | self.edge_counter = EdgeCounter()
31 | self.num_classes = num_classes
32 | self.residual = residual
33 | self.shared_extractor = shared_extractor
34 |
35 | self.no_prop = GraphExtractor(in_features=num_input_features, out_features=hidden_final, use_x=use_x)
36 | self.initial_lin = nn.Linear(num_input_features, hidden)
37 |
38 | self.convs = nn.ModuleList()
39 | self.batch_norm_list = nn.ModuleList()
40 | for i in range(0, num_layers):
41 | self.convs.append(ZincSMPLayer(in_features=hidden, num_towers=num_towers, out_features=hidden,
42 | edge_features=num_edge_features, use_x=use_x,
43 | use_edge_features=use_edge_features))
44 | self.batch_norm_list.append(BatchNorm(hidden, use_x) if i > 0 else None)
45 |
46 | # Feature extractors
47 | if shared_extractor:
48 | self.feature_extractor = GraphExtractor(in_features=hidden, out_features=hidden_final, use_x=use_x,
49 | simplified=simplified)
50 | else:
51 | self.feature_extractors = torch.nn.ModuleList([])
52 | for i in range(0, num_layers):
53 | self.feature_extractors.append(GraphExtractor(in_features=hidden, out_features=hidden_final,
54 | use_x=use_x, simplified=simplified))
55 |
56 | # Last layers
57 | self.after_conv = nn.Linear(hidden_final, hidden_final)
58 | self.final_lin = nn.Linear(hidden_final, num_classes)
59 |
60 | def forward(self, data):
61 | """ data.x: (num_nodes, num_node_features)"""
62 | x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
63 |
64 | # Compute information about the batch
65 | batch_info = create_batch_info(data, self.edge_counter)
66 |
67 | # Create the context matrix
68 | if self.use_x:
69 | assert x is not None
70 | u = x
71 | elif self.map_x_to_u:
72 | u = map_x_to_u(data, batch_info)
73 | else:
74 | u = data.x.new_zeros((data.num_nodes, batch_info['n_colors']))
75 | u.scatter_(1, data.coloring, 1)
76 | u = u[..., None]
77 |
78 | # Forward pass
79 | out = self.no_prop(u, batch_info)
80 | u = self.initial_lin(u)
81 | for i in range(len(self.convs)):
82 | conv = self.convs[i]
83 | bn = self.batch_norm_list[i]
84 | extractor = self.feature_extractor if self.shared_extractor else self.feature_extractors[i]
85 | if self.use_batch_norm and i > 0:
86 | u = bn(u)
87 | u = conv(u, edge_index, edge_attr, batch_info) + (u if self.residual else 0)
88 | global_features = extractor.forward(u, batch_info)
89 | out += global_features / len(self.convs)
90 |
91 | out = self.final_lin(torch.relu(self.after_conv(out)) + out)
92 | assert out.shape[1] == 1
93 | return out[:, 0]
94 |
95 | def __repr__(self):
96 | return self.__class__.__name__
97 |
--------------------------------------------------------------------------------
/models/ppgn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch import Tensor
5 |
6 |
7 | class InvariantMaxLayer(nn.Module):
8 | def forward(self, x: Tensor):
9 | """ x: (batch_size, n_nodes, n_nodes, channels)"""
10 | bs, n, channels = x.shape[0], x.shape[1], x.shape[3]
11 | diag = torch.diagonal(x, dim1=1, dim2=2).contiguous() # batch, Channels, n_nodes
12 | # max_diag = diag.max(dim=2)[0] # Batch, channels
13 | max_diag = diag.sum(dim=2)
14 | mask = ~ torch.eye(n=x.shape[1], dtype=torch.bool, device=x.device)[None, :, :, None].expand(x.shape)
15 | x_off_diag = x[mask].reshape(bs, n, n - 1, channels)
16 | # max_off_diag = x_off_diag.max(dim=1)[0].max(dim=1)[0]
17 | max_off_diag = x_off_diag.sum(dim=1).sum(dim=1)
18 | out = torch.cat((max_diag, max_off_diag), dim=1)
19 | return out
20 |
21 |
22 | class UnitMLP(nn.Module):
23 | def __init__(self, in_feat: int, out_feat: int, num_layers):
24 | super().__init__()
25 | self.layers = nn.ModuleList()
26 | self.layers.append(nn.Conv2d(in_feat, out_feat, (1, 1)))
27 | for i in range(1, num_layers):
28 | self.layers.append(nn.Conv2d(out_feat, out_feat, (1, 1)))
29 |
30 | def forward(self, x: Tensor):
31 | """ x: batch x N x N x channels"""
32 | # Convert for conv2d
33 | x = x.permute(0, 3, 1, 2).contiguous() # channels, N, N
34 | for layer in self.layers[:-1]:
35 | x = F.relu(layer.forward(x))
36 | x = self.layers[-1].forward(x)
37 | x = x.permute(0, 2, 3, 1) # batch_size, N, N, channels
38 | return x
39 |
40 |
41 | class PowerfulLayer(nn.Module):
42 | def __init__(self, in_feat: int, out_feat: int, num_layers: int):
43 | super().__init__()
44 | a = in_feat
45 | b = out_feat
46 | self.m1 = UnitMLP(a, b, num_layers)
47 | self.m2 = UnitMLP(a, b, num_layers)
48 | self.m4 = nn.Linear(a + b, b, bias=True)
49 |
50 | def forward(self, x):
51 | """ x: batch x N x N x in_feat"""
52 | out1 = self.m1.forward(x).permute(0, 3, 1, 2) # batch, out_feat, N, N
53 | out2 = self.m2.forward(x).permute(0, 3, 1, 2) # batch, out_feat, N, N
54 | out3 = x
55 | mult = out1 @ out2 # batch, out_feat, N, N
56 | out = torch.cat((mult.permute(0, 2, 3, 1), out3), dim=3) # batch, N, N, out_feat
57 | suffix = self.m4.forward(out)
58 | return suffix
59 |
60 |
61 | class FeatureExtractor(nn.Module):
62 | def __init__(self, in_features: int, out_features: int):
63 | super().__init__()
64 | self.lin1 = nn.Linear(in_features, out_features, bias=True)
65 | self.lin2 = nn.Linear(in_features, out_features, bias=False)
66 | self.lin3 = torch.nn.Linear(out_features, out_features, bias=False)
67 |
68 | def forward(self, u):
69 | """ u: (batch_size, num_nodes, num_nodes, in_features)
70 | output: (batch_size, out_features). """
71 | n = u.shape[1]
72 | diag = u.diagonal(dim1=1, dim2=2) # batch_size, channels, num_nodes
73 | trace = torch.sum(diag, dim=2)
74 | out1 = self.lin1.forward(trace / n)
75 |
76 | s = (torch.sum(u, dim=[1, 2]) - trace) / (n * (n-1))
77 | out2 = self.lin2.forward(s) # bs, out_feat
78 | out = out1 + out2
79 | out = out + self.lin3.forward(F.relu(out))
80 | return out
81 |
82 |
83 | class Powerful(nn.Module):
84 | def __init__(self, num_classes: int, num_layers: int, hidden: int, hidden_final: int, dropout_prob: float,
85 | simplified: bool):
86 | super().__init__()
87 | layers_per_conv = 1
88 | self.layer_after_conv = not simplified
89 | self.dropout_prob = dropout_prob
90 | self.no_prop = FeatureExtractor(1, hidden_final)
91 | initial_conv = PowerfulLayer(1, hidden, layers_per_conv)
92 | self.convs = nn.ModuleList([initial_conv])
93 | self.bns = nn.ModuleList([])
94 | for i in range(1, num_layers):
95 | self.convs.append(PowerfulLayer(hidden, hidden, layers_per_conv))
96 |
97 | self.feature_extractors = torch.nn.ModuleList([])
98 | for i in range(num_layers):
99 | self.bns.append(nn.BatchNorm2d(hidden))
100 | self.feature_extractors.append(FeatureExtractor(hidden, hidden_final))
101 | if self.layer_after_conv:
102 | self.after_conv = nn.Linear(hidden_final, hidden_final)
103 | self.final_lin = nn.Linear(hidden_final, num_classes)
104 |
105 | def forward(self, data):
106 | u = data.A[..., None] # batch, N, N, 1
107 | out = self.no_prop.forward(u)
108 | for conv, extractor, bn in zip(self.convs, self.feature_extractors, self.bns):
109 | u = conv(u)
110 | u = bn(u.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
111 | out = out + extractor.forward(u)
112 | out = F.relu(out) / len(self.convs)
113 | if self.layer_after_conv:
114 | out = out + F.relu(self.after_conv(out))
115 | out = F.dropout(out, p=self.dropout_prob, training=self.training)
116 | out = self.final_lin(out)
117 | return F.log_softmax(out, dim=-1)
118 |
--------------------------------------------------------------------------------
/models/ring_gnn.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/leichen2018/Ring-GNN/blob/master/src/model.py
2 |
3 | import torch
4 | import torch as th
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import numpy as np
8 |
9 |
10 | class FeatureExtractor(nn.Module):
11 | def __init__(self, in_features: int, out_features: int):
12 | super().__init__()
13 | self.lin1 = nn.Linear(in_features, out_features, bias=True)
14 | self.lin2 = nn.Linear(in_features, out_features, bias=False)
15 | self.lin3 = torch.nn.Linear(out_features, out_features, bias=False)
16 |
17 | def forward(self, u):
18 | """ u: (batch_size, num_nodes, num_nodes, in_features)
19 | output: (batch_size, out_features). """
20 | n = u.shape[1]
21 | diag = u.diagonal(dim1=1, dim2=2) # batch_size, channels, num_nodes
22 | trace = torch.sum(diag, dim=2)
23 | out1 = self.lin1.forward(trace / n)
24 |
25 | s = (torch.sum(u, dim=[1, 2]) - trace) / (n * (n-1))
26 | out2 = self.lin2.forward(s) # bs, out_feat
27 | out = out1 + out2
28 | out = out + self.lin3.forward(F.relu(out))
29 | return out
30 |
31 |
32 | class RingGNN(nn.Module):
33 | def __init__(self, num_classes: int, num_layers: int, hidden: int, hidden_final: int, dropout_prob: float,
34 | simplified: bool):
35 | super().__init__()
36 | self.layer_after_conv = not simplified
37 | self.dropout_prob = dropout_prob
38 | self.no_prop = FeatureExtractor(1, hidden_final)
39 | initial_conv = equi_2_to_2(1, hidden)
40 | self.convs = nn.ModuleList([initial_conv])
41 | self.bns = nn.ModuleList([])
42 | for i in range(1, num_layers):
43 | self.convs.append(equi_2_to_2(hidden, hidden))
44 |
45 | self.feature_extractors = torch.nn.ModuleList([])
46 | for i in range(num_layers):
47 | self.bns.append(nn.BatchNorm2d(hidden))
48 | self.feature_extractors.append(FeatureExtractor(hidden, hidden_final))
49 | if self.layer_after_conv:
50 | self.after_conv = nn.Linear(hidden_final, hidden_final)
51 | self.final_lin = nn.Linear(hidden_final, num_classes)
52 |
53 | def forward(self, data):
54 | u = data.A[..., None] # batch, N, N, 1
55 | out = self.no_prop.forward(u)
56 | for conv, extractor, bn in zip(self.convs, self.feature_extractors, self.bns):
57 | u = conv(u)
58 | u = bn(u.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
59 | out = out + extractor.forward(u)
60 | out = F.relu(out) / len(self.convs)
61 | if self.layer_after_conv:
62 | out = out + F.relu(self.after_conv(out))
63 | out = F.dropout(out, p=self.dropout_prob, training=self.training)
64 | out = self.final_lin(out)
65 | return F.log_softmax(out, dim=-1)
66 |
67 |
68 |
69 | class MLP(nn.Module):
70 | def __init__(self, feats):
71 | super(MLP, self).__init__()
72 | self.linears = nn.ModuleList([nn.Linear(m, n) for m, n in zip(feats[:-1], feats[1:])])
73 |
74 | def forward(self, x):
75 | for layer in self.linears[:-1]:
76 | x = layer(x)
77 | x = F.relu(x)
78 |
79 | return self.linears[-1](x)
80 |
81 |
82 | class equi_2_to_2(nn.Module):
83 | def __init__(self, input_depth, output_depth, normalization='inf', normalization_val=1.0, radius=2, k2_init=0.1):
84 | super(equi_2_to_2, self).__init__()
85 | basis_dimension = 15
86 | self.radius = radius
87 | # coeffs_values = lambda i, j, k: th.randn([i, j, k]) * th.sqrt(2. / (i + j).float())
88 | coeffs_values = lambda i, j, k: th.randn([i, j, k]) * np.sqrt(2. / float((i + j)))
89 | self.diag_bias_list = nn.ParameterList([])
90 |
91 | for i in range(radius):
92 | for j in range(i + 1):
93 | self.diag_bias_list.append(nn.Parameter(th.zeros(1, output_depth, 1, 1)))
94 |
95 | self.all_bias = nn.Parameter(th.zeros(1, output_depth, 1, 1))
96 | self.coeffs_list = nn.ParameterList([])
97 |
98 | for i in range(radius):
99 | for j in range(i + 1):
100 | self.coeffs_list.append(nn.Parameter(coeffs_values(input_depth, output_depth, basis_dimension)))
101 |
102 | self.switch = nn.ParameterList([nn.Parameter(th.FloatTensor([1])), nn.Parameter(th.FloatTensor([k2_init]))])
103 | self.output_depth = output_depth
104 |
105 | self.normalization = normalization
106 | self.normalization_val = normalization_val
107 |
108 | def forward(self, inputs):
109 | inputs = inputs.permute(0, 3, 1, 2) # Convert to N x D x m x m
110 | m = inputs.size()[3]
111 | ops_out = ops_2_to_2(inputs, m, normalization=self.normalization)
112 | ops_out = th.stack(ops_out, dim=2)
113 | output_list = []
114 |
115 | for i in range(self.radius):
116 | for j in range(i + 1):
117 | output_i = th.einsum('dsb,ndbij->nsij', self.coeffs_list[i * (i + 1) // 2 + j], ops_out)
118 | mat_diag_bias = th.eye(inputs.size()[3]).to(inputs.device).unsqueeze(0).unsqueeze(0) * self.diag_bias_list[
119 | i * (i + 1) // 2 + j]
120 | if j == 0:
121 | output = output_i + mat_diag_bias
122 | else:
123 | output = th.einsum('abcd,abde->abce', output_i, output)
124 |
125 | output_list.append(output)
126 |
127 | output = 0
128 | for i in range(self.radius):
129 | output += output_list[i] * self.switch[i]
130 |
131 | output = output + self.all_bias
132 | output = output.permute(0, 2, 3, 1)
133 | return output
134 |
135 |
136 | def diag_offdiag_maxpool(input):
137 | max_diag = th.max(th.diagonal(input, dim1=2, dim2=3), dim=2)[0]
138 |
139 | max_val = th.max(max_diag)
140 |
141 | min_val = th.max(input * (-1.))
142 | val = th.abs(max_val + min_val)
143 | min_mat = th.diag_embed(th.diagonal(input[0][0]) * 0 + val).unsqueeze(0).unsqueeze(0)
144 | max_offdiag = th.max(th.max(input - min_mat, dim=2)[0], dim=2)[0]
145 |
146 | return th.cat([max_diag, max_offdiag], dim=1)
147 |
148 |
149 | def ops_2_to_2(inputs, dim, normalization='inf', normalization_val=1.0): # N x D x m x m
150 | # input: N x D x m x m
151 | diag_part = th.diagonal(inputs, dim1=2, dim2=3) # N x D x m
152 | sum_diag_part = th.sum(diag_part, dim=2, keepdim=True) # N x D x 1
153 | sum_of_rows = th.sum(inputs, dim=3) # N x D x m
154 | sum_of_cols = th.sum(inputs, dim=2) # N x D x m
155 | sum_all = th.sum(sum_of_rows, dim=2) # N x D
156 |
157 | # op1 - (1234) - extract diag
158 | op1 = th.diag_embed(diag_part) # N x D x m x m
159 |
160 | # op2 - (1234) + (12)(34) - place sum of diag on diag
161 | op2 = th.diag_embed(sum_diag_part.repeat(1, 1, dim))
162 |
163 | # op3 - (1234) + (123)(4) - place sum of row i on diag ii
164 | op3 = th.diag_embed(sum_of_rows)
165 |
166 | # op4 - (1234) + (124)(3) - place sum of col i on diag ii
167 | op4 = th.diag_embed(sum_of_cols)
168 |
169 | # op5 - (1234) + (124)(3) + (123)(4) + (12)(34) + (12)(3)(4) - place sum of all entries on diag
170 | op5 = th.diag_embed(sum_all.unsqueeze(2).repeat(1, 1, dim))
171 |
172 | # op6 - (14)(23) + (13)(24) + (24)(1)(3) + (124)(3) + (1234) - place sum of col i on row i
173 | op6 = sum_of_cols.unsqueeze(3).repeat(1, 1, 1, dim)
174 |
175 | # op7 - (14)(23) + (23)(1)(4) + (234)(1) + (123)(4) + (1234) - place sum of row i on row i
176 | op7 = sum_of_rows.unsqueeze(3).repeat(1, 1, 1, dim)
177 |
178 | # op8 - (14)(2)(3) + (134)(2) + (14)(23) + (124)(3) + (1234) - place sum of col i on col i
179 | op8 = sum_of_cols.unsqueeze(2).repeat(1, 1, dim, 1)
180 |
181 | # op9 - (13)(24) + (13)(2)(4) + (134)(2) + (123)(4) + (1234) - place sum of row i on col i
182 | op9 = sum_of_rows.unsqueeze(2).repeat(1, 1, dim, 1)
183 |
184 | # op10 - (1234) + (14)(23) - identity
185 | op10 = inputs
186 |
187 | # op11 - (1234) + (13)(24) - transpose
188 | op11 = th.transpose(inputs, -2, -1)
189 |
190 | # op12 - (1234) + (234)(1) - place ii element in row i
191 | op12 = diag_part.unsqueeze(3).repeat(1, 1, 1, dim)
192 |
193 | # op13 - (1234) + (134)(2) - place ii element in col i
194 | op13 = diag_part.unsqueeze(2).repeat(1, 1, dim, 1)
195 |
196 | # op14 - (34)(1)(2) + (234)(1) + (134)(2) + (1234) + (12)(34) - place sum of diag in all entries
197 | op14 = sum_diag_part.unsqueeze(3).repeat(1, 1, dim, dim)
198 |
199 | # op15 - sum of all ops - place sum of all entries in all entries
200 | op15 = sum_all.unsqueeze(2).unsqueeze(3).repeat(1, 1, dim, dim)
201 |
202 | # A_2 = th.einsum('abcd,abde->abce', inputs, inputs)
203 | # A_4 = th.einsum('abcd,abde->abce', A_2, A_2)
204 | # op16 = th.where(A_4>1, th.ones(A_4.size()), A_4)
205 |
206 | if normalization is not None:
207 | float_dim = float(dim)
208 | if normalization is 'inf':
209 | op2 = th.div(op2, float_dim)
210 | op3 = th.div(op3, float_dim)
211 | op4 = th.div(op4, float_dim)
212 | op5 = th.div(op5, float_dim ** 2)
213 | op6 = th.div(op6, float_dim)
214 | op7 = th.div(op7, float_dim)
215 | op8 = th.div(op8, float_dim)
216 | op9 = th.div(op9, float_dim)
217 | op14 = th.div(op14, float_dim)
218 | op15 = th.div(op15, float_dim ** 2)
219 |
220 | # return [op1, op2, op3, op4, op5, op6, op7, op8, op9, op10, op11, op12, op13, op14, op15, op16]
221 | '''
222 | l = [op1, op2, op3, op4, op5, op6, op7, op8, op9, op10, op11, op12, op13, op14, op15]
223 | for i, ls in enumerate(l):
224 | print(i+1)
225 | print(th.sum(ls))
226 | print("$%^&*(*&^%$#$%^&*(*&^%$%^&*(*&^%$%^&*(")
227 | '''
228 | return [op1, op2, op3, op4, op5, op6, op7, op8, op9, op10, op11, op12, op13, op14, op15]
229 |
--------------------------------------------------------------------------------
/models/smp_layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch import Tensor
4 | from torch_geometric.nn import MessagePassing
5 | from models.utils.layers import XtoX, UtoU, UtoU, EntrywiseU, EntryWiseX
6 |
7 |
8 | class SimplifiedFastSMPLayer(MessagePassing):
9 | def __init__(self, in_features: int, num_towers: int, out_features: int, use_x: bool):
10 | super().__init__(aggr='add', node_dim=-3)
11 | self.use_x = use_x
12 | self.message_nn = (XtoX if use_x else UtoU)(in_features, out_features, bias=True)
13 | if self.use_x:
14 | self.alpha = nn.Parameter(torch.zeros(1, out_features), requires_grad=True)
15 | else:
16 | self.alpha = nn.Parameter(torch.zeros(1, 1, out_features), requires_grad=True)
17 |
18 | def reset_parameters(self):
19 | self.message_nn.reset_parameters()
20 | self.alpha.requires_grad_(False)
21 | self.alpha[...] = 0
22 | self.alpha.requires_grad_(True)
23 |
24 | def forward(self, u, edge_index, batch_info):
25 | """ x corresponds either to node features or to the local context, depending on use_x."""
26 | n = batch_info['num_nodes']
27 | if self.use_x and u.dim() == 1:
28 | u = u.unsqueeze(-1)
29 | u = self.message_nn(u, batch_info)
30 | new_u = self.propagate(edge_index, size=(n, n), u=u)
31 | # Normalization
32 | if len(new_u.shape) == 2:
33 | # node features are used
34 | new_u /= batch_info['average_edges'][:, :, 0]
35 | else:
36 | # local contexts are used
37 | new_u /= batch_info['average_edges']
38 | return new_u
39 |
40 | def message(self, u_j: Tensor):
41 | return u_j
42 |
43 | def update(self, aggr_u, u):
44 | return aggr_u + u + self.alpha * u * aggr_u
45 |
46 |
47 | class FastSMPLayer(MessagePassing):
48 | def __init__(self, in_features: int, num_towers: int, out_features: int, use_x: bool):
49 | super().__init__(aggr='add', node_dim=-2 if use_x else -3)
50 | self.use_x = use_x
51 | self.in_u, self.out_u = in_features, out_features
52 | if use_x:
53 | self.message_nn = XtoX(in_features, out_features, bias=True)
54 | self.linu_i = EntryWiseX(out_features, out_features, num_towers=out_features)
55 | self.linu_j = EntryWiseX(out_features, out_features, num_towers=out_features)
56 | else:
57 | self.message_nn = UtoU(in_features, out_features, n_groups=num_towers, residual=False)
58 | self.linu_i = EntrywiseU(out_features, out_features, num_towers=out_features)
59 | self.linu_j = EntrywiseU(out_features, out_features, num_towers=out_features)
60 |
61 | def forward(self, u, edge_index, batch_info):
62 | n = batch_info['num_nodes']
63 | u = self.message_nn(u, batch_info)
64 | new_u = self.propagate(edge_index, size=(n, n), u=u)
65 | new_u /= batch_info['average_edges']
66 | return new_u
67 |
68 | def message(self, u_j):
69 | return u_j
70 |
71 | def update(self, aggr_u, u):
72 | a_i = self.linu_i(u)
73 | a_j = self.linu_j(aggr_u)
74 | return aggr_u + u + a_i * a_j
75 |
76 |
77 | class SMPLayer(MessagePassing):
78 | def __init__(self, in_features: int, num_towers: int, out_features: int, use_x: bool):
79 | super().__init__(aggr='add', node_dim=-3)
80 | self.use_x = use_x
81 | self.in_u, self.out_u = in_features, out_features
82 | if use_x:
83 | self.message_nn = XtoX(in_features, out_features, bias=True)
84 | self.order2_i = EntryWiseX(out_features, out_features, num_towers)
85 | self.order2_j = EntryWiseX(out_features, out_features, num_towers)
86 | self.order2 = EntryWiseX(out_features, out_features, num_towers)
87 | else:
88 | self.message_nn = UtoU(in_features, out_features, n_groups=num_towers, residual=False)
89 | self.order2_i = EntrywiseU(out_features, out_features, num_towers)
90 | self.order2_j = EntrywiseU(out_features, out_features, num_towers)
91 | self.order2 = EntrywiseU(out_features, out_features, num_towers)
92 | self.update1 = nn.Linear(2 * out_features, out_features)
93 | self.update2 = nn.Linear(out_features, out_features)
94 |
95 | def forward(self, u, edge_index, batch_info):
96 | n = batch_info['num_nodes']
97 | u = self.message_nn(u, batch_info)
98 | u1 = self.order2_i(u)
99 | u2 = self.order2_j(u)
100 | new_u = self.propagate(edge_index, size=(n, n), u=u, u1=u1, u2=u2)
101 | new_u /= batch_info['average_edges']
102 | return new_u
103 |
104 | def message(self, u_j, u1_i, u2_j):
105 | order2 = self.order2(torch.relu(u1_i + u2_j))
106 | return order2
107 |
108 | def update(self, aggr_u, u):
109 | up1 = self.update1(torch.cat((u, aggr_u), dim=-1))
110 | up2 = up1 + self.update2(up1)
111 | return up2
112 |
113 |
114 | class ZincSMPLayer(MessagePassing):
115 | def __init__(self, in_features: int, num_towers: int, out_features: int, edge_features: int, use_x: bool,
116 | use_edge_features: bool):
117 | """ Use a MLP both for the update and message function + edge features. """
118 | super().__init__(aggr='add', node_dim=-2 if use_x else -3)
119 | self.use_x, self.use_edge_features = use_x, use_edge_features
120 | self.in_u, self.out_u, self.edge_features = in_features, out_features, edge_features
121 | self.edge_nn = nn.Linear(edge_features, out_features) if use_edge_features else None
122 |
123 | self.message_nn = (EntryWiseX if use_x else UtoU)(in_features, out_features,
124 | n_groups=num_towers, residual=False)
125 |
126 | args_order2 = [out_features, out_features, num_towers]
127 | entry_wise = EntryWiseX if use_x else EntrywiseU
128 | self.order2_i = entry_wise(*args_order2)
129 | self.order2_j = entry_wise(*args_order2)
130 | self.order2 = entry_wise(*args_order2)
131 |
132 | self.update1 = nn.Linear(2 * out_features, out_features)
133 | self.update2 = nn.Linear(out_features, out_features)
134 |
135 | def forward(self, u, edge_index, edge_attr, batch_info):
136 | n = batch_info['num_nodes']
137 | u = self.message_nn(u, batch_info)
138 | u1 = self.order2_i(u)
139 | u2 = self.order2_j(u)
140 | new_u = self.propagate(edge_index, size=(n, n), u=u, u1=u1, u2=u2, edge_attr=edge_attr)
141 | new_u /= batch_info['average_edges'][:, :, 0] if self.use_x else batch_info['average_edges']
142 | return new_u
143 |
144 | def message(self, u_j, u1_i, u2_j, edge_attr):
145 | edge_feat = self.edge_nn(edge_attr) if self.use_edge_features else 0
146 | if not self.use_x:
147 | edge_feat = edge_feat.unsqueeze(1)
148 | order2 = self.order2(torch.relu(u1_i + u2_j + edge_feat))
149 | u_j = u_j + order2
150 | return u_j
151 |
152 | def update(self, aggr_u, u):
153 | up1 = self.update1(torch.cat((u, aggr_u), dim=-1))
154 | up2 = up1 + self.update2(up1)
155 | return up2 + u
156 |
--------------------------------------------------------------------------------
/models/utils/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch import Tensor as Tensor
4 | from torch.nn import Linear as Linear
5 | import torch.nn.init as init
6 | from torch.nn.init import _calculate_correct_fan, calculate_gain
7 | import torch.nn.functional as F
8 | from torch_geometric.nn import global_mean_pool, global_add_pool, global_max_pool, MessagePassing
9 | import math
10 |
11 | small_gain = 0.01
12 |
13 |
14 | def pooling(x: torch.Tensor, batch_info, method):
15 | if method == 'add':
16 | return global_add_pool(x, batch_info['batch'], batch_info['num_graphs'])
17 | elif method == 'mean':
18 | return global_mean_pool(x, batch_info['batch'], batch_info['num_graphs'])
19 | elif method == 'max':
20 | return global_max_pool(x, batch_info['batch'], batch_info['num_graphs'])
21 | else:
22 | raise ValueError("Pooling method not implemented")
23 |
24 |
25 | def kaiming_init_with_gain(x: Tensor, gain: float, a=0, mode='fan_in', nonlinearity='relu'):
26 | fan = _calculate_correct_fan(x, mode)
27 | non_linearity_gain = calculate_gain(nonlinearity, a)
28 | std = non_linearity_gain / math.sqrt(fan)
29 | bound = math.sqrt(3.0) * std * gain # Calculate uniform bounds from standard deviation
30 | with torch.no_grad():
31 | return x.uniform_(-bound, bound)
32 |
33 |
34 | class BatchNorm(nn.Module):
35 | def __init__(self, channels: int, use_x: bool):
36 | super().__init__()
37 | self.bn = nn.BatchNorm1d(channels)
38 | self.use_x = use_x
39 |
40 | def reset_parameters(self):
41 | self.bn.reset_parameters()
42 |
43 | def forward(self, u):
44 | if self.use_x:
45 | return self.bn(u)
46 | else:
47 | return self.bn(u.transpose(1, 2)).transpose(1, 2)
48 |
49 |
50 | class EdgeCounter(MessagePassing):
51 | def __init__(self):
52 | super().__init__(aggr='add')
53 |
54 | def forward(self, x, edge_index, batch, batch_size):
55 | n_edges = self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)
56 | return global_mean_pool(n_edges, batch, batch_size)[batch]
57 |
58 |
59 | class Linear(nn.Module):
60 | """ Linear layer with potentially smaller parameters at initialization. """
61 | __constants__ = ['bias', 'in_features', 'out_features']
62 |
63 | def __init__(self, in_features, out_features, bias=True, gain: float = 1.0):
64 | super().__init__()
65 | self.gain = gain
66 | self.lin = nn.Linear(in_features, out_features, bias)
67 |
68 | def reset_parameters(self):
69 | kaiming_init_with_gain(self.lin.weight, self.gain)
70 | if self.lin.bias is not None:
71 | nn.init.normal_(self.lin.bias, 0, self.gain / math.sqrt(self.lin.out_features))
72 |
73 | def forward(self, x):
74 | return self.lin.forward(x)
75 |
76 |
77 | class XtoX(Linear):
78 | def forward(self, x, batch_info: dict = None):
79 | return self.lin.forward(x)
80 |
81 |
82 | class XtoGlobal(Linear):
83 | def forward(self, x: Tensor, batch_info: dict, method='mean'):
84 | """ x: (num_nodes, in_features). """
85 | g = pooling(x, batch_info, method) # bs, N, in_feat or bs, in_feat
86 | return self.lin.forward(g)
87 |
88 |
89 | class EntrywiseU(nn.Module):
90 | def __init__(self, in_features: int, out_features: int, num_towers=None):
91 | super().__init__()
92 | if num_towers is None:
93 | num_towers = in_features
94 | self.lin1 = torch.nn.Conv1d(in_features, out_features, kernel_size=1, groups=num_towers, bias=False)
95 |
96 | def forward(self, u):
97 | """ u: N x colors x channels. """
98 | u = u.transpose(1, 2)
99 | u = self.lin1(u)
100 | return u.transpose(1, 2)
101 |
102 |
103 | class EntryWiseX(nn.Module):
104 | def __init__(self, in_features: int, out_features: int, n_groups=None, residual=False):
105 | super().__init__()
106 | self.residual = residual
107 | if n_groups is None:
108 | n_groups = in_features
109 | self.lin1 = torch.nn.Conv1d(in_features, out_features, kernel_size=1, groups=n_groups, bias=False)
110 |
111 | def forward(self, x, batch_info=None):
112 | """ x: N x channels. """
113 | new_x = self.lin1(x.unsqueeze(-1)).squeeze()
114 | return (new_x + x) if self.residual else new_x
115 |
116 | class UtoU(nn.Module):
117 | def __init__(self, in_features: int, out_features: int, residual=True, n_groups=None):
118 | super().__init__()
119 | if n_groups is None:
120 | n_groups = 1
121 | self.residual = residual
122 | self.lin1 = torch.nn.Conv1d(in_features, out_features, kernel_size=1, groups=n_groups, bias=True)
123 | self.lin2 = torch.nn.Conv1d(in_features, out_features, kernel_size=1, groups=n_groups, bias=False)
124 | self.lin3 = torch.nn.Conv1d(in_features, out_features, kernel_size=1, groups=n_groups, bias=False)
125 |
126 | def forward(self, u: Tensor, batch_info: dict = None):
127 | """ U: N x n_colors x channels"""
128 | old_u = u
129 | n = batch_info['num_nodes']
130 | num_colors = u.shape[1]
131 | out_feat = self.lin1.out_channels
132 |
133 | mask = batch_info['mask'][..., None].expand(n, num_colors, out_feat)
134 | normalizer = batch_info['n_batch']
135 | mean2 = torch.sum(u / normalizer, dim=1) # N, in_feat
136 | mean2 = mean2.unsqueeze(-1) # N, in_feat, 1
137 | # 1. Transform u element-wise
138 | u = u.permute(0, 2, 1) # In conv1d, channel dimension is second
139 | out = self.lin1(u).permute(0, 2, 1)
140 |
141 | # 2. Put in self of each line the sum over each line
142 | # The 0.1 factor is here to bias the network in favor of learning powers of the adjacency
143 | z2 = self.lin2(mean2) * 0.1 # N, out_feat, 1
144 | z2 = z2.transpose(1, 2) # N, 1, out_feat
145 | index_tensor = batch_info['coloring'][:, :, None].expand(out.shape[0], 1, out_feat)
146 | out.scatter_add_(1, index_tensor, z2) # n, n_colors, out_feat
147 |
148 | # 3. Put everywhere the sum over each line
149 | z3 = self.lin3(mean2) # N, out_feat, 1
150 | z3 = z3.transpose(1, 2) # N, 1, out_feat
151 | out3 = z3.expand(n, num_colors, out_feat)
152 | out += out3 * mask * 0.1 # Mask the extra colors
153 | if self.residual:
154 | return old_u + out
155 | return out
156 |
157 |
158 | class UtoGlobal(nn.Module):
159 | def __init__(self, in_features: int , out_features: int, bias: bool, gain: float):
160 | super().__init__()
161 | self.lin1 = Linear(in_features, out_features, bias, gain=gain)
162 | self.lin2 = Linear(in_features, out_features, bias, gain=gain)
163 |
164 | def reset_parameters(self):
165 | for layer in [self.lin1, self.lin2]:
166 | layer.reset_parameters()
167 |
168 | def forward(self, u, batch_info: dict, method='mean'):
169 | """ u: (num_nodes, colors, in_features)
170 | output: (batch_size, out_features). """
171 | coloring = batch_info['coloring']
172 | # Extract trace
173 | index_tensor = coloring[:, :, None].expand(u.shape[0], 1, u.shape[2])
174 | extended_diag = u.gather(1, index_tensor)[:, 0, :] # n_nodes, in_feat
175 | mean_batch_trace = pooling(extended_diag, batch_info, 'mean') # n_graphs, in_feat
176 | out1 = self.lin1(mean_batch_trace) # bs, out_feat
177 | # Extract sum of elements - trace
178 | mean = torch.sum(u / batch_info['n_batch'], dim=1) # num_nodes, in_feat
179 | batch_sum = pooling(mean, batch_info, 'mean') # n_graphs, in_feat
180 | batch_sum = batch_sum - mean_batch_trace # make the basis orthogonal
181 | out2 = self.lin2(batch_sum) # bs, out_feat
182 | return out1 + out2
183 |
184 |
185 | class NodeExtractor(nn.Module):
186 | def __init__(self, in_features_u: int, out_features_u: int):
187 | super().__init__()
188 | # Extract from U with a Deep set
189 | self.lin1_u = nn.Linear(in_features_u, in_features_u)
190 | self.lin2_u = nn.Linear(in_features_u, in_features_u)
191 | self.combine1 = nn.Linear(3 * in_features_u, out_features_u)
192 |
193 | def forward(self, x: Tensor, u: Tensor, batch_info: dict):
194 | """ u: (num_nodes, num_nodes, in_features).
195 | output: (num_nodes, out_feat).
196 | this method can probably be made more efficient.
197 | """
198 | # Extract u
199 | new_u = self.lin2_u(torch.relu(self.lin1_u(u)))
200 | # Aggregation
201 | # a. Extract the value in self
202 | index_tensor = batch_info['coloring'][:, :, None].expand(u.shape[0], 1, u.shape[-1])
203 | x1 = torch.gather(new_u, 1, index_tensor)
204 | x1 = x1[:, 0, :]
205 | # b. Mean over the line
206 | x2 = torch.sum(new_u / batch_info['n_batch'], dim=1) # num_nodes x in_feat
207 | # c. Max over the line
208 | x3 = torch.max(new_u, dim=1)[0] # num_nodes x out_feat
209 | # Combine
210 | x_full = torch.cat((x1, x2, x3), dim=1)
211 | out = self.combine1(x_full)
212 | return out
213 |
214 |
215 | class GraphExtractor(nn.Module):
216 | def __init__(self, in_features: int, out_features: int, use_x: bool, simplified=False):
217 | super().__init__()
218 | self.use_x, self.simplified = use_x, simplified
219 | self.extractor = (XtoGlobal if self.use_x else UtoGlobal)(in_features, out_features, True, 1)
220 | self.lin = nn.Linear(out_features, out_features)
221 |
222 | def reset_parameters(self):
223 | for layer in [self.extractor, self.lin]:
224 | layer.reset_parameters()
225 |
226 | def forward(self, u: Tensor, batch_info: dict):
227 | out = self.extractor(u, batch_info)
228 | if self.simplified:
229 | return out
230 | out = out + self.lin(F.relu(out))
231 | return out
232 |
--------------------------------------------------------------------------------
/models/utils/misc.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def create_batch_info(data, edge_counter):
5 | """ Compute some information about the batch that will be used by SMP."""
6 | x, edge_index, batch, batch_size = data.x, data.edge_index, data.batch, data.num_graphs
7 |
8 | # Compute some information about the batch
9 | # Count the number of nodes in each graph
10 | unique, n_per_graph = torch.unique(data.batch, return_counts=True)
11 | n_batch = torch.zeros_like(batch, dtype=torch.float)
12 |
13 | for value, n in zip(unique, n_per_graph):
14 | n_batch[batch == value] = n.float()
15 |
16 | # Count the average number of edges per graph
17 | dummy = x.new_ones((data.num_nodes, 1))
18 | average_edges = edge_counter(dummy, edge_index, batch, batch_size)
19 |
20 | # Create the coloring if it does not exist yet
21 | if not hasattr(data, 'coloring'):
22 | data.coloring = data.x.new_zeros(data.num_nodes, dtype=torch.long)
23 | for i in range(data.num_graphs):
24 | data.coloring[data.batch == i] = torch.arange(n_per_graph[i], device=data.x.device)
25 | data.coloring = data.coloring[:, None]
26 | n_colors = torch.max(data.coloring) + 1 # Indexing starts at 0
27 |
28 | mask = torch.zeros(data.num_nodes, n_colors, dtype=torch.bool, device=x.device)
29 | for value, n in zip(unique, n_per_graph):
30 | mask[batch == value, :n] = True
31 |
32 | # Aggregate into a dict
33 | batch_info = {'num_nodes': data.num_nodes,
34 | 'num_graphs': data.num_graphs,
35 | 'batch': data.batch,
36 | 'n_per_graph': n_per_graph,
37 | 'n_batch': n_batch[:, None, None].float(),
38 | 'average_edges': average_edges[:, :, None],
39 | 'coloring': data.coloring,
40 | 'n_colors': n_colors,
41 | 'mask': mask # Used because of batching - it tells which entries of u are not used by the graph
42 | }
43 | return batch_info
44 |
45 |
46 | def map_x_to_u(data, batch_info):
47 | """ map the node features to the right row of the initial local context."""
48 | x = data.x
49 | u = x.new_zeros((data.num_nodes, batch_info['n_colors']))
50 | u.scatter_(1, data.coloring, 1)
51 | u = u[..., None]
52 |
53 | u_x = u.new_zeros((u.shape[0], u.shape[1], x.shape[1]))
54 |
55 | n_features = x.shape[1]
56 | coloring = batch_info['coloring'] # N x 1
57 | expanded_colors = coloring[..., None].expand(-1, -1, n_features)
58 |
59 | u_x = u_x.scatter_(dim=1, index=expanded_colors, src=x[:, None, :])
60 |
61 | u = torch.cat((u, u_x), dim=2)
62 | return u
63 |
64 |
--------------------------------------------------------------------------------
/models/utils/transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.utils import to_dense_adj
3 | import networkx as nx
4 | import torch_geometric
5 | from torch_geometric.data import Data
6 | from networkx.algorithms.shortest_paths.unweighted import all_pairs_shortest_path_length
7 | from networkx.algorithms.coloring import greedy_color
8 | import numpy as np
9 | from sklearn.preprocessing import OneHotEncoder
10 |
11 |
12 | class EyeTransform(object):
13 | def __init__(self, max_num_nodes):
14 | self.max_num_nodes = max_num_nodes
15 |
16 | def __call__(self, data):
17 | n = data.x.shape[0]
18 | data.x = torch.eye(n, self.max_num_nodes, dtype=torch.float)
19 | return data
20 |
21 | def __repr__(self):
22 | return str(self.__class__.__name__)
23 |
24 |
25 | class RandomId(object):
26 | r"""Adds the node degree as one hot encodings to the node features.
27 |
28 | Args:
29 | max_degree (int): Maximum degree.
30 | in_degree (bool, optional): If set to :obj:`True`, will compute the
31 | in-degree of nodes instead of the out-degree.
32 | (default: :obj:`False`)
33 | cat (bool, optional): Concat node degrees to node features instead
34 | of replacing them. (default: :obj:`True`)
35 | """
36 | def __init__(self):
37 | pass
38 | def __call__(self, data):
39 | n = data.x.shape[0]
40 | data.x = torch.randint(0, 100, (n, 1), dtype=torch.float) / 100
41 | # data.x = torch.randn(n, self.embedding_size, dtype=torch.float)
42 | return data
43 |
44 | def __repr__(self):
45 | return str(self.__class__.__name__)
46 |
47 |
48 | class DenseAdjMatrix(object):
49 | def __init__(self, n: int):
50 | """ n: number of nodes in the graph (should be constant)"""
51 | self.n = n
52 |
53 | def __call__(self, data):
54 | batch = data.edge_index.new_zeros(self.n)
55 | data.A = to_dense_adj(data.edge_index, batch)
56 | return data
57 |
58 | def __repr__(self):
59 | return str(self.__class__.__name__)
60 |
61 |
62 | class KHopColoringTransform(object):
63 | def __init__(self, k: int):
64 | self.k = k
65 |
66 | def __call__(self, data):
67 | """ Compute a coloring such that no node sees twice the same color in its k-hop neighbourhood."""
68 | k = self.k
69 | g = torch_geometric.utils.to_networkx(data, to_undirected=True, remove_self_loops=True)
70 | lengths = all_pairs_shortest_path_length(g, cutoff=2 * k)
71 | lengths = [l for l in lengths]
72 | # Graph where 2k hop neighbors are connected
73 | k_hop_graph = nx.Graph()
74 | for lengths_tuple in lengths:
75 | origin = lengths_tuple[0]
76 | edges = [(origin, dest) for dest in lengths_tuple[1].keys()]
77 | k_hop_graph.add_edges_from(edges)
78 | # Color the k-hop graph
79 | best_n_colors = np.infty
80 | best_color_dict = None
81 | # for strategy in ['largest_first', 'random_sequential', 'saturation_largest_first']:
82 | for strategy in ['largest_first']:
83 | color_dict = greedy_color(k_hop_graph, strategy)
84 | n_colors = np.max([color for color in color_dict.values()]) + 1
85 | if n_colors < best_n_colors:
86 | best_n_colors = n_colors
87 | best_color_dict = color_dict
88 | # Convert back to torch-geometric. The coloring is contained in data.x
89 | data.coloring = torch.zeros((data.num_nodes, 1), dtype=torch.long)
90 | for key, val in best_color_dict.items():
91 | data.coloring[key] = val
92 | print('Number of nodes: {} - Number of colors: {}'.format(data.num_nodes, data.coloring.max() + 1))
93 | return data
94 |
95 | def __repr__(self):
96 | return '{}({})'.format(self.__class__.__name__, self.k)
97 |
98 |
99 | class OneHotNodeEdgeFeatures(object):
100 | def __init__(self, node_types, edge_types):
101 | self.c = node_types
102 | self.d = edge_types
103 |
104 | def __call__(self, data):
105 | n = data.x.shape[0]
106 | node_encoded = torch.zeros((n, self.c), dtype=torch.float32)
107 | node_encoded.scatter_(1, data.x.long(), 1)
108 | data.x = node_encoded
109 | e = data.edge_attr.shape[0]
110 | edge_encoded = torch.zeros((e, self.d), dtype=torch.float32)
111 | edge_attr = (data.edge_attr - 1).long().unsqueeze(-1)
112 | edge_encoded.scatter_(1, edge_attr, 1)
113 | data.edge_attr = edge_encoded
114 | return data
115 |
116 | def __repr__(self):
117 | return str(self.__class__.__name__)
--------------------------------------------------------------------------------
/multi_task_main.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 | import yaml
4 | from multi_task_utils.train import execute_train, build_arg_parser
5 |
6 | # Training settings
7 | parser = build_arg_parser()
8 | parser.add_argument('--wandb', action='store_true')
9 | parser.add_argument('--batch-size', type=int, default=16)
10 | parser.add_argument('--clip', type=float, default=5)
11 | parser.add_argument('--name', type=str, help="name for weights and biases")
12 | parser.add_argument('--debug', action='store_true')
13 | parser.add_argument('--load-from-epoch', type=int, default=-1)
14 | args = parser.parse_args()
15 |
16 | yaml_file = 'config_multi_task.yaml'
17 | with open(yaml_file) as f:
18 | model_config = yaml.load(f, Loader=yaml.FullLoader)
19 | print(model_config)
20 |
21 | model_name = model_config['model_name']
22 | model_config.pop('model_name')
23 | print("Model name:", model_name)
24 |
25 | if args.wandb or args.name:
26 | import wandb
27 | args.wandb = True
28 | if args.name is None:
29 | args.name = model_name + f'_{args.k}_{args.n}'
30 | wandb.init(project="pna_v2", config=model_config, name=args.name)
31 | wandb.config.update(args)
32 |
33 | execute_train(gnn_args=model_config, args=args)
34 |
--------------------------------------------------------------------------------
/multi_task_utils/train.py:
--------------------------------------------------------------------------------
1 | # This file was adapted from https://github.com/lukecavabarrett/pna
2 |
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | import argparse
7 | import os
8 | import sys
9 | import time
10 | from types import SimpleNamespace
11 | import wandb
12 | import numpy as np
13 | import torch
14 | import torch.optim as optim
15 | import numpy.random as npr
16 | from torch_geometric.data import DataLoader
17 | from models.model_multi_task import SMP
18 | from multi_task_utils.util import load_dataset, to_torch_geom, specific_loss_torch_geom
19 |
20 | log_loss_tasks = ["log_shortest_path", "log_eccentricity", "log_laplacian",
21 | "log_connected", "log_diameter", "log_radius"]
22 |
23 |
24 | def build_arg_parser():
25 | """
26 | :return: argparse.ArgumentParser() filled with the standard arguments for a training session.
27 | Might need to be enhanced for some models.
28 | """
29 | parser = argparse.ArgumentParser()
30 |
31 | parser.add_argument('--data', type=str, default='./data/multitask_dataset.pkl', help='Data path.')
32 | parser.add_argument('--gpu', type=int, help='Id of the GPU')
33 | parser.add_argument('--no-cuda', action='store_true', default=False, help='Disables CUDA training.')
34 | parser.add_argument('--only_nodes', action='store_true', default=False, help='Evaluate only nodes labels.')
35 | parser.add_argument('--only_graph', action='store_true', default=False, help='Evaluate only graph labels.')
36 | parser.add_argument('--seed', type=int, default=42, help='Random seed.')
37 | parser.add_argument('--epochs', type=int, default=3000, help='Number of epochs to train.')
38 | parser.add_argument('--lr', type=float, default=0.001, help='Initial learning rate.')
39 | parser.add_argument('--weight_decay', type=float, default=1e-6, help='Weight decay (L2 loss on parameters).')
40 | parser.add_argument('--patience', type=int, default=1000, help='Patience')
41 | parser.add_argument('--loss', type=str, default='mse', help='Loss function to use.')
42 | parser.add_argument('--print_every', type=int, default=5, help='Print training results every')
43 | return parser
44 |
45 |
46 | def execute_train(gnn_args, args):
47 | """
48 | :param gnn_args: the description of the model to be trained (expressed as arguments for GNN.__init__)
49 | :param args: the parameters of the training session
50 | """
51 | if not os.path.isdir('./saved_models'):
52 | os.mkdir('./saved_models')
53 | if args.name is not None:
54 | save_dir = f'./saved_models/{args.name}'
55 | else:
56 | save_dir = f'./saved_models/'
57 | if args.name is not None and not os.path.isdir(save_dir):
58 | os.mkdir(save_dir)
59 |
60 | use_cuda = args.gpu is not None and torch.cuda.is_available() and not args.no_cuda
61 | if use_cuda:
62 | device = torch.device("cuda:" + str(args.gpu))
63 | torch.cuda.set_device(args.gpu)
64 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
65 | else:
66 | device = "cpu"
67 | args.device = device
68 | args.kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
69 | print('Using device:', device)
70 |
71 | np.random.seed(args.seed)
72 | torch.manual_seed(args.seed)
73 | if use_cuda:
74 | torch.cuda.manual_seed(args.seed)
75 |
76 | # load data
77 | adj, features, node_labels, graph_labels = load_dataset(args.data, args.loss, args.only_nodes, args.only_graph,
78 | print_baseline=True)
79 | print("Processing torch geometric data")
80 | graphs = to_torch_geom(adj, features, node_labels, graph_labels, device, args.debug)
81 | train_loaders = [DataLoader(given_size, args.batch_size, shuffle=True) for given_size in graphs['train']]
82 | batch_sizes = {'train': args.batch_size, 'val': 128, 'test': 256}
83 | val_loaders = [DataLoader(given_size, 128) for given_size in graphs['val']]
84 | test_loaders = [DataLoader(given_size, 256) for given_size in graphs['test']]
85 | print("Data loaders created")
86 | # model and optimizer
87 | gnn_args = SimpleNamespace(**gnn_args)
88 |
89 | gnn_args.num_input_features = features['train'][0].shape[2]
90 | gnn_args.nodes_out = 3
91 | gnn_args.graph_out = 3
92 | model = SMP(**vars(gnn_args)).to(device)
93 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
94 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
95 | step_size=50,
96 | gamma=0.92)
97 |
98 | pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
99 | print("Total params", pytorch_total_params)
100 |
101 | if args.load_from_epoch != -1:
102 | checkpoint = torch.load(os.path.join(save_dir, f'{args.load_from_epoch}.pkl'))
103 | model.load_state_dict(checkpoint['model_state_dict'])
104 | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
105 | epoch = checkpoint['epoch']
106 | else:
107 | epoch = 0
108 |
109 |
110 | def train(epoch):
111 | """ Execute a single epoch of the training loop
112 | epoch (int): the number of the epoch being performed (0-indexed)."""
113 | t = time.time()
114 |
115 | # 1. Train
116 | nan_counts = 0
117 | model.train()
118 | total_train_loss_per_task = 0
119 | npr.shuffle(train_loaders)
120 | for i, loader in enumerate(train_loaders):
121 | for j, data in enumerate(loader):
122 | # Optimization
123 | optimizer.zero_grad()
124 | output = model(data.to(device))
125 | train_loss_per_task = specific_loss_torch_geom(output, (data.pos, data.y), data.batch, args.batch_size)
126 | loss_train = torch.mean(train_loss_per_task)
127 | if torch.isnan(loss_train):
128 | print(f"Warning: loss was nan at epoch {epoch} and batch {i}{j}.")
129 | nan_counts += 1
130 | if nan_counts < 20:
131 | continue
132 | else:
133 | raise ValueError(f"Too many NaNs. Stopping training at epoch {epoch}. Best epoch: {best_epoch}")
134 | loss_train.backward()
135 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
136 | optimizer.step()
137 | # Compute metrics
138 | total_train_loss_per_task += train_loss_per_task / len(loader)
139 | total_train_loss_per_task /= len(train_loaders)
140 | train_log_loss_per_task = torch.log10(total_train_loss_per_task).data.cpu().numpy()
141 | train_loss = torch.mean(total_train_loss_per_task).data.item()
142 |
143 | # validation epoch
144 | model.eval()
145 | val_loss_per_task = 0
146 | for loader in val_loaders:
147 | for i, data in enumerate(loader):
148 | if i > 0:
149 | print("Warning: not all the batch was loaded at once. It will lead to incorrect results.")
150 | output = model(data.to(device))
151 | batch_loss_per_task = specific_loss_torch_geom(output, (data.pos, data.y), data.batch, batch_sizes['val'])
152 | val_loss_per_task += batch_loss_per_task.detach() / len(val_loaders)
153 |
154 | val_log_loss_per_task = torch.log10(val_loss_per_task).data.cpu().numpy()
155 | val_log_loss = torch.mean(val_loss_per_task).item()
156 |
157 | if epoch % args.print_every == 0:
158 | print('Epoch: {:04d}'.format(epoch + 1),
159 | 'loss.train: {:.4f}'.format(train_loss),
160 | 'log.loss.val: {:.4f}'.format(val_log_loss),
161 | 'time: {:.4f}s'.format(time.time() - t))
162 | print(f'train loss per task (log10 scale): {train_log_loss_per_task}')
163 | print(f'val loss per task (log10 scale): {val_log_loss_per_task}')
164 | sys.stdout.flush()
165 | if args.wandb:
166 | wandb_dict = {"Epoch": epoch, "Duration": time.time() - t, "Train loss": train_loss,
167 | "Val log loss": val_log_loss}
168 | for loss, tr, val in zip(log_loss_tasks, train_log_loss_per_task, val_log_loss_per_task):
169 | wandb_dict[loss + 'tr'] = tr
170 | wandb_dict[loss + 'val'] = val
171 | wandb.log(wandb_dict)
172 |
173 | return val_log_loss
174 |
175 | def compute_test():
176 | """
177 | Evaluate the current model on all the sets of the dataset, printing results.
178 | This procedure is destructive on datasets.
179 | """
180 | model.eval()
181 | sets = list(features.keys())
182 | for dset, loaders in zip(sets, [train_loaders, val_loaders, test_loaders]):
183 | final_specific_loss = 0
184 | final_total_loss = 0
185 | for loader in loaders:
186 | loader_total_loss = 0
187 | loader_specific_loss = 0
188 | for data in loader:
189 | output = model(data.to(device))
190 | specific_loss = specific_loss_torch_geom(output, (data.pos, data.y),
191 | data.batch, batch_sizes[dset]).detach()
192 | loader_specific_loss += specific_loss
193 | loader_total_loss += torch.mean(specific_loss)
194 | # Average the loss over each loader
195 | loader_specific_loss /= len(loader)
196 | loader_total_loss /= len(loader)
197 | # Average the loss over the different loaders
198 | final_specific_loss += loader_specific_loss / len(loaders)
199 | final_total_loss += loader_total_loss / len(loaders)
200 | del output, loader_specific_loss
201 |
202 | print("Test set results ", dset, ": loss= {:.4f}".format(final_total_loss))
203 | print(dset, ": ", final_specific_loss)
204 | print("Results in log scale", np.log10(final_specific_loss.detach().cpu()),
205 | np.log10(final_total_loss.detach().cpu().numpy()))
206 | if args.wandb:
207 | wandb.run.summary["test results"] = np.log10(final_specific_loss.detach().cpu())
208 | # free unnecessary data
209 |
210 |
211 | final_specific_numpy = np.log10(final_specific_loss.detach().cpu())
212 | del final_total_loss, final_specific_loss
213 | torch.cuda.empty_cache()
214 | return final_specific_numpy
215 |
216 | sys.stdout.flush()
217 | # Train model
218 | t_total = time.time()
219 | loss_values = []
220 | bad_counter = 0
221 | best = args.epochs + 1
222 | best_epoch = -1
223 |
224 | sys.stdout.flush()
225 |
226 | while epoch < args.epochs:
227 | epoch += 1
228 |
229 | loss_values.append(train(epoch))
230 | scheduler.step()
231 | if epoch % 100 == 0:
232 | print("Results on the test set:")
233 | results_test = compute_test()
234 | print('Test set results', results_test)
235 | print(f"Saving checkpoint at epoch {epoch}")
236 | torch.save({
237 | 'epoch': epoch,
238 | 'model_state_dict': model.state_dict(),
239 | 'optimizer_state_dict': optimizer.state_dict(),
240 | }, os.path.join(save_dir, f'{epoch}.pkl'))
241 |
242 | if loss_values[-1] < best:
243 | # save current model
244 | if loss_values[-1] < best:
245 | print(f"New best validation error at epoch {epoch}")
246 | else:
247 | print(f"Saving checkpoint at epoch {epoch}")
248 | torch.save({
249 | 'epoch': epoch,
250 | 'model_state_dict': model.state_dict(),
251 | 'optimizer_state_dict': optimizer.state_dict(),
252 | }, os.path.join(save_dir, f'{epoch}.pkl'))
253 | # remove previous model
254 | if best_epoch >= 0:
255 | f_name = os.path.join(save_dir, f'{best_epoch}.pkl')
256 | if os.path.isfile(f_name):
257 | os.remove(f_name)
258 | # update training variables
259 | best = loss_values[-1]
260 | best_epoch = epoch
261 | bad_counter = 0
262 | else:
263 | bad_counter += 1
264 |
265 | if bad_counter == args.patience:
266 | print('Early stop at epoch {} (no improvement in last {} epochs)'.format(epoch + 1, bad_counter))
267 | break
268 |
269 | print("Optimization Finished!")
270 | print("Total time elapsed: {:.4f}s".format(time.time() - t_total))
271 |
272 | # Restore best model
273 | print('Loading {}th epoch'.format(best_epoch + 1))
274 | checkpoint = torch.load(os.path.join(save_dir, f'{best_epoch}.pkl'))
275 | model.load_state_dict(checkpoint['model_state_dict'])
276 |
277 | # Testing
278 | print("Results on the test set:")
279 | results_test = compute_test()
280 | print('Test set results', results_test)
281 |
--------------------------------------------------------------------------------
/multi_task_utils/util.py:
--------------------------------------------------------------------------------
1 | # This file was adapted from https://github.com/lukecavabarrett/pna
2 |
3 | from __future__ import division
4 | from __future__ import print_function
5 | from torch_geometric import data
6 | from torch_geometric.utils import dense_to_sparse
7 | import pickle
8 | import numpy as np
9 | import torch
10 | import torch.nn.functional as F
11 | from torch_geometric.nn import global_add_pool
12 |
13 |
14 | def to_torch_geom(adj, features, node_labels, graph_labels, device, debug):
15 | graphs = {}
16 | for key in adj.keys(): # train, val, test
17 | graphs[key] = []
18 | for i in range(len(adj[key])): # Graph of a given size
19 | batch_i = []
20 | for j in range(adj[key][i].shape[0]): # Number of graphs
21 | graph_adj = adj[key][i][j]
22 | graph = data.Data(x=features[key][i][j],
23 | edge_index=dense_to_sparse(graph_adj)[0],
24 | y=graph_labels[key][i][j].unsqueeze(0),
25 | pos=node_labels[key][i][j])
26 | if not debug:
27 | batch_i.append(graph)
28 | if debug:
29 | batch_i.append(graph)
30 | graphs[key].append(batch_i)
31 | return graphs
32 |
33 |
34 | def load_dataset(data_path, loss, only_nodes, only_graph, print_baseline=True):
35 | with open(data_path, 'rb') as f:
36 | (adj, features, node_labels, graph_labels) = pickle.load(f)
37 |
38 | # normalize labels
39 | max_node_labels = torch.cat([nls.max(0)[0].max(0)[0].unsqueeze(0) for nls in node_labels['train']]).max(0)[0]
40 | max_graph_labels = torch.cat([gls.max(0)[0].unsqueeze(0) for gls in graph_labels['train']]).max(0)[0]
41 | for dset in node_labels.keys():
42 | node_labels[dset] = [nls / max_node_labels for nls in node_labels[dset]]
43 | graph_labels[dset] = [gls / max_graph_labels for gls in graph_labels[dset]]
44 |
45 | if print_baseline:
46 | # calculate baseline
47 | mean_node_labels = torch.cat([nls.mean(0).mean(0).unsqueeze(0) for nls in node_labels['train']]).mean(0)
48 | mean_graph_labels = torch.cat([gls.mean(0).unsqueeze(0) for gls in graph_labels['train']]).mean(0)
49 |
50 | for dset in node_labels.keys():
51 | if dset not in ['train', 'val']:
52 | baseline_nodes = [mean_node_labels.repeat(list(nls.shape[0:-1]) + [1]) for nls in node_labels[dset]]
53 | baseline_graph = [mean_graph_labels.repeat([gls.shape[0], 1]) for gls in graph_labels[dset]]
54 |
55 | print("Baseline loss ", dset,
56 | np.log10(specific_loss_multiple_batches((baseline_nodes, baseline_graph),
57 | (node_labels[dset], graph_labels[dset]),
58 | loss=loss, only_nodes=only_nodes, only_graph=only_graph)))
59 |
60 | return adj, features, node_labels, graph_labels
61 |
62 |
63 | SUPPORTED_ACTIVATION_MAP = {'ReLU', 'Sigmoid', 'Tanh', 'ELU', 'SELU', 'GLU', 'LeakyReLU', 'Softplus', 'None'}
64 |
65 |
66 | def get_activation(activation):
67 | """ returns the activation function represented by the input string """
68 | if activation and callable(activation):
69 | # activation is already a function
70 | return activation
71 | # search in SUPPORTED_ACTIVATION_MAP a torch.nn.modules.activation
72 | activation = [x for x in SUPPORTED_ACTIVATION_MAP if activation.lower() == x.lower()]
73 | assert len(activation) == 1 and isinstance(activation[0], str), 'Unhandled activation function'
74 | activation = activation[0]
75 | if activation.lower() == 'none':
76 | return None
77 | return vars(torch.nn.modules.activation)[activation]()
78 |
79 |
80 | def get_loss(loss, output, target):
81 | if loss == "mse":
82 | return F.mse_loss(output, target)
83 | elif loss == "cross_entropy":
84 | if len(output.shape) > 2:
85 | (B, N, _) = output.shape
86 | output = output.reshape((B * N, -1))
87 | target = target.reshape((B * N, -1))
88 | _, target = target.max(dim=1)
89 | return F.cross_entropy(output, target)
90 | else:
91 | print("Error: loss function not supported")
92 |
93 |
94 | def specific_loss_torch_geom(output, target, batch, batch_size):
95 | """ output: list of len 2 containing node and graph outputs
96 | returns the average losses of each task """
97 | average_nodes = output[0].shape[0] / batch_size # Average nb nodes in each graph
98 | # Node loss
99 | node_out = output[0] # N x 3
100 | loss = (node_out - target[0]) ** 2
101 | error = global_add_pool(loss, batch, batch_size) / average_nodes # N graphs x 3
102 | nodes_loss = torch.mean(error, dim=0) # 3
103 | graph_loss = torch.mean((output[1] - target[1]) ** 2, dim=0) # 3
104 | specific_loss = torch.cat((nodes_loss, graph_loss))
105 | return specific_loss
106 |
107 |
108 | def total_loss_torch_geom(output, target, batch, batch_size):
109 | """ returns the average of the average losses of each task """
110 | specific_loss = specific_loss_torch_geom(output, target, batch, batch_size)
111 | weighted_average = torch.mean(specific_loss)
112 | return weighted_average
113 |
114 |
115 | def total_loss(output, target, loss='mse', only_nodes=False, only_graph=False):
116 | """ returns the average of the average losses of each task """
117 | assert not (only_nodes and only_graph)
118 |
119 | if only_nodes:
120 | nodes_loss = get_loss(loss, output[0], target[0])
121 | return nodes_loss
122 | elif only_graph:
123 | graph_loss = get_loss(loss, output[1], target[1])
124 | return graph_loss
125 |
126 | nodes_loss = get_loss(loss, output[0], target[0])
127 | graph_loss = get_loss(loss, output[1], target[1])
128 | weighted_average = (nodes_loss * output[0].shape[-1] + graph_loss * output[1].shape[-1]) / (
129 | output[0].shape[-1] + output[1].shape[-1])
130 | return weighted_average
131 |
132 |
133 | def total_loss_multiple_batches(output, target, loss='mse', only_nodes=False, only_graph=False):
134 | """ returns the average of the average losses of each task over all batches,
135 | batches are weighted equally regardless of their cardinality or graph size """
136 | return sum([total_loss_torch_geom((output[0][batch], output[1][batch]), (target[0][batch], target[1][batch]),
137 | loss, only_nodes, only_graph).data.item()
138 | for batch in range(len(output[0]))]) / len(output[0])
139 |
140 |
141 | def specific_loss(output, target, loss='mse', only_nodes=False, only_graph=False):
142 | """ returns the average loss for each task """
143 | assert not (only_nodes and only_graph)
144 |
145 | if only_nodes:
146 | nodes_losses = [get_loss(loss, output[0][:, :, k], target[0][:, :, k]).item() for k in
147 | range(output[0].shape[-1])]
148 | return nodes_losses
149 | elif only_graph:
150 | graph_loss = [get_loss(loss, output[1][:, k], target[1][:, k]).item() for k in range(output[1].shape[-1])]
151 | return graph_loss
152 |
153 | nodes_losses = [get_loss(loss, output[0][:, :, k], target[0][:, :, k]).item() for k in range(output[0].shape[-1])]
154 | graph_loss = [get_loss(loss, output[1][:, k], target[1][:, k]).item() for k in range(output[1].shape[-1])]
155 | return nodes_losses + graph_loss
156 |
157 |
158 | def specific_loss_multiple_batches(output, target, loss='mse', only_nodes=False, only_graph=False):
159 | """ returns the average loss over all batches for each task,
160 | batches are weighted equally regardless of their cardinality or graph size """
161 | assert not (only_nodes and only_graph)
162 |
163 | n_batches = len(output[0])
164 | classes = (output[0][0].shape[-1] if not only_graph else 0) + (output[1][0].shape[-1] if not only_nodes else 0)
165 |
166 | sum_losses = [0] * classes
167 | for batch in range(n_batches):
168 | spec_loss = specific_loss((output[0][batch], output[1][batch]), (target[0][batch], target[1][batch]), loss,
169 | only_nodes, only_graph)
170 | for par in range(classes):
171 | sum_losses[par] += spec_loss[par]
172 |
173 | return [sum_loss / n_batches for sum_loss in sum_losses]
174 |
175 |
176 | def save_checkpoint(path, model, optimizer, epoch):
177 | torch.save({
178 | 'epoch': epoch,
179 | 'model_state_dict': model.state_dict(),
180 | 'optimizer_state_dict': optimizer.state_dict()}, path)
181 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | yaml
3 | argparse
4 | numpy
5 | pickle
6 | networkx
7 | easydict
8 | wandb
--------------------------------------------------------------------------------
/saved_models/PPGN_4/epoch0.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cvignac/SMP/95b55a880d0fc9149ddf32e8c2fdf5eac5b474b3/saved_models/PPGN_4/epoch0.pkl
--------------------------------------------------------------------------------
/saved_models/ZINC/Zinc_SMP.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cvignac/SMP/95b55a880d0fc9149ddf32e8c2fdf5eac5b474b3/saved_models/ZINC/Zinc_SMP.pkl
--------------------------------------------------------------------------------
/zinc_main.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 | import os
5 | import torch
6 | import torch.nn as nn
7 | from torch_geometric.data import DataLoader
8 | from torch_geometric.datasets import ZINC
9 | import argparse
10 | import numpy as np
11 | import time
12 | import yaml
13 | from models.model_zinc import SMPZinc
14 | from models.utils.transforms import OneHotNodeEdgeFeatures
15 |
16 | # Change the following to point to the the folder where the datasets are stored
17 | if os.path.isdir('/datasets2/'):
18 | rootdir = '/datasets2/ZINC/'
19 | else:
20 | rootdir = './data/ZINC/'
21 | yaml_file = './config_zinc.yaml'
22 |
23 | torch.manual_seed(0)
24 | np.random.seed(0)
25 |
26 | parser = argparse.ArgumentParser()
27 | parser.add_argument('--epochs', type=int, default=3000)
28 | parser.add_argument('--wandb', action='store_true',
29 | help="Use weights and biases library")
30 | parser.add_argument('--gpu', type=int, help='Id of gpu device. By default use cpu')
31 | parser.add_argument('--lr', type=float, default=0.001, help="Initial learning rate")
32 | parser.add_argument('--batch-size', type=int, default=128)
33 | parser.add_argument('--weight-decay', type=float, default=1e-6)
34 | parser.add_argument('--clip', type=float, default=10, help="Gradient clipping")
35 | parser.add_argument('--name', type=str, help="Name for weights and biases")
36 | parser.add_argument('--full', action='store_true')
37 | parser.add_argument('--lr-reduce-factor', type=float, default=0.5)
38 | parser.add_argument('--lr_schedule_patience', type=int, default=100)
39 | parser.add_argument('--save-model', action='store_true', help='Save the model after training')
40 | parser.add_argument('--load-model', action='store_true', help='Evaluate a pretrained model')
41 | parser.add_argument('--lr-limit', type=float, default=5e-6, help='Stop training once it is reached')
42 | args = parser.parse_args()
43 |
44 | args.subset = not args.full # Train either on the full dataset or the subset of 10k samples
45 |
46 | # Handle the device
47 | use_cuda = args.gpu is not None and torch.cuda.is_available()
48 | if use_cuda:
49 | device = torch.device("cuda:" + str(args.gpu))
50 | torch.cuda.set_device(args.gpu)
51 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
52 | else:
53 | device = "cpu"
54 | args.device = device
55 | args.kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
56 | print('Device used:', device)
57 |
58 | # Load the config file of the model
59 | with open(yaml_file) as f:
60 | model_config = yaml.load(f, Loader=yaml.FullLoader)
61 | print(model_config)
62 | model_config['num_input_features'] = 28 if model_config['use_x'] else 29
63 | model_config['num_edge_features'] = 3
64 | model_config['num_classes'] = 1
65 |
66 |
67 | # Create a name for weights and biases
68 | model_name = 'Zinc_SMP'
69 | if args.name:
70 | args.wandb = True
71 | if args.wandb:
72 | import wandb
73 | if args.name is None:
74 | args.name = model_name + \
75 | f"_{model_config['num_layers']}_{model_config['hidden']}_{model_config['hidden_final']}"
76 | wandb.init(project="smp-zinc-subset" if args.subset else "smp-zinc", config=model_config, name=args.name)
77 | wandb.config.update(args)
78 |
79 |
80 | # The paths can be changed here
81 | if args.save_model or args.load_model:
82 | if os.path.isdir('/SCRATCH2/'):
83 | savedir = '/SCRATCH2/vignac/SMP/saved_models/ZINC/'
84 | else:
85 | savedir = './saved_models/ZINC/'
86 | if not os.path.isdir(savedir):
87 | os.makedirs(savedir)
88 |
89 |
90 | def train():
91 | """ Train for one epoch. """
92 | model.train()
93 | loss_all = 0
94 | for batch_idx, data in enumerate(train_loader):
95 | data = data.to(device)
96 | optimizer.zero_grad()
97 | output = model(data)
98 | loss = loss_fct(output, data.y)
99 | loss.backward()
100 | loss_all += loss.item()
101 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
102 | optimizer.step()
103 | return loss_all / len(train_loader.dataset)
104 |
105 |
106 | def test(loader):
107 | model.eval()
108 | total_mae = 0.0
109 | for data in loader:
110 | data = data.to(device)
111 | output = model(data)
112 | total_mae += loss_fct(output, data.y).item()
113 | average_mae = total_mae / len(loader.dataset)
114 | return average_mae
115 |
116 |
117 | start = time.time()
118 |
119 | model = SMPZinc(**model_config).to(device)
120 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
121 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
122 | factor=args.lr_reduce_factor,
123 | patience=args.lr_schedule_patience,
124 | verbose=True)
125 | lr_limit = args.lr_limit
126 |
127 | if args.load_model:
128 | model = torch.load(savedir + model_name + '.pkl')
129 |
130 | pytorch_total_params = sum(p.numel() for p in model.parameters())
131 | print("Total number of parameters", pytorch_total_params)
132 |
133 | loss_fct = nn.L1Loss(reduction='sum')
134 |
135 | # Load the data
136 | batch_size = args.batch_size
137 | transform = OneHotNodeEdgeFeatures(model_config['num_input_features'] - 1, model_config['num_edge_features'])
138 |
139 | train_data = ZINC(rootdir, subset=args.subset, split='train', pre_transform=transform)
140 | val_data = ZINC(rootdir, subset=args.subset, split='val', pre_transform=transform)
141 | test_data = ZINC(rootdir, subset=args.subset, split='test', pre_transform=transform)
142 |
143 | train_loader = DataLoader(train_data, batch_size, shuffle=True)
144 | val_loader = DataLoader(val_data, batch_size, shuffle=False)
145 | test_loader = DataLoader(test_data, batch_size, shuffle=False)
146 |
147 | print("Starting to train")
148 | for epoch in range(args.epochs):
149 | if args.load_model:
150 | break
151 | epoch_start = time.time()
152 | tr_loss = train()
153 | current_lr = optimizer.param_groups[0]["lr"]
154 | if current_lr < lr_limit:
155 | break
156 | duration = time.time() - epoch_start
157 | print(f'Time:{duration:2.2f} | {epoch:5d} | Train MAE: {tr_loss:2.5f} | LR: {current_lr:.6f}')
158 | mae_val = test(val_loader)
159 | scheduler.step(mae_val)
160 | print(f'MAE on the validation set: {mae_val:2.5f}')
161 | if args.wandb:
162 | wandb.log({"Epoch": epoch, "Duration": duration, "Train MAE": tr_loss,
163 | "Val MAE": mae_val})
164 |
165 | if not args.load_model:
166 | cur_lr = optimizer.param_groups[0]["lr"]
167 | print(f'{epoch:2.5f} | Loss: {tr_loss:2.5f} | LR: {cur_lr:.6f} | Val MAE: {mae_val:2.5f}')
168 | print(f'Elapsed time: {(time.time() - start) / 60:.1f} minutes')
169 | print('done!')
170 |
171 | test_mae = test(test_loader)
172 | print(f"Final MAE on the test set: {test_mae}")
173 | print("Done.")
174 |
175 | if args.wandb:
176 | wandb.run.summary['Final test MAE'] = test_mae
177 |
178 | if args.save_model:
179 | torch.save(model, savedir + model_name + '.pkl')
--------------------------------------------------------------------------------