├── .gitignore ├── LICENSE ├── Pipfile ├── Pipfile.lock ├── README.md ├── aot ├── __init__.py ├── aot.py ├── convert.py ├── little_cpp.py └── to_source.py ├── examples └── readme_ex.py └── test └── test_aot.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # IPython 78 | profile_default/ 79 | ipython_config.py 80 | 81 | # pyenv 82 | .python-version 83 | 84 | # celery beat schedule file 85 | celerybeat-schedule 86 | 87 | # SageMath parsed files 88 | *.sage.py 89 | 90 | # Environments 91 | .env 92 | .venv 93 | env/ 94 | venv/ 95 | ENV/ 96 | env.bak/ 97 | venv.bak/ 98 | 99 | # Spyder project settings 100 | .spyderproject 101 | .spyproject 102 | 103 | # Rope project settings 104 | .ropeproject 105 | 106 | # mkdocs documentation 107 | /site 108 | 109 | # mypy 110 | .mypy_cache/ 111 | .dmypy.json 112 | dmypy.json 113 | 114 | # Pyre type checker 115 | .pyre/ 116 | 117 | # Generated Source 118 | source.cc -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2019 UW SAMPL 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Pipfile: -------------------------------------------------------------------------------- 1 | [[source]] 2 | url = "https://pypi.org/simple" 3 | verify_ssl = true 4 | name = "pypi" 5 | 6 | [packages] 7 | numpy = "*" 8 | 9 | [dev-packages] 10 | 11 | [requires] 12 | python_version = "3.6" 13 | -------------------------------------------------------------------------------- /Pipfile.lock: -------------------------------------------------------------------------------- 1 | { 2 | "_meta": { 3 | "hash": { 4 | "sha256": "415dfdcb118dd9bdfef17671cb7dcd78dbd69b6ae7d4f39e8b44e71d60ca72e7" 5 | }, 6 | "pipfile-spec": 6, 7 | "requires": { 8 | "python_version": "3.6" 9 | }, 10 | "sources": [ 11 | { 12 | "name": "pypi", 13 | "url": "https://pypi.org/simple", 14 | "verify_ssl": true 15 | } 16 | ] 17 | }, 18 | "default": {}, 19 | "develop": {} 20 | } 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # relay-aot 2 | 3 | An experimental ahead of time compiler for Relay. 4 | 5 | The ahead of time compiler enables the execution of Relay code without 6 | requiring a framework's interpreter written in C++ or Python. 7 | 8 | The removal of framework and interpretation overhead combined 9 | with optimized operators produced by TVM's operators dramatically 10 | reduces execution time. Additionally, the compiler produces a single 11 | binary which only depends on TVM, simplifying the deployment story. 12 | This repository contains an external library which implements ahead of time 13 | compilation for Relay. 14 | The current approach is a proof of concept which lowers Relay to C++, and relies on a 15 | C++ compiler such as `gcc` or `clang` to produce an executable. 16 | 17 | The ahead of time compiler comes as a standalone library which exposes a 18 | primitive `compile` function which compiles a `relay.Function` 19 | into a Python closure which wraps the compiled native code. 20 | 21 | The compiler's design is straight forward. It lowers functions into a 22 | small C++-like IR, and generates a C++ program which can be compiled 23 | and dynamically linked. We extract the corresponding symbol from the 24 | dynamic library, and wrap it as a Python closure. 25 | 26 | You can see an example below: 27 | 28 | ```python 29 | import numpy as np 30 | import tvm 31 | from tvm.relay import Module, GlobalVar, Function 32 | from aot import compile 33 | 34 | def double_example(): 35 | # Declare a Relay module. 36 | mod = Module() 37 | 38 | # Implement the double function. 39 | x = var('x', shape=()) 40 | double = GlobalVar('double') 41 | mod[double] = Function([x], x + x) 42 | 43 | # Generate a function which calls double twice. 44 | x = var('x', shape=()) 45 | f = Function([x], double(double(x))) 46 | # Compile the function. 47 | cfunc = compile(f, mod) 48 | 49 | a = tvm.nd.array(np.array(1.5, dtype='float32')) 50 | output = cfunc(a).asnumpy() # array(6.) 51 | ``` 52 | 53 | Currently there is no Python package due to the lack of a package for `tvm` 54 | itself. You can use the ahead of time compiler by adding it to your `PYTHONPATH`. 55 | 56 | ```shell 57 | export PYTHONPATH="THE_AOT_PATH:${PYTHONPATH}" 58 | ``` 59 | 60 | You must set the variable `TVM_HOME` in order to use the native compiler 61 | currently. 62 | 63 | 64 | You can test to ensure you setup the ahead of time correctly by running: 65 | 66 | ```shell 67 | TVM_HOME=~/Git/tvm python3 examples/readme_ex.py 68 | ``` 69 | 70 | *Note: this only sets the TVM_HOME for this command you must reexport in your shell.* 71 | -------------------------------------------------------------------------------- /aot/__init__.py: -------------------------------------------------------------------------------- 1 | from .aot import compile 2 | 3 | compile = compile 4 | -------------------------------------------------------------------------------- /aot/aot.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import numpy as np 3 | import os 4 | import subprocess 5 | import tempfile 6 | import tvm 7 | from tvm import relay, get_global_func, target, register_func 8 | from tvm.relay.function import Function 9 | from tvm.relay.expr import Expr, Let, GlobalVar 10 | from tvm.relay.adt import Constructor 11 | from tvm.relay.expr_functor import ExprFunctor, ExprVisitor 12 | from tvm.relay.backend import compile_engine 13 | from .little_cpp import PackedCall, CPPFunction, Invoke, Decl, CPPIf, CPPTuple, CPPMatch, CPPConstructor, CPPTupleGetItem 14 | from .little_cpp import CPPRefCreate, CPPRefRead, CPPRefWrite 15 | from . import to_source 16 | from .convert import convert 17 | 18 | TVM_PATH = os.environ['TVM_HOME'] 19 | 20 | def must_run_process(args): 21 | proc = subprocess.run(args) 22 | assert proc.returncode == 0 23 | 24 | def compile_cpp(source, lib_name, flags=None, lib_path=None): 25 | if flags is None: 26 | flags = [] 27 | 28 | if lib_path is None: 29 | lib_path = os.curdir 30 | 31 | debug_source_path = os.path.join(lib_path, 'source.cc') 32 | # Write out the file for debugging. 33 | with open(debug_source_path, 'w') as source_file: 34 | source_file.write(source) 35 | 36 | # with tempfile.TmporaryDirectory() as tmpdir: 37 | tmpdir = tempfile.mkdtemp(prefix="relay_aot_compiler") 38 | lib_path = os.path.join(tmpdir, lib_name) 39 | source_path = os.path.join(tmpdir, 'source.cc') 40 | with open(source_path, 'w') as source_file: 41 | source_file.write(source) 42 | 43 | must_run_process(["clang-format", "-i", debug_source_path]) 44 | 45 | system = os.uname()[0] 46 | if system == 'Darwin': 47 | command = [ 48 | "clang", 49 | "-std=c++14", 50 | "-shared", 51 | "-undefined", 52 | "dynamic_lookup", 53 | "-o", 54 | lib_path, 55 | source_path, 56 | f"-I{TVM_PATH}/3rdparty/dmlc-core/include", 57 | f"-I{TVM_PATH}/3rdparty/dlpack/include", 58 | f"-I{TVM_PATH}/3rdparty/HalideIR/src", 59 | f"-I{TVM_PATH}/include", 60 | f"-L{TVM_PATH}/build", 61 | "-ltvm" 62 | ] + flags 63 | else: 64 | command = [ 65 | "clang", 66 | "-std=c++14", 67 | "-shared", 68 | "-fPIC", 69 | "-o", 70 | lib_path, 71 | source_path, 72 | f"-I{TVM_PATH}/3rdparty/dmlc-core/include", 73 | f"-I{TVM_PATH}/3rdparty/dlpack/include", 74 | f"-I{TVM_PATH}/3rdparty/HalideIR/src", 75 | f"-I{TVM_PATH}/include", 76 | f"-L{TVM_PATH}/build", 77 | "-ltvm" 78 | ] + flags 79 | 80 | must_run_process(command) 81 | return lib_path 82 | 83 | def load_lib(name): 84 | return ctypes.CDLL(name, ctypes.RTLD_GLOBAL) 85 | 86 | def is_primitive(e: relay.Expr): 87 | return isinstance(e, relay.Function) and e.attrs and e.attrs.Primitive.value == 1 88 | 89 | class AoTCompiler(ExprFunctor): 90 | def __init__(self, mod, tgt) -> None: 91 | super().__init__() 92 | self.mod = mod 93 | self.tgt = tgt 94 | self.engine = compile_engine.get() 95 | self.bindings = [[]] 96 | self.gv_map = {} 97 | 98 | def add_binding(self, var, value): 99 | self.bindings[-1].append((var, value)) 100 | 101 | def optimize(self, expr: Function) -> Function: 102 | opts = tvm.transform.Sequential([relay.transform.FuseOps(), 103 | relay.transform.ToANormalForm()]) 104 | self.mod['main'] = expr 105 | self.mod = opts(self.mod) 106 | ret = self.mod['main'] 107 | return ret 108 | 109 | def mk_primitive_op(self, func: Expr, args, output_type) -> Expr: 110 | cc_key = compile_engine.CCacheKey(func, self.tgt) 111 | hash = tvm.ir.structural_hash(func) 112 | name = f"op_{hash}" 113 | if not get_global_func(name, allow_missing=True): 114 | jit_func = self.engine.jit(cc_key, self.tgt) 115 | register_func(name, jit_func) 116 | return PackedCall(name, args, [x.checked_type for x in args], output_type) 117 | 118 | def visit_call(self, call: Expr) -> Expr: 119 | if is_primitive(call.op): 120 | return self.mk_primitive_op(call.op, call.args, call.checked_type) 121 | elif isinstance(call.op, Constructor): 122 | return CPPConstructor(call.op.tag, [self.visit(arg) for arg in call.args]) 123 | else: 124 | assert(call.attrs == None) 125 | args = [self.visit(arg) for arg in call.args] 126 | fn = self.visit(call.op) 127 | return Invoke(fn, args) 128 | 129 | def visit_let(self, let: Expr) -> Expr: 130 | self.bindings.append([]) 131 | 132 | while isinstance(let, Let): 133 | cpp_value = self.visit(let.value) 134 | self.add_binding(let.var, cpp_value) 135 | let = let.body 136 | 137 | bindings = self.bindings.pop() 138 | body = self.visit(let) 139 | 140 | return Decl(bindings, body) 141 | 142 | def visit_var(self, var): 143 | return var 144 | 145 | def visit_global_var(self, gv): 146 | if gv not in self.gv_map: 147 | self.gv_map[gv] = "to be updated" 148 | self.gv_map[gv] = self.visit(self.mod[gv]) 149 | return gv 150 | 151 | def visit_function(self, func): 152 | if is_primitive(func): 153 | body = self.mk_primitive_op(func, func.params, func.ret_type) 154 | return CPPFunction(func.params, body, func.checked_type.ret_type) 155 | else: 156 | return CPPFunction(func.params, self.visit(func.body), func.checked_type.ret_type) 157 | 158 | def visit_constant(self, const): 159 | return const 160 | 161 | def visit_if(self, i): 162 | return CPPIf(self.visit(i.cond), 163 | self.visit(i.true_branch), 164 | self.visit(i.false_branch), 165 | i.checked_type) 166 | 167 | def visit_tuple(self, t): 168 | return CPPTuple([self.visit(f) for f in t.fields], t.checked_type) 169 | 170 | def visit_match(self, m): 171 | return CPPMatch(self.visit(m.data), 172 | [(c.lhs, self.visit(c.rhs)) for c in m.clauses], 173 | m.checked_type) 174 | 175 | def visit_op(self, op): 176 | raise Exception(f'op outside of primitive: {op}') 177 | 178 | def visit_tuple_getitem(self, t): 179 | return CPPTupleGetItem(self.visit(t.tuple_value), t.index, t.checked_type) 180 | 181 | def visit_ref_create(self, r): 182 | return CPPRefCreate(self.visit(r.value), r.checked_type) 183 | 184 | def visit_ref_read(self, r): 185 | return CPPRefRead(self.visit(r.ref), r.checked_type) 186 | 187 | def visit_ref_write(self, r): 188 | return CPPRefWrite(self.visit(r.ref), self.visit(r.value)) 189 | 190 | _LIB_COUNTER = 1 191 | _LIB = [] 192 | 193 | def lib_and_func_name(name): 194 | global _LIB_COUNTER 195 | packed_name = f'relay.aot.{name}.{_LIB_COUNTER}' 196 | lib_name = f"librelay_aot_{_LIB_COUNTER}.so" 197 | _LIB_COUNTER += 1 198 | return lib_name, packed_name 199 | 200 | import time 201 | 202 | def _mk_wrapper(fn, ctx, constants, record_time): 203 | def _wrapper(*args): 204 | new_constants = [convert(a, ctx) for a in constants] 205 | new_args = [convert(a, ctx) for a in args] 206 | begin = time.perf_counter() 207 | res = fn(*new_constants, *new_args) 208 | end = time.perf_counter() 209 | return res if not record_time else (res, end - begin) 210 | return _wrapper 211 | 212 | import sys 213 | sys.setrecursionlimit(10000) 214 | 215 | def compile(func, mod, ctx, tgt, name='default', record_time=False): 216 | """Compile a relay function into a native library function. 217 | 218 | Parameters 219 | ---------- 220 | func: Expr 221 | The function. 222 | 223 | mod: Module 224 | The Module. 225 | 226 | ctx: Context 227 | The Context. 228 | 229 | tgt: Target 230 | The target 231 | 232 | name: String 233 | The name of the target binary library. 234 | 235 | record_time: Bool 236 | Time cost to call f? 237 | 238 | Returns 239 | ------- 240 | result: Function 241 | A function that, when pass in some values, 242 | will convert them to the right format and call the compiled func. 243 | """ 244 | global _LIB 245 | if isinstance(func, GlobalVar): 246 | func = mod[func] 247 | assert isinstance(func, Function) 248 | compiler = AoTCompiler(mod, tgt) 249 | func = compiler.optimize(func) 250 | func = compiler.visit(func) 251 | lib_name, packed_name = lib_and_func_name(name) 252 | constants, source_code = to_source.to_source(mod, func, compiler.gv_map, ctx, packed_name) 253 | lib_name = f"librelay_aot_{_LIB_COUNTER}.so" 254 | library_path = compile_cpp(source_code, lib_name, flags=["-O3"]) 255 | _LIB.append(load_lib(library_path)) 256 | fn = get_global_func(packed_name) 257 | return _mk_wrapper(fn, ctx, constants, record_time) 258 | -------------------------------------------------------------------------------- /aot/convert.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tvm 3 | from tvm import relay 4 | 5 | # convert(convert(a, tg), tg) = convert(a, tg) 6 | def convert(a, ctx): 7 | while True: 8 | if isinstance(a, int): 9 | a = np.array(a, dtype='int32') 10 | elif isinstance(a, np.ndarray): 11 | a = tvm.nd.array(a, ctx) 12 | elif isinstance(a, tvm.runtime.NDArray): 13 | return a 14 | elif isinstance(a, relay.Call): 15 | assert isinstance(a.op, relay.Constructor) 16 | a = (a.op, *a.args) 17 | elif isinstance(a, tuple): 18 | assert isinstance(a[0], relay.Constructor) 19 | a = relay.backend.interpreter.ConstructorValue(a[0].tag, [convert(arg, ctx) for arg in a[1:]], a[0]) 20 | elif isinstance(a, relay.backend.interpreter.ConstructorValue): 21 | return a 22 | else: 23 | raise Exception(a, type(a)) 24 | -------------------------------------------------------------------------------- /aot/little_cpp.py: -------------------------------------------------------------------------------- 1 | from tvm.relay import Var, TypeVar 2 | from typing import Any, Optional, List, Tuple 3 | import attr 4 | 5 | class LittleCppNode: 6 | pass 7 | 8 | @attr.s(auto_attribs=True) 9 | class Decl(LittleCppNode): 10 | bindings: List[Tuple[Var, LittleCppNode]] 11 | body: LittleCppNode 12 | 13 | @attr.s(auto_attribs=True) 14 | class PackedCall(LittleCppNode): 15 | name: str 16 | args: Any 17 | args_type: Any 18 | ret_type: Any 19 | 20 | @attr.s(auto_attribs=True) 21 | class Invoke(LittleCppNode): 22 | call: Any 23 | args: Any 24 | 25 | @attr.s(auto_attribs=True) 26 | class CPPFunction(LittleCppNode): 27 | params: List[Var] 28 | body: Any 29 | ret_type: Any 30 | name: Optional[str] = None 31 | 32 | @attr.s(auto_attribs=True) 33 | class CPPIf(LittleCppNode): 34 | cond: Any 35 | true_branch: Any 36 | false_branch: Any 37 | relay_type: Any 38 | 39 | @attr.s(auto_attribs=True) 40 | class CPPTuple(LittleCppNode): 41 | fields: List[Any] 42 | relay_type: Any 43 | 44 | @attr.s(auto_attribs=True) 45 | class CPPMatch(LittleCppNode): 46 | data: Any 47 | clause: List[Tuple[Any, Any]] 48 | relay_type: Any 49 | 50 | @attr.s(auto_attribs=True) 51 | class CPPConstructor(LittleCppNode): 52 | tag: int 53 | fields: List[Any] 54 | 55 | @attr.s(auto_attribs=True) 56 | class CPPTupleGetItem(LittleCppNode): 57 | tuple_value: Any 58 | index: int 59 | relay_type: Any 60 | 61 | @attr.s(auto_attribs=True) 62 | class CPPRefCreate(LittleCppNode): 63 | value: Any 64 | relay_type: Any 65 | 66 | @attr.s(auto_attribs=True) 67 | class CPPRefRead(LittleCppNode): 68 | ref: Any 69 | relay_type: Any 70 | 71 | @attr.s(auto_attribs=True) 72 | class CPPRefWrite(LittleCppNode): 73 | ref: Any 74 | value: Any 75 | -------------------------------------------------------------------------------- /aot/to_source.py: -------------------------------------------------------------------------------- 1 | from . import little_cpp 2 | from tvm import relay 3 | from tvm.relay.prelude import Prelude 4 | 5 | class ExprWithStmt: 6 | def __init__(self, expr, stmt=""): 7 | assert isinstance(expr, str) 8 | assert isinstance(stmt, str) 9 | assert "ExprWithStmt" not in expr 10 | assert "ExprWithStmt" not in stmt 11 | self.expr = expr 12 | self.stmt = stmt 13 | 14 | def __str__(self): 15 | return f"ExprWithStmt({self.expr}, {self.stmt})" 16 | 17 | def __repr__(self): 18 | return self.__str__() 19 | 20 | class ToSource: 21 | def __init__(self, gv_map): 22 | self.gv_map = gv_map 23 | self.name_counter = 0 24 | self.source_content = "" 25 | self.name_map = {} 26 | self.local = True 27 | self.declare = "" 28 | self.declare_map = {} 29 | self.input_const = [] 30 | 31 | def fresh_global_name(self): 32 | name = f"global{self.name_counter}" 33 | self.name_counter += 1 34 | return name 35 | 36 | def sanitize(self, str): 37 | return str.replace("-", "_").replace("/", "_") 38 | 39 | def fresh_local_name(self, var=None): 40 | if var is not None: 41 | name = f"local_{self.sanitize(var.name_hint)}_{self.name_counter}" 42 | else: 43 | name = f"local_{self.name_counter}" 44 | self.name_counter += 1 45 | return name 46 | 47 | def fresh_label_name(self): 48 | name = f"label_{self.name_counter}" 49 | self.name_counter += 1 50 | return name 51 | 52 | # return (str, str) with lhs being stmts, and rhs being expression 53 | def visit(self, node, *, local=True, name=None): 54 | if isinstance(node, little_cpp.PackedCall): 55 | res = self.visit_packed_call(node) 56 | elif isinstance(node, little_cpp.CPPFunction): 57 | res = self.visit_cpp_function(node, local, name) 58 | elif isinstance(node, little_cpp.Decl): 59 | res = self.visit_decl(node) 60 | elif isinstance(node, little_cpp.Invoke): 61 | res = self.visit_invoke(node) 62 | elif isinstance(node, relay.Var): 63 | res = ExprWithStmt(self.name_map[node]) 64 | elif isinstance(node, relay.GlobalVar): 65 | res = self.visit_global_var(node) 66 | elif isinstance(node, relay.Constant): 67 | res = self.visit_constant(node) 68 | elif isinstance(node, little_cpp.CPPIf): 69 | res = self.visit_if(node) 70 | elif isinstance(node, little_cpp.CPPTuple): 71 | res = self.visit_tuple(node) 72 | elif isinstance(node, little_cpp.CPPConstructor): 73 | res = self.visit_constructor(node) 74 | elif isinstance(node, little_cpp.CPPMatch): 75 | res = self.visit_match(node) 76 | elif isinstance(node, little_cpp.CPPTupleGetItem): 77 | res = self.visit_tuple_getitem(node) 78 | elif isinstance(node, little_cpp.CPPRefCreate): 79 | res = self.visit_ref_create(node) 80 | elif isinstance(node, little_cpp.CPPRefRead): 81 | res = self.visit_ref_read(node) 82 | elif isinstance(node, little_cpp.CPPRefWrite): 83 | res = self.visit_ref_write(node) 84 | else: 85 | raise Exception(str(node)) 86 | assert isinstance(res, ExprWithStmt) 87 | return res 88 | 89 | def visit_ref_create(self, node): 90 | vv = self.visit(node.value) 91 | return ExprWithStmt(f"RefValue({vv.expr})", vv.stmt) 92 | 93 | def visit_ref_read(self, node): 94 | vr = self.visit(node.ref) 95 | return ExprWithStmt(f"Downcast({vr.expr})->value", vr.stmt) 96 | 97 | def visit_ref_write(self, node): 98 | vr = self.visit(node.ref) 99 | vv = self.visit(node.value) 100 | stmt = vr.stmt + vv.stmt + f"Downcast({vr.expr})->value={vv.expr};\n" 101 | return ExprWithStmt("runtime::ADT::Tuple()", stmt) 102 | 103 | def visit_tuple_getitem(self, node): 104 | vt = self.visit(node.tuple_value) 105 | return ExprWithStmt(f"Downcast({vt.expr})[{node.index}]", vt.stmt) 106 | 107 | def visit_constructor(self, node): 108 | args_str, stmt_str = self.visit_args(node.fields) 109 | return ExprWithStmt(f"TagToCV({node.tag}, {{{args_str}}})") 110 | 111 | def pattern_var(self, pat, var_set): 112 | if isinstance(pat, relay.PatternConstructor): 113 | for x in pat.patterns: 114 | self.pattern_var(x, var_set) 115 | elif isinstance(pat, relay.PatternVar): 116 | assert pat.var not in var_set 117 | var_set.add(pat.var) 118 | else: 119 | raise Exception(str(pat)) 120 | 121 | def visit_match(self, node): 122 | vd = self.visit(node.data) 123 | stmt_str = vd.stmt 124 | 125 | pattern_var_set = set() 126 | for c in node.clause: 127 | self.pattern_var(c[0], pattern_var_set) 128 | 129 | for v in pattern_var_set: 130 | bind_name = self.fresh_local_name() 131 | self.name_map[v] = bind_name 132 | stmt_str += f"ObjectRef {bind_name};\n" 133 | 134 | # match data_name to pat, and fill the var accordingly. 135 | # go to fail_label or ok_label base on failure/success. 136 | def visit_pattern(pat, data_name, fail_label, ok_label): 137 | if isinstance(pat, relay.PatternConstructor): 138 | data_name = f"Downcast({data_name})" 139 | ok_case = "" 140 | bind_names = [] 141 | assert len(pat.constructor.inputs) == len(pat.patterns) 142 | for i, input_type in enumerate(pat.constructor.inputs): 143 | bind_name = self.fresh_local_name() 144 | bind_names.append(bind_name) 145 | ok_case += f"ObjectRef {bind_name} = {data_name}->fields[{i}];\n" 146 | for bind_name, p in zip(bind_names, pat.patterns): 147 | next_label = self.fresh_label_name() 148 | ok_case += visit_pattern(p, bind_name, fail_label, next_label) 149 | ok_case += f"{next_label}:\n" 150 | ok_case += f"goto {ok_label};" 151 | return f""" 152 | CHECK({data_name}->tag != -1); 153 | if ({data_name}->tag == {pat.constructor.tag}) {{ 154 | {ok_case} 155 | }} else {{ 156 | goto {fail_label}; 157 | }} 158 | """ 159 | elif isinstance(pat, relay.PatternVar): 160 | return f""" 161 | {self.name_map[pat.var]} = {data_name}; 162 | """ 163 | else: 164 | raise Exception(str(pat)) 165 | 166 | in_name = self.fresh_local_name() 167 | out_name = self.fresh_local_name() 168 | stmt_str += f"ObjectRef {in_name} = {vd.expr};\n" 169 | stmt_str += f"ObjectRef {out_name};\n" 170 | match_finish_label = self.fresh_label_name() 171 | for c in node.clause: 172 | vc = self.visit(c[1]) 173 | fail_label = self.fresh_label_name() 174 | ok_label = self.fresh_label_name() 175 | stmt_str += f"""{{ 176 | {visit_pattern(c[0], in_name, fail_label, ok_label)} 177 | }} 178 | """ 179 | stmt_str += f"""{{ 180 | {ok_label}: 181 | {vc.stmt} 182 | {out_name} = {vc.expr}; 183 | goto {match_finish_label}; 184 | }} 185 | """ 186 | stmt_str += f"{fail_label}:\n" 187 | stmt_str += """CHECK(false) << "does not match any";\n""" 188 | stmt_str += f"{match_finish_label}: ;" 189 | return ExprWithStmt(out_name, stmt_str) 190 | 191 | def visit_tuple(self, node): 192 | expr = [] 193 | stmt_str = "" 194 | for x in node.fields: 195 | vx = self.visit(x) 196 | expr.append(vx.expr) 197 | stmt_str += vx.stmt 198 | list_name = self.fresh_local_name() 199 | stmt_str += f"std::vector {list_name} = {{{inter(expr)}}};" 200 | return ExprWithStmt(f"runtime::ADT::Tuple({list_name})", stmt_str) 201 | 202 | def visit_if(self, node): 203 | vc = self.visit(node.cond) 204 | vt = self.visit(node.true_branch) 205 | vf = self.visit(node.false_branch) 206 | ret_name = self.fresh_local_name() 207 | stmt = f"ObjectRef {ret_name};" 208 | stmt += f""" 209 | {vc.stmt} 210 | if (NDToBool(ObjectRefToND({vc.expr}))) {{ 211 | {vt.stmt} 212 | {ret_name} = {vt.expr}; 213 | }} else {{ 214 | {vf.stmt} 215 | {ret_name} = {vf.expr}; 216 | }} 217 | """ 218 | return ExprWithStmt(ret_name, stmt) 219 | 220 | def visit_constant(self, const): 221 | if const not in self.declare_map: 222 | name = self.fresh_global_name() 223 | self.declare_map[const] = name 224 | self.declare += f"ObjectRef {name};\n" 225 | self.input_const.append((name, const.data.asnumpy())) 226 | return ExprWithStmt(self.declare_map[const]) 227 | 228 | def visit_global_var(self, gv): 229 | if gv not in self.declare_map: 230 | name = self.fresh_global_name() 231 | self.declare_map[gv] = f"{name}" 232 | vgv = self.visit(self.gv_map[gv], local=False, name=name) 233 | assert vgv.stmt == "" 234 | assert vgv.expr == f"{name}" 235 | return ExprWithStmt(self.declare_map[gv]) 236 | 237 | def visit_args(self, args): 238 | args_str = "" 239 | stmt_str = "" 240 | for i, arg in enumerate(args): 241 | va = self.visit(arg) 242 | args_str += va.expr 243 | stmt_str += va.stmt 244 | if i != len(args) - 1: 245 | args_str += ", " 246 | return args_str, stmt_str 247 | 248 | def visit_invoke(self, invoke): 249 | args_str, stmt_str = self.visit_args(invoke.args) 250 | func = self.visit(invoke.call) 251 | return ExprWithStmt(f"Apply({func.expr}, std::vector({{{args_str}}}))", stmt_str + func.stmt) 252 | 253 | def visit_decl(self, decl): 254 | source = "" 255 | for var, value in decl.bindings: 256 | local_name = self.fresh_local_name(var) 257 | self.name_map[var] = local_name 258 | vv = self.visit(value, name=local_name) 259 | source += vv.stmt 260 | source += f"""ObjectRef {local_name} = {vv.expr};""" 261 | vb = self.visit(decl.body) 262 | source += vb.stmt 263 | return ExprWithStmt(vb.expr, source) 264 | 265 | def nd_dtype(self, tt): 266 | assert isinstance(tt, relay.ty.TensorType) 267 | if tt.dtype == 'int32': 268 | return 'dtype_i32' 269 | elif tt.dtype == 'int8': 270 | return 'dtype_i8' 271 | elif tt.dtype == 'float32': 272 | return 'dtype_f32' 273 | elif tt.dtype == 'bool': 274 | return 'dtype_u1' 275 | raise Exception("unknown tensor dtype: " + str(tt)) 276 | 277 | def nd_shape(self, tt): 278 | return f"{{{inter([str(s) for s in tt.shape])}}}" 279 | 280 | def visit_packed_call(self, call): 281 | decl_str = "" 282 | args = [] 283 | for arg in call.args: 284 | va = self.visit(arg) 285 | decl_str += va.stmt 286 | args.append(va.expr) 287 | args_str = [] 288 | def convert_input(ty, arg): 289 | if isinstance(ty, relay.ty.TensorType): 290 | args_str.append(f"{arg}") 291 | else: 292 | assert isinstance(ty, relay.ty.TupleType) 293 | tuple_name = self.fresh_local_name() 294 | nonlocal decl_str 295 | decl_str += f"runtime::ADT {tuple_name} = Downcast({arg});\n" 296 | for i, t in enumerate(ty.fields): 297 | convert_input(t, f"{tuple_name}[{i}]") 298 | assert len(call.args_type) == len(call.args) 299 | for i in range(len(call.args_type)): 300 | convert_input(call.args_type[i], args[i]) 301 | 302 | def convert_output(ty): 303 | nonlocal decl_str 304 | if isinstance(ty, relay.ty.TensorType): 305 | tensor_name = self.fresh_local_name() 306 | decl_str += f"NDArray {tensor_name} = NDArray::Empty({self.nd_shape(ty)}, {self.nd_dtype(ty)}, context);\n" 307 | args_str.append(f"{tensor_name}") 308 | return tensor_name 309 | else: 310 | assert isinstance(ty, relay.ty.TupleType) 311 | list_name = self.fresh_local_name() 312 | list_members = inter([convert_output(t) for t in ty.fields]) 313 | decl_str += f"std::vector {list_name} = {{{list_members}}};" 314 | return f"runtime::ADT::Tuple({list_name})" 315 | out = convert_output(call.ret_type) 316 | return ExprWithStmt(out, f""" 317 | {decl_str} 318 | const PackedFunc *pf = runtime::Registry::Get("{call.name}"); 319 | CHECK(pf); 320 | (*pf)({inter(args_str)}); 321 | """) 322 | 323 | def visit_cpp_function(self, func, local, name): 324 | vec = self.fresh_local_name() 325 | body = "" 326 | 327 | end = len(func.params) - 1 328 | for i, param in enumerate(func.params): 329 | pname = self.fresh_local_name(param) 330 | self.name_map[param] = pname 331 | body += f"ObjectRef {pname} = {vec}.at({i});\n" 332 | 333 | body += f"ObjectRef {name} = self;\n" 334 | vb = self.visit(func.body) 335 | body = body + vb.stmt + f"""return {vb.expr};""" 336 | expr = f"""FunctionValueNode::make([=](const std::vector& {vec}, const ObjectRef& self) {{ 337 | {body} 338 | }}); 339 | """ 340 | 341 | if local: 342 | return ExprWithStmt(expr) 343 | else: 344 | if name is None: 345 | name = self.fresh_global_name() 346 | self.declare += f""" 347 | static ObjectRef {name}_func() {{ 348 | static ObjectRef ret = {expr}; 349 | return ret; 350 | }} 351 | ObjectRef {name} = {name}_func(); 352 | """ 353 | return ExprWithStmt(f"{name}") 354 | 355 | def mk_register_api(self, name: str, func) -> str: 356 | vf = self.visit(func, local=False) 357 | assert vf.stmt == "" 358 | source = self.declare 359 | 360 | args = "" 361 | if isinstance(func, relay.GlobalVar): 362 | func = self.gv_map[func] 363 | end = len(func.params) - 1 364 | init = "" 365 | for i, (input_name, _) in enumerate(self.input_const): 366 | init += f"{input_name} = args[{i}];\n" 367 | for i in range(len(func.params)): 368 | args += f"args[{i+len(self.input_const)}]" 369 | if i != end: 370 | args += ", " 371 | 372 | source += f""" 373 | TVM_REGISTER_GLOBAL("{name}") 374 | .set_body([](TVMArgs args, TVMRetValue* ret) {{ 375 | {init} 376 | std::initializer_list ilist = {{{args}}}; 377 | *ret = Apply({vf.expr}, std::vector(ilist)); 378 | }}); 379 | """ 380 | return source 381 | 382 | def inter(strs, sep=", "): 383 | ret = "" 384 | for i in range(len(strs)): 385 | ret += strs[i] 386 | if i != len(strs) - 1: 387 | ret += sep 388 | return ret 389 | 390 | def mk_file(body, ctx): 391 | return f""" 392 | #include 393 | #include 394 | #include 395 | #include 396 | #include 397 | #include 398 | 399 | using namespace tvm; 400 | using namespace runtime; 401 | using namespace relay; 402 | 403 | static DLDataType dtype_f32 = DLDataType {{ .code = DLDataTypeCode::kDLFloat, .bits = 32, .lanes = 1 }}; 404 | static DLDataType dtype_u32 = DLDataType {{ .code = DLDataTypeCode::kDLUInt, .bits = 32, .lanes = 1 }}; 405 | static DLDataType dtype_u1 = DLDataType {{ .code = DLDataTypeCode::kDLUInt, .bits = 1, .lanes = 1 }}; 406 | static DLDataType dtype_i32 = DLDataType {{ .code = DLDataTypeCode::kDLInt, .bits = 32, .lanes = 1 }}; 407 | static DLDataType dtype_i8 = DLDataType {{ .code = DLDataTypeCode::kDLInt, .bits = 8, .lanes = 1 }}; 408 | static DLContext context = DLContext {{ .device_type = DLDeviceType({ctx.device_type}), .device_id = {ctx.device_id} }}; 409 | 410 | static bool NDToBool(const NDArray& nd) {{ 411 | DLContext cpu_ctx; 412 | cpu_ctx.device_type = kDLCPU; 413 | cpu_ctx.device_id = 0; 414 | NDArray cpu_array = nd.CopyTo(cpu_ctx); 415 | CHECK_EQ(DataType(cpu_array->dtype), DataType::Bool()); 416 | return reinterpret_cast(cpu_array->data)[0]; 417 | }} 418 | 419 | static NDArray ObjectRefToND(const ObjectRef& v) {{ 420 | return Downcast(v); 421 | }} 422 | 423 | static ConstructorValue TagToCV(size_t tag, const tvm::Array& fields) {{ 424 | ObjectPtr n = make_object(); 425 | ObjectPtr con = make_object(); 426 | con->tag = tag; 427 | n->tag = tag; 428 | n->constructor = Constructor(con); 429 | n->fields = fields; 430 | return ConstructorValue(n); 431 | }} 432 | 433 | /*! \\brief A Function value. */ 434 | class FunctionValue; 435 | 436 | using function_value_t = std::function&, const ObjectRef&)>; 437 | struct FunctionValueNode : Object {{ 438 | function_value_t f; 439 | 440 | FunctionValueNode() {{ }} 441 | 442 | void VisitAttrs(tvm::AttrVisitor* v) {{ }} 443 | 444 | TVM_DLL static FunctionValue make(const function_value_t& f); 445 | 446 | static constexpr const char* _type_key = "relay.FunctionValue"; 447 | TVM_DECLARE_FINAL_OBJECT_INFO(FunctionValueNode, Object); 448 | }}; 449 | 450 | class FunctionValue : public ObjectRef {{ 451 | public: 452 | TVM_DEFINE_OBJECT_REF_METHODS(FunctionValue, ObjectRef, FunctionValueNode); 453 | }}; 454 | 455 | FunctionValue FunctionValueNode::make(const function_value_t& f) {{ 456 | ObjectPtr n = make_object(); 457 | n->f = f; 458 | return FunctionValue(n); 459 | }} 460 | 461 | ObjectRef Apply(const ObjectRef& op, const std::vector& args) {{ 462 | return Downcast(op)->f(args, op); 463 | }} 464 | 465 | {body} 466 | """ 467 | 468 | def to_source(mod, program, gv_map, ctx, name) -> str: 469 | convert = ToSource(gv_map) 470 | ret = mk_file(convert.mk_register_api(name, program), ctx) 471 | return [value for name, value in convert.input_const], ret 472 | -------------------------------------------------------------------------------- /examples/readme_ex.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tvm 3 | from tvm.relay import Module, GlobalVar, Function 4 | from aot import compile 5 | 6 | def double_example(): 7 | # Declare a Relay module. 8 | mod = Module() 9 | 10 | # Implement the double function. 11 | x = var('x', shape=()) 12 | double = GlobalVar('double') 13 | mod[double] = Function([x], x + x) 14 | 15 | # Generate a function which calls double twice. 16 | x = var('x', shape=()) 17 | f = Function([x], double(double(x))) 18 | # Compile the function. 19 | cfunc = compile(f, mod) 20 | 21 | a = tvm.nd.array(np.array(1.5, dtype='float32')) 22 | print(cfunc(a).asnumpy()) 23 | -------------------------------------------------------------------------------- /test/test_aot.py: -------------------------------------------------------------------------------- 1 | from tvm import relay 2 | from tvm import IRModule as Module 3 | from tvm.relay import var, Function, op, GlobalVar, TypeVar, FuncType 4 | from tvm.relay.prelude import Prelude 5 | from tvm.relay.testing import add_nat_definitions 6 | import numpy as np 7 | import tvm 8 | import aot 9 | 10 | 11 | def compile(f, mod): 12 | tgt = tvm.target.create('llvm') 13 | ctx = tvm.context('llvm', 0) 14 | return aot.compile(f, mod, ctx=ctx, tgt=tgt) 15 | 16 | 17 | def test_identity(): 18 | mod = Module() 19 | x = var('x', shape=()) 20 | func = Function([x], x) 21 | cfunc = compile(func, mod) 22 | a = tvm.nd.array(np.array(1.0, dtype='float32')) 23 | output = cfunc(a) 24 | np.testing.assert_allclose(output.asnumpy(), a.asnumpy()) 25 | 26 | 27 | def test_add(): 28 | mod = Module() 29 | x = var('x', shape=()) 30 | y = var('y', shape=()) 31 | z = x + y 32 | func = Function([x, y], z) 33 | cfunc = compile(func, mod) 34 | a = tvm.nd.array(np.array(1.0, dtype='float32')) 35 | b = tvm.nd.array(np.array(1.0, dtype='float32')) 36 | c = tvm.nd.array(np.array(2.0, dtype='float32')) 37 | output = cfunc(a, b) 38 | np.testing.assert_allclose(output.asnumpy(), c.asnumpy()) 39 | 40 | 41 | def test_mult_op(): 42 | mod = Module() 43 | x = var('x', shape=()) 44 | y = var('y', shape=()) 45 | z = x + y 46 | zz = op.exp(z) 47 | func = Function([x, y], zz) 48 | cfunc = compile(func, mod) 49 | a = tvm.nd.array(np.array(1.0, dtype='float32')) 50 | b = tvm.nd.array(np.array(1.0, dtype='float32')) 51 | output = cfunc(a, b) 52 | np.testing.assert_allclose(output.asnumpy(), np.exp(a.asnumpy() + b.asnumpy())) 53 | 54 | 55 | def test_double(): 56 | mod = Module() 57 | x = var('x', shape=()) 58 | double = GlobalVar('double') 59 | mod[double] = Function([x], x + x) 60 | x = var('x', shape=()) 61 | cfunc = compile(Function([x], double(double(x))), mod) 62 | a = tvm.nd.array(np.array(1.5, dtype='float32')) 63 | output = cfunc(a) 64 | np.testing.assert_allclose(output.asnumpy(), np.array(6.0, dtype='float32')) 65 | 66 | 67 | def test_42(): 68 | mod = Module() 69 | func = Function([], relay.const(42)) 70 | cfunc = compile(func, mod) 71 | output = cfunc() 72 | np.testing.assert_allclose(output.asnumpy(), np.array(42.0, dtype='float32')) 73 | 74 | 75 | def test_add_42(): 76 | mod = Module() 77 | x = var('x', shape=()) 78 | func = Function([x], x + relay.const(42.0)) 79 | cfunc = compile(func, mod) 80 | a = tvm.nd.array(np.array(42.0, dtype='float32')) 81 | output = cfunc(a) 82 | np.testing.assert_allclose(output.asnumpy(), np.array(84.0, dtype='float32')) 83 | 84 | 85 | def test_int_mult_3(): 86 | mod = Module() 87 | x = var('x', dtype='int32', shape=()) 88 | func = Function([x], x * relay.const(3)) 89 | cfunc = compile(func, mod) 90 | a = tvm.nd.array(np.array(4, dtype='int32')) 91 | output = cfunc(a) 92 | np.testing.assert_allclose(output.asnumpy(), np.array(12, dtype='int32')) 93 | 94 | 95 | def test_abs(): 96 | mod = Module() 97 | x = var('x', shape=()) 98 | func = Function([x], relay.If(op.less(x, relay.const(0.0)), relay.const(-1.0) * x, x)) 99 | cfunc = compile(func, mod) 100 | a = tvm.nd.array(np.array(12.0, dtype='float32')) 101 | output = cfunc(a) 102 | np.testing.assert_allclose(output.asnumpy(), np.array(12.0, dtype='float32')) 103 | a = tvm.nd.array(np.array(-34.0, dtype='float32')) 104 | output = cfunc(a) 105 | np.testing.assert_allclose(output.asnumpy(), np.array(34.0, dtype='float32')) 106 | 107 | 108 | def test_recur_sum_global(): 109 | mod = Module() 110 | x = var('x', dtype='int32', shape=()) 111 | sum = GlobalVar('sum') 112 | c = relay.const(0) 113 | mod[sum] = Function([x], 114 | relay.If(op.less(x, c), c, x + sum(x - relay.const(1))), 115 | relay.TensorType(dtype='int32', shape=())) 116 | cfunc = compile(Function([], sum(relay.const(10))), mod) 117 | output = cfunc() 118 | np.testing.assert_allclose(output.asnumpy(), np.array(55, dtype='int32')) 119 | 120 | 121 | def nat_to_int(n): 122 | if n.constructor.tag & 0xff == 1: 123 | return 1 + nat_to_int(n.fields[0]) 124 | else: 125 | assert n.constructor.tag & 0xff == 0 126 | return 0 127 | 128 | 129 | def int_to_nat(p, i): 130 | if i > 0: 131 | return p.s(int_to_nat(p, i - 1)) 132 | else: 133 | assert i == 0 134 | return p.z() 135 | 136 | 137 | def test_nat_3(): 138 | mod = Module() 139 | p = Prelude(mod) 140 | add_nat_definitions(p) 141 | cfunc = compile(Function([], p.s(p.s(p.s(p.z())))), mod) 142 | output = cfunc() 143 | assert nat_to_int(output) == 3 144 | 145 | 146 | def test_nat_add(): 147 | mod = Module() 148 | p = Prelude(mod) 149 | add_nat_definitions(p) 150 | cfunc = compile(Function([], p.add(p.s(p.s(p.s(p.z()))), p.s(p.s(p.s(p.s(p.z())))))), mod) 151 | output = cfunc() 152 | assert nat_to_int(output) == 7 153 | 154 | 155 | def test_add_convert(): 156 | mod = Module() 157 | p = Prelude(mod) 158 | add_nat_definitions(p) 159 | cfunc = compile(p.add, mod) 160 | output = cfunc(int_to_nat(p, 12), int_to_nat(p, 34)) 161 | assert nat_to_int(output) == 46 162 | 163 | 164 | def test_ref(): 165 | mod = Module() 166 | three_with_ref = relay.GlobalVar('three_with_ref') 167 | i = relay.Var('i') 168 | iv = relay.Var('iv') 169 | u = relay.Var('u') 170 | uv = relay.Var('uv') 171 | body = relay.add(iv, uv) 172 | body = relay.Let(uv, relay.RefRead(i), body) 173 | body = relay.Let(u, relay.RefWrite(i, relay.const(2, dtype='int32')), body) 174 | body = relay.Let(iv, relay.RefRead(i), body) 175 | body = relay.Let(i, relay.RefCreate(relay.const(1, dtype='int32')), body) 176 | mod[three_with_ref] = relay.Function([], body) 177 | cfunc = compile(three_with_ref, mod) 178 | output = cfunc() 179 | np.testing.assert_allclose(output.asnumpy(), np.array(3, dtype='int32')) 180 | 181 | 182 | def test_tuple(): 183 | mod = Module() 184 | cfunc = compile(Function([], 185 | relay.TupleGetItem(relay.Tuple([relay.const(3, dtype='int32'), 186 | relay.const(4.0, dtype='float32')]), 187 | 1)), 188 | mod) 189 | np.testing.assert_allclose(cfunc().asnumpy(), np.array(4.0, dtype='float32')) 190 | 191 | 192 | def test_get_valid_counts(): 193 | # Based on test_get_valid_counts in tvm's test_op_level5. 194 | # Tests the case of a packed func returning a Relay tuple. 195 | # Only checks the shapes of the output because the reference implementation 196 | # is long and inconvenient. 197 | shape = (1, 2500, 6) 198 | score_threshold = 0 199 | id_index = 0 200 | score_index = 1 201 | np_data = np.random.uniform(low=-2, high=2, size=shape).astype("float32") 202 | mod = Module() 203 | cfunc = compile( 204 | relay.Function( 205 | [], 206 | relay.vision.get_valid_counts( 207 | relay.const(np_data), score_threshold, id_index, score_index 208 | ).astuple()), 209 | mod) 210 | 211 | relay_out = cfunc() 212 | out1 = relay_out[0].asnumpy() 213 | out2 = relay_out[1].asnumpy() 214 | assert out1.shape == (shape[0],) 215 | assert out2.shape == shape 216 | 217 | 218 | def test_compose(): 219 | mod = Module() 220 | p = Prelude(mod) 221 | add_nat_definitions(p) 222 | x = relay.Var('x') 223 | inc = GlobalVar('inc') 224 | mod[inc] = Function([x], p.s(x)) 225 | x = relay.Var('x') 226 | func = GlobalVar('func') 227 | f = Function([x], relay.Call(p.compose(inc, p.double), [x])) 228 | mod[func] = f 229 | cfunc = compile(func, mod) 230 | assert nat_to_int(cfunc(p.s(p.s(p.z())))) == 5 231 | 232 | 233 | def test_recur_sum_local(): 234 | mod = Module() 235 | x = var('x', dtype='int32', shape=()) 236 | t = relay.TensorType(dtype='int32', shape=()) 237 | sum = relay.Var('sum', type_annotation=relay.FuncType([t], t)) 238 | c = relay.const(0) 239 | func = Function([x], 240 | relay.If(op.less(x, c), c, x + sum(x - relay.const(1))), 241 | t) 242 | body = relay.Let(sum, func, sum(relay.const(10))) 243 | cfunc = compile(Function([], body), mod) 244 | output = cfunc() 245 | np.testing.assert_allclose(output.asnumpy(), np.array(55, dtype='int32')) 246 | 247 | 248 | def test_local_local_rec_outer_scope(): 249 | mod = Module() 250 | x = var('x', dtype='int32', shape=()) 251 | t = relay.TensorType(dtype='int32', shape=()) 252 | sum = relay.Var('sum', type_annotation=relay.FuncType([t], t)) 253 | c = relay.const(0) 254 | 255 | # we define a locally recursive function inside another function's scope 256 | # and have that function return the closure of the locally recursive function 257 | inner_func = Function([x], 258 | relay.If(op.less(x, c), c, x + sum(x - relay.const(1))), 259 | t) 260 | outer_func_body = relay.Let(sum, inner_func, sum) 261 | outer_func = Function([], outer_func_body) 262 | f = relay.Var('f') 263 | body = relay.Let(f, outer_func(), f(relay.const(10))) 264 | cfunc = compile(Function([], body), mod) 265 | output = cfunc() 266 | np.testing.assert_allclose(output.asnumpy(), np.array(55, dtype='int32')) 267 | 268 | 269 | if __name__ == "__main__": 270 | test_identity() 271 | test_add() 272 | test_mult_op() 273 | test_double() 274 | test_42() 275 | test_add_42() 276 | test_int_mult_3() 277 | test_abs() 278 | test_recur_sum_global() 279 | test_nat_3() 280 | test_nat_add() 281 | test_add_convert() 282 | test_ref() 283 | test_tuple() 284 | test_get_valid_counts() 285 | test_compose() 286 | test_recur_sum_local() 287 | test_local_local_rec_outer_scope() 288 | --------------------------------------------------------------------------------