├── README.md ├── ast_pretty_printer.py ├── ast_trampoline.py ├── stack_heights.py ├── tailbytes.py ├── test_rec_fact.py ├── trampoline.py ├── trampolines_in_Python.ipynb ├── translator.py └── translator_tests.py /README.md: -------------------------------------------------------------------------------- 1 | Recursion. 2 | ================== 3 | This project explores a variety of ways to optimize recursion in Python. 4 | So far, the following approaches have been explored: 5 | Trampolines. 6 | ----------------- 7 | TCO using Python standard means. Taking a tail-recursive function that returns a thunk in place of a recursive call, the trampoline function keeps calling the thunked result until a non-callable value is reached. 8 | >Associated files:
9 | >trampolines.py - all source code for trampolines
10 | >trampolines_in_Python.ipynb 11 | 12 | AST trampolines. 13 | -------------------------------- 14 | TCO by manipulating AST of the function (with Paul Tagliamonte). 15 | The ast of the tail-recursive function is altered so that all returns are "thunked", a trampolining decorator is injected into the tree. 16 | >Associated files:
17 | >ast_trampoline.py 18 | 19 | Tailbytes. 20 | --------- 21 | TCO using direct bytecode manipulation (with Allison Kaptur). 22 | By swapping the bytecode representation of a tail-recursive function, the recursive call is substituted by resetting the variables and jumping to the beginning of the function. 23 | Problems we've faced and resolved so far: 24 | 25 | * _Problem:_ Deleting and inserting bytes into the bytecode changes the location of the original bytes. 26 | _Solution:_ Absolute jumps are updated after bytecode alternation has been performed. 27 | * _Problem:_ Our initial algorithm was removing all calls to *any* function. This caused a problem if the function involve non-recursive function calls. 28 | _Solution:_ Remove a CALL_FUNCTION instruction only if it is a part of the recursive call (to determine that, we keep track of the stack size). We target the CALL_FUNCTION instructions which return the stack to the same size that it had when the recursive function was loaded onto it. 29 | 30 | >Associated files:
31 | >tailbytes.py
32 | >stack_heights.py
33 | >See also [presentation slides](http://www.slideshare.net/lnikolaeva/tailbytes-pygotham) 34 | 35 | Transform regular recursion into tail recursion. 36 | ------------------------------- 37 | Interpreter that reads an ast representation of a regular recursion and mutates it to a tree that compiles to the same function using tail recursion. 38 | >Associated files:
39 | >translator.py 40 | 41 | Further goals: 42 | -------- 43 | * Transform binary recursion (i.e. fibonacci-like functions) into tail-recursion. 44 | * Find a way to optimize mutual recursion. 45 | 46 | Tools: 47 | ----------- 48 | * ast_pretty_printer.py
49 | print_ast takes an ast tree and prettyprints it, working off of ast.dump 50 | -------------------------------------------------------------------------------- /ast_pretty_printer.py: -------------------------------------------------------------------------------- 1 | import ast, inspect 2 | 3 | class AltDump(ast.NodeVisitor): 4 | 5 | def visit(self, node): 6 | res = [] 7 | res.append(type(node).__name__+'(') 8 | field_values = [] 9 | for field, value in ast.iter_fields(node): 10 | if type(value) == list: 11 | field_values.append(field+'=['+ ', '.join([self.visit(item) for item in value])+']') 12 | elif isinstance(value, ast.AST): 13 | field_values.append(field+'='+self.visit(value)) 14 | # elif value == None: 15 | # field_values.append(field+'=None') 16 | else: 17 | field_values.append(field+'='+ ("'{}'".format(value) if type(value) == str 18 | else str(value))) 19 | res.append(', '.join(field_values)) 20 | self.generic_visit(node) 21 | res.append(')') 22 | return ''.join(map(str, res)) 23 | 24 | class Pretty(ast.NodeVisitor): 25 | def __init__(self, filler=" "): 26 | self.res = '' 27 | self.level = -1 28 | self.filler = filler 29 | self.out = [] 30 | 31 | @property 32 | def prefix(self): 33 | return self.level * self.filler 34 | 35 | def visit(self, node): 36 | self.level +=1 37 | print(node) 38 | self.out.append(type(node).__name__ +'(') 39 | self.level += 1 40 | for field, value in ast.iter_fields(node): 41 | print(field, value) 42 | self.out.append(self.prefix+field) 43 | if type(value) == list and len(value) > 0: 44 | self.out.append('=[\n') 45 | for n in value: 46 | self.visit(n) 47 | # self.out.append(self.prefix+type(n).__name__ +'(\n') 48 | # for child in ast.iter_child_nodes(n): 49 | # self.visit(child) 50 | self.out.append('=\n') 51 | elif isinstance(value, ast.AST): 52 | self.out.append('=') 53 | self.generic_visit(value) 54 | else: 55 | self.out.append('=' + str(value) + '\n') 56 | self.level -= 2 57 | 58 | def visit_Name(self, node): 59 | print("I'm here!") 60 | self.level += 1 61 | self.out.append(self.prefix + ast.dump(node)) 62 | self.level -= 1 63 | 64 | class PrettyPrinter(ast.NodeVisitor): 65 | def __init__(self): 66 | self.res = '' 67 | self.level = -1 68 | 69 | def generic_visit(self, node): 70 | print(self.prefix+type(node).__name__) 71 | ast.NodeVisitor.generic_visit(self, node) 72 | 73 | @property 74 | def prefix(self): 75 | return self.level * '\t' 76 | 77 | def visit(self, node): 78 | def gen_visit_deco(f): 79 | def new_visit(*args, **kwargs): 80 | # self = args[0] 81 | self.level += 1 82 | f(*args, **kwargs) 83 | self.level -= 1 84 | return new_visit 85 | try: 86 | # method = getattr(self, "visit_{}".format(node.__class__.__name__)) #what's the difference betw. .__class__ vs. type()? 87 | method = getattr(self, "visit_{}".format(type(node).__name__)) 88 | except AttributeError: 89 | method = self.generic_visit 90 | gen_visit_deco(method)(node) 91 | 92 | def visit_FunctionDef(self, node): 93 | print(self.prefix + type(node).__name__ +': '+ node.name) 94 | ast.NodeVisitor.generic_visit(self, node) 95 | 96 | def visit_arguments(self, node): 97 | print(self.prefix + type(node).__name__ +': ') 98 | for field in node._fields: 99 | list_to_visit = getattr(node, field) 100 | if list_to_visit: 101 | print(self.prefix + field + ': ') 102 | for item in list_to_visit: 103 | self.level += 1 104 | ast.NodeVisitor.visit(self, item) 105 | self.level -= 1 106 | 107 | def visit_arg(self, node): 108 | print(self.prefix + node.arg) 109 | 110 | def visit_Name(self, node): 111 | print(self.prefix+ node.id) 112 | 113 | def visit_Num(self, node): 114 | print(self.prefix+'Num:' + str(node.__dict__['n'])) 115 | 116 | def visit_Str(self, node): 117 | print(self.prefix+'Str:' + node.s) 118 | 119 | def visit_Print(self, node): 120 | print(self.prefix+'Print:') 121 | ast.NodeVisitor.generic_visit(self, node) 122 | 123 | def visit_Assign(self, node): 124 | print(self.prefix+'Assign:') 125 | ast.NodeVisitor.visit(self, node) 126 | 127 | def visit_Expr(self, node): 128 | print(self.prefix+'Expr:') 129 | ast.NodeVisitor.generic_visit(self, node) 130 | 131 | def polish(dump, prefix='\t'): 132 | new_str = [] 133 | level = 0 134 | for i in range(len(dump)): 135 | if dump[i] == '(': 136 | level += 1 137 | new_str.append('('+('\n'+prefix*level)*(dump[i+1] != ')')) 138 | 139 | elif dump[i] == '[': 140 | level += 1 141 | new_str.append('['+('\n'+prefix*level)*(dump[i+1] != ']')) 142 | 143 | elif dump[i] == ',': 144 | new_str.append(',\n'+prefix*level) 145 | 146 | elif dump[i] == ')' or dump[i] == ']': 147 | level -= 1 148 | new_str.append(dump[i]) 149 | elif dump[i] == ' ': 150 | continue 151 | 152 | else: 153 | new_str.append(dump[i]) 154 | return ''.join(new_str) 155 | 156 | def print_ast(tree, prefix=' '): 157 | print(polish(ast.dump(tree), prefix)) 158 | 159 | def tail_fact(n, accum=1): 160 | if n <= 1: 161 | return accum 162 | else: 163 | return tail_fact(n - 1, accum * n) 164 | 165 | def rec_fact(n): 166 | if n <= 1: 167 | return 1 168 | else: 169 | return n * rec_fact(n-1) 170 | 171 | # tree = ast.parse(inspect.getsource(f)) 172 | # d = Pretty() 173 | # tree = ast.parse(open('./test.py', 'r').read()) 174 | # print(ast.dump(tree)) 175 | tree = ast.parse(inspect.getsource(rec_fact)) 176 | # tree = ast.parse("a = 5") 177 | a = AltDump() 178 | my_dump = a.visit(tree) 179 | # print(''.join(a.res)) 180 | py_dump = ast.dump(tree) 181 | print my_dump 182 | print(py_dump) 183 | # d.visit(tree) 184 | # print(''.join(d.out)) 185 | # print_ast(tree, ' ') 186 | # print(ast.dump(ast.parse(inspect.getsource(rec_fact)))) 187 | # d.visit(ast.parse(inspect.getsource(tail_fact))) 188 | # print(ast.dump(ast.parse(inspect.getsource(tail_fact)))) 189 | # d.visit(tree) 190 | 191 | # v = PrettyPrinter() 192 | # v.visit(tree) 193 | -------------------------------------------------------------------------------- /ast_trampoline.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import imp 3 | import ast 4 | import inspect 5 | 6 | 7 | (filename,) = sys.argv[1:] 8 | 9 | tree = ast.parse(open(filename, 'r').read()) 10 | 11 | 12 | def tramp_deco(f): 13 | def trampo(*args, **kwargs): 14 | res = f(*args, **kwargs) 15 | while callable(res): 16 | res = res() 17 | return res 18 | return trampo 19 | 20 | tramp_ast = ast.parse(inspect.getsource(tramp_deco)).body[0] 21 | 22 | 23 | class NameChanger(ast.NodeTransformer): 24 | def __init__(self, name): 25 | self.name = name 26 | 27 | def visit_Name(self, el): 28 | if el.id == self.name: 29 | el.id = "_{}".format(self.name) 30 | return el 31 | 32 | 33 | class AutoLambda(ast.NodeTransformer): 34 | 35 | def visit_FunctionDef(self, el): 36 | nc = NameChanger(el.name) 37 | nc.visit(el) 38 | el.name = "_{}".format(el.name) 39 | 40 | self.generic_visit(el) 41 | return el 42 | 43 | def visit_Return(self, el): 44 | el.value = ast.Lambda( 45 | args=ast.arguments( 46 | args=[], 47 | varargs=None, 48 | kwargs=None, 49 | kwonlyargs=[], 50 | defaults=[], 51 | kw_defaults=[], 52 | ), 53 | body=el.value 54 | ) 55 | return el 56 | 57 | 58 | traverse = AutoLambda() 59 | traverse.visit(tree) 60 | 61 | tree.body.insert(0, tramp_ast) 62 | 63 | newbody = [] 64 | for statement in tree.body: 65 | newbody.append(statement) 66 | 67 | if isinstance(statement, ast.FunctionDef): 68 | name = statement.name 69 | if name.startswith("_"): 70 | newbody.append(ast.parse( 71 | "{} = tramp_deco({})".format( 72 | name[1:], name 73 | ) 74 | ).body[0]) 75 | 76 | tree.body = newbody 77 | 78 | ast.fix_missing_locations(tree) 79 | code = compile(tree, filename, "exec") 80 | 81 | namespace = imp.new_module(filename) 82 | eval(code, namespace.__dict__) 83 | -------------------------------------------------------------------------------- /stack_heights.py: -------------------------------------------------------------------------------- 1 | height_map = { 2 | 'STOP_CODE': 0, 3 | 'POP_TOP' : -1, 4 | 'ROT_TWO': 0, 5 | 'ROT_THREE': 0, 6 | 'DUP_TOP': 1, 7 | 'ROT_FOUR': 0, 8 | 9 | 'NOP': 0, 10 | 'UNARY_POSITIVE': 0, 11 | 'UNARY_NEGATIVE': 0, 12 | 'UNARY_NOT': 0, 13 | 'UNARY_CONVERT': 0, 14 | 'UNARY_INVERT': 0, 15 | 16 | 'BINARY_POWER': -1, 17 | 'BINARY_MULTIPLY': -1, 18 | 'BINARY_DIVIDE': -1, 19 | 'BINARY_MODULO': -1, 20 | 'BINARY_ADD': -1, 21 | 'BINARY_SUBTRACT': -1, 22 | 'BINARY_SUBSCR': -1, 23 | 'BINARY_FLOOR_DIVIDE': -1, 24 | 'BINARY_TRUE_DIVIDE': -1, 25 | 'INPLACE_FLOOR_DIVIDE': -1, 26 | 'INPLACE_TRUE_DIVIDE': -1, 27 | # 'SLICE+0': 28 | # 'SLICE+1': 29 | # 'SLICE+2': 30 | # 'SLICE+3': 31 | 32 | # 'STORE_SLICE+0': 33 | # 'STORE_SLICE+1': 34 | # 'STORE_SLICE+2': 35 | # 'STORE_SLICE+3': 36 | 37 | # 'DELETE_SLICE+0': 38 | # 'DELETE_SLICE+1': 39 | # 'DELETE_SLICE+2': 40 | # 'DELETE_SLICE+3': 41 | 42 | 'STORE_MAP': -2, 43 | 'INPLACE_ADD': -1, 44 | 'INPLACE_SUBTRACT': -1, 45 | 'INPLACE_MULTIPLY': -1, 46 | 'INPLACE_DIVIDE': -1, 47 | 'INPLACE_MODULO': -1, 48 | 'STORE_SUBSCR': -3, 49 | 'DELETE_SUBSCR': -2, 50 | 'BINARY_LSHIFT': -1, 51 | 'BINARY_RSHIFT': -1, 52 | 'BINARY_AND': -1, 53 | 'BINARY_XOR': -1, 54 | 'BINARY_OR': -1, 55 | 'INPLACE_POWER': -1, 56 | 'GET_ITER': 0, 57 | 58 | 'PRINT_EXPR': -1, 59 | 'PRINT_ITEM': -1, 60 | 'PRINT_NEWLINE': 0, 61 | 'PRINT_ITEM_TO': -2, 62 | 'PRINT_NEWLINE_TO': -1, 63 | 'INPLACE_LSHIFT': -1, 64 | 'INPLACE_RSHIFT': -1, 65 | 'INPLACE_AND': -1, 66 | 'INPLACE_XOR': -1, 67 | 'INPLACE_OR': -1, 68 | 'BREAK_LOOP': 0, 69 | 'WITH_CLEANUP': 'bad', # cannot say, nondeterministic 70 | 'LOAD_LOCALS': 1, 71 | 'RETURN_VALUE': -1, 72 | 'IMPORT_STAR': -1, 73 | 'EXEC_STMT': -3, 74 | 'YIELD_VALUE': -1, 75 | 'POP_BLOCK': 0, 76 | 'END_FINALLY': 'bad', # cannot say, nondeterministic 77 | 'BUILD_CLASS': -3, 78 | 'STORE_NAME': -1, 79 | 'DELETE_NAME': 0, 80 | 'UNPACK_SEQUENCE': 'bad', # argument - 1 81 | 'FOR_ITER': 'bad', # either +1 during loop or -1 when loop is finished 82 | 'LIST_APPEND': -1, 83 | 'STORE_ATTR': -2, 84 | 'DELETE_ATTR': -1, 85 | 'STORE_GLOBAL': -1, 86 | 'DELETE_GLOBAL': 0, 87 | 'DUP_TOPX': 'bad', # +x 88 | 'LOAD_CONST': +1, 89 | 'LOAD_NAME': +1, 90 | 'BUILD_TUPLE': 'bad', # argument -1 91 | 'BUILD_LIST': 'bad', # argument -1 92 | 'BUILD_SET': 'bad', # argument -1 93 | 'BUILD_MAP': +1, 94 | 'LOAD_ATTR': 0, 95 | 'COMPARE_OP': -1, 96 | 'IMPORT_NAME': -1, 97 | 'IMPORT_FROM': +1, 98 | 'JUMP_FORWARD': 0, 99 | 'JUMP_IF_FALSE_OR_POP': -1, 100 | 'JUMP_IF_TRUE_OR_POP': -1, 101 | 'JUMP_ABSOLUTE': 0, 102 | 'POP_JUMP_IF_FALSE': -1, 103 | 'POP_JUMP_IF_TRUE': -1, 104 | 105 | 'LOAD_GLOBAL': +1, 106 | 107 | 'CONTINUE_LOOP': 0, 108 | 'SETUP_LOOP': 0, 109 | 'SETUP_EXCEPT': 0, 110 | 'SETUP_FINALLY': 0, 111 | 112 | 'LOAD_FAST': +1, 113 | 'STORE_FAST': -1, 114 | 'DELETE_FAST': 0, 115 | 116 | 'RAISE_VARARGS': 'bad', #varargs! 117 | 'CALL_FUNCTION': 'bad', # -(argument) 118 | 'MAKE_FUNCTION': 'bad', # argument + 1 119 | # 'BUILD_SLICE': 120 | # 'MAKE_CLOSURE': 121 | # 'LOAD_CLOSURE': 122 | # 'LOAD_DEREF': 123 | # 'STORE_DEREF': 124 | 125 | # 'CALL_FUNCTION_VAR': 126 | # 'CALL_FUNCTION_KW': 127 | # 'CALL_FUNCTION_VAR_KW': 128 | 129 | # 'SETUP_WITH': 130 | 131 | # 'EXTENDED_ARG': 132 | # 'SET_ADD': 133 | # 'MAP_ADD': 134 | } 135 | -------------------------------------------------------------------------------- /tailbytes.py: -------------------------------------------------------------------------------- 1 | from types import CodeType 2 | import opcode 3 | import dis 4 | from stack_heights import height_map 5 | from opcode import opname, opmap 6 | 7 | def make_tail_recursive(fn): 8 | old_code = fn.__code__ 9 | bytecode = advanced_recurse(fn) 10 | c = CodeType(old_code.co_argcount, old_code.co_nlocals, old_code.co_stacksize, 11 | old_code.co_flags, bytecode, old_code.co_consts, old_code.co_names, 12 | old_code.co_varnames, old_code.co_filename, old_code.co_name, old_code.co_firstlineno, 13 | old_code.co_lnotab) 14 | fn.__code__ = c 15 | return fn 16 | 17 | def tail_recurse(fn): 18 | new_bytecode = [] 19 | jump_displacement = 0 20 | jump_list = [] 21 | code_obj = fn.__code__ 22 | for byte, arg in consume(code_obj.co_code): 23 | name = opcode.opname[byte] 24 | if name == "LOAD_GLOBAL" and code_obj.co_names[arg] == fn.__name__: 25 | new_bytecode.append(opmap["NOP"]) 26 | new_bytecode.append(opmap["NOP"]) 27 | new_bytecode.append(opmap["NOP"]) 28 | elif name == "CALL_FUNCTION": 29 | for i in range(arg): # 0, 1 30 | new_bytecode.append(opmap["STORE_FAST"]) 31 | new_bytecode += split(arg - i - 1) 32 | new_bytecode.append(opmap["JUMP_ABSOLUTE"]) 33 | # new_bytecode += split(-jump_displacement) # jump to beginning of bytecode 34 | new_bytecode += split(0) # jump to beginning of bytecode 35 | jump_displacement += 3 * arg 36 | else: 37 | new_bytecode.append(byte) 38 | if arg is not None: 39 | new_bytecode += split(arg) 40 | 41 | if arg is not None: 42 | jump_list.append(jump_displacement) 43 | jump_list.append(jump_displacement) 44 | jump_list.append(jump_displacement) 45 | 46 | assert len(jump_list) == len(code_obj.co_code) 47 | 48 | newer_bytecode = [] 49 | for byte, arg in consume(new_bytecode): 50 | if byte in opcode.hasjabs: 51 | arg = arg + jump_list[arg] 52 | 53 | newer_bytecode.append(byte) 54 | if arg is not None: 55 | newer_bytecode += split(arg) 56 | 57 | def advanced_recurse(fn): 58 | new_bytecode = [] 59 | jump_displacement = 0 60 | jump_list = [] 61 | code_obj = fn.__code__ 62 | inside_recur = False 63 | 64 | for byte, arg in consume(code_obj.co_code): 65 | name = opcode.opname[byte] 66 | if not inside_recur: 67 | #if hit LOAD_GLOBAL(self), remove this instruction and get into inside_recur mode 68 | if name == 'LOAD_GLOBAL' and code_obj.co_names[arg] == fn.__name__: 69 | new_bytecode.append(opmap["NOP"]) 70 | new_bytecode.append(opmap["NOP"]) 71 | new_bytecode.append(opmap["NOP"]) 72 | stack_size = 0 73 | inside_recur = True 74 | else: 75 | new_bytecode.append(byte) 76 | if arg is not None: 77 | new_bytecode += split(arg) 78 | else: #inside_recur 79 | #update stack count 80 | height_change = height_map[name] 81 | if name == 'CALL_FUNCTION': 82 | height_change = -arg 83 | stack_size += height_change 84 | if name == "CALL_FUNCTION" and stack_size == 0: #we hit the recursive call, time to reset 85 | for i in range(arg): # 0, 1 86 | new_bytecode.append(opmap["STORE_FAST"]) 87 | new_bytecode += split(arg - i - 1) 88 | new_bytecode.append(opmap["JUMP_ABSOLUTE"]) 89 | new_bytecode += split(-jump_displacement) # jump to beginning of bytecode 90 | # new_bytecode += split(0) # jump to beginning of bytecode 91 | jump_displacement += 3 * arg 92 | else: 93 | new_bytecode.append(byte) 94 | if arg is not None: 95 | new_bytecode += split(arg) 96 | 97 | if arg is not None: 98 | jump_list.append(jump_displacement) 99 | jump_list.append(jump_displacement) 100 | jump_list.append(jump_displacement) 101 | 102 | assert len(jump_list) == len(code_obj.co_code) 103 | 104 | newer_bytecode = [] 105 | for byte, arg in consume(new_bytecode): 106 | if byte in opcode.hasjabs: 107 | arg = arg + jump_list[arg] 108 | 109 | newer_bytecode.append(byte) 110 | if arg is not None: 111 | newer_bytecode += split(arg) 112 | # print fn.__name__ 113 | # print code_obj.co_code 114 | # print jump_list 115 | # print new_bytecode 116 | # print newer_bytecode 117 | # print "old" 118 | # print dis.dis(code_obj.co_code) 119 | # print "new" 120 | # print new_bytecode 121 | # print dis.dis("".join([chr(b) for b in new_bytecode])) 122 | # print "newer" 123 | # print dis.dis("".join([chr(b) for b in newer_bytecode])) 124 | return "".join([chr(b) for b in newer_bytecode]) 125 | 126 | 127 | def split(num): 128 | """ Return an integer as two bytes""" 129 | return divmod(num, 255)[::-1] 130 | 131 | 132 | def consume(bytecode): 133 | if isinstance(bytecode[0], str): 134 | bytecode = [ord(b) for b in bytecode] 135 | i = 0 136 | while i < len(bytecode): 137 | op = bytecode[i] 138 | if op > opcode.HAVE_ARGUMENT: 139 | args = bytecode[i+1:i+3] 140 | arg = args[0] + (args[1] << 8) 141 | yield op, arg 142 | i += 3 143 | else: 144 | yield op, None 145 | i += 1 146 | 147 | @make_tail_recursive 148 | def fact(n, accum): 149 | if n <= 1: 150 | return accum 151 | else: 152 | return fact(n-1, accum*n) 153 | 154 | @make_tail_recursive 155 | def fact2(n, accum): 156 | if n > 1: 157 | return fact2(n-1, accum*n) 158 | else: 159 | return accum 160 | 161 | def identity(x): 162 | return x 163 | 164 | @make_tail_recursive 165 | def fact3(n, accum): 166 | if n <= 1: 167 | return accum 168 | else: 169 | return fact3(n-1, accum*identity(n)) 170 | 171 | def sq(x): return x*x 172 | 173 | @make_tail_recursive 174 | def sum_squares(n, accum): 175 | if n < 1: 176 | return accum 177 | else: 178 | return sum_squares(n-1, accum+sq(n)) 179 | 180 | if __name__ == '__main__': 181 | # print fact(1000,1) 182 | # print fact2(1000, 1) 183 | print dis.dis(fact) 184 | print dis.dis(fact3) 185 | # f = make_tail_recursive(fact3) 186 | # print dis.dis(f) 187 | print fact3(5, 1) 188 | print fact3(1000, 1) 189 | print sum_squares(1000, 0) -------------------------------------------------------------------------------- /test_rec_fact.py: -------------------------------------------------------------------------------- 1 | def fact(n): 2 | if n <= 1: 3 | return 1 4 | else: 5 | return n * fact(n - 1) 6 | 7 | print(fact(10)) 8 | 9 | def sq(x): 10 | return x * x 11 | 12 | def sum_squares(n): 13 | if n <= 1: 14 | return 1 15 | else: 16 | return sq(n) * sum_squares(n - 1) 17 | 18 | print sum_squares(10) # 13168189440000 19 | 20 | def two_even_three_odd(n): 21 | if n == 1: 22 | return 3 23 | elif n == 2: 24 | return 4 25 | elif n % 2 == 0: 26 | return 2 * n + two_even_three_odd(n-1) 27 | else: 28 | return 3 * n + two_even_three_odd(n-1) 29 | 30 | print(two_even_three_odd(10)) #132 -------------------------------------------------------------------------------- /trampoline.py: -------------------------------------------------------------------------------- 1 | # normal recursive factorial 2 | def fact_recursive(n): 3 | return 1 if n <= 1 else n * fact_recursive(n-1) 4 | 5 | # normal tail-recursive factorial 6 | def fact_normal_tail_rec(n, accum): 7 | return accum if n <= 1 else fact_normal_tail_rec(n-1, accum*n) 8 | 9 | # both blow up the stack 10 | # print fact_recursive(100) 11 | # print fact_normal_tail_rec(100, 1) 12 | # print fact_recursive(1000) #raises RuntimeError: maximum recursion depth exceeded 13 | # print fact_normal_tail_rec(1000, 1) 14 | 15 | # normal mutual recursion 16 | def even_recursive(n): 17 | return True if n == 0 else odd_recursive(n-1) 18 | 19 | def odd_recursive(n): 20 | return False if n == 0 else even_recursive(n-1) 21 | 22 | # normal mutual recursion blows the stack 23 | # print even_recursive(100) 24 | # print odd_recursive(50) 25 | # print even_recursive(1000) 26 | # print odd_recursive(1000) 27 | 28 | # trampoline-ready recursive functions 29 | def fact(n, accum): 30 | if n <= 1: 31 | return accum 32 | else: 33 | return lambda: fact(n-1, accum*n) #returns thunk instead of a call to itself 34 | 35 | # trampoline-ready mutual recursion 36 | def even(n): 37 | if n == 0: 38 | return True 39 | else: 40 | return lambda: odd(n-1) 41 | 42 | def odd(n): 43 | if n == 0: 44 | return False 45 | else: 46 | return lambda: even(n-1) 47 | 48 | # finally, the trampoline function 49 | def tramp(f, *args, **kwargs): 50 | res = f(*args, **kwargs) 51 | while callable(res): 52 | res = res() 53 | return res 54 | 55 | # trampolining recursive functions prevents stack overflows 56 | # print tramp(fact, 1000, 1) 57 | # print tramp(even, 1000) 58 | 59 | # there is another way to define the trampoline function: it can accept the _result_ of calling f on *args: 60 | def tramp_result(f): 61 | res = f 62 | while callable(res): 63 | res = res() 64 | return res 65 | 66 | # this also works 67 | # print tramp_result(fact(1000, 1)) 68 | # print tramp_result(odd(1000)) 69 | 70 | #now attempting to stick a trampoline into a generator 71 | def bad_tramp(f): 72 | """ 73 | takes a tail-recursive function f 74 | returns a trampolined version of this function 75 | """ 76 | def tramp_f(*args, **kwargs): 77 | res = f(*args, **kwargs) 78 | while callable(res): 79 | res = res() 80 | return res 81 | return tramp_f 82 | 83 | # there is nothing inherently bad with this decorator: it returns a working trampolined function 84 | # tramped_fact = bad_tramp(fact) 85 | # print tramped_fact(1000, 1) 86 | 87 | #BUT if we use it to change the original function, the trouble begins... 88 | #there are two ways of doing this: either using decorator syntax directly: 89 | 90 | @bad_tramp 91 | def bad_factorial(n, accum): 92 | return accum if n <= 1 else lambda: bad_factorial(n-1, accum*n) #same definition as fact 93 | # print bad_factorial(1000, 1) 94 | 95 | #or by doing what decorators are doing under the hood: 96 | # fact = bad_tramp(fact) 97 | # print fact(1000, 1) #raises RuntimeError: maximum recursion depth exceeded while calling a Python object 98 | # different error from before (cf.: RuntimeError: maximum recursion depth exceeded) 99 | 100 | # this behavior is caused by the fact that lambda: <...> doesn't close over any variables, 101 | # so if we change the definition of fact in the process of creating a trampolined version of it, the new function fact 102 | # will be calling the "new" version of itself, which will create new frames. Cf.: 103 | 104 | # x = [lambda: i+1 for i in range(3)] 105 | # print [j() for j in x] #[3, 3, 3] When this line is called, the value of i is read from the global environment 106 | ##way to go around it is to force the closure over i: 107 | # x = [(lambda real_i: lambda: real_i+1)(i) for i in range(3)] 108 | # print [j() for j in x] #[1, 2, 3] 109 | 110 | #how can we take this idea and improve our decorator? 111 | def better_tramp(f): 112 | """ 113 | takes a tail-recursive function f 114 | returns a trampolined version of this function 115 | """ 116 | def tramp_f(*args, **kwargs): 117 | res = f(*args, **kwargs) 118 | while callable(res): 119 | res = res() 120 | return res 121 | return tramp_f 122 | 123 | @better_tramp 124 | def better_factorial(n, accum): 125 | return accum if n <= 1 else (lambda real_f: lambda: real_f(n-1, accum*n))(better_factorial) 126 | print better_factorial(1000, 1) 127 | 128 | # a possible workaround to avoid name collision. Won't work with mutual recursion. 129 | # def deco_tramp(f): 130 | # """ 131 | # takes a tail-recursive function f 132 | # return a trampolined version of this function 133 | # """ 134 | # f_itself = lambda *a,**kw: f(f_itself, *a, **kw) 135 | # def tramp_f(*args, **kwargs): 136 | # res = f_itself(*args, **kwargs) 137 | # while callable(res): 138 | # res = res() 139 | # return res 140 | # return tramp_f 141 | # 142 | # @deco_tramp 143 | # def recur_deco_fact(recur, n, accum): 144 | # if n <= 1: 145 | # return accum 146 | # else: 147 | # return lambda: recur(n-1, accum*n) 148 | 149 | # import sys 150 | # sys.setrecursionlimit(4) -------------------------------------------------------------------------------- /translator.py: -------------------------------------------------------------------------------- 1 | __author__ = 'Liuda' 2 | import ast, inspect, imp 3 | from ast_pretty_printer import print_ast 4 | import sys 5 | 6 | """ 7 | High level algorithm: 8 | takes in a file, finds all recursive functions 9 | calls translator for each of the recursive functions found. 10 | """ 11 | #TODO: how do you distinguish between a normal-recursive and tail-recursive function? 12 | #TODO: deal with base case that is not a value, but is an expression. 13 | 14 | # def is_recursive_function(tree): 15 | # try: 16 | # #get function name and function definition subtree 17 | # for node in ast.walk(tree): 18 | # if type(node).__name__ == 'FunctionDef': 19 | # name = node.name 20 | # func_def = node 21 | # for node in ast.walk(func_def): #find whether this func_def contains recursive calls 22 | # if type(node).__name__ == 'Call' and node.func.id == name: 23 | # return True 24 | # return False 25 | # except UnboundLocalError: 26 | # print('No functions found') 27 | # return False 28 | 29 | 30 | def is_recursive(tree, name): 31 | """ 32 | Takes a tree and returns a recursive ast.Call if one is present in the tree, otherwise None. 33 | """ 34 | for node in ast.walk(tree): 35 | if type(node).__name__ == 'Call' and node.func.id == name: 36 | return True 37 | return False 38 | 39 | class InfoGatherer(ast.NodeVisitor): 40 | """ 41 | traverses the tree, storing the function name, default value and recursive calls. 42 | """ 43 | def __init__(self): 44 | self.name = None 45 | self.default_value = None 46 | self.recursive_returns = [] 47 | 48 | def visit_FunctionDef(self, node): 49 | self.name = node.name 50 | ast.NodeVisitor.generic_visit(self, node) 51 | 52 | def visit_Return(self, node): 53 | rec_call = is_recursive(node, self.name) 54 | if rec_call: 55 | self.recursive_returns.append(node) 56 | else: 57 | self.default_value = node.value 58 | 59 | class Mangler(ast.NodeTransformer): 60 | """ 61 | Requirements to the ast tree passed: 62 | tree should be a result of ast.parse on a recursive function that has one base case, one recursive call per recursive return 63 | and no internal function definitions. 64 | """ 65 | def __init__(self, tree): 66 | ig = InfoGatherer() 67 | ig.visit(tree) 68 | self.name = ig.name 69 | self.default_value = ig.default_value 70 | self.recursive_returns = ig.recursive_returns 71 | 72 | def make_Name(self, name, ctx_val): 73 | return ast.Name( 74 | id=name, 75 | ctx=ctx_val 76 | ) 77 | 78 | def make_Num(self, x): 79 | return ast.Num(n=x) 80 | 81 | def update_recursive_calls(self): 82 | for node in self.recursive_returns: 83 | self.visit_Return(node) 84 | 85 | def visit_Return(self, node): 86 | if is_recursive(node, self.name): 87 | #substitute the recursive call with accum, store the call in self.outer 88 | ast.NodeTransformer.generic_visit(self, node) 89 | self.outer.args.append(node.value) 90 | return ast.Return(value=self.outer) 91 | else: 92 | return ast.Return(value=self.make_Name('accum', ast.Load())) 93 | 94 | def visit_Call(self, node): 95 | if node.func.id == self.name: 96 | self.outer = node 97 | return self.make_Name('accum', ast.Load()) 98 | else: 99 | return node 100 | 101 | def visit_FunctionDef(self, node): 102 | #update arguments 103 | node.args.args.append(self.make_Name('accum', ast.Param())) 104 | node.args.defaults.append(self.default_value) 105 | ast.NodeTransformer.generic_visit(self, node) 106 | return node 107 | 108 | def visit_Print(self, node): 109 | return node 110 | 111 | class ModuleCrawler(ast.NodeTransformer): 112 | def visit_Module(self, tree): 113 | for node in ast.iter_child_nodes(tree): 114 | if type(node).__name__ == 'FunctionDef' and is_recursive(node, node.name): 115 | m = Mangler(node) 116 | m.visit(node) 117 | return tree 118 | 119 | def tail_fact(n, accum=1): 120 | if n <= 1: return accum 121 | else: return tail_fact(n - 1, accum * n) 122 | 123 | def rec_fact(n): 124 | if n <= 1: 125 | return 1 126 | else: 127 | return n * rec_fact(n-1) 128 | 129 | 130 | (filename,) = sys.argv[1:] 131 | tree = ast.parse(open(filename, 'r').read()) 132 | 133 | m = ModuleCrawler() 134 | m.visit(tree) 135 | print_ast(tree) 136 | 137 | ast.fix_missing_locations(tree) 138 | code = compile(tree, filename, "exec") 139 | 140 | namespace = imp.new_module(filename) 141 | eval(code, namespace.__dict__) -------------------------------------------------------------------------------- /translator_tests.py: -------------------------------------------------------------------------------- 1 | __author__ = 'Liuda' 2 | import ast, inspect 3 | import translator 4 | 5 | 6 | def rec_fact(n): 7 | if n <= 1: 8 | return 1 9 | else: 10 | return n * rec_fact(n-1) 11 | 12 | def tail_fact(n, accum=1): 13 | if n <= 1: 14 | return accum 15 | else: 16 | return tail_fact(n-1, accum * n) 17 | 18 | def f(n): 19 | print(n) 20 | f = 5 21 | return f 22 | 23 | def sq(x): 24 | return x*x 25 | 26 | def sum_squares(x): 27 | if x <= 1: 28 | return sq(x) 29 | else: 30 | return sq(x) + sum_squares(x-1) 31 | 32 | def print_n_squares(x): 33 | print(sq(x)) 34 | if x > 1: 35 | print_n_squares(x-1) 36 | 37 | print(translator.isRecursive(ast.parse('a = 5 + 9'))) 38 | 39 | def test(): 40 | for func in [rec_fact, tail_fact, f, sq, sum_squares, print_n_squares]: 41 | tree = ast.parse(inspect.getsource(func)) 42 | print(func.__name__, translator.isRecursive(tree)) 43 | 44 | test() --------------------------------------------------------------------------------