├── requirements.txt ├── README.rst ├── pi.py ├── test_assembler.py ├── blank_jit.py ├── test_jit.py ├── assembler.py ├── ast2png.py └── jit.py /requirements.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | cffi 3 | astpretty 4 | git+https://github.com/Maratyszcza/PeachPy.git@01d1515#egg=peachpy 5 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | ========================================== 2 | HOW TO WRITE A JIT COMPILER IN 30 MINUTES 3 | ========================================== 4 | 5 | * Antonio Cuni, PyPy core dev 6 | 7 | * http://github.com/antocuni/jit30min 8 | 9 | * *VERY* simple JIT compiler for x86_64 10 | 11 | * Subset of Python 12 | 13 | * All variables are assumed to be of type `float` 14 | 15 | * We use PeachPy to encode ASM instructions 16 | 17 | * *DISCLAIMER* 18 | 19 | - I have never written any assembly before 20 | 21 | ================================ 22 | EP2019 "Write your own JIT" game 23 | ================================= 24 | 25 | * http://github.com/antocuni/jit30min 26 | 27 | * Extend the repo with your favorite feature 28 | 29 | * Send me a PR 30 | 31 | * I won't merge the PR but will keep a list of interested PRs 32 | -------------------------------------------------------------------------------- /pi.py: -------------------------------------------------------------------------------- 1 | import time 2 | import inspect 3 | import platform 4 | import jit 5 | 6 | def compute_pi(iterations): 7 | delta = 1.0 / iterations 8 | inside = 0.0 9 | x = 0.0 10 | while x < 1: 11 | y = 0.0 12 | while y < 1: 13 | if x*x + y*y < 1: 14 | inside = inside + 1 15 | y = y + delta 16 | x = x + delta 17 | total = iterations * iterations 18 | return inside / total * 4 19 | 20 | def run(name, fn, iterations): 21 | a = time.time() 22 | pi = fn(iterations) 23 | b = time.time() 24 | t = b - a 25 | print('%10s pi = %.6f t = %.2f secs' % (name, pi, t)) 26 | 27 | N = 3000 28 | def main(): 29 | run(platform.python_implementation(), compute_pi, N) 30 | if platform.python_implementation() != 'PyPy': 31 | jitted = jit.compile(compute_pi) 32 | run('JIT', jitted, N) 33 | 34 | if __name__ == '__main__': 35 | main() 36 | -------------------------------------------------------------------------------- /test_assembler.py: -------------------------------------------------------------------------------- 1 | from assembler import FunctionAssembler 2 | from peachpy import x86_64 3 | 4 | class TestFunctionAssembler: 5 | 6 | def load(self, asm): 7 | encoded_func = asm._encode() 8 | return encoded_func.load() 9 | 10 | def test_getattr(self): 11 | asm = FunctionAssembler('foo', []) 12 | assert asm.xmm0 is x86_64.xmm0 13 | assert asm.rsp is x86_64.rsp 14 | assert asm.qword is x86_64.qword 15 | 16 | def test_opcode(self): 17 | asm = FunctionAssembler('foo', []) 18 | asm.ADDSD(asm.xmm0, asm.xmm1) 19 | assert len(asm._peachpy_fn._instructions) == 1 20 | assert asm._peachpy_fn._instructions[0].__class__.__name__ == 'ADDSD' 21 | 22 | def test_encode(self): 23 | asm = FunctionAssembler('foo', ['a', 'b']) 24 | asm.ADDSD(asm.xmm0, asm.xmm1) 25 | asm.RET() 26 | pyfn = self.load(asm) 27 | assert pyfn(3, 4) == 7 28 | 29 | def test_const(self): 30 | asm = FunctionAssembler('foo', ['a', 'b']) 31 | asm.ADDSD(asm.xmm0, asm.xmm1) 32 | asm.ADDSD(asm.xmm0, asm.const(100)) 33 | asm.RET() 34 | pyfn = self.load(asm) 35 | assert pyfn(3, 4) == 107 36 | 37 | def test_pushsd_popsd(self): 38 | asm = FunctionAssembler('foo', ['a', 'b', 'c']) 39 | asm.pushsd(asm.xmm0) 40 | asm.pushsd(asm.xmm1) 41 | asm.pushsd(asm.xmm2) 42 | asm.PXOR(asm.xmm0, asm.xmm0) # xmm0 = 0 43 | asm.popsd(asm.xmm1) 44 | asm.MULSD(asm.xmm1, asm.const(100)) # xmm0 += (xmm1*100) 45 | asm.ADDSD(asm.xmm0, asm.xmm1) 46 | asm.popsd(asm.xmm1) 47 | asm.MULSD(asm.xmm1, asm.const(10)) # xmm0 += (xmm1*10) 48 | asm.ADDSD(asm.xmm0, asm.xmm1) 49 | asm.popsd(asm.xmm1) 50 | asm.ADDSD(asm.xmm0, asm.xmm1) 51 | asm.RET() 52 | pyfn = self.load(asm) 53 | assert pyfn(1, 2, 3) == 321 54 | 55 | def test_jump(self): 56 | asm = FunctionAssembler('foo', []) 57 | label = asm.Label() 58 | asm.MOVSD(asm.xmm0, asm.const(42)) 59 | asm.JMP(label) 60 | asm.MOVSD(asm.xmm0, asm.const(123)) # this is not executed 61 | asm.LABEL(label) 62 | asm.RET() 63 | pyfn = self.load(asm) 64 | assert pyfn() == 42 65 | -------------------------------------------------------------------------------- /blank_jit.py: -------------------------------------------------------------------------------- 1 | import mmap 2 | import ast 3 | import textwrap 4 | import inspect 5 | from collections import defaultdict 6 | from cffi import FFI 7 | from assembler import FunctionAssembler as FA 8 | 9 | ## code = b'\xf2\x0f\x58\xc1\xc3' # addsd xmm0,xmm1 ; ret 10 | ## buf = mmap.mmap(-1, len(code), mmap.MAP_PRIVATE, 11 | ## mmap.PROT_READ | mmap.PROT_WRITE | 12 | ## mmap.PROT_EXEC) 13 | ## buf[:] = code 14 | 15 | ## ffi = FFI() 16 | ## ffi.cdef(""" 17 | ## typedef double (*fn0)(void); 18 | ## typedef double (*fn1)(double); 19 | ## typedef double (*fn2)(double, double); 20 | ## typedef double (*fn3)(double, double, double); 21 | ## """) 22 | 23 | ## fptr = ffi.cast("fn2", ffi.from_buffer(buf)) 24 | ## print(fptr(5, 2)) 25 | 26 | 27 | ## class CompiledFunction: 28 | 29 | ## def __init__(self, nargs, code): 30 | ## ... 31 | ## fntype = 'fn%d' % nargs 32 | ## self.fptr = ffi.cast(fntype, ffi.from_buffer(self.buf)) 33 | 34 | ## def __call__(self, *args): 35 | ## return self.fptr(*args) 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | ## class RegAllocator: 45 | 46 | ## REGISTERS = (FA.xmm0, FA.xmm1, FA.xmm2, FA.xmm3, FA.xmm4, 47 | ## FA.xmm5, FA.xmm6, FA.xmm7, FA.xmm8, FA.xmm9, 48 | ## FA.xmm10, FA.xmm11, FA.xmm12, FA.xmm13, 49 | ## FA.xmm14, FA.xmm15) 50 | 51 | 52 | 53 | 54 | 55 | ## class AstCompiler: 56 | 57 | ## def __init__(self, src): 58 | ## self.tree = ast.parse(textwrap.dedent(src)) 59 | ## self.asm = None 60 | 61 | ## def show(self, node): 62 | ## import astpretty 63 | ## from ast2png import ast2png 64 | ## astpretty.pprint(node) 65 | ## ast2png(self.tree, highlight_node=node, filename='ast.png') 66 | ## 67 | ## def compile(self): 68 | ## self.visit(self.tree) 69 | ## assert self.asm is not None, 'No function found?' 70 | ## code = self.asm.assemble_and_relocate() 71 | ## return CompiledFunction(self.asm.nargs, code) 72 | ## 73 | ## def visit(self, node): 74 | ## methname = node.__class__.__name__ 75 | ## meth = getattr(self, methname, None) 76 | ## if meth is None: 77 | ## raise NotImplementedError(methname) 78 | ## return meth(node) 79 | -------------------------------------------------------------------------------- /test_jit.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import jit 3 | from assembler import FunctionAssembler as FA 4 | from test_assembler import TestFunctionAssembler as AssemblerTest 5 | 6 | class TestCompiledFuntion(AssemblerTest): 7 | 8 | def test_basic(self): 9 | code = b'\xf2\x0f\x58\xc1\xc3' # addsd xmm0,xmm1 ; ret 10 | p = jit.CompiledFunction(2, code) 11 | assert p.fptr(12.34, 56.78) == 12.34 + 56.78 12 | assert p(12.34, 56.78) == 12.34 + 56.78 13 | 14 | def load(self, asm): 15 | code = asm.assemble_and_relocate() 16 | return jit.CompiledFunction(asm.nargs, code) 17 | 18 | class TestRegAllocator: 19 | 20 | def test_allocate(self): 21 | regs = jit.RegAllocator() 22 | assert regs.get('a') == FA.xmm0 23 | assert regs.get('b') == FA.xmm1 24 | assert regs.get('a') == FA.xmm0 25 | assert regs.get('c') == FA.xmm2 26 | 27 | def test_too_many_vars(self): 28 | regs = jit.RegAllocator() 29 | # force allocation of xmm0..xmm14 30 | for i in range(15): 31 | regs.get('var%d' % i) 32 | assert regs.get('var15') == FA.xmm15 33 | pytest.raises(NotImplementedError, "regs.get('var16')") 34 | 35 | 36 | class TestAstCompiler: 37 | 38 | def test_empty(self): 39 | comp = jit.AstCompiler(""" 40 | def foo(): 41 | pass 42 | """) 43 | fn = comp.compile() 44 | assert fn() == 0 45 | 46 | def test_simple(self): 47 | comp = jit.AstCompiler(""" 48 | def foo(): 49 | return 100 50 | """) 51 | fn = comp.compile() 52 | assert fn() == 100 53 | 54 | def test_arguments(self): 55 | comp = jit.AstCompiler(""" 56 | def foo(a, b): 57 | return b 58 | """) 59 | fn = comp.compile() 60 | assert fn(3, 4) == 4 61 | 62 | def test_add(self): 63 | comp = jit.AstCompiler(""" 64 | def foo(a, b): 65 | return a+b 66 | """) 67 | fn = comp.compile() 68 | assert fn(3, 4) == 7 69 | 70 | def test_binops(self): 71 | comp = jit.AstCompiler(""" 72 | def foo(a, b): 73 | return (a-b) + (a*b) - (a/b) 74 | """) 75 | fn = comp.compile() 76 | res = (3-4) + (3*4) - (3.0/4) 77 | assert fn(3, 4) == res 78 | 79 | def test_assign(self): 80 | comp = jit.AstCompiler(""" 81 | def foo(a): 82 | b = a + 1 83 | return b 84 | """) 85 | fn = comp.compile() 86 | assert fn(41) == 42 87 | 88 | def test_if(self): 89 | comp = jit.AstCompiler(""" 90 | def foo(a): 91 | if a < 0: 92 | return 0-a 93 | return a 94 | """) 95 | fn = comp.compile() 96 | assert fn(-42) == 42 97 | assert fn(42) == 42 98 | 99 | def test_while(self): 100 | comp = jit.AstCompiler(""" 101 | def foo(a): 102 | tot = 0 103 | i = 0 104 | while i < a: 105 | tot = tot + i 106 | i = i + 1 107 | return tot 108 | """) 109 | fn = comp.compile() 110 | assert fn(5) == 1+2+3+4 111 | 112 | 113 | class TestDecorator: 114 | 115 | def test_simple(self): 116 | @jit.compile 117 | def foo(a, b): 118 | return a+b 119 | assert type(foo) is jit.CompiledFunction 120 | assert foo(39, 3) == 42.0 121 | -------------------------------------------------------------------------------- /assembler.py: -------------------------------------------------------------------------------- 1 | import peachpy 2 | from peachpy import Argument, double_, Constant 3 | from peachpy import x86_64 4 | # workaround because peachpy forget to expose rsp 5 | x86_64.rsp = peachpy.x86_64.registers.rsp 6 | 7 | 8 | class FunctionAssembler: 9 | 10 | from peachpy.x86_64 import (xmm0, xmm1, xmm2, xmm3, xmm4, 11 | xmm5, xmm6, xmm7, xmm8, xmm9, 12 | xmm10, xmm11, xmm12, xmm13, 13 | xmm14, xmm15) 14 | 15 | def __init__(self, name, argnames): 16 | self.name = name 17 | self.nargs = len(argnames) 18 | args = [Argument(double_, name=name) for name in argnames] 19 | self._peachpy_fn = x86_64.Function(name, args, double_) 20 | 21 | def __getattr__(self, name): 22 | obj = getattr(x86_64, name) 23 | if type(obj) is type and issubclass(obj, x86_64.instructions.Instruction): 24 | instr = obj 25 | def emit(*args): 26 | self._peachpy_fn.add_instruction(instr(*args)) 27 | return emit 28 | else: 29 | return obj 30 | 31 | def const(self, val): 32 | return Constant.float64(float(val)) 33 | 34 | def pushsd(self, reg): 35 | self.SUB(self.rsp, 16) 36 | self.MOVSD(self.qword[self.rsp], reg) 37 | 38 | def popsd(self, reg): 39 | self.MOVSD(reg, self.qword[self.rsp]) 40 | self.ADD(self.rsp, 16) 41 | 42 | def _encode(self): 43 | abi_func = self._peachpy_fn.finalize(x86_64.abi.detect()) 44 | return abi_func.encode() 45 | 46 | def assemble_and_relocate(self): 47 | # this code has been copied and adapted from PeachPy: the goal is to 48 | # put the data section immediately contiguous to the code section: 49 | # this is suboptimal and dangerous because it mixes code and data and 50 | # puts the data inside an mmap-ed region with the executable flag, but 51 | # I don't want/don't have time to talk and explain this stuff during 52 | # the talk. Please forgive me :) 53 | 54 | encoded_func = self._encode() 55 | #print(); print(encoded_func.format()) 56 | code_segment = bytearray(encoded_func.code_section.content) 57 | const_segment = bytearray(encoded_func.const_section.content) 58 | const_offset = len(code_segment) 59 | 60 | # Apply relocations 61 | from peachpy.x86_64.meta import RelocationType 62 | from peachpy.util import is_sint32 63 | for relocation in encoded_func.code_section.relocations: 64 | assert relocation.type == RelocationType.rip_disp32 65 | assert relocation.symbol in encoded_func.const_section.symbols 66 | old_value = code_segment[relocation.offset] | \ 67 | (code_segment[relocation.offset + 1] << 8) | \ 68 | (code_segment[relocation.offset + 2] << 16) | \ 69 | (code_segment[relocation.offset + 3] << 24) 70 | 71 | # this is the biggest difference wrt peachpy 72 | new_value = (old_value + const_offset + relocation.symbol.offset + 73 | -relocation.program_counter) 74 | # 75 | assert is_sint32(new_value) 76 | code_segment[relocation.offset] = new_value & 0xFF 77 | code_segment[relocation.offset + 1] = (new_value >> 8) & 0xFF 78 | code_segment[relocation.offset + 2] = (new_value >> 16) & 0xFF 79 | code_segment[relocation.offset + 3] = (new_value >> 24) & 0xFF 80 | assert not encoded_func.const_section.relocations 81 | 82 | return code_segment + const_segment 83 | -------------------------------------------------------------------------------- /ast2png.py: -------------------------------------------------------------------------------- 1 | """ 2 | Render an ast to png. Stolen&adapted from: 3 | https://github.com/hchasestevens/show_ast 4 | """ 5 | 6 | import ast 7 | import itertools 8 | from functools import partial 9 | import graphviz 10 | 11 | try: 12 | reduce 13 | except NameError: 14 | from functools import reduce 15 | 16 | try: 17 | _basestring = basestring 18 | except NameError: 19 | _basestring = str 20 | 21 | 22 | SETTINGS = dict( 23 | # Styling options: 24 | scale=2, 25 | font='courier', 26 | shape='none', 27 | terminal_color='#008040', 28 | nonterminal_color='#004080', 29 | 30 | # AST display options: 31 | omit_module=True, 32 | omit_docstrings=True, 33 | ) 34 | 35 | 36 | def _strip_docstring(body): 37 | first = body[0] 38 | if isinstance(first, ast.Expr) and isinstance(first.value, ast.Str): 39 | return body[1:] 40 | return body 41 | 42 | def recurse_through_ast(node, handle_ast, handle_terminal, handle_fields, 43 | handle_no_fields, omit_docstrings): 44 | possible_docstring = isinstance(node, (ast.FunctionDef, 45 | ast.ClassDef, ast.Module)) 46 | 47 | node_fields = zip( 48 | node._fields, 49 | (getattr(node, attr) for attr in node._fields) 50 | ) 51 | field_results = [] 52 | for field_name, field_value in node_fields: 53 | if field_name == 'ctx': 54 | continue 55 | if isinstance(field_value, ast.AST): 56 | field_results.append(handle_ast(field_value)) 57 | 58 | elif isinstance(field_value, list): 59 | if possible_docstring and omit_docstrings and field_name == 'body': 60 | field_value = _strip_docstring(field_value) 61 | field_results.extend( 62 | handle_ast(item) 63 | if isinstance(item, ast.AST) else 64 | handle_terminal(item) 65 | for item in field_value 66 | ) 67 | 68 | elif isinstance(field_value, _basestring): 69 | field_results.append(handle_terminal('"{}"'.format(field_value))) 70 | 71 | elif field_value is not None: 72 | field_results.append(handle_terminal(field_value)) 73 | 74 | if not field_results: 75 | return handle_no_fields(node) 76 | 77 | return handle_fields(node, field_results) 78 | 79 | 80 | 81 | def _bold(label): 82 | return '<{}>'.format(label) 83 | 84 | 85 | def _attach_to_parent(parent, graph, names, label, name=None, **style): 86 | node_name = next(names) if name is None else name 87 | node = graph.node(node_name, label=label, **style) 88 | if parent is not None: 89 | graph.edge(parent, node_name, sametail='t{}'.format(parent)) 90 | 91 | 92 | def handle_ast(node, parent_node, graph, names, omit_docstrings, terminal_color, 93 | nonterminal_color, highlight_node): 94 | if node is highlight_node: 95 | nonterminal_color = '#CC0000' 96 | 97 | attach_to_parent = partial( 98 | _attach_to_parent, 99 | graph=graph, 100 | names=names, 101 | ) 102 | node_name = next(names) 103 | attach_to_parent( 104 | parent=parent_node, 105 | label=_bold(node.__class__.__name__), 106 | name=node_name, 107 | fontcolor=nonterminal_color, 108 | ) 109 | recurse_through_ast( 110 | node, 111 | partial( 112 | handle_ast, 113 | parent_node=node_name, 114 | graph=graph, 115 | names=names, 116 | omit_docstrings=omit_docstrings, 117 | terminal_color=terminal_color, 118 | nonterminal_color=nonterminal_color, 119 | highlight_node=highlight_node, 120 | ), 121 | partial( 122 | handle_terminal, 123 | attach_to_parent=partial( 124 | attach_to_parent, 125 | parent=node_name, 126 | fontcolor=terminal_color, 127 | ), 128 | ), 129 | handle_fields, 130 | partial( 131 | handle_no_fields, 132 | parent_node=node_name, 133 | graph=graph, 134 | terminal_color=terminal_color, 135 | nonterminal_color=nonterminal_color, 136 | ), 137 | omit_docstrings, 138 | ) 139 | 140 | 141 | def handle_terminal(terminal, attach_to_parent): 142 | attach_to_parent(label=str(terminal)) 143 | 144 | 145 | def handle_fields(*__): 146 | pass 147 | 148 | 149 | def handle_no_fields(__, parent_node, graph, terminal_color, nonterminal_color): 150 | parent_node_beginning = '{} '.format(parent_node) 151 | parent_node_num = int(parent_node) 152 | for i, node in enumerate(graph.body[parent_node_num:]): 153 | if node.strip().startswith(parent_node_beginning): 154 | break 155 | else: 156 | raise KeyError("Could not find parent in graph.") 157 | replacements = { 158 | nonterminal_color: terminal_color, 159 | '<': '', 160 | '>': '', 161 | } 162 | graph.body[i + parent_node_num] = reduce( 163 | lambda s, replacement: s.replace(*replacement), 164 | replacements.items(), 165 | node, 166 | ) 167 | 168 | 169 | def ast2png(root, highlight_node=None, filename='ast.png', settings=SETTINGS): 170 | graph = graphviz.Graph(format='png') 171 | names = (str(x) for x in itertools.count()) 172 | 173 | handle_ast( 174 | root, 175 | parent_node=None, 176 | graph=graph, 177 | names=names, 178 | omit_docstrings=settings['omit_docstrings'], 179 | terminal_color=settings['terminal_color'], 180 | nonterminal_color=settings['nonterminal_color'], 181 | highlight_node=highlight_node, 182 | ) 183 | 184 | graph.node_attr.update(dict( 185 | fontname=settings['font'], 186 | shape=settings['shape'], 187 | #height='0.25', # TODO: how to incorporate with scale param? 188 | #fixedsize='true', 189 | )) 190 | 191 | data = graph.pipe() 192 | with open(filename, 'wb') as f: 193 | f.write(data) 194 | 195 | -------------------------------------------------------------------------------- /jit.py: -------------------------------------------------------------------------------- 1 | import mmap 2 | import ast 3 | import textwrap 4 | import inspect 5 | from collections import defaultdict 6 | from cffi import FFI 7 | from assembler import FunctionAssembler as FA 8 | 9 | ffi = FFI() 10 | ffi.cdef(""" 11 | typedef double (*fn0)(void); 12 | typedef double (*fn1)(double); 13 | typedef double (*fn2)(double, double); 14 | typedef double (*fn3)(double, double, double); 15 | """) 16 | 17 | class CompiledFunction: 18 | 19 | def __init__(self, nargs, code): 20 | self.buf = mmap.mmap(-1, len(code), mmap.MAP_PRIVATE, 21 | mmap.PROT_READ | mmap.PROT_WRITE | 22 | mmap.PROT_EXEC) 23 | self.buf[:len(code)] = code 24 | fntype = 'fn%d' % nargs 25 | self.fptr = ffi.cast(fntype, ffi.from_buffer(self.buf)) 26 | 27 | def __call__(self, *args): 28 | return self.fptr(*args) 29 | 30 | 31 | class RegAllocator: 32 | 33 | REGISTERS = (FA.xmm0, FA.xmm1, FA.xmm2, FA.xmm3, FA.xmm4, 34 | FA.xmm5, FA.xmm6, FA.xmm7, FA.xmm8, FA.xmm9, 35 | FA.xmm10, FA.xmm11, FA.xmm12, FA.xmm13, 36 | FA.xmm14, FA.xmm15) 37 | 38 | def __init__(self): 39 | self._registers = list(reversed(self.REGISTERS)) 40 | self.vars = defaultdict(self._allocate) # name -> reg 41 | 42 | def _allocate(self): 43 | try: 44 | return self._registers.pop() 45 | except IndexError: 46 | raise NotImplementedError("Too many variables: register spilling not implemented") 47 | 48 | def get(self, varname): 49 | return self.vars[varname] 50 | 51 | class AstCompiler: 52 | 53 | def __init__(self, src): 54 | self.tree = ast.parse(textwrap.dedent(src)) 55 | self.asm = None 56 | 57 | def show(self, node): 58 | import astpretty 59 | from ast2png import ast2png 60 | astpretty.pprint(node) 61 | ast2png(self.tree, highlight_node=node, filename='ast.png') 62 | 63 | def _newfunc(self, name, argnames): 64 | self.asm = FA(name, argnames) 65 | self.regs = RegAllocator() 66 | for argname in argnames: 67 | self.regs.get(argname) 68 | self.tmp0 = self.regs.get('__scratch_register_0__') 69 | self.tmp1 = self.regs.get('__scratch_register_1__') 70 | 71 | def compile(self): 72 | self.visit(self.tree) 73 | assert self.asm is not None, 'No function found?' 74 | code = self.asm.assemble_and_relocate() 75 | return CompiledFunction(self.asm.nargs, code) 76 | 77 | def visit(self, node): 78 | methname = node.__class__.__name__ 79 | meth = getattr(self, methname, None) 80 | if meth is None: 81 | raise NotImplementedError(methname) 82 | return meth(node) 83 | 84 | def Module(self, node): 85 | for child in node.body: 86 | self.visit(child) 87 | 88 | def FunctionDef(self, node): 89 | assert not self.asm, 'cannot compile more than one function' 90 | argnames = [arg.arg for arg in node.args.args] 91 | self._newfunc(node.name, argnames) 92 | for child in node.body: 93 | self.visit(child) 94 | # return 0 by default 95 | self.asm.PXOR(self.asm.xmm0, self.asm.xmm0) 96 | self.asm.RET() 97 | 98 | def Pass(self, node): 99 | pass 100 | 101 | def Return(self, node): 102 | self.visit(node.value) 103 | self.asm.popsd(self.asm.xmm0) 104 | self.asm.RET() 105 | 106 | def Num(self, node): 107 | self.asm.MOVSD(self.tmp0, self.asm.const(node.n)) 108 | self.asm.pushsd(self.tmp0) 109 | 110 | def BinOp(self, node): 111 | OPS = { 112 | 'ADD': self.asm.ADDSD, 113 | 'SUB': self.asm.SUBSD, 114 | 'MULT': self.asm.MULSD, 115 | 'DIV': self.asm.DIVSD, 116 | } 117 | opname = node.op.__class__.__name__.upper() 118 | self.visit(node.left) 119 | self.visit(node.right) 120 | self.asm.popsd(self.tmp1) 121 | self.asm.popsd(self.tmp0) 122 | OPS[opname](self.tmp0, self.tmp1) 123 | self.asm.pushsd(self.tmp0) 124 | 125 | def Name(self, node): 126 | reg = self.regs.get(node.id) 127 | self.asm.pushsd(reg) 128 | 129 | def Assign(self, node): 130 | assert len(node.targets) == 1 131 | varname = node.targets[0].id 132 | reg = self.regs.get(varname) 133 | self.visit(node.value) 134 | self.asm.popsd(self.tmp0) 135 | self.asm.MOVSD(reg, self.tmp0) 136 | 137 | def If(self, node): 138 | """ 139 | IF GOTO then_label 140 | GOTO end_label 141 | then_label: 142 | 143 | end_label: 144 | ... 145 | """ 146 | CMP = { 147 | 'LT': self.asm.JB 148 | } 149 | assert not node.orelse 150 | then_label = self.asm.Label() 151 | end_label = self.asm.Label() 152 | op = self.visit(node.test) 153 | CMP[op](then_label) 154 | self.asm.JMP(end_label) 155 | self.asm.LABEL(then_label) 156 | for child in node.body: 157 | self.visit(child) 158 | self.asm.LABEL(end_label) 159 | 160 | def Compare(self, node): 161 | assert len(node.ops) == 1 162 | cmp_op = node.ops[0].__class__.__name__.upper() 163 | self.visit(node.left) 164 | self.visit(node.comparators[0]) 165 | self.asm.popsd(self.tmp1) 166 | self.asm.popsd(self.tmp0) 167 | self.asm.UCOMISD(self.tmp0, self.tmp1) 168 | return cmp_op 169 | 170 | def While(self, node): 171 | """ 172 | begin_label: 173 | IF GOTO body_label 174 | GOTO end_label 175 | body_label: 176 | 177 | GOTO begin_label 178 | end_label: 179 | ... 180 | """ 181 | CMP = { 182 | 'LT': self.asm.JB 183 | } 184 | begin_label = self.asm.Label() 185 | body_label = self.asm.Label() 186 | end_label = self.asm.Label() 187 | # 188 | self.asm.LABEL(begin_label) 189 | op = self.visit(node.test) 190 | CMP[op](body_label) 191 | self.asm.JMP(end_label) 192 | self.asm.LABEL(body_label) 193 | for child in node.body: 194 | self.visit(child) 195 | self.asm.JMP(begin_label) 196 | self.asm.LABEL(end_label) 197 | 198 | 199 | def compile(fn): 200 | src = inspect.getsource(fn) 201 | comp = AstCompiler(src) 202 | return comp.compile() 203 | --------------------------------------------------------------------------------