├── README.md
├── ast_pretty_printer.py
├── ast_trampoline.py
├── stack_heights.py
├── tailbytes.py
├── test_rec_fact.py
├── trampoline.py
├── trampolines_in_Python.ipynb
├── translator.py
└── translator_tests.py
/README.md:
--------------------------------------------------------------------------------
1 | Recursion.
2 | ==================
3 | This project explores a variety of ways to optimize recursion in Python.
4 | So far, the following approaches have been explored:
5 | Trampolines.
6 | -----------------
7 | TCO using Python standard means. Taking a tail-recursive function that returns a thunk in place of a recursive call, the trampoline function keeps calling the thunked result until a non-callable value is reached.
8 | >Associated files:
9 | >trampolines.py - all source code for trampolines
10 | >trampolines_in_Python.ipynb
11 |
12 | AST trampolines.
13 | --------------------------------
14 | TCO by manipulating AST of the function (with Paul Tagliamonte).
15 | The ast of the tail-recursive function is altered so that all returns are "thunked", a trampolining decorator is injected into the tree.
16 | >Associated files:
17 | >ast_trampoline.py
18 |
19 | Tailbytes.
20 | ---------
21 | TCO using direct bytecode manipulation (with Allison Kaptur).
22 | By swapping the bytecode representation of a tail-recursive function, the recursive call is substituted by resetting the variables and jumping to the beginning of the function.
23 | Problems we've faced and resolved so far:
24 |
25 | * _Problem:_ Deleting and inserting bytes into the bytecode changes the location of the original bytes.
26 | _Solution:_ Absolute jumps are updated after bytecode alternation has been performed.
27 | * _Problem:_ Our initial algorithm was removing all calls to *any* function. This caused a problem if the function involve non-recursive function calls.
28 | _Solution:_ Remove a CALL_FUNCTION instruction only if it is a part of the recursive call (to determine that, we keep track of the stack size). We target the CALL_FUNCTION instructions which return the stack to the same size that it had when the recursive function was loaded onto it.
29 |
30 | >Associated files:
31 | >tailbytes.py
32 | >stack_heights.py
33 | >See also [presentation slides](http://www.slideshare.net/lnikolaeva/tailbytes-pygotham)
34 |
35 | Transform regular recursion into tail recursion.
36 | -------------------------------
37 | Interpreter that reads an ast representation of a regular recursion and mutates it to a tree that compiles to the same function using tail recursion.
38 | >Associated files:
39 | >translator.py
40 |
41 | Further goals:
42 | --------
43 | * Transform binary recursion (i.e. fibonacci-like functions) into tail-recursion.
44 | * Find a way to optimize mutual recursion.
45 |
46 | Tools:
47 | -----------
48 | * ast_pretty_printer.py
49 | print_ast takes an ast tree and prettyprints it, working off of ast.dump
50 |
--------------------------------------------------------------------------------
/ast_pretty_printer.py:
--------------------------------------------------------------------------------
1 | import ast, inspect
2 |
3 | class AltDump(ast.NodeVisitor):
4 |
5 | def visit(self, node):
6 | res = []
7 | res.append(type(node).__name__+'(')
8 | field_values = []
9 | for field, value in ast.iter_fields(node):
10 | if type(value) == list:
11 | field_values.append(field+'=['+ ', '.join([self.visit(item) for item in value])+']')
12 | elif isinstance(value, ast.AST):
13 | field_values.append(field+'='+self.visit(value))
14 | # elif value == None:
15 | # field_values.append(field+'=None')
16 | else:
17 | field_values.append(field+'='+ ("'{}'".format(value) if type(value) == str
18 | else str(value)))
19 | res.append(', '.join(field_values))
20 | self.generic_visit(node)
21 | res.append(')')
22 | return ''.join(map(str, res))
23 |
24 | class Pretty(ast.NodeVisitor):
25 | def __init__(self, filler=" "):
26 | self.res = ''
27 | self.level = -1
28 | self.filler = filler
29 | self.out = []
30 |
31 | @property
32 | def prefix(self):
33 | return self.level * self.filler
34 |
35 | def visit(self, node):
36 | self.level +=1
37 | print(node)
38 | self.out.append(type(node).__name__ +'(')
39 | self.level += 1
40 | for field, value in ast.iter_fields(node):
41 | print(field, value)
42 | self.out.append(self.prefix+field)
43 | if type(value) == list and len(value) > 0:
44 | self.out.append('=[\n')
45 | for n in value:
46 | self.visit(n)
47 | # self.out.append(self.prefix+type(n).__name__ +'(\n')
48 | # for child in ast.iter_child_nodes(n):
49 | # self.visit(child)
50 | self.out.append('=\n')
51 | elif isinstance(value, ast.AST):
52 | self.out.append('=')
53 | self.generic_visit(value)
54 | else:
55 | self.out.append('=' + str(value) + '\n')
56 | self.level -= 2
57 |
58 | def visit_Name(self, node):
59 | print("I'm here!")
60 | self.level += 1
61 | self.out.append(self.prefix + ast.dump(node))
62 | self.level -= 1
63 |
64 | class PrettyPrinter(ast.NodeVisitor):
65 | def __init__(self):
66 | self.res = ''
67 | self.level = -1
68 |
69 | def generic_visit(self, node):
70 | print(self.prefix+type(node).__name__)
71 | ast.NodeVisitor.generic_visit(self, node)
72 |
73 | @property
74 | def prefix(self):
75 | return self.level * '\t'
76 |
77 | def visit(self, node):
78 | def gen_visit_deco(f):
79 | def new_visit(*args, **kwargs):
80 | # self = args[0]
81 | self.level += 1
82 | f(*args, **kwargs)
83 | self.level -= 1
84 | return new_visit
85 | try:
86 | # method = getattr(self, "visit_{}".format(node.__class__.__name__)) #what's the difference betw. .__class__ vs. type()?
87 | method = getattr(self, "visit_{}".format(type(node).__name__))
88 | except AttributeError:
89 | method = self.generic_visit
90 | gen_visit_deco(method)(node)
91 |
92 | def visit_FunctionDef(self, node):
93 | print(self.prefix + type(node).__name__ +': '+ node.name)
94 | ast.NodeVisitor.generic_visit(self, node)
95 |
96 | def visit_arguments(self, node):
97 | print(self.prefix + type(node).__name__ +': ')
98 | for field in node._fields:
99 | list_to_visit = getattr(node, field)
100 | if list_to_visit:
101 | print(self.prefix + field + ': ')
102 | for item in list_to_visit:
103 | self.level += 1
104 | ast.NodeVisitor.visit(self, item)
105 | self.level -= 1
106 |
107 | def visit_arg(self, node):
108 | print(self.prefix + node.arg)
109 |
110 | def visit_Name(self, node):
111 | print(self.prefix+ node.id)
112 |
113 | def visit_Num(self, node):
114 | print(self.prefix+'Num:' + str(node.__dict__['n']))
115 |
116 | def visit_Str(self, node):
117 | print(self.prefix+'Str:' + node.s)
118 |
119 | def visit_Print(self, node):
120 | print(self.prefix+'Print:')
121 | ast.NodeVisitor.generic_visit(self, node)
122 |
123 | def visit_Assign(self, node):
124 | print(self.prefix+'Assign:')
125 | ast.NodeVisitor.visit(self, node)
126 |
127 | def visit_Expr(self, node):
128 | print(self.prefix+'Expr:')
129 | ast.NodeVisitor.generic_visit(self, node)
130 |
131 | def polish(dump, prefix='\t'):
132 | new_str = []
133 | level = 0
134 | for i in range(len(dump)):
135 | if dump[i] == '(':
136 | level += 1
137 | new_str.append('('+('\n'+prefix*level)*(dump[i+1] != ')'))
138 |
139 | elif dump[i] == '[':
140 | level += 1
141 | new_str.append('['+('\n'+prefix*level)*(dump[i+1] != ']'))
142 |
143 | elif dump[i] == ',':
144 | new_str.append(',\n'+prefix*level)
145 |
146 | elif dump[i] == ')' or dump[i] == ']':
147 | level -= 1
148 | new_str.append(dump[i])
149 | elif dump[i] == ' ':
150 | continue
151 |
152 | else:
153 | new_str.append(dump[i])
154 | return ''.join(new_str)
155 |
156 | def print_ast(tree, prefix=' '):
157 | print(polish(ast.dump(tree), prefix))
158 |
159 | def tail_fact(n, accum=1):
160 | if n <= 1:
161 | return accum
162 | else:
163 | return tail_fact(n - 1, accum * n)
164 |
165 | def rec_fact(n):
166 | if n <= 1:
167 | return 1
168 | else:
169 | return n * rec_fact(n-1)
170 |
171 | # tree = ast.parse(inspect.getsource(f))
172 | # d = Pretty()
173 | # tree = ast.parse(open('./test.py', 'r').read())
174 | # print(ast.dump(tree))
175 | tree = ast.parse(inspect.getsource(rec_fact))
176 | # tree = ast.parse("a = 5")
177 | a = AltDump()
178 | my_dump = a.visit(tree)
179 | # print(''.join(a.res))
180 | py_dump = ast.dump(tree)
181 | print my_dump
182 | print(py_dump)
183 | # d.visit(tree)
184 | # print(''.join(d.out))
185 | # print_ast(tree, ' ')
186 | # print(ast.dump(ast.parse(inspect.getsource(rec_fact))))
187 | # d.visit(ast.parse(inspect.getsource(tail_fact)))
188 | # print(ast.dump(ast.parse(inspect.getsource(tail_fact))))
189 | # d.visit(tree)
190 |
191 | # v = PrettyPrinter()
192 | # v.visit(tree)
193 |
--------------------------------------------------------------------------------
/ast_trampoline.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import imp
3 | import ast
4 | import inspect
5 |
6 |
7 | (filename,) = sys.argv[1:]
8 |
9 | tree = ast.parse(open(filename, 'r').read())
10 |
11 |
12 | def tramp_deco(f):
13 | def trampo(*args, **kwargs):
14 | res = f(*args, **kwargs)
15 | while callable(res):
16 | res = res()
17 | return res
18 | return trampo
19 |
20 | tramp_ast = ast.parse(inspect.getsource(tramp_deco)).body[0]
21 |
22 |
23 | class NameChanger(ast.NodeTransformer):
24 | def __init__(self, name):
25 | self.name = name
26 |
27 | def visit_Name(self, el):
28 | if el.id == self.name:
29 | el.id = "_{}".format(self.name)
30 | return el
31 |
32 |
33 | class AutoLambda(ast.NodeTransformer):
34 |
35 | def visit_FunctionDef(self, el):
36 | nc = NameChanger(el.name)
37 | nc.visit(el)
38 | el.name = "_{}".format(el.name)
39 |
40 | self.generic_visit(el)
41 | return el
42 |
43 | def visit_Return(self, el):
44 | el.value = ast.Lambda(
45 | args=ast.arguments(
46 | args=[],
47 | varargs=None,
48 | kwargs=None,
49 | kwonlyargs=[],
50 | defaults=[],
51 | kw_defaults=[],
52 | ),
53 | body=el.value
54 | )
55 | return el
56 |
57 |
58 | traverse = AutoLambda()
59 | traverse.visit(tree)
60 |
61 | tree.body.insert(0, tramp_ast)
62 |
63 | newbody = []
64 | for statement in tree.body:
65 | newbody.append(statement)
66 |
67 | if isinstance(statement, ast.FunctionDef):
68 | name = statement.name
69 | if name.startswith("_"):
70 | newbody.append(ast.parse(
71 | "{} = tramp_deco({})".format(
72 | name[1:], name
73 | )
74 | ).body[0])
75 |
76 | tree.body = newbody
77 |
78 | ast.fix_missing_locations(tree)
79 | code = compile(tree, filename, "exec")
80 |
81 | namespace = imp.new_module(filename)
82 | eval(code, namespace.__dict__)
83 |
--------------------------------------------------------------------------------
/stack_heights.py:
--------------------------------------------------------------------------------
1 | height_map = {
2 | 'STOP_CODE': 0,
3 | 'POP_TOP' : -1,
4 | 'ROT_TWO': 0,
5 | 'ROT_THREE': 0,
6 | 'DUP_TOP': 1,
7 | 'ROT_FOUR': 0,
8 |
9 | 'NOP': 0,
10 | 'UNARY_POSITIVE': 0,
11 | 'UNARY_NEGATIVE': 0,
12 | 'UNARY_NOT': 0,
13 | 'UNARY_CONVERT': 0,
14 | 'UNARY_INVERT': 0,
15 |
16 | 'BINARY_POWER': -1,
17 | 'BINARY_MULTIPLY': -1,
18 | 'BINARY_DIVIDE': -1,
19 | 'BINARY_MODULO': -1,
20 | 'BINARY_ADD': -1,
21 | 'BINARY_SUBTRACT': -1,
22 | 'BINARY_SUBSCR': -1,
23 | 'BINARY_FLOOR_DIVIDE': -1,
24 | 'BINARY_TRUE_DIVIDE': -1,
25 | 'INPLACE_FLOOR_DIVIDE': -1,
26 | 'INPLACE_TRUE_DIVIDE': -1,
27 | # 'SLICE+0':
28 | # 'SLICE+1':
29 | # 'SLICE+2':
30 | # 'SLICE+3':
31 |
32 | # 'STORE_SLICE+0':
33 | # 'STORE_SLICE+1':
34 | # 'STORE_SLICE+2':
35 | # 'STORE_SLICE+3':
36 |
37 | # 'DELETE_SLICE+0':
38 | # 'DELETE_SLICE+1':
39 | # 'DELETE_SLICE+2':
40 | # 'DELETE_SLICE+3':
41 |
42 | 'STORE_MAP': -2,
43 | 'INPLACE_ADD': -1,
44 | 'INPLACE_SUBTRACT': -1,
45 | 'INPLACE_MULTIPLY': -1,
46 | 'INPLACE_DIVIDE': -1,
47 | 'INPLACE_MODULO': -1,
48 | 'STORE_SUBSCR': -3,
49 | 'DELETE_SUBSCR': -2,
50 | 'BINARY_LSHIFT': -1,
51 | 'BINARY_RSHIFT': -1,
52 | 'BINARY_AND': -1,
53 | 'BINARY_XOR': -1,
54 | 'BINARY_OR': -1,
55 | 'INPLACE_POWER': -1,
56 | 'GET_ITER': 0,
57 |
58 | 'PRINT_EXPR': -1,
59 | 'PRINT_ITEM': -1,
60 | 'PRINT_NEWLINE': 0,
61 | 'PRINT_ITEM_TO': -2,
62 | 'PRINT_NEWLINE_TO': -1,
63 | 'INPLACE_LSHIFT': -1,
64 | 'INPLACE_RSHIFT': -1,
65 | 'INPLACE_AND': -1,
66 | 'INPLACE_XOR': -1,
67 | 'INPLACE_OR': -1,
68 | 'BREAK_LOOP': 0,
69 | 'WITH_CLEANUP': 'bad', # cannot say, nondeterministic
70 | 'LOAD_LOCALS': 1,
71 | 'RETURN_VALUE': -1,
72 | 'IMPORT_STAR': -1,
73 | 'EXEC_STMT': -3,
74 | 'YIELD_VALUE': -1,
75 | 'POP_BLOCK': 0,
76 | 'END_FINALLY': 'bad', # cannot say, nondeterministic
77 | 'BUILD_CLASS': -3,
78 | 'STORE_NAME': -1,
79 | 'DELETE_NAME': 0,
80 | 'UNPACK_SEQUENCE': 'bad', # argument - 1
81 | 'FOR_ITER': 'bad', # either +1 during loop or -1 when loop is finished
82 | 'LIST_APPEND': -1,
83 | 'STORE_ATTR': -2,
84 | 'DELETE_ATTR': -1,
85 | 'STORE_GLOBAL': -1,
86 | 'DELETE_GLOBAL': 0,
87 | 'DUP_TOPX': 'bad', # +x
88 | 'LOAD_CONST': +1,
89 | 'LOAD_NAME': +1,
90 | 'BUILD_TUPLE': 'bad', # argument -1
91 | 'BUILD_LIST': 'bad', # argument -1
92 | 'BUILD_SET': 'bad', # argument -1
93 | 'BUILD_MAP': +1,
94 | 'LOAD_ATTR': 0,
95 | 'COMPARE_OP': -1,
96 | 'IMPORT_NAME': -1,
97 | 'IMPORT_FROM': +1,
98 | 'JUMP_FORWARD': 0,
99 | 'JUMP_IF_FALSE_OR_POP': -1,
100 | 'JUMP_IF_TRUE_OR_POP': -1,
101 | 'JUMP_ABSOLUTE': 0,
102 | 'POP_JUMP_IF_FALSE': -1,
103 | 'POP_JUMP_IF_TRUE': -1,
104 |
105 | 'LOAD_GLOBAL': +1,
106 |
107 | 'CONTINUE_LOOP': 0,
108 | 'SETUP_LOOP': 0,
109 | 'SETUP_EXCEPT': 0,
110 | 'SETUP_FINALLY': 0,
111 |
112 | 'LOAD_FAST': +1,
113 | 'STORE_FAST': -1,
114 | 'DELETE_FAST': 0,
115 |
116 | 'RAISE_VARARGS': 'bad', #varargs!
117 | 'CALL_FUNCTION': 'bad', # -(argument)
118 | 'MAKE_FUNCTION': 'bad', # argument + 1
119 | # 'BUILD_SLICE':
120 | # 'MAKE_CLOSURE':
121 | # 'LOAD_CLOSURE':
122 | # 'LOAD_DEREF':
123 | # 'STORE_DEREF':
124 |
125 | # 'CALL_FUNCTION_VAR':
126 | # 'CALL_FUNCTION_KW':
127 | # 'CALL_FUNCTION_VAR_KW':
128 |
129 | # 'SETUP_WITH':
130 |
131 | # 'EXTENDED_ARG':
132 | # 'SET_ADD':
133 | # 'MAP_ADD':
134 | }
135 |
--------------------------------------------------------------------------------
/tailbytes.py:
--------------------------------------------------------------------------------
1 | from types import CodeType
2 | import opcode
3 | import dis
4 | from stack_heights import height_map
5 | from opcode import opname, opmap
6 |
7 | def make_tail_recursive(fn):
8 | old_code = fn.__code__
9 | bytecode = advanced_recurse(fn)
10 | c = CodeType(old_code.co_argcount, old_code.co_nlocals, old_code.co_stacksize,
11 | old_code.co_flags, bytecode, old_code.co_consts, old_code.co_names,
12 | old_code.co_varnames, old_code.co_filename, old_code.co_name, old_code.co_firstlineno,
13 | old_code.co_lnotab)
14 | fn.__code__ = c
15 | return fn
16 |
17 | def tail_recurse(fn):
18 | new_bytecode = []
19 | jump_displacement = 0
20 | jump_list = []
21 | code_obj = fn.__code__
22 | for byte, arg in consume(code_obj.co_code):
23 | name = opcode.opname[byte]
24 | if name == "LOAD_GLOBAL" and code_obj.co_names[arg] == fn.__name__:
25 | new_bytecode.append(opmap["NOP"])
26 | new_bytecode.append(opmap["NOP"])
27 | new_bytecode.append(opmap["NOP"])
28 | elif name == "CALL_FUNCTION":
29 | for i in range(arg): # 0, 1
30 | new_bytecode.append(opmap["STORE_FAST"])
31 | new_bytecode += split(arg - i - 1)
32 | new_bytecode.append(opmap["JUMP_ABSOLUTE"])
33 | # new_bytecode += split(-jump_displacement) # jump to beginning of bytecode
34 | new_bytecode += split(0) # jump to beginning of bytecode
35 | jump_displacement += 3 * arg
36 | else:
37 | new_bytecode.append(byte)
38 | if arg is not None:
39 | new_bytecode += split(arg)
40 |
41 | if arg is not None:
42 | jump_list.append(jump_displacement)
43 | jump_list.append(jump_displacement)
44 | jump_list.append(jump_displacement)
45 |
46 | assert len(jump_list) == len(code_obj.co_code)
47 |
48 | newer_bytecode = []
49 | for byte, arg in consume(new_bytecode):
50 | if byte in opcode.hasjabs:
51 | arg = arg + jump_list[arg]
52 |
53 | newer_bytecode.append(byte)
54 | if arg is not None:
55 | newer_bytecode += split(arg)
56 |
57 | def advanced_recurse(fn):
58 | new_bytecode = []
59 | jump_displacement = 0
60 | jump_list = []
61 | code_obj = fn.__code__
62 | inside_recur = False
63 |
64 | for byte, arg in consume(code_obj.co_code):
65 | name = opcode.opname[byte]
66 | if not inside_recur:
67 | #if hit LOAD_GLOBAL(self), remove this instruction and get into inside_recur mode
68 | if name == 'LOAD_GLOBAL' and code_obj.co_names[arg] == fn.__name__:
69 | new_bytecode.append(opmap["NOP"])
70 | new_bytecode.append(opmap["NOP"])
71 | new_bytecode.append(opmap["NOP"])
72 | stack_size = 0
73 | inside_recur = True
74 | else:
75 | new_bytecode.append(byte)
76 | if arg is not None:
77 | new_bytecode += split(arg)
78 | else: #inside_recur
79 | #update stack count
80 | height_change = height_map[name]
81 | if name == 'CALL_FUNCTION':
82 | height_change = -arg
83 | stack_size += height_change
84 | if name == "CALL_FUNCTION" and stack_size == 0: #we hit the recursive call, time to reset
85 | for i in range(arg): # 0, 1
86 | new_bytecode.append(opmap["STORE_FAST"])
87 | new_bytecode += split(arg - i - 1)
88 | new_bytecode.append(opmap["JUMP_ABSOLUTE"])
89 | new_bytecode += split(-jump_displacement) # jump to beginning of bytecode
90 | # new_bytecode += split(0) # jump to beginning of bytecode
91 | jump_displacement += 3 * arg
92 | else:
93 | new_bytecode.append(byte)
94 | if arg is not None:
95 | new_bytecode += split(arg)
96 |
97 | if arg is not None:
98 | jump_list.append(jump_displacement)
99 | jump_list.append(jump_displacement)
100 | jump_list.append(jump_displacement)
101 |
102 | assert len(jump_list) == len(code_obj.co_code)
103 |
104 | newer_bytecode = []
105 | for byte, arg in consume(new_bytecode):
106 | if byte in opcode.hasjabs:
107 | arg = arg + jump_list[arg]
108 |
109 | newer_bytecode.append(byte)
110 | if arg is not None:
111 | newer_bytecode += split(arg)
112 | # print fn.__name__
113 | # print code_obj.co_code
114 | # print jump_list
115 | # print new_bytecode
116 | # print newer_bytecode
117 | # print "old"
118 | # print dis.dis(code_obj.co_code)
119 | # print "new"
120 | # print new_bytecode
121 | # print dis.dis("".join([chr(b) for b in new_bytecode]))
122 | # print "newer"
123 | # print dis.dis("".join([chr(b) for b in newer_bytecode]))
124 | return "".join([chr(b) for b in newer_bytecode])
125 |
126 |
127 | def split(num):
128 | """ Return an integer as two bytes"""
129 | return divmod(num, 255)[::-1]
130 |
131 |
132 | def consume(bytecode):
133 | if isinstance(bytecode[0], str):
134 | bytecode = [ord(b) for b in bytecode]
135 | i = 0
136 | while i < len(bytecode):
137 | op = bytecode[i]
138 | if op > opcode.HAVE_ARGUMENT:
139 | args = bytecode[i+1:i+3]
140 | arg = args[0] + (args[1] << 8)
141 | yield op, arg
142 | i += 3
143 | else:
144 | yield op, None
145 | i += 1
146 |
147 | @make_tail_recursive
148 | def fact(n, accum):
149 | if n <= 1:
150 | return accum
151 | else:
152 | return fact(n-1, accum*n)
153 |
154 | @make_tail_recursive
155 | def fact2(n, accum):
156 | if n > 1:
157 | return fact2(n-1, accum*n)
158 | else:
159 | return accum
160 |
161 | def identity(x):
162 | return x
163 |
164 | @make_tail_recursive
165 | def fact3(n, accum):
166 | if n <= 1:
167 | return accum
168 | else:
169 | return fact3(n-1, accum*identity(n))
170 |
171 | def sq(x): return x*x
172 |
173 | @make_tail_recursive
174 | def sum_squares(n, accum):
175 | if n < 1:
176 | return accum
177 | else:
178 | return sum_squares(n-1, accum+sq(n))
179 |
180 | if __name__ == '__main__':
181 | # print fact(1000,1)
182 | # print fact2(1000, 1)
183 | print dis.dis(fact)
184 | print dis.dis(fact3)
185 | # f = make_tail_recursive(fact3)
186 | # print dis.dis(f)
187 | print fact3(5, 1)
188 | print fact3(1000, 1)
189 | print sum_squares(1000, 0)
--------------------------------------------------------------------------------
/test_rec_fact.py:
--------------------------------------------------------------------------------
1 | def fact(n):
2 | if n <= 1:
3 | return 1
4 | else:
5 | return n * fact(n - 1)
6 |
7 | print(fact(10))
8 |
9 | def sq(x):
10 | return x * x
11 |
12 | def sum_squares(n):
13 | if n <= 1:
14 | return 1
15 | else:
16 | return sq(n) * sum_squares(n - 1)
17 |
18 | print sum_squares(10) # 13168189440000
19 |
20 | def two_even_three_odd(n):
21 | if n == 1:
22 | return 3
23 | elif n == 2:
24 | return 4
25 | elif n % 2 == 0:
26 | return 2 * n + two_even_three_odd(n-1)
27 | else:
28 | return 3 * n + two_even_three_odd(n-1)
29 |
30 | print(two_even_three_odd(10)) #132
--------------------------------------------------------------------------------
/trampoline.py:
--------------------------------------------------------------------------------
1 | # normal recursive factorial
2 | def fact_recursive(n):
3 | return 1 if n <= 1 else n * fact_recursive(n-1)
4 |
5 | # normal tail-recursive factorial
6 | def fact_normal_tail_rec(n, accum):
7 | return accum if n <= 1 else fact_normal_tail_rec(n-1, accum*n)
8 |
9 | # both blow up the stack
10 | # print fact_recursive(100)
11 | # print fact_normal_tail_rec(100, 1)
12 | # print fact_recursive(1000) #raises RuntimeError: maximum recursion depth exceeded
13 | # print fact_normal_tail_rec(1000, 1)
14 |
15 | # normal mutual recursion
16 | def even_recursive(n):
17 | return True if n == 0 else odd_recursive(n-1)
18 |
19 | def odd_recursive(n):
20 | return False if n == 0 else even_recursive(n-1)
21 |
22 | # normal mutual recursion blows the stack
23 | # print even_recursive(100)
24 | # print odd_recursive(50)
25 | # print even_recursive(1000)
26 | # print odd_recursive(1000)
27 |
28 | # trampoline-ready recursive functions
29 | def fact(n, accum):
30 | if n <= 1:
31 | return accum
32 | else:
33 | return lambda: fact(n-1, accum*n) #returns thunk instead of a call to itself
34 |
35 | # trampoline-ready mutual recursion
36 | def even(n):
37 | if n == 0:
38 | return True
39 | else:
40 | return lambda: odd(n-1)
41 |
42 | def odd(n):
43 | if n == 0:
44 | return False
45 | else:
46 | return lambda: even(n-1)
47 |
48 | # finally, the trampoline function
49 | def tramp(f, *args, **kwargs):
50 | res = f(*args, **kwargs)
51 | while callable(res):
52 | res = res()
53 | return res
54 |
55 | # trampolining recursive functions prevents stack overflows
56 | # print tramp(fact, 1000, 1)
57 | # print tramp(even, 1000)
58 |
59 | # there is another way to define the trampoline function: it can accept the _result_ of calling f on *args:
60 | def tramp_result(f):
61 | res = f
62 | while callable(res):
63 | res = res()
64 | return res
65 |
66 | # this also works
67 | # print tramp_result(fact(1000, 1))
68 | # print tramp_result(odd(1000))
69 |
70 | #now attempting to stick a trampoline into a generator
71 | def bad_tramp(f):
72 | """
73 | takes a tail-recursive function f
74 | returns a trampolined version of this function
75 | """
76 | def tramp_f(*args, **kwargs):
77 | res = f(*args, **kwargs)
78 | while callable(res):
79 | res = res()
80 | return res
81 | return tramp_f
82 |
83 | # there is nothing inherently bad with this decorator: it returns a working trampolined function
84 | # tramped_fact = bad_tramp(fact)
85 | # print tramped_fact(1000, 1)
86 |
87 | #BUT if we use it to change the original function, the trouble begins...
88 | #there are two ways of doing this: either using decorator syntax directly:
89 |
90 | @bad_tramp
91 | def bad_factorial(n, accum):
92 | return accum if n <= 1 else lambda: bad_factorial(n-1, accum*n) #same definition as fact
93 | # print bad_factorial(1000, 1)
94 |
95 | #or by doing what decorators are doing under the hood:
96 | # fact = bad_tramp(fact)
97 | # print fact(1000, 1) #raises RuntimeError: maximum recursion depth exceeded while calling a Python object
98 | # different error from before (cf.: RuntimeError: maximum recursion depth exceeded)
99 |
100 | # this behavior is caused by the fact that lambda: <...> doesn't close over any variables,
101 | # so if we change the definition of fact in the process of creating a trampolined version of it, the new function fact
102 | # will be calling the "new" version of itself, which will create new frames. Cf.:
103 |
104 | # x = [lambda: i+1 for i in range(3)]
105 | # print [j() for j in x] #[3, 3, 3] When this line is called, the value of i is read from the global environment
106 | ##way to go around it is to force the closure over i:
107 | # x = [(lambda real_i: lambda: real_i+1)(i) for i in range(3)]
108 | # print [j() for j in x] #[1, 2, 3]
109 |
110 | #how can we take this idea and improve our decorator?
111 | def better_tramp(f):
112 | """
113 | takes a tail-recursive function f
114 | returns a trampolined version of this function
115 | """
116 | def tramp_f(*args, **kwargs):
117 | res = f(*args, **kwargs)
118 | while callable(res):
119 | res = res()
120 | return res
121 | return tramp_f
122 |
123 | @better_tramp
124 | def better_factorial(n, accum):
125 | return accum if n <= 1 else (lambda real_f: lambda: real_f(n-1, accum*n))(better_factorial)
126 | print better_factorial(1000, 1)
127 |
128 | # a possible workaround to avoid name collision. Won't work with mutual recursion.
129 | # def deco_tramp(f):
130 | # """
131 | # takes a tail-recursive function f
132 | # return a trampolined version of this function
133 | # """
134 | # f_itself = lambda *a,**kw: f(f_itself, *a, **kw)
135 | # def tramp_f(*args, **kwargs):
136 | # res = f_itself(*args, **kwargs)
137 | # while callable(res):
138 | # res = res()
139 | # return res
140 | # return tramp_f
141 | #
142 | # @deco_tramp
143 | # def recur_deco_fact(recur, n, accum):
144 | # if n <= 1:
145 | # return accum
146 | # else:
147 | # return lambda: recur(n-1, accum*n)
148 |
149 | # import sys
150 | # sys.setrecursionlimit(4)
--------------------------------------------------------------------------------
/translator.py:
--------------------------------------------------------------------------------
1 | __author__ = 'Liuda'
2 | import ast, inspect, imp
3 | from ast_pretty_printer import print_ast
4 | import sys
5 |
6 | """
7 | High level algorithm:
8 | takes in a file, finds all recursive functions
9 | calls translator for each of the recursive functions found.
10 | """
11 | #TODO: how do you distinguish between a normal-recursive and tail-recursive function?
12 | #TODO: deal with base case that is not a value, but is an expression.
13 |
14 | # def is_recursive_function(tree):
15 | # try:
16 | # #get function name and function definition subtree
17 | # for node in ast.walk(tree):
18 | # if type(node).__name__ == 'FunctionDef':
19 | # name = node.name
20 | # func_def = node
21 | # for node in ast.walk(func_def): #find whether this func_def contains recursive calls
22 | # if type(node).__name__ == 'Call' and node.func.id == name:
23 | # return True
24 | # return False
25 | # except UnboundLocalError:
26 | # print('No functions found')
27 | # return False
28 |
29 |
30 | def is_recursive(tree, name):
31 | """
32 | Takes a tree and returns a recursive ast.Call if one is present in the tree, otherwise None.
33 | """
34 | for node in ast.walk(tree):
35 | if type(node).__name__ == 'Call' and node.func.id == name:
36 | return True
37 | return False
38 |
39 | class InfoGatherer(ast.NodeVisitor):
40 | """
41 | traverses the tree, storing the function name, default value and recursive calls.
42 | """
43 | def __init__(self):
44 | self.name = None
45 | self.default_value = None
46 | self.recursive_returns = []
47 |
48 | def visit_FunctionDef(self, node):
49 | self.name = node.name
50 | ast.NodeVisitor.generic_visit(self, node)
51 |
52 | def visit_Return(self, node):
53 | rec_call = is_recursive(node, self.name)
54 | if rec_call:
55 | self.recursive_returns.append(node)
56 | else:
57 | self.default_value = node.value
58 |
59 | class Mangler(ast.NodeTransformer):
60 | """
61 | Requirements to the ast tree passed:
62 | tree should be a result of ast.parse on a recursive function that has one base case, one recursive call per recursive return
63 | and no internal function definitions.
64 | """
65 | def __init__(self, tree):
66 | ig = InfoGatherer()
67 | ig.visit(tree)
68 | self.name = ig.name
69 | self.default_value = ig.default_value
70 | self.recursive_returns = ig.recursive_returns
71 |
72 | def make_Name(self, name, ctx_val):
73 | return ast.Name(
74 | id=name,
75 | ctx=ctx_val
76 | )
77 |
78 | def make_Num(self, x):
79 | return ast.Num(n=x)
80 |
81 | def update_recursive_calls(self):
82 | for node in self.recursive_returns:
83 | self.visit_Return(node)
84 |
85 | def visit_Return(self, node):
86 | if is_recursive(node, self.name):
87 | #substitute the recursive call with accum, store the call in self.outer
88 | ast.NodeTransformer.generic_visit(self, node)
89 | self.outer.args.append(node.value)
90 | return ast.Return(value=self.outer)
91 | else:
92 | return ast.Return(value=self.make_Name('accum', ast.Load()))
93 |
94 | def visit_Call(self, node):
95 | if node.func.id == self.name:
96 | self.outer = node
97 | return self.make_Name('accum', ast.Load())
98 | else:
99 | return node
100 |
101 | def visit_FunctionDef(self, node):
102 | #update arguments
103 | node.args.args.append(self.make_Name('accum', ast.Param()))
104 | node.args.defaults.append(self.default_value)
105 | ast.NodeTransformer.generic_visit(self, node)
106 | return node
107 |
108 | def visit_Print(self, node):
109 | return node
110 |
111 | class ModuleCrawler(ast.NodeTransformer):
112 | def visit_Module(self, tree):
113 | for node in ast.iter_child_nodes(tree):
114 | if type(node).__name__ == 'FunctionDef' and is_recursive(node, node.name):
115 | m = Mangler(node)
116 | m.visit(node)
117 | return tree
118 |
119 | def tail_fact(n, accum=1):
120 | if n <= 1: return accum
121 | else: return tail_fact(n - 1, accum * n)
122 |
123 | def rec_fact(n):
124 | if n <= 1:
125 | return 1
126 | else:
127 | return n * rec_fact(n-1)
128 |
129 |
130 | (filename,) = sys.argv[1:]
131 | tree = ast.parse(open(filename, 'r').read())
132 |
133 | m = ModuleCrawler()
134 | m.visit(tree)
135 | print_ast(tree)
136 |
137 | ast.fix_missing_locations(tree)
138 | code = compile(tree, filename, "exec")
139 |
140 | namespace = imp.new_module(filename)
141 | eval(code, namespace.__dict__)
--------------------------------------------------------------------------------
/translator_tests.py:
--------------------------------------------------------------------------------
1 | __author__ = 'Liuda'
2 | import ast, inspect
3 | import translator
4 |
5 |
6 | def rec_fact(n):
7 | if n <= 1:
8 | return 1
9 | else:
10 | return n * rec_fact(n-1)
11 |
12 | def tail_fact(n, accum=1):
13 | if n <= 1:
14 | return accum
15 | else:
16 | return tail_fact(n-1, accum * n)
17 |
18 | def f(n):
19 | print(n)
20 | f = 5
21 | return f
22 |
23 | def sq(x):
24 | return x*x
25 |
26 | def sum_squares(x):
27 | if x <= 1:
28 | return sq(x)
29 | else:
30 | return sq(x) + sum_squares(x-1)
31 |
32 | def print_n_squares(x):
33 | print(sq(x))
34 | if x > 1:
35 | print_n_squares(x-1)
36 |
37 | print(translator.isRecursive(ast.parse('a = 5 + 9')))
38 |
39 | def test():
40 | for func in [rec_fact, tail_fact, f, sq, sum_squares, print_n_squares]:
41 | tree = ast.parse(inspect.getsource(func))
42 | print(func.__name__, translator.isRecursive(tree))
43 |
44 | test()
--------------------------------------------------------------------------------