├── README.md ├── dfs.py ├── dijkstra.py ├── enviroment_nobuilds.yml ├── generate_dataset.py ├── generate_dataset_2.py ├── graph_generation.py ├── kahn.py ├── mpnn.py ├── mpnn_2.py ├── papers ├── graph_attention_networks.pdf ├── neural_execution_engines.pdf ├── neural_execution_of_graphs.pdf └── neural_message_passing_networks.pdf ├── requirements.txt ├── tasks.txt ├── training.py └── training_2.py /README.md: -------------------------------------------------------------------------------- 1 | # Extending the Neural Graph Algorithm Executor 2 | 3 | ## Description 4 | 5 | *Topic*: algorithms/data structures, graph neural networks, learning-to-execute 6 | 7 | *Category*: implementation + research 8 | 9 | In recent work [1, 2], the utility of *graph neural networks* as algorithm executors has been demonstrated. Naturally, one might wish to extend these insights to a broader space of algorithms. In this project, you would be expected to teach neural networks to execute one or more additional parallel or sequential graph algorithms, previously unexplored by related work (some suggestions: depth-first search, Dijkstra’s algorithm or topological sorting). The students will have freedom to choose the depth/breadth of the project (e.g. whether to focus just on one algorithm with in-depth studies, or to explore multiple algorithms at once). 10 | 11 | ## Setup 12 | 13 | 1. Clone the main repo 14 | ```sh 15 | https://github.com/gabrielbarcik/graphnets.git 16 | ``` 17 | 18 | 2. Create the conda environment (graphnet) `environment_nobuilds.yml`. 19 | 20 | ```sh 21 | conda env create -f environment_nobuilds.yml 22 | ``` 23 | 24 | 25 | ## Resources 26 | [1] Veličković, P., Ying, R., Padovano, M., Hadsell, R. and Blundell, C. (2019). Neural Execution of Graph Algorithms. arXiv preprint arXiv:1910.10593 27 | 28 | [2] Anonymous (2019). Neural Execution Engines. Submitted to ICLR 2020. 29 | 30 | ## Authors 31 | 32 | 33 | * **Gabriel Fedrigo Barcik** - [mail](gbarcike@gmail.com) 34 | * **Louis Dumont** - [mail](louis.dumont@eleves.enpc.fr) 35 | 36 | -------------------------------------------------------------------------------- /dfs.py: -------------------------------------------------------------------------------- 1 | # Implements the "clasical" computation of DFS 2 | 3 | import numpy as np 4 | import networkx as nx 5 | from matplotlib import pyplot as plt 6 | from graph_generation import GraphGenerator 7 | 8 | 9 | class DFS: 10 | 11 | def __init__(self): 12 | pass 13 | 14 | def decode_last_state(self, x): 15 | nb_seen = np.sum(np.where(x < 0, 1, 0)) 16 | sort = np.argsort(np.where(x < 0, -x, float('inf'))) 17 | sort[nb_seen:] = -1 18 | return sort 19 | 20 | def run(self, graph, root=0): 21 | ''' 22 | Parameters 23 | ---------- 24 | graph: NetworkX Graph instance 25 | The graph on which the algorithm should be run 26 | root: index of the node that should be used as root for the DFS 27 | 28 | Returns: 29 | -------- 30 | The history of x (states) when executing the DFS algorithm, and the DFS 31 | output 32 | ''' 33 | 34 | E = nx.to_numpy_matrix(graph) 35 | x = self.initialize_x(graph, root) 36 | history = [x.copy()] 37 | 38 | while np.max(x) > 0: 39 | x = self.iter_DFS(graph, x, E) 40 | history.append(x.copy()) 41 | 42 | return np.asarray(history), self.decode_last_state(x) 43 | 44 | 45 | def initialize_x(self, graph, root=0): 46 | ''' 47 | Parameters 48 | ---------- 49 | graph: NetworkX Graph instance 50 | The graph on which the algorithm should be run 51 | root: index of the node that should be used as a root for the DFS 52 | 53 | Returns: 54 | -------- 55 | Initialized numpy representation of the graph, as used by our DFS implementation 56 | ''' 57 | 58 | nb_nodes = graph.number_of_nodes() 59 | x = np.zeros((nb_nodes)) 60 | x[root] = 1 61 | 62 | return x 63 | 64 | 65 | def iter_DFS(self, graph, x, E): 66 | ''' 67 | Parameters 68 | ---------- 69 | x: numpy array 70 | array of the node's features. 71 | At initialization, x[i] should be 1 for the source node and 0 otherwise 72 | E: numpy array 73 | adjacency matrix. E[i,j]=1 indicates a edge from node i to node j 74 | 75 | Returns 76 | ------- 77 | Modifies x, using our DFS algorithm 78 | ''' 79 | 80 | i0 = np.argmax(x) # Select the node with highest rank 81 | 82 | next_label = x[i0] + 1 # Detect rank to assign 83 | 84 | x[i0] = np.amin(x) - 1 # Mark node as seen. Implicitely encodes the position in which the node was seen 85 | 86 | neigh = np.argwhere(E[i0] > 0)[:,1] # Select the neighbours of this node 87 | 88 | neigh = sorted(neigh, key=lambda id: graph.nodes[id]['priority']) 89 | 90 | #neigh.reverse() 91 | 92 | for ind in neigh: 93 | # If son was not explored, update it it 94 | if x[ind] >= 0: # Not == 0: we want to update the node's sons priority even if they are son of a shallower node 95 | x[ind] = next_label # Mark the sons with highest rank, so that it is explored in priority 96 | next_label += 1 # Update highest rank 97 | 98 | return x 99 | 100 | 101 | if __name__=="__main__": 102 | # graph = nx.balanced_tree(2,3) 103 | root= 2 104 | generator = GraphGenerator() 105 | graph = generator.gen_graph_type(10, 'erdos_renyi') 106 | 107 | dfs = DFS() 108 | 109 | hs, output = dfs.run(graph) 110 | print('dfs output: {}'.format(output)) 111 | print(hs) 112 | 113 | labels = dict((n, [n, np.around(d['priority'], decimals=2)]) for n, d in graph.nodes(data=True)) 114 | nx.draw(graph, labels=labels) 115 | plt.show() 116 | 117 | -------------------------------------------------------------------------------- /dijkstra.py: -------------------------------------------------------------------------------- 1 | # Implements the "clasical" computation of Dijkstra algorithm 2 | 3 | import numpy as np 4 | import networkx as nx 5 | from matplotlib import pyplot as plt 6 | from graph_generation import GraphGenerator 7 | 8 | 9 | class Dijkstra: 10 | 11 | def __init__(self): 12 | pass 13 | 14 | def decode_last_state(self, x): 15 | # to do: recover solution from x 16 | return None 17 | 18 | def run(self, graph, root=0): 19 | ''' 20 | Parameters 21 | ---------- 22 | graph: NetworkX Graph instance 23 | The graph on which the algorithm should be run 24 | root: index of the node that should be used as the source for Dijkstra 25 | 26 | Returns: 27 | -------- 28 | The history of x, p (states) when executing the Dijkstra algorithm, and 29 | the Dijkstra output 30 | ''' 31 | 32 | E = nx.to_numpy_matrix(graph) 33 | 34 | # set infinity as sum(weights) + 1 35 | inf = sum([w[2] for w in graph.edges.data('weight')]) + 1 36 | print('inf set to be: {}'.format(inf)) 37 | 38 | x, p = self.initialize_states(graph, inf, root) 39 | 40 | history = [x.copy(), p.copy()] 41 | 42 | # stop when the smallest tentative distance among unvisited nodes +inf 43 | while np.min(x[x[:, 0] == 1][:, 1], initial=inf) != inf: 44 | x, p = self.iter_dijkstra(graph, x, p, E) 45 | history.append((x.copy(), p.copy())) 46 | print(x, p) 47 | 48 | return np.asarray(history) 49 | 50 | 51 | def initialize_states(self, graph, inf, root=0): 52 | ''' 53 | Parameters 54 | ---------- 55 | graph: NetworkX Graph instance 56 | The graph on which the algorithm should be run 57 | root: index of the node that should be used as the source for Dijkstra 58 | 59 | Returns: 60 | -------- 61 | Initialized numpy representation of the graph, as used by our Dijkstra implementation 62 | x[i] contains two fields (unvisited (bool), distance to source (float)) 63 | p[i] contain an integer that represents the previous node to get to i 64 | ''' 65 | 66 | nb_nodes = graph.number_of_nodes() 67 | x = np.ones((nb_nodes, 2)) 68 | x[:, 1] = inf 69 | x[root] = [1, 0] 70 | 71 | p = -1 * np.ones(nb_nodes) 72 | p[root] = root 73 | 74 | return x, p 75 | 76 | 77 | def iter_dijkstra(self, graph, x, p, E): 78 | ''' 79 | Parameters 80 | ---------- 81 | x: numpy array 82 | array of the node's features. 83 | At initialization, x[i] should be 1 for the source node and 0 otherwise 84 | E: numpy array 85 | adjacency matrix. E[i,j]>0 indicates a edge from node i to node j 86 | 87 | Returns 88 | ------- 89 | Modifies x and p using our Dijkstra algorithm 90 | ''' 91 | 92 | # minimum distance of unvisited nodes 93 | min_dist = np.min(x[x[:, 0] == 1][:, 1]) 94 | i0 = np.argwhere(x[:, 1] == min_dist)[0][0] 95 | 96 | # select the neighbours of this node 97 | neigh = np.argwhere(E[i0]>0)[:,1] 98 | 99 | for v in neigh: 100 | # update only unvisited nodes 101 | if x[v][0] == 1: 102 | if E[i0, v] + x[i0][1] < x[v][1]: 103 | # x[v][1] = min(x[v][1], E[i0, v] + x[i0][1]) 104 | x[v][1] = E[i0, v] + x[i0][1] 105 | p[v] = i0 106 | 107 | # mark current node as visited 108 | x[i0][0] = 0 109 | 110 | return x, p 111 | 112 | 113 | if __name__=="__main__": 114 | root= 2 115 | generator = GraphGenerator() 116 | graph = generator.gen_graph_type(5, 'erdos_renyi', set_weights=True) 117 | 118 | dijkstra = Dijkstra() 119 | 120 | hs = dijkstra.run(graph) 121 | print(hs) 122 | 123 | labels = dict((n, [n, np.around(d['priority'], decimals=2)]) for n, d in graph.nodes(data=True)) 124 | nx.draw(graph, labels=labels) 125 | pos = nx.spring_layout(graph) 126 | edges = nx.get_edge_attributes(graph, 'weight') 127 | edges = {e: np.around(w, decimals=2) for e, w in edges.items()} 128 | print(edges) 129 | 130 | plt.show() 131 | 132 | -------------------------------------------------------------------------------- /enviroment_nobuilds.yml: -------------------------------------------------------------------------------- 1 | name: graphnet 2 | channels: 3 | - pytorch 4 | - anaconda 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1 8 | - blas=1.0 9 | - ca-certificates=2019.10.16 10 | - certifi=2019.9.11 11 | - cffi=1.13.2 12 | - cudatoolkit=10.1.243 13 | - cycler=0.10.0 14 | - dbus=1.13.12 15 | - decorator=4.4.1 16 | - expat=2.2.6 17 | - fontconfig=2.13.0 18 | - freetype=2.9.1 19 | - glib=2.63.1 20 | - gst-plugins-base=1.14.0 21 | - gstreamer=1.14.0 22 | - icu=58.2 23 | - intel-openmp=2019.4 24 | - joblib=0.14.0 25 | - jpeg=9b 26 | - kiwisolver=1.1.0 27 | - libedit=3.1.20181209 28 | - libffi=3.2.1 29 | - libgcc-ng=9.1.0 30 | - libgfortran-ng=7.3.0 31 | - libpng=1.6.37 32 | - libstdcxx-ng=9.1.0 33 | - libuuid=1.0.3 34 | - libxcb=1.13 35 | - libxml2=2.9.9 36 | - matplotlib=3.1.1 37 | - mkl=2019.4 38 | - mkl-service=2.3.0 39 | - mkl_fft=1.0.15 40 | - mkl_random=1.1.0 41 | - ncurses=6.1 42 | - networkx=2.4 43 | - ninja=1.9.0 44 | - numpy=1.17.3 45 | - numpy-base=1.17.3 46 | - openssl=1.1.1 47 | - pandas=0.25.3 48 | - pcre=8.43 49 | - pip=19.3.1 50 | - pycparser=2.19 51 | - pyparsing=2.4.5 52 | - pyqt=5.9.2 53 | - python=3.6.9 54 | - python-dateutil=2.8.1 55 | - pytorch=1.3.1 56 | - pytz=2019.3 57 | - qt=5.9.7 58 | - readline=7.0 59 | - scikit-learn=0.21.3 60 | - scipy=1.3.1 61 | - setuptools=42.0.1 62 | - sip=4.19.8 63 | - six=1.13.0 64 | - sqlite=3.30.1 65 | - tk=8.6.8 66 | - tornado=6.0.3 67 | - wheel=0.33.6 68 | - xz=5.2.4 69 | - zlib=1.2.11 70 | prefix: /home/barcik/miniconda3/envs/graphnet 71 | 72 | -------------------------------------------------------------------------------- /generate_dataset.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import numpy as np 3 | from matplotlib import pyplot as plt 4 | from dfs import DFS 5 | from graph_generation import GraphGenerator 6 | 7 | 8 | class DatasetGenerator: 9 | 10 | def __init__(self): 11 | self.graph_generator = GraphGenerator() 12 | 13 | def run(self, graph_type, nb_graphs, nb_nodes, algorithm_type): 14 | graphs = [] 15 | dataset = [] 16 | next_nodes = [] 17 | 18 | for _ in range(nb_graphs): 19 | 20 | if algorithm_type == 'DFS': 21 | dfs = DFS() 22 | graph = self.graph_generator.gen_graph_type(nb_nodes, graph_type) 23 | graphs.append(graph) 24 | history, _ = dfs.run(graph) 25 | dataset.append(history) 26 | # Generate the "next node" data 27 | next_nodes.append(np.asarray([np.where(history[i]-history[i+1]>0, 1, 0) for i in range(history.shape[0]-1)])) 28 | 29 | return graphs, np.asarray(dataset), np.asarray(next_nodes) 30 | 31 | 32 | if __name__ == '__main__': 33 | graph_type = 'erdos_renyi' 34 | nb_graphs = 3 35 | nb_nodes = 8 36 | algorithm_type = 'DFS' 37 | 38 | data_gen = DatasetGenerator() 39 | graphs, dataset, next_nodes = data_gen.run(graph_type, nb_graphs, nb_nodes, 40 | algorithm_type) 41 | 42 | print(dataset, [np.argmax(next_node, axis=1) for next_node in next_nodes]) 43 | 44 | -------------------------------------------------------------------------------- /generate_dataset_2.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import numpy as np 3 | from matplotlib import pyplot as plt 4 | from dfs import DFS 5 | from graph_generation import GraphGenerator 6 | 7 | 8 | class DatasetGenerator: 9 | 10 | def __init__(self): 11 | self.graph_generator = GraphGenerator() 12 | 13 | def run(self, graph_type, nb_graphs, nb_nodes, algorithm_type): 14 | graphs = [] 15 | dataset = [] 16 | next_nodes = [] 17 | 18 | 19 | for _ in range(nb_graphs): 20 | 21 | if algorithm_type == 'DFS': 22 | dfs = DFS() 23 | graph = self.graph_generator.gen_graph_type(nb_nodes, graph_type) 24 | graphs.append(graph) 25 | history, _ = dfs.run(graph) 26 | dataset.append(history) 27 | 28 | # Generate the "next node" data 29 | states = [] 30 | states.append(history[0]) 31 | for i in range(1, history.shape[0] - 1): 32 | idx = np.where(history[i] - history[i+1] > 0)[0][0] 33 | s = states[i-1].copy() 34 | s[idx] = 1 35 | states.append(s) 36 | 37 | next_nodes.append(states) 38 | 39 | return graphs, np.asarray(next_nodes) 40 | 41 | 42 | if __name__ == '__main__': 43 | graph_type = 'erdos_renyi' 44 | nb_graphs = 3 45 | nb_nodes = 8 46 | algorithm_type = 'DFS' 47 | 48 | data_gen = DatasetGenerator() 49 | graphs, dataset, next_nodes = data_gen.run(graph_type, nb_graphs, nb_nodes, algorithm_type) 50 | 51 | print(dataset, [np.argmax(next_node, axis=1) for next_node in next_nodes]) 52 | -------------------------------------------------------------------------------- /graph_generation.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import numpy as np 3 | from matplotlib import pyplot as plt 4 | 5 | 6 | class GraphGenerator: 7 | def __init__(self): 8 | pass 9 | 10 | def gen_graph_type(self, nb_nodes, graph_type, set_weights=False): 11 | g = None 12 | 13 | if graph_type == 'gn_graph': 14 | g = nx.gn_graph(nb_nodes) 15 | 16 | if graph_type == 'ladder': 17 | g = nx.ladder_graph(nb_nodes) 18 | 19 | if graph_type == 'grid': 20 | g = nx.grid_2d_graph(nb_nodes, nb_nodes // 2 + 1) 21 | 22 | if graph_type == 'erdos_renyi': 23 | p = min(np.log(nb_nodes)/nb_nodes, 0.5) 24 | g = nx.erdos_renyi_graph(n=nb_nodes, p=p, directed=False) 25 | 26 | if graph_type == 'barabasi_albert': 27 | nb_neighs = 5 28 | g = nx.barabasi_albert_graph(n=nb_nodes, m=nb_neighs) 29 | 30 | if graph_type == '4_caveman': 31 | # l (int) – Number of groups 32 | # k (int) – Size of cliques 33 | # p (float) – Probabilty of rewiring each edge. 34 | g = nx.relaxed_caveman_graph(l=4, k=5, p=0.3) 35 | 36 | if g is None: 37 | raise ValueError 38 | else: 39 | # give a priority to each node 40 | priorities = {i: p for i, p in 41 | enumerate(np.random.uniform(0.2, 1, len(g.nodes)))} 42 | nx.set_node_attributes(g, priorities, name='priority') 43 | 44 | # give a weight to each edge 45 | if set_weights: 46 | weights = np.random.uniform(0.2, 1, len(g.edges)) 47 | weights = {e: weights[i] for i, e in enumerate(g.edges)} 48 | nx.set_edge_attributes(g, weights, 'weight') 49 | return g 50 | 51 | 52 | if __name__ == '__main__': 53 | gen = GraphGenerator() 54 | g = gen.gen_graph_type(10, 'ladder') 55 | nx.draw(g.to_directed()) 56 | plt.show() 57 | print(nx.adjacency_matrix(g).todense()) 58 | print(g.nodes(data=True)) 59 | -------------------------------------------------------------------------------- /kahn.py: -------------------------------------------------------------------------------- 1 | # Implements the "clasical" computation of Kahn 2 | 3 | import numpy as np 4 | import networkx as nx 5 | from matplotlib import pyplot as plt 6 | from graph_generation import GraphGenerator 7 | 8 | 9 | class Kahn: 10 | 11 | def __init__(self): 12 | pass 13 | 14 | def decode_last_state(self, x): 15 | # argsort of the last iteration, with handeling of the unseen nodes 16 | nb_seen = np.sum(np.where(x[:,0] < 0, 1, 0)) 17 | sort = np.argsort(np.where(x[:,0] < 0, -x[:,0], float('inf'))) 18 | sort[nb_seen:] = -1 19 | return sort 20 | 21 | def run(self, graph): 22 | ''' 23 | Parameters 24 | ---------- 25 | graph: NetworkX Graph instance 26 | The graph on which the algorithm should be run 27 | 28 | Returns: 29 | -------- 30 | The history of x (states) when executing the Kahn algorithm, and the 31 | execution list output 32 | ''' 33 | 34 | E = nx.to_numpy_matrix(graph) 35 | x = self.initialize_x(graph) 36 | history = [x.copy()] 37 | 38 | # Stopping condition is when no node can be processed 39 | while np.any(np.isin(0, x[:,1])): 40 | x = self.iter_Kahn(graph, x, E) 41 | history.append(x.copy()) 42 | 43 | return np.asarray(history), self.decode_last_state(x) 44 | 45 | 46 | def initialize_x(self, graph): 47 | ''' 48 | Parameters 49 | ---------- 50 | graph: NetworkX Graph instance 51 | The graph on which the algorithm should be run 52 | 53 | Returns: 54 | -------- 55 | Initialized numpy representation of the graph, as used by our Kahn implementation 56 | ''' 57 | E = nx.to_numpy_matrix(graph) 58 | nb_nodes = graph.number_of_nodes() 59 | 60 | def get_degree(E, ind): 61 | return np.sum(E[:,ind]) 62 | 63 | x = np.array([(get_degree(E,i), -1) for i in range(nb_nodes)]) 64 | 65 | # Nodes with degree 0 are the ones that can be executed right away 66 | free_idx = np.argwhere(x[:,0]==0) 67 | free_nodes = [free_idx[i][0] for i in range(free_idx.size)] 68 | 69 | if not free_nodes: 70 | print('No free nodes') 71 | return x 72 | 73 | for i in free_nodes: 74 | x[i][1] = 0 75 | 76 | return x 77 | 78 | 79 | def iter_Kahn(self, graph, x, E): 80 | ''' 81 | Parameters 82 | ---------- 83 | x: numpy array 84 | array of the node's features. 85 | At initialization, x[i] should be 1 for the source node and 0 otherwise 86 | E: numpy array 87 | adjacency matrix. E[i,j]=1 indicates a edge from node i to node j 88 | 89 | Returns 90 | ------- 91 | Modifies x, using our Kahn algorithm 92 | ''' 93 | 94 | available_nodes = np.argwhere(x[:,1]==0) 95 | 96 | # Getting the prioritary node 97 | i0 = sorted([available_nodes[i][0] for i in range(available_nodes.size)], key=lambda id: graph.nodes[id]['priority'])[-1] 98 | 99 | m = np.amin(x[:,0]) 100 | 101 | # Set the node as seen 102 | x[i0, 1] = 1 103 | # Store its execution time in the labels 104 | x[i0, 0] = m-1 105 | 106 | # Get all nodes the depend on its execution 107 | neigh = np.argwhere(E[i0] == 1) 108 | 109 | for ind in neigh[:,1]: 110 | # Decrease the degree 111 | x[ind, 0] -= 1 112 | 113 | if x[ind, 0] == 0: 114 | # If the degree reaches zero, set the node as able to be processed 115 | x[ind, 1] = 0 116 | 117 | return x 118 | 119 | 120 | if __name__=="__main__": 121 | # graph = nx.balanced_tree(2,3) 122 | root= 2 123 | generator = GraphGenerator() 124 | 125 | # Ensure that we generate a directed acyclic graph 126 | graph = generator.gen_graph_type(10, 'gn_graph') 127 | print(nx.to_numpy_matrix(graph)) 128 | 129 | 130 | kahn = Kahn() 131 | 132 | hs, output = kahn.run(graph) 133 | print(kahn.run(graph)[1]) 134 | 135 | E = nx.to_numpy_matrix(graph) 136 | x = kahn.initialize_x(graph) 137 | print(x) 138 | 139 | while np.any(np.isin(0, x[:,1])): 140 | kahn.iter_Kahn(graph, x,E) 141 | print(x) 142 | 143 | print('Kahn output: {}'.format(kahn.decode_last_state(x))) 144 | 145 | labels = dict((n, [n, np.around(d['priority'], decimals=2)]) for n, d in graph.nodes(data=True)) 146 | nx.draw(graph, labels=labels) 147 | plt.show() 148 | 149 | -------------------------------------------------------------------------------- /mpnn.py: -------------------------------------------------------------------------------- 1 | # Define the MNN model that will be trained 2 | 3 | # Part of the code is based on https://github.com/timlacroix/nri_practical_session 4 | 5 | import dgl 6 | from generate_dataset import DatasetGenerator 7 | import torch 8 | import torch.nn as nn 9 | 10 | # Setting the seed for replicability 11 | import random 12 | random.seed(33) 13 | 14 | # Do not use DGL in the end? 15 | 16 | # A first NN for a single algorithm 17 | # Define the MPNN module 18 | # (Use of DFS for the first exemple) 19 | 20 | class MPNN(nn.Module): 21 | # Expects dgl graphs as inputs 22 | def __init__(self, in_feats, hidden_feats, edge_feats, out_feats, useCuda=False): 23 | super(MPNN, self).__init__() 24 | self.n_hid = hidden_feats 25 | self.encoder = nn.Linear(in_feats + hidden_feats +1, hidden_feats) # +1 is for the weights (needed so far, might be removed later) 26 | self.M = nn.Linear( hidden_feats * 2 + edge_feats, 32) 27 | self.U = nn.Linear(hidden_feats * 2 , hidden_feats) 28 | #self.decoder = Linear_layer(hidden_feats * 2 , in_feats) # "first" version, does not account for next node prediction 29 | self.decoder_nextnode = nn.Linear(hidden_feats * 2 , 1) # output "energy" will be soft-maxed to predict next node 30 | self.decoder_update = nn.Linear(hidden_feats*2+1, in_feats) # takes the same inputs + next_node "energy" and computes updates 31 | self.termination = nn.Linear(hidden_feats , 1) # Find a way to have only 1 outputs whatever the graph size is 32 | self.useCuda = useCuda 33 | 34 | def compute_send_messages(self, edges): 35 | # The argument is a batch of edges. 36 | # This computes a (batch of) message called 'msg' using the source node's feature 'h'. 37 | z_src = edges.src['z'] 38 | z_dst = edges.dst['z'] 39 | 40 | msg = self.M(torch.cat([z_src, z_dst, edges.data['features'].view(-1,1)], 1)) 41 | return {'msg' : msg} 42 | 43 | def max_reduce_messages(self, nodes): 44 | # The argument is a batch of nodes. 45 | # This computes the new 'h' features by summing received 'msg' in each node's mailbox. 46 | #return {'u_input' : torch.sum(nodes.mailbox['msg'], dim=1)} # for sum and mean: add '' 47 | return {'u_input' : torch.max(nodes.mailbox['msg'], dim=1).values} 48 | 49 | 50 | # A step corresponds to 1 iteration of the network: 51 | # Giving the state of the graph after one iteration of the algrithm 52 | def step(self, graph, inputs, hidden): 53 | 54 | # Helpers to stack conviniently z and e 55 | n_atoms = inputs.size(0) 56 | id1 = torch.LongTensor(sum([[i] * n_atoms for i in range(n_atoms)], [])) 57 | id2 = torch.LongTensor(sum([list(range(n_atoms)) for i in range(n_atoms)], [])) 58 | 59 | # Encoding x^t and h^{t-1} 60 | inputs = inputs.view(-1,1) 61 | inp = torch.cat([inputs, hidden, graph.ndata['priority'].view(-1,1)], 1) 62 | z = self.encoder(inp) 63 | graph.ndata['z'] = z 64 | 65 | # Processor 66 | # without dgl: messages = self.M(stack) but hard to pass on messages 67 | # Extract the aggregation(max) of messages at each position 68 | # Easier with DGL: 69 | graph.send(graph.edges(), self.compute_send_messages) 70 | # trigger aggregation at all nodes 71 | graph.recv(graph.nodes(), self.max_reduce_messages) 72 | 73 | u_input = graph.ndata.pop('u_input') 74 | 75 | new_hidden = self.U(torch.cat([z, u_input], 1)) 76 | 77 | # Stoping criterion for the next step 78 | H_mean = torch.mean(new_hidden, dim=0, keepdim=True) 79 | 80 | # TODO: find right way to broadcast 81 | loc_inp = torch.cat([new_hidden, H_mean]) 82 | loc_out = self.termination(loc_inp) 83 | m = nn.Sigmoid() 84 | stop = m(loc_out) 85 | stop = torch.max(stop).view((1,1)) 86 | 87 | # Decoder 88 | #new_state = self.decoder(torch.cat([new_hidden, z], 1)) # first version 89 | next_node_energy = self.decoder_nextnode(torch.cat([new_hidden, z], 1)) 90 | # Add a message-passing between the two? 91 | new_state = self.decoder_update(torch.cat([new_hidden, z, next_node_energy], 1)) 92 | 93 | return new_state, new_hidden, stop, next_node_energy 94 | 95 | 96 | # Iterate steps until completion 97 | def forward(self, graph, states, edges_mat): 98 | 99 | # Initialize hidden state at zero 100 | hidden = torch.zeros(states.size(1), self.n_hid).float() 101 | #print('Shape of hidden state:', hidden.size()) 102 | 103 | # Store states and termination prediction 104 | pred_all = [states[0].view(-1,1).float()] 105 | pred_stop = [torch.tensor([[0]]).float()] 106 | pred_nextnode = [] 107 | 108 | # set all edges features inside graph (for easier message passing) 109 | edges_features = [] 110 | for i in range(graph.edges()[0].size(0)): 111 | # Extract the features of each existing edge 112 | edges_features.append(edges_mat[graph.edges()[0][i], graph.edges()[1][i]]) 113 | 114 | graph.edata['features'] = torch.FloatTensor(edges_features) 115 | 116 | if self.useCuda: 117 | graph.edata['features'] = graph.edata['features'].cuda() 118 | graph.ndata['priority'] = graph.ndata['priority'].cuda() 119 | hidden = hidden.cuda() 120 | pred_stop = [torch.tensor([[0]]).float().cuda()] 121 | 122 | # Iterate the algorithm for all steps 123 | for i in range(states.size(0)-1): 124 | new_state, hidden, stop, next_node_energy = self.step(graph, pred_all[i], hidden) 125 | 126 | next_node_pred = next_node_energy#nn.Softmax(dim=0)(next_node_energy) # Softmax is already done in CrossEntropyLoss from Pytorch? 127 | 128 | pred_all.append(new_state) 129 | pred_stop.append(stop) 130 | pred_nextnode.append(next_node_pred) 131 | 132 | preds = torch.stack(pred_all, dim=1).view(states.size(0),states.size(1)) 133 | preds_stop = torch.stack(pred_stop, dim=1) 134 | preds_nextnode = torch.stack(pred_nextnode, dim=1) 135 | 136 | 137 | return preds, preds_stop, preds_nextnode 138 | -------------------------------------------------------------------------------- /mpnn_2.py: -------------------------------------------------------------------------------- 1 | # Define the MNN model that will be trained 2 | 3 | import dgl 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | # Setting the seed for replicability 9 | import random 10 | random.seed(33) 11 | 12 | 13 | class MPNN(nn.Module): 14 | # Expects dgl graphs as inputs 15 | def __init__(self, in_feats, hidden_feats, edge_feats, out_feats, useCuda=False): 16 | super(MPNN, self).__init__() 17 | self.n_hid = hidden_feats 18 | self.encoder = nn.Linear(in_feats + hidden_feats +1, hidden_feats) 19 | self.M = nn.Linear( hidden_feats * 2 + edge_feats, 32) 20 | self.U = nn.Linear(hidden_feats * 2 , hidden_feats) 21 | self.decoder = nn.Linear(hidden_feats * 2 , in_feats) 22 | self.termination = nn.Linear(hidden_feats , 1) 23 | self.useCuda = useCuda 24 | 25 | def compute_send_messages(self, edges): 26 | # The argument is a batch of edges. 27 | # This computes a (batch of) message called 'msg' using the source node's feature 'h'. 28 | z_src = edges.src['z'] 29 | z_dst = edges.dst['z'] 30 | 31 | msg = self.M(torch.cat([z_src, z_dst, edges.data['features'].view(-1,1)], 1)) 32 | return {'msg' : msg} 33 | 34 | def max_reduce_messages(self, nodes): 35 | # The argument is a batch of nodes. 36 | # This computes the new 'h' features by summing received 'msg' in each node's mailbox. 37 | return {'u_input' : torch.max(nodes.mailbox['msg'], dim=1).values} 38 | 39 | 40 | # A step corresponds to 1 iteration of the network: 41 | # Giving the state of the graph after one iteration of the algrithm 42 | def step(self, graph, inputs, hidden): 43 | 44 | # Encoding x^t and h^{t-1} 45 | inputs = inputs.view(-1,1).float() 46 | inp = torch.cat([inputs, hidden, graph.ndata['priority'].view(-1,1)], 1) 47 | z = self.encoder(inp) 48 | graph.ndata['z'] = z 49 | 50 | # Processor 51 | graph.send(graph.edges(), self.compute_send_messages) 52 | # trigger aggregation at all nodes 53 | graph.recv(graph.nodes(), self.max_reduce_messages) 54 | 55 | u_input = graph.ndata.pop('u_input') 56 | 57 | new_hidden = self.U(torch.cat([z, u_input], 1)) 58 | 59 | # Stoping criterion for the next step 60 | H_mean = torch.mean(new_hidden, dim=0, keepdim=True) 61 | loc_out = self.termination(H_mean) 62 | m = nn.Sigmoid() 63 | stop = m(loc_out) 64 | stop = torch.max(stop).view((1,1)) 65 | 66 | # Decoder 67 | new_state = self.decoder(torch.cat([new_hidden, z], 1)) 68 | new_state = new_state.masked_fill(inputs.bool(), float('-inf')) 69 | softmax = F.softmax(new_state, 0) 70 | 71 | return softmax, new_hidden, stop 72 | 73 | 74 | # Iterate steps until completion 75 | def forward(self, graph, states, edges_mat): 76 | # Initialize hidden state at zero 77 | hidden = torch.zeros(states.size(1), self.n_hid).float() 78 | 79 | # Store states and termination prediction 80 | pred_states = [] 81 | pred_stop = [torch.tensor([[0]]).float()] 82 | 83 | # set all edges features inside graph (for easier message passing) 84 | edges_features = [] 85 | for i in range(graph.edges()[0].size(0)): 86 | # Extract the features of each existing edge 87 | edges_features.append(edges_mat[graph.edges()[0][i], graph.edges()[1][i]]) 88 | 89 | graph.edata['features'] = torch.FloatTensor(edges_features) 90 | 91 | if self.useCuda: 92 | graph.edata['features'] = graph.edata['features'].cuda() 93 | graph.ndata['priority'] = graph.ndata['priority'].cuda() 94 | hidden = hidden.cuda() 95 | pred_stop = [torch.tensor([[0]]).float().cuda()] 96 | 97 | # Iterate the algorithm for all steps 98 | for i in range(states.size(0)): 99 | new_state, hidden, stop = self.step(graph, states[i], hidden) 100 | 101 | pred_states.append(new_state) 102 | pred_stop.append(stop) 103 | 104 | if len(pred_states) == 1: 105 | preds = torch.empty(1) 106 | preds_stop = torch.stack([pred_stop[1]], dim=1) 107 | else: 108 | preds = torch.stack(pred_states[:-1], dim=0).view(-1, states.size(1)) 109 | preds_stop = torch.stack(pred_stop[:-1], dim=1) 110 | 111 | return preds, preds_stop 112 | 113 | 114 | # Iterate steps until completion 115 | def predict(self, graph, states, edges_mat): 116 | # Initialize hidden state at zero 117 | hidden = torch.zeros(states.size(1), self.n_hid).float() 118 | 119 | # Store states and termination prediction 120 | pred_states = [] 121 | pred_stop = [] 122 | 123 | # set all edges features inside graph (for easier message passing) 124 | edges_features = [] 125 | for i in range(graph.edges()[0].size(0)): 126 | # Extract the features of each existing edge 127 | edges_features.append(edges_mat[graph.edges()[0][i], graph.edges()[1][i]]) 128 | 129 | graph.edata['features'] = torch.FloatTensor(edges_features) 130 | 131 | if self.useCuda: 132 | graph.edata['features'] = graph.edata['features'].cuda() 133 | graph.ndata['priority'] = graph.ndata['priority'].cuda() 134 | hidden = hidden.cuda() 135 | pred_stop = [torch.tensor([[0]]).float().cuda()] 136 | 137 | # Iterate the algorithm until termination flag 138 | # Clone to avoid modifying ground truth data 139 | new_state = states[0].clone() 140 | stop = 0 141 | for _ in range(graph.number_of_nodes()) : 142 | if stop > 0.5: 143 | break 144 | 145 | softmax, hidden, stop = self.step(graph, new_state, hidden) 146 | 147 | idx = torch.argmax(softmax).item() 148 | 149 | # loc_res is the prediction, new_state the new state 150 | loc_res = torch.zeros(states.size(1), 1).float() 151 | if self.useCuda: loc_res = loc_res.cuda() 152 | new_state[idx] = 1 153 | loc_res[idx] =1 154 | 155 | pred_states.append(loc_res) 156 | pred_stop.append(stop) 157 | 158 | if len(pred_states) == 1: 159 | preds = None 160 | preds_stop = torch.stack([pred_stop[1]], dim=1) 161 | else: 162 | preds = torch.stack(pred_states[:], dim=0).view(-1, states.size(1)) 163 | preds_stop = torch.stack(pred_stop[:], dim=1) 164 | 165 | return preds, preds_stop 166 | -------------------------------------------------------------------------------- /papers/graph_attention_networks.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gbarcik/graphnets/f1cbf3b6a9fe8beac3d8ee2674deb62fe361b15d/papers/graph_attention_networks.pdf -------------------------------------------------------------------------------- /papers/neural_execution_engines.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gbarcik/graphnets/f1cbf3b6a9fe8beac3d8ee2674deb62fe361b15d/papers/neural_execution_engines.pdf -------------------------------------------------------------------------------- /papers/neural_execution_of_graphs.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gbarcik/graphnets/f1cbf3b6a9fe8beac3d8ee2674deb62fe361b15d/papers/neural_execution_of_graphs.pdf -------------------------------------------------------------------------------- /papers/neural_message_passing_networks.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gbarcik/graphnets/f1cbf3b6a9fe8beac3d8ee2674deb62fe361b15d/papers/neural_message_passing_networks.pdf -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | certifi==2019.9.11 2 | cffi==1.13.2 3 | cycler==0.10.0 4 | decorator==4.4.1 5 | joblib==0.14.0 6 | kiwisolver==1.1.0 7 | matplotlib==3.1.1 8 | mkl-fft==1.0.15 9 | mkl-random==1.1.0 10 | mkl-service==2.3.0 11 | networkx==2.4 12 | numpy==1.17.3 13 | pandas==0.25.3 14 | pycparser==2.19 15 | pyparsing==2.4.5 16 | python-dateutil==2.8.1 17 | pytz==2019.3 18 | scikit-learn==0.21.3 19 | scipy==1.3.1 20 | six==1.13.0 21 | torch==1.3.1 22 | tornado==6.0.3 23 | dgl==0.4.1 -------------------------------------------------------------------------------- /tasks.txt: -------------------------------------------------------------------------------- 1 | graph generation: 2 | 3 | [] ladder graphs 4 | [] 2d grid graphs 5 | [] trees, uniformly randomly generated from the Pruffer sequence 6 | [] erdos-renyi graphs with edge probability min(logV/V, 0.5) 7 | [] barabasi-albert graphs, attaching either four or five edges to every incoming node 8 | [] 4-community graphs - first generating four disjoint erdos-renyi graphs with edge probability 0.7, followed by interconnecting their nodes with edge probability 0.01 9 | [] 4-caveman graphs, having each of their intra-clique edges removed with probability 0.7, followed by inserting 0.025|V| additional shortcut edges between cliques 10 | 11 | implementation of graph algorithms 12 | 13 | [] topological sort 14 | [] dfs 15 | 16 | dataset generation: 17 | 18 | [] create dataset generation given (graph_type, nb_graphs, nb_nodes, algorithm_type) 19 | 20 | -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | # Performs a training of the current MNN model 2 | # TODO: store the generated graphs and data to make sure the experiments are replicable 3 | # TODO: store results of experiments in a file 4 | 5 | import dgl 6 | from generate_dataset import DatasetGenerator 7 | import torch 8 | import torch.nn as nn 9 | from mpnn import MPNN 10 | import time 11 | import networkx as nx 12 | import numpy as np 13 | 14 | # Setting the seed for replicability 15 | import random 16 | random.seed(33) 17 | 18 | use_cuda = torch.cuda.is_available() 19 | 20 | ##################################################### 21 | # --- Training parameters 22 | ##################################################### 23 | 24 | # For the training hyperparameters, insire from paper 25 | nb_epochs = 20 26 | nb_features = 32 27 | lr = 0.005 28 | 29 | # Datasets parameters 30 | graph_type = 'erdos_renyi' 31 | nb_graphs = 200 32 | nb_nodes = 20 33 | algorithm_type = 'DFS' 34 | 35 | max_steps = nb_nodes + 1 # maximum number of steps before stopping 36 | # I added +1 as experimentally the case happends, to investigate 37 | 38 | #################################################### 39 | # --- Data generation 40 | #################################################### 41 | 42 | start = time.time() 43 | data_gen = DatasetGenerator() 44 | graphs, history_dataset, next_nodes = data_gen.run(graph_type, nb_graphs, nb_nodes, 45 | algorithm_type) 46 | 47 | print('Dataset created in:', time.time()-start) 48 | clock = time.time() 49 | 50 | # Prepare the data in an easily exploitable format 51 | # It could probably be optimised with DGL 52 | 53 | terminations = [] 54 | edges_mats = [] 55 | 56 | # Do all the necessary transforms on the data 57 | for i in range(nb_graphs): 58 | states = history_dataset[i] 59 | nextnode_data = next_nodes[i] 60 | termination = np.zeros(states.shape[0]) 61 | termination[-1] = 1 62 | 63 | # TODO?: Adding self edge to every node, as in the paper 64 | for j in range(nb_nodes): 65 | graphs[i].add_edge(j, j) 66 | 67 | assert max_steps >= states.shape[0] 68 | 69 | # set states to fixed lenght with termination boolean to be able to compute termination error 70 | if states.shape[0] < max_steps: 71 | pad_idx = [(0,0) for i in range(states.ndim)] 72 | pad_idx[0] = (0, max_steps-states.shape[0]) 73 | states = np.pad(states, pad_idx, 'edge') 74 | nextnode_data = np.pad(nextnode_data, pad_idx, 'edge') 75 | termination = np.pad(termination, (0, max_steps-termination.size), 'constant', constant_values=(1)) 76 | history_dataset[i] = states 77 | # For nn.CrossEntroppyLoss, next node is expected to be an index and not 1 hot 78 | nextnode_data = np.argmax(nextnode_data, axis=1) 79 | next_nodes[i] = nextnode_data 80 | terminations.append(termination) 81 | edges_mats.append(nx.to_numpy_matrix(graphs[i])) 82 | g = dgl.DGLGraph() 83 | g.from_networkx(graphs[i], node_attrs=['priority']) 84 | graphs[i] = g 85 | 86 | # Take 10% of the graphs as validation 87 | nb_val = int(0.1*nb_graphs) 88 | 89 | train_data = [(graphs[i], edges_mats[i], history_dataset[i], terminations[i], next_nodes[i]) for i in range(nb_graphs-nb_val)] 90 | test_data = [(graphs[i], edges_mats[i], history_dataset[i], terminations[i], next_nodes[i]) for i in range(nb_graphs-nb_val, nb_graphs)] 91 | 92 | # Data loaders do not seem to be compatible with DGL 93 | #train_loader = torch.utils.data.DataLoader(train_data) 94 | #train_loader = torch.utils.data.DataLoader(test_data) 95 | 96 | # dataset is an array of size batch size 97 | # each idem is a history of shape (nb_steps, nb_nodes) 98 | # We now have in our possession a full dataset on which to train 99 | 100 | 101 | ################################################### 102 | 103 | model = MPNN(1, 32, 1, 1, useCuda=use_cuda) 104 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 105 | 106 | if use_cuda: 107 | print('Using GPU') 108 | model.cuda() 109 | else: 110 | print('Using CPU') 111 | 112 | model.train() 113 | 114 | def nll_gaussian(preds, target): 115 | neg_log_p = ((preds - target) ** 2) 116 | return neg_log_p.sum() / (target.size(0) * target.size(1)) 117 | 118 | verbose = False 119 | 120 | teacher_forcing = True 121 | 122 | if teacher_forcing: 123 | print('Using teacher forcing!') 124 | 125 | for epoch in range(nb_epochs): 126 | print('Epoch:', epoch) 127 | 128 | losses = [] 129 | clock = time.time() 130 | 131 | for batch_idx, (graph, edges_mat, states, termination, nextnodes_mat) in enumerate(train_data): 132 | 133 | states = torch.from_numpy(states) 134 | edges_mat = torch.from_numpy(edges_mat) 135 | termination = torch.from_numpy(termination) 136 | nextnodes_mat = torch.from_numpy(nextnodes_mat) 137 | 138 | if use_cuda: 139 | states, edges_mat, termination, nextnodes_mat = states.cuda(), edges_mat.cuda(), termination.cuda(), nextnodes_mat.cuda() 140 | 141 | if verbose: 142 | print('--- Processing new graph! ---') 143 | print('edges_mat shape:', edges_mat.shape) 144 | print('states shape (after reshape):', states.shape) 145 | print('termination shape (after reshape):', termination.shape) 146 | 147 | if teacher_forcing: 148 | 149 | for step in range(states.size()[1]-1): 150 | loc_states = states[step:step+2] 151 | loc_nextnode = nextnodes_mat[step] 152 | loc_termination = termination[step:step+2] 153 | preds, pred_stops, pred_nextnodes = model(graph, loc_states, edges_mat) 154 | pred_nextnodes = pred_nextnodes.view(-1, pred_nextnodes.size()[0]) 155 | #print('size of nextnode pred in train', pred_nextnodes.size()) 156 | 157 | loss = nll_gaussian(preds, (loc_states)) 158 | loss += 100 * nn.CrossEntropyLoss()(pred_nextnodes, loc_nextnode.unsqueeze(0)) 159 | loss += ((pred_stops-loc_termination)**2).sum()/max_steps 160 | 161 | optimizer.zero_grad() 162 | loss.backward() 163 | optimizer.step() 164 | 165 | losses.append(loss.item()) 166 | 167 | 168 | else: 169 | 170 | 171 | # We do optimizer call only after completely processing the graph 172 | #graph, states = graphs[i], history_dataset[i] 173 | # extract adjacency matrix 174 | #edges_mat = nx.to_numpy_matrix(graph) 175 | # Convert graph to DGL 176 | #graph = dgl.DGLGraph(graph) 177 | 178 | preds, pred_stops, pred_nextnodes = model(graph, states, edges_mat) 179 | 180 | #print('true', nextnodes_mat.size()) 181 | pred_nextnodes = pred_nextnodes.view(-1, pred_nextnodes.size()[0]) 182 | #print('pred', pred_nextnodes.size()) 183 | 184 | # Compare the components of the loss for tuning 185 | loss = nll_gaussian(preds, torch.t(states)) 186 | print('prediction loss:', nll_gaussian(preds, torch.t(states))) 187 | loss += 100 * nn.CrossEntropyLoss()(pred_nextnodes, nextnodes_mat) 188 | print('Next node prediction loss:', nn.CrossEntropyLoss()(pred_nextnodes, nextnodes_mat)) 189 | print('pred_nextnodes', pred_nextnodes) 190 | print('nextnodes_mat', nextnodes_mat) 191 | loss += ((pred_stops-termination)**2).sum()/max_steps # MSE of output and states + termination loss 192 | print('termination loss:', ((pred_stops-termination)**2).sum()/max_steps) 193 | 194 | optimizer.zero_grad() 195 | loss.backward() 196 | optimizer.step() 197 | 198 | losses.append(loss.item()) 199 | 200 | print('Epoch run in:', time.time()-clock) 201 | clock = time.time() 202 | print('Loss:', np.mean(np.asarray(losses))) 203 | 204 | print('states:', states) 205 | print('pred:', torch.t(preds)) 206 | print('termination:', termination) 207 | print('pred_stops:', pred_stops) -------------------------------------------------------------------------------- /training_2.py: -------------------------------------------------------------------------------- 1 | # Performs a training of the current MNN model 2 | 3 | import dgl 4 | from generate_dataset_2 import DatasetGenerator 5 | import torch 6 | import torch.nn as nn 7 | from mpnn_2 import MPNN 8 | import time 9 | import networkx as nx 10 | import numpy as np 11 | 12 | # Setting the seed for replicability 13 | import random 14 | random.seed(32) 15 | 16 | use_cuda = torch.cuda.is_available() 17 | 18 | ##################################################### 19 | # --- Training parameters 20 | ##################################################### 21 | 22 | # For the training hyperparameters, insire from paper 23 | nb_epochs = 50 24 | nb_features = 32 25 | lr = 0.0005 26 | 27 | # Datasets parameters 28 | algorithm_type = 'DFS' 29 | 30 | nb_graphs = {} 31 | nb_nodes = {} 32 | graph_types = {} 33 | 34 | # Allows to generate graphs of different sizes in each dataset 35 | nb_graphs['train'] = [40] 36 | nb_nodes['train'] = [20] 37 | graph_types['train'] = ['erdos_renyi'] 38 | 39 | nb_graphs['test'] = [20] 40 | nb_nodes['test'] = [20] 41 | graph_types['test'] = ['erdos_renyi'] 42 | 43 | # maximum number of steps before stopping 44 | max_steps = max(max(nb_nodes['train']), max(nb_nodes['test'])) + 1 45 | 46 | #################################################### 47 | # --- Data generation 48 | #################################################### 49 | 50 | start = time.time() 51 | data_gen = DatasetGenerator() 52 | 53 | data = { 54 | 'train': [], 55 | 'test': [] 56 | } 57 | 58 | # Prepare the data in an easily exploitable format 59 | for phase in ['train', 'test']: 60 | # Generate the right number of graphs of each size 61 | for idx, nb_g in enumerate(nb_graphs[phase]): 62 | 63 | nb_n = nb_nodes[phase][idx] 64 | graph_type = graph_types[phase][idx] 65 | 66 | graphs, next_nodes = data_gen.run(graph_type, nb_g, nb_n, algorithm_type) 67 | 68 | terminations = [] 69 | edges_mats = [] 70 | next_nodes = list(next_nodes) 71 | 72 | # Do all the necessary transforms on the data 73 | for i in range(nb_g): 74 | states = np.asarray(next_nodes[i]) 75 | termination = np.zeros(len(states)) 76 | termination[-1] = 1 77 | 78 | # Adding self edge to every node, as in the paper 79 | for j in range(nb_n): 80 | graphs[i].add_edge(j, j) 81 | 82 | assert max_steps >= len(states) 83 | 84 | terminations.append(termination) 85 | edges_mats.append(nx.to_numpy_matrix(graphs[i])) 86 | g = dgl.DGLGraph() 87 | g.from_networkx(graphs[i], node_attrs=['priority']) 88 | graphs[i] = g 89 | 90 | data[phase] += [(graphs[i], edges_mats[i], next_nodes[i], terminations[i]) for i in range(nb_g)] 91 | 92 | train_data = data['train'] 93 | test_data = data['test'] 94 | 95 | print('Dataset created in:', time.time()-start) 96 | clock = time.time() 97 | 98 | ################################################### 99 | 100 | model = MPNN(1, 32, 1, 1, useCuda=use_cuda) 101 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 102 | 103 | if use_cuda: 104 | print('Using GPU') 105 | model.cuda() 106 | else: 107 | print('Using CPU') 108 | 109 | model.train() 110 | 111 | def nll_gaussian(preds, target): 112 | neg_log_p = ((preds - target) ** 2) 113 | return neg_log_p.sum() / (target.size(0) * target.size(1)) 114 | 115 | 116 | def next_state_accuracy(preds, targets): 117 | # Evaluates the average accuracy in predicting the next node 118 | next_node_pred = torch.argmax(preds, axis=-1) 119 | nb_false = torch.nonzero(targets-next_node_pred).size(0) 120 | return (targets.shape[0]-nb_false) / targets.shape[0] 121 | 122 | verbose = False 123 | 124 | teacher_forcing = True 125 | 126 | if teacher_forcing: 127 | print('Using teacher forcing!') 128 | 129 | for epoch in range(nb_epochs): 130 | print('Epoch:', epoch) 131 | 132 | train_losses = [] 133 | test_losses = [] 134 | 135 | train_accuracies = [] 136 | test_accuracies = [] 137 | test_exact_terminations = 0 138 | 139 | clock = time.time() 140 | 141 | model.train() 142 | 143 | for batch_idx, (graph, edges_mat, states, termination) in enumerate(train_data): 144 | 145 | states = torch.from_numpy(np.asarray(states)) 146 | edges_mat = torch.from_numpy(edges_mat) 147 | termination = torch.from_numpy(termination) 148 | 149 | if states.shape[0] > 1: 150 | # if more than 1 state, prepare the target of the network 151 | target = [] 152 | target.extend([np.where(states[i]-states[i-1])[0] for i in range(1, states.shape[0])]) 153 | target = np.hstack(target) 154 | target = torch.LongTensor(target) 155 | if use_cuda: target = target.cuda() 156 | 157 | if use_cuda: 158 | states, edges_mat, termination = states.cuda(), edges_mat.cuda(), termination.cuda() 159 | 160 | if verbose: 161 | print('--- Processing new graph! ---') 162 | print('edges_mat shape:', edges_mat.shape) 163 | print('states shape (after reshape):', states.shape) 164 | print('termination shape (after reshape):', termination.shape) 165 | 166 | # We do optimizer call only after completely processing the graph 167 | preds, pred_stops = model(graph, states, edges_mat) 168 | 169 | # Compare the components of the loss for tuning 170 | if states.shape[0] > 1: 171 | loss = nn.CrossEntropyLoss() 172 | output = loss(preds, target) 173 | else: 174 | # Sometimes the algorithm is already terminated when starting, in which case there is nothing to compare 175 | output = torch.tensor([0]).type(torch.FloatTensor) 176 | if use_cuda: output = output.cuda() 177 | 178 | loss2 = nn.BCELoss() 179 | output += loss2(pred_stops.view(-1, 1), termination.float().view(-1, 1)) 180 | 181 | optimizer.zero_grad() 182 | output.backward() 183 | optimizer.step() 184 | 185 | train_losses.append(output.item()) 186 | if states.shape[0] > 1: train_accuracies.append(next_state_accuracy(preds, target)) 187 | 188 | print('Train epoch run in:', time.time()-clock) 189 | clock = time.time() 190 | print(' Training Loss:', np.mean(np.asarray(train_losses))) 191 | print('Train average accuracy:', np.mean(np.asarray(train_accuracies))) 192 | 193 | model.eval() 194 | 195 | for batch_idx, (graph, edges_mat, states, termination) in enumerate(test_data): 196 | 197 | states = torch.from_numpy(np.asarray(states)) 198 | edges_mat = torch.from_numpy(edges_mat) 199 | termination = torch.from_numpy(termination) 200 | 201 | target = None 202 | if states.shape[0] > 1: 203 | # if more than 1 state, prepare the target of the network 204 | target = [] 205 | target.extend([np.where(states[i]-states[i-1])[0] for i in range(1, states.shape[0])]) 206 | target = np.hstack(target) 207 | target = torch.LongTensor(target) 208 | if use_cuda: target = target.cuda() 209 | 210 | if use_cuda: 211 | states, edges_mat, termination = states.cuda(), edges_mat.cuda(), termination.cuda() 212 | 213 | if verbose: 214 | print('--- Processing new graph! ---') 215 | print('edges_mat shape:', edges_mat.shape) 216 | print('states shape (after reshape):', states.shape) 217 | print('termination shape (after reshape):', termination.shape) 218 | 219 | 220 | # We do optimizer call only after completely processing the graph 221 | preds, pred_stops = model.predict(graph, states, edges_mat) 222 | 223 | # Compare the components of the loss for tuning 224 | 225 | # In tests, need to compare the right lenghts 226 | if preds is not None and states.shape[0] > 1: 227 | comparable_lenght = min(preds.size()[0], target.size()[0]) 228 | test_exact_terminations += (preds.size()[0] == target.size()[0]) 229 | else: 230 | comparable_lenght = 0 231 | test_exact_terminations += (states.shape[0] == 1) 232 | 233 | if states.shape[0] > 1 and preds is not None: # 234 | loss = nn.CrossEntropyLoss() 235 | output = loss(preds[:comparable_lenght], target[:comparable_lenght]) 236 | else: 237 | # Sometimes the algorithm is already terminated when starting, in which case there is nothing to compare 238 | output = torch.tensor([0]).type(torch.FloatTensor) 239 | if use_cuda: output = output.cuda() 240 | 241 | loss2 = nn.BCELoss() 242 | output += loss2(pred_stops.view(-1, 1)[:comparable_lenght+1], termination.float().view(-1, 1)[:comparable_lenght+1]) 243 | 244 | test_losses.append(output.item()) 245 | if states.shape[0] > 1 and preds is not None: test_accuracies.append(next_state_accuracy(preds[:comparable_lenght], target[:comparable_lenght])) 246 | 247 | 248 | if epoch