├── requirements.txt
├── README.rst
├── pi.py
├── test_assembler.py
├── blank_jit.py
├── test_jit.py
├── assembler.py
├── ast2png.py
└── jit.py
/requirements.txt:
--------------------------------------------------------------------------------
1 | pytest
2 | cffi
3 | astpretty
4 | git+https://github.com/Maratyszcza/PeachPy.git@01d1515#egg=peachpy
5 |
--------------------------------------------------------------------------------
/README.rst:
--------------------------------------------------------------------------------
1 | ==========================================
2 | HOW TO WRITE A JIT COMPILER IN 30 MINUTES
3 | ==========================================
4 |
5 | * Antonio Cuni, PyPy core dev
6 |
7 | * http://github.com/antocuni/jit30min
8 |
9 | * *VERY* simple JIT compiler for x86_64
10 |
11 | * Subset of Python
12 |
13 | * All variables are assumed to be of type `float`
14 |
15 | * We use PeachPy to encode ASM instructions
16 |
17 | * *DISCLAIMER*
18 |
19 | - I have never written any assembly before
20 |
21 | ================================
22 | EP2019 "Write your own JIT" game
23 | =================================
24 |
25 | * http://github.com/antocuni/jit30min
26 |
27 | * Extend the repo with your favorite feature
28 |
29 | * Send me a PR
30 |
31 | * I won't merge the PR but will keep a list of interested PRs
32 |
--------------------------------------------------------------------------------
/pi.py:
--------------------------------------------------------------------------------
1 | import time
2 | import inspect
3 | import platform
4 | import jit
5 |
6 | def compute_pi(iterations):
7 | delta = 1.0 / iterations
8 | inside = 0.0
9 | x = 0.0
10 | while x < 1:
11 | y = 0.0
12 | while y < 1:
13 | if x*x + y*y < 1:
14 | inside = inside + 1
15 | y = y + delta
16 | x = x + delta
17 | total = iterations * iterations
18 | return inside / total * 4
19 |
20 | def run(name, fn, iterations):
21 | a = time.time()
22 | pi = fn(iterations)
23 | b = time.time()
24 | t = b - a
25 | print('%10s pi = %.6f t = %.2f secs' % (name, pi, t))
26 |
27 | N = 3000
28 | def main():
29 | run(platform.python_implementation(), compute_pi, N)
30 | if platform.python_implementation() != 'PyPy':
31 | jitted = jit.compile(compute_pi)
32 | run('JIT', jitted, N)
33 |
34 | if __name__ == '__main__':
35 | main()
36 |
--------------------------------------------------------------------------------
/test_assembler.py:
--------------------------------------------------------------------------------
1 | from assembler import FunctionAssembler
2 | from peachpy import x86_64
3 |
4 | class TestFunctionAssembler:
5 |
6 | def load(self, asm):
7 | encoded_func = asm._encode()
8 | return encoded_func.load()
9 |
10 | def test_getattr(self):
11 | asm = FunctionAssembler('foo', [])
12 | assert asm.xmm0 is x86_64.xmm0
13 | assert asm.rsp is x86_64.rsp
14 | assert asm.qword is x86_64.qword
15 |
16 | def test_opcode(self):
17 | asm = FunctionAssembler('foo', [])
18 | asm.ADDSD(asm.xmm0, asm.xmm1)
19 | assert len(asm._peachpy_fn._instructions) == 1
20 | assert asm._peachpy_fn._instructions[0].__class__.__name__ == 'ADDSD'
21 |
22 | def test_encode(self):
23 | asm = FunctionAssembler('foo', ['a', 'b'])
24 | asm.ADDSD(asm.xmm0, asm.xmm1)
25 | asm.RET()
26 | pyfn = self.load(asm)
27 | assert pyfn(3, 4) == 7
28 |
29 | def test_const(self):
30 | asm = FunctionAssembler('foo', ['a', 'b'])
31 | asm.ADDSD(asm.xmm0, asm.xmm1)
32 | asm.ADDSD(asm.xmm0, asm.const(100))
33 | asm.RET()
34 | pyfn = self.load(asm)
35 | assert pyfn(3, 4) == 107
36 |
37 | def test_pushsd_popsd(self):
38 | asm = FunctionAssembler('foo', ['a', 'b', 'c'])
39 | asm.pushsd(asm.xmm0)
40 | asm.pushsd(asm.xmm1)
41 | asm.pushsd(asm.xmm2)
42 | asm.PXOR(asm.xmm0, asm.xmm0) # xmm0 = 0
43 | asm.popsd(asm.xmm1)
44 | asm.MULSD(asm.xmm1, asm.const(100)) # xmm0 += (xmm1*100)
45 | asm.ADDSD(asm.xmm0, asm.xmm1)
46 | asm.popsd(asm.xmm1)
47 | asm.MULSD(asm.xmm1, asm.const(10)) # xmm0 += (xmm1*10)
48 | asm.ADDSD(asm.xmm0, asm.xmm1)
49 | asm.popsd(asm.xmm1)
50 | asm.ADDSD(asm.xmm0, asm.xmm1)
51 | asm.RET()
52 | pyfn = self.load(asm)
53 | assert pyfn(1, 2, 3) == 321
54 |
55 | def test_jump(self):
56 | asm = FunctionAssembler('foo', [])
57 | label = asm.Label()
58 | asm.MOVSD(asm.xmm0, asm.const(42))
59 | asm.JMP(label)
60 | asm.MOVSD(asm.xmm0, asm.const(123)) # this is not executed
61 | asm.LABEL(label)
62 | asm.RET()
63 | pyfn = self.load(asm)
64 | assert pyfn() == 42
65 |
--------------------------------------------------------------------------------
/blank_jit.py:
--------------------------------------------------------------------------------
1 | import mmap
2 | import ast
3 | import textwrap
4 | import inspect
5 | from collections import defaultdict
6 | from cffi import FFI
7 | from assembler import FunctionAssembler as FA
8 |
9 | ## code = b'\xf2\x0f\x58\xc1\xc3' # addsd xmm0,xmm1 ; ret
10 | ## buf = mmap.mmap(-1, len(code), mmap.MAP_PRIVATE,
11 | ## mmap.PROT_READ | mmap.PROT_WRITE |
12 | ## mmap.PROT_EXEC)
13 | ## buf[:] = code
14 |
15 | ## ffi = FFI()
16 | ## ffi.cdef("""
17 | ## typedef double (*fn0)(void);
18 | ## typedef double (*fn1)(double);
19 | ## typedef double (*fn2)(double, double);
20 | ## typedef double (*fn3)(double, double, double);
21 | ## """)
22 |
23 | ## fptr = ffi.cast("fn2", ffi.from_buffer(buf))
24 | ## print(fptr(5, 2))
25 |
26 |
27 | ## class CompiledFunction:
28 |
29 | ## def __init__(self, nargs, code):
30 | ## ...
31 | ## fntype = 'fn%d' % nargs
32 | ## self.fptr = ffi.cast(fntype, ffi.from_buffer(self.buf))
33 |
34 | ## def __call__(self, *args):
35 | ## return self.fptr(*args)
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 | ## class RegAllocator:
45 |
46 | ## REGISTERS = (FA.xmm0, FA.xmm1, FA.xmm2, FA.xmm3, FA.xmm4,
47 | ## FA.xmm5, FA.xmm6, FA.xmm7, FA.xmm8, FA.xmm9,
48 | ## FA.xmm10, FA.xmm11, FA.xmm12, FA.xmm13,
49 | ## FA.xmm14, FA.xmm15)
50 |
51 |
52 |
53 |
54 |
55 | ## class AstCompiler:
56 |
57 | ## def __init__(self, src):
58 | ## self.tree = ast.parse(textwrap.dedent(src))
59 | ## self.asm = None
60 |
61 | ## def show(self, node):
62 | ## import astpretty
63 | ## from ast2png import ast2png
64 | ## astpretty.pprint(node)
65 | ## ast2png(self.tree, highlight_node=node, filename='ast.png')
66 | ##
67 | ## def compile(self):
68 | ## self.visit(self.tree)
69 | ## assert self.asm is not None, 'No function found?'
70 | ## code = self.asm.assemble_and_relocate()
71 | ## return CompiledFunction(self.asm.nargs, code)
72 | ##
73 | ## def visit(self, node):
74 | ## methname = node.__class__.__name__
75 | ## meth = getattr(self, methname, None)
76 | ## if meth is None:
77 | ## raise NotImplementedError(methname)
78 | ## return meth(node)
79 |
--------------------------------------------------------------------------------
/test_jit.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import jit
3 | from assembler import FunctionAssembler as FA
4 | from test_assembler import TestFunctionAssembler as AssemblerTest
5 |
6 | class TestCompiledFuntion(AssemblerTest):
7 |
8 | def test_basic(self):
9 | code = b'\xf2\x0f\x58\xc1\xc3' # addsd xmm0,xmm1 ; ret
10 | p = jit.CompiledFunction(2, code)
11 | assert p.fptr(12.34, 56.78) == 12.34 + 56.78
12 | assert p(12.34, 56.78) == 12.34 + 56.78
13 |
14 | def load(self, asm):
15 | code = asm.assemble_and_relocate()
16 | return jit.CompiledFunction(asm.nargs, code)
17 |
18 | class TestRegAllocator:
19 |
20 | def test_allocate(self):
21 | regs = jit.RegAllocator()
22 | assert regs.get('a') == FA.xmm0
23 | assert regs.get('b') == FA.xmm1
24 | assert regs.get('a') == FA.xmm0
25 | assert regs.get('c') == FA.xmm2
26 |
27 | def test_too_many_vars(self):
28 | regs = jit.RegAllocator()
29 | # force allocation of xmm0..xmm14
30 | for i in range(15):
31 | regs.get('var%d' % i)
32 | assert regs.get('var15') == FA.xmm15
33 | pytest.raises(NotImplementedError, "regs.get('var16')")
34 |
35 |
36 | class TestAstCompiler:
37 |
38 | def test_empty(self):
39 | comp = jit.AstCompiler("""
40 | def foo():
41 | pass
42 | """)
43 | fn = comp.compile()
44 | assert fn() == 0
45 |
46 | def test_simple(self):
47 | comp = jit.AstCompiler("""
48 | def foo():
49 | return 100
50 | """)
51 | fn = comp.compile()
52 | assert fn() == 100
53 |
54 | def test_arguments(self):
55 | comp = jit.AstCompiler("""
56 | def foo(a, b):
57 | return b
58 | """)
59 | fn = comp.compile()
60 | assert fn(3, 4) == 4
61 |
62 | def test_add(self):
63 | comp = jit.AstCompiler("""
64 | def foo(a, b):
65 | return a+b
66 | """)
67 | fn = comp.compile()
68 | assert fn(3, 4) == 7
69 |
70 | def test_binops(self):
71 | comp = jit.AstCompiler("""
72 | def foo(a, b):
73 | return (a-b) + (a*b) - (a/b)
74 | """)
75 | fn = comp.compile()
76 | res = (3-4) + (3*4) - (3.0/4)
77 | assert fn(3, 4) == res
78 |
79 | def test_assign(self):
80 | comp = jit.AstCompiler("""
81 | def foo(a):
82 | b = a + 1
83 | return b
84 | """)
85 | fn = comp.compile()
86 | assert fn(41) == 42
87 |
88 | def test_if(self):
89 | comp = jit.AstCompiler("""
90 | def foo(a):
91 | if a < 0:
92 | return 0-a
93 | return a
94 | """)
95 | fn = comp.compile()
96 | assert fn(-42) == 42
97 | assert fn(42) == 42
98 |
99 | def test_while(self):
100 | comp = jit.AstCompiler("""
101 | def foo(a):
102 | tot = 0
103 | i = 0
104 | while i < a:
105 | tot = tot + i
106 | i = i + 1
107 | return tot
108 | """)
109 | fn = comp.compile()
110 | assert fn(5) == 1+2+3+4
111 |
112 |
113 | class TestDecorator:
114 |
115 | def test_simple(self):
116 | @jit.compile
117 | def foo(a, b):
118 | return a+b
119 | assert type(foo) is jit.CompiledFunction
120 | assert foo(39, 3) == 42.0
121 |
--------------------------------------------------------------------------------
/assembler.py:
--------------------------------------------------------------------------------
1 | import peachpy
2 | from peachpy import Argument, double_, Constant
3 | from peachpy import x86_64
4 | # workaround because peachpy forget to expose rsp
5 | x86_64.rsp = peachpy.x86_64.registers.rsp
6 |
7 |
8 | class FunctionAssembler:
9 |
10 | from peachpy.x86_64 import (xmm0, xmm1, xmm2, xmm3, xmm4,
11 | xmm5, xmm6, xmm7, xmm8, xmm9,
12 | xmm10, xmm11, xmm12, xmm13,
13 | xmm14, xmm15)
14 |
15 | def __init__(self, name, argnames):
16 | self.name = name
17 | self.nargs = len(argnames)
18 | args = [Argument(double_, name=name) for name in argnames]
19 | self._peachpy_fn = x86_64.Function(name, args, double_)
20 |
21 | def __getattr__(self, name):
22 | obj = getattr(x86_64, name)
23 | if type(obj) is type and issubclass(obj, x86_64.instructions.Instruction):
24 | instr = obj
25 | def emit(*args):
26 | self._peachpy_fn.add_instruction(instr(*args))
27 | return emit
28 | else:
29 | return obj
30 |
31 | def const(self, val):
32 | return Constant.float64(float(val))
33 |
34 | def pushsd(self, reg):
35 | self.SUB(self.rsp, 16)
36 | self.MOVSD(self.qword[self.rsp], reg)
37 |
38 | def popsd(self, reg):
39 | self.MOVSD(reg, self.qword[self.rsp])
40 | self.ADD(self.rsp, 16)
41 |
42 | def _encode(self):
43 | abi_func = self._peachpy_fn.finalize(x86_64.abi.detect())
44 | return abi_func.encode()
45 |
46 | def assemble_and_relocate(self):
47 | # this code has been copied and adapted from PeachPy: the goal is to
48 | # put the data section immediately contiguous to the code section:
49 | # this is suboptimal and dangerous because it mixes code and data and
50 | # puts the data inside an mmap-ed region with the executable flag, but
51 | # I don't want/don't have time to talk and explain this stuff during
52 | # the talk. Please forgive me :)
53 |
54 | encoded_func = self._encode()
55 | #print(); print(encoded_func.format())
56 | code_segment = bytearray(encoded_func.code_section.content)
57 | const_segment = bytearray(encoded_func.const_section.content)
58 | const_offset = len(code_segment)
59 |
60 | # Apply relocations
61 | from peachpy.x86_64.meta import RelocationType
62 | from peachpy.util import is_sint32
63 | for relocation in encoded_func.code_section.relocations:
64 | assert relocation.type == RelocationType.rip_disp32
65 | assert relocation.symbol in encoded_func.const_section.symbols
66 | old_value = code_segment[relocation.offset] | \
67 | (code_segment[relocation.offset + 1] << 8) | \
68 | (code_segment[relocation.offset + 2] << 16) | \
69 | (code_segment[relocation.offset + 3] << 24)
70 |
71 | # this is the biggest difference wrt peachpy
72 | new_value = (old_value + const_offset + relocation.symbol.offset +
73 | -relocation.program_counter)
74 | #
75 | assert is_sint32(new_value)
76 | code_segment[relocation.offset] = new_value & 0xFF
77 | code_segment[relocation.offset + 1] = (new_value >> 8) & 0xFF
78 | code_segment[relocation.offset + 2] = (new_value >> 16) & 0xFF
79 | code_segment[relocation.offset + 3] = (new_value >> 24) & 0xFF
80 | assert not encoded_func.const_section.relocations
81 |
82 | return code_segment + const_segment
83 |
--------------------------------------------------------------------------------
/ast2png.py:
--------------------------------------------------------------------------------
1 | """
2 | Render an ast to png. Stolen&adapted from:
3 | https://github.com/hchasestevens/show_ast
4 | """
5 |
6 | import ast
7 | import itertools
8 | from functools import partial
9 | import graphviz
10 |
11 | try:
12 | reduce
13 | except NameError:
14 | from functools import reduce
15 |
16 | try:
17 | _basestring = basestring
18 | except NameError:
19 | _basestring = str
20 |
21 |
22 | SETTINGS = dict(
23 | # Styling options:
24 | scale=2,
25 | font='courier',
26 | shape='none',
27 | terminal_color='#008040',
28 | nonterminal_color='#004080',
29 |
30 | # AST display options:
31 | omit_module=True,
32 | omit_docstrings=True,
33 | )
34 |
35 |
36 | def _strip_docstring(body):
37 | first = body[0]
38 | if isinstance(first, ast.Expr) and isinstance(first.value, ast.Str):
39 | return body[1:]
40 | return body
41 |
42 | def recurse_through_ast(node, handle_ast, handle_terminal, handle_fields,
43 | handle_no_fields, omit_docstrings):
44 | possible_docstring = isinstance(node, (ast.FunctionDef,
45 | ast.ClassDef, ast.Module))
46 |
47 | node_fields = zip(
48 | node._fields,
49 | (getattr(node, attr) for attr in node._fields)
50 | )
51 | field_results = []
52 | for field_name, field_value in node_fields:
53 | if field_name == 'ctx':
54 | continue
55 | if isinstance(field_value, ast.AST):
56 | field_results.append(handle_ast(field_value))
57 |
58 | elif isinstance(field_value, list):
59 | if possible_docstring and omit_docstrings and field_name == 'body':
60 | field_value = _strip_docstring(field_value)
61 | field_results.extend(
62 | handle_ast(item)
63 | if isinstance(item, ast.AST) else
64 | handle_terminal(item)
65 | for item in field_value
66 | )
67 |
68 | elif isinstance(field_value, _basestring):
69 | field_results.append(handle_terminal('"{}"'.format(field_value)))
70 |
71 | elif field_value is not None:
72 | field_results.append(handle_terminal(field_value))
73 |
74 | if not field_results:
75 | return handle_no_fields(node)
76 |
77 | return handle_fields(node, field_results)
78 |
79 |
80 |
81 | def _bold(label):
82 | return '<{}>'.format(label)
83 |
84 |
85 | def _attach_to_parent(parent, graph, names, label, name=None, **style):
86 | node_name = next(names) if name is None else name
87 | node = graph.node(node_name, label=label, **style)
88 | if parent is not None:
89 | graph.edge(parent, node_name, sametail='t{}'.format(parent))
90 |
91 |
92 | def handle_ast(node, parent_node, graph, names, omit_docstrings, terminal_color,
93 | nonterminal_color, highlight_node):
94 | if node is highlight_node:
95 | nonterminal_color = '#CC0000'
96 |
97 | attach_to_parent = partial(
98 | _attach_to_parent,
99 | graph=graph,
100 | names=names,
101 | )
102 | node_name = next(names)
103 | attach_to_parent(
104 | parent=parent_node,
105 | label=_bold(node.__class__.__name__),
106 | name=node_name,
107 | fontcolor=nonterminal_color,
108 | )
109 | recurse_through_ast(
110 | node,
111 | partial(
112 | handle_ast,
113 | parent_node=node_name,
114 | graph=graph,
115 | names=names,
116 | omit_docstrings=omit_docstrings,
117 | terminal_color=terminal_color,
118 | nonterminal_color=nonterminal_color,
119 | highlight_node=highlight_node,
120 | ),
121 | partial(
122 | handle_terminal,
123 | attach_to_parent=partial(
124 | attach_to_parent,
125 | parent=node_name,
126 | fontcolor=terminal_color,
127 | ),
128 | ),
129 | handle_fields,
130 | partial(
131 | handle_no_fields,
132 | parent_node=node_name,
133 | graph=graph,
134 | terminal_color=terminal_color,
135 | nonterminal_color=nonterminal_color,
136 | ),
137 | omit_docstrings,
138 | )
139 |
140 |
141 | def handle_terminal(terminal, attach_to_parent):
142 | attach_to_parent(label=str(terminal))
143 |
144 |
145 | def handle_fields(*__):
146 | pass
147 |
148 |
149 | def handle_no_fields(__, parent_node, graph, terminal_color, nonterminal_color):
150 | parent_node_beginning = '{} '.format(parent_node)
151 | parent_node_num = int(parent_node)
152 | for i, node in enumerate(graph.body[parent_node_num:]):
153 | if node.strip().startswith(parent_node_beginning):
154 | break
155 | else:
156 | raise KeyError("Could not find parent in graph.")
157 | replacements = {
158 | nonterminal_color: terminal_color,
159 | '<': '',
160 | '>': '',
161 | }
162 | graph.body[i + parent_node_num] = reduce(
163 | lambda s, replacement: s.replace(*replacement),
164 | replacements.items(),
165 | node,
166 | )
167 |
168 |
169 | def ast2png(root, highlight_node=None, filename='ast.png', settings=SETTINGS):
170 | graph = graphviz.Graph(format='png')
171 | names = (str(x) for x in itertools.count())
172 |
173 | handle_ast(
174 | root,
175 | parent_node=None,
176 | graph=graph,
177 | names=names,
178 | omit_docstrings=settings['omit_docstrings'],
179 | terminal_color=settings['terminal_color'],
180 | nonterminal_color=settings['nonterminal_color'],
181 | highlight_node=highlight_node,
182 | )
183 |
184 | graph.node_attr.update(dict(
185 | fontname=settings['font'],
186 | shape=settings['shape'],
187 | #height='0.25', # TODO: how to incorporate with scale param?
188 | #fixedsize='true',
189 | ))
190 |
191 | data = graph.pipe()
192 | with open(filename, 'wb') as f:
193 | f.write(data)
194 |
195 |
--------------------------------------------------------------------------------
/jit.py:
--------------------------------------------------------------------------------
1 | import mmap
2 | import ast
3 | import textwrap
4 | import inspect
5 | from collections import defaultdict
6 | from cffi import FFI
7 | from assembler import FunctionAssembler as FA
8 |
9 | ffi = FFI()
10 | ffi.cdef("""
11 | typedef double (*fn0)(void);
12 | typedef double (*fn1)(double);
13 | typedef double (*fn2)(double, double);
14 | typedef double (*fn3)(double, double, double);
15 | """)
16 |
17 | class CompiledFunction:
18 |
19 | def __init__(self, nargs, code):
20 | self.buf = mmap.mmap(-1, len(code), mmap.MAP_PRIVATE,
21 | mmap.PROT_READ | mmap.PROT_WRITE |
22 | mmap.PROT_EXEC)
23 | self.buf[:len(code)] = code
24 | fntype = 'fn%d' % nargs
25 | self.fptr = ffi.cast(fntype, ffi.from_buffer(self.buf))
26 |
27 | def __call__(self, *args):
28 | return self.fptr(*args)
29 |
30 |
31 | class RegAllocator:
32 |
33 | REGISTERS = (FA.xmm0, FA.xmm1, FA.xmm2, FA.xmm3, FA.xmm4,
34 | FA.xmm5, FA.xmm6, FA.xmm7, FA.xmm8, FA.xmm9,
35 | FA.xmm10, FA.xmm11, FA.xmm12, FA.xmm13,
36 | FA.xmm14, FA.xmm15)
37 |
38 | def __init__(self):
39 | self._registers = list(reversed(self.REGISTERS))
40 | self.vars = defaultdict(self._allocate) # name -> reg
41 |
42 | def _allocate(self):
43 | try:
44 | return self._registers.pop()
45 | except IndexError:
46 | raise NotImplementedError("Too many variables: register spilling not implemented")
47 |
48 | def get(self, varname):
49 | return self.vars[varname]
50 |
51 | class AstCompiler:
52 |
53 | def __init__(self, src):
54 | self.tree = ast.parse(textwrap.dedent(src))
55 | self.asm = None
56 |
57 | def show(self, node):
58 | import astpretty
59 | from ast2png import ast2png
60 | astpretty.pprint(node)
61 | ast2png(self.tree, highlight_node=node, filename='ast.png')
62 |
63 | def _newfunc(self, name, argnames):
64 | self.asm = FA(name, argnames)
65 | self.regs = RegAllocator()
66 | for argname in argnames:
67 | self.regs.get(argname)
68 | self.tmp0 = self.regs.get('__scratch_register_0__')
69 | self.tmp1 = self.regs.get('__scratch_register_1__')
70 |
71 | def compile(self):
72 | self.visit(self.tree)
73 | assert self.asm is not None, 'No function found?'
74 | code = self.asm.assemble_and_relocate()
75 | return CompiledFunction(self.asm.nargs, code)
76 |
77 | def visit(self, node):
78 | methname = node.__class__.__name__
79 | meth = getattr(self, methname, None)
80 | if meth is None:
81 | raise NotImplementedError(methname)
82 | return meth(node)
83 |
84 | def Module(self, node):
85 | for child in node.body:
86 | self.visit(child)
87 |
88 | def FunctionDef(self, node):
89 | assert not self.asm, 'cannot compile more than one function'
90 | argnames = [arg.arg for arg in node.args.args]
91 | self._newfunc(node.name, argnames)
92 | for child in node.body:
93 | self.visit(child)
94 | # return 0 by default
95 | self.asm.PXOR(self.asm.xmm0, self.asm.xmm0)
96 | self.asm.RET()
97 |
98 | def Pass(self, node):
99 | pass
100 |
101 | def Return(self, node):
102 | self.visit(node.value)
103 | self.asm.popsd(self.asm.xmm0)
104 | self.asm.RET()
105 |
106 | def Num(self, node):
107 | self.asm.MOVSD(self.tmp0, self.asm.const(node.n))
108 | self.asm.pushsd(self.tmp0)
109 |
110 | def BinOp(self, node):
111 | OPS = {
112 | 'ADD': self.asm.ADDSD,
113 | 'SUB': self.asm.SUBSD,
114 | 'MULT': self.asm.MULSD,
115 | 'DIV': self.asm.DIVSD,
116 | }
117 | opname = node.op.__class__.__name__.upper()
118 | self.visit(node.left)
119 | self.visit(node.right)
120 | self.asm.popsd(self.tmp1)
121 | self.asm.popsd(self.tmp0)
122 | OPS[opname](self.tmp0, self.tmp1)
123 | self.asm.pushsd(self.tmp0)
124 |
125 | def Name(self, node):
126 | reg = self.regs.get(node.id)
127 | self.asm.pushsd(reg)
128 |
129 | def Assign(self, node):
130 | assert len(node.targets) == 1
131 | varname = node.targets[0].id
132 | reg = self.regs.get(varname)
133 | self.visit(node.value)
134 | self.asm.popsd(self.tmp0)
135 | self.asm.MOVSD(reg, self.tmp0)
136 |
137 | def If(self, node):
138 | """
139 | IF GOTO then_label
140 | GOTO end_label
141 | then_label:
142 |
143 | end_label:
144 | ...
145 | """
146 | CMP = {
147 | 'LT': self.asm.JB
148 | }
149 | assert not node.orelse
150 | then_label = self.asm.Label()
151 | end_label = self.asm.Label()
152 | op = self.visit(node.test)
153 | CMP[op](then_label)
154 | self.asm.JMP(end_label)
155 | self.asm.LABEL(then_label)
156 | for child in node.body:
157 | self.visit(child)
158 | self.asm.LABEL(end_label)
159 |
160 | def Compare(self, node):
161 | assert len(node.ops) == 1
162 | cmp_op = node.ops[0].__class__.__name__.upper()
163 | self.visit(node.left)
164 | self.visit(node.comparators[0])
165 | self.asm.popsd(self.tmp1)
166 | self.asm.popsd(self.tmp0)
167 | self.asm.UCOMISD(self.tmp0, self.tmp1)
168 | return cmp_op
169 |
170 | def While(self, node):
171 | """
172 | begin_label:
173 | IF GOTO body_label
174 | GOTO end_label
175 | body_label:
176 |
177 | GOTO begin_label
178 | end_label:
179 | ...
180 | """
181 | CMP = {
182 | 'LT': self.asm.JB
183 | }
184 | begin_label = self.asm.Label()
185 | body_label = self.asm.Label()
186 | end_label = self.asm.Label()
187 | #
188 | self.asm.LABEL(begin_label)
189 | op = self.visit(node.test)
190 | CMP[op](body_label)
191 | self.asm.JMP(end_label)
192 | self.asm.LABEL(body_label)
193 | for child in node.body:
194 | self.visit(child)
195 | self.asm.JMP(begin_label)
196 | self.asm.LABEL(end_label)
197 |
198 |
199 | def compile(fn):
200 | src = inspect.getsource(fn)
201 | comp = AstCompiler(src)
202 | return comp.compile()
203 |
--------------------------------------------------------------------------------