├── .gitignore ├── ChangeLog.md ├── LICENSE ├── README.md ├── README.zh_CN.md ├── benchmarks ├── append3.py ├── brainfuck.py ├── const_fib.py ├── dna_read.py ├── fib.py ├── hypot.py ├── selection_sort.py ├── trans.py └── union_type.py ├── diojit ├── __init__.py ├── absint │ ├── __init__.py │ ├── abs.py │ ├── intrinsics.py │ └── prescr.py ├── codegen │ ├── __init__.py │ └── julia.py ├── runtime │ ├── __init__.py │ └── julia_rt.py ├── stack2reg │ ├── __init__.py │ ├── cflags.py │ ├── opcodes.py │ └── translate.py └── user │ ├── __init__.py │ └── client.py ├── genopname.py ├── prepub.sh ├── requirements.txt ├── runtests ├── doil.py ├── load.py └── tutorial.py ├── setup.py └── test.sh /.gitignore: -------------------------------------------------------------------------------- 1 | _test.py 2 | _test.ll 3 | .vscode/ 4 | bin/ 5 | obj/ 6 | .ionide/ 7 | node_modules/ 8 | .fake 9 | .idea/ 10 | .ci/ 11 | typeshed/ 12 | 13 | **__pycache__/ 14 | dev_test/ 15 | Digraph.gv* 16 | .spyproject/ 17 | **.egg-info/ 18 | dist/ 19 | **~ 20 | build/ 21 | 22 | 23 | # Created by https://www.gitignore.io/api/macos,python 24 | # Edit at https://www.gitignore.io/?templates=macos,python 25 | 26 | ### macOS ### 27 | # General 28 | .DS_Store 29 | .AppleDouble 30 | .LSOverride 31 | 32 | # Icon must end with two \r 33 | Icon 34 | 35 | # Thumbnails 36 | ._* 37 | 38 | # Files that might appear in the root of a volume 39 | .DocumentRevisions-V100 40 | .fseventsd 41 | .Spotlight-V100 42 | .TemporaryItems 43 | .Trashes 44 | .VolumeIcon.icns 45 | .com.apple.timemachine.donotpresent 46 | 47 | # Directories potentially created on remote AFP share 48 | .AppleDB 49 | .AppleDesktop 50 | Network Trash Folder 51 | Temporary Items 52 | .apdisk 53 | 54 | ### Python ### 55 | # Byte-compiled / optimized / DLL files 56 | __pycache__/ 57 | *.py[cod] 58 | *$py.class 59 | 60 | # C extensions 61 | *.so 62 | 63 | # Distribution / packaging 64 | .Python 65 | build/ 66 | develop-eggs/ 67 | dist/ 68 | downloads/ 69 | eggs/ 70 | .eggs/ 71 | lib/ 72 | lib64/ 73 | parts/ 74 | sdist/ 75 | var/ 76 | wheels/ 77 | pip-wheel-metadata/ 78 | share/python-wheels/ 79 | *.egg-info/ 80 | .installed.cfg 81 | *.egg 82 | MANIFEST 83 | 84 | # PyInstaller 85 | # Usually these files are written by a python script from a template 86 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 87 | *.manifest 88 | *.spec 89 | 90 | # Installer logs 91 | pip-log.txt 92 | pip-delete-this-directory.txt 93 | 94 | # Unit test / coverage reports 95 | htmlcov/ 96 | .tox/ 97 | .nox/ 98 | .coverage 99 | .coverage.* 100 | .cache 101 | nosetests.xml 102 | coverage.xml 103 | *.cover 104 | .hypothesis/ 105 | .pytest_cache/ 106 | 107 | # Translations 108 | *.mo 109 | *.pot 110 | 111 | # Django stuff: 112 | *.log 113 | local_settings.py 114 | db.sqlite3 115 | db.sqlite3-journal 116 | 117 | # Flask stuff: 118 | instance/ 119 | .webassets-cache 120 | 121 | # Scrapy stuff: 122 | .scrapy 123 | 124 | # Sphinx documentation 125 | docs/_build/ 126 | 127 | # PyBuilder 128 | target/ 129 | 130 | # Jupyter Notebook 131 | .ipynb_checkpoints 132 | 133 | # IPython 134 | profile_default/ 135 | ipython_config.py 136 | 137 | # pyenv 138 | .python-version 139 | 140 | # pipenv 141 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 142 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 143 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 144 | # install all needed dependencies. 145 | #Pipfile.lock 146 | 147 | # celery beat schedule file 148 | celerybeat-schedule 149 | 150 | # SageMath parsed files 151 | *.sage.py 152 | 153 | # Environments 154 | .env 155 | .venv 156 | env/ 157 | venv/ 158 | ENV/ 159 | env.bak/ 160 | venv.bak/ 161 | 162 | # Spyder project settings 163 | .spyderproject 164 | .spyproject 165 | 166 | # Rope project settings 167 | .ropeproject 168 | 169 | # mkdocs documentation 170 | /site 171 | 172 | # mypy 173 | .mypy_cache/ 174 | .dmypy.json 175 | dmypy.json 176 | 177 | # Pyre type checker 178 | .pyre/ 179 | 180 | # End of https://www.gitignore.io/api/macos,python 181 | 182 | slide-examples/ -------------------------------------------------------------------------------- /ChangeLog.md: -------------------------------------------------------------------------------- 1 | ## 0.2a(2/8/2021) 2 | 3 | - Rename `jit_spec_call` to `spec_call`. 4 | 5 | ## 0.1.5(2/8/2021) 6 | 7 | - Add experimental features: `@eagerjit` and `@conservativejit`. 8 | 9 | The first one assumes field types according to annotations of fields, and tries to make all 10 | methods jit-able. 11 | 12 | The second one needs manually specifying jit-able methods, and does not totally believe users' 13 | annotations to fields: a runtime type check will be generated when accessing fields. 14 | 15 | ## 0.1.4.1(2/7/2021) 16 | 17 | - RC analysis. 18 | 19 | Previously, `def f(x): x = g(x)` can cause segmentfault due to unexpected deallocation. 20 | 21 | We now added analysis for reference counting in Python side, greatly reducing redundant RC 22 | operations at runtime. 23 | 24 | 25 | ## 0.1.2(2/5/2021) 26 | 27 | - The JIT compiler is now able to optimise selection sort using lists(40% speed up). 28 | 29 | **Experiments have shown that if we can have type-parameterised lists, we can have 30 | a performance gain in a factor of 600%.** 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Copyright (c) 2021, thautwarm 3 | 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without modification, 7 | are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, 10 | this list of conditions and the following disclaimer. 11 | * Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 16 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 17 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 18 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER 19 | OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 23 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 24 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | 27 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## DIO-JIT: General-purpose Python JIT 2 | 3 | [![中文README](https://img.shields.io/badge/i18n-%E4%B8%AD%E6%96%87%E6%96%87%E6%A1%A3-teal)](https://github.com/thautwarm/diojit/blob/master/README.zh_CN.md) [![PyPI version shields.io](https://img.shields.io/pypi/v/diojit.svg)](https://pypi.python.org/pypi/diojit/) 4 | [![JIT](https://img.shields.io/badge/cpython-3.8|3.9-green.svg)](https://pypi.python.org/pypi/diojit/) 5 | 6 | Important: 7 | 8 | 1. DIO-JIT now works for Python >= 3.8. We heavily rely on the `LOAD_METHOD` bytecode instruction. 9 | 2. DIO-JIT is not production-ready. a large number of specialisation rules are required to make DIO-JIT batteries-included. 10 | 3. This document is mainly provided for prospective developers. Users are not required to write any specialisation rules, which means that users need to learn nothing but `@jit.jit` and `jit.spec_call`. 11 | 12 | ### Benchmark 13 | 14 | | Item | PY38 | JIT PY38 | PY39 | JIT PY39 | 15 | | ------------------------------------------------------------------------------------------ | ------ | -------- | ------ | -------- | 16 | | [BF](https://github.com/thautwarm/diojit/blob/master/benchmarks/brainfuck.py) | 265.74 | 134.23 | 244.50 | 140.34 | 17 | | [append3](https://github.com/thautwarm/diojit/blob/master/benchmarks/append3.py) | 23.94 | 10.70 | 22.29 | 11.21 | 18 | | [DNA READ](https://github.com/thautwarm/diojit/blob/master/benchmarks/dna_read.py) | 16.96 | 14.82 | 15.03 | 14.38 | 19 | | [fib(15)](https://github.com/thautwarm/diojit/blob/master/benchmarks/fib.py) | 11.63 | 1.54 | 10.41 | 1.51 | 20 | | [hypot(str, str)](https://github.com/thautwarm/diojit/blob/master/benchmarks/hypot.py) | 6.19 | 3.87 | 6.53 | 4.29 | 21 | | [selectsort](https://github.com/thautwarm/diojit/blob/master/benchmarks/selection_sort.py) | 46.95 | 33.88 | 38.71 | 29.49 | 22 | | [trans](https://github.com/thautwarm/diojit/blob/master/benchmarks/trans.py) | 24.22 | 7.79 | 23.23 | 7.71 | 23 | 24 | The benchmark item "DNA READ" does not show a significant performance gain, this is because "DNA READ" heavily uses `bytearray` and `bytes`, whose specialised C-APIs 25 | are not exposed. In this case, although the JIT can infer the types, we have to fall back to CPython's default behaviour, or even worse: after all, the interpreter can access internal things, while we cannot. 26 | 27 | P.S: 28 | DIO-JIT can do very powerful partial evaluation, which is disabled in default but you can 29 | leverage it in your domain specific tasks. Here is an example of achieving **500x** speed up against pure Python: [fibs.py](https://github.com/thautwarm/diojit/blob/master/benchmarks/const_fib.py) 30 | 31 | ## Install Instructions 32 | 33 |
Step 1: Install Julia as an in-process native code compiler for DIO-JIT 34 |

35 | 36 | There are several options for you to install Julia: 37 | 38 | - [scoop](http://scoop.sh/) (Windows) 39 | - [julialang.org](https://julialang.org/downloads) (recommended for Windows users) 40 | - [jill.py](https://github.com/johnnychen94/jill.py): 41 | 42 | ```bash 43 | $ pip install jill && jill install 1.6 --upstream Official 44 | ``` 45 | 46 | - [jill](https://github.com/abelsiqueira/jill) (Mac and Linux only!): 47 | 48 | ```bash 49 | $ bash -ci "$(curl -fsSL https://raw.githubusercontent.com/abelsiqueira/jill/master/jill.sh)" 50 | ``` 51 | 52 |

53 |
54 | 55 |
Step 2: Install DIO.jl in Julia 56 |

57 | 58 | Type `julia` and open the REPL, then 59 | 60 | ```julia 61 | julia> 62 | # press ] 63 | pkg> add https://github.com/thautwarm/DIO.jl 64 | # press backspace 65 | julia> using DIO # precompile 66 | ``` 67 | 68 |

69 |
70 | 71 |
Step 3: Install Python Package 72 |

73 | 74 | ```bash 75 | $ pip install git+https://github.com/thautwarm/diojit 76 | ``` 77 | 78 |

79 |
80 | 81 |

82 | 83 | 84 |
How to fetch latest DIO-JIT?(if you have installed DIO) 85 | 86 |

87 | 88 | ```bash 89 | $ pip install -U diojit 90 | $ julia -e "using Pkg; Pkg.update(string(:DIO));using DIO" 91 | ``` 92 | 93 |

94 |
95 | 96 | Usage from Python side is quite similar to that from Numba. 97 | 98 | ```python 99 | import diojit 100 | from math import sqrt 101 | # eagerjit: assuming all global references are fixed 102 | @diojit.eagerjit 103 | def fib(a): 104 | if a <= 2: 105 | return 1 106 | return fib(a + -1) + fib(a + -2) 107 | 108 | jit_fib = diojit.spec_call(fib, diojit.oftype(int), diojit.oftype(int)) 109 | jit_fib(15) # 600% faster than pure python 110 | ``` 111 | 112 | It might look strange to you that we use `a + -1` and `a + -2` here. 113 | 114 | Clever observation! And that's the point! 115 | 116 | DIO-JIT relies on specilisation rules. We have written one for additions, more specifically, `operator.__add__`: [specilisation for `operator.__add__`](https://github.com/thautwarm/diojit/blob/175aab5f4cb65fee923b9f6cb97c256252fc49f5/diojit/absint/prescr.py#L226). 117 | 118 | However, due to the bandwidth limitation, rules for `operator.__sub__` is not implemented yet. 119 | 120 | (P.S: [why `operator.__add__`](https://github.com/thautwarm/diojit/blob/3ceb9513377234f476566f70792632ce08c13373/diojit/stack2reg/translate.py#L30).) 121 | 122 | Although specilisation is common in the scope of optimisation, unlike many other JIT attempts, DIO-JIT doesn't need to 123 | hard encode rules at compiler level. The DIO-JIT compiler implements the skeleton of abstract interpretation, but concrete 124 | rules for specialisation and other inferences can be added within Python itself in an extensible way! 125 | 126 | See an example below. 127 | 128 | ## Contribution Example: Add a specialisation rule for `list.append` 129 | 130 | 1. Python Side: 131 | 132 | ```python 133 | import diojit as jit 134 | import timeit 135 | jit.create_shape(list, oop=True) 136 | @jit.register(list, attr="append") 137 | def list_append_analysis(self: jit.Judge, *args: jit.AbsVal): 138 | if len(args) != 2: 139 | # rollback to CPython's default code 140 | return NotImplemented 141 | lst, elt = args 142 | 143 | return jit.CallSpec( 144 | instance=None, # return value is not static 145 | e_call=jit.S(jit.intrinsic("PyList_Append"))(lst, elt), 146 | possibly_return_types=tuple({jit.S(type(None))}), 147 | ) 148 | ``` 149 | 150 | 151 | `jit.intrinsic("PyList_Append")` mentioned in above code means the intrinsic provided by the Julia codegen backend. 152 | Usually it's calling a CPython C API, but sometimes may not. 153 | 154 | No matter if it is an existing CPython C API, we can implement intrinsics in Julia. 155 | 156 | - [import PyList_Append symbol](https://github.com/thautwarm/DIO.jl/blob/c3ec304645437da6bb02c9e5acb0c91e5e3800a8/src/symbols.jl#L53) 157 | 158 | - [generate PyList_Append calling convention](https://github.com/thautwarm/DIO.jl/blob/5fa79357798ff3eaee561d14d4f04a271213282c/src/dynamic.jl#L120): 159 | 160 | ```julia 161 | @autoapi PyList_Append(PyPtr, PyPtr)::Cint != Cint(-1) cast(_cint2none) nocastexc 162 | ``` 163 | 164 | As a consequence, we automatically generate an instrinsic function for DIO-JIT. This intrinsic function 165 | is capable of handling CPython exception and reference counting. 166 | 167 | You can either do step 2) at Python side. It might looks more intuitive. 168 | 169 | ```python 170 | import diojit as jit 171 | from diojit.runtime.julia_rt import jl_eval 172 | jl_implemented_intrinsic = """ 173 | function PyList_Append(lst::Ptr, elt::PyPtr) 174 | if ccall(PyAPI.PyList_Append, Cint, (PyPtr, PyPtr), lst, elt) == -1 175 | return Py_NULL 176 | end 177 | nothing # automatically maps to a Python None 178 | end 179 | DIO.DIO_ExceptCode(::typeof(PyList_Append)) != Py_NULL 180 | """ 181 | jl_eval(jl_implemented_intrinsic) 182 | ``` 183 | 184 | You immediately get a >**100%** time speed up: 185 | 186 | ```python 187 | @jit.jit 188 | def append3(xs, x): 189 | xs.append(x) 190 | xs.append(x) 191 | xs.append(x) 192 | 193 | jit_append3 = jit.spec_call(append3, jit.oftype(list), jit.Top) # 'Top' means 'Any' 194 | xs = [1] 195 | jit_append3(xs, 3) 196 | 197 | print("test jit_append3, [1] append 3 for 3 times:", xs) 198 | # test jit func, [1] append 3 for 3 times: [1, 3, 3, 3] 199 | 200 | xs = [] 201 | %timeit append3(xs, 1) 202 | # 293 ns ± 26.2 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) 203 | 204 | xs = [] 205 | %timeit jit_append3(xs, 1) 206 | # 142 ns ± 14.9 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each) 207 | ``` 208 | 209 | ## Why Julia? 210 | 211 | We don't want to maintain a C compiler, and calling `gcc` or others will introduce cross-process IO, which is slow. 212 | We prefer compiling JITed code with LLVM, and **Julia is quite a killer tool for this use case**. 213 | 214 | ## Current Limitations 215 | 216 | 1. Support for `*varargs` and `**kwargs` are not ready: we do can immediately support them with very tiny JIT performance gain, but considering backward compatibility we decide not to do this. 217 | 218 | 2. Exception handling is not yet supported inside JIT functions. 219 | 220 |
Why? 221 |

222 | 223 | We haven't implemented the translation from exception handling bytecode to untyped DIO IR (`jit.absint.abs.In_Stmt`). 224 | 225 |

226 |
227 | 228 |
Will support? 229 |

230 | 231 | Yes. 232 | 233 | In fact, now a callsite in any JIT function can raise an exception. It will not be handled by JIT functions, instead, it is lifted up to the root call, which is a pure Python call. 234 | 235 | Exception handling will be supported when we have efforts on translating CPython bytecode about exception handling into untyped DIO IR (`jit.absint.abs.In_Stmt`). 236 | 237 | P.S: This will be finished simultaneously with the support for `for` loop. 238 | 239 |

240 |
241 | 242 | 3. Support for `for` loop is missing. 243 | 244 |
Why? 245 |

246 | 247 | Firstly, in CPython, `for` loop relies on exception handling, which is not supported yet. 248 | 249 | Secondly, we're considering a fast path for `for` loop, maybe proposing a `__citer__` protocol for faster iteration for JIT functions, which requires communications with Python developers. 250 | 251 |

252 |
253 | 254 |
Will support? 255 |

256 | 257 | Yes. 258 | 259 | This will be finished simultaneously with support for exception handling (faster `for` loop might come later). 260 | 261 |

262 |
263 | 264 | 4. Closure support is missing. 265 | 266 |
Why? 267 |

268 | 269 | In imperative languages, closures use *cell* structures to achieve mutable free/cell variables. 270 | 271 | However, a writable cell makes it hard to optimise in a dynamic language. 272 | 273 | We recommend using `types.MethodType` to create immutable closures,which can be highly optimised in DIO-JIT(near future). 274 | 275 | ```python 276 | import types 277 | def f(freevars, z): 278 | x, y = freevars 279 | return x + y + z 280 | 281 | def hof(x, y): 282 | return types.MethodType(f, (x, y)) 283 | ``` 284 | 285 |

286 |
287 | 288 |
Will support? 289 |

290 | 291 | Still yes. However, don't expect much about the performance gain for Python's vanilla closures. 292 | 293 |

294 |
295 | 296 | 5. Specifying fixed global references(`@diojit.jit(fixed_references=['isinstance', 'str', ...]`) too annoying? 297 | 298 | Sorry, you have to. We are thinking about the possibility about automatic JIT covering all existing CPython code, but the biggest impediment is the volatile global variables. 299 | 300 | You might use `@eagerjit`, and in this case you'd be cautious in making global variables unchangeable. 301 | 302 |
Possibility? 303 |

304 | 305 | Recently we found CPython's newly(`:)`) added feature `Dict.ma_version_tag` might be used to automatically notifying JITed functions to re-compile when the global references change. 306 | 307 | More research is required. 308 | 309 |

310 |
311 | 312 | ## Contributions 313 | 314 | 1. Add more prescribed specialisation rules at `jit.absint.prescr`. 315 | 2. TODO 316 | 317 | ## Benchmarks 318 | 319 | Check `benchmarks` directory. 320 | -------------------------------------------------------------------------------- /README.zh_CN.md: -------------------------------------------------------------------------------- 1 | ## DIO-JIT: Python的泛用jit 2 | 3 | [![README](https://img.shields.io/badge/i18n-English-teal)](https://github.com/thautwarm/diojit/blob/master/README.zh_CN.md) [![PyPI version shields.io](https://img.shields.io/pypi/v/diojit.svg)](https://pypi.python.org/pypi/diojit/) 4 | [![JIT](https://img.shields.io/badge/cpython-3.8|3.9-green.svg)](https://pypi.python.org/pypi/diojit/) 5 | 6 | DIO-JIT是一种 method JIT, 在抽象解释和调用点特化下成为可能。抽象解释由编译器实现,而特化规则可以扩展式地注册(例见`jit.absint.prescr`)。 7 | 8 | Important: 9 | 10 | 1. 注意, DIO-JIT目前只在Python>=3.8时工作。我们高度依赖Python 3.8之后的`LOAD_METHOD`字节码指令。 11 | 2. 在多数情况下来看,目前DIO-JIT不适合生产环境。我们还需要提供更多的特化规则,来让DIO-JIT变得开箱即用。 12 | 3. 这个文档主要是为开发者提供的。用户不需要了解如何写特化规则,只需要使用`jit.jit(func_obj)`和`jit.spec_call(func_obj, arg_specs...)`。 13 | 14 | ## Benchmark 15 | 16 | | Item | PY38 | JIT PY38 | PY39 | JIT PY39 | 17 | | ------------------------------------------------------------------------------------------ | ------ | -------- | ------ | -------- | 18 | | [BF](https://github.com/thautwarm/diojit/blob/master/benchmarks/brainfuck.py) | 265.74 | 134.23 | 244.50 | 140.34 | 19 | | [append3](https://github.com/thautwarm/diojit/blob/master/benchmarks/append3.py) | 23.94 | 10.70 | 22.29 | 11.21 | 20 | | [DNA READ](https://github.com/thautwarm/diojit/blob/master/benchmarks/dna_read.py) | 16.96 | 14.82 | 15.03 | 14.38 | 21 | | [fib(15)](https://github.com/thautwarm/diojit/blob/master/benchmarks/fib.py) | 11.63 | 1.54 | 10.41 | 1.51 | 22 | | [hypot(str, str)](https://github.com/thautwarm/diojit/blob/master/benchmarks/hypot.py) | 6.19 | 3.87 | 6.53 | 4.29 | 23 | | [selectsort](https://github.com/thautwarm/diojit/blob/master/benchmarks/selection_sort.py) | 46.95 | 33.88 | 38.71 | 29.49 | 24 | | [trans](https://github.com/thautwarm/diojit/blob/master/benchmarks/trans.py) | 24.22 | 7.79 | 23.23 | 7.71 | 25 | 26 | The bechmark item "DNA READ" does not show a significant performance gain, this is because "DNA READ" heavily uses `bytearray` and `bytes`, whose specialised C-APIs 27 | are not exposed. In this case, although the JIT can infer the types, we have to fall back to CPython's default behaviour, or even worse: after all, the interpreter can access internal things, while we cannot. 28 | 29 | P.S: 30 | DIO-JIT可以做聪明的部分求值, 但为了编译器的快速收敛,online常量折叠默认是关闭的。 31 | 你可以在领域特定任务中使用这个能力。 这里有一个对cpython提速 **500倍**的例子: [fibs.py](https://github.com/thautwarm/diojit/blob/master/benchmarks/const_fib.py) 32 | 33 | ## 安装 34 | 35 |
1: 安装Julia(我们的"底层代码编译服务"提供者) 36 |

37 | 38 | 推荐以如下方式安装Julia: 39 | 40 | - [scoop](http://scoop.sh/) (Windows) 41 | - [julialang.org](https://cn.julialang.org/downloads/) (Windows) 42 | - [jill.py](https://github.com/johnnychen94/jill.py) (跨平台,但安装路径不符合Windows上Unix用户习惯): 43 | 44 | ```bash 45 | $ pip install jill && jill install 1.6 46 | ``` 47 | 48 | - [jill](https://github.com/abelsiqueira/jill) (Mac and Linux): 49 | 50 | ```bash 51 | $ bash -ci "$(curl -fsSL https://raw.githubusercontent.com/abelsiqueira/jill/master/jill.sh)" 52 | ``` 53 | 54 |

55 |
56 | 57 |
2: 在Julia中安装 DIO.jl 58 |

59 | 60 | 输入 `julia` 打开REPL 61 | 62 | ```julia 63 | julia> 64 | # 按下 ] 65 | pkg> add https://github.com/thautwarm/DIO.jl 66 | # 按下 backspace 键 67 | julia> using DIO # 预编译 68 | ``` 69 | 70 |

71 |
72 | 73 |
3: 安装Python 74 |

75 | 76 | ```bash 77 | $ pip install git+https://github.com/thautwarm/diojit 78 | ``` 79 | 80 |

81 |
82 | 83 |

84 | 85 | 86 |
如何获取最新的DIO-JIT?(需安装过DIO-JIT) 87 |

88 | 89 | ```bash 90 | $ pip install -U diojit 91 | $ julia -e "using Pkg; Pkg.update(string(:DIO));using DIO" 92 | ``` 93 | 94 |

95 |
96 | 97 | 从Python端使用DIO-JIT和使用Numba类似: 98 | 99 | ```python 100 | import diojit 101 | from math import sqrt 102 | # eagerjit: 假设所有全局引用不变 103 | @diojit.eagerjit 104 | def fib(a): 105 | if a <= 2: 106 | return 1 107 | return fib(a + -1) + fib(a + -2) 108 | 109 | jit_fib = diojit.spec_call(fib, diojit.oftype(int), diojit.oftype(int)) 110 | jit_fib(15) # 比原生Python快600%以上 111 | ``` 112 | 113 | 你可能会问, 为什么上面的代码要使用`a + -1`和`a + -2`这么迷惑的写法? 114 | 115 | 你get了重点! 116 | 117 | 我们的jit依赖于已有的特化规则。我们已经为加法,具体的说是`operator.__add__`实现了特化规则: [`operator.__add__`的特化规则](https://github.com/thautwarm/diojit/blob/175aab5f4cb65fee923b9f6cb97c256252fc49f5/diojit/absint/prescr.py#L226). 118 | 119 | (P.S: [为啥是 `operator.__add__`](https://github.com/thautwarm/diojit/blob/3ceb9513377234f476566f70792632ce08c13373/diojit/stack2reg/translate.py#L30).) 120 | 121 | 但因为个人精力有限,目前还没有为`operator.__sub__`实现对应的规则。 122 | 123 | 虽然特化是非常常见的优化技术,但与很多的Python JIT不同的是,DIO-JIT并不需要在编译器层面内建特别的优化。DIO-JIT编译器只负责实现一个抽象解释的算法, 124 | 而更具体的推导、特化规则,在Python里就可以扩展式地添加! 125 | 126 | 下面是一个例子。 127 | 128 | ## 代码贡献案例: 为`list.append`注册特化规则 129 | 130 | 步骤1:Python端如下代码 131 | 132 | ```python 133 | import diojit as jit 134 | import timeit 135 | jit.create_shape(list, oop=True) 136 | @jit.register(list, attr="append") 137 | def list_append_analysis(self: jit.Judge, *args: jit.AbsVal): 138 | if len(args) != 2: 139 | # rollback to CPython's default code 140 | return NotImplemented 141 | lst, elt = args 142 | 143 | return jit.CallSpec( 144 | instance=None, # return value is not static 145 | e_call=jit.S(jit.intrinsic("PyList_Append"))(lst, elt), 146 | possibly_return_types=tuple({jit.S(type(None))}), 147 | ) 148 | ``` 149 | 150 | 上面的`jit.intrinsic("PyList_Append")`指的是JIT后端提供的底层原语,它通常是调用CPython的C API。 151 | 152 | 我们可以在Julia里面实现这些底层原语。 153 | 154 | 步骤2: Julia端 155 | 156 | - [导入PyList_Append符号](https://github.com/thautwarm/DIO.jl/blob/c3ec304645437da6bb02c9e5acb0c91e5e3800a8/src/symbols.jl#L53) 157 | 158 | - [生成PyList_Append的调用约定](https://github.com/thautwarm/DIO.jl/blob/5fa79357798ff3eaee561d14d4f04a271213282c/src/dynamic.jl#L120): 159 | 160 | ```julia 161 | @autoapi PyList_Append(PyPtr, PyPtr)::Cint != Cint(-1) cast(_cint2none) nocastexc 162 | ``` 163 | 164 | 这样一来,我们就自动生成了一个能够处理CPython错误处理和引用计数的原语函数。 165 | 166 | 实际上,你也可以在Python端手动实现步骤2,没有宏看起来可能会更直观一些: 167 | 168 | ```python 169 | import diojit as jit 170 | from diojit.runtime.julia_rt import jl_eval 171 | jl_implemented_intrinsic = b""" 172 | function PyList_Append(lst::Ptr, elt::PyPtr) 173 | if ccall(PyAPI.PyList_Append, Cint, (PyPtr, PyPtr), lst, elt) == -1 174 | return Py_NULL 175 | end 176 | nothing # automatically maps to a Python None 177 | end 178 | DIO.DIO_ExceptCode(::typeof(PyList_Append)) != Py_NULL 179 | """ 180 | jl_eval(jl_implemented_intrinsic) 181 | ``` 182 | 183 | 我们立即得到大于**100%**的性能提升。 184 | 185 | ```python 186 | @jit.jit 187 | def append3(xs, x): 188 | xs.append(x) 189 | xs.append(x) 190 | xs.append(x) 191 | 192 | jit_append3 = jit.spec_call(append3, jit.oftype(list), jit.Top) # 'Top' means 'Any' 193 | xs = [1] 194 | jit_append3(xs, 3) 195 | 196 | print("test jit_append3, [1] append 3 for 3 times:", xs) 197 | # test jit func, [1] append 3 for 3 times: [1, 3, 3, 3] 198 | 199 | xs = [] 200 | %timeit append3(xs, 1) 201 | # 293 ns ± 26.2 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) 202 | 203 | xs = [] 204 | %timeit jit_append3(xs, 1) 205 | # 142 ns ± 14.9 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each) 206 | ``` 207 | 208 | ## 为什么要用Julia? 209 | 210 | 我们不想维护一个C编译器,并且调用`gcc`这样的操作会引入跨进程的IO。这好吗?这不好。 211 | 212 | 我们倾向于使用LLVM在运行时编译底层代码,而基于其上的Julia提供了对其底层架构的自由访问,从而**成为了运行时编译的杀手级应用**。 213 | 214 | ## 现状和限制 215 | 216 | 1. **暂**未支持不定参数和关键字参数。我们可以立刻支持它们,并提供较小的JIT性能提升。但由此可能引来后向兼容的问题,权衡之下还是暂且搁置。 217 | 218 | 2. **暂**不支持在JIT函数中处理异常. 219 | 220 |
??? 221 |

222 | 223 | 还未实现从相关的CPython字节码到无类型DIO IR的转译(`jit.absint.abs.In_Stmt`) 224 | 225 |

226 |
227 | 228 |
会支持吗? 229 |

230 | 231 | 会的。 232 | 233 | 实际上,目前JIT函数内部的调用可以正常抛错。这样的错误无法被JIT函数处理,而是被交给更上层。 234 | 235 | 在我们有精力实现对应的(错误处理)字节码到无类型 DIO IR的转译后,JIT函数中将可以做错误处理。 236 | 237 | P.S: 这会和`for`循环的支持同时实现。 238 | 239 |

240 |
241 | 242 | 3. **暂**不支持`for`循环. 243 | 244 |
??? 245 |

246 | 247 | 首先,在CPython中,`for`循环的视线依赖错误处理,而这目前还未支持。 248 | 249 | 其次,我们在考虑一个更高效的`for`循环实现,可能会提议一个`__citer__`协议用以JIT函数的优化。而这需要和Python开发者进一步探讨。 250 | 251 |

252 |
253 | 254 |
会支持吗? 255 |

256 | 257 | 嗯。 258 | 259 | 这会和错误处理同时实现。快速`for`可能会引入得更晚一些。 260 | 261 |

262 |
263 | 264 | 4. 未支持Python的原生闭包 265 | 266 |
??? 267 |

268 | 269 | 在命令式语言中,闭包使用一种叫`cell`的数据结构来实现可变(mutable)的自由变量(free variables)。 270 | 271 | 然而,在动态语言里边,优化可写的闭包是一个相当困难的问题。 272 | 273 | 我们建议你使用`types.MethodType`创建自由变量不可变的闭包,这是DIO-JIT(很快就能)高效优化的写法。 274 | 275 | ```python 276 | import types 277 | def f(freevars, z): 278 | x, y = freevars 279 | return x + y + z 280 | 281 | def hof(x, y): 282 | return types.MethodType(f, (x, y)) 283 | ``` 284 | 285 |

286 |
287 | 288 |
会支持吗? 289 |

290 | 291 | 会是会的,毕竟我们的目的是覆盖所有的CPython代码。 292 | 293 | 但对此不要期待很大性能提升。 294 | 295 |

296 |
297 | 298 | 5. 手动指定不可变全局变量太啰嗦了?(`@diojit.jit(fixed_references=['isinstance', 'str', ...]`) 299 | 300 | 很遗憾,你得这样。我们在考虑全自动JIT的可能性,但Python任人随意修改的全部变量是这个目标最大的阻碍。 301 | 302 |
不写行不行? 303 |

304 | 305 | 可能会可以的。 306 | 307 | 近期CPython优化了存储全局变量的字典。字典的内存布局多了一个名为`ma_version_tag`的数字,用以指示字典最近被写入过。这个改动可能可以用来触发JIT函数的重编译。 308 | 309 | 这还需要更多的研究。 310 | 311 |

312 |
313 | 314 | ## Contributions 315 | 316 | 1. 发挥你的聪明才智,展示你对Python语义的了解,为DIO-JIT添加更多的特化规则吧! 317 | 318 | 例见`jit.absint.prescr`。 319 | 2. TODO 320 | 321 | ## Benchmarks 322 | 323 | 康康 `benchmarks` 文件夹。 324 | -------------------------------------------------------------------------------- /benchmarks/append3.py: -------------------------------------------------------------------------------- 1 | """ 2 | append3 (jit) bench time: 1.4360911 3 | append3 (pure py) bench time: 2.9313146 4 | """ 5 | import diojit as jit 6 | from inspect import getsource 7 | from diojit.runtime.julia_rt import jl_eval, splice 8 | import timeit 9 | 10 | 11 | @jit.jit 12 | def append3(xs, x): 13 | xs.append(x) 14 | xs.append(x) 15 | xs.append(x) 16 | 17 | 18 | print("append3".center(70, "=")) 19 | print(getsource(append3)) 20 | 21 | # jit.In_Def.UserCodeDyn[append3].show() 22 | jit_append3 = jit.spec_call( 23 | append3, 24 | jit.oftype(list), 25 | jit.Top, 26 | # print_dio_ir=print, 27 | ) 28 | x = [1] 29 | y = 2 30 | # jl_eval(f"println(J_append3_0({splice(x)}, {splice(y)}))") 31 | # raise 32 | xs = [1] 33 | jit_append3(xs, 3) 34 | print("test jit func: [1] append 3 for 3 times =", xs) 35 | 36 | 37 | xs = [] 38 | print( 39 | "append3 (jit) bench time:", 40 | timeit.timeit( 41 | "f(xs, 1)", globals=dict(f=jit_append3, xs=xs), number=100000000 42 | ), 43 | ) 44 | xs = [] 45 | print( 46 | "append3 (pure py) bench time:", 47 | timeit.timeit( 48 | "f(xs, 1)", globals=dict(f=append3, xs=xs), number=100000000 49 | ), 50 | ) 51 | -------------------------------------------------------------------------------- /benchmarks/brainfuck.py: -------------------------------------------------------------------------------- 1 | """ 2 | jit time 140.34073400497437 3 | pure py time 244.5001037120819 4 | 23280 ==? 23280 5 | """ 6 | import platform 7 | import socket 8 | import sys 9 | import time 10 | 11 | import os 12 | import itertools 13 | from pathlib import Path 14 | import diojit as jit 15 | import typing 16 | 17 | sys.setrecursionlimit(2200) 18 | print("brainfuck".center(50, "=")) 19 | 20 | 21 | INC = 1 22 | MOVE = 2 23 | LOOP = 3 24 | PRINT = 4 25 | 26 | OP = 0 27 | VAL = 1 28 | 29 | 30 | @jit.eagerjit 31 | class Op(object): 32 | op: int 33 | val: typing.Union[int, list] 34 | 35 | def __init__(self, op, val): 36 | assert isinstance(op, int) and isinstance(val, (int, list)) 37 | self.op = op 38 | self.val = val 39 | 40 | 41 | @jit.eagerjit 42 | class Tape(object): 43 | tape: list 44 | pos: int 45 | 46 | def __init__(self): 47 | self.tape = [0] 48 | self.pos = 0 49 | 50 | def get(self): 51 | return self.tape[self.pos] 52 | 53 | def inc(self, x): 54 | self.tape[self.pos] += x 55 | 56 | def move(self, x): 57 | self.pos += x 58 | while self.pos >= len(self.tape): 59 | self.tape.extend(itertools.repeat(0, len(self.tape))) 60 | 61 | 62 | @jit.eagerjit 63 | class Printer(object): 64 | sum1: int 65 | sum2: int 66 | quiet: bool 67 | 68 | def __init__(self, quiet): 69 | self.sum1 = 0 70 | self.sum2 = 0 71 | self.quiet = quiet 72 | 73 | def print(self, n): 74 | if self.quiet: 75 | self.sum1 = (self.sum1 + n) % 255 76 | self.sum2 = (self.sum2 + self.sum1) % 255 77 | else: 78 | sys.stdout.write(chr(n)) 79 | sys.stdout.flush() 80 | 81 | @property 82 | def checksum(self): 83 | return (self.sum2 << 8) | self.sum1 84 | 85 | 86 | def parse(iterator): 87 | res = [] 88 | while True: 89 | try: 90 | c = iterator.__next__() 91 | except StopIteration: 92 | break 93 | 94 | if c == "+": 95 | res.append(Op(INC, 1)) 96 | elif c == "-": 97 | res.append(Op(INC, -1)) 98 | elif c == ">": 99 | res.append(Op(MOVE, 1)) 100 | elif c == "<": 101 | res.append(Op(MOVE, -1)) 102 | elif c == ".": 103 | res.append(Op(PRINT, 0)) 104 | elif c == "[": 105 | res.append(Op(LOOP, parse(iterator))) 106 | elif c == "]": 107 | break 108 | 109 | return res 110 | 111 | 112 | @jit.eagerjit 113 | def _run(program, tape, p): 114 | i = 0 115 | n = len(program) 116 | while i < n: 117 | op = program[i] 118 | i += 1 119 | if op.op == INC: 120 | tape.inc(op.val) 121 | elif op.op == MOVE: 122 | tape.move(op.val) 123 | elif op.op == LOOP: 124 | while tape.get() > 0: 125 | _run(op.val, tape, p) 126 | elif op.op == PRINT: 127 | p.print(tape.get()) 128 | 129 | 130 | class Program(object): 131 | def __init__(self, code): 132 | self.ops = parse(iter(code)) 133 | 134 | def run(self, p, use_jit: bool): 135 | if use_jit: 136 | _run_jit = jit.spec_call( 137 | _run, 138 | jit.oftype(list), 139 | jit.oftype(Tape), 140 | jit.oftype(Printer), 141 | print_dio_ir=print, 142 | ) 143 | _run_jit(self.ops, Tape(), p) 144 | else: 145 | _run(self.ops, Tape(), p) 146 | 147 | 148 | p1 = Printer(True) 149 | p2 = Printer(True) 150 | 151 | prog = Program( 152 | """>++[<+++++++++++++>-]<[[>+>+<<-]>[<+>-]++++++++ 153 | [>++++++++<-]>.[-]<<>++++++++++[>++++++++++[>++ 154 | ++++++++[>++++++++++[>++++++++++[>++++++++++[>+ 155 | +++++++++[-]<-]<-]<-]<-]<-]<-]<-]++++++++++.""" 156 | ) 157 | n = time.time() 158 | prog.run(p2, use_jit=True) 159 | print("jit time", time.time() - n) 160 | 161 | n = time.time() 162 | prog.run(p1, use_jit=False) 163 | print("pure py time", time.time() - n) 164 | 165 | 166 | print(p1.checksum, "==?", p2.checksum) 167 | -------------------------------------------------------------------------------- /benchmarks/const_fib.py: -------------------------------------------------------------------------------- 1 | import diojit as jit 2 | import operator 3 | import timeit 4 | 5 | 6 | def try_const_add(a, b): 7 | return a + b 8 | 9 | 10 | def try_const_le(a, b): 11 | return a <= b 12 | 13 | 14 | def try_const_eq(a, b): 15 | return a == b 16 | 17 | 18 | @jit.register(try_const_add, create_shape=True) 19 | def call_const_add(self: jit.Judge, *args: jit.AbsVal): 20 | if len(args) != 2: 21 | return NotImplemented 22 | a, b = args 23 | if a.is_s() and b.is_s(): 24 | const = jit.S(operator.add(a.base, b.base)) 25 | ret_types = (jit.S(type(const.base)),) 26 | return jit.CallSpec(const, const, ret_types) 27 | 28 | return self.spec(jit.S(operator.__add__), "__call__", list(args)) 29 | 30 | 31 | @jit.register(try_const_le, create_shape=True) 32 | def call_const_le(self: jit.Judge, *args: jit.AbsVal): 33 | if len(args) != 2: 34 | return NotImplemented 35 | a, b = args 36 | if a.is_literal() and b.is_literal(): 37 | const = jit.S(operator.__le__(a.base, b.base)) 38 | ret_types = (jit.S(type(const.base)),) 39 | return jit.CallSpec(const, const, ret_types) 40 | 41 | return self.spec(jit.S(operator.__le__), "__call__", list(args)) 42 | 43 | 44 | @jit.register(try_const_eq, create_shape=True) 45 | def call_const_eq(self: jit.Judge, *args: jit.AbsVal): 46 | if len(args) != 2: 47 | return NotImplemented 48 | a, b = args 49 | if a.is_literal() and b.is_literal(): 50 | const = jit.S(operator.__eq__(a.base, b.base)) 51 | ret_types = (jit.S(type(const.base)),) 52 | return jit.CallSpec(const, const, ret_types) 53 | 54 | return self.spec(jit.S(operator.__eq__), "__call__", list(args)) 55 | 56 | 57 | @jit.eagerjit 58 | def fib(x): 59 | if x <= 2: 60 | return 1 61 | return fib(x + -1) + fib(x + -2) 62 | 63 | 64 | jit_fib = jit.spec_call(fib, jit.oftype(int)) 65 | 66 | 67 | @jit.eagerjit 68 | def try_const_fib(x): 69 | if try_const_le(x, 2): 70 | return 1 71 | return try_const_add( 72 | try_const_fib( 73 | try_const_add(x, -1), 74 | ), 75 | try_const_fib( 76 | try_const_add(x, -2), 77 | ), 78 | ) 79 | 80 | 81 | try_const_fib = jit.spec_call(try_const_fib, jit.ofval(20)) 82 | 83 | 84 | def bench(kind, f, number=2000): 85 | print( 86 | kind, 87 | timeit.timeit( 88 | "fib(x)", 89 | globals=dict(fib=f, x=20), 90 | number=number, 91 | ), 92 | ) 93 | 94 | 95 | print(fib(20)) 96 | print(jit_fib(20)) 97 | print(try_const_fib(20)) 98 | 99 | bench("pure py fib", fib) 100 | bench("jit fib", jit_fib) 101 | bench("jit fib_fast", try_const_fib) 102 | -------------------------------------------------------------------------------- /benchmarks/dna_read.py: -------------------------------------------------------------------------------- 1 | """ 2 | [Python39] 3 | jit: 7.4767213 4 | pure py: 7.8394832 5 | 6 | [Python38] 7 | 7.8734842 8 | 8.994623599999999 9 | """ 10 | import sys, string 11 | import requests 12 | from io import BytesIO 13 | from tempfile import mktemp 14 | import timeit 15 | 16 | import diojit as jit 17 | from diojit.codegen.julia import splice 18 | from diojit.runtime.julia_rt import check_jl_err, get_libjulia 19 | 20 | print('DNA READ'.center(50, '=')) 21 | 22 | libjl = get_libjulia() 23 | contents = requests.get( 24 | r"https://raw.githubusercontent.com/dundee/pybenchmarks/master/bencher/data/revcomp-input1000.txt" 25 | ).text.encode() 26 | table = bytes.maketrans( 27 | b"ACBDGHK\nMNSRUTWVYacbdghkmnsrutwvy", 28 | b"TGVHCDM\nKNSYAAWBRTGVHCDMKNSYAAWBR", 29 | ) 30 | 31 | 32 | def jl_eval(s: str): 33 | libjl.jl_eval_string(s.encode()) 34 | check_jl_err(libjl) 35 | 36 | 37 | @jit.jit(fixed_references=["len"]) 38 | def show(out_io: bytearray, seq): 39 | # FIXME: optimisation point 40 | seq_ = (b"".join(seq)).translate(table)[::-1] 41 | i = 0 42 | n = len(seq_) 43 | while i < n: 44 | out_io.extend(seq_[i : i + 60]) 45 | i += 60 46 | 47 | 48 | # @jit.jit 49 | def main(out_io: bytearray): 50 | in_io = BytesIO(contents) 51 | seq = [] 52 | for line in in_io: 53 | if line[0] in b">;": 54 | show(out_io, seq) 55 | out_io.extend(line) 56 | del seq[:] 57 | else: 58 | seq.append(line[:-1]) 59 | show(out_io, seq) 60 | return out_io 61 | 62 | 63 | @jit.jit(fixed_references=["next", "BytesIO", "show"]) 64 | def main2(out_io: bytearray): 65 | in_io = BytesIO(contents) 66 | seq = [] 67 | while True: 68 | line = next(in_io, None) 69 | if line is None: 70 | break 71 | if line[0] in b">;": 72 | show(out_io, seq) 73 | out_io.extend(line) 74 | del seq[:] 75 | else: 76 | seq.append(line[:-1]) 77 | show(out_io, seq) 78 | return out_io 79 | 80 | 81 | # @jit.jit 82 | # def b(): 83 | # return [print, set, list] 84 | # for i in range(1000): 85 | # print(jit.jit_spec_call(b)()) 86 | # b() 87 | 88 | jit_main = jit.spec_call( 89 | main2, 90 | jit.oftype(bytearray), 91 | # print_dio_ir=print, 92 | ) 93 | # raise 94 | # 95 | x = bytearray() 96 | jl_eval(f"J_main2_0({splice(x)})") 97 | # raise 98 | print(main(bytearray()) == jit_main(bytearray())) 99 | print( 100 | 'jit', 101 | timeit.timeit( 102 | "main(bytearray())", 103 | number=200000, 104 | globals=dict(main=jit_main), 105 | ), 106 | ) 107 | 108 | print( 109 | 'pure py', 110 | timeit.timeit( 111 | "main(bytearray())", 112 | number=200000, 113 | globals=dict(main=main), 114 | ), 115 | ) 116 | -------------------------------------------------------------------------------- /benchmarks/fib.py: -------------------------------------------------------------------------------- 1 | """ 2 | fib(15) (py) bench time: 1.3318193000000003 3 | fib(15) (jit+untyped) bench time: 0.42067140000000025 4 | fib(15) (jit+inferred) bench time: 0.1776359000000003 5 | """ 6 | import diojit as jit 7 | from inspect import getsource 8 | import timeit 9 | from diojit.runtime.julia_rt import splice, jl_eval 10 | 11 | 12 | def fib(a): 13 | if a <= 2: 14 | return 1 15 | return fib(a - 1) + fib(a - 2) 16 | 17 | 18 | @jit.jit(fixed_references=["fib_fix"]) 19 | def fib_fix(a): 20 | if a <= 2: 21 | return 1 22 | return fib_fix(a + -1) + fib_fix(a + -2) 23 | 24 | 25 | jit_fib_fix_typed = jit.spec_call( 26 | fib_fix, 27 | jit.oftype(int), 28 | # print_jl=print, 29 | ) 30 | jit_fib_fix_untyped = jit.spec_call(fib_fix, jit.Top) 31 | jl_eval(f"println(J_fib__fix_1({splice(20)}))") 32 | # check_jl_err(libjl) 33 | print("fib".center(70, "=")) 34 | print(getsource(fib)) 35 | print( 36 | "fib(15), jit_fib_fix_untyped(15), jit_fib_fix_typed(15) = ", 37 | (fib(15), jit_fib_fix_untyped(15), jit_fib_fix_typed(15)), 38 | ) 39 | print( 40 | "fib(py) bench time:", 41 | timeit.timeit("f(15)", globals=dict(f=fib), number=100000), 42 | ) 43 | print( 44 | "fib(jit+untyped) bench time:", 45 | timeit.timeit( 46 | "f(15)", globals=dict(f=jit_fib_fix_untyped), number=100000 47 | ), 48 | ) 49 | print( 50 | "fib(jit+inferred) bench time:", 51 | timeit.timeit( 52 | "f(15)", globals=dict(f=jit_fib_fix_typed), number=100000 53 | ), 54 | ) 55 | -------------------------------------------------------------------------------- /benchmarks/hypot.py: -------------------------------------------------------------------------------- 1 | """ 2 | hypot (pure py) bench time: 0.7044676000000001 3 | hypot (jit) bench time: 0.4247455999999996 4 | """ 5 | import diojit as jit 6 | from inspect import getsource 7 | import timeit 8 | from math import sqrt 9 | 10 | 11 | @jit.jit(fixed_references=["sqrt", "str", "int", "isinstance"]) 12 | def hypot(x, y): 13 | if isinstance(x, str): 14 | x = int(x) 15 | 16 | if isinstance(y, str): 17 | y = int(y) 18 | 19 | return sqrt(x ** 2 + y ** 2) 20 | 21 | 22 | print(getsource(hypot)) 23 | 24 | 25 | # print("Direct Translation From Stack Instructions".center(70, "=")) 26 | 27 | # jit.absint.In_Def.UserCodeDyn[hypot].show() 28 | # print("After JITing".center(70, "=")) 29 | 30 | 31 | jit_func_name = repr( 32 | jit.spec_call_ir( 33 | hypot, jit.S(int), jit.S(int) 34 | ).e_call.func 35 | ) 36 | 37 | 38 | hypot_spec = jit.spec_call( 39 | hypot, 40 | jit.oftype(int), 41 | jit.oftype(int), 42 | # print_jl=print, 43 | # print_dio_ir=print, 44 | ) 45 | # # 46 | # libjl = jit.runtime.julia_rt.get_libjulia() 47 | # libjl.jl_eval_string(f'using InteractiveUtils;@code_llvm {jit_func_name}(PyO.int, PyO.int)'.encode()) 48 | # jit.runtime.julia_rt.check_jl_err(libjl) 49 | 50 | print("hypot(1, 2) (jit) = ", hypot_spec(1, 2)) 51 | print("hypot(1, 2) (pure py) = ", hypot(1, 2)) 52 | print( 53 | "hypot (pure py) bench time:", 54 | timeit.timeit("f(1, 2)", number=10000000, globals=dict(f=hypot)), 55 | ) 56 | print( 57 | "hypot (jit) bench time:", 58 | timeit.timeit( 59 | "f(1, 2)", number=10000000, globals=dict(f=hypot_spec) 60 | ), 61 | ) 62 | -------------------------------------------------------------------------------- /benchmarks/selection_sort.py: -------------------------------------------------------------------------------- 1 | """ 2 | pure py: 5.2469079999999995 3 | jit: 3.6917899 4 | >40% performance gain. 5 | (but if we can have a strict generic list type in Python, 6 | we can have a 600% performance gain. 7 | """ 8 | import diojit as jit 9 | import timeit 10 | import numpy as np 11 | from diojit.runtime.julia_rt import check_jl_err 12 | from diojit.codegen.julia import splice 13 | import sys 14 | 15 | print('selection sort'.center(50, '=')) 16 | sys.setrecursionlimit(2000) 17 | 18 | libjl = jit.runtime.julia_rt.get_libjulia() 19 | 20 | 21 | def jl_eval(s: str): 22 | libjl.jl_eval_string(s.encode()) 23 | check_jl_err(libjl) 24 | 25 | 26 | @jit.jit(fixed_references=["int"]) 27 | def lt(e, min_val): 28 | return int(e) < int(min_val) 29 | 30 | 31 | @jit.jit(fixed_references=["lt", "int"]) 32 | def argmin(xs, i, n): 33 | min_val = xs[i] # int(xs[i]) 34 | min_i = i 35 | j = i + 1 36 | while j < n: 37 | e = xs[j] # int(xs[j]) 38 | # if lt(e, min_val): 39 | if e < min_val: 40 | min_i = j 41 | min_val = e 42 | j = j + 1 43 | return min_i 44 | 45 | 46 | @jit.jit 47 | def swap(xs, min_i, i): 48 | xs[min_i], xs[i] = xs[i], xs[min_i] 49 | 50 | 51 | @jit.jit(fixed_references=["argmin", "swap", "len"]) 52 | def msort(xs): 53 | xs = xs.copy() 54 | n = len(xs) 55 | i = 0 56 | while i < n: 57 | min_i = argmin(xs, i, n) 58 | swap(xs, min_i, i) 59 | i = i + 1 60 | return xs 61 | 62 | 63 | @jit.jit 64 | def mwe(xs): 65 | return xs[0] < xs[2] 66 | 67 | 68 | jit_msort = jit.spec_call( 69 | msort, 70 | jit.oftype(list), 71 | # print_dio_ir=print, 72 | ) 73 | 74 | 75 | xs = list(np.random.randint(0, 10000, 100)) 76 | print( 77 | "pure py:", 78 | timeit.timeit("f(xs)", globals=dict(xs=xs, f=msort), number=100000), 79 | ) 80 | 81 | print( 82 | "jit:", 83 | timeit.timeit( 84 | "f(xs)", globals=dict(xs=xs, f=jit_msort), number=100000 85 | ), 86 | ) 87 | 88 | 89 | ## This is the specialisation that produces 600% performance gain: 90 | 91 | # @register(operator.__getitem__, create_shape=True) 92 | # def call_getitem(self: Judge, *args: AbsVal): 93 | # if len(args) != 2: 94 | # # 返回到默认python实现 95 | # return NotImplemented 96 | # subject, item = args 97 | # ret_types = (Top,) 98 | # 99 | # if ( 100 | # subject.type not in (Top, Bot) 101 | # and subject.type.base == list 102 | # and item.type not in (Top, Bot) 103 | # and issubclass(item.type.base, int) 104 | # ): 105 | # func = S(intrinsic("PyList_GetItem")) 106 | # # ret_types = tuple({Values.A_Int}) 107 | # else: 108 | # func = S(intrinsic("PyObject_GetItem")) 109 | # 110 | # e_call = func(subject, item) 111 | # instance = None 112 | # return CallSpec(instance, e_call, ret_types) 113 | -------------------------------------------------------------------------------- /benchmarks/trans.py: -------------------------------------------------------------------------------- 1 | """ 2 | pure py: 2.6820188 3 | jit: 0.8471317999999997 4 | """ 5 | import diojit as jit 6 | from inspect import getsource 7 | import timeit 8 | from diojit.runtime.julia_rt import splice, jl_eval 9 | 10 | print('trans'.center(50, '=')) 11 | 12 | @jit.jit 13 | def f(x): 14 | x = 1 + x 15 | y = 1 + x 16 | z = 1 + y 17 | x = 1 + z 18 | y = 1 + x 19 | z = 1 + y 20 | x = 1 + z 21 | return x 22 | 23 | 24 | jit_f = jit.spec_call(f, jit.oftype(int)) 25 | 26 | print(jit_f(10)) 27 | print('pure py:', timeit.timeit("f(10)", globals=dict(f=f), number=111111111)) 28 | print('jit:', timeit.timeit("f(10)", globals=dict(f=jit_f), number=111111111)) 29 | -------------------------------------------------------------------------------- /benchmarks/union_type.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import diojit as jit 3 | import typing 4 | import timeit 5 | 6 | 7 | @jit.eagerjit 8 | class Node: 9 | next: typing.Union[type(None), Node] 10 | val: int 11 | 12 | def __init__(self, n, val): 13 | self.next = n 14 | self.val = val 15 | 16 | 17 | @jit.eagerjit 18 | def sum_chain(n: Node): 19 | a = 0 20 | while n is not None: 21 | a += n.val 22 | n = n.next 23 | 24 | return a 25 | 26 | 27 | n = Node(None, 0) 28 | for i in range(100): 29 | n = Node(n, (i + 2) * 5) 30 | 31 | jit_sum_chain = jit.spec_call( 32 | sum_chain, jit.oftype(Node), print_dio_ir=print 33 | ) 34 | print(jit_sum_chain(n), sum_chain(n)) 35 | 36 | 37 | def bench(kind, f, number=1000000): 38 | print( 39 | kind, 40 | timeit.timeit( 41 | "f(x)", 42 | globals=dict(f=f, x=n), 43 | number=number, 44 | ), 45 | ) 46 | 47 | 48 | bench("jit", jit_sum_chain) 49 | bench("pure py", sum_chain) 50 | -------------------------------------------------------------------------------- /diojit/__init__.py: -------------------------------------------------------------------------------- 1 | from .user.client import * 2 | from .absint import * 3 | from . import absint, stack2reg 4 | from . import codegen 5 | from . import runtime 6 | -------------------------------------------------------------------------------- /diojit/absint/__init__.py: -------------------------------------------------------------------------------- 1 | from .intrinsics import * 2 | from .abs import * 3 | from .prescr import * 4 | -------------------------------------------------------------------------------- /diojit/absint/abs.py: -------------------------------------------------------------------------------- 1 | """ 2 | ::= D_i^t | NonD 3 | NonD ::= S_o^ [*] 4 | | Top 5 | | Bot 6 | """ 7 | 8 | from __future__ import annotations 9 | from typing import ( 10 | Union, 11 | Callable, 12 | Optional, 13 | Iterable, 14 | Sequence, 15 | TypeVar, 16 | Type, 17 | NamedTuple, 18 | TYPE_CHECKING, 19 | cast, 20 | ) 21 | from collections import defaultdict, OrderedDict 22 | from contextlib import contextmanager 23 | from functools import total_ordering 24 | import types 25 | import dataclasses 26 | import pyrsistent 27 | import builtins 28 | from .intrinsics import * 29 | 30 | NoneType = type(None) 31 | 32 | # awesome pycharm :( 33 | FunctionType: Type[types.FunctionType] = cast( 34 | Type[types.FunctionType], types.FunctionType 35 | ) 36 | 37 | # set False when not made for Python 38 | PYTHON = True 39 | 40 | __all__ = [ 41 | "Values", 42 | "AbsVal", 43 | "D", 44 | "S", 45 | "Top", 46 | "Bot", 47 | "Judge", 48 | "JITSpecInfo", 49 | "CallSpec", 50 | "CallRecord", 51 | "PreSpecMaps", 52 | "SpecMaps", 53 | "RecTraces", 54 | "In_Def", 55 | "In_Move", 56 | "In_Goto", 57 | "In_SetLineno", 58 | "In_Bind", 59 | "In_Stmt", 60 | "In_Blocks", 61 | "In_Cond", 62 | "In_Return", 63 | "Out_Def", 64 | "Out_Label", 65 | "Out_Call", 66 | "Out_Return", 67 | "Out_Goto", 68 | "Out_TypeCase", 69 | "Out_If", 70 | "Out_Assign", 71 | "Out_SetLineno", 72 | "Out_DecRef", 73 | "Out_Instr", 74 | "print_out", 75 | "print_in", 76 | "from_runtime", 77 | "AWARED_IMMUTABLES", 78 | "ShapeSystem", 79 | "Shape", 80 | "FunctionType", # TODO: put it elsewhere 81 | ] 82 | 83 | 84 | class Out_Callable: 85 | def __call__(self, *args: AbsVal): 86 | # noinspection PyTypeChecker 87 | return Out_Call(self, args) 88 | 89 | 90 | class AbsVal: 91 | if TYPE_CHECKING: 92 | 93 | @property 94 | def type(self) -> NonD: 95 | raise NotImplementedError 96 | 97 | def is_literal(self) -> bool: 98 | raise NotImplementedError 99 | 100 | def is_s(self): 101 | return False 102 | 103 | 104 | @total_ordering 105 | class D(Out_Callable, AbsVal): 106 | """ 107 | dynamic abstract value 108 | """ 109 | 110 | i: int 111 | type: NonD 112 | 113 | def __init__(self, i: int, type: NonD): 114 | self.i = i 115 | self.type = type 116 | 117 | def __repr__(self): 118 | if self.type is Top: 119 | return f"D{self.i}" 120 | return f"D{self.i} : {self.type}" 121 | 122 | def is_literal(self): 123 | return False 124 | 125 | def __hash__(self): 126 | return 114514 ^ hash(self.i) ^ hash(self.type) 127 | 128 | def __eq__(self, other): 129 | return ( 130 | isinstance(other, D) 131 | and self.i == other.i 132 | and self.type == other.type 133 | ) 134 | 135 | def __lt__(self, other): 136 | # noinspection PyTypeHints 137 | if not isinstance(other, AbsVal): 138 | return False 139 | if other is Top or other is Bot: 140 | return True 141 | if isinstance(other, S): 142 | return True 143 | if self.i == other.i: 144 | return self.type < other.type 145 | return self.i < other.i 146 | 147 | 148 | @total_ordering 149 | class S(Out_Callable, AbsVal): 150 | """ 151 | type abstract value 152 | """ 153 | 154 | base: object 155 | params: Optional[tuple[NonD, ...]] 156 | 157 | def __init__( 158 | self, base: object, params: Optional[tuple[NonD, ...]] = None 159 | ): 160 | self.base = base 161 | self.params = params 162 | 163 | def __hash__(self): 164 | return 1919810 ^ hash(self.base) ^ hash(self.params) 165 | 166 | def __eq__(self, other): 167 | return ( 168 | isinstance(other, S) 169 | and self.base == other.base 170 | and self.params == other.params 171 | ) 172 | 173 | def __lt__(self, other): 174 | # noinspection PyTypeHints 175 | if not isinstance(other, AbsVal): 176 | return False 177 | if other is Top or other is Bot: 178 | return True 179 | if isinstance(other, D): 180 | return False 181 | if self.base == other.base: 182 | return self.params < other.params 183 | return hash(self.base) < hash(other.base) 184 | 185 | def is_s(self): 186 | return True 187 | 188 | def oop(self): 189 | return (shape := self.shape) and shape.oop 190 | 191 | @property 192 | def type(self): 193 | base = self.base 194 | t = type(base) 195 | if abs_t := _literal_type_maps.get(t): 196 | return abs_t 197 | elif t is tuple: 198 | return tuple_type(base) 199 | a_t = from_runtime(t) 200 | assert not isinstance(a_t, D) 201 | return a_t 202 | 203 | @property 204 | def shape(self) -> Optional[Shape]: 205 | if type(self.base) in _literal_type_maps: 206 | return 207 | return ShapeSystem.get(self.base) 208 | 209 | def is_literal(self): 210 | return type(self.base) in _literal_type_maps 211 | 212 | def __repr__(self): 213 | if isinstance(self.base, type): 214 | n = self.base.__name__ 215 | elif isinstance(self.base, FunctionType): 216 | n = self.base.__name__ 217 | elif isinstance(self.base, types.BuiltinFunctionType): 218 | n = f"{self.base.__module__}.{self.base.__name__}" 219 | else: 220 | n = repr(self.base) 221 | 222 | if self.params is None: 223 | return n 224 | 225 | return f"{n}<{', '.join(map(repr, self.params))}>" 226 | 227 | 228 | class Values: 229 | A_Int = S(int) 230 | A_Float = S(float) 231 | A_Str = S(str) 232 | A_NoneType = S(NoneType) 233 | A_FuncType = S(FunctionType) 234 | A_MethType = S(types.MethodType) 235 | A_Complex = S(complex) 236 | A_Bool = S(bool) 237 | A_Intrinsic = S(Intrinsic) 238 | A_Type = S(type) 239 | A_NotImplemented = S(NotImplemented) 240 | A_NotImplementedType = S(type(NotImplemented)) 241 | 242 | 243 | A_Int = Values.A_Int 244 | A_Float = Values.A_Float 245 | A_Str = Values.A_Str 246 | A_NoneType = Values.A_NoneType 247 | A_FuncType = Values.A_FuncType 248 | A_Bool = Values.A_Bool 249 | A_Intrinsic = Values.A_Intrinsic 250 | A_Complex = Values.A_Complex 251 | 252 | 253 | _literal_type_maps = { 254 | int: A_Int, 255 | float: A_Float, 256 | complex: A_Complex, 257 | str: A_Str, 258 | bool: A_Bool, 259 | NoneType: A_NoneType, 260 | tuple: None, 261 | } 262 | 263 | _T = TypeVar("_T") 264 | 265 | 266 | class _Top(Out_Callable, AbsVal): 267 | def is_literal(self) -> bool: 268 | return False 269 | 270 | @property 271 | def type(self): 272 | raise TypeError 273 | 274 | def __repr__(self): 275 | return "Top" 276 | 277 | 278 | class _Bot(Out_Callable, AbsVal): 279 | def is_literal(self) -> bool: 280 | return False 281 | 282 | @property 283 | def type(self): 284 | raise TypeError 285 | 286 | def __repr__(self): 287 | return "Bot" 288 | 289 | 290 | Top = _Top() 291 | Bot = _Bot() 292 | 293 | NonD = Union[S, _Top, _Bot] 294 | if TYPE_CHECKING: 295 | AbsVal = Union[D, NonD] 296 | 297 | 298 | undef = object() 299 | 300 | 301 | @dataclasses.dataclass 302 | class Shape: 303 | name: object 304 | oop: bool 305 | fields: dict[str, Union[S, types.FunctionType]] 306 | # some type has unique instance 307 | # None.__class__ has None only 308 | instance: Union[ 309 | None, S, Callable[[tuple[NonD, ...]], Optional[S]] 310 | ] = dataclasses.field(default=None) 311 | self_bound: bool = dataclasses.field(default=False) 312 | 313 | 314 | # None: means No Shape 315 | ShapeSystem: dict[object, Optional[Shape]] = {} 316 | 317 | AWARED_IMMUTABLES = {*_literal_type_maps, type, Intrinsic, FunctionType} 318 | 319 | 320 | def from_runtime(o: object, rt_map: list[object] = None): 321 | if hash(o): 322 | return S(o) 323 | t = type(o) 324 | if t is tuple: 325 | type_abs = tuple_type(o) 326 | else: 327 | type_abs = from_runtime(t) 328 | rt_map = rt_map or [] 329 | i = len(rt_map) 330 | abs_val = D(i, type_abs) 331 | rt_map.append(abs_val) 332 | return abs_val 333 | 334 | 335 | def tuple_type(xs): 336 | return S(tuple, tuple(from_runtime(x) for x in xs)) 337 | 338 | 339 | @dataclasses.dataclass(frozen=True) 340 | class In_Move: 341 | target: D 342 | source: AbsVal 343 | 344 | def __repr__(self): 345 | return f"{self.target} = {self.source!r}" 346 | 347 | 348 | @dataclasses.dataclass(frozen=True) 349 | class In_Bind: 350 | target: D 351 | sub: AbsVal 352 | attr: AbsVal 353 | args: tuple[AbsVal, ...] 354 | 355 | def __repr__(self): 356 | args = [repr(x) for x in self.args] 357 | return f"{self.target} = {self.sub!r}.{self.attr}({','.join(args)})" 358 | 359 | 360 | @dataclasses.dataclass(frozen=True) 361 | class In_Goto: 362 | label: str 363 | 364 | def __repr__(self): 365 | return f"goto {self.label}" 366 | 367 | 368 | @dataclasses.dataclass(frozen=True) 369 | class In_SetLineno: 370 | line: int 371 | filename: str 372 | 373 | def __repr__(self): 374 | return f"# line {self.line} at {self.filename}" 375 | 376 | 377 | @dataclasses.dataclass(frozen=True) 378 | class In_Cond: 379 | test: AbsVal 380 | then: str 381 | otherwise: str 382 | 383 | def __repr__(self): 384 | return ( 385 | f"if {self.test!r} then {self.then} else {self.otherwise}" 386 | ) 387 | 388 | 389 | @dataclasses.dataclass(frozen=True) 390 | class In_Return: 391 | value: AbsVal 392 | 393 | def __repr__(self): 394 | return f"return {self.value!r}" 395 | 396 | 397 | In_Stmt = Union[ 398 | In_Cond, In_SetLineno, In_Goto, In_Move, In_Return, In_Bind 399 | ] 400 | In_Blocks = "dict[str, list[In_Stmt]]" 401 | 402 | 403 | def print_in(b: In_Blocks, print=print): 404 | for label, xs in sorted(b.items(), key=lambda x: x[0] != "entry"): 405 | print(label, ":") 406 | for each in xs: 407 | print(each) 408 | 409 | 410 | @dataclasses.dataclass 411 | class In_Def: 412 | narg: int 413 | blocks: In_Blocks 414 | _func: FunctionType 415 | static_glob: set[str] 416 | 417 | UserCodeDyn = {} # type: dict[FunctionType, In_Def] 418 | 419 | def show(self): 420 | args = [f"D[{i}]" for i in range(self.narg)] 421 | 422 | print(f'def {self.name}({",".join(args)})', "{") 423 | print_in(self.blocks, print=lambda *x: print("", *x)) 424 | print("}") 425 | 426 | @property 427 | def func(self) -> FunctionType: 428 | """this is for hacking PyCharm's type checker""" 429 | # noinspection PyTypeChecker 430 | return self._func 431 | 432 | @property 433 | def name(self) -> str: 434 | return self.func.__name__ 435 | 436 | @property 437 | def glob(self) -> dict: 438 | # noinspection PyUnresolvedReferences 439 | return self.func.__globals__ 440 | 441 | 442 | @dataclasses.dataclass(unsafe_hash=True) 443 | class Out_Call(Out_Callable): 444 | func: AbsVal 445 | args: tuple[AbsVal, ...] 446 | 447 | def __repr__(self): 448 | return f"{self.func!r}{self.args!r}" 449 | 450 | 451 | @dataclasses.dataclass(frozen=True) 452 | class Out_Assign: 453 | target: D 454 | expr: Out_Call 455 | decrefs: tuple[int, ...] 456 | 457 | def show(self, prefix, print): 458 | decrefs = ",".join(f"D{i}" for i in self.decrefs) 459 | print(f"{prefix}{self.target} = {self.expr!r}") 460 | print(f"{prefix}when err: decref [{decrefs}] ") 461 | 462 | 463 | @dataclasses.dataclass(frozen=True) 464 | class Out_If: 465 | test: AbsVal 466 | 467 | t: str 468 | f: str 469 | 470 | def show(self, prefix, print): 471 | print(f"{prefix}if {self.test!r}") 472 | print(f"{prefix}then goto {self.t}") 473 | print(f"{prefix}else goto {self.f}") 474 | 475 | 476 | @dataclasses.dataclass(frozen=True) 477 | class Out_TypeCase: 478 | obj: AbsVal 479 | cases: pyrsistent.PMap[AbsVal, tuple[Out_Instr, ...]] 480 | 481 | def show(self, prefix, print): 482 | print(f"{prefix}case typeof {self.obj!r}") 483 | for t, xs in self.cases.items(): 484 | print(f"{prefix} {t!r} ->") 485 | print_out(xs, prefix + " ", print) 486 | 487 | 488 | @dataclasses.dataclass(frozen=True) 489 | class Out_Label: 490 | label: str 491 | 492 | def show(self, prefix, print): 493 | print(f"label {self.label}:") 494 | 495 | 496 | @dataclasses.dataclass(frozen=True) 497 | class Out_Goto: 498 | label: str 499 | 500 | def show(self, prefix, print): 501 | print(f"{prefix}goto {self.label}") 502 | 503 | 504 | @dataclasses.dataclass(frozen=True) 505 | class Out_Return: 506 | value: AbsVal 507 | decrefs: tuple[int, ...] 508 | 509 | def show(self, prefix, print): 510 | decrefs = ",".join(f"D{i}" for i in self.decrefs) 511 | print(f"{prefix}return {self.value!r}") 512 | print(f"{prefix}and decref [{decrefs}]") 513 | 514 | 515 | @dataclasses.dataclass(frozen=True) 516 | class Out_DecRef: 517 | i: int 518 | 519 | def show(self, prefix, print): 520 | print(f"{prefix}decref D{self.i}") 521 | 522 | 523 | @dataclasses.dataclass(frozen=True) 524 | class Out_SetLineno: 525 | line: int 526 | filename: str 527 | 528 | def show(self, prefix, print): 529 | print(f"{prefix}# line {self.line} at {self.filename}") 530 | 531 | 532 | Out_Instr = Union[ 533 | Out_DecRef, 534 | Out_Label, 535 | Out_TypeCase, 536 | Out_If, 537 | Out_Assign, 538 | Out_Return, 539 | Out_Goto, 540 | ] 541 | 542 | 543 | CallRecord = "tuple[FunctionType, tuple[AbsVal, ...]]" 544 | 545 | 546 | def print_out(xs: Iterable[Out_Instr], prefix, print): 547 | for each in xs: 548 | each.show(prefix, print) 549 | 550 | 551 | @dataclasses.dataclass 552 | class Out_Def: 553 | spec: JITSpecInfo 554 | params: tuple[AbsVal, ...] 555 | instrs: tuple[Out_Instr, ...] 556 | start: str 557 | func: FunctionType 558 | 559 | GenerateCache = OrderedDict() # type: dict[Intrinsic, Out_Def] 560 | 561 | @property 562 | def name(self) -> str: 563 | return self.func.__name__ 564 | 565 | def show(self, print=print): 566 | ret_types = self.spec.possibly_return_types 567 | name = self.spec.abs_jit_func 568 | instance = self.spec.instance 569 | print( 570 | "|".join(map(repr, ret_types)), 571 | f"{name!r}(", 572 | ", ".join(map(repr, self.params)), 573 | ")", 574 | f"-> {instance} {{" if instance else "{", 575 | ) 576 | # print(f" START from {self.start}") 577 | for i in self.instrs: 578 | i.show(" ", print) 579 | print("}") 580 | 581 | 582 | @dataclasses.dataclass 583 | class JITSpecInfo: 584 | instance: Optional[AbsVal] # maybe return a constant instance 585 | abs_jit_func: AbsVal 586 | possibly_return_types: tuple[AbsVal, ...] 587 | 588 | 589 | class CallSpec: 590 | instance: Optional[AbsVal] # maybe return a constant instance 591 | e_call: Union[Out_Call, AbsVal] 592 | possibly_return_types: tuple[AbsVal, ...] 593 | 594 | def __init__( 595 | self, 596 | instance: Optional[AbsVal], 597 | e_call: Union[Out_Call, AbsVal], 598 | possibly_return_types: Iterable[AbsVal, ...], 599 | ): 600 | self.instance = instance 601 | self.e_call = e_call 602 | if not isinstance(possibly_return_types, tuple): 603 | possibly_return_types = tuple(possibly_return_types) 604 | self.possibly_return_types = possibly_return_types 605 | 606 | def __eq__(self, other): 607 | return ( 608 | isinstance(other, CallSpec) 609 | and self.instance == other.instance 610 | and self.e_call == other.e_call 611 | and self.possibly_return_types 612 | == other.possibly_return_types 613 | ) 614 | 615 | def astuple(self): 616 | return self.instance, self.e_call, self.possibly_return_types 617 | 618 | 619 | # user function calls recorded here cannot be reanalyzed next time 620 | RecTraces: set[tuple[str, tuple[AbsVal, ...]]] = set() 621 | 622 | # specialisations map 623 | # return types not inferred but compiled function name, and partially inferenced types 624 | PreSpecMaps: dict[CallRecord, tuple[str, set[AbsVal, ...]]] = {} 625 | 626 | # cache return types and function address 627 | SpecMaps: dict[CallRecord, JITSpecInfo] = {} 628 | 629 | 630 | def mk_prespec_name( 631 | key: CallRecord, partial_returns: set[AbsVal], name="" 632 | ): 633 | v = PreSpecMaps.get(key) 634 | if v is None: 635 | i = len(PreSpecMaps) 636 | n = f"J_{name.replace('_', '__')}_{i}" 637 | PreSpecMaps[key] = n, partial_returns 638 | return n 639 | return v[0] 640 | 641 | 642 | class MemSlot(NamedTuple): 643 | # reference count 644 | rc: int 645 | # is locally allocated 646 | ila: bool 647 | 648 | def __repr__(self): 649 | x = f"[{self.rc}]" 650 | if self.ila: 651 | x = f"!{x}" 652 | return x 653 | 654 | 655 | @dataclasses.dataclass(frozen=True, eq=True, order=True) 656 | class Local: 657 | mem: pyrsistent.PVector[MemSlot] 658 | store: pyrsistent.PMap[int, AbsVal] 659 | 660 | def up_mem(self, mem): 661 | return Local(mem, self.store) 662 | 663 | def up_store(self, store): 664 | return Local(self.mem, store) 665 | 666 | 667 | def alloc(local: Local): 668 | i = -1 669 | for i, each in enumerate(local.mem): 670 | if each.rc == 0: 671 | return i 672 | i += 1 673 | return i 674 | 675 | 676 | def valid_value(x: AbsVal, mk_exc=None): 677 | if x is Bot or x is Top: 678 | raise mk_exc and mk_exc() or TypeError 679 | return x 680 | 681 | 682 | def judge_lit(local: Local, a: AbsVal): 683 | if isinstance(a, D): 684 | abs_val = local.store.get(a.i) 685 | return abs_val or Bot 686 | return a 687 | 688 | 689 | def try_spec_val_then_decref(a, t: NonD) -> tuple[list, AbsVal]: 690 | """ 691 | A dynamic abstract value might have a singleton type, 692 | in this case it's specialised into a static abstract vlaue. 693 | e.g., a value whose type is 'type(None)' must be exactly 'None'. 694 | """ 695 | if t is Bot: 696 | return [], Bot 697 | if isinstance(a, D): 698 | # noinspection PyUnboundLocalVariable 699 | if ( 700 | isinstance(t, S) 701 | and (shape := t.shape) 702 | and (inst := shape.instance) 703 | ): 704 | # noinspection PyTypeHints 705 | # noinspection PyUnboundLocalVariable 706 | if isinstance(inst, S): 707 | # noinspection PyUnboundLocalVariable 708 | return [Out_DecRef(a.i)], inst 709 | # noinspection PyUnboundLocalVariable 710 | inst = inst(t.params) 711 | return [Out_DecRef(a.i)], inst 712 | a_t = D(a.i, t) 713 | return [], a_t 714 | return [], a 715 | 716 | 717 | def blocks_to_instrs(blocks: dict[str, list[Out_Instr]], start: str): 718 | merge_blocks = defaultdict(OrderedDict) 719 | for label, block in sorted(blocks.items()): 720 | block = tuple(block) 721 | labels = merge_blocks[block] 722 | labels[label] = None 723 | instrs = [] 724 | for block, labels in merge_blocks.items(): 725 | for label in labels: 726 | instrs.append(Out_Label(label)) 727 | instrs.extend(block) 728 | return instrs 729 | 730 | 731 | def decref(self: Judge, local: Local, a: AbsVal): 732 | if not isinstance(a, D): 733 | return local 734 | i = a.i 735 | if i < len(local.mem): 736 | mem = local.mem 737 | slot = mem[i] 738 | if refcnt := slot.rc - 1: 739 | slot = MemSlot(refcnt, slot.ila) 740 | else: 741 | if slot.ila: 742 | self << Out_DecRef(i) 743 | slot = MemSlot(0, True) 744 | mem = mem.set(i, slot) 745 | return local.up_mem(mem) 746 | return local 747 | 748 | 749 | def incref(local: Local, a: AbsVal): 750 | if not isinstance(a, D): 751 | return local 752 | i = a.i 753 | mem = local.mem 754 | try: 755 | ref = mem[i] 756 | slot = MemSlot(ref.rc + 1, ref.ila) 757 | mem = mem.set(i, slot) 758 | except IndexError: 759 | assert len(mem) == i 760 | mem = mem.append(MemSlot(1, True)) 761 | return local.up_mem(mem) 762 | 763 | 764 | class Judge: 765 | def __init__(self, blocks: In_Blocks, func: FunctionType, abs_glob): 766 | # States 767 | self.in_blocks = blocks 768 | self.block_map: dict[tuple[str, Local], str] = {} 769 | self.returns: set[AbsVal] = set() 770 | self.code: list[Out_Instr] = [] 771 | self.out_blocks: dict[str, list[Out_Instr]] = OrderedDict() 772 | self.abs_glob: dict[str, AbsVal] = abs_glob 773 | self.func = func 774 | self.label_cnt = 0 775 | 776 | @property 777 | def glob(self) -> dict: 778 | # noinspection PyUnresolvedReferences 779 | return self.func.__globals__ 780 | 781 | @contextmanager 782 | def use_code(self, code: list[Out_Instr]): 783 | old_code = self.code 784 | self.code = code 785 | try: 786 | yield 787 | finally: 788 | self.code = old_code 789 | 790 | def gen_label(self, kind=""): 791 | self.label_cnt += 1 792 | return f"{kind}_{self.label_cnt}" 793 | 794 | def __lshift__(self, a): 795 | if isinstance(a, list): 796 | self.code.extend(a) 797 | else: 798 | self.code.append(a) 799 | 800 | def jump(self, local: Local, label: str): 801 | key = (label, local) 802 | if gen_label := self.block_map.get(key): 803 | return gen_label 804 | gen_label = self.gen_label() 805 | self.block_map[key] = gen_label 806 | code = [] 807 | with self.use_code(code): 808 | self.stmt(local, self.in_blocks[label], 0) 809 | self.out_blocks[gen_label] = code 810 | return gen_label 811 | 812 | def stmt(self, local: Local, xs: Sequence[In_Stmt], index: int): 813 | try: 814 | while (hd := xs[index]) and isinstance(hd, In_SetLineno): 815 | self << Out_SetLineno(hd.line, hd.filename) 816 | index += 1 817 | except IndexError: 818 | # TODO 819 | raise Exception("non-terminaor terminate") 820 | if isinstance(hd, In_Move): 821 | a_x = judge_lit(local, hd.target) 822 | # print(hd, judge_lit(local, hd.source), local.store) 823 | a_y = judge_lit(local, hd.source) 824 | if a_y is Top or a_y is Bot: 825 | self.error(local) 826 | return 827 | 828 | local = decref(self, local, a_x) 829 | local = incref(local, a_y) 830 | local = local.up_store(local.store.set(hd.target.i, a_y)) 831 | self.stmt(local, xs, index + 1) 832 | elif isinstance(hd, In_Goto): 833 | label_gen = self.jump(local, hd.label) 834 | self << Out_Goto(label_gen) 835 | elif isinstance(hd, In_Return): 836 | a = valid_value(judge_lit(local, hd.value)) 837 | self.returns.add(a) 838 | self << Out_Return(a, tuple(self.all_ownerships(local))) 839 | elif isinstance(hd, In_Cond): 840 | a = judge_lit(local, hd.test) 841 | if a is Top or a is Bot: 842 | self.error(local) 843 | return 844 | if a.type == A_Bool: 845 | a_cond = a 846 | else: 847 | # extract 848 | instance, e_call, union_types = self.spec( 849 | A_Bool, "__call__", [a] 850 | ).astuple() 851 | if e_call is Top or e_call is Bot: 852 | self.error(local) 853 | return 854 | 855 | if not isinstance(e_call, (Out_Call, D)): 856 | a_cond = e_call 857 | else: 858 | j = alloc(local) 859 | a_cond = D(j, A_Bool) 860 | self << Out_Assign( 861 | a_cond, 862 | e_call, 863 | tuple(self.all_ownerships(local)), 864 | ) 865 | 866 | self << Out_DecRef(j) 867 | 868 | if instance: 869 | a_cond = instance 870 | 871 | if a_cond.is_literal() and isinstance(a_cond.base, bool): 872 | direct_label = hd.then if a_cond.base else hd.otherwise 873 | label_generated = self.jump(local, direct_label) 874 | self << Out_Goto(label_generated) 875 | return 876 | 877 | l1 = self.jump(local, hd.then) 878 | l2 = self.jump(local, hd.otherwise) 879 | self << Out_If(a_cond, l1, l2) 880 | elif isinstance(hd, In_Bind): 881 | a_x = judge_lit(local, hd.target) 882 | a_subj = judge_lit(local, hd.sub) 883 | if a_subj is Top or a_subj is Bot: 884 | self.error(local) 885 | return 886 | attr = judge_lit(local, hd.attr) 887 | assert attr.is_literal() and isinstance( 888 | attr.base, str 889 | ), f"attr {attr} shall be a string" 890 | attr = attr.base 891 | a_args = [] 892 | for a in hd.args: 893 | a_args.append(judge_lit(local, a)) 894 | if a_args[-1] in (Top, Bot): 895 | self.error(local) 896 | return 897 | instance, e_call, union_types = self.spec( 898 | a_subj, attr, a_args 899 | ).astuple() 900 | 901 | if e_call in (Top, Bot): 902 | self.error(local) 903 | return 904 | 905 | # 1. no actual CALL happens 906 | if not isinstance(e_call, (Out_Call, D)): 907 | local = decref(self, local, a_x) 908 | rhs = e_call 909 | if instance: 910 | rhs = instance 911 | local = local.up_store( 912 | local.store.set(hd.target.i, rhs) 913 | ) 914 | self.stmt(local, xs, index + 1) 915 | return 916 | 917 | j = alloc(local) 918 | # 2. CALL happens but not union-typed and the result might be a constant 919 | if len(union_types) == 1: 920 | a_t = union_types[0] 921 | a_spec = D(j, a_t) 922 | if isinstance(e_call, Out_Call): 923 | self << Out_Assign( 924 | a_spec, 925 | e_call, 926 | tuple(self.all_ownerships(local)), 927 | ) 928 | # TODO: documenting that 'D(i, t1)' means 'D(i, t2)' 929 | # at a given program counter. 930 | local = decref(self, local, a_x) 931 | if a_t is Bot: 932 | self.error(local) 933 | return 934 | if instance: 935 | a_spec = instance 936 | code = [] 937 | else: 938 | a_t = a_spec.type 939 | code, a_spec = try_spec_val_then_decref( 940 | a_spec, a_t 941 | ) # handle instance 942 | valid_value(a_spec) 943 | if code: 944 | self << code[0] 945 | local = incref(local, a_spec) 946 | local = local.up_store( 947 | local.store.set(hd.target.i, a_spec) 948 | ) 949 | self.stmt(local, xs, index + 1) 950 | return 951 | # 3. CALL happens, union-typed; 952 | # result might be a constant for each type 953 | a_union = D(j, Top) 954 | self << Out_Assign( 955 | a_union, e_call, tuple(self.all_ownerships(local)) 956 | ) 957 | local = decref(self, local, a_x) 958 | # TODO: Top(if any) should be put in the last of 'union_types' 959 | split = [ 960 | try_spec_val_then_decref(a_union, t) 961 | for t in union_types 962 | ] 963 | cases: list[tuple[S, list[Out_Instr]]] = [] 964 | for (code, a_spec), a_t in zip(split, union_types): 965 | cases.append((a_t, code)) 966 | if a_t is Bot: 967 | self.error(local) 968 | continue 969 | 970 | valid_value(a_spec) 971 | local_i = incref(local, a_spec) 972 | local_i = local_i.up_store( 973 | local_i.store.set(hd.target.i, a_spec) 974 | ) 975 | with self.use_code(code): 976 | self.stmt(local_i, xs, index + 1) 977 | self << Out_TypeCase( 978 | a_union, 979 | pyrsistent.pmap( 980 | {case: tuple(code) for case, code in cases} 981 | ), 982 | ) 983 | 984 | def no_spec( 985 | self, a_sub: AbsVal, attr: str, a_args: list[AbsVal] 986 | ) -> CallSpec: 987 | assert isinstance(attr, str) 988 | a_sub = valid_value(a_sub) 989 | if attr == "__call__": 990 | return CallSpec( 991 | None, 992 | S(Intrinsic.Py_CallFunction)(a_sub, *a_args), 993 | (Top,), 994 | ) 995 | else: 996 | return CallSpec( 997 | None, 998 | S(Intrinsic.Py_CallMethod)(a_sub, S(attr), *a_args), 999 | (Top,), 1000 | ) 1001 | 1002 | def spec( 1003 | self, a_sub: AbsVal, attr: str, a_args: list[AbsVal] 1004 | ) -> CallSpec: 1005 | assert isinstance(attr, str) 1006 | a_sub = valid_value(a_sub) 1007 | if attr == "__call__": 1008 | 1009 | def default(): 1010 | return CallSpec( 1011 | None, 1012 | S(Intrinsic.Py_CallFunction)(a_sub, *a_args), 1013 | (Top,), 1014 | ) 1015 | 1016 | else: 1017 | 1018 | def default(): 1019 | return CallSpec( 1020 | None, 1021 | S(Intrinsic.Py_CallMethod)(a_sub, S(attr), *a_args), 1022 | (Top,), 1023 | ) 1024 | 1025 | a_t = a_sub.type 1026 | if a_t is Top: 1027 | return default() 1028 | if a_t is Bot: 1029 | raise TypeError 1030 | if isinstance(a_sub, S): 1031 | # python literal is not callable 1032 | if a_sub.is_literal(): 1033 | a_t = cast(S, a_t) 1034 | shape = a_t.shape 1035 | if not shape: 1036 | return default() 1037 | assert shape.oop 1038 | meth = judge_resolve(shape, attr) 1039 | if not meth: 1040 | return default() 1041 | # noinspection PyTypeHints 1042 | if isinstance(meth, AbsVal): 1043 | return self.spec(meth, "__call__", [a_sub, *a_args]) 1044 | r = meth(self, a_sub, *a_args) 1045 | if r is NotImplemented: 1046 | return default() 1047 | return r 1048 | shape = a_sub.shape 1049 | if shape and (meth_ := judge_resolve(shape, attr)): 1050 | # hack pycharm for type check 1051 | # noinspection PyUnboundLocalVariable 1052 | meth = meth_ 1053 | if shape.self_bound: 1054 | a_args = [a_sub, *a_args] 1055 | # noinspection PyTypeHints 1056 | if isinstance(meth, AbsVal): 1057 | return self.spec(meth, "__call__", a_args) 1058 | 1059 | r = meth(self, *a_args) 1060 | if r is NotImplemented: 1061 | return default() 1062 | return r 1063 | 1064 | a_t = cast(S, a_t) 1065 | shape = a_t.shape 1066 | if not shape: 1067 | return default() 1068 | meth = judge_resolve(shape, attr) 1069 | if not meth: 1070 | return default() 1071 | if shape.oop: 1072 | a_args = [a_sub, *a_args] 1073 | 1074 | # noinspection PyTypeHints 1075 | if isinstance(meth, AbsVal): 1076 | return self.spec(meth, "__call__", a_args) 1077 | 1078 | r = meth(self, *a_args) 1079 | if r is NotImplemented: 1080 | return default() 1081 | return r 1082 | 1083 | def all_ownerships(self, local: Local): 1084 | for i, each in enumerate(local.mem): 1085 | if each.rc and each.ila: 1086 | yield i 1087 | 1088 | def error(self, local): 1089 | raise ValueError(local) 1090 | 1091 | 1092 | def ufunc_spec(self, a_func: AbsVal, *arguments: AbsVal) -> CallSpec: 1093 | a_func = valid_value(a_func) 1094 | 1095 | def default(): 1096 | return CallSpec( 1097 | None, 1098 | S(Intrinsic.Py_CallFunction)(a_func, *arguments), 1099 | (Top,), 1100 | ) 1101 | 1102 | if isinstance(a_func, D): 1103 | return default() 1104 | # isinstance functiontype not handled in PyCharm 1105 | assert isinstance(a_func.base, FunctionType) 1106 | func = a_func.base 1107 | in_def = In_Def.UserCodeDyn.get(func) 1108 | if not in_def: 1109 | # not registered as jit func, skip 1110 | return default() 1111 | parameters = [] 1112 | mem = [] 1113 | store = {} 1114 | 1115 | for i, a_arg in enumerate(arguments): 1116 | if isinstance(a_arg, D): 1117 | j = len(mem) 1118 | mem.append(MemSlot(1, False)) 1119 | a_param = D(j, a_arg.type) 1120 | parameters.append(a_param) 1121 | 1122 | else: 1123 | a_param = a_arg 1124 | parameters.append(a_arg) 1125 | store[i] = a_param 1126 | 1127 | parameters = tuple(parameters) 1128 | call_record = a_func.base, parameters 1129 | if call_record in SpecMaps: 1130 | spec = SpecMaps[call_record] 1131 | e_call = spec.abs_jit_func(*arguments) 1132 | ret_types = spec.possibly_return_types 1133 | instance = spec.instance 1134 | elif partial_spec := PreSpecMaps.get(call_record): 1135 | jit_func_name, partial_returns = partial_spec 1136 | abs_jit_func = S(intrinsic(jit_func_name)) 1137 | e_call = abs_jit_func(*arguments) 1138 | partial_return_types = set( 1139 | each.type for each in partial_returns 1140 | ) 1141 | partial_return_types.add(Top) 1142 | ret_types = tuple(sorted(partial_return_types)) 1143 | instance = None 1144 | else: 1145 | partial_returns = set() 1146 | jit_func_name = mk_prespec_name( 1147 | call_record, partial_returns, name=in_def.name 1148 | ) 1149 | 1150 | abs_glob = {} 1151 | for glob_name in in_def.static_glob: 1152 | v = in_def.glob.get(glob_name, undef) 1153 | if v is undef: 1154 | v = getattr(builtins, glob_name, undef) 1155 | if v is undef: 1156 | continue 1157 | a_v = from_runtime(v) 1158 | if isinstance(v, D): 1159 | continue 1160 | abs_glob[glob_name] = a_v 1161 | 1162 | sub_judge = Judge( 1163 | blocks=in_def.blocks, func=in_def.func, abs_glob=abs_glob 1164 | ) 1165 | sub_judge.returns = partial_returns 1166 | local = Local(pyrsistent.pvector(mem), pyrsistent.pmap(store)) 1167 | gen_start = sub_judge.jump(local, "entry") 1168 | instrs = blocks_to_instrs(sub_judge.out_blocks, gen_start) 1169 | 1170 | ret_types: tuple[AbsVal, ...] = tuple( 1171 | sorted({r.type for r in sub_judge.returns}) 1172 | ) 1173 | instance, *is_union = sub_judge.returns 1174 | instance = valid_value(instance) 1175 | 1176 | if isinstance(instance, D): 1177 | instance = None 1178 | elif is_union: 1179 | instance = None 1180 | elif instance is None: 1181 | t = ret_types[0] 1182 | if t.is_s(): 1183 | instance = t.shape.instance 1184 | if isinstance(instance, FunctionType): 1185 | instance = instance(t.params) 1186 | 1187 | intrin = intrinsic(jit_func_name) 1188 | spec_info = JITSpecInfo(instance, S(intrin), ret_types) 1189 | SpecMaps[call_record] = spec_info 1190 | out_def = Out_Def( 1191 | spec_info, parameters, tuple(instrs), gen_start, in_def.func 1192 | ) 1193 | Out_Def.GenerateCache[intrin] = out_def 1194 | e_call = spec_info.abs_jit_func(*arguments) 1195 | 1196 | return CallSpec(instance, e_call, ret_types) 1197 | 1198 | 1199 | ShapeSystem[types.FunctionType] = Shape( 1200 | types.FunctionType, 1201 | oop=True, 1202 | fields={"__call__": cast(FunctionType, ufunc_spec)}, 1203 | ) 1204 | 1205 | 1206 | def judge_resolve(shape: Shape, attr: str): 1207 | if meth := shape.fields.get(attr): 1208 | return meth 1209 | if not isinstance(shape.name, type): 1210 | return None 1211 | for base in shape.name.__bases__: 1212 | if (shape := ShapeSystem.get(base)) and ( 1213 | meth := shape.fields.get(attr) 1214 | ): 1215 | return meth 1216 | 1217 | return None 1218 | -------------------------------------------------------------------------------- /diojit/absint/intrinsics.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import sys 3 | import operator 4 | from typing import Callable, Optional 5 | 6 | __all__ = ["intrinsic", "Intrinsic"] 7 | 8 | 9 | def setup(): 10 | pass 11 | 12 | 13 | def _mk(name, bases, ns): 14 | global setup 15 | t = type(name, bases, ns) 16 | 17 | def setup(): 18 | for n, v in ns["__annotations__"].items(): 19 | if n.startswith("_"): 20 | continue 21 | setattr(t, n, intrinsic(n)) 22 | 23 | return t 24 | 25 | 26 | class Intrinsic(metaclass=_mk): 27 | _callback: Optional[Callable] = None 28 | _name: str 29 | J: object 30 | 31 | @property 32 | def name(self): 33 | return self._name 34 | 35 | def __repr__(self): 36 | return f"{self._name}" 37 | 38 | def __eq__(self, other): 39 | return ( 40 | isinstance(other, Intrinsic) and self._name is other._name 41 | ) 42 | 43 | def __hash__(self): 44 | return id(self._name) 45 | 46 | Py_TYPE: Intrinsic 47 | 48 | # hints: implemented by 49 | # 1. PyObject_CallFunctionObjArgs() 50 | # 2. PyObject_CallNoArgs() 51 | # 3. PyObject_CallOneArg() 52 | Py_CallFunction: Intrinsic 53 | 54 | # 1. PyObject_CallMethodNoArgs() 55 | # 2. PyObject_CallMethodOneArg() 56 | # 3. PyObject_CallMethodObjArgs() 57 | Py_CallMethod: Intrinsic 58 | 59 | Py_LoadGlobal: Intrinsic 60 | 61 | Py_StoreGlobal: Intrinsic 62 | 63 | Py_StoreAttr: Intrinsic 64 | 65 | Py_LoadAttr: Intrinsic 66 | 67 | Py_BuildTuple: Intrinsic 68 | 69 | Py_BuildList: Intrinsic 70 | 71 | Py_Raise: Intrinsic 72 | 73 | Py_AddressCompare: Intrinsic 74 | Py_Not: Intrinsic 75 | 76 | # Py_Pow = operator.__pow__ 77 | # Py_Mul = operator.__mul__ 78 | # Py_Matmul = operator.__matmul__ 79 | # Py_Floordiv = operator.__floordiv__ 80 | # Py_Truediv = operator.__truediv__ 81 | # Py_Mod = operator.__mod__ 82 | # Py_Add = operator.__add__ 83 | # Py_Sub = operator.__sub__ 84 | # Py_Getitem = operator.__getitem__ 85 | # Py_Lshift = operator.__lshift__ 86 | # PY_Rshift = operator.__rshift__ 87 | # Py_And = operator.__and__ 88 | # Py_Xor = operator.__xor__ 89 | # Py_Or = operator.__or__ 90 | # 91 | # Py_Lt = operator.__lt__ 92 | # Py_Gt = operator.__gt__ 93 | # Py_Le = operator.__le__ 94 | # Py_Ge = operator.__ge__ 95 | # Py_Ne = operator.__ne__ 96 | # Py_Eq = operator.__eq__ 97 | 98 | 99 | _cache = {} 100 | 101 | 102 | def intrinsic(name: str) -> Intrinsic: 103 | name = sys.intern(name) 104 | if o := _cache.get(name): 105 | return o 106 | o = Intrinsic.__new__(Intrinsic) 107 | o._name = name 108 | _cache[name] = o 109 | return o 110 | 111 | 112 | setup() 113 | -------------------------------------------------------------------------------- /diojit/absint/prescr.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from .abs import * 3 | from .intrinsics import * 4 | from collections.abc import Iterable 5 | from functools import lru_cache 6 | import warnings 7 | import typing 8 | import operator 9 | import math 10 | import builtins 11 | import io 12 | 13 | if typing.TYPE_CHECKING: 14 | from mypy_extensions import VarArg 15 | 16 | _undef = object() 17 | __all__ = ["create_shape", "register"] 18 | 19 | _not_inferred_types = (Top, Bot) 20 | 21 | 22 | def lit(x: str): 23 | """directly pass this string to the backend""" 24 | return typing.cast(AbsVal, x) 25 | 26 | 27 | def u64i(i: int): 28 | """uint64 from integer""" 29 | return f"{i:#0{18}x}" 30 | 31 | 32 | concrete_numeric_types = { 33 | int: "PyInt_Compare", 34 | float: "PyObject_RichCompare", 35 | complex: "PyObject_RichCompare", 36 | bool: "PyInt_Compare", 37 | } 38 | 39 | 40 | def create_shape( 41 | o: object, oop: bool = False, self_bound=False, instance=_undef 42 | ): 43 | """ 44 | o: 'Special Object' in the paper. Must be immutable. 45 | oop: whether attached methods will be 46 | used as bound method in method resolution. 47 | instance: only works when 'o' is a 'type'. 48 | It is defined only the type has only 49 | one instance. 50 | """ 51 | try: 52 | hash(o) 53 | except TypeError: 54 | raise TypeError(f"create shape for non immuatble object {o}") 55 | if instance is _undef: 56 | instance = None 57 | else: 58 | instance = from_runtime(instance) 59 | assert not isinstance(instance, D) 60 | 61 | if shape := ShapeSystem.get(o): 62 | return shape 63 | shape = ShapeSystem[o] = Shape(o, oop, {}, instance, self_bound) 64 | return shape 65 | 66 | 67 | _create_shape = create_shape 68 | 69 | 70 | def register( 71 | o: object, 72 | attr="__call__", 73 | create_shape: typing.Union[dict, None, typing.Literal[True]] = None, 74 | ): 75 | 76 | if shape := ShapeSystem.get(o): 77 | pass 78 | else: 79 | if create_shape is not None: 80 | if create_shape is True: 81 | create_shape = {} 82 | shape = _create_shape(o, **create_shape) 83 | else: 84 | raise ValueError( 85 | f"No shape found for {{ o={o} }}.\n" 86 | f" Maybe use {{ create_shape(o, oop_or_not) }} firstly?" 87 | ) 88 | if attr in shape.fields: 89 | warnings.warn( 90 | Warning( 91 | f"field {attr} exists for the shape of object f{o}." 92 | ) 93 | ) 94 | 95 | def ap(f: typing.Callable[[Judge, VarArg(AbsVal)], CallSpec]): 96 | shape.fields[attr] = f 97 | return f 98 | 99 | return ap 100 | 101 | 102 | def shape_of(o): 103 | return ShapeSystem[o] 104 | 105 | 106 | create_shape(Intrinsic, oop=True) 107 | create_shape(list, oop=True) 108 | create_shape(dict, oop=True) 109 | create_shape(bytes, oop=True) 110 | create_shape(bool, oop=True) 111 | create_shape(bytearray, oop=True) 112 | create_shape(int, oop=True) 113 | # create_shape(type, self_bound=True, oop=True) 114 | create_shape(next) 115 | create_shape(io.BytesIO) 116 | _NoneType_shape = create_shape(type(None)) 117 | _NoneType_shape.instance = S(None) 118 | 119 | 120 | @register(Intrinsic, "__call__") 121 | def py_call_intrinsic( 122 | self: Judge, f: AbsVal, *args: AbsVal 123 | ) -> CallSpec: 124 | return CallSpec(None, f(*args), (Top,)) 125 | 126 | 127 | @register( 128 | Intrinsic.Py_LoadGlobal, "__call__", create_shape=dict(oop=False) 129 | ) 130 | def py_load_global(self: Judge, a_str: AbsVal) -> CallSpec: 131 | """ 132 | TODO: S(self.glob) is UNSAFE here. 133 | Theoretically, 'S' shall not be used for 134 | a mutable data. This is a special case. 135 | """ 136 | 137 | def slow_path(): 138 | instance = None 139 | func = S(intrinsic("PyDict_LoadGlobal")) 140 | e_call = func(S(self.func), S(builtins), a_str) 141 | ret_types = (Top,) 142 | return CallSpec(instance, e_call, ret_types) 143 | 144 | def constant_key_path(): 145 | instance = None 146 | hash_val = hash(a_str.base) 147 | func = S(intrinsic("PyDict_LoadGlobal_KnownHash")) 148 | e_call = func( 149 | S(self.func), S(builtins), a_str, lit(str(hash_val)) 150 | ) 151 | ret_types = (Top,) 152 | return CallSpec(instance, e_call, ret_types) 153 | 154 | if a_str.is_literal(): 155 | if isinstance(a_str.base, str): 156 | attr = a_str.base 157 | if attr in self.abs_glob: 158 | a = self.abs_glob[attr] 159 | return CallSpec(a, a, possibly_return_types=(a.type,)) 160 | 161 | return constant_key_path() 162 | return slow_path() 163 | 164 | 165 | @register(bool, attr="__call__") 166 | def py_call_bool_type(self: Judge, *args: AbsVal): 167 | if not args: 168 | # bool() = False 169 | constant_return = S(False) 170 | return CallSpec( 171 | constant_return, constant_return, (Values.A_Bool,) 172 | ) 173 | if len(args) != 1: 174 | # bool(a, b, c) = False 175 | return NotImplemented 176 | # bool(a) 177 | arg = args[0] 178 | if isinstance(arg.type, S) and issubclass(arg.type.base, bool): 179 | constant_return = isinstance(arg, S) and arg or None 180 | return CallSpec(constant_return, arg, (Values.A_Bool,)) 181 | return CallSpec( 182 | None, 183 | S(intrinsic("Py_CallBoolIfNecessary"))(arg), 184 | (Values.A_Bool,), 185 | ) 186 | 187 | 188 | @register(isinstance, create_shape=True) 189 | def spec_isinstance(self: Judge, *args: AbsVal): 190 | if len(args) != 2: 191 | return NotImplemented 192 | l, r = args 193 | return_types = tuple({Values.A_Bool}) 194 | if ( 195 | isinstance(l.type, S) 196 | and isinstance(r, S) 197 | and isinstance(r.base, type) 198 | ): 199 | const = l.type == r or l.type.base in r.base.__bases__ 200 | return CallSpec(S(const), S(const), return_types) 201 | 202 | func = S(intrinsic("PyObject_IsInstance")) 203 | return CallSpec(None, func(l, r), return_types) 204 | 205 | 206 | @register(operator.__pow__, create_shape=True) 207 | def spec_pow(self: Judge, l: AbsVal, r: AbsVal): 208 | if l.type == Values.A_Int: 209 | if r.type == Values.A_Int: 210 | py_int_power_int = S(intrinsic("Py_IntPowInt")) 211 | return_types = tuple({Values.A_Int}) 212 | constant_result = None # no constant result 213 | 214 | return CallSpec( 215 | constant_result, py_int_power_int(l, r), return_types 216 | ) 217 | return NotImplemented 218 | 219 | 220 | @register(len, create_shape=True) 221 | def spec_len(self: Judge, *args: AbsVal): 222 | if len(args) != 1: 223 | return NotImplemented 224 | arg = args[0] 225 | 226 | func = S(intrinsic("PySequence_Length")) 227 | e_call = func(arg) 228 | return CallSpec(None, e_call, tuple({Values.A_Int})) 229 | 230 | 231 | @register(operator.__add__, create_shape=True) 232 | def spec_add(self: Judge, l: AbsVal, r: AbsVal): 233 | if l.type == Values.A_Int: 234 | if r.type == Values.A_Int: 235 | py_int_add_int = S(intrinsic("Py_IntAddInt")) 236 | return_types = tuple({Values.A_Int}) 237 | constant_result = None # no constant result 238 | return CallSpec( 239 | constant_result, py_int_add_int(l, r), return_types 240 | ) 241 | return NotImplemented 242 | 243 | 244 | @register(operator.__iadd__, create_shape=True) 245 | def spec_add(self: Judge, l: AbsVal, r: AbsVal): 246 | if l.type == Values.A_Int: 247 | if r.type == Values.A_Int: 248 | py_int_add_int = S(intrinsic("Py_IntAddInt")) 249 | return_types = tuple({Values.A_Int}) 250 | constant_result = None # no constant result 251 | return CallSpec( 252 | constant_result, py_int_add_int(l, r), return_types 253 | ) 254 | return NotImplemented 255 | 256 | 257 | def spec_cmp(op): 258 | def spec_op(self: Judge, *args: AbsVal): 259 | if len(args) != 2: 260 | return NotImplemented 261 | l, r = args 262 | func_name = "PyObject_RichCompare" 263 | l_t = l.type 264 | r_t = r.type 265 | if ( 266 | l_t.is_s() 267 | and r_t.is_s() 268 | and l_t.base in concrete_numeric_types 269 | and r_t.base in concrete_numeric_types 270 | ): 271 | if l_t.base == r_t.base: 272 | func_name = concrete_numeric_types[l_t.base] 273 | ret_types = tuple({Values.A_Bool}) 274 | else: 275 | ret_types = tuple({Top}) 276 | 277 | func = S(intrinsic(func_name)) 278 | return CallSpec(None, func(l, r, lit(op)), ret_types) 279 | 280 | return spec_op 281 | 282 | 283 | register(operator.__le__, create_shape=True)(spec_cmp("Py_LE")) 284 | register(operator.__lt__, create_shape=True)(spec_cmp("Py_LT")) 285 | register(operator.__ge__, create_shape=True)(spec_cmp("Py_GE")) 286 | register(operator.__gt__, create_shape=True)(spec_cmp("Py_GT")) 287 | register(operator.__eq__, create_shape=True)(spec_cmp("Py_EQ")) 288 | register(operator.__ne__, create_shape=True)(spec_cmp("Py_NE")) 289 | 290 | 291 | @register(math.sqrt, create_shape=True) 292 | def spec_sqrt(self: Judge, a: AbsVal): 293 | if a.type == Values.A_Int: 294 | int_sqrt = S(intrinsic("Py_IntSqrt")) 295 | return CallSpec(None, int_sqrt(a), tuple({Values.A_Float})) 296 | return NotImplemented 297 | 298 | 299 | @register(setattr, create_shape=True) 300 | def call_setattr(self: Judge, *args: AbsVal): 301 | if len(args) != 3: 302 | return NotImplemented 303 | 304 | func = S(intrinsic("PyObject_SetAttr")) 305 | e_call = func(*args) 306 | return CallSpec(S(None), e_call, (Values.A_NoneType,)) 307 | 308 | 309 | @register(getattr, create_shape=True) 310 | def call_getattr(self: Judge, *args: AbsVal): 311 | if len(args) != 2: 312 | return NotImplemented 313 | ret_types = (Top,) 314 | subject, attr = args 315 | # noinspection PyUnboundLocalVariable 316 | if ( 317 | subject.type.is_s() 318 | and (shape := subject.type.shape) 319 | and (__getattr__ := shape.fields.get("__getattr__")) 320 | ): 321 | # noinspection PyUnboundLocalVariable 322 | __getattr__ = __getattr__ 323 | if not shape.oop: 324 | args = [attr] 325 | if isinstance(__getattr__, S): 326 | r = self.spec(__getattr__, "__call__", *args) 327 | if r is not NotImplemented: 328 | return r 329 | else: 330 | r = __getattr__(self, *args) 331 | if r is not NotImplemented: 332 | return r 333 | 334 | func = S(intrinsic("PyObject_GetAttr")) 335 | e_call = func(subject, attr) 336 | instance = None 337 | return CallSpec(instance, e_call, ret_types) 338 | 339 | 340 | @register(operator.__getitem__, create_shape=True) 341 | def call_getitem(self: Judge, *args: AbsVal): 342 | if len(args) != 2: 343 | return NotImplemented 344 | subject, item = args 345 | ret_types = (Top,) 346 | sub_t = subject.type 347 | if sub_t not in (Top, Bot): 348 | if sub_t.shape and ( 349 | dispatched := sub_t.shape.fields.get("__getitem__") 350 | ): 351 | # noinspection PyUnboundLocalVariable 352 | if isinstance(dispatched, FunctionType): 353 | # noinspection PyUnboundLocalVariable 354 | return dispatched(self, *args) 355 | return self.spec(subject, "__call__", *args) 356 | 357 | func = S(intrinsic("PyObject_GetItem")) 358 | e_call = func(subject, item) 359 | instance = None 360 | return CallSpec(instance, e_call, ret_types) 361 | 362 | 363 | @register(list, attr="__getitem__") 364 | def call_list_getitem(self: Judge, *args): 365 | if len(args) != 2: 366 | return NotImplemented 367 | ret_types = (Top,) 368 | subject, item = args 369 | func = S(intrinsic("PyList_GetItem")) 370 | e_call = func(subject, item) 371 | instance = None 372 | return CallSpec(instance, e_call, ret_types) 373 | 374 | 375 | @register(dict, attr="__getitem__") 376 | def call_list_getitem(self: Judge, *args): 377 | if len(args) != 2: 378 | return NotImplemented 379 | ret_types = (Top,) 380 | subject, item = args 381 | func = S(intrinsic("PyDict_GetItemWithError")) 382 | e_call = func(subject, item) 383 | instance = None 384 | return CallSpec(instance, e_call, ret_types) 385 | 386 | 387 | @register(bytearray, attr="__getitem__") 388 | def call_bytearray_getitem(self: Judge, *args): 389 | if len(args) != 2: 390 | 391 | return NotImplemented 392 | ret_types = (Values.A_Int,) 393 | subject, item = args 394 | func = S(intrinsic("PyObject_GetItem")) 395 | e_call = func(subject, item) 396 | instance = None 397 | return CallSpec(instance, e_call, ret_types) 398 | 399 | 400 | @register(bytes, attr="join") 401 | def call_bytearray_join(self: Judge, *args: AbsVal): 402 | if not args != 2: 403 | return NotImplemented 404 | sep, iters = args 405 | if ( 406 | sep.type.is_s() 407 | and sep.type.base is bytes 408 | and iters.type.is_s() 409 | and issubclass(iters.type.base, Iterable) 410 | ): 411 | # https://sourcegraph.com/github.com/python/cpython@3.9/-/blob/Include/cpython/bytesobject.h#L36 412 | func = S(intrinsic("_PyBytes_Join")) 413 | e_call = func(sep, iters) 414 | return CallSpec(None, e_call, (S(bytes),)) 415 | 416 | spec = self.no_spec(sep, "join", [iters]) 417 | return CallSpec(None, spec.e_call, (S(bytes),)) 418 | 419 | 420 | @register(bytes, attr="__getitem__") 421 | def call_bytearray_getitem(self: Judge, *args: AbsVal): 422 | if len(args) != 2: 423 | return NotImplemented 424 | ret_types = (Values.A_Int,) 425 | subject, item = args 426 | # Sequence protocol is slower: 427 | # if item.type.is_s() and issubclass(item.type.base, int): 428 | # func = S(intrinsic("PySequence_GetItem")) 429 | # else: 430 | func = S(intrinsic("PyObject_GetItem")) 431 | e_call = func(subject, item) 432 | instance = None 433 | return CallSpec(instance, e_call, ret_types) 434 | 435 | 436 | @register(operator.__setitem__, create_shape=True) 437 | def call_getitem(self: Judge, *args: AbsVal): 438 | if len(args) != 3: 439 | # default python impl 440 | return NotImplemented 441 | subject, item, value = args 442 | func = S(intrinsic("PyObject_SetItem")) 443 | e_call = func(subject, item, value) 444 | instance = None 445 | ret_types = (Top,) 446 | return CallSpec(instance, e_call, ret_types) 447 | 448 | 449 | @register(list, attr="copy") 450 | def call_list_copy(self: Judge, *args: AbsVal): 451 | if len(args) != 1: 452 | return NotImplemented 453 | return CallSpec( 454 | None, 455 | S(Intrinsic.Py_CallMethod)(args[0], S("copy")), 456 | tuple({S(list)}), 457 | ) 458 | 459 | 460 | @register(list, attr="append") 461 | def list_append_analysis(self: Judge, *args: AbsVal): 462 | if len(args) != 2: 463 | # rollback to CPython's default code 464 | return NotImplemented 465 | lst, elt = args 466 | 467 | return CallSpec( 468 | instance=None, # return value is not static 469 | e_call=S(intrinsic("PyList_Append"))(lst, elt), 470 | possibly_return_types=tuple({S(type(None))}), 471 | ) 472 | 473 | 474 | @register(Intrinsic.Py_BuildList, create_shape=True) 475 | def call_build_list(self: Judge, *args: AbsVal): 476 | ret_types = tuple({S(list)}) 477 | func = S(intrinsic("PyList_Construct")) 478 | 479 | return CallSpec(None, func(*args), ret_types) 480 | 481 | 482 | @register(io.BytesIO) 483 | def call_bytes_io(self: Judge, *args: AbsVal): 484 | if len(args) != 1: 485 | return NotImplemented 486 | 487 | arg = args[0] 488 | func = S(Intrinsic.Py_CallFunction) 489 | abs_bytes_io = S(io.BytesIO) 490 | return CallSpec(None, func(abs_bytes_io, arg), (abs_bytes_io,)) 491 | 492 | 493 | next_type_maps = {io.BytesIO: bytes} 494 | 495 | 496 | @register(operator.is_, create_shape=True) 497 | def call_is(self: Judge, *args: AbsVal): 498 | if len(args) != 2: 499 | return NotImplemented 500 | ret_types = (Values.A_Bool,) 501 | l, r = args 502 | if l == r: 503 | return CallSpec(S(True), S(True), ret_types) 504 | if l.type.is_s() and r.type.is_s() and l.type.base != r.type.base: 505 | 506 | return CallSpec(S(False), S(False), ret_types) 507 | 508 | func = S(intrinsic("Py_AddressCompare")) 509 | return CallSpec(None, func(*args), ret_types) 510 | 511 | 512 | @register(operator.__not__, create_shape=True) 513 | def call_not_(self: Judge, *args: AbsVal): 514 | if len(args) != 1: 515 | return NotImplemented 516 | ret_types = (Values.A_Bool,) 517 | arg = args[0] 518 | if arg.is_literal(): 519 | const = S(not arg.base) 520 | return CallSpec(const, const, ret_types) 521 | c = self.no_spec(S(operator.__not__), "__call__", list(args)) 522 | return CallSpec(c.instance, c.e_call, ret_types) 523 | 524 | 525 | @register(next, create_shape=True) 526 | def call_next(self: Judge, *args: AbsVal): 527 | if len(args) not in (1, 2): 528 | return NotImplemented 529 | 530 | o = args[0] 531 | if o.type in (Top, Bot): 532 | return NotImplemented 533 | 534 | t = o.type.base 535 | if eltype := next_type_maps.get(t): 536 | ret_types = {S(eltype)} 537 | else: 538 | ret_types = {Top} 539 | 540 | if len(args) == 2: 541 | default = args[1] 542 | ret_types.add(default.type) 543 | func = S(Intrinsic.Py_CallFunction) 544 | 545 | ret_types = tuple(sorted(ret_types)) 546 | return CallSpec(None, func(S(next), *args), ret_types) 547 | 548 | 549 | # @register(int, "__init__") 550 | # def call_int(self: Judge, *args: AbsVal): 551 | # return CallSpec(S(None), S(None), (Values.A_NoneType, )) 552 | 553 | 554 | @register(int) 555 | def call_int(self: Judge, t: AbsVal, *args: AbsVal): 556 | A_Int = S(int) 557 | return_types = (A_Int,) 558 | if len(args) == 0: 559 | return CallSpec(S(0), S(0), return_types) 560 | if len(args) != 1: 561 | return NotImplemented 562 | o = args[0] 563 | if o.type == A_Int: 564 | return CallSpec(None, o, return_types) 565 | 566 | func = S(intrinsic("PyNumber_Long")) 567 | return CallSpec(None, func(o), return_types) 568 | 569 | 570 | # @lru_cache() 571 | # def mk_call_type_n(N): 572 | # args = ",".join(f"x{a}" for a in range(N)) 573 | # name = f"call_type{N}" 574 | # f = f""" 575 | # def {name}(typ: AbsVal, {args}): 576 | # o = typ.__new__(typ, {args}) 577 | # if not isinstance(o, typ): 578 | # return o 579 | # typ.__init__(o, {args}) 580 | # return o 581 | # """ 582 | # scope = {} 583 | # exec(f, scope) 584 | # func = scope[name] 585 | # from diojit.user.client import jit 586 | # 587 | # return jit(func, fixed_references=["isinstance"]) 588 | # 589 | # 590 | # @register(type) 591 | # def call_type(self: Judge, typ: AbsVal, *args: AbsVal): 592 | # if typ == Values.A_Type and args: 593 | # if len(args) == 1: 594 | # arg = args[0] 595 | # if arg.type.is_s(): 596 | # a_t = arg.type 597 | # return CallSpec(a_t, a_t, (Values.A_Type,)) 598 | # func = S(intrinsic("PyObject_Type")) 599 | # return CallSpec(None, func(arg), (Values.A_Type,)) 600 | # return NotImplemented 601 | # if typ.is_s(): 602 | # return self.spec( 603 | # S(mk_call_type_n(len(args))), "__call__", [typ, *args] 604 | # ) 605 | # return NotImplemented 606 | -------------------------------------------------------------------------------- /diojit/codegen/__init__.py: -------------------------------------------------------------------------------- 1 | from . import julia -------------------------------------------------------------------------------- /diojit/codegen/julia.py: -------------------------------------------------------------------------------- 1 | """ 2 | why julia? 3 | 1. low latency incremental compilation(LLVM) 4 | 2. easier interface and ABI treatment(LLVM C API is hardo) 5 | 3. can do some zero-specialisation in julia side. 6 | For instance, manual analysis for Python RC's ownership transfer 7 | is verbose, but with Julia specialisation it is easily 8 | made automatic. 9 | 4. I love Julia 10 | """ 11 | from __future__ import annotations 12 | from ..absint import * 13 | from io import StringIO 14 | from typing import Union 15 | from contextlib import contextmanager 16 | from itertools import repeat, chain 17 | 18 | import json 19 | 20 | 21 | def splice(o): 22 | return f"@DIO_Obj({u64o(o)})" 23 | 24 | 25 | def u64o(o: object): 26 | """ 27 | uint64 address from object 28 | """ 29 | return u64i(id(o)) 30 | 31 | 32 | def u64i(i: int): 33 | """uint64 from integer""" 34 | return f"{i:#0{18}x}" 35 | 36 | 37 | class Codegen: 38 | def __init__(self, out_def: Out_Def): 39 | self.out_def = out_def 40 | self.io = StringIO() 41 | self.indent = "" 42 | self.params = set(map(self.param, self.out_def.params)) 43 | # self._inc = 0 44 | 45 | def __lshift__(self, other: str): 46 | # self.io.write(f'println({self._inc})\n') 47 | # self._inc += 1 48 | 49 | self.io.write(self.indent) 50 | self.io.write(other) 51 | self.io.write("\n") 52 | 53 | def __matmul__(self, other): 54 | self.io.write("\n") 55 | self.io.write(other) 56 | self.io.write("\n") 57 | 58 | def var_i(self, i: int): 59 | return f"x{i}" 60 | 61 | def var(self, target: D): 62 | return self.var_i(target.i) 63 | 64 | def param(self, p: Union[S, D]): 65 | if isinstance(p, D): 66 | return f"x{p.i}" 67 | return "_" 68 | 69 | @staticmethod 70 | def uint64(i): 71 | return u64i(i) 72 | 73 | def val(self, v: Union[str, S, D]): 74 | if isinstance(v, str): 75 | # see 'diojit.prescr.lit' 76 | return v 77 | 78 | if isinstance(v, S): 79 | base = v.base 80 | if isinstance(base, Intrinsic): 81 | a = repr(base) 82 | return a 83 | # get object from address 84 | return f"{splice(v.base)} #= {repr(v.base).replace('=#', '//=//#')} =#" 85 | 86 | return self.var(v) 87 | 88 | def call(self, x: Out_Call): 89 | assert isinstance(x.func, S) and isinstance( 90 | x.func.base, Intrinsic 91 | ), x 92 | f = self.val(x.func) 93 | args = ", ".join(map(self.val, x.args)) 94 | return f, f"{f}({args})" 95 | 96 | def get_jl_definitions(self): 97 | self.io = StringIO() 98 | spec_info = self.out_def.spec 99 | with self.indent_inc(): 100 | self.visit_many(self.out_def.instrs) 101 | func_body = self.io.getvalue() 102 | self.io = StringIO() 103 | params = ", ".join(map(self.param, self.out_def.params)) 104 | self << f"DIO.@codegen DIO.@q function {spec_info.abs_jit_func}({params})" 105 | self << func_body 106 | self @ "@label except" 107 | # TODO: add traceback 108 | self << " DIO_Return = Py_NULL" 109 | self @ "@label ret" 110 | self << " return DIO_Return" 111 | self << "end" 112 | doc_io = StringIO() 113 | self.out_def.show(lambda *args: print(*args, file=doc_io)) 114 | # function documentation 115 | doc = json.dumps(doc_io.getvalue()).replace("$", "\\$") 116 | self << f"const DOC_{spec_info.abs_jit_func} = " f"{doc}" 117 | return self.io.getvalue() 118 | 119 | def get_py_interfaces(self): 120 | narg = len(self.out_def.params) 121 | spec_info = self.out_def.spec 122 | return f"@DIO_MakePtrCFunc {narg} {spec_info.abs_jit_func} {self.out_def.name}\n" 123 | 124 | @contextmanager 125 | def indent_inc(self): 126 | old = self.indent 127 | try: 128 | self.indent = old + " " 129 | yield 130 | finally: 131 | self.indent = old 132 | 133 | def visit_many(self, instrs: tuple[Out_Instr, ...]): 134 | for each in instrs: 135 | self.visit(each) 136 | 137 | def visit(self, instr: Out_Instr): 138 | if isinstance(instr, Out_SetLineno): 139 | filename = json.dumps(instr.filename).replace("$", "\\$") 140 | self << f"@DIO_SetLineno {instr.line} {filename}" 141 | pass 142 | elif isinstance(instr, Out_Assign): 143 | var = self.var(instr.target) 144 | assert isinstance(instr.expr, Out_Call) 145 | f, val = self.call(instr.expr) 146 | self << f"__tmp__ = {val}" 147 | self << f"if __tmp__ === DIO_ExceptCode({f})" 148 | for i in instr.decrefs: 149 | self << f" Py_DECREF({self.var_i(i)})" 150 | self << r" @goto except" 151 | self << f"elseif DIO_HasCast({f})" 152 | self << f" {var} = DIO_HasCast({f}, __tmp__)" 153 | self << f" if DIO_CastExc({f})" 154 | self << f" {var} === Py_NULL && return Py_NULL" 155 | self << r" end" 156 | self << r"elseif __tmp__ isa PyPtr" 157 | self << f" {var} = __tmp__" 158 | self << r"elseif __tmp__ isa Integer" 159 | self << f" {var} = DIO_WrapIntValue(__tmp__)" 160 | self << f" {var} === Py_NULL && return Py_NULL" 161 | self << r"else" 162 | self << f" {var} = DIO_NewNone()" 163 | self << r"end" 164 | return 165 | elif isinstance(instr, Out_Label): 166 | self @ f"@label {instr.label}" 167 | return 168 | elif isinstance(instr, Out_Goto): 169 | self << f"@goto {instr.label}" 170 | return 171 | elif isinstance(instr, Out_Return): 172 | val = self.val(instr.value) 173 | self << f"DIO_Return = {val}" 174 | mini_opt = False 175 | if isinstance(instr.value, D): 176 | hold = instr.value.i 177 | if hold in instr.decrefs: 178 | mini_opt = True 179 | for i in instr.decrefs: 180 | if i != hold: 181 | self << f"Py_DECREF({self.var_i(i)})" 182 | if not mini_opt: 183 | self << f"Py_INCREF(DIO_Return)" 184 | for i in instr.decrefs: 185 | self << f"Py_DECREF({self.var_i(i)})" 186 | self << f"@goto ret" 187 | return 188 | elif isinstance(instr, Out_If): 189 | val = self.val(instr.test) 190 | self << f"if {u64o(True)} === reinterpret(UInt64, {val})" 191 | self << f" @goto {instr.t}" 192 | self << "else" 193 | self << f" @goto {instr.f}" 194 | self << "end" 195 | return 196 | 197 | elif isinstance(instr, Out_TypeCase): 198 | val = self.val(instr.obj) 199 | self << f"__type__ = reinterpret(UInt64, Py_TYPE({val}))" 200 | cases = instr.cases 201 | has_any_type = Top in cases 202 | if has_any_type and len(cases) == 1: 203 | self.visit_many(cases[Top]) 204 | return 205 | ts = [] 206 | headers = chain(["if"], repeat("elseif")) 207 | for i, (typecase, block) in enumerate(instr.cases.items()): 208 | if typecase is Top: 209 | continue 210 | head = next(headers) 211 | ts.append(typecase) 212 | t = ts[-1].base 213 | self << f"# when type is {t}" 214 | self << f"{head} __type__ === {u64o(t)}" 215 | with self.indent_inc(): 216 | self.visit_many(block) 217 | if not has_any_type: 218 | # msg = ",".join(map(repr, ts)) 219 | self << "else" 220 | self << ' error("analyser produces incorrect return")' 221 | else: 222 | self << "else" 223 | with self.indent_inc(): 224 | self.visit_many(cases[Top]) 225 | self << "end" 226 | return 227 | elif isinstance(instr, Out_DecRef): 228 | self << f"Py_DECREF({self.var_i(instr.i)})" 229 | -------------------------------------------------------------------------------- /diojit/runtime/__init__.py: -------------------------------------------------------------------------------- 1 | from . import julia_rt -------------------------------------------------------------------------------- /diojit/runtime/julia_rt.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import os 3 | import signal 4 | import warnings 5 | import ctypes 6 | import sys 7 | import posixpath 8 | import importlib 9 | import json 10 | import julia.libjulia as jl_libjulia 11 | from json import dumps 12 | from julia.libjulia import LibJulia 13 | from julia.juliainfo import JuliaInfo 14 | from typing import Union 15 | from ..absint.abs import Out_Def as _Out_Def 16 | from ..codegen.julia import Codegen, u64o, splice 17 | import ctypes 18 | 19 | pydll_address = hex(ctypes.pythonapi._handle) 20 | GenerateCache = _Out_Def.GenerateCache 21 | 22 | 23 | def get_libjulia(): 24 | global libjl 25 | if not libjl: 26 | libjl = startup() 27 | return libjl 28 | 29 | 30 | class RichCallSubprocessError(subprocess.CalledProcessError): 31 | def __str__(self): 32 | if self.returncode and self.returncode < 0: 33 | try: 34 | return "Command '%s' died with %r." % ( 35 | self.cmd, 36 | signal.Signals(-self.returncode), 37 | ) 38 | except ValueError: 39 | return "Command '%s' died with unknown signal %d." % ( 40 | self.cmd, 41 | -self.returncode, 42 | ) 43 | else: 44 | return ( 45 | "Command '%s' returned non-zero exit status %d: %s" 46 | % (self.cmd, self.returncode, self.stderr) 47 | ) 48 | 49 | 50 | def mk_libjulia(julia="julia", **popen_kwargs): 51 | if lib := getattr(jl_libjulia, "_LIBJULIA"): 52 | return lib 53 | 54 | proc = subprocess.Popen( 55 | [ 56 | julia, 57 | "--startup-file=no", 58 | "-e", 59 | "using DIO; DIO.PyJulia_INFO()", 60 | ], 61 | stdout=subprocess.PIPE, 62 | stderr=subprocess.PIPE, 63 | universal_newlines=True, 64 | **popen_kwargs, 65 | ) 66 | 67 | stdout, stderr = proc.communicate() 68 | retcode = proc.wait() 69 | if retcode != 0: 70 | raise RichCallSubprocessError( 71 | retcode, [julia, "-e", "..."], stdout, stderr 72 | ) 73 | 74 | stderr = stderr.strip() 75 | if stderr: 76 | warnings.warn("{} warned:\n{}".format(julia, stderr)) 77 | 78 | args = stdout.rstrip().split("\n") 79 | 80 | libjl = LibJulia.from_juliainfo(JuliaInfo(julia, *args)) 81 | libjl.jl_string_ptr.restype = ctypes.c_char_p 82 | libjl.jl_string_ptr.argtypes = [ctypes.c_void_p] 83 | libjl.jl_call1.argtypes = [ctypes.c_void_p, ctypes.c_void_p] 84 | libjl.jl_call1.restype = ctypes.c_void_p 85 | libjl.jl_eval_string.argtypes = [ctypes.c_char_p] 86 | libjl.jl_eval_string.restype = ctypes.c_void_p 87 | libjl.jl_stderr_stream.argtypes = [] 88 | libjl.jl_stderr_stream.restype = ctypes.c_void_p 89 | libjl.jl_printf.argtypes = [ctypes.c_void_p, ctypes.c_char_p] 90 | libjl.jl_printf.restype = ctypes.c_int 91 | return libjl 92 | 93 | 94 | class JuliaException(Exception): 95 | def __init__(self, msg): 96 | self.msg = msg 97 | 98 | def __repr__(self): 99 | return f"from julia: {self.msg}" 100 | 101 | 102 | def check_jl_err(libjl: LibJulia): 103 | if o := libjl.jl_exception_occurred(): 104 | msg = libjl.jl_string_ptr( 105 | libjl.jl_call1(libjl.jl_eval_string(b"error2str"), o) 106 | ).decode("utf-8") 107 | raise JuliaException(msg) 108 | 109 | 110 | def startup(): 111 | global libjl 112 | libjl = mk_libjulia() 113 | libjl.init_julia() 114 | # DIO package already checked when getting libjulia 115 | libjl.jl_eval_string( 116 | b"function error2str(e)\n" 117 | b" sprint(showerror, e; context=:color=>true)\n" 118 | b"end" 119 | ) 120 | libjl.jl_eval_string(b"using DIO") 121 | check_jl_err(libjl) 122 | 123 | import builtins 124 | 125 | libjl.jl_eval_string( 126 | str.encode( 127 | f"const PyO = PyOType(" 128 | f"PY_VERSION = Tuple({json.dumps(sys.version_info)})," 129 | f"builtins = {splice(builtins)}," 130 | f"print = {splice(print)}," 131 | f"bool = {splice(bool)}," 132 | f"int = {splice(int)}," 133 | f"float = {splice(float)}," 134 | f"str = {splice(str)}," 135 | f"type = {splice(type)}," 136 | f"True = {splice(True)}," 137 | f"False = {splice(False)}," 138 | f"None = {splice(None)}," 139 | f"complex = {splice(complex)}," 140 | f"tuple = {splice(tuple)}," 141 | f"list = {splice(list)}," 142 | f"set = {splice(set)}," 143 | f"dict = {splice(dict)}," 144 | f"import_module = {splice(importlib.import_module)}," 145 | f")", 146 | encoding="utf-8", 147 | ) 148 | ) 149 | check_jl_err(libjl) 150 | 151 | libjl.jl_eval_string( 152 | b"DIO.@setup(%s)" % pydll_address.encode("utf-8") 153 | ) 154 | check_jl_err(libjl) 155 | libjl.jl_eval_string(b"printerror(x) = println(showerror(x))") 156 | check_jl_err(libjl) 157 | # a = libjl.jl_eval_string( 158 | # b"Py_CallFunction(@DIO_Obj(%s), @DIO_Obj(%s), @DIO_Obj(%s))" 159 | # % ( 160 | # Codegen.uint64(id(print)).encode(), 161 | # Codegen.uint64(id(1)).encode(), 162 | # Codegen.uint64(id(3)).encode(), 163 | # ) 164 | # ) 165 | # check_jl_err(libjl) 166 | return libjl 167 | 168 | 169 | def as_py(res: ctypes.c_void_p): 170 | """ 171 | This should be used on the return of a JIT func. 172 | No need to incref as it's already done by the JIT func. 173 | """ 174 | libjl = get_libjulia() 175 | if res == 0: 176 | return None 177 | pyobj = libjl.jl_unbox_voidpointer(res) 178 | check_jl_err(libjl) 179 | return pyobj 180 | 181 | 182 | def code_gen(print_jl=None): 183 | libjl = get_libjulia() 184 | interfaces = bytearray() 185 | for out_def in GenerateCache.values(): 186 | cg = Codegen(out_def) 187 | interfaces.extend(cg.get_py_interfaces().encode("utf-8")) 188 | definition = cg.get_jl_definitions() 189 | if print_jl: 190 | print_jl(definition) 191 | definition = definition.encode("utf-8") 192 | libjl.jl_eval_string(definition) 193 | check_jl_err(libjl) 194 | if print_jl: 195 | print("# interfaces") 196 | print_jl(interfaces.decode("utf-8")) 197 | libjl.jl_eval_string(bytes(interfaces)) 198 | check_jl_err(libjl) 199 | for intrin in GenerateCache: 200 | v = libjl.jl_eval_string( 201 | b"PyFunc_%s" % repr(intrin).encode("utf-8") 202 | ) 203 | check_jl_err(libjl) 204 | intrin._callback = as_py(v) 205 | 206 | GenerateCache.clear() 207 | 208 | 209 | def jl_eval(s: Union[str, bytes]): 210 | if isinstance(s, str): 211 | s = s.encode("utf-8") 212 | libjl.jl_eval_string(s) 213 | check_jl_err(libjl) 214 | 215 | 216 | startup() 217 | -------------------------------------------------------------------------------- /diojit/stack2reg/__init__.py: -------------------------------------------------------------------------------- 1 | from .translate import translate 2 | -------------------------------------------------------------------------------- /diojit/stack2reg/cflags.py: -------------------------------------------------------------------------------- 1 | import dis 2 | _flags = {v: k for k, v in dis.COMPILER_FLAG_NAMES.items()} 3 | OPTIMIZED = _flags['OPTIMIZED'] 4 | NEWLOCALS = _flags['NEWLOCALS'] 5 | VARARGS = _flags['VARARGS'] 6 | VARKEYWORDS = _flags['VARKEYWORDS'] 7 | NESTED = _flags['NESTED'] 8 | GENERATOR = _flags['GENERATOR'] 9 | NOFREE = _flags['NOFREE'] 10 | COROUTINE = _flags['COROUTINE'] 11 | ITERABLE_COROUTINE = _flags['ITERABLE_COROUTINE'] 12 | ASYNC_GENERATOR = _flags['ASYNC_GENERATOR'] -------------------------------------------------------------------------------- /diojit/stack2reg/opcodes.py: -------------------------------------------------------------------------------- 1 | import opcode 2 | UNKNOWN_INSTR = object() 3 | POP_TOP = opcode.opmap.get('POP_TOP', UNKNOWN_INSTR) 4 | ROT_TWO = opcode.opmap.get('ROT_TWO', UNKNOWN_INSTR) 5 | ROT_THREE = opcode.opmap.get('ROT_THREE', UNKNOWN_INSTR) 6 | DUP_TOP = opcode.opmap.get('DUP_TOP', UNKNOWN_INSTR) 7 | DUP_TOP_TWO = opcode.opmap.get('DUP_TOP_TWO', UNKNOWN_INSTR) 8 | ROT_FOUR = opcode.opmap.get('ROT_FOUR', UNKNOWN_INSTR) 9 | NOP = opcode.opmap.get('NOP', UNKNOWN_INSTR) 10 | UNARY_POSITIVE = opcode.opmap.get('UNARY_POSITIVE', UNKNOWN_INSTR) 11 | UNARY_NEGATIVE = opcode.opmap.get('UNARY_NEGATIVE', UNKNOWN_INSTR) 12 | UNARY_NOT = opcode.opmap.get('UNARY_NOT', UNKNOWN_INSTR) 13 | UNARY_INVERT = opcode.opmap.get('UNARY_INVERT', UNKNOWN_INSTR) 14 | BINARY_MATRIX_MULTIPLY = opcode.opmap.get('BINARY_MATRIX_MULTIPLY', UNKNOWN_INSTR) 15 | INPLACE_MATRIX_MULTIPLY = opcode.opmap.get('INPLACE_MATRIX_MULTIPLY', UNKNOWN_INSTR) 16 | BINARY_POWER = opcode.opmap.get('BINARY_POWER', UNKNOWN_INSTR) 17 | BINARY_MULTIPLY = opcode.opmap.get('BINARY_MULTIPLY', UNKNOWN_INSTR) 18 | BINARY_MODULO = opcode.opmap.get('BINARY_MODULO', UNKNOWN_INSTR) 19 | BINARY_ADD = opcode.opmap.get('BINARY_ADD', UNKNOWN_INSTR) 20 | BINARY_SUBTRACT = opcode.opmap.get('BINARY_SUBTRACT', UNKNOWN_INSTR) 21 | BINARY_SUBSCR = opcode.opmap.get('BINARY_SUBSCR', UNKNOWN_INSTR) 22 | BINARY_FLOOR_DIVIDE = opcode.opmap.get('BINARY_FLOOR_DIVIDE', UNKNOWN_INSTR) 23 | BINARY_TRUE_DIVIDE = opcode.opmap.get('BINARY_TRUE_DIVIDE', UNKNOWN_INSTR) 24 | INPLACE_FLOOR_DIVIDE = opcode.opmap.get('INPLACE_FLOOR_DIVIDE', UNKNOWN_INSTR) 25 | INPLACE_TRUE_DIVIDE = opcode.opmap.get('INPLACE_TRUE_DIVIDE', UNKNOWN_INSTR) 26 | RERAISE = opcode.opmap.get('RERAISE', UNKNOWN_INSTR) 27 | WITH_EXCEPT_START = opcode.opmap.get('WITH_EXCEPT_START', UNKNOWN_INSTR) 28 | GET_AITER = opcode.opmap.get('GET_AITER', UNKNOWN_INSTR) 29 | GET_ANEXT = opcode.opmap.get('GET_ANEXT', UNKNOWN_INSTR) 30 | BEFORE_ASYNC_WITH = opcode.opmap.get('BEFORE_ASYNC_WITH', UNKNOWN_INSTR) 31 | END_ASYNC_FOR = opcode.opmap.get('END_ASYNC_FOR', UNKNOWN_INSTR) 32 | INPLACE_ADD = opcode.opmap.get('INPLACE_ADD', UNKNOWN_INSTR) 33 | INPLACE_SUBTRACT = opcode.opmap.get('INPLACE_SUBTRACT', UNKNOWN_INSTR) 34 | INPLACE_MULTIPLY = opcode.opmap.get('INPLACE_MULTIPLY', UNKNOWN_INSTR) 35 | INPLACE_MODULO = opcode.opmap.get('INPLACE_MODULO', UNKNOWN_INSTR) 36 | STORE_SUBSCR = opcode.opmap.get('STORE_SUBSCR', UNKNOWN_INSTR) 37 | DELETE_SUBSCR = opcode.opmap.get('DELETE_SUBSCR', UNKNOWN_INSTR) 38 | BINARY_LSHIFT = opcode.opmap.get('BINARY_LSHIFT', UNKNOWN_INSTR) 39 | BINARY_RSHIFT = opcode.opmap.get('BINARY_RSHIFT', UNKNOWN_INSTR) 40 | BINARY_AND = opcode.opmap.get('BINARY_AND', UNKNOWN_INSTR) 41 | BINARY_XOR = opcode.opmap.get('BINARY_XOR', UNKNOWN_INSTR) 42 | BINARY_OR = opcode.opmap.get('BINARY_OR', UNKNOWN_INSTR) 43 | INPLACE_POWER = opcode.opmap.get('INPLACE_POWER', UNKNOWN_INSTR) 44 | GET_ITER = opcode.opmap.get('GET_ITER', UNKNOWN_INSTR) 45 | GET_YIELD_FROM_ITER = opcode.opmap.get('GET_YIELD_FROM_ITER', UNKNOWN_INSTR) 46 | PRINT_EXPR = opcode.opmap.get('PRINT_EXPR', UNKNOWN_INSTR) 47 | LOAD_BUILD_CLASS = opcode.opmap.get('LOAD_BUILD_CLASS', UNKNOWN_INSTR) 48 | YIELD_FROM = opcode.opmap.get('YIELD_FROM', UNKNOWN_INSTR) 49 | GET_AWAITABLE = opcode.opmap.get('GET_AWAITABLE', UNKNOWN_INSTR) 50 | LOAD_ASSERTION_ERROR = opcode.opmap.get('LOAD_ASSERTION_ERROR', UNKNOWN_INSTR) 51 | INPLACE_LSHIFT = opcode.opmap.get('INPLACE_LSHIFT', UNKNOWN_INSTR) 52 | INPLACE_RSHIFT = opcode.opmap.get('INPLACE_RSHIFT', UNKNOWN_INSTR) 53 | INPLACE_AND = opcode.opmap.get('INPLACE_AND', UNKNOWN_INSTR) 54 | INPLACE_XOR = opcode.opmap.get('INPLACE_XOR', UNKNOWN_INSTR) 55 | INPLACE_OR = opcode.opmap.get('INPLACE_OR', UNKNOWN_INSTR) 56 | LIST_TO_TUPLE = opcode.opmap.get('LIST_TO_TUPLE', UNKNOWN_INSTR) 57 | RETURN_VALUE = opcode.opmap.get('RETURN_VALUE', UNKNOWN_INSTR) 58 | IMPORT_STAR = opcode.opmap.get('IMPORT_STAR', UNKNOWN_INSTR) 59 | SETUP_ANNOTATIONS = opcode.opmap.get('SETUP_ANNOTATIONS', UNKNOWN_INSTR) 60 | YIELD_VALUE = opcode.opmap.get('YIELD_VALUE', UNKNOWN_INSTR) 61 | POP_BLOCK = opcode.opmap.get('POP_BLOCK', UNKNOWN_INSTR) 62 | POP_EXCEPT = opcode.opmap.get('POP_EXCEPT', UNKNOWN_INSTR) 63 | STORE_NAME = opcode.opmap.get('STORE_NAME', UNKNOWN_INSTR) 64 | DELETE_NAME = opcode.opmap.get('DELETE_NAME', UNKNOWN_INSTR) 65 | UNPACK_SEQUENCE = opcode.opmap.get('UNPACK_SEQUENCE', UNKNOWN_INSTR) 66 | FOR_ITER = opcode.opmap.get('FOR_ITER', UNKNOWN_INSTR) 67 | UNPACK_EX = opcode.opmap.get('UNPACK_EX', UNKNOWN_INSTR) 68 | STORE_ATTR = opcode.opmap.get('STORE_ATTR', UNKNOWN_INSTR) 69 | DELETE_ATTR = opcode.opmap.get('DELETE_ATTR', UNKNOWN_INSTR) 70 | STORE_GLOBAL = opcode.opmap.get('STORE_GLOBAL', UNKNOWN_INSTR) 71 | DELETE_GLOBAL = opcode.opmap.get('DELETE_GLOBAL', UNKNOWN_INSTR) 72 | LOAD_CONST = opcode.opmap.get('LOAD_CONST', UNKNOWN_INSTR) 73 | LOAD_NAME = opcode.opmap.get('LOAD_NAME', UNKNOWN_INSTR) 74 | BUILD_TUPLE = opcode.opmap.get('BUILD_TUPLE', UNKNOWN_INSTR) 75 | BUILD_LIST = opcode.opmap.get('BUILD_LIST', UNKNOWN_INSTR) 76 | BUILD_SET = opcode.opmap.get('BUILD_SET', UNKNOWN_INSTR) 77 | BUILD_MAP = opcode.opmap.get('BUILD_MAP', UNKNOWN_INSTR) 78 | LOAD_ATTR = opcode.opmap.get('LOAD_ATTR', UNKNOWN_INSTR) 79 | COMPARE_OP = opcode.opmap.get('COMPARE_OP', UNKNOWN_INSTR) 80 | IMPORT_NAME = opcode.opmap.get('IMPORT_NAME', UNKNOWN_INSTR) 81 | IMPORT_FROM = opcode.opmap.get('IMPORT_FROM', UNKNOWN_INSTR) 82 | JUMP_FORWARD = opcode.opmap.get('JUMP_FORWARD', UNKNOWN_INSTR) 83 | JUMP_IF_FALSE_OR_POP = opcode.opmap.get('JUMP_IF_FALSE_OR_POP', UNKNOWN_INSTR) 84 | JUMP_IF_TRUE_OR_POP = opcode.opmap.get('JUMP_IF_TRUE_OR_POP', UNKNOWN_INSTR) 85 | JUMP_ABSOLUTE = opcode.opmap.get('JUMP_ABSOLUTE', UNKNOWN_INSTR) 86 | POP_JUMP_IF_FALSE = opcode.opmap.get('POP_JUMP_IF_FALSE', UNKNOWN_INSTR) 87 | POP_JUMP_IF_TRUE = opcode.opmap.get('POP_JUMP_IF_TRUE', UNKNOWN_INSTR) 88 | LOAD_GLOBAL = opcode.opmap.get('LOAD_GLOBAL', UNKNOWN_INSTR) 89 | IS_OP = opcode.opmap.get('IS_OP', UNKNOWN_INSTR) 90 | CONTAINS_OP = opcode.opmap.get('CONTAINS_OP', UNKNOWN_INSTR) 91 | JUMP_IF_NOT_EXC_MATCH = opcode.opmap.get('JUMP_IF_NOT_EXC_MATCH', UNKNOWN_INSTR) 92 | SETUP_FINALLY = opcode.opmap.get('SETUP_FINALLY', UNKNOWN_INSTR) 93 | LOAD_FAST = opcode.opmap.get('LOAD_FAST', UNKNOWN_INSTR) 94 | STORE_FAST = opcode.opmap.get('STORE_FAST', UNKNOWN_INSTR) 95 | DELETE_FAST = opcode.opmap.get('DELETE_FAST', UNKNOWN_INSTR) 96 | RAISE_VARARGS = opcode.opmap.get('RAISE_VARARGS', UNKNOWN_INSTR) 97 | CALL_FUNCTION = opcode.opmap.get('CALL_FUNCTION', UNKNOWN_INSTR) 98 | MAKE_FUNCTION = opcode.opmap.get('MAKE_FUNCTION', UNKNOWN_INSTR) 99 | BUILD_SLICE = opcode.opmap.get('BUILD_SLICE', UNKNOWN_INSTR) 100 | LOAD_CLOSURE = opcode.opmap.get('LOAD_CLOSURE', UNKNOWN_INSTR) 101 | LOAD_DEREF = opcode.opmap.get('LOAD_DEREF', UNKNOWN_INSTR) 102 | STORE_DEREF = opcode.opmap.get('STORE_DEREF', UNKNOWN_INSTR) 103 | DELETE_DEREF = opcode.opmap.get('DELETE_DEREF', UNKNOWN_INSTR) 104 | CALL_FUNCTION_KW = opcode.opmap.get('CALL_FUNCTION_KW', UNKNOWN_INSTR) 105 | CALL_FUNCTION_EX = opcode.opmap.get('CALL_FUNCTION_EX', UNKNOWN_INSTR) 106 | SETUP_WITH = opcode.opmap.get('SETUP_WITH', UNKNOWN_INSTR) 107 | LIST_APPEND = opcode.opmap.get('LIST_APPEND', UNKNOWN_INSTR) 108 | SET_ADD = opcode.opmap.get('SET_ADD', UNKNOWN_INSTR) 109 | MAP_ADD = opcode.opmap.get('MAP_ADD', UNKNOWN_INSTR) 110 | LOAD_CLASSDEREF = opcode.opmap.get('LOAD_CLASSDEREF', UNKNOWN_INSTR) 111 | EXTENDED_ARG = opcode.opmap.get('EXTENDED_ARG', UNKNOWN_INSTR) 112 | SETUP_ASYNC_WITH = opcode.opmap.get('SETUP_ASYNC_WITH', UNKNOWN_INSTR) 113 | FORMAT_VALUE = opcode.opmap.get('FORMAT_VALUE', UNKNOWN_INSTR) 114 | BUILD_CONST_KEY_MAP = opcode.opmap.get('BUILD_CONST_KEY_MAP', UNKNOWN_INSTR) 115 | BUILD_STRING = opcode.opmap.get('BUILD_STRING', UNKNOWN_INSTR) 116 | LOAD_METHOD = opcode.opmap.get('LOAD_METHOD', UNKNOWN_INSTR) 117 | CALL_METHOD = opcode.opmap.get('CALL_METHOD', UNKNOWN_INSTR) 118 | LIST_EXTEND = opcode.opmap.get('LIST_EXTEND', UNKNOWN_INSTR) 119 | SET_UPDATE = opcode.opmap.get('SET_UPDATE', UNKNOWN_INSTR) 120 | DICT_MERGE = opcode.opmap.get('DICT_MERGE', UNKNOWN_INSTR) 121 | DICT_UPDATE = opcode.opmap.get('DICT_UPDATE', UNKNOWN_INSTR) 122 | 123 | -------------------------------------------------------------------------------- /diojit/stack2reg/translate.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import typing as _t 3 | import dis 4 | import types as _types 5 | import operator 6 | from contextlib import contextmanager 7 | from . import opcodes 8 | from . import cflags 9 | from diojit.absint import * 10 | 11 | __all__ = ["translate"] 12 | 13 | JUMP_NAMES: frozenset[int] = frozenset( 14 | { 15 | opcodes.POP_JUMP_IF_TRUE, 16 | opcodes.POP_JUMP_IF_FALSE, 17 | opcodes.JUMP_IF_TRUE_OR_POP, 18 | opcodes.JUMP_IF_FALSE_OR_POP, 19 | } 20 | ) 21 | 22 | 23 | BIN_OPS: dict[int, FunctionType] = { 24 | opcodes.BINARY_POWER: operator.__pow__, 25 | opcodes.BINARY_MULTIPLY: operator.__mul__, 26 | opcodes.BINARY_MATRIX_MULTIPLY: operator.__matmul__, 27 | opcodes.BINARY_FLOOR_DIVIDE: operator.__floordiv__, 28 | opcodes.BINARY_TRUE_DIVIDE: operator.__truediv__, 29 | opcodes.BINARY_MODULO: operator.__mod__, 30 | opcodes.BINARY_ADD: operator.__add__, 31 | opcodes.BINARY_SUBTRACT: operator.__sub__, 32 | opcodes.BINARY_SUBSCR: operator.__getitem__, 33 | opcodes.BINARY_LSHIFT: operator.__lshift__, 34 | opcodes.BINARY_RSHIFT: operator.__rshift__, 35 | opcodes.BINARY_AND: operator.__and__, 36 | opcodes.BINARY_XOR: operator.__xor__, 37 | opcodes.BINARY_OR: operator.__or__, 38 | } 39 | 40 | 41 | INP_BIN_OPS: dict[int, FunctionType] = { 42 | opcodes.INPLACE_POWER: operator.__ipow__, 43 | opcodes.INPLACE_MULTIPLY: operator.__imul__, 44 | opcodes.INPLACE_MATRIX_MULTIPLY: operator.__imatmul__, 45 | opcodes.INPLACE_FLOOR_DIVIDE: operator.__ifloordiv__, 46 | opcodes.INPLACE_TRUE_DIVIDE: operator.__itruediv__, 47 | opcodes.INPLACE_MODULO: operator.__imod__, 48 | opcodes.INPLACE_ADD: operator.__iadd__, 49 | opcodes.INPLACE_SUBTRACT: operator.__isub__, 50 | opcodes.INPLACE_LSHIFT: operator.__ilshift__, 51 | opcodes.INPLACE_RSHIFT: operator.__irshift__, 52 | opcodes.INPLACE_AND: operator.__iand__, 53 | opcodes.INPLACE_XOR: operator.__ixor__, 54 | opcodes.INPLACE_OR: operator.__ior__, 55 | } 56 | 57 | 58 | CMP_OPS: dict[str, FunctionType] = _t.cast( 59 | _t.Dict[str, FunctionType], 60 | { 61 | "<": operator.__lt__, 62 | ">": operator.__gt__, 63 | "<=": operator.__le__, 64 | ">=": operator.__ge__, 65 | "!=": operator.__ne__, 66 | "==": operator.__eq__, 67 | }, 68 | ) 69 | 70 | 71 | def flags_check(flag, *patterns): 72 | return any(flag & p for p in patterns) 73 | 74 | 75 | class PyC: 76 | def __init__(self, f): 77 | co = dis.Bytecode(f) 78 | code: _types.CodeType = co.codeobj 79 | if flags_check( 80 | code.co_flags, 81 | cflags.VARARGS, 82 | cflags.ASYNC_GENERATOR, 83 | cflags.ITERABLE_COROUTINE, 84 | cflags.VARKEYWORDS, 85 | cflags.GENERATOR, 86 | cflags.COROUTINE, 87 | ): 88 | raise ValueError("varargs/coroutine/varkeywords/generator are not supported yet.") 89 | 90 | self.co = list(co) 91 | self.label_to_co_offsets = _map = {} 92 | for i, instr in enumerate( 93 | self.co 94 | ): # type: int, dis.Instruction 95 | if instr.is_jump_target: 96 | _map[instr.offset] = i 97 | if instr.opcode in JUMP_NAMES: 98 | instr = self.co[i + 1] 99 | _map[instr.offset] = i + 1 100 | self.glob_names = set() 101 | self.offset = 0 102 | self.codeobj = code 103 | self.stack_size = len(code.co_varnames) 104 | assert ( 105 | not code.co_cellvars and not code.co_freevars 106 | ), "cannot handle closures so far" 107 | self.cur_block = entry_block = [] 108 | self.blocks: In_Blocks = {"entry": entry_block} 109 | self.block_maps: _t.Dict[_t.Tuple[int, int], str] = {} 110 | self.label_cnt = 0 111 | self.lastlinenumber = None 112 | self.filename = code.co_filename 113 | 114 | def make(self): 115 | self.interp(0) 116 | return self.blocks, self.glob_names 117 | 118 | def is_jump_target(self, bytecode_offset: int): 119 | return bytecode_offset in self.label_to_co_offsets 120 | 121 | def gen_label(self): 122 | self.label_cnt += 1 123 | return f"l{self.label_cnt}" 124 | 125 | def codegen(self, *stmts: In_Stmt): 126 | self.cur_block.extend(stmts) 127 | 128 | def cur(self) -> dis.Instruction: 129 | return self.co[self.offset] 130 | 131 | def next(self) -> dis.Instruction: 132 | return self.co[self.offset + 1] 133 | 134 | def push(self, a: AbsVal): 135 | self.stack_size += 1 136 | tos = self.peek(0) 137 | if tos != a: 138 | self.codegen(In_Move(tos, a)) 139 | 140 | def pop(self): 141 | tos = self.peek(0) 142 | self.stack_size -= 1 143 | return tos 144 | 145 | def call(self, f: AbsVal, *args: AbsVal): 146 | assert all(isinstance(a, AbsVal) for a in args) 147 | self.stack_size += 1 148 | tos = self.peek(0) 149 | self.codegen(In_Bind(tos, f, S("__call__"), args)) 150 | 151 | def call_method(self, m: AbsVal, attr: AbsVal, *args: AbsVal): 152 | self.stack_size += 1 153 | tos = self.peek(0) 154 | self.codegen(In_Bind(tos, m, attr, args)) 155 | 156 | def peek(self, i: int): 157 | return D(self.stack_size - i - 1, Top) 158 | 159 | def varlocal(self, i: int): 160 | return D(i, Top) 161 | 162 | def find_offset(self, bytecode_offset: int): 163 | return self.label_to_co_offsets[bytecode_offset] 164 | 165 | @contextmanager 166 | def generate_new_block(self, offset: int): 167 | old_offset = self.offset 168 | old_cur_block = self.cur_block 169 | old_stack_size = self.stack_size 170 | try: 171 | key = offset, self.stack_size 172 | self.offset = offset 173 | label = self.block_maps[key] = self.gen_label() 174 | self.cur_block = self.blocks[label] = [] 175 | yield label 176 | finally: 177 | self.offset = old_offset 178 | self.cur_block = old_cur_block 179 | self.stack_size = old_stack_size 180 | 181 | def jump(self, offset: int) -> str: 182 | key = offset, self.stack_size 183 | if key in self.block_maps: 184 | return self.block_maps[key] 185 | 186 | with self.generate_new_block(offset) as label: 187 | self.interp(offset) 188 | return label 189 | 190 | def get_nargs(self, n: int): 191 | args = [] 192 | for _ in range(n): 193 | arg = self.pop() 194 | args.append(arg) 195 | args.reverse() 196 | return args 197 | 198 | def interp(self, offset=0): 199 | while True: 200 | x: dis.Instruction = self.cur() 201 | if ( 202 | x.starts_line is not None 203 | and x.starts_line != self.lastlinenumber 204 | ): 205 | line = x.starts_line 206 | self.codegen(In_SetLineno(line, self.filename)) 207 | 208 | if self.is_jump_target(x.offset) and self.offset != offset: 209 | label = self.jump(self.offset) 210 | self.codegen(In_Goto(label)) 211 | return 212 | 213 | if x.opcode is opcodes.LOAD_CONST: 214 | argval = x.argval 215 | self.push(S(argval)) 216 | elif x.opcode is opcodes.LOAD_FAST: 217 | var = self.varlocal(x.arg) 218 | self.push(var) 219 | elif x.opcode is opcodes.STORE_FAST: 220 | tos = self.pop() 221 | var = self.varlocal(x.arg) 222 | self.codegen(In_Move(var, tos)) 223 | elif x.opcode is opcodes.LOAD_GLOBAL: 224 | self.require_global(x.argval) 225 | self.call(S(Intrinsic.Py_LoadGlobal), S(x.argval)) 226 | elif x.opcode is opcodes.STORE_GLOBAL: 227 | self.require_global(x.argval) 228 | self.call(S(Intrinsic.Py_StoreGlobal), S(x.argval)) 229 | elif x.opcode is opcodes.STORE_ATTR: 230 | a_base = self.pop() 231 | a_value = self.pop() 232 | self.call( 233 | S(setattr), 234 | a_base, 235 | S(x.argval), 236 | a_value, 237 | ) 238 | elif ( 239 | x.opcode is opcodes.JUMP_ABSOLUTE 240 | or x.opcode is opcodes.JUMP_FORWARD 241 | ): 242 | label = self.jump(self.find_offset(x.argval)) 243 | self.codegen(In_Goto(label)) 244 | return 245 | elif x.opcode is opcodes.JUMP_IF_TRUE_OR_POP: 246 | b_off_1 = self.find_offset(x.arg) 247 | b_off_2 = self.find_offset(self.next().offset) 248 | tos = self.peek(0) 249 | l1 = self.jump(b_off_1) 250 | self.pop() 251 | l2 = self.jump(b_off_2) 252 | self.codegen(In_Cond(tos, l1, l2)) 253 | return 254 | elif x.opcode is opcodes.JUMP_IF_FALSE_OR_POP: 255 | b_off_1 = self.find_offset(x.arg) 256 | b_off_2 = self.find_offset(self.next().offset) 257 | tos = self.peek(0) 258 | l1 = self.jump(b_off_1) 259 | self.pop() 260 | l2 = self.jump(b_off_2) 261 | self.codegen(In_Cond(tos, l2, l1)) 262 | return 263 | elif x.opcode is opcodes.POP_JUMP_IF_TRUE: 264 | b_off_1 = self.find_offset(x.arg) 265 | b_off_2 = self.find_offset(self.next().offset) 266 | tos = self.pop() 267 | l1 = self.jump(b_off_1) 268 | l2 = self.jump(b_off_2) 269 | self.codegen(In_Cond(tos, l1, l2)) 270 | return 271 | elif x.opcode is opcodes.POP_JUMP_IF_FALSE: 272 | b_off_1 = self.find_offset(x.arg) 273 | b_off_2 = self.find_offset(self.next().offset) 274 | tos = self.pop() 275 | l1 = self.jump(b_off_1) 276 | l2 = self.jump(b_off_2) 277 | self.codegen(In_Cond(tos, l2, l1)) 278 | return 279 | elif x.opcode is opcodes.LOAD_METHOD: 280 | self.push(S(x.argval)) 281 | elif x.opcode is opcodes.LOAD_ATTR: 282 | tos = self.pop() 283 | self.call(S(getattr), tos, S(x.argval)) 284 | elif x.opcode is opcodes.CALL_METHOD: 285 | args = self.get_nargs(x.argval) 286 | attr = self.pop() 287 | subj = self.pop() 288 | self.call_method(subj, attr, *args) 289 | elif x.opcode is opcodes.CALL_FUNCTION: 290 | args = self.get_nargs(x.argval) 291 | f = self.pop() 292 | self.call(f, *args) 293 | elif x.opcode is opcodes.ROT_TWO: 294 | self.push(self.peek(0)) 295 | self.codegen(In_Move(self.peek(1), self.peek(2))) 296 | subj = self.peek(2) 297 | self.codegen(In_Move(subj, self.pop())) 298 | elif x.opcode is opcodes.ROT_THREE: 299 | a1 = self.peek(0) 300 | a2 = self.peek(1) 301 | a3 = self.peek(2) 302 | 303 | self.push(a1) 304 | b1 = self.peek(0) 305 | self.push(a2) 306 | b2 = self.peek(0) 307 | self.push(a3) 308 | b3 = self.peek(0) 309 | 310 | self.codegen(In_Move(a3, b1)) 311 | self.codegen(In_Move(a2, b3)) 312 | self.codegen(In_Move(a1, b2)) 313 | 314 | self.pop() 315 | self.pop() 316 | self.pop() 317 | elif x.opcode is opcodes.ROT_FOUR: 318 | a1 = self.peek(0) 319 | a2 = self.peek(1) 320 | a3 = self.peek(2) 321 | a4 = self.peek(3) 322 | 323 | self.push(a1) 324 | b1 = self.peek(0) 325 | self.push(a2) 326 | b2 = self.peek(0) 327 | self.push(a3) 328 | b3 = self.peek(0) 329 | self.push(a4) 330 | b4 = self.peek(0) 331 | 332 | self.codegen(In_Move(a4, b1)) 333 | self.codegen(In_Move(a3, b4)) 334 | self.codegen(In_Move(a2, b3)) 335 | self.codegen(In_Move(a1, b2)) 336 | 337 | self.pop() 338 | self.pop() 339 | self.pop() 340 | self.pop() 341 | elif x.opcode is opcodes.DUP_TOP: 342 | self.push(self.peek(0)) 343 | elif x.opcode is opcodes.DUP_TOP_TWO: 344 | a = self.peek(1) 345 | b = self.peek(0) 346 | self.push(a) 347 | self.push(b) 348 | elif x.opcode is opcodes.BINARY_SUBSCR: 349 | # | TOS1 | TOS | 350 | # TOS1[TOS] 351 | tos = self.pop() 352 | tos1 = self.pop() 353 | self.call(S(operator.__getitem__), tos1, tos) 354 | elif x.opname.startswith("BINARY_"): 355 | right, left = self.pop(), self.pop() 356 | self.call(S(BIN_OPS[x.opcode]), left, right) 357 | elif x.opname.startswith("INPLACE_"): 358 | right, left = self.pop(), self.pop() 359 | self.call(S(INP_BIN_OPS[x.opcode]), left, right) 360 | elif x.opcode is opcodes.BUILD_TUPLE: 361 | args = self.get_nargs(x.argval) 362 | self.call(S(Intrinsic.Py_BuildTuple), *args) 363 | elif x.opcode is opcodes.BUILD_LIST: 364 | args = self.get_nargs(x.argval) 365 | self.call( 366 | S(Intrinsic.Py_BuildList), 367 | *args, 368 | ) 369 | elif x.opcode is opcodes.RETURN_VALUE: 370 | tos = self.pop() 371 | self.codegen(In_Return(tos)) 372 | return 373 | # python 3.9 374 | elif x.opcode is opcodes.IS_OP: 375 | right, left = self.pop(), self.pop() 376 | if x.argval != 1: 377 | self.call(S(operator.is_), left, right) 378 | else: 379 | self.call(S(operator.is_), left, right) 380 | self.call(S(operator.__not__), self.pop()) 381 | elif x.opcode is opcodes.CONTAINS_OP: 382 | right, left = self.pop(), self.pop() 383 | if x.argval != 1: 384 | self.call(S(operator.__contains__), right, left) 385 | else: 386 | self.call(S(operator.__contains__), right, left) 387 | self.call(S(operator.__not__), self.pop()) 388 | elif x.opcode is opcodes.COMPARE_OP: 389 | cmp_name = dis.cmp_op[x.arg] 390 | right, left = self.pop(), self.pop() 391 | if cmp_name == "not_in": 392 | self.call(S(operator.__contains__), right, left) 393 | self.call(S(operator.__not__), self.pop()) 394 | elif cmp_name == "in": 395 | self.call(S(operator.__contains__), right, left) 396 | elif cmp_name == "exception match": 397 | raise NotImplemented 398 | elif cmp_name == "is": 399 | self.call(S(operator.is_), left, right) 400 | elif cmp_name == "is not": 401 | self.call(S(operator.is_), left, right) 402 | self.call(S(operator.__not__), self.pop()) 403 | else: 404 | self.call(S(CMP_OPS[cmp_name]), left, right) 405 | 406 | elif x.opcode is opcodes.DELETE_SUBSCR: 407 | tos = self.pop() 408 | tos1 = self.pop() 409 | self.call(S(operator.__delitem__), tos1, tos) 410 | self.pop() 411 | elif x.opcode is opcodes.STORE_SUBSCR: 412 | # | TOS2 | TOS1 | TOS | 413 | # TOS1[TOS] = TOS2 414 | tos = self.pop() 415 | tos1 = self.pop() 416 | tos2 = self.pop() 417 | self.call(S(operator.__setitem__), tos1, tos, tos2) 418 | self.pop() 419 | elif x.opcode is opcodes.BUILD_SLICE: 420 | # """ 421 | # Pushes a slice object on the stack. argc must be 2 or 3. If it is 2, slice(TOS1, TOS) is pushed; 422 | # if it is 3, slice(TOS2, TOS1, TOS) is pushed. See the slice() built-in function for more information. 423 | # """ 424 | tos = self.pop() 425 | tos1 = self.pop() 426 | if x.argval == 2: 427 | self.call(S(slice), tos1, tos) 428 | elif x.argval == 3: 429 | tos2 = self.pop() 430 | self.call(S(slice), tos2, tos1, tos) 431 | else: 432 | raise ValueError 433 | 434 | elif x.opcode is opcodes.POP_TOP: 435 | self.pop() 436 | else: 437 | raise ValueError(x.opname) 438 | self.offset += 1 439 | 440 | def build_const_tuple(self, argval: tuple): 441 | for each in argval: 442 | if isinstance(each, tuple): 443 | self.build_const_tuple(each) 444 | else: 445 | self.push(each) 446 | 447 | def require_global(self, argval): 448 | self.glob_names.add(argval) 449 | 450 | 451 | def translate(f): 452 | pyc = PyC(f) 453 | return pyc.make() 454 | -------------------------------------------------------------------------------- /diojit/user/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thautwarm/diojit/87b738a2edc9242ac82b2c34744d173f8c5005c2/diojit/user/__init__.py -------------------------------------------------------------------------------- /diojit/user/client.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import typing 3 | import dataclasses 4 | import sys 5 | from diojit.absint.intrinsics import intrinsic 6 | from diojit.absint.prescr import create_shape, register 7 | 8 | from ..absint.abs import ( 9 | CallSpec, 10 | FunctionType, 11 | S, 12 | Top, 13 | AbsVal, 14 | ) 15 | from .. import absint 16 | from ..stack2reg.translate import translate 17 | from typing import Iterable 18 | 19 | __all__ = [ 20 | "eagerjit", 21 | "conservativejit", 22 | "jit", 23 | "eager_jitclass", 24 | "jitclass", "spec_call", "spec_call_ir", 25 | "oftype", 26 | "ofval", 27 | ] 28 | 29 | 30 | def _is_function(o): 31 | return isinstance(o, type(_is_function)) 32 | 33 | 34 | def jit( 35 | func: absint.FunctionType = None, 36 | fixed_references: Iterable[str] = None, 37 | ): 38 | 39 | if fixed_references is not None: 40 | fixed_references = set(fixed_references) 41 | if not func: 42 | return lambda func: _jit(func, fixed_references) 43 | return _jit(func, fixed_references) 44 | else: 45 | assert func 46 | return _jit(func, set()) 47 | 48 | 49 | def _jit(func: absint.FunctionType, glob: set[str]): 50 | 51 | code = func.__code__ 52 | blocks, glob_names = translate(func) 53 | static_global = glob_names & glob 54 | in_def = absint.In_Def( 55 | code.co_argcount, 56 | blocks, 57 | func, 58 | static_global, 59 | ) 60 | absint.In_Def.UserCodeDyn[func] = in_def 61 | return func 62 | 63 | 64 | @dataclasses.dataclass 65 | class Val: 66 | a: object 67 | 68 | 69 | def spec_call_ir( 70 | f: FunctionType, *args, attr="__call__", glob=None 71 | ): 72 | narg = f.__code__.co_argcount 73 | assert ( 74 | _is_function(f) and len(args) == narg 75 | ), f"Function {f} takes exactly {narg} arguments." 76 | 77 | rt_map = [] 78 | a_args = [] 79 | for arg in args: 80 | if isinstance(arg, Val): 81 | a_args.append(absint.from_runtime(arg.a)) 82 | else: 83 | assert isinstance(arg, absint.AbsVal) 84 | a_args.append(absint.D(len(rt_map), arg)) 85 | a_f = absint.from_runtime(f, rt_map) 86 | j = absint.Judge({}, f, {} if glob is None else glob) 87 | return j.spec(a_f, attr, a_args) 88 | 89 | 90 | _code_gen = None 91 | 92 | 93 | def spec_call( 94 | f: absint.FunctionType, 95 | *args, 96 | attr="__call__", 97 | glob=None, 98 | print_jl=None, 99 | print_dio_ir=None, 100 | ): 101 | global _code_gen 102 | if not _code_gen: 103 | from diojit.runtime.julia_rt import code_gen 104 | 105 | _code_gen = code_gen 106 | 107 | narg = f.__code__.co_argcount 108 | assert ( 109 | _is_function(f) and len(args) == narg 110 | ), f"Function {f} takes exactly {narg} arguments." 111 | rt_map = [] 112 | a_args = [] 113 | for arg in args: 114 | if isinstance(arg, Val): 115 | a_args.append(absint.from_runtime(arg.a)) 116 | else: 117 | assert isinstance(arg, absint.AbsVal), arg 118 | a_args.append(absint.D(len(rt_map), arg)) 119 | a_f = absint.from_runtime(f, rt_map) 120 | j = absint.Judge({}, f, {} if glob is None else glob) 121 | spec = j.spec(a_f, attr, a_args) 122 | jit_f = spec.e_call.func.base 123 | assert isinstance(jit_f, absint.Intrinsic) 124 | if print_dio_ir: 125 | from diojit.runtime.julia_rt import GenerateCache 126 | 127 | for each in GenerateCache.values(): 128 | each.show(print_dio_ir) 129 | 130 | _code_gen(print_jl) 131 | return getattr(jit_f, "_callback") 132 | 133 | 134 | def oftype(t: object): 135 | """ 136 | create an abstract value from type object 137 | """ 138 | abs = absint.from_runtime(t) 139 | if not isinstance(abs, absint.D): 140 | return abs 141 | raise TypeError(f"{t} is not a type") 142 | 143 | 144 | def ofval(o: object): 145 | return Val(o) 146 | 147 | 148 | def eagerjit(func: typing.Union[FunctionType, type]): 149 | if isinstance(func, type): 150 | cls = func 151 | return eager_jitclass(cls) 152 | assert isinstance(func, FunctionType) 153 | fixed_references = func.__code__.co_names 154 | return _jit(func, set(fixed_references)) 155 | 156 | 157 | def conservativejit( 158 | func: typing.Union[FunctionType, type], fixed_references=() 159 | ): 160 | fixed_references = set(fixed_references) 161 | if isinstance(func, type): 162 | cls = func 163 | return jitclass(cls, fixed_references) 164 | assert isinstance(func, FunctionType) 165 | fixed_references = func.__code__.co_names 166 | return _jit(func, fixed_references) 167 | 168 | 169 | _cache = {} 170 | 171 | u_inst_type = type(typing.Union[int, float]) 172 | 173 | 174 | def process_annotations(anns: dict, glob: dict): 175 | if ret := _cache.get(id(anns)): 176 | return ret 177 | 178 | def each(v): 179 | if isinstance(v, str): 180 | v = eval(v, glob) 181 | if hasattr(typing, "GenericAlias") and isinstance( 182 | v, typing.GenericAlias 183 | ): 184 | v = v.__origin__ 185 | elif type(v) is u_inst_type and v.__origin__ is typing.Union: 186 | return tuple(each(a) for a in v.__args__) 187 | 188 | assert isinstance(v, type), v 189 | return S(v) 190 | 191 | ret = {k: each(v) for k, v in anns.items()} 192 | _cache[id(anns)] = ret 193 | return ret 194 | 195 | 196 | def eager_jitclass(cls: type): 197 | shape = create_shape(cls, oop=True) 198 | if annotations := getattr(cls, "__annotations__", None): 199 | 200 | def get_attr(self, *args: AbsVal): 201 | if len(args) != 2: 202 | return NotImplemented 203 | a_obj, a_attr = args 204 | if a_attr.is_literal() and a_attr.base in annotations: 205 | if ret_types := process_annotations( 206 | annotations, sys.modules[cls.__module__].__dict__ 207 | ).get(a_attr.base): 208 | if not isinstance(ret_types, tuple): 209 | assert isinstance(ret_types, AbsVal) 210 | ret_types = (ret_types,) 211 | func = S(intrinsic("PyObject_GetAttr")) 212 | return CallSpec( 213 | None, func(a_obj, a_attr), ret_types 214 | ) 215 | return NotImplemented 216 | 217 | register(cls, attr="__getattr__")(get_attr) 218 | 219 | for each, f in cls.__dict__.items(): 220 | if not each.startswith("__") and isinstance(f, FunctionType): 221 | eagerjit(f) 222 | shape.fields[each] = S(f) 223 | return cls 224 | 225 | 226 | def jitclass( 227 | cls: type, 228 | fixed_references: Iterable[str] = (), 229 | meth_jit_policy=conservativejit, 230 | jit_methods: typing.Union[all, Iterable[str]] = all, 231 | ): 232 | fixed_references = set(fixed_references) 233 | shape = create_shape(cls, oop=True) 234 | if annotations := getattr(cls, "__annotations__", None): 235 | 236 | def get_attr(self, *args: AbsVal): 237 | if len(args) != 2: 238 | return NotImplemented 239 | a_obj, a_attr = args 240 | if a_attr.is_literal() and a_attr.base in annotations: 241 | if ret_types := process_annotations( 242 | annotations, sys.modules[cls.__module__].__dict__ 243 | ).get(a_attr.base): 244 | if not isinstance(ret_types, tuple): 245 | assert isinstance(ret_types, AbsVal) 246 | ret_types = (ret_types,) 247 | func = S(intrinsic("PyObject_GetAttr")) 248 | return CallSpec( 249 | None, func(a_obj, a_attr), (*ret_types, Top) 250 | ) 251 | return NotImplemented 252 | 253 | register(cls, attr="__getattr__")(get_attr) 254 | 255 | for each, f in ( 256 | jit_methods is all and cls.__dict__.items() or jit_methods 257 | ): 258 | if not each.startswith("__") and isinstance(f, FunctionType): 259 | meth_jit_policy(f, fixed_references) 260 | shape.fields[each] = S(f) 261 | return cls 262 | -------------------------------------------------------------------------------- /genopname.py: -------------------------------------------------------------------------------- 1 | import opcode 2 | import os 3 | import dis 4 | 5 | dir = "diojit" 6 | with open(os.path.join(dir, "stack2reg", "opcodes.py"), "w") as f: 7 | f.write("import opcode\n") 8 | f.write("UNKNOWN_INSTR = object()\n") 9 | for each in opcode.opmap: 10 | f.write(f"{each} = opcode.opmap.get({each!r}, UNKNOWN_INSTR)\n") 11 | f.write("\n") 12 | 13 | with open(os.path.join(dir, "stack2reg", "cflags.py"), "w") as f: 14 | f.write("import dis\n") 15 | f.write("_flags = {v: k for k, v in dis.COMPILER_FLAG_NAMES.items()}\n") 16 | 17 | for _, n in dis.COMPILER_FLAG_NAMES.items(): 18 | f.write(f"{n} = _flags[{n!r}]\n") 19 | f.write("\n") 20 | -------------------------------------------------------------------------------- /prepub.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | fd -H -E .git ... -0 | xargs -0 dos2unix 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pyrsistent >= 0.17.0 2 | julia >= 0.5.0 -------------------------------------------------------------------------------- /runtests/doil.py: -------------------------------------------------------------------------------- 1 | import diojit 2 | 3 | 4 | @diojit.jit 5 | def test(a): 6 | t = (1, 2, 3) 7 | i = 0 8 | while i < a: 9 | t = (t[1], t[2], t[0]) 10 | i = i + t[0] 11 | return i 12 | 13 | 14 | @diojit.jit(fixed_references=["test"]) 15 | def f(x): 16 | return test(x) 17 | 18 | 19 | in_def = diojit.absint.In_Def.UserCodeDyn[test] 20 | # in_def.show() 21 | callspec = diojit.spec_call_ir(f, diojit.Val(500)) 22 | print("return types: ", *callspec.possibly_return_types) 23 | print("instance : ", callspec.instance) 24 | print("call expr : ", callspec.e_call) 25 | for each in reversed(diojit.absint.Out_Def.GenerateCache): 26 | each.show() 27 | 28 | 29 | -------------------------------------------------------------------------------- /runtests/load.py: -------------------------------------------------------------------------------- 1 | import diojit.runtime.julia_rt -------------------------------------------------------------------------------- /runtests/tutorial.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | import operator 3 | import timeit 4 | import builtins 5 | import diojit 6 | from inspect import getsource 7 | from diojit.runtime.julia_rt import check_jl_err 8 | from diojit.codegen.julia import splice 9 | 10 | GenerateCache = diojit.Out_Def.GenerateCache 11 | 12 | import diojit as jit 13 | import timeit 14 | from operator import add 15 | 16 | libjl = jit.runtime.julia_rt.get_libjulia() 17 | 18 | 19 | def jl_eval(s: str): 20 | libjl.jl_eval_string(s.encode()) 21 | check_jl_err(libjl) 22 | 23 | 24 | def fib(a): 25 | if a <= 2: 26 | return 1 27 | return fib(a - 1) + fib(a - 2) 28 | 29 | 30 | @jit.jit(fixed_references=["fib_fix"]) 31 | def fib_fix(a): 32 | if a <= 2: 33 | return 1 34 | return fib_fix(a + -1) + fib_fix(a + -2) 35 | 36 | 37 | jit_fib_fix_untyped = jit.spec_call(fib_fix, jit.Top) 38 | jit_fib_fix_typed = jit.spec_call( 39 | fib_fix, jit.oftype(int) 40 | ) 41 | # jl_eval(f"println(J_fib__fix_1({splice(20)}))") 42 | # check_jl_err(libjl) 43 | print("fib".center(70, "=")) 44 | print(getsource(fib)) 45 | print( 46 | "fib(15), jit_fib_fix_untyped(15), jit_fib_fix_typed(15) = ", 47 | (fib(15), jit_fib_fix_untyped(15), jit_fib_fix_typed(15)), 48 | ) 49 | print( 50 | "fib(py) bench time:", 51 | timeit.timeit("f(15)", globals=dict(f=fib), number=10000), 52 | ) 53 | print( 54 | "fib(jit+untyped) bench time:", 55 | timeit.timeit( 56 | "f(15)", globals=dict(f=jit_fib_fix_untyped), number=10000 57 | ), 58 | ) 59 | print( 60 | "fib(jit+inferred) bench time:", 61 | timeit.timeit( 62 | "f(15)", globals=dict(f=jit_fib_fix_typed), number=10000 63 | ), 64 | ) 65 | 66 | print("hypot".center(70, "=")) 67 | 68 | 69 | @diojit.jit(fixed_references=["sqrt", "str", "int", "isinstance"]) 70 | def hypot(x, y): 71 | if isinstance(x, str): 72 | x = int(x) 73 | 74 | if isinstance(y, str): 75 | y = int(y) 76 | 77 | return sqrt(x ** 2 + y ** 2) 78 | 79 | 80 | print(getsource(hypot)) 81 | 82 | 83 | # print("Direct Translation From Stack Instructions".center(70, "=")) 84 | 85 | # diojit.absint.In_Def.UserCodeDyn[hypot].show() 86 | # print("After JITing".center(70, "=")) 87 | 88 | 89 | jit_func_name = repr( 90 | diojit.spec_call_ir( 91 | hypot, diojit.S(int), diojit.S(int) 92 | ).e_call.func 93 | ) 94 | 95 | 96 | hypot_spec = diojit.spec_call( 97 | hypot, 98 | diojit.oftype(int), 99 | diojit.oftype(int), 100 | # print_jl=print, 101 | # print_dio_ir=print, 102 | ) 103 | # # 104 | # libjl = diojit.runtime.julia_rt.get_libjulia() 105 | # libjl.jl_eval_string(f'using InteractiveUtils;@code_llvm {jit_func_name}(PyO.int, PyO.int)'.encode()) 106 | # diojit.runtime.julia_rt.check_jl_err(libjl) 107 | 108 | print("hypot(1, 2) (jit) = ", hypot_spec(1, 2)) 109 | print("hypot(1, 2) (pure py) = ", hypot(1, 2)) 110 | print( 111 | "hypot (pure py) bench time:", 112 | timeit.timeit("f(1, 2)", number=1000000, globals=dict(f=hypot)), 113 | ) 114 | print( 115 | "hypot (jit) bench time:", 116 | timeit.timeit( 117 | "f(1, 2)", number=1000000, globals=dict(f=hypot_spec) 118 | ), 119 | ) 120 | 121 | diojit.create_shape(list, oop=True) 122 | 123 | 124 | @diojit.register(list, attr="append") 125 | def list_append_analysis(self: diojit.Judge, *args: diojit.AbsVal): 126 | if len(args) != 2: 127 | # rollback to CPython's default code 128 | return NotImplemented 129 | lst, elt = args 130 | 131 | return diojit.CallSpec( 132 | instance=None, # return value is not static 133 | e_call=diojit.S(diojit.intrinsic("PyList_Append"))(lst, elt), 134 | possibly_return_types=tuple({diojit.S(type(None))}), 135 | ) 136 | 137 | 138 | @diojit.jit 139 | def append3(xs, x): 140 | xs.append(x) 141 | xs.append(x) 142 | xs.append(x) 143 | 144 | 145 | print("append3".center(70, "=")) 146 | print(getsource(append3)) 147 | 148 | # diojit.In_Def.UserCodeDyn[append3].show() 149 | jit_append3 = diojit.spec_call( 150 | append3, diojit.oftype(list), diojit.Top 151 | ) 152 | xs = [1] 153 | jit_append3(xs, 3) 154 | print("test jit func: [1] append 3 for 3 times =", xs) 155 | 156 | 157 | xs = [] 158 | print( 159 | "append3 (jit) bench time:", 160 | timeit.timeit( 161 | "f(xs, 1)", globals=dict(f=jit_append3, xs=xs), number=10000000 162 | ), 163 | ) 164 | xs = [] 165 | print( 166 | "append3 (pure py) bench time:", 167 | timeit.timeit( 168 | "f(xs, 1)", globals=dict(f=append3, xs=xs), number=10000000 169 | ), 170 | ) 171 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from datetime import datetime 3 | from pathlib import Path 4 | 5 | 6 | version = "0.2a" 7 | with Path("README.md").open() as readme: 8 | readme = readme.read() 9 | 10 | 11 | setup( 12 | name="diojit", 13 | version=version if isinstance(version, str) else str(version), 14 | keywords="Just-In-Time, JIT, compiler", # keywords of your project that separated by comma "," 15 | description="A general-purpose JIT for CPython.", # a concise introduction of your project 16 | long_description=readme, 17 | long_description_content_type="text/markdown", 18 | license="mit", 19 | python_requires=">=3.8.0", 20 | url="https://github.com/thautwarm/diojit", 21 | author="thautwarm", 22 | author_email="twshere@outlook.com", 23 | packages=find_packages(), 24 | entry_points={"console_scripts": []}, 25 | # above option specifies what commands to install, 26 | # e.g: entry_points={"console_scripts": ["yapypy=yapypy.cmd:compiler"]} 27 | install_requires=["pyrsistent", "julia"], # dependencies 28 | platforms="any", 29 | classifiers=[ 30 | "Programming Language :: Python :: 3.6", 31 | "Programming Language :: Python :: 3.7", 32 | "Programming Language :: Python :: Implementation :: CPython", 33 | ], 34 | zip_safe=False, 35 | ) 36 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | echo "=====py39=====" 2 | pip install diojit 3 | source activate base # py39 4 | python benchmarks/append3.py 5 | python benchmarks/brainfuck.py 6 | python benchmarks/dna_read.py 7 | python benchmarks/fib.py 8 | python benchmarks/hypot.py 9 | python benchmarks/selection_sort.py 10 | python benchmarks/trans.py 11 | pip uninstall diojit 12 | 13 | echo "=====py38=====" 14 | source activate py38 15 | pip install . 16 | python benchmarks/append3.py 17 | python benchmarks/brainfuck.py 18 | python benchmarks/dna_read.py 19 | python benchmarks/fib.py 20 | python benchmarks/hypot.py 21 | python benchmarks/selection_sort.py 22 | python benchmarks/trans.py 23 | pip uninstall diojit --------------------------------------------------------------------------------