├── .gitignore ├── .travis.yml ├── Cargo.toml ├── LICENSE ├── README.md ├── examples └── resnet │ ├── Cargo.toml │ ├── README.md │ ├── build.rs │ └── src │ ├── build_resnet.py │ └── main.rs ├── rustfmt.toml ├── src ├── bytearray.rs ├── context.rs ├── errors.rs ├── function.rs ├── internal_api.rs ├── lib.rs ├── module.rs ├── ndarray.rs ├── ty.rs └── value.rs ├── tests ├── basics │ ├── .gitignore │ ├── Cargo.toml │ ├── build.rs │ └── src │ │ ├── main.rs │ │ ├── tvm_add_cpu.py │ │ └── tvm_add_gpu.py └── callback │ ├── Cargo.toml │ └── src │ └── bin │ ├── array.rs │ ├── error.rs │ ├── float.rs │ ├── int.rs │ └── string.rs └── tvm-sys ├── Cargo.toml ├── LICENSE ├── build.rs └── src └── lib.rs /.gitignore: -------------------------------------------------------------------------------- 1 | target 2 | **/*.rs.bk 3 | Cargo.lock 4 | /tvm-sys/src/bindgen.rs 5 | /tests/basics/add_* 6 | /examples/resnet/deploy_* 7 | /examples/resnet/*.png 8 | /examples/resnet/synset.* 9 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: rust 2 | rust: 3 | - nightly 4 | matrix: 5 | fast_finish: true 6 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "tvm-frontend" 3 | version = "0.1.0" 4 | license = "Apache-2.0" 5 | description = "Rust frontend support for TVM" 6 | repository = "https://github.com/dmlc/tvm" 7 | homepage = "https://github.com/dmlc/tvm" 8 | readme = "README.md" 9 | keywords = ["rust", "tvm", "nnvm"] 10 | categories = ["api-bindings", "science"] 11 | authors = ["Ehsan M.Kermani "] 12 | 13 | [lib] 14 | name = "tvm_frontend" 15 | crate-type = ["dylib"] 16 | 17 | [dependencies] 18 | tvm-sys = { version = "0.1.0", path = "tvm-sys" } 19 | ndarray = "0.12.1" 20 | lazy_static = "1.1.0" 21 | num-traits = "0.2" 22 | error-chain = "0.12.0" 23 | 24 | [features] 25 | blas = ["ndarray/blas"] 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DEPRECATED 2 | 3 | **The RFC is closed and this has been merge into [TVM](https://github.com/dmlc/tvm/tree/master/rust).** 4 | 5 | # TVM Runtime Frontend Support 6 | 7 | This crate provides an idiomatic Rust API for [TVM](https://github.com/dmlc/tvm) runtime frontend as part of the ~~[ongoing RFC](https://github.com/dmlc/tvm/issues/1601)~~. Currently this requires **Nightly Rust**. 8 | 9 | Checkout the [docs](https://ehsanmok.github.io/tvm_frontend/tvm_frontend/index.html). 10 | 11 | ## What Does This Crate Offer? 12 | 13 | Here is a major workflow 14 | 15 | 1. Train your **Deep Learning** model using any major framework such as [PyTorch](https://pytorch.org/), [Apache MXNet](https://mxnet.incubator.apache.org/) or [TensorFlow](https://www.tensorflow.org/) 16 | 2. Use **TVM** to build optimized model artifacts on a supported context such as CPU, GPU, OpenCL, Vulkan, VPI, ROCM, etc. 17 | 3. Deploy your models using **Rust** :heart: 18 | 19 | ### Example: Deploy Image Classification from Pretrained Resnet18 on ImageNet1k 20 | 21 | Please checkout [examples/resnet](https://github.com/ehsanmok/tvm-rust/tree/master/examples/resnet) for the complete end-to-end example. 22 | 23 | Here's a Python snippet for downloading and building a pretrained Resnet18 via MXNet and TVM 24 | 25 | ```python 26 | block = get_model('resnet18_v1', pretrained=True) 27 | 28 | sym, params = nnvm.frontend.from_mxnet(block) 29 | # add the softmax layer for prediction 30 | net = nnvm.sym.softmax(sym) 31 | # compile the model 32 | with nnvm.compiler.build_config(opt_level=opt_level): 33 | graph, lib, params = nnvm.compiler.build( 34 | net, target, shape={"data": data_shape}, params=params) 35 | # same the model artifacts 36 | lib.save(os.path.join(target_dir, "deploy_lib.o")) 37 | cc.create_shared(os.path.join(target_dir, "deploy_lib.so"), 38 | [os.path.join(target_dir, "deploy_lib.o")]) 39 | 40 | with open(os.path.join(target_dir, "deploy_graph.json"), "w") as fo: 41 | fo.write(graph.json()) 42 | with open(os.path.join(target_dir,"deploy_param.params"), "wb") as fo: 43 | fo.write(nnvm.compiler.save_param_dict(params)) 44 | ``` 45 | 46 | Now, we need to input the artifacts to create and run the *Graph Runtime* to detect our input cat image 47 | 48 | ![cat](https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true) 49 | 50 | as demostrated in the following Rust snippet 51 | 52 | ```rust 53 | 54 | let graph = fs::read_to_string("deploy_graph.json")?; 55 | // load module 56 | let lib = Module::load(&Path::new("deploy_lib.so"))?; 57 | // get the global TVM graph runtime function 58 | let runtime_create_fn = Function::get_function("tvm.graph_runtime.create", true).unwrap(); 59 | 60 | let runtime_create_fn_ret = call_packed!( 61 | runtime_create_fn, 62 | &graph, 63 | &lib, 64 | &ctx.device_type, 65 | &ctx.device_id 66 | )?; 67 | // get graph runtime module 68 | let graph_runtime_module = runtime_create_fn_ret.to_module(); 69 | // get the registered `load_params` from runtime module 70 | let load_param_fn = graph_runtime_module 71 | .get_function("load_params", false) 72 | .unwrap(); 73 | // parse parameters and convert to TVMByteArray 74 | let params: Vec = fs::read("deploy_param.params")?; 75 | let barr = TVMByteArray::from(¶ms); 76 | // load the parameters 77 | call_packed!(load_param_fn, &barr)?; 78 | // get the set_input function 79 | let set_input_fn = graph_runtime_module 80 | .get_function("set_input", false) 81 | .unwrap(); 82 | 83 | call_packed!(set_input_fn, "data", &input)?; 84 | // get `run` function from runtime module 85 | let run_fn = graph_runtime_module.get_function("run", false).unwrap(); 86 | // execute the run function. Note that it has no argument. 87 | call_packed!(run_fn,)?; 88 | // prepare to get the output 89 | let output_shape = &mut [1, 1000]; 90 | let output = empty(output_shape, TVMContext::cpu(0), TVMType::from("float")); 91 | // get the `get_output` function from runtime module 92 | let get_output_fn = graph_runtime_module 93 | .get_function("get_output", false) 94 | .unwrap(); 95 | // execute the get output function 96 | call_packed!(get_output_fn, &0, &output)?; 97 | // flatten the output as Vec 98 | let output = output.to_vec::()?; 99 | ``` 100 | 101 | ## Installations 102 | 103 | Please follow TVM [installations](https://docs.tvm.ai/install/index.html), `export TVM_HOME=/path/to/tvm` and add `libtvm_runtime` to your `LD_LIBRARY_PATH`. 104 | 105 | *Note:* To run the end-to-end examples and tests, `tvm`, `nnvm` and `topi` need to be added to your `PYTHONPATH` or it's automatic via an Anaconda environment when install individually. 106 | 107 | ## Supported TVM Functionalities 108 | 109 | ### Use TVM to Generate Shared Library 110 | 111 | One can use the following Python snippet to generate `add_gpu.so` which add two vectors on GPU. 112 | 113 | ```python 114 | import os 115 | import tvm 116 | from tvm.contrib import cc 117 | 118 | def test_add(target_dir): 119 | if not tvm.module.enabled("cuda"): 120 | print(f"skip {__file__} because cuda is not enabled...") 121 | return 122 | n = tvm.var("n") 123 | A = tvm.placeholder((n,), name='A') 124 | B = tvm.placeholder((n,), name='B') 125 | C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C") 126 | s = tvm.create_schedule(C.op) 127 | bx, tx = s[C].split(C.op.axis[0], factor=64) 128 | s[C].bind(bx, tvm.thread_axis("blockIdx.x")) 129 | s[C].bind(tx, tvm.thread_axis("threadIdx.x")) 130 | fadd_cuda = tvm.build(s, [A, B, C], "cuda", target_host="llvm", name="myadd") 131 | 132 | fadd_cuda.save(os.path.join(target_dir, "add_gpu.o")) 133 | fadd_cuda.imported_modules[0].save(os.path.join(target_dir, "add_gpu.ptx")) 134 | cc.create_shared(os.path.join(target_dir, "add_gpu.so"), 135 | [os.path.join(target_dir, "add_gpu.o")]) 136 | 137 | 138 | if __name__ == "__main__": 139 | import sys 140 | if len(sys.argv) != 2: 141 | sys.exit(-1) 142 | test_add(sys.argv[1]) 143 | ``` 144 | 145 | ### Run the Generated Shared Library 146 | 147 | The following code snippet demonstrates how to load and test the generated shared library (`add_gpu.so`) in Rust. 148 | 149 | ```rust 150 | extern crate tvm_frontend as tvm; 151 | 152 | use tvm::*; 153 | 154 | fn main() { 155 | let shape = &mut [2]; 156 | let mut data = vec![3f32, 4.0]; 157 | let mut arr = empty(shape, TVMContext::gpu(0), TVMType::from("float")); 158 | arr.copy_from_buffer(data.as_mut_slice()); 159 | let mut ret = empty(shape, TVMContext::gpu(0), TVMType::from("float")); 160 | let path = Path::new("add_gpu.so"); 161 | let ptx = Path::new("add_gpu.ptx"); 162 | let mut fadd = Module::load(path).unwrap(); 163 | let fadd_dep = Module::load(ptx).unwrap(); 164 | assert!(fadd.enabled("gpu")); 165 | fadd.import_module(fadd_dep); 166 | fadd.entry_func(); 167 | function::Builder::from(&mut fadd) 168 | .arg(&arr) 169 | .arg(&arr) 170 | .set_output(&mut ret) 171 | .invoke() 172 | .unwrap(); 173 | 174 | assert_eq!(ret.to_vec::().unwrap(), vec![6f32, 8.0]); 175 | } 176 | ``` 177 | 178 | **Note:** it is required to instruct the `rustc` to link to the generated `add_gpu.so` in runtime, for example by 179 | `cargo:rustc-link-search=native=add_gpu`. 180 | 181 | See the tests and examples custom `build.rs` for more details. 182 | 183 | ### Convert and Register a Rust Function as a TVM Packed Function 184 | 185 | One can use `register_global_func!` macro to convert and register a Rust 186 | function of type `fn(&[TVMArgValue]) -> Result` to a global TVM **packed function** as follows 187 | 188 | ```rust 189 | #[macro_use] 190 | extern crate tvm_frontend as tvm; 191 | 192 | use tvm::*; 193 | 194 | fn main() { 195 | register_global_func! { 196 | fn sum(args: &[TVMArgValue]) -> Result { 197 | let mut ret = 0f32; 198 | let shape = &mut [2]; 199 | for arg in args.iter() { 200 | let e = empty(shape, TVMContext::cpu(0), TVMType::from("float")); 201 | let arr = arg.to_ndarray().copy_to_ndarray(e).unwrap(); 202 | let rnd: ArrayD = ArrayD::try_from(&arr).unwrap(); 203 | ret += rnd.scalar_sum(); 204 | } 205 | let ret_val = TVMRetValue::from(&ret); 206 | Ok(ret_val) 207 | } 208 | } 209 | 210 | let shape = &mut [2]; 211 | let mut data = vec![3f32, 4.0]; 212 | let mut arr = empty(shape, TVMContext::cpu(0), TVMType::from("float")); 213 | arr.copy_from_buffer(data.as_mut_slice()); 214 | let mut registered = function::Builder::default(); 215 | registered 216 | .get_function("sum", true) 217 | .arg(&arr) 218 | .arg(&arr); 219 | 220 | assert_eq!(registered.invoke().unwrap().to_float(), 14f64); 221 | } 222 | ``` 223 | -------------------------------------------------------------------------------- /examples/resnet/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "resnet" 3 | version = "0.0.0" 4 | authors = ["Ehsan M.Kermani "] 5 | license = "Apache-2.0" 6 | build = "build.rs" 7 | 8 | [dependencies] 9 | ndarray = "0.12.1" 10 | tvm-frontend = { path = "../../" } 11 | image = "0.20.1" 12 | csv = "1" 13 | -------------------------------------------------------------------------------- /examples/resnet/README.md: -------------------------------------------------------------------------------- 1 | ## Resnet example 2 | 3 | This end-to-end example shows how to: 4 | * build `Resnet 18` with `tvm` and `nnvm` from Python 5 | * use the provided Rust frontend API to test for an input image 6 | 7 | To run the example, first `tvm`, `nnvm` and `mxnet` must be installed for the python build. To install mxnet for cpu, run `pip install mxnet` 8 | and to install `tvm` and `nnvm` with `llvm` follow the [TVM installation guide](https://docs.tvm.ai/install/index.html). 9 | 10 | * **Build the example**: `cargo build` 11 | 12 | To have a successful build, note that it is required to instruct Rust compiler to link to the compiled shared library, for example with 13 | `println!("cargo:rustc-link-search=native={}", build_path)`. See the `build.rs` for more details. 14 | 15 | * **Run the example**: `cargo run` 16 | -------------------------------------------------------------------------------- /examples/resnet/build.rs: -------------------------------------------------------------------------------- 1 | use std::process::Command; 2 | 3 | fn main() { 4 | let script_path = concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py"); 5 | let output = Command::new("python") 6 | .arg(script_path) 7 | .output() 8 | .expect("Failed to execute command"); 9 | 10 | println!("stdout: {}", String::from_utf8_lossy(&output.stdout)); 11 | println!("stderr: {}", String::from_utf8_lossy(&output.stderr)); 12 | println!( 13 | "cargo:rustc-link-search=native={}", 14 | env!("CARGO_MANIFEST_DIR") 15 | ); 16 | } 17 | -------------------------------------------------------------------------------- /examples/resnet/src/build_resnet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import logging 4 | from os import path as osp 5 | import sys 6 | 7 | import numpy as np 8 | 9 | import mxnet as mx 10 | from mxnet.gluon.model_zoo.vision import get_model 11 | from mxnet.gluon.utils import download 12 | 13 | import tvm 14 | from tvm.contrib import graph_runtime, cc 15 | import nnvm 16 | 17 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 18 | logger = logging.getLogger(__name__) 19 | 20 | parser = argparse.ArgumentParser(description='Resnet build example') 21 | aa = parser.add_argument 22 | aa('--batch-size', type=int, default=1, help='input image batch size') 23 | aa('--opt-level', type=int, default=3, 24 | help='level of optimization. 0 is unoptimized and 3 is the highest level') 25 | aa('--target', type=str, default='llvm', help='target context for compilation') 26 | aa('--image-shape', type=str, default='3,224,224', help='input image dimensions') 27 | aa('--image-name', type=str, default='cat.png', help='name of input image to download') 28 | args = parser.parse_args() 29 | 30 | target_dir = osp.dirname(osp.dirname(osp.realpath(__file__))) 31 | batch_size = args.batch_size 32 | opt_level = args.opt_level 33 | target = tvm.target.create(args.target) 34 | image_shape = tuple(map(int, args.image_shape.split(","))) 35 | data_shape = (batch_size,) + image_shape 36 | 37 | def build(target_dir): 38 | """ Compiles resnet18 with TVM""" 39 | # download the pretrained resnet18 trained on imagenet1k dataset for 40 | # image classification task 41 | block = get_model('resnet18_v1', pretrained=True) 42 | 43 | sym, params = nnvm.frontend.from_mxnet(block) 44 | # add the softmax layer for prediction 45 | net = nnvm.sym.softmax(sym) 46 | # compile the model 47 | with nnvm.compiler.build_config(opt_level=opt_level): 48 | graph, lib, params = nnvm.compiler.build( 49 | net, target, shape={"data": data_shape}, params=params) 50 | # save the model artifacts 51 | lib.save(osp.join(target_dir, "deploy_lib.o")) 52 | cc.create_shared(osp.join(target_dir, "deploy_lib.so"), 53 | [osp.join(target_dir, "deploy_lib.o")]) 54 | 55 | with open(osp.join(target_dir, "deploy_graph.json"), "w") as fo: 56 | fo.write(graph.json()) 57 | 58 | with open(osp.join(target_dir,"deploy_param.params"), "wb") as fo: 59 | fo.write(nnvm.compiler.save_param_dict(params)) 60 | 61 | def download_img_labels(): 62 | """ Download an image and imagenet1k class labels for test""" 63 | img_name = 'cat.png' 64 | synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/', 65 | '4d0b62f3d01426887599d4f7ede23ee5/raw/', 66 | '596b27d23537e5a1b5751d2b0481ef172f58b539/', 67 | 'imagenet1000_clsid_to_human.txt']) 68 | synset_name = 'synset.txt' 69 | download('https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true', img_name) 70 | download(synset_url, synset_name) 71 | 72 | with open(synset_name) as fin: 73 | synset = eval(fin.read()) 74 | 75 | with open("synset.csv", "w") as fout: 76 | w = csv.writer(fout) 77 | w.writerows(synset.items()) 78 | 79 | def test_build(target_dir): 80 | """ Sanity check with random input""" 81 | graph = open(osp.join(target_dir, "deploy_graph.json")).read() 82 | lib = tvm.module.load(osp.join(target_dir, "deploy_lib.so")) 83 | params = bytearray(open(osp.join(target_dir,"deploy_param.params"), "rb").read()) 84 | input_data = tvm.nd.array(np.random.uniform(size=data_shape).astype("float32")) 85 | ctx = tvm.cpu() 86 | module = graph_runtime.create(graph, lib, ctx) 87 | module.load_params(params) 88 | module.run(data=input_data) 89 | out = module.get_output(0).asnumpy() 90 | 91 | 92 | if __name__ == '__main__': 93 | logger.info("building the model") 94 | build(target_dir) 95 | logger.info("build was successful") 96 | logger.info("test the build artifacts") 97 | test_build(target_dir) 98 | logger.info("test was successful") 99 | download_img_labels() 100 | logger.info("image and synset downloads are successful") 101 | -------------------------------------------------------------------------------- /examples/resnet/src/main.rs: -------------------------------------------------------------------------------- 1 | extern crate csv; 2 | extern crate image; 3 | extern crate ndarray; 4 | extern crate tvm_frontend as tvm; 5 | 6 | use std::{ 7 | collections::HashMap, 8 | error::Error, 9 | fs::{self, File}, 10 | path::Path, 11 | result::Result, 12 | }; 13 | 14 | use image::{FilterType, GenericImageView}; 15 | use ndarray::{Array, ArrayD, Axis}; 16 | 17 | use tvm::*; 18 | 19 | fn main() -> Result<(), Box> { 20 | let ctx = TVMContext::cpu(0); 21 | let img = image::open("cat.png")?; 22 | println!("original image dimensions: {:?}", img.dimensions()); 23 | // for bigger size images, one needs to first resize to 256x256 24 | // with `img.resize_exact` method and then `image.crop` to 224x224 25 | let img = img.resize(224, 224, FilterType::Nearest).to_rgb(); 26 | println!("resized image dimensions: {:?}", img.dimensions()); 27 | let mut pixels: Vec = vec![]; 28 | for pixel in img.pixels() { 29 | let tmp = pixel.data; 30 | // normalize the RGB channels using mean, std of imagenet1k 31 | let tmp = [ 32 | (tmp[0] as f32 - 123.0) / 58.395, // R 33 | (tmp[1] as f32 - 117.0) / 57.12, // G 34 | (tmp[2] as f32 - 104.0) / 57.375, // B 35 | ]; 36 | for e in &tmp { 37 | pixels.push(*e); 38 | } 39 | } 40 | 41 | let arr = Array::from_shape_vec((224, 224, 3), pixels)?; 42 | let arr: ArrayD = arr.permuted_axes([2, 0, 1]).into_dyn(); 43 | // make arr shape as [1, 3, 224, 224] acceptable to resnet 44 | let arr = arr.insert_axis(Axis(0)); 45 | // create input tensor from rust's ndarray 46 | let input = NDArray::from_rust_ndarray(&arr, TVMContext::cpu(0), TVMType::from("float"))?; 47 | println!( 48 | "input size is {:?}", 49 | input.shape().expect("cannot get the input shape") 50 | ); 51 | let graph = fs::read_to_string("deploy_graph.json")?; 52 | // load the built module 53 | let lib = Module::load(&Path::new("deploy_lib.so"))?; 54 | // get the global TVM graph runtime function 55 | let runtime_create_fn = Function::get_function("tvm.graph_runtime.create", true).unwrap(); 56 | let runtime_create_fn_ret = call_packed!( 57 | runtime_create_fn, 58 | &graph, 59 | &lib, 60 | &ctx.device_type, 61 | &ctx.device_id 62 | )?; 63 | // get graph runtime module 64 | let graph_runtime_module = runtime_create_fn_ret.to_module(); 65 | // get the registered `load_params` from runtime module 66 | let load_param_fn = graph_runtime_module 67 | .get_function("load_params", false) 68 | .unwrap(); 69 | // parse parameters and convert to TVMByteArray 70 | let params: Vec = fs::read("deploy_param.params")?; 71 | let barr = TVMByteArray::from(¶ms); 72 | // load the parameters 73 | call_packed!(load_param_fn, &barr)?; 74 | // get the set_input function 75 | let set_input_fn = graph_runtime_module 76 | .get_function("set_input", false) 77 | .unwrap(); 78 | 79 | call_packed!(set_input_fn, "data", &input)?; 80 | // get `run` function from runtime module 81 | let run_fn = graph_runtime_module.get_function("run", false).unwrap(); 82 | // execute the run function. Note that it has no argument 83 | call_packed!(run_fn,)?; 84 | // prepare to get the output 85 | let output_shape = &mut [1, 1000]; 86 | let output = empty(output_shape, TVMContext::cpu(0), TVMType::from("float")); 87 | // get the `get_output` function from runtime module 88 | let get_output_fn = graph_runtime_module 89 | .get_function("get_output", false) 90 | .unwrap(); 91 | // execute the get output function 92 | call_packed!(get_output_fn, &0, &output)?; 93 | // flatten the output as Vec 94 | let output = output.to_vec::()?; 95 | // find the maximum entry in the output and its index 96 | let mut argmax = -1; 97 | let mut max_prob = 0.; 98 | for i in 0..output.len() { 99 | if output[i] > max_prob { 100 | max_prob = output[i]; 101 | argmax = i as i32; 102 | } 103 | } 104 | // create a hash map of (class id, class name) 105 | let mut synset: HashMap = HashMap::new(); 106 | let file = File::open("synset.csv")?; 107 | let mut rdr = csv::ReaderBuilder::new() 108 | .has_headers(true) 109 | .from_reader(file); 110 | 111 | for result in rdr.records() { 112 | let record = result?; 113 | let id: i32 = record[0].parse()?; 114 | let cls = record[1].to_string(); 115 | synset.insert(id, cls); 116 | } 117 | 118 | println!( 119 | "input image belongs to the class `{}` with probability {}", 120 | synset 121 | .get(&argmax) 122 | .expect("cannot find the class id for argmax"), 123 | max_prob 124 | ); 125 | 126 | Ok(()) 127 | } 128 | -------------------------------------------------------------------------------- /rustfmt.toml: -------------------------------------------------------------------------------- 1 | max_width = 100 2 | hard_tabs = false 3 | tab_spaces = 4 4 | newline_style = "Auto" 5 | use_small_heuristics = "Default" 6 | indent_style = "Block" 7 | wrap_comments = false 8 | comment_width = 80 9 | normalize_comments = false 10 | format_strings = false 11 | format_macro_matchers = false 12 | format_macro_bodies = true 13 | empty_item_single_line = true 14 | struct_lit_single_line = true 15 | fn_single_line = false 16 | where_single_line = false 17 | imports_indent = "Block" 18 | imports_layout = "Mixed" 19 | merge_imports = true 20 | reorder_imports = true 21 | reorder_modules = true 22 | reorder_impl_items = false 23 | type_punctuation_density = "Wide" 24 | space_before_colon = false 25 | space_after_colon = true 26 | spaces_around_ranges = false 27 | binop_separator = "Front" 28 | remove_nested_parens = true 29 | combine_control_expr = true 30 | struct_field_align_threshold = 0 31 | match_arm_blocks = true 32 | force_multiline_blocks = false 33 | fn_args_density = "Tall" 34 | brace_style = "SameLineWhere" 35 | control_brace_style = "AlwaysSameLine" 36 | trailing_semicolon = true 37 | trailing_comma = "Vertical" 38 | match_block_trailing_comma = false 39 | blank_lines_upper_bound = 1 40 | blank_lines_lower_bound = 0 41 | edition = "2015" 42 | merge_derives = true 43 | use_try_shorthand = true 44 | use_field_init_shorthand = false 45 | force_explicit_abi = true 46 | condense_wildcard_suffixes = false 47 | color = "Auto" 48 | required_version = "0.99.5" 49 | unstable_features = false 50 | disable_all_formatting = false 51 | skip_children = false 52 | hide_parse_errors = false 53 | error_on_line_overflow = false 54 | error_on_unformatted = false 55 | report_todo = "Never" 56 | report_fixme = "Never" 57 | ignore = [] 58 | emit_mode = "Files" 59 | make_backup = false 60 | -------------------------------------------------------------------------------- /src/bytearray.rs: -------------------------------------------------------------------------------- 1 | //! Provides [`TVMByteArray`] used for passing the model parameters 2 | //! (stored as byte-array) to a runtime module. 3 | //! 4 | //! For more detail, please see the example `resnet` in `examples` repository. 5 | 6 | use std::os::raw::c_char; 7 | 8 | use ts; 9 | 10 | /// A struct holding TVM byte-array. 11 | /// 12 | /// ## Example 13 | /// 14 | /// ``` 15 | /// let v = b"hello".to_vec(); 16 | /// let barr = TVMByteArray::from(&v); 17 | /// assert_eq!(barr.len(), v.len()); 18 | /// assert_eq!(barr.data(), vec![104i8, 101, 108, 108, 111]); 19 | /// ``` 20 | #[derive(Debug, Clone)] 21 | pub struct TVMByteArray { 22 | pub(crate) inner: ts::TVMByteArray, 23 | } 24 | 25 | impl TVMByteArray { 26 | pub(crate) fn new(barr: ts::TVMByteArray) -> TVMByteArray { 27 | TVMByteArray { inner: barr } 28 | } 29 | 30 | /// Gets the length of the underlying byte-array 31 | pub fn len(&self) -> usize { 32 | self.inner.size 33 | } 34 | 35 | /// Gets the underlying byte-array as `Vec` 36 | pub fn data(&self) -> Vec { 37 | unsafe { 38 | let sz = self.len(); 39 | let mut ret_buf = Vec::with_capacity(sz); 40 | ret_buf.set_len(sz); 41 | self.inner.data.copy_to(ret_buf.as_mut_ptr(), sz); 42 | ret_buf 43 | } 44 | } 45 | } 46 | 47 | impl<'a> From<&'a Vec> for TVMByteArray { 48 | fn from(arg: &Vec) -> Self { 49 | let barr = ts::TVMByteArray { 50 | data: arg.as_ptr() as *const c_char, 51 | size: arg.len(), 52 | }; 53 | TVMByteArray::new(barr) 54 | } 55 | } 56 | 57 | #[cfg(test)] 58 | mod tests { 59 | use super::*; 60 | 61 | #[test] 62 | fn convert() { 63 | let v = vec![1u8, 2, 3]; 64 | let barr = TVMByteArray::from(&v); 65 | assert_eq!(barr.len(), v.len()); 66 | assert_eq!(barr.data(), vec![1i8, 2, 3]); 67 | let v = b"hello".to_vec(); 68 | let barr = TVMByteArray::from(&v); 69 | assert_eq!(barr.len(), v.len()); 70 | assert_eq!(barr.data(), vec![104i8, 101, 108, 108, 111]); 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /src/context.rs: -------------------------------------------------------------------------------- 1 | //! Provides [`TVMContext`] and related device specific queries. 2 | //! 3 | //! Create a new context by device type (cpu is 1) and device id. 4 | //! 5 | //! # Example 6 | //! 7 | //! ``` 8 | //! let ctx = TVMContext::new(1, 0); 9 | //! let cpu0 = TVMContext::cpu(0); 10 | //! assert_eq!(ctx, cpu0); 11 | //! ``` 12 | //! 13 | //! Or from a supported device name. 14 | //! 15 | //! ``` 16 | //! let cpu0 = TVMContext::from("cpu"); 17 | //! println!("{}", cpu0); 18 | //! ``` 19 | 20 | use std::{ 21 | fmt::{self, Display, Formatter}, 22 | os::raw::c_void, 23 | ptr, 24 | }; 25 | 26 | use function; 27 | use internal_api; 28 | use ts; 29 | use Result; 30 | 31 | /// Device type can be from a supported device name. See the supported devices 32 | /// in [TVM](https://github.com/dmlc/tvm). 33 | /// 34 | /// ## Example 35 | /// 36 | /// ``` 37 | /// let cpu = TVMDeviceType::from("cpu"); 38 | /// println!("device is: {}", cpu); 39 | ///``` 40 | 41 | #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] 42 | pub struct TVMDeviceType(pub usize); 43 | 44 | impl Default for TVMDeviceType { 45 | /// default device is cpu. 46 | fn default() -> Self { 47 | TVMDeviceType(1) 48 | } 49 | } 50 | 51 | impl From for ts::DLDeviceType { 52 | fn from(device_type: TVMDeviceType) -> Self { 53 | match device_type.0 { 54 | 1 => ts::DLDeviceType_kDLCPU, 55 | 2 => ts::DLDeviceType_kDLGPU, 56 | 3 => ts::DLDeviceType_kDLCPUPinned, 57 | 4 => ts::DLDeviceType_kDLOpenCL, 58 | 7 => ts::DLDeviceType_kDLVulkan, 59 | 8 => ts::DLDeviceType_kDLMetal, 60 | 9 => ts::DLDeviceType_kDLVPI, 61 | 10 => ts::DLDeviceType_kDLROCM, 62 | 12 => ts::DLDeviceType_kDLExtDev, 63 | _ => panic!("device type not found!"), 64 | } 65 | } 66 | } 67 | 68 | impl From for TVMDeviceType { 69 | fn from(device_type: ts::DLDeviceType) -> Self { 70 | match device_type { 71 | ts::DLDeviceType_kDLCPU => TVMDeviceType(1), 72 | ts::DLDeviceType_kDLGPU => TVMDeviceType(2), 73 | ts::DLDeviceType_kDLCPUPinned => TVMDeviceType(3), 74 | ts::DLDeviceType_kDLOpenCL => TVMDeviceType(4), 75 | ts::DLDeviceType_kDLVulkan => TVMDeviceType(7), 76 | ts::DLDeviceType_kDLMetal => TVMDeviceType(8), 77 | ts::DLDeviceType_kDLVPI => TVMDeviceType(9), 78 | ts::DLDeviceType_kDLROCM => TVMDeviceType(10), 79 | ts::DLDeviceType_kDLExtDev => TVMDeviceType(12), 80 | _ => panic!("device type not found!"), 81 | } 82 | } 83 | } 84 | 85 | impl Display for TVMDeviceType { 86 | fn fmt(&self, f: &mut Formatter) -> fmt::Result { 87 | write!( 88 | f, 89 | "{}", 90 | match self { 91 | TVMDeviceType(1) => "cpu", 92 | TVMDeviceType(2) => "gpu", 93 | TVMDeviceType(3) => "cpu_pinned", 94 | TVMDeviceType(4) => "opencl", 95 | TVMDeviceType(8) => "meta", 96 | TVMDeviceType(9) => "vpi", 97 | TVMDeviceType(10) => "rocm", 98 | TVMDeviceType(_) => "rpc", 99 | } 100 | ) 101 | } 102 | } 103 | 104 | impl<'a> From<&'a str> for TVMDeviceType { 105 | fn from(type_str: &'a str) -> Self { 106 | match type_str { 107 | "cpu" => TVMDeviceType(1), 108 | "llvm" => TVMDeviceType(1), 109 | "stackvm" => TVMDeviceType(1), 110 | "gpu" => TVMDeviceType(2), 111 | "cuda" => TVMDeviceType(2), 112 | "nvptx" => TVMDeviceType(2), 113 | "cl" => TVMDeviceType(4), 114 | "opencl" => TVMDeviceType(4), 115 | "metal" => TVMDeviceType(8), 116 | "vpi" => TVMDeviceType(9), 117 | "rocm" => TVMDeviceType(10), 118 | _ => panic!("{:?} not supported!", type_str), 119 | } 120 | } 121 | } 122 | 123 | /// Represents the underlying device context. Default is cpu. 124 | /// 125 | /// ## Examples 126 | /// 127 | /// ``` 128 | /// let ctx = TVMContext::from("gpu"); 129 | /// assert!(ctx.exist()); 130 | /// 131 | /// ``` 132 | /// 133 | /// It is possible to query the underlying context as follows 134 | /// 135 | /// ``` 136 | /// println!("maximun threads per block: {}", ctx.max_threads_per_block()); 137 | /// println!("compute version: {}", ctx.compute_version()); 138 | /// ``` 139 | 140 | #[derive(Debug, Default, Clone, Hash, PartialEq, Eq)] 141 | pub struct TVMContext { 142 | /// Supported device types 143 | pub device_type: TVMDeviceType, 144 | /// Device id 145 | pub device_id: usize, 146 | } 147 | 148 | impl TVMContext { 149 | /// Creates context from device type and id. 150 | pub fn new(device_type: TVMDeviceType, device_id: usize) -> Self { 151 | TVMContext { 152 | device_type: device_type, 153 | device_id: device_id, 154 | } 155 | } 156 | } 157 | 158 | macro_rules! impl_ctxs { 159 | ($(($ctx:ident, $dldevt:expr));+) => { 160 | $( 161 | impl TVMContext { 162 | pub fn $ctx(device_id: usize) -> Self { 163 | Self::new(TVMDeviceType($dldevt), device_id) 164 | } 165 | } 166 | )+ 167 | }; 168 | } 169 | 170 | impl_ctxs!((cpu, 1); 171 | (gpu, 2); 172 | (nvptx, 2); 173 | (cuda, 2); 174 | (cpu_pinned, 3); 175 | (cl, 4); 176 | (opencl, 4); 177 | (metal, 8); 178 | (vpi, 9); 179 | (rocm, 10)); 180 | 181 | impl<'a> From<&'a str> for TVMContext { 182 | fn from(target: &str) -> Self { 183 | TVMContext::new(TVMDeviceType::from(target), 0) 184 | } 185 | } 186 | 187 | impl TVMContext { 188 | /// Checks whether the context exists or not. 189 | pub fn exist(&self) -> bool { 190 | let func = internal_api::get_api("_GetDeviceAttr".to_owned()); 191 | let dt = self.device_type.0 as usize; 192 | // `unwrap` is ok here because if there is any error, 193 | // if would occure inside `call_packed!` 194 | let ret = call_packed!(func, &dt, &self.device_id, &0).unwrap(); 195 | ret.to_int() != 0 196 | } 197 | 198 | /// Synchronize the context stream. 199 | pub fn sync(&self) -> Result<()> { 200 | check_call!(ts::TVMSynchronize( 201 | self.device_type.0 as i32, 202 | self.device_id as i32, 203 | ptr::null_mut() as *mut c_void 204 | )); 205 | Ok(()) 206 | } 207 | } 208 | 209 | macro_rules! impl_dev_attrs { 210 | ($attr_name:ident, $attr_kind:expr) => { 211 | impl TVMContext { 212 | pub fn $attr_name(&self) -> usize { 213 | let func = ::internal_api::get_api("_GetDeviceAttr".to_owned()); 214 | let dt = self.device_type.0 as usize; 215 | // `unwrap` is ok here because if there is any error, 216 | // if would occur in function call. 217 | let ret = function::Builder::from(func) 218 | .args(&[dt, self.device_id, $attr_kind]) 219 | .invoke() 220 | .unwrap(); 221 | ret.to_int() as usize 222 | } 223 | } 224 | }; 225 | } 226 | 227 | impl_dev_attrs!(max_threads_per_block, 1); 228 | impl_dev_attrs!(warp_size, 2); 229 | impl_dev_attrs!(max_shared_memory_per_block, 3); 230 | impl_dev_attrs!(compute_version, 4); 231 | impl_dev_attrs!(device_name, 5); 232 | impl_dev_attrs!(max_clock_rate, 6); 233 | impl_dev_attrs!(multi_processor_count, 7); 234 | impl_dev_attrs!(max_thread_dimensions, 8); 235 | 236 | impl From for TVMContext { 237 | fn from(ctx: ts::DLContext) -> Self { 238 | TVMContext { 239 | device_type: TVMDeviceType::from(ctx.device_type), 240 | device_id: ctx.device_id as usize, 241 | } 242 | } 243 | } 244 | 245 | impl From for ts::DLContext { 246 | fn from(ctx: TVMContext) -> Self { 247 | ts::DLContext { 248 | device_type: ctx.device_type.into(), 249 | device_id: ctx.device_id as i32, 250 | } 251 | } 252 | } 253 | 254 | impl Display for TVMContext { 255 | fn fmt(&self, f: &mut Formatter) -> fmt::Result { 256 | write!(f, "{}({})", self.device_type, self.device_id) 257 | } 258 | } 259 | 260 | #[cfg(test)] 261 | mod tests { 262 | use super::*; 263 | 264 | #[test] 265 | fn context() { 266 | let ctx = TVMContext::cpu(0); 267 | println!("ctx: {}", ctx); 268 | let default_ctx = TVMContext::new(TVMDeviceType(1), 0); 269 | assert_eq!(ctx.clone(), default_ctx); 270 | assert_ne!(ctx, TVMContext::gpu(0)); 271 | 272 | let str_ctx = TVMContext::new(TVMDeviceType::from("gpu"), 0); 273 | assert_eq!(str_ctx.clone(), str_ctx); 274 | assert_ne!(str_ctx, TVMContext::new(TVMDeviceType::from("cpu"), 0)); 275 | } 276 | 277 | #[test] 278 | fn sync() { 279 | let ctx = TVMContext::cpu(0); 280 | assert!(ctx.sync().is_ok()) 281 | } 282 | } 283 | -------------------------------------------------------------------------------- /src/errors.rs: -------------------------------------------------------------------------------- 1 | //! This module implements TVM custom [`Error`], [`ErrorKind`] and [`Result`] types. 2 | 3 | use std::{ffi, option}; 4 | 5 | use rust_ndarray; 6 | 7 | error_chain!{ 8 | errors { 9 | EmptyArray { 10 | description("cannot convert from an empty array") 11 | } 12 | 13 | NullHandle(name: String) { 14 | description("null handle") 15 | display("requested `{}` handle is null", name) 16 | } 17 | 18 | FunctionNotFound { 19 | description("function not found") 20 | display("function was not set in `function::Builder`") 21 | } 22 | 23 | TypeMismatch(expected: String, found: String) { 24 | description("type mismatch!") 25 | display("expected type `{}`, but found `{}`", expected, found) 26 | } 27 | 28 | MissingShapeError { 29 | description("ndarray `shape()` returns `None`") 30 | display("called `Option::unwrap()` on a `None` value") 31 | } 32 | 33 | } 34 | 35 | foreign_links { 36 | ShapeError(rust_ndarray::ShapeError); 37 | NulError(ffi::NulError); 38 | IntoStringError(ffi::IntoStringError); 39 | } 40 | } 41 | 42 | impl From for Error { 43 | fn from(_err: option::NoneError) -> Self { 44 | ErrorKind::MissingShapeError.into() 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /src/function.rs: -------------------------------------------------------------------------------- 1 | //! This module provides an idiomatic Rust API for creating and working with TVM functions. 2 | //! 3 | //! For calling an already registered TVM function use [`function::Builder`] 4 | //! To register a TVM packed function from Rust side either 5 | //! use [`function::register`] or the macro [`register_global_func`]. 6 | //! 7 | //! See the tests and examples repository for more examples. 8 | 9 | use std::{ 10 | ffi::{CStr, CString}, 11 | mem, 12 | os::raw::{c_char, c_int, c_void}, 13 | ptr, slice, str, 14 | sync::Mutex, 15 | }; 16 | 17 | use ts; 18 | 19 | use ty::TypeCode; 20 | use value::{TVMValue, ValueKind}; 21 | use ErrorKind; 22 | use Module; 23 | use Result; 24 | use TVMArgValue; 25 | use TVMRetValue; 26 | 27 | lazy_static! { 28 | static ref GLOBAL_FUNCTION_NAMES: Mutex> = { 29 | let mut out_size = 0 as c_int; 30 | let name = ptr::null_mut() as *mut c_char; 31 | let mut out_array = name as *mut _; 32 | check_call!(ts::TVMFuncListGlobalNames( 33 | &mut out_size as *mut _, 34 | &mut out_array 35 | )); 36 | let names_list = unsafe { slice::from_raw_parts(out_array, out_size as usize) }; 37 | Mutex::new( 38 | names_list 39 | .into_iter() 40 | .map(|&p| unsafe { CStr::from_ptr(p).to_str().unwrap() }) 41 | .collect(), 42 | ) 43 | }; 44 | } 45 | 46 | /// Returns a registered TVM function by name. 47 | pub fn get_global_func(name: &str, is_global: bool) -> Option { 48 | let name = CString::new(name).expect("function name should not contain any `0` byte"); 49 | let mut handle = ptr::null_mut() as ts::TVMFunctionHandle; 50 | check_call!(ts::TVMFuncGetGlobal( 51 | name.as_ptr() as *const c_char, 52 | &mut handle as *mut _ 53 | )); 54 | if !(handle.is_null()) { 55 | mem::forget(name); 56 | return Some(Function::new(handle, is_global, false)); 57 | } else { 58 | None 59 | } 60 | } 61 | 62 | /// Wrapper around TVM function handle which includes `is_global` 63 | /// indicating whether the function is global or not, `is_released` 64 | /// to hint dropping the function handle and `is_cloned` showing 65 | /// not to drop a cloned function from Rust side. 66 | /// The value of these fields can be accessed through their respective methods. 67 | #[derive(Debug, Hash)] 68 | pub struct Function { 69 | pub(crate) handle: ts::TVMFunctionHandle, 70 | // whether the registered function is global or not. 71 | is_global: bool, 72 | // whether the function has been dropped from frontend or not. 73 | is_released: bool, 74 | // whether the function has been cloned from frontend or not. 75 | is_cloned: bool, 76 | } 77 | 78 | impl Function { 79 | pub(crate) fn new(handle: ts::TVMFunctionHandle, is_global: bool, is_released: bool) -> Self { 80 | Function { 81 | handle: handle, 82 | is_global: is_global, 83 | is_released: is_released, 84 | is_cloned: false, 85 | } 86 | } 87 | 88 | /// For a given function, it returns a function by name. 89 | pub fn get_function(name: &str, is_global: bool) -> Option { 90 | let gnames = GLOBAL_FUNCTION_NAMES.lock().unwrap(); 91 | let fn_name = gnames.iter().find(|&&s| s == name)?; 92 | get_global_func(fn_name, is_global) 93 | } 94 | 95 | /// Returns the underlying TVM function handle. 96 | pub fn handle(&self) -> ts::TVMFunctionHandle { 97 | self.handle 98 | } 99 | 100 | /// Returns `true` if the underlying TVM function is global and `false` otherwise. 101 | pub fn is_global(&self) -> bool { 102 | self.is_global 103 | } 104 | 105 | /// Returns `true` if the underlying TVM function has been released 106 | /// from the frontend and `false` otherwise. 107 | pub fn is_released(&self) -> bool { 108 | self.is_released 109 | } 110 | 111 | /// Returns `true` if the underlying TVM function has been cloned 112 | /// from the frontend and `false` otherwise. 113 | pub fn is_cloned(&self) -> bool { 114 | self.is_cloned 115 | } 116 | } 117 | 118 | impl Clone for Function { 119 | fn clone(&self) -> Function { 120 | if !self.is_released && !self.is_cloned { 121 | Self { 122 | handle: self.handle, 123 | is_global: self.is_global, 124 | is_released: self.is_released, 125 | is_cloned: true, 126 | } 127 | } else { 128 | Function::new(self.handle, self.is_global, self.is_released) 129 | } 130 | } 131 | } 132 | 133 | impl Drop for Function { 134 | fn drop(&mut self) { 135 | if !self.is_released && !self.is_global && !self.is_cloned { 136 | check_call!(ts::TVMFuncFree(self.handle)); 137 | self.is_released = true; 138 | } 139 | } 140 | } 141 | 142 | /// Function builder in order to create and call functions. 143 | /// 144 | /// *Note:* Currently TVM functions accept *at most* one return value. 145 | #[derive(Debug, Clone, Default)] 146 | pub struct Builder<'a> { 147 | pub func: Option, 148 | pub arg_buf: Option]>>, 149 | pub ret_buf: Option>, 150 | } 151 | 152 | impl<'a> Builder<'a> { 153 | pub fn new( 154 | func: Option, 155 | arg_buf: Option]>>, 156 | ret_buf: Option>, 157 | ) -> Self { 158 | Self { 159 | func, 160 | arg_buf, 161 | ret_buf, 162 | } 163 | } 164 | 165 | pub fn get_function(&mut self, name: &str, is_global: bool) -> &mut Self { 166 | self.func = Function::get_function(name, is_global); 167 | self 168 | } 169 | 170 | /// Pushes a [`TVMArgValue`] into the function argument buffer. 171 | pub fn arg<'b, T: ?Sized>(&mut self, arg: &'b T) -> &mut Self 172 | where 173 | TVMValue: From<&'b T>, 174 | TypeCode: From<&'b T>, 175 | { 176 | let tvm_arg = TVMArgValue::from(arg); 177 | if self.arg_buf.is_none() { 178 | self.arg_buf = Some(Box::new([tvm_arg])); 179 | } else { 180 | let new_arg_buf = self.arg_buf.take().map(|bbuf| { 181 | let mut new_arg_buf = Vec::from(bbuf); 182 | new_arg_buf.push(tvm_arg); 183 | let new_len = new_arg_buf.len(); 184 | new_arg_buf.truncate(new_len); 185 | new_arg_buf.into_boxed_slice() 186 | }); 187 | self.arg_buf = new_arg_buf; 188 | } 189 | self 190 | } 191 | 192 | /// Pushes multiple [`TVMArgValue`]s into the function argument buffer. 193 | pub fn args<'b, T: 'b + ?Sized, I>(&mut self, args: I) -> &mut Self 194 | where 195 | I: IntoIterator, 196 | TVMValue: From<&'b T>, 197 | TypeCode: From<&'b T>, 198 | { 199 | for arg in args { 200 | self.arg(&arg); 201 | } 202 | self 203 | } 204 | 205 | /// Sets an output for a function that requirs a mutable output to be provided. 206 | /// See the `basics` in tests for an example. 207 | pub fn set_output<'b, T: 'b + ?Sized>(&mut self, arg: &'b mut T) -> &mut Self 208 | where 209 | TVMValue: From<&'b T>, 210 | TypeCode: From<&'b T>, 211 | { 212 | let tvm_ret = TVMRetValue::new(TVMValue::from(arg), TypeCode::from(arg)); 213 | if self.ret_buf.is_none() { 214 | self.ret_buf = Some(Box::new([tvm_ret])); 215 | } else { 216 | let new_ret_buf = self.ret_buf.take().map(|_| { 217 | let mut new_buf = Vec::with_capacity(1); 218 | new_buf.push(tvm_ret); 219 | new_buf.into_boxed_slice() 220 | }); 221 | self.ret_buf = new_ret_buf; 222 | } 223 | self 224 | } 225 | 226 | /// Calls the function that created from `Builder`. 227 | pub fn invoke(&mut self) -> Result { 228 | self.clone()(()) 229 | } 230 | } 231 | 232 | impl<'a> FnOnce<((),)> for Builder<'a> { 233 | type Output = Result; 234 | extern "rust-call" fn call_once(self, _: ((),)) -> Self::Output { 235 | if self.func.is_none() { 236 | bail!("{}", ErrorKind::FunctionNotFound); 237 | } 238 | let mut ret_val = unsafe { mem::uninitialized::() }; 239 | let mut ret_type_code = 0 as c_int; 240 | if self.arg_buf.is_some() { 241 | let arg_buf = self.arg_buf?; 242 | let mut num_args = arg_buf.len(); 243 | let mut values = arg_buf 244 | .iter() 245 | .map(|tav| tav.value.inner) 246 | .collect::>(); 247 | let mut tcodes = arg_buf 248 | .iter() 249 | .map(|tav| tav.type_code as c_int) 250 | .collect::>(); 251 | if self.ret_buf.is_some() { 252 | num_args = num_args + 1; 253 | ret_val = *self.ret_buf.clone()?[0].value; 254 | ret_type_code = self.ret_buf.clone()?[0].type_code as c_int; 255 | values.append(&mut vec![ret_val]); 256 | tcodes.append(&mut vec![ret_type_code]); 257 | } 258 | values.truncate(num_args); 259 | tcodes.truncate(num_args); 260 | check_call!(ts::TVMFuncCall( 261 | self.func?.handle, 262 | values.as_mut_ptr(), 263 | tcodes.as_mut_ptr(), 264 | num_args as c_int, 265 | &mut ret_val as *mut _, 266 | &mut ret_type_code as *mut _ 267 | )); 268 | } else { 269 | check_call!(ts::TVMFuncCall( 270 | self.func?.handle, 271 | ptr::null_mut(), 272 | ptr::null_mut(), 273 | 0 as c_int, 274 | &mut ret_val as *mut _, 275 | &mut ret_type_code as *mut _ 276 | )); 277 | } 278 | let ret = TVMRetValue::new( 279 | TVMValue::new(ValueKind::Return, ret_val), 280 | ret_type_code.into(), 281 | ); 282 | Ok(ret) 283 | } 284 | } 285 | 286 | /// Converts a [`Function`] to builder. Currently, this is the best way to work with 287 | /// TVM functions. 288 | impl<'a> From for Builder<'a> { 289 | fn from(func: Function) -> Self { 290 | Builder::new(Some(func), None, None) 291 | } 292 | } 293 | 294 | /// Converts a mutable reference of a [`Module`] to [`Builder`]. 295 | impl<'a: 'b, 'b> From<&'b mut Module> for Builder<'a> { 296 | fn from(module: &mut Module) -> Self { 297 | Builder::new(module.entry.take(), None, None) 298 | } 299 | } 300 | 301 | unsafe extern "C" fn tvm_callback( 302 | args: *mut ts::TVMValue, 303 | type_codes: *mut c_int, 304 | num_args: c_int, 305 | ret: ts::TVMRetValueHandle, 306 | fhandle: *mut c_void, 307 | ) -> c_int { 308 | let len = num_args as usize; 309 | let args_list = slice::from_raw_parts_mut(args, len); 310 | let type_codes_list = slice::from_raw_parts_mut(type_codes, len); 311 | let mut local_args: Vec = Vec::new(); 312 | // due to unsafe mem::uninitialized rustc warning about unused `value` and `tcode`. 313 | let mut _value = mem::uninitialized::(); 314 | let mut _tcode = mem::uninitialized::(); 315 | let rust_fn = mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result>(fhandle); 316 | for i in 0..len { 317 | _value = args_list[i]; 318 | _tcode = type_codes_list[i]; 319 | if _tcode == TypeCode::kNodeHandle as c_int 320 | || _tcode == TypeCode::kFuncHandle as c_int 321 | || _tcode == TypeCode::kModuleHandle as c_int 322 | { 323 | check_call!(ts::TVMCbArgToReturn(&mut _value as *mut _, _tcode)); 324 | } 325 | local_args.push(TVMArgValue::new( 326 | TVMValue::new(ValueKind::Handle, _value), 327 | _tcode.into(), 328 | )); 329 | } 330 | 331 | let rv = match rust_fn(local_args.as_slice()) { 332 | Ok(v) => v, 333 | Err(msg) => { 334 | ::set_last_error(&msg); 335 | return -1; 336 | } 337 | }; 338 | let mut ret_val = *rv.value; 339 | let mut ret_type_code = rv.type_code as c_int; 340 | check_call!(ts::TVMCFuncSetReturn( 341 | ret, 342 | &mut ret_val as *mut _, 343 | &mut ret_type_code as *mut _, 344 | 1 as c_int 345 | )); 346 | 0 347 | } 348 | 349 | unsafe extern "C" fn tvm_callback_finalizer(fhandle: *mut c_void) { 350 | let rust_fn = mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result>(fhandle); 351 | mem::drop(rust_fn); 352 | } 353 | 354 | fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result) -> Function { 355 | let mut fhandle = ptr::null_mut() as ts::TVMFunctionHandle; 356 | let resource_handle = f as *mut fn(&[TVMArgValue]) -> Result; 357 | check_call!(ts::TVMFuncCreateFromCFunc( 358 | Some(tvm_callback), 359 | resource_handle as *mut c_void, 360 | Some(tvm_callback_finalizer), 361 | &mut fhandle as *mut _ 362 | )); 363 | Function::new(fhandle, false, false) 364 | } 365 | 366 | /// Registers a Rust function with signature 367 | /// `fn(&[TVMArgValue]) -> Result` 368 | /// as a **global TVM packed function** from frontend to TVM backend. 369 | /// 370 | /// Use [`register_global_func`] if overriding an existing global TVM function 371 | /// is not required. 372 | /// 373 | /// ## Example 374 | /// 375 | /// ``` 376 | /// fn sum(args: &[TVMArgValue]) -> Result { 377 | /// let mut ret = 0; 378 | /// for arg in args.iter() { 379 | /// ret += arg.to_int(); 380 | /// } 381 | /// let ret_val = TVMRetValue::from(&ret); 382 | /// Ok(ret_val) 383 | /// } 384 | /// 385 | /// tvm::function::register(sum, "mysum".to_owned(), false).unwrap(); 386 | /// let mut registered = function::Builder::default(); 387 | /// registered.get_function("mysum", true); 388 | /// assert!(registered.func.is_some()); 389 | /// registered.args(&[10, 20, 30]); 390 | /// assert_eq!(registered.invoke().unwrap().to_int(), 60); 391 | /// ``` 392 | pub fn register( 393 | f: fn(&[TVMArgValue]) -> Result, 394 | name: String, 395 | override_: bool, 396 | ) -> Result<()> { 397 | let func = convert_to_tvm_func(f); 398 | let name = CString::new(name)?; 399 | check_call!(ts::TVMFuncRegisterGlobal( 400 | name.as_ptr() as *const c_char, 401 | func.handle(), 402 | override_ as c_int 403 | )); 404 | mem::forget(name); 405 | Ok(()) 406 | } 407 | 408 | /// Convenient macro for registering functions from frontend to backend as global 409 | /// TVM packed functions without overriding. If overriding an existing function is needed 410 | /// use the [`function::register`] function instead. 411 | /// 412 | /// ## Example 413 | /// 414 | /// ``` 415 | /// register_global_func! { 416 | /// fn sum(args: &[TVMArgValue]) -> Result { 417 | /// let mut ret = 0f64; 418 | /// for arg in args.iter() { 419 | /// ret += arg.to_float(); 420 | /// } 421 | /// let ret_val = TVMRetValue::from(&ret); 422 | /// Ok(ret_val) 423 | /// } 424 | /// } 425 | /// 426 | /// let mut registered = function::Builder::default(); 427 | /// registered.get_function("sum", true); 428 | /// assert!(registered.func.is_some()); 429 | /// registered.args(&[10f64, 20f64, 30f64]); 430 | /// assert_eq!(registered.invoke().unwrap().to_float(), 60f64); 431 | /// ``` 432 | #[macro_export] 433 | macro_rules! register_global_func { 434 | { 435 | $(#[$m:meta])* 436 | fn $fn_name:ident($args:ident : &[TVMArgValue]) -> Result { 437 | $($code:tt)* 438 | } 439 | } => {{ 440 | $(#[$m])* 441 | fn $fn_name($args: &[TVMArgValue]) -> Result { 442 | $($code)* 443 | } 444 | 445 | $crate::function::register($fn_name, stringify!($fn_name).to_owned(), false).unwrap(); 446 | }} 447 | } 448 | 449 | /// Convenient macro for calling TVM packed functions by providing a 450 | /// function identifier and some arguments. This macro outputs a `Result` type 451 | /// and let user to perform proper error handling. 452 | /// 453 | /// **Note**: this macro does *not* expect an outside mutable output. To 454 | /// set mutable output use [`set_output`] directly in the builder pattern. 455 | /// 456 | /// [`set_output`]:function/struct.Builder.html#method.set_output 457 | /// 458 | /// ## Example 459 | /// 460 | /// Instead of 461 | /// 462 | /// ``` 463 | /// function::Builder::from(func).arg(&a).arg(&b).invoke(); 464 | /// ``` 465 | /// 466 | /// one can use 467 | /// 468 | /// ``` 469 | /// call_packed!(func, &a, &b); 470 | /// ``` 471 | #[macro_export] 472 | macro_rules! call_packed { 473 | ($fn_name:ident, $($arg:expr),*) => {{ 474 | let mut builder = $crate::function::Builder::from($fn_name); 475 | $( 476 | builder.arg($arg); 477 | )* 478 | builder.invoke() 479 | }} 480 | } 481 | 482 | #[cfg(test)] 483 | mod tests { 484 | use super::*; 485 | 486 | #[test] 487 | fn list_global_func() { 488 | assert!( 489 | GLOBAL_FUNCTION_NAMES 490 | .lock() 491 | .unwrap() 492 | .iter() 493 | .find(|ref s| ***s == "tvm.graph_runtime.create") 494 | .is_some() 495 | ); 496 | } 497 | 498 | #[test] 499 | fn get_fn() { 500 | assert!(Function::get_function("tvm.graph_runtime.remote_create", true).is_some()); 501 | assert!(Function::get_function("does not exists!", false).is_none()); 502 | } 503 | 504 | #[test] 505 | fn provide_args() { 506 | let mut func = Builder::default(); 507 | func.get_function("tvm.graph_runtime.remote_create", true) 508 | .args(&[10, 20]) 509 | .arg(&"test".to_owned()); 510 | assert!(func.arg_buf.is_some()); 511 | assert_eq!(func.arg_buf.take().map(|bv| Vec::from(bv).len()), Some(3)); 512 | } 513 | } 514 | -------------------------------------------------------------------------------- /src/internal_api.rs: -------------------------------------------------------------------------------- 1 | use std::{cell::RefCell, collections::HashMap}; 2 | 3 | use Function; 4 | 5 | // access TVM internal API 6 | thread_local! { 7 | pub(crate) static API: RefCell> = RefCell::new(HashMap::new()); 8 | } 9 | 10 | pub(crate) fn get(name: String) -> Option { 11 | API.with(|hm| hm.borrow().get(&name).map(|f| f.clone())) 12 | } 13 | 14 | pub(crate) fn set(name: String, func: Function) { 15 | API.with(|hm| { 16 | (*hm.borrow_mut()).insert(name, func); 17 | }) 18 | } 19 | 20 | pub(crate) fn get_api(name: String) -> Function { 21 | let mut func = get(name.clone()); 22 | if func.is_none() { 23 | func = Function::get_function(&name, true); 24 | set( 25 | name, 26 | func.clone().expect("access to `internal_api` never panics"), 27 | ); 28 | } 29 | func.expect("access to `internal_api` never panics") 30 | } 31 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! [TVM](https://github.com/dmlc/tvm) is a compiler stack for deep learning systems. 2 | //! 3 | //! This crate provides an idiomatic Rust API for TVM runtime frontend. 4 | //! 5 | //! One particular use case is that given optimized deep learning model artifacts, 6 | //! already compiled with TVM, which include a shared library 7 | //! `lib.so`, `graph.json` and a byte-array `param.params`, one can load them 8 | //! in Rust to create a graph runtime. Then, run the model for some inputs and get the 9 | //! desired predictions all in Rust. 10 | //! 11 | //! Checkout the `examples` repository for more details. 12 | 13 | #![crate_name = "tvm_frontend"] 14 | #![recursion_limit = "1024"] 15 | #![allow(non_camel_case_types, unused_unsafe)] 16 | #![feature(try_from, try_trait, fn_traits, unboxed_closures, box_syntax)] 17 | 18 | #[macro_use] 19 | extern crate error_chain; 20 | extern crate tvm_sys as ts; 21 | #[macro_use] 22 | extern crate lazy_static; 23 | extern crate ndarray as rust_ndarray; 24 | extern crate num_traits; 25 | 26 | use std::{ 27 | ffi::{CStr, CString}, 28 | str, 29 | }; 30 | 31 | // Macro to check the return call to TVM runtime shared library 32 | macro_rules! check_call { 33 | ($e:expr) => {{ 34 | if unsafe { $e } != 0 { 35 | panic!("{}", $crate::get_last_error()); 36 | } 37 | }}; 38 | } 39 | 40 | /// Gets the last error message 41 | pub fn get_last_error() -> &'static str { 42 | unsafe { 43 | match CStr::from_ptr(ts::TVMGetLastError()).to_str() { 44 | Ok(s) => s, 45 | Err(_) => "Invalid UTF-8 message", 46 | } 47 | } 48 | } 49 | 50 | pub(crate) fn set_last_error(err: &Error) { 51 | let c_string = CString::new(err.to_string()).unwrap(); 52 | unsafe { 53 | ts::TVMAPISetLastError(c_string.as_ptr()); 54 | } 55 | } 56 | 57 | #[macro_use] 58 | pub mod function; 59 | pub mod bytearray; 60 | pub mod context; 61 | pub mod errors; 62 | mod internal_api; 63 | pub mod module; 64 | pub mod ndarray; 65 | pub mod ty; 66 | pub mod value; 67 | 68 | pub use bytearray::TVMByteArray; 69 | pub use context::{TVMContext, TVMDeviceType}; 70 | pub use errors::*; 71 | pub use function::Function; 72 | pub use module::Module; 73 | pub use ndarray::{empty, NDArray}; 74 | pub use ty::TVMType; 75 | pub use value::{TVMArgValue, TVMRetValue}; 76 | 77 | /// Outputs the current TVM version 78 | pub fn version() -> &'static str { 79 | match str::from_utf8(ts::TVM_VERSION) { 80 | Ok(s) => s, 81 | Err(_) => "Invalid UTF-8 string", 82 | } 83 | } 84 | 85 | #[cfg(test)] 86 | mod tests { 87 | use super::*; 88 | 89 | #[test] 90 | fn print_version() { 91 | println!("TVM version: {}", version()); 92 | } 93 | 94 | #[test] 95 | fn set_error() { 96 | let err = ErrorKind::EmptyArray; 97 | set_last_error(&err.into()); 98 | assert_eq!(get_last_error().trim(), ErrorKind::EmptyArray.to_string()); 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /src/module.rs: -------------------------------------------------------------------------------- 1 | //! Provides the [`Module`] type and methods for working with runtime TVM modules. 2 | 3 | use std::{ 4 | ffi::CString, 5 | mem, 6 | os::raw::{c_char, c_int}, 7 | path::Path, 8 | ptr, 9 | }; 10 | 11 | use ts; 12 | 13 | use function::Function; 14 | use internal_api; 15 | use ErrorKind; 16 | use Result; 17 | 18 | const ENTRY_FUNC: &'static str = "__tvm_main__"; 19 | 20 | /// Wrapper around TVM module handle which contains an entry function. 21 | /// The entry function can be applied to an imported module through [`entry_func`]. 22 | /// Also [`is_released`] shows whether the module is dropped or not. 23 | /// 24 | /// [`entry_func`]:struct.Module.html#method.entry_func 25 | /// [`is_released`]:struct.Module.html#method.is_released 26 | #[derive(Debug, Clone)] 27 | pub struct Module { 28 | pub(crate) handle: ts::TVMModuleHandle, 29 | is_released: bool, 30 | pub(crate) entry: Option, 31 | } 32 | 33 | impl Module { 34 | pub(crate) fn new( 35 | handle: ts::TVMModuleHandle, 36 | is_released: bool, 37 | entry: Option, 38 | ) -> Self { 39 | Self { 40 | handle, 41 | is_released, 42 | entry, 43 | } 44 | } 45 | 46 | /// Sets the entry function of a module. 47 | pub fn entry_func(&mut self) { 48 | if self.entry.is_none() { 49 | self.entry = self.get_function(ENTRY_FUNC, false).ok(); 50 | } 51 | } 52 | 53 | /// Gets a function by name from a registered module. 54 | pub fn get_function(&self, name: &str, query_import: bool) -> Result { 55 | let name = CString::new(name)?; 56 | let mut fhandle = ptr::null_mut() as ts::TVMFunctionHandle; 57 | check_call!(ts::TVMModGetFunction( 58 | self.handle, 59 | name.as_ptr() as *const c_char, 60 | query_import as c_int, 61 | &mut fhandle as *mut _ 62 | )); 63 | if fhandle.is_null() { 64 | bail!(ErrorKind::NullHandle(format!("{}", name.into_string()?))) 65 | } else { 66 | mem::forget(name); 67 | Ok(Function::new(fhandle, false, false)) 68 | } 69 | } 70 | 71 | /// Imports a dependent module such as `.ptx` for gpu. 72 | pub fn import_module(&self, dependent_module: Module) { 73 | check_call!(ts::TVMModImport(self.handle, dependent_module.handle)) 74 | } 75 | 76 | /// Loads a module shared library from path. 77 | pub fn load(path: &Path) -> Result { 78 | let path = path.to_owned(); 79 | let path_str = path.to_str()?.to_owned(); 80 | let ext = path.extension()?.to_str()?.to_owned(); 81 | let func = internal_api::get_api("module._LoadFromFile".to_owned()); 82 | let ret = call_packed!(func, &path_str, &ext)?; 83 | mem::forget(path); 84 | Ok(ret.to_module()) 85 | } 86 | 87 | /// Checks if a target device is enabled for a module. 88 | pub fn enabled(&self, target: &str) -> bool { 89 | let func = internal_api::get_api("module._Enabled".to_owned()); 90 | // `unwrap` is safe here because if there is any error during the 91 | // function call, it would occur in `call_packed!`. 92 | let ret = call_packed!(func, target).unwrap(); 93 | ret.to_int() != 0 94 | } 95 | 96 | /// Returns the underlying module handle. 97 | pub fn handle(&self) -> ts::TVMModuleHandle { 98 | self.handle 99 | } 100 | 101 | /// Returns true if the underlying module has been dropped and false otherwise. 102 | pub fn is_released(&self) -> bool { 103 | self.is_released 104 | } 105 | } 106 | 107 | impl Drop for Module { 108 | fn drop(&mut self) { 109 | if !self.is_released { 110 | check_call!(ts::TVMModFree(self.handle)); 111 | self.is_released = true; 112 | } 113 | } 114 | } 115 | -------------------------------------------------------------------------------- /src/ndarray.rs: -------------------------------------------------------------------------------- 1 | //! This module implements the [`NDArray`] type for working with *TVM tensors* or 2 | //! coverting from a Rust's ndarray to TVM `NDArray`. 3 | //! 4 | //! One can create an empty NDArray given the shape, context and dtype using [`empty`]. 5 | //! To create an NDArray from a mutable buffer in cpu use [`copy_from_buffer`]. 6 | //! To copy an NDArray to different context use [`copy_to_ctx`]. 7 | //! 8 | //! Given a [`Rust's dynamic ndarray`], one can convert it to TVM NDArray as follows: 9 | //! 10 | //! # Example 11 | //! 12 | //! ``` 13 | //! let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.]) 14 | //! .unwrap() 15 | //! .into_dyn(); // Rust's ndarray 16 | //! let nd = NDArray::from_rust_ndarray(&a, TVMContext::cpu(0), TVMType::from("float")).unwrap(); 17 | //! assert_eq!(nd.shape(), Some(&mut [2, 2])); 18 | //! let rnd: ArrayD = ArrayD::try_from(&nd).unwrap(); 19 | //! assert!(rnd.all_close(&a, 1e-8f32)); 20 | //! ``` 21 | //! 22 | //! [`Rust's dynamic ndarray`]:https://docs.rs/ndarray/0.12.1/ndarray/ 23 | //! [`copy_from_buffer`]:struct.NDArray.html#method.copy_from_buffer 24 | //! [`copy_to_ctx`]:struct.NDArray.html#method.copy_to_ctx 25 | 26 | use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice}; 27 | 28 | use num_traits::Num; 29 | use rust_ndarray::{Array, ArrayD}; 30 | 31 | use ts; 32 | 33 | use Error; 34 | use ErrorKind; 35 | use Result; 36 | use TVMByteArray; 37 | use TVMContext; 38 | use TVMType; 39 | 40 | /// See the [`module-level documentation`](../ndarray/index.html) for more details. 41 | /// 42 | /// Wrapper around TVM array handle. 43 | #[derive(Debug)] 44 | pub struct NDArray { 45 | pub(crate) handle: ts::TVMArrayHandle, 46 | is_view: bool, 47 | } 48 | 49 | impl NDArray { 50 | pub(crate) fn new(handle: ts::TVMArrayHandle, is_view: bool) -> Self { 51 | NDArray { 52 | handle: handle, 53 | is_view: is_view, 54 | } 55 | } 56 | 57 | /// Returns the underlying array handle. 58 | pub fn handle(&self) -> ts::TVMArrayHandle { 59 | self.handle 60 | } 61 | 62 | pub fn is_view(&self) -> bool { 63 | self.is_view 64 | } 65 | 66 | /// Returns the shape of the NDArray. 67 | pub fn shape(&self) -> Option<&mut [usize]> { 68 | let arr = unsafe { *(self.handle) }; 69 | if arr.shape.is_null() || arr.data.is_null() { 70 | return None; 71 | }; 72 | let slc = unsafe { slice::from_raw_parts_mut(arr.shape as *mut usize, arr.ndim as usize) }; 73 | Some(slc) 74 | } 75 | 76 | /// Returns the total number of entries of the NDArray. 77 | pub fn size(&self) -> Option { 78 | self.shape() 79 | .map(|v| v.into_iter().fold(1, |acc, &mut e| acc * e)) 80 | } 81 | 82 | /// Returns the context which the NDArray was defined. 83 | pub fn ctx(&self) -> TVMContext { 84 | unsafe { (*self.handle).ctx.into() } 85 | } 86 | 87 | /// Returns the type of the entries of the NDArray. 88 | pub fn dtype(&self) -> TVMType { 89 | unsafe { (*self.handle).dtype.into() } 90 | } 91 | 92 | /// Returns the number of dimensions of the NDArray. 93 | pub fn ndim(&self) -> usize { 94 | unsafe { (*self.handle).ndim as usize } 95 | } 96 | 97 | /// Returns the strides of the underlying NDArray. 98 | pub fn strides(&self) -> Option<&[usize]> { 99 | unsafe { 100 | let sz = self.ndim() * mem::size_of::(); 101 | let slc = slice::from_raw_parts((*self.handle).strides as *const usize, sz); 102 | Some(slc) 103 | } 104 | } 105 | 106 | /// Shows whether the underlying ndarray is contiguous in memory or not. 107 | pub fn is_contiguous(&self) -> Result { 108 | Ok(match self.strides() { 109 | None => true, 110 | Some(strides) => { 111 | // MissingShapeError in case shape is not determined 112 | self.shape()? 113 | .iter() 114 | .zip(strides) 115 | .rfold( 116 | (true, 1), 117 | |(is_contig, expected_stride), (shape, stride)| { 118 | ( 119 | is_contig && *stride == expected_stride, 120 | expected_stride * (*shape as usize), 121 | ) 122 | }, 123 | ) 124 | .0 125 | } 126 | }) 127 | } 128 | 129 | pub fn byte_offset(&self) -> isize { 130 | unsafe { (*self.handle).byte_offset as isize } 131 | } 132 | 133 | /// Flattens the NDArray to a `Vec` of the same type in cpu. 134 | /// 135 | /// ## Example 136 | /// 137 | /// ``` 138 | /// let shape = &mut [4]; 139 | /// let mut data = vec![1i32, 2, 3, 4]; 140 | /// let ctx = TVMContext::cpu(0); 141 | /// let mut ndarray = empty(shape, ctx, TVMType::from("int")); 142 | /// ndarray.copy_from_buffer(&mut data); 143 | /// assert_eq!(ndarray.shape(), Some(shape)); 144 | /// assert_eq!(ndarray.to_vec::().unwrap(), data); 145 | /// ``` 146 | pub fn to_vec(&self) -> Result> { 147 | if self.shape().is_none() { 148 | bail!("{}", ErrorKind::EmptyArray); 149 | } 150 | let earr = empty(self.shape()?, TVMContext::cpu(0), self.dtype()); 151 | let target = self.copy_to_ndarray(earr)?; 152 | let arr = unsafe { *(target.handle) }; 153 | let sz = self.size()? as usize; 154 | let mut v: Vec = Vec::with_capacity(sz * mem::size_of::()); 155 | unsafe { 156 | v.as_mut_ptr() 157 | .copy_from_nonoverlapping(arr.data as *const T, sz); 158 | v.set_len(sz); 159 | } 160 | Ok(v) 161 | } 162 | 163 | /// Converts the NDArray to [`TVMByteArray`]. 164 | pub fn to_bytearray(&self) -> Result { 165 | let v = self.to_vec::()?; 166 | Ok(TVMByteArray::from(&v)) 167 | } 168 | 169 | /// Creates an NDArray from a mutable buffer of types i32, u32 or f32 in cpu. 170 | /// 171 | /// ## Example 172 | /// 173 | /// ``` 174 | /// let shape = &mut [2]; 175 | /// let mut data = vec![1f32, 2]; 176 | /// let ctx = TVMContext::gpu(0); 177 | /// let mut ndarray = empty(shape, ctx, TVMType::from("int")); 178 | /// ndarray.copy_from_buffer(&mut data); 179 | /// ``` 180 | /// 181 | /// *Note*: if something goes wrong during the copy, it will panic 182 | /// from TVM side. See `TVMArrayCopyFromBytes` in `c_runtime_api.h`. 183 | pub fn copy_from_buffer(&mut self, data: &mut [T]) { 184 | check_call!(ts::TVMArrayCopyFromBytes( 185 | self.handle, 186 | data.as_ptr() as *mut _, 187 | data.len() * mem::size_of::() 188 | )); 189 | } 190 | 191 | /// Copies the NDArray to another target NDArray. 192 | pub fn copy_to_ndarray(&self, target: NDArray) -> Result { 193 | if self.dtype() != target.dtype() { 194 | bail!( 195 | "{}", 196 | ErrorKind::TypeMismatch( 197 | format!("{}", self.dtype().to_string()), 198 | format!("{}", target.dtype().to_string()), 199 | ) 200 | ); 201 | } 202 | check_call!(ts::TVMArrayCopyFromTo( 203 | self.handle, 204 | target.handle, 205 | ptr::null_mut() as ts::TVMStreamHandle 206 | )); 207 | Ok(target) 208 | } 209 | 210 | /// Copies the NDArray to a target context. 211 | pub fn copy_to_ctx(&self, target: &TVMContext) -> Result { 212 | let tmp = empty(self.shape()?, target.clone(), self.dtype()); 213 | let copy = self.copy_to_ndarray(tmp)?; 214 | Ok(copy) 215 | } 216 | 217 | /// Converts a Rust's ndarray to TVM NDArray. 218 | pub fn from_rust_ndarray( 219 | rnd: &ArrayD, 220 | ctx: TVMContext, 221 | dtype: TVMType, 222 | ) -> Result { 223 | let mut shape = rnd.shape().to_vec(); 224 | let mut nd = empty(&mut shape, ctx, dtype); 225 | let mut buf = Array::from_iter(rnd.into_iter().map(|&v| v as T)); 226 | nd.copy_from_buffer(buf.as_slice_mut()?); 227 | Ok(nd) 228 | } 229 | } 230 | 231 | /// Allocates and creates an empty NDArray given the shape, context and dtype. 232 | pub fn empty(shape: &mut [usize], ctx: TVMContext, dtype: TVMType) -> NDArray { 233 | let mut handle = ptr::null_mut() as ts::TVMArrayHandle; 234 | check_call!(ts::TVMArrayAlloc( 235 | shape.as_ptr() as *const i64, 236 | shape.len() as c_int, 237 | dtype.inner.code as c_int, 238 | dtype.inner.bits as c_int, 239 | dtype.inner.lanes as c_int, 240 | ctx.device_type.0 as c_int, 241 | ctx.device_id as c_int, 242 | &mut handle as *mut _, 243 | )); 244 | NDArray::new(handle, false) 245 | } 246 | 247 | macro_rules! impl_from_ndarray_rustndarray { 248 | ($type:ty, $type_name:tt) => { 249 | impl<'a> TryFrom<&'a NDArray> for ArrayD<$type> { 250 | type Error = Error; 251 | fn try_from(nd: &NDArray) -> Result> { 252 | if nd.shape().is_none() { 253 | bail!("{}", ErrorKind::EmptyArray); 254 | } 255 | assert_eq!(nd.dtype(), TVMType::from($type_name), "Type mismatch"); 256 | Ok(Array::from_shape_vec(&*nd.shape()?, nd.to_vec::<$type>()?)?) 257 | } 258 | } 259 | 260 | impl<'a> TryFrom<&'a mut NDArray> for ArrayD<$type> { 261 | type Error = Error; 262 | fn try_from(nd: &mut NDArray) -> Result> { 263 | if nd.shape().is_none() { 264 | bail!("{}", ErrorKind::EmptyArray); 265 | } 266 | assert_eq!(nd.dtype(), TVMType::from($type_name), "Type mismatch"); 267 | Ok(Array::from_shape_vec(&*nd.shape()?, nd.to_vec::<$type>()?)?) 268 | } 269 | } 270 | }; 271 | } 272 | 273 | impl_from_ndarray_rustndarray!(i32, "int"); 274 | impl_from_ndarray_rustndarray!(u32, "uint"); 275 | impl_from_ndarray_rustndarray!(f32, "float"); 276 | 277 | impl Drop for NDArray { 278 | fn drop(&mut self) { 279 | if !self.is_view { 280 | check_call!(ts::TVMArrayFree(self.handle)); 281 | } 282 | } 283 | } 284 | 285 | /// A trait for the supported 32bits numerical types in frontend. 286 | pub trait Num32: Num { 287 | const BITS: u8 = 32; 288 | } 289 | 290 | macro_rules! impl_num32 { 291 | ($($type:ty),+) => { 292 | $( 293 | impl Num32 for $type {} 294 | )+ 295 | }; 296 | } 297 | 298 | impl_num32!(i32, u32, f32); 299 | 300 | #[cfg(test)] 301 | mod tests { 302 | use super::*; 303 | 304 | #[test] 305 | fn basics() { 306 | let shape = &mut [1, 2, 3]; 307 | let ctx = TVMContext::cpu(0); 308 | let ndarray = empty(shape, ctx, TVMType::from("int")); 309 | assert_eq!(ndarray.shape().unwrap(), shape); 310 | assert_eq!( 311 | ndarray.size().unwrap(), 312 | shape.to_vec().into_iter().product() 313 | ); 314 | assert_eq!(ndarray.ndim(), 3); 315 | assert!(ndarray.strides().is_none()); 316 | assert_eq!(ndarray.byte_offset(), 0); 317 | } 318 | 319 | #[test] 320 | fn copy() { 321 | let shape = &mut [4]; 322 | let mut data = vec![1i32, 2, 3, 4]; 323 | let ctx = TVMContext::cpu(0); 324 | let mut ndarray = empty(shape, ctx, TVMType::from("int")); 325 | assert!(ndarray.to_vec::().is_ok()); 326 | ndarray.copy_from_buffer(&mut data); 327 | assert_eq!(ndarray.shape().unwrap(), shape); 328 | assert_eq!(ndarray.to_vec::().unwrap(), data); 329 | assert_eq!(ndarray.ndim(), 1); 330 | assert!(ndarray.is_contiguous().is_ok()); 331 | assert_eq!(ndarray.byte_offset(), 0); 332 | let mut shape = vec![4]; 333 | let e = empty(&mut shape, TVMContext::cpu(0), TVMType::from("int")); 334 | let nd = ndarray.copy_to_ndarray(e); 335 | assert!(nd.is_ok()); 336 | assert_eq!(nd.unwrap().to_vec::().unwrap(), data); 337 | } 338 | 339 | #[test] 340 | #[should_panic(expected = "called `Result::unwrap()` on an `Err`")] 341 | fn copy_wrong_dtype() { 342 | let mut shape = vec![4]; 343 | let mut data = vec![1f32, 2., 3., 4.]; 344 | let ctx = TVMContext::cpu(0); 345 | let mut nd_float = empty(&mut shape, ctx.clone(), TVMType::from("float")); 346 | nd_float.copy_from_buffer(&mut data); 347 | let empty_int = empty(&mut shape, ctx, TVMType::from("int")); 348 | nd_float.copy_to_ndarray(empty_int).unwrap(); 349 | } 350 | 351 | #[test] 352 | fn rust_ndarray() { 353 | let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.]) 354 | .unwrap() 355 | .into_dyn(); 356 | let nd = 357 | NDArray::from_rust_ndarray(&a, TVMContext::cpu(0), TVMType::from("float")).unwrap(); 358 | assert_eq!(nd.shape().unwrap(), &mut [2, 2]); 359 | let rnd: ArrayD = ArrayD::try_from(&nd).unwrap(); 360 | assert!(rnd.all_close(&a, 1e-8f32)); 361 | } 362 | } 363 | -------------------------------------------------------------------------------- /src/ty.rs: -------------------------------------------------------------------------------- 1 | //! This module implements the required conversions from Rust types to TVM types. 2 | //! 3 | //! In TVM frontend only conversions from 32bits numeric types (i32, u32, f32) 4 | //! and 64bits pointers are supported. 5 | //! 6 | //! # Example 7 | //! 8 | //! ``` 9 | //! let dtype = TVMType::from("float"); 10 | //! println!("dtype is: {}", dtype); 11 | //! ``` 12 | 13 | use std::{ 14 | ffi::{CStr, CString}, 15 | fmt::{self, Display, Formatter}, 16 | ops::{Deref, DerefMut}, 17 | }; 18 | 19 | use ts; 20 | 21 | use function::Function; 22 | use module::Module; 23 | use ndarray::NDArray; 24 | use TVMByteArray; 25 | use TVMContext; 26 | use TVMDeviceType; 27 | 28 | /// TVM Type codes 29 | #[repr(u32)] 30 | #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] 31 | pub enum TypeCode { 32 | kDLInt = 0, 33 | kDLUInt = 1, 34 | kDLFloat = 2, 35 | kHandle = 3, 36 | kNull = 4, 37 | kTVMType = 5, 38 | kTVMContext = 6, 39 | kArrayHandle = 7, 40 | kNodeHandle = 8, 41 | kModuleHandle = 9, 42 | kFuncHandle = 10, 43 | kStr = 11, 44 | kBytes = 12, 45 | } 46 | 47 | impl Default for TypeCode { 48 | fn default() -> Self { 49 | TypeCode::kDLInt 50 | } 51 | } 52 | 53 | impl<'a> Into for i32 { 54 | fn into(self) -> TypeCode { 55 | match self { 56 | 0 => TypeCode::kDLInt, 57 | 1 => TypeCode::kDLUInt, 58 | 2 => TypeCode::kDLFloat, 59 | 3 => TypeCode::kHandle, 60 | 4 => TypeCode::kNull, 61 | 5 => TypeCode::kTVMType, 62 | 6 => TypeCode::kTVMContext, 63 | 7 => TypeCode::kArrayHandle, 64 | 8 => TypeCode::kNodeHandle, 65 | 9 => TypeCode::kModuleHandle, 66 | 10 => TypeCode::kFuncHandle, 67 | 11 => TypeCode::kStr, 68 | 12 => TypeCode::kBytes, 69 | _ => unreachable!(), 70 | } 71 | } 72 | } 73 | 74 | impl Display for TypeCode { 75 | fn fmt(&self, f: &mut Formatter) -> fmt::Result { 76 | write!( 77 | f, 78 | "{}", 79 | match self { 80 | TypeCode::kDLInt => "int", 81 | TypeCode::kDLUInt => "uint", 82 | TypeCode::kDLFloat => "float", 83 | TypeCode::kHandle => "handle", 84 | TypeCode::kNull => "null", 85 | TypeCode::kTVMType => "TVM type", 86 | TypeCode::kTVMContext => "TVM context", 87 | TypeCode::kArrayHandle => "Array handle", 88 | TypeCode::kNodeHandle => "Node handle", 89 | TypeCode::kModuleHandle => "Module handle", 90 | TypeCode::kFuncHandle => "Function handle", 91 | TypeCode::kStr => "string", 92 | TypeCode::kBytes => "bytes", 93 | } 94 | ) 95 | } 96 | } 97 | 98 | macro_rules! impl_prim_type { 99 | ($type:ty, $variant:ident) => { 100 | impl<'a> From<&'a $type> for TypeCode { 101 | fn from(_arg: &$type) -> Self { 102 | TypeCode::$variant 103 | } 104 | } 105 | 106 | impl<'a> From<&'a mut $type> for TypeCode { 107 | fn from(_arg: &mut $type) -> Self { 108 | TypeCode::$variant 109 | } 110 | } 111 | }; 112 | } 113 | 114 | impl_prim_type!(usize, kDLInt); 115 | impl_prim_type!(i64, kDLInt); 116 | impl_prim_type!(i32, kDLInt); 117 | impl_prim_type!(i16, kDLInt); 118 | impl_prim_type!(i8, kDLInt); 119 | impl_prim_type!(TVMDeviceType, kDLInt); 120 | 121 | impl_prim_type!(u64, kDLUInt); 122 | impl_prim_type!(u32, kDLUInt); 123 | impl_prim_type!(u16, kDLUInt); 124 | impl_prim_type!(u8, kDLUInt); 125 | 126 | impl_prim_type!(f64, kDLFloat); 127 | impl_prim_type!(f32, kDLFloat); 128 | 129 | impl_prim_type!(str, kStr); 130 | impl_prim_type!(CStr, kStr); 131 | impl_prim_type!(String, kStr); 132 | impl_prim_type!(CString, kStr); 133 | 134 | impl_prim_type!(TVMContext, kTVMContext); 135 | 136 | impl_prim_type!(TVMType, kTVMType); 137 | 138 | impl_prim_type!(Function, kFuncHandle); 139 | 140 | impl_prim_type!(Module, kModuleHandle); 141 | 142 | impl_prim_type!(NDArray, kArrayHandle); 143 | 144 | impl_prim_type!([u8], kBytes); 145 | impl_prim_type!(TVMByteArray, kBytes); 146 | 147 | /// See the [module-level documentation](../ty/index.html) for more details. 148 | /// 149 | /// Wrapper around underlying TVMType 150 | #[derive(Debug, Copy, Clone, PartialEq, Eq)] 151 | pub struct TVMType { 152 | pub inner: ts::TVMType, // fields are code: u8, bits: u8, lanes: u16 153 | } 154 | 155 | impl TVMType { 156 | pub(crate) fn new(type_code: u8, bits: u8, lanes: u16) -> Self { 157 | TVMType { 158 | inner: ts::TVMType { 159 | code: type_code, 160 | bits: bits, 161 | lanes: lanes, 162 | }, 163 | } 164 | } 165 | } 166 | 167 | impl<'a> From<&'a str> for TVMType { 168 | fn from(type_str: &'a str) -> Self { 169 | match type_str { 170 | "int" => TVMType::new(0, 32, 1), 171 | "uint" => TVMType::new(1, 32, 1), 172 | "float" => TVMType::new(2, 32, 1), 173 | "handle" => TVMType::new(4, 64, 1), 174 | _ => panic!("Unsupported type {:?}", type_str), 175 | } 176 | } 177 | } 178 | 179 | impl Display for TVMType { 180 | fn fmt(&self, f: &mut Formatter) -> fmt::Result { 181 | match self.inner { 182 | ts::TVMType { 183 | code: 0, 184 | bits: 32, 185 | lanes: 1, 186 | } => write!(f, "int"), 187 | ts::TVMType { 188 | code: 1, 189 | bits: 32, 190 | lanes: 1, 191 | } => write!(f, "uint"), 192 | ts::TVMType { 193 | code: 2, 194 | bits: 32, 195 | lanes: 1, 196 | } => write!(f, "float"), 197 | ts::TVMType { 198 | code: 4, 199 | bits: 64, 200 | lanes: 1, 201 | } => write!(f, "handle"), 202 | _ => write!(f, "unknown type"), 203 | } 204 | } 205 | } 206 | 207 | impl From for ts::DLDataType { 208 | fn from(dtype: TVMType) -> Self { 209 | dtype.inner 210 | } 211 | } 212 | 213 | impl From for TVMType { 214 | fn from(dtype: ts::DLDataType) -> Self { 215 | Self::new(dtype.code, dtype.bits, dtype.lanes) 216 | } 217 | } 218 | 219 | impl Deref for TVMType { 220 | type Target = ts::TVMType; 221 | fn deref(&self) -> &Self::Target { 222 | &self.inner 223 | } 224 | } 225 | 226 | impl DerefMut for TVMType { 227 | fn deref_mut(&mut self) -> &mut Self::Target { 228 | &mut self.inner 229 | } 230 | } 231 | 232 | impl<'a, 'b> From<&'b TVMType> for &'a str { 233 | fn from(ty: &TVMType) -> Self { 234 | match **ty { 235 | ts::TVMType { 236 | code: 0, 237 | bits: 32, 238 | lanes: 1, 239 | } => "int", 240 | ts::TVMType { 241 | code: 1, 242 | bits: 32, 243 | lanes: 1, 244 | } => "uint", 245 | ts::TVMType { 246 | code: 2, 247 | bits: 32, 248 | lanes: 1, 249 | } => "float", 250 | ts::TVMType { 251 | code: 4, 252 | bits: 64, 253 | lanes: 1, 254 | } => "handle", 255 | _ => panic!("undefined type"), 256 | } 257 | } 258 | } 259 | -------------------------------------------------------------------------------- /src/value.rs: -------------------------------------------------------------------------------- 1 | //! This module implements [`TVMArgValue`] and [`TVMRetValue`] types and their conversions 2 | //! to other supported `TVMValues`. `TVMRetValue` is the owned version of `TVMArgValue`. 3 | //! 4 | //! # Examples 5 | //! 6 | //! ``` 7 | //! let a = 42i8; 8 | //! let arg = TVMArgValue::from(&a); 9 | //! assert_eq!(arg.to_int() as i8, a); 10 | //! let ret = TVMRetValue::from(&a); 11 | //! assert_eq!(ret.to_int() as i8, a); 12 | //! ``` 13 | 14 | use std::{ 15 | any::Any, 16 | ffi::{CStr, CString}, 17 | fmt::{self, Debug, Formatter}, 18 | marker::PhantomData, 19 | mem, 20 | ops::{Deref, DerefMut}, 21 | os::raw::{c_char, c_void}, 22 | }; 23 | 24 | use ts; 25 | 26 | use ty::TypeCode; 27 | use Function; 28 | use Module; 29 | use NDArray; 30 | use TVMByteArray; 31 | use TVMContext; 32 | use TVMDeviceType; 33 | use TVMType; 34 | 35 | #[derive(Debug, Clone, PartialOrd, Ord, PartialEq, Eq)] 36 | pub(crate) enum ValueKind { 37 | Int, 38 | Float, 39 | Handle, 40 | Str, 41 | Bytes, 42 | Type, 43 | Context, 44 | Return, 45 | } 46 | 47 | /// Wrapper around the underlying `TVMValue`. 48 | #[derive(Clone)] 49 | pub struct TVMValue { 50 | pub(crate) kind: ValueKind, 51 | pub(crate) inner: ts::TVMValue, 52 | } 53 | 54 | impl TVMValue { 55 | pub(crate) fn new(kind: ValueKind, inner: ts::TVMValue) -> Self { 56 | TVMValue { kind, inner } 57 | } 58 | 59 | pub fn to_int(&self) -> i64 { 60 | unsafe { self.inner.v_int64 } 61 | } 62 | } 63 | 64 | macro_rules! impl_prim_val { 65 | ($type:ty, $kind:expr, $field:ident, $cast:ty) => { 66 | impl<'a> From<&'a $type> for TVMValue { 67 | fn from(arg: &$type) -> Self { 68 | let inner = ts::TVMValue { 69 | $field: *arg as $cast, 70 | }; 71 | Self::new($kind, inner) 72 | } 73 | } 74 | 75 | impl<'a> From<&'a mut $type> for TVMValue { 76 | fn from(arg: &mut $type) -> Self { 77 | let inner = ts::TVMValue { 78 | $field: *arg as $cast, 79 | }; 80 | Self::new($kind, inner) 81 | } 82 | } 83 | }; 84 | } 85 | 86 | impl_prim_val!(usize, ValueKind::Int, v_int64, i64); 87 | impl_prim_val!(i64, ValueKind::Int, v_int64, i64); 88 | impl_prim_val!(i32, ValueKind::Int, v_int64, i64); 89 | impl_prim_val!(i16, ValueKind::Int, v_int64, i64); 90 | impl_prim_val!(i8, ValueKind::Int, v_int64, i64); 91 | impl_prim_val!(u64, ValueKind::Int, v_int64, i64); 92 | impl_prim_val!(u32, ValueKind::Int, v_int64, i64); 93 | impl_prim_val!(u16, ValueKind::Int, v_int64, i64); 94 | impl_prim_val!(u8, ValueKind::Int, v_int64, i64); 95 | impl_prim_val!(bool, ValueKind::Int, v_int64, i64); 96 | impl_prim_val!(f64, ValueKind::Float, v_float64, f64); 97 | impl_prim_val!(f32, ValueKind::Float, v_float64, f64); 98 | 99 | impl<'a> From<&'a str> for TVMValue { 100 | fn from(arg: &str) -> TVMValue { 101 | let arg = CString::new(arg).unwrap(); 102 | let inner = ts::TVMValue { 103 | v_str: arg.as_ptr() as *const c_char, 104 | }; 105 | mem::forget(arg); 106 | Self::new(ValueKind::Str, inner) 107 | } 108 | } 109 | 110 | impl<'a> From<&'a String> for TVMValue { 111 | fn from(arg: &String) -> TVMValue { 112 | let arg = CString::new(arg.as_bytes()).unwrap(); 113 | let inner = ts::TVMValue { 114 | v_str: arg.as_ptr() as *const c_char, 115 | }; 116 | mem::forget(arg); 117 | Self::new(ValueKind::Str, inner) 118 | } 119 | } 120 | 121 | impl<'a> From<&'a CString> for TVMValue { 122 | fn from(arg: &CString) -> TVMValue { 123 | let arg = arg.to_owned(); 124 | let inner = ts::TVMValue { 125 | v_str: arg.as_ptr() as *const c_char, 126 | }; 127 | mem::forget(arg); 128 | Self::new(ValueKind::Str, inner) 129 | } 130 | } 131 | 132 | impl<'a> From<&'a [u8]> for TVMValue { 133 | fn from(arg: &[u8]) -> TVMValue { 134 | let arg = arg.to_owned(); 135 | let inner = ts::TVMValue { 136 | v_handle: &arg as *const _ as *mut c_void, 137 | }; 138 | mem::forget(arg); 139 | Self::new(ValueKind::Handle, inner) 140 | } 141 | } 142 | 143 | macro_rules! impl_tvm_val_from_handle { 144 | ($($ty:ty),+) => { 145 | $( 146 | impl<'a> From<&'a $ty> for TVMValue { 147 | fn from(arg: &$ty) -> Self { 148 | let inner = ts::TVMValue { 149 | v_handle: arg.handle as *mut _ as *mut c_void, 150 | }; 151 | Self::new(ValueKind::Handle, inner) 152 | } 153 | } 154 | )+ 155 | } 156 | } 157 | 158 | impl_tvm_val_from_handle!(Module, Function, NDArray); 159 | 160 | impl<'a> From<&'a TVMType> for TVMValue { 161 | fn from(ty: &TVMType) -> Self { 162 | let inner = ts::TVMValue { v_type: ty.inner }; 163 | Self::new(ValueKind::Type, inner) 164 | } 165 | } 166 | 167 | impl<'a> From<&'a TVMContext> for TVMValue { 168 | fn from(ctx: &TVMContext) -> Self { 169 | let inner = ts::TVMValue { 170 | v_ctx: ctx.clone().into(), 171 | }; 172 | Self::new(ValueKind::Context, inner) 173 | } 174 | } 175 | 176 | impl<'a> From<&'a TVMDeviceType> for TVMValue { 177 | fn from(dev: &TVMDeviceType) -> Self { 178 | let inner = ts::TVMValue { 179 | v_int64: dev.0 as i64, 180 | }; 181 | Self::new(ValueKind::Int, inner) 182 | } 183 | } 184 | 185 | impl<'a> From<&'a TVMByteArray> for TVMValue { 186 | fn from(barr: &TVMByteArray) -> Self { 187 | let inner = ts::TVMValue { 188 | v_handle: &barr.inner as *const ts::TVMByteArray as *mut c_void, 189 | }; 190 | Self::new(ValueKind::Bytes, inner) 191 | } 192 | } 193 | 194 | impl Debug for TVMValue { 195 | fn fmt(&self, f: &mut Formatter) -> fmt::Result { 196 | unsafe { 197 | write!( 198 | f, 199 | "TVMValue: [v_int64: {:?}], [v_float64: {:?}], [v_handle: {:?}],\ 200 | [v_str: {:?}]", 201 | self.inner.v_int64, self.inner.v_float64, self.inner.v_handle, self.inner.v_str 202 | ) 203 | } 204 | } 205 | } 206 | 207 | impl Deref for TVMValue { 208 | type Target = ts::TVMValue; 209 | fn deref(&self) -> &Self::Target { 210 | &self.inner 211 | } 212 | } 213 | 214 | impl DerefMut for TVMValue { 215 | fn deref_mut(&mut self) -> &mut Self::Target { 216 | &mut self.inner 217 | } 218 | } 219 | 220 | /// This type is needed for passing the supported values as arguments to [`call_packed!`] 221 | /// or [`function::Builder`]. Checkout the methods for conversions. 222 | /// 223 | /// ## Example 224 | /// 225 | /// ``` 226 | /// let ctx = TVMContext::gpu(0); 227 | /// let arg = TVMArgValue::from(&ctx); 228 | /// assert_eq!(arg.to_ctx(), ctx); 229 | /// ``` 230 | /// 231 | /// [`function::Builder`]:../function/struct.Builder.html 232 | #[derive(Debug, Clone)] 233 | pub struct TVMArgValue<'a> { 234 | pub value: TVMValue, 235 | pub type_code: TypeCode, 236 | _lifetime: PhantomData<&'a ()>, 237 | } 238 | 239 | impl<'a> TVMArgValue<'a> { 240 | pub fn new(value: TVMValue, type_code: TypeCode) -> Self { 241 | TVMArgValue { 242 | value: value, 243 | type_code: type_code, 244 | _lifetime: PhantomData, 245 | } 246 | } 247 | } 248 | 249 | /// Main way to create a TVMArgValue from suported Rust values. 250 | impl<'b, 'a: 'b, T: 'b + ?Sized> From<&'b T> for TVMArgValue<'a> 251 | where 252 | TVMValue: From<&'b T>, 253 | TypeCode: From<&'b T>, 254 | { 255 | fn from(arg: &'b T) -> Self { 256 | TVMArgValue::new(TVMValue::from(arg), TypeCode::from(arg)) 257 | } 258 | } 259 | 260 | /// TVMRetValue is an owned TVMArgValue. 261 | /// 262 | /// ## Example 263 | /// 264 | /// ``` 265 | /// let ctx = TVMContext::gpu(0); 266 | /// let arg = TVMRetValue::from(&ctx); 267 | /// assert_eq!(arg.to_ctx(), ctx); 268 | /// ``` 269 | // TODO: WIP for unification of runtime and frontend 270 | #[derive(Debug)] 271 | pub struct TVMRetValue { 272 | pub value: TVMValue, 273 | box_value: Box, 274 | pub type_code: TypeCode, 275 | } 276 | 277 | impl TVMRetValue { 278 | pub(crate) fn new(value: TVMValue, type_code: TypeCode) -> Self { 279 | Self { 280 | value, 281 | box_value: box (), // starting the unification 282 | type_code, 283 | } 284 | } 285 | } 286 | 287 | impl Clone for TVMRetValue { 288 | fn clone(&self) -> TVMRetValue { 289 | TVMRetValue { 290 | value: self.value.clone(), 291 | box_value: box (), 292 | type_code: self.type_code, 293 | } 294 | } 295 | } 296 | 297 | impl<'b, T: 'b + ?Sized> From<&'b T> for TVMRetValue 298 | where 299 | TVMValue: From<&'b T>, 300 | TypeCode: From<&'b T>, 301 | { 302 | fn from(arg: &'b T) -> Self { 303 | TVMRetValue::new(TVMValue::from(arg), TypeCode::from(arg)) 304 | } 305 | } 306 | 307 | macro_rules! impl_to_methods { 308 | ($ty:ty) => { 309 | pub fn to_int(&self) -> i64 { 310 | if self.type_code != TypeCode::kDLInt && self.type_code != TypeCode::kNull { 311 | panic!("Requires i64 or NULL, but found {:?}", self.type_code); 312 | } 313 | 314 | unsafe { self.value.inner.v_int64 } 315 | } 316 | 317 | pub fn to_float(&self) -> f64 { 318 | assert_eq!( 319 | self.type_code, 320 | TypeCode::kDLFloat, 321 | "Requires f64, but found {:?}", 322 | self.type_code 323 | ); 324 | unsafe { self.value.inner.v_float64 } 325 | } 326 | 327 | pub fn to_bytearray(&self) -> TVMByteArray { 328 | assert_eq!( 329 | self.type_code, 330 | TypeCode::kBytes, 331 | "Requires byte array, but found {:?}", 332 | self.type_code 333 | ); 334 | unsafe { 335 | let barr_ptr = 336 | mem::transmute::<*mut c_void, *mut ts::TVMByteArray>(self.value.inner.v_handle); 337 | TVMByteArray::new(*barr_ptr) 338 | } 339 | } 340 | 341 | pub fn to_module(&self) -> Module { 342 | assert_eq!( 343 | self.type_code, 344 | TypeCode::kModuleHandle, 345 | "Requires module handle, but found {:?}", 346 | self.type_code 347 | ); 348 | let module_handle = unsafe { self.value.inner.v_handle }; 349 | Module::new(module_handle, false, None) 350 | } 351 | 352 | pub fn to_string(&self) -> String { 353 | assert_eq!( 354 | self.type_code, 355 | TypeCode::kStr, 356 | "Requires string, but found {:?}", 357 | self.type_code 358 | ); 359 | let ret_str = unsafe { 360 | match CStr::from_ptr(self.value.inner.v_str).to_str() { 361 | Ok(s) => s, 362 | Err(_) => "Invalid UTF-8 message", 363 | } 364 | }; 365 | ret_str.to_string() 366 | } 367 | 368 | pub fn to_ndarray(&self) -> NDArray { 369 | assert_eq!( 370 | self.type_code, 371 | TypeCode::kArrayHandle, 372 | "Requires Array handle, but found {:?}", 373 | self.type_code 374 | ); 375 | let handle = unsafe { self.value.inner.v_handle }; 376 | let arr_handle = unsafe { mem::transmute::<*mut c_void, ts::TVMArrayHandle>(handle) }; 377 | NDArray::new(arr_handle, true) 378 | } 379 | 380 | pub fn to_type(&self) -> TVMType { 381 | assert_eq!( 382 | self.type_code, 383 | TypeCode::kTVMType, 384 | "Requires TVMType, but found {:?}", 385 | self.type_code 386 | ); 387 | let ty = unsafe { self.value.inner.v_type }; 388 | TVMType::from(ty) 389 | } 390 | 391 | pub fn to_ctx(&self) -> TVMContext { 392 | assert_eq!( 393 | self.type_code, 394 | TypeCode::kTVMContext, 395 | "Requires TVMContext, but found {:?}", 396 | self.type_code 397 | ); 398 | let ctx = unsafe { self.value.inner.v_ctx }; 399 | TVMContext::from(ctx) 400 | } 401 | }; 402 | 403 | (refnc $ty:ty) => { 404 | impl<'a> $ty { 405 | impl_to_methods!($ty); 406 | } 407 | }; 408 | 409 | (owned $ty:ty) => { 410 | impl $ty { 411 | impl_to_methods!($ty); 412 | } 413 | } 414 | } 415 | 416 | impl_to_methods!(refnc TVMArgValue<'a>); 417 | impl_to_methods!(owned TVMRetValue); 418 | 419 | #[cfg(test)] 420 | mod tests { 421 | use super::*; 422 | 423 | #[test] 424 | fn numeric() { 425 | macro_rules! arg_ret_tests { 426 | ($v:expr; ints $($ty:ty),+) => {{ 427 | $( 428 | let v = $v as $ty; 429 | let a = TVMArgValue::from(&v); 430 | assert_eq!(a.to_int() as $ty, v); 431 | let b = TVMRetValue::from(&v); 432 | assert_eq!(b.to_int() as $ty, v); 433 | )+ 434 | }}; 435 | ($v:expr; floats $($ty:ty),+) => {{ 436 | $( 437 | let v = $v as $ty; 438 | let a = TVMArgValue::from(&v); 439 | assert_eq!(a.to_float() as $ty, v); 440 | let b = TVMRetValue::from(&v); 441 | assert_eq!(b.to_float() as $ty, v); 442 | )+ 443 | }}; 444 | } 445 | 446 | arg_ret_tests!(42; ints i8, i16, i32, i64); 447 | arg_ret_tests!(42; floats f32, f64); 448 | } 449 | 450 | #[test] 451 | fn bytearray() { 452 | let v = CString::new(b"hello".to_vec()).unwrap(); 453 | let v = v.into_bytes(); 454 | let tvm = TVMRetValue::from(&v[..]); 455 | assert_eq!( 456 | tvm.to_bytearray().data(), 457 | v.iter().map(|e| *e as i8).collect::>() 458 | ); 459 | let w = vec![1u8, 2, 3, 4, 5]; 460 | let tvm = TVMRetValue::from(&w[..]); 461 | assert_eq!( 462 | tvm.to_bytearray().data(), 463 | w.iter().map(|e| *e as i8).collect::>() 464 | ); 465 | } 466 | 467 | #[test] 468 | fn string() { 469 | let s = "hello"; 470 | let tvm_arg = TVMRetValue::from(s); 471 | assert_eq!(tvm_arg.to_string(), s.to_string()); 472 | let s = "hello".to_string(); 473 | let tvm_arg = TVMRetValue::from(&s); 474 | assert_eq!(tvm_arg.to_string(), s); 475 | } 476 | 477 | #[test] 478 | fn ty() { 479 | let t = TVMType::from("int"); 480 | let tvm = TVMRetValue::from(&t); 481 | assert_eq!(tvm.to_type(), t); 482 | } 483 | 484 | #[test] 485 | fn ctx() { 486 | let c = TVMContext::from("gpu"); 487 | let tvm = TVMRetValue::from(&c); 488 | assert_eq!(tvm.to_ctx(), c); 489 | } 490 | } 491 | -------------------------------------------------------------------------------- /tests/basics/.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | **/*.rs.bk 3 | Cargo.lock 4 | *.o 5 | *.so 6 | *.ptx 7 | *.json 8 | -------------------------------------------------------------------------------- /tests/basics/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "basics" 3 | version = "0.0.0" 4 | authors = ["Ehsan M.Kermani "] 5 | license = "Apache-2.0" 6 | build = "build.rs" 7 | 8 | [dependencies] 9 | ndarray = "0.12.1" 10 | tvm-frontend = { path = "../../" } 11 | 12 | [features] 13 | cpu = [] 14 | gpu = [] 15 | -------------------------------------------------------------------------------- /tests/basics/build.rs: -------------------------------------------------------------------------------- 1 | use std::process::Command; 2 | 3 | fn main() { 4 | #[cfg(feature = "cpu")] 5 | let script_path = concat!(env!("CARGO_MANIFEST_DIR"), "/src/tvm_add_cpu.py"); 6 | #[cfg(feature = "gpu")] 7 | let script_path = concat!(env!("CARGO_MANIFEST_DIR"), "/src/tvm_add_gpu.py"); 8 | let output = Command::new("python") 9 | .args(&[&script_path, env!("CARGO_MANIFEST_DIR")]) 10 | .output() 11 | .expect("Failed to execute command"); 12 | if output.stderr.len() > 0 { 13 | panic!(String::from_utf8(output.stderr).unwrap()); 14 | } 15 | println!( 16 | "cargo:rustc-link-search=native={}", 17 | env!("CARGO_MANIFEST_DIR") 18 | ); 19 | } 20 | -------------------------------------------------------------------------------- /tests/basics/src/main.rs: -------------------------------------------------------------------------------- 1 | extern crate ndarray as rust_ndarray; 2 | extern crate tvm_frontend as tvm; 3 | 4 | use std::path::Path; 5 | 6 | use tvm::*; 7 | 8 | fn main() { 9 | println!("start integration test"); 10 | let shape = &mut [2]; 11 | let mut data = vec![3f32, 4.0]; 12 | 13 | if cfg!(feature = "cpu") { 14 | println!("cpu test"); 15 | let mut arr = empty(shape, TVMContext::cpu(0), TVMType::from("float")); 16 | 17 | arr.copy_from_buffer(data.as_mut_slice()); 18 | 19 | let mut ret = empty(shape, TVMContext::cpu(0), TVMType::from("float")); 20 | let path = Path::new("add_cpu.so"); 21 | let mut fadd = Module::load(&path).unwrap(); 22 | assert!(fadd.enabled("cpu")); 23 | fadd.entry_func(); 24 | function::Builder::from(&mut fadd) 25 | .arg(&arr) 26 | .arg(&arr) 27 | .set_output(&mut ret) 28 | .invoke() 29 | .unwrap(); 30 | 31 | assert_eq!(ret.to_vec::().unwrap(), vec![6f32, 8.0]); 32 | println!("success!") 33 | } 34 | 35 | if cfg!(feature = "gpu") { 36 | println!("gpu test"); 37 | let mut arr = empty(shape, TVMContext::gpu(0), TVMType::from("float")); 38 | 39 | arr.copy_from_buffer(data.as_mut_slice()); 40 | 41 | let mut ret = empty(shape, TVMContext::gpu(0), TVMType::from("float")); 42 | let path = Path::new("add_gpu.so"); 43 | let ptx = Path::new("add_gpu.ptx"); 44 | let mut fadd = Module::load(path).unwrap(); 45 | let fadd_dep = Module::load(ptx).unwrap(); 46 | assert!(fadd.enabled("gpu")); 47 | fadd.import_module(fadd_dep); 48 | fadd.entry_func(); 49 | function::Builder::from(&mut fadd) 50 | .arg(&arr) 51 | .arg(&arr) 52 | .set_output(&mut ret) 53 | .invoke() 54 | .unwrap(); 55 | 56 | assert_eq!(ret.to_vec::().unwrap(), vec![6f32, 8.0]); 57 | println!("success!") 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /tests/basics/src/tvm_add_cpu.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import tvm 4 | from tvm.contrib import cc 5 | 6 | 7 | def test_add(target_dir): 8 | n = tvm.var("n") 9 | A = tvm.placeholder((n,), name='A') 10 | B = tvm.placeholder((n,), name='B') 11 | C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C") 12 | s = tvm.create_schedule(C.op) 13 | fadd = tvm.build(s, [A, B, C], "llvm", target_host="llvm", name="myadd") 14 | 15 | fadd.save(os.path.join(target_dir, "add_cpu.o")) 16 | cc.create_shared(os.path.join(target_dir, "add_cpu.so"), 17 | [os.path.join(target_dir, "add_cpu.o")]) 18 | 19 | 20 | if __name__ == "__main__": 21 | import sys 22 | if len(sys.argv) != 2: 23 | sys.exit(-1) 24 | test_add(sys.argv[1]) 25 | -------------------------------------------------------------------------------- /tests/basics/src/tvm_add_gpu.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import tvm 4 | from tvm.contrib import cc 5 | 6 | 7 | def test_add(target_dir): 8 | if not tvm.module.enabled("cuda"): 9 | print(f"skip {__file__} because cuda is not enabled...") 10 | return 11 | n = tvm.var("n") 12 | A = tvm.placeholder((n,), name='A') 13 | B = tvm.placeholder((n,), name='B') 14 | C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C") 15 | 16 | s = tvm.create_schedule(C.op) 17 | 18 | bx, tx = s[C].split(C.op.axis[0], factor=64) 19 | s[C].bind(bx, tvm.thread_axis("blockIdx.x")) 20 | s[C].bind(tx, tvm.thread_axis("threadIdx.x")) 21 | fadd_cuda = tvm.build(s, [A, B, C], "cuda", target_host="llvm", name="myadd") 22 | 23 | fadd_cuda.save(os.path.join(target_dir, "add_gpu.o")) 24 | fadd_cuda.imported_modules[0].save(os.path.join(target_dir, "add_gpu.ptx")) 25 | cc.create_shared(os.path.join(target_dir, "add_gpu.so"), 26 | [os.path.join(target_dir, "add_gpu.o")]) 27 | 28 | 29 | if __name__ == "__main__": 30 | import sys 31 | if len(sys.argv) != 2: 32 | sys.exit(-1) 33 | test_add(sys.argv[1]) 34 | -------------------------------------------------------------------------------- /tests/callback/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "callback" 3 | version = "0.0.0" 4 | authors = ["Ehsan M.Kermani "] 5 | 6 | [dependencies] 7 | ndarray = "0.12.1" 8 | tvm-frontend = { path = "../../" } 9 | -------------------------------------------------------------------------------- /tests/callback/src/bin/array.rs: -------------------------------------------------------------------------------- 1 | #![feature(extern_crate_item_prelude, try_from)] 2 | #![allow(unused_imports)] 3 | 4 | extern crate ndarray as rust_ndarray; 5 | #[macro_use] 6 | extern crate tvm_frontend as tvm; 7 | 8 | use std::convert::TryFrom; 9 | 10 | use rust_ndarray::ArrayD; 11 | 12 | use tvm::*; 13 | 14 | fn main() { 15 | register_global_func! { 16 | fn sum(args: &[TVMArgValue]) -> Result { 17 | let mut ret = 0f32; 18 | let shape = &mut [2]; 19 | for arg in args.iter() { 20 | let e = empty(shape, TVMContext::cpu(0), TVMType::from("float")); 21 | let arr = arg.to_ndarray().copy_to_ndarray(e).unwrap(); 22 | let rnd: ArrayD = ArrayD::try_from(&arr).unwrap(); 23 | ret += rnd.scalar_sum(); 24 | } 25 | let ret_val = TVMRetValue::from(&ret); 26 | Ok(ret_val) 27 | } 28 | } 29 | 30 | let shape = &mut [2]; 31 | let mut data = vec![3f32, 4.0]; 32 | let mut arr = empty(shape, TVMContext::cpu(0), TVMType::from("float")); 33 | arr.copy_from_buffer(data.as_mut_slice()); 34 | 35 | let mut registered = function::Builder::default(); 36 | registered.get_function("sum", true).arg(&arr).arg(&arr); 37 | assert_eq!(registered.invoke().unwrap().to_float(), 14f64); 38 | } 39 | -------------------------------------------------------------------------------- /tests/callback/src/bin/error.rs: -------------------------------------------------------------------------------- 1 | #![feature(extern_crate_item_prelude, panic_info_message)] 2 | #![allow(unused_imports, unused_must_use)] 3 | 4 | use std::panic; 5 | 6 | extern crate tvm_frontend as tvm; 7 | 8 | use tvm::*; 9 | 10 | fn main() { 11 | register_global_func! { 12 | fn error(_args: &[TVMArgValue]) -> Result { 13 | Err(ErrorKind::TypeMismatch( 14 | format!("{}", "i64".to_string()), 15 | format!("{}", "f64".to_string()), 16 | ).into()) 17 | } 18 | } 19 | 20 | let mut registered = function::Builder::default(); 21 | registered.get_function("error", true); 22 | assert!(registered.func.is_some()); 23 | registered.args(&[10, 20]); 24 | 25 | println!("expected error message is:"); 26 | panic::set_hook(Box::new(|panic_info| { 27 | if let Some(msg) = panic_info.message() { 28 | println!("{:?}", msg); 29 | } 30 | if let Some(location) = panic_info.location() { 31 | println!( 32 | "panic occurred in file '{}' at line {}", 33 | location.file(), 34 | location.line() 35 | ); 36 | } else { 37 | println!("panic occurred but can't get location information"); 38 | } 39 | })); 40 | 41 | let _result = registered.invoke(); 42 | } 43 | -------------------------------------------------------------------------------- /tests/callback/src/bin/float.rs: -------------------------------------------------------------------------------- 1 | #![feature(extern_crate_item_prelude)] 2 | #![allow(unused_imports)] 3 | 4 | #[macro_use] 5 | extern crate tvm_frontend as tvm; 6 | 7 | use tvm::*; 8 | 9 | fn main() { 10 | register_global_func! { 11 | fn sum(args: &[TVMArgValue]) -> Result { 12 | let mut ret = 0f64; 13 | for arg in args.iter() { 14 | ret += arg.to_float(); 15 | } 16 | let ret_val = TVMRetValue::from(&ret); 17 | Ok(ret_val) 18 | } 19 | } 20 | 21 | let mut registered = function::Builder::default(); 22 | registered.get_function("sum", true); 23 | assert!(registered.func.is_some()); 24 | registered.args(&[10f64, 20f64, 30f64]); 25 | assert_eq!(registered.invoke().unwrap().to_float(), 60f64); 26 | } 27 | -------------------------------------------------------------------------------- /tests/callback/src/bin/int.rs: -------------------------------------------------------------------------------- 1 | #![feature(extern_crate_item_prelude)] 2 | #![allow(unused_imports)] 3 | 4 | extern crate tvm_frontend as tvm; 5 | 6 | use tvm::*; 7 | 8 | fn main() { 9 | fn sum(args: &[TVMArgValue]) -> Result { 10 | let mut ret = 0; 11 | for arg in args.iter() { 12 | ret += arg.to_int(); 13 | } 14 | let ret_val = TVMRetValue::from(&ret); 15 | Ok(ret_val) 16 | } 17 | tvm::function::register(sum, "mysum".to_owned(), false).unwrap(); 18 | 19 | let mut registered = function::Builder::default(); 20 | registered.get_function("mysum", true); 21 | assert!(registered.func.is_some()); 22 | registered.args(&[10, 20, 30]); 23 | assert_eq!(registered.invoke().unwrap().to_int(), 60); 24 | } 25 | -------------------------------------------------------------------------------- /tests/callback/src/bin/string.rs: -------------------------------------------------------------------------------- 1 | #![feature(extern_crate_item_prelude, try_from)] 2 | #![allow(unused_imports)] 3 | 4 | #[macro_use] 5 | extern crate tvm_frontend as tvm; 6 | 7 | use tvm::*; 8 | 9 | fn main() { 10 | register_global_func! { 11 | fn concate_str(args: &[TVMArgValue]) -> Result { 12 | let mut ret = "".to_owned(); 13 | for arg in args.iter() { 14 | ret += arg.to_string().as_str(); 15 | } 16 | let ret_val = TVMRetValue::from(&ret); 17 | Ok(ret_val) 18 | } 19 | } 20 | let mut registered = function::Builder::default(); 21 | registered.get_function("concate_str", true); 22 | assert!(registered.func.is_some()); 23 | let a = "a".to_string(); 24 | let b = "b".to_string(); 25 | let c = "c".to_string(); 26 | registered.arg(&a).arg(&b).arg(&c); 27 | assert_eq!(registered.invoke().unwrap().to_string(), "abc".to_owned()); 28 | } 29 | -------------------------------------------------------------------------------- /tvm-sys/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "tvm-sys" 3 | version = "0.1.0" 4 | authors = ["Ehsan M.Kermani "] 5 | license = "Apache-2.0" 6 | description = "Raw C API" 7 | 8 | [build-dependencies] 9 | bindgen = "0.37.4" 10 | -------------------------------------------------------------------------------- /tvm-sys/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /tvm-sys/build.rs: -------------------------------------------------------------------------------- 1 | extern crate bindgen; 2 | 3 | use std::{env, error::Error, path::PathBuf, process, result::Result}; 4 | 5 | const TVM_RUNTIME: &'static str = "tvm_runtime"; 6 | 7 | fn main() { 8 | match run() { 9 | Ok(_) => (), 10 | Err(err) => { 11 | eprintln!("error occured during the build: {:?}", err); 12 | process::exit(1) 13 | } 14 | } 15 | } 16 | 17 | fn run() -> Result<(), Box> { 18 | println!("cargo:rustc-link-lib=dylib={}", TVM_RUNTIME); 19 | let lib = format!("lib{}", TVM_RUNTIME); 20 | println!("cargo:rustc-link-search=native={}", lib); 21 | let tvm_home = env::var("TVM_HOME").expect("TVM_HOME not found!"); 22 | let bindings = bindgen::Builder::default() 23 | .header(format!("{}/include/tvm/runtime/c_runtime_api.h", tvm_home)) 24 | .clang_arg(format!("-I{}/3rdparty/dlpack/include/", tvm_home)) 25 | .blacklist_type("max_align_t") // https://github.com/rust-lang-nursery/rust-bindgen/issues/550 26 | .layout_tests(false) 27 | .derive_partialeq(true) 28 | .derive_eq(true) 29 | .generate() 30 | .expect("unable to generate bindings"); 31 | 32 | bindings 33 | .write_to_file(PathBuf::from("src/bindgen.rs")) 34 | .expect("can not write the bindings!"); 35 | 36 | Ok(()) 37 | } 38 | -------------------------------------------------------------------------------- /tvm-sys/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![allow( 2 | non_camel_case_types, 3 | non_snake_case, 4 | non_upper_case_globals, 5 | dead_code, 6 | improper_ctypes 7 | )] 8 | 9 | include!("bindgen.rs"); 10 | --------------------------------------------------------------------------------