├── .gitignore ├── README.md ├── jit ├── __init__.py ├── codegen.py ├── decorator.py ├── engine.py ├── errors.py ├── j_types.py ├── settings.py └── stdlib.py ├── life.py ├── requirements.txt ├── tests ├── __main__.py ├── test_add.py ├── test_and.py ├── test_arrays.py ├── test_bitshift.py ├── test_coerce_bool.py ├── test_conditionals.py ├── test_eq_neq.py ├── test_gt.py ├── test_inference.py ├── test_lt.py ├── test_modulo.py ├── test_mul_div.py ├── test_negation.py ├── test_print.py ├── test_sub.py └── test_void_zero.py └── todo.md /.gitignore: -------------------------------------------------------------------------------- 1 | .env* 2 | .vscode 3 | __pycache__ 4 | *.llvm 5 | debug.* 6 | *.jit -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This is a prototype for a JIT/AOT framework that takes in a decorated Python function, written in a subset of Python's syntax, and generates a native binary using LLVM (specifically, with the `llvmlite` library). 2 | 3 | This is an *extremely* simple proof-of-concept. It doesn't do very much yet. It might eventually turn into something larger, but right now it's just an exploratory playground. 4 | 5 | If you're interested in making something of it, get in touch. 6 | 7 | # Rationale 8 | 9 | This project grew out of an earlier attempt to use Python to create a compiler for a toy language. I actually got fairly far along with that project, but left it behind for other things. 10 | 11 | When I came back to it, I decided it might be more interesting to use a subset of Python as the base language, instead of one I'd wheel-reinvented. It made sense to leverage Python's own AST mechanisms, and Python's own syntax, since we already got all those for free with Python. 12 | 13 | # Usage 14 | 15 | When you decorate a function with the `@jit` decorator, it's transformed into machine-native assembly when the function is first executed. You can also use the `@jit_immediate` decorator to have the function JITted when the function's module is loaded, as opposed to when the decorated function is first executed. 16 | 17 | In time this functionality could be expanded to AOT compilation as well. For instance, one could feed it a function and have a binary generated and deployed side-by-side with one's Python code, with some convenience functions provided by the compiler to wrap the binary and use it in your code. (It might also be possible to feed it an entire code tree and compile that, but that's a long way off.) 18 | 19 | Right now very few operations are supported. The JIT can only perform basic arithmetic on bytes, store them in variables, and return the results. It does not yet reliably trap type errors or perform other checking (yet). 20 | 21 | # Quickstart 22 | 23 | > NOTE: Python 3.9+ is required, due to some functions in the `ast` module only being available there. 24 | 25 | Clone the repo and install requirements from `requirements.txt`. 26 | 27 | The test suite (run `python .\tests\`) should run through the complete feature set. 28 | 29 | `life.py` provides a simple example, a JIT-accelerated rendition of Conway's Game Of Life. The code used in this example will mature along with the rest of the project. 30 | 31 | # License 32 | 33 | MIT -------------------------------------------------------------------------------- /jit/__init__.py: -------------------------------------------------------------------------------- 1 | from .decorator import jit, jit_immediate, jit_m 2 | -------------------------------------------------------------------------------- /jit/codegen.py: -------------------------------------------------------------------------------- 1 | from llvmlite import ir, binding 2 | import sys 3 | import ast, inspect, builtins 4 | import pprint 5 | import pathlib 6 | 7 | from llvmlite.ir.types import VoidType 8 | from .j_types import * 9 | from .errors import JitTypeError, BaseJitError 10 | from collections import namedtuple 11 | from . import settings 12 | from . import stdlib 13 | 14 | from typing import Union 15 | 16 | 17 | class JitObj: 18 | def __init__(self, j_type, llvm, obj): 19 | self.j_type = j_type 20 | self.llvm = llvm 21 | self.obj = obj 22 | 23 | def reify_type(self, other): 24 | pass 25 | 26 | 27 | class Value(JitObj): 28 | pass 29 | 30 | 31 | class Variable(JitObj): 32 | pass 33 | 34 | 35 | class TypeTarget: 36 | def __init__(self, tt_list, type_target): 37 | self.tt_list = tt_list 38 | self.tt_list.append(type_target) 39 | 40 | def __enter__(self): 41 | pass 42 | 43 | def __exit__(self, *a): 44 | self.tt_list.pop() 45 | 46 | 47 | class Codegen: 48 | def __init__(self): 49 | self.modules = {} 50 | self.target_data = None 51 | 52 | def val(self, obj: JitObj): 53 | if isinstance(obj, Value): 54 | return obj.llvm 55 | elif isinstance(obj, Variable): 56 | return self.builder.load(obj.llvm) 57 | 58 | def _coerce_bool(self, expression, node): 59 | if node.j_type != u1: 60 | expression = node.j_type.to_bool(self, expression) 61 | return expression 62 | 63 | def generate_function_name(self, name): 64 | return f"jit.{name}" 65 | 66 | def codegen_all(self, code_obj): 67 | 68 | # String name of the module this function lives in 69 | self.py_module_name: str = code_obj.__module__ 70 | # Reference to module object that hosts function 71 | self.py_module = sys.modules.get(self.py_module_name) 72 | self.module: ir.Module = self.modules.get(self.py_module_name) 73 | 74 | if not self.module: 75 | self.module = ir.Module(self.py_module_name) 76 | self.modules[self.py_module_name] = self.module 77 | 78 | if not self.target_data: 79 | self.target_data = binding.create_target_data(self.module.data_layout) 80 | self.bitness = ( 81 | ir.PointerType(ir.IntType(8)).get_abi_size(self.target_data) * 8 82 | ) 83 | self.mem = ir.IntType(self.bitness) 84 | self.mem_ptr = ir.PointerType(self.mem) 85 | self.memv = lambda val: ir.Constant(ir.IntType(self.bitness), val) 86 | self.zero = self.memv(0) 87 | 88 | self.code_obj = code_obj 89 | self.instructions = ast.parse(inspect.getsource(code_obj)) 90 | self.var_counter = 0 91 | 92 | dump_file = "debug" 93 | 94 | if settings.DUMP_TO_DIR: 95 | module_file = codegen.py_module.__file__ 96 | module_path = pathlib.Path(module_file) 97 | dump_file = f"{module_path}.debug" 98 | 99 | if settings.DEBUG: 100 | with open(f"{dump_file}.debug.txt", "w") as self.output: 101 | self.codegen(self.instructions) 102 | 103 | if settings.DUMP: 104 | with open(f"{dump_file}.debug.llvm", "w") as self.output: 105 | self.output.write(str(self.module)) 106 | 107 | def codegen(self, instruction): 108 | # print(instruction) 109 | itype = instruction.__class__.__name__ 110 | self.output.write(f"\n\n>> {itype}\n\n") 111 | self.output.write(pprint.pformat(instruction.__dict__)) 112 | call = getattr(self, f"visit_{itype}", None) 113 | if not call: 114 | print(itype) 115 | return 116 | try: 117 | return call(instruction) 118 | except BaseJitError as e: 119 | raise e 120 | 121 | def visit_Module(self, node: ast.Module): 122 | for module_node in node.body: 123 | self.codegen(module_node) 124 | 125 | def type_target(self, tt): 126 | return TypeTarget(self.type_targets, tt) 127 | 128 | def get_annotation(self, annotation): 129 | item = ast.unparse(annotation) 130 | arg_type = eval(item, self.py_module.__dict__) 131 | if not arg_type: 132 | raise JitTypeError("annotation not found") 133 | 134 | if not isinstance(arg_type, PrimitiveType): 135 | converted_arg_type = type_conversions(arg_type) 136 | 137 | if converted_arg_type is None: 138 | raise JitTypeError(f"Type {arg_type} not supported") 139 | arg_type = converted_arg_type 140 | 141 | if isinstance(arg_type, ObjectType): 142 | arg_type = objectpointer(arg_type) 143 | 144 | return arg_type 145 | 146 | def visit_ClassDef(self, node: ast.ClassDef): 147 | print(node.__dict__) 148 | 149 | def visit_FunctionDef(self, node: ast.FunctionDef): 150 | 151 | self.type_targets = [] 152 | self.break_stack = [] 153 | self.loop_stack = [] 154 | self.argtypes = [] 155 | 156 | for argument in node.args.args: 157 | if not argument.annotation: 158 | raise JitTypeError(f"Arg {argument.arg} not annotated") 159 | 160 | arg_type = self.get_annotation(argument.annotation) 161 | self.argtypes.append(arg_type) 162 | 163 | self.type_data: dict = {} 164 | self.type_unset: bool = False 165 | 166 | # set return type from annotation if available 167 | 168 | self.return_type: PrimitiveType = self.type_data.get("return", None) 169 | if not self.return_type: 170 | self.type_unset = True 171 | self.return_type = void 172 | 173 | function_returntype = ( 174 | ir.VoidType() if self.return_type is void else self.return_type 175 | ) 176 | self.functiontype = ir.FunctionType( 177 | function_returntype, [x.llvm for x in self.argtypes], False 178 | ) 179 | self.function = ir.Function(self.module, self.functiontype, node.name) 180 | self.function.return_jtype = self.return_type 181 | self.builder = ir.IRBuilder() 182 | 183 | self.setup_block = self.function.append_basic_block("setup") 184 | self.entry_block = self.function.append_basic_block("entry") 185 | self.exit_block = self.function.append_basic_block("exit") 186 | 187 | self.builder = ir.IRBuilder(self.setup_block) 188 | # TODO: use type alloca 189 | 190 | if self.return_type is not void: 191 | self.return_value = Variable( 192 | self.return_type, self.return_type.alloca(self), None 193 | ) 194 | else: 195 | self.return_value = void 196 | 197 | self.vars = {} 198 | 199 | for func_argument, argument, argtype in zip( 200 | self.function.args, node.args.args, self.argtypes 201 | ): 202 | # if isinstance(argtype, ObjectPointer): 203 | # val = Value(argtype.pointee, self.builder.load(func_argument), argument) 204 | # else: 205 | val = Value(argtype, func_argument, argument) 206 | self.vars[argument.arg] = val 207 | 208 | self.setup_exit = self.builder.branch(self.entry_block) 209 | self.builder.position_at_start(self.entry_block) 210 | 211 | for instruction in node.body: 212 | self.codegen(instruction) 213 | 214 | if not self.builder._block.is_terminated: 215 | self.builder.branch(self.exit_block) 216 | 217 | self.builder.position_at_start(self.exit_block) 218 | 219 | if self.return_value is void: 220 | self.builder.ret_void() 221 | else: 222 | self.builder.ret(self.val(self.return_value)) 223 | 224 | def visit_Return(self, node: ast.Return): 225 | if node.value is None: 226 | return 227 | 228 | return_value: JitObj = self.codegen(node.value) 229 | if return_value is None: 230 | return 231 | 232 | llvm = self.val(return_value) 233 | 234 | if return_value.j_type != self.return_type: 235 | if not self.type_unset: 236 | raise JitTypeError("too many return type redefinitions") 237 | self.return_type = return_value.j_type 238 | self.functiontype = ir.FunctionType( 239 | llvm.type, [x.llvm for x in self.argtypes], False 240 | ) 241 | self.function.type = ir.PointerType(self.functiontype) 242 | self.function.ftype = self.functiontype 243 | self.function.return_value.type = llvm.type 244 | self.function.return_jtype = return_value.j_type 245 | self.type_unset = False 246 | 247 | with self.builder.goto_block(self.entry_block): 248 | self.builder.position_before(self.setup_exit) 249 | self.return_value = Variable( 250 | self.return_type, self.return_type.alloca(self), None 251 | ) 252 | 253 | self.builder.store(llvm, self.return_value.llvm) 254 | self.builder.branch(self.exit_block) 255 | 256 | def visit_Constant(self, node: ast.Constant): 257 | 258 | val = node.value 259 | default_type_to_use = type_conversions(val.__class__) 260 | 261 | if not default_type_to_use: 262 | raise JitTypeError("type not supported", type(node.value)) 263 | 264 | if self.type_targets and self.type_targets[0] is not None: 265 | if not isinstance(self.type_targets[0], PrimitiveType): 266 | raise JitTypeError( 267 | "can't coerce", type(node.value), self.type_targets[0] 268 | ) 269 | default_type_to_use = self.type_targets[0] 270 | 271 | result = Value( 272 | default_type_to_use, ir.Constant(default_type_to_use.llvm, val), node 273 | ) 274 | 275 | return result 276 | 277 | def visit_Name(self, node: ast.Name): 278 | 279 | var_ref = self.vars.get(node.id) 280 | if var_ref: 281 | # if isinstance(var_ref, ObjectPointer): 282 | # return self.builder.load(var_ref) 283 | return var_ref 284 | 285 | var_ref = self.py_module.__dict__.get(node.id) 286 | 287 | # attempt to capture value at compile time from surrounding module 288 | # TODO: pass silently as a variable? 289 | 290 | if var_ref: 291 | new_node = ast.Constant(value=var_ref) 292 | new_value = self.codegen(new_node) 293 | self.create_name(node, new_value) 294 | return new_value 295 | 296 | return None 297 | 298 | def create_name(self, varname, value): 299 | with self.builder.goto_block(self.setup_block): 300 | alloc = value.j_type.alloca(self) 301 | ref = Variable(value.j_type, alloc, varname) 302 | self.vars[varname.id] = ref 303 | return ref 304 | 305 | def visit_Assign(self, node: Union[ast.Assign, ast.AnnAssign]): 306 | 307 | # TODO: allow multiple assign 308 | 309 | if isinstance(node, ast.AnnAssign): 310 | varname: ast.Name = node.target 311 | else: 312 | varname: ast.Name = node.targets[0] 313 | 314 | var_ref: JitObj = self.codegen(varname) 315 | 316 | if isinstance(node, ast.AnnAssign): 317 | tt = self.get_annotation(node.annotation) 318 | else: 319 | if var_ref is None: 320 | tt = None 321 | else: 322 | tt = var_ref.j_type 323 | 324 | with self.type_target(tt): 325 | value: JitObj = self.codegen(node.value) 326 | 327 | if not var_ref: 328 | ref = self.create_name(varname, value) 329 | else: 330 | ref = var_ref 331 | 332 | if isinstance(ref.j_type, PointerType): 333 | if ref.j_type.pointee != value.j_type: 334 | raise JitTypeError( 335 | "mismatched types:", ref.j_type.pointee, value.j_type 336 | ) 337 | elif ref.j_type != value.j_type: 338 | raise JitTypeError("mismatched types:", ref.j_type, value.j_type) 339 | 340 | ref_llvm = ref.llvm 341 | val = self.val(value) 342 | return self.builder.store(val, ref_llvm) 343 | 344 | def visit_AugAssign(self, node: ast.AugAssign): 345 | binop = ast.BinOp(left=node.target, right=node.value, op=node.op) 346 | assignment = ast.Assign(targets=[node.target], value=binop) 347 | return self.codegen(assignment) 348 | 349 | def visit_AnnAssign(self, node: ast.AnnAssign): 350 | self.visit_Assign(node) 351 | 352 | def visit_BinOp(self, node: ast.BinOp): 353 | lhs: JitObj = self.codegen(node.left) 354 | # lhs_v = self.val_node(lhs) 355 | 356 | with self.type_target(lhs.j_type): 357 | rhs: JitObj = self.codegen(node.right) 358 | 359 | optype = node.op.__class__.__name__ 360 | op = getattr(lhs.j_type, f"impl_{optype}", None) 361 | if not op: 362 | raise Exception("Op not supported", optype) 363 | 364 | if lhs.j_type != rhs.j_type: 365 | raise JitTypeError("mismatched types for op:", lhs.j_type, rhs.j_type) 366 | 367 | lhs_v = self.val(lhs) 368 | rhs_v = self.val(rhs) 369 | result = op(self, lhs_v, rhs_v) 370 | 371 | return Value(lhs.j_type, result, node) 372 | 373 | def visit_UnaryOp(self, node: ast.UnaryOp): 374 | 375 | lhs: JitObj = self.codegen(node.operand) 376 | 377 | optype = node.op.__class__.__name__ 378 | op = getattr(lhs.j_type, f"impl_{optype}", None) 379 | if not op: 380 | raise Exception("Op not supported", optype) 381 | 382 | result = op(self, lhs) 383 | return Value(lhs.j_type, result, node) 384 | 385 | def visit_Compare(self, node: ast.Compare): 386 | 387 | # TODO: multi-comparison 388 | # maybe unpack that to if x==y and y==z 389 | 390 | lhs: JitObj = self.codegen(node.left) 391 | lhs_val = self.val(lhs) 392 | # lhs_n = self.val_node(lhs) 393 | # print(lhs_n.__dict__) 394 | 395 | with self.type_target(lhs.j_type): 396 | rhs: JitObj = self.codegen(node.comparators[0]) 397 | rhs_val = self.val(rhs) 398 | 399 | optype = node.ops[0].__class__.__name__ 400 | op = getattr(lhs.j_type, f"impl_{optype}", None) 401 | if not op: 402 | raise Exception("Op not supported", optype, lhs.j_type) 403 | 404 | if lhs.j_type != rhs.j_type: 405 | raise JitTypeError("mismatched types for op:", lhs.j_type, rhs.j_type) 406 | 407 | result = op(self, lhs_val, rhs_val) 408 | return Value(u1, result, node) 409 | 410 | def visit_If(self, node: ast.If): 411 | then_block = self.builder.append_basic_block("then") 412 | else_block = self.builder.append_basic_block("else") 413 | end_block = self.builder.append_basic_block("end") 414 | 415 | test_clause = self.codegen(node.test) 416 | test_clause_llvm = self._coerce_bool(self.val(test_clause), test_clause) 417 | 418 | self.builder.cbranch(test_clause_llvm, then_block, else_block) 419 | 420 | self.builder.position_at_start(then_block) 421 | for n in node.body: 422 | self.codegen(n) 423 | if not self.builder.block.is_terminated: 424 | self.builder.branch(end_block) 425 | 426 | self.builder.position_at_start(else_block) 427 | for n in node.orelse: 428 | self.codegen(n) 429 | if not self.builder.block.is_terminated: 430 | self.builder.branch(end_block) 431 | self.builder.position_at_start(end_block) 432 | 433 | def visit_While(self, node: ast.While): 434 | loop_block = self.builder.append_basic_block("while") 435 | end_block = self.builder.append_basic_block("end_while") 436 | self.loop_stack.append(loop_block) 437 | self.break_stack.append(end_block) 438 | 439 | self.builder.branch(loop_block) 440 | self.builder.position_at_start(loop_block) 441 | for n in node.body: 442 | self.codegen(n) 443 | test_clause = self.codegen(node.test) 444 | test_clause_llvm = self._coerce_bool(self.val(test_clause), test_clause) 445 | 446 | self.builder.cbranch(test_clause_llvm, loop_block, end_block) 447 | self.builder.position_at_start(end_block) 448 | 449 | def visit_Break(self, node: ast.Break): 450 | if not self.break_stack: 451 | raise Exception("break encountered outside of loop") 452 | break_target = self.break_stack.pop() 453 | self.builder.branch(break_target) 454 | 455 | def visit_Continue(self, node: ast.Continue): 456 | loop_target = self.loop_stack.pop() 457 | self.builder.branch(loop_target) 458 | 459 | def visit_Subscript(self, node: ast.Subscript): 460 | value = self.codegen(node.value) 461 | val_llvm = value.llvm 462 | 463 | slice = self.codegen(node.slice) 464 | index = self.val(slice) 465 | ptr = self.builder.gep(val_llvm, [self.zero, index]) 466 | 467 | # TODO: this feels wrong 468 | # we should have a method for j-type that extracts 469 | # the j_type of a subscript 470 | 471 | if isinstance(value.j_type, ObjectPointer): 472 | return Value(value.j_type.pointee.base_type, ptr, None) 473 | 474 | return Variable(value.j_type.base_type, ptr, None) 475 | 476 | def visit_BoolOp(self, node: ast.BoolOp): 477 | 478 | # TODO: multicomparisons 479 | 480 | lhs = self.codegen(node.values[0]) 481 | rhs = self.codegen(node.values[1]) 482 | 483 | lhs_val = self.val(lhs) 484 | rhs_val = self.val(rhs) 485 | 486 | if isinstance(node.op, ast.And): 487 | result = self.builder.and_(lhs_val, rhs_val) 488 | 489 | return Value(u1, result, None) 490 | 491 | def visit_Expr(self, node: ast.Expr): 492 | return self.codegen(node.value) 493 | 494 | def visit_Call(self, node: ast.Call): 495 | 496 | # only one argument supported so far 497 | args = [self.codegen(_) for _ in node.args] 498 | # a_val = self.val(arg) 499 | vals = [self.val(_) for _ in args] 500 | 501 | function_name = node.func.id 502 | 503 | # TODO: will val pass a pointer to an object? find out 504 | 505 | # first, find out if this is a function that already exists 506 | # call = self.module.globals.get(function_name, None) 507 | # if call: 508 | # return self.builder.call(call, vals) 509 | 510 | # next, find out if this is a function from the standard library 511 | call = stdlib.make(function_name) 512 | if call: 513 | return call(self, vals) 514 | 515 | raise Exception(f"no such function {function_name}") 516 | 517 | 518 | codegen = Codegen() 519 | -------------------------------------------------------------------------------- /jit/decorator.py: -------------------------------------------------------------------------------- 1 | from .engine import jitengine 2 | from .codegen import codegen as c 3 | from .j_types import JitType, ObjectType 4 | 5 | 6 | def jit_m(name=None): 7 | def fn(func): 8 | func._new_name = name 9 | return jit(func) 10 | 11 | return fn 12 | 13 | 14 | def jit_immediate(func): 15 | 16 | try: 17 | c.codegen_all(func) 18 | except Exception as e: 19 | raise e 20 | 21 | # TODO: separate compilation from extraction of function 22 | 23 | jitted_function = jitengine.compile(c, entry_point=func.__name__) 24 | func._jit = jitted_function 25 | 26 | def wrapper(*a, **ka): 27 | aa = [] 28 | for arg in a: 29 | if isinstance(arg, JitType): 30 | aa.append(arg.from_jtype(arg)) 31 | else: 32 | aa.append(arg) 33 | 34 | result = jitted_function(*aa, **ka) 35 | if hasattr(result, "contents"): 36 | return result.contents 37 | return result 38 | 39 | # wrapper.f = func 40 | wrapper._wrapped = func 41 | wrapper._jit = jitted_function 42 | 43 | return wrapper 44 | 45 | 46 | def jit(func): 47 | def wrapper(*a, **ka): 48 | aa = [] 49 | for arg in a: 50 | if isinstance(arg, JitType): 51 | aa.append(arg.from_jtype(arg)) 52 | else: 53 | aa.append(arg) 54 | 55 | try: 56 | return func._jit(*aa, **ka) 57 | except AttributeError: 58 | pass 59 | 60 | try: 61 | c.codegen_all(func) 62 | except Exception as e: 63 | raise e 64 | 65 | # TODO: separate compilation from extraction of function 66 | 67 | jitted_function = jitengine.compile(c, entry_point=func.__name__) 68 | func._jit = jitted_function 69 | result = jitted_function(*aa, **ka) 70 | if hasattr(result, "contents"): 71 | return result.contents 72 | return result 73 | 74 | wrapper._wrapped = func 75 | # wrapper._jit = func._jit 76 | 77 | return wrapper 78 | -------------------------------------------------------------------------------- /jit/engine.py: -------------------------------------------------------------------------------- 1 | import llvmlite.binding as llvm 2 | from ctypes import CFUNCTYPE, ArgumentError 3 | from .j_types import PrimitiveType 4 | import pathlib 5 | from .errors import JitTypeError 6 | 7 | from . import settings 8 | 9 | 10 | class JitEngine: 11 | def __init__(self): 12 | llvm.initialize() 13 | llvm.initialize_native_target() 14 | llvm.initialize_native_asmprinter() 15 | self.modules = {} 16 | self.engines = {} 17 | self.engine = None 18 | 19 | def create_execution_engine(self): 20 | # Create a target machine representing the host 21 | target = llvm.Target.from_default_triple() 22 | target_machine = target.create_target_machine() 23 | # And an execution engine with an empty backing module 24 | backing_mod = llvm.parse_assembly("") 25 | engine = llvm.create_mcjit_compiler(backing_mod, target_machine) 26 | 27 | self.pm = llvm.ModulePassManager() 28 | llvm.PassManagerBuilder().populate(self.pm) 29 | return engine 30 | 31 | def compile_ir(self, llvm_ir, opt_level=None): 32 | # Create a LLVM module object from the IR 33 | try: 34 | mod = llvm.parse_assembly(llvm_ir) 35 | if opt_level: 36 | if opt_level is True: 37 | opt_level = 3 38 | self.pm.opt_level = opt_level 39 | self.pm.run(mod) 40 | except RuntimeError as e: 41 | print(llvm_ir) 42 | raise e 43 | mod.verify() 44 | 45 | # Now add the module and make sure it is ready for execution 46 | self.engine.add_module(mod) 47 | self.engine.finalize_object() 48 | self.engine.run_static_constructors() 49 | 50 | with open("debug.opt.llvm", "w") as f: 51 | f.write(str(mod)) 52 | return mod 53 | 54 | def compile(self, codegen, opt_level=None, entry_point="main"): 55 | 56 | module_file = codegen.py_module.__file__ 57 | module_path = pathlib.Path(module_file) 58 | module_base_path = module_path.parent 59 | module_filename = module_path.parts[-1] 60 | jit_module_filename = f"{module_filename}.jit" 61 | obj_module_filename = f"{module_filename}.obj" 62 | jit_module_path = pathlib.Path(module_base_path, jit_module_filename) 63 | 64 | self.engine = self.engines.get(codegen.py_module_name, None) 65 | if self.engine is None: 66 | self.engine = self.create_execution_engine() 67 | 68 | if settings.ASM: 69 | 70 | def obj_write(module, buffer): 71 | with open(module_base_path / obj_module_filename, "wb") as f: 72 | f.write(buffer) 73 | 74 | self.engine.set_object_cache(obj_write) 75 | 76 | do_jit = False 77 | 78 | function_name = codegen.code_obj.__name__ 79 | if not self.engine.get_function_address(function_name): 80 | do_jit = True 81 | 82 | if jit_module_path.exists(): 83 | if jit_module_path.stat().st_mtime < module_path.stat().st_mtime: 84 | do_jit = True 85 | else: 86 | do_jit = True 87 | 88 | if do_jit: 89 | mod = self.compile_ir(str(codegen.module), opt_level) 90 | if settings.CACHE: 91 | with open(f"{module_file}.jit", "wb") as f: 92 | f.write(mod.as_bitcode()) 93 | else: 94 | with open(f"{module_file}.jit", "rb") as f: 95 | bitcode = f.read() 96 | mod = llvm.parse_bitcode(bitcode) 97 | mod.verify() 98 | self.engine.add_module(mod) 99 | self.engine.finalize_object() 100 | self.engine.run_static_constructors() 101 | 102 | self.modules[codegen.py_module_name] = mod 103 | 104 | arg_types = [_.to_ctype() for _ in codegen.argtypes] 105 | func = codegen.module.globals[entry_point] 106 | cfunctype = CFUNCTYPE(func.return_jtype.to_ctype(), *arg_types) 107 | 108 | eng = self.engine 109 | 110 | def ff(*a, **ka): 111 | func_ptr = eng.get_function_address(entry_point) 112 | cfunc = cfunctype(func_ptr) 113 | try: 114 | return cfunc(*a, **ka) 115 | except ArgumentError: 116 | raise JitTypeError 117 | 118 | ff.restype = func.return_jtype.to_ctype() 119 | 120 | return ff 121 | 122 | 123 | jitengine = JitEngine() 124 | -------------------------------------------------------------------------------- /jit/errors.py: -------------------------------------------------------------------------------- 1 | class BaseJitError(Exception): 2 | pass 3 | 4 | 5 | class JitTypeError(BaseJitError): 6 | pass 7 | -------------------------------------------------------------------------------- /jit/j_types.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | from llvmlite import ir 3 | import array as arr 4 | 5 | 6 | class JitType: 7 | llvm = None 8 | 9 | def alloca(self): 10 | raise NotImplementedError 11 | 12 | def slice(self): 13 | raise NotImplementedError 14 | 15 | 16 | class Void(JitType): 17 | llvm = ir.VoidType() 18 | signed = None 19 | 20 | def to_ctype(self): 21 | return lambda x: None 22 | 23 | 24 | class PrimitiveType(JitType): 25 | signed = None 26 | 27 | def __init__(self, size): 28 | self.size = size 29 | self.llvm = self.j_type(self.size) 30 | 31 | def __repr__(self): 32 | return f'<{"Un" if not self.signed else ""}signed i{self.size}>' 33 | 34 | def alloca(self, codegen): 35 | return codegen.builder.alloca(self.llvm) 36 | 37 | 38 | class BaseInteger(PrimitiveType): 39 | 40 | j_type = ir.IntType 41 | 42 | _from_ctype = { 43 | True: { 44 | 64: ctypes.c_int64, 45 | 32: ctypes.c_int32, 46 | 16: ctypes.c_int16, 47 | 8: ctypes.c_int8, 48 | }, 49 | False: { 50 | 64: ctypes.c_uint64, 51 | 32: ctypes.c_uint32, 52 | 16: ctypes.c_uint16, 53 | 8: ctypes.c_uint8, 54 | 1: ctypes.c_bool, 55 | }, 56 | } 57 | 58 | def to_ctype(self): 59 | return self._from_ctype[self.signed][self.size] 60 | 61 | def impl_Add(self, codegen, lhs, rhs): 62 | return codegen.builder.add(lhs, rhs) 63 | 64 | def impl_Sub(self, codegen, lhs, rhs): 65 | return codegen.builder.sub(lhs, rhs) 66 | 67 | def impl_Mult(self, codegen, lhs, rhs): 68 | return codegen.builder.mul(lhs, rhs) 69 | 70 | def impl_LShift(self, codegen, lhs, rhs): 71 | return codegen.builder.shl(lhs, rhs) 72 | 73 | def impl_RShift(self, codegen, lhs, rhs): 74 | return codegen.builder.ashr(lhs, rhs) 75 | 76 | 77 | class SignedInteger(BaseInteger): 78 | signed = True 79 | 80 | def impl_Div(self, codegen, lhs, rhs): 81 | return codegen.builder.sdiv(lhs, rhs) 82 | 83 | def impl_USub(self, codegen, lhs): 84 | lhs = codegen.val(lhs) 85 | return codegen.builder.sub(ir.Constant(lhs.type, 0), lhs) 86 | 87 | def impl_Mod(self, codegen, lhs, rhs): 88 | b: ir.IRBuilder = codegen.builder 89 | 90 | g1 = b.icmp_signed(">=", lhs, ir.Constant(lhs.type, 0)) 91 | 92 | g2 = b.urem(lhs, rhs) 93 | 94 | g3a = b.sub(ir.Constant(lhs.type, 0), lhs) 95 | g3b = b.urem(g3a, rhs) 96 | g3c = b.sub(rhs, g3b) 97 | g3d = b.urem(g3c, rhs) 98 | 99 | return b.select(g1, g2, g3d) 100 | 101 | def impl_Eq(self, codegen, lhs, rhs): 102 | return codegen.builder.icmp_signed("==", lhs, rhs) 103 | 104 | def impl_NotEq(self, codegen, lhs, rhs): 105 | return codegen.builder.icmp_signed("!=", lhs, rhs) 106 | 107 | def impl_Gt(self, codegen, lhs, rhs): 108 | return codegen.builder.icmp_signed(">", lhs, rhs) 109 | 110 | def impl_Lt(self, codegen, lhs, rhs): 111 | return codegen.builder.icmp_signed("<", lhs, rhs) 112 | 113 | def impl_GtE(self, codegen, lhs, rhs): 114 | return codegen.builder.icmp_signed(">=", lhs, rhs) 115 | 116 | def impl_LtE(self, codegen, lhs, rhs): 117 | return codegen.builder.icmp_signed("<=", lhs, rhs) 118 | 119 | def to_bool(self, codegen, value): 120 | return codegen.builder.icmp_signed("!=", value, ir.Constant(value.type, 0)) 121 | 122 | 123 | class UnsignedInteger(BaseInteger): 124 | signed = False 125 | 126 | def impl_Div(self, codegen, lhs, rhs): 127 | return codegen.builder.udiv(lhs, rhs) 128 | 129 | def impl_Mod(self, codegen, lhs, rhs): 130 | return codegen.builder.urem(lhs, rhs) 131 | 132 | def impl_Eq(self, codegen, lhs, rhs): 133 | return codegen.builder.icmp_unsigned("==", lhs, rhs) 134 | 135 | def impl_NotEq(self, codegen, lhs, rhs): 136 | return codegen.builder.icmp_unsigned("!=", lhs, rhs) 137 | 138 | def impl_Gt(self, codegen, lhs, rhs): 139 | return codegen.builder.icmp_unsigned(">", lhs, rhs) 140 | 141 | def impl_Lt(self, codegen, lhs, rhs): 142 | return codegen.builder.icmp_unsigned("<", lhs, rhs) 143 | 144 | def impl_GtE(self, codegen, lhs, rhs): 145 | return codegen.builder.icmp_unsigned(">=", lhs, rhs) 146 | 147 | def impl_LtE(self, codegen, lhs, rhs): 148 | return codegen.builder.icmp_unsigned("<=", lhs, rhs) 149 | 150 | def to_bool(self, codegen, value): 151 | return codegen.builder.icmp_unsigned("!=", value, ir.Constant(value.type, 0)) 152 | 153 | 154 | class BaseFloat(PrimitiveType): 155 | signed = True 156 | 157 | def __repr__(self): 158 | return f'<{"Un" if not self.signed else ""}signed u{self.size}>' 159 | 160 | def __init__(self): 161 | self.llvm = self.j_type() 162 | 163 | def impl_Add(self, codegen, lhs, rhs): 164 | return codegen.builder.fadd(lhs, rhs) 165 | 166 | def impl_Sub(self, codegen, lhs, rhs): 167 | return codegen.builder.fsub(lhs, rhs) 168 | 169 | def impl_Mult(self, codegen, lhs, rhs): 170 | return codegen.builder.fmul(lhs, rhs) 171 | 172 | def impl_Div(self, codegen, lhs, rhs): 173 | return codegen.builder.fdiv(lhs, rhs) 174 | 175 | def impl_USub(self, codegen, lhs): 176 | lhs = codegen.val(lhs) 177 | return codegen.builder.fsub(ir.Constant(lhs.type, 0.0), lhs) 178 | 179 | def impl_Eq(self, codegen, lhs, rhs): 180 | return codegen.builder.fcmp_unordered("==", lhs, rhs) 181 | 182 | def impl_NotEq(self, codegen, lhs, rhs): 183 | return codegen.builder.fcmp_unordered("!=", lhs, rhs) 184 | 185 | def impl_Gt(self, codegen, lhs, rhs): 186 | return codegen.builder.fcmp_unordered(">", lhs, rhs) 187 | 188 | def impl_Lt(self, codegen, lhs, rhs): 189 | return codegen.builder.fcmp_unordered("<", lhs, rhs) 190 | 191 | def impl_GtE(self, codegen, lhs, rhs): 192 | return codegen.builder.fcmp_unordered(">=", lhs, rhs) 193 | 194 | def impl_LtE(self, codegen, lhs, rhs): 195 | return codegen.builder.fcmp_unordered("<=", lhs, rhs) 196 | 197 | def to_bool(self, codegen, value): 198 | return codegen.builder.fcmp_unordered("!=", value, ir.Constant(value.type, 0.0)) 199 | 200 | 201 | class Float(BaseFloat): 202 | size = 32 203 | 204 | def to_ctype(self): 205 | return ctypes.c_float 206 | 207 | j_type = ir.FloatType 208 | 209 | 210 | class Double(BaseFloat): 211 | size = 64 212 | 213 | def to_ctype(self): 214 | return ctypes.c_double 215 | 216 | j_type = ir.DoubleType 217 | 218 | 219 | i64 = SignedInteger(64) 220 | u64 = UnsignedInteger(64) 221 | 222 | i32 = SignedInteger(32) 223 | u32 = UnsignedInteger(32) 224 | 225 | i16 = SignedInteger(16) 226 | u16 = UnsignedInteger(16) 227 | 228 | i8 = SignedInteger(8) 229 | ubyte = UnsignedInteger(8) 230 | u8 = ubyte 231 | 232 | u1 = UnsignedInteger(1) 233 | 234 | f64 = Double() 235 | f32 = Float() 236 | 237 | void = Void() 238 | 239 | 240 | class ObjectType(JitType): 241 | pass 242 | 243 | 244 | import copy 245 | 246 | 247 | class ArrayType(ObjectType): 248 | def __init__(self, base_type, dimensions): 249 | 250 | if len(dimensions) > 1: 251 | self.base_type = ArrayType(base_type, dimensions[1:]) 252 | self.dimensions = dimensions[1:] 253 | self.size = dimensions[0] 254 | self.llvm = ir.ArrayType(self.base_type.llvm, dimensions[0]) 255 | else: 256 | self.base_type = base_type 257 | self.dimensions = dimensions 258 | self.size = dimensions[0] 259 | self.llvm = ir.ArrayType(self.base_type.llvm, dimensions[0]) 260 | 261 | self._a_type = self.base_type.to_ctype() * self.size 262 | 263 | def to_ctype(self): 264 | return self._a_type 265 | 266 | def __call__(self): 267 | new = Array(self._a_type) 268 | return new 269 | 270 | def slice(self): 271 | return self.base_type 272 | 273 | 274 | class Array(ArrayType): 275 | def __init__(self, array_type): 276 | self._a_type = array_type 277 | self._array = self._a_type() 278 | 279 | def from_jtype(self, value): 280 | return self._array 281 | 282 | def __getitem__(self, item): 283 | return self._array[item] 284 | 285 | def __setitem__(self, item, value): 286 | self._array[item] = value 287 | 288 | 289 | def array(base_type, dimensions): 290 | return ArrayType(base_type, dimensions) 291 | 292 | 293 | class PointerType(JitType): 294 | def __init__(self, pointee): 295 | self.pointee = pointee 296 | self.llvm = ir.PointerType(self.pointee.llvm) 297 | 298 | def from_jtype(self, value): 299 | return ctypes.POINTER(self.pointee.from_jtype()) 300 | 301 | def to_ctype(self): 302 | return ctypes.POINTER(self.pointee.to_ctype()) 303 | 304 | def alloca(self, codegen): 305 | return codegen.builder.alloca(self.llvm) 306 | 307 | 308 | def pointer(pointee): 309 | return PointerType(pointee) 310 | 311 | 312 | class ObjectPointer(PointerType): 313 | pass 314 | 315 | 316 | def objectpointer(pointee): 317 | return ObjectPointer(pointee) 318 | 319 | 320 | def type_conversions(type_to_convert): 321 | if type_to_convert == int: 322 | return i64 323 | elif type_to_convert == float: 324 | return f64 325 | elif isinstance(type_to_convert, ArrayType): 326 | return type_to_convert 327 | 328 | return None 329 | -------------------------------------------------------------------------------- /jit/settings.py: -------------------------------------------------------------------------------- 1 | # Cache bitcode 2 | CACHE = False 3 | # Save object file for module 4 | ASM = False 5 | # Write LLVM dump to file 6 | DUMP = True 7 | # Enable AST debugging dump 8 | DEBUG = True 9 | # Write dump files to same directory as module? 10 | # if False, this will write all dumps to one "debug" file in the main dir 11 | DUMP_TO_DIR = False 12 | -------------------------------------------------------------------------------- /jit/stdlib.py: -------------------------------------------------------------------------------- 1 | from llvmlite import ir 2 | 3 | 4 | def make(function_name): 5 | 6 | maker = globals().get(f"make_{function_name}", None) 7 | if not maker: 8 | return None 9 | 10 | return maker 11 | 12 | 13 | def _make_llvm( 14 | self, function_name, return_type=ir.VoidType, arguments=[], var_arg=False 15 | ): 16 | m: ir.Module = self.module 17 | 18 | p_func = m.globals.get("function_name", None) 19 | if p_func: 20 | return p_func 21 | 22 | p_func = ir.Function( 23 | self.module, 24 | ir.FunctionType( 25 | return_type, 26 | arguments, 27 | var_arg=var_arg, 28 | ), 29 | function_name, 30 | ) 31 | return p_func 32 | 33 | 34 | def make_print(self, args): 35 | 36 | from .codegen import Value 37 | from . import j_types as j 38 | 39 | p_func = _make_llvm( 40 | self, 41 | "printf", 42 | ir.IntType(64), 43 | [ir.PointerType(ir.IntType(8)), ir.IntType(64)], 44 | var_arg=True, 45 | ) 46 | 47 | s1 = ir.GlobalVariable(self.module, ir.ArrayType(ir.IntType(8), 6), "str_1") 48 | 49 | s1.initializer = ir.Constant( 50 | ir.ArrayType(ir.IntType(8), 6), bytearray("%lld\n\x00", encoding="utf8") 51 | ) 52 | 53 | s2 = self.builder.gep(s1, [self.zero, self.zero]) 54 | 55 | result = self.builder.call(p_func, [s2] + args) 56 | return Value(j.u64, result, None) 57 | -------------------------------------------------------------------------------- /life.py: -------------------------------------------------------------------------------- 1 | from jit import jit, j_types as j 2 | 3 | WIDTH = 80 4 | HEIGHT = 40 5 | 6 | arr_type = j.array(j.u8, (2, WIDTH, HEIGHT)) 7 | arr = arr_type() 8 | 9 | import random 10 | 11 | for n in range(HEIGHT): 12 | for m in range(WIDTH): 13 | arr[0][m][n] = random.random() > 0.8 14 | 15 | # TODO: variable capture for more than one instance of captured variable is broken 16 | 17 | @jit 18 | def life(a: arr_type, world: int): 19 | 20 | current = world 21 | target = 1 - world 22 | 23 | x = 0 24 | y = 0 25 | 26 | z: j.u8 = 0 27 | q: j.u8 = 0 28 | 29 | H = HEIGHT 30 | W = WIDTH 31 | 32 | while y < H: 33 | x = 0 34 | while x < W: 35 | 36 | y1 = y - 1 37 | z = a[current][x][y] 38 | total = 0 39 | 40 | while y1 < y + 2: 41 | x1 = x - 1 42 | while x1 < x + 2: 43 | if x1 == x and y1 == y: 44 | x1 += 1 45 | continue 46 | if a[current][x1 % W][y1 % H] > 0: 47 | total += 1 48 | x1 += 1 49 | y1 += 1 50 | 51 | q = 0 52 | 53 | if z: 54 | if total > 1 and total < 4: 55 | q = 1 56 | else: 57 | if total == 3: 58 | q = 1 59 | 60 | a[target][x][y] = q 61 | 62 | x += 1 63 | 64 | y += 1 65 | 66 | 67 | world = 0 68 | 69 | while True: 70 | 71 | for n in range(HEIGHT): 72 | for m in range(WIDTH): 73 | print("O" if arr[world][m][n] else " ", end="") 74 | print() 75 | 76 | input() 77 | 78 | life(arr, world) 79 | world = 1 - world 80 | 81 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | llvmlite 2 | black -------------------------------------------------------------------------------- /tests/__main__.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import sys 4 | 5 | sys.path.insert(0, ".\\") 6 | 7 | 8 | def main(): 9 | print("Discovering tests.") 10 | tests = unittest.TestLoader().discover(".\\tests", pattern="test_*.py") 11 | print("Starting.") 12 | unittest.TextTestRunner(failfast=True).run(tests) 13 | 14 | 15 | if __name__ == "__main__": 16 | main() 17 | -------------------------------------------------------------------------------- /tests/test_add.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from jit import jit, j_types as j, errors as err, jit_immediate as jl 3 | 4 | 5 | @jit 6 | def add0(a: j.i64, b: j.i64): 7 | return a + b 8 | 9 | 10 | @jit 11 | def add1(a: int): 12 | return a + 1 13 | 14 | 15 | @jit 16 | def add2(a: float): 17 | return a + 1.0 18 | 19 | 20 | @jit 21 | def add3(a: int, b: int): 22 | return a + b 23 | 24 | 25 | @jit 26 | def add4(a: float, b: float): 27 | return a + b 28 | 29 | 30 | @jit 31 | def add(a, b): 32 | return a + b 33 | 34 | 35 | class Test(unittest.TestCase): 36 | def test_add(self): 37 | self.assertEqual(add1(1), 2) 38 | self.assertEqual(add2(1.0), 2.0) 39 | self.assertEqual(add3(1, 1), 2) 40 | self.assertEqual(add4(1.0, 1.0), 2.0) 41 | 42 | def test_add_err(self): 43 | with self.assertRaises(err.JitTypeError): 44 | add(2, 2) 45 | with self.assertRaises(err.JitTypeError): 46 | add0(2.0, 2.0) 47 | -------------------------------------------------------------------------------- /tests/test_and.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from jit import jit 3 | 4 | 5 | @jit 6 | def andtest(): 7 | x = 1 8 | y = 2 9 | if x == 1 and y == 2: 10 | return 1 11 | 12 | 13 | class Test(unittest.TestCase): 14 | def test_and(self): 15 | self.assertEqual(andtest(), True) 16 | -------------------------------------------------------------------------------- /tests/test_arrays.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from jit import jit, j_types as j, errors as err, jit_immediate as jl 3 | 4 | # type definition 5 | arr = j.array(j.u8, (2, 80, 25)) 6 | # instance 7 | x = arr() 8 | 9 | x[0][0][0] = 5 10 | x[1][79][24] = 32 11 | 12 | # array is automatically passed by reference 13 | @jit 14 | def main(a: arr): 15 | a[0][0][0] = 100 16 | return a 17 | 18 | 19 | @jit 20 | def main2(a: arr): 21 | a[1][79][24] = 32 22 | 23 | 24 | @jit 25 | def main3(a: arr): 26 | xx1 = -1 % 80 27 | if a[1][xx1][24] > 0: 28 | return 1 29 | return 0 30 | 31 | 32 | @jit 33 | def main4(a: arr): 34 | xx = 1 35 | a[1][79][24] = xx 36 | 37 | 38 | class Test(unittest.TestCase): 39 | def test_array(self): 40 | y = main(x) 41 | self.assertEqual(y[0][0][0], 100) 42 | self.assertEqual(y[0][0][1], 0) 43 | self.assertEqual(y[1][79][24], 32) 44 | 45 | def test_array2(self): 46 | x[0][0][0] = 5 47 | x[1][79][24] = 32 48 | main(x) 49 | self.assertEqual(x[0][0][0], 100) 50 | self.assertEqual(x[0][0][1], 0) 51 | self.assertEqual(x[1][79][24], 32) 52 | 53 | def test_array3(self): 54 | x[0][0][0] = 5 55 | # x[1][79][24] = 32 56 | main2(x) 57 | self.assertEqual(x[0][0][0], 5) 58 | self.assertEqual(x[0][0][1], 0) 59 | self.assertEqual(x[1][79][24], 32) 60 | 61 | def test_array4(self): 62 | x[1][79][24] = 0 63 | self.assertEqual(main3(x), False) 64 | x[1][79][24] = 1 65 | self.assertEqual(main3(x), True) 66 | with self.assertRaises(err.JitTypeError): 67 | main4(x) 68 | -------------------------------------------------------------------------------- /tests/test_bitshift.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from jit import jit 3 | 4 | 5 | @jit 6 | def sh1(): 7 | x = 2 8 | return x << 2 9 | 10 | 11 | @jit 12 | def sh2(): 13 | x = 8 14 | return x >> 2 15 | 16 | 17 | class Test(unittest.TestCase): 18 | def test_bitshift(self): 19 | self.assertEqual(sh1(), 8) 20 | self.assertEqual(sh2(), 2) 21 | -------------------------------------------------------------------------------- /tests/test_coerce_bool.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from jit import jit 3 | 4 | 5 | @jit 6 | def coerce_from_int(): 7 | x = 5 8 | if x: 9 | return 1 10 | else: 11 | return 0 12 | 13 | 14 | @jit 15 | def coerce_from_int2(): 16 | x = 0 17 | if x: 18 | return 1 19 | else: 20 | return 0 21 | 22 | 23 | @jit 24 | def coerce_from_float(): 25 | x = 1.0 26 | if x: 27 | return 1 28 | else: 29 | return 0 30 | 31 | 32 | @jit 33 | def coerce_from_float2(): 34 | x = 0.0 35 | if x: 36 | return 1 37 | else: 38 | return 0 39 | 40 | 41 | class Test(unittest.TestCase): 42 | def test_void_zero(self): 43 | self.assertEqual(coerce_from_int(), True) 44 | self.assertEqual(coerce_from_int2(), False) 45 | self.assertEqual(coerce_from_float(), True) 46 | self.assertEqual(coerce_from_float2(), False) 47 | -------------------------------------------------------------------------------- /tests/test_conditionals.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from jit import jit 3 | 4 | 5 | @jit 6 | def cond1(): 7 | x = 0 8 | while x < 10: 9 | x = x + 1 10 | if x == 5: 11 | break 12 | return x 13 | 14 | 15 | @jit 16 | def cond2(): 17 | x = 0 18 | while x < 10: 19 | x = x + 1 20 | return x 21 | 22 | 23 | @jit 24 | def cond3(): 25 | x = 0 26 | y = 0 27 | if x == 0: 28 | y = 1 29 | elif x > 0: 30 | y = 2 31 | else: 32 | y = 3 33 | return y 34 | 35 | 36 | @jit 37 | def cond4(): 38 | x = 5 39 | y = 0 40 | if x == 0: 41 | y = 1 42 | elif x > 0: 43 | y = 2 44 | else: 45 | y = 3 46 | return y 47 | 48 | 49 | @jit 50 | def cond5(): 51 | x = -1 52 | y = 0 53 | if x == 0: 54 | y = 1 55 | elif x > 0: 56 | y = 2 57 | else: 58 | y = 3 59 | return y 60 | 61 | 62 | class Test(unittest.TestCase): 63 | def test_conditionals(self): 64 | self.assertEqual(cond1(), 5) 65 | self.assertEqual(cond2(), 10) 66 | self.assertEqual(cond3(), 1) 67 | self.assertEqual(cond4(), 2) 68 | self.assertEqual(cond5(), 3) 69 | -------------------------------------------------------------------------------- /tests/test_eq_neq.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from jit import jit 3 | 4 | 5 | @jit 6 | def eq1(): 7 | x = 1 8 | return x == 1 9 | 10 | 11 | @jit 12 | def eq2(): 13 | x = 1 14 | return x == 2 15 | 16 | 17 | @jit 18 | def eq3(): 19 | x = 1.0 20 | return x == 1.0 21 | 22 | 23 | @jit 24 | def eq4(): 25 | x = 1.0 26 | return x == 2.0 27 | 28 | 29 | @jit 30 | def neq1(): 31 | x = 1 32 | return x != 0 33 | 34 | 35 | @jit 36 | def neq2(): 37 | x = 1 38 | return x != 1 39 | 40 | 41 | @jit 42 | def neq3(): 43 | x = 1.0 44 | return x != 2.0 45 | 46 | 47 | @jit 48 | def neq4(): 49 | x = 1.0 50 | return x != 1.0 51 | 52 | 53 | class Test(unittest.TestCase): 54 | def test_eq_neq(self): 55 | self.assertEqual(eq1(), True) 56 | self.assertEqual(eq2(), False) 57 | self.assertEqual(eq3(), True) 58 | self.assertEqual(eq4(), False) 59 | self.assertEqual(neq1(), True) 60 | self.assertEqual(neq2(), False) 61 | self.assertEqual(neq3(), True) 62 | self.assertEqual(neq4(), False) 63 | -------------------------------------------------------------------------------- /tests/test_gt.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from jit import jit 3 | 4 | 5 | @jit 6 | def gt1(): 7 | x = 1 8 | return x > 0 9 | 10 | 11 | @jit 12 | def gt2(): 13 | x = 1 14 | return x > 1 15 | 16 | 17 | @jit 18 | def gt3(): 19 | x = 1.0 20 | return x > 0.0 21 | 22 | 23 | @jit 24 | def gt4(): 25 | x = 1.0 26 | return x > 1.0 27 | 28 | 29 | @jit 30 | def gt5(): 31 | x = 0 32 | return x >= 0 33 | 34 | 35 | @jit 36 | def gt6(): 37 | x = -1 38 | return x >= 1 39 | 40 | 41 | @jit 42 | def gt7(): 43 | x = 1.0 44 | return x >= 1.0 45 | 46 | 47 | @jit 48 | def gt8(): 49 | x = 1.0 50 | return x >= 2.0 51 | 52 | 53 | class Test(unittest.TestCase): 54 | def test_gt(self): 55 | self.assertEqual(gt1(), True) 56 | self.assertEqual(gt2(), False) 57 | self.assertEqual(gt3(), True) 58 | self.assertEqual(gt4(), False) 59 | self.assertEqual(gt5(), True) 60 | self.assertEqual(gt6(), False) 61 | self.assertEqual(gt7(), True) 62 | self.assertEqual(gt8(), False) 63 | -------------------------------------------------------------------------------- /tests/test_inference.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from jit import jit, j_types as j 3 | import ctypes 4 | 5 | 6 | @jit 7 | def inf1(a: j.f64): 8 | return a + 2 9 | 10 | 11 | @jit 12 | def inf2(a: j.i32): 13 | return a + 2 14 | 15 | 16 | class Test(unittest.TestCase): 17 | def test_inference(self): 18 | 19 | self.assertEqual(inf1(2), 4.0) 20 | self.assertEqual(inf2(2), 4) 21 | 22 | self.assertEqual(inf1._wrapped._jit.restype, ctypes.c_double) 23 | self.assertEqual(inf2._wrapped._jit.restype, ctypes.c_int32) 24 | -------------------------------------------------------------------------------- /tests/test_lt.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from jit import jit 3 | 4 | 5 | @jit 6 | def lt1(): 7 | x = 1 8 | return x < 2 9 | 10 | 11 | @jit 12 | def lt2(): 13 | x = 1 14 | return x < 0 15 | 16 | 17 | @jit 18 | def lt3(): 19 | x = 1.0 20 | return x < 2.0 21 | 22 | 23 | @jit 24 | def lt4(): 25 | x = 1.0 26 | return x < 0.0 27 | 28 | 29 | @jit 30 | def lt5(): 31 | x = 2 32 | return x <= 2 33 | 34 | 35 | @jit 36 | def lt6(): 37 | x = 1 38 | return x <= 0 39 | 40 | 41 | @jit 42 | def lt7(): 43 | x = 1.0 44 | return x <= 2.0 45 | 46 | 47 | @jit 48 | def lt8(): 49 | x = 1.0 50 | return x <= 0.0 51 | 52 | 53 | class Test(unittest.TestCase): 54 | def test_lt(self): 55 | self.assertEqual(lt1(), True) 56 | self.assertEqual(lt2(), False) 57 | self.assertEqual(lt3(), True) 58 | self.assertEqual(lt4(), False) 59 | self.assertEqual(lt5(), True) 60 | self.assertEqual(lt6(), False) 61 | self.assertEqual(lt7(), True) 62 | self.assertEqual(lt8(), False) 63 | -------------------------------------------------------------------------------- /tests/test_modulo.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from jit import jit 3 | 4 | 5 | @jit 6 | def mod1(): 7 | x = 4 8 | return x % 4 9 | 10 | 11 | @jit 12 | def mod2(): 13 | x = 0 14 | return x % 4 15 | 16 | 17 | @jit 18 | def mod3(): 19 | x = 1 20 | return x % 4 21 | 22 | 23 | @jit 24 | def mod4(): 25 | x = 2 26 | return x % 4 27 | 28 | 29 | @jit 30 | def mod5(): 31 | x = -1 32 | return x % 4 33 | 34 | 35 | @jit 36 | def mod6(): 37 | x = 7 38 | return x % 4 39 | 40 | 41 | class Test(unittest.TestCase): 42 | def test_modulo(self): 43 | self.assertEqual(mod1(), 0) 44 | self.assertEqual(mod2(), 0) 45 | self.assertEqual(mod3(), 1) 46 | self.assertEqual(mod4(), 2) 47 | self.assertEqual(mod5(), 3) 48 | self.assertEqual(mod6(), 3) 49 | -------------------------------------------------------------------------------- /tests/test_mul_div.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from jit import jit 3 | 4 | 5 | @jit 6 | def mul1(): 7 | x = 2 8 | return x * 2 9 | 10 | 11 | @jit 12 | def mul2(): 13 | x = 2.0 14 | return x * 2.5 15 | 16 | 17 | @jit 18 | def div1(): 19 | x = 4 20 | return x / 2 21 | 22 | 23 | @jit 24 | def div2(): 25 | x = 4.0 26 | return x / 2.5 27 | 28 | 29 | class Test(unittest.TestCase): 30 | def test_mul_div(self): 31 | self.assertEqual(mul1(), 4) 32 | self.assertEqual(mul2(), 5.0) 33 | self.assertEqual(div1(), 2) 34 | self.assertEqual(div2(), 1.6) 35 | -------------------------------------------------------------------------------- /tests/test_negation.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from jit import jit 3 | 4 | 5 | @jit 6 | def neg1(): 7 | x = 2 8 | return -x 9 | 10 | 11 | @jit 12 | def neg2(): 13 | x = 2.0 14 | return -x 15 | 16 | 17 | class Test(unittest.TestCase): 18 | def test_negation(self): 19 | self.assertEqual(neg1(), -2) 20 | self.assertEqual(neg2(), -2.0) 21 | -------------------------------------------------------------------------------- /tests/test_print.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from jit import jit 3 | from jit import j_types as j 4 | 5 | 6 | @jit 7 | def test_print(x: j.i64): 8 | return print(x) 9 | 10 | 11 | class Test(unittest.TestCase): 12 | def test_void_zero(self): 13 | self.assertEqual(test_print(8), 2) 14 | self.assertEqual(test_print(64), 3) 15 | -------------------------------------------------------------------------------- /tests/test_sub.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from jit import jit 3 | 4 | 5 | @jit 6 | def sub1(): 7 | x = 3 8 | return x - 1 9 | 10 | 11 | @jit 12 | def sub2(): 13 | x = 3.0 14 | return x - 1.5 15 | 16 | 17 | class Test(unittest.TestCase): 18 | def test_sub(self): 19 | self.assertEqual(sub1(), 2) 20 | self.assertEqual(sub2(), 1.5) 21 | -------------------------------------------------------------------------------- /tests/test_void_zero.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from jit import jit 3 | 4 | 5 | @jit 6 | def void(): 7 | return 8 | 9 | 10 | @jit 11 | def zero(): 12 | return 0 13 | 14 | 15 | class Test(unittest.TestCase): 16 | def test_void_zero(self): 17 | self.assertEqual(void(), None) 18 | self.assertEqual(zero(), 0) 19 | -------------------------------------------------------------------------------- /todo.md: -------------------------------------------------------------------------------- 1 | # TODO 2 | * signed/unsigned behaviors 3 | * iterator (see above) 4 | * * `range` function 5 | * iterator object 6 | * objects and bound methods 7 | * ability to mangle names for things 8 | * special name mangling decorator 9 | * store items to mangle in dict, reference by object 10 | * implement name mangling 11 | * `exec()` function calls a mangled function with arguments passed 12 | * `gep()` 13 | * `ref()`, `deref()` 14 | * the Python versions of these operate on their Py object counterparts if possible 15 | * convert to loadable C modules, obviate need for interface? 16 | * programmatic generation of tests 17 | 18 | --------------------------------------------------------------------------------