├── images ├── siamese.png ├── 01_graph1.png ├── 02_graph2.png ├── 05_result.png ├── 07_match.png ├── 09_preds.png ├── download.png ├── 08_mismatch.png ├── 03_result_graph2.png └── 04_result_graph1.png ├── requirements.txt ├── .gitignore ├── loaders ├── test_data_generator.py ├── superpixels.py ├── benchmark.py ├── loaders.py └── data_generator.py ├── maskedtensors ├── test_metrics.py ├── test_losses.py ├── test_hierarchy.py ├── test_maskedtensor.py └── maskedtensor.py ├── data_benchmarking_gnns ├── generating_MNIST.py ├── generating_CIFAR10.py ├── get_data.py ├── data_generator.py └── data_helper.py ├── models ├── __init__.py ├── blocks_emb.py ├── utils.py ├── trainers.py └── layers.py ├── toolbox ├── losses.py ├── metrics.py └── utils.py ├── default_config.yaml ├── README.md ├── LICENSE ├── commander_explore.py └── plot_accuracy_regular.ipynb /images/siamese.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlelarge/graph_neural_net/HEAD/images/siamese.png -------------------------------------------------------------------------------- /images/01_graph1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlelarge/graph_neural_net/HEAD/images/01_graph1.png -------------------------------------------------------------------------------- /images/02_graph2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlelarge/graph_neural_net/HEAD/images/02_graph2.png -------------------------------------------------------------------------------- /images/05_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlelarge/graph_neural_net/HEAD/images/05_result.png -------------------------------------------------------------------------------- /images/07_match.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlelarge/graph_neural_net/HEAD/images/07_match.png -------------------------------------------------------------------------------- /images/09_preds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlelarge/graph_neural_net/HEAD/images/09_preds.png -------------------------------------------------------------------------------- /images/download.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlelarge/graph_neural_net/HEAD/images/download.png -------------------------------------------------------------------------------- /images/08_mismatch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlelarge/graph_neural_net/HEAD/images/08_mismatch.png -------------------------------------------------------------------------------- /images/03_result_graph2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlelarge/graph_neural_net/HEAD/images/03_result_graph2.png -------------------------------------------------------------------------------- /images/04_result_graph1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlelarge/graph_neural_net/HEAD/images/04_result_graph1.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.5 2 | numpy 3 | networkx 4 | pytest 5 | PyYAML 6 | scikit-learn 7 | pytorch-lightning 8 | 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | #Custom 2 | *.np 3 | *.npz 4 | *.png 5 | *.pt 6 | *.py.swp 7 | .ipynb_checkpoints/ 8 | dataset*/ 9 | runs/ 10 | interpretation/ 11 | temp/ 12 | .neptune/ 13 | *.json 14 | *.pkl 15 | 16 | #pyconcorde residual files 17 | *.res 18 | *.sol 19 | 20 | # Byte-compiled / optimized / DLL files 21 | __pycache__/ 22 | *.py[cod] 23 | *$py.class 24 | -------------------------------------------------------------------------------- /loaders/test_data_generator.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import networkx 3 | import torch 4 | import data_generator 5 | 6 | N_VERTICES = 50 7 | NOISE = 0.05 8 | EDGE_DENSITY = 0.2 9 | 10 | @pytest.fixture 11 | def regular_graph(): 12 | g, W = data_generator.generate_regular_graph_netx(EDGE_DENSITY, N_VERTICES) 13 | return g, W 14 | 15 | def test_edge_swap_on_regular(regular_graph): 16 | g, W = regular_graph 17 | W_noise = data_generator.noise_edge_swap(g, W, NOISE, EDGE_DENSITY) 18 | degrees = torch.sum(W, 0) 19 | degrees_noise = torch.sum(W_noise, 0) 20 | assert torch.equal(degrees, degrees_noise) 21 | -------------------------------------------------------------------------------- /maskedtensors/test_metrics.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from toolbox.metrics import accuracy_linear_assignment, accuracy_max 4 | from maskedtensor import from_list 5 | 6 | N_VERTICES_RANGE = range(40, 50) 7 | 8 | @pytest.fixture 9 | def correct_batch(): 10 | tensor_lst = [torch.eye(n_vertices) for n_vertices in N_VERTICES_RANGE] 11 | return from_list(tensor_lst, dims=(0, 1)) 12 | 13 | TEST_ACCURACY_FUNCS = [ 14 | #(accuracy_linear_assignment, 'accuracy_linear_assignment'), 15 | (accuracy_max, 'accuracy_max')] 16 | 17 | @pytest.mark.parametrize('func_data', TEST_ACCURACY_FUNCS, ids=lambda func_data: func_data[1]) 18 | def test_perfect_accuracy(correct_batch, func_data): 19 | func, _ = func_data 20 | correct, total = func(correct_batch) 21 | assert correct == total, (correct, total) 22 | 23 | @pytest.fixture 24 | def batch(): 25 | tensor_lst = [torch.empty(n_vertices, n_vertices).normal_() for n_vertices in N_VERTICES_RANGE] 26 | return from_list(tensor_lst, dims=(0, 1)) 27 | -------------------------------------------------------------------------------- /data_benchmarking_gnns/generating_MNIST.py: -------------------------------------------------------------------------------- 1 | # This file should be run in the env provided in 2 | # https://github.com/graphdeeplearning/benchmarking-gnns/blob/master/docs/01_benchmark_installation.md 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from data.data import LoadData 8 | 9 | DATASET_NAME = "MNIST" 10 | dataset = LoadData(DATASET_NAME) 11 | 12 | testset_dense = [dataset.collate_dense_gnn([d]) for d in dataset.test] 13 | testset_dense = [(g.squeeze(0), l) for (g,l) in testset_dense] 14 | 15 | torch.save(testset_dense, '/home/mlelarge/data/superpixels/MNIST/mnist_test.pt') 16 | 17 | valset = [dataset.collate_dense_gnn([d]) for d in dataset.val] 18 | valset = [(g.squeeze(0), l) for (g,l) in valset] 19 | torch.save(valset, '/home/mlelarge/data/superpixels/MNIST/mnist_val.pt') 20 | 21 | trainset = [dataset.collate_dense_gnn([d]) for d in dataset.train] 22 | trainset = [(g.squeeze(0), l) for (g,l) in trainset] 23 | torch.save(trainset, '/home/mlelarge/data/superpixels/MNIST/mnist_train.pt') -------------------------------------------------------------------------------- /data_benchmarking_gnns/generating_CIFAR10.py: -------------------------------------------------------------------------------- 1 | # This file should be run in the env provided in 2 | # https://github.com/graphdeeplearning/benchmarking-gnns/blob/master/docs/01_benchmark_installation.md 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from data.data import LoadData 8 | 9 | DATASET_NAME = "CIFAR10" 10 | dataset = LoadData(DATASET_NAME) 11 | 12 | testset_dense = [dataset.collate_dense_gnn([d]) for d in dataset.test] 13 | testset_dense = [(g.squeeze(0), l) for (g,l) in testset_dense] 14 | 15 | torch.save(testset_dense, '/home/mlelarge/data/superpixels/CIFAR10/cifar_test.pt') 16 | 17 | valset = [dataset.collate_dense_gnn([d]) for d in dataset.val] 18 | valset = [(g.squeeze(0), l) for (g,l) in valset] 19 | torch.save(valset, '/home/mlelarge/data/superpixels/CIFAR10/cifar_val.pt') 20 | 21 | trainset = [dataset.collate_dense_gnn([d]) for d in dataset.train] 22 | trainset = [(g.squeeze(0), l) for (g,l) in trainset] 23 | torch.save(trainset, '/home/mlelarge/data/superpixels/CIFAR10/cifar_train.pt') -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.trainers import Siamese_Node_Exp#, Graph_Classif_Exp #, scaled_block, block, block_sym, Graph_Classif_Exp 2 | from toolbox.utils import load_json 3 | 4 | from data_benchmarking_gnns.data_helper import NUM_LABELS, NUM_CLASSES 5 | 6 | def get_siamese_model_exp(args, config_optim): 7 | args_dict = {'lr' : config_optim['lr'], 8 | 'scheduler_decay': config_optim['scheduler_decay'], 9 | 'scheduler_step': config_optim['scheduler_step'] 10 | } 11 | original_features_num = args['original_features_num'] 12 | node_emb = args['node_emb'] 13 | print('Fetching model %s with (total = %s ) init %s and inside %s' % (node_emb['type'], node_emb['num_blocks'], 14 | node_emb['block_init'], node_emb['block_inside'])) 15 | #print(node_emb) 16 | model = Siamese_Node_Exp(original_features_num, node_emb, **args_dict) 17 | return model 18 | 19 | def get_siamese_model_test(name, config=None): 20 | if config is None: 21 | split_name = name.split("/")[-4] 22 | cname = name.split(split_name)[0] 23 | config = load_json(cname+'config.json') 24 | return Siamese_Node_Exp.load_from_checkpoint(name, original_features_num=2, node_emb=config['arch']['node_emb']) 25 | -------------------------------------------------------------------------------- /maskedtensors/test_losses.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from toolbox.losses import triplet_loss 4 | import maskedtensor 5 | 6 | BATCH_SIZE = 32 7 | N_VERTICES = 50 8 | C = 10 9 | 10 | @pytest.fixture 11 | def std_batch(): 12 | tensor = torch.empty((BATCH_SIZE, N_VERTICES, N_VERTICES)).normal_() 13 | return tensor 14 | 15 | @pytest.fixture 16 | def masked_batch(): 17 | lst = [torch.empty((N_VERTICES, N_VERTICES)).normal_() 18 | for _ in range(BATCH_SIZE)] 19 | mtensor = maskedtensor.from_list(lst, dims=(0, 1)) 20 | return mtensor 21 | 22 | @pytest.fixture 23 | def batch(request): 24 | return request.getfixturevalue(request.param) 25 | 26 | @pytest.mark.parametrize('batch', ['std_batch', 'masked_batch'], indirect=True) 27 | def test_loss_fixed_size(batch): 28 | #device = torch.device('cpu') 29 | loss_func_mean = triplet_loss(loss_reduction='mean') 30 | loss_func_mean_of_mean = triplet_loss(loss_reduction='mean_of_mean') 31 | loss_mean = loss_func_mean(batch) 32 | loss_mean_of_mean = loss_func_mean_of_mean(batch) 33 | assert loss_mean.size() == loss_mean_of_mean.size() 34 | assert torch.allclose(loss_mean, loss_mean_of_mean), loss_mean - loss_mean_of_mean 35 | 36 | @pytest.fixture 37 | def rand_labels(): 38 | return torch.empty(BATCH_SIZE, 1, dtype=torch.long).random_(0, C) 39 | 40 | -------------------------------------------------------------------------------- /loaders/superpixels.py: -------------------------------------------------------------------------------- 1 | import torch 2 | #import pickle 3 | import os 4 | import time 5 | 6 | class SuperPixDataset(torch.utils.data.Dataset): 7 | 8 | def __init__(self, name, main_data_dir): 9 | """ 10 | Loading Superpixels datasets 11 | """ 12 | start = time.time() 13 | print("[I] Loading dataset %s..." % (name)) 14 | self.name = name 15 | data_dir = os.path.join(main_data_dir, 'superpixels/', name) 16 | if name == 'MNIST': 17 | self.test = torch.load(data_dir+'/mnist_test.pt') 18 | self.val = torch.load(data_dir+'/mnist_val.pt') 19 | self.train = torch.load(data_dir+'/mnist_train.pt') 20 | elif name == 'CIFAR10': 21 | self.test = torch.load(data_dir+'/cifar_test.pt') 22 | self.val = torch.load(data_dir+'/cifar_val.pt') 23 | self.train = torch.load(data_dir+'/cifar_train.pt') 24 | #self.test = torch.load(data_dir+'/cifar_val.pt') 25 | #self.val = torch.load(data_dir+'/cifar_val.pt') 26 | #self.train = torch.load(data_dir+'/cifar_val.pt') 27 | else: 28 | print('Only MNIST and CIFAR available') 29 | print('train, test, val sizes :',len(self.train),len(self.test),len(self.val)) 30 | print("[I] Finished loading.") 31 | print("[I] Data load time: {:.4f}s".format(time.time()-start)) -------------------------------------------------------------------------------- /toolbox/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.modules.activation import Sigmoid 5 | from toolbox.utils import get_device 6 | 7 | 8 | class triplet_loss(nn.Module): 9 | def __init__(self, loss_reduction='mean', loss=nn.CrossEntropyLoss(reduction='sum')): 10 | super(triplet_loss, self).__init__() 11 | self.loss = loss 12 | if loss_reduction == 'mean': 13 | self.increments = lambda new_loss, n_vertices : (new_loss, n_vertices) 14 | elif loss_reduction == 'mean_of_mean': 15 | self.increments = lambda new_loss, n_vertices : (new_loss/n_vertices, 1) 16 | else: 17 | raise ValueError('Unknown loss_reduction parameters {}'.format(loss_reduction)) 18 | 19 | # !!! to be checked: only working with graphs same size ?!!! 20 | def forward(self, raw_scores): 21 | """ 22 | raw_scores is the output of siamese network (bs,n_vertices,n_vertices) 23 | """ 24 | device = get_device(raw_scores) 25 | loss = 0 26 | total = 0 27 | for out in raw_scores: 28 | n_vertices = out.shape[0] 29 | ide = torch.arange(n_vertices) 30 | target = ide.to(device) 31 | incrs = self.increments(self.loss(out, target), n_vertices) 32 | loss += incrs[0] 33 | total += incrs[1] 34 | return loss/total 35 | -------------------------------------------------------------------------------- /maskedtensors/test_hierarchy.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from toolbox.metrics import accuracy_max 4 | 5 | from toolbox.losses import triplet_loss 6 | import maskedtensor 7 | import math 8 | import scipy.optimize 9 | 10 | N_VERTICES_RANGE = range(40, 50) 11 | DEVICE = torch.device('cpu') 12 | OPT_SCALE = True 13 | 14 | def perturb(target): 15 | target[0, :] = 2 16 | return target 17 | 18 | @pytest.fixture 19 | def batch(request): 20 | transpose = request.param 21 | if transpose: 22 | tensor_lst = [torch.t(perturb(torch.eye(n_vertices, n_vertices))) 23 | for n_vertices in N_VERTICES_RANGE] 24 | else: 25 | tensor_lst = [perturb(torch.eye(n_vertices, n_vertices)) for n_vertices in N_VERTICES_RANGE] 26 | return maskedtensor.from_list(tensor_lst, dims=(0, 1)) 27 | 28 | @pytest.mark.parametrize('batch', [False, True], indirect=['batch']) 29 | def test_hierarchy(batch): 30 | correct, total = accuracy_max(batch) 31 | acc = correct/total 32 | loss_func = triplet_loss(loss_reduction='mean') 33 | if OPT_SCALE: 34 | res = scipy.optimize.minimize_scalar(lambda x: loss_func(torch.mul(batch, x)), bracket=(1e-1, 1e2)) 35 | scale = res.x 36 | if scale <= 0: 37 | raise RuntimeError("Something went wrong during the optimization process") 38 | else: 39 | scale = 216 40 | loss = loss_func(torch.mul(batch, scale)) 41 | assert loss >= (1 - acc) * math.log(2) 42 | -------------------------------------------------------------------------------- /loaders/benchmark.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import data_benchmarking_gnns.data_helper as helper 3 | 4 | 5 | class BenchmarkDataset(torch.utils.data.Dataset): 6 | def __init__(self, dataset_name, num_fold): 7 | self.dataset_name = dataset_name 8 | self.num_fold = num_fold 9 | self.load_data() 10 | self.make_dataset() 11 | 12 | def load_data(self): 13 | graphs, labels = helper.load_dataset(self.dataset_name) 14 | if self.num_fold is None: 15 | idx = len(graphs) // 10 16 | self.train_graphs, self.train_labels, self.val_graphs, self.val_labels = [graphs[i] for i in range(idx, len(graphs))], 17 | [labels[i] for i in range(idx,len(graphs))], [graphs[i] for i in range(idx)], [labels[i] for i in range(idx)] 18 | elif self.num_fold == 0: 19 | train_idx, test_idx = helper.get_parameter_split(self.dataset_name) 20 | self.train_graphs, self.train_labels, self.val_graphs, self.val_labels = [graphs[i] for i in train_idx], [labels[i] for i in train_idx], [graphs[i] for i in test_idx], [labels[i] for i in test_idx] 21 | else: 22 | train_idx, test_idx = helper.get_train_val_indexes(self.num_fold, self.dataset_name) 23 | self.train_graphs, self.train_labels, self.val_graphs, self.val_labels = [graphs[i] for i in train_idx], [labels[i] for i in train_idx], [graphs[i] for i in test_idx], [labels[i] for i in test_idx] 24 | self.train_size = len(self.train_graphs) 25 | self.val_size = len(self.val_graphs) 26 | 27 | def make_dataset(self): 28 | self.train = [(torch.as_tensor(g, dtype=torch.float), torch.tensor(l, dtype=torch.long)) for (g,l) in zip(self.train_graphs, self.train_labels)] 29 | self.val = [(torch.as_tensor(g, dtype=torch.float), torch.tensor(l, dtype=torch.long)) for (g,l) in zip(self.val_graphs, self.val_labels)] 30 | -------------------------------------------------------------------------------- /data_benchmarking_gnns/get_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is used to download the data used to train and test our model 3 | adapted from https://github.com/hadarser/ProvablyPowerfulGraphNetworks_torch 4 | """ 5 | 6 | import os 7 | from six.moves import urllib 8 | import zipfile 9 | from pathlib import Path 10 | ROOT_DIR = Path.home() 11 | DATA_DIR = os.path.join(ROOT_DIR,'data/') 12 | #raw_dir = os.path.join(os.getcwd(), 'data') 13 | 14 | 15 | def download_url(url, folder, filename): 16 | r"""Downloads the content of an URL to a specific folder. 17 | 18 | Args: 19 | url (string): The url. 20 | folder (string): The folder. 21 | log (bool, optional): If :obj:`False`, will not print anything to the 22 | console. (default: :obj:`True`) 23 | """ 24 | print('Downloading', url) 25 | 26 | os.makedirs(folder, exist_ok=True) 27 | 28 | data = urllib.request.urlopen(url) 29 | path = os.path.join(folder, filename) 30 | 31 | with open(path, 'wb') as f: 32 | f.write(data.read()) 33 | 34 | return path 35 | 36 | 37 | def download_benchmarks(raw_dir): 38 | url = 'https://www.dropbox.com/s/vjd6wy5nemg2gh6/benchmark_graphs.zip?dl=1' 39 | file_path = download_url(url, raw_dir, 'benchmark_graphs.zip') 40 | zipfile.ZipFile(file_path, 'r').extractall(raw_dir) 41 | os.unlink(file_path) 42 | 43 | 44 | def download_QM9(raw_dir): 45 | urls = [('https://www.dropbox.com/sh/acvh0sqgnvra53d/AAAxhVewejSl7gVMACa1tBUda/QM9_test.p?dl=1', 'QM9_test.p'), 46 | ('https://www.dropbox.com/sh/acvh0sqgnvra53d/AAAOfEx-jGC6vvi43fh0tOq6a/QM9_val.p?dl=1', 'QM9_val.p'), 47 | ('https://www.dropbox.com/sh/acvh0sqgnvra53d/AADtx0EMRz5fhUNXaHFipkrza/QM9_train.p?dl=1', 'QM9_train.p')] 48 | data_dir = os.path.join(raw_dir, 'QM9') 49 | for url, filename in urls: 50 | _ = download_url(url, data_dir, filename) 51 | 52 | 53 | def main(): 54 | os.makedirs(DATA_DIR, exist_ok=True) 55 | download_benchmarks(DATA_DIR) 56 | #download_QM9() 57 | 58 | 59 | if __name__ == '__main__': 60 | main() 61 | 62 | -------------------------------------------------------------------------------- /models/blocks_emb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | #from collections import namedtuple, defaultdict 5 | from models.utils import * 6 | from models.layers import MlpBlock_Real, ColumnMaxPooling, Concat, Identity, Matmul#,ColumnSumPooling, MlpBlock_vec, AttentionBlock_vec, Permute, Matmul_zerodiag, Add, GraphAttentionLayer, GraphNorm, Diag, Rec_block, Recall_block 7 | 8 | 9 | def block_emb(in_features, out_features, depth_of_mlp, constant_n_vertices=True): 10 | return { 11 | 'in': Identity(), 12 | 'mlp3': MlpBlock_Real(in_features, out_features,depth_of_mlp, 13 | constant_n_vertices=constant_n_vertices) 14 | } 15 | 16 | def block(in_features, out_features, depth_of_mlp, constant_n_vertices=True): 17 | return { 18 | 'in': Identity(), 19 | 'mlp1': (MlpBlock_Real(in_features, out_features, depth_of_mlp, 20 | constant_n_vertices=constant_n_vertices), ['in']), 21 | 'mlp2': (MlpBlock_Real(in_features, out_features, depth_of_mlp, 22 | constant_n_vertices=constant_n_vertices), ['in']), 23 | 'mult': (Matmul(), ['mlp1', 'mlp2']), 24 | 'cat': (Concat(), ['mult', 'in']), 25 | 'mlp3': MlpBlock_Real(in_features+out_features, out_features,depth_of_mlp, 26 | constant_n_vertices=constant_n_vertices) 27 | } 28 | 29 | def base_model(original_features_num, num_blocks, in_features,out_features, depth_of_mlp, block=block, constant_n_vertices=True): 30 | d = {'in': Identity()} 31 | last_layer_features = original_features_num 32 | for i in range(num_blocks-1): 33 | d['block'+str(i+1)] = block(last_layer_features, in_features, depth_of_mlp, constant_n_vertices=constant_n_vertices) 34 | last_layer_features = in_features 35 | d['block'+str(num_blocks)] = block(last_layer_features, out_features, depth_of_mlp, constant_n_vertices=constant_n_vertices) 36 | return d 37 | 38 | def node_embedding(original_features_num, num_blocks, in_features,out_features, depth_of_mlp, 39 | block=block, constant_n_vertices=True, **kwargs): 40 | d = {'in': Identity()} 41 | d['bm'] = base_model(original_features_num, num_blocks, in_features,out_features, depth_of_mlp, block, constant_n_vertices=constant_n_vertices) 42 | d['suffix'] = ColumnMaxPooling() 43 | return d 44 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from collections import defaultdict 4 | 5 | ##################### 6 | ## dict utils 7 | ##################### 8 | 9 | union = lambda *dicts: {k: v for d in dicts for (k, v) in d.items()} 10 | 11 | def path_iter(nested_dict, pfx=()): 12 | for name, val in nested_dict.items(): 13 | if isinstance(val, dict): yield from path_iter(val, (*pfx, name)) 14 | else: yield ((*pfx, name), val) 15 | 16 | def map_nested(func, nested_dict): 17 | return {k: map_nested(func, v) if isinstance(v, dict) else func(v) for k,v in nested_dict.items()} 18 | 19 | def group_by_key(items): 20 | res = defaultdict(list) 21 | for k, v in items: 22 | res[k].append(v) 23 | return res 24 | 25 | ##################### 26 | ## graph building 27 | ##################### 28 | sep = '/' 29 | 30 | def split(path): 31 | i = path.rfind(sep) + 1 32 | return path[:i].rstrip(sep), path[i:] 33 | 34 | def normpath(path): 35 | #simplified os.path.normpath 36 | parts = [] 37 | for p in path.split(sep): 38 | if p == '..': parts.pop() 39 | elif p.startswith(sep): parts = [p] 40 | else: parts.append(p) 41 | return sep.join(parts) 42 | 43 | has_inputs = lambda node: type(node) is tuple 44 | 45 | def pipeline(net): 46 | return [(sep.join(path), (node if has_inputs(node) else (node, [-1]))) for (path, node) in path_iter(net)] 47 | 48 | def build_graph(net): 49 | flattened = pipeline(net) 50 | resolve_input = lambda rel_path, path, idx: normpath(sep.join((path, '..', rel_path))) if isinstance(rel_path, str) else flattened[idx+rel_path][0] 51 | return {path: (node[0], [resolve_input(rel_path, path, idx) for rel_path in node[1]]) for idx, (path, node) in enumerate(flattened)} 52 | 53 | class Network(nn.Module): 54 | def __init__(self, net): 55 | super().__init__() 56 | self.graph = build_graph(net) 57 | for path, (val, _) in self.graph.items(): 58 | setattr(self, path.replace('/', '_'), val) 59 | 60 | def nodes(self): 61 | return (node for node, _ in self.graph.values()) 62 | 63 | def forward(self, inputs): 64 | outputs = dict(inputs) 65 | for k, (node, ins) in self.graph.items(): 66 | #only compute nodes that are not supplied as inputs. 67 | if k not in outputs: 68 | outputs[k] = node(*[outputs[x] for x in ins]) 69 | return outputs 70 | 71 | def half(self): 72 | for node in self.nodes(): 73 | if isinstance(node, nn.Module) and not isinstance(node, nn.BatchNorm2d): 74 | node.half() 75 | return self -------------------------------------------------------------------------------- /loaders/loaders.py: -------------------------------------------------------------------------------- 1 | import maskedtensors.maskedtensor as maskedtensor 2 | from torch.utils.data import DataLoader#, default_collate 3 | import torch 4 | 5 | def collate_fn_pair(samples_list): 6 | input1_list = [input1 for input1, _ in samples_list] 7 | input2_list = [input2 for _, input2 in samples_list] 8 | input1 = maskedtensor.from_list(input1_list, dims=(1, 2), base_name='N') 9 | input2 = maskedtensor.from_list(input2_list, dims=(1, 2), base_name='M') 10 | return input1, input2 11 | 12 | def collate_fn_pair_explore(samples_list): 13 | input1_list = [input1 for input1, _ in samples_list] 14 | input2_list = [input2 for _, input2 in samples_list] 15 | return {'input': torch.stack(input1_list)}, {'input': torch.stack(input2_list)} 16 | 17 | def siamese_loader(data, batch_size, constant_n_vertices, shuffle=True): 18 | assert len(data) > 0 19 | if constant_n_vertices: 20 | return DataLoader(data, batch_size=batch_size, shuffle=shuffle, 21 | num_workers=4, collate_fn=collate_fn_pair_explore) 22 | return DataLoader(data, batch_size=batch_size, shuffle=shuffle, 23 | num_workers=4, collate_fn=collate_fn_pair) 24 | 25 | def collate_fn(samples_list): 26 | inputs = [inp for inp,_ in samples_list] 27 | labels = [lab for _,lab in samples_list] 28 | return maskedtensor.from_list(inputs, dims=(1, 2), base_name='N'), torch.tensor(labels) 29 | 30 | 31 | def collate_fn_explore(samples_list): 32 | graphs = [inp[0,:,:].unsqueeze(0) for inp,_ in samples_list] 33 | nodes_f = [torch.diagonal(inp[1:,:,:], dim1=1, dim2=2) for inp,_ in samples_list] 34 | labels = [lab for _,lab in samples_list] 35 | #print(nodes_f) 36 | return {'graphs': maskedtensor.from_list(graphs, dims=(1, 2), base_name='N'), 37 | 'nodes_f': maskedtensor.from_list(nodes_f, dims=(1,), base_name='N'), 38 | 'target': torch.tensor(labels)} 39 | 40 | 41 | def simple_loader(data, batch_size, constant_n_vertices, shuffle=True): 42 | assert len(data) > 0 43 | if constant_n_vertices: 44 | return DataLoader(data, batch_size=batch_size, shuffle=shuffle, 45 | num_workers=4) 46 | return DataLoader(data, batch_size=batch_size, shuffle=shuffle, 47 | num_workers=0, collate_fn=collate_fn_explore) 48 | 49 | def collate_fn_benchmark(list): 50 | graphs = [inp[0,:,:].unsqueeze(0) for inp,_ in list] 51 | nodes_f = [torch.diagonal(inp[1:,:,:], dim1=1, dim2=2) for inp,_ in list] 52 | labels = [lab for _,lab in list] 53 | #print(nodes_f) 54 | return {'graphs': maskedtensor.from_list(graphs, dims=(1, 2), base_name='N'), 55 | 'nodes_f': maskedtensor.from_list(nodes_f, dims=(1,), base_name='N'), 56 | 'target': torch.tensor(labels)} 57 | 58 | def benchmark_loader(data, batch_size, constant_n_vertices=False, shuffle=True): 59 | assert len(data) > 0 60 | if constant_n_vertices: 61 | print('Not implemented') 62 | #return DataLoader(data, batch_size=batch_size, shuffle=shuffle, num_workers=4) 63 | return DataLoader(data, batch_size=batch_size, shuffle=shuffle, 64 | num_workers=0, collate_fn=collate_fn_benchmark) -------------------------------------------------------------------------------- /default_config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | problem: qap # PB_DIR = experiments-gnn/$problem 3 | name: expe_norm # results will be stored in PB_DIR/$name 4 | cpu: No 5 | #root_dir: 'experiments-gnn' # not used... 6 | #test_enabled: Yes 7 | #use_dgl: No 8 | #path_dataset: data # Path where datasets are stored, default data/ 9 | 10 | data: 11 | train: # Train/Val data generation parameters 12 | num_examples_train: 20000 13 | num_examples_val: 1000 14 | n_vertices: 50 15 | sparsify: None #Only works for not fgnns. Put to None if you don't want sparsifying 16 | generative_model: Regular #Seed # so far ErdosRenyi, Regular or BarabasiAlbert 17 | noise_model: ErdosRenyi 18 | edge_density: 0.2 #0.05 #0.015 #0.025 19 | vertex_proba: 1. # Parameter of the binomial distribution of vertices 20 | noise: 0.1 #0.3 #0.32 #0.2 #0.2 0.4 0.6 0.8 0.9 21 | 22 | test: #Test data generation parameters not used yet... 23 | num_examples_test: 1000 24 | n_vertices: 50 25 | #sparsify: None #Only works for not fgnns. Put to None if you don't want sparsifying 26 | #custom: No #If No, keeps the data_generation from train, just a failsafe so people consciously have to activate custom test 27 | generative_model: Regular #Seed # so far ErdosRenyi, Regular or BarabasiAlbert 28 | noise_model: ErdosRenyi 29 | edge_density: 0.2 #0.0125 30 | vertex_proba: 1. # Parameter of the binomial distribution of vertices 31 | noise: 0.1 32 | path_model: '/home/mlelarge/experiments-gnn/qap/expe_norm/node_embedding_Regular_100_0.05/07-27-23-14-45/qap_expe_norm/prges07j/checkpoints/epoch=9-step=6250.ckpt' 33 | #path_model: '/home/mlelarge/experiments-gnn/qap/expe_norm/node_embedding_RegularSeed_100_0.05/07-25-23-11-30/qap_expe_norm/mvki2vap/checkpoints/epoch=9-step=6250.ckpt' #'/home/mlelarge/experiments-gnn/qap/expe_norm/node_embedding_Regular_100_0.05/07-19-23-11-54/qap_expe_norm/qye55q7e/checkpoints/epoch=7-step=5000.ckpt' #'/home/mlelarge/experiments-gnn/qap/expe_norm/node_embedding_rec_Regular_100_0.05/01-12-23-14-18/qap_expe_norm/262h3uh7/checkpoints/epoch=4-step=3125.ckpt' 34 | 35 | 36 | train: # Training parameters 37 | epochs: 100 38 | batch_size: 256 #32 #10 #8 #32 #16 #64 39 | lr: !!float 1e-3 #1e-3 40 | scheduler_step: 3 41 | scheduler_decay: 0.5 42 | lr_stop: !!float 1e-5 43 | log_freq: 50 44 | anew: Yes 45 | start_model: '/home/mlelarge/experiments-gnn/qap/qap_res/gatedgcn_8_ErdosRenyi_64_0.09375/02-11-22-20-55/model_best.pth.tar' #'/home/mlelarge/experiments-gnn/qap/qap_res/fgnn_4_ErdosRenyi_64_0.09375/02-11-22-09-31/model_best.pth.tar' 46 | 47 | arch: # Architecture and model 48 | original_features_num: 2 # 2 for fgnn 1 for mgnn 49 | node_emb: 50 | type: node_embedding 51 | block_init: block_emb 52 | block_inside: block 53 | num_blocks: 4 54 | in_features: 32 55 | out_features: 32 56 | depth_of_mlp: 3 57 | num_heads: 16 58 | 59 | #arch_gnn: fgnn #fgnn, gcn, gatedgcn 60 | #arch_load: siamese #siamese or simple(to be done) 61 | #embedding: node #node or edge, rs_node 62 | #num_blocks: 4 #4 63 | 64 | #dim_features: 64 #64 65 | #depth_of_mlp: 3 66 | #input_embed: No # No 67 | 68 | observers: 69 | wandb: Yes 70 | 71 | -------------------------------------------------------------------------------- /models/trainers.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | from toolbox.losses import triplet_loss 3 | from toolbox.metrics import accuracy_max, accuracy_linear_assignment 4 | from models.blocks_emb import * 5 | from models.utils import * 6 | #from toolbox.utils import schedule 7 | 8 | get_node_emb = { 9 | 'node_embedding': node_embedding, 10 | } 11 | 12 | get_block_init = { 13 | 'block_emb': block_emb 14 | } 15 | 16 | get_block_inside = { 17 | 'block': block 18 | } 19 | 20 | class Siamese_Node_Exp(pl.LightningModule): 21 | def __init__(self, original_features_num, node_emb, lr=1e-3, scheduler_decay=0.5, scheduler_step=3, lr_stop = 1e-5): 22 | """ 23 | take a batch of pair of graphs as 24 | (bs, original_features, n_vertices, n_vertices) 25 | and return a batch of "node similarities (i.e. dot product)" 26 | with shape (bs, n_vertices, n_vertices) 27 | graphs must NOT have same size inside the batch when maskedtensors are used 28 | """ 29 | super().__init__() 30 | try: 31 | node_emb_type = get_node_emb[node_emb['type']] 32 | except KeyError: 33 | raise NotImplementedError(f"node embedding {node_emb['type']} is not implemented") 34 | try: 35 | block_inside = get_block_inside[node_emb['block_inside']] 36 | node_emb['block_inside'] = block_inside 37 | except KeyError: 38 | raise NotImplementedError(f"block inside {node_emb['block_inside']} is not implemented") 39 | try: 40 | block_init = get_block_init[node_emb['block_init']] 41 | node_emb['block_init'] = block_init 42 | except KeyError: 43 | raise NotImplementedError(f"block init {node_emb['block_init']} is not implemented") 44 | 45 | self.out_features = node_emb['out_features'] 46 | self.node_embedder_dic = { 47 | 'input': (None, []), 48 | 'ne': node_emb_type(original_features_num, **node_emb) 49 | } 50 | self.node_embedder = Network(self.node_embedder_dic) 51 | 52 | self.loss = triplet_loss() 53 | self.metric = accuracy_linear_assignment#accuracy_max 54 | self.lr = lr 55 | self.scheduler_decay = scheduler_decay 56 | self.scheduler_step = scheduler_step 57 | self.lr_stop = lr_stop 58 | 59 | 60 | def forward(self, x1, x2): 61 | """ 62 | Data should be given with the shape (b,2,f,n,n) 63 | """ 64 | x1 = self.node_embedder(x1)['ne/suffix'] 65 | x2 = self.node_embedder(x2)['ne/suffix'] 66 | #raw_scores = torch.einsum('bfi,bfj-> bij', x1, x2) 67 | raw_scores = torch.matmul(torch.transpose(x1,1,2),x2) 68 | return raw_scores 69 | 70 | def training_step(self, batch, batch_idx): 71 | raw_scores = self(batch[0], batch[1]) 72 | loss = self.loss(raw_scores) 73 | self.log('train_loss', loss) 74 | (acc,n) = self.metric(raw_scores) 75 | self.log("train_acc", acc/n) 76 | return loss 77 | 78 | def validation_step(self, batch, batch_idx): 79 | raw_scores = self(batch[0], batch[1]) 80 | loss = self.loss(raw_scores) 81 | self.log('val_loss', loss) 82 | (acc,n) = self.metric(raw_scores) 83 | self.log("val_acc", acc/n) 84 | 85 | def test_step(self, batch, batch_idx): 86 | raw_scores = self(batch[0], batch[1]) 87 | loss = self.loss(raw_scores) 88 | self.log('test_loss', loss) 89 | (acc,n) = self.metric(raw_scores) 90 | self.log("test_acc", acc/n) 91 | 92 | def configure_optimizers(self): 93 | optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, 94 | amsgrad=False) 95 | return { 96 | "optimizer": optimizer, 97 | "lr_scheduler": { 98 | "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=self.scheduler_decay, patience=self.scheduler_step, verbose=True, min_lr=self.lr_stop), 99 | "monitor": "val_loss", 100 | "frequency": 1 101 | # If "monitor" references validation metrics, then "frequency" should be set to a 102 | # multiple of "trainer.check_val_every_n_epoch". 103 | }, 104 | } -------------------------------------------------------------------------------- /data_benchmarking_gnns/data_generator.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is 3 | adapted from https://github.com/hadarser/ProvablyPowerfulGraphNetworks_torch 4 | """ 5 | import data_benchmarking_gnns.data_helper as helper 6 | #import utils.config 7 | import torch 8 | 9 | 10 | class DataGenerator: 11 | def __init__(self, config): 12 | self.config = config 13 | # load data here 14 | self.batch_size = self.config.hyperparams.batch_size 15 | self.is_qm9 = self.config.dataset_name == 'QM9' 16 | self.labels_dtype = torch.float32 if self.is_qm9 else torch.long 17 | 18 | self.load_data() 19 | 20 | # load the specified dataset in the config to the data_generator instance 21 | def load_data(self): 22 | if self.is_qm9: 23 | self.load_qm9_data() 24 | else: 25 | self.load_data_benchmark() 26 | 27 | self.split_val_test_to_batches() 28 | 29 | # load QM9 data set 30 | def load_qm9_data(self): 31 | train_graphs, train_labels, val_graphs, val_labels, test_graphs, test_labels = \ 32 | helper.load_qm9(self.config.target_param) 33 | 34 | # preprocess all labels by train set mean and std 35 | train_labels_mean = train_labels.mean(axis=0) 36 | train_labels_std = train_labels.std(axis=0) 37 | train_labels = (train_labels - train_labels_mean) / train_labels_std 38 | val_labels = (val_labels - train_labels_mean) / train_labels_std 39 | test_labels = (test_labels - train_labels_mean) / train_labels_std 40 | 41 | self.train_graphs, self.train_labels = train_graphs, train_labels 42 | self.val_graphs, self.val_labels = val_graphs, val_labels 43 | self.test_graphs, self.test_labels = test_graphs, test_labels 44 | 45 | self.train_size = len(self.train_graphs) 46 | self.val_size = len(self.val_graphs) 47 | self.test_size = len(self.test_graphs) 48 | self.labels_std = train_labels_std # Needed for postprocess, multiply mean abs distance by this std 49 | 50 | # load data for a benchmark graph (COLLAB, NCI1, NCI109, MUTAG, PTC, IMDBBINARY, IMDBMULTI, PROTEINS) 51 | def load_data_benchmark(self): 52 | graphs, labels = helper.load_dataset(self.config.dataset_name) 53 | # if no fold specify creates random split to train and validation 54 | if self.config.num_fold is None: 55 | graphs, labels = helper.shuffle(graphs, labels) 56 | idx = len(graphs) // 10 57 | self.train_graphs, self.train_labels, self.val_graphs, self.val_labels = graphs[idx:], labels[idx:], graphs[:idx], labels[:idx] 58 | elif self.config.num_fold == 0: 59 | train_idx, test_idx = helper.get_parameter_split(self.config.dataset_name) 60 | self.train_graphs, self.train_labels, self.val_graphs, self.val_labels = graphs[train_idx], labels[ 61 | train_idx], graphs[test_idx], labels[test_idx] 62 | else: 63 | train_idx, test_idx = helper.get_train_val_indexes(self.config.num_fold, self.config.dataset_name) 64 | self.train_graphs, self.train_labels, self.val_graphs, self.val_labels = graphs[train_idx], labels[train_idx], graphs[test_idx], labels[ 65 | test_idx] 66 | # change validation graphs to the right shape 67 | self.train_size = len(self.train_graphs) 68 | self.val_size = len(self.val_graphs) 69 | 70 | def next_batch(self): 71 | graphs, labels = next(self.iter) 72 | graphs, labels = torch.cuda.FloatTensor(graphs), torch.tensor(labels, device='cuda', dtype=self.labels_dtype) 73 | return graphs, labels 74 | 75 | # initialize an iterator from the data for one training epoch 76 | def initialize(self, what_set): 77 | if what_set == 'train': 78 | self.reshuffle_data() 79 | elif what_set == 'val' or what_set == 'validation': 80 | self.iter = zip(self.val_graphs_batches, self.val_labels_batches) 81 | elif what_set == 'test': 82 | self.iter = zip(self.test_graphs_batches, self.test_labels_batches) 83 | else: 84 | raise ValueError("what_set should be either 'train', 'val' or 'test'") 85 | 86 | def reshuffle_data(self): 87 | """ 88 | Reshuffle train data between epochs 89 | """ 90 | graphs, labels = helper.group_same_size(self.train_graphs, self.train_labels) 91 | graphs, labels = helper.shuffle_same_size(graphs, labels) 92 | graphs, labels = helper.split_to_batches(graphs, labels, self.batch_size) 93 | self.num_iterations_train = len(graphs) 94 | graphs, labels = helper.shuffle(graphs, labels) 95 | self.iter = zip(graphs, labels) 96 | 97 | def split_val_test_to_batches(self): 98 | # Split the val and test sets to batchs, no shuffling is needed 99 | graphs, labels = helper.group_same_size(self.val_graphs, self.val_labels) 100 | graphs, labels = helper.split_to_batches(graphs, labels, self.batch_size) 101 | self.num_iterations_val = len(graphs) 102 | self.val_graphs_batches, self.val_labels_batches = graphs, labels 103 | 104 | if self.is_qm9: 105 | # Benchmark graphs have no test sets 106 | graphs, labels = helper.group_same_size(self.test_graphs, self.test_labels) 107 | graphs, labels = helper.split_to_batches(graphs, labels, self.batch_size) 108 | self.num_iterations_test = len(graphs) 109 | self.test_graphs_batches, self.test_labels_batches = graphs, labels 110 | 111 | 112 | if __name__ == '__main__': 113 | config = utils.config.process_config('../configs/10fold_config.json') 114 | data = DataGenerator(config) 115 | data.initialize('train') 116 | 117 | 118 | -------------------------------------------------------------------------------- /toolbox/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.lib.arraysetops import isin 3 | import torch 4 | from scipy.optimize import linear_sum_assignment 5 | from torch.nn.modules.activation import Sigmoid, Softmax 6 | from toolbox.utils import get_device, greedy_qap, perm_matrix 7 | import torch.nn.functional as F 8 | from sklearn.cluster import KMeans 9 | import sklearn.metrics as skmetrics 10 | import toolbox.utils as utils 11 | 12 | #from toolbox.searches import mcp_beam_method 13 | 14 | class Meter(object): 15 | """Computes and stores the sum, average and current value""" 16 | def __init__(self): 17 | self.reset() 18 | 19 | def reset(self): 20 | self.val = 0 21 | self.avg = 0 22 | self.sum = 0 23 | self.count = 0 24 | 25 | def update(self, val, n=1): 26 | self.val = val 27 | self.sum += val 28 | self.count += n 29 | self.avg = self.sum / self.count 30 | 31 | def get_avg(self): 32 | return self.avg 33 | 34 | def get_sum(self): 35 | return self.sum 36 | 37 | def value(self): 38 | """ Returns the value over one epoch """ 39 | return self.avg 40 | 41 | def is_active(self): 42 | return self.count > 0 43 | 44 | class ValueMeter(object): 45 | """Computes and stores the average and current value""" 46 | def __init__(self): 47 | self.reset() 48 | 49 | def reset(self): 50 | self.val = 0 51 | 52 | def update(self, val): 53 | self.val = val 54 | 55 | def value(self): 56 | return self.val 57 | 58 | def make_meter_loss(): 59 | meters_dict = { 60 | 'loss': Meter(), 61 | 'loss_ref': Meter(), 62 | 'batch_time': Meter(), 63 | 'data_time': Meter(), 64 | 'epoch_time': Meter(), 65 | } 66 | return meters_dict 67 | 68 | def make_meter_acc(): 69 | meters_dict = { 70 | 'loss': Meter(), 71 | 'acc': Meter(), 72 | 'batch_time': Meter(), 73 | 'data_time': Meter(), 74 | 'epoch_time': Meter(), 75 | } 76 | return meters_dict 77 | 78 | def make_meter_f1(): 79 | meters_dict = { 80 | 'loss': Meter(), 81 | 'f1': Meter(), 82 | 'precision': Meter(), 83 | 'recall': Meter(), 84 | 'batch_time': Meter(), 85 | 'data_time': Meter(), 86 | 'epoch_time': Meter(), 87 | } 88 | return meters_dict 89 | 90 | #QAP 91 | 92 | def accuracy_linear_assignment(rawscores, labels=None, aggregate_score=True): 93 | """ 94 | weights should be (bs,n,n) and labels (bs,n) numpy arrays 95 | """ 96 | total_n_vertices = 0 97 | acc = 0 98 | all_acc = [] 99 | weights = torch.log_softmax(rawscores,-1) 100 | for i, weight in enumerate(weights): 101 | if labels: 102 | label = labels[i] 103 | else: 104 | label = np.arange(len(weight)) 105 | cost = -weight.cpu().detach().numpy() 106 | _, preds = linear_sum_assignment(cost) 107 | if aggregate_score: 108 | acc += np.sum(preds == label) 109 | total_n_vertices += len(weight) 110 | else: 111 | all_acc += [np.sum(preds == label) / len(weight)] 112 | 113 | if aggregate_score: 114 | return acc, total_n_vertices 115 | else: 116 | return all_acc 117 | 118 | def accuracy_max(weights, labels=None, aggregate_score=True): 119 | """ 120 | weights should be (bs,n,n) and labels (bs,n) numpy arrays 121 | """ 122 | acc = 0 123 | all_acc = [] 124 | total_n_vertices = 0 125 | for i, weight in enumerate(weights): 126 | if labels is not None: 127 | label = labels[i] 128 | else: 129 | label = np.arange(len(weight)) 130 | weight = weight.cpu().detach().numpy() 131 | preds = np.argmax(weight, 1) 132 | if aggregate_score: 133 | acc += np.sum(preds == label) 134 | total_n_vertices += len(weight) 135 | else: 136 | all_acc += [np.sum(preds == label) / len(weight)] 137 | 138 | if aggregate_score: 139 | return acc, total_n_vertices 140 | else: 141 | return all_acc 142 | 143 | 144 | def all_losses_acc(val_loader,model,criterion, 145 | device,eval_score=None): 146 | #model.eval() 147 | all_losses =[] 148 | all_acc = [] 149 | model = model.to(device) 150 | 151 | for (data1, data2) in val_loader: 152 | data1['input'] = data1['input'].to(device) 153 | data2['input'] = data2['input'].to(device) 154 | rawscores = model(data1, data2) 155 | #n_vertices = output.shape[0] 156 | #ide = torch.arange(n_vertices) 157 | #target = ide.to(device) 158 | 159 | loss = criterion(rawscores) 160 | 161 | all_losses.append(loss.item()) 162 | 163 | if eval_score is not None: 164 | acc = eval_score(rawscores,aggregate_score=False) 165 | all_acc += acc 166 | return np.array(all_losses), np.array(all_acc) 167 | 168 | def all_acc_qap(val_loader,model,device): 169 | #model.eval() 170 | all_qap = [] 171 | all_acc = [] 172 | all_planted = [] 173 | model = model.to(device) 174 | 175 | for (data1, data2) in val_loader: 176 | data1['input'] = data1['input'].to(device) 177 | data2['input'] = data2['input'].to(device) 178 | rawscores = model(data1, data2) 179 | weights = torch.log_softmax(rawscores,-1) 180 | g1 = data1['input'][:,0,:].cpu().detach().numpy() 181 | g2 = data2['input'][:,0,:].cpu().detach().numpy() 182 | for i, weight in enumerate(weights): 183 | cost = -weight.cpu().detach().numpy() 184 | row_ind, col_ind = linear_sum_assignment(cost) 185 | qap = (g1[i]*(g2[i][col_ind,:][:,col_ind])).sum() 186 | planted = (g1[i]*g2[i]).sum() 187 | label = np.arange(len(weight)) 188 | acc = np.sum(col_ind == label) 189 | all_qap.append(qap) 190 | all_acc.append(acc) 191 | all_planted.append(planted) 192 | 193 | return np.array(all_acc), np.array(all_qap), np.array(all_planted) 194 | 195 | # code below should be corrected/refactored... 196 | 197 | def all_greedy_losses_acc(val_loader,model,criterion, 198 | device,T=10): 199 | # only tested with batch size = 1 200 | model.eval() 201 | all_losses =[] 202 | all_acc = [] 203 | 204 | for (data, target) in val_loader: 205 | data = data.to(device) 206 | target_deviced = target.to(device) 207 | output = model(data) 208 | rawscores = output.squeeze(-1) 209 | raw_scores = torch.softmax(rawscores,-1) 210 | 211 | loss = criterion(raw_scores,target_deviced) 212 | 213 | all_losses.append(loss.item()) 214 | 215 | A = data[0,0,:,:,1].data.cpu().detach().numpy() 216 | B = data[0,1,:,:,1].data.cpu().detach().numpy() 217 | cost = -raw_scores.cpu().detach().numpy().squeeze() 218 | #print(i, " | ", cost) 219 | row, preds = linear_sum_assignment(cost) 220 | a, na,nb, acc, _ = greedy_qap(A,B,perm_matrix(row,preds),T) 221 | all_acc.append(acc) 222 | 223 | return np.array(all_losses), np.array(all_acc) 224 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Expressive Power of Invariant and Equivariant Graph Neural Networks 2 | 3 | In this repository, we show how to use powerful GNN (2-FGNN) to solve a graph alignment problem. This code was used to derive the practical results in the following paper: 4 | 5 | Waiss Azizian, Marc Lelarge. Expressive Power of Invariant and Equivariant Graph Neural Networks, ICLR 2021. 6 | 7 | [arXiv](https://arxiv.org/abs/2006.15646) [OpenReview](https://openreview.net/forum?id=lxHgXYN4bwl) 8 | ## Problem: alignment of graphs 9 | The graph isomorphism problem is the computational problem of determining whether two finite graphs are isomorphic. Here we consider a noisy version of this problem: the two graphs below are noisy versions of a parent graph. There is no strict isomorphism between them. Can we still match the vertices of graph 1 with the corresponding vertices of graph 2? 10 | 11 | graph 1 | graph 2 12 | :---:|:---: 13 | ![](images/01_graph1.png) | ![](images/02_graph2.png) 14 | 15 | With our GNN, we obtain the following results: green vertices are well paired vertices and red vertices are errors. Both graphs are now represented using the layout from the right above but the color of the vertices are the same on both sides. At inference, our GNN builds node embedding for the vertices of graphs 1 and 2. Finally a node of graph 1 is matched to its most similar node of graph 2 in this embedding space. 16 | 17 | graph 1 | graph 2 18 | :---:|:---: 19 | ![](images/04_result_graph1.png) | ![](images/03_result_graph2.png) 20 | 21 | Below, on the left, we plot the errors made by our GNN: errors made on red vertices are represented by links corresponding to a wrong matching or cycle; on the right, we superpose the two graphs: green edges are in both graphs (they correspond to the parent graph), orange edges are in graph 1 only and blue edges are in graph 2 only. We clearly see the impact of the noisy edges (orange and blue) as each red vertex (corresponding to an error) is connected to such edges (except the isolated red vertex). 22 | 23 | Wrong matchings/cycles | Superposing the 2 graphs 24 | :---:|:---: 25 | ![](images/09_preds.png) | ![](images/05_result.png) 26 | 27 | To measure the performance of our GNN, instead of looking at vertices, we can look at edges. On the left below, we see that our GNN recovers most of the green edges present in graphs 1 and 2 (edges from the parent graph). On the right, mismatched edges correspond mostly to noisy (orange and blue) edges (present in only one of the graphs 1 or 2). 28 | 29 | Matched edges | Mismatched edges 30 | :---:|:---: 31 | ![](images/07_match.png) | ![](images/08_mismatch.png) 32 | 33 | ## Training GNN for the graph alignment problem 34 | 35 | For the training of our GNN, we generate synthetic datasets as follows: first sample the parent graph and then add edges to construct graphs 1 and 2. We obtain a dataset made of pairs of graphs for which we know the true matching of vertices. We then use a siamese encoder as shown below where the same GNN (i.e. shared weights) is used for both graphs. The node embeddings constructed for each graph are then used to predict the corresponding permutation index by taking the outer product and a softmax along each row. The GNN is trained with a standard cross-entropy loss. 36 | At inference, we can add a LAP solver to get a permutation from the matrix . 37 | 38 | ![](images/siamese.png) 39 | 40 | Various architectures can be used for the GNN and we find that FGNN (first introduced by Maron et al. in [Provably Powerful Graph Networks](https://papers.nips.cc/paper/2019/hash/bb04af0f7ecaee4aae62035497da1387-Abstract.html) NeurIPS 2019) are best performing for our task. In our paper [Expressive Power of Invariant and Equivariant Graph Neural Networks](https://openreview.net/forum?id=lxHgXYN4bwl), we substantiate these empirical findings by **proving that FGNN has a better power of approximation among all equivariant architectures working with tensors of order 2 presented so far** (this includes message passing GNN or linear GNN). 41 | 42 | ## Results 43 | 44 | ![](images/download.png) 45 | 46 | Each line corresponds to a model trained at a given noise level and shows 47 | its accuracy across all noise levels. We see that pretrained models generalize very well at noise levels unseen during the training. 48 | 49 | We provide a simple [notebook](https://github.com/mlelarge/graph_neural_net/blob/master/plot_accuracy_regular.ipynb) to reproduce this result for the pretrained model released with this repository (to run the notebook create a `ipykernel` with name gnn and with the required dependencies as described below). 50 | 51 | We refer to our [paper](https://openreview.net/forum?id=lxHgXYN4bwl) for comparisons with other algorithms (message passing GNN, spectral or SDP algorithms). 52 | 53 | To cite our paper: 54 | ``` 55 | @inproceedings{azizian2020characterizing, 56 | title={Expressive power of invariant and equivariant graph neural networks}, 57 | author={Azizian, Wa{\"\i}ss and Lelarge, Marc}, 58 | booktitle={International Conference on Learning Representations}, 59 | year={2021}, 60 | url={https://openreview.net/forum?id=lxHgXYN4bwl} 61 | } 62 | ``` 63 | 64 | ## Overview of the code 65 | ### Project structure 66 | 67 | ```bash 68 | . 69 | ├── cpp_code # C++ code for exact solving to be compiled 70 | ├── loaders 71 | | └── dataset selector 72 | | └── data_generator.py # generating random graphs 73 | | └── test_data_generator.py 74 | | └── siamese_loader.py # loading pairs 75 | ├── models 76 | | └── architecture selector 77 | | └── layers.py # equivariant block 78 | | └── base_model.py # powerful GNN Graph -> Graph 79 | | └── siamese_net.py # GNN to match graphs 80 | ├── toolbox 81 | | └── optimizer and losses selectors 82 | | └── data_handler.py # class handling the io of data and task-planning 83 | | └── helper.py # base class for helping the selection of experiments when training a model 84 | | └── logger.py # keeping track of most results during training 85 | | └── losses.py # computing losses 86 | | └── maskedtensor.py # Tensor-like class to handle batches of graphs of different sizes 87 | | └── metrics.py # computing scores 88 | | └── mcp_solver.py # class handling the multi-threaded exact solving of MCP problems 89 | | └── minb_solver.py # class handling the multi-threaded exact solving of Min Bisection problems 90 | | └── optimizer.py # optimizers 91 | | └── searches.py # contains beam searches and exact solving functions 92 | | └── utility.py 93 | | └── vision.py # functions for visualization 94 | ├── article_commander.py # main file for computing the data needed for the figures 95 | ├── commander.py # main file from the project serving for calling all necessary functions for training and testing 96 | ├── trainer.py # pipelines for training and validation 97 | ├── eval.py # testing models 98 | 99 | ``` 100 | 101 | ### Dependencies 102 | Dependencies are listed in `requirements.txt`. To install, run 103 | ``` 104 | pip install -r requirements.txt 105 | ``` 106 | DGL is not included in the `requirements.txt`, please follow the [dgl specific instructions](https://www.dgl.ai/pages/start.html) 107 | 108 | ## Training 109 | Run the main file ```commander.py``` with the command ```train``` 110 | ``` 111 | python train commander.py 112 | ``` 113 | To change options, use [Sacred](https://github.com/IDSIA/sacred) command-line interface and see ```default_config.yaml``` for the configuration structure. For instance, 114 | ``` 115 | python commander.py train with cpu=No data.generative_model=Regular train.epoch=10 116 | ``` 117 | You can also copy ```default_config.yaml``` and modify the configuration parameters there. 118 | 119 | See [Sacred documentation](http://sacred.readthedocs.org/) for an exhaustive reference. 120 | 121 | To save logs to [Neptune](https://neptune.ai/), you need to provide your own API key via the dedicated environment variable. 122 | 123 | ## Evaluating 124 | 125 | There are two ways of evaluating the models. If you juste ran the training with a configuration ```conf.yaml```, you can simply do, 126 | ``` 127 | python commander.py eval with conf.yaml 128 | ``` 129 | You can omit ```with conf.yaml``` if you are using the default configuartion. 130 | 131 | If you downloaded a model with a config file from here, you can edit the section ```test_data``` of this config if you wish and then run, 132 | ``` 133 | python commander.py eval with /path/to/config model_path=/path/to/model.pth.tar 134 | ``` 135 | 136 | will retrieve the trained model and evaluated it on a test dataset. More options are available in `eval.py`. 137 | 138 | -------------------------------------------------------------------------------- /toolbox/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import json 4 | from typing import Tuple 5 | #from matplotlib.pyplot import isinteractive 6 | from numpy.lib.arraysetops import isin 7 | import torch 8 | import numpy as np 9 | from scipy.spatial.distance import cdist 10 | from scipy.optimize import linear_sum_assignment 11 | from networkx import to_numpy_array as nx_to_numpy_array 12 | #import dgl as dgl 13 | import torch.backends.cudnn as cudnn 14 | 15 | def schedule(k, max_epochs=8): 16 | return max(max_epochs-k,0)/max_epochs #torch.tensor(max(max_epochs-k,0)/max_epochs, dtype=torch.float) 17 | 18 | def load_json(json_file): 19 | # Load the JSON file into a variable 20 | with open(json_file) as f: 21 | json_data = json.load(f) 22 | 23 | # Return the data as a dictionary 24 | return json_data 25 | 26 | # create directory if it does not exist 27 | def check_dir(dir_path): 28 | dir_path = dir_path.replace('//','/') 29 | os.makedirs(dir_path, exist_ok=True) 30 | 31 | def check_file(file_path): 32 | file_path = file_path.replace('//','/') 33 | dir_path = os.path.dirname(file_path) 34 | check_dir(dir_path) 35 | if not os.path.exists(file_path): 36 | with open(file_path,'w') as f: 37 | pass 38 | 39 | def setup_env(cpu): 40 | # Randomness is already controlled by Sacred 41 | # See https://sacred.readthedocs.io/en/stable/randomness.html 42 | if not cpu: 43 | cudnn.benchmark = True 44 | 45 | def save_checkpoint(state, is_best, log_dir, filename='checkpoint.pth.tar'): 46 | #check_dir(log_dir) 47 | filename = os.path.join(log_dir, filename) 48 | torch.save(state, filename) 49 | if is_best: 50 | shutil.copyfile(filename, os.path.join(log_dir, 'model_best.pth.tar')) 51 | #shutil.copyfile(filename, model_path) 52 | print(f"Best Model yet : saving at {log_dir+'/model_best.pth.tar'}") 53 | 54 | fn = os.path.join(log_dir, 'checkpoint_epoch{}.pth.tar') 55 | torch.save(state, fn.format(state['epoch'])) 56 | 57 | if (state['epoch'] - 1 ) % 5 != 0: 58 | #remove intermediate saved models, e.g. non-modulo 5 ones 59 | if os.path.exists(fn.format(state['epoch'] - 1 )): 60 | os.remove(fn.format(state['epoch'] - 1 )) 61 | 62 | state['exp_logger'].to_json(log_dir=log_dir,filename='logger.json') 63 | 64 | # move in utils 65 | def load_model(model, device, model_path): 66 | """ Load model. Note that the model_path argument is captured """ 67 | if os.path.exists(model_path): 68 | print("Reading model from ", model_path) 69 | checkpoint = torch.load(model_path, map_location=torch.device(device)) 70 | model.load_state_dict(checkpoint['state_dict']) 71 | return model 72 | else: 73 | raise RuntimeError('Model does not exist!') 74 | 75 | def save_to_json(jsonkey, loss, relevant_metric_dict, filename): 76 | if os.path.exists(filename): 77 | with open(filename, "r") as jsonFile: 78 | data = json.load(jsonFile) 79 | else: 80 | data = {} 81 | data[jsonkey] = {'loss':loss} 82 | for dkey, value in relevant_metric_dict.items(): 83 | data[jsonkey][dkey] = value 84 | with open(filename, 'w') as jsonFile: 85 | json.dump(data, jsonFile) 86 | 87 | # from https://stackoverflow.com/questions/50916422/python-typeerror-object-of-type-int64-is-not-json-serializable/50916741 88 | class NpEncoder(json.JSONEncoder): 89 | def default(self, obj): 90 | if isinstance(obj, np.integer): 91 | return int(obj) 92 | elif isinstance(obj, np.floating): 93 | return float(obj) 94 | elif isinstance(obj, np.ndarray): 95 | return obj.tolist() 96 | else: 97 | return super(NpEncoder, self).default(obj) 98 | 99 | 100 | def get_lr(optimizer): 101 | for param_group in optimizer.param_groups: 102 | return param_group['lr'] 103 | 104 | def get_device(t): 105 | if t.is_cuda: 106 | return t.get_device() 107 | return 'cpu' 108 | 109 | #Matrix operation 110 | 111 | def symmetrize_matrix(A): 112 | """ 113 | Symmetrizes a matrix : 114 | If shape is (a,b,c) will symmetrize by considering a is batch size 115 | """ 116 | Af = A.triu(0) + A.triu(1).transpose(-2,-1) 117 | return Af 118 | 119 | def list_to_tensor(liste) -> torch.Tensor: 120 | """Transforms a list of same shaped tensors""" 121 | if isinstance(liste,torch.Tensor): 122 | return liste 123 | bs = len(liste) 124 | shape = liste[0].shape 125 | final_shape = (bs,*shape) 126 | tensor_eq = torch.empty(final_shape) 127 | for k in range(bs): 128 | tensor_eq[k] = liste[k] 129 | return tensor_eq 130 | 131 | #Graph operations 132 | 133 | """ def edge_features_to_dense_tensor(graph, features, device='cpu'): 134 | N = graph.number_of_nodes() 135 | resqueeze = False 136 | if len(features.shape)==1: 137 | features.unsqueeze(-1) 138 | resqueeze = True 139 | n_feats = features.shape[1] 140 | t = torch.zeros((N,N,n_feats)).to(device) 141 | #adj = torch.tensor(nx_to_numpy_array(graph.to_networkx())).to(device)#edges = np.array(graph.edges().cpu()).T #Transpose for the right shape (2,n_edges) 142 | adj = graph.adj(ctx=device).to_dense() 143 | ix,iy = torch.where(adj==1) 144 | t[ix,iy] = features 145 | if resqueeze: 146 | t.squeeze(-1) 147 | return t 148 | 149 | def edge_features_to_dense_sym_tensor(graph,features,device='cpu'): 150 | t = edge_features_to_dense_tensor(graph,features,device) 151 | if torch.all(t.transpose(0,1)+t==2*t): #Matrix already symmetric 152 | return t 153 | 154 | N = graph.number_of_nodes() 155 | tril = torch.tril(torch.ones((N,N)),-1) 156 | tril = tril.unsqueeze(-1).to(device) #For the multiplication, we need to add the dimension 157 | if torch.all(t*tril==0): #Only zeros in the lower triangle features 158 | return t + t.transpose(0,1) * tril #Here we remove the diagonal with '* tril' 159 | 160 | tbool = (t!=0) 161 | tbool = tbool.sum(-1)!=0 #Here we have True where the feature vectors are not 0 162 | ix,iy = torch.where(tbool!=0) 163 | for i,j in zip(ix,iy): 164 | if i==j or torch.all(t[j,i]==t[i,j]): 165 | continue 166 | elif torch.all(t[j,i]==0): 167 | t[j,i] = t[i,j] 168 | else: 169 | raise AssertionError(f"Feature values are asymmetric, should not have used the symetric function.") 170 | return t 171 | 172 | def edge_features_to_dense_features(graph, features, device='cpu'): 173 | t = edge_features_to_dense_tensor(graph, features, device) 174 | if len(features.shape)==1: 175 | return t.flatten() 176 | n_features = features.shape[1] 177 | N = graph.number_of_nodes() 178 | t_features = t.reshape((N**2,n_features)) 179 | return t_features 180 | 181 | def edge_features_to_dense_sym_features(graph, features, device='cpu'): 182 | t = edge_features_to_dense_sym_tensor(graph, features, device) 183 | if len(features.shape)==1: 184 | return t.flatten() 185 | n_features = features.shape[1] 186 | N = graph.number_of_nodes() 187 | t_features = t.reshape((N**2,n_features)) 188 | return t_features 189 | 190 | def edge_tensor_to_features(graph: dgl.DGLGraph, features: torch.Tensor, device='cpu'): 191 | n_edges = graph.number_of_edges() 192 | resqueeze = False 193 | if len(features.shape)==3: 194 | resqueeze=True 195 | features = features.unsqueeze(-1) 196 | bs,N,_,n_features = features.shape 197 | 198 | ix,iy = graph.edges() 199 | bsx,bsy = ix//N,iy//N 200 | Nx,Ny = ix%N,iy%N 201 | assert torch.all(bsx==bsy), "Edges between graphs, should not be allowed !" #Sanity check 202 | final_features = features[(bsx,Nx,Ny)] #Here, shape will be (n_edges,n_features) 203 | if resqueeze: 204 | final_features = final_features.squeeze(-1) 205 | return final_features 206 | 207 | def temp_sym(t): 208 | if torch.all(t.transpose(0,1)+t==2*t): 209 | return t 210 | elif torch.all(torch.tril(t,-1)==0): 211 | return t + torch.triu(t,1).transpose(0,1) 212 | else: 213 | ix,iy = torch.where(t!=0) 214 | for i,j in zip(ix,iy): 215 | if t[j,i]==0: 216 | t[j,i] = t[i,j] 217 | elif t[j,i]==t[i,j]: 218 | continue 219 | else: 220 | raise AssertionError(f"Feature values are asymmetric, should not have used the symetric function.") 221 | return t 222 | """ 223 | #QAP 224 | 225 | def perm_matrix(row,preds): 226 | n = len(row) 227 | permutation_matrix = np.zeros((n, n)) 228 | permutation_matrix[row, preds] = 1 229 | return permutation_matrix 230 | 231 | def score(A,B,perm): 232 | return np.trace(A @ perm @ B @ np.transpose(perm))/2, np.sum(A)/2, np.sum(B)/2 233 | 234 | def improve(A,B,perm): 235 | label = np.arange(A.shape[0]) 236 | cost_adj = - A @ perm @ B 237 | r, p = linear_sum_assignment(cost_adj) 238 | acc = np.sum(p == label) 239 | return perm_matrix(r,p), acc 240 | 241 | def greedy_qap(A,B,perm,T,verbose=False): 242 | #perm_p = perm 243 | s_best, na, nb = score(A,B,perm) 244 | perm_p, acc_best = improve(A,B,perm) 245 | T_best = 0 246 | for i in range(T): 247 | perm_n, acc = improve(A,B,perm_p) 248 | perm_p = perm_n 249 | s,na,nb = score(A,B,perm_p) 250 | if s > s_best: 251 | acc_best = acc 252 | s_best = s 253 | T_best = i 254 | if verbose: 255 | print(s,na,nb,acc) 256 | return s_best, na, nb, acc_best, T_best 257 | 258 | -------------------------------------------------------------------------------- /data_benchmarking_gnns/data_helper.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is 3 | adapted from https://github.com/hadarser/ProvablyPowerfulGraphNetworks_torch 4 | """ 5 | 6 | import numpy as np 7 | import os 8 | import pickle 9 | from pathlib import Path 10 | ROOT_DIR = Path.home() 11 | DATA_DIR = os.path.join(ROOT_DIR,'data/') 12 | 13 | NUM_LABELS = {'ENZYMES': 3, 'COLLAB': 0, 'IMDBBINARY': 0, 'IMDBMULTI': 0, 'MUTAG': 7, 'NCI1': 37, 'NCI109': 38, 14 | 'PROTEINS': 3, 'PTC': 22, 'DD': 89} 15 | #BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 16 | NUM_CLASSES = {'COLLAB':3, 'IMDBBINARY':2, 'IMDBMULTI':3, 'MUTAG':2, 'NCI1':2, 'NCI109':2, 'PROTEINS':2, 'PTC':2, 'QM9': 12} 17 | 18 | 19 | def load_dataset(ds_name): 20 | """ 21 | construct graphs and labels from dataset text in data folder 22 | :param ds_name: name of data set you want to load 23 | :return: two lists of lenght (num_of_graphs). 24 | the graphs array contains in each entry a ndarray represent adjacency matrix of a graph of shape (num_vertex_labels+1, num_vertex, num_vertex) 25 | the labels array in index i represent the class of graphs[i] 26 | """ 27 | directory = DATA_DIR + "benchmark_graphs/{0}/{0}.txt".format(ds_name) 28 | graphs = [] 29 | labels = [] 30 | with open(directory, "r") as data: 31 | num_graphs = int(data.readline().rstrip().split(" ")[0]) 32 | for i in range(num_graphs): 33 | graph_meta = data.readline().rstrip().split(" ") 34 | num_vertex = int(graph_meta[0]) 35 | curr_graph = np.zeros(shape=(NUM_LABELS[ds_name]+2, num_vertex, num_vertex), dtype=np.float32) 36 | labels.append(int(graph_meta[1])) 37 | for j in range(num_vertex): 38 | vertex = data.readline().rstrip().split(" ") 39 | if NUM_LABELS[ds_name] != 0: 40 | curr_graph[int(vertex[0])+1, j, j]= 1. 41 | for k in range(2,len(vertex)): 42 | curr_graph[0, j, int(vertex[k])] = 1. 43 | #print(curr_graph.shape) 44 | #curr_graph = normalize_graph(curr_graph) 45 | graphs.append(curr_graph) 46 | #graphs = np.array(graphs) 47 | #for i in range(graphs.shape[0]): 48 | # graphs[i] = np.transpose(graphs[i], [2,0,1]) 49 | return graphs, labels#np.array(labels) 50 | 51 | 52 | def load_qm9(target_param): 53 | """ 54 | Constructs the graphs and labels of QM9 data set, already split to train, val and test sets 55 | :return: 6 numpy arrays: 56 | train_graphs: N_train, 57 | train_labels: N_train x 12, (or Nx1 is target_param is not False) 58 | val_graphs: N_val, 59 | val_labels: N_train x 12, (or Nx1 is target_param is not False) 60 | test_graphs: N_test, 61 | test_labels: N_test x 12, (or Nx1 is target_param is not False) 62 | each graph of shape: 19 x Nodes x Nodes (CHW representation) 63 | """ 64 | train_graphs, train_labels = load_qm9_aux('train', target_param) 65 | val_graphs, val_labels = load_qm9_aux('val', target_param) 66 | test_graphs, test_labels = load_qm9_aux('test', target_param) 67 | return train_graphs, train_labels, val_graphs, val_labels, test_graphs, test_labels 68 | 69 | 70 | def load_qm9_aux(which_set, target_param): 71 | """ 72 | Read and construct the graphs and labels of QM9 data set, already split to train, val and test sets 73 | :param which_set: 'test', 'train' or 'val' 74 | :param target_param: if not false, return the labels for this specific param only 75 | :return: graphs: (N,) 76 | labels: N x 12, (or Nx1 is target_param is not False) 77 | each graph of shape: 19 x Nodes x Nodes (CHW representation) 78 | """ 79 | base_path = BASE_DIR + "/data/QM9/QM9_{}.p".format(which_set) 80 | graphs, labels = [], [] 81 | with open(base_path, 'rb') as f: 82 | data = pickle.load(f) 83 | for instance in data: 84 | labels.append(instance['y']) 85 | nodes_num = instance['usable_features']['x'].shape[0] 86 | graph = np.empty((nodes_num, nodes_num, 19)) 87 | for i in range(13): 88 | # 13 features per node - for each, create a diag matrix of it as a feature 89 | graph[:, :, i] = np.diag(instance['usable_features']['x'][:, i]) 90 | graph[:, :, 13] = instance['usable_features']['distance_mat'] 91 | graph[:, :, 14] = instance['usable_features']['affinity'] 92 | graph[:, :, 15:] = instance['usable_features']['edge_features'] # shape n x n x 4 93 | graphs.append(graph) 94 | graphs = np.array(graphs) 95 | for i in range(graphs.shape[0]): 96 | graphs[i] = np.transpose(graphs[i], [2, 0, 1]) 97 | labels = np.array(labels).squeeze() # shape N x 12 98 | if target_param is not False: # regression over a specific target, not all 12 elements 99 | labels = labels[:, target_param].reshape(-1, 1) # shape N x 1 100 | 101 | return graphs, labels 102 | 103 | 104 | def get_train_val_indexes(num_val, ds_name): 105 | """ 106 | reads the indexes of a specific split to train and validation sets from data folder 107 | :param num_val: number of the split 108 | :param ds_name: name of data set 109 | :return: indexes of the train and test graphs 110 | """ 111 | directory = DATA_DIR + "benchmark_graphs/{0}/10fold_idx".format(ds_name) 112 | train_file = "train_idx-{0}.txt".format(num_val) 113 | train_idx=[] 114 | with open(os.path.join(directory, train_file), 'r') as file: 115 | for line in file: 116 | train_idx.append(int(line.rstrip())) 117 | test_file = "test_idx-{0}.txt".format(num_val) 118 | test_idx = [] 119 | with open(os.path.join(directory, test_file), 'r') as file: 120 | for line in file: 121 | test_idx.append(int(line.rstrip())) 122 | return train_idx, test_idx 123 | 124 | 125 | def get_parameter_split(ds_name): 126 | """ 127 | reads the indexes of a specific split to train and validation sets from data folder 128 | :param ds_name: name of data set 129 | :return: indexes of the train and test graphs 130 | """ 131 | directory = DATA_DIR + "benchmark_graphs/{0}/".format(ds_name) 132 | train_file = "tests_train_split.txt" 133 | train_idx=[] 134 | with open(os.path.join(directory, train_file), 'r') as file: 135 | for line in file: 136 | train_idx.append(int(line.rstrip())) 137 | test_file = "tests_val_split.txt" 138 | test_idx = [] 139 | with open(os.path.join(directory, test_file), 'r') as file: 140 | for line in file: 141 | test_idx.append(int(line.rstrip())) 142 | return train_idx, test_idx 143 | 144 | 145 | # def group_same_size(graphs, labels): 146 | # """ 147 | # group graphs of same size to same array 148 | # :param graphs: numpy array of shape (num_of_graphs) of numpy arrays of graphs adjacency matrix 149 | # :param labels: numpy array of labels 150 | # :return: two numpy arrays. graphs arrays in the shape (num of different size graphs) where each entry is a numpy array 151 | # in the shape (number of graphs with this size, num vertex, num. vertex, num vertex labels) 152 | # the second arrayy is labels with correspons shape 153 | # """ 154 | # sizes = list(map(lambda t: t.shape[1], graphs)) 155 | # indexes = np.argsort(sizes) 156 | # graphs = graphs[indexes] 157 | # labels = labels[indexes] 158 | # r_graphs = [] 159 | # r_labels = [] 160 | # one_size = [] 161 | # start = 0 162 | # size = graphs[0].shape[1] 163 | # for i in range(len(graphs)): 164 | # if graphs[i].shape[1] == size: 165 | # one_size.append(np.expand_dims(graphs[i], axis=0)) 166 | # else: 167 | # r_graphs.append(np.concatenate(one_size, axis=0)) 168 | # r_labels.append(np.array(labels[start:i])) 169 | # start = i 170 | # one_size = [] 171 | # size = graphs[i].shape[1] 172 | # one_size.append(np.expand_dims(graphs[i], axis=0)) 173 | # r_graphs.append(np.concatenate(one_size, axis=0)) 174 | # r_labels.append(np.array(labels[start:])) 175 | # return r_graphs, r_labels 176 | 177 | 178 | # helper method to shuffle each same size graphs array 179 | # def shuffle_same_size(graphs, labels): 180 | # r_graphs, r_labels = [], [] 181 | # for i in range(len(labels)): 182 | # curr_graph, curr_labels = shuffle(graphs[i], labels[i]) 183 | # r_graphs.append(curr_graph) 184 | # r_labels.append(curr_labels) 185 | # return r_graphs, r_labels 186 | 187 | 188 | # def split_to_batches(graphs, labels, size): 189 | # """ 190 | # split the same size graphs array to batches of specified size 191 | # last batch is in size num_of_graphs_this_size % size 192 | # :param graphs: array of arrays of same size graphs 193 | # :param labels: the corresponding labels of the graphs 194 | # :param size: batch size 195 | # :return: two arrays. graphs array of arrays in size (batch, num vertex, num vertex. num vertex labels) 196 | # corresponds labels 197 | # """ 198 | # r_graphs = [] 199 | # r_labels = [] 200 | # for k in range(len(graphs)): 201 | # r_graphs = r_graphs + np.split(graphs[k], [j for j in range(size, graphs[k].shape[0], size)]) 202 | # r_labels = r_labels + np.split(labels[k], [j for j in range(size, labels[k].shape[0], size)]) 203 | 204 | # # Avoid bug for batch_size=1, where instead of creating numpy array of objects, we had numpy array of floats with 205 | # # different sizes - could not reshape 206 | # ret1, ret2 = np.empty(len(r_graphs), dtype=object), np.empty(len(r_labels), dtype=object) 207 | # ret1[:] = r_graphs 208 | # ret2[:] = r_labels 209 | # return ret1, ret2 210 | 211 | 212 | # helper method to shuffle the same way graphs and labels arrays 213 | # def shuffle(graphs, labels): 214 | # shf = np.arange(labels.shape[0], dtype=np.int32) 215 | # np.random.shuffle(shf) 216 | # return np.array(graphs)[shf], labels[shf] 217 | 218 | 219 | # def normalize_graph(curr_graph): 220 | 221 | # split = np.split(curr_graph, [1], axis=2) 222 | 223 | # adj = np.squeeze(split[0], axis=2) 224 | # deg = np.sqrt(np.sum(adj, 0)) 225 | # deg = np.divide(1., deg, out=np.zeros_like(deg), where=deg!=0) 226 | # normal = np.diag(deg) 227 | # norm_adj = np.expand_dims(np.matmul(np.matmul(normal, adj), normal), axis=2) 228 | # ones = np.ones(shape=(curr_graph.shape[0], curr_graph.shape[1], curr_graph.shape[2]), dtype=np.float32) 229 | # spred_adj = np.multiply(ones, norm_adj) 230 | # labels= np.append(np.zeros(shape=(curr_graph.shape[0], curr_graph.shape[1], 1)), split[1], axis=2) 231 | # return np.add(spred_adj, labels) 232 | 233 | 234 | # if __name__ == '__main__': 235 | # graphs, labels = load_dataset("MUTAG") 236 | # a, b = get_train_val_indexes(1, "MUTAG") 237 | # print(np.transpose(graphs[a[0]], [1, 2, 0])[0]) 238 | -------------------------------------------------------------------------------- /loaders/data_generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import itertools 4 | import networkx 5 | #from networkx.algorithms.approximation.clique import max_clique 6 | from numpy import diag_indices 7 | import numpy as np 8 | import torch 9 | import torch.utils 10 | import toolbox.utils as utils 11 | #from toolbox.searches import mcp_beam_method 12 | from sklearn.decomposition import PCA 13 | #from numpy import pi,angle,cos,sin 14 | from numpy.random import default_rng 15 | import tqdm 16 | from numpy import mgrid as npmgrid 17 | 18 | from numpy import indices as npindices, argpartition as npargpartition, array as nparray 19 | 20 | 21 | rng = default_rng(41) 22 | 23 | GENERATOR_FUNCTIONS = {} 24 | ADJ_UNIQUE_TENSOR = torch.Tensor([0.,1.]) 25 | 26 | def is_adj(matrix): 27 | return torch.all((matrix==0) + (matrix==1)) 28 | 29 | class TimeOutException(Exception): 30 | pass 31 | 32 | def generates(name): 33 | """ Register a generator function for a graph distribution """ 34 | def decorator(func): 35 | GENERATOR_FUNCTIONS[name] = func 36 | return func 37 | return decorator 38 | 39 | @generates("ErdosRenyi") 40 | def generate_erdos_renyi_netx(p, N): 41 | """ Generate random Erdos Renyi graph """ 42 | g = networkx.erdos_renyi_graph(N, p) 43 | W = networkx.adjacency_matrix(g).todense() 44 | return g, torch.as_tensor(W, dtype=torch.float) 45 | 46 | @generates("BarabasiAlbert") 47 | def generate_barabasi_albert_netx(p, N): 48 | """ Generate random Barabasi Albert graph """ 49 | m = int(p*(N -1)/2) 50 | g = networkx.barabasi_albert_graph(N, m) 51 | W = networkx.adjacency_matrix(g).todense() 52 | return g, torch.as_tensor(W, dtype=torch.float) 53 | 54 | #@generates("RegularSeed") 55 | #def generate_regularseed(p,N): 56 | # return None, None 57 | 58 | @generates("Regular") 59 | def generate_regular_graph_netx(p, N): 60 | """ Generate random regular graph """ 61 | d = p * N 62 | d = int(d) 63 | # Make sure N * d is even 64 | if N * d % 2 == 1: 65 | d += 1 66 | g = networkx.random_regular_graph(d, N) 67 | W = networkx.adjacency_matrix(g).todense() 68 | return g, torch.as_tensor(W, dtype=torch.float) 69 | 70 | NOISE_FUNCTIONS = {} 71 | 72 | def noise(name): 73 | """ Register a noise function """ 74 | def decorator(func): 75 | NOISE_FUNCTIONS[name] = func 76 | return func 77 | return decorator 78 | 79 | @noise("ErdosRenyi") 80 | def noise_erdos_renyi(g, W, noise, edge_density): 81 | n_vertices = len(W) 82 | pe1 = noise 83 | pe2 = (edge_density*noise)/(1-edge_density) 84 | _,noise1 = generate_erdos_renyi_netx(pe1, n_vertices) 85 | _,noise2 = generate_erdos_renyi_netx(pe2, n_vertices) 86 | W_noise = W*(1-noise1) + (1-W)*noise2 87 | return W_noise 88 | 89 | def is_swappable(g, u, v, s, t): 90 | """ 91 | Check whether we can swap 92 | the edges u,v and s,t 93 | to get u,t and s,v 94 | """ 95 | actual_edges = g.has_edge(u, v) and g.has_edge(s, t) 96 | no_self_loop = (u != t) and (s != v) 97 | no_parallel_edge = not (g.has_edge(u, t) or g.has_edge(s, v)) 98 | return actual_edges and no_self_loop and no_parallel_edge 99 | 100 | def do_swap(g, u, v, s, t): 101 | g.remove_edge(u, v) 102 | g.remove_edge(s, t) 103 | g.add_edge(u, t) 104 | g.add_edge(s, v) 105 | 106 | @noise("EdgeSwap") 107 | def noise_edge_swap(g, W, noise, edge_density): #Permet de garder la regularite 108 | g_noise = g.copy() 109 | edges_iter = list(itertools.chain(iter(g.edges), ((v, u) for (u, v) in g.edges))) 110 | for u,v in edges_iter: 111 | if random.random() < noise: 112 | for s, t in edges_iter: 113 | if random.random() < noise and is_swappable(g_noise, u, v, s, t): 114 | do_swap(g_noise, u, v, s, t) 115 | W_noise = networkx.adjacency_matrix(g_noise).todense() 116 | return torch.as_tensor(W_noise, dtype=torch.float) 117 | 118 | def adjacency_matrix_to_tensor_representation(W): 119 | """ Create a tensor B[0,:,:] = W and B[1,i,i] = deg(i)""" 120 | degrees = W.sum(1) 121 | B = torch.zeros((2,len(W), len(W))) 122 | B[0, :, :] = W 123 | indices = torch.arange(len(W)) 124 | B[1, indices, indices] = degrees 125 | return B 126 | 127 | class Base_Generator(torch.utils.data.Dataset): 128 | def __init__(self, name, path_dataset, num_examples): 129 | self.path_dataset = path_dataset 130 | self.name = name 131 | self.num_examples = num_examples 132 | 133 | def load_dataset(self, use_dgl= False): 134 | """ 135 | Look for required dataset in files and create it if 136 | it does not exist 137 | """ 138 | filename = self.name + '.pkl' 139 | filename_dgl = self.name + '_dgl.pkl' 140 | path = os.path.join(self.path_dataset, filename) 141 | path_dgl = os.path.join(self.path_dataset, filename_dgl) 142 | if os.path.exists(path): 143 | if use_dgl: 144 | print('Reading dataset at {}'.format(path_dgl)) 145 | data = torch.load(path_dgl) 146 | else: 147 | print('Reading dataset at {}'.format(path)) 148 | data = torch.load(path) 149 | self.data = list(data) 150 | else: 151 | print('Creating dataset at {}'.format(path)) 152 | l_data = self.create_dataset() 153 | print('Saving dataset at {}'.format(path)) 154 | torch.save(l_data, path) 155 | self.data = l_data 156 | 157 | def remove_file(self): 158 | os.remove(os.path.join(self.path_dataset, self.name + '.pkl')) 159 | 160 | def create_dataset(self): 161 | l_data = [] 162 | for _ in tqdm.tqdm(range(self.num_examples)): 163 | example = self.compute_example() 164 | l_data.append(example) 165 | return l_data 166 | 167 | def __getitem__(self, i): 168 | """ Fetch sample at index i """ 169 | return self.data[i] 170 | 171 | def __len__(self): 172 | """ Get dataset length """ 173 | return len(self.data) 174 | 175 | class QAP_Generator(Base_Generator): 176 | """ 177 | Build a numpy dataset of pairs of (Graph, noisy Graph) 178 | """ 179 | def __init__(self, name, args, path_dataset): 180 | self.generative_model = args['generative_model'] 181 | self.noise_model = args['noise_model'] 182 | self.edge_density = args['edge_density'] 183 | self.noise = args['noise'] 184 | num_examples = args['num_examples_' + name] 185 | n_vertices = args['n_vertices'] 186 | vertex_proba = args['vertex_proba'] 187 | subfolder_name = 'QAP_{}_{}_{}_{}_{}_{}_{}'.format(self.generative_model, 188 | self.noise_model, 189 | num_examples, 190 | n_vertices, vertex_proba, 191 | self.noise, self.edge_density) 192 | path_dataset = os.path.join(path_dataset, subfolder_name) 193 | super().__init__(name, path_dataset, num_examples) 194 | self.data = [] 195 | self.constant_n_vertices = (vertex_proba == 1.) 196 | self.n_vertices_sampler = torch.distributions.Binomial(n_vertices, vertex_proba) 197 | 198 | 199 | utils.check_dir(self.path_dataset) 200 | 201 | def compute_example(self): 202 | """ 203 | Compute pairs (Adjacency, noisy Adjacency) 204 | """ 205 | n_vertices = int(self.n_vertices_sampler.sample().item()) 206 | try: 207 | g, W = GENERATOR_FUNCTIONS[self.generative_model](self.edge_density, n_vertices) 208 | except KeyError: 209 | raise ValueError('Generative model {} not supported' 210 | .format(self.generative_model)) 211 | try: 212 | W_noise = NOISE_FUNCTIONS[self.noise_model](g, W, self.noise, self.edge_density) 213 | except KeyError: 214 | raise ValueError('Noise model {} not supported' 215 | .format(self.noise_model)) 216 | B = adjacency_matrix_to_tensor_representation(W) 217 | B_noise = adjacency_matrix_to_tensor_representation(W_noise) 218 | #data = torch.cat((B.unsqueeze(0),B_noise.unsqueeze(0))) 219 | return (B, B_noise) 220 | 221 | def make_laplacian(W): 222 | D = W @ torch.ones(W.shape[-1]) 223 | return torch.diag(1/torch.sqrt(D)) @ W @ torch.diag(1/torch.sqrt(D)) 224 | 225 | def make_spectral_feature(L,n=4): 226 | out = torch.zeros((n,*L.shape)) 227 | scale = 1#L.shape[-1] 228 | L_prev = torch.eye(L.shape[-1]) 229 | for i in range(n): 230 | L_prev = L_prev @ L 231 | out[i,:,:] = scale*L_prev 232 | return out 233 | 234 | class QAP_spectralGenerator(Base_Generator): 235 | """ 236 | Build a numpy dataset of pairs of (Graph, noisy Graph) 237 | """ 238 | def __init__(self, name, args, path_dataset): 239 | self.generative_model = args['generative_model'] 240 | self.noise_model = args['noise_model'] 241 | self.edge_density = args['edge_density'] 242 | self.noise = args['noise'] 243 | num_examples = args['num_examples_' + name] 244 | n_vertices = args['n_vertices'] 245 | vertex_proba = args['vertex_proba'] 246 | subfolder_name = 'QAPspectral_{}_{}_{}_{}_{}_{}_{}'.format(self.generative_model, 247 | self.noise_model, 248 | num_examples, 249 | n_vertices, vertex_proba, 250 | self.noise, self.edge_density) 251 | path_dataset = os.path.join(path_dataset, subfolder_name) 252 | super().__init__(name, path_dataset, num_examples) 253 | self.data = [] 254 | self.constant_n_vertices = (vertex_proba == 1.) 255 | self.n_vertices_sampler = torch.distributions.Binomial(n_vertices, vertex_proba) 256 | utils.check_dir(self.path_dataset) 257 | 258 | def compute_example(self): 259 | """ 260 | Compute pairs (Adjacency, noisy Adjacency) 261 | """ 262 | n_vertices = int(self.n_vertices_sampler.sample().item()) 263 | try: 264 | g, W = GENERATOR_FUNCTIONS[self.generative_model](self.edge_density, n_vertices) 265 | except KeyError: 266 | raise ValueError('Generative model {} not supported' 267 | .format(self.generative_model)) 268 | try: 269 | W_noise = NOISE_FUNCTIONS[self.noise_model](g, W, self.noise, self.edge_density) 270 | except KeyError: 271 | raise ValueError('Noise model {} not supported' 272 | .format(self.noise_model)) 273 | L = make_laplacian(W) 274 | L_noise = make_laplacian(W_noise) 275 | F = make_spectral_feature(L) 276 | F_noise = make_spectral_feature(L_noise) 277 | return (F, F_noise) 278 | -------------------------------------------------------------------------------- /maskedtensors/test_maskedtensor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for maskedtensor module 3 | To execute, run python -m pytest at the root of the project 4 | 5 | Recommanded: install pytest-repeat to repeat tests with e.g. 6 | python -m pytest . --count 10 7 | """ 8 | 9 | import functools 10 | import pytest 11 | import torch 12 | import torch.nn as nn 13 | import maskedtensor 14 | #from models.graph_classif import Graph_Classif 15 | #from models.layers import MlpBlock, RegularBlock, MlpBlock_Real, Scaled_Block, MlpBlock_vec 16 | from models.layers import MlpBlock_Real, MlpBlock_vec, Matmul, normalize, GraphNorm, Concat, Add, Diag 17 | #from models.base_model_old import Node_Embedding, Graph_Embedding 18 | #from models.siamese_net import Siamese_Node 19 | from toolbox.metrics import accuracy_linear_assignment, accuracy_max 20 | from toolbox.losses import triplet_loss 21 | 22 | def apply_list_tensors(lst, func): 23 | """ Apply func on each tensor (with batch dim) """ 24 | batched_lst = [tens.unsqueeze(0) for tens in lst] 25 | batched_res_lst = [func(tens) for tens in batched_lst] 26 | res_lst = [tens.squeeze(0) for tens in batched_res_lst] 27 | return res_lst 28 | 29 | def apply_binary_list_tensors(lst, func): 30 | """ Apply func on each tensor (with batch dim) """ 31 | batched_lst = [(tens.unsqueeze(0), other.unsqueeze(0)) for tens, other in lst] 32 | batched_res_lst = [func(*tpl) for tpl in batched_lst] 33 | res_lst = [tens.squeeze(0) for tens in batched_res_lst] 34 | return res_lst 35 | 36 | N_FEATURES = 16 37 | N_VERTICES_RANGE = range(40,50) 38 | FIXED_N_VERTICES = 50 39 | ATOL = 1e-5 40 | DEVICE = torch.device('cpu') 41 | 42 | @pytest.fixture 43 | def tensor_list(): 44 | """ Generate list of tensors (graphs)""" 45 | lst = [torch.empty((N_FEATURES, n_vertices, n_vertices)).normal_() 46 | for n_vertices in N_VERTICES_RANGE] 47 | return lst 48 | 49 | @pytest.fixture 50 | def tensor_listvec(): 51 | """ Generate list of tensors (vector) """ 52 | lst = [torch.empty((N_FEATURES, n_vertices)).normal_() 53 | for n_vertices in N_VERTICES_RANGE] 54 | return lst 55 | 56 | @pytest.fixture 57 | def score_list(): 58 | """ Generate list of tensors with no features and fixed n_vertices""" 59 | lst = [torch.empty((FIXED_N_VERTICES, FIXED_N_VERTICES)).normal_() 60 | for _ in N_VERTICES_RANGE] 61 | return lst 62 | 63 | other_tensor_list = tensor_list 64 | 65 | """ def graph_conv_wrapper(func): 66 | Applies graph convention to a function using pytorch convention 67 | @functools.wraps(func) 68 | def wrapped_func(*args, **kwargs): 69 | new_args = [x.permute(0, 3, 1, 2) for x in args] 70 | ret = func(*new_args, **kwargs) 71 | return ret.permute(0, 2, 3, 1) 72 | return wrapped_func 73 | """ 74 | 75 | def accuracy_wrapper(func): 76 | """ Wraps accuracy funcs so that they behave like the other funcs """ 77 | @functools.wraps(func) 78 | def wrapped_func(weights, *args, **kwargs): 79 | # remove features 80 | new_weights = torch.sum(weights, -1) 81 | ret = func(new_weights, *args, **kwargs) 82 | return torch.Tensor(ret) 83 | return wrapped_func 84 | 85 | # the third parameter specifies whether the base_name of the second maskedtensor 86 | # should match the first one 87 | TEST_BINARY_FUNCS = [ 88 | # when pytorch issue is fixed, change this 89 | #(lambda t1, t2: maskedtensor.dispatch_cat((t1, t2), dim=-1), 'torch.cat', True), 90 | (lambda t1, t2: torch.cat((t1, t2), dim=1), 'torch.cat', True), 91 | (lambda t1, t2: torch.stack((t1, t2), dim=1), 'torch.stack', True), 92 | (torch.matmul, 'torch.matmul', True), 93 | (torch.matmul, 'torch.matmul', False), 94 | #(Siamese_Node(N_FEATURES, 2, 32, 32, 3), 'Siamese_Node', False), 95 | (Matmul(), 'Matmul', False), 96 | (Concat(), 'Concat', True), 97 | (Add(), 'Add', True)] 98 | # embedding is not working yet... 99 | 100 | @pytest.mark.parametrize('func_data', TEST_BINARY_FUNCS, ids=lambda func_data: func_data[1]) 101 | def test_binary_torch_func(tensor_list, other_tensor_list, func_data): 102 | """ Test torch function wich use two tensors """ 103 | func, _, same_base_name = func_data 104 | masked_tensor = maskedtensor.from_list(tensor_list, dims=(1, 2)) 105 | other_base_name = 'N' if same_base_name else 'M' 106 | other_masked_tensor = maskedtensor.from_list(other_tensor_list, dims=(1, 2), 107 | base_name=other_base_name) 108 | res_mt = list(func(masked_tensor, other_masked_tensor)) 109 | binary_list = zip(tensor_list, other_tensor_list) 110 | res_lst = apply_binary_list_tensors(binary_list, func) 111 | for t_mt, t_lst in zip(res_mt, res_lst): 112 | assert t_mt.size() == t_lst.size() 113 | assert torch.allclose(t_mt, t_lst, atol=ATOL), torch.norm(t_mt - t_lst, p=float('inf')) 114 | 115 | ln = nn.LayerNorm(N_FEATURES) 116 | 117 | TEST_FUNCS = [ 118 | (lambda t: torch.add(t, 1), 'torch.add'), 119 | (lambda t: torch.mul(t, 2), 'torch.mul'), 120 | (lambda t: torch.sum(t, 2), 'torch.sum'), 121 | (lambda t: torch.max(t, 2)[0], 'torch.max(dim=2)'), 122 | (lambda t: torch.mean(t, dim=(-2,-1)), 'torch.mean'), 123 | (lambda t: torch.var(t, unbiased=False, dim=(-2,-1)), 'torch.var'), 124 | # keep first dim not to perturb apply_list_tensors 125 | (lambda t: t.permute(0, 3, 2, 1), 'permute'), 126 | (nn.Conv2d(N_FEATURES, 2*N_FEATURES, 1), 'nn.Conv2d'), 127 | (lambda t: ln(t.permute(0,3,2,1)), 'nn.LayerNorm'), 128 | (nn.InstanceNorm2d(N_FEATURES, affine=False, track_running_stats=False), 'InstanceNorm2d'), 129 | (nn.InstanceNorm2d(N_FEATURES, affine=True, track_running_stats=False), 'InstanceNorm2d_affine'), 130 | (lambda t: torch.diag_embed(t,dim1=-2,dim2=-1), 'torch.diag_embed'), 131 | (Diag(), 'Diag') 132 | #(MlpBlock(N_FEATURES, 2*N_FEATURES, 2), 'MlpBlock'), 133 | #(RegularBlock(N_FEATURES, 2*N_FEATURES, 2), 'RegularBlock'), 134 | #(Node_Embedding(N_FEATURES, 2, 32, 32, 3), 'Simple_Node_Embedding'), 135 | #(MlpBlock_Real(N_FEATURES, 2*N_FEATURES, 3), 'MlpBlock_Real'), 136 | #(Scaled_Block(N_FEATURES, 2*N_FEATURES, 2), 'Scaled_Block'), 137 | #(Graph_Embedding(N_FEATURES, 2, 32, 32, 3), 'Graph_Embedding'), 138 | #(Graph_Classif(N_FEATURES, 2, 32, 32, 3), 'Graph_Classif') 139 | ] 140 | 141 | @pytest.mark.parametrize('func_data', TEST_FUNCS, ids=lambda func_data: func_data[1]) 142 | def test_torch_func(tensor_list, func_data): 143 | """ Test torch function """ 144 | func, _ = func_data 145 | masked_tensor = maskedtensor.from_list(tensor_list, dims=(1, 2)) 146 | res_mt = list(func(masked_tensor)) 147 | res_lst = apply_list_tensors(tensor_list, func) 148 | for t_mt, t_lst in zip(res_mt, res_lst): 149 | assert t_mt.size() == t_lst.size() 150 | assert torch.allclose(t_mt, t_lst, atol=ATOL), torch.norm(t_mt - t_lst, p=float('inf')) 151 | 152 | TEST_CUST_FUNCS = [ 153 | (normalize, 'normalize') 154 | ] 155 | 156 | @pytest.mark.parametrize('func_data', TEST_CUST_FUNCS, ids=lambda func_data: func_data[1]) 157 | def test_custom_func(tensor_list, func_data): 158 | """ Test custom function """ 159 | func, _ = func_data 160 | masked_tensor = maskedtensor.from_list(tensor_list, dims=(1, 2)) 161 | res_mt = list(func(masked_tensor, constant_n_vertices=False)) 162 | res_lst = apply_list_tensors(tensor_list, func) 163 | for t_mt, t_lst in zip(res_mt, res_lst): 164 | assert t_mt.size() == t_lst.size() 165 | assert torch.allclose(t_mt, t_lst, atol=ATOL), torch.norm(t_mt - t_lst, p=float('inf')) 166 | 167 | mlp_mt = MlpBlock_Real(N_FEATURES, 2*N_FEATURES, 2, constant_n_vertices=False) 168 | mlp = MlpBlock_Real(N_FEATURES, 2*N_FEATURES, 2) 169 | mlp.convs = mlp_mt.convs 170 | gn_mt = GraphNorm(N_FEATURES, constant_n_vertices=False) 171 | gn = GraphNorm(N_FEATURES) 172 | 173 | TEST_LAYERS = [ 174 | (mlp_mt, mlp, 'MlpBlock_Real'), 175 | (gn_mt, gn, 'GraphNorm') 176 | ] 177 | 178 | @pytest.mark.parametrize('func_data', TEST_LAYERS, 179 | ids=lambda func_data: func_data[2]) 180 | def test_layers(tensor_list, func_data): 181 | """ Test layer """ 182 | func_mt, func, _ = func_data 183 | masked_tensor = maskedtensor.from_list(tensor_list, dims=(1, 2)) 184 | res_mt = list(func_mt(masked_tensor)) 185 | res_lst = apply_list_tensors(tensor_list, func) 186 | for t_mt, t_lst in zip(res_mt, res_lst): 187 | assert t_mt.size() == t_lst.size() 188 | assert torch.allclose(t_mt, t_lst, atol=ATOL), torch.norm(t_mt - t_lst, p=float('inf')) 189 | 190 | 191 | TEST_MAX =[ 192 | (lambda t: torch.max(t), 'torch.max') 193 | ] 194 | 195 | @pytest.mark.parametrize('fun_max', TEST_MAX, ids=lambda fun_max: fun_max[1]) 196 | def test_max(tensor_list, fun_max): 197 | f_max , _ = fun_max 198 | masked_tensor = maskedtensor.from_list(tensor_list, dims=(1, 2)) 199 | res_mt = f_max(masked_tensor) 200 | res_lst = torch.max(torch.tensor(apply_list_tensors(tensor_list, f_max))) 201 | assert torch.allclose(res_mt, res_lst, atol=ATOL) 202 | 203 | TEST_VEC = [ 204 | (lambda t: torch.max(t, dim=1)[0], 'torch.max_vec'), 205 | (MlpBlock_vec(N_FEATURES, 2*N_FEATURES, 2), 'MlpBlock_vec'), 206 | (lambda t: ln(t.permute(0,2,1)), 'nn.LayerNorm_vec') 207 | ] 208 | 209 | @pytest.mark.parametrize('func_data', TEST_VEC, ids=lambda func_data: func_data[1]) 210 | def test_vec_func(tensor_listvec, func_data): 211 | """ Test vec function """ 212 | func, _ = func_data 213 | masked_tensor = maskedtensor.from_list(tensor_listvec, dims=(1,)) 214 | res_mt = list(func(masked_tensor)) 215 | res_lst = apply_list_tensors(tensor_listvec, func) 216 | for t_mt, t_lst in zip(res_mt, res_lst): 217 | assert t_mt.size() == t_lst.size() 218 | assert torch.allclose(t_mt, t_lst, atol=ATOL), torch.norm(t_mt - t_lst, p=float('inf')) 219 | 220 | 221 | TEST_LOSS_FUNCS = [ 222 | (triplet_loss(loss_reduction='mean'), 'loss_mean'), 223 | (triplet_loss(loss_reduction='mean_of_mean'), 'loss_mean_of_mean')] 224 | 225 | @pytest.mark.parametrize('func_data', TEST_LOSS_FUNCS, ids=lambda func_data: func_data[1]) 226 | def test_loss_func(score_list, func_data): 227 | """ Test score function """ 228 | func, _ = func_data 229 | masked_tensor = maskedtensor.from_list(score_list, dims=(0, 1)) 230 | res_mt = func(masked_tensor) 231 | res_lst = func(torch.stack(score_list)) 232 | assert torch.allclose(res_mt, res_lst, atol=ATOL), torch.norm(res_mt - res_lst, p=float('inf')) 233 | 234 | TEST_ACCURACY_FUNCS = [ 235 | #(accuracy_wrapper(accuracy_linear_assignment), 'accuracy_linear_assignment'), 236 | (accuracy_wrapper(accuracy_max), 'accuracy_max')] 237 | 238 | @pytest.mark.parametrize('func_data', TEST_ACCURACY_FUNCS, ids=lambda func_data: func_data[1]) 239 | def test_accuracy_func(tensor_list, func_data): 240 | func, _ = func_data 241 | masked_tensor = maskedtensor.from_list(tensor_list, dims=(1, 2)) 242 | res_mt = func(masked_tensor) 243 | res_lst = sum(apply_list_tensors(tensor_list, func)) 244 | assert torch.allclose(res_mt, res_lst, atol=ATOL), torch.norm(res_mt - res_lst, p=float('inf')) 245 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /models/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | #from maskedtensors.maskedtensor import dispatch_cat 5 | import math 6 | from collections import namedtuple 7 | from torch.nn.parameter import Parameter 8 | from models.utils import * 9 | 10 | # class GraphNorm(nn.Module): 11 | # def __init__(self, features=0, constant_n_vertices=True, eps =1e-05): 12 | # super().__init__() 13 | # self.constant_n_vertices = constant_n_vertices 14 | # self.eps = eps 15 | 16 | # def forward(self, b): 17 | # return normalize(b, self.constant_n_vertices, self.eps) 18 | 19 | class Rec_block(nn.Module): 20 | def __init__(self, block, n_iter): 21 | super().__init__() 22 | self.block_dic = {'input': (None, []) , 'block': block, 'output' : Identity()} 23 | self.n_iter = n_iter 24 | self.net = Network(self.block_dic) 25 | 26 | def forward(self, x): 27 | x_input = {'input': x} 28 | for i in range(self.n_iter): 29 | x = self.net(x_input) 30 | x_input = {'input': x['output']} 31 | return x['output'] 32 | 33 | class Recall_block(nn.Module): 34 | def __init__(self, block, n_iter): 35 | super().__init__() 36 | self.block_dic = {'input': (None, []) , 'block': block, 'output' : Identity()} 37 | self.n_iter = n_iter 38 | self.net = Network(self.block_dic) 39 | 40 | def forward(self, x0, x1): 41 | x_input = {'input': torch.cat((x0,x1), dim=1)} 42 | for i in range(self.n_iter): 43 | x = self.net(x_input) 44 | x_input = {'input': torch.cat((x0,x['output']), dim=1)} 45 | return x['output'] 46 | 47 | class GraphNorm(nn.Module): 48 | def __init__(self, features, constant_n_vertices=True, elementwise_affine=True, eps =1e-05,device=None, dtype=None): 49 | super().__init__() 50 | factory_kwargs = {'device': device, 'dtype': dtype} 51 | self.constant_n_vertices = constant_n_vertices 52 | self.eps = eps 53 | self.elementwise_affine = elementwise_affine 54 | self.features = (1,features,1,1) 55 | if self.elementwise_affine: 56 | self.weight = Parameter(torch.empty(self.features, **factory_kwargs)) 57 | self.bias = Parameter(torch.empty(self.features, **factory_kwargs)) 58 | else: 59 | self.register_parameter('weight', None) 60 | self.register_parameter('bias', None) 61 | self.reset_parameters() 62 | 63 | def reset_parameters(self) -> None: 64 | if self.elementwise_affine: 65 | nn.init.ones_(self.weight) 66 | nn.init.zeros_(self.bias) 67 | 68 | def forward(self, b): 69 | return self.weight*normalize(b, constant_n_vertices=self.constant_n_vertices, eps=self.eps)+self.bias 70 | 71 | def normalize(b, constant_n_vertices=True, eps =1e-05): 72 | means = torch.mean(b, dim = (-1,-2), keepdim=True) 73 | vars = torch.var(b, unbiased=False,dim = (-1,-2), keepdim=True) 74 | #(b,f,n1,n2) = b.shape 75 | #assert n1 == n2 76 | if constant_n_vertices: 77 | n = b.size(-1) 78 | else: 79 | n = torch.sum(b.mask_dict['N'], dim=1).align_as(vars) 80 | return (b-means)/(2*torch.sqrt(n*(vars+eps))) 81 | 82 | import numpy as np 83 | 84 | class MlpBlock_Node(nn.Module): 85 | """ 86 | Block of MLP layers with activation function after each (1x1 conv layers) except last one 87 | """ 88 | def __init__(self, in_features, out_features, depth_of_mlp, activation_fn = F.relu, constant_n_vertices=True): 89 | super().__init__() 90 | self.activation = activation_fn 91 | self.depth_mlp = depth_of_mlp 92 | self.cst_vertices = constant_n_vertices 93 | self.convs = nn.ModuleList() 94 | for _ in range(depth_of_mlp): 95 | self.convs.append(nn.Conv1d(in_features, out_features, kernel_size=1, padding=0, bias=True)) 96 | _init_weights(self.convs[-1]) 97 | in_features = out_features 98 | self.gn = GraphNorm(out_features, constant_n_vertices=constant_n_vertices) 99 | #self.gn = nn.InstanceNorm2d(out_features, affine=False) 100 | 101 | def forward(self, inputs): 102 | n = inputs.size(-1) 103 | out = inputs 104 | for conv_layer in self.convs[:-1]: 105 | out = self.activation(conv_layer(out)) 106 | return self.gn(self.convs[-1](out))#normalize(self.convs[-1](out), constant_n_vertices=self.cst_vertices) 107 | 108 | 109 | class MlpBlock_Real(nn.Module): 110 | """ 111 | Block of MLP layers with activation function after each (1x1 conv layers) except last one 112 | """ 113 | def __init__(self, in_features, out_features, depth_of_mlp, activation_fn = F.relu, constant_n_vertices=True): 114 | super().__init__() 115 | self.activation = activation_fn 116 | self.depth_mlp = depth_of_mlp 117 | self.cst_vertices = constant_n_vertices 118 | self.convs = nn.ModuleList() 119 | for _ in range(depth_of_mlp): 120 | self.convs.append(nn.Conv2d(in_features, out_features, kernel_size=1, padding=0, bias=True)) 121 | _init_weights(self.convs[-1]) 122 | in_features = out_features 123 | self.gn = GraphNorm(out_features, constant_n_vertices=constant_n_vertices) 124 | #self.gn = nn.InstanceNorm2d(out_features, affine=False) 125 | 126 | def forward(self, inputs): 127 | n = inputs.size(-1) 128 | out = inputs 129 | for conv_layer in self.convs[:-1]: 130 | out = self.activation(conv_layer(out)) 131 | return self.gn(self.convs[-1](out))#normalize(self.convs[-1](out), constant_n_vertices=self.cst_vertices) 132 | 133 | 134 | def _init_weights(layer): 135 | """ 136 | Init weights of the layer 137 | :param layer: 138 | :return: 139 | """ 140 | nn.init.xavier_uniform_(layer.weight) 141 | if layer.bias is not None: 142 | nn.init.zeros_(layer.bias) 143 | 144 | 145 | class Concat(nn.Module): 146 | def forward(self, *xs): return torch.cat(xs, dim=1) 147 | 148 | class Diag(nn.Module): 149 | def forward(self, xs): return torch.diag_embed(xs) 150 | 151 | class Identity(namedtuple('Identity', [])): 152 | def __call__(self, x): return x 153 | 154 | 155 | class Permute(namedtuple('Permute', [])): 156 | def __call__(self, x): return x.permute(0,2,1) 157 | 158 | class Add(nn.Module): 159 | def forward(self, xs1, xs2): return torch.add(xs1, xs2) 160 | 161 | class Matmul(nn.Module): 162 | def forward(self, xs1, xs2): return torch.matmul(xs1, xs2) 163 | 164 | class Matmul_zerodiag(nn.Module): 165 | def forward(self, xs1, xs2): 166 | (bs,f,n1,n2) = xs1.shape 167 | device = xs1.device 168 | assert n1 == n2 169 | mask = torch.ones(n1,n1) - torch.eye(n1,n1) 170 | mask = mask.reshape((1,1,n1,n1)).to(device) 171 | mask_b = mask.repeat(bs,f,1,1) 172 | return torch.matmul(torch.mul(xs1,mask_b), torch.mul(xs2,mask_b))#torch.mul(torch.matmul(xs1, xs2), mask_b) 173 | #return torch.matmul(zero_diag_fun(xs1), zero_diag_fun(xs2)) 174 | 175 | def zero_diag_fun(x): 176 | (bs,f,n1,n2) = x.shape 177 | device = x.device 178 | assert n1 == n2 179 | mask = torch.ones(n1,n1, dtype =x.dtype, device=x.device).fill_diagonal_(0) 180 | y = x 181 | for s in x: 182 | for f in s: 183 | y *= mask 184 | return y 185 | 186 | class Layernorm(nn.Module): 187 | def __init__(self, n_features): 188 | super().__init__() 189 | self.layer_norm = nn.LayerNorm([100,100,n_features], elementwise_affine=False) 190 | 191 | def forward(self, x): 192 | return self.layer_norm(x.permute(0,2,3,1)).permute(0,3,1,2) 193 | 194 | class ColumnMaxPooling(nn.Module): 195 | """ 196 | take a batch (bs, in_features, n_vertices, n_vertices) 197 | and returns (bs, in_features, n_vertices) 198 | """ 199 | def __init__(self): 200 | super().__init__() 201 | 202 | def forward(self, x): 203 | return torch.max(x, -1)[0] 204 | 205 | class ColumnSumPooling(nn.Module): 206 | """ 207 | take a batch (bs, in_features, n_vertices, n_vertices) 208 | and returns (bs, in_features, n_vertices) 209 | """ 210 | def __init__(self): 211 | super().__init__() 212 | 213 | def forward(self, x): 214 | return torch.sum(x, -1) 215 | 216 | 217 | class MlpBlock_vec(nn.Module): 218 | """ 219 | Block of MLP layers acting on vectors (bs, features, n) 220 | """ 221 | def __init__(self, in_features, out_features, depth_of_mlp, activation_fn = F.relu): 222 | super().__init__() 223 | self.activation = activation_fn 224 | self.depth_mlp = depth_of_mlp 225 | self.mlp = nn.ModuleList() 226 | for _ in range(depth_of_mlp): 227 | self.mlp.append(nn.Linear(in_features, out_features)) 228 | _init_weights(self.mlp[-1]) 229 | in_features = out_features 230 | 231 | def forward(self, inputs): 232 | out = inputs.permute(0,2,1) 233 | for fc in self.mlp[:-1]: 234 | out = self.activation(fc(out)) 235 | return self.mlp[-1](out).permute(0,2,1) 236 | 237 | class SelfAttentionLayer(nn.Module): 238 | def __init__(self, config, dmt=False): 239 | super().__init__() 240 | assert config.n_embd % config.n_heads == 0 241 | self.n_embd = config.n_embd 242 | self.n_heads = config.n_heads 243 | self.emb_hea = self.n_embd//self.n_heads 244 | #print(self.emb_hea) 245 | self.dmt = dmt 246 | self.Query = nn.Linear(self.n_embd, self.n_embd) 247 | self.Key = nn.Linear(self.n_embd, self.n_embd) 248 | self.Value = nn.Linear(self.n_embd, self.n_embd) 249 | 250 | def forward(self, x): # x (bs, T, ne) 251 | b,t,n = x.size() 252 | #x = x.permute(0,2,1) 253 | Q = self.Query(x) # (bs, T, ne) 254 | Q = Q.view(b,t,self.n_heads,self.emb_hea).permute(0,2,1,3) # (bs, nh, T, nne) 255 | if self.dmt: 256 | Q.tensor.rename_(N='N_') 257 | K = torch.div(self.Key(x),math.sqrt(self.emb_hea)) # (bs, T, ne) 258 | K = K.view(b,t,self.n_heads,self.emb_hea).permute(0,2,1,3) # (bs, nh, T, nne) 259 | V = self.Value(x) # (bs, T, ne) 260 | V = V.view(b,t,self.n_heads,self.emb_hea).permute(0,2,1,3) # (bs, nh, T, nne) 261 | #A = torch.einsum('ntk,nsk->nst', Q, K) # (bs, T, kc), (bs, T, kc) -> (bs , T, T) 262 | A = torch.matmul(K, Q.permute(0,1,3,2)) #(bs, nh, T, T) 263 | A = F.softmax(A, dim=-1) 264 | y = torch.matmul(A, V)#torch.bmm(A, V) # (bs, nh, T, nne) 265 | y = y.permute(0,2,1,3).contiguous().view(b, t, n) # (bs, T, ne) 266 | return y, A 267 | 268 | class AttentionBlock_vec(nn.Module): 269 | """ 270 | Attention Block of MLP layers acting on vectors (bs, features, n) 271 | """ 272 | def __init__(self, nb_features, depth_of_mlp, nb_heads=1 ,dmt=True, activation_fn = F.relu): 273 | super().__init__() 274 | self.activation = activation_fn 275 | self.depth_mlp = depth_of_mlp 276 | self.mlp = nn.ModuleList() 277 | for _ in range(depth_of_mlp): 278 | self.mlp.append(nn.Linear(nb_features, nb_features)) 279 | _init_weights(self.mlp[-1]) 280 | #in_features = out_features 281 | 282 | class config: 283 | n_embd = nb_features 284 | n_heads = nb_heads 285 | self.attn = SelfAttentionLayer(config, dmt) 286 | self.ln_1 = nn.LayerNorm(config.n_embd) 287 | self.ln_2 = nn.LayerNorm(config.n_embd) 288 | 289 | def mlpf(self, x): 290 | out = x 291 | for fc in self.mlp[:-1]: 292 | out = self.activation(fc(out)) 293 | return self.mlp[-1](out) 294 | 295 | def forward(self, x): 296 | y, A = self.attn(self.ln_1(x)) 297 | x = torch.add(x, y) 298 | return torch.add(x,self.mlpf(self.ln_2(x))) 299 | 300 | class GraphAttentionLayer(nn.Module): 301 | def __init__(self, config, dmt=False): 302 | super().__init__() 303 | assert config.n_embd % config.n_heads == 0 304 | self.n_embd = config.n_embd 305 | self.n_heads = config.n_heads 306 | self.depth_of_mlp = config.d_of_mlp 307 | self.emb_hea = self.n_embd//self.n_heads 308 | #print(self.emb_hea) 309 | self.dmt = dmt 310 | self.Query = nn.Linear(self.n_embd, self.n_embd) 311 | self.Key = nn.Linear(self.n_embd, self.n_embd) 312 | self.Value = MlpBlock_Real(self.n_embd,self.n_embd,self.depth_of_mlp)#nn.Linear(self.n_embd, self.n_embd) 313 | 314 | def forward(self, x): # x (bs, ne, T, T) 315 | b,n,t,t = x.size() 316 | V = self.Value(x) # (bs, ne, T, T) 317 | x = x.permute(0,2,3,1) # (bs, T, T, ne) 318 | Q = normalize(self.Query(x)) # (bs, ne, T, T) 319 | K = normalize(self.Key(x)) # (bs, ne, T, T) 320 | Q = Q.view(b,t,t,self.n_heads,self.emb_hea).permute(0,3,4,1,2) # (bs, nh, nne, T, T) 321 | K = K.view(b,t,t,self.n_heads,self.emb_hea).permute(0,3,4,1,2) # (bs, nh, nne, T, T) 322 | V = V.view(b,self.n_heads,self.emb_hea,t,t)#.permute(0,3,4,1,2) # (bs, nh, nne, T, T) 323 | A = torch.einsum('nhftu,nhfru->nhtr', Q, K) # (bs,nh,nne,T, T), (bs,nh,nne,T,T) -> (bs,nh, T, T) 324 | A = F.softmax(A, dim=-1) 325 | y = torch.einsum('bhst,bhfst -> bhfs', A, V) # (bs, nh,nne T, T) 326 | #y = torch.matmul(A.unsqueeze(2), V) # (bs, nh,nne T, T) 327 | y = y.contiguous().view(b, n, t) 328 | return y#normalize(y) 329 | 330 | class GraphAttentionLayer_mlp(nn.Module): 331 | def __init__(self, config, dmt=False): 332 | super().__init__() 333 | assert config.n_embd % config.n_heads == 0 334 | self.n_embd = config.n_embd 335 | self.n_heads = config.n_heads 336 | self.depth_of_mlp = config.d_of_mlp 337 | self.emb_hea = self.n_embd//self.n_heads 338 | #print(self.emb_hea) 339 | self.dmt = dmt 340 | self.Query = MlpBlock_Real(self.n_embd,self.n_embd,self.depth_of_mlp)#nn.Linear(self.n_embd, self.n_embd) 341 | self.Key = MlpBlock_Real(self.n_embd,self.n_embd,self.depth_of_mlp)#nn.Linear(self.n_embd, self.n_embd) 342 | self.Value = MlpBlock_Real(self.n_embd,self.n_embd,self.depth_of_mlp)#nn.Linear(self.n_embd, self.n_embd) 343 | 344 | def forward(self, x): # x (bs, ne, T, T) 345 | b,n,t,t = x.size() 346 | V = self.Value(x) # (bs, ne, T, T) 347 | #print(x.shape) 348 | #x = x.permute(0,2,3,1) # (bs, T, T, ne) 349 | Q = self.Query(x) # (bs,ne, T, T) 350 | Q = Q.view(b,self.n_heads,self.emb_hea,t,t) # (bs, nh, nne, T, T) 351 | K = self.Key(x)#torch.div(self.Key(x),math.sqrt(self.emb_hea)) # (bs,ne, T, T) 352 | K = K.view(b,self.n_heads,self.emb_hea,t,t) # (bs, nh, nne, T, T) 353 | V = V.view(b,self.n_heads,self.emb_hea,t,t) # (bs, nh, nne, T, T) 354 | A = torch.einsum('nhftu,nhfru->nhtr', Q, K) # (bs,nh,nne,T, T), (bs,nh,nne,T,T) -> (bs,nh, T, T) 355 | A = F.softmax(A, dim=-1) 356 | #print(A.unsqueeze(2).shape) 357 | #print(V.shape) 358 | #y = torch.matmul(A.unsqueeze(2), V) # (bs, nh,nne T, T) 359 | y = torch.einsum('bhst,bhfst -> bhfst', A, V) # (bs, nh,nne T, T) 360 | y = y.contiguous().view(b, n, t, t) 361 | return normalize(y) -------------------------------------------------------------------------------- /commander_explore.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import yaml 4 | import argparse 5 | 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | from models import get_siamese_model_exp, get_siamese_model_test 9 | import loaders.data_generator as dg 10 | from loaders.loaders import siamese_loader 11 | #from toolbox.optimizer import get_optimizer 12 | import toolbox.utils as utils 13 | from datetime import datetime 14 | from pathlib import Path 15 | import pytorch_lightning as pl 16 | from pytorch_lightning.loggers import WandbLogger 17 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint 18 | 19 | 20 | def get_config(filename='default_config.yaml') -> dict: 21 | with open(filename, 'r') as f: 22 | config = yaml.safe_load(f) 23 | return config 24 | 25 | def custom_name(config): 26 | l_name = [config['arch']['node_emb']['type'], 27 | config['data']['train']['generative_model'], config['data']['train']['n_vertices'], 28 | config['data']['train']['edge_density']] 29 | name = "_".join([str(e) for e in l_name]) 30 | return name 31 | 32 | def check_paths_update(config, name): 33 | """ 34 | add to the configuration: 35 | 'path_log' = root/experiments/$problem/$name/ 36 | (arch_gnn)_(numb_locks)_(generative_model)_(n_vertices)_(edge_density)/date_time 37 | 'date_time' 38 | ['data']['path_dataset'] = root/experiments/qap/data 39 | save the new configuration at path_log/config.json 40 | """ 41 | now = datetime.now() # current date and time 42 | date_time = now.strftime("%m-%d-%y-%H-%M") 43 | dic = {'date_time' : date_time} 44 | #expe_runs_qap = os.path.join(QAP_DIR,config['name'],'runs/') 45 | name = custom_name(config) 46 | name = os.path.join(name, str(date_time)) 47 | path_log = os.path.join(PB_DIR,config['name'], name) 48 | utils.check_dir(path_log) 49 | dic['path_log'] = path_log 50 | 51 | utils.check_dir(DATA_PB_DIR) 52 | config['data'].update({'path_dataset' : DATA_PB_DIR}) 53 | 54 | #print(path_log) 55 | config.update(dic) 56 | with open(os.path.join(path_log, 'config.json'), 'w') as f: 57 | json.dump(config, f) 58 | return config 59 | 60 | def train(config): 61 | """ Main func. 62 | """ 63 | cpu = config['cpu'] 64 | #train, 65 | problem =config['problem'] 66 | config_arch = config['arch'] 67 | #test_enabled, 68 | path_log = config['path_log'] 69 | data = config['data'] 70 | max_epochs = config['train']['epochs'] 71 | batch_size = config['train']['batch_size'] 72 | config_optim = config['train'] 73 | log_freq = config_optim['log_freq'] 74 | 75 | print("Heading to Training.") 76 | global best_score, best_epoch 77 | best_score, best_epoch = -1, -1 78 | print("Current problem : ", problem) 79 | 80 | use_cuda = not cpu and torch.cuda.is_available() 81 | device = 'cuda' if use_cuda else 'cpu' 82 | print('Using device:', device) 83 | 84 | # init random seeds 85 | utils.setup_env(cpu) 86 | 87 | print("Models saved in ", path_log) 88 | #exp_helper = init_helper(problem) 89 | model_pl = get_siamese_model_exp(config_arch, config_optim) 90 | 91 | generator = dg.QAP_Generator 92 | #generator = dg.QAP_spectralGenerator 93 | gene_train = generator('train', data['train'], data['path_dataset']) 94 | gene_train.load_dataset() 95 | gene_val = generator('val', data['train'], data['path_dataset']) 96 | gene_val.load_dataset() 97 | train_loader = siamese_loader(gene_train, batch_size, 98 | gene_train.constant_n_vertices) 99 | val_loader = siamese_loader(gene_val, batch_size, 100 | gene_val.constant_n_vertices, shuffle=False) 101 | 102 | 103 | #optimizer, scheduler = get_optimizer(train,model) 104 | #print("Model #parameters : ", sum(p.numel() for p in model.parameters() if p.requires_grad)) 105 | 106 | """ if not train['anew']: 107 | try: 108 | utils.load_model(model,device,train['start_model']) 109 | print("Model found, using it.") 110 | except RuntimeError: 111 | print("Model not existing. Starting from scratch.") 112 | """ 113 | #model.to(device) 114 | # train model 115 | checkpoint_callback = ModelCheckpoint(save_top_k=1, mode='max', monitor="val_acc") 116 | lr_monitor = LearningRateMonitor(logging_interval='epoch') 117 | if config['observers']['wandb']: 118 | logger = WandbLogger(project=f"{config['problem']}_{config['name']}", log_model="all", save_dir=path_log) 119 | logger.experiment.config.update(config) 120 | trainer = pl.Trainer(accelerator=device,max_epochs=max_epochs,logger=logger,log_every_n_steps=log_freq,callbacks=[lr_monitor, checkpoint_callback],precision=16) 121 | else: 122 | trainer = pl.Trainer(accelerator=device,max_epochs=max_epochs,log_every_n_steps=log_freq,callbacks=[lr_monitor, checkpoint_callback],precision=16) 123 | trainer.fit(model_pl, train_loader, val_loader) 124 | 125 | return trainer 126 | 127 | 128 | def tune(config): 129 | """ Main func. 130 | """ 131 | cpu = config['cpu'] 132 | #train, 133 | problem =config['problem'] 134 | config_arch = config['arch'] 135 | #test_enabled, 136 | path_log = config['path_log'] 137 | data = config['data'] 138 | max_epochs = config['train']['epochs'] 139 | batch_size = config['train']['batch_size'] 140 | config_optim = config['train'] 141 | log_freq = config_optim['log_freq'] 142 | 143 | print("Heading to Tuning.") 144 | global best_score, best_epoch 145 | best_score, best_epoch = -1, -1 146 | print("Current problem : ", problem) 147 | 148 | use_cuda = not cpu and torch.cuda.is_available() 149 | device = 'cuda' if use_cuda else 'cpu' 150 | print('Using device:', device) 151 | 152 | # init random seeds 153 | utils.setup_env(cpu) 154 | 155 | print("Models saved in ", path_log) 156 | #exp_helper = init_helper(problem) 157 | model_pl = get_siamese_model_test(data['test']['path_model']) 158 | 159 | generator = dg.QAP_Generator 160 | #generator = dg.QAP_spectralGenerator 161 | gene_train = generator('train', data['train'], data['path_dataset']) 162 | gene_train.load_dataset() 163 | gene_val = generator('val', data['train'], data['path_dataset']) 164 | gene_val.load_dataset() 165 | train_loader = siamese_loader(gene_train, batch_size, 166 | gene_train.constant_n_vertices) 167 | val_loader = siamese_loader(gene_val, batch_size, 168 | gene_val.constant_n_vertices, shuffle=False) 169 | 170 | 171 | #optimizer, scheduler = get_optimizer(train,model) 172 | #print("Model #parameters : ", sum(p.numel() for p in model.parameters() if p.requires_grad)) 173 | 174 | """ if not train['anew']: 175 | try: 176 | utils.load_model(model,device,train['start_model']) 177 | print("Model found, using it.") 178 | except RuntimeError: 179 | print("Model not existing. Starting from scratch.") 180 | """ 181 | #model.to(device) 182 | # train model 183 | if config['observers']['wandb']: 184 | logger = WandbLogger(project=f"{config['problem']}_{config['name']}", log_model="all", save_dir=path_log) 185 | logger.experiment.config.update(config) 186 | lr_monitor = LearningRateMonitor(logging_interval='epoch') 187 | #sc_cb = ScheduleCallback() 188 | trainer = pl.Trainer(accelerator=device,max_epochs=max_epochs,logger=logger,log_every_n_steps=log_freq,callbacks=[lr_monitor,],precision=16) 189 | else: 190 | trainer = pl.Trainer(accelerator=device,max_epochs=max_epochs,log_every_n_steps=log_freq,precision=16) 191 | trainer.fit(model_pl, train_loader, val_loader) 192 | 193 | return trainer 194 | 195 | """ is_best = True 196 | try: 197 | for epoch in range(train['epoch']): 198 | print('Current epoch: ', epoch) 199 | if not use_dgl: 200 | trainer.train_triplet(train_loader,model,optimizer,exp_helper,device,epoch,eval_score=True,print_freq=train['print_freq']) 201 | else: 202 | trainer.train_triplet_dgl(train_loader,model,optimizer,exp_helper,device,epoch,uncollate_function, 203 | sym_problem=symmetric_problem,eval_score=True,print_freq=train['print_freq']) 204 | 205 | 206 | if not use_dgl: 207 | relevant_metric, loss = trainer.val_triplet(val_loader,model,exp_helper,device,epoch,eval_score=True) 208 | else: 209 | relevant_metric, loss = trainer.val_triplet_dgl(val_loader,model,exp_helper,device,epoch,uncollate_function,eval_score=True) 210 | scheduler.step(loss) 211 | # remember best acc and save checkpoint 212 | is_best = (relevant_metric > best_score) 213 | best_score = max(relevant_metric, best_score) 214 | if True == is_best: 215 | best_epoch = epoch 216 | utils.save_checkpoint({ 217 | 'epoch': epoch + 1, 218 | 'state_dict': model.state_dict(), 219 | 'best_score': best_score, 220 | 'best_epoch': best_epoch, 221 | 'exp_logger': exp_helper.get_logger(), 222 | }, is_best,path_log) 223 | 224 | cur_lr = utils.get_lr(optimizer) 225 | if exp_helper.stop_condition(cur_lr): 226 | print(f"Learning rate ({cur_lr}) under stopping threshold, ending training.") 227 | break 228 | except KeyboardInterrupt: 229 | print('-' * 89) 230 | print('Exiting from training early because of KeyboardInterrupt') 231 | if test_enabled: 232 | eval(use_model=model) 233 | """ 234 | def test(config): 235 | """ Main func. 236 | """ 237 | cpu = config['cpu'] 238 | #train, 239 | #problem =config['problem'] 240 | #config_arch = config['arch'] 241 | #test_enabled, 242 | #path_log = config['path_log'] 243 | data = config['data'] 244 | #max_epochs = config['train']['epochs'] 245 | batch_size = 1#config['train']['batch_size'] 246 | #config_optim = config['train'] 247 | #log_freq = config_optim['log_freq'] 248 | 249 | #print("Heading to Test.") 250 | #global best_score, best_epoch 251 | #best_score, best_epoch = -1, -1 252 | #print("Current problem : ", problem) 253 | 254 | use_cuda = not cpu and torch.cuda.is_available() 255 | device = 'cuda' if use_cuda else 'cpu' 256 | print('Using device:', device) 257 | 258 | # init random seeds 259 | utils.setup_env(cpu) 260 | 261 | #print("Models saved in ", path_log) 262 | #exp_helper = init_helper(problem) 263 | #model_pl = get_siamese_model_exp(config_arch, config_optim) 264 | model = get_siamese_model_test(data['test']['path_model']) 265 | 266 | path_data_test = os.path.join(data['path_dataset'], 'test/') 267 | utils.check_dir(path_data_test) 268 | generator = dg.QAP_Generator 269 | #generator = dg.QAP_spectralGenerator 270 | gene_test = generator('test', data['test'], path_data_test) 271 | gene_test.load_dataset() 272 | #gene_val = generator('val', data['train'], data['path_dataset']) 273 | #gene_val.load_dataset() 274 | test_loader = siamese_loader(gene_test, batch_size, 275 | gene_test.constant_n_vertices, shuffle=False) 276 | #val_loader = siamese_loader(gene_val, batch_size, 277 | # gene_val.constant_n_vertices, shuffle=False) 278 | 279 | 280 | #optimizer, scheduler = get_optimizer(train,model) 281 | #print("Model #parameters : ", sum(p.numel() for p in model.parameters() if p.requires_grad)) 282 | 283 | """ if not train['anew']: 284 | try: 285 | utils.load_model(model,device,train['start_model']) 286 | print("Model found, using it.") 287 | except RuntimeError: 288 | print("Model not existing. Starting from scratch.") 289 | """ 290 | #model.to(device) 291 | 292 | trainer = pl.Trainer(accelerator=device,precision=16) 293 | res_test = trainer.test(model, test_loader) 294 | return res_test 295 | #return trainer 296 | #@ex.command 297 | """ def eval(cpu, train, arch, data, use_dgl, problem, use_model=None): 298 | print("Heading to evaluation.") 299 | 300 | use_cuda = not cpu and torch.cuda.is_available() 301 | device = 'cuda' if use_cuda else 'cpu' 302 | print('Using device:', device) 303 | 304 | if use_model is None: 305 | model = get_model_gen(arch) 306 | model.to(device) 307 | model = utils.load_model(model, device, train['start_model']) 308 | else: 309 | model = use_model 310 | 311 | helper = init_helper(problem) 312 | 313 | if use_dgl: 314 | print(f"Arch : {arch['arch_gnn']}") 315 | from loaders.siamese_loaders import get_uncollate_function 316 | uncollate_function = get_uncollate_function(data['test']['n_vertices'],problem) 317 | cur_crit = helper.criterion 318 | cur_eval = helper.eval_function 319 | helper.criterion = lambda output, target : cur_crit(uncollate_function(output), target) 320 | helper.eval_function = lambda output, target : cur_eval(uncollate_function(output), target) 321 | 322 | 323 | gene_test = helper.generator('test', data['test'], data['path_dataset']) 324 | gene_test.load_dataset(use_dgl) 325 | test_loader = get_loader(use_dgl,gene_test, train['batch_size'], 326 | gene_test.constant_n_vertices,problem=problem) 327 | 328 | relevant_metric, loss = trainer.val_triplet(test_loader, model, helper, device, 329 | epoch=0, eval_score=True, 330 | val_test='test') 331 | """ 332 | #key = create_key() 333 | #filename_test = os.path.join(log_dir, output_filename) 334 | #print('Saving result at: ',filename_test) 335 | #metric_to_save = helper.get_relevant_metric_with_name('test') 336 | #utils.save_to_json(key, loss, metric_to_save, filename_test) 337 | 338 | #@ex.automain 339 | def main(): 340 | parser = argparse.ArgumentParser(description='Main file for creating experiments.') 341 | parser.add_argument('command', metavar='c', choices=['train','test', 'tune'], 342 | help='Command to execute : train or test') 343 | parser.add_argument('--n_vertices', type=int, default=0) 344 | parser.add_argument('--noise', type=float, default=0) 345 | parser.add_argument('--edge_density', type=float, default=0) 346 | parser.add_argument('--block_init', type=str, default='block') 347 | parser.add_argument('--block_inside', type=str, default='block_inside') 348 | parser.add_argument('--node_emb', type=str, default='node_embedding_block') 349 | args = parser.parse_args() 350 | if args.command=='train': 351 | training=True 352 | default_test = False 353 | tuning = False 354 | elif args.command=='test': 355 | training=False 356 | default_test = True 357 | tuning = False 358 | elif args.command=='tune': 359 | training=False 360 | default_test=False 361 | tuning = True 362 | 363 | config = get_config() 364 | if args.n_vertices != 0: 365 | config['data']['train']['n_vertices'] = args.n_vertices 366 | if args.noise != 0: 367 | config['data']['train']['noise'] = args.noise 368 | if args.edge_density != 0: 369 | config['data']['train']['edge_density'] = args.edge_density 370 | if args.block_init != 'block': 371 | config['arch']['node_emb']['block_init'] = args.block_init 372 | print(f"block_init override: {args.block_init}") 373 | if args.block_inside != 'block_inside': 374 | config['arch']['node_emb']['block_inside'] = args.block_inside 375 | print(f"block_inside override: {args.block_inside}") 376 | if args.node_emb != 'node_embedding_block': 377 | config['arch']['node_emb']['type'] = args.node_emb 378 | print(f"node_embedding override: {args.node_emb}") 379 | 380 | 381 | global ROOT_DIR 382 | ROOT_DIR = Path.home() 383 | global EXPE_DIR #= os.path.join(ROOT_DIR,'experiments-gnn/') 384 | EXPE_DIR = os.path.join(ROOT_DIR,'experiments-gnn/') 385 | global PB_DIR #= os.path.join(EXPE_DIR, config['problem']) 386 | PB_DIR = os.path.join(EXPE_DIR, config['problem']) 387 | global DATA_PB_DIR #= os.path.join(PB_DIR,'data/') 388 | DATA_PB_DIR = os.path.join(PB_DIR,'data/') 389 | name = custom_name(config) 390 | config = check_paths_update(config, name) 391 | trainer=None 392 | if training: 393 | trainer = train(config) 394 | if default_test: #or config['test_enabled']: 395 | res_test = test(config) 396 | if tuning: 397 | trainer = tune(config) 398 | 399 | if __name__=="__main__": 400 | pl.seed_everything(3787, workers=True) 401 | main() -------------------------------------------------------------------------------- /maskedtensors/maskedtensor.py: -------------------------------------------------------------------------------- 1 | """ Masked tensors to handle batches with mixed node numbers """ 2 | 3 | import itertools 4 | import functools 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | def from_list(tensor_list, dims, batch_name='B', base_name='N'): 9 | """ 10 | Build a masked tensor from a list of tensors 11 | Dims is a tuple of dimensions which should be masked 12 | The tensors are supposed to agree on the other dimensions (and dtype) 13 | """ 14 | dims = list(dims) 15 | n_dim = tensor_list[0].dim() 16 | batch_size = len(tensor_list) 17 | 18 | # Create names 19 | data_names = [None] * (n_dim + 1) 20 | data_names[0] = batch_name 21 | for i, dim in enumerate(dims): 22 | data_names[dim+1] = base_name + i * '_' 23 | 24 | # Compute sizes of data and mask 25 | data_size = [0] * (n_dim + 1) 26 | data_size[0] = batch_size 27 | for dim in range(n_dim): 28 | data_size[dim+1] = max((tens.size(dim) for tens in tensor_list)) 29 | 30 | # Fill data using padding 31 | data = torch.zeros(data_size, names=data_names, dtype=tensor_list[0].dtype) 32 | for i, tens in enumerate(tensor_list): 33 | # caution: dims for pad are specified from last to first 34 | data_padding = [[0, data_size[dim+1] - tens.size(dim)] for dim in range(n_dim)] 35 | data_padding = reversed(data_padding) 36 | data_padding = list(itertools.chain.from_iterable(data_padding)) 37 | data[i] = F.pad(tens, data_padding) 38 | 39 | # Build mask 40 | mask = {} 41 | for dim, name in enumerate(data.names): 42 | if dim >= 1 and name: 43 | mask[name] = torch.zeros((batch_size, data.size(name)), 44 | names=(batch_name, name), dtype=data.dtype) 45 | for i, tens in enumerate(tensor_list): 46 | mask[name][i, :tens.size(dim-1)] = 1 47 | 48 | return MaskedTensor(data, mask, adjust_mask=False, apply_mask=False) 49 | 50 | class MaskedTensor: 51 | """ 52 | Masked tensor class 53 | - Unless you know what you are doing, should not be created with __init__, 54 | use from_list instead 55 | - Mask is always copied; data is copied iff copy is set to True 56 | - Individual tensors of a masked tensor mt can be retrived using list(mt), 57 | iterating with for tensor in mt or with indexing mt[i] 58 | """ 59 | def __init__(self, data, mask, adjust_mask=True, apply_mask=False, copy=False, batch_name='B'): 60 | self.tensor = torch.tensor(data) if copy else data 61 | self.mask_dict = mask.copy() 62 | self._batch_name = batch_name 63 | #self._is_cuda = self.tensor.is_cuda 64 | self.dtype = self.tensor.dtype 65 | self.device = self.tensor.device 66 | if adjust_mask: 67 | self._adjust_mask_() 68 | if apply_mask: 69 | self.mask_() 70 | 71 | def __repr__(self): 72 | return "Data:\n{}\nMask:\n{}".format(self.tensor, self.mask_dict) 73 | 74 | ## Mask methods 75 | def _adjust_mask_(self): 76 | """ Check compatibily and remove unecessary masked dims """ 77 | # To prevent changing the iterator during iteration 78 | mask_keys = list(self.mask_dict.keys()) 79 | for name in mask_keys: 80 | mask_size = self.mask_dict[name].size(name) 81 | try: 82 | data_size = self.tensor.size(name) 83 | assert mask_size == data_size 84 | except RuntimeError: 85 | del self.mask_dict[name] 86 | 87 | def mask_(self): 88 | """ Mask data in place""" 89 | for mask in self.mask_dict.values(): 90 | self.tensor = self.tensor * mask.align_as(self.tensor) 91 | 92 | def mask(self): 93 | """ Return new MaskedTensor with masked adata """ 94 | return MaskedTensor(self.tensor, self.mask_dict, adjust_mask=False, 95 | apply_mask=True, copy=True) 96 | 97 | ## Torch function override 98 | @classmethod 99 | def __torch_function__(self, func, types, args=(), kwargs=None): 100 | """ 101 | Support torch.* functions, derived from pytorch doc 102 | See https://pytorch.org/docs/master/notes/extending.html 103 | """ 104 | if kwargs is None: 105 | kwargs = {} 106 | if func in SPECIAL_FUNCTIONS: 107 | return SPECIAL_FUNCTIONS[func](*args, **kwargs) 108 | new_args = [a.tensor if isinstance(a, MaskedTensor) else a for a in args] 109 | masks = (a.mask_dict for a in args if isinstance(a, MaskedTensor)) 110 | new_mask = dict(item for mask_dict in masks for item in mask_dict.items()) 111 | ret = func(*new_args, **kwargs) 112 | return MaskedTensor(ret, new_mask, adjust_mask=True, apply_mask=True) 113 | 114 | ## Iterator methods 115 | def __getitem__(self, index): 116 | item = self.tensor[index] 117 | names = item.names 118 | for dim, name in enumerate(names): 119 | if name: 120 | length = int(torch.sum(self.mask_dict[name][index]).item()) 121 | item = torch.narrow(item, dim, 0, length) 122 | return item.rename(None) 123 | 124 | def __len__(self): 125 | return self.tensor.size(self._batch_name) 126 | 127 | def __iter__(self): 128 | return (self.__getitem__(index) for index in range(self.__len__())) 129 | 130 | ## Tensor methods 131 | def size(self, *args): 132 | """ Return size of the underlying tensor """ 133 | return self.tensor.size(*args) 134 | 135 | def dim(self): 136 | return self.tensor.dim() 137 | 138 | def contiguous(self, *args): 139 | self.tensor = self.tensor.contiguous(*args) 140 | return self 141 | 142 | def view(self, *dims): 143 | """ only acting on named dim None which should be at the end 144 | i.e. not acting on masked dimensions or batch dimension""" 145 | names = self.tensor.names 146 | nameless_tensor = self.tensor.rename(None).view(*dims) 147 | new_names = [None] * len(dims) 148 | for (i,n) in enumerate(names): 149 | if i < len(dims): 150 | new_names[i] = n 151 | res_tensor = nameless_tensor.rename(*new_names) 152 | return MaskedTensor(res_tensor, self.mask_dict, adjust_mask=False, apply_mask=False) 153 | 154 | @property 155 | def shape(self): 156 | """ Return shape of the underlying tensor """ 157 | return self.tensor.size() 158 | 159 | @property 160 | def is_cuda(self): 161 | return self.tensor.is_cuda 162 | 163 | @property 164 | def get_device(self): 165 | return self.tensor.get_device() 166 | 167 | def permute(self, *dims): 168 | """ Permute the tensor """ 169 | # Unfortunately, permute is not yet implemented for named tensors 170 | # So we do it by hand 171 | if len(dims) != len(self.tensor.size()): 172 | raise ValueError 173 | names = self.tensor.names 174 | nameless_tensor = self.tensor.rename(None).permute(*dims) 175 | permuted_names = [names[dim] for dim in dims] 176 | res_tensor = nameless_tensor.rename(*permuted_names) 177 | return MaskedTensor(res_tensor, self.mask_dict, adjust_mask=False, apply_mask=False) 178 | 179 | 180 | def to(self, *args, **kwargs): 181 | """ Apply the method .to() to both tensor and mask """ 182 | new_dict = {name:mask.to(*args, **kwargs) for name, mask in self.mask_dict.items()} 183 | new_tensor = self.tensor.to(*args, **kwargs) 184 | return MaskedTensor(new_tensor, new_dict, adjust_mask=False, apply_mask=False) 185 | 186 | 187 | 188 | ### Torch function overrides 189 | SPECIAL_FUNCTIONS = {} 190 | 191 | def implements(torch_function): 192 | """ 193 | Register a torch function override for MaskedTensor 194 | See https://pytorch.org/docs/master/notes/extending.html 195 | """ 196 | @functools.wraps(torch_function) 197 | def decorator(func): 198 | SPECIAL_FUNCTIONS[torch_function] = func 199 | return func 200 | return decorator 201 | 202 | def get_dtype_min_value(dtype): 203 | """ Get the min value of given dtype, whether int or float """ 204 | try: 205 | return torch.finfo(dtype).min 206 | except TypeError: 207 | pass 208 | try: 209 | return torch.iinfo(dtype).min 210 | except TypeError: 211 | raise TypeError("dtype is neither float nor int") 212 | 213 | @implements(torch.max) 214 | def torch_max(masked_tensor, dim=None): 215 | """ Implements torch.max """ 216 | if dim is None: 217 | # !!! taking the max over the whole batch !!! 218 | return torch.max(masked_tensor.tensor) 219 | else: 220 | tensor = masked_tensor.tensor 221 | min_value = get_dtype_min_value(tensor.dtype) 222 | for mask in masked_tensor.mask_dict.values(): 223 | aligned_mask = mask.align_as(tensor) 224 | tensor = tensor * aligned_mask + min_value * (1 - aligned_mask) 225 | max_tensor, indices = torch.max(tensor, dim) 226 | new_masked_tensor = MaskedTensor(max_tensor, masked_tensor.mask_dict, 227 | adjust_mask=True, apply_mask=True) 228 | return new_masked_tensor, indices 229 | 230 | @implements(F.conv2d) 231 | def torch_conv2d(inp, *args, **kwargs): 232 | """ Implements conv2d on masked tensors """ 233 | # Unfortunately, conv2d does not support named tensors yet 234 | names = inp.tensor.names 235 | nameless_tensor = inp.tensor.rename(None) 236 | nameless_res_tensor = F.conv2d(nameless_tensor, *args, **kwargs) 237 | res_tensor = nameless_res_tensor.rename(*names) 238 | return MaskedTensor(res_tensor, inp.mask_dict, adjust_mask=False, apply_mask=True) 239 | 240 | @implements(F.linear) 241 | def torch_linear(inp, *args, **kwargs): 242 | """ Implements linear on masked tensors """ 243 | # Unfortunately, linear does not support named tensors yet 244 | names = inp.tensor.names 245 | nameless_tensor = inp.tensor.rename(None) 246 | nameless_res_tensor = F.linear(nameless_tensor, *args, **kwargs) 247 | res_tensor = nameless_res_tensor.rename(*names) 248 | return MaskedTensor(res_tensor, inp.mask_dict, adjust_mask=False, apply_mask=True) 249 | 250 | @implements(torch.cat) 251 | def torch_cat(tensors, dim=0): 252 | """ 253 | Implements torch.cat for masked tensors 254 | We have to implement it manually for the same reason as the issue 255 | mentionned below 256 | """ 257 | # Improvement: find a more elegant way when pytorch finds an elegant way 258 | # for the issues mentionned below 259 | new_args = [a.tensor if isinstance(a, MaskedTensor) else a for a in tensors] 260 | masks = (a.mask_dict for a in tensors if isinstance(a, MaskedTensor)) 261 | new_mask = dict(item for mask_dict in masks for item in mask_dict.items()) 262 | ret = torch.cat(new_args, dim=dim) 263 | return MaskedTensor(ret, new_mask, adjust_mask=False, apply_mask=False) 264 | 265 | def dispatch_cat(tensors, dim=0): 266 | """ 267 | Temporary workaround to dispatch issue with torch.cat 268 | See https://github.com/pytorch/pytorch/issues/34294 269 | """ 270 | tensor = tensors[0] 271 | if isinstance(tensor, torch.Tensor): 272 | return torch.cat(tensors, dim=dim) 273 | return tensor.__torch_function__(torch.cat, [type(t) for t in tensors], (tensors,), {'dim':dim}) 274 | 275 | @implements(torch.stack) 276 | def torch_stack(tensors, dim=0): 277 | """ 278 | same pb as above 279 | """ 280 | # Unfortunately, does not support named tensors yet... 281 | new_args = [a.tensor.rename(None) if isinstance(a, MaskedTensor) else a for a in tensors] 282 | names = [a.tensor.names if isinstance(a, MaskedTensor) else None for a in tensors] 283 | try: 284 | assert names[0] == names[1] 285 | except: 286 | print('trying to stack uncompatible masked tensors') 287 | masks = (a.mask_dict for a in tensors if isinstance(a, MaskedTensor)) 288 | new_mask = dict(item for mask_dict in masks for item in mask_dict.items()) 289 | ret = torch.stack(new_args, dim=dim) 290 | new_names = names[0][0:dim] + (None,) + names[0][dim:] 291 | return MaskedTensor(ret.refine_names(*new_names), new_mask, adjust_mask=True, apply_mask=False) 292 | 293 | def dispatch_stack(tensors, dim=0): 294 | tensor = tensors[0] 295 | if isinstance(tensor, torch.Tensor): 296 | return torch.stack(tensors, dim=dim) 297 | return tensor.__torch_function__(torch.stack, [type(t) for t in tensors], (tensors,), {'dim':dim}) 298 | 299 | 300 | @implements(torch.flatten) 301 | def torch_flatten(inp, start_dim=0, end_dim=-1): 302 | """ Implements torch.flatten """ 303 | # Unfortunately, does not support named tensors yet... 304 | names = inp.tensor.names 305 | new_names = names[0:start_dim] + (None,) + names[end_dim+1:] 306 | res_tensor = torch.flatten(inp.tensor.rename(None), start_dim=start_dim, end_dim=end_dim) 307 | res_tensor = res_tensor.refine_names(*new_names) 308 | return MaskedTensor(res_tensor, inp.mask_dict, adjust_mask=True, apply_mask=False) 309 | 310 | def get_sizes(masked_tensor, keepdim=False): 311 | # returns the number of non-masked entries 312 | full_mask = torch.ones_like(masked_tensor.tensor) 313 | names = tuple(masked_tensor.mask_dict.keys()) 314 | for mask in masked_tensor.mask_dict.values(): 315 | aligned_mask = mask.align_as(full_mask) 316 | full_mask = full_mask * aligned_mask 317 | return torch.sum(full_mask, dim=names, keepdim=keepdim) 318 | 319 | @implements(torch.mean) 320 | def torch_mean(masked_tensor, keepdim = False, *args, **kwargs): 321 | # returns a tensor 322 | # args are not taken into account! 323 | # computing the mean over the masked dimensions 324 | sizes = get_sizes(masked_tensor, keepdim=keepdim) 325 | names = tuple(masked_tensor.mask_dict.keys()) 326 | return torch.sum(masked_tensor.tensor, dim = names, keepdim=keepdim)/sizes 327 | 328 | @implements(torch.var) 329 | def torch_var(masked_tensor, keepdim = False, *args, **kwargs): 330 | # same restriction as above! 331 | sizes = get_sizes(masked_tensor, keepdim=keepdim) 332 | means = torch_mean(masked_tensor, keepdim=True) 333 | vars = MaskedTensor((masked_tensor.tensor - means)**2, masked_tensor.mask_dict, adjust_mask=False, apply_mask=True) 334 | names = tuple(vars.mask_dict.keys()) 335 | return torch.sum(vars.tensor, dim = names, keepdim=keepdim)/sizes 336 | 337 | @implements(F.instance_norm) 338 | def torch_instance_norm(masked_tensor, eps=1e-05, weight=None, bias =None, *args, **kwargs): 339 | """ Implements instance_norm on masked tensors 340 | only works for shape (b,f,n,n) when normalization is taken on (n,n) (InstanceNorm2d) 341 | for each feature f with track_running_stats=False 342 | """ 343 | # Unfortunately, InstanceNorm2d does not support named tensors yet 344 | means = torch_mean(masked_tensor, keepdim=True) 345 | var_s = torch_var(masked_tensor, keepdim=True) 346 | res_tensor = (masked_tensor.tensor - means)/torch.sqrt(var_s+eps) 347 | if (weight is not None) and (bias is not None): 348 | res_tensor = weight.reshape(1,weight.shape[0],1,1)*res_tensor+bias.reshape(1,bias.shape[0],1,1) 349 | return MaskedTensor(res_tensor, masked_tensor.mask_dict, adjust_mask=False, apply_mask=True) 350 | 351 | @implements(F.layer_norm) 352 | def torch_layer_norm(masked_tensor, *args, **kwargs): 353 | """ 354 | Implements layer_norm on masked tensors 355 | when applied accross channel direction (not accross masked dim!) 356 | https://github.com/pytorch/pytorch/issues/81985#issuecomment-1236143883 357 | """ 358 | names = masked_tensor.tensor.names 359 | nameless_tensor = masked_tensor.tensor.rename(None) 360 | nameless_res_tensor = F.layer_norm(nameless_tensor, *args, **kwargs) 361 | res_tensor = nameless_res_tensor.rename(*names) 362 | return MaskedTensor(res_tensor, masked_tensor.mask_dict, adjust_mask=False, apply_mask=True) 363 | 364 | @implements(torch.diag_embed) 365 | def torch_diag_embed(inp, offset=0, dim1=-2, dim2=-1, *args, **kwargs): 366 | names = inp.tensor.names 367 | new_name = names[-1]+'_' 368 | new_names = names + (new_name,) 369 | nameless_tensor = inp.tensor.rename(None) 370 | nameless_res_tensor = torch.diag_embed(nameless_tensor, offset=offset, dim1=dim1, dim2=dim2, *args, **kwargs) 371 | res_tensor = nameless_res_tensor.rename(*new_names) 372 | new_dict = inp.mask_dict 373 | new_mask = inp.mask_dict[names[-1]].rename(None) 374 | names_mask = inp.mask_dict[names[-1]].names[:-1] +(new_name,) 375 | new_dict[new_name] = new_mask.rename(*names_mask) 376 | return MaskedTensor(res_tensor, new_dict, adjust_mask=False, apply_mask=True) 377 | 378 | @implements(F.nll_loss) 379 | def torch_nll_loss(masked_tensor, target, *args, **kwargs): 380 | return F.nll_loss(masked_tensor.tensor.rename(None), target, *args, **kwargs) 381 | 382 | @implements(F.cross_entropy) 383 | def torch_cross_entropy(masked_tensor, target, *args, **kwargs): 384 | return F.cross_entropy(masked_tensor.tensor.rename(None), target, *args, **kwargs) -------------------------------------------------------------------------------- /plot_accuracy_regular.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github", 7 | "colab_type": "text" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "source": [ 16 | "# If running on Colab, uncomment the code below:\n", 17 | "#!git clone https://github.com/mlelarge/graph_neural_net.git\n", 18 | "#!pip install lightning\n", 19 | "#%cd graph_neural_net" 20 | ], 21 | "metadata": { 22 | "id": "7pu5pDuUHako", 23 | "outputId": "c6b3849d-ad8d-4da1-d219-c68c3f4da9aa", 24 | "colab": { 25 | "base_uri": "https://localhost:8080/" 26 | } 27 | }, 28 | "id": "7pu5pDuUHako", 29 | "execution_count": 1, 30 | "outputs": [ 31 | { 32 | "output_type": "stream", 33 | "name": "stdout", 34 | "text": [ 35 | "/content/graph_neural_net\n" 36 | ] 37 | } 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 2, 43 | "id": "245b4687", 44 | "metadata": { 45 | "id": "245b4687" 46 | }, 47 | "outputs": [], 48 | "source": [ 49 | "import os\n", 50 | "from pathlib import Path\n", 51 | "import json\n", 52 | "import numpy as np\n", 53 | "import torch\n", 54 | "import matplotlib.pyplot as plt\n", 55 | "\n", 56 | "import pytorch_lightning as pl\n", 57 | "from models import get_siamese_model_test\n", 58 | "import loaders.data_generator as dg\n", 59 | "from loaders.loaders import siamese_loader\n", 60 | "from toolbox.utils import check_dir\n", 61 | "import loaders.data_generator as dg\n", 62 | "from loaders.loaders import siamese_loader\n", 63 | "\n", 64 | "ROOT_DIR = Path.home()\n", 65 | "path_dataset = os.path.join(ROOT_DIR, 'data/')\n", 66 | "\n", 67 | "from toolbox.metrics import all_losses_acc, accuracy_linear_assignment\n", 68 | "from toolbox.losses import triplet_loss\n", 69 | "\n", 70 | "criterion = triplet_loss()" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "id": "2c7b53c3", 76 | "metadata": { 77 | "id": "2c7b53c3" 78 | }, 79 | "source": [ 80 | "# Downloading the pretrained model\n", 81 | "\n", 82 | "The cell below should only be run once, it creates a folder `downloads/` and then downloads in this folder the pretrained model and the configuration file." 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 3, 88 | "id": "d43f0179", 89 | "metadata": { 90 | "id": "d43f0179" 91 | }, 92 | "outputs": [], 93 | "source": [ 94 | "import requests\n", 95 | "config_url = 'https://github.com/mlelarge/graph_neural_net/releases/download/v1-qap/config.json'\n", 96 | "model_url = 'https://github.com/mlelarge/graph_neural_net/releases/download/v1-qap/epoch.98-step.7821.ckpt'\n", 97 | "cwd = os.getcwd()\n", 98 | "downloads = os.path.join(cwd, 'downloads')\n", 99 | "check_dir(downloads)\n", 100 | "\n", 101 | "r = requests.get(config_url)\n", 102 | "with open(cwd+'/downloads/config.json', 'wb') as f:\n", 103 | " f.write(r.content)\n", 104 | "\n", 105 | "r = requests.get(model_url)\n", 106 | "with open(cwd+'/downloads/epoch.98-step.7821.ckpt', 'wb') as f:\n", 107 | " f.write(r.content)\n", 108 | "" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 4, 114 | "id": "4a1497fa", 115 | "metadata": { 116 | "id": "4a1497fa" 117 | }, 118 | "outputs": [], 119 | "source": [ 120 | "name = cwd+'/downloads/epoch.98-step.7821.ckpt'\n", 121 | "path = cwd+'/downloads/'" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 5, 127 | "id": "fe9083af", 128 | "metadata": { 129 | "id": "fe9083af" 130 | }, 131 | "outputs": [], 132 | "source": [ 133 | "def get_device_config(path):\n", 134 | " config_file = os.path.join(path,'config.json')\n", 135 | " with open(config_file) as json_file:\n", 136 | " config_model = json.load(json_file)\n", 137 | " use_cuda = not config_model['cpu'] and torch.cuda.is_available()\n", 138 | " device = 'cuda' if use_cuda else 'cpu'\n", 139 | " return config_model, device" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 6, 145 | "id": "c95fc627", 146 | "metadata": { 147 | "id": "c95fc627" 148 | }, 149 | "outputs": [], 150 | "source": [ 151 | "config_model, device = get_device_config(path)" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 7, 157 | "id": "650423d8", 158 | "metadata": { 159 | "id": "650423d8", 160 | "outputId": "57edce5c-f30c-4a9c-9429-deeeb776f27f", 161 | "colab": { 162 | "base_uri": "https://localhost:8080/" 163 | } 164 | }, 165 | "outputs": [ 166 | { 167 | "output_type": "stream", 168 | "name": "stderr", 169 | "text": [ 170 | "INFO:pytorch_lightning.utilities.migration.utils:Lightning automatically upgraded your loaded checkpoint from v1.9.0 to v2.0.7. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file downloads/epoch.98-step.7821.ckpt`\n" 171 | ] 172 | } 173 | ], 174 | "source": [ 175 | "model = get_siamese_model_test(name, config_model)" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 8, 181 | "id": "ee2981c1", 182 | "metadata": { 183 | "id": "ee2981c1" 184 | }, 185 | "outputs": [], 186 | "source": [ 187 | "config_data = config_model['data']" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": 9, 193 | "id": "92709df1", 194 | "metadata": { 195 | "id": "92709df1" 196 | }, 197 | "outputs": [], 198 | "source": [ 199 | "num = 23\n", 200 | "list_noise = np.linspace(0, 0.22, num=num)" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": 10, 206 | "id": "964599b4", 207 | "metadata": { 208 | "id": "964599b4" 209 | }, 210 | "outputs": [], 211 | "source": [ 212 | "def compute_all(list_noise,config_data,path_dataset,model,device,bs=32):\n", 213 | " all_acc = np.zeros((len(list_noise),config_data['test']['num_examples_test']))\n", 214 | " for i,noise in enumerate(list_noise):\n", 215 | " config_data['test']['noise']=noise\n", 216 | " gene_test = dg.QAP_Generator('test', config_data['test'], path_dataset)\n", 217 | " gene_test.load_dataset()\n", 218 | " test_loader = siamese_loader(gene_test, bs, gene_test.constant_n_vertices, shuffle=False)\n", 219 | " _, all_acc[i,:] = all_losses_acc(test_loader,model,criterion,device,eval_score=accuracy_linear_assignment)\n", 220 | " return all_acc" 221 | ] 222 | }, 223 | { 224 | "cell_type": "markdown", 225 | "id": "2edc01ce", 226 | "metadata": { 227 | "id": "2edc01ce" 228 | }, 229 | "source": [ 230 | "# Inference\n", 231 | "\n", 232 | "The cell below will create dataset of graphs with various level of noise if they do not exist, otherwise it will only read them." 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": 11, 238 | "id": "9c9bd8cb", 239 | "metadata": { 240 | "id": "9c9bd8cb", 241 | "outputId": "6dc9930a-6068-410f-cbe2-211c0222ac31", 242 | "colab": { 243 | "base_uri": "https://localhost:8080/" 244 | } 245 | }, 246 | "outputs": [ 247 | { 248 | "output_type": "stream", 249 | "name": "stdout", 250 | "text": [ 251 | "Reading dataset at /root/data/QAP_Regular_ErdosRenyi_1000_50_1.0_0.0_0.2/test.pkl\n" 252 | ] 253 | }, 254 | { 255 | "output_type": "stream", 256 | "name": "stderr", 257 | "text": [ 258 | "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", 259 | " warnings.warn(_create_warning_msg(\n" 260 | ] 261 | }, 262 | { 263 | "output_type": "stream", 264 | "name": "stdout", 265 | "text": [ 266 | "Reading dataset at /root/data/QAP_Regular_ErdosRenyi_1000_50_1.0_0.01_0.2/test.pkl\n", 267 | "Reading dataset at /root/data/QAP_Regular_ErdosRenyi_1000_50_1.0_0.02_0.2/test.pkl\n", 268 | "Reading dataset at /root/data/QAP_Regular_ErdosRenyi_1000_50_1.0_0.03_0.2/test.pkl\n", 269 | "Reading dataset at /root/data/QAP_Regular_ErdosRenyi_1000_50_1.0_0.04_0.2/test.pkl\n", 270 | "Reading dataset at /root/data/QAP_Regular_ErdosRenyi_1000_50_1.0_0.05_0.2/test.pkl\n", 271 | "Reading dataset at /root/data/QAP_Regular_ErdosRenyi_1000_50_1.0_0.06_0.2/test.pkl\n", 272 | "Reading dataset at /root/data/QAP_Regular_ErdosRenyi_1000_50_1.0_0.07_0.2/test.pkl\n", 273 | "Reading dataset at /root/data/QAP_Regular_ErdosRenyi_1000_50_1.0_0.08_0.2/test.pkl\n", 274 | "Reading dataset at /root/data/QAP_Regular_ErdosRenyi_1000_50_1.0_0.09_0.2/test.pkl\n", 275 | "Reading dataset at /root/data/QAP_Regular_ErdosRenyi_1000_50_1.0_0.1_0.2/test.pkl\n", 276 | "Reading dataset at /root/data/QAP_Regular_ErdosRenyi_1000_50_1.0_0.11_0.2/test.pkl\n", 277 | "Reading dataset at /root/data/QAP_Regular_ErdosRenyi_1000_50_1.0_0.12_0.2/test.pkl\n", 278 | "Reading dataset at /root/data/QAP_Regular_ErdosRenyi_1000_50_1.0_0.13_0.2/test.pkl\n", 279 | "Reading dataset at /root/data/QAP_Regular_ErdosRenyi_1000_50_1.0_0.14_0.2/test.pkl\n", 280 | "Reading dataset at /root/data/QAP_Regular_ErdosRenyi_1000_50_1.0_0.15_0.2/test.pkl\n", 281 | "Reading dataset at /root/data/QAP_Regular_ErdosRenyi_1000_50_1.0_0.16_0.2/test.pkl\n", 282 | "Reading dataset at /root/data/QAP_Regular_ErdosRenyi_1000_50_1.0_0.17_0.2/test.pkl\n", 283 | "Reading dataset at /root/data/QAP_Regular_ErdosRenyi_1000_50_1.0_0.18_0.2/test.pkl\n", 284 | "Reading dataset at /root/data/QAP_Regular_ErdosRenyi_1000_50_1.0_0.19_0.2/test.pkl\n", 285 | "Reading dataset at /root/data/QAP_Regular_ErdosRenyi_1000_50_1.0_0.2_0.2/test.pkl\n", 286 | "Reading dataset at /root/data/QAP_Regular_ErdosRenyi_1000_50_1.0_0.21_0.2/test.pkl\n", 287 | "Reading dataset at /root/data/QAP_Regular_ErdosRenyi_1000_50_1.0_0.22_0.2/test.pkl\n" 288 | ] 289 | } 290 | ], 291 | "source": [ 292 | "all_a = compute_all(list_noise,config_data,path_dataset,model,device);" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": 12, 298 | "id": "f8000455", 299 | "metadata": { 300 | "id": "f8000455" 301 | }, 302 | "outputs": [], 303 | "source": [ 304 | "def compute_quant(all_acc,quant_low=0.1,quant_up=0.9):\n", 305 | " median_acc = np.median(all_acc,1)\n", 306 | " num = len(median_acc)\n", 307 | " q_acc = np.zeros((num,2))\n", 308 | " for i in range(num):\n", 309 | " q_acc[i,:] = np.quantile(all_acc[i,:],[quant_low, quant_up])\n", 310 | " return median_acc, q_acc\n", 311 | "\n", 312 | "def acc_2_error(median_acc, q_acc):\n", 313 | " error = q_acc-median_acc[:,np.newaxis]\n", 314 | " error[:,0] = -error[:,0]\n", 315 | " return error" 316 | ] 317 | }, 318 | { 319 | "cell_type": "markdown", 320 | "id": "343eac94", 321 | "metadata": { 322 | "id": "343eac94" 323 | }, 324 | "source": [ 325 | "# Results\n", 326 | "\n", 327 | "The FGNN has been trained with regular graphs with $50$ vertices, average degree $10$ and noise level $0.1$. The accuracy below is the fraction of matched vertices between two noisy versions of a given graph at various level of noise." 328 | ] 329 | }, 330 | { 331 | "cell_type": "code", 332 | "execution_count": 13, 333 | "id": "e498d6e5", 334 | "metadata": { 335 | "id": "e498d6e5" 336 | }, 337 | "outputs": [], 338 | "source": [ 339 | "quant_low=0.1\n", 340 | "quant_up=0.9\n", 341 | "mc_50, q50 = compute_quant(all_a,quant_low=quant_low,quant_up=quant_up)" 342 | ] 343 | }, 344 | { 345 | "cell_type": "code", 346 | "execution_count": 14, 347 | "id": "1cebcdf9", 348 | "metadata": { 349 | "id": "1cebcdf9", 350 | "outputId": "88428e6c-9710-479c-a613-05c40223f72f", 351 | "colab": { 352 | "base_uri": "https://localhost:8080/", 353 | "height": 472 354 | } 355 | }, 356 | "outputs": [ 357 | { 358 | "output_type": "display_data", 359 | "data": { 360 | "text/plain": [ 361 | "
" 362 | ], 363 | "image/png": "\n" 364 | }, 365 | "metadata": {} 366 | } 367 | ], 368 | "source": [ 369 | "error_50 = acc_2_error(mc_50,q50)\n", 370 | "\n", 371 | "plt.errorbar(list_noise,mc_50,yerr=error_50.T,label='FGNN (median)');\n", 372 | "plt.xlabel('noise (fraction of noisy edges)')\n", 373 | "plt.ylabel('accuracy')\n", 374 | "plt.title(f'Graphs with avg. degree 10 (quantiles {int(100*quant_low)}-{int(100*quant_up)}%)')\n", 375 | "plt.legend()\n", 376 | "plt.show()" 377 | ] 378 | }, 379 | { 380 | "cell_type": "code", 381 | "execution_count": 14, 382 | "id": "41802f3f", 383 | "metadata": { 384 | "id": "41802f3f" 385 | }, 386 | "outputs": [], 387 | "source": [] 388 | } 389 | ], 390 | "metadata": { 391 | "kernelspec": { 392 | "display_name": "gnn12", 393 | "language": "python", 394 | "name": "gnn12" 395 | }, 396 | "language_info": { 397 | "codemirror_mode": { 398 | "name": "ipython", 399 | "version": 3 400 | }, 401 | "file_extension": ".py", 402 | "mimetype": "text/x-python", 403 | "name": "python", 404 | "nbconvert_exporter": "python", 405 | "pygments_lexer": "ipython3", 406 | "version": "3.8.2" 407 | }, 408 | "colab": { 409 | "provenance": [], 410 | "include_colab_link": true 411 | } 412 | }, 413 | "nbformat": 4, 414 | "nbformat_minor": 5 415 | } --------------------------------------------------------------------------------