├── .gitignore
├── ChangeLog.md
├── LICENSE
├── README.md
├── README.zh_CN.md
├── benchmarks
├── append3.py
├── brainfuck.py
├── const_fib.py
├── dna_read.py
├── fib.py
├── hypot.py
├── selection_sort.py
├── trans.py
└── union_type.py
├── diojit
├── __init__.py
├── absint
│ ├── __init__.py
│ ├── abs.py
│ ├── intrinsics.py
│ └── prescr.py
├── codegen
│ ├── __init__.py
│ └── julia.py
├── runtime
│ ├── __init__.py
│ └── julia_rt.py
├── stack2reg
│ ├── __init__.py
│ ├── cflags.py
│ ├── opcodes.py
│ └── translate.py
└── user
│ ├── __init__.py
│ └── client.py
├── genopname.py
├── prepub.sh
├── requirements.txt
├── runtests
├── doil.py
├── load.py
└── tutorial.py
├── setup.py
└── test.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | _test.py
2 | _test.ll
3 | .vscode/
4 | bin/
5 | obj/
6 | .ionide/
7 | node_modules/
8 | .fake
9 | .idea/
10 | .ci/
11 | typeshed/
12 |
13 | **__pycache__/
14 | dev_test/
15 | Digraph.gv*
16 | .spyproject/
17 | **.egg-info/
18 | dist/
19 | **~
20 | build/
21 |
22 |
23 | # Created by https://www.gitignore.io/api/macos,python
24 | # Edit at https://www.gitignore.io/?templates=macos,python
25 |
26 | ### macOS ###
27 | # General
28 | .DS_Store
29 | .AppleDouble
30 | .LSOverride
31 |
32 | # Icon must end with two \r
33 | Icon
34 |
35 | # Thumbnails
36 | ._*
37 |
38 | # Files that might appear in the root of a volume
39 | .DocumentRevisions-V100
40 | .fseventsd
41 | .Spotlight-V100
42 | .TemporaryItems
43 | .Trashes
44 | .VolumeIcon.icns
45 | .com.apple.timemachine.donotpresent
46 |
47 | # Directories potentially created on remote AFP share
48 | .AppleDB
49 | .AppleDesktop
50 | Network Trash Folder
51 | Temporary Items
52 | .apdisk
53 |
54 | ### Python ###
55 | # Byte-compiled / optimized / DLL files
56 | __pycache__/
57 | *.py[cod]
58 | *$py.class
59 |
60 | # C extensions
61 | *.so
62 |
63 | # Distribution / packaging
64 | .Python
65 | build/
66 | develop-eggs/
67 | dist/
68 | downloads/
69 | eggs/
70 | .eggs/
71 | lib/
72 | lib64/
73 | parts/
74 | sdist/
75 | var/
76 | wheels/
77 | pip-wheel-metadata/
78 | share/python-wheels/
79 | *.egg-info/
80 | .installed.cfg
81 | *.egg
82 | MANIFEST
83 |
84 | # PyInstaller
85 | # Usually these files are written by a python script from a template
86 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
87 | *.manifest
88 | *.spec
89 |
90 | # Installer logs
91 | pip-log.txt
92 | pip-delete-this-directory.txt
93 |
94 | # Unit test / coverage reports
95 | htmlcov/
96 | .tox/
97 | .nox/
98 | .coverage
99 | .coverage.*
100 | .cache
101 | nosetests.xml
102 | coverage.xml
103 | *.cover
104 | .hypothesis/
105 | .pytest_cache/
106 |
107 | # Translations
108 | *.mo
109 | *.pot
110 |
111 | # Django stuff:
112 | *.log
113 | local_settings.py
114 | db.sqlite3
115 | db.sqlite3-journal
116 |
117 | # Flask stuff:
118 | instance/
119 | .webassets-cache
120 |
121 | # Scrapy stuff:
122 | .scrapy
123 |
124 | # Sphinx documentation
125 | docs/_build/
126 |
127 | # PyBuilder
128 | target/
129 |
130 | # Jupyter Notebook
131 | .ipynb_checkpoints
132 |
133 | # IPython
134 | profile_default/
135 | ipython_config.py
136 |
137 | # pyenv
138 | .python-version
139 |
140 | # pipenv
141 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
142 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
143 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
144 | # install all needed dependencies.
145 | #Pipfile.lock
146 |
147 | # celery beat schedule file
148 | celerybeat-schedule
149 |
150 | # SageMath parsed files
151 | *.sage.py
152 |
153 | # Environments
154 | .env
155 | .venv
156 | env/
157 | venv/
158 | ENV/
159 | env.bak/
160 | venv.bak/
161 |
162 | # Spyder project settings
163 | .spyderproject
164 | .spyproject
165 |
166 | # Rope project settings
167 | .ropeproject
168 |
169 | # mkdocs documentation
170 | /site
171 |
172 | # mypy
173 | .mypy_cache/
174 | .dmypy.json
175 | dmypy.json
176 |
177 | # Pyre type checker
178 | .pyre/
179 |
180 | # End of https://www.gitignore.io/api/macos,python
181 |
182 | slide-examples/
--------------------------------------------------------------------------------
/ChangeLog.md:
--------------------------------------------------------------------------------
1 | ## 0.2a(2/8/2021)
2 |
3 | - Rename `jit_spec_call` to `spec_call`.
4 |
5 | ## 0.1.5(2/8/2021)
6 |
7 | - Add experimental features: `@eagerjit` and `@conservativejit`.
8 |
9 | The first one assumes field types according to annotations of fields, and tries to make all
10 | methods jit-able.
11 |
12 | The second one needs manually specifying jit-able methods, and does not totally believe users'
13 | annotations to fields: a runtime type check will be generated when accessing fields.
14 |
15 | ## 0.1.4.1(2/7/2021)
16 |
17 | - RC analysis.
18 |
19 | Previously, `def f(x): x = g(x)` can cause segmentfault due to unexpected deallocation.
20 |
21 | We now added analysis for reference counting in Python side, greatly reducing redundant RC
22 | operations at runtime.
23 |
24 |
25 | ## 0.1.2(2/5/2021)
26 |
27 | - The JIT compiler is now able to optimise selection sort using lists(40% speed up).
28 |
29 | **Experiments have shown that if we can have type-parameterised lists, we can have
30 | a performance gain in a factor of 600%.**
31 |
32 |
33 |
34 |
35 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Copyright (c) 2021, thautwarm
3 |
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without modification,
7 | are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice,
10 | this list of conditions and the following disclaimer.
11 | * Redistributions in binary form must reproduce the above copyright notice,
12 | this list of conditions and the following disclaimer in the documentation
13 | and/or other materials provided with the distribution.
14 |
15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER
19 | OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 |
27 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## DIO-JIT: General-purpose Python JIT
2 |
3 | [](https://github.com/thautwarm/diojit/blob/master/README.zh_CN.md) [](https://pypi.python.org/pypi/diojit/)
4 | [](https://pypi.python.org/pypi/diojit/)
5 |
6 | Important:
7 |
8 | 1. DIO-JIT now works for Python >= 3.8. We heavily rely on the `LOAD_METHOD` bytecode instruction.
9 | 2. DIO-JIT is not production-ready. a large number of specialisation rules are required to make DIO-JIT batteries-included.
10 | 3. This document is mainly provided for prospective developers. Users are not required to write any specialisation rules, which means that users need to learn nothing but `@jit.jit` and `jit.spec_call`.
11 |
12 | ### Benchmark
13 |
14 | | Item | PY38 | JIT PY38 | PY39 | JIT PY39 |
15 | | ------------------------------------------------------------------------------------------ | ------ | -------- | ------ | -------- |
16 | | [BF](https://github.com/thautwarm/diojit/blob/master/benchmarks/brainfuck.py) | 265.74 | 134.23 | 244.50 | 140.34 |
17 | | [append3](https://github.com/thautwarm/diojit/blob/master/benchmarks/append3.py) | 23.94 | 10.70 | 22.29 | 11.21 |
18 | | [DNA READ](https://github.com/thautwarm/diojit/blob/master/benchmarks/dna_read.py) | 16.96 | 14.82 | 15.03 | 14.38 |
19 | | [fib(15)](https://github.com/thautwarm/diojit/blob/master/benchmarks/fib.py) | 11.63 | 1.54 | 10.41 | 1.51 |
20 | | [hypot(str, str)](https://github.com/thautwarm/diojit/blob/master/benchmarks/hypot.py) | 6.19 | 3.87 | 6.53 | 4.29 |
21 | | [selectsort](https://github.com/thautwarm/diojit/blob/master/benchmarks/selection_sort.py) | 46.95 | 33.88 | 38.71 | 29.49 |
22 | | [trans](https://github.com/thautwarm/diojit/blob/master/benchmarks/trans.py) | 24.22 | 7.79 | 23.23 | 7.71 |
23 |
24 | The benchmark item "DNA READ" does not show a significant performance gain, this is because "DNA READ" heavily uses `bytearray` and `bytes`, whose specialised C-APIs
25 | are not exposed. In this case, although the JIT can infer the types, we have to fall back to CPython's default behaviour, or even worse: after all, the interpreter can access internal things, while we cannot.
26 |
27 | P.S:
28 | DIO-JIT can do very powerful partial evaluation, which is disabled in default but you can
29 | leverage it in your domain specific tasks. Here is an example of achieving **500x** speed up against pure Python: [fibs.py](https://github.com/thautwarm/diojit/blob/master/benchmarks/const_fib.py)
30 |
31 | ## Install Instructions
32 |
33 | Step 1: Install Julia as an in-process native code compiler for DIO-JIT
34 |
35 |
36 | There are several options for you to install Julia:
37 |
38 | - [scoop](http://scoop.sh/) (Windows)
39 | - [julialang.org](https://julialang.org/downloads) (recommended for Windows users)
40 | - [jill.py](https://github.com/johnnychen94/jill.py):
41 |
42 | ```bash
43 | $ pip install jill && jill install 1.6 --upstream Official
44 | ```
45 |
46 | - [jill](https://github.com/abelsiqueira/jill) (Mac and Linux only!):
47 |
48 | ```bash
49 | $ bash -ci "$(curl -fsSL https://raw.githubusercontent.com/abelsiqueira/jill/master/jill.sh)"
50 | ```
51 |
52 |
53 |
54 |
55 | Step 2: Install DIO.jl in Julia
56 |
57 |
58 | Type `julia` and open the REPL, then
59 |
60 | ```julia
61 | julia>
62 | # press ]
63 | pkg> add https://github.com/thautwarm/DIO.jl
64 | # press backspace
65 | julia> using DIO # precompile
66 | ```
67 |
68 |
69 |
70 |
71 | Step 3: Install Python Package
72 |
73 |
74 | ```bash
75 | $ pip install git+https://github.com/thautwarm/diojit
76 | ```
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 | How to fetch latest DIO-JIT?(if you have installed DIO)
85 |
86 |
87 |
88 | ```bash
89 | $ pip install -U diojit
90 | $ julia -e "using Pkg; Pkg.update(string(:DIO));using DIO"
91 | ```
92 |
93 |
94 |
95 |
96 | Usage from Python side is quite similar to that from Numba.
97 |
98 | ```python
99 | import diojit
100 | from math import sqrt
101 | # eagerjit: assuming all global references are fixed
102 | @diojit.eagerjit
103 | def fib(a):
104 | if a <= 2:
105 | return 1
106 | return fib(a + -1) + fib(a + -2)
107 |
108 | jit_fib = diojit.spec_call(fib, diojit.oftype(int), diojit.oftype(int))
109 | jit_fib(15) # 600% faster than pure python
110 | ```
111 |
112 | It might look strange to you that we use `a + -1` and `a + -2` here.
113 |
114 | Clever observation! And that's the point!
115 |
116 | DIO-JIT relies on specilisation rules. We have written one for additions, more specifically, `operator.__add__`: [specilisation for `operator.__add__`](https://github.com/thautwarm/diojit/blob/175aab5f4cb65fee923b9f6cb97c256252fc49f5/diojit/absint/prescr.py#L226).
117 |
118 | However, due to the bandwidth limitation, rules for `operator.__sub__` is not implemented yet.
119 |
120 | (P.S: [why `operator.__add__`](https://github.com/thautwarm/diojit/blob/3ceb9513377234f476566f70792632ce08c13373/diojit/stack2reg/translate.py#L30).)
121 |
122 | Although specilisation is common in the scope of optimisation, unlike many other JIT attempts, DIO-JIT doesn't need to
123 | hard encode rules at compiler level. The DIO-JIT compiler implements the skeleton of abstract interpretation, but concrete
124 | rules for specialisation and other inferences can be added within Python itself in an extensible way!
125 |
126 | See an example below.
127 |
128 | ## Contribution Example: Add a specialisation rule for `list.append`
129 |
130 | 1. Python Side:
131 |
132 | ```python
133 | import diojit as jit
134 | import timeit
135 | jit.create_shape(list, oop=True)
136 | @jit.register(list, attr="append")
137 | def list_append_analysis(self: jit.Judge, *args: jit.AbsVal):
138 | if len(args) != 2:
139 | # rollback to CPython's default code
140 | return NotImplemented
141 | lst, elt = args
142 |
143 | return jit.CallSpec(
144 | instance=None, # return value is not static
145 | e_call=jit.S(jit.intrinsic("PyList_Append"))(lst, elt),
146 | possibly_return_types=tuple({jit.S(type(None))}),
147 | )
148 | ```
149 |
150 |
151 | `jit.intrinsic("PyList_Append")` mentioned in above code means the intrinsic provided by the Julia codegen backend.
152 | Usually it's calling a CPython C API, but sometimes may not.
153 |
154 | No matter if it is an existing CPython C API, we can implement intrinsics in Julia.
155 |
156 | - [import PyList_Append symbol](https://github.com/thautwarm/DIO.jl/blob/c3ec304645437da6bb02c9e5acb0c91e5e3800a8/src/symbols.jl#L53)
157 |
158 | - [generate PyList_Append calling convention](https://github.com/thautwarm/DIO.jl/blob/5fa79357798ff3eaee561d14d4f04a271213282c/src/dynamic.jl#L120):
159 |
160 | ```julia
161 | @autoapi PyList_Append(PyPtr, PyPtr)::Cint != Cint(-1) cast(_cint2none) nocastexc
162 | ```
163 |
164 | As a consequence, we automatically generate an instrinsic function for DIO-JIT. This intrinsic function
165 | is capable of handling CPython exception and reference counting.
166 |
167 | You can either do step 2) at Python side. It might looks more intuitive.
168 |
169 | ```python
170 | import diojit as jit
171 | from diojit.runtime.julia_rt import jl_eval
172 | jl_implemented_intrinsic = """
173 | function PyList_Append(lst::Ptr, elt::PyPtr)
174 | if ccall(PyAPI.PyList_Append, Cint, (PyPtr, PyPtr), lst, elt) == -1
175 | return Py_NULL
176 | end
177 | nothing # automatically maps to a Python None
178 | end
179 | DIO.DIO_ExceptCode(::typeof(PyList_Append)) != Py_NULL
180 | """
181 | jl_eval(jl_implemented_intrinsic)
182 | ```
183 |
184 | You immediately get a >**100%** time speed up:
185 |
186 | ```python
187 | @jit.jit
188 | def append3(xs, x):
189 | xs.append(x)
190 | xs.append(x)
191 | xs.append(x)
192 |
193 | jit_append3 = jit.spec_call(append3, jit.oftype(list), jit.Top) # 'Top' means 'Any'
194 | xs = [1]
195 | jit_append3(xs, 3)
196 |
197 | print("test jit_append3, [1] append 3 for 3 times:", xs)
198 | # test jit func, [1] append 3 for 3 times: [1, 3, 3, 3]
199 |
200 | xs = []
201 | %timeit append3(xs, 1)
202 | # 293 ns ± 26.2 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
203 |
204 | xs = []
205 | %timeit jit_append3(xs, 1)
206 | # 142 ns ± 14.9 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
207 | ```
208 |
209 | ## Why Julia?
210 |
211 | We don't want to maintain a C compiler, and calling `gcc` or others will introduce cross-process IO, which is slow.
212 | We prefer compiling JITed code with LLVM, and **Julia is quite a killer tool for this use case**.
213 |
214 | ## Current Limitations
215 |
216 | 1. Support for `*varargs` and `**kwargs` are not ready: we do can immediately support them with very tiny JIT performance gain, but considering backward compatibility we decide not to do this.
217 |
218 | 2. Exception handling is not yet supported inside JIT functions.
219 |
220 | Why?
221 |
222 |
223 | We haven't implemented the translation from exception handling bytecode to untyped DIO IR (`jit.absint.abs.In_Stmt`).
224 |
225 |
226 |
227 |
228 | Will support?
229 |
230 |
231 | Yes.
232 |
233 | In fact, now a callsite in any JIT function can raise an exception. It will not be handled by JIT functions, instead, it is lifted up to the root call, which is a pure Python call.
234 |
235 | Exception handling will be supported when we have efforts on translating CPython bytecode about exception handling into untyped DIO IR (`jit.absint.abs.In_Stmt`).
236 |
237 | P.S: This will be finished simultaneously with the support for `for` loop.
238 |
239 |
240 |
241 |
242 | 3. Support for `for` loop is missing.
243 |
244 | Why?
245 |
246 |
247 | Firstly, in CPython, `for` loop relies on exception handling, which is not supported yet.
248 |
249 | Secondly, we're considering a fast path for `for` loop, maybe proposing a `__citer__` protocol for faster iteration for JIT functions, which requires communications with Python developers.
250 |
251 |
252 |
253 |
254 | Will support?
255 |
256 |
257 | Yes.
258 |
259 | This will be finished simultaneously with support for exception handling (faster `for` loop might come later).
260 |
261 |
262 |
263 |
264 | 4. Closure support is missing.
265 |
266 | Why?
267 |
268 |
269 | In imperative languages, closures use *cell* structures to achieve mutable free/cell variables.
270 |
271 | However, a writable cell makes it hard to optimise in a dynamic language.
272 |
273 | We recommend using `types.MethodType` to create immutable closures,which can be highly optimised in DIO-JIT(near future).
274 |
275 | ```python
276 | import types
277 | def f(freevars, z):
278 | x, y = freevars
279 | return x + y + z
280 |
281 | def hof(x, y):
282 | return types.MethodType(f, (x, y))
283 | ```
284 |
285 |
286 |
287 |
288 | Will support?
289 |
290 |
291 | Still yes. However, don't expect much about the performance gain for Python's vanilla closures.
292 |
293 |
294 |
295 |
296 | 5. Specifying fixed global references(`@diojit.jit(fixed_references=['isinstance', 'str', ...]`) too annoying?
297 |
298 | Sorry, you have to. We are thinking about the possibility about automatic JIT covering all existing CPython code, but the biggest impediment is the volatile global variables.
299 |
300 | You might use `@eagerjit`, and in this case you'd be cautious in making global variables unchangeable.
301 |
302 | Possibility?
303 |
304 |
305 | Recently we found CPython's newly(`:)`) added feature `Dict.ma_version_tag` might be used to automatically notifying JITed functions to re-compile when the global references change.
306 |
307 | More research is required.
308 |
309 |
310 |
311 |
312 | ## Contributions
313 |
314 | 1. Add more prescribed specialisation rules at `jit.absint.prescr`.
315 | 2. TODO
316 |
317 | ## Benchmarks
318 |
319 | Check `benchmarks` directory.
320 |
--------------------------------------------------------------------------------
/README.zh_CN.md:
--------------------------------------------------------------------------------
1 | ## DIO-JIT: Python的泛用jit
2 |
3 | [](https://github.com/thautwarm/diojit/blob/master/README.zh_CN.md) [](https://pypi.python.org/pypi/diojit/)
4 | [](https://pypi.python.org/pypi/diojit/)
5 |
6 | DIO-JIT是一种 method JIT, 在抽象解释和调用点特化下成为可能。抽象解释由编译器实现,而特化规则可以扩展式地注册(例见`jit.absint.prescr`)。
7 |
8 | Important:
9 |
10 | 1. 注意, DIO-JIT目前只在Python>=3.8时工作。我们高度依赖Python 3.8之后的`LOAD_METHOD`字节码指令。
11 | 2. 在多数情况下来看,目前DIO-JIT不适合生产环境。我们还需要提供更多的特化规则,来让DIO-JIT变得开箱即用。
12 | 3. 这个文档主要是为开发者提供的。用户不需要了解如何写特化规则,只需要使用`jit.jit(func_obj)`和`jit.spec_call(func_obj, arg_specs...)`。
13 |
14 | ## Benchmark
15 |
16 | | Item | PY38 | JIT PY38 | PY39 | JIT PY39 |
17 | | ------------------------------------------------------------------------------------------ | ------ | -------- | ------ | -------- |
18 | | [BF](https://github.com/thautwarm/diojit/blob/master/benchmarks/brainfuck.py) | 265.74 | 134.23 | 244.50 | 140.34 |
19 | | [append3](https://github.com/thautwarm/diojit/blob/master/benchmarks/append3.py) | 23.94 | 10.70 | 22.29 | 11.21 |
20 | | [DNA READ](https://github.com/thautwarm/diojit/blob/master/benchmarks/dna_read.py) | 16.96 | 14.82 | 15.03 | 14.38 |
21 | | [fib(15)](https://github.com/thautwarm/diojit/blob/master/benchmarks/fib.py) | 11.63 | 1.54 | 10.41 | 1.51 |
22 | | [hypot(str, str)](https://github.com/thautwarm/diojit/blob/master/benchmarks/hypot.py) | 6.19 | 3.87 | 6.53 | 4.29 |
23 | | [selectsort](https://github.com/thautwarm/diojit/blob/master/benchmarks/selection_sort.py) | 46.95 | 33.88 | 38.71 | 29.49 |
24 | | [trans](https://github.com/thautwarm/diojit/blob/master/benchmarks/trans.py) | 24.22 | 7.79 | 23.23 | 7.71 |
25 |
26 | The bechmark item "DNA READ" does not show a significant performance gain, this is because "DNA READ" heavily uses `bytearray` and `bytes`, whose specialised C-APIs
27 | are not exposed. In this case, although the JIT can infer the types, we have to fall back to CPython's default behaviour, or even worse: after all, the interpreter can access internal things, while we cannot.
28 |
29 | P.S:
30 | DIO-JIT可以做聪明的部分求值, 但为了编译器的快速收敛,online常量折叠默认是关闭的。
31 | 你可以在领域特定任务中使用这个能力。 这里有一个对cpython提速 **500倍**的例子: [fibs.py](https://github.com/thautwarm/diojit/blob/master/benchmarks/const_fib.py)
32 |
33 | ## 安装
34 |
35 | 1: 安装Julia(我们的"底层代码编译服务"提供者)
36 |
37 |
38 | 推荐以如下方式安装Julia:
39 |
40 | - [scoop](http://scoop.sh/) (Windows)
41 | - [julialang.org](https://cn.julialang.org/downloads/) (Windows)
42 | - [jill.py](https://github.com/johnnychen94/jill.py) (跨平台,但安装路径不符合Windows上Unix用户习惯):
43 |
44 | ```bash
45 | $ pip install jill && jill install 1.6
46 | ```
47 |
48 | - [jill](https://github.com/abelsiqueira/jill) (Mac and Linux):
49 |
50 | ```bash
51 | $ bash -ci "$(curl -fsSL https://raw.githubusercontent.com/abelsiqueira/jill/master/jill.sh)"
52 | ```
53 |
54 |
55 |
56 |
57 | 2: 在Julia中安装 DIO.jl
58 |
59 |
60 | 输入 `julia` 打开REPL
61 |
62 | ```julia
63 | julia>
64 | # 按下 ]
65 | pkg> add https://github.com/thautwarm/DIO.jl
66 | # 按下 backspace 键
67 | julia> using DIO # 预编译
68 | ```
69 |
70 |
71 |
72 |
73 | 3: 安装Python
74 |
75 |
76 | ```bash
77 | $ pip install git+https://github.com/thautwarm/diojit
78 | ```
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 | 如何获取最新的DIO-JIT?(需安装过DIO-JIT)
87 |
88 |
89 | ```bash
90 | $ pip install -U diojit
91 | $ julia -e "using Pkg; Pkg.update(string(:DIO));using DIO"
92 | ```
93 |
94 |
95 |
96 |
97 | 从Python端使用DIO-JIT和使用Numba类似:
98 |
99 | ```python
100 | import diojit
101 | from math import sqrt
102 | # eagerjit: 假设所有全局引用不变
103 | @diojit.eagerjit
104 | def fib(a):
105 | if a <= 2:
106 | return 1
107 | return fib(a + -1) + fib(a + -2)
108 |
109 | jit_fib = diojit.spec_call(fib, diojit.oftype(int), diojit.oftype(int))
110 | jit_fib(15) # 比原生Python快600%以上
111 | ```
112 |
113 | 你可能会问, 为什么上面的代码要使用`a + -1`和`a + -2`这么迷惑的写法?
114 |
115 | 你get了重点!
116 |
117 | 我们的jit依赖于已有的特化规则。我们已经为加法,具体的说是`operator.__add__`实现了特化规则: [`operator.__add__`的特化规则](https://github.com/thautwarm/diojit/blob/175aab5f4cb65fee923b9f6cb97c256252fc49f5/diojit/absint/prescr.py#L226).
118 |
119 | (P.S: [为啥是 `operator.__add__`](https://github.com/thautwarm/diojit/blob/3ceb9513377234f476566f70792632ce08c13373/diojit/stack2reg/translate.py#L30).)
120 |
121 | 但因为个人精力有限,目前还没有为`operator.__sub__`实现对应的规则。
122 |
123 | 虽然特化是非常常见的优化技术,但与很多的Python JIT不同的是,DIO-JIT并不需要在编译器层面内建特别的优化。DIO-JIT编译器只负责实现一个抽象解释的算法,
124 | 而更具体的推导、特化规则,在Python里就可以扩展式地添加!
125 |
126 | 下面是一个例子。
127 |
128 | ## 代码贡献案例: 为`list.append`注册特化规则
129 |
130 | 步骤1:Python端如下代码
131 |
132 | ```python
133 | import diojit as jit
134 | import timeit
135 | jit.create_shape(list, oop=True)
136 | @jit.register(list, attr="append")
137 | def list_append_analysis(self: jit.Judge, *args: jit.AbsVal):
138 | if len(args) != 2:
139 | # rollback to CPython's default code
140 | return NotImplemented
141 | lst, elt = args
142 |
143 | return jit.CallSpec(
144 | instance=None, # return value is not static
145 | e_call=jit.S(jit.intrinsic("PyList_Append"))(lst, elt),
146 | possibly_return_types=tuple({jit.S(type(None))}),
147 | )
148 | ```
149 |
150 | 上面的`jit.intrinsic("PyList_Append")`指的是JIT后端提供的底层原语,它通常是调用CPython的C API。
151 |
152 | 我们可以在Julia里面实现这些底层原语。
153 |
154 | 步骤2: Julia端
155 |
156 | - [导入PyList_Append符号](https://github.com/thautwarm/DIO.jl/blob/c3ec304645437da6bb02c9e5acb0c91e5e3800a8/src/symbols.jl#L53)
157 |
158 | - [生成PyList_Append的调用约定](https://github.com/thautwarm/DIO.jl/blob/5fa79357798ff3eaee561d14d4f04a271213282c/src/dynamic.jl#L120):
159 |
160 | ```julia
161 | @autoapi PyList_Append(PyPtr, PyPtr)::Cint != Cint(-1) cast(_cint2none) nocastexc
162 | ```
163 |
164 | 这样一来,我们就自动生成了一个能够处理CPython错误处理和引用计数的原语函数。
165 |
166 | 实际上,你也可以在Python端手动实现步骤2,没有宏看起来可能会更直观一些:
167 |
168 | ```python
169 | import diojit as jit
170 | from diojit.runtime.julia_rt import jl_eval
171 | jl_implemented_intrinsic = b"""
172 | function PyList_Append(lst::Ptr, elt::PyPtr)
173 | if ccall(PyAPI.PyList_Append, Cint, (PyPtr, PyPtr), lst, elt) == -1
174 | return Py_NULL
175 | end
176 | nothing # automatically maps to a Python None
177 | end
178 | DIO.DIO_ExceptCode(::typeof(PyList_Append)) != Py_NULL
179 | """
180 | jl_eval(jl_implemented_intrinsic)
181 | ```
182 |
183 | 我们立即得到大于**100%**的性能提升。
184 |
185 | ```python
186 | @jit.jit
187 | def append3(xs, x):
188 | xs.append(x)
189 | xs.append(x)
190 | xs.append(x)
191 |
192 | jit_append3 = jit.spec_call(append3, jit.oftype(list), jit.Top) # 'Top' means 'Any'
193 | xs = [1]
194 | jit_append3(xs, 3)
195 |
196 | print("test jit_append3, [1] append 3 for 3 times:", xs)
197 | # test jit func, [1] append 3 for 3 times: [1, 3, 3, 3]
198 |
199 | xs = []
200 | %timeit append3(xs, 1)
201 | # 293 ns ± 26.2 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
202 |
203 | xs = []
204 | %timeit jit_append3(xs, 1)
205 | # 142 ns ± 14.9 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
206 | ```
207 |
208 | ## 为什么要用Julia?
209 |
210 | 我们不想维护一个C编译器,并且调用`gcc`这样的操作会引入跨进程的IO。这好吗?这不好。
211 |
212 | 我们倾向于使用LLVM在运行时编译底层代码,而基于其上的Julia提供了对其底层架构的自由访问,从而**成为了运行时编译的杀手级应用**。
213 |
214 | ## 现状和限制
215 |
216 | 1. **暂**未支持不定参数和关键字参数。我们可以立刻支持它们,并提供较小的JIT性能提升。但由此可能引来后向兼容的问题,权衡之下还是暂且搁置。
217 |
218 | 2. **暂**不支持在JIT函数中处理异常.
219 |
220 | ???
221 |
222 |
223 | 还未实现从相关的CPython字节码到无类型DIO IR的转译(`jit.absint.abs.In_Stmt`)
224 |
225 |
226 |
227 |
228 | 会支持吗?
229 |
230 |
231 | 会的。
232 |
233 | 实际上,目前JIT函数内部的调用可以正常抛错。这样的错误无法被JIT函数处理,而是被交给更上层。
234 |
235 | 在我们有精力实现对应的(错误处理)字节码到无类型 DIO IR的转译后,JIT函数中将可以做错误处理。
236 |
237 | P.S: 这会和`for`循环的支持同时实现。
238 |
239 |
240 |
241 |
242 | 3. **暂**不支持`for`循环.
243 |
244 | ???
245 |
246 |
247 | 首先,在CPython中,`for`循环的视线依赖错误处理,而这目前还未支持。
248 |
249 | 其次,我们在考虑一个更高效的`for`循环实现,可能会提议一个`__citer__`协议用以JIT函数的优化。而这需要和Python开发者进一步探讨。
250 |
251 |
252 |
253 |
254 | 会支持吗?
255 |
256 |
257 | 嗯。
258 |
259 | 这会和错误处理同时实现。快速`for`可能会引入得更晚一些。
260 |
261 |
262 |
263 |
264 | 4. 未支持Python的原生闭包
265 |
266 | ???
267 |
268 |
269 | 在命令式语言中,闭包使用一种叫`cell`的数据结构来实现可变(mutable)的自由变量(free variables)。
270 |
271 | 然而,在动态语言里边,优化可写的闭包是一个相当困难的问题。
272 |
273 | 我们建议你使用`types.MethodType`创建自由变量不可变的闭包,这是DIO-JIT(很快就能)高效优化的写法。
274 |
275 | ```python
276 | import types
277 | def f(freevars, z):
278 | x, y = freevars
279 | return x + y + z
280 |
281 | def hof(x, y):
282 | return types.MethodType(f, (x, y))
283 | ```
284 |
285 |
286 |
287 |
288 | 会支持吗?
289 |
290 |
291 | 会是会的,毕竟我们的目的是覆盖所有的CPython代码。
292 |
293 | 但对此不要期待很大性能提升。
294 |
295 |
296 |
297 |
298 | 5. 手动指定不可变全局变量太啰嗦了?(`@diojit.jit(fixed_references=['isinstance', 'str', ...]`)
299 |
300 | 很遗憾,你得这样。我们在考虑全自动JIT的可能性,但Python任人随意修改的全部变量是这个目标最大的阻碍。
301 |
302 | 不写行不行?
303 |
304 |
305 | 可能会可以的。
306 |
307 | 近期CPython优化了存储全局变量的字典。字典的内存布局多了一个名为`ma_version_tag`的数字,用以指示字典最近被写入过。这个改动可能可以用来触发JIT函数的重编译。
308 |
309 | 这还需要更多的研究。
310 |
311 |
312 |
313 |
314 | ## Contributions
315 |
316 | 1. 发挥你的聪明才智,展示你对Python语义的了解,为DIO-JIT添加更多的特化规则吧!
317 |
318 | 例见`jit.absint.prescr`。
319 | 2. TODO
320 |
321 | ## Benchmarks
322 |
323 | 康康 `benchmarks` 文件夹。
324 |
--------------------------------------------------------------------------------
/benchmarks/append3.py:
--------------------------------------------------------------------------------
1 | """
2 | append3 (jit) bench time: 1.4360911
3 | append3 (pure py) bench time: 2.9313146
4 | """
5 | import diojit as jit
6 | from inspect import getsource
7 | from diojit.runtime.julia_rt import jl_eval, splice
8 | import timeit
9 |
10 |
11 | @jit.jit
12 | def append3(xs, x):
13 | xs.append(x)
14 | xs.append(x)
15 | xs.append(x)
16 |
17 |
18 | print("append3".center(70, "="))
19 | print(getsource(append3))
20 |
21 | # jit.In_Def.UserCodeDyn[append3].show()
22 | jit_append3 = jit.spec_call(
23 | append3,
24 | jit.oftype(list),
25 | jit.Top,
26 | # print_dio_ir=print,
27 | )
28 | x = [1]
29 | y = 2
30 | # jl_eval(f"println(J_append3_0({splice(x)}, {splice(y)}))")
31 | # raise
32 | xs = [1]
33 | jit_append3(xs, 3)
34 | print("test jit func: [1] append 3 for 3 times =", xs)
35 |
36 |
37 | xs = []
38 | print(
39 | "append3 (jit) bench time:",
40 | timeit.timeit(
41 | "f(xs, 1)", globals=dict(f=jit_append3, xs=xs), number=100000000
42 | ),
43 | )
44 | xs = []
45 | print(
46 | "append3 (pure py) bench time:",
47 | timeit.timeit(
48 | "f(xs, 1)", globals=dict(f=append3, xs=xs), number=100000000
49 | ),
50 | )
51 |
--------------------------------------------------------------------------------
/benchmarks/brainfuck.py:
--------------------------------------------------------------------------------
1 | """
2 | jit time 140.34073400497437
3 | pure py time 244.5001037120819
4 | 23280 ==? 23280
5 | """
6 | import platform
7 | import socket
8 | import sys
9 | import time
10 |
11 | import os
12 | import itertools
13 | from pathlib import Path
14 | import diojit as jit
15 | import typing
16 |
17 | sys.setrecursionlimit(2200)
18 | print("brainfuck".center(50, "="))
19 |
20 |
21 | INC = 1
22 | MOVE = 2
23 | LOOP = 3
24 | PRINT = 4
25 |
26 | OP = 0
27 | VAL = 1
28 |
29 |
30 | @jit.eagerjit
31 | class Op(object):
32 | op: int
33 | val: typing.Union[int, list]
34 |
35 | def __init__(self, op, val):
36 | assert isinstance(op, int) and isinstance(val, (int, list))
37 | self.op = op
38 | self.val = val
39 |
40 |
41 | @jit.eagerjit
42 | class Tape(object):
43 | tape: list
44 | pos: int
45 |
46 | def __init__(self):
47 | self.tape = [0]
48 | self.pos = 0
49 |
50 | def get(self):
51 | return self.tape[self.pos]
52 |
53 | def inc(self, x):
54 | self.tape[self.pos] += x
55 |
56 | def move(self, x):
57 | self.pos += x
58 | while self.pos >= len(self.tape):
59 | self.tape.extend(itertools.repeat(0, len(self.tape)))
60 |
61 |
62 | @jit.eagerjit
63 | class Printer(object):
64 | sum1: int
65 | sum2: int
66 | quiet: bool
67 |
68 | def __init__(self, quiet):
69 | self.sum1 = 0
70 | self.sum2 = 0
71 | self.quiet = quiet
72 |
73 | def print(self, n):
74 | if self.quiet:
75 | self.sum1 = (self.sum1 + n) % 255
76 | self.sum2 = (self.sum2 + self.sum1) % 255
77 | else:
78 | sys.stdout.write(chr(n))
79 | sys.stdout.flush()
80 |
81 | @property
82 | def checksum(self):
83 | return (self.sum2 << 8) | self.sum1
84 |
85 |
86 | def parse(iterator):
87 | res = []
88 | while True:
89 | try:
90 | c = iterator.__next__()
91 | except StopIteration:
92 | break
93 |
94 | if c == "+":
95 | res.append(Op(INC, 1))
96 | elif c == "-":
97 | res.append(Op(INC, -1))
98 | elif c == ">":
99 | res.append(Op(MOVE, 1))
100 | elif c == "<":
101 | res.append(Op(MOVE, -1))
102 | elif c == ".":
103 | res.append(Op(PRINT, 0))
104 | elif c == "[":
105 | res.append(Op(LOOP, parse(iterator)))
106 | elif c == "]":
107 | break
108 |
109 | return res
110 |
111 |
112 | @jit.eagerjit
113 | def _run(program, tape, p):
114 | i = 0
115 | n = len(program)
116 | while i < n:
117 | op = program[i]
118 | i += 1
119 | if op.op == INC:
120 | tape.inc(op.val)
121 | elif op.op == MOVE:
122 | tape.move(op.val)
123 | elif op.op == LOOP:
124 | while tape.get() > 0:
125 | _run(op.val, tape, p)
126 | elif op.op == PRINT:
127 | p.print(tape.get())
128 |
129 |
130 | class Program(object):
131 | def __init__(self, code):
132 | self.ops = parse(iter(code))
133 |
134 | def run(self, p, use_jit: bool):
135 | if use_jit:
136 | _run_jit = jit.spec_call(
137 | _run,
138 | jit.oftype(list),
139 | jit.oftype(Tape),
140 | jit.oftype(Printer),
141 | print_dio_ir=print,
142 | )
143 | _run_jit(self.ops, Tape(), p)
144 | else:
145 | _run(self.ops, Tape(), p)
146 |
147 |
148 | p1 = Printer(True)
149 | p2 = Printer(True)
150 |
151 | prog = Program(
152 | """>++[<+++++++++++++>-]<[[>+>+<<-]>[<+>-]++++++++
153 | [>++++++++<-]>.[-]<<>++++++++++[>++++++++++[>++
154 | ++++++++[>++++++++++[>++++++++++[>++++++++++[>+
155 | +++++++++[-]<-]<-]<-]<-]<-]<-]<-]++++++++++."""
156 | )
157 | n = time.time()
158 | prog.run(p2, use_jit=True)
159 | print("jit time", time.time() - n)
160 |
161 | n = time.time()
162 | prog.run(p1, use_jit=False)
163 | print("pure py time", time.time() - n)
164 |
165 |
166 | print(p1.checksum, "==?", p2.checksum)
167 |
--------------------------------------------------------------------------------
/benchmarks/const_fib.py:
--------------------------------------------------------------------------------
1 | import diojit as jit
2 | import operator
3 | import timeit
4 |
5 |
6 | def try_const_add(a, b):
7 | return a + b
8 |
9 |
10 | def try_const_le(a, b):
11 | return a <= b
12 |
13 |
14 | def try_const_eq(a, b):
15 | return a == b
16 |
17 |
18 | @jit.register(try_const_add, create_shape=True)
19 | def call_const_add(self: jit.Judge, *args: jit.AbsVal):
20 | if len(args) != 2:
21 | return NotImplemented
22 | a, b = args
23 | if a.is_s() and b.is_s():
24 | const = jit.S(operator.add(a.base, b.base))
25 | ret_types = (jit.S(type(const.base)),)
26 | return jit.CallSpec(const, const, ret_types)
27 |
28 | return self.spec(jit.S(operator.__add__), "__call__", list(args))
29 |
30 |
31 | @jit.register(try_const_le, create_shape=True)
32 | def call_const_le(self: jit.Judge, *args: jit.AbsVal):
33 | if len(args) != 2:
34 | return NotImplemented
35 | a, b = args
36 | if a.is_literal() and b.is_literal():
37 | const = jit.S(operator.__le__(a.base, b.base))
38 | ret_types = (jit.S(type(const.base)),)
39 | return jit.CallSpec(const, const, ret_types)
40 |
41 | return self.spec(jit.S(operator.__le__), "__call__", list(args))
42 |
43 |
44 | @jit.register(try_const_eq, create_shape=True)
45 | def call_const_eq(self: jit.Judge, *args: jit.AbsVal):
46 | if len(args) != 2:
47 | return NotImplemented
48 | a, b = args
49 | if a.is_literal() and b.is_literal():
50 | const = jit.S(operator.__eq__(a.base, b.base))
51 | ret_types = (jit.S(type(const.base)),)
52 | return jit.CallSpec(const, const, ret_types)
53 |
54 | return self.spec(jit.S(operator.__eq__), "__call__", list(args))
55 |
56 |
57 | @jit.eagerjit
58 | def fib(x):
59 | if x <= 2:
60 | return 1
61 | return fib(x + -1) + fib(x + -2)
62 |
63 |
64 | jit_fib = jit.spec_call(fib, jit.oftype(int))
65 |
66 |
67 | @jit.eagerjit
68 | def try_const_fib(x):
69 | if try_const_le(x, 2):
70 | return 1
71 | return try_const_add(
72 | try_const_fib(
73 | try_const_add(x, -1),
74 | ),
75 | try_const_fib(
76 | try_const_add(x, -2),
77 | ),
78 | )
79 |
80 |
81 | try_const_fib = jit.spec_call(try_const_fib, jit.ofval(20))
82 |
83 |
84 | def bench(kind, f, number=2000):
85 | print(
86 | kind,
87 | timeit.timeit(
88 | "fib(x)",
89 | globals=dict(fib=f, x=20),
90 | number=number,
91 | ),
92 | )
93 |
94 |
95 | print(fib(20))
96 | print(jit_fib(20))
97 | print(try_const_fib(20))
98 |
99 | bench("pure py fib", fib)
100 | bench("jit fib", jit_fib)
101 | bench("jit fib_fast", try_const_fib)
102 |
--------------------------------------------------------------------------------
/benchmarks/dna_read.py:
--------------------------------------------------------------------------------
1 | """
2 | [Python39]
3 | jit: 7.4767213
4 | pure py: 7.8394832
5 |
6 | [Python38]
7 | 7.8734842
8 | 8.994623599999999
9 | """
10 | import sys, string
11 | import requests
12 | from io import BytesIO
13 | from tempfile import mktemp
14 | import timeit
15 |
16 | import diojit as jit
17 | from diojit.codegen.julia import splice
18 | from diojit.runtime.julia_rt import check_jl_err, get_libjulia
19 |
20 | print('DNA READ'.center(50, '='))
21 |
22 | libjl = get_libjulia()
23 | contents = requests.get(
24 | r"https://raw.githubusercontent.com/dundee/pybenchmarks/master/bencher/data/revcomp-input1000.txt"
25 | ).text.encode()
26 | table = bytes.maketrans(
27 | b"ACBDGHK\nMNSRUTWVYacbdghkmnsrutwvy",
28 | b"TGVHCDM\nKNSYAAWBRTGVHCDMKNSYAAWBR",
29 | )
30 |
31 |
32 | def jl_eval(s: str):
33 | libjl.jl_eval_string(s.encode())
34 | check_jl_err(libjl)
35 |
36 |
37 | @jit.jit(fixed_references=["len"])
38 | def show(out_io: bytearray, seq):
39 | # FIXME: optimisation point
40 | seq_ = (b"".join(seq)).translate(table)[::-1]
41 | i = 0
42 | n = len(seq_)
43 | while i < n:
44 | out_io.extend(seq_[i : i + 60])
45 | i += 60
46 |
47 |
48 | # @jit.jit
49 | def main(out_io: bytearray):
50 | in_io = BytesIO(contents)
51 | seq = []
52 | for line in in_io:
53 | if line[0] in b">;":
54 | show(out_io, seq)
55 | out_io.extend(line)
56 | del seq[:]
57 | else:
58 | seq.append(line[:-1])
59 | show(out_io, seq)
60 | return out_io
61 |
62 |
63 | @jit.jit(fixed_references=["next", "BytesIO", "show"])
64 | def main2(out_io: bytearray):
65 | in_io = BytesIO(contents)
66 | seq = []
67 | while True:
68 | line = next(in_io, None)
69 | if line is None:
70 | break
71 | if line[0] in b">;":
72 | show(out_io, seq)
73 | out_io.extend(line)
74 | del seq[:]
75 | else:
76 | seq.append(line[:-1])
77 | show(out_io, seq)
78 | return out_io
79 |
80 |
81 | # @jit.jit
82 | # def b():
83 | # return [print, set, list]
84 | # for i in range(1000):
85 | # print(jit.jit_spec_call(b)())
86 | # b()
87 |
88 | jit_main = jit.spec_call(
89 | main2,
90 | jit.oftype(bytearray),
91 | # print_dio_ir=print,
92 | )
93 | # raise
94 | #
95 | x = bytearray()
96 | jl_eval(f"J_main2_0({splice(x)})")
97 | # raise
98 | print(main(bytearray()) == jit_main(bytearray()))
99 | print(
100 | 'jit',
101 | timeit.timeit(
102 | "main(bytearray())",
103 | number=200000,
104 | globals=dict(main=jit_main),
105 | ),
106 | )
107 |
108 | print(
109 | 'pure py',
110 | timeit.timeit(
111 | "main(bytearray())",
112 | number=200000,
113 | globals=dict(main=main),
114 | ),
115 | )
116 |
--------------------------------------------------------------------------------
/benchmarks/fib.py:
--------------------------------------------------------------------------------
1 | """
2 | fib(15) (py) bench time: 1.3318193000000003
3 | fib(15) (jit+untyped) bench time: 0.42067140000000025
4 | fib(15) (jit+inferred) bench time: 0.1776359000000003
5 | """
6 | import diojit as jit
7 | from inspect import getsource
8 | import timeit
9 | from diojit.runtime.julia_rt import splice, jl_eval
10 |
11 |
12 | def fib(a):
13 | if a <= 2:
14 | return 1
15 | return fib(a - 1) + fib(a - 2)
16 |
17 |
18 | @jit.jit(fixed_references=["fib_fix"])
19 | def fib_fix(a):
20 | if a <= 2:
21 | return 1
22 | return fib_fix(a + -1) + fib_fix(a + -2)
23 |
24 |
25 | jit_fib_fix_typed = jit.spec_call(
26 | fib_fix,
27 | jit.oftype(int),
28 | # print_jl=print,
29 | )
30 | jit_fib_fix_untyped = jit.spec_call(fib_fix, jit.Top)
31 | jl_eval(f"println(J_fib__fix_1({splice(20)}))")
32 | # check_jl_err(libjl)
33 | print("fib".center(70, "="))
34 | print(getsource(fib))
35 | print(
36 | "fib(15), jit_fib_fix_untyped(15), jit_fib_fix_typed(15) = ",
37 | (fib(15), jit_fib_fix_untyped(15), jit_fib_fix_typed(15)),
38 | )
39 | print(
40 | "fib(py) bench time:",
41 | timeit.timeit("f(15)", globals=dict(f=fib), number=100000),
42 | )
43 | print(
44 | "fib(jit+untyped) bench time:",
45 | timeit.timeit(
46 | "f(15)", globals=dict(f=jit_fib_fix_untyped), number=100000
47 | ),
48 | )
49 | print(
50 | "fib(jit+inferred) bench time:",
51 | timeit.timeit(
52 | "f(15)", globals=dict(f=jit_fib_fix_typed), number=100000
53 | ),
54 | )
55 |
--------------------------------------------------------------------------------
/benchmarks/hypot.py:
--------------------------------------------------------------------------------
1 | """
2 | hypot (pure py) bench time: 0.7044676000000001
3 | hypot (jit) bench time: 0.4247455999999996
4 | """
5 | import diojit as jit
6 | from inspect import getsource
7 | import timeit
8 | from math import sqrt
9 |
10 |
11 | @jit.jit(fixed_references=["sqrt", "str", "int", "isinstance"])
12 | def hypot(x, y):
13 | if isinstance(x, str):
14 | x = int(x)
15 |
16 | if isinstance(y, str):
17 | y = int(y)
18 |
19 | return sqrt(x ** 2 + y ** 2)
20 |
21 |
22 | print(getsource(hypot))
23 |
24 |
25 | # print("Direct Translation From Stack Instructions".center(70, "="))
26 |
27 | # jit.absint.In_Def.UserCodeDyn[hypot].show()
28 | # print("After JITing".center(70, "="))
29 |
30 |
31 | jit_func_name = repr(
32 | jit.spec_call_ir(
33 | hypot, jit.S(int), jit.S(int)
34 | ).e_call.func
35 | )
36 |
37 |
38 | hypot_spec = jit.spec_call(
39 | hypot,
40 | jit.oftype(int),
41 | jit.oftype(int),
42 | # print_jl=print,
43 | # print_dio_ir=print,
44 | )
45 | # #
46 | # libjl = jit.runtime.julia_rt.get_libjulia()
47 | # libjl.jl_eval_string(f'using InteractiveUtils;@code_llvm {jit_func_name}(PyO.int, PyO.int)'.encode())
48 | # jit.runtime.julia_rt.check_jl_err(libjl)
49 |
50 | print("hypot(1, 2) (jit) = ", hypot_spec(1, 2))
51 | print("hypot(1, 2) (pure py) = ", hypot(1, 2))
52 | print(
53 | "hypot (pure py) bench time:",
54 | timeit.timeit("f(1, 2)", number=10000000, globals=dict(f=hypot)),
55 | )
56 | print(
57 | "hypot (jit) bench time:",
58 | timeit.timeit(
59 | "f(1, 2)", number=10000000, globals=dict(f=hypot_spec)
60 | ),
61 | )
62 |
--------------------------------------------------------------------------------
/benchmarks/selection_sort.py:
--------------------------------------------------------------------------------
1 | """
2 | pure py: 5.2469079999999995
3 | jit: 3.6917899
4 | >40% performance gain.
5 | (but if we can have a strict generic list type in Python,
6 | we can have a 600% performance gain.
7 | """
8 | import diojit as jit
9 | import timeit
10 | import numpy as np
11 | from diojit.runtime.julia_rt import check_jl_err
12 | from diojit.codegen.julia import splice
13 | import sys
14 |
15 | print('selection sort'.center(50, '='))
16 | sys.setrecursionlimit(2000)
17 |
18 | libjl = jit.runtime.julia_rt.get_libjulia()
19 |
20 |
21 | def jl_eval(s: str):
22 | libjl.jl_eval_string(s.encode())
23 | check_jl_err(libjl)
24 |
25 |
26 | @jit.jit(fixed_references=["int"])
27 | def lt(e, min_val):
28 | return int(e) < int(min_val)
29 |
30 |
31 | @jit.jit(fixed_references=["lt", "int"])
32 | def argmin(xs, i, n):
33 | min_val = xs[i] # int(xs[i])
34 | min_i = i
35 | j = i + 1
36 | while j < n:
37 | e = xs[j] # int(xs[j])
38 | # if lt(e, min_val):
39 | if e < min_val:
40 | min_i = j
41 | min_val = e
42 | j = j + 1
43 | return min_i
44 |
45 |
46 | @jit.jit
47 | def swap(xs, min_i, i):
48 | xs[min_i], xs[i] = xs[i], xs[min_i]
49 |
50 |
51 | @jit.jit(fixed_references=["argmin", "swap", "len"])
52 | def msort(xs):
53 | xs = xs.copy()
54 | n = len(xs)
55 | i = 0
56 | while i < n:
57 | min_i = argmin(xs, i, n)
58 | swap(xs, min_i, i)
59 | i = i + 1
60 | return xs
61 |
62 |
63 | @jit.jit
64 | def mwe(xs):
65 | return xs[0] < xs[2]
66 |
67 |
68 | jit_msort = jit.spec_call(
69 | msort,
70 | jit.oftype(list),
71 | # print_dio_ir=print,
72 | )
73 |
74 |
75 | xs = list(np.random.randint(0, 10000, 100))
76 | print(
77 | "pure py:",
78 | timeit.timeit("f(xs)", globals=dict(xs=xs, f=msort), number=100000),
79 | )
80 |
81 | print(
82 | "jit:",
83 | timeit.timeit(
84 | "f(xs)", globals=dict(xs=xs, f=jit_msort), number=100000
85 | ),
86 | )
87 |
88 |
89 | ## This is the specialisation that produces 600% performance gain:
90 |
91 | # @register(operator.__getitem__, create_shape=True)
92 | # def call_getitem(self: Judge, *args: AbsVal):
93 | # if len(args) != 2:
94 | # # 返回到默认python实现
95 | # return NotImplemented
96 | # subject, item = args
97 | # ret_types = (Top,)
98 | #
99 | # if (
100 | # subject.type not in (Top, Bot)
101 | # and subject.type.base == list
102 | # and item.type not in (Top, Bot)
103 | # and issubclass(item.type.base, int)
104 | # ):
105 | # func = S(intrinsic("PyList_GetItem"))
106 | # # ret_types = tuple({Values.A_Int})
107 | # else:
108 | # func = S(intrinsic("PyObject_GetItem"))
109 | #
110 | # e_call = func(subject, item)
111 | # instance = None
112 | # return CallSpec(instance, e_call, ret_types)
113 |
--------------------------------------------------------------------------------
/benchmarks/trans.py:
--------------------------------------------------------------------------------
1 | """
2 | pure py: 2.6820188
3 | jit: 0.8471317999999997
4 | """
5 | import diojit as jit
6 | from inspect import getsource
7 | import timeit
8 | from diojit.runtime.julia_rt import splice, jl_eval
9 |
10 | print('trans'.center(50, '='))
11 |
12 | @jit.jit
13 | def f(x):
14 | x = 1 + x
15 | y = 1 + x
16 | z = 1 + y
17 | x = 1 + z
18 | y = 1 + x
19 | z = 1 + y
20 | x = 1 + z
21 | return x
22 |
23 |
24 | jit_f = jit.spec_call(f, jit.oftype(int))
25 |
26 | print(jit_f(10))
27 | print('pure py:', timeit.timeit("f(10)", globals=dict(f=f), number=111111111))
28 | print('jit:', timeit.timeit("f(10)", globals=dict(f=jit_f), number=111111111))
29 |
--------------------------------------------------------------------------------
/benchmarks/union_type.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import diojit as jit
3 | import typing
4 | import timeit
5 |
6 |
7 | @jit.eagerjit
8 | class Node:
9 | next: typing.Union[type(None), Node]
10 | val: int
11 |
12 | def __init__(self, n, val):
13 | self.next = n
14 | self.val = val
15 |
16 |
17 | @jit.eagerjit
18 | def sum_chain(n: Node):
19 | a = 0
20 | while n is not None:
21 | a += n.val
22 | n = n.next
23 |
24 | return a
25 |
26 |
27 | n = Node(None, 0)
28 | for i in range(100):
29 | n = Node(n, (i + 2) * 5)
30 |
31 | jit_sum_chain = jit.spec_call(
32 | sum_chain, jit.oftype(Node), print_dio_ir=print
33 | )
34 | print(jit_sum_chain(n), sum_chain(n))
35 |
36 |
37 | def bench(kind, f, number=1000000):
38 | print(
39 | kind,
40 | timeit.timeit(
41 | "f(x)",
42 | globals=dict(f=f, x=n),
43 | number=number,
44 | ),
45 | )
46 |
47 |
48 | bench("jit", jit_sum_chain)
49 | bench("pure py", sum_chain)
50 |
--------------------------------------------------------------------------------
/diojit/__init__.py:
--------------------------------------------------------------------------------
1 | from .user.client import *
2 | from .absint import *
3 | from . import absint, stack2reg
4 | from . import codegen
5 | from . import runtime
6 |
--------------------------------------------------------------------------------
/diojit/absint/__init__.py:
--------------------------------------------------------------------------------
1 | from .intrinsics import *
2 | from .abs import *
3 | from .prescr import *
4 |
--------------------------------------------------------------------------------
/diojit/absint/abs.py:
--------------------------------------------------------------------------------
1 | """
2 | ::= D_i^t | NonD
3 | NonD ::= S_o^ [*]
4 | | Top
5 | | Bot
6 | """
7 |
8 | from __future__ import annotations
9 | from typing import (
10 | Union,
11 | Callable,
12 | Optional,
13 | Iterable,
14 | Sequence,
15 | TypeVar,
16 | Type,
17 | NamedTuple,
18 | TYPE_CHECKING,
19 | cast,
20 | )
21 | from collections import defaultdict, OrderedDict
22 | from contextlib import contextmanager
23 | from functools import total_ordering
24 | import types
25 | import dataclasses
26 | import pyrsistent
27 | import builtins
28 | from .intrinsics import *
29 |
30 | NoneType = type(None)
31 |
32 | # awesome pycharm :(
33 | FunctionType: Type[types.FunctionType] = cast(
34 | Type[types.FunctionType], types.FunctionType
35 | )
36 |
37 | # set False when not made for Python
38 | PYTHON = True
39 |
40 | __all__ = [
41 | "Values",
42 | "AbsVal",
43 | "D",
44 | "S",
45 | "Top",
46 | "Bot",
47 | "Judge",
48 | "JITSpecInfo",
49 | "CallSpec",
50 | "CallRecord",
51 | "PreSpecMaps",
52 | "SpecMaps",
53 | "RecTraces",
54 | "In_Def",
55 | "In_Move",
56 | "In_Goto",
57 | "In_SetLineno",
58 | "In_Bind",
59 | "In_Stmt",
60 | "In_Blocks",
61 | "In_Cond",
62 | "In_Return",
63 | "Out_Def",
64 | "Out_Label",
65 | "Out_Call",
66 | "Out_Return",
67 | "Out_Goto",
68 | "Out_TypeCase",
69 | "Out_If",
70 | "Out_Assign",
71 | "Out_SetLineno",
72 | "Out_DecRef",
73 | "Out_Instr",
74 | "print_out",
75 | "print_in",
76 | "from_runtime",
77 | "AWARED_IMMUTABLES",
78 | "ShapeSystem",
79 | "Shape",
80 | "FunctionType", # TODO: put it elsewhere
81 | ]
82 |
83 |
84 | class Out_Callable:
85 | def __call__(self, *args: AbsVal):
86 | # noinspection PyTypeChecker
87 | return Out_Call(self, args)
88 |
89 |
90 | class AbsVal:
91 | if TYPE_CHECKING:
92 |
93 | @property
94 | def type(self) -> NonD:
95 | raise NotImplementedError
96 |
97 | def is_literal(self) -> bool:
98 | raise NotImplementedError
99 |
100 | def is_s(self):
101 | return False
102 |
103 |
104 | @total_ordering
105 | class D(Out_Callable, AbsVal):
106 | """
107 | dynamic abstract value
108 | """
109 |
110 | i: int
111 | type: NonD
112 |
113 | def __init__(self, i: int, type: NonD):
114 | self.i = i
115 | self.type = type
116 |
117 | def __repr__(self):
118 | if self.type is Top:
119 | return f"D{self.i}"
120 | return f"D{self.i} : {self.type}"
121 |
122 | def is_literal(self):
123 | return False
124 |
125 | def __hash__(self):
126 | return 114514 ^ hash(self.i) ^ hash(self.type)
127 |
128 | def __eq__(self, other):
129 | return (
130 | isinstance(other, D)
131 | and self.i == other.i
132 | and self.type == other.type
133 | )
134 |
135 | def __lt__(self, other):
136 | # noinspection PyTypeHints
137 | if not isinstance(other, AbsVal):
138 | return False
139 | if other is Top or other is Bot:
140 | return True
141 | if isinstance(other, S):
142 | return True
143 | if self.i == other.i:
144 | return self.type < other.type
145 | return self.i < other.i
146 |
147 |
148 | @total_ordering
149 | class S(Out_Callable, AbsVal):
150 | """
151 | type abstract value
152 | """
153 |
154 | base: object
155 | params: Optional[tuple[NonD, ...]]
156 |
157 | def __init__(
158 | self, base: object, params: Optional[tuple[NonD, ...]] = None
159 | ):
160 | self.base = base
161 | self.params = params
162 |
163 | def __hash__(self):
164 | return 1919810 ^ hash(self.base) ^ hash(self.params)
165 |
166 | def __eq__(self, other):
167 | return (
168 | isinstance(other, S)
169 | and self.base == other.base
170 | and self.params == other.params
171 | )
172 |
173 | def __lt__(self, other):
174 | # noinspection PyTypeHints
175 | if not isinstance(other, AbsVal):
176 | return False
177 | if other is Top or other is Bot:
178 | return True
179 | if isinstance(other, D):
180 | return False
181 | if self.base == other.base:
182 | return self.params < other.params
183 | return hash(self.base) < hash(other.base)
184 |
185 | def is_s(self):
186 | return True
187 |
188 | def oop(self):
189 | return (shape := self.shape) and shape.oop
190 |
191 | @property
192 | def type(self):
193 | base = self.base
194 | t = type(base)
195 | if abs_t := _literal_type_maps.get(t):
196 | return abs_t
197 | elif t is tuple:
198 | return tuple_type(base)
199 | a_t = from_runtime(t)
200 | assert not isinstance(a_t, D)
201 | return a_t
202 |
203 | @property
204 | def shape(self) -> Optional[Shape]:
205 | if type(self.base) in _literal_type_maps:
206 | return
207 | return ShapeSystem.get(self.base)
208 |
209 | def is_literal(self):
210 | return type(self.base) in _literal_type_maps
211 |
212 | def __repr__(self):
213 | if isinstance(self.base, type):
214 | n = self.base.__name__
215 | elif isinstance(self.base, FunctionType):
216 | n = self.base.__name__
217 | elif isinstance(self.base, types.BuiltinFunctionType):
218 | n = f"{self.base.__module__}.{self.base.__name__}"
219 | else:
220 | n = repr(self.base)
221 |
222 | if self.params is None:
223 | return n
224 |
225 | return f"{n}<{', '.join(map(repr, self.params))}>"
226 |
227 |
228 | class Values:
229 | A_Int = S(int)
230 | A_Float = S(float)
231 | A_Str = S(str)
232 | A_NoneType = S(NoneType)
233 | A_FuncType = S(FunctionType)
234 | A_MethType = S(types.MethodType)
235 | A_Complex = S(complex)
236 | A_Bool = S(bool)
237 | A_Intrinsic = S(Intrinsic)
238 | A_Type = S(type)
239 | A_NotImplemented = S(NotImplemented)
240 | A_NotImplementedType = S(type(NotImplemented))
241 |
242 |
243 | A_Int = Values.A_Int
244 | A_Float = Values.A_Float
245 | A_Str = Values.A_Str
246 | A_NoneType = Values.A_NoneType
247 | A_FuncType = Values.A_FuncType
248 | A_Bool = Values.A_Bool
249 | A_Intrinsic = Values.A_Intrinsic
250 | A_Complex = Values.A_Complex
251 |
252 |
253 | _literal_type_maps = {
254 | int: A_Int,
255 | float: A_Float,
256 | complex: A_Complex,
257 | str: A_Str,
258 | bool: A_Bool,
259 | NoneType: A_NoneType,
260 | tuple: None,
261 | }
262 |
263 | _T = TypeVar("_T")
264 |
265 |
266 | class _Top(Out_Callable, AbsVal):
267 | def is_literal(self) -> bool:
268 | return False
269 |
270 | @property
271 | def type(self):
272 | raise TypeError
273 |
274 | def __repr__(self):
275 | return "Top"
276 |
277 |
278 | class _Bot(Out_Callable, AbsVal):
279 | def is_literal(self) -> bool:
280 | return False
281 |
282 | @property
283 | def type(self):
284 | raise TypeError
285 |
286 | def __repr__(self):
287 | return "Bot"
288 |
289 |
290 | Top = _Top()
291 | Bot = _Bot()
292 |
293 | NonD = Union[S, _Top, _Bot]
294 | if TYPE_CHECKING:
295 | AbsVal = Union[D, NonD]
296 |
297 |
298 | undef = object()
299 |
300 |
301 | @dataclasses.dataclass
302 | class Shape:
303 | name: object
304 | oop: bool
305 | fields: dict[str, Union[S, types.FunctionType]]
306 | # some type has unique instance
307 | # None.__class__ has None only
308 | instance: Union[
309 | None, S, Callable[[tuple[NonD, ...]], Optional[S]]
310 | ] = dataclasses.field(default=None)
311 | self_bound: bool = dataclasses.field(default=False)
312 |
313 |
314 | # None: means No Shape
315 | ShapeSystem: dict[object, Optional[Shape]] = {}
316 |
317 | AWARED_IMMUTABLES = {*_literal_type_maps, type, Intrinsic, FunctionType}
318 |
319 |
320 | def from_runtime(o: object, rt_map: list[object] = None):
321 | if hash(o):
322 | return S(o)
323 | t = type(o)
324 | if t is tuple:
325 | type_abs = tuple_type(o)
326 | else:
327 | type_abs = from_runtime(t)
328 | rt_map = rt_map or []
329 | i = len(rt_map)
330 | abs_val = D(i, type_abs)
331 | rt_map.append(abs_val)
332 | return abs_val
333 |
334 |
335 | def tuple_type(xs):
336 | return S(tuple, tuple(from_runtime(x) for x in xs))
337 |
338 |
339 | @dataclasses.dataclass(frozen=True)
340 | class In_Move:
341 | target: D
342 | source: AbsVal
343 |
344 | def __repr__(self):
345 | return f"{self.target} = {self.source!r}"
346 |
347 |
348 | @dataclasses.dataclass(frozen=True)
349 | class In_Bind:
350 | target: D
351 | sub: AbsVal
352 | attr: AbsVal
353 | args: tuple[AbsVal, ...]
354 |
355 | def __repr__(self):
356 | args = [repr(x) for x in self.args]
357 | return f"{self.target} = {self.sub!r}.{self.attr}({','.join(args)})"
358 |
359 |
360 | @dataclasses.dataclass(frozen=True)
361 | class In_Goto:
362 | label: str
363 |
364 | def __repr__(self):
365 | return f"goto {self.label}"
366 |
367 |
368 | @dataclasses.dataclass(frozen=True)
369 | class In_SetLineno:
370 | line: int
371 | filename: str
372 |
373 | def __repr__(self):
374 | return f"# line {self.line} at {self.filename}"
375 |
376 |
377 | @dataclasses.dataclass(frozen=True)
378 | class In_Cond:
379 | test: AbsVal
380 | then: str
381 | otherwise: str
382 |
383 | def __repr__(self):
384 | return (
385 | f"if {self.test!r} then {self.then} else {self.otherwise}"
386 | )
387 |
388 |
389 | @dataclasses.dataclass(frozen=True)
390 | class In_Return:
391 | value: AbsVal
392 |
393 | def __repr__(self):
394 | return f"return {self.value!r}"
395 |
396 |
397 | In_Stmt = Union[
398 | In_Cond, In_SetLineno, In_Goto, In_Move, In_Return, In_Bind
399 | ]
400 | In_Blocks = "dict[str, list[In_Stmt]]"
401 |
402 |
403 | def print_in(b: In_Blocks, print=print):
404 | for label, xs in sorted(b.items(), key=lambda x: x[0] != "entry"):
405 | print(label, ":")
406 | for each in xs:
407 | print(each)
408 |
409 |
410 | @dataclasses.dataclass
411 | class In_Def:
412 | narg: int
413 | blocks: In_Blocks
414 | _func: FunctionType
415 | static_glob: set[str]
416 |
417 | UserCodeDyn = {} # type: dict[FunctionType, In_Def]
418 |
419 | def show(self):
420 | args = [f"D[{i}]" for i in range(self.narg)]
421 |
422 | print(f'def {self.name}({",".join(args)})', "{")
423 | print_in(self.blocks, print=lambda *x: print("", *x))
424 | print("}")
425 |
426 | @property
427 | def func(self) -> FunctionType:
428 | """this is for hacking PyCharm's type checker"""
429 | # noinspection PyTypeChecker
430 | return self._func
431 |
432 | @property
433 | def name(self) -> str:
434 | return self.func.__name__
435 |
436 | @property
437 | def glob(self) -> dict:
438 | # noinspection PyUnresolvedReferences
439 | return self.func.__globals__
440 |
441 |
442 | @dataclasses.dataclass(unsafe_hash=True)
443 | class Out_Call(Out_Callable):
444 | func: AbsVal
445 | args: tuple[AbsVal, ...]
446 |
447 | def __repr__(self):
448 | return f"{self.func!r}{self.args!r}"
449 |
450 |
451 | @dataclasses.dataclass(frozen=True)
452 | class Out_Assign:
453 | target: D
454 | expr: Out_Call
455 | decrefs: tuple[int, ...]
456 |
457 | def show(self, prefix, print):
458 | decrefs = ",".join(f"D{i}" for i in self.decrefs)
459 | print(f"{prefix}{self.target} = {self.expr!r}")
460 | print(f"{prefix}when err: decref [{decrefs}] ")
461 |
462 |
463 | @dataclasses.dataclass(frozen=True)
464 | class Out_If:
465 | test: AbsVal
466 |
467 | t: str
468 | f: str
469 |
470 | def show(self, prefix, print):
471 | print(f"{prefix}if {self.test!r}")
472 | print(f"{prefix}then goto {self.t}")
473 | print(f"{prefix}else goto {self.f}")
474 |
475 |
476 | @dataclasses.dataclass(frozen=True)
477 | class Out_TypeCase:
478 | obj: AbsVal
479 | cases: pyrsistent.PMap[AbsVal, tuple[Out_Instr, ...]]
480 |
481 | def show(self, prefix, print):
482 | print(f"{prefix}case typeof {self.obj!r}")
483 | for t, xs in self.cases.items():
484 | print(f"{prefix} {t!r} ->")
485 | print_out(xs, prefix + " ", print)
486 |
487 |
488 | @dataclasses.dataclass(frozen=True)
489 | class Out_Label:
490 | label: str
491 |
492 | def show(self, prefix, print):
493 | print(f"label {self.label}:")
494 |
495 |
496 | @dataclasses.dataclass(frozen=True)
497 | class Out_Goto:
498 | label: str
499 |
500 | def show(self, prefix, print):
501 | print(f"{prefix}goto {self.label}")
502 |
503 |
504 | @dataclasses.dataclass(frozen=True)
505 | class Out_Return:
506 | value: AbsVal
507 | decrefs: tuple[int, ...]
508 |
509 | def show(self, prefix, print):
510 | decrefs = ",".join(f"D{i}" for i in self.decrefs)
511 | print(f"{prefix}return {self.value!r}")
512 | print(f"{prefix}and decref [{decrefs}]")
513 |
514 |
515 | @dataclasses.dataclass(frozen=True)
516 | class Out_DecRef:
517 | i: int
518 |
519 | def show(self, prefix, print):
520 | print(f"{prefix}decref D{self.i}")
521 |
522 |
523 | @dataclasses.dataclass(frozen=True)
524 | class Out_SetLineno:
525 | line: int
526 | filename: str
527 |
528 | def show(self, prefix, print):
529 | print(f"{prefix}# line {self.line} at {self.filename}")
530 |
531 |
532 | Out_Instr = Union[
533 | Out_DecRef,
534 | Out_Label,
535 | Out_TypeCase,
536 | Out_If,
537 | Out_Assign,
538 | Out_Return,
539 | Out_Goto,
540 | ]
541 |
542 |
543 | CallRecord = "tuple[FunctionType, tuple[AbsVal, ...]]"
544 |
545 |
546 | def print_out(xs: Iterable[Out_Instr], prefix, print):
547 | for each in xs:
548 | each.show(prefix, print)
549 |
550 |
551 | @dataclasses.dataclass
552 | class Out_Def:
553 | spec: JITSpecInfo
554 | params: tuple[AbsVal, ...]
555 | instrs: tuple[Out_Instr, ...]
556 | start: str
557 | func: FunctionType
558 |
559 | GenerateCache = OrderedDict() # type: dict[Intrinsic, Out_Def]
560 |
561 | @property
562 | def name(self) -> str:
563 | return self.func.__name__
564 |
565 | def show(self, print=print):
566 | ret_types = self.spec.possibly_return_types
567 | name = self.spec.abs_jit_func
568 | instance = self.spec.instance
569 | print(
570 | "|".join(map(repr, ret_types)),
571 | f"{name!r}(",
572 | ", ".join(map(repr, self.params)),
573 | ")",
574 | f"-> {instance} {{" if instance else "{",
575 | )
576 | # print(f" START from {self.start}")
577 | for i in self.instrs:
578 | i.show(" ", print)
579 | print("}")
580 |
581 |
582 | @dataclasses.dataclass
583 | class JITSpecInfo:
584 | instance: Optional[AbsVal] # maybe return a constant instance
585 | abs_jit_func: AbsVal
586 | possibly_return_types: tuple[AbsVal, ...]
587 |
588 |
589 | class CallSpec:
590 | instance: Optional[AbsVal] # maybe return a constant instance
591 | e_call: Union[Out_Call, AbsVal]
592 | possibly_return_types: tuple[AbsVal, ...]
593 |
594 | def __init__(
595 | self,
596 | instance: Optional[AbsVal],
597 | e_call: Union[Out_Call, AbsVal],
598 | possibly_return_types: Iterable[AbsVal, ...],
599 | ):
600 | self.instance = instance
601 | self.e_call = e_call
602 | if not isinstance(possibly_return_types, tuple):
603 | possibly_return_types = tuple(possibly_return_types)
604 | self.possibly_return_types = possibly_return_types
605 |
606 | def __eq__(self, other):
607 | return (
608 | isinstance(other, CallSpec)
609 | and self.instance == other.instance
610 | and self.e_call == other.e_call
611 | and self.possibly_return_types
612 | == other.possibly_return_types
613 | )
614 |
615 | def astuple(self):
616 | return self.instance, self.e_call, self.possibly_return_types
617 |
618 |
619 | # user function calls recorded here cannot be reanalyzed next time
620 | RecTraces: set[tuple[str, tuple[AbsVal, ...]]] = set()
621 |
622 | # specialisations map
623 | # return types not inferred but compiled function name, and partially inferenced types
624 | PreSpecMaps: dict[CallRecord, tuple[str, set[AbsVal, ...]]] = {}
625 |
626 | # cache return types and function address
627 | SpecMaps: dict[CallRecord, JITSpecInfo] = {}
628 |
629 |
630 | def mk_prespec_name(
631 | key: CallRecord, partial_returns: set[AbsVal], name=""
632 | ):
633 | v = PreSpecMaps.get(key)
634 | if v is None:
635 | i = len(PreSpecMaps)
636 | n = f"J_{name.replace('_', '__')}_{i}"
637 | PreSpecMaps[key] = n, partial_returns
638 | return n
639 | return v[0]
640 |
641 |
642 | class MemSlot(NamedTuple):
643 | # reference count
644 | rc: int
645 | # is locally allocated
646 | ila: bool
647 |
648 | def __repr__(self):
649 | x = f"[{self.rc}]"
650 | if self.ila:
651 | x = f"!{x}"
652 | return x
653 |
654 |
655 | @dataclasses.dataclass(frozen=True, eq=True, order=True)
656 | class Local:
657 | mem: pyrsistent.PVector[MemSlot]
658 | store: pyrsistent.PMap[int, AbsVal]
659 |
660 | def up_mem(self, mem):
661 | return Local(mem, self.store)
662 |
663 | def up_store(self, store):
664 | return Local(self.mem, store)
665 |
666 |
667 | def alloc(local: Local):
668 | i = -1
669 | for i, each in enumerate(local.mem):
670 | if each.rc == 0:
671 | return i
672 | i += 1
673 | return i
674 |
675 |
676 | def valid_value(x: AbsVal, mk_exc=None):
677 | if x is Bot or x is Top:
678 | raise mk_exc and mk_exc() or TypeError
679 | return x
680 |
681 |
682 | def judge_lit(local: Local, a: AbsVal):
683 | if isinstance(a, D):
684 | abs_val = local.store.get(a.i)
685 | return abs_val or Bot
686 | return a
687 |
688 |
689 | def try_spec_val_then_decref(a, t: NonD) -> tuple[list, AbsVal]:
690 | """
691 | A dynamic abstract value might have a singleton type,
692 | in this case it's specialised into a static abstract vlaue.
693 | e.g., a value whose type is 'type(None)' must be exactly 'None'.
694 | """
695 | if t is Bot:
696 | return [], Bot
697 | if isinstance(a, D):
698 | # noinspection PyUnboundLocalVariable
699 | if (
700 | isinstance(t, S)
701 | and (shape := t.shape)
702 | and (inst := shape.instance)
703 | ):
704 | # noinspection PyTypeHints
705 | # noinspection PyUnboundLocalVariable
706 | if isinstance(inst, S):
707 | # noinspection PyUnboundLocalVariable
708 | return [Out_DecRef(a.i)], inst
709 | # noinspection PyUnboundLocalVariable
710 | inst = inst(t.params)
711 | return [Out_DecRef(a.i)], inst
712 | a_t = D(a.i, t)
713 | return [], a_t
714 | return [], a
715 |
716 |
717 | def blocks_to_instrs(blocks: dict[str, list[Out_Instr]], start: str):
718 | merge_blocks = defaultdict(OrderedDict)
719 | for label, block in sorted(blocks.items()):
720 | block = tuple(block)
721 | labels = merge_blocks[block]
722 | labels[label] = None
723 | instrs = []
724 | for block, labels in merge_blocks.items():
725 | for label in labels:
726 | instrs.append(Out_Label(label))
727 | instrs.extend(block)
728 | return instrs
729 |
730 |
731 | def decref(self: Judge, local: Local, a: AbsVal):
732 | if not isinstance(a, D):
733 | return local
734 | i = a.i
735 | if i < len(local.mem):
736 | mem = local.mem
737 | slot = mem[i]
738 | if refcnt := slot.rc - 1:
739 | slot = MemSlot(refcnt, slot.ila)
740 | else:
741 | if slot.ila:
742 | self << Out_DecRef(i)
743 | slot = MemSlot(0, True)
744 | mem = mem.set(i, slot)
745 | return local.up_mem(mem)
746 | return local
747 |
748 |
749 | def incref(local: Local, a: AbsVal):
750 | if not isinstance(a, D):
751 | return local
752 | i = a.i
753 | mem = local.mem
754 | try:
755 | ref = mem[i]
756 | slot = MemSlot(ref.rc + 1, ref.ila)
757 | mem = mem.set(i, slot)
758 | except IndexError:
759 | assert len(mem) == i
760 | mem = mem.append(MemSlot(1, True))
761 | return local.up_mem(mem)
762 |
763 |
764 | class Judge:
765 | def __init__(self, blocks: In_Blocks, func: FunctionType, abs_glob):
766 | # States
767 | self.in_blocks = blocks
768 | self.block_map: dict[tuple[str, Local], str] = {}
769 | self.returns: set[AbsVal] = set()
770 | self.code: list[Out_Instr] = []
771 | self.out_blocks: dict[str, list[Out_Instr]] = OrderedDict()
772 | self.abs_glob: dict[str, AbsVal] = abs_glob
773 | self.func = func
774 | self.label_cnt = 0
775 |
776 | @property
777 | def glob(self) -> dict:
778 | # noinspection PyUnresolvedReferences
779 | return self.func.__globals__
780 |
781 | @contextmanager
782 | def use_code(self, code: list[Out_Instr]):
783 | old_code = self.code
784 | self.code = code
785 | try:
786 | yield
787 | finally:
788 | self.code = old_code
789 |
790 | def gen_label(self, kind=""):
791 | self.label_cnt += 1
792 | return f"{kind}_{self.label_cnt}"
793 |
794 | def __lshift__(self, a):
795 | if isinstance(a, list):
796 | self.code.extend(a)
797 | else:
798 | self.code.append(a)
799 |
800 | def jump(self, local: Local, label: str):
801 | key = (label, local)
802 | if gen_label := self.block_map.get(key):
803 | return gen_label
804 | gen_label = self.gen_label()
805 | self.block_map[key] = gen_label
806 | code = []
807 | with self.use_code(code):
808 | self.stmt(local, self.in_blocks[label], 0)
809 | self.out_blocks[gen_label] = code
810 | return gen_label
811 |
812 | def stmt(self, local: Local, xs: Sequence[In_Stmt], index: int):
813 | try:
814 | while (hd := xs[index]) and isinstance(hd, In_SetLineno):
815 | self << Out_SetLineno(hd.line, hd.filename)
816 | index += 1
817 | except IndexError:
818 | # TODO
819 | raise Exception("non-terminaor terminate")
820 | if isinstance(hd, In_Move):
821 | a_x = judge_lit(local, hd.target)
822 | # print(hd, judge_lit(local, hd.source), local.store)
823 | a_y = judge_lit(local, hd.source)
824 | if a_y is Top or a_y is Bot:
825 | self.error(local)
826 | return
827 |
828 | local = decref(self, local, a_x)
829 | local = incref(local, a_y)
830 | local = local.up_store(local.store.set(hd.target.i, a_y))
831 | self.stmt(local, xs, index + 1)
832 | elif isinstance(hd, In_Goto):
833 | label_gen = self.jump(local, hd.label)
834 | self << Out_Goto(label_gen)
835 | elif isinstance(hd, In_Return):
836 | a = valid_value(judge_lit(local, hd.value))
837 | self.returns.add(a)
838 | self << Out_Return(a, tuple(self.all_ownerships(local)))
839 | elif isinstance(hd, In_Cond):
840 | a = judge_lit(local, hd.test)
841 | if a is Top or a is Bot:
842 | self.error(local)
843 | return
844 | if a.type == A_Bool:
845 | a_cond = a
846 | else:
847 | # extract
848 | instance, e_call, union_types = self.spec(
849 | A_Bool, "__call__", [a]
850 | ).astuple()
851 | if e_call is Top or e_call is Bot:
852 | self.error(local)
853 | return
854 |
855 | if not isinstance(e_call, (Out_Call, D)):
856 | a_cond = e_call
857 | else:
858 | j = alloc(local)
859 | a_cond = D(j, A_Bool)
860 | self << Out_Assign(
861 | a_cond,
862 | e_call,
863 | tuple(self.all_ownerships(local)),
864 | )
865 |
866 | self << Out_DecRef(j)
867 |
868 | if instance:
869 | a_cond = instance
870 |
871 | if a_cond.is_literal() and isinstance(a_cond.base, bool):
872 | direct_label = hd.then if a_cond.base else hd.otherwise
873 | label_generated = self.jump(local, direct_label)
874 | self << Out_Goto(label_generated)
875 | return
876 |
877 | l1 = self.jump(local, hd.then)
878 | l2 = self.jump(local, hd.otherwise)
879 | self << Out_If(a_cond, l1, l2)
880 | elif isinstance(hd, In_Bind):
881 | a_x = judge_lit(local, hd.target)
882 | a_subj = judge_lit(local, hd.sub)
883 | if a_subj is Top or a_subj is Bot:
884 | self.error(local)
885 | return
886 | attr = judge_lit(local, hd.attr)
887 | assert attr.is_literal() and isinstance(
888 | attr.base, str
889 | ), f"attr {attr} shall be a string"
890 | attr = attr.base
891 | a_args = []
892 | for a in hd.args:
893 | a_args.append(judge_lit(local, a))
894 | if a_args[-1] in (Top, Bot):
895 | self.error(local)
896 | return
897 | instance, e_call, union_types = self.spec(
898 | a_subj, attr, a_args
899 | ).astuple()
900 |
901 | if e_call in (Top, Bot):
902 | self.error(local)
903 | return
904 |
905 | # 1. no actual CALL happens
906 | if not isinstance(e_call, (Out_Call, D)):
907 | local = decref(self, local, a_x)
908 | rhs = e_call
909 | if instance:
910 | rhs = instance
911 | local = local.up_store(
912 | local.store.set(hd.target.i, rhs)
913 | )
914 | self.stmt(local, xs, index + 1)
915 | return
916 |
917 | j = alloc(local)
918 | # 2. CALL happens but not union-typed and the result might be a constant
919 | if len(union_types) == 1:
920 | a_t = union_types[0]
921 | a_spec = D(j, a_t)
922 | if isinstance(e_call, Out_Call):
923 | self << Out_Assign(
924 | a_spec,
925 | e_call,
926 | tuple(self.all_ownerships(local)),
927 | )
928 | # TODO: documenting that 'D(i, t1)' means 'D(i, t2)'
929 | # at a given program counter.
930 | local = decref(self, local, a_x)
931 | if a_t is Bot:
932 | self.error(local)
933 | return
934 | if instance:
935 | a_spec = instance
936 | code = []
937 | else:
938 | a_t = a_spec.type
939 | code, a_spec = try_spec_val_then_decref(
940 | a_spec, a_t
941 | ) # handle instance
942 | valid_value(a_spec)
943 | if code:
944 | self << code[0]
945 | local = incref(local, a_spec)
946 | local = local.up_store(
947 | local.store.set(hd.target.i, a_spec)
948 | )
949 | self.stmt(local, xs, index + 1)
950 | return
951 | # 3. CALL happens, union-typed;
952 | # result might be a constant for each type
953 | a_union = D(j, Top)
954 | self << Out_Assign(
955 | a_union, e_call, tuple(self.all_ownerships(local))
956 | )
957 | local = decref(self, local, a_x)
958 | # TODO: Top(if any) should be put in the last of 'union_types'
959 | split = [
960 | try_spec_val_then_decref(a_union, t)
961 | for t in union_types
962 | ]
963 | cases: list[tuple[S, list[Out_Instr]]] = []
964 | for (code, a_spec), a_t in zip(split, union_types):
965 | cases.append((a_t, code))
966 | if a_t is Bot:
967 | self.error(local)
968 | continue
969 |
970 | valid_value(a_spec)
971 | local_i = incref(local, a_spec)
972 | local_i = local_i.up_store(
973 | local_i.store.set(hd.target.i, a_spec)
974 | )
975 | with self.use_code(code):
976 | self.stmt(local_i, xs, index + 1)
977 | self << Out_TypeCase(
978 | a_union,
979 | pyrsistent.pmap(
980 | {case: tuple(code) for case, code in cases}
981 | ),
982 | )
983 |
984 | def no_spec(
985 | self, a_sub: AbsVal, attr: str, a_args: list[AbsVal]
986 | ) -> CallSpec:
987 | assert isinstance(attr, str)
988 | a_sub = valid_value(a_sub)
989 | if attr == "__call__":
990 | return CallSpec(
991 | None,
992 | S(Intrinsic.Py_CallFunction)(a_sub, *a_args),
993 | (Top,),
994 | )
995 | else:
996 | return CallSpec(
997 | None,
998 | S(Intrinsic.Py_CallMethod)(a_sub, S(attr), *a_args),
999 | (Top,),
1000 | )
1001 |
1002 | def spec(
1003 | self, a_sub: AbsVal, attr: str, a_args: list[AbsVal]
1004 | ) -> CallSpec:
1005 | assert isinstance(attr, str)
1006 | a_sub = valid_value(a_sub)
1007 | if attr == "__call__":
1008 |
1009 | def default():
1010 | return CallSpec(
1011 | None,
1012 | S(Intrinsic.Py_CallFunction)(a_sub, *a_args),
1013 | (Top,),
1014 | )
1015 |
1016 | else:
1017 |
1018 | def default():
1019 | return CallSpec(
1020 | None,
1021 | S(Intrinsic.Py_CallMethod)(a_sub, S(attr), *a_args),
1022 | (Top,),
1023 | )
1024 |
1025 | a_t = a_sub.type
1026 | if a_t is Top:
1027 | return default()
1028 | if a_t is Bot:
1029 | raise TypeError
1030 | if isinstance(a_sub, S):
1031 | # python literal is not callable
1032 | if a_sub.is_literal():
1033 | a_t = cast(S, a_t)
1034 | shape = a_t.shape
1035 | if not shape:
1036 | return default()
1037 | assert shape.oop
1038 | meth = judge_resolve(shape, attr)
1039 | if not meth:
1040 | return default()
1041 | # noinspection PyTypeHints
1042 | if isinstance(meth, AbsVal):
1043 | return self.spec(meth, "__call__", [a_sub, *a_args])
1044 | r = meth(self, a_sub, *a_args)
1045 | if r is NotImplemented:
1046 | return default()
1047 | return r
1048 | shape = a_sub.shape
1049 | if shape and (meth_ := judge_resolve(shape, attr)):
1050 | # hack pycharm for type check
1051 | # noinspection PyUnboundLocalVariable
1052 | meth = meth_
1053 | if shape.self_bound:
1054 | a_args = [a_sub, *a_args]
1055 | # noinspection PyTypeHints
1056 | if isinstance(meth, AbsVal):
1057 | return self.spec(meth, "__call__", a_args)
1058 |
1059 | r = meth(self, *a_args)
1060 | if r is NotImplemented:
1061 | return default()
1062 | return r
1063 |
1064 | a_t = cast(S, a_t)
1065 | shape = a_t.shape
1066 | if not shape:
1067 | return default()
1068 | meth = judge_resolve(shape, attr)
1069 | if not meth:
1070 | return default()
1071 | if shape.oop:
1072 | a_args = [a_sub, *a_args]
1073 |
1074 | # noinspection PyTypeHints
1075 | if isinstance(meth, AbsVal):
1076 | return self.spec(meth, "__call__", a_args)
1077 |
1078 | r = meth(self, *a_args)
1079 | if r is NotImplemented:
1080 | return default()
1081 | return r
1082 |
1083 | def all_ownerships(self, local: Local):
1084 | for i, each in enumerate(local.mem):
1085 | if each.rc and each.ila:
1086 | yield i
1087 |
1088 | def error(self, local):
1089 | raise ValueError(local)
1090 |
1091 |
1092 | def ufunc_spec(self, a_func: AbsVal, *arguments: AbsVal) -> CallSpec:
1093 | a_func = valid_value(a_func)
1094 |
1095 | def default():
1096 | return CallSpec(
1097 | None,
1098 | S(Intrinsic.Py_CallFunction)(a_func, *arguments),
1099 | (Top,),
1100 | )
1101 |
1102 | if isinstance(a_func, D):
1103 | return default()
1104 | # isinstance functiontype not handled in PyCharm
1105 | assert isinstance(a_func.base, FunctionType)
1106 | func = a_func.base
1107 | in_def = In_Def.UserCodeDyn.get(func)
1108 | if not in_def:
1109 | # not registered as jit func, skip
1110 | return default()
1111 | parameters = []
1112 | mem = []
1113 | store = {}
1114 |
1115 | for i, a_arg in enumerate(arguments):
1116 | if isinstance(a_arg, D):
1117 | j = len(mem)
1118 | mem.append(MemSlot(1, False))
1119 | a_param = D(j, a_arg.type)
1120 | parameters.append(a_param)
1121 |
1122 | else:
1123 | a_param = a_arg
1124 | parameters.append(a_arg)
1125 | store[i] = a_param
1126 |
1127 | parameters = tuple(parameters)
1128 | call_record = a_func.base, parameters
1129 | if call_record in SpecMaps:
1130 | spec = SpecMaps[call_record]
1131 | e_call = spec.abs_jit_func(*arguments)
1132 | ret_types = spec.possibly_return_types
1133 | instance = spec.instance
1134 | elif partial_spec := PreSpecMaps.get(call_record):
1135 | jit_func_name, partial_returns = partial_spec
1136 | abs_jit_func = S(intrinsic(jit_func_name))
1137 | e_call = abs_jit_func(*arguments)
1138 | partial_return_types = set(
1139 | each.type for each in partial_returns
1140 | )
1141 | partial_return_types.add(Top)
1142 | ret_types = tuple(sorted(partial_return_types))
1143 | instance = None
1144 | else:
1145 | partial_returns = set()
1146 | jit_func_name = mk_prespec_name(
1147 | call_record, partial_returns, name=in_def.name
1148 | )
1149 |
1150 | abs_glob = {}
1151 | for glob_name in in_def.static_glob:
1152 | v = in_def.glob.get(glob_name, undef)
1153 | if v is undef:
1154 | v = getattr(builtins, glob_name, undef)
1155 | if v is undef:
1156 | continue
1157 | a_v = from_runtime(v)
1158 | if isinstance(v, D):
1159 | continue
1160 | abs_glob[glob_name] = a_v
1161 |
1162 | sub_judge = Judge(
1163 | blocks=in_def.blocks, func=in_def.func, abs_glob=abs_glob
1164 | )
1165 | sub_judge.returns = partial_returns
1166 | local = Local(pyrsistent.pvector(mem), pyrsistent.pmap(store))
1167 | gen_start = sub_judge.jump(local, "entry")
1168 | instrs = blocks_to_instrs(sub_judge.out_blocks, gen_start)
1169 |
1170 | ret_types: tuple[AbsVal, ...] = tuple(
1171 | sorted({r.type for r in sub_judge.returns})
1172 | )
1173 | instance, *is_union = sub_judge.returns
1174 | instance = valid_value(instance)
1175 |
1176 | if isinstance(instance, D):
1177 | instance = None
1178 | elif is_union:
1179 | instance = None
1180 | elif instance is None:
1181 | t = ret_types[0]
1182 | if t.is_s():
1183 | instance = t.shape.instance
1184 | if isinstance(instance, FunctionType):
1185 | instance = instance(t.params)
1186 |
1187 | intrin = intrinsic(jit_func_name)
1188 | spec_info = JITSpecInfo(instance, S(intrin), ret_types)
1189 | SpecMaps[call_record] = spec_info
1190 | out_def = Out_Def(
1191 | spec_info, parameters, tuple(instrs), gen_start, in_def.func
1192 | )
1193 | Out_Def.GenerateCache[intrin] = out_def
1194 | e_call = spec_info.abs_jit_func(*arguments)
1195 |
1196 | return CallSpec(instance, e_call, ret_types)
1197 |
1198 |
1199 | ShapeSystem[types.FunctionType] = Shape(
1200 | types.FunctionType,
1201 | oop=True,
1202 | fields={"__call__": cast(FunctionType, ufunc_spec)},
1203 | )
1204 |
1205 |
1206 | def judge_resolve(shape: Shape, attr: str):
1207 | if meth := shape.fields.get(attr):
1208 | return meth
1209 | if not isinstance(shape.name, type):
1210 | return None
1211 | for base in shape.name.__bases__:
1212 | if (shape := ShapeSystem.get(base)) and (
1213 | meth := shape.fields.get(attr)
1214 | ):
1215 | return meth
1216 |
1217 | return None
1218 |
--------------------------------------------------------------------------------
/diojit/absint/intrinsics.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import sys
3 | import operator
4 | from typing import Callable, Optional
5 |
6 | __all__ = ["intrinsic", "Intrinsic"]
7 |
8 |
9 | def setup():
10 | pass
11 |
12 |
13 | def _mk(name, bases, ns):
14 | global setup
15 | t = type(name, bases, ns)
16 |
17 | def setup():
18 | for n, v in ns["__annotations__"].items():
19 | if n.startswith("_"):
20 | continue
21 | setattr(t, n, intrinsic(n))
22 |
23 | return t
24 |
25 |
26 | class Intrinsic(metaclass=_mk):
27 | _callback: Optional[Callable] = None
28 | _name: str
29 | J: object
30 |
31 | @property
32 | def name(self):
33 | return self._name
34 |
35 | def __repr__(self):
36 | return f"{self._name}"
37 |
38 | def __eq__(self, other):
39 | return (
40 | isinstance(other, Intrinsic) and self._name is other._name
41 | )
42 |
43 | def __hash__(self):
44 | return id(self._name)
45 |
46 | Py_TYPE: Intrinsic
47 |
48 | # hints: implemented by
49 | # 1. PyObject_CallFunctionObjArgs()
50 | # 2. PyObject_CallNoArgs()
51 | # 3. PyObject_CallOneArg()
52 | Py_CallFunction: Intrinsic
53 |
54 | # 1. PyObject_CallMethodNoArgs()
55 | # 2. PyObject_CallMethodOneArg()
56 | # 3. PyObject_CallMethodObjArgs()
57 | Py_CallMethod: Intrinsic
58 |
59 | Py_LoadGlobal: Intrinsic
60 |
61 | Py_StoreGlobal: Intrinsic
62 |
63 | Py_StoreAttr: Intrinsic
64 |
65 | Py_LoadAttr: Intrinsic
66 |
67 | Py_BuildTuple: Intrinsic
68 |
69 | Py_BuildList: Intrinsic
70 |
71 | Py_Raise: Intrinsic
72 |
73 | Py_AddressCompare: Intrinsic
74 | Py_Not: Intrinsic
75 |
76 | # Py_Pow = operator.__pow__
77 | # Py_Mul = operator.__mul__
78 | # Py_Matmul = operator.__matmul__
79 | # Py_Floordiv = operator.__floordiv__
80 | # Py_Truediv = operator.__truediv__
81 | # Py_Mod = operator.__mod__
82 | # Py_Add = operator.__add__
83 | # Py_Sub = operator.__sub__
84 | # Py_Getitem = operator.__getitem__
85 | # Py_Lshift = operator.__lshift__
86 | # PY_Rshift = operator.__rshift__
87 | # Py_And = operator.__and__
88 | # Py_Xor = operator.__xor__
89 | # Py_Or = operator.__or__
90 | #
91 | # Py_Lt = operator.__lt__
92 | # Py_Gt = operator.__gt__
93 | # Py_Le = operator.__le__
94 | # Py_Ge = operator.__ge__
95 | # Py_Ne = operator.__ne__
96 | # Py_Eq = operator.__eq__
97 |
98 |
99 | _cache = {}
100 |
101 |
102 | def intrinsic(name: str) -> Intrinsic:
103 | name = sys.intern(name)
104 | if o := _cache.get(name):
105 | return o
106 | o = Intrinsic.__new__(Intrinsic)
107 | o._name = name
108 | _cache[name] = o
109 | return o
110 |
111 |
112 | setup()
113 |
--------------------------------------------------------------------------------
/diojit/absint/prescr.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from .abs import *
3 | from .intrinsics import *
4 | from collections.abc import Iterable
5 | from functools import lru_cache
6 | import warnings
7 | import typing
8 | import operator
9 | import math
10 | import builtins
11 | import io
12 |
13 | if typing.TYPE_CHECKING:
14 | from mypy_extensions import VarArg
15 |
16 | _undef = object()
17 | __all__ = ["create_shape", "register"]
18 |
19 | _not_inferred_types = (Top, Bot)
20 |
21 |
22 | def lit(x: str):
23 | """directly pass this string to the backend"""
24 | return typing.cast(AbsVal, x)
25 |
26 |
27 | def u64i(i: int):
28 | """uint64 from integer"""
29 | return f"{i:#0{18}x}"
30 |
31 |
32 | concrete_numeric_types = {
33 | int: "PyInt_Compare",
34 | float: "PyObject_RichCompare",
35 | complex: "PyObject_RichCompare",
36 | bool: "PyInt_Compare",
37 | }
38 |
39 |
40 | def create_shape(
41 | o: object, oop: bool = False, self_bound=False, instance=_undef
42 | ):
43 | """
44 | o: 'Special Object' in the paper. Must be immutable.
45 | oop: whether attached methods will be
46 | used as bound method in method resolution.
47 | instance: only works when 'o' is a 'type'.
48 | It is defined only the type has only
49 | one instance.
50 | """
51 | try:
52 | hash(o)
53 | except TypeError:
54 | raise TypeError(f"create shape for non immuatble object {o}")
55 | if instance is _undef:
56 | instance = None
57 | else:
58 | instance = from_runtime(instance)
59 | assert not isinstance(instance, D)
60 |
61 | if shape := ShapeSystem.get(o):
62 | return shape
63 | shape = ShapeSystem[o] = Shape(o, oop, {}, instance, self_bound)
64 | return shape
65 |
66 |
67 | _create_shape = create_shape
68 |
69 |
70 | def register(
71 | o: object,
72 | attr="__call__",
73 | create_shape: typing.Union[dict, None, typing.Literal[True]] = None,
74 | ):
75 |
76 | if shape := ShapeSystem.get(o):
77 | pass
78 | else:
79 | if create_shape is not None:
80 | if create_shape is True:
81 | create_shape = {}
82 | shape = _create_shape(o, **create_shape)
83 | else:
84 | raise ValueError(
85 | f"No shape found for {{ o={o} }}.\n"
86 | f" Maybe use {{ create_shape(o, oop_or_not) }} firstly?"
87 | )
88 | if attr in shape.fields:
89 | warnings.warn(
90 | Warning(
91 | f"field {attr} exists for the shape of object f{o}."
92 | )
93 | )
94 |
95 | def ap(f: typing.Callable[[Judge, VarArg(AbsVal)], CallSpec]):
96 | shape.fields[attr] = f
97 | return f
98 |
99 | return ap
100 |
101 |
102 | def shape_of(o):
103 | return ShapeSystem[o]
104 |
105 |
106 | create_shape(Intrinsic, oop=True)
107 | create_shape(list, oop=True)
108 | create_shape(dict, oop=True)
109 | create_shape(bytes, oop=True)
110 | create_shape(bool, oop=True)
111 | create_shape(bytearray, oop=True)
112 | create_shape(int, oop=True)
113 | # create_shape(type, self_bound=True, oop=True)
114 | create_shape(next)
115 | create_shape(io.BytesIO)
116 | _NoneType_shape = create_shape(type(None))
117 | _NoneType_shape.instance = S(None)
118 |
119 |
120 | @register(Intrinsic, "__call__")
121 | def py_call_intrinsic(
122 | self: Judge, f: AbsVal, *args: AbsVal
123 | ) -> CallSpec:
124 | return CallSpec(None, f(*args), (Top,))
125 |
126 |
127 | @register(
128 | Intrinsic.Py_LoadGlobal, "__call__", create_shape=dict(oop=False)
129 | )
130 | def py_load_global(self: Judge, a_str: AbsVal) -> CallSpec:
131 | """
132 | TODO: S(self.glob) is UNSAFE here.
133 | Theoretically, 'S' shall not be used for
134 | a mutable data. This is a special case.
135 | """
136 |
137 | def slow_path():
138 | instance = None
139 | func = S(intrinsic("PyDict_LoadGlobal"))
140 | e_call = func(S(self.func), S(builtins), a_str)
141 | ret_types = (Top,)
142 | return CallSpec(instance, e_call, ret_types)
143 |
144 | def constant_key_path():
145 | instance = None
146 | hash_val = hash(a_str.base)
147 | func = S(intrinsic("PyDict_LoadGlobal_KnownHash"))
148 | e_call = func(
149 | S(self.func), S(builtins), a_str, lit(str(hash_val))
150 | )
151 | ret_types = (Top,)
152 | return CallSpec(instance, e_call, ret_types)
153 |
154 | if a_str.is_literal():
155 | if isinstance(a_str.base, str):
156 | attr = a_str.base
157 | if attr in self.abs_glob:
158 | a = self.abs_glob[attr]
159 | return CallSpec(a, a, possibly_return_types=(a.type,))
160 |
161 | return constant_key_path()
162 | return slow_path()
163 |
164 |
165 | @register(bool, attr="__call__")
166 | def py_call_bool_type(self: Judge, *args: AbsVal):
167 | if not args:
168 | # bool() = False
169 | constant_return = S(False)
170 | return CallSpec(
171 | constant_return, constant_return, (Values.A_Bool,)
172 | )
173 | if len(args) != 1:
174 | # bool(a, b, c) = False
175 | return NotImplemented
176 | # bool(a)
177 | arg = args[0]
178 | if isinstance(arg.type, S) and issubclass(arg.type.base, bool):
179 | constant_return = isinstance(arg, S) and arg or None
180 | return CallSpec(constant_return, arg, (Values.A_Bool,))
181 | return CallSpec(
182 | None,
183 | S(intrinsic("Py_CallBoolIfNecessary"))(arg),
184 | (Values.A_Bool,),
185 | )
186 |
187 |
188 | @register(isinstance, create_shape=True)
189 | def spec_isinstance(self: Judge, *args: AbsVal):
190 | if len(args) != 2:
191 | return NotImplemented
192 | l, r = args
193 | return_types = tuple({Values.A_Bool})
194 | if (
195 | isinstance(l.type, S)
196 | and isinstance(r, S)
197 | and isinstance(r.base, type)
198 | ):
199 | const = l.type == r or l.type.base in r.base.__bases__
200 | return CallSpec(S(const), S(const), return_types)
201 |
202 | func = S(intrinsic("PyObject_IsInstance"))
203 | return CallSpec(None, func(l, r), return_types)
204 |
205 |
206 | @register(operator.__pow__, create_shape=True)
207 | def spec_pow(self: Judge, l: AbsVal, r: AbsVal):
208 | if l.type == Values.A_Int:
209 | if r.type == Values.A_Int:
210 | py_int_power_int = S(intrinsic("Py_IntPowInt"))
211 | return_types = tuple({Values.A_Int})
212 | constant_result = None # no constant result
213 |
214 | return CallSpec(
215 | constant_result, py_int_power_int(l, r), return_types
216 | )
217 | return NotImplemented
218 |
219 |
220 | @register(len, create_shape=True)
221 | def spec_len(self: Judge, *args: AbsVal):
222 | if len(args) != 1:
223 | return NotImplemented
224 | arg = args[0]
225 |
226 | func = S(intrinsic("PySequence_Length"))
227 | e_call = func(arg)
228 | return CallSpec(None, e_call, tuple({Values.A_Int}))
229 |
230 |
231 | @register(operator.__add__, create_shape=True)
232 | def spec_add(self: Judge, l: AbsVal, r: AbsVal):
233 | if l.type == Values.A_Int:
234 | if r.type == Values.A_Int:
235 | py_int_add_int = S(intrinsic("Py_IntAddInt"))
236 | return_types = tuple({Values.A_Int})
237 | constant_result = None # no constant result
238 | return CallSpec(
239 | constant_result, py_int_add_int(l, r), return_types
240 | )
241 | return NotImplemented
242 |
243 |
244 | @register(operator.__iadd__, create_shape=True)
245 | def spec_add(self: Judge, l: AbsVal, r: AbsVal):
246 | if l.type == Values.A_Int:
247 | if r.type == Values.A_Int:
248 | py_int_add_int = S(intrinsic("Py_IntAddInt"))
249 | return_types = tuple({Values.A_Int})
250 | constant_result = None # no constant result
251 | return CallSpec(
252 | constant_result, py_int_add_int(l, r), return_types
253 | )
254 | return NotImplemented
255 |
256 |
257 | def spec_cmp(op):
258 | def spec_op(self: Judge, *args: AbsVal):
259 | if len(args) != 2:
260 | return NotImplemented
261 | l, r = args
262 | func_name = "PyObject_RichCompare"
263 | l_t = l.type
264 | r_t = r.type
265 | if (
266 | l_t.is_s()
267 | and r_t.is_s()
268 | and l_t.base in concrete_numeric_types
269 | and r_t.base in concrete_numeric_types
270 | ):
271 | if l_t.base == r_t.base:
272 | func_name = concrete_numeric_types[l_t.base]
273 | ret_types = tuple({Values.A_Bool})
274 | else:
275 | ret_types = tuple({Top})
276 |
277 | func = S(intrinsic(func_name))
278 | return CallSpec(None, func(l, r, lit(op)), ret_types)
279 |
280 | return spec_op
281 |
282 |
283 | register(operator.__le__, create_shape=True)(spec_cmp("Py_LE"))
284 | register(operator.__lt__, create_shape=True)(spec_cmp("Py_LT"))
285 | register(operator.__ge__, create_shape=True)(spec_cmp("Py_GE"))
286 | register(operator.__gt__, create_shape=True)(spec_cmp("Py_GT"))
287 | register(operator.__eq__, create_shape=True)(spec_cmp("Py_EQ"))
288 | register(operator.__ne__, create_shape=True)(spec_cmp("Py_NE"))
289 |
290 |
291 | @register(math.sqrt, create_shape=True)
292 | def spec_sqrt(self: Judge, a: AbsVal):
293 | if a.type == Values.A_Int:
294 | int_sqrt = S(intrinsic("Py_IntSqrt"))
295 | return CallSpec(None, int_sqrt(a), tuple({Values.A_Float}))
296 | return NotImplemented
297 |
298 |
299 | @register(setattr, create_shape=True)
300 | def call_setattr(self: Judge, *args: AbsVal):
301 | if len(args) != 3:
302 | return NotImplemented
303 |
304 | func = S(intrinsic("PyObject_SetAttr"))
305 | e_call = func(*args)
306 | return CallSpec(S(None), e_call, (Values.A_NoneType,))
307 |
308 |
309 | @register(getattr, create_shape=True)
310 | def call_getattr(self: Judge, *args: AbsVal):
311 | if len(args) != 2:
312 | return NotImplemented
313 | ret_types = (Top,)
314 | subject, attr = args
315 | # noinspection PyUnboundLocalVariable
316 | if (
317 | subject.type.is_s()
318 | and (shape := subject.type.shape)
319 | and (__getattr__ := shape.fields.get("__getattr__"))
320 | ):
321 | # noinspection PyUnboundLocalVariable
322 | __getattr__ = __getattr__
323 | if not shape.oop:
324 | args = [attr]
325 | if isinstance(__getattr__, S):
326 | r = self.spec(__getattr__, "__call__", *args)
327 | if r is not NotImplemented:
328 | return r
329 | else:
330 | r = __getattr__(self, *args)
331 | if r is not NotImplemented:
332 | return r
333 |
334 | func = S(intrinsic("PyObject_GetAttr"))
335 | e_call = func(subject, attr)
336 | instance = None
337 | return CallSpec(instance, e_call, ret_types)
338 |
339 |
340 | @register(operator.__getitem__, create_shape=True)
341 | def call_getitem(self: Judge, *args: AbsVal):
342 | if len(args) != 2:
343 | return NotImplemented
344 | subject, item = args
345 | ret_types = (Top,)
346 | sub_t = subject.type
347 | if sub_t not in (Top, Bot):
348 | if sub_t.shape and (
349 | dispatched := sub_t.shape.fields.get("__getitem__")
350 | ):
351 | # noinspection PyUnboundLocalVariable
352 | if isinstance(dispatched, FunctionType):
353 | # noinspection PyUnboundLocalVariable
354 | return dispatched(self, *args)
355 | return self.spec(subject, "__call__", *args)
356 |
357 | func = S(intrinsic("PyObject_GetItem"))
358 | e_call = func(subject, item)
359 | instance = None
360 | return CallSpec(instance, e_call, ret_types)
361 |
362 |
363 | @register(list, attr="__getitem__")
364 | def call_list_getitem(self: Judge, *args):
365 | if len(args) != 2:
366 | return NotImplemented
367 | ret_types = (Top,)
368 | subject, item = args
369 | func = S(intrinsic("PyList_GetItem"))
370 | e_call = func(subject, item)
371 | instance = None
372 | return CallSpec(instance, e_call, ret_types)
373 |
374 |
375 | @register(dict, attr="__getitem__")
376 | def call_list_getitem(self: Judge, *args):
377 | if len(args) != 2:
378 | return NotImplemented
379 | ret_types = (Top,)
380 | subject, item = args
381 | func = S(intrinsic("PyDict_GetItemWithError"))
382 | e_call = func(subject, item)
383 | instance = None
384 | return CallSpec(instance, e_call, ret_types)
385 |
386 |
387 | @register(bytearray, attr="__getitem__")
388 | def call_bytearray_getitem(self: Judge, *args):
389 | if len(args) != 2:
390 |
391 | return NotImplemented
392 | ret_types = (Values.A_Int,)
393 | subject, item = args
394 | func = S(intrinsic("PyObject_GetItem"))
395 | e_call = func(subject, item)
396 | instance = None
397 | return CallSpec(instance, e_call, ret_types)
398 |
399 |
400 | @register(bytes, attr="join")
401 | def call_bytearray_join(self: Judge, *args: AbsVal):
402 | if not args != 2:
403 | return NotImplemented
404 | sep, iters = args
405 | if (
406 | sep.type.is_s()
407 | and sep.type.base is bytes
408 | and iters.type.is_s()
409 | and issubclass(iters.type.base, Iterable)
410 | ):
411 | # https://sourcegraph.com/github.com/python/cpython@3.9/-/blob/Include/cpython/bytesobject.h#L36
412 | func = S(intrinsic("_PyBytes_Join"))
413 | e_call = func(sep, iters)
414 | return CallSpec(None, e_call, (S(bytes),))
415 |
416 | spec = self.no_spec(sep, "join", [iters])
417 | return CallSpec(None, spec.e_call, (S(bytes),))
418 |
419 |
420 | @register(bytes, attr="__getitem__")
421 | def call_bytearray_getitem(self: Judge, *args: AbsVal):
422 | if len(args) != 2:
423 | return NotImplemented
424 | ret_types = (Values.A_Int,)
425 | subject, item = args
426 | # Sequence protocol is slower:
427 | # if item.type.is_s() and issubclass(item.type.base, int):
428 | # func = S(intrinsic("PySequence_GetItem"))
429 | # else:
430 | func = S(intrinsic("PyObject_GetItem"))
431 | e_call = func(subject, item)
432 | instance = None
433 | return CallSpec(instance, e_call, ret_types)
434 |
435 |
436 | @register(operator.__setitem__, create_shape=True)
437 | def call_getitem(self: Judge, *args: AbsVal):
438 | if len(args) != 3:
439 | # default python impl
440 | return NotImplemented
441 | subject, item, value = args
442 | func = S(intrinsic("PyObject_SetItem"))
443 | e_call = func(subject, item, value)
444 | instance = None
445 | ret_types = (Top,)
446 | return CallSpec(instance, e_call, ret_types)
447 |
448 |
449 | @register(list, attr="copy")
450 | def call_list_copy(self: Judge, *args: AbsVal):
451 | if len(args) != 1:
452 | return NotImplemented
453 | return CallSpec(
454 | None,
455 | S(Intrinsic.Py_CallMethod)(args[0], S("copy")),
456 | tuple({S(list)}),
457 | )
458 |
459 |
460 | @register(list, attr="append")
461 | def list_append_analysis(self: Judge, *args: AbsVal):
462 | if len(args) != 2:
463 | # rollback to CPython's default code
464 | return NotImplemented
465 | lst, elt = args
466 |
467 | return CallSpec(
468 | instance=None, # return value is not static
469 | e_call=S(intrinsic("PyList_Append"))(lst, elt),
470 | possibly_return_types=tuple({S(type(None))}),
471 | )
472 |
473 |
474 | @register(Intrinsic.Py_BuildList, create_shape=True)
475 | def call_build_list(self: Judge, *args: AbsVal):
476 | ret_types = tuple({S(list)})
477 | func = S(intrinsic("PyList_Construct"))
478 |
479 | return CallSpec(None, func(*args), ret_types)
480 |
481 |
482 | @register(io.BytesIO)
483 | def call_bytes_io(self: Judge, *args: AbsVal):
484 | if len(args) != 1:
485 | return NotImplemented
486 |
487 | arg = args[0]
488 | func = S(Intrinsic.Py_CallFunction)
489 | abs_bytes_io = S(io.BytesIO)
490 | return CallSpec(None, func(abs_bytes_io, arg), (abs_bytes_io,))
491 |
492 |
493 | next_type_maps = {io.BytesIO: bytes}
494 |
495 |
496 | @register(operator.is_, create_shape=True)
497 | def call_is(self: Judge, *args: AbsVal):
498 | if len(args) != 2:
499 | return NotImplemented
500 | ret_types = (Values.A_Bool,)
501 | l, r = args
502 | if l == r:
503 | return CallSpec(S(True), S(True), ret_types)
504 | if l.type.is_s() and r.type.is_s() and l.type.base != r.type.base:
505 |
506 | return CallSpec(S(False), S(False), ret_types)
507 |
508 | func = S(intrinsic("Py_AddressCompare"))
509 | return CallSpec(None, func(*args), ret_types)
510 |
511 |
512 | @register(operator.__not__, create_shape=True)
513 | def call_not_(self: Judge, *args: AbsVal):
514 | if len(args) != 1:
515 | return NotImplemented
516 | ret_types = (Values.A_Bool,)
517 | arg = args[0]
518 | if arg.is_literal():
519 | const = S(not arg.base)
520 | return CallSpec(const, const, ret_types)
521 | c = self.no_spec(S(operator.__not__), "__call__", list(args))
522 | return CallSpec(c.instance, c.e_call, ret_types)
523 |
524 |
525 | @register(next, create_shape=True)
526 | def call_next(self: Judge, *args: AbsVal):
527 | if len(args) not in (1, 2):
528 | return NotImplemented
529 |
530 | o = args[0]
531 | if o.type in (Top, Bot):
532 | return NotImplemented
533 |
534 | t = o.type.base
535 | if eltype := next_type_maps.get(t):
536 | ret_types = {S(eltype)}
537 | else:
538 | ret_types = {Top}
539 |
540 | if len(args) == 2:
541 | default = args[1]
542 | ret_types.add(default.type)
543 | func = S(Intrinsic.Py_CallFunction)
544 |
545 | ret_types = tuple(sorted(ret_types))
546 | return CallSpec(None, func(S(next), *args), ret_types)
547 |
548 |
549 | # @register(int, "__init__")
550 | # def call_int(self: Judge, *args: AbsVal):
551 | # return CallSpec(S(None), S(None), (Values.A_NoneType, ))
552 |
553 |
554 | @register(int)
555 | def call_int(self: Judge, t: AbsVal, *args: AbsVal):
556 | A_Int = S(int)
557 | return_types = (A_Int,)
558 | if len(args) == 0:
559 | return CallSpec(S(0), S(0), return_types)
560 | if len(args) != 1:
561 | return NotImplemented
562 | o = args[0]
563 | if o.type == A_Int:
564 | return CallSpec(None, o, return_types)
565 |
566 | func = S(intrinsic("PyNumber_Long"))
567 | return CallSpec(None, func(o), return_types)
568 |
569 |
570 | # @lru_cache()
571 | # def mk_call_type_n(N):
572 | # args = ",".join(f"x{a}" for a in range(N))
573 | # name = f"call_type{N}"
574 | # f = f"""
575 | # def {name}(typ: AbsVal, {args}):
576 | # o = typ.__new__(typ, {args})
577 | # if not isinstance(o, typ):
578 | # return o
579 | # typ.__init__(o, {args})
580 | # return o
581 | # """
582 | # scope = {}
583 | # exec(f, scope)
584 | # func = scope[name]
585 | # from diojit.user.client import jit
586 | #
587 | # return jit(func, fixed_references=["isinstance"])
588 | #
589 | #
590 | # @register(type)
591 | # def call_type(self: Judge, typ: AbsVal, *args: AbsVal):
592 | # if typ == Values.A_Type and args:
593 | # if len(args) == 1:
594 | # arg = args[0]
595 | # if arg.type.is_s():
596 | # a_t = arg.type
597 | # return CallSpec(a_t, a_t, (Values.A_Type,))
598 | # func = S(intrinsic("PyObject_Type"))
599 | # return CallSpec(None, func(arg), (Values.A_Type,))
600 | # return NotImplemented
601 | # if typ.is_s():
602 | # return self.spec(
603 | # S(mk_call_type_n(len(args))), "__call__", [typ, *args]
604 | # )
605 | # return NotImplemented
606 |
--------------------------------------------------------------------------------
/diojit/codegen/__init__.py:
--------------------------------------------------------------------------------
1 | from . import julia
--------------------------------------------------------------------------------
/diojit/codegen/julia.py:
--------------------------------------------------------------------------------
1 | """
2 | why julia?
3 | 1. low latency incremental compilation(LLVM)
4 | 2. easier interface and ABI treatment(LLVM C API is hardo)
5 | 3. can do some zero-specialisation in julia side.
6 | For instance, manual analysis for Python RC's ownership transfer
7 | is verbose, but with Julia specialisation it is easily
8 | made automatic.
9 | 4. I love Julia
10 | """
11 | from __future__ import annotations
12 | from ..absint import *
13 | from io import StringIO
14 | from typing import Union
15 | from contextlib import contextmanager
16 | from itertools import repeat, chain
17 |
18 | import json
19 |
20 |
21 | def splice(o):
22 | return f"@DIO_Obj({u64o(o)})"
23 |
24 |
25 | def u64o(o: object):
26 | """
27 | uint64 address from object
28 | """
29 | return u64i(id(o))
30 |
31 |
32 | def u64i(i: int):
33 | """uint64 from integer"""
34 | return f"{i:#0{18}x}"
35 |
36 |
37 | class Codegen:
38 | def __init__(self, out_def: Out_Def):
39 | self.out_def = out_def
40 | self.io = StringIO()
41 | self.indent = ""
42 | self.params = set(map(self.param, self.out_def.params))
43 | # self._inc = 0
44 |
45 | def __lshift__(self, other: str):
46 | # self.io.write(f'println({self._inc})\n')
47 | # self._inc += 1
48 |
49 | self.io.write(self.indent)
50 | self.io.write(other)
51 | self.io.write("\n")
52 |
53 | def __matmul__(self, other):
54 | self.io.write("\n")
55 | self.io.write(other)
56 | self.io.write("\n")
57 |
58 | def var_i(self, i: int):
59 | return f"x{i}"
60 |
61 | def var(self, target: D):
62 | return self.var_i(target.i)
63 |
64 | def param(self, p: Union[S, D]):
65 | if isinstance(p, D):
66 | return f"x{p.i}"
67 | return "_"
68 |
69 | @staticmethod
70 | def uint64(i):
71 | return u64i(i)
72 |
73 | def val(self, v: Union[str, S, D]):
74 | if isinstance(v, str):
75 | # see 'diojit.prescr.lit'
76 | return v
77 |
78 | if isinstance(v, S):
79 | base = v.base
80 | if isinstance(base, Intrinsic):
81 | a = repr(base)
82 | return a
83 | # get object from address
84 | return f"{splice(v.base)} #= {repr(v.base).replace('=#', '//=//#')} =#"
85 |
86 | return self.var(v)
87 |
88 | def call(self, x: Out_Call):
89 | assert isinstance(x.func, S) and isinstance(
90 | x.func.base, Intrinsic
91 | ), x
92 | f = self.val(x.func)
93 | args = ", ".join(map(self.val, x.args))
94 | return f, f"{f}({args})"
95 |
96 | def get_jl_definitions(self):
97 | self.io = StringIO()
98 | spec_info = self.out_def.spec
99 | with self.indent_inc():
100 | self.visit_many(self.out_def.instrs)
101 | func_body = self.io.getvalue()
102 | self.io = StringIO()
103 | params = ", ".join(map(self.param, self.out_def.params))
104 | self << f"DIO.@codegen DIO.@q function {spec_info.abs_jit_func}({params})"
105 | self << func_body
106 | self @ "@label except"
107 | # TODO: add traceback
108 | self << " DIO_Return = Py_NULL"
109 | self @ "@label ret"
110 | self << " return DIO_Return"
111 | self << "end"
112 | doc_io = StringIO()
113 | self.out_def.show(lambda *args: print(*args, file=doc_io))
114 | # function documentation
115 | doc = json.dumps(doc_io.getvalue()).replace("$", "\\$")
116 | self << f"const DOC_{spec_info.abs_jit_func} = " f"{doc}"
117 | return self.io.getvalue()
118 |
119 | def get_py_interfaces(self):
120 | narg = len(self.out_def.params)
121 | spec_info = self.out_def.spec
122 | return f"@DIO_MakePtrCFunc {narg} {spec_info.abs_jit_func} {self.out_def.name}\n"
123 |
124 | @contextmanager
125 | def indent_inc(self):
126 | old = self.indent
127 | try:
128 | self.indent = old + " "
129 | yield
130 | finally:
131 | self.indent = old
132 |
133 | def visit_many(self, instrs: tuple[Out_Instr, ...]):
134 | for each in instrs:
135 | self.visit(each)
136 |
137 | def visit(self, instr: Out_Instr):
138 | if isinstance(instr, Out_SetLineno):
139 | filename = json.dumps(instr.filename).replace("$", "\\$")
140 | self << f"@DIO_SetLineno {instr.line} {filename}"
141 | pass
142 | elif isinstance(instr, Out_Assign):
143 | var = self.var(instr.target)
144 | assert isinstance(instr.expr, Out_Call)
145 | f, val = self.call(instr.expr)
146 | self << f"__tmp__ = {val}"
147 | self << f"if __tmp__ === DIO_ExceptCode({f})"
148 | for i in instr.decrefs:
149 | self << f" Py_DECREF({self.var_i(i)})"
150 | self << r" @goto except"
151 | self << f"elseif DIO_HasCast({f})"
152 | self << f" {var} = DIO_HasCast({f}, __tmp__)"
153 | self << f" if DIO_CastExc({f})"
154 | self << f" {var} === Py_NULL && return Py_NULL"
155 | self << r" end"
156 | self << r"elseif __tmp__ isa PyPtr"
157 | self << f" {var} = __tmp__"
158 | self << r"elseif __tmp__ isa Integer"
159 | self << f" {var} = DIO_WrapIntValue(__tmp__)"
160 | self << f" {var} === Py_NULL && return Py_NULL"
161 | self << r"else"
162 | self << f" {var} = DIO_NewNone()"
163 | self << r"end"
164 | return
165 | elif isinstance(instr, Out_Label):
166 | self @ f"@label {instr.label}"
167 | return
168 | elif isinstance(instr, Out_Goto):
169 | self << f"@goto {instr.label}"
170 | return
171 | elif isinstance(instr, Out_Return):
172 | val = self.val(instr.value)
173 | self << f"DIO_Return = {val}"
174 | mini_opt = False
175 | if isinstance(instr.value, D):
176 | hold = instr.value.i
177 | if hold in instr.decrefs:
178 | mini_opt = True
179 | for i in instr.decrefs:
180 | if i != hold:
181 | self << f"Py_DECREF({self.var_i(i)})"
182 | if not mini_opt:
183 | self << f"Py_INCREF(DIO_Return)"
184 | for i in instr.decrefs:
185 | self << f"Py_DECREF({self.var_i(i)})"
186 | self << f"@goto ret"
187 | return
188 | elif isinstance(instr, Out_If):
189 | val = self.val(instr.test)
190 | self << f"if {u64o(True)} === reinterpret(UInt64, {val})"
191 | self << f" @goto {instr.t}"
192 | self << "else"
193 | self << f" @goto {instr.f}"
194 | self << "end"
195 | return
196 |
197 | elif isinstance(instr, Out_TypeCase):
198 | val = self.val(instr.obj)
199 | self << f"__type__ = reinterpret(UInt64, Py_TYPE({val}))"
200 | cases = instr.cases
201 | has_any_type = Top in cases
202 | if has_any_type and len(cases) == 1:
203 | self.visit_many(cases[Top])
204 | return
205 | ts = []
206 | headers = chain(["if"], repeat("elseif"))
207 | for i, (typecase, block) in enumerate(instr.cases.items()):
208 | if typecase is Top:
209 | continue
210 | head = next(headers)
211 | ts.append(typecase)
212 | t = ts[-1].base
213 | self << f"# when type is {t}"
214 | self << f"{head} __type__ === {u64o(t)}"
215 | with self.indent_inc():
216 | self.visit_many(block)
217 | if not has_any_type:
218 | # msg = ",".join(map(repr, ts))
219 | self << "else"
220 | self << ' error("analyser produces incorrect return")'
221 | else:
222 | self << "else"
223 | with self.indent_inc():
224 | self.visit_many(cases[Top])
225 | self << "end"
226 | return
227 | elif isinstance(instr, Out_DecRef):
228 | self << f"Py_DECREF({self.var_i(instr.i)})"
229 |
--------------------------------------------------------------------------------
/diojit/runtime/__init__.py:
--------------------------------------------------------------------------------
1 | from . import julia_rt
--------------------------------------------------------------------------------
/diojit/runtime/julia_rt.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import os
3 | import signal
4 | import warnings
5 | import ctypes
6 | import sys
7 | import posixpath
8 | import importlib
9 | import json
10 | import julia.libjulia as jl_libjulia
11 | from json import dumps
12 | from julia.libjulia import LibJulia
13 | from julia.juliainfo import JuliaInfo
14 | from typing import Union
15 | from ..absint.abs import Out_Def as _Out_Def
16 | from ..codegen.julia import Codegen, u64o, splice
17 | import ctypes
18 |
19 | pydll_address = hex(ctypes.pythonapi._handle)
20 | GenerateCache = _Out_Def.GenerateCache
21 |
22 |
23 | def get_libjulia():
24 | global libjl
25 | if not libjl:
26 | libjl = startup()
27 | return libjl
28 |
29 |
30 | class RichCallSubprocessError(subprocess.CalledProcessError):
31 | def __str__(self):
32 | if self.returncode and self.returncode < 0:
33 | try:
34 | return "Command '%s' died with %r." % (
35 | self.cmd,
36 | signal.Signals(-self.returncode),
37 | )
38 | except ValueError:
39 | return "Command '%s' died with unknown signal %d." % (
40 | self.cmd,
41 | -self.returncode,
42 | )
43 | else:
44 | return (
45 | "Command '%s' returned non-zero exit status %d: %s"
46 | % (self.cmd, self.returncode, self.stderr)
47 | )
48 |
49 |
50 | def mk_libjulia(julia="julia", **popen_kwargs):
51 | if lib := getattr(jl_libjulia, "_LIBJULIA"):
52 | return lib
53 |
54 | proc = subprocess.Popen(
55 | [
56 | julia,
57 | "--startup-file=no",
58 | "-e",
59 | "using DIO; DIO.PyJulia_INFO()",
60 | ],
61 | stdout=subprocess.PIPE,
62 | stderr=subprocess.PIPE,
63 | universal_newlines=True,
64 | **popen_kwargs,
65 | )
66 |
67 | stdout, stderr = proc.communicate()
68 | retcode = proc.wait()
69 | if retcode != 0:
70 | raise RichCallSubprocessError(
71 | retcode, [julia, "-e", "..."], stdout, stderr
72 | )
73 |
74 | stderr = stderr.strip()
75 | if stderr:
76 | warnings.warn("{} warned:\n{}".format(julia, stderr))
77 |
78 | args = stdout.rstrip().split("\n")
79 |
80 | libjl = LibJulia.from_juliainfo(JuliaInfo(julia, *args))
81 | libjl.jl_string_ptr.restype = ctypes.c_char_p
82 | libjl.jl_string_ptr.argtypes = [ctypes.c_void_p]
83 | libjl.jl_call1.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
84 | libjl.jl_call1.restype = ctypes.c_void_p
85 | libjl.jl_eval_string.argtypes = [ctypes.c_char_p]
86 | libjl.jl_eval_string.restype = ctypes.c_void_p
87 | libjl.jl_stderr_stream.argtypes = []
88 | libjl.jl_stderr_stream.restype = ctypes.c_void_p
89 | libjl.jl_printf.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
90 | libjl.jl_printf.restype = ctypes.c_int
91 | return libjl
92 |
93 |
94 | class JuliaException(Exception):
95 | def __init__(self, msg):
96 | self.msg = msg
97 |
98 | def __repr__(self):
99 | return f"from julia: {self.msg}"
100 |
101 |
102 | def check_jl_err(libjl: LibJulia):
103 | if o := libjl.jl_exception_occurred():
104 | msg = libjl.jl_string_ptr(
105 | libjl.jl_call1(libjl.jl_eval_string(b"error2str"), o)
106 | ).decode("utf-8")
107 | raise JuliaException(msg)
108 |
109 |
110 | def startup():
111 | global libjl
112 | libjl = mk_libjulia()
113 | libjl.init_julia()
114 | # DIO package already checked when getting libjulia
115 | libjl.jl_eval_string(
116 | b"function error2str(e)\n"
117 | b" sprint(showerror, e; context=:color=>true)\n"
118 | b"end"
119 | )
120 | libjl.jl_eval_string(b"using DIO")
121 | check_jl_err(libjl)
122 |
123 | import builtins
124 |
125 | libjl.jl_eval_string(
126 | str.encode(
127 | f"const PyO = PyOType("
128 | f"PY_VERSION = Tuple({json.dumps(sys.version_info)}),"
129 | f"builtins = {splice(builtins)},"
130 | f"print = {splice(print)},"
131 | f"bool = {splice(bool)},"
132 | f"int = {splice(int)},"
133 | f"float = {splice(float)},"
134 | f"str = {splice(str)},"
135 | f"type = {splice(type)},"
136 | f"True = {splice(True)},"
137 | f"False = {splice(False)},"
138 | f"None = {splice(None)},"
139 | f"complex = {splice(complex)},"
140 | f"tuple = {splice(tuple)},"
141 | f"list = {splice(list)},"
142 | f"set = {splice(set)},"
143 | f"dict = {splice(dict)},"
144 | f"import_module = {splice(importlib.import_module)},"
145 | f")",
146 | encoding="utf-8",
147 | )
148 | )
149 | check_jl_err(libjl)
150 |
151 | libjl.jl_eval_string(
152 | b"DIO.@setup(%s)" % pydll_address.encode("utf-8")
153 | )
154 | check_jl_err(libjl)
155 | libjl.jl_eval_string(b"printerror(x) = println(showerror(x))")
156 | check_jl_err(libjl)
157 | # a = libjl.jl_eval_string(
158 | # b"Py_CallFunction(@DIO_Obj(%s), @DIO_Obj(%s), @DIO_Obj(%s))"
159 | # % (
160 | # Codegen.uint64(id(print)).encode(),
161 | # Codegen.uint64(id(1)).encode(),
162 | # Codegen.uint64(id(3)).encode(),
163 | # )
164 | # )
165 | # check_jl_err(libjl)
166 | return libjl
167 |
168 |
169 | def as_py(res: ctypes.c_void_p):
170 | """
171 | This should be used on the return of a JIT func.
172 | No need to incref as it's already done by the JIT func.
173 | """
174 | libjl = get_libjulia()
175 | if res == 0:
176 | return None
177 | pyobj = libjl.jl_unbox_voidpointer(res)
178 | check_jl_err(libjl)
179 | return pyobj
180 |
181 |
182 | def code_gen(print_jl=None):
183 | libjl = get_libjulia()
184 | interfaces = bytearray()
185 | for out_def in GenerateCache.values():
186 | cg = Codegen(out_def)
187 | interfaces.extend(cg.get_py_interfaces().encode("utf-8"))
188 | definition = cg.get_jl_definitions()
189 | if print_jl:
190 | print_jl(definition)
191 | definition = definition.encode("utf-8")
192 | libjl.jl_eval_string(definition)
193 | check_jl_err(libjl)
194 | if print_jl:
195 | print("# interfaces")
196 | print_jl(interfaces.decode("utf-8"))
197 | libjl.jl_eval_string(bytes(interfaces))
198 | check_jl_err(libjl)
199 | for intrin in GenerateCache:
200 | v = libjl.jl_eval_string(
201 | b"PyFunc_%s" % repr(intrin).encode("utf-8")
202 | )
203 | check_jl_err(libjl)
204 | intrin._callback = as_py(v)
205 |
206 | GenerateCache.clear()
207 |
208 |
209 | def jl_eval(s: Union[str, bytes]):
210 | if isinstance(s, str):
211 | s = s.encode("utf-8")
212 | libjl.jl_eval_string(s)
213 | check_jl_err(libjl)
214 |
215 |
216 | startup()
217 |
--------------------------------------------------------------------------------
/diojit/stack2reg/__init__.py:
--------------------------------------------------------------------------------
1 | from .translate import translate
2 |
--------------------------------------------------------------------------------
/diojit/stack2reg/cflags.py:
--------------------------------------------------------------------------------
1 | import dis
2 | _flags = {v: k for k, v in dis.COMPILER_FLAG_NAMES.items()}
3 | OPTIMIZED = _flags['OPTIMIZED']
4 | NEWLOCALS = _flags['NEWLOCALS']
5 | VARARGS = _flags['VARARGS']
6 | VARKEYWORDS = _flags['VARKEYWORDS']
7 | NESTED = _flags['NESTED']
8 | GENERATOR = _flags['GENERATOR']
9 | NOFREE = _flags['NOFREE']
10 | COROUTINE = _flags['COROUTINE']
11 | ITERABLE_COROUTINE = _flags['ITERABLE_COROUTINE']
12 | ASYNC_GENERATOR = _flags['ASYNC_GENERATOR']
--------------------------------------------------------------------------------
/diojit/stack2reg/opcodes.py:
--------------------------------------------------------------------------------
1 | import opcode
2 | UNKNOWN_INSTR = object()
3 | POP_TOP = opcode.opmap.get('POP_TOP', UNKNOWN_INSTR)
4 | ROT_TWO = opcode.opmap.get('ROT_TWO', UNKNOWN_INSTR)
5 | ROT_THREE = opcode.opmap.get('ROT_THREE', UNKNOWN_INSTR)
6 | DUP_TOP = opcode.opmap.get('DUP_TOP', UNKNOWN_INSTR)
7 | DUP_TOP_TWO = opcode.opmap.get('DUP_TOP_TWO', UNKNOWN_INSTR)
8 | ROT_FOUR = opcode.opmap.get('ROT_FOUR', UNKNOWN_INSTR)
9 | NOP = opcode.opmap.get('NOP', UNKNOWN_INSTR)
10 | UNARY_POSITIVE = opcode.opmap.get('UNARY_POSITIVE', UNKNOWN_INSTR)
11 | UNARY_NEGATIVE = opcode.opmap.get('UNARY_NEGATIVE', UNKNOWN_INSTR)
12 | UNARY_NOT = opcode.opmap.get('UNARY_NOT', UNKNOWN_INSTR)
13 | UNARY_INVERT = opcode.opmap.get('UNARY_INVERT', UNKNOWN_INSTR)
14 | BINARY_MATRIX_MULTIPLY = opcode.opmap.get('BINARY_MATRIX_MULTIPLY', UNKNOWN_INSTR)
15 | INPLACE_MATRIX_MULTIPLY = opcode.opmap.get('INPLACE_MATRIX_MULTIPLY', UNKNOWN_INSTR)
16 | BINARY_POWER = opcode.opmap.get('BINARY_POWER', UNKNOWN_INSTR)
17 | BINARY_MULTIPLY = opcode.opmap.get('BINARY_MULTIPLY', UNKNOWN_INSTR)
18 | BINARY_MODULO = opcode.opmap.get('BINARY_MODULO', UNKNOWN_INSTR)
19 | BINARY_ADD = opcode.opmap.get('BINARY_ADD', UNKNOWN_INSTR)
20 | BINARY_SUBTRACT = opcode.opmap.get('BINARY_SUBTRACT', UNKNOWN_INSTR)
21 | BINARY_SUBSCR = opcode.opmap.get('BINARY_SUBSCR', UNKNOWN_INSTR)
22 | BINARY_FLOOR_DIVIDE = opcode.opmap.get('BINARY_FLOOR_DIVIDE', UNKNOWN_INSTR)
23 | BINARY_TRUE_DIVIDE = opcode.opmap.get('BINARY_TRUE_DIVIDE', UNKNOWN_INSTR)
24 | INPLACE_FLOOR_DIVIDE = opcode.opmap.get('INPLACE_FLOOR_DIVIDE', UNKNOWN_INSTR)
25 | INPLACE_TRUE_DIVIDE = opcode.opmap.get('INPLACE_TRUE_DIVIDE', UNKNOWN_INSTR)
26 | RERAISE = opcode.opmap.get('RERAISE', UNKNOWN_INSTR)
27 | WITH_EXCEPT_START = opcode.opmap.get('WITH_EXCEPT_START', UNKNOWN_INSTR)
28 | GET_AITER = opcode.opmap.get('GET_AITER', UNKNOWN_INSTR)
29 | GET_ANEXT = opcode.opmap.get('GET_ANEXT', UNKNOWN_INSTR)
30 | BEFORE_ASYNC_WITH = opcode.opmap.get('BEFORE_ASYNC_WITH', UNKNOWN_INSTR)
31 | END_ASYNC_FOR = opcode.opmap.get('END_ASYNC_FOR', UNKNOWN_INSTR)
32 | INPLACE_ADD = opcode.opmap.get('INPLACE_ADD', UNKNOWN_INSTR)
33 | INPLACE_SUBTRACT = opcode.opmap.get('INPLACE_SUBTRACT', UNKNOWN_INSTR)
34 | INPLACE_MULTIPLY = opcode.opmap.get('INPLACE_MULTIPLY', UNKNOWN_INSTR)
35 | INPLACE_MODULO = opcode.opmap.get('INPLACE_MODULO', UNKNOWN_INSTR)
36 | STORE_SUBSCR = opcode.opmap.get('STORE_SUBSCR', UNKNOWN_INSTR)
37 | DELETE_SUBSCR = opcode.opmap.get('DELETE_SUBSCR', UNKNOWN_INSTR)
38 | BINARY_LSHIFT = opcode.opmap.get('BINARY_LSHIFT', UNKNOWN_INSTR)
39 | BINARY_RSHIFT = opcode.opmap.get('BINARY_RSHIFT', UNKNOWN_INSTR)
40 | BINARY_AND = opcode.opmap.get('BINARY_AND', UNKNOWN_INSTR)
41 | BINARY_XOR = opcode.opmap.get('BINARY_XOR', UNKNOWN_INSTR)
42 | BINARY_OR = opcode.opmap.get('BINARY_OR', UNKNOWN_INSTR)
43 | INPLACE_POWER = opcode.opmap.get('INPLACE_POWER', UNKNOWN_INSTR)
44 | GET_ITER = opcode.opmap.get('GET_ITER', UNKNOWN_INSTR)
45 | GET_YIELD_FROM_ITER = opcode.opmap.get('GET_YIELD_FROM_ITER', UNKNOWN_INSTR)
46 | PRINT_EXPR = opcode.opmap.get('PRINT_EXPR', UNKNOWN_INSTR)
47 | LOAD_BUILD_CLASS = opcode.opmap.get('LOAD_BUILD_CLASS', UNKNOWN_INSTR)
48 | YIELD_FROM = opcode.opmap.get('YIELD_FROM', UNKNOWN_INSTR)
49 | GET_AWAITABLE = opcode.opmap.get('GET_AWAITABLE', UNKNOWN_INSTR)
50 | LOAD_ASSERTION_ERROR = opcode.opmap.get('LOAD_ASSERTION_ERROR', UNKNOWN_INSTR)
51 | INPLACE_LSHIFT = opcode.opmap.get('INPLACE_LSHIFT', UNKNOWN_INSTR)
52 | INPLACE_RSHIFT = opcode.opmap.get('INPLACE_RSHIFT', UNKNOWN_INSTR)
53 | INPLACE_AND = opcode.opmap.get('INPLACE_AND', UNKNOWN_INSTR)
54 | INPLACE_XOR = opcode.opmap.get('INPLACE_XOR', UNKNOWN_INSTR)
55 | INPLACE_OR = opcode.opmap.get('INPLACE_OR', UNKNOWN_INSTR)
56 | LIST_TO_TUPLE = opcode.opmap.get('LIST_TO_TUPLE', UNKNOWN_INSTR)
57 | RETURN_VALUE = opcode.opmap.get('RETURN_VALUE', UNKNOWN_INSTR)
58 | IMPORT_STAR = opcode.opmap.get('IMPORT_STAR', UNKNOWN_INSTR)
59 | SETUP_ANNOTATIONS = opcode.opmap.get('SETUP_ANNOTATIONS', UNKNOWN_INSTR)
60 | YIELD_VALUE = opcode.opmap.get('YIELD_VALUE', UNKNOWN_INSTR)
61 | POP_BLOCK = opcode.opmap.get('POP_BLOCK', UNKNOWN_INSTR)
62 | POP_EXCEPT = opcode.opmap.get('POP_EXCEPT', UNKNOWN_INSTR)
63 | STORE_NAME = opcode.opmap.get('STORE_NAME', UNKNOWN_INSTR)
64 | DELETE_NAME = opcode.opmap.get('DELETE_NAME', UNKNOWN_INSTR)
65 | UNPACK_SEQUENCE = opcode.opmap.get('UNPACK_SEQUENCE', UNKNOWN_INSTR)
66 | FOR_ITER = opcode.opmap.get('FOR_ITER', UNKNOWN_INSTR)
67 | UNPACK_EX = opcode.opmap.get('UNPACK_EX', UNKNOWN_INSTR)
68 | STORE_ATTR = opcode.opmap.get('STORE_ATTR', UNKNOWN_INSTR)
69 | DELETE_ATTR = opcode.opmap.get('DELETE_ATTR', UNKNOWN_INSTR)
70 | STORE_GLOBAL = opcode.opmap.get('STORE_GLOBAL', UNKNOWN_INSTR)
71 | DELETE_GLOBAL = opcode.opmap.get('DELETE_GLOBAL', UNKNOWN_INSTR)
72 | LOAD_CONST = opcode.opmap.get('LOAD_CONST', UNKNOWN_INSTR)
73 | LOAD_NAME = opcode.opmap.get('LOAD_NAME', UNKNOWN_INSTR)
74 | BUILD_TUPLE = opcode.opmap.get('BUILD_TUPLE', UNKNOWN_INSTR)
75 | BUILD_LIST = opcode.opmap.get('BUILD_LIST', UNKNOWN_INSTR)
76 | BUILD_SET = opcode.opmap.get('BUILD_SET', UNKNOWN_INSTR)
77 | BUILD_MAP = opcode.opmap.get('BUILD_MAP', UNKNOWN_INSTR)
78 | LOAD_ATTR = opcode.opmap.get('LOAD_ATTR', UNKNOWN_INSTR)
79 | COMPARE_OP = opcode.opmap.get('COMPARE_OP', UNKNOWN_INSTR)
80 | IMPORT_NAME = opcode.opmap.get('IMPORT_NAME', UNKNOWN_INSTR)
81 | IMPORT_FROM = opcode.opmap.get('IMPORT_FROM', UNKNOWN_INSTR)
82 | JUMP_FORWARD = opcode.opmap.get('JUMP_FORWARD', UNKNOWN_INSTR)
83 | JUMP_IF_FALSE_OR_POP = opcode.opmap.get('JUMP_IF_FALSE_OR_POP', UNKNOWN_INSTR)
84 | JUMP_IF_TRUE_OR_POP = opcode.opmap.get('JUMP_IF_TRUE_OR_POP', UNKNOWN_INSTR)
85 | JUMP_ABSOLUTE = opcode.opmap.get('JUMP_ABSOLUTE', UNKNOWN_INSTR)
86 | POP_JUMP_IF_FALSE = opcode.opmap.get('POP_JUMP_IF_FALSE', UNKNOWN_INSTR)
87 | POP_JUMP_IF_TRUE = opcode.opmap.get('POP_JUMP_IF_TRUE', UNKNOWN_INSTR)
88 | LOAD_GLOBAL = opcode.opmap.get('LOAD_GLOBAL', UNKNOWN_INSTR)
89 | IS_OP = opcode.opmap.get('IS_OP', UNKNOWN_INSTR)
90 | CONTAINS_OP = opcode.opmap.get('CONTAINS_OP', UNKNOWN_INSTR)
91 | JUMP_IF_NOT_EXC_MATCH = opcode.opmap.get('JUMP_IF_NOT_EXC_MATCH', UNKNOWN_INSTR)
92 | SETUP_FINALLY = opcode.opmap.get('SETUP_FINALLY', UNKNOWN_INSTR)
93 | LOAD_FAST = opcode.opmap.get('LOAD_FAST', UNKNOWN_INSTR)
94 | STORE_FAST = opcode.opmap.get('STORE_FAST', UNKNOWN_INSTR)
95 | DELETE_FAST = opcode.opmap.get('DELETE_FAST', UNKNOWN_INSTR)
96 | RAISE_VARARGS = opcode.opmap.get('RAISE_VARARGS', UNKNOWN_INSTR)
97 | CALL_FUNCTION = opcode.opmap.get('CALL_FUNCTION', UNKNOWN_INSTR)
98 | MAKE_FUNCTION = opcode.opmap.get('MAKE_FUNCTION', UNKNOWN_INSTR)
99 | BUILD_SLICE = opcode.opmap.get('BUILD_SLICE', UNKNOWN_INSTR)
100 | LOAD_CLOSURE = opcode.opmap.get('LOAD_CLOSURE', UNKNOWN_INSTR)
101 | LOAD_DEREF = opcode.opmap.get('LOAD_DEREF', UNKNOWN_INSTR)
102 | STORE_DEREF = opcode.opmap.get('STORE_DEREF', UNKNOWN_INSTR)
103 | DELETE_DEREF = opcode.opmap.get('DELETE_DEREF', UNKNOWN_INSTR)
104 | CALL_FUNCTION_KW = opcode.opmap.get('CALL_FUNCTION_KW', UNKNOWN_INSTR)
105 | CALL_FUNCTION_EX = opcode.opmap.get('CALL_FUNCTION_EX', UNKNOWN_INSTR)
106 | SETUP_WITH = opcode.opmap.get('SETUP_WITH', UNKNOWN_INSTR)
107 | LIST_APPEND = opcode.opmap.get('LIST_APPEND', UNKNOWN_INSTR)
108 | SET_ADD = opcode.opmap.get('SET_ADD', UNKNOWN_INSTR)
109 | MAP_ADD = opcode.opmap.get('MAP_ADD', UNKNOWN_INSTR)
110 | LOAD_CLASSDEREF = opcode.opmap.get('LOAD_CLASSDEREF', UNKNOWN_INSTR)
111 | EXTENDED_ARG = opcode.opmap.get('EXTENDED_ARG', UNKNOWN_INSTR)
112 | SETUP_ASYNC_WITH = opcode.opmap.get('SETUP_ASYNC_WITH', UNKNOWN_INSTR)
113 | FORMAT_VALUE = opcode.opmap.get('FORMAT_VALUE', UNKNOWN_INSTR)
114 | BUILD_CONST_KEY_MAP = opcode.opmap.get('BUILD_CONST_KEY_MAP', UNKNOWN_INSTR)
115 | BUILD_STRING = opcode.opmap.get('BUILD_STRING', UNKNOWN_INSTR)
116 | LOAD_METHOD = opcode.opmap.get('LOAD_METHOD', UNKNOWN_INSTR)
117 | CALL_METHOD = opcode.opmap.get('CALL_METHOD', UNKNOWN_INSTR)
118 | LIST_EXTEND = opcode.opmap.get('LIST_EXTEND', UNKNOWN_INSTR)
119 | SET_UPDATE = opcode.opmap.get('SET_UPDATE', UNKNOWN_INSTR)
120 | DICT_MERGE = opcode.opmap.get('DICT_MERGE', UNKNOWN_INSTR)
121 | DICT_UPDATE = opcode.opmap.get('DICT_UPDATE', UNKNOWN_INSTR)
122 |
123 |
--------------------------------------------------------------------------------
/diojit/stack2reg/translate.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import typing as _t
3 | import dis
4 | import types as _types
5 | import operator
6 | from contextlib import contextmanager
7 | from . import opcodes
8 | from . import cflags
9 | from diojit.absint import *
10 |
11 | __all__ = ["translate"]
12 |
13 | JUMP_NAMES: frozenset[int] = frozenset(
14 | {
15 | opcodes.POP_JUMP_IF_TRUE,
16 | opcodes.POP_JUMP_IF_FALSE,
17 | opcodes.JUMP_IF_TRUE_OR_POP,
18 | opcodes.JUMP_IF_FALSE_OR_POP,
19 | }
20 | )
21 |
22 |
23 | BIN_OPS: dict[int, FunctionType] = {
24 | opcodes.BINARY_POWER: operator.__pow__,
25 | opcodes.BINARY_MULTIPLY: operator.__mul__,
26 | opcodes.BINARY_MATRIX_MULTIPLY: operator.__matmul__,
27 | opcodes.BINARY_FLOOR_DIVIDE: operator.__floordiv__,
28 | opcodes.BINARY_TRUE_DIVIDE: operator.__truediv__,
29 | opcodes.BINARY_MODULO: operator.__mod__,
30 | opcodes.BINARY_ADD: operator.__add__,
31 | opcodes.BINARY_SUBTRACT: operator.__sub__,
32 | opcodes.BINARY_SUBSCR: operator.__getitem__,
33 | opcodes.BINARY_LSHIFT: operator.__lshift__,
34 | opcodes.BINARY_RSHIFT: operator.__rshift__,
35 | opcodes.BINARY_AND: operator.__and__,
36 | opcodes.BINARY_XOR: operator.__xor__,
37 | opcodes.BINARY_OR: operator.__or__,
38 | }
39 |
40 |
41 | INP_BIN_OPS: dict[int, FunctionType] = {
42 | opcodes.INPLACE_POWER: operator.__ipow__,
43 | opcodes.INPLACE_MULTIPLY: operator.__imul__,
44 | opcodes.INPLACE_MATRIX_MULTIPLY: operator.__imatmul__,
45 | opcodes.INPLACE_FLOOR_DIVIDE: operator.__ifloordiv__,
46 | opcodes.INPLACE_TRUE_DIVIDE: operator.__itruediv__,
47 | opcodes.INPLACE_MODULO: operator.__imod__,
48 | opcodes.INPLACE_ADD: operator.__iadd__,
49 | opcodes.INPLACE_SUBTRACT: operator.__isub__,
50 | opcodes.INPLACE_LSHIFT: operator.__ilshift__,
51 | opcodes.INPLACE_RSHIFT: operator.__irshift__,
52 | opcodes.INPLACE_AND: operator.__iand__,
53 | opcodes.INPLACE_XOR: operator.__ixor__,
54 | opcodes.INPLACE_OR: operator.__ior__,
55 | }
56 |
57 |
58 | CMP_OPS: dict[str, FunctionType] = _t.cast(
59 | _t.Dict[str, FunctionType],
60 | {
61 | "<": operator.__lt__,
62 | ">": operator.__gt__,
63 | "<=": operator.__le__,
64 | ">=": operator.__ge__,
65 | "!=": operator.__ne__,
66 | "==": operator.__eq__,
67 | },
68 | )
69 |
70 |
71 | def flags_check(flag, *patterns):
72 | return any(flag & p for p in patterns)
73 |
74 |
75 | class PyC:
76 | def __init__(self, f):
77 | co = dis.Bytecode(f)
78 | code: _types.CodeType = co.codeobj
79 | if flags_check(
80 | code.co_flags,
81 | cflags.VARARGS,
82 | cflags.ASYNC_GENERATOR,
83 | cflags.ITERABLE_COROUTINE,
84 | cflags.VARKEYWORDS,
85 | cflags.GENERATOR,
86 | cflags.COROUTINE,
87 | ):
88 | raise ValueError("varargs/coroutine/varkeywords/generator are not supported yet.")
89 |
90 | self.co = list(co)
91 | self.label_to_co_offsets = _map = {}
92 | for i, instr in enumerate(
93 | self.co
94 | ): # type: int, dis.Instruction
95 | if instr.is_jump_target:
96 | _map[instr.offset] = i
97 | if instr.opcode in JUMP_NAMES:
98 | instr = self.co[i + 1]
99 | _map[instr.offset] = i + 1
100 | self.glob_names = set()
101 | self.offset = 0
102 | self.codeobj = code
103 | self.stack_size = len(code.co_varnames)
104 | assert (
105 | not code.co_cellvars and not code.co_freevars
106 | ), "cannot handle closures so far"
107 | self.cur_block = entry_block = []
108 | self.blocks: In_Blocks = {"entry": entry_block}
109 | self.block_maps: _t.Dict[_t.Tuple[int, int], str] = {}
110 | self.label_cnt = 0
111 | self.lastlinenumber = None
112 | self.filename = code.co_filename
113 |
114 | def make(self):
115 | self.interp(0)
116 | return self.blocks, self.glob_names
117 |
118 | def is_jump_target(self, bytecode_offset: int):
119 | return bytecode_offset in self.label_to_co_offsets
120 |
121 | def gen_label(self):
122 | self.label_cnt += 1
123 | return f"l{self.label_cnt}"
124 |
125 | def codegen(self, *stmts: In_Stmt):
126 | self.cur_block.extend(stmts)
127 |
128 | def cur(self) -> dis.Instruction:
129 | return self.co[self.offset]
130 |
131 | def next(self) -> dis.Instruction:
132 | return self.co[self.offset + 1]
133 |
134 | def push(self, a: AbsVal):
135 | self.stack_size += 1
136 | tos = self.peek(0)
137 | if tos != a:
138 | self.codegen(In_Move(tos, a))
139 |
140 | def pop(self):
141 | tos = self.peek(0)
142 | self.stack_size -= 1
143 | return tos
144 |
145 | def call(self, f: AbsVal, *args: AbsVal):
146 | assert all(isinstance(a, AbsVal) for a in args)
147 | self.stack_size += 1
148 | tos = self.peek(0)
149 | self.codegen(In_Bind(tos, f, S("__call__"), args))
150 |
151 | def call_method(self, m: AbsVal, attr: AbsVal, *args: AbsVal):
152 | self.stack_size += 1
153 | tos = self.peek(0)
154 | self.codegen(In_Bind(tos, m, attr, args))
155 |
156 | def peek(self, i: int):
157 | return D(self.stack_size - i - 1, Top)
158 |
159 | def varlocal(self, i: int):
160 | return D(i, Top)
161 |
162 | def find_offset(self, bytecode_offset: int):
163 | return self.label_to_co_offsets[bytecode_offset]
164 |
165 | @contextmanager
166 | def generate_new_block(self, offset: int):
167 | old_offset = self.offset
168 | old_cur_block = self.cur_block
169 | old_stack_size = self.stack_size
170 | try:
171 | key = offset, self.stack_size
172 | self.offset = offset
173 | label = self.block_maps[key] = self.gen_label()
174 | self.cur_block = self.blocks[label] = []
175 | yield label
176 | finally:
177 | self.offset = old_offset
178 | self.cur_block = old_cur_block
179 | self.stack_size = old_stack_size
180 |
181 | def jump(self, offset: int) -> str:
182 | key = offset, self.stack_size
183 | if key in self.block_maps:
184 | return self.block_maps[key]
185 |
186 | with self.generate_new_block(offset) as label:
187 | self.interp(offset)
188 | return label
189 |
190 | def get_nargs(self, n: int):
191 | args = []
192 | for _ in range(n):
193 | arg = self.pop()
194 | args.append(arg)
195 | args.reverse()
196 | return args
197 |
198 | def interp(self, offset=0):
199 | while True:
200 | x: dis.Instruction = self.cur()
201 | if (
202 | x.starts_line is not None
203 | and x.starts_line != self.lastlinenumber
204 | ):
205 | line = x.starts_line
206 | self.codegen(In_SetLineno(line, self.filename))
207 |
208 | if self.is_jump_target(x.offset) and self.offset != offset:
209 | label = self.jump(self.offset)
210 | self.codegen(In_Goto(label))
211 | return
212 |
213 | if x.opcode is opcodes.LOAD_CONST:
214 | argval = x.argval
215 | self.push(S(argval))
216 | elif x.opcode is opcodes.LOAD_FAST:
217 | var = self.varlocal(x.arg)
218 | self.push(var)
219 | elif x.opcode is opcodes.STORE_FAST:
220 | tos = self.pop()
221 | var = self.varlocal(x.arg)
222 | self.codegen(In_Move(var, tos))
223 | elif x.opcode is opcodes.LOAD_GLOBAL:
224 | self.require_global(x.argval)
225 | self.call(S(Intrinsic.Py_LoadGlobal), S(x.argval))
226 | elif x.opcode is opcodes.STORE_GLOBAL:
227 | self.require_global(x.argval)
228 | self.call(S(Intrinsic.Py_StoreGlobal), S(x.argval))
229 | elif x.opcode is opcodes.STORE_ATTR:
230 | a_base = self.pop()
231 | a_value = self.pop()
232 | self.call(
233 | S(setattr),
234 | a_base,
235 | S(x.argval),
236 | a_value,
237 | )
238 | elif (
239 | x.opcode is opcodes.JUMP_ABSOLUTE
240 | or x.opcode is opcodes.JUMP_FORWARD
241 | ):
242 | label = self.jump(self.find_offset(x.argval))
243 | self.codegen(In_Goto(label))
244 | return
245 | elif x.opcode is opcodes.JUMP_IF_TRUE_OR_POP:
246 | b_off_1 = self.find_offset(x.arg)
247 | b_off_2 = self.find_offset(self.next().offset)
248 | tos = self.peek(0)
249 | l1 = self.jump(b_off_1)
250 | self.pop()
251 | l2 = self.jump(b_off_2)
252 | self.codegen(In_Cond(tos, l1, l2))
253 | return
254 | elif x.opcode is opcodes.JUMP_IF_FALSE_OR_POP:
255 | b_off_1 = self.find_offset(x.arg)
256 | b_off_2 = self.find_offset(self.next().offset)
257 | tos = self.peek(0)
258 | l1 = self.jump(b_off_1)
259 | self.pop()
260 | l2 = self.jump(b_off_2)
261 | self.codegen(In_Cond(tos, l2, l1))
262 | return
263 | elif x.opcode is opcodes.POP_JUMP_IF_TRUE:
264 | b_off_1 = self.find_offset(x.arg)
265 | b_off_2 = self.find_offset(self.next().offset)
266 | tos = self.pop()
267 | l1 = self.jump(b_off_1)
268 | l2 = self.jump(b_off_2)
269 | self.codegen(In_Cond(tos, l1, l2))
270 | return
271 | elif x.opcode is opcodes.POP_JUMP_IF_FALSE:
272 | b_off_1 = self.find_offset(x.arg)
273 | b_off_2 = self.find_offset(self.next().offset)
274 | tos = self.pop()
275 | l1 = self.jump(b_off_1)
276 | l2 = self.jump(b_off_2)
277 | self.codegen(In_Cond(tos, l2, l1))
278 | return
279 | elif x.opcode is opcodes.LOAD_METHOD:
280 | self.push(S(x.argval))
281 | elif x.opcode is opcodes.LOAD_ATTR:
282 | tos = self.pop()
283 | self.call(S(getattr), tos, S(x.argval))
284 | elif x.opcode is opcodes.CALL_METHOD:
285 | args = self.get_nargs(x.argval)
286 | attr = self.pop()
287 | subj = self.pop()
288 | self.call_method(subj, attr, *args)
289 | elif x.opcode is opcodes.CALL_FUNCTION:
290 | args = self.get_nargs(x.argval)
291 | f = self.pop()
292 | self.call(f, *args)
293 | elif x.opcode is opcodes.ROT_TWO:
294 | self.push(self.peek(0))
295 | self.codegen(In_Move(self.peek(1), self.peek(2)))
296 | subj = self.peek(2)
297 | self.codegen(In_Move(subj, self.pop()))
298 | elif x.opcode is opcodes.ROT_THREE:
299 | a1 = self.peek(0)
300 | a2 = self.peek(1)
301 | a3 = self.peek(2)
302 |
303 | self.push(a1)
304 | b1 = self.peek(0)
305 | self.push(a2)
306 | b2 = self.peek(0)
307 | self.push(a3)
308 | b3 = self.peek(0)
309 |
310 | self.codegen(In_Move(a3, b1))
311 | self.codegen(In_Move(a2, b3))
312 | self.codegen(In_Move(a1, b2))
313 |
314 | self.pop()
315 | self.pop()
316 | self.pop()
317 | elif x.opcode is opcodes.ROT_FOUR:
318 | a1 = self.peek(0)
319 | a2 = self.peek(1)
320 | a3 = self.peek(2)
321 | a4 = self.peek(3)
322 |
323 | self.push(a1)
324 | b1 = self.peek(0)
325 | self.push(a2)
326 | b2 = self.peek(0)
327 | self.push(a3)
328 | b3 = self.peek(0)
329 | self.push(a4)
330 | b4 = self.peek(0)
331 |
332 | self.codegen(In_Move(a4, b1))
333 | self.codegen(In_Move(a3, b4))
334 | self.codegen(In_Move(a2, b3))
335 | self.codegen(In_Move(a1, b2))
336 |
337 | self.pop()
338 | self.pop()
339 | self.pop()
340 | self.pop()
341 | elif x.opcode is opcodes.DUP_TOP:
342 | self.push(self.peek(0))
343 | elif x.opcode is opcodes.DUP_TOP_TWO:
344 | a = self.peek(1)
345 | b = self.peek(0)
346 | self.push(a)
347 | self.push(b)
348 | elif x.opcode is opcodes.BINARY_SUBSCR:
349 | # | TOS1 | TOS |
350 | # TOS1[TOS]
351 | tos = self.pop()
352 | tos1 = self.pop()
353 | self.call(S(operator.__getitem__), tos1, tos)
354 | elif x.opname.startswith("BINARY_"):
355 | right, left = self.pop(), self.pop()
356 | self.call(S(BIN_OPS[x.opcode]), left, right)
357 | elif x.opname.startswith("INPLACE_"):
358 | right, left = self.pop(), self.pop()
359 | self.call(S(INP_BIN_OPS[x.opcode]), left, right)
360 | elif x.opcode is opcodes.BUILD_TUPLE:
361 | args = self.get_nargs(x.argval)
362 | self.call(S(Intrinsic.Py_BuildTuple), *args)
363 | elif x.opcode is opcodes.BUILD_LIST:
364 | args = self.get_nargs(x.argval)
365 | self.call(
366 | S(Intrinsic.Py_BuildList),
367 | *args,
368 | )
369 | elif x.opcode is opcodes.RETURN_VALUE:
370 | tos = self.pop()
371 | self.codegen(In_Return(tos))
372 | return
373 | # python 3.9
374 | elif x.opcode is opcodes.IS_OP:
375 | right, left = self.pop(), self.pop()
376 | if x.argval != 1:
377 | self.call(S(operator.is_), left, right)
378 | else:
379 | self.call(S(operator.is_), left, right)
380 | self.call(S(operator.__not__), self.pop())
381 | elif x.opcode is opcodes.CONTAINS_OP:
382 | right, left = self.pop(), self.pop()
383 | if x.argval != 1:
384 | self.call(S(operator.__contains__), right, left)
385 | else:
386 | self.call(S(operator.__contains__), right, left)
387 | self.call(S(operator.__not__), self.pop())
388 | elif x.opcode is opcodes.COMPARE_OP:
389 | cmp_name = dis.cmp_op[x.arg]
390 | right, left = self.pop(), self.pop()
391 | if cmp_name == "not_in":
392 | self.call(S(operator.__contains__), right, left)
393 | self.call(S(operator.__not__), self.pop())
394 | elif cmp_name == "in":
395 | self.call(S(operator.__contains__), right, left)
396 | elif cmp_name == "exception match":
397 | raise NotImplemented
398 | elif cmp_name == "is":
399 | self.call(S(operator.is_), left, right)
400 | elif cmp_name == "is not":
401 | self.call(S(operator.is_), left, right)
402 | self.call(S(operator.__not__), self.pop())
403 | else:
404 | self.call(S(CMP_OPS[cmp_name]), left, right)
405 |
406 | elif x.opcode is opcodes.DELETE_SUBSCR:
407 | tos = self.pop()
408 | tos1 = self.pop()
409 | self.call(S(operator.__delitem__), tos1, tos)
410 | self.pop()
411 | elif x.opcode is opcodes.STORE_SUBSCR:
412 | # | TOS2 | TOS1 | TOS |
413 | # TOS1[TOS] = TOS2
414 | tos = self.pop()
415 | tos1 = self.pop()
416 | tos2 = self.pop()
417 | self.call(S(operator.__setitem__), tos1, tos, tos2)
418 | self.pop()
419 | elif x.opcode is opcodes.BUILD_SLICE:
420 | # """
421 | # Pushes a slice object on the stack. argc must be 2 or 3. If it is 2, slice(TOS1, TOS) is pushed;
422 | # if it is 3, slice(TOS2, TOS1, TOS) is pushed. See the slice() built-in function for more information.
423 | # """
424 | tos = self.pop()
425 | tos1 = self.pop()
426 | if x.argval == 2:
427 | self.call(S(slice), tos1, tos)
428 | elif x.argval == 3:
429 | tos2 = self.pop()
430 | self.call(S(slice), tos2, tos1, tos)
431 | else:
432 | raise ValueError
433 |
434 | elif x.opcode is opcodes.POP_TOP:
435 | self.pop()
436 | else:
437 | raise ValueError(x.opname)
438 | self.offset += 1
439 |
440 | def build_const_tuple(self, argval: tuple):
441 | for each in argval:
442 | if isinstance(each, tuple):
443 | self.build_const_tuple(each)
444 | else:
445 | self.push(each)
446 |
447 | def require_global(self, argval):
448 | self.glob_names.add(argval)
449 |
450 |
451 | def translate(f):
452 | pyc = PyC(f)
453 | return pyc.make()
454 |
--------------------------------------------------------------------------------
/diojit/user/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thautwarm/diojit/87b738a2edc9242ac82b2c34744d173f8c5005c2/diojit/user/__init__.py
--------------------------------------------------------------------------------
/diojit/user/client.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import typing
3 | import dataclasses
4 | import sys
5 | from diojit.absint.intrinsics import intrinsic
6 | from diojit.absint.prescr import create_shape, register
7 |
8 | from ..absint.abs import (
9 | CallSpec,
10 | FunctionType,
11 | S,
12 | Top,
13 | AbsVal,
14 | )
15 | from .. import absint
16 | from ..stack2reg.translate import translate
17 | from typing import Iterable
18 |
19 | __all__ = [
20 | "eagerjit",
21 | "conservativejit",
22 | "jit",
23 | "eager_jitclass",
24 | "jitclass", "spec_call", "spec_call_ir",
25 | "oftype",
26 | "ofval",
27 | ]
28 |
29 |
30 | def _is_function(o):
31 | return isinstance(o, type(_is_function))
32 |
33 |
34 | def jit(
35 | func: absint.FunctionType = None,
36 | fixed_references: Iterable[str] = None,
37 | ):
38 |
39 | if fixed_references is not None:
40 | fixed_references = set(fixed_references)
41 | if not func:
42 | return lambda func: _jit(func, fixed_references)
43 | return _jit(func, fixed_references)
44 | else:
45 | assert func
46 | return _jit(func, set())
47 |
48 |
49 | def _jit(func: absint.FunctionType, glob: set[str]):
50 |
51 | code = func.__code__
52 | blocks, glob_names = translate(func)
53 | static_global = glob_names & glob
54 | in_def = absint.In_Def(
55 | code.co_argcount,
56 | blocks,
57 | func,
58 | static_global,
59 | )
60 | absint.In_Def.UserCodeDyn[func] = in_def
61 | return func
62 |
63 |
64 | @dataclasses.dataclass
65 | class Val:
66 | a: object
67 |
68 |
69 | def spec_call_ir(
70 | f: FunctionType, *args, attr="__call__", glob=None
71 | ):
72 | narg = f.__code__.co_argcount
73 | assert (
74 | _is_function(f) and len(args) == narg
75 | ), f"Function {f} takes exactly {narg} arguments."
76 |
77 | rt_map = []
78 | a_args = []
79 | for arg in args:
80 | if isinstance(arg, Val):
81 | a_args.append(absint.from_runtime(arg.a))
82 | else:
83 | assert isinstance(arg, absint.AbsVal)
84 | a_args.append(absint.D(len(rt_map), arg))
85 | a_f = absint.from_runtime(f, rt_map)
86 | j = absint.Judge({}, f, {} if glob is None else glob)
87 | return j.spec(a_f, attr, a_args)
88 |
89 |
90 | _code_gen = None
91 |
92 |
93 | def spec_call(
94 | f: absint.FunctionType,
95 | *args,
96 | attr="__call__",
97 | glob=None,
98 | print_jl=None,
99 | print_dio_ir=None,
100 | ):
101 | global _code_gen
102 | if not _code_gen:
103 | from diojit.runtime.julia_rt import code_gen
104 |
105 | _code_gen = code_gen
106 |
107 | narg = f.__code__.co_argcount
108 | assert (
109 | _is_function(f) and len(args) == narg
110 | ), f"Function {f} takes exactly {narg} arguments."
111 | rt_map = []
112 | a_args = []
113 | for arg in args:
114 | if isinstance(arg, Val):
115 | a_args.append(absint.from_runtime(arg.a))
116 | else:
117 | assert isinstance(arg, absint.AbsVal), arg
118 | a_args.append(absint.D(len(rt_map), arg))
119 | a_f = absint.from_runtime(f, rt_map)
120 | j = absint.Judge({}, f, {} if glob is None else glob)
121 | spec = j.spec(a_f, attr, a_args)
122 | jit_f = spec.e_call.func.base
123 | assert isinstance(jit_f, absint.Intrinsic)
124 | if print_dio_ir:
125 | from diojit.runtime.julia_rt import GenerateCache
126 |
127 | for each in GenerateCache.values():
128 | each.show(print_dio_ir)
129 |
130 | _code_gen(print_jl)
131 | return getattr(jit_f, "_callback")
132 |
133 |
134 | def oftype(t: object):
135 | """
136 | create an abstract value from type object
137 | """
138 | abs = absint.from_runtime(t)
139 | if not isinstance(abs, absint.D):
140 | return abs
141 | raise TypeError(f"{t} is not a type")
142 |
143 |
144 | def ofval(o: object):
145 | return Val(o)
146 |
147 |
148 | def eagerjit(func: typing.Union[FunctionType, type]):
149 | if isinstance(func, type):
150 | cls = func
151 | return eager_jitclass(cls)
152 | assert isinstance(func, FunctionType)
153 | fixed_references = func.__code__.co_names
154 | return _jit(func, set(fixed_references))
155 |
156 |
157 | def conservativejit(
158 | func: typing.Union[FunctionType, type], fixed_references=()
159 | ):
160 | fixed_references = set(fixed_references)
161 | if isinstance(func, type):
162 | cls = func
163 | return jitclass(cls, fixed_references)
164 | assert isinstance(func, FunctionType)
165 | fixed_references = func.__code__.co_names
166 | return _jit(func, fixed_references)
167 |
168 |
169 | _cache = {}
170 |
171 | u_inst_type = type(typing.Union[int, float])
172 |
173 |
174 | def process_annotations(anns: dict, glob: dict):
175 | if ret := _cache.get(id(anns)):
176 | return ret
177 |
178 | def each(v):
179 | if isinstance(v, str):
180 | v = eval(v, glob)
181 | if hasattr(typing, "GenericAlias") and isinstance(
182 | v, typing.GenericAlias
183 | ):
184 | v = v.__origin__
185 | elif type(v) is u_inst_type and v.__origin__ is typing.Union:
186 | return tuple(each(a) for a in v.__args__)
187 |
188 | assert isinstance(v, type), v
189 | return S(v)
190 |
191 | ret = {k: each(v) for k, v in anns.items()}
192 | _cache[id(anns)] = ret
193 | return ret
194 |
195 |
196 | def eager_jitclass(cls: type):
197 | shape = create_shape(cls, oop=True)
198 | if annotations := getattr(cls, "__annotations__", None):
199 |
200 | def get_attr(self, *args: AbsVal):
201 | if len(args) != 2:
202 | return NotImplemented
203 | a_obj, a_attr = args
204 | if a_attr.is_literal() and a_attr.base in annotations:
205 | if ret_types := process_annotations(
206 | annotations, sys.modules[cls.__module__].__dict__
207 | ).get(a_attr.base):
208 | if not isinstance(ret_types, tuple):
209 | assert isinstance(ret_types, AbsVal)
210 | ret_types = (ret_types,)
211 | func = S(intrinsic("PyObject_GetAttr"))
212 | return CallSpec(
213 | None, func(a_obj, a_attr), ret_types
214 | )
215 | return NotImplemented
216 |
217 | register(cls, attr="__getattr__")(get_attr)
218 |
219 | for each, f in cls.__dict__.items():
220 | if not each.startswith("__") and isinstance(f, FunctionType):
221 | eagerjit(f)
222 | shape.fields[each] = S(f)
223 | return cls
224 |
225 |
226 | def jitclass(
227 | cls: type,
228 | fixed_references: Iterable[str] = (),
229 | meth_jit_policy=conservativejit,
230 | jit_methods: typing.Union[all, Iterable[str]] = all,
231 | ):
232 | fixed_references = set(fixed_references)
233 | shape = create_shape(cls, oop=True)
234 | if annotations := getattr(cls, "__annotations__", None):
235 |
236 | def get_attr(self, *args: AbsVal):
237 | if len(args) != 2:
238 | return NotImplemented
239 | a_obj, a_attr = args
240 | if a_attr.is_literal() and a_attr.base in annotations:
241 | if ret_types := process_annotations(
242 | annotations, sys.modules[cls.__module__].__dict__
243 | ).get(a_attr.base):
244 | if not isinstance(ret_types, tuple):
245 | assert isinstance(ret_types, AbsVal)
246 | ret_types = (ret_types,)
247 | func = S(intrinsic("PyObject_GetAttr"))
248 | return CallSpec(
249 | None, func(a_obj, a_attr), (*ret_types, Top)
250 | )
251 | return NotImplemented
252 |
253 | register(cls, attr="__getattr__")(get_attr)
254 |
255 | for each, f in (
256 | jit_methods is all and cls.__dict__.items() or jit_methods
257 | ):
258 | if not each.startswith("__") and isinstance(f, FunctionType):
259 | meth_jit_policy(f, fixed_references)
260 | shape.fields[each] = S(f)
261 | return cls
262 |
--------------------------------------------------------------------------------
/genopname.py:
--------------------------------------------------------------------------------
1 | import opcode
2 | import os
3 | import dis
4 |
5 | dir = "diojit"
6 | with open(os.path.join(dir, "stack2reg", "opcodes.py"), "w") as f:
7 | f.write("import opcode\n")
8 | f.write("UNKNOWN_INSTR = object()\n")
9 | for each in opcode.opmap:
10 | f.write(f"{each} = opcode.opmap.get({each!r}, UNKNOWN_INSTR)\n")
11 | f.write("\n")
12 |
13 | with open(os.path.join(dir, "stack2reg", "cflags.py"), "w") as f:
14 | f.write("import dis\n")
15 | f.write("_flags = {v: k for k, v in dis.COMPILER_FLAG_NAMES.items()}\n")
16 |
17 | for _, n in dis.COMPILER_FLAG_NAMES.items():
18 | f.write(f"{n} = _flags[{n!r}]\n")
19 | f.write("\n")
20 |
--------------------------------------------------------------------------------
/prepub.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | fd -H -E .git ... -0 | xargs -0 dos2unix
3 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pyrsistent >= 0.17.0
2 | julia >= 0.5.0
--------------------------------------------------------------------------------
/runtests/doil.py:
--------------------------------------------------------------------------------
1 | import diojit
2 |
3 |
4 | @diojit.jit
5 | def test(a):
6 | t = (1, 2, 3)
7 | i = 0
8 | while i < a:
9 | t = (t[1], t[2], t[0])
10 | i = i + t[0]
11 | return i
12 |
13 |
14 | @diojit.jit(fixed_references=["test"])
15 | def f(x):
16 | return test(x)
17 |
18 |
19 | in_def = diojit.absint.In_Def.UserCodeDyn[test]
20 | # in_def.show()
21 | callspec = diojit.spec_call_ir(f, diojit.Val(500))
22 | print("return types: ", *callspec.possibly_return_types)
23 | print("instance : ", callspec.instance)
24 | print("call expr : ", callspec.e_call)
25 | for each in reversed(diojit.absint.Out_Def.GenerateCache):
26 | each.show()
27 |
28 |
29 |
--------------------------------------------------------------------------------
/runtests/load.py:
--------------------------------------------------------------------------------
1 | import diojit.runtime.julia_rt
--------------------------------------------------------------------------------
/runtests/tutorial.py:
--------------------------------------------------------------------------------
1 | from math import sqrt
2 | import operator
3 | import timeit
4 | import builtins
5 | import diojit
6 | from inspect import getsource
7 | from diojit.runtime.julia_rt import check_jl_err
8 | from diojit.codegen.julia import splice
9 |
10 | GenerateCache = diojit.Out_Def.GenerateCache
11 |
12 | import diojit as jit
13 | import timeit
14 | from operator import add
15 |
16 | libjl = jit.runtime.julia_rt.get_libjulia()
17 |
18 |
19 | def jl_eval(s: str):
20 | libjl.jl_eval_string(s.encode())
21 | check_jl_err(libjl)
22 |
23 |
24 | def fib(a):
25 | if a <= 2:
26 | return 1
27 | return fib(a - 1) + fib(a - 2)
28 |
29 |
30 | @jit.jit(fixed_references=["fib_fix"])
31 | def fib_fix(a):
32 | if a <= 2:
33 | return 1
34 | return fib_fix(a + -1) + fib_fix(a + -2)
35 |
36 |
37 | jit_fib_fix_untyped = jit.spec_call(fib_fix, jit.Top)
38 | jit_fib_fix_typed = jit.spec_call(
39 | fib_fix, jit.oftype(int)
40 | )
41 | # jl_eval(f"println(J_fib__fix_1({splice(20)}))")
42 | # check_jl_err(libjl)
43 | print("fib".center(70, "="))
44 | print(getsource(fib))
45 | print(
46 | "fib(15), jit_fib_fix_untyped(15), jit_fib_fix_typed(15) = ",
47 | (fib(15), jit_fib_fix_untyped(15), jit_fib_fix_typed(15)),
48 | )
49 | print(
50 | "fib(py) bench time:",
51 | timeit.timeit("f(15)", globals=dict(f=fib), number=10000),
52 | )
53 | print(
54 | "fib(jit+untyped) bench time:",
55 | timeit.timeit(
56 | "f(15)", globals=dict(f=jit_fib_fix_untyped), number=10000
57 | ),
58 | )
59 | print(
60 | "fib(jit+inferred) bench time:",
61 | timeit.timeit(
62 | "f(15)", globals=dict(f=jit_fib_fix_typed), number=10000
63 | ),
64 | )
65 |
66 | print("hypot".center(70, "="))
67 |
68 |
69 | @diojit.jit(fixed_references=["sqrt", "str", "int", "isinstance"])
70 | def hypot(x, y):
71 | if isinstance(x, str):
72 | x = int(x)
73 |
74 | if isinstance(y, str):
75 | y = int(y)
76 |
77 | return sqrt(x ** 2 + y ** 2)
78 |
79 |
80 | print(getsource(hypot))
81 |
82 |
83 | # print("Direct Translation From Stack Instructions".center(70, "="))
84 |
85 | # diojit.absint.In_Def.UserCodeDyn[hypot].show()
86 | # print("After JITing".center(70, "="))
87 |
88 |
89 | jit_func_name = repr(
90 | diojit.spec_call_ir(
91 | hypot, diojit.S(int), diojit.S(int)
92 | ).e_call.func
93 | )
94 |
95 |
96 | hypot_spec = diojit.spec_call(
97 | hypot,
98 | diojit.oftype(int),
99 | diojit.oftype(int),
100 | # print_jl=print,
101 | # print_dio_ir=print,
102 | )
103 | # #
104 | # libjl = diojit.runtime.julia_rt.get_libjulia()
105 | # libjl.jl_eval_string(f'using InteractiveUtils;@code_llvm {jit_func_name}(PyO.int, PyO.int)'.encode())
106 | # diojit.runtime.julia_rt.check_jl_err(libjl)
107 |
108 | print("hypot(1, 2) (jit) = ", hypot_spec(1, 2))
109 | print("hypot(1, 2) (pure py) = ", hypot(1, 2))
110 | print(
111 | "hypot (pure py) bench time:",
112 | timeit.timeit("f(1, 2)", number=1000000, globals=dict(f=hypot)),
113 | )
114 | print(
115 | "hypot (jit) bench time:",
116 | timeit.timeit(
117 | "f(1, 2)", number=1000000, globals=dict(f=hypot_spec)
118 | ),
119 | )
120 |
121 | diojit.create_shape(list, oop=True)
122 |
123 |
124 | @diojit.register(list, attr="append")
125 | def list_append_analysis(self: diojit.Judge, *args: diojit.AbsVal):
126 | if len(args) != 2:
127 | # rollback to CPython's default code
128 | return NotImplemented
129 | lst, elt = args
130 |
131 | return diojit.CallSpec(
132 | instance=None, # return value is not static
133 | e_call=diojit.S(diojit.intrinsic("PyList_Append"))(lst, elt),
134 | possibly_return_types=tuple({diojit.S(type(None))}),
135 | )
136 |
137 |
138 | @diojit.jit
139 | def append3(xs, x):
140 | xs.append(x)
141 | xs.append(x)
142 | xs.append(x)
143 |
144 |
145 | print("append3".center(70, "="))
146 | print(getsource(append3))
147 |
148 | # diojit.In_Def.UserCodeDyn[append3].show()
149 | jit_append3 = diojit.spec_call(
150 | append3, diojit.oftype(list), diojit.Top
151 | )
152 | xs = [1]
153 | jit_append3(xs, 3)
154 | print("test jit func: [1] append 3 for 3 times =", xs)
155 |
156 |
157 | xs = []
158 | print(
159 | "append3 (jit) bench time:",
160 | timeit.timeit(
161 | "f(xs, 1)", globals=dict(f=jit_append3, xs=xs), number=10000000
162 | ),
163 | )
164 | xs = []
165 | print(
166 | "append3 (pure py) bench time:",
167 | timeit.timeit(
168 | "f(xs, 1)", globals=dict(f=append3, xs=xs), number=10000000
169 | ),
170 | )
171 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 | from datetime import datetime
3 | from pathlib import Path
4 |
5 |
6 | version = "0.2a"
7 | with Path("README.md").open() as readme:
8 | readme = readme.read()
9 |
10 |
11 | setup(
12 | name="diojit",
13 | version=version if isinstance(version, str) else str(version),
14 | keywords="Just-In-Time, JIT, compiler", # keywords of your project that separated by comma ","
15 | description="A general-purpose JIT for CPython.", # a concise introduction of your project
16 | long_description=readme,
17 | long_description_content_type="text/markdown",
18 | license="mit",
19 | python_requires=">=3.8.0",
20 | url="https://github.com/thautwarm/diojit",
21 | author="thautwarm",
22 | author_email="twshere@outlook.com",
23 | packages=find_packages(),
24 | entry_points={"console_scripts": []},
25 | # above option specifies what commands to install,
26 | # e.g: entry_points={"console_scripts": ["yapypy=yapypy.cmd:compiler"]}
27 | install_requires=["pyrsistent", "julia"], # dependencies
28 | platforms="any",
29 | classifiers=[
30 | "Programming Language :: Python :: 3.6",
31 | "Programming Language :: Python :: 3.7",
32 | "Programming Language :: Python :: Implementation :: CPython",
33 | ],
34 | zip_safe=False,
35 | )
36 |
--------------------------------------------------------------------------------
/test.sh:
--------------------------------------------------------------------------------
1 | echo "=====py39====="
2 | pip install diojit
3 | source activate base # py39
4 | python benchmarks/append3.py
5 | python benchmarks/brainfuck.py
6 | python benchmarks/dna_read.py
7 | python benchmarks/fib.py
8 | python benchmarks/hypot.py
9 | python benchmarks/selection_sort.py
10 | python benchmarks/trans.py
11 | pip uninstall diojit
12 |
13 | echo "=====py38====="
14 | source activate py38
15 | pip install .
16 | python benchmarks/append3.py
17 | python benchmarks/brainfuck.py
18 | python benchmarks/dna_read.py
19 | python benchmarks/fib.py
20 | python benchmarks/hypot.py
21 | python benchmarks/selection_sort.py
22 | python benchmarks/trans.py
23 | pip uninstall diojit
--------------------------------------------------------------------------------