├── .gitignore ├── LICENSE ├── README.md ├── mlp.py ├── neurosat ├── README.md ├── cnf.py ├── generator.py ├── instance_loader.py ├── logutil.py ├── model.py ├── neurosat.py └── sat-test.py ├── tgn.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Machine Reasoning and Learning Research Group 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # typed-graph-network 2 | 3 | Companion code to the Typed Graph Networks paper https://arxiv.org/abs/1901.07984 4 | A model builder helper for creating graph networks akin to the ones described in https://arxiv.org/abs/1806.01261 and graph neural networks https://ieeexplore.ieee.org/document/4700287. 5 | 6 | To run the example model copy the tgn.py, mlp.py and util.py files to the repository. 7 | -------------------------------------------------------------------------------- /mlp.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | class Mlp(object): 4 | def __init__( 5 | self, 6 | layer_sizes, 7 | output_size = None, 8 | activations = None, 9 | output_activation = None, 10 | use_bias = True, 11 | kernel_initializer = None, 12 | bias_initializer = tf.zeros_initializer(), 13 | kernel_regularizer = None, 14 | bias_regularizer = None, 15 | activity_regularizer = None, 16 | kernel_constraint = None, 17 | bias_constraint = None, 18 | trainable = True, 19 | name = None, 20 | name_internal_layers = True 21 | ): 22 | """Stacks len(layer_sizes) dense layers on top of each other, with an additional layer with output_size neurons, if specified.""" 23 | self.layers = [] 24 | internal_name = None 25 | # If object isn't a list, assume it is a single value that will be repeated for all values 26 | if not isinstance( activations, list ): 27 | activations = [ activations for _ in layer_sizes ] 28 | #end if 29 | # If there is one specifically for the output, add it to the list of layers to be built 30 | if output_size is not None: 31 | layer_sizes = layer_sizes + [output_size] 32 | activations = activations + [output_activation] 33 | #end if 34 | for i, params in enumerate( zip( layer_sizes, activations ) ): 35 | size, activation = params 36 | if name_internal_layers: 37 | internal_name = name + "_MLP_layer_{}".format( i + 1 ) 38 | #end if 39 | new_layer = tf.layers.Dense( 40 | size, 41 | activation = activation, 42 | use_bias = use_bias, 43 | kernel_initializer = kernel_initializer, 44 | bias_initializer = bias_initializer, 45 | kernel_regularizer = kernel_regularizer, 46 | bias_regularizer = bias_regularizer, 47 | activity_regularizer = activity_regularizer, 48 | kernel_constraint = kernel_constraint, 49 | bias_constraint = bias_constraint, 50 | trainable = trainable, 51 | name = internal_name 52 | ) 53 | self.layers.append( new_layer ) 54 | #end for 55 | #end __init__ 56 | 57 | def __call__( self, inputs, *args, **kwargs ): 58 | outputs = [ inputs ] 59 | for layer in self.layers: 60 | outputs.append( layer( outputs[-1] ) ) 61 | #end for 62 | return outputs[-1] 63 | #end __call__ 64 | #end Mlp 65 | -------------------------------------------------------------------------------- /neurosat/README.md: -------------------------------------------------------------------------------- 1 | # NeuroSAT 2 | NeuroSAT specific code, using the generic graphNN 3 | 4 | # Dependencies 5 | 6 | Pycosat 0.6.3 7 | -------------------------------------------------------------------------------- /neurosat/cnf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import multiprocessing as mp 3 | import pycosat 4 | import copy 5 | import numpy as np 6 | 7 | class CNF(object): 8 | 9 | def __init__(self,n,m=0): 10 | self.n = n 11 | self.m = m 12 | self.clauses = [] 13 | self.sat = None 14 | self.filename = "" 15 | #end 16 | 17 | def SR(n): 18 | 19 | cnf = CNF(n) 20 | sat = True 21 | 22 | while sat: 23 | # Select a random k ~ Bernouilli(0.3) + Geo(0.4) 24 | k = 1 + np.random.binomial(1,0.3) + np.random.geometric(0.4) 25 | # Create a clause with k randomly selected variables 26 | clause = [ int(np.random.randint(1,n+1) * np.random.choice([-1,+1])) for i in range(k) ] 27 | # Append clause to cnf 28 | cnf.clauses.append(clause) 29 | # Check for satisfiability 30 | if pycosat.solve(cnf.clauses) == "UNSAT": 31 | sat = False 32 | # Create an identical copy of cnf 33 | cnf2 = copy.deepcopy(cnf) 34 | # Flip the polarity of a single literal in the last clause of cnf2 35 | cnf2.clauses[-1][np.random.randint(0,len(cnf2.clauses[-1]))] *= -1 36 | #end 37 | #end 38 | 39 | cnf.sat = False 40 | cnf2.sat = True 41 | 42 | cnf.m = cnf2.m = len(cnf.clauses) 43 | 44 | return cnf,cnf2 45 | #end 46 | 47 | def SRU(n0,n1): 48 | n = np.random.randint(n0,n1+1) 49 | return CNF.SR(n) 50 | #end 51 | 52 | def random_3SAT_critical(n): 53 | m = int(4.26 * n) 54 | 55 | cnf = CNF(n,m) 56 | 57 | for i in range(m): 58 | clause = [ int(np.random.randint(1,n+1) * np.random.choice([-1,+1])) for k in range(3) ] 59 | cnf.clauses.append( clause ) 60 | #end 61 | 62 | cnf.sat = pycosat.solve(cnf.clauses) != "UNSAT" 63 | 64 | return cnf 65 | #end 66 | 67 | def write_dimacs(self,path): 68 | with open(path,"w") as out: 69 | out.write("p cnf {} {} {}\n".format(self.n, self.m, int(self.sat) )) 70 | 71 | for clause in self.clauses: 72 | out.write( ' '.join([ str(x) for x in clause]) + ' 0\n') 73 | #end 74 | #end 75 | #end 76 | 77 | def read_dimacs(path): 78 | with open(path,"r") as f: 79 | n, m, sat = [ int(x) for x in f.readline().split()[2:]] 80 | cnf = CNF(n,m) 81 | cnf.sat = bool(sat) 82 | for i in range(m): 83 | cnf.clauses.append( [ int(x) for x in f.readline().split()[:-1]] ) 84 | #end 85 | cnf.filename = path 86 | #end 87 | return cnf 88 | #end 89 | #end 90 | 91 | class BatchCNF(object): 92 | 93 | def __init__(self,n,m,clauses,sat,filenames=None): 94 | """ 95 | batch_size: number of instances in this batch 96 | n: number of variables for each instance 97 | total_n: total number of variables among all instances 98 | m: number of clauses for each instance 99 | total_m: total number of clauses among all instances 100 | clauses: concatenated list of clauses among all instances 101 | sat: satisfiability of each instance 102 | """ 103 | self.batch_size = len(n) 104 | self.n = n 105 | self.total_n = sum(n) 106 | self.m = m 107 | self.total_m = sum(m) 108 | self.clauses = clauses 109 | self.sat = sat 110 | self.filenames = [] if filenames is None else filenames 111 | #end 112 | 113 | def get_dense_matrix(self): 114 | M = np.zeros( (2*self.total_n, self.total_m), dtype=np.float32 ) 115 | n_cells = sum([ len(clause) for clause in self.clauses ]) 116 | cell = 0 117 | for (j,clause) in enumerate(self.clauses): 118 | for literal in clause: 119 | i = int(abs(literal) - 1) 120 | p = np.sign(literal) 121 | if p == +1: 122 | M[i,j] = 1 123 | elif p == -1: 124 | M[self.total_n + i, j] = 1 125 | #end 126 | cell += 1 127 | #end for literal 128 | #end for j,clause 129 | return M 130 | #end get_dense_matrix 131 | 132 | def get_sparse_matrix(self): 133 | """ 134 | First we need to count the number of non-null cells in our 135 | adjacency matrix. This can be computed as the sum of all clause 136 | sizes Σ|c| ∀c ∈ F 137 | """ 138 | n_cells = sum([ len(clause) for clause in self.clauses ]) 139 | 140 | # Define sparse_M with shape (n_cells,2) 141 | sparse_M = np.zeros((n_cells,2), dtype=np.int) 142 | 143 | cell = 0 144 | for (j,clause) in enumerate(self.clauses): 145 | for literal in clause: 146 | i = int(abs(literal) - 1) 147 | p = np.sign(literal) 148 | 149 | if p == +1: 150 | sparse_M[cell] = [i,j] 151 | elif p == -1: 152 | sparse_M[cell] = [self.total_n + i, j] 153 | #end 154 | 155 | cell += 1 156 | #end 157 | #end 158 | return sparse_M, np.ones( n_cells, dtype = np.float32 ), (2*self.total_n, self.total_m) 159 | #end 160 | 161 | #end 162 | 163 | def create_batchCNF(instances): 164 | """ 165 | Create a BatchCNF object from a list of cnf instances 166 | """ 167 | n = [] 168 | m = [] 169 | clauses = [] 170 | sat = [] 171 | filenames = [] 172 | offset = 0 173 | for cnf in instances: 174 | n.append(cnf.n) 175 | m.append(cnf.m) 176 | clauses.extend( [ [ np.sign(literal) * (abs(literal) + offset) for literal in clause ] for clause in cnf.clauses ] ) 177 | sat.append(cnf.sat) 178 | filenames.append(cnf.filename) 179 | offset += cnf.n 180 | #end 181 | 182 | return BatchCNF(n,m,clauses,sat,filenames) 183 | #end 184 | 185 | def create_dataset( n_min = 10, n_max = 40, samples = 1000, path = "instances", MP=True ): 186 | if not MP: 187 | for i in range(samples): 188 | 189 | cnf1, cnf2 = CNF.SRU( 10, 40 ) 190 | 191 | cnf1.write_dimacs("{}/unsat/{:09d}-{:09d}-{}.cnf".format(path,cnf1.n,cnf1.m,i)) 192 | cnf2.write_dimacs("{}/sat/{:09d}-{:09d}-{}.cnf".format(path,cnf2.n,cnf2.m,i)) 193 | #end for 194 | else: 195 | ns = [ np.random.randint(n_min,n_max+1) for i in range(samples) ] 196 | ns.sort() 197 | print( "Starting multithreading with {} cores".format(mp.cpu_count()) ) 198 | with mp.Pool(mp.cpu_count()) as p: 199 | for i, (cnf1, cnf2) in enumerate( p.imap( CNF.SR, ns ) ): 200 | print( "{pct:0.2f}% -- {instance}".format( pct = 100.0 * i / samples, instance = i ) ) 201 | cnf1.write_dimacs("{}/unsat/{:09d}-{:09d}-{}.cnf".format(path,cnf1.n,cnf1.m,i)) 202 | cnf2.write_dimacs("{}/sat/{:09d}-{:09d}-{}.cnf".format(path,cnf2.n,cnf2.m,i)) 203 | #end for 204 | #end with 205 | #end if 206 | #end 207 | 208 | def create_critical_dataset( n = 40, samples = 512, path = "critical_instances", MP=True ): 209 | if not MP: 210 | for i in range( samples ): 211 | cnf = CNF.random_3SAT_critical( n ) 212 | cnf.write_dimacs( "{}/{:09d}-{}.cnf".format( path, cnf.n, i ) ) 213 | #end for 214 | else: 215 | ns = [ n for i in range(samples) ] 216 | ns.sort() 217 | with mp.Pool(12) as p: 218 | for i, cnf in enumerate( p.imap( CNF.random_3SAT_critical, ns ) ): 219 | cnf.write_dimacs( "{}/{:09d}-{}.cnf".format( path, cnf.n, i ) ) 220 | #end if 221 | #end with 222 | #end create_critical_dataset 223 | 224 | def ensure_datasets( make_critical = True ): 225 | idirs = [ "instances", "instances/sat", "instances/unsat" ] 226 | if not all( map( os.path.isdir, idirs ) ): 227 | for d in idirs: 228 | os.makedirs( d ) 229 | #end for 230 | create_dataset( 10, 40, 2**15, path = idirs[0] ) 231 | #end if 232 | tdirs = [ "test-instances", "test-instances/sat", "test-instances/unsat" ] 233 | if not all( map( os.path.isdir, tdirs ) ): 234 | for d in tdirs: 235 | os.makedirs( d ) 236 | #end for 237 | create_dataset( 40, 40, 512, path = tdirs[0]) 238 | #end if 239 | if make_critical: 240 | c40dir = "critical-instances-40" 241 | if not os.path.isdir( c40dir ): 242 | os.makedirs( c40dir ) 243 | create_critical_dataset( 40, 512, c40dir ) 244 | #end if 245 | c80dir = "critical-instances-80" 246 | if not os.path.isdir( c80dir ): 247 | os.makedirs( c80dir ) 248 | create_critical_dataset( 80, 512, c80dir ) 249 | #end if 250 | #end if 251 | #end ensure_datasets 252 | 253 | if __name__ == '__main__': 254 | ensure_datasets() 255 | -------------------------------------------------------------------------------- /neurosat/generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import copy 3 | import itertools 4 | import pycosat 5 | import time 6 | 7 | def SR(n): 8 | # Add clauses while satisfiable 9 | F1 = [] 10 | # Current solutions 11 | buffer_size = 1 12 | solutions = list(itertools.islice(pycosat.itersolve(F1,n),buffer_size)) 13 | while True: 14 | """ 15 | Randomly chooses 'k', the number of literals in this clause. 16 | k has mean 5 and is described by the distribution: 17 | 2 + Bernouilli(.3) + Geo(.4) 18 | """ 19 | k = 2 + np.random.binomial(1,.3) + np.random.geometric(.4) 20 | """ 21 | Now we sample k variables uniformly at random 22 | and negate each one with equal probability to 23 | create a clause 24 | """ 25 | C = [ int(np.random.choice([-1,+1]) * np.random.randint(1,n+1)) for i in range(k) ] 26 | # Add C to the formula 27 | F1.append(C) 28 | 29 | # Check to see F1 is unsatisfiable 30 | 31 | # Filter solutions list for satisfiable solutions 32 | solutions = [ X for X in solutions if any([ np.sign(l) * np.sign(X[int(abs(l))-1]) == 1 for l in C ]) ] 33 | 34 | # If there are no satisfiable solutions in our buffer, run PycoSAT to look for new solutions 35 | if len(solutions) == 0: 36 | solutions = list(itertools.islice(pycosat.itersolve(F1,n),buffer_size)) 37 | # If PycoSAT couldn't find any new solutions, the algorithm is done 38 | if solutions == []: 39 | break 40 | #end if 41 | elif len(solutions) < buffer_size: 42 | solutions += list(itertools.islice(pycosat.itersolve(F1,n),buffer_size-len(solutions))) 43 | #end if 44 | 45 | #end def 46 | 47 | """ 48 | By now, F1 is unsatisfied by X. But because it was 49 | satisfied up until the penultimate clause, we can 50 | create a satisfiable variant F2 by flipping the polarity 51 | of a single literal in the last clause 52 | """ 53 | F2 = copy.deepcopy(F1) 54 | F2[-1][ np.random.randint(0,len(F2[-1])) ] *= -1 55 | 56 | return F1, F2 57 | #end def 58 | 59 | def to_matrix(n,m,F): 60 | """ 61 | Converts a SAT instance from list format: 62 | i.e. a formula is a list of clauses, each of 63 | which is a list of literals, each of which is 64 | a 2-list with a polarity p ϵ {-1,+1} and a 65 | variable index i ϵ {0,n-1}) 66 | into adjacency matrix format: 67 | i.e. a formula is a binary matrix M ϵ {0,1}²ⁿˣᵐ 68 | where 69 | M(i,j) = 1 iff xi ϵ Cj else 0, 70 | M(n+i,j) = 1 iff ¬xi ϵ Cj else 0 71 | """ 72 | M = np.zeros((2*n,m)) 73 | for (j,C) in enumerate(F): 74 | for (p,i) in [ (np.sign(l), int(abs(l))-1) for l in C ]: 75 | if p == +1: 76 | M[i,j] = 1 77 | else: 78 | M[n+i,j] = 1 79 | #end if 80 | #end for 81 | #end for 82 | return M 83 | #end def 84 | 85 | def generate(n, m, batch_size = 32): 86 | while True: 87 | 88 | """ 89 | First we create a list of (batch_size//2) pairs of SAT formulas, 90 | filtering those which exceed m clauses 91 | """ 92 | unsat_formulas = [] 93 | sat_formulas = [] 94 | while len(unsat_formulas) < batch_size//2: 95 | unsat, sat = SR(n) 96 | if len(unsat) > m: 97 | continue 98 | else: 99 | unsat_formulas.append(unsat) 100 | sat_formulas.append(sat) 101 | #end if 102 | #end while 103 | 104 | """ 105 | Features are adjacency matrices with the following structure: 106 | M ϵ {0,1}²ⁿˣᵐ, 107 | M(i,j) = 1 iff xi ϵ Cj else 0, 108 | M(n+i,j) = 1 iff ¬xi ϵ Cj else 0 109 | where n is the number of variables and m the number of clauses 110 | 111 | Labels are binary scalars (+1 for satisfiable instances and -1 otherwise) 112 | """ 113 | features = np.zeros((batch_size, 2*n, m)) 114 | labels = np.zeros((batch_size,)) 115 | 116 | # We populate the batch in pairs, so we need just batch_size//2 iterations 117 | for i in range(batch_size//2): 118 | M1, M2 = to_matrix(n,m,unsat_formulas[i]), to_matrix(n,m,sat_formulas[i]) 119 | 120 | features[2*i, :] = M1 121 | features[2*i+1, :] = M2 122 | 123 | labels[2*i] = -1 124 | labels[2*i+1] = +1 125 | #end for 126 | yield features, labels 127 | #end while 128 | #end def 129 | 130 | if __name__ == '__main__': 131 | 132 | n = 5 133 | m = 50 134 | 135 | generator = generate(n,m,batch_size=32) 136 | 137 | last_time = time.time() 138 | for batch in generator: 139 | print("Created batch in {} seconds".format(time.time()-last_time)) 140 | last_time = time.time() 141 | #end for 142 | 143 | #end if 144 | -------------------------------------------------------------------------------- /neurosat/instance_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from cnf import CNF, BatchCNF, create_batchCNF 4 | from functools import reduce 5 | 6 | 7 | class InstanceLoaderSequential(object): 8 | 9 | def __init__(self,path): 10 | assert os.path.isdir( path ), "Path is not a directory. Path {}".format( path ) 11 | if path[-1] == "/": 12 | path = path[0:-1] 13 | #end if 14 | 15 | sat_folder = path + '/sat/' 16 | unsat_folder = path + '/unsat/' 17 | 18 | self.filenames = reduce(lambda x,y: x + y, sorted([ (sat.path,unsat.path) for (sat,unsat) in zip(os.scandir(sat_folder),os.scandir(unsat_folder)) ]),[]) 19 | 20 | self.reset() 21 | #end 22 | 23 | def get_instances(self, n_instances): 24 | for i in range(n_instances): 25 | yield CNF.read_dimacs(self.filenames[self.index]) 26 | self.index += 1 27 | #end 28 | #end 29 | 30 | def get_batches(self, batch_size): 31 | for i in range( len(self.filenames) // batch_size ): 32 | yield create_batchCNF(self.get_instances(batch_size)) 33 | #end 34 | #end 35 | 36 | def reset(self): 37 | self.index = 0 38 | #end 39 | #end 40 | 41 | 42 | class InstanceLoaderRandomPaired(object): 43 | 44 | def __init__(self,path): 45 | assert os.path.isdir( path ), "Path is not a directory. Path {}".format( path ) 46 | if path[-1] == "/": 47 | path = path[0:-1] 48 | #end if 49 | 50 | sat_folder = path + '/sat/' 51 | unsat_folder = path + '/unsat/' 52 | 53 | self.filenames = [ (sat.path,unsat.path) for (sat,unsat) in zip(os.scandir(sat_folder),os.scandir(unsat_folder)) ] 54 | print( self.filenames ) 55 | 56 | self.reset() 57 | #end 58 | 59 | def get_instances(self, n_instances): 60 | i = 0 61 | while i < n_instances: 62 | if i%2 == 0: 63 | yield CNF.read_dimacs(self.filenames[self.index][0]) 64 | else: 65 | yield CNF.read_dimacs(self.filenames[self.index][1]) 66 | self.index += 1 67 | #end if-else 68 | i += 1 69 | #end 70 | #end 71 | 72 | def get_batches(self, batch_size): 73 | for i in range( len(self.filenames) // batch_size ): 74 | yield create_batchCNF(self.get_instances(batch_size)) 75 | #end 76 | #end 77 | 78 | def reset(self): 79 | random.shuffle( self.filenames ) 80 | self.index = 0 81 | #end 82 | #end InstanceLoaderRandom 83 | 84 | InstanceLoader = InstanceLoaderRandomPaired 85 | 86 | if __name__ == '__main__': 87 | 88 | instance_loader = InstanceLoaderRandomPaired("test-instances") 89 | 90 | #end 91 | -------------------------------------------------------------------------------- /neurosat/logutil.py: -------------------------------------------------------------------------------- 1 | import sys, os, time 2 | sys.path.insert(1, os.path.join(sys.path[0], '..')) 3 | import instance_loader 4 | import numpy as np 5 | from util import timestamp, memory_usage 6 | 7 | def sigmoid( x, derivative = False ): 8 | return x*(1-x) if derivative else 1/(1+np.exp(-x)) 9 | #end sigmoid 10 | 11 | def run_and_log_batch( sess, solver, epoch, b, batch, time_steps, train = True ): 12 | sat = list( 1 if sat else 0 for sat in batch.sat ) 13 | # Build feed_dict 14 | feed_dict = { 15 | solver["time_steps"]: time_steps, 16 | solver["M"]: batch.get_dense_matrix(), 17 | solver["instance_SAT"]: np.array( sat ), 18 | solver["num_vars_on_instance"]: batch.n 19 | } 20 | # Run session 21 | if train: 22 | _, pred_SAT, loss_val, accuracy_val = sess.run( 23 | [ solver["train_step"], solver["predicted_SAT"], solver["loss"], solver["accuracy"] ], 24 | feed_dict = feed_dict 25 | ) 26 | else: 27 | pred_SAT, loss_val, accuracy_val = sess.run( 28 | [ solver["predicted_SAT"], solver["loss"], solver["accuracy"] ], 29 | feed_dict = feed_dict 30 | ) 31 | #end if 32 | avg_pred = np.mean( np.round( sigmoid( pred_SAT ) ) ) 33 | # Print train step loss and accuracy, as well as predicted sat values compared with the normal ones 34 | print( 35 | "{timestamp}\t{memory}\tEpoch {epoch} Batch {batch} (n,m) ({n},{m}) Loss: {loss:.4f} Accuracy: {accuracy:.4f} Average Prediction: {avg_pred:.4f}".format( 36 | timestamp = timestamp(), 37 | memory = memory_usage(), 38 | epoch = epoch, 39 | batch = b, 40 | loss = loss_val, 41 | accuracy = accuracy_val, 42 | avg_pred = avg_pred, 43 | n = batch.total_n, 44 | m = batch.total_m 45 | ), 46 | flush = True 47 | ) 48 | return loss_val, accuracy_val, avg_pred 49 | #end run_and_log_batch 50 | 51 | def test_with( sess, solver, path, name, time_steps = 26, batch_size = 1 ): 52 | # Load test instances 53 | print( "{timestamp}\t{memory}\tLoading test {name} instances ...".format( timestamp = timestamp(), memory = memory_usage(), name = name ) ) 54 | test_generator = instance_loader.InstanceLoader( path ) 55 | test_loss = 0.0 56 | test_accuracy = 0.0 57 | test_avg_pred = 0.0 58 | test_batches = 0 59 | # Run with the test instances 60 | print( "{timestamp}\t{memory}\t{name} TEST SET BEGIN".format( timestamp = timestamp(), memory = memory_usage(), name = name ) ) 61 | for b, batch in enumerate( test_generator.get_batches( batch_size ) ): 62 | l, a, p = run_and_log_batch( sess, solver, name, b, batch, time_steps, train = False ) 63 | test_loss += l 64 | test_accuracy += a 65 | test_avg_pred += p 66 | test_batches += 1 67 | #end for 68 | # Summarize results and print test summary 69 | test_loss /= test_batches 70 | test_accuracy /= test_batches 71 | test_avg_pred /= test_batches 72 | print( "{timestamp}\t{memory}\t{name} TEST SET END Mean loss: {loss:.4f} Mean Accuracy = {accuracy} Mean prediction {avg_pred:.4f}".format( 73 | loss = test_loss, 74 | accuracy = test_accuracy, 75 | avg_pred = test_avg_pred, 76 | timestamp = timestamp(), 77 | memory = memory_usage(), 78 | name = name 79 | ) 80 | ) 81 | #end test_with 82 | -------------------------------------------------------------------------------- /neurosat/model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | # Import model builder 4 | from tgn import TGN 5 | from mlp import Mlp 6 | from cnf import CNF 7 | 8 | def build_neurosat(d): 9 | 10 | # Hyperparameters 11 | learning_rate = 2e-5 12 | parameter_l2norm_scaling = 1e-10 13 | global_norm_gradient_clipping_ratio = 0.65 14 | 15 | # Define placeholder for satisfiability statuses (one per problem) 16 | instance_SAT = tf.placeholder( tf.float32, [ None ], name = "instance_SAT" ) 17 | time_steps = tf.placeholder(tf.int32, shape=(), name='time_steps') 18 | matrix_placeholder = tf.placeholder( tf.float32, [ None, None ], name = "M" ) 19 | num_vars_on_instance = tf.placeholder( tf.int32, [ None ], name = "instance_n" ) 20 | 21 | # Literals 22 | s = tf.shape( matrix_placeholder ) 23 | l = s[0] 24 | m = s[1] 25 | n = tf.floordiv( l, tf.constant( 2 ) ) 26 | # Compute number of problems 27 | p = tf.shape( instance_SAT )[0] 28 | 29 | # Define INV, a tf function to exchange positive and negative literal embeddings 30 | def INV(Lh): 31 | l = tf.shape(Lh)[0] 32 | n = tf.div(l,tf.constant(2)) 33 | # Send messages from negated literals to positive ones, and vice-versa 34 | Lh_pos = tf.gather( Lh, tf.range( tf.constant( 0 ), n ) ) 35 | Lh_neg = tf.gather( Lh, tf.range( n, l ) ) 36 | Lh_inverted = tf.concat( [ Lh_neg, Lh_pos ], axis = 0 ) 37 | return Lh_inverted 38 | #end 39 | 40 | var = { "L": d, "C": d } 41 | s = tf.shape( matrix_placeholder ) 42 | num_vars = { "L": l, "C": m } 43 | initial_embeddings = { v:tf.get_variable(initializer=tf.random_normal((1,d)), dtype=tf.float32, name='{v}_init'.format(v=v)) for (v,d) in var.items() } 44 | tiled_and_normalized_initial_embeddings = { 45 | v: tf.tile( 46 | tf.div( 47 | init, 48 | tf.sqrt( tf.cast( var[v], tf.float32 ) ) 49 | ), 50 | [ num_vars[v], 1 ] 51 | ) for v, init in initial_embeddings.items() 52 | } 53 | 54 | # Define Typed Graph Network 55 | gnn = TGN( 56 | var, 57 | { 58 | "M": ("L","C") 59 | }, 60 | { 61 | "Lmsg": ("L","C"), 62 | "Cmsg": ("C","L") 63 | }, 64 | { 65 | "L": [ 66 | { 67 | "fun": INV, 68 | "var": "L" 69 | }, 70 | { 71 | "mat": "M", 72 | "msg": "Cmsg", 73 | "var": "C" 74 | } 75 | ], 76 | "C": [ 77 | { 78 | "mat": "M", 79 | "transpose?": True, 80 | "msg": "Lmsg", 81 | "var": "L" 82 | } 83 | ] 84 | }, 85 | name="NeuroSAT" 86 | ) 87 | 88 | # Define L_vote 89 | L_vote_MLP = Mlp( 90 | layer_sizes = [ d for _ in range(3) ], 91 | activations = [ tf.nn.relu for _ in range(3) ], 92 | output_size = 1, 93 | name = "L_vote", 94 | name_internal_layers = True, 95 | kernel_initializer = tf.contrib.layers.xavier_initializer(), 96 | bias_initializer = tf.zeros_initializer() 97 | ) 98 | 99 | # Get the last embeddings 100 | L_n = gnn( 101 | { "M": matrix_placeholder }, 102 | tiled_and_normalized_initial_embeddings, 103 | time_steps 104 | )["L"].h 105 | L_vote = L_vote_MLP( L_n ) 106 | 107 | # Reorganize votes' result to obtain a prediction for each problem instance 108 | def _vote_while_cond(i, p, n_acc, n, n_var_list, predicted_sat, L_vote): 109 | return tf.less( i, p ) 110 | #end _vote_while_cond 111 | 112 | def _vote_while_body(i, p, n_acc, n, n_var_list, predicted_SAT, L_vote): 113 | # Helper for the amount of variables in this problem 114 | i_n = n_var_list[i] 115 | # Gather the positive and negative literals for that problem 116 | pos_lits = tf.gather( L_vote, tf.range( n_acc, tf.add( n_acc, i_n ) ) ) 117 | neg_lits = tf.gather( L_vote, tf.range( tf.add( n, n_acc ), tf.add( n, tf.add( n_acc, i_n ) ) ) ) 118 | # Concatenate positive and negative literals and average their vote values 119 | problem_predicted_SAT = tf.reduce_mean( tf.concat( [pos_lits, neg_lits], axis = 1 ) ) 120 | # Update TensorArray 121 | predicted_SAT = predicted_SAT.write( i, problem_predicted_SAT ) 122 | return tf.add( i, tf.constant( 1 ) ), p, tf.add( n_acc, i_n ), n, n_var_list, predicted_SAT, L_vote 123 | #end _vote_while_body 124 | 125 | predicted_SAT = tf.TensorArray( size = p, dtype = tf.float32 ) 126 | _, _, _, _, _, predicted_SAT, _ = tf.while_loop( 127 | _vote_while_cond, 128 | _vote_while_body, 129 | [ tf.constant( 0, dtype = tf.int32 ), p, tf.constant( 0, dtype = tf.int32 ), n, num_vars_on_instance, predicted_SAT, L_vote ] 130 | ) 131 | predicted_SAT = predicted_SAT.stack() 132 | 133 | # Define loss, accuracy 134 | predict_costs = tf.nn.sigmoid_cross_entropy_with_logits( labels = instance_SAT, logits = predicted_SAT ) 135 | predict_cost = tf.reduce_mean( predict_costs ) 136 | vars_cost = tf.zeros([]) 137 | tvars = tf.trainable_variables() 138 | for var in tvars: 139 | vars_cost = tf.add( vars_cost, tf.nn.l2_loss( var ) ) 140 | #end for 141 | loss = tf.add( predict_cost, tf.multiply( vars_cost, parameter_l2norm_scaling ) ) 142 | optimizer = tf.train.AdamOptimizer( name = "Adam", learning_rate = learning_rate ) 143 | grads, _ = tf.clip_by_global_norm( tf.gradients( loss, tvars ), global_norm_gradient_clipping_ratio ) 144 | train_step = optimizer.apply_gradients( zip( grads, tvars ) ) 145 | 146 | accuracy = tf.reduce_mean( 147 | tf.cast( 148 | tf.equal( 149 | tf.cast( instance_SAT, tf.bool ), 150 | tf.cast( tf.round( tf.nn.sigmoid( predicted_SAT ) ), tf.bool ) 151 | ) 152 | , tf.float32 153 | ) 154 | ) 155 | 156 | # Define neurosat dictionary 157 | neurosat = {} 158 | neurosat["M"] = matrix_placeholder 159 | neurosat["time_steps"] = time_steps 160 | neurosat["gnn"] = gnn 161 | neurosat["instance_SAT"] = instance_SAT 162 | neurosat["predicted_SAT"] = predicted_SAT 163 | neurosat["num_vars_on_instance"] = num_vars_on_instance 164 | neurosat["loss"] = loss 165 | neurosat["accuracy"] = accuracy 166 | neurosat["train_step"] = train_step 167 | 168 | return neurosat 169 | #end build_neurosat 170 | -------------------------------------------------------------------------------- /neurosat/neurosat.py: -------------------------------------------------------------------------------- 1 | import sys, os, time 2 | os.environ['TF_CPP_MIN_LOG_LEVEL']='2' 3 | sys.path.insert(1, os.path.join(sys.path[0], '..')) 4 | import tensorflow as tf 5 | # Import model builder 6 | from model import build_neurosat 7 | # Import tools 8 | from cnf import ensure_datasets 9 | import instance_loader 10 | import itertools 11 | from util import timestamp, memory_usage 12 | from logutil import run_and_log_batch 13 | 14 | if __name__ == '__main__': 15 | print( "{timestamp}\t{memory}\tMaking sure ther datasets exits ...".format( timestamp = timestamp(), memory = memory_usage() ) ) 16 | ensure_datasets() 17 | if not os.path.isdir( "tmp" ): 18 | os.makedirs( "tmp" ) 19 | #end if 20 | epochs = 2**10 21 | d = 128 22 | 23 | time_steps = 26 24 | batch_size = 128 25 | batches_per_epoch = 128 26 | 27 | early_stopping_window = [ 0 for _ in range(3) ] 28 | early_stopping_threshold = 0.85 29 | 30 | # Build model 31 | print( "{timestamp}\t{memory}\tBuilding model ...".format( timestamp = timestamp(), memory = memory_usage() ) ) 32 | solver = build_neurosat( d ) 33 | 34 | # Create batch loader 35 | print( "{timestamp}\t{memory}\tLoading instances ...".format( timestamp = timestamp(), memory = memory_usage() ) ) 36 | generator = instance_loader.InstanceLoader( "./instances" ) 37 | # If you want to use the entire dataset on each epoch, use: 38 | # batches_per_epoch = len(generator.filenames) // batch_size 39 | 40 | test_generator = instance_loader.InstanceLoader( "./test-instances" ) 41 | 42 | # Create model saver 43 | saver = tf.train.Saver() 44 | 45 | # Disallow GPU use 46 | config = tf.ConfigProto( 47 | #device_count = {"GPU":0}, 48 | gpu_options = tf.GPUOptions( allow_growth = True ), 49 | ) 50 | with tf.Session(config=config) as sess: 51 | 52 | # Initialize global variables 53 | print( "{timestamp}\t{memory}\tInitializing global variables ... ".format( timestamp = timestamp(), memory = memory_usage() ) ) 54 | sess.run( tf.global_variables_initializer() ) 55 | 56 | if os.path.exists( "./tmp/neurosat.ckpt" ): 57 | # Restore saved weights 58 | print( "{timestamp}\t{memory}\tRestoring saved model ... ".format( timestamp = timestamp(), memory = memory_usage() ) ) 59 | saver.restore(sess, "./tmp/neurosat.ckpt") 60 | #end if 61 | 62 | # Run for a number of epochs 63 | print( "{timestamp}\t{memory}\tRunning for {} epochs".format( epochs, timestamp = timestamp(), memory = memory_usage() ) ) 64 | for epoch in range( epochs ): 65 | 66 | # Save current weights 67 | save_path = saver.save(sess, "./tmp/neurosat.ckpt") 68 | print( "{timestamp}\t{memory}\tMODEL SAVED IN PATH: {save_path}".format( timestamp = timestamp(), memory = memory_usage(), save_path=save_path ) ) 69 | 70 | if all( [ early_stopping_threshold < v for v in early_stopping_window ] ): 71 | print( "{timestamp}\t{memory}\tEARLY STOPPING because the test accuracy on the last {epochs} epochs were above {threshold:.2f}% accuracy.".format( timestamp = timestamp(), memory = memory_usage(), epochs = len( early_stopping_window ), threshold = early_stopping_threshold * 100 ) ) 72 | break 73 | #end if 74 | 75 | # Reset training generator and run with a sample of the training instances 76 | print( "{timestamp}\t{memory}\tTRAINING SET BEGIN".format( timestamp = timestamp(), memory = memory_usage() ) ) 77 | generator.reset() 78 | epoch_loss = 0.0 79 | epoch_accuracy = 0.0 80 | for b, batch in itertools.islice( enumerate( generator.get_batches( batch_size ) ), batches_per_epoch ): 81 | l, a, p = run_and_log_batch( sess, solver, epoch, b, batch, time_steps ) 82 | epoch_loss += l 83 | epoch_accuracy += a 84 | #end for 85 | epoch_loss /= batches_per_epoch 86 | epoch_accuracy /= batches_per_epoch 87 | print( "{timestamp}\t{memory}\tTRAINING SET END Mean loss: {loss:.4f} Mean Accuracy = {accuracy:.4f}".format( 88 | loss = epoch_loss, 89 | accuracy = epoch_accuracy, 90 | timestamp = timestamp(), 91 | memory = memory_usage() 92 | ) 93 | ) 94 | # Summarize results and print epoch summary 95 | test_loss = 0.0 96 | test_accuracy = 0.0 97 | test_batches = 0 98 | # Reset test generator and run with the test instances 99 | print( "{timestamp}\t{memory}\tTEST SET BEGIN".format( timestamp = timestamp(), memory = memory_usage() ) ) 100 | test_generator.reset() 101 | for b, batch in enumerate( test_generator.get_batches( batch_size ) ): 102 | l, a, p = run_and_log_batch( sess, solver, epoch, b, batch, time_steps, train = False ) 103 | test_loss += l 104 | test_accuracy += a 105 | test_batches += 1 106 | #end for 107 | # Summarize results and print test summary 108 | test_loss /= test_batches 109 | test_accuracy /= test_batches 110 | print( "{timestamp}\t{memory}\tTEST SET END Mean loss: {loss:.4f} Mean Accuracy = {accuracy:.4f}".format( 111 | loss = test_loss, 112 | accuracy = test_accuracy, 113 | timestamp = timestamp(), 114 | memory = memory_usage() 115 | ) 116 | ) 117 | 118 | early_stopping_window = early_stopping_window[1:] + [ test_accuracy ] 119 | #end for 120 | #end Session 121 | -------------------------------------------------------------------------------- /neurosat/sat-test.py: -------------------------------------------------------------------------------- 1 | import sys, os, time 2 | os.environ['TF_CPP_MIN_LOG_LEVEL']='2' 3 | sys.path.insert(1, os.path.join(sys.path[0], '..')) 4 | import tensorflow as tf 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | from model import build_neurosat 8 | import instance_loader 9 | import itertools 10 | from logutil import test_with 11 | from util import timestamp, memory_usage 12 | from cnf import ensure_datasets 13 | 14 | if __name__ == "__main__": 15 | print( "{timestamp}\t{memory}\tMaking sure ther datasets exits ...".format( timestamp = timestamp(), memory = memory_usage() ) ) 16 | ensure_datasets( make_critical = True ) 17 | if not os.path.isdir( "tmp" ): 18 | sys.exit(1) 19 | #end if 20 | d = 128 21 | batch_size = 64 22 | if 1 < len( sys.argv ): 23 | test_time_steps = int( sys.argv[1] ) # Use a much bigger number of time steps 24 | else: 25 | test_time_steps = 28 26 | #end if 27 | test_batch_size = batch_size 28 | 29 | # Build model 30 | print( "{timestamp}\t{memory}\tBuilding model testing with {time_steps} time_steps ...".format( timestamp = timestamp(), memory = memory_usage(), time_steps = test_time_steps ) ) 31 | solver = build_neurosat( d ) 32 | 33 | # Create model saver 34 | saver = tf.train.Saver() 35 | 36 | with tf.Session() as sess: 37 | 38 | # Initialize global variables 39 | print( "{timestamp}\t{memory}\tInitializing global variables ... ".format( timestamp = timestamp(), memory = memory_usage() ) ) 40 | sess.run( tf.global_variables_initializer() ) 41 | 42 | # Restore saved weights 43 | print( "{timestamp}\t{memory}\tRestoring saved model ... ".format( timestamp = timestamp(), memory = memory_usage() ) ) 44 | saver.restore(sess, "./tmp/neurosat.ckpt") 45 | 46 | # Test SR distribution 47 | test_with( 48 | sess, 49 | solver, 50 | "./test-instances", 51 | "SR", 52 | time_steps = test_time_steps 53 | ) 54 | # Test Phase Transition distribution 55 | test_with( 56 | sess, 57 | solver, 58 | "./critical-instances-40", 59 | "PT40", 60 | time_steps = test_time_steps 61 | ) 62 | test_with( 63 | sess, 64 | solver, 65 | "./critical-instances-80", 66 | "PT80", 67 | time_steps = test_time_steps 68 | ) 69 | -------------------------------------------------------------------------------- /tgn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from mlp import Mlp 3 | 4 | class TGN(object): 5 | def __init__( 6 | self, 7 | var, 8 | mat, 9 | msg, 10 | loop, 11 | MLP_depth = 3, 12 | MLP_weight_initializer = tf.contrib.layers.xavier_initializer, 13 | MLP_bias_initializer = tf.zeros_initializer, 14 | RNN_cell = tf.contrib.rnn.LayerNormBasicLSTMCell, 15 | Cell_activation = tf.nn.relu, 16 | Msg_activation = tf.nn.relu, 17 | Msg_last_activation = None, 18 | float_dtype = tf.float32, 19 | name = 'TGN' 20 | ): 21 | """ 22 | Receives three dictionaries: var, mat and msg. 23 | 24 | ○ var is a dictionary from variable names to embedding sizes. 25 | That is: an entry var["V1"] = 10 means that the variable "V1" will have an embedding size of 10. 26 | 27 | ○ mat is a dictionary from matrix names to variable pairs. 28 | That is: an entry mat["M"] = ("V1","V2") means that the matrix "M" can be used to mask messages from "V1" to "V2". 29 | 30 | ○ msg is a dictionary from function names to variable pairs. 31 | That is: an entry msg["cast"] = ("V1","V2") means that one can apply "cast" to convert messages from "V1" to "V2". 32 | 33 | ○ loop is a dictionary from variable names to lists of dictionaries: 34 | { 35 | "mat": the matrix name which will be used, 36 | "transpose?": if true then the matrix M will be transposed, 37 | "fun": transfer function (python function built using tensorflow operations, 38 | "msg": message name, 39 | "var": variable name 40 | } 41 | If "mat" is None, it will be the identity matrix, 42 | If "transpose?" is None, it will default to false, 43 | if "fun" is None, no function will be applied, 44 | If "msg" is false, no message conversion function will be applied, 45 | If "var" is false, then [1] will be supplied as a surrogate. 46 | 47 | That is: an entry loop["V2"] = [ {"mat":None,"fun":f,"var":"V2"}, {"mat":"M","transpose?":true,"msg":"cast","var":"V1"} ] enforces the following update rule for every timestep: 48 | V2 ← tf.append( [ f(V2), Mᵀ × cast(V1) ] ) 49 | """ 50 | self.var, self.mat, self.msg, self.loop, self.name = var, mat, msg, loop, name 51 | 52 | self.MLP_depth = MLP_depth 53 | self.MLP_weight_initializer = MLP_weight_initializer 54 | self.MLP_bias_initializer = MLP_bias_initializer 55 | self.RNN_cell = RNN_cell 56 | self.Cell_activation = Cell_activation 57 | self.Msg_activation = Msg_activation 58 | self.Msg_last_activation = Msg_last_activation 59 | self.float_dtype = float_dtype 60 | 61 | # Check model for inconsistencies 62 | self.check_model() 63 | 64 | # Initialize the parameters 65 | with tf.variable_scope(self.name): 66 | with tf.variable_scope('parameters'): 67 | self._init_parameters() 68 | #end parameter scope 69 | #end TGN scope 70 | #end __init__ 71 | 72 | def check_model(self): 73 | # Procedure to check model for inconsistencies 74 | for v in self.var: 75 | if v not in self.loop: 76 | raise Warning('Variable {v} is not updated anywhere! Consider removing it from the model'.format(v=v)) 77 | #end if 78 | #end for 79 | 80 | for v in self.loop: 81 | if v not in self.var: 82 | raise Exception('Updating variable {v}, which has not been declared!'.format(v=v)) 83 | #end if 84 | #end for 85 | 86 | for mat, (v1,v2) in self.mat.items(): 87 | if v1 not in self.var: 88 | raise Exception('Matrix {mat} definition depends on undeclared variable {v}'.format(mat=mat, v=v1)) 89 | #end if 90 | if v2 not in self.var and type(v2) is not int: 91 | raise Exception('Matrix {mat} definition depends on undeclared variable {v}'.format(mat=mat, v=v2)) 92 | #end if 93 | #end for 94 | 95 | for msg, (v1,v2) in self.msg.items(): 96 | if v1 not in self.var: 97 | raise Exception('Message {msg} maps from undeclared variable {v}'.format(msg=msg, v=v1)) 98 | #end if 99 | if v2 not in self.var: 100 | raise Exception('Message {msg} maps to undeclared variable {v}'.format(msg=msg, v=v2)) 101 | #end if 102 | #end for 103 | #end check_model 104 | 105 | def _init_parameters(self): 106 | # Init LSTM cells 107 | self._RNN_cells = { 108 | v: self.RNN_cell( 109 | d, 110 | activation = self.Cell_activation 111 | ) for (v,d) in self.var.items() 112 | } 113 | # Init message-computing MLPs 114 | self._msg_MLPs = { 115 | msg: Mlp( 116 | layer_sizes = [ self.var[vin] for _ in range( self.MLP_depth ) ], 117 | output_size = self.var[vout], 118 | activations = [ self.Msg_activation for _ in range( self.MLP_depth ) ], 119 | output_activation = self.Msg_last_activation, 120 | kernel_initializer = self.MLP_weight_initializer(), 121 | bias_initializer = self.MLP_weight_initializer(), 122 | name = msg, 123 | name_internal_layers = True 124 | ) for msg, (vin,vout) in self.msg.items() 125 | } 126 | #end _init_parameters 127 | 128 | def __call__( self, adjacency_matrices, initial_embeddings, time_steps, LSTM_initial_states = {} ): 129 | with tf.variable_scope(self.name): 130 | with tf.variable_scope( "assertions" ): 131 | assertions = self.check_run( adjacency_matrices, initial_embeddings, time_steps, LSTM_initial_states ) 132 | #end assertion variable scope 133 | with tf.control_dependencies( assertions ): 134 | states = {} 135 | for v, init in initial_embeddings.items(): 136 | h0 = init 137 | c0 = tf.zeros_like(h0, dtype=self.float_dtype) if v not in LSTM_initial_states else LSTM_initial_states[v] 138 | states[v] = tf.contrib.rnn.LSTMStateTuple(h=h0, c=c0) 139 | #end 140 | 141 | # Build while loop body function 142 | def while_body( t, states ): 143 | new_states = {} 144 | for v in self.var: 145 | inputs = [] 146 | for update in self.loop[v]: 147 | if 'var' in update: 148 | y = states[update['var']].h 149 | if 'fun' in update: 150 | y = update['fun'](y) 151 | #end if 152 | if 'msg' in update: 153 | y = self._msg_MLPs[update['msg']](y) 154 | #end if 155 | if 'mat' in update: 156 | y = tf.matmul( 157 | adjacency_matrices[update['mat']], 158 | y, 159 | adjoint_a = update['transpose?'] if 'transpose?' in update else False 160 | ) 161 | #end if 162 | inputs.append( y ) 163 | else: 164 | inputs.append( adjacency_matrices[update['mat']] ) 165 | #end if var in update 166 | #end for update in loop 167 | inputs = tf.concat( inputs, axis = 1 ) 168 | with tf.variable_scope( '{v}_cell'.format( v = v ) ): 169 | _, new_states[v] = self._RNN_cells[v]( inputs = inputs, state = states[v] ) 170 | #end cell scope 171 | #end for v in var 172 | return (t+1), new_states 173 | #end while_body 174 | 175 | _, last_states = tf.while_loop( 176 | lambda t, states: tf.less( t, time_steps ), 177 | while_body, 178 | [0,states] 179 | ) 180 | #end assertions 181 | #end Graph scope 182 | return last_states 183 | #end __call__ 184 | 185 | def check_run( self, adjacency_matrices, initial_embeddings, time_steps, LSTM_initial_states ): 186 | assertions = [] 187 | # Procedure to check model for inconsistencies 188 | num_vars = {} 189 | for v, d in self.var.items(): 190 | init_shape = tf.shape( initial_embeddings[v] ) 191 | num_vars[v] = init_shape[0] 192 | assertions.append( 193 | tf.assert_equal( 194 | init_shape[1], 195 | d, 196 | data = [ init_shape[1] ], 197 | message = "Initial embedding of variable {v} doesn't have the same dimensionality {d} as declared".format( 198 | v = v, 199 | d = d 200 | ) 201 | ) 202 | ) 203 | if v in LSTM_initial_states: 204 | lstm_init_shape = tf.shape( LSTM_initial_states[v] ) 205 | assertions.append( 206 | tf.assert_equal( 207 | lstm_init_shape[1], 208 | d, 209 | data = [ lstm_init_shape[1] ], 210 | message = "Initial hidden state of variable {v}'s LSTM doesn't have the same dimensionality {d} as declared".format( 211 | v = v, 212 | d = d 213 | ) 214 | ) 215 | ) 216 | 217 | assertions.append( 218 | tf.assert_equal( 219 | lstm_init_shape, 220 | init_shape, 221 | data = [ init_shape, lstm_init_shape ], 222 | message = "Initial embeddings of variable {v} don't have the same shape as the its LSTM's initial hidden state".format( 223 | v = v, 224 | d = d 225 | ) 226 | ) 227 | ) 228 | #end if 229 | #end for v 230 | 231 | for mat, (v1,v2) in self.mat.items(): 232 | mat_shape = tf.shape( adjacency_matrices[mat] ) 233 | assertions.append( 234 | tf.assert_equal( 235 | mat_shape[0], 236 | num_vars[v1], 237 | data = [ mat_shape[0], num_vars[v1] ], 238 | message = "Matrix {m} doesn't have the same number of nodes as the initial embeddings of its variable {v}".format( 239 | v = v1, 240 | m = mat 241 | ) 242 | ) 243 | ) 244 | if type(v2) is int: 245 | assertions.append( 246 | tf.assert_equal( 247 | mat_shape[1], 248 | v2, 249 | data = [ mat_shape[1], v2 ], 250 | message = "Matrix {m} doesn't have the same dimensionality {d} on the second variable as declared".format( 251 | m = mat, 252 | d = v2 253 | ) 254 | ) 255 | ) 256 | else: 257 | assertions.append( 258 | tf.assert_equal( 259 | mat_shape[1], 260 | num_vars[v2], 261 | data = [ mat_shape[1], num_vars[v2] ], 262 | message = "Matrix {m} doesn't have the same number of nodes as the initial embeddings of its variable {v}".format( 263 | v = v2, 264 | m = mat 265 | ) 266 | ) 267 | ) 268 | #end if-else 269 | #end for mat, (v1,v2) 270 | return assertions 271 | #end check_run 272 | #end TGN 273 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import time, sys, os, random 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | def timestamp(): 6 | return time.strftime( "%Y%m%d%H%M%S", time.gmtime() ) 7 | #end timestamp 8 | 9 | def memory_usage(): 10 | pid=os.getpid() 11 | s = next( line for line in open( '/proc/{}/status'.format( pid ) ).read().splitlines() if line.startswith( 'VmSize' ) ).split() 12 | return "{} {}".format( s[-2], s[-1] ) 13 | #end memory_usage 14 | 15 | def sparse_to_dense( M_sparse, default = 0.0 ): 16 | M_i, M_v, M_shape = M_sparse 17 | n, m = M_shape 18 | M = np.ones( (n, m), dtype = np.float32 ) * default 19 | for indexes, value in zip( M_i, M_v ): 20 | i,j = indexes 21 | M[i,j] = value 22 | #end for 23 | return M 24 | #end sparse_to_dense 25 | 26 | def dense_to_sparse( M, check = lambda x: x != 0, val = lambda x: x ): 27 | n, m = M.shape 28 | M_i = [] 29 | M_v = [] 30 | M_shape = (n,m) 31 | for i in range( n ): 32 | for j in range( m ): 33 | if check( M[i,j] ): 34 | M_i.append( (i,j ) ) 35 | M_v.append( val( M[i,j] ) ) 36 | #end if 37 | #end for 38 | #end for 39 | return (M_i,M_v,M_shape) 40 | #end dense_to_sparse 41 | 42 | def reindex_matrix( n, m, M ): 43 | new_index = [] 44 | new_value = [] 45 | for i, v in zip( M[0], M[1] ): 46 | s, t = i 47 | new_index.append( (n + s, m + t) ) 48 | new_value.append( v ) 49 | #end for 50 | return zip( new_index, new_value ) 51 | #end reindex_matrix 52 | 53 | def load_weights(sess,path,scope=None): 54 | if os.path.exists(path): 55 | # Restore saved weights 56 | print( "{timestamp}\t{memory}\tRestoring saved model ... ".format( timestamp = timestamp(), memory = memory_usage() ) ) 57 | # Create model saver 58 | if scope is None: 59 | saver = tf.train.Saver() 60 | else: 61 | saver = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope)) 62 | #end 63 | saver.restore(sess, "%s/model.ckpt" % path) 64 | #end if 65 | #end 66 | 67 | def save_weights(sess,path,scope=None): 68 | # Create /tmp/ directory to save weights 69 | if not os.path.exists(path): 70 | os.makedirs(path) 71 | #end if 72 | # Create model saver 73 | if scope is None: 74 | saver = tf.train.Saver() 75 | else: 76 | saver = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope)) 77 | #end 78 | saver.save(sess, "%s/model.ckpt" % path) 79 | print( "{timestamp}\t{memory}\tMODEL SAVED IN PATH: {path}".format( timestamp = timestamp(), memory = memory_usage(), path=path ) ) 80 | #end 81 | --------------------------------------------------------------------------------