├── .gitignore ├── README.md ├── example_fib.py ├── example_find_primes.py ├── example_generic.py ├── example_is_prime.py ├── example_loop.py ├── example_while.py ├── pyjiting ├── __init__.py ├── ast.py ├── codegen.py ├── infer.py ├── ll_types.py ├── main.py ├── parser.py ├── types.py └── utils.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,pycharm 3 | # Edit at https://www.toptal.com/developers/gitignore?templates=python,visualstudiocode,pycharm 4 | 5 | ### PyCharm ### 6 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 7 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 8 | 9 | # User-specific stuff 10 | .idea/**/workspace.xml 11 | .idea/**/tasks.xml 12 | .idea/**/usage.statistics.xml 13 | .idea/**/dictionaries 14 | .idea/**/shelf 15 | 16 | # AWS User-specific 17 | .idea/**/aws.xml 18 | 19 | # Generated files 20 | .idea/**/contentModel.xml 21 | 22 | # Sensitive or high-churn files 23 | .idea/**/dataSources/ 24 | .idea/**/dataSources.ids 25 | .idea/**/dataSources.local.xml 26 | .idea/**/sqlDataSources.xml 27 | .idea/**/dynamic.xml 28 | .idea/**/uiDesigner.xml 29 | .idea/**/dbnavigator.xml 30 | 31 | # Gradle 32 | .idea/**/gradle.xml 33 | .idea/**/libraries 34 | 35 | # Gradle and Maven with auto-import 36 | # When using Gradle or Maven with auto-import, you should exclude module files, 37 | # since they will be recreated, and may cause churn. Uncomment if using 38 | # auto-import. 39 | # .idea/artifacts 40 | # .idea/compiler.xml 41 | # .idea/jarRepositories.xml 42 | # .idea/modules.xml 43 | # .idea/*.iml 44 | # .idea/modules 45 | # *.iml 46 | # *.ipr 47 | 48 | # CMake 49 | cmake-build-*/ 50 | 51 | # Mongo Explorer plugin 52 | .idea/**/mongoSettings.xml 53 | 54 | # File-based project format 55 | *.iws 56 | 57 | # IntelliJ 58 | out/ 59 | 60 | # mpeltonen/sbt-idea plugin 61 | .idea_modules/ 62 | 63 | # JIRA plugin 64 | atlassian-ide-plugin.xml 65 | 66 | # Cursive Clojure plugin 67 | .idea/replstate.xml 68 | 69 | # Crashlytics plugin (for Android Studio and IntelliJ) 70 | com_crashlytics_export_strings.xml 71 | crashlytics.properties 72 | crashlytics-build.properties 73 | fabric.properties 74 | 75 | # Editor-based Rest Client 76 | .idea/httpRequests 77 | 78 | # Android studio 3.1+ serialized cache file 79 | .idea/caches/build_file_checksums.ser 80 | 81 | ### PyCharm Patch ### 82 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 83 | 84 | # *.iml 85 | # modules.xml 86 | # .idea/misc.xml 87 | # *.ipr 88 | 89 | # Sonarlint plugin 90 | # https://plugins.jetbrains.com/plugin/7973-sonarlint 91 | .idea/**/sonarlint/ 92 | 93 | # SonarQube Plugin 94 | # https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin 95 | .idea/**/sonarIssues.xml 96 | 97 | # Markdown Navigator plugin 98 | # https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced 99 | .idea/**/markdown-navigator.xml 100 | .idea/**/markdown-navigator-enh.xml 101 | .idea/**/markdown-navigator/ 102 | 103 | # Cache file creation bug 104 | # See https://youtrack.jetbrains.com/issue/JBR-2257 105 | .idea/$CACHE_FILE$ 106 | 107 | # CodeStream plugin 108 | # https://plugins.jetbrains.com/plugin/12206-codestream 109 | .idea/codestream.xml 110 | 111 | ### Python ### 112 | # Byte-compiled / optimized / DLL files 113 | __pycache__/ 114 | *.py[cod] 115 | *$py.class 116 | 117 | # C extensions 118 | *.so 119 | 120 | # Distribution / packaging 121 | .Python 122 | build/ 123 | develop-eggs/ 124 | dist/ 125 | downloads/ 126 | eggs/ 127 | .eggs/ 128 | lib/ 129 | lib64/ 130 | parts/ 131 | sdist/ 132 | var/ 133 | wheels/ 134 | share/python-wheels/ 135 | *.egg-info/ 136 | .installed.cfg 137 | *.egg 138 | MANIFEST 139 | 140 | # PyInstaller 141 | # Usually these files are written by a python script from a template 142 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 143 | *.manifest 144 | *.spec 145 | 146 | # Installer logs 147 | pip-log.txt 148 | pip-delete-this-directory.txt 149 | 150 | # Unit test / coverage reports 151 | htmlcov/ 152 | .tox/ 153 | .nox/ 154 | .coverage 155 | .coverage.* 156 | .cache 157 | nosetests.xml 158 | coverage.xml 159 | *.cover 160 | *.py,cover 161 | .hypothesis/ 162 | .pytest_cache/ 163 | cover/ 164 | 165 | # Translations 166 | *.mo 167 | *.pot 168 | 169 | # Django stuff: 170 | *.log 171 | local_settings.py 172 | db.sqlite3 173 | db.sqlite3-journal 174 | 175 | # Flask stuff: 176 | instance/ 177 | .webassets-cache 178 | 179 | # Scrapy stuff: 180 | .scrapy 181 | 182 | # Sphinx documentation 183 | docs/_build/ 184 | 185 | # PyBuilder 186 | .pybuilder/ 187 | target/ 188 | 189 | # Jupyter Notebook 190 | .ipynb_checkpoints 191 | 192 | # IPython 193 | profile_default/ 194 | ipython_config.py 195 | 196 | # pyenv 197 | # For a library or package, you might want to ignore these files since the code is 198 | # intended to run in multiple environments; otherwise, check them in: 199 | # .python-version 200 | 201 | # pipenv 202 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 203 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 204 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 205 | # install all needed dependencies. 206 | #Pipfile.lock 207 | 208 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 209 | __pypackages__/ 210 | 211 | # Celery stuff 212 | celerybeat-schedule 213 | celerybeat.pid 214 | 215 | # SageMath parsed files 216 | *.sage.py 217 | 218 | # Environments 219 | .env 220 | .venv 221 | env/ 222 | venv/ 223 | ENV/ 224 | env.bak/ 225 | venv.bak/ 226 | 227 | # Spyder project settings 228 | .spyderproject 229 | .spyproject 230 | 231 | # Rope project settings 232 | .ropeproject 233 | 234 | # mkdocs documentation 235 | /site 236 | 237 | # mypy 238 | .mypy_cache/ 239 | .dmypy.json 240 | dmypy.json 241 | 242 | # Pyre type checker 243 | .pyre/ 244 | 245 | # pytype static type analyzer 246 | .pytype/ 247 | 248 | # Cython debug symbols 249 | cython_debug/ 250 | 251 | ### VisualStudioCode ### 252 | .vscode/* 253 | !.vscode/settings.json 254 | !.vscode/tasks.json 255 | !.vscode/launch.json 256 | !.vscode/extensions.json 257 | *.code-workspace 258 | 259 | # Local History for Visual Studio Code 260 | .history/ 261 | 262 | ### VisualStudioCode Patch ### 263 | # Ignore all local history of files 264 | .history 265 | .ionide 266 | 267 | # Support for Project snippet scope 268 | !.vscode/*.code-snippets 269 | 270 | # End of https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,pycharm -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pyjiting 2 | 3 | Pyjiting is a experimental Python-JIT compiler, which is the product of my undergraduate thesis. The goal is to implement a light-weight miniature general-purpose Python JIT compiler. 4 | 5 | ## Functions that have been implemented so far 6 | 7 | 1. Use llvmlite as Backend, support Python 3.9 and up, fix some errors in the numpile tutorial. 8 | 2. Support calls the Python function(code in `.py`, need manually register, just add `@reg`) from the JITed function(llvm binary) based on TypeHints. (See `example_find_primes.py`) 9 | 3. Implement basic functions, such as Compare operators, Mathematical operators, etc. 10 | 4. Implement `if` expressions. 11 | 5. Implement `for` expressions, and allow `break`. 12 | 6. Implement `while` expressions. 13 | 7. Implement recursion. 14 | 15 | 16 | ## Performance 17 | 18 | You can find the source code of these test samples in the root directory. 19 | 20 | ``` 21 | My test environment: 22 | CPU: i7-8700K@4.8Ghz 23 | Memory: 32GB DDR4 3200Mhz 24 | OS: Windows 10 21H2 64bit 25 | Python 3.9.9 64bit 26 | LLVMLite 0.37.0 27 | ``` 28 | 29 | ``` 30 | fib_jit(40) = 102334155 (cost time: 169.3246364593506 ms) 31 | fib_nojit(40) = 102334155 (cost time: 26464.3042087554929 ms) 32 | rate: 156.29328821921743 33 | 34 | find_primes_jit(100000) = 0 (cost time: 3358.8650226593018 ms) 35 | find_primes_nojit(100000) = 0 (cost time: 37986.81926727295 ms) 36 | rate: 11.309421191685097 37 | 38 | is_prime_jit(169941229) = True (cost time: 958.0211639404297 ms) 39 | is_prime_nojit(169941229) = True (cost time: 13279.908418655396 ms) 40 | rate: 13.861811114938112 41 | 42 | loop_jit(100000000) = 200000000 (cost time: 0.0 ms) 43 | loop_nojit(100000000) = 200000000 (cost time: 7256.179332733154 ms) 44 | rate: Infinite 45 | 46 | test_while_jit(100000000) = 100000000 (cost time: 0.0 ms) 47 | test_while_nojit(100000000) = 100000000 (cost time: 7198.246292114258 ms) 48 | rate: Infinite 49 | ``` 50 | 51 | # Special thanks 52 | 53 | Inspired by [numpile](https://dev.stephendiehl.com/numpile/) tutorial and continue to work on this basis. 54 | 55 | I am deeply grateful to professor Mr Takeshi Ogasawara gave me many inspirations and appropriate advice. 56 | -------------------------------------------------------------------------------- /example_fib.py: -------------------------------------------------------------------------------- 1 | # Fibonacci is mainly used to show the performance of recursion. 2 | from pyjiting import jit 3 | import time 4 | 5 | 6 | @jit 7 | def fib_jit(x): 8 | if x < 3: 9 | return 1 10 | return fib_jit(x-1) + fib_jit(x-2) 11 | 12 | 13 | def fib_nojit(x): 14 | if x < 3: 15 | return 1 16 | return fib_nojit(x-1) + fib_nojit(x-2) 17 | 18 | 19 | fib_jit(0) 20 | fib_nojit(0) 21 | 22 | 23 | start_time = time.time() 24 | result = fib_jit(40) 25 | cost_time_ms = (time.time() - start_time) * 1000 26 | print('fib_jit(40) =', result, f'(cost time: {cost_time_ms} ms)') 27 | 28 | start_time = time.time() 29 | result = fib_nojit(40) 30 | cost_time_ms_nojit = (time.time() - start_time) * 1000 31 | print('fib_nojit(40) =', result, f'(cost time: {cost_time_ms_nojit} ms)') 32 | print('rate:', 'Infinite' if cost_time_ms == 0 else cost_time_ms_nojit / cost_time_ms) 33 | -------------------------------------------------------------------------------- /example_find_primes.py: -------------------------------------------------------------------------------- 1 | # Find primes tests the performance of calling functions that are not been jited (Registered python code) in jited functions 2 | from pyjiting import jit, reg 3 | import time 4 | 5 | all_primes = [] 6 | all_primes_nojit = [] 7 | 8 | # jit 9 | @reg 10 | def do_something_jit(x: int) -> int: 11 | print(f'{x} is prime!') 12 | all_primes.append(x) 13 | return 0 14 | 15 | 16 | @jit 17 | def find_primes_jit(n): 18 | for i in range(2, n): 19 | is_prime = True 20 | for j in range(2, i): 21 | if i % j == 0: 22 | is_prime = False 23 | break 24 | if is_prime == True: 25 | do_something_jit(i) 26 | return 0 27 | 28 | # nojit 29 | def do_something_nojit(x: int) -> int: 30 | print(f'{x} is prime!') 31 | all_primes_nojit.append(x) 32 | return 0 33 | 34 | 35 | def find_primes_nojit(n): 36 | for i in range(2, n): 37 | is_prime = True 38 | for j in range(2, i): 39 | if i % j == 0: 40 | is_prime = False 41 | break 42 | if is_prime == True: 43 | do_something_nojit(i) 44 | return 0 45 | 46 | 47 | find_primes_nojit(0) 48 | find_primes_jit(0) 49 | 50 | start_time = time.time() 51 | result = find_primes_jit(100000) 52 | cost_time_ms = (time.time() - start_time) * 1000 53 | jit_cost = cost_time_ms 54 | 55 | 56 | start_time = time.time() 57 | result = find_primes_nojit(100000) 58 | cost_time_ms = (time.time() - start_time) * 1000 59 | nojit_cost = cost_time_ms 60 | 61 | print('find_primes_jit(100000) =', result, f'(cost time: {jit_cost} ms)') 62 | print('find_primes_nojit(100000) =', result, f'(cost time: {nojit_cost} ms)') 63 | print('rate:', nojit_cost / jit_cost) -------------------------------------------------------------------------------- /example_generic.py: -------------------------------------------------------------------------------- 1 | # This test is used to test pyjiting support for dynamic types. 2 | from pyjiting import jit 3 | 4 | 5 | @jit 6 | def test(a, b): 7 | return a + b 8 | 9 | 10 | print(test(114, 514)) 11 | print(test(11.4, 51.4)) 12 | -------------------------------------------------------------------------------- /example_is_prime.py: -------------------------------------------------------------------------------- 1 | # Is prime number test including loop and if statement test, math and compare operation test. 2 | from pyjiting import jit 3 | import time 4 | 5 | 6 | @jit 7 | def is_prime_jit(x: int) -> int: 8 | for i in range(2, x): 9 | if x % i == 0: 10 | return 0 11 | return 1 12 | 13 | 14 | def is_prime_nojit(x: int) -> int: 15 | for i in range(2, x): 16 | if x % i == 0: 17 | return 0 18 | return 1 19 | 20 | 21 | is_prime_jit(1) 22 | is_prime_nojit(1) 23 | 24 | start_time = time.time() 25 | result = bool(is_prime_jit(169941229)) 26 | cost_time_ms = (time.time() - start_time) * 1000 27 | print('is_prime_jit(169941229) =', result, f'(cost time: {cost_time_ms} ms)') 28 | 29 | start_time = time.time() 30 | result = bool(is_prime_nojit(169941229)) 31 | cost_time_ms_nojit = (time.time() - start_time) * 1000 32 | print('is_prime_nojit(169941229) =', result, f'(cost time: {cost_time_ms_nojit} ms)') 33 | print('rate:', 'Infinite' if cost_time_ms == 0 else cost_time_ms_nojit / cost_time_ms) 34 | -------------------------------------------------------------------------------- /example_loop.py: -------------------------------------------------------------------------------- 1 | # This is a pure loop performance test 2 | from pyjiting import jit 3 | import time 4 | 5 | 6 | @jit 7 | def loop_jit(n): 8 | for _ in range(n): 9 | n += 1 10 | return n 11 | 12 | 13 | def loop_nojit(n): 14 | for _ in range(n): 15 | n += 1 16 | return n 17 | 18 | 19 | loop_jit(0) 20 | loop_nojit(0) 21 | 22 | 23 | start_time = time.time() 24 | result = loop_jit(100000000) 25 | cost_time_ms = (time.time() - start_time) * 1000 26 | print('loop_jit(100000000) =', result, f'(cost time: {cost_time_ms} ms)') 27 | 28 | start_time = time.time() 29 | result = loop_nojit(100000000) 30 | cost_time_ms_nojit = (time.time() - start_time) * 1000 31 | print('loop_nojit(100000000) =', result, f'(cost time: {cost_time_ms_nojit} ms)') 32 | print('rate:', 'Infinite' if cost_time_ms == 0 else cost_time_ms_nojit / cost_time_ms) 33 | -------------------------------------------------------------------------------- /example_while.py: -------------------------------------------------------------------------------- 1 | # Test while 2 | from pyjiting import jit 3 | 4 | import time 5 | 6 | 7 | def test_while(x: int) -> int: 8 | res = 0 9 | while res < x: 10 | res = res + 1 11 | return res 12 | 13 | 14 | test_while_jit = jit(test_while) 15 | 16 | test_while_jit(0) 17 | test_while(0) 18 | 19 | start_time = time.time() 20 | result = test_while_jit(100000000) 21 | cost_time_ms = (time.time() - start_time) * 1000 22 | print('test_while_jit(100000000) =', result, 23 | f'(cost time: {cost_time_ms} ms)') 24 | 25 | start_time = time.time() 26 | result = test_while(100000000) 27 | cost_time_ms_nojit = (time.time() - start_time) * 1000 28 | print('test_while_nojit(100000000) =', result, 29 | f'(cost time: {cost_time_ms_nojit} ms)') 30 | print('rate:', 'Infinite' if cost_time_ms == 0 else cost_time_ms_nojit / cost_time_ms) 31 | -------------------------------------------------------------------------------- /pyjiting/__init__.py: -------------------------------------------------------------------------------- 1 | from .main import jit, reg 2 | -------------------------------------------------------------------------------- /pyjiting/ast.py: -------------------------------------------------------------------------------- 1 | import ast 2 | 3 | ''' 4 | Core AST Define 5 | ''' 6 | 7 | 8 | class Var(ast.AST): 9 | _fields = ['id', 'type'] 10 | 11 | def __init__(self, id, type=None): 12 | self.id = id 13 | self.type = type 14 | 15 | 16 | class Assign(ast.AST): 17 | _fields = ['ref', 'value', 'type'] 18 | 19 | def __init__(self, ref, value, type=None): 20 | self.ref = ref 21 | self.value = value 22 | self.type = type 23 | 24 | 25 | class Return(ast.AST): 26 | _fields = ['value'] 27 | 28 | def __init__(self, value): 29 | self.value = value 30 | 31 | 32 | class Loop(ast.AST): 33 | _fields = ['var', 'begin', 'end', 'body', 'step'] 34 | 35 | def __init__(self, var, begin, end, body, step): 36 | self.var = var 37 | self.begin = begin 38 | self.end = end 39 | self.body = body 40 | self.step = step 41 | 42 | 43 | class If(ast.AST): 44 | _fields = ['test', 'body', 'orelse'] 45 | 46 | def __init__(self, test, body, orelse): 47 | self.test = test 48 | self.body = body 49 | self.orelse = orelse 50 | 51 | 52 | class While(ast.AST): 53 | _fields = ['test', 'body', 'orelse'] 54 | 55 | def __init__(self, test, body, orelse): 56 | self.test = test 57 | self.body = body 58 | self.orelse = orelse 59 | 60 | 61 | class Compare(ast.AST): 62 | _fields = ['left', 'ops', 'comparators'] 63 | 64 | def __init__(self, left, ops, comparators): 65 | self.left = left 66 | self.ops = ops 67 | self.comparators = comparators 68 | 69 | 70 | class CallFunc(ast.AST): 71 | _fields = ['fn', 'args'] 72 | 73 | def __init__(self, fn, args): 74 | self.fn = fn 75 | self.args = args 76 | 77 | 78 | class Fun(ast.AST): 79 | _fields = ['fname', 'args', 'body'] 80 | 81 | def __init__(self, fname, args, body): 82 | self.fname = fname 83 | self.args = args 84 | self.body = body 85 | 86 | 87 | class LitInt(ast.AST): 88 | _fields = ['n'] 89 | 90 | def __init__(self, n, type=None): 91 | self.n = int(n) 92 | self.type = type 93 | 94 | 95 | class LitFloat(ast.AST): 96 | _fields = ['n'] 97 | 98 | def __init__(self, n, type=None): 99 | self.n = float(n) 100 | self.type = None 101 | 102 | 103 | class LitBool(ast.AST): 104 | _fields = ['n'] 105 | 106 | def __init__(self, n): 107 | self.n = bool(n) 108 | 109 | 110 | class Prim(ast.AST): 111 | _fields = ['fn', 'args'] 112 | 113 | def __init__(self, fn, args): 114 | self.fn = fn 115 | self.args = args 116 | 117 | 118 | class Const(ast.AST): 119 | _fields = ['value'] 120 | 121 | def __init__(self, value): 122 | self.value = value 123 | 124 | 125 | class Index(ast.AST): 126 | _fields = ['value', 'ix'] 127 | 128 | def __init__(self, value, ix): 129 | self.value = value 130 | self.ix = ix 131 | 132 | 133 | class Expr(ast.AST): 134 | _fields = ['value'] 135 | 136 | def __init__(self, value): 137 | self.value = value 138 | 139 | 140 | class Noop(ast.AST): 141 | _fields = [] 142 | 143 | 144 | class Break(ast.AST): 145 | _fields = [] 146 | 147 | 148 | PRIM_OPS = { 149 | ast.Add: 'add#', 150 | ast.Mult: 'mult#', 151 | ast.Sub: 'sub#', 152 | ast.Div: 'div#', 153 | ast.Pow: 'pow#', 154 | ast.Mod: 'mod#', 155 | ast.And: 'and#', 156 | ast.Or: 'or#', 157 | ast.Eq: 'eq#', 158 | ast.NotEq: 'ne#', 159 | ast.Lt: 'lt#', 160 | ast.LtE: 'le#', 161 | ast.Gt: 'gt#', 162 | ast.GtE: 'ge#' 163 | } 164 | 165 | LLVM_PRIM_OPS = list(PRIM_OPS.values()) 166 | -------------------------------------------------------------------------------- /pyjiting/codegen.py: -------------------------------------------------------------------------------- 1 | import ast 2 | from collections import defaultdict 3 | from ctypes import c_int64, c_void_p 4 | import ctypes 5 | 6 | import llvmlite.llvmpy.core as lc 7 | from llvmlite import ir 8 | 9 | from pyjiting.ll_types import mangler 10 | 11 | from .ast import (LLVM_PRIM_OPS, Assign, Break, CallFunc, Compare, Const, Expr, Fun, If, Index, 12 | LitFloat, LitInt, Loop, Noop, Prim, Return, Var, While) 13 | from .types import * 14 | 15 | ''' 16 | Codegen is a module that takes an AST and generates LLVM IR. 17 | ''' 18 | 19 | reg_known_func = {} 20 | 21 | ir_ptr_t = ir.PointerType 22 | ir_int32_t = ir.IntType(32) 23 | ir_int64_t = ir.IntType(64) 24 | ir_float_t = ir.FloatType() 25 | ir_double_t = ir.DoubleType() 26 | ir_bool_t = ir.IntType(64) 27 | ir_void_t = ir.VoidType() 28 | ir_void_ptr_t = ir_ptr_t(ir.IntType(8)) 29 | 30 | 31 | def array_type(elt_type): 32 | struct_type = ir.global_context.get_identified_type( 33 | 'ndarray_' + str(elt_type)) 34 | 35 | # The type can already exist. 36 | if struct_type.elements: 37 | return struct_type 38 | 39 | # If not, initialize it. 40 | struct_type.set_body( 41 | ir_ptr_t(elt_type), # data 42 | ir_int32_t, # dimensions 43 | ir_ptr_t(ir_int32_t), # shape 44 | ) 45 | return struct_type 46 | 47 | 48 | ir_int32_array_t = ir_ptr_t(array_type(ir_int32_t)) 49 | ir_int64_array_t = ir_ptr_t(array_type(ir_int64_t)) 50 | ir_double_array_t = ir_ptr_t(array_type(ir_double_t)) 51 | 52 | lltypes_map = { 53 | int32_t: ir_int32_t, 54 | int64_t: ir_int64_t, 55 | bool_t: ir_bool_t, 56 | float32_t: ir_float_t, 57 | double64_t: ir_double_t, 58 | int32_array_t: ir_int32_array_t, 59 | int64_array_t: ir_int64_array_t, 60 | double64_array_t: ir_double_array_t, 61 | void_t: ir_void_t, 62 | } 63 | 64 | 65 | def to_lltype(ptype): 66 | return lltypes_map[ptype] 67 | 68 | 69 | def determined(ty): 70 | return len(ftv(ty)) == 0 71 | 72 | 73 | def reg_func(func_name, func): 74 | reg_known_func[func_name] = func 75 | 76 | 77 | def get_reg_func(func_name): 78 | return reg_known_func.get(func_name, None) 79 | 80 | 81 | def arg_ctype(arg): 82 | if arg == ir_int64_t: 83 | return c_int64 84 | raise RuntimeError('Unsupported type:', arg) 85 | 86 | 87 | def arg_classtype(arg): 88 | if arg is int: 89 | return int64_t 90 | elif arg is float: 91 | return double64_t 92 | else: 93 | raise RuntimeError('Unsupported type:', arg) 94 | 95 | 96 | class LLVMCodeGen(object): 97 | def __init__(self, module, spec_types, return_type, args): 98 | self.module = module # LLVM Module 99 | self.function = None # LLVM Function 100 | self.builder = None # LLVM Builder 101 | self.locals = {} # Local variables 102 | self.arrays = defaultdict(dict) # Array metadata 103 | self.exit_block = None # Exit block 104 | self.spec_types = spec_types # Type specialization 105 | self.return_type = return_type # Return type 106 | self.args = args # Argument types 107 | self.org_func_name = None # Original function name 108 | self.break_block_stack = [] # Break block stack 109 | 110 | def start_function(self, name, module, ir_ret_type, argtypes): 111 | func_type = ir.FunctionType(ir_ret_type, argtypes, False) 112 | function = ir.Function(module, func_type, name) 113 | entry_block = function.append_basic_block('entry') 114 | builder = ir.IRBuilder(entry_block) 115 | self.exit_block = function.append_basic_block('exit') 116 | self.function = function 117 | self.builder = builder 118 | 119 | def end_function(self): 120 | self.builder.position_at_end(self.exit_block) 121 | 122 | if 'retval' in self.locals: 123 | retval = self.builder.load(self.locals['retval']) 124 | self.builder.ret(retval) 125 | else: 126 | self.builder.ret_void() 127 | 128 | def add_block(self, name): 129 | return self.function.append_basic_block(name) 130 | 131 | def set_block(self, block): 132 | self.block = block 133 | self.builder.position_at_end(block) 134 | 135 | def cbranch(self, cond, true_block, false_block): 136 | self.builder.cbranch(cond, true_block, false_block) 137 | 138 | def branch(self, next_block): 139 | self.builder.branch(next_block) 140 | 141 | def specialize(self, value): 142 | if isinstance(value.type, VarType): 143 | return to_lltype(self.spec_types[value.type.s]) 144 | if isinstance(value.type, BaseType): 145 | return to_lltype(value.type) 146 | return to_lltype(value.type) 147 | 148 | def const(self, value): 149 | if value is None: 150 | return ir.Constant(ir_void_t, None) 151 | elif isinstance(value, ir.Constant): 152 | return value 153 | elif isinstance(value, bool): 154 | return ir.Constant(ir_bool_t, int(value)) 155 | elif isinstance(value, int): 156 | return ir.Constant(ir_int64_t, value) 157 | elif isinstance(value, float): 158 | return ir.Constant(ir_double_t, value) 159 | elif isinstance(value, str): 160 | # raise NotImplementedError 161 | return lc.Constant.stringz(value) 162 | else: 163 | print(value, type(value)) 164 | raise NotImplementedError 165 | 166 | def visit_Const(self, node: Const): 167 | return self.const(node.value) 168 | 169 | def visit_LitInt(self, node: LitInt): 170 | ty = self.specialize(node) 171 | if ty is ir_double_t: 172 | return ir.Constant(ir_double_t, node.n) 173 | elif ty == ir_int64_t: 174 | return ir.Constant(ir_int64_t, node.n) 175 | elif ty == ir_int32_t: 176 | return ir.Constant(ir_int32_t, node.n) 177 | 178 | def visit_LitFloat(self, node: LitFloat): 179 | ty = self.specialize(node) 180 | if ty is ir_double_t: 181 | return ir.Constant(ir_double_t, node.n) 182 | elif ty == ir_int64_t: 183 | return ir.Constant(ir_int64_t, node.n) 184 | elif ty == ir_int32_t: 185 | return ir.Constant(ir_int32_t, node.n) 186 | 187 | def visit_Noop(self, node: Noop): 188 | pass 189 | 190 | def visit_Fun(self, node: Fun): 191 | ir_ret_type = to_lltype(self.return_type) 192 | argtypes = list(map(to_lltype, self.args)) 193 | # Create a unique specialized name 194 | func_name = mangler(node.fname, self.args) 195 | self.org_func_name = node.fname 196 | self.start_function(func_name, self.module, ir_ret_type, argtypes) 197 | 198 | for (ar, llarg, argty) in list(zip(node.args, self.function.args, self.args)): 199 | name = ar.id 200 | llarg.name = name 201 | 202 | if is_array(argty): 203 | zero = self.const(0) 204 | one = self.const(1) 205 | two = self.const(2) 206 | 207 | data = self.builder.gep(llarg, [ 208 | zero, zero], name=(name + '_data')) 209 | dims = self.builder.gep(llarg, [ 210 | zero, one], name=(name + '_dims')) 211 | shape = self.builder.gep(llarg, [ 212 | zero, two], name=(name + '_strides')) 213 | 214 | self.arrays[name]['data'] = self.builder.load(data) 215 | self.arrays[name]['dims'] = self.builder.load(dims) 216 | self.arrays[name]['shape'] = self.builder.load(shape) 217 | self.locals[name] = llarg 218 | else: 219 | argref = self.builder.alloca(to_lltype(argty)) 220 | self.builder.store(llarg, argref) 221 | self.locals[name] = argref 222 | 223 | # Setup the register for return type. 224 | if ir_ret_type is not ir_void_t: 225 | self.locals['retval'] = self.builder.alloca( 226 | ir_ret_type, name='retval') 227 | 228 | list(map(self.visit, node.body)) 229 | self.end_function() 230 | 231 | def visit_Index(self, node: Index): 232 | if isinstance(node.value, Var) and node.value.id in self.arrays: 233 | value = self.visit(node.value) 234 | ix = self.visit(node.ix) 235 | dataptr = self.arrays[node.value.id]['data'] 236 | ret = self.builder.gep(dataptr, [ix]) 237 | return self.builder.load(ret) 238 | else: 239 | value = self.visit(node.value) 240 | ix = self.visit(node.ix) 241 | ret = self.builder.gep(value, [ix]) 242 | return self.builder.load(ret) 243 | 244 | def visit_Var(self, node: Var): 245 | return self.builder.load(self.locals[node.id]) 246 | 247 | def visit_Return(self, node: Return): 248 | value = self.visit(node.value) 249 | if value.type != ir_void_t: 250 | self.builder.store(value, self.locals['retval']) 251 | self.builder.branch(self.exit_block) 252 | 253 | def visit_Loop(self, node: Loop): 254 | if not hasattr(self, '_for_counter'): 255 | self._for_counter = 0 256 | self._for_counter += 1 257 | init_block = self.add_block(f'for_init_{self._for_counter}') 258 | test_block = self.add_block(f'for_cond_{self._for_counter}') 259 | body_block = self.add_block(f'for_body_{self._for_counter}') 260 | end_block = self.add_block(f'for_after_{self._for_counter}') 261 | self.break_block_stack.append(end_block) 262 | 263 | self.branch(init_block) 264 | self.set_block(init_block) 265 | 266 | start = self.visit(node.begin) 267 | stop = self.visit(node.end) 268 | step = self.visit(node.step) 269 | 270 | # Setup the increment variable 271 | varname = node.var.id 272 | inc = self.builder.alloca(ir_int64_t, name=varname) 273 | self.builder.store(start, inc) 274 | self.locals[varname] = inc 275 | 276 | # Setup the loop condition 277 | self.branch(test_block) 278 | self.set_block(test_block) 279 | cond = self.builder.icmp_signed('<', self.builder.load(inc), stop) 280 | self.builder.cbranch(cond, body_block, end_block) 281 | 282 | # Generate the loop body 283 | self.set_block(body_block) 284 | list(map(self.visit, node.body)) 285 | 286 | if self.block.terminator is None: 287 | # Increment the counter 288 | succ = self.builder.add(self.const(step), self.builder.load(inc)) 289 | self.builder.store(succ, inc) 290 | 291 | # Exit the loop 292 | self.builder.branch(test_block) 293 | self.set_block(end_block) 294 | 295 | # Pop the break block 296 | self.break_block_stack.pop() 297 | 298 | def visit_Break(self, node: Break): 299 | if self.block.terminator is None: 300 | self.branch(self.break_block_stack[-1]) 301 | 302 | def visit_Prim(self, node: Prim): 303 | if node.fn == 'shape#': 304 | ref = node.args[0] 305 | shape = self.arrays[ref.id]['shape'] 306 | return shape 307 | elif node.fn not in LLVM_PRIM_OPS: 308 | raise NotImplementedError(ast.dump(node)) 309 | if node.fn == 'mult#': 310 | a = self.visit(node.args[0]) 311 | b = self.visit(node.args[1]) 312 | if a.type == ir_double_t: 313 | return self.builder.fmul(a, b) 314 | else: 315 | return self.builder.mul(a, b) 316 | elif node.fn == 'add#': 317 | a = self.visit(node.args[0]) 318 | b = self.visit(node.args[1]) 319 | if a.type == ir_double_t: 320 | return self.builder.fadd(a, b) 321 | else: 322 | return self.builder.add(a, b) 323 | elif node.fn == 'sub#': 324 | a = self.visit(node.args[0]) 325 | b = self.visit(node.args[1]) 326 | if a.type == ir_double_t: 327 | return self.builder.fsub(a, b) 328 | else: 329 | return self.builder.sub(a, b) 330 | elif node.fn == 'div#': 331 | a = self.visit(node.args[0]) 332 | b = self.visit(node.args[1]) 333 | if a.type == ir_double_t: 334 | return self.builder.fdiv(a, b) 335 | else: 336 | return self.builder.sdiv(a, b) 337 | elif node.fn == 'mod#': 338 | a = self.visit(node.args[0]) 339 | b = self.visit(node.args[1]) 340 | if a.type == ir_double_t: 341 | return self.builder.frem(a, b) 342 | else: 343 | return self.builder.srem(a, b) 344 | elif node.fn == 'lt#': 345 | a = self.visit(node.args[0]) 346 | b = self.visit(node.args[1]) 347 | if a.type == ir_double_t: 348 | return self.builder.fcmp_unordered('<', a, b) 349 | else: 350 | return self.builder.icmp_signed('<', a, b) 351 | elif node.fn == 'gt#': 352 | a = self.visit(node.args[0]) 353 | b = self.visit(node.args[1]) 354 | if a.type == ir_double_t: 355 | return self.builder.fcmp_unordered('>', a, b) 356 | else: 357 | return self.builder.icmp_signed('>', a, b) 358 | elif node.fn == 'le#': 359 | a = self.visit(node.args[0]) 360 | b = self.visit(node.args[1]) 361 | if a.type == ir_double_t: 362 | return self.builder.fcmp_unordered('<=', a, b) 363 | else: 364 | return self.builder.icmp_signed('<=', a, b) 365 | elif node.fn == 'ge#': 366 | a = self.visit(node.args[0]) 367 | b = self.visit(node.args[1]) 368 | if a.type == ir_double_t: 369 | return self.builder.fcmp_unordered('>=', a, b) 370 | else: 371 | return self.builder.icmp_signed('>=', a, b) 372 | elif node.fn == 'eq#': 373 | a = self.visit(node.args[0]) 374 | b = self.visit(node.args[1]) 375 | if a.type == ir_double_t: 376 | return self.builder.fcmp_unordered('==', a, b) 377 | else: 378 | return self.builder.icmp_signed('==', a, b) 379 | elif node.fn == 'ne#': 380 | a = self.visit(node.args[0]) 381 | b = self.visit(node.args[1]) 382 | if a.type == ir_double_t: 383 | return self.builder.fcmp_unordered('!=', a, b) 384 | else: 385 | return self.builder.icmp_signed('!=', a, b) 386 | elif node.fn == 'and#': 387 | a = self.visit(node.args[0]) 388 | b = self.visit(node.args[1]) 389 | return self.builder.and_(a, b) 390 | elif node.fn == 'or#': 391 | a = self.visit(node.args[0]) 392 | b = self.visit(node.args[1]) 393 | return self.builder.or_(a, b) 394 | elif node.fn == 'not#': 395 | a = self.visit(node.args[0]) 396 | return self.builder.not_(a) 397 | elif node.fn == 'neg#': 398 | a = self.visit(node.args[0]) 399 | if a.type == ir_double_t: 400 | return self.builder.fsub(self.const(0), a) 401 | else: 402 | return self.builder.sub(self.const(0), a) 403 | elif node.fn == 'pow#': 404 | # a = self.visit(node.args[0]) 405 | # b = self.visit(node.args[1]) 406 | # return self.builder.call(pow_func, [a, b]) 407 | raise NotImplementedError('pow#', ast.dump(node)) 408 | 409 | def visit_Assign(self, node: Assign): 410 | # Subsequent assignment 411 | if node.ref in self.locals: 412 | name = node.ref 413 | ptr = self.locals[name] 414 | value = self.visit(node.value) 415 | self.builder.store(value, ptr) 416 | self.locals[name] = ptr 417 | return ptr 418 | 419 | # First assignment 420 | else: 421 | name = node.ref 422 | value = self.visit(node.value) 423 | ty = self.specialize(node) 424 | ptr = self.builder.alloca(ty, name=name) 425 | self.builder.store(value, ptr) 426 | self.locals[name] = ptr 427 | return ptr 428 | 429 | def visit_NoneType(self, node: None): 430 | return None 431 | 432 | def visit_If(self, node: If): 433 | if not hasattr(self, '_if_counter'): 434 | self._if_counter = 0 435 | self._if_counter += 1 436 | test_block = self.add_block(f'if_cond_{self._if_counter}') 437 | then_block = self.add_block(f'if_then_{self._if_counter}') 438 | if has_else := len(node.orelse) > 0: 439 | else_block = self.add_block(f'if_orelse_{self._if_counter}') 440 | end_block = self.add_block(f'if_after_{self._if_counter}') 441 | 442 | self.branch(test_block) 443 | self.set_block(test_block) 444 | test = self.visit(node.test) 445 | self.builder.cbranch( 446 | test, then_block, else_block if has_else else end_block) 447 | 448 | self.set_block(then_block) 449 | list(map(self.visit, node.body)) 450 | if self.block.terminator is None: 451 | self.branch(end_block) 452 | 453 | if has_else: 454 | self.set_block(else_block) 455 | list(map(self.visit, node.orelse)) 456 | if self.block.terminator is None: 457 | self.branch(end_block) 458 | 459 | self.set_block(end_block) 460 | 461 | def visit_While(self, node: While): 462 | if not hasattr(self, '_while_counter'): 463 | self._while_counter = 0 464 | self._while_counter += 1 465 | test_block = self.add_block(f'while_cond_{self._while_counter}') 466 | then_block = self.add_block(f'while_then_{self._while_counter}') 467 | if has_else := len(node.orelse) > 0: 468 | else_block = self.add_block(f'while_orelse_{self._while_counter}') 469 | end_block = self.add_block(f'while_after_{self._while_counter}') 470 | 471 | self.branch(test_block) 472 | self.set_block(test_block) 473 | test = self.visit(node.test) 474 | self.builder.cbranch( 475 | test, then_block, else_block if has_else else end_block) 476 | 477 | self.set_block(then_block) 478 | list(map(self.visit, node.body)) 479 | if self.block.terminator is None: 480 | self.branch(test_block) 481 | 482 | if has_else: 483 | self.set_block(else_block) 484 | list(map(self.visit, node.orelse)) 485 | if self.block.terminator is None: 486 | self.branch(test_block) 487 | 488 | self.set_block(end_block) 489 | 490 | def visit_Compare(self, node: Compare): 491 | # Setup the increment variable 492 | lf = self.visit(node.left) 493 | rt = self.visit(node.comparators[0]) 494 | op = { 495 | 'eq#': '==', 496 | 'ne#': '!=', 497 | 'lt#': '<', 498 | 'gt#': '>', 499 | 'le#': '<=', 500 | 'ge#': '>=', 501 | }.get(node.ops[0], None) 502 | if op is None: 503 | raise NotImplementedError(node.ops[0]) 504 | cond = self.builder.icmp_signed(op, lf, rt) 505 | return cond 506 | 507 | def visit_CallFunc(self, node: CallFunc): 508 | func_name = node.fn.id 509 | args = [self.visit(arg) for arg in node.args] 510 | func = None 511 | if func_name == self.org_func_name: 512 | # Implement recursion 513 | func = self.function 514 | return self.builder.call(func, args) 515 | 516 | # Call registered functions 517 | fn = get_reg_func(func_name) 518 | if fn is None: 519 | raise NotImplementedError( 520 | f'CallFunc: {func_name} function is not registered!') 521 | 522 | fname = fn.__name__ 523 | types = fn.__annotations__ 524 | arg_names = fn.__code__.co_varnames 525 | ir_return_type = to_lltype(arg_classtype(types['return'])) 526 | ir_args = [to_lltype(arg_classtype(types[arg_name])) 527 | for arg_name in arg_names] 528 | wrap_caller_func_t = ir.FunctionType(ir_return_type, ir_args) 529 | wrap_caller_func_t_ptr = wrap_caller_func_t.as_pointer() 530 | 531 | # Get source function addr 532 | c_args = list(map(arg_ctype, ir_args)) 533 | FUNC_T = ctypes.CFUNCTYPE(arg_ctype(ir_return_type), *c_args) 534 | pyfunc_ptr = ctypes.cast(FUNC_T(fn), c_void_p).value 535 | 536 | # Make llvm ir wrapper function 537 | func_ptr = self.builder.inttoptr( 538 | ir.Constant(ir_int64_t, pyfunc_ptr), 539 | wrap_caller_func_t_ptr, name=f'{mangler(fname, args)}_ptr' 540 | ) 541 | return self.builder.call(func_ptr, args) 542 | 543 | def visit_Expr(self, node: Expr): 544 | return self.visit(node.value) 545 | 546 | def visit(self, node): 547 | name = f'visit_{type(node).__name__}' 548 | if hasattr(self, name): 549 | return getattr(self, name)(node) 550 | else: 551 | return self.generic_visit(node) 552 | 553 | def generic_visit(self, node): 554 | raise NotImplementedError(ast.dump(node)) 555 | -------------------------------------------------------------------------------- /pyjiting/infer.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import string 3 | 4 | from llvmlite.ir.types import VoidType 5 | 6 | from .ast import LLVM_PRIM_OPS, Fun 7 | from .types import * 8 | 9 | ''' 10 | Type Inference 11 | ''' 12 | 13 | 14 | def naming(): 15 | k = 0 16 | while True: 17 | for a in string.ascii_lowercase: 18 | yield f'\{a}{str(k)}' if (k > 0) else a 19 | k = k+1 20 | 21 | 22 | class TypeInferencer: 23 | 24 | def __init__(self, env=None): 25 | self.constraints = [] 26 | self.env = {} if env is None else env 27 | self.names = naming() 28 | self.org_func_name = None 29 | 30 | def fresh(self): 31 | return VarType('$' + next(self.names)) # New meta type variable. 32 | 33 | def visit(self, node): 34 | name = f'visit_{type(node).__name__}' 35 | if hasattr(self, name): 36 | return getattr(self, name)(node) 37 | else: 38 | return self.generic_visit(node) 39 | 40 | def visit_Fun(self, node): 41 | self.org_func_name = node.fname 42 | self.args = [self.fresh() for v in node.args] 43 | self.return_type = VarType('$return_type') 44 | for (arg, ty) in list(zip(node.args, self.args)): 45 | arg.type = ty 46 | self.env[arg.id] = ty 47 | list(map(self.visit, node.body)) 48 | return FuncType(args=self.args, return_type=self.return_type) 49 | 50 | def visit_NoneType(self, node): 51 | return VoidType() 52 | 53 | def visit_Noop(self, node): 54 | pass 55 | 56 | def visit_break(self, node): 57 | return None 58 | 59 | def visit_If(self, node): 60 | self.visit(node.test) 61 | list(map(self.visit, node.body)) 62 | list(map(self.visit, node.orelse)) 63 | 64 | def visit_While(self, node): 65 | self.visit(node.test) 66 | list(map(self.visit, node.body)) 67 | list(map(self.visit, node.orelse)) 68 | 69 | def visit_Compare(self, node): 70 | ty = self.visit(node.left) 71 | return ty 72 | 73 | def visit_LitInt(self, node): 74 | tv = self.fresh() 75 | node.type = tv 76 | return tv 77 | 78 | def visit_LitFloat(self, node): 79 | tv = self.fresh() 80 | node.type = tv 81 | return tv 82 | 83 | def visit_Assign(self, node): 84 | ty = self.visit(node.value) 85 | if node.ref in self.env: 86 | # Subsequent uses of a variable must have the same type. 87 | self.constraints += [(ty, self.env[node.ref])] 88 | self.env[node.ref] = ty 89 | node.type = ty 90 | return None 91 | 92 | def visit_Index(self, node): 93 | tv = self.fresh() 94 | ty = self.visit(node.value) 95 | ixty = self.visit(node.ix) 96 | self.constraints += [(ty, make_array_type(tv)), (ixty, int64_t)] 97 | return tv 98 | 99 | def visit_Prim(self, node): 100 | if node.fn == 'shape#': 101 | return make_array_type(int64_t) 102 | elif node.fn in LLVM_PRIM_OPS: 103 | tya = self.visit(node.args[0]) 104 | tyb = self.visit(node.args[1]) 105 | self.constraints += [(tya, tyb)] 106 | return tyb 107 | else: 108 | raise NotImplementedError(ast.dump(node)) 109 | 110 | def visit_Var(self, node): 111 | ty = self.env[node.id] 112 | node.type = ty 113 | return ty 114 | 115 | def visit_Return(self, node): 116 | ty = self.visit(node.value) 117 | self.constraints += [(ty, self.return_type)] 118 | 119 | def visit_Const(self, node): 120 | if isinstance(node.value, int): 121 | ty = int64_t 122 | elif isinstance(node.value, float): 123 | ty = double64_t 124 | elif isinstance(node.value, bool): 125 | ty = bool_t 126 | elif node.value is None: 127 | ty = void_t 128 | else: 129 | raise NotImplementedError(node.value) 130 | return ty 131 | 132 | def visit_Loop(self, node): 133 | self.env[node.var.id] = int64_t 134 | varty = self.visit(node.var) 135 | begin = self.visit(node.begin) 136 | end = self.visit(node.end) 137 | self.constraints += [(varty, int64_t), ( 138 | begin, int64_t), (end, int64_t)] 139 | list(map(self.visit, node.body)) 140 | 141 | def visit_Break(self, node): 142 | return None 143 | 144 | def visit_CallFunc(self, node): 145 | list(map(self.visit, node.args)) 146 | # Implement recursion 147 | if node.fn.id == self.org_func_name: 148 | return self.return_type 149 | # TODO: type inference for function calls return value. 150 | return int64_t 151 | 152 | def visit_Expr(self, node): 153 | self.visit(node.value) 154 | 155 | def generic_visit(self, node): 156 | raise NotImplementedError(ast.dump(node)) 157 | 158 | 159 | class UnderDetermined(Exception): 160 | def __str__(self): 161 | return 'The types in the function are not fully determined by the input types. Add annotations.' 162 | 163 | 164 | class InferError(Exception): 165 | def __init__(self, ty1, ty2): 166 | self.ty1 = ty1 167 | self.ty2 = ty2 168 | 169 | def __str__(self): 170 | return '\n'.join([ 171 | 'Type mismatch: ', 172 | 'Given: ', '\t' + str(self.ty1), 173 | 'Expected: ', '\t' + str(self.ty2) 174 | ]) 175 | 176 | 177 | class InfiniteType(Exception): 178 | def __init__(self, ty1, ty2): 179 | self.ty1 = ty1 180 | self.ty2 = ty2 181 | 182 | def __str__(self): 183 | return '\n'.join([ 184 | 'Type mismatch: ', 185 | 'Given: ', '\t' + str(self.ty1), 186 | 'Expected: ', '\t' + str(self.ty2) 187 | ]) 188 | -------------------------------------------------------------------------------- /pyjiting/ll_types.py: -------------------------------------------------------------------------------- 1 | 2 | import ctypes 3 | 4 | import numpy as np 5 | from llvmlite import ir 6 | 7 | ''' 8 | Define a mapping from NumPy dtype to C types. 9 | ''' 10 | 11 | 12 | # Adapt the LLVM types to use libffi/ctypes wrapper so we can dynamically create 13 | # the appropriate C types for our JIT'd function at runtime. 14 | _nptypemap = { 15 | 'i': ctypes.c_int, 16 | 'l': ctypes.c_long, 17 | 'f': ctypes.c_float, 18 | 'd': ctypes.c_double, 19 | } 20 | 21 | 22 | def mangler(fname: str, sig) -> str: 23 | return fname + str(hash(tuple(sig))) 24 | 25 | 26 | def wrap_module(sig, llfunc, engine): 27 | pfunc = wrap_function(llfunc, engine) 28 | dispatch = dispatcher(pfunc) 29 | return dispatch 30 | 31 | 32 | def wrap_function(func, engine): 33 | args = func.type.pointee.args 34 | ret_type = func.type.pointee.return_type 35 | ret_ctype = wrap_type(ret_type) 36 | args_ctypes = list(map(wrap_type, args)) 37 | 38 | functype = ctypes.CFUNCTYPE(ret_ctype, *args_ctypes) 39 | fptr = engine.get_function_address(func.name) 40 | 41 | cfunc = functype(fptr) 42 | cfunc.__name__ = func.name 43 | return cfunc 44 | 45 | 46 | def wrap_type(llvm_type): 47 | if isinstance(llvm_type, ir.IntType): 48 | ctype = getattr(ctypes, 'c_int'+str(llvm_type.width)) 49 | elif isinstance(llvm_type, ir.DoubleType): 50 | ctype = ctypes.c_double 51 | elif isinstance(llvm_type, ir.FloatType): 52 | ctype = ctypes.c_float 53 | elif isinstance(llvm_type, ir.VoidType): 54 | ctype = None 55 | elif isinstance(llvm_type, ir.PointerType): 56 | pointee = llvm_type.pointee 57 | if isinstance(pointee, ir.IntType): 58 | width = pointee.width 59 | if width == 8: 60 | ctype = ctypes.c_char_p 61 | else: 62 | ctype = ctypes.POINTER(wrap_type(pointee)) 63 | elif isinstance(pointee, ir.VoidType): 64 | ctype = ctypes.c_void_p 65 | else: 66 | ctype = ctypes.POINTER(wrap_type(pointee)) 67 | elif isinstance(llvm_type, ir.IdentifiedStructType): 68 | struct_name = llvm_type.name.split('.')[-1] 69 | struct_type = None 70 | 71 | if struct_type and issubclass(struct_type, ctypes.Structure): 72 | return struct_type 73 | 74 | if hasattr(struct_type, '_fields_'): 75 | names = struct_type._fields_ 76 | else: 77 | names = ['field'+str(n) for n in range(len(llvm_type.elements))] 78 | 79 | ctype = type(ctypes.Structure)(struct_name, (ctypes.Structure,), 80 | {'__module__': 'numpile'}) 81 | 82 | fields = [(name, wrap_type(elem)) 83 | for name, elem in list(zip(names, llvm_type.elements))] 84 | setattr(ctype, '_fields_', fields) 85 | else: 86 | raise RuntimeError(f'Unknown LLVM type {llvm_type}') 87 | return ctype 88 | 89 | 90 | def wrap_ndarray(na): 91 | # For NumPy arrays grab the underlying data pointer. Doesn't copy. 92 | ctype = _nptypemap[na.dtype.char] 93 | _shape = list(na.shape) 94 | data = na.ctypes.data_as(ctypes.POINTER(ctype)) 95 | dims = len(na.strides) 96 | shape = (ctypes.c_int*dims)(*_shape) 97 | return (data, dims, shape) 98 | 99 | 100 | def wrap_arg(arg, value): 101 | if isinstance(value, np.ndarray): 102 | ndarray = arg._type_ 103 | data, dims, shape = wrap_ndarray(value) 104 | return ndarray(data, dims, shape) 105 | else: 106 | return value 107 | 108 | 109 | def dispatcher(fn): 110 | def _call_closure(*args): 111 | cargs = list(fn._argtypes_) 112 | pargs = list(args) 113 | rargs = list(map(wrap_arg, cargs, pargs)) 114 | return fn(*rargs) 115 | _call_closure.__name__ = fn.__name__ 116 | return _call_closure 117 | -------------------------------------------------------------------------------- /pyjiting/main.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import inspect 3 | import sys 4 | from ast import dump as ast_dump 5 | from ast import parse as ast_parse 6 | from ctypes import c_int64, c_void_p 7 | from textwrap import dedent 8 | 9 | import llvmlite.binding as llvm 10 | import numpy as np 11 | from llvmlite import ir 12 | 13 | from .codegen import LLVMCodeGen, determined, ir_int64_t, reg_func, to_lltype 14 | from .infer import TypeInferencer, UnderDetermined 15 | from .ll_types import mangler, wrap_module 16 | from .parser import ASTVisitor 17 | from .types import * 18 | from .utils import apply, compose, solve, unify 19 | 20 | # Output debug info 21 | DEBUG = False 22 | 23 | 24 | def debug(fmt, *args): 25 | if not DEBUG: 26 | return 27 | print('=' * 80) 28 | print(fmt, *args) 29 | 30 | 31 | llvm.initialize() 32 | llvm.initialize_native_target() 33 | llvm.initialize_native_asmprinter() 34 | 35 | module = ir.Module('pyjiting.module') 36 | function_cache = {} 37 | 38 | target_machine = llvm.Target.from_default_triple().create_target_machine() 39 | backing_mod = llvm.parse_assembly('') 40 | engine = llvm.create_mcjit_compiler(backing_mod, target_machine) 41 | 42 | 43 | def reg(fn): 44 | fname = fn.__name__ 45 | reg_func(fname, fn) 46 | return fn 47 | 48 | 49 | def jit(fn): 50 | debug(ast_dump(ast_parse(dedent(inspect.getsource(fn))), indent=4)) 51 | transformer = ASTVisitor() 52 | ast = transformer(fn) 53 | (ty, mgu) = typeinfer(ast) 54 | debug(ast_dump(ast, indent=4)) 55 | return specialize(ast, ty, mgu) 56 | 57 | 58 | def arg_pytype(arg): 59 | if isinstance(arg, np.ndarray): 60 | if arg.dtype == np.dtype('int64'): 61 | return make_array_type(int64_t) 62 | elif arg.dtype == np.dtype('double'): 63 | return make_array_type(double64_t) 64 | elif arg.dtype == np.dtype('float'): 65 | return make_array_type(float32_t) 66 | elif isinstance(arg, int) and arg <= sys.maxsize: 67 | return int64_t 68 | elif isinstance(arg, float): 69 | return double64_t 70 | else: 71 | raise RuntimeError('Unsupported type:', type(arg)) 72 | 73 | 74 | def specialize(ast, infer_ty, mgu): 75 | def _wrapper(*func_args): 76 | types = list(map(arg_pytype, list(func_args))) 77 | spec_ty = FuncType(args=types, return_type=VarType('$return_type')) 78 | unifier = unify(infer_ty, spec_ty) 79 | specializer = compose(unifier, mgu) 80 | debug('specializer:', specializer) 81 | 82 | return_type = apply(specializer, VarType('$return_type')) 83 | args = [apply(specializer, ty) for ty in types] 84 | debug('Specialized Function:', FuncType( 85 | args=args, return_type=return_type)) 86 | 87 | is_deteremined_return_type = determined(return_type) 88 | if is_deteremined_return_type and all(map(determined, args)): 89 | key = mangler(ast.fname, args) 90 | # Don't recompile after we've specialized. 91 | if key in function_cache: 92 | return function_cache[key](*func_args) 93 | else: 94 | llfunc = codegen(module, ast, specializer, return_type, args) 95 | pyfunc = wrap_module(args, llfunc, engine) 96 | function_cache[key] = pyfunc 97 | return pyfunc(*func_args) 98 | else: 99 | raise UnderDetermined() 100 | return _wrapper 101 | 102 | 103 | def typeinfer(ast): 104 | infer = TypeInferencer() 105 | ty = infer.visit(ast) 106 | mgu = solve(infer.constraints) 107 | infer_ty = apply(mgu, ty) 108 | debug('infer_ty', infer_ty) 109 | debug('mgu', mgu) 110 | debug('infer.constraints', infer.constraints) 111 | return (infer_ty, mgu) 112 | 113 | 114 | def codegen(module, ast, specializer, return_type, args): 115 | cgen = LLVMCodeGen(module, specializer, return_type, args) 116 | cgen.visit(ast) 117 | 118 | mod = llvm.parse_assembly(str(module)) 119 | mod.verify() 120 | 121 | pmb = llvm.PassManagerBuilder() 122 | pmb.opt_level = 3 123 | pmb.loop_vectorize = True 124 | 125 | pm = llvm.ModulePassManager() 126 | pmb.populate(pm) 127 | 128 | pm.run(mod) 129 | 130 | engine.add_module(mod) 131 | 132 | debug(cgen.function) 133 | debug(target_machine.emit_assembly(mod)) 134 | return cgen.function 135 | -------------------------------------------------------------------------------- /pyjiting/parser.py: -------------------------------------------------------------------------------- 1 | 2 | import ast 3 | import inspect 4 | import types 5 | from textwrap import dedent 6 | 7 | from .ast import (PRIM_OPS, CallFunc, Assign, Break, Compare, Const, Expr, Fun, If, Index, 8 | LitBool, LitFloat, LitInt, Loop, Noop, Prim, Return, Var, While) 9 | from .types import * 10 | 11 | ''' 12 | Parse a Python function into pyjiting CoreAST. 13 | ''' 14 | 15 | 16 | def get_type_hint(var): 17 | if hasattr(var, 'annotation') and hasattr(var.annotation, 'id'): 18 | ty = var.annotation.id 19 | if ty == 'int64': 20 | return int64_t 21 | elif ty == 'float': 22 | return double64_t 23 | elif ty == 'bool': 24 | return int64_t 25 | return None 26 | return None 27 | 28 | 29 | class ASTVisitor(ast.NodeVisitor): 30 | 31 | def __init__(self): 32 | pass 33 | 34 | def __call__(self, source): 35 | if isinstance(source, types.ModuleType): 36 | source = dedent(inspect.getsource(source)) 37 | if isinstance(source, types.FunctionType): 38 | source = dedent(inspect.getsource(source)) 39 | if isinstance(source, types.LambdaType): 40 | source = dedent(inspect.getsource(source)) 41 | elif isinstance(source, str): 42 | source = dedent(source) 43 | else: 44 | raise NotImplementedError(ast.dump(source)) 45 | 46 | self._source = source 47 | self._ast = ast.parse(source) 48 | return self.visit(self._ast) 49 | 50 | def visit_Module(self, node): 51 | body = list(map(self.visit, node.body)) 52 | return body[0] 53 | 54 | def visit_Name(self, node): 55 | return Var(node.id) 56 | 57 | def visit_Num(self, node): 58 | if isinstance(node.n, float): 59 | return LitFloat(node.n) 60 | else: 61 | return LitInt(node.n) 62 | 63 | def visit_Bool(self, node): 64 | return LitBool(node.n) 65 | 66 | def visit_Call(self, node): 67 | name = self.visit(node.func) 68 | args = list(map(self.visit, node.args)) 69 | return CallFunc(name, args) 70 | 71 | def visit_BinOp(self, node): 72 | op_str = node.op.__class__ 73 | a = self.visit(node.left) 74 | b = self.visit(node.right) 75 | opname = PRIM_OPS[op_str] 76 | return Prim(opname, [a, b]) 77 | 78 | def visit_Assign(self, node): 79 | assert len(node.targets) == 1 80 | var = node.targets[0].id 81 | value = self.visit(node.value) 82 | return Assign(var, value, get_type_hint(var)) 83 | 84 | def visit_FunctionDef(self, node): 85 | stmts = list(node.body) 86 | stmts = list(map(self.visit, stmts)) 87 | args = [Var(a.arg, get_type_hint(a)) for a in node.args.args] 88 | res = Fun(node.name, args, stmts) 89 | return res 90 | 91 | def visit_Pass(self, node): 92 | return Noop() 93 | 94 | def visit_Break(self, node): 95 | return Break() 96 | 97 | def visit_Return(self, node): 98 | value = self.visit(node.value) 99 | return Return(value) 100 | 101 | def visit_Constant(self, node): 102 | val = node.value 103 | if isinstance(val, bool): 104 | return LitBool(val) 105 | elif isinstance(val, int): 106 | return LitInt(val) 107 | elif isinstance(val, float): 108 | return LitFloat(val) 109 | raise NotImplementedError(node) 110 | 111 | def visit_Attribute(self, node): 112 | if node.attr == 'shape': 113 | value = self.visit(node.value) 114 | return Prim('shape#', [value]) 115 | else: 116 | raise NotImplementedError(ast.dump(node)) 117 | 118 | def visit_Subscript(self, node): 119 | if isinstance(node.ctx, ast.Load): 120 | if node.slice: 121 | value = self.visit(node.value) 122 | ix = self.visit(node.slice) 123 | return Index(value, ix) 124 | elif isinstance(node.ctx, ast.Store): 125 | raise NotImplementedError(ast.dump(node)) 126 | 127 | def visit_int(self, node): 128 | return LitInt(node, type=int64_t) 129 | 130 | def visit_For(self, node): 131 | target = self.visit(node.target) 132 | stmts = list(map(self.visit, node.body)) 133 | if node.iter.func.id in ['xrange', 'range']: 134 | args = list(map(self.visit, node.iter.args)) 135 | else: 136 | raise RuntimeError('Loop must be over range') 137 | 138 | start = 0 139 | stop = 0 140 | step = Const(1) 141 | if len(args) == 1: # range(stop) 142 | start = Const(0) 143 | stop = args[0] 144 | elif len(args) == 2: # range(start,stop) 145 | start = args[0] 146 | stop = args[1] 147 | elif len(args) == 3: # range(start,stop,step) 148 | start = args[0] 149 | stop = args[1] 150 | step = args[2] 151 | return Loop(target, start, stop, stmts, step) 152 | 153 | def visit_If(self, node): 154 | test = self.visit(node.test) 155 | body = list(map(self.visit, node.body)) 156 | orelse = list(map(self.visit, node.orelse)) 157 | return If(test, body, orelse) 158 | 159 | def visit_While(self, node): 160 | test = self.visit(node.test) 161 | body = list(map(self.visit, node.body)) 162 | orelse = list(map(self.visit, node.orelse)) 163 | return While(test, body, orelse) 164 | 165 | def visit_Compare(self, node): 166 | def visit_op(sub_node): 167 | op_str = sub_node.__class__ 168 | opname = PRIM_OPS[op_str] 169 | return opname 170 | left = self.visit(node.left) 171 | ops = list(map(visit_op, node.ops)) 172 | comparators = list(map(self.visit, node.comparators)) 173 | return Compare(left, ops, comparators) 174 | 175 | def visit_AugAssign(self, node): 176 | if isinstance(node.op, ast.Add): 177 | ref = node.target.id 178 | value = self.visit(node.value) 179 | return Assign(ref, Prim('add#', [Var(ref), value])) 180 | if isinstance(node.op, ast.Mult): 181 | ref = node.target.id 182 | value = self.visit(node.value) 183 | return Assign(ref, Prim('mult#', [Var(ref), value])) 184 | else: 185 | raise NotImplementedError(ast.dump(node)) 186 | 187 | def visit_Constant(self, node): 188 | return Const(node.value) 189 | 190 | def visit_Expr(self, node): 191 | return Expr(self.visit(node.value)) 192 | 193 | def generic_visit(self, node): 194 | raise NotImplementedError(ast.dump(node)) 195 | -------------------------------------------------------------------------------- /pyjiting/types.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | from typing import Any, Union 3 | 4 | from llvmlite.ir.types import FunctionType, PointerType, Type 5 | 6 | ''' 7 | Basic types and type constructors 8 | ''' 9 | 10 | 11 | class VarType(Type): 12 | def __init__(self, s): 13 | self.s = s 14 | 15 | def __hash__(self): 16 | return hash(self.s) 17 | 18 | def __eq__(self, other): 19 | if isinstance(other, VarType): 20 | return (self.s == other.s) 21 | else: 22 | return False 23 | 24 | def __str__(self): 25 | return self.s 26 | 27 | 28 | class BaseType(Type): 29 | def __init__(self, s): 30 | self.s = s 31 | 32 | def __eq__(self, other): 33 | if isinstance(other, BaseType): 34 | return (self.s == other.s) 35 | else: 36 | return False 37 | 38 | def __hash__(self): 39 | return hash(self.s) 40 | 41 | def __str__(self): 42 | return self.s 43 | 44 | 45 | class GenericType(Type): 46 | def __init__(self, a, b): 47 | self.a = a 48 | self.b = b 49 | 50 | def __eq__(self, other): 51 | if isinstance(other, GenericType): 52 | return (self.a == other.a) & (self.b == other.b) 53 | else: 54 | return False 55 | 56 | def __hash__(self): 57 | return hash((self.a, self.b)) 58 | 59 | def __str__(self): 60 | return str(self.a) + ' ' + str(self.b) 61 | 62 | 63 | class FuncType(FunctionType): 64 | def __str__(self): 65 | return str(self.args) + ' -> ' + str(self.return_type) 66 | 67 | 68 | CoreType = Union[GenericType, BaseType, FuncType, VarType] 69 | 70 | int32_t = BaseType('Int32') 71 | int64_t = BaseType('Int64') 72 | bool_t = BaseType('Bool') 73 | float32_t = BaseType('Float') 74 | double64_t = BaseType('Double') 75 | void_t = BaseType('Void') 76 | array_t = BaseType('Array') 77 | 78 | ptr_t = PointerType 79 | 80 | 81 | def make_array_type(t): return GenericType(array_t, t) 82 | 83 | 84 | int32_array_t = make_array_type(int32_t) 85 | int64_array_t = make_array_type(int64_t) 86 | double64_array_t = make_array_type(double64_t) 87 | 88 | 89 | def ftv(x) -> set: 90 | # ftv: free type variables 91 | if isinstance(x, BaseType): 92 | return set() 93 | elif isinstance(x, GenericType): 94 | return ftv(x.a) | ftv(x.b) 95 | elif isinstance(x, FuncType): 96 | return reduce(set.union, set(map(ftv, x.args))) | ftv(x.return_type) 97 | elif isinstance(x, VarType): 98 | return set([x]) 99 | 100 | 101 | def is_array(ty: Union[GenericType, Any]) -> bool: 102 | return isinstance(ty, GenericType) and ty.a == array_t 103 | -------------------------------------------------------------------------------- /pyjiting/utils.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | 3 | from .infer import InferError, InfiniteType 4 | from .types import BaseType, CoreType, FuncType, GenericType, VarType, ftv 5 | 6 | ''' 7 | Lang utils and constraint solver 8 | ''' 9 | 10 | 11 | def empty() -> dict: 12 | return {} 13 | 14 | 15 | def apply(s: dict, t: CoreType) -> CoreType: 16 | if isinstance(t, BaseType): 17 | return t 18 | elif isinstance(t, GenericType): 19 | return GenericType(apply(s, t.a), apply(s, t.b)) 20 | elif isinstance(t, FuncType): 21 | args = [apply(s, a) for a in t.args] 22 | return_type = apply(s, t.return_type) 23 | return FuncType(args=args, return_type=return_type) 24 | elif isinstance(t, VarType): 25 | return s.get(t.s, t) 26 | 27 | 28 | def apply_list(s: dict, xs: list) -> list: 29 | return [(apply(s, x), apply(s, y)) for (x, y) in xs] 30 | 31 | 32 | def unify(x: CoreType, y: CoreType) -> dict: 33 | if isinstance(x, GenericType) and isinstance(y, GenericType): 34 | s1 = unify(x.a, y.a) 35 | s2 = unify(apply(s1, x.b), apply(s1, y.b)) 36 | return compose(s2, s1) 37 | elif isinstance(x, BaseType) and isinstance(y, BaseType) and (x == y): 38 | return empty() 39 | elif isinstance(x, FuncType) and isinstance(y, FuncType): 40 | if len(x.args) != len(y.args): 41 | raise RuntimeError('Wrong number of arguments') 42 | s1 = solve(list(zip(x.args, y.args))) 43 | s2 = unify(apply(s1, x.return_type), apply(s1, y.return_type)) 44 | return compose(s2, s1) 45 | elif isinstance(x, VarType): 46 | return bind(x.s, y) 47 | elif isinstance(y, VarType): 48 | return bind(y.s, x) 49 | else: 50 | raise InferError(x, y) 51 | 52 | 53 | def solve(xs: list): 54 | mgu = empty() 55 | cs = deque(xs) 56 | while len(cs): 57 | (a, b) = cs.pop() 58 | s = unify(a, b) 59 | mgu = compose(s, mgu) 60 | cs = deque(apply_list(s, list(cs))) 61 | return mgu 62 | 63 | 64 | def bind(n, x): 65 | if x == n: 66 | return empty() 67 | elif occurs_check(n, x): 68 | raise InfiniteType(n, x) 69 | else: 70 | return dict([(n, x)]) 71 | 72 | 73 | def occurs_check(n, x) -> bool: 74 | return n in ftv(x) 75 | 76 | 77 | def union(s1: dict, s2: dict) -> dict: 78 | nenv = s1.copy() 79 | nenv.update(s2) 80 | return nenv 81 | 82 | 83 | def compose(s1: dict, s2: dict) -> dict: 84 | s3 = dict((t, apply(s1, u)) for t, u in s2.items()) 85 | return union(s1, s3) 86 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | llvmlite 2 | numpy 3 | --------------------------------------------------------------------------------