├── API.py ├── CAD.py ├── CNN.py ├── ForwardSample.py ├── MCTS.py ├── MHDPA.py ├── README.md ├── SMC.py ├── pointerNetwork.py ├── programGraph.py ├── randomSolver.py └── utilities.py /API.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | 4 | class Solver: 5 | def __init__(self, dsl): 6 | self.dsl = dsl 7 | pass 8 | 9 | def _report(self, program): 10 | l = self.loss(program) 11 | if len(self.reportedSolutions) == 0 or self.reportedSolutions[-1].loss > l: 12 | self.reportedSolutions.append(SearchResult(program, l, time.time() - self.startTime)) 13 | 14 | def infer(self, spec, loss, timeout): 15 | """ 16 | spec: specification of goal 17 | loss: function from (spec, program) to real 18 | timeout: maximum time to run solver, measured in seconds 19 | returns: list of `SearchResult`s 20 | Should take no longer than timeout seconds.""" 21 | self.reportedSolutions = [] 22 | self.startTime = time.time() 23 | self.loss = lambda p: loss(spec, p) 24 | 25 | with torch.no_grad(): 26 | self._infer(spec, loss, timeout) 27 | 28 | self.loss = None # in case we need to serialize this object and loss is a lambda 29 | 30 | return self.reportedSolutions 31 | 32 | def _infer(self, spec, loss, timeout): 33 | assert False, "not implemented" 34 | 35 | class SearchResult: 36 | def __init__(self, program, loss, time): 37 | self.program = program 38 | self.loss = loss 39 | self.time = time 40 | 41 | 42 | class ParseFailure(Exception): 43 | """Objects of type Program should throw this exception in their constructor if their arguments are bad""" 44 | 45 | class DSL: 46 | def __init__(self, operators, lexicon=None): 47 | """ 48 | operators: a list of classes that inherit from Program 49 | lexicon: (optionally) a list of symbols in the serialization of programs built from those operators 50 | """ 51 | self.lexicon = lexicon 52 | self.operators = operators 53 | 54 | self.tokenToOperator = {o.token: o 55 | for o in operators} 56 | 57 | def __str__(self): 58 | return "DSL({%s})"%(", ".join( f"{o.__name__} : {str(o.type)}" 59 | for o in self.operators )) 60 | 61 | def parseLine(self, tokens): 62 | """ 63 | Parses a serialized line of code into a Program object. 64 | Returns None if the DSL cannot parse the serialized code. 65 | """ 66 | if len(tokens) == 0 or tokens[0] not in self.tokenToOperator: return None 67 | 68 | f = self.tokenToOperator[tokens[0]] 69 | ft = f.type 70 | 71 | if ft.isArrow: # Expects arguments 72 | # Make sure we have the right number of arguments 73 | tokens = tokens[1:] 74 | if len(tokens) != len(ft.arguments): return None 75 | # Make sure that each token is an instance of the correct type 76 | for token, argument_type in zip(tokens, ft.arguments): 77 | if not argument_type.instance(token): return None 78 | # Type checking succeeded - try building the object 79 | try: 80 | return f(*tokens) 81 | except ParseFailure: return None 82 | else: # Does not expect any arguments - just call the constructor with no arguments 83 | if len(tokens) > 1: return None # got arguments when we were not expecting any 84 | return f() 85 | 86 | class Program: 87 | 88 | # TODO: implement type property 89 | 90 | def execute(self, context): 91 | assert False, "not implemented" 92 | 93 | def children(self): 94 | assert False, "not implemented" 95 | 96 | 97 | class Type(): 98 | @property 99 | def isArrow(self): return False 100 | 101 | @property 102 | def isInteger(self): return False 103 | 104 | @property 105 | def isBase(self): return False 106 | 107 | def returnType(self): 108 | """What this type indicates the expression should return. For arrows this is the right-hand side. Otherwise it is just the type.""" 109 | return self 110 | 111 | class BaseType(Type): 112 | def __init__(self, thing): 113 | self.constructor = thing 114 | 115 | def __str__(self): 116 | return self.constructor.__name__ 117 | 118 | @property 119 | def isBase(self): return False 120 | 121 | def instance(self, x): 122 | return isinstance(x, self.constructor) 123 | 124 | class arrow(Type): 125 | def __init__(self, *args): 126 | assert len(args) > 1 127 | for a in args: 128 | assert isinstance(a, Type) 129 | self.out = args[-1] 130 | self.arguments = args[:-1] 131 | 132 | def __str__(self): 133 | return " -> ".join( str(t) for t in list(self.arguments) + [self.out] ) 134 | 135 | @property 136 | def isArrow(self): return True 137 | 138 | def instance(self, x): 139 | assert False, "Cannot check whether a object is an instance of a arrow type" 140 | 141 | def returnType(self): return self.out 142 | 143 | class integer(Type): 144 | def __init__(self,lower,upper): 145 | assert type(lower) is int 146 | assert type(upper) is int 147 | self.upper = upper 148 | self.lower = lower 149 | 150 | def __str__(self): 151 | return f"int({self.lower}, {self.upper})" 152 | 153 | @property 154 | def isInteger(self): return True 155 | 156 | def instance(self, x): 157 | return isinstance(x, int) and x >= self.lower and x <= self.upper 158 | 159 | 160 | -------------------------------------------------------------------------------- /CAD.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | 4 | from API import * 5 | 6 | from randomSolver import * 7 | from pointerNetwork import * 8 | from programGraph import * 9 | from SMC import * 10 | from ForwardSample import * 11 | from MCTS import MCTS 12 | from CNN import * 13 | 14 | import time 15 | import random 16 | 17 | 18 | RESOLUTION = 32 19 | 20 | import torch 21 | import torch.nn as nn 22 | 23 | class CSG(Program): 24 | lexicon = ['+','-','t','c','r'] + list(range(RESOLUTION)) 25 | 26 | def __init__(self): 27 | self._rendering = None 28 | 29 | def __repr__(self): 30 | return str(self) 31 | 32 | def __ne__(self, o): return not (self == o) 33 | 34 | def execute(self): 35 | if self._rendering is None: self._rendering = self.render() 36 | return self._rendering 37 | 38 | def IoU(self, other): 39 | if isinstance(other, CSG): other = other.execute() 40 | return (self.execute()*other).sum()/(self.execute() + other - self.execute()*other).sum() 41 | 42 | def render(self, w=None, h=None): 43 | w = w or RESOLUTION 44 | h = h or RESOLUTION 45 | 46 | a = np.zeros((w,h)) 47 | for x in range(w): 48 | for y in range(h): 49 | if (x,y) in self: 50 | a[x,y] = 1 51 | return a 52 | 53 | # The type of CSG's 54 | tCSG = BaseType(CSG) 55 | 56 | class Rectangle(CSG): 57 | token = 'r' 58 | type = arrow(integer(0, RESOLUTION - 1), integer(0, RESOLUTION - 1), tCSG) 59 | 60 | def __init__(self, w, h): 61 | super(Rectangle, self).__init__() 62 | self.w = w 63 | self.h = h 64 | 65 | def toTrace(self): return [self] 66 | 67 | def __str__(self): 68 | return f"(r {self.w} {self.h})" 69 | 70 | def children(self): return [] 71 | 72 | def __eq__(self, o): 73 | return isinstance(o, Rectangle) and o.w == self.w and o.h == self.h 74 | 75 | def __hash__(self): 76 | return hash(('r',self.w,self.h)) 77 | 78 | def serialize(self): 79 | return (self.__class__.token, self.w, self.h) 80 | 81 | def __contains__(self, p): 82 | return p[0] >= 0 and p[1] >= 0 and \ 83 | p[0] < self.w and p[1] < self.h 84 | 85 | class Circle(CSG): 86 | token = 'c' 87 | type = arrow(integer(0, RESOLUTION - 1), tCSG) 88 | 89 | def __init__(self, r): 90 | super(Circle, self).__init__() 91 | self.r = r 92 | 93 | def toTrace(self): return [self] 94 | 95 | def __str__(self): 96 | return f"(c {self.r})" 97 | 98 | def children(self): return [] 99 | 100 | def __eq__(self, o): 101 | return isinstance(o, Circle) and o.r == self.r 102 | def __hash__(self): 103 | return hash(('c', str(self.r))) 104 | 105 | def serialize(self): 106 | return (self.__class__.token, self.r) 107 | 108 | def __contains__(self, p): 109 | return p[0]*p[0] + p[1]*p[1] <= self.r*self.r 110 | 111 | class Translation(CSG): 112 | token = 't' 113 | type = arrow(integer(0, RESOLUTION - 1), integer(0, RESOLUTION - 1), tCSG, tCSG) 114 | 115 | def __init__(self, x, y, child): 116 | super(Translation, self).__init__() 117 | self.v = (x, y) 118 | self.child = child 119 | 120 | def toTrace(self): return self.child.toTrace() + [self] 121 | 122 | def __str__(self): 123 | return f"(t {self.v} {self.child})" 124 | 125 | def children(self): return [self.child] 126 | 127 | def serialize(self): 128 | return ('t', self.v[0], self.v[1], self.child) 129 | 130 | def __eq__(self, o): 131 | return isinstance(o, Translation) and o.v == self.v and self.child == o.child 132 | 133 | def __hash__(self): 134 | return hash(('t', self.v, self.child)) 135 | 136 | def __contains__(self, p): 137 | p = (p[0] - self.v[0], 138 | p[1] - self.v[1]) 139 | return p in self.child 140 | 141 | class Union(CSG): 142 | token = '+' 143 | type = arrow(tCSG, tCSG, tCSG) 144 | 145 | def __init__(self, a, b): 146 | super(Union, self).__init__() 147 | self.elements = [a,b] 148 | 149 | def toTrace(self): 150 | return self.elements[0].toTrace() + self.elements[1].toTrace() + [self] 151 | 152 | def __str__(self): 153 | return f"(+ {str(self.elements[0])} {str(self.elements[1])})" 154 | 155 | def children(self): return self.elements 156 | 157 | def serialize(self): 158 | return ('+',list(self.elements)[0],list(self.elements)[1]) 159 | 160 | def __eq__(self, o): 161 | return isinstance(o, Union) and tuple(o.elements) == tuple(self.elements) 162 | 163 | def __hash__(self): 164 | return hash(('u', tuple(self.elements))) 165 | 166 | def __contains__(self, p): 167 | return any( p in e for e in self.elements ) 168 | 169 | class Difference(CSG): 170 | token = '-' 171 | type = arrow(tCSG, tCSG, tCSG) 172 | 173 | def __init__(self, a, b): 174 | super(Difference, self).__init__() 175 | self.a, self.b = a, b 176 | 177 | def toTrace(self): 178 | return self.a.toTrace() + self.b.toTrace() + [self] 179 | 180 | def __str__(self): 181 | return f"(- {self.a} {self.b})" 182 | 183 | def children(self): return [self.a, self.b] 184 | 185 | def serialize(self): 186 | return ('-',self.a,self.b) 187 | 188 | def __eq__(self, o): 189 | return isinstance(o, Difference) and self.a == o.a and self.b == o.b 190 | 191 | def __hash__(self): 192 | return hash(('-', hash(self.a), hash(self.b))) 193 | 194 | def __contains__(self, p): 195 | return p in self.a and (not (p in self.b)) 196 | 197 | dsl = DSL([Rectangle, Circle, Translation, Union, Difference], 198 | lexicon=CSG.lexicon) 199 | 200 | """Neural networks""" 201 | class ObjectEncoder(CNN): 202 | def __init__(self): 203 | super(ObjectEncoder, self).__init__(channels=2, 204 | inputImageDimension=RESOLUTION) 205 | 206 | def forward(self, spec, obj): 207 | if isinstance(obj, list): # batched - expect a single spec and multiple objects 208 | spec = np.repeat(spec[np.newaxis,:,:],len(obj),axis=0) 209 | obj = np.stack(obj) 210 | return super(ObjectEncoder, self).forward(np.stack([spec, obj],1)) 211 | else: # not batched 212 | return super(ObjectEncoder, self).forward(np.stack([spec, obj])) 213 | 214 | 215 | class SpecEncoder(CNN): 216 | def __init__(self): 217 | super(SpecEncoder, self).__init__(channels=1, 218 | inputImageDimension=RESOLUTION) 219 | 220 | 221 | """Training""" 222 | def randomScene(resolution=32, maxShapes=3, minShapes=1, verbose=False, export=None): 223 | dc = 8 # number of distinct coordinates 224 | def quadrilateral(): 225 | choices = [c 226 | for c in range(resolution//(dc*2), resolution, resolution//dc) ] 227 | w = random.choice([2,5]) 228 | h = random.choice([2,5]) 229 | x = random.choice(choices) 230 | y = random.choice(choices) 231 | return Translation(x,y, 232 | Rectangle(w,h)) 233 | 234 | def circular(): 235 | r = random.choice([2,4]) 236 | choices = [c 237 | for c in range(resolution//(dc*2), resolution, resolution//dc) ] 238 | x = random.choice(choices) 239 | y = random.choice(choices) 240 | return Translation(x,y, 241 | Circle(r)) 242 | s = None 243 | numberOfShapes = 0 244 | desiredShapes = random.choice(range(minShapes, 1 + maxShapes)) 245 | for _ in range(desiredShapes): 246 | o = quadrilateral() if random.choice([True,False]) else circular() 247 | if s is None: s = o 248 | else: 249 | if (s.execute()*o.execute()).sum() > 0.5: continue 250 | s = Union(s,o) 251 | numberOfShapes += 1 252 | if verbose: 253 | print(s) 254 | print(ProgramGraph.fromRoot(s, oneParent=True).prettyPrint()) 255 | import matplotlib.pyplot as plot 256 | plot.imshow(s.execute()) 257 | plot.show() 258 | if export: 259 | import matplotlib.pyplot as plot 260 | plot.imshow(s.execute()) 261 | plot.savefig(export) 262 | 263 | return s 264 | 265 | def trainCSG(m, getProgram, trainTime=None, checkpoint=None): 266 | print("cuda?",m.use_cuda) 267 | assert checkpoint is not None, "must provide a checkpoint path to export to" 268 | 269 | optimizer = torch.optim.Adam(m.parameters(), lr=0.001, eps=1e-3, amsgrad=True) 270 | 271 | startTime = time.time() 272 | reportingFrequency = 100 273 | totalLosses = [] 274 | movedLosses = [] 275 | iteration = 0 276 | 277 | while trainTime is None or time.time() - startTime < trainTime: 278 | s = getProgram() 279 | l = m.gradientStepTrace(optimizer, s.execute(), s.toTrace()) 280 | totalLosses.append(sum(l)) 281 | movedLosses.append(sum(l)/len(l)) 282 | 283 | if iteration%reportingFrequency == 0: 284 | print(f"\n\nAfter {iteration} gradient steps...\n\tTrace loss {sum(totalLosses)/len(totalLosses)}\t\tMove loss {sum(movedLosses)/len(movedLosses)}\n{iteration/(time.time() - startTime)} grad steps/sec") 285 | totalLosses = [] 286 | movedLosses = [] 287 | with open(checkpoint,"wb") as handle: 288 | pickle.dump(m, handle) 289 | 290 | iteration += 1 291 | 292 | def testCSG(m, getProgram, timeout, export): 293 | oneParent = m.oneParent 294 | solvers = [# RandomSolver(dsl), 295 | # MCTS(m, reward=lambda l: 1. - l), 296 | # SMC(m), 297 | ForwardSample(m, maximumLength=18)] 298 | loss = lambda spec, program: 1-max( o.IoU(spec) for o in program.objects() ) if len(program) > 0 else 1. 299 | 300 | testResults = [[] for _ in solvers] 301 | 302 | for _ in range(30): 303 | spec = getProgram() 304 | print("Trying to explain the program:") 305 | print(ProgramGraph.fromRoot(spec, oneParent=oneParent).prettyPrint()) 306 | print() 307 | for n, solver in enumerate(solvers): 308 | testSequence = solver.infer(spec.execute(), loss, timeout) 309 | testResults[n].append(testSequence) 310 | for result in testSequence: 311 | print(f"After time {result.time}, achieved loss {result.loss} w/") 312 | print(result.program.prettyPrint()) 313 | print() 314 | 315 | plotTestResults(testResults, timeout, 316 | defaultLoss=1., 317 | names=[# "MCTS","SMC", 318 | "FS"], 319 | export=export) 320 | 321 | def plotTestResults(testResults, timeout, defaultLoss=None, 322 | names=None, export=None): 323 | import matplotlib.pyplot as plot 324 | 325 | def averageLoss(n, T): 326 | results = testResults[n] # list of list of results, one for each test case 327 | # Filter out results that occurred after time T 328 | results = [ [r for r in rs if r.time <= T] 329 | for rs in results ] 330 | losses = [ min([defaultLoss] + [r.loss for r in rs]) for rs in results ] 331 | return sum(losses)/len(losses) 332 | 333 | plot.figure() 334 | plot.xlabel('Time') 335 | plot.ylabel('Average Loss') 336 | 337 | for n in range(len(testResults)): 338 | xs = list(np.arange(0,timeout,0.1)) 339 | plot.plot(xs, [averageLoss(n,x) for x in xs], 340 | label=names[n]) 341 | plot.legend() 342 | if export: 343 | plot.savefig(export) 344 | else: 345 | plot.show() 346 | 347 | 348 | 349 | 350 | 351 | 352 | if __name__ == "__main__": 353 | import argparse 354 | parser = argparse.ArgumentParser(description = "") 355 | parser.add_argument("mode", choices=["train","test","demo"]) 356 | parser.add_argument("--checkpoint", default="checkpoints/CSG.pickle") 357 | parser.add_argument("--maxShapes", default=2, 358 | type=int) 359 | parser.add_argument("--trainTime", default=None, type=float, 360 | help="Time in hours to train the network") 361 | parser.add_argument("--attention", default=1, type=int, 362 | help="Number of rounds of self attention to perform upon objects in scope") 363 | parser.add_argument("--heads", default=2, type=int, 364 | help="Number of attention heads") 365 | parser.add_argument("--hidden", "-H", type=int, default=256, 366 | help="Size of hidden layers") 367 | parser.add_argument("--timeout", default=5, type=float, 368 | help="Test time maximum timeout") 369 | parser.add_argument("--oneParent", default=False, action='store_true') 370 | arguments = parser.parse_args() 371 | 372 | if arguments.mode == "demo": 373 | for n in range(100): 374 | randomScene(export=f"/tmp/CAD_{n}.png",maxShapes=arguments.maxShapes) 375 | import sys 376 | sys.exit(0) 377 | 378 | 379 | 380 | if arguments.mode == "train": 381 | m = ProgramPointerNetwork(ObjectEncoder(), SpecEncoder(), dsl, 382 | oneParent=arguments.oneParent, 383 | attentionRounds=arguments.attention, 384 | heads=arguments.heads, 385 | H=arguments.hidden) 386 | trainCSG(m, lambda: randomScene(maxShapes=arguments.maxShapes), 387 | trainTime=arguments.trainTime*60*60 if arguments.trainTime else None, 388 | checkpoint=arguments.checkpoint) 389 | elif arguments.mode == "test": 390 | with open(arguments.checkpoint,"rb") as handle: 391 | m = pickle.load(handle) 392 | testCSG(m, 393 | lambda: randomScene(maxShapes=arguments.maxShapes, minShapes=arguments.maxShapes), arguments.timeout, 394 | export=f"figures/CAD_{arguments.maxShapes}_shapes.png") 395 | -------------------------------------------------------------------------------- /CNN.py: -------------------------------------------------------------------------------- 1 | from utilities import * 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | class Flatten(nn.Module): 7 | def __init__(self): 8 | super(Flatten, self).__init__() 9 | 10 | def forward(self, x): 11 | return x.view(x.size(0), -1) 12 | 13 | 14 | class CNN(Module): 15 | def __init__(self, _=None, channels=1, layers=4, 16 | inputImageDimension=None, hiddenChannels=64, outputChannels=64): 17 | super(CNN, self).__init__() 18 | assert inputImageDimension is not None 19 | assert layers > 1 20 | def conv_block(in_channels, out_channels, p=True): 21 | return nn.Sequential( 22 | nn.Conv2d(in_channels, out_channels, 3, padding=1), 23 | nn.ReLU(), 24 | nn.Conv2d(out_channels, out_channels, 3, padding=1), 25 | nn.ReLU(), 26 | nn.MaxPool2d(2)) 27 | 28 | self.inputImageDimension = inputImageDimension 29 | 30 | # channels for hidden 31 | hid_dim = hiddenChannels 32 | z_dim = outputChannels 33 | 34 | self.encoder = nn.Sequential(*([conv_block(channels, hid_dim)] + \ 35 | [conv_block(hid_dim, hid_dim) for _ in range(layers - 2) ] + \ 36 | [conv_block(hid_dim, z_dim)] + \ 37 | [Flatten()])) 38 | 39 | self.outputDimensionality = int(outputChannels*inputImageDimension*inputImageDimension/(4**layers)) 40 | self.channels = channels 41 | 42 | self.finalize() 43 | 44 | def forward(self, v): 45 | if isinstance(v, list): v = np.array(v) 46 | if self.channels == 1: # input is either BxWxH or WxH 47 | if len(v.shape) == 2: squeeze = 2 48 | elif len(v.shape) == 3: squeeze = 1 49 | else: assert False 50 | else: # either [b,c,w,h] or [c,w,h] 51 | if len(v.shape) == 3: squeeze = 1 52 | elif len(v.shape) == 4: squeeze = 0 53 | 54 | v = self.tensor(v) 55 | for _ in range(squeeze): v = v.unsqueeze(0) 56 | v = self.encoder(v.float()) 57 | for _ in range(squeeze): v = v.squeeze(0) 58 | return v 59 | -------------------------------------------------------------------------------- /ForwardSample.py: -------------------------------------------------------------------------------- 1 | from programGraph import * 2 | from API import * 3 | from pointerNetwork import * 4 | 5 | import time 6 | 7 | 8 | class ForwardSample(Solver): 9 | def __init__(self, model, _=None, maximumLength=8): 10 | self.maximumLength = maximumLength 11 | self.model = model 12 | 13 | def _infer(self, spec, loss, timeout): 14 | t0 = time.time() 15 | specEncoding = self.model.specEncoder(spec) 16 | 17 | # Maps from an object to its embedding 18 | objectEncodings = ScopeEncoding(self.model, spec) 19 | 20 | while time.time() - t0 < timeout: 21 | g = ProgramGraph([]) 22 | for _ in range(self.maximumLength): 23 | newObjects = self.model.repeatedlySample(specEncoding, g, objectEncodings, 1) 24 | if len(newObjects) == 0 or newObjects[0] is None: break 25 | g = g.extend(newObjects[0]) 26 | self._report(g) 27 | -------------------------------------------------------------------------------- /MCTS.py: -------------------------------------------------------------------------------- 1 | from API import * 2 | from programGraph import * 3 | from pointerNetwork import * 4 | import time 5 | 6 | 7 | class MCTS(Solver): 8 | """ 9 | AlphaZero-style Monte Carlo tree search 10 | Currently ignores learned distance / value, but is biased by learned policy 11 | """ 12 | def __init__(self, model, _=None, reward=None, 13 | c_puct=5, rolloutDepth=None): 14 | """ 15 | c_puct: Trades off exploration and exploitation. Larger values favor exploration, guided by policy. 16 | reward: function from loss to reward. 17 | """ 18 | assert reward is not None, "reward must be specified. This function converts loss into reward." 19 | self.reward = reward 20 | self.c_puct = c_puct 21 | self.model = model 22 | self.rolloutDepth = rolloutDepth 23 | 24 | self.beamTime = 0. 25 | self.rollingTime = 0. 26 | 27 | def __str__(self): 28 | return f"MCTS(puct={self.c_puct})" 29 | 30 | def _infer(self, spec, loss, timeout): 31 | startTime = time.time() 32 | owner = self 33 | 34 | class Node: 35 | def __init__(self, graph): 36 | self.graph = graph 37 | self.visits = 0 38 | self.edges = [] 39 | self.generator = owner.model.bestFirstEnumeration(specEncoding, graph, objectEncodings) 40 | 41 | class Edge: 42 | def __init__(self, parent, child, logLikelihood): 43 | self.logLikelihood = logLikelihood 44 | self.parent = parent 45 | self.child = child 46 | self.traversals = 0 47 | self.totalReward = 0 48 | 49 | specEncoding = self.model.specEncoder(spec) 50 | objectEncodings = ScopeEncoding(self.model, spec) 51 | 52 | 53 | def expand(n): 54 | """Adds a single child to a node""" 55 | if n.generator is None: return 56 | t0 = time.time() 57 | try: o, ll = next(n.generator) 58 | except StopIteration: 59 | n.generator = None 60 | o, ll = None, None 61 | self.beamTime += time.time() - t0 62 | 63 | if o is None or o in n.graph.nodes: return 64 | newGraph = n.graph.extend(o) 65 | if newGraph in graph2node: 66 | child = graph2node[newGraph] 67 | else: 68 | self._report(newGraph) 69 | child = Node(newGraph) 70 | e = Edge(n, child, ll) 71 | n.edges.append(e) 72 | 73 | def rollout(g): 74 | t0 = time.time() 75 | depth = 0 76 | while True: 77 | samples = self.model.repeatedlySample(specEncoding, g, objectEncodings, 1) 78 | assert len(samples) <= 1 79 | depth += 1 80 | if len(samples) == 0 or samples[0] is None: break 81 | g = g.extend(samples[0]) 82 | if self.rolloutDepth is not None and depth >= self.rolloutDepth: break 83 | 84 | self.rollingTime += time.time() - t0 85 | self._report(g) 86 | 87 | return g 88 | 89 | def uct(e): 90 | # Exploit: rewards Q(s,a) 91 | if e.traversals == 0: q = 0. 92 | else: q = e.totalReward/e.traversals 93 | 94 | # Explore, biased by policy 95 | exploration_bonus = math.exp(e.logLikelihood) * (e.parent.visits**0.5) / (1. + e.traversals) 96 | 97 | # Trade-off of exploration and exploitation 98 | return q + self.c_puct*exploration_bonus 99 | 100 | rootNode = Node(ProgramGraph([])) 101 | graph2node = {ProgramGraph([]): rootNode} 102 | 103 | while time.time() - startTime < timeout: 104 | n = rootNode 105 | trajectory = [] # list of traversed edges 106 | 107 | while len(n.edges) > 0: 108 | e = max(n.edges, key=uct) 109 | trajectory.append(e) 110 | n = e.child 111 | 112 | r = self.reward(self.loss(rollout(n.graph))) 113 | 114 | # Expand nodes if their single visit-0 child was visited 115 | for e in trajectory: 116 | if e.child.visits == 0: 117 | expand(e.parent) 118 | 119 | # back up the reward 120 | for e in trajectory: 121 | e.totalReward += r 122 | e.traversals += 1 123 | e.parent.visits += 1 124 | 125 | expand(n) 126 | n.visits += 1 127 | 128 | 129 | 130 | 131 | 132 | -------------------------------------------------------------------------------- /MHDPA.py: -------------------------------------------------------------------------------- 1 | from utilities import * 2 | import torch.nn.functional as F 3 | 4 | class MultiHeadAttention(Module): 5 | def __init__(self, heads, entity_dimensionality, rounds=1, residual=True, 6 | layers=2): 7 | super().__init__() 8 | self.entity_dimensionality = entity_dimensionality 9 | self.heads = heads 10 | 11 | assert entity_dimensionality%heads == 0,\ 12 | "dimensionality of entities must be divisible by number of heads" 13 | 14 | # Dimensionality of each head 15 | self.d = entity_dimensionality//heads 16 | 17 | self.Q = nn.Linear(entity_dimensionality, entity_dimensionality) 18 | self.V = nn.Linear(entity_dimensionality, entity_dimensionality) 19 | self.K = nn.Linear(entity_dimensionality, entity_dimensionality) 20 | self.output = nn.Sequential(*[ layer 21 | for _ in range(layers) 22 | for layer in [nn.Linear(entity_dimensionality, entity_dimensionality), 23 | nn.ReLU()]]) 24 | 25 | self.rounds = rounds 26 | self.residual = residual 27 | 28 | self.finalize() 29 | 30 | def forward(self, entities): 31 | """ 32 | entities: (# entities)x(entity_dimensionality) 33 | returns: (# entities)x(entity_dimensionality) 34 | """ 35 | for _ in range(self.rounds): 36 | # query, values, and keys should all be of size HxExD 37 | q = self.Q(entities).view(entities.size(0), self.heads, self.d).permute(1,0,2) 38 | v = self.V(entities).view(entities.size(0), self.heads, self.d).permute(1,0,2) 39 | k = self.K(entities).view(entities.size(0), self.heads, self.d).permute(1,0,2) 40 | 41 | # attention[i,j] = q_i . k_j 42 | # i.e., amount that object I is attending to object J 43 | attention = F.softmax(q@(k.permute(0,2,1))/(self.d**0.5), dim=-1) 44 | 45 | # Mix together values 46 | o = (attention@v).transpose(0,1).contiguous().view(entities.size(0), self.entity_dimensionality) 47 | 48 | # Apply output transformation 49 | o = self.output(o) 50 | 51 | # residual connection 52 | if self.residual: 53 | entities = entities + o 54 | else: 55 | entities = o 56 | 57 | return entities 58 | 59 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Write, Execute, Assess: Program Synthesis With A REPL 2 | 3 | https://papers.nips.cc/paper/9116-write-execute-assess-program-synthesis-with-a-repl 4 | 5 | We present a neural program synthesis approach integrating components which write, execute, and assess code to navigate the search space of possible programs. We equip the search process with an interpreter or a read-eval-print-loop (REPL), which immediately executes partially written programs, exposing their semantics. The REPL addresses a basic challenge of program synthesis: tiny changes in syntax can lead to huge changes in semantics. We train a pair of models, a policy that proposes the new piece of code to write, and a value function that assesses the prospects of the code written so-far. At test time we can combine these models with a Sequential Monte Carlo algorithm. We apply our approach to two domains: synthesizing text editing programs and inferring 2D and 3D graphics programs. 6 | -------------------------------------------------------------------------------- /SMC.py: -------------------------------------------------------------------------------- 1 | from programGraph import * 2 | from API import * 3 | from pointerNetwork import * 4 | 5 | import time 6 | 7 | class SMC(Solver): 8 | def __init__(self, model, _=None, 9 | maximumLength=8, 10 | initialParticles=100, exponentialGrowthFactor=2, 11 | fitnessWeight=2.): 12 | self.maximumLength = maximumLength 13 | self.initialParticles = initialParticles 14 | self.exponentialGrowthFactor = exponentialGrowthFactor 15 | self.fitnessWeight = fitnessWeight 16 | self.model = model 17 | 18 | def _infer(self, spec, loss, timeout): 19 | startTime = time.time() 20 | numberOfParticles = self.initialParticles 21 | 22 | specEncoding = self.model.specEncoder(spec) 23 | 24 | # Maps from an object to its embedding 25 | objectEncodings = ScopeEncoding(self.model, spec) 26 | 27 | # Maps from a graph to its distance 28 | _distance = {} 29 | def distance(g): 30 | if g in _distance: return _distance[g] 31 | se = objectEncodings.encoding(list(g.objects())) 32 | d = self.model.distance(se, specEncoding) 33 | _distance[g] = d 34 | return d 35 | 36 | class Particle(): 37 | def __init__(self, graph, frequency): 38 | self.frequency = frequency 39 | self.graph = graph 40 | self.distance = distance(graph) 41 | 42 | 43 | while True: 44 | population = [Particle(ProgramGraph([]), numberOfParticles)] 45 | for _ in range(self.maximumLength): 46 | sampleFrequency = {} 47 | for p in population: 48 | for newObject in self.model.repeatedlySample(specEncoding, p.graph, 49 | objectEncodings, p.frequency): 50 | if newObject is None: newGraph = p.graph 51 | else: newGraph = p.graph.extend(newObject) 52 | sampleFrequency[newGraph] = sampleFrequency.get(newGraph, 0) + 1 53 | 54 | if time.time() - startTime >= timeout: return 55 | 56 | for g in sampleFrequency: self._report(g) 57 | 58 | # Convert graphs to particles 59 | samples = [Particle(g, f) 60 | for g, f in sampleFrequency.items() ] 61 | 62 | # Resample 63 | logWeights = [math.log(p.frequency) - p.distance 64 | for p in samples] 65 | ps = [ math.exp(lw - max(logWeights)) for lw in logWeights ] 66 | ps = [p/sum(ps) for p in ps] 67 | sampleFrequencies = np.random.multinomial(numberOfParticles, ps) 68 | 69 | population = [] 70 | for particle, frequency in zip(samples, sampleFrequencies): 71 | if frequency > 0: 72 | particle.frequency = frequency 73 | population.append(particle) 74 | 75 | numberOfParticles *= self.exponentialGrowthFactor 76 | -------------------------------------------------------------------------------- /pointerNetwork.py: -------------------------------------------------------------------------------- 1 | from utilities import * 2 | 3 | import random 4 | import math 5 | import torch.nn as nn 6 | import torch 7 | import torch.nn.functional as F 8 | import torch.optim as optimization 9 | from torch.autograd import Variable 10 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 11 | from MHDPA import * 12 | from programGraph import ProgramGraph 13 | from API import Program 14 | import numpy as np 15 | 16 | class Pointer(): 17 | def __init__(self, i, m=None): 18 | self.i = i 19 | self.m = m 20 | def __str__(self): return f"P({self.i}, max={self.m})" 21 | def __repr__(self): return str(self) 22 | 23 | class SymbolEncoder(Module): 24 | def __init__(self, lexicon, H=256): 25 | super(SymbolEncoder, self).__init__() 26 | 27 | self.encoder = nn.Embedding(len(lexicon), H) 28 | self.lexicon = lexicon 29 | self.wordToIndex = {w: j for j,w in enumerate(self.lexicon) } 30 | 31 | self.finalize() 32 | 33 | def forward(self, objects): 34 | return self.encoder(self.device(torch.tensor([self.wordToIndex[o] for o in objects]))) 35 | 36 | class LineDecoder(Module): 37 | def __init__(self, lexicon, H=256, encoderDimensionality=256, layers=1): 38 | """ 39 | H: Hidden size for GRU & size of embedding of output tokens 40 | encoderDimensionality: Dimensionality of objects we are attending over (objects we can point to) 41 | lexicon: list of symbols that can occur in a line of code. STARTING, ENDING, & POINTER are reserved symbols. 42 | """ 43 | super(LineDecoder, self).__init__() 44 | 45 | self.encoderDimensionality = encoderDimensionality 46 | 47 | self.model = nn.GRU(H + encoderDimensionality, H, layers) 48 | 49 | self.specialSymbols = [ 50 | "STARTING", "ENDING", "POINTER" 51 | ] 52 | 53 | self.lexicon = lexicon + self.specialSymbols 54 | self.wordToIndex = {w: j for j,w in enumerate(self.lexicon) } 55 | self.embedding = nn.Embedding(len(self.lexicon), H) 56 | 57 | self.output = nn.Sequential(nn.Linear(H, len(self.lexicon)), 58 | nn.LogSoftmax()) 59 | 60 | self.decoderToPointer = nn.Linear(H, H, bias=False) 61 | self.encoderToPointer = nn.Linear(encoderDimensionality, H, bias=False) 62 | self.attentionSelector = nn.Linear(H, 1, bias=False) 63 | 64 | self.pointerIndex = self.wordToIndex["POINTER"] 65 | 66 | self.finalize() 67 | 68 | def pointerAttention(self, hiddenStates, objectEncodings, _=None, 69 | pointerBounds=[], objectKeys=None): 70 | """ 71 | hiddenStates: BxH 72 | objectEncodings: (# objects)x(encoder dimensionality); if this is set to none, expects: 73 | objectKeys: (# objects)x(key dimensionality; this is H passed to constructor) 74 | OUTPUT: Bx(# objects) attention matrix 75 | """ 76 | hiddenStates = self.decoderToPointer(hiddenStates) 77 | if objectKeys is None: 78 | objectKeys = self.encoderToPointer(objectEncodings) 79 | else: 80 | assert objectEncodings is None, "You either provide object encodings or object keys but not both" 81 | 82 | _h = hiddenStates.unsqueeze(1).repeat(1, objectKeys.size(0), 1) 83 | _o = objectKeys.unsqueeze(0).repeat(hiddenStates.size(0), 1, 1) 84 | attention = self.attentionSelector(torch.tanh(_h + _o)).squeeze(2) 85 | #attention = self.attentionSelector(torch.tanh(_h * some_bilinear * _o)).squeeze(2) 86 | 87 | mask = np.zeros((hiddenStates.size(0), objectKeys.size(0))) 88 | 89 | for p,b in enumerate(pointerBounds): 90 | if b is not None: 91 | mask[p, b:] = NEGATIVEINFINITY 92 | 93 | return F.log_softmax(attention + self.device(torch.tensor(mask).float()), dim=1) 94 | 95 | def logLikelihood_hidden(self, initialState, target, encodedInputs): 96 | symbolSequence = [self.wordToIndex[t if not isinstance(t,Pointer) else "POINTER"] 97 | for t in ["STARTING"] + target + ["ENDING"] ] 98 | 99 | # inputSequence : L x H 100 | inputSequence = self.tensor(symbolSequence[:-1]) 101 | outputSequence = self.tensor(symbolSequence[1:]) 102 | inputSequence = self.embedding(inputSequence) 103 | 104 | # Concatenate the object encodings w/ the inputs 105 | objectInputs = self.device(torch.zeros(len(symbolSequence) - 1, self.encoderDimensionality)) 106 | for t, p in enumerate(target): 107 | if isinstance(p, Pointer): 108 | objectInputs[t + 1] = encodedInputs[p.i] 109 | objectInputs = objectInputs 110 | 111 | inputSequence = torch.cat([inputSequence, objectInputs], 1).unsqueeze(1) 112 | 113 | if initialState is not None: initialState = initialState.unsqueeze(0).unsqueeze(0) 114 | 115 | o, h = self.model(inputSequence, initialState) 116 | 117 | # output sequence log likelihood, ignoring pointer values 118 | sll = -F.nll_loss(self.output(o.squeeze(1)), outputSequence, reduce=True, size_average=False) 119 | 120 | 121 | # pointer value log likelihood 122 | pointerTimes = [t - 1 for t,s in enumerate(symbolSequence) if self.pointerIndex == s ] 123 | if len(pointerTimes) == 0: 124 | pll = 0. 125 | else: 126 | assert encodedInputs is not None 127 | pointerValues = [v.i for v in target if isinstance(v, Pointer) ] 128 | pointerBounds = [v.m for v in target if isinstance(v, Pointer) ] 129 | pointerHiddens = o[self.tensor(pointerTimes),:,:].squeeze(1) 130 | 131 | attention = self.pointerAttention(pointerHiddens, encodedInputs, 132 | pointerBounds=pointerBounds) 133 | pll = -F.nll_loss(attention, self.tensor(pointerValues), 134 | reduce=True, size_average=False) 135 | return sll + pll, h 136 | 137 | def logLikelihood(self, initialState, target, encodedInputs): 138 | return self.logLikelihood_hidden(initialState, target, encodedInputs)[0] 139 | 140 | def sample(self, initialState, encodedInputs): 141 | sequence = ["STARTING"] 142 | h = initialState 143 | while len(sequence) < 100: 144 | lastWord = sequence[-1] 145 | if isinstance(lastWord, Pointer): 146 | latestPointer = encodedInputs[lastWord.i] 147 | lastWord = "POINTER" 148 | else: 149 | latestPointer = self.device(torch.zeros(self.encoderDimensionality)) 150 | i = self.embedding(self.tensor(self.wordToIndex[lastWord])) 151 | i = torch.cat([i, latestPointer]) 152 | if h is not None: h = h.unsqueeze(0).unsqueeze(0) 153 | o,h = self.model(i.unsqueeze(0).unsqueeze(0), h) 154 | o = o.squeeze(0).squeeze(0) 155 | h = h.squeeze(0).squeeze(0) 156 | 157 | # Sample the next symbol 158 | distribution = self.output(o) 159 | next_symbol = self.lexicon[torch.multinomial(distribution.exp(), 1)[0].data.item()] 160 | if next_symbol == "ENDING": 161 | break 162 | if next_symbol == "POINTER": 163 | if encodedInputs is not None: 164 | # Sample the next pointer 165 | a = self.pointerAttention(h.unsqueeze(0), encodedInputs, []).squeeze(0) 166 | next_symbol = Pointer(torch.multinomial(a.exp(),1)[0].data.item()) 167 | else: 168 | return None 169 | 170 | sequence.append(next_symbol) 171 | 172 | return sequence[1:] 173 | 174 | def beam(self, initialState, encodedObjects, B, 175 | maximumLength=50): 176 | """Given an initial hidden state, of size H, and the encodings of the 177 | objects in scope, of size Ox(self.encoderDimensionality), do a beam 178 | search with beam width B. Returns a list of (log likelihood, sequence of tokens)""" 179 | master = self 180 | class Particle(): 181 | def __init__(self, h, ll, sequence): 182 | self.h = h 183 | self.ll = ll 184 | self.sequence = sequence 185 | def input(self): 186 | lastWord = self.sequence[-1] 187 | if isinstance(lastWord, Pointer): 188 | latestPointer = encodedObjects[lastWord.i] 189 | lastWord = "POINTER" 190 | else: 191 | latestPointer = master.device(torch.zeros(master.encoderDimensionality)) 192 | return torch.cat([master.embedding(master.tensor(master.wordToIndex[lastWord])), latestPointer]) 193 | @property 194 | def finished(self): return self.sequence[-1] == "ENDING" 195 | def children(self, outputDistribution, pointerDistribution, newHidden): 196 | if self.finished: return [self] 197 | def tokenLikelihood(token): 198 | if isinstance(token, Pointer): 199 | return outputDistribution[master.pointerIndex] + pointerDistribution[token.i] 200 | return outputDistribution[master.wordToIndex[token]] 201 | bestTokens = list(sorted([ t for t in master.lexicon if t not in ["STARTING","POINTER"] ] + \ 202 | [Pointer(i) for i in range(numberOfObjects) ], 203 | key=tokenLikelihood, reverse=True))[:B] 204 | return [Particle(newHidden, self.ll + tokenLikelihood(t), 205 | self.sequence + [t]) 206 | for t in bestTokens ] 207 | def trimmed(self): 208 | if self.sequence[-1] == "ENDING": return self.sequence[1:-1] 209 | return self.sequence[1:] 210 | 211 | particles = [Particle(initialState, 0., ["STARTING"])] 212 | if encodedObjects is not None: 213 | objectKeys = self.encoderToPointer(encodedObjects) 214 | numberOfObjects = objectKeys.size(0) 215 | else: 216 | numberOfObjects = 0 217 | 218 | for _ in range(maximumLength): 219 | unfinishedParticles = [p for p in particles if not p.finished ] 220 | inputs = torch.stack([p.input() for p in unfinishedParticles]).unsqueeze(0) 221 | if any( p.h is not None for p in unfinishedParticles ): 222 | hs = torch.stack([p.h for p in unfinishedParticles]).unsqueeze(0) 223 | else: 224 | hs = None 225 | o, h = self.model(inputs, hs) 226 | o = o.squeeze(0) 227 | h = h.squeeze(0) 228 | 229 | outputDistributions = self.output(o).detach().cpu().numpy() 230 | if encodedObjects is not None: 231 | attention = self.pointerAttention(h, None, objectKeys=objectKeys).detach().cpu().numpy() 232 | else: 233 | attention = [None]*len(unfinishedParticles) 234 | 235 | particles = [child 236 | for j,p in enumerate(unfinishedParticles) 237 | for child in p.children(outputDistributions[j], attention[j], h[j]) ] + \ 238 | [p for p in particles if p.finished ] 239 | particles.sort(key=lambda p: p.ll, reverse=True) 240 | particles = particles[:B] 241 | 242 | if all( p.finished for p in particles ): break 243 | return [(p.ll, p.trimmed()) for p in particles if p.finished] 244 | 245 | 246 | def bestFirstEnumeration(self, initialState, encodedObjects): 247 | """Given an initial hidden state of size H and the encodings of objects in scope, 248 | do a best first search and yield a stream of (log likelihood, sequence of tokens)""" 249 | if encodedObjects is not None: 250 | objectKeys = self.encoderToPointer(encodedObjects) 251 | numberOfObjects = objectKeys.size(0) 252 | else: 253 | numberOfObjects = 0 254 | 255 | class State(): 256 | def __init__(self, h, ll, sequence): 257 | self.h = h 258 | self.ll = ll 259 | self.sequence = sequence 260 | 261 | @property 262 | def finished(self): 263 | return self.sequence[-1] == "ENDING" 264 | def trimmed(self): 265 | return self.sequence[1:-1] 266 | 267 | frontier = PQ() 268 | def addToFrontier(s): 269 | frontier.push(s.ll, s) 270 | addToFrontier(State(initialState, 0., ["STARTING"])) 271 | 272 | while len(frontier) > 0: 273 | best = frontier.popMaximum() 274 | if best.finished: 275 | yield (best.ll, best.trimmed()) 276 | continue 277 | 278 | 279 | # Calculate the input vector 280 | lastWord = best.sequence[-1] 281 | if isinstance(lastWord, Pointer): 282 | latestPointer = encodedObjects[lastWord.i] 283 | lastWord = "POINTER" 284 | else: 285 | latestPointer = self.device(torch.zeros(self.encoderDimensionality)) 286 | i = torch.cat([self.embedding(self.tensor(self.wordToIndex[lastWord])), latestPointer]) 287 | 288 | # Run the RNN forward 289 | i = i.unsqueeze(0).unsqueeze(0) 290 | o,h = self.model(i,best.h.unsqueeze(0).unsqueeze(0) if best.h is not None else None) 291 | 292 | # incorporate successors into heap 293 | o = self.output(o.squeeze(0).squeeze(0)).cpu().detach().numpy() 294 | h = h.squeeze(0) 295 | if numberOfObjects > 0: 296 | a = self.pointerAttention(h, None, objectKeys=objectKeys).squeeze(0).cpu().detach().numpy() 297 | h = h.squeeze(0) 298 | for j,w in enumerate(self.lexicon): 299 | ll = o[j] 300 | if w == "POINTER": 301 | for objectIndex in range(numberOfObjects): 302 | pointer_ll = ll + a[objectIndex] 303 | successor = State(h, best.ll + pointer_ll, best.sequence + [Pointer(objectIndex)]) 304 | addToFrontier(successor) 305 | else: 306 | addToFrontier(State(h, best.ll + ll, best.sequence + [w])) 307 | 308 | class PointerNetwork(Module): 309 | def __init__(self, encoder, lexicon, H=256): 310 | super(PointerNetwork, self).__init__() 311 | self.encoder = encoder 312 | self.decoder = LineDecoder(lexicon, H=H) 313 | 314 | self.finalize() 315 | 316 | def gradientStep(self, optimizer, inputObjects, outputSequence, 317 | verbose=False): 318 | self.zero_grad() 319 | l = -self.decoder.logLikelihood(None, outputSequence, 320 | self.encoder(inputObjects) if inputObjects else None) 321 | l.backward() 322 | optimizer.step() 323 | if verbose: 324 | print("loss",l.data.item()) 325 | 326 | def sample(self, inputObjects): 327 | return [ inputObjects[s.i] if isinstance(s,Pointer) else s 328 | for s in self.decoder.sample(None, 329 | self.encoder(inputObjects)) ] 330 | 331 | def beam(self, inputObjects, B, maximumLength=10): 332 | return [ (ll, [ inputObjects[s.i] if isinstance(s,Pointer) else s 333 | for s in sequence ]) 334 | for ll, sequence in self.decoder.beam(None, self.encoder(inputObjects), B, 335 | maximumLength=maximumLength)] 336 | 337 | def bestFirstEnumeration(self, inputObjects): 338 | for ll, sequence in self.decoder.bestFirstEnumeration(None, self.encoder(inputObjects)): 339 | yield ll, [inputObjects[p.i] if isinstance(p, Pointer) else p 340 | for p in sequence] 341 | 342 | class ScopeEncoding(): 343 | """A cache of the encodings of objects in scope""" 344 | def __init__(self, owner, spec): 345 | """owner: a ProgramPointerNetwork that "owns" this scope encoding""" 346 | self.spec = spec 347 | self.owner = owner 348 | self.object2index = {} 349 | self.objectEncoding = None 350 | 351 | def registerObject(self, o): 352 | if o in self.object2index: return self 353 | oe = self.owner.objectEncoder(self.spec, o.execute()) 354 | if self.objectEncoding is None: 355 | self.objectEncoding = oe.view(1,-1) 356 | else: 357 | self.objectEncoding = torch.cat([self.objectEncoding, oe.view(1,-1)]) 358 | self.object2index[o] = len(self.object2index) 359 | return self 360 | 361 | def registerObjects(self, os): 362 | os = [o for o in os if o not in self.object2index ] 363 | if len(os) == 0: return self 364 | encodings = self.owner.objectEncoder(self.spec, [o.execute() for o in os]) 365 | if self.objectEncoding is None: 366 | self.objectEncoding = encodings 367 | else: 368 | self.objectEncoding = torch.cat([self.objectEncoding, encodings]) 369 | for o in os: 370 | self.object2index[o] = len(self.object2index) 371 | return self 372 | 373 | def encoding(self, objects): 374 | """Takes as input O objects (as a list) and returns a OxE tensor of their encodings. 375 | If the owner has a self attention module, also applies the attention module. 376 | If objects is the empty list then return None""" 377 | if len(objects) == 0: return None 378 | self.registerObjects(objects) 379 | preAttention = self.objectEncoding[self.owner.device(torch.tensor([self.object2index[o] 380 | for o in objects ]))] 381 | return self.owner.selfAttention(preAttention) 382 | 383 | class ProgramPointerNetwork(Module): 384 | """A network that looks at the objects in a ProgramGraph and then predicts what to add to the graph""" 385 | def __init__(self, objectEncoder, specEncoder, DSL, oneParent=False, 386 | H=256, attentionRounds=1, heads=4): 387 | """ 388 | specEncoder: Module that encodes spec to initial hidden state of RNN 389 | objectEncoder: Module that encodes (spec, object) to features we attend over 390 | oneParent: Whether each node in the program graph is constrained to have no more than one parent 391 | """ 392 | super(ProgramPointerNetwork, self).__init__() 393 | 394 | self.DSL = DSL 395 | self.oneParent = oneParent 396 | self.objectEncoder = objectEncoder 397 | self.specEncoder = specEncoder 398 | self.decoder = LineDecoder(DSL.lexicon + ["RETURN"], 399 | encoderDimensionality=H, # self attention outputs size H 400 | H=H) 401 | self._initialHidden = nn.Sequential( 402 | nn.Linear(H + specEncoder.outputDimensionality, H), 403 | nn.ReLU()) 404 | 405 | self._distance = nn.Sequential( 406 | nn.Linear(H + specEncoder.outputDimensionality, H), 407 | nn.ReLU(), 408 | nn.Linear(H, 1), 409 | nn.ReLU()) 410 | 411 | self.selfAttention = nn.Sequential( 412 | nn.Linear(objectEncoder.outputDimensionality, H), 413 | MultiHeadAttention(heads, H, rounds=attentionRounds, residual=True)) 414 | 415 | self.H = H 416 | 417 | self.finalize() 418 | 419 | def initialHidden(self, objectEncodings, specEncoding): 420 | if objectEncodings is None: 421 | objectEncodings = self.device(torch.zeros(self.H)) 422 | else: 423 | objectEncodings = objectEncodings.sum(0) 424 | return self._initialHidden(torch.cat([specEncoding, objectEncodings])) 425 | 426 | def distance(self, objectEncodings, specEncoding): 427 | """Returns a 1-dimensional tensor which should be the sum of (# objects to create) + (# spurious objects created)""" 428 | if objectEncodings is None: 429 | objectEncodings = self.device(torch.zeros(self.H)) 430 | else: 431 | objectEncodings = objectEncodings.sum(0) 432 | 433 | return self._distance(torch.cat([specEncoding, objectEncodings])) 434 | 435 | def traceLogLikelihood(self, spec, trace, scopeEncoding=None): 436 | scopeEncoding = scopeEncoding or ScopeEncoding(self, spec).registerObjects(set(trace)) 437 | currentGraph = ProgramGraph([]) 438 | specEncoding = self.specEncoder(spec) 439 | lls = [] 440 | for obj in trace + [['RETURN']]: 441 | finalMove = obj == ['RETURN'] 442 | 443 | # Gather together objects in scope 444 | objectsInScope = list(currentGraph.objects(oneParent=self.oneParent)) 445 | scope = scopeEncoding.encoding(objectsInScope) 446 | object2pointer = {o: Pointer(i) 447 | for i, o in enumerate(objectsInScope)} 448 | 449 | h0 = self.initialHidden(scope, specEncoding) 450 | def substitutePointers(serialization): 451 | return [object2pointer.get(token, token) 452 | for token in serialization] 453 | lls.append(self.decoder.logLikelihood(h0, 454 | substitutePointers(obj.serialize()) if not finalMove else obj, 455 | scope)) 456 | if not finalMove: 457 | currentGraph = currentGraph.extend(obj) 458 | return sum(lls), lls 459 | 460 | def gradientStepTrace(self, optimizer, spec, trace): 461 | """Returns [policy losses]""" 462 | self.zero_grad() 463 | 464 | ll, lls = self.traceLogLikelihood(spec, trace) 465 | 466 | (-ll).backward() 467 | optimizer.step() 468 | return [-l.data.item() for l in lls] 469 | 470 | def sample(self, spec, maxMoves=None): 471 | specEncoding = self.specEncoder(spec) 472 | objectEncodings = ScopeEncoding(self, spec) 473 | 474 | graph = ProgramGraph([]) 475 | 476 | while True: 477 | # Make the encoding matrix 478 | objectsInScope = list(graph.objects(oneParent=self.oneParent)) 479 | oe = objectEncodings.encoding(objectsInScope) 480 | h0 = self.initialHidden(oe, specEncoding) 481 | 482 | nextLineOfCode = self.decoder.sample(h0, oe) 483 | if nextLineOfCode is None: return None 484 | nextLineOfCode = [objectsInScope[t.i] if isinstance(t, Pointer) else t 485 | for t in nextLineOfCode ] 486 | 487 | if 'RETURN' in nextLineOfCode or len(graph) >= maxMoves: return graph 488 | 489 | nextObject = self.DSL.parseLine(nextLineOfCode) 490 | if nextObject is None: return None 491 | 492 | graph = graph.extend(nextObject) 493 | 494 | def repeatedlySample(self, specEncoding, graph, objectEncodings, n_samples): 495 | """Repeatedly samples a single line of code. 496 | specEncoding: Encoding of the spec 497 | objectEncodings: a ScopeEncoding 498 | graph: the current graph 499 | n_samples: how many samples to draw 500 | returns: list of sampled DSL objects. If the sample is `RETURN` then that entry in the list is None. 501 | """ 502 | objectsInScope = list(graph.objects(oneParent=self.oneParent)) 503 | oe = objectEncodings.encoding(objectsInScope) 504 | h0 = self.initialHidden(oe, specEncoding) 505 | 506 | samples = [] 507 | for _ in range(n_samples): 508 | nextLineOfCode = self.decoder.sample(h0, oe) 509 | if nextLineOfCode is None: continue 510 | nextLineOfCode = [objectsInScope[t.i] if isinstance(t, Pointer) else t 511 | for t in nextLineOfCode ] 512 | if 'RETURN' in nextLineOfCode: 513 | samples.append(None) 514 | else: 515 | nextObject = self.DSL.parseLine(nextLineOfCode) 516 | if nextObject is not None: 517 | samples.append(nextObject) 518 | 519 | return samples 520 | 521 | 522 | def beamNextLine(self, specEncoding, graph, objectEncodings, B): 523 | """Does a beam search for a single line of code. 524 | specEncoding: Encoding of the spec 525 | objectEncodings: a ScopeEncoding 526 | graph: the current graph 527 | B: beam size 528 | returns: list of (at most B) beamed (DSL object, log likelihood). None denotes `RETURN` 529 | """ 530 | objectsInScope = list(graph.objects(oneParent=self.oneParent)) 531 | oe = objectEncodings.encoding(objectsInScope) 532 | h0 = self.initialHidden(oe, specEncoding) 533 | lines = [] 534 | for ll, tokens in self.decoder.beam(h0, oe, B, maximumLength=10): 535 | tokens = [objectsInScope[t.i] if isinstance(t, Pointer) else t 536 | for t in tokens] 537 | if 'RETURN' in tokens: 538 | lines.append((None, ll)) 539 | else: 540 | line = self.DSL.parseLine(tokens) 541 | if line is None: continue 542 | lines.append((line, ll)) 543 | return lines 544 | 545 | def bestFirstEnumeration(self, specEncoding, graph, objectEncodings): 546 | """Does a best first search for a single line of code. 547 | specEncoding: Encoding of the spec 548 | objectEncodings: a ScopeEncoding 549 | graph: current graph 550 | yields: stream of (DSL object, log likelihood). None denotes `RETURN'""" 551 | objectsInScope = list(graph.objects(oneParent=self.oneParent)) 552 | oe = objectEncodings.encoding(objectsInScope) 553 | h0 = self.initialHidden(oe, specEncoding) 554 | for ll, tokens in self.decoder.bestFirstEnumeration(h0, oe): 555 | tokens = [objectsInScope[t.i] if isinstance(t, Pointer) else t 556 | for t in tokens] 557 | if 'RETURN' in tokens and len(tokens) == 0: 558 | yield (None, ll) 559 | else: 560 | line = self.DSL.parseLine(tokens) 561 | if line is None: continue 562 | yield (line, ll) 563 | 564 | if __name__ == "__main__": 565 | m = PointerNetwork(SymbolEncoder([str(n) for n in range(10) ]), ["large","small"]) 566 | optimizer = torch.optim.Adam(m.parameters(), lr=0.001, eps=1e-3, amsgrad=True) 567 | for n in range(90000): 568 | x = str(random.choice(range(10))) 569 | y = str(random.choice(range(10))) 570 | if x == y: continue 571 | large = max(x,y) 572 | small = min(x,y) 573 | if random.choice([False,True]): 574 | sequence = ["large", Pointer(int(large == y)), Pointer(int(large == y)), 575 | "small", Pointer(int(small == y))] 576 | else: 577 | sequence = ["small", Pointer(int(small == y)), 578 | "large", Pointer(int(large == y))] 579 | verbose = n%50 == 0 580 | if random.choice([False,True]): 581 | m.gradientStep(optimizer, [x,y], sequence, verbose=verbose) 582 | else: 583 | m.gradientStep(optimizer, [], ["small","small"], verbose=verbose) 584 | if verbose: 585 | print([x,y],"goes to",m.sample([x,y])) 586 | print([x,y],"beams into:") 587 | for ll, s in m.beam([x,y],10): 588 | print(f"{s}\t(w/ ll={ll})") 589 | print() 590 | print([x,y],"best first into") 591 | lines = 0 592 | for ll, s in m.bestFirstEnumeration([x,y]): 593 | print(f"{s}\t(w/ ll={ll})") 594 | lines += 1 595 | if lines > 5: break 596 | 597 | print() 598 | 599 | -------------------------------------------------------------------------------- /programGraph.py: -------------------------------------------------------------------------------- 1 | from API import * 2 | from utilities import * 3 | 4 | class ProgramGraph: 5 | """A program graph is a state in the search space""" 6 | def __init__(self, nodes): 7 | self.nodes = nodes if isinstance(nodes, tuple) else tuple(nodes) 8 | 9 | @staticmethod 10 | def fromRoot(r, oneParent=False): 11 | if not oneParent: 12 | ns = set() 13 | def reachable(n): 14 | if n in ns: return 15 | ns.add(n) 16 | for c in n.children(): 17 | reachable(c) 18 | reachable(r) 19 | return ProgramGraph(ns) 20 | else: 21 | ns = [] 22 | def visit(n): 23 | ns.append(n) 24 | for c in n.children(): visit(c) 25 | visit(r) 26 | return ProgramGraph(ns) 27 | 28 | 29 | def __len__(self): return len(self.nodes) 30 | 31 | def prettyPrint(self): 32 | variableOfNode = [None for _ in self.nodes] 33 | nameOfNode = [None for _ in self.nodes] # pp of node 34 | 35 | lines = [] 36 | 37 | def getIndex(p): 38 | for i, pp in enumerate(self.nodes): 39 | if p is pp: return i 40 | assert False 41 | 42 | def pp(j): 43 | if variableOfNode[j] is not None: return variableOfNode[j] 44 | serialization = [t if not isinstance(t,Program) else pp(getIndex(t)) 45 | for t in self.nodes[j].serialize()] 46 | expression = f"({' '.join(map(str, serialization))})" 47 | variableOfNode[j] = f"${len(lines)}" 48 | lines.append(f"{variableOfNode[j]} <- {expression}") 49 | return variableOfNode[j] 50 | 51 | for j in range(len(self.nodes)): 52 | pp(j) 53 | return "\n".join(lines) 54 | 55 | def extend(self, newNode): 56 | return ProgramGraph([newNode] + list(self.nodes)) 57 | 58 | def objects(self, oneParent=False): 59 | return [o for o in self.nodes 60 | if not oneParent or not any( any( c is o for c in op.children() ) for op in self.nodes )] 61 | -------------------------------------------------------------------------------- /randomSolver.py: -------------------------------------------------------------------------------- 1 | from API import * 2 | from programGraph import * 3 | 4 | import time 5 | import random 6 | 7 | class RandomSolver(Solver): 8 | def __init__(self, DSL): 9 | self.DSL = DSL 10 | 11 | def _infer(self, spec, loss, timeout): 12 | t0 = time.time() 13 | 14 | g = ProgramGraph([]) 15 | 16 | def getArgument(requestedType): 17 | if requestedType.isInteger: 18 | return random.choice(range(requestedType.lower, requestedType.upper + 1)) 19 | 20 | choices = [o for o in g.objects() if requestedType.instance(o)] 21 | if choices: return random.choice(choices) 22 | else: return None 23 | 24 | while time.time() - t0 < timeout: 25 | 26 | # Pick a random DSL production 27 | operator = random.choice(self.DSL.operators) 28 | tp = operator.type 29 | 30 | if not tp.isArrow: 31 | object = operator() 32 | else: 33 | # Sample random arguments 34 | arguments = [getArgument(t) for t in tp.arguments] 35 | if any( a is None for a in arguments ): continue 36 | try: 37 | object = operator(*arguments) 38 | except: continue 39 | 40 | if object not in g.objects(): 41 | g = g.extend(object) 42 | self._report(ProgramGraph.fromRoot(object)) 43 | 44 | -------------------------------------------------------------------------------- /utilities.py: -------------------------------------------------------------------------------- 1 | import heapq 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | class Module(nn.Module): 7 | """Wrapper over torch Module class that handles GPUs elegantly""" 8 | def __init__(self): 9 | super(Module, self).__init__() 10 | self.use_cuda = torch.cuda.is_available() 11 | 12 | def tensor(self, array): 13 | return self.device(torch.tensor(array)) 14 | def device(self, t): 15 | if self.use_cuda: return t.cuda() 16 | else: return t 17 | def finalize(self): 18 | if self.use_cuda: self.cuda() 19 | 20 | 21 | class PQ(object): 22 | """why the fuck does Python not wrap this in a class 23 | This is a priority queue, a.k.a. max heap""" 24 | 25 | def __init__(self): 26 | self.h = [] 27 | self.index2value = {} 28 | self.nextIndex = 0 29 | 30 | def push(self, priority, v): 31 | self.index2value[self.nextIndex] = v 32 | heapq.heappush(self.h, (-priority, self.nextIndex)) 33 | self.nextIndex += 1 34 | 35 | def popMaximum(self): 36 | i = heapq.heappop(self.h)[1] 37 | v = self.index2value[i] 38 | del self.index2value[i] 39 | return v 40 | 41 | def __iter__(self): 42 | for _, v in self.h: 43 | yield self.index2value[v] 44 | 45 | def __len__(self): return len(self.h) 46 | --------------------------------------------------------------------------------