├── .gitignore ├── Cargo.toml ├── README.md ├── build-res ├── docsrs_res.zip └── tflitec_with_xnnpack_BUILD.bazel ├── build.rs ├── src ├── error.rs ├── interpreter.rs ├── lib.rs ├── model.rs └── tensor.rs └── tests └── add.bin /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .DS_Store 3 | /target 4 | Cargo.lock 5 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "tflitec" 3 | version = "0.6.0" 4 | authors = ["ebraraktas "] 5 | edition = "2018" 6 | license = "MIT" 7 | description = "A safe Rust wrapper of TensorFlow Lite C API supporting x86_64 and ARM (iOS, Android)" 8 | repository = "https://github.com/VoysAI/tflitec-rs" 9 | keywords = ["bindings", "tensorflow", "machine-learning", "ffi", "cross-platform"] 10 | categories = ["api-bindings", "science"] 11 | 12 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 13 | [lib] 14 | name = "tflitec" 15 | 16 | [dependencies] 17 | 18 | [build-dependencies] 19 | bindgen = "0.65" 20 | fs_extra = "1.3" 21 | curl = "0.4" 22 | 23 | [features] 24 | xnnpack = [] 25 | xnnpack_qu8 = ["xnnpack"] 26 | xnnpack_qs8 = ["xnnpack"] 27 | 28 | # docs.rs-specific configuration 29 | [package.metadata.docs.rs] 30 | # document all features 31 | all-features = true 32 | # defines the configuration attribute `docsrs` 33 | rustdoc-args = ["--cfg", "docsrs"] 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Crates.io](https://img.shields.io/crates/v/tflitec.svg)](https://crates.io/crates/tflitec) 2 | [![Docs.rs](https://img.shields.io/docsrs/tflitec)](https://docs.rs/tflitec/latest/tflitec/) 3 | 4 | This crate is a safe Rust wrapper of [TensorFlow Lite C API]. 5 | Its API is very similar to that of [TensorFlow Lite Swift API]. 6 | 7 | # Supported Targets 8 | 9 | Targets below are tested. However, others may work, too. 10 | * iOS: `aarch64-apple-ios` and `x86_64-apple-ios` 11 | * MacOS: `x86_64-apple-darwin` 12 | * Linux: `x86_64-unknown-linux-gnu` 13 | * Android: `aarch64-linux-android` and `armv7-linux-androideabi` 14 | * Windows ([see details](#Windows)) 15 | 16 | See [compilation](#compilation) section to see build instructions for your target. Please 17 | read [Optimized Build](#optimized_build) section carefully. 18 | 19 | # Features 20 | 21 | * `xnnpack` - Compiles XNNPACK and allows you to use XNNPACK delegate. See details of XNNPACK 22 | on [here][XNNPACK_blog]. 23 | * `xnnpack_qs8` - Compiles XNNPACK with additional build flags to accelerate inference of 24 | operators with symmetric quantization. See details in [this blog post][XNNPACK_quant_blog]. 25 | Implies `xnnpack`. 26 | * `xnnpack_qu8` - Similar to `xnnpack_qs8`, but accelerates few operators with 27 | asymmetric quantization. Implies `xnnpack`. 28 | 29 | *Note:* `xnnpack` is already enabled for iOS, but `xnnpack_qs8` and `xnnpack_qu8` 30 | should be enabled manually. 31 | 32 | # Examples 33 | 34 | The example below shows running inference on a TensorFlow Lite model. 35 | 36 | ```rust 37 | use tflitec::interpreter::{Interpreter, Options}; 38 | use tflitec::tensor; 39 | use tflitec::model::Model; 40 | 41 | // Create interpreter options 42 | let mut options = Options::default(); 43 | options.thread_count = 1; 44 | 45 | // Load example model which outputs y = 3 * x 46 | let model = Model::new("tests/add.bin")?; 47 | // Or initialize with model bytes if it is not available as a file 48 | // let model_data = std::fs::read("tests/add.bin")?; 49 | // let model = Model::from_bytes(&model_data)?; 50 | 51 | // Create interpreter 52 | let interpreter = Interpreter::new(&model, Some(options))?; 53 | // Resize input 54 | let input_shape = tensor::Shape::new(vec![10, 8, 8, 3]); 55 | let input_element_count = input_shape.dimensions().iter().copied().reduce(std::ops::Mul::mul).unwrap(); 56 | interpreter.resize_input(0, input_shape)?; 57 | // Allocate tensors if you just created Interpreter or resized its inputs 58 | interpreter.allocate_tensors()?; 59 | 60 | // Create dummy input 61 | let data = (0..input_element_count).map(|x| x as f32).collect::>(); 62 | 63 | let input_tensor = interpreter.input(0)?; 64 | assert_eq!(input_tensor.data_type(), tensor::DataType::Float32); 65 | 66 | // Copy input to buffer of first tensor (with index 0) 67 | // You have 2 options: 68 | // Set data using Tensor handle if you have it already 69 | assert!(input_tensor.set_data(&data[..]).is_ok()); 70 | // Or set data using Interpreter: 71 | assert!(interpreter.copy(&data[..], 0).is_ok()); 72 | 73 | // Invoke interpreter 74 | assert!(interpreter.invoke().is_ok()); 75 | 76 | // Get output tensor 77 | let output_tensor = interpreter.output(0)?; 78 | 79 | assert_eq!(output_tensor.shape().dimensions(), &vec![10, 8, 8, 3]); 80 | let output_vector = output_tensor.data::().to_vec(); 81 | let expected: Vec = data.iter().map(|e| e * 3.0).collect(); 82 | assert_eq!(expected, output_vector); 83 | # // The line below is needed for doctest, please ignore it 84 | # Ok::<(), Box>(()) 85 | ``` 86 | 87 | # Prebuilt Library Support 88 | 89 | As described in the [compilation section](#compilation), `libtensorflowlite_c` is built during compilation and 90 | this step may take a few minutes. To allow reusing prebuilt library, one can set `TFLITEC_PREBUILT_PATH` or 91 | `TFLITEC_PREBUILT_PATH_` environment variables (the latter has precedence). 92 | `NORMALIZED_TARGET` is the target triple which is [converted to uppercase and underscores][triple_normalization], 93 | as in the cargo configuration environment variables. Below you can find example values for different `TARGET`s: 94 | 95 | * `TFLITEC_PREBUILT_PATH_AARCH64_APPLE_IOS=/path/to/TensorFlowLiteC.framework` 96 | * `TFLITEC_PREBUILT_PATH_ARMV7_LINUX_ANDROIDEABI=/path/to/libtensorflowlite_c.so` 97 | * `TFLITEC_PREBUILT_PATH_X86_64_APPLE_DARWIN=/path/to/libtensorflowlite_c.dylib` 98 | * `TFLITEC_PREBUILT_PATH_X86_64_PC_WINDOWS_MSVC=/path/to/tensorflowlite_c.dll`. **Note that**, the prebuilt `.dll` 99 | file must have the corresponding `.lib` file under the same directory. 100 | 101 | You can find these files under the [`OUT_DIR`][cargo documentation] after you compile the library for the first time, 102 | then copy them to a persistent path and set environment variable. 103 | 104 | ## XNNPACK support 105 | 106 | You can activate `xnnpack` features with a prebuilt library, too. 107 | However, you must have built that library with XNNPACK, otherwise you will see a linking error. 108 | 109 | ## Local Header Directory Support 110 | 111 | Some tensorflow header files are downloaded from GitHub during compilation with or without prebuild binary. 112 | However, some users may have difficulty to access GitHub. Hence, one can pass header directory with 113 | `TFLITEC_HEADER_DIR` or `TFLITEC_HEADER_DIR_` environment variables 114 | (the latter has precedence). See the example command below: 115 | 116 | ```shell 117 | TFLITEC_HEADER_DIR=/path/to/tensorflow_v2.9.1_headers cargo build --release 118 | # Structure of /path/to/tensorflow_v2.9.1_headers is given below: 119 | # tensorflow_v2.9.1_headers 120 | # └── tensorflow 121 | # └── lite 122 | # ├── c 123 | # │ ├── c_api.h # Required 124 | # │ ├── c_api_types.h # Required 125 | # │ └── common.h # Required if xnnpack enabled 126 | # └── delegates 127 | # └── xnnpack 128 | # └── xnnpack_delegate.h # Required if xnnpack enabled 129 | ``` 130 | 131 | # Linking 132 | 133 | This library builds `libtensorflowlite_c` **dynamic library** and must be linked to it. This is not an issue 134 | if you build and run a binary target with `cargo run`. However, if you run your binary directly, you **must** 135 | have `libtensorflowlite_c` dynamic library in your library search path. You can either copy built library under 136 | `target/{release,debug}/build/tflitec-*/out` to one of the system library search paths 137 | (such as `/usr/lib` or `/usr/local/lib`) or add that directory (`out`) to search path. 138 | 139 | Similarly, if you distribute a prebuilt library depending on this, you must distribute `libtensorflowlite_c`, too. 140 | Or, document this warning in your library to instruct your users. 141 | 142 | # Compilation 143 | 144 | Current version of the crate builds tag `v2.9.1` of the [tensorflow project]. 145 | Compiled dynamic library or Framework will be available under `OUT_DIR` 146 | (see [cargo documentation]) of `tflitec`. 147 | You won't need this most of the time, because the crate output is linked appropriately. 148 | In addition, it may be better to read [prebuilt library support](#prebuilt-library-support) section 149 | to make your builds faster. 150 | For all environments and targets you will need to have: 151 | 152 | * `git` CLI to fetch [TensorFlow] 153 | * [Bazel] to build [TensorFlow], it is recommended to use [bazelisk]. 154 | * Python3 to build [TensorFlow] 155 | 156 | ## Optimized Build 157 | To build [TensorFlow] for your machine with native optimizations 158 | or pass other `--copts` to [Bazel], set environment variable below: 159 | ```sh 160 | TFLITEC_BAZEL_COPTS="OPT1 OPT2 ..." # space seperated values will be passed as `--copt=OPTN` to bazel 161 | TFLITEC_BAZEL_COPTS="-march=native" # for native optimized build 162 | # You can set target specific opts by appending normalized target to variable name 163 | TFLITEC_BAZEL_COPTS_X86_64_APPLE_DARWIN="-march=native" 164 | ``` 165 | --- 166 | Some OSs or targets may require additional steps. 167 | 168 | ## Android: 169 | * [Android NDK] 170 | * Following environment variables should be set appropriately 171 | to build [TensorFlow] for android: 172 | * `ANDROID_NDK_HOME` 173 | * `ANDROID_NDK_API_LEVEL` 174 | * `ANDROID_SDK_HOME` 175 | * `ANDROID_API_LEVEL` 176 | * `ANDROID_BUILD_TOOLS_VERSION` 177 | * [Bindgen] needs extra arguments, so set the environment variable below: 178 | ```sh 179 | # Set appropriate host tag and target name. 180 | # see https://developer.android.com/ndk/guides/other_build_systems 181 | HOST_TAG=darwin-x86_64 # as example 182 | TARGET_TRIPLE=arm-linux-androideabi # as example 183 | BINDGEN_EXTRA_CLANG_ARGS="\ 184 | -I${ANDROID_NDK_HOME}/sources/cxx-stl/llvm-libc++/include/ \ 185 | -I${ANDROID_NDK_HOME}/sysroot/usr/include/ \ 186 | -I${ANDROID_NDK_HOME}/toolchains/llvm/prebuilt/${HOST_TAG}/sysroot/usr/include/${TARGET_TRIPLE}/" 187 | ``` 188 | * (Recommended) [cargo-ndk] simplifies `cargo build` process. Recent version of the tool has `--bindgen` flag 189 | which sets `BINDGEN_EXTRA_CLANG_ARGS` variable appropriately. Hence, you can skip the step above. 190 | 191 | ## Windows 192 | 193 | Windows support is experimental. It is tested on Windows 10. You should follow instructions in 194 | the `Setup for Windows` section on [TensorFlow Build Instructions for Windows]. In other words, 195 | you should install following before build: 196 | * Python 3.8.x 64 bit (the instructions suggest 3.6.x but this package is tested with 3.8.x) 197 | * [Bazel] 198 | * [MSYS2] 199 | * Visual C++ Build Tools 2019 200 | 201 | Do not forget to add relevant paths to `%PATH%` environment variable by following the 202 | [TensorFlow Build Instructions for Windows] **carefully** (the only exception is the Python version). 203 | 204 | [TensorFlow]: https://www.tensorflow.org/ 205 | [Bazel]: https://bazel.build/ 206 | [bazelisk]: https://github.com/bazelbuild/bazelisk 207 | [Bindgen]: https://github.com/rust-lang/rust-bindgen 208 | [tensorflow project]: https://github.com/tensorflow/tensorflow 209 | [TensorFlow Lite Swift API]: https://www.tensorflow.org/lite/guide/ios 210 | [TensorFlow Lite C API]: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/c 211 | [XNNPACK_blog]: https://blog.tensorflow.org/2020/07/accelerating-tensorflow-lite-xnnpack-integration.html 212 | [XNNPACK_quant_blog]: https://blog.tensorflow.org/2021/09/faster-quantized-inference-with-xnnpack.html 213 | [Android NDK]: https://developer.android.com/ndk/guides 214 | [cargo documentation]: https://doc.rust-lang.org/cargo/reference/environment-variables.html#environment-variables-cargo-sets-for-crates 215 | [cargo-ndk]: https://github.com/bbqsrc/cargo-ndk 216 | [TensorFlow Build Instructions for Windows]: https://www.tensorflow.org/install/source_windows 217 | [MSYS2]: https://www.msys2.org/ 218 | [triple_normalization]: https://doc.rust-lang.org/cargo/reference/config.html#environment-variables -------------------------------------------------------------------------------- /build-res/docsrs_res.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebraraktas/tflitec-rs/d3f4cd8623436b0083f715c5439c96511f5c0c77/build-res/docsrs_res.zip -------------------------------------------------------------------------------- /build-res/tflitec_with_xnnpack_BUILD.bazel: -------------------------------------------------------------------------------- 1 | load( 2 | "//tensorflow/lite:build_def.bzl", 3 | "tflite_cc_shared_object", 4 | "tflite_copts", 5 | ) 6 | 7 | tflite_cc_shared_object( 8 | name = "tensorflowlite_c", 9 | linkopts = select({ 10 | "//tensorflow:ios": [ 11 | "-Wl,-exported_symbols_list,$(location //tensorflow/lite/c:exported_symbols.lds)", 12 | ], 13 | "//tensorflow:macos": [ 14 | "-Wl,-exported_symbols_list,$(location //tensorflow/lite/c:exported_symbols.lds)", 15 | ], 16 | "//tensorflow:windows": [], 17 | "//conditions:default": [ 18 | "-z defs", 19 | "-Wl,--version-script,$(location //tensorflow/lite/c:version_script.lds)", 20 | ], 21 | }), 22 | per_os_targets = True, 23 | deps = [ 24 | ":c_api_with_xnn_pack", 25 | "//tensorflow/lite/c:c_api_experimental", 26 | "//tensorflow/lite/c:exported_symbols.lds", 27 | "//tensorflow/lite/c:version_script.lds", 28 | ], 29 | ) 30 | 31 | cc_library( 32 | name = "c_api_with_xnn_pack", 33 | hdrs = ["//tensorflow/lite/c:c_api.h", 34 | "//tensorflow/lite/delegates/xnnpack:xnnpack_delegate.h"], 35 | copts = tflite_copts(), 36 | deps = [ 37 | "//tensorflow/lite/c:c_api", 38 | "//tensorflow/lite/delegates/xnnpack:xnnpack_delegate" 39 | ], 40 | alwayslink = 1, 41 | ) -------------------------------------------------------------------------------- /build.rs: -------------------------------------------------------------------------------- 1 | extern crate bindgen; 2 | 3 | use std::env; 4 | use std::fmt::Debug; 5 | use std::io::Write; 6 | use std::path::{Path, PathBuf}; 7 | use std::time::Instant; 8 | 9 | const TAG: &str = "v2.9.1"; 10 | const TF_GIT_URL: &str = "https://github.com/tensorflow/tensorflow.git"; 11 | const BAZEL_COPTS_ENV_VAR: &str = "TFLITEC_BAZEL_COPTS"; 12 | const PREBUILT_PATH_ENV_VAR: &str = "TFLITEC_PREBUILT_PATH"; 13 | const HEADER_DIR_ENV_VAR: &str = "TFLITEC_HEADER_DIR"; 14 | 15 | fn target_os() -> String { 16 | env::var("CARGO_CFG_TARGET_OS").expect("Unable to get TARGET_OS") 17 | } 18 | 19 | fn dll_extension() -> &'static str { 20 | match target_os().as_str() { 21 | "macos" => "dylib", 22 | "windows" => "dll", 23 | _ => "so", 24 | } 25 | } 26 | 27 | fn dll_prefix() -> &'static str { 28 | match target_os().as_str() { 29 | "windows" => "", 30 | _ => "lib", 31 | } 32 | } 33 | 34 | fn copy_or_overwrite + Debug, Q: AsRef + Debug>(src: P, dest: Q) { 35 | let src_path: &Path = src.as_ref(); 36 | let dest_path: &Path = dest.as_ref(); 37 | if dest_path.exists() { 38 | if dest_path.is_file() { 39 | std::fs::remove_file(dest_path).expect("Cannot remove file"); 40 | } else { 41 | std::fs::remove_dir_all(dest_path).expect("Cannot remove directory"); 42 | } 43 | } 44 | if src_path.is_dir() { 45 | let options = fs_extra::dir::CopyOptions { 46 | copy_inside: true, 47 | ..fs_extra::dir::CopyOptions::new() 48 | }; 49 | fs_extra::dir::copy(src_path, dest_path, &options).unwrap_or_else(|e| { 50 | panic!( 51 | "Cannot copy directory from {:?} to {:?}. Error: {}", 52 | src, dest, e 53 | ) 54 | }); 55 | } else { 56 | std::fs::copy(src_path, dest_path).unwrap_or_else(|e| { 57 | panic!( 58 | "Cannot copy file from {:?} to {:?}. Error: {}", 59 | src, dest, e 60 | ) 61 | }); 62 | } 63 | } 64 | 65 | fn normalized_target() -> Option { 66 | env::var("TARGET") 67 | .ok() 68 | .map(|t| t.to_uppercase().replace('-', "_")) 69 | } 70 | 71 | /// Looks for the env var `var_${NORMALIZED_TARGET}`, and falls back to just `var` when 72 | /// it is not found. 73 | /// 74 | /// `NORMALIZED_TARGET` is the target triple which is converted to uppercase and underscores. 75 | fn get_target_dependent_env_var(var: &str) -> Option { 76 | if let Some(target) = normalized_target() { 77 | if let Ok(v) = env::var(format!("{var}_{target}")) { 78 | return Some(v); 79 | } 80 | } 81 | env::var(var).ok() 82 | } 83 | 84 | fn test_python_bin(python_bin_path: &str) -> bool { 85 | println!("Testing Python at {}", python_bin_path); 86 | let success = std::process::Command::new(python_bin_path) 87 | .args(["-c", "import numpy, importlib.util"]) 88 | .status() 89 | .map(|s| s.success()) 90 | .unwrap_or_default(); 91 | if success { 92 | println!("Using Python at {}", python_bin_path); 93 | } 94 | success 95 | } 96 | 97 | fn get_python_bin_path() -> Option { 98 | if let Ok(val) = env::var("PYTHON_BIN_PATH") { 99 | if !test_python_bin(&val) { 100 | panic!("Given Python binary failed in test!") 101 | } 102 | Some(PathBuf::from(val)) 103 | } else { 104 | let bin = if target_os() == "windows" { 105 | "where" 106 | } else { 107 | "which" 108 | }; 109 | if let Ok(x) = std::process::Command::new(bin).arg("python3").output() { 110 | for path in String::from_utf8(x.stdout).unwrap().lines() { 111 | if test_python_bin(path) { 112 | return Some(PathBuf::from(path)); 113 | } 114 | println!("cargo:warning={:?} failed import test", path) 115 | } 116 | } 117 | if let Ok(x) = std::process::Command::new(bin).arg("python").output() { 118 | for path in String::from_utf8(x.stdout).unwrap().lines() { 119 | if test_python_bin(path) { 120 | return Some(PathBuf::from(path)); 121 | } 122 | println!("cargo:warning={:?} failed import test", path) 123 | } 124 | None 125 | } else { 126 | None 127 | } 128 | } 129 | } 130 | 131 | fn prepare_tensorflow_source(tf_src_path: &Path) { 132 | let complete_clone_hint_file = tf_src_path.join(".complete_clone"); 133 | if !complete_clone_hint_file.exists() { 134 | if tf_src_path.exists() { 135 | std::fs::remove_dir_all(tf_src_path).expect("Cannot clean tf_src_path"); 136 | } 137 | let mut git = std::process::Command::new("git"); 138 | git.arg("clone") 139 | .args(["--depth", "1"]) 140 | .arg("--shallow-submodules") 141 | .args(["--branch", TAG]) 142 | .arg("--single-branch") 143 | .arg(TF_GIT_URL) 144 | .arg(tf_src_path.to_str().unwrap()); 145 | println!("Git clone started"); 146 | let start = Instant::now(); 147 | if !git.status().expect("Cannot execute `git clone`").success() { 148 | panic!("git clone failed"); 149 | } 150 | std::fs::File::create(complete_clone_hint_file).expect("Cannot create clone hint file!"); 151 | println!("Clone took {:?}", Instant::now() - start); 152 | } 153 | 154 | #[cfg(feature = "xnnpack")] 155 | { 156 | let root = std::path::PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap()); 157 | let bazel_build_path = root.join("build-res/tflitec_with_xnnpack_BUILD.bazel"); 158 | let target = tf_src_path.join("tensorflow/lite/c/tmp/BUILD"); 159 | std::fs::create_dir_all(target.parent().unwrap()).expect("Cannot create tmp directory"); 160 | std::fs::copy(bazel_build_path, target).expect("Cannot copy temporary BUILD file"); 161 | } 162 | } 163 | 164 | fn check_and_set_envs() { 165 | let python_bin_path = get_python_bin_path().expect( 166 | "Cannot find Python binary having required packages. \ 167 | Make sure that `which python3` or `which python` points to a Python3 binary having numpy \ 168 | installed. Or set PYTHON_BIN_PATH to the path of that binary.", 169 | ); 170 | let os = env::var("CARGO_CFG_TARGET_OS").expect("Unable to get TARGET_OS"); 171 | let default_envs = [ 172 | ["PYTHON_BIN_PATH", python_bin_path.to_str().unwrap()], 173 | ["USE_DEFAULT_PYTHON_LIB_PATH", "1"], 174 | ["TF_NEED_OPENCL", "0"], 175 | ["TF_CUDA_CLANG", "0"], 176 | ["TF_NEED_TENSORRT", "0"], 177 | ["TF_DOWNLOAD_CLANG", "0"], 178 | ["TF_NEED_MPI", "0"], 179 | ["TF_NEED_ROCM", "0"], 180 | ["TF_NEED_CUDA", "0"], 181 | ["TF_OVERRIDE_EIGEN_STRONG_INLINE", "1"], // Windows only 182 | ["CC_OPT_FLAGS", "-Wno-sign-compare"], 183 | [ 184 | "TF_SET_ANDROID_WORKSPACE", 185 | if os == "android" { "1" } else { "0" }, 186 | ], 187 | ["TF_CONFIGURE_IOS", if os == "ios" { "1" } else { "0" }], 188 | ]; 189 | for kv in default_envs { 190 | let name = kv[0]; 191 | let val = kv[1]; 192 | if env::var(name).is_err() { 193 | env::set_var(name, val); 194 | } 195 | } 196 | let true_vals = ["1", "t", "true", "y", "yes"]; 197 | if true_vals.contains(&env::var("TF_SET_ANDROID_WORKSPACE").unwrap().as_str()) { 198 | let android_env_vars = [ 199 | "ANDROID_NDK_HOME", 200 | "ANDROID_NDK_API_LEVEL", 201 | "ANDROID_SDK_HOME", 202 | "ANDROID_API_LEVEL", 203 | "ANDROID_BUILD_TOOLS_VERSION", 204 | ]; 205 | for name in android_env_vars { 206 | env::var(name) 207 | .unwrap_or_else(|_| panic!("{} should be set to build for Android", name)); 208 | } 209 | } 210 | } 211 | 212 | fn lib_output_path() -> PathBuf { 213 | if target_os() != "ios" { 214 | let ext = dll_extension(); 215 | let lib_prefix = dll_prefix(); 216 | out_dir().join(format!("{}tensorflowlite_c.{}", lib_prefix, ext)) 217 | } else { 218 | out_dir().join("TensorFlowLiteC.framework") 219 | } 220 | } 221 | 222 | fn build_tensorflow_with_bazel(tf_src_path: &str, config: &str, lib_output_path: &Path) { 223 | let target_os = target_os(); 224 | let bazel_output_path_buf; 225 | let bazel_target; 226 | if target_os != "ios" { 227 | let ext = dll_extension(); 228 | let sub_directory = if cfg!(feature = "xnnpack") { 229 | "/tmp" 230 | } else { 231 | "" 232 | }; 233 | let mut lib_out_dir = PathBuf::from(tf_src_path) 234 | .join("bazel-bin") 235 | .join("tensorflow") 236 | .join("lite") 237 | .join("c"); 238 | if !sub_directory.is_empty() { 239 | lib_out_dir = lib_out_dir.join(&sub_directory[1..]); 240 | } 241 | let lib_prefix = dll_prefix(); 242 | bazel_output_path_buf = lib_out_dir.join(format!("{}tensorflowlite_c.{}", lib_prefix, ext)); 243 | bazel_target = format!("//tensorflow/lite/c{}:tensorflowlite_c", sub_directory); 244 | } else { 245 | bazel_output_path_buf = PathBuf::from(tf_src_path) 246 | .join("bazel-bin") 247 | .join("tensorflow") 248 | .join("lite") 249 | .join("ios") 250 | .join("TensorFlowLiteC_framework.zip"); 251 | bazel_target = String::from("//tensorflow/lite/ios:TensorFlowLiteC_framework"); 252 | }; 253 | 254 | let python_bin_path = env::var("PYTHON_BIN_PATH").expect("Cannot read PYTHON_BIN_PATH"); 255 | if !std::process::Command::new(&python_bin_path) 256 | .arg("configure.py") 257 | .current_dir(tf_src_path) 258 | .status() 259 | .unwrap_or_else(|_| panic!("Cannot execute python at {}", &python_bin_path)) 260 | .success() 261 | { 262 | panic!("Cannot configure tensorflow") 263 | } 264 | let mut bazel = std::process::Command::new("bazel"); 265 | { 266 | // Set bazel outputBase under OUT_DIR 267 | let bazel_output_base_path = out_dir().join(format!("tensorflow_{}_output_base", TAG)); 268 | bazel.arg(format!( 269 | "--output_base={}", 270 | bazel_output_base_path.to_str().unwrap() 271 | )); 272 | } 273 | bazel.arg("build").arg("-c").arg("opt"); 274 | 275 | // Configure XNNPACK flags 276 | // In r2.6, it is enabled for some OS such as Windows by default. 277 | // To enable it by feature flag, we disable it by default on all platforms. 278 | #[cfg(not(feature = "xnnpack"))] 279 | bazel.arg("--define").arg("tflite_with_xnnpack=false"); 280 | #[cfg(any(feature = "xnnpack_qu8", feature = "xnnpack_qs8"))] 281 | bazel.arg("--define").arg("tflite_with_xnnpack=true"); 282 | #[cfg(feature = "xnnpack_qs8")] 283 | bazel.arg("--define").arg("xnn_enable_qs8=true"); 284 | #[cfg(feature = "xnnpack_qu8")] 285 | bazel.arg("--define").arg("xnn_enable_qu8=true"); 286 | 287 | bazel 288 | .arg(format!("--config={}", config)) 289 | .arg(bazel_target) 290 | .current_dir(tf_src_path); 291 | 292 | if let Some(copts) = get_target_dependent_env_var(BAZEL_COPTS_ENV_VAR) { 293 | let copts = copts.split_ascii_whitespace(); 294 | for opt in copts { 295 | bazel.arg(format!("--copt={}", opt)); 296 | } 297 | } 298 | 299 | if target_os == "ios" { 300 | bazel.args(["--apple_bitcode=embedded", "--copt=-fembed-bitcode"]); 301 | } 302 | println!("Bazel Build Command: {:?}", bazel); 303 | if !bazel.status().expect("Cannot execute bazel").success() { 304 | panic!("Cannot build TensorFlowLiteC"); 305 | } 306 | if !bazel_output_path_buf.exists() { 307 | panic!( 308 | "Library/Framework not found in {}", 309 | bazel_output_path_buf.display() 310 | ) 311 | } 312 | if target_os != "ios" { 313 | copy_or_overwrite(&bazel_output_path_buf, lib_output_path); 314 | if target_os == "windows" { 315 | let mut bazel_output_winlib_path_buf = bazel_output_path_buf; 316 | bazel_output_winlib_path_buf.set_extension("dll.if.lib"); 317 | let winlib_output_path_buf = out_dir().join("tensorflowlite_c.lib"); 318 | copy_or_overwrite(bazel_output_winlib_path_buf, winlib_output_path_buf); 319 | } 320 | } else { 321 | if lib_output_path.exists() { 322 | std::fs::remove_dir_all(lib_output_path).unwrap(); 323 | } 324 | let mut unzip = std::process::Command::new("unzip"); 325 | unzip.args([ 326 | "-q", 327 | bazel_output_path_buf.to_str().unwrap(), 328 | "-d", 329 | out_dir().to_str().unwrap(), 330 | ]); 331 | unzip.status().expect("Cannot execute unzip"); 332 | } 333 | } 334 | 335 | fn out_dir() -> PathBuf { 336 | PathBuf::from(env::var("OUT_DIR").unwrap()) 337 | } 338 | 339 | fn prepare_for_docsrs() { 340 | // Docs.rs cannot access to network, use resource files 341 | let library_path = out_dir().join("libtensorflowlite_c.so"); 342 | let bindings_path = out_dir().join("bindings.rs"); 343 | 344 | let mut unzip = std::process::Command::new("unzip"); 345 | let root = std::path::PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap()); 346 | unzip 347 | .arg(root.join("build-res/docsrs_res.zip")) 348 | .arg("-d") 349 | .arg(out_dir()); 350 | if !(unzip 351 | .status() 352 | .unwrap_or_else(|_| panic!("Cannot execute unzip")) 353 | .success() 354 | && library_path.exists() 355 | && bindings_path.exists()) 356 | { 357 | panic!("Cannot extract docs.rs resources") 358 | } 359 | } 360 | 361 | fn generate_bindings(tf_src_path: PathBuf) { 362 | let mut builder = bindgen::Builder::default().header( 363 | tf_src_path 364 | .join("tensorflow/lite/c/c_api.h") 365 | .to_str() 366 | .unwrap(), 367 | ); 368 | if cfg!(feature = "xnnpack") { 369 | builder = builder.header( 370 | tf_src_path 371 | .join("tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h") 372 | .to_str() 373 | .unwrap(), 374 | ); 375 | } 376 | 377 | let bindings = builder 378 | .clang_arg(format!("-I{}", tf_src_path.to_str().unwrap())) 379 | // Tell cargo to invalidate the built crate whenever any of the 380 | // included header files changed. 381 | .parse_callbacks(Box::new(bindgen::CargoCallbacks)) 382 | // Finish the builder and generate the bindings. 383 | .generate() 384 | // Unwrap the Result and panic on failure. 385 | .expect("Unable to generate bindings"); 386 | 387 | // Write the bindings to the $OUT_DIR/bindings.rs file. 388 | bindings 389 | .write_to_file(out_dir().join("bindings.rs")) 390 | .expect("Couldn't write bindings!"); 391 | } 392 | 393 | fn install_prebuilt(prebuilt_tflitec_path: &str, tf_src_path: &Path, lib_output_path: &PathBuf) { 394 | // Copy prebuilt library to given path 395 | { 396 | let prebuilt_tflitec_path = PathBuf::from(prebuilt_tflitec_path); 397 | // Copy .{so,dylib,dll,Framework} file 398 | copy_or_overwrite(&prebuilt_tflitec_path, lib_output_path); 399 | 400 | if target_os() == "windows" { 401 | // Copy .lib file 402 | let mut prebuilt_lib_path = prebuilt_tflitec_path; 403 | prebuilt_lib_path.set_extension("lib"); 404 | if !prebuilt_lib_path.exists() { 405 | panic!("A prebuilt windows .dll file must have the corresponding .lib file under the same directory!") 406 | } 407 | let mut lib_file_path = lib_output_path.clone(); 408 | lib_file_path.set_extension("lib"); 409 | copy_or_overwrite(prebuilt_lib_path, lib_file_path); 410 | } 411 | } 412 | 413 | copy_or_download_headers( 414 | tf_src_path, 415 | &[ 416 | "tensorflow/lite/c/c_api.h", 417 | "tensorflow/lite/c/c_api_types.h", 418 | ], 419 | ); 420 | if cfg!(feature = "xnnpack") { 421 | copy_or_download_headers( 422 | tf_src_path, 423 | &[ 424 | "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h", 425 | "tensorflow/lite/c/common.h", 426 | ], 427 | ); 428 | } 429 | } 430 | 431 | fn copy_or_download_headers(tf_src_path: &Path, file_paths: &[&str]) { 432 | if let Some(header_src_dir) = get_target_dependent_env_var(HEADER_DIR_ENV_VAR) { 433 | copy_headers(Path::new(&header_src_dir), tf_src_path, file_paths) 434 | } else { 435 | download_headers(tf_src_path, file_paths) 436 | } 437 | } 438 | 439 | fn copy_headers(header_src_dir: &Path, tf_src_path: &Path, file_paths: &[&str]) { 440 | // Download header files from Github 441 | for file_path in file_paths { 442 | let dst_path = tf_src_path.join(file_path); 443 | if dst_path.exists() { 444 | continue; 445 | } 446 | if let Some(p) = dst_path.parent() { 447 | std::fs::create_dir_all(p).expect("Cannot generate header dir"); 448 | } 449 | copy_or_overwrite(header_src_dir.join(file_path), dst_path); 450 | } 451 | } 452 | 453 | fn download_headers(tf_src_path: &Path, file_paths: &[&str]) { 454 | // Download header files from Github 455 | for file_path in file_paths { 456 | let download_path = tf_src_path.join(file_path); 457 | if download_path.exists() { 458 | continue; 459 | } 460 | if let Some(p) = download_path.parent() { 461 | std::fs::create_dir_all(p).expect("Cannot generate header dir"); 462 | } 463 | let url = format!( 464 | "https://raw.githubusercontent.com/tensorflow/tensorflow/{}/{}", 465 | TAG, file_path 466 | ); 467 | download_file(&url, download_path.as_path()); 468 | } 469 | } 470 | 471 | fn download_file(url: &str, path: &Path) { 472 | let mut easy = curl::easy::Easy::new(); 473 | let output_file = std::fs::File::create(path).unwrap(); 474 | let mut writer = std::io::BufWriter::new(output_file); 475 | easy.url(url).unwrap(); 476 | easy.write_function(move |data| Ok(writer.write(data).unwrap())) 477 | .unwrap(); 478 | easy.perform().unwrap_or_else(|e| { 479 | std::fs::remove_file(path).unwrap(); // Delete corrupted or empty file 480 | panic!("Error occurred while downloading from {}: {:?}", url, e); 481 | }); 482 | } 483 | 484 | fn main() { 485 | { 486 | let env_vars = [ 487 | BAZEL_COPTS_ENV_VAR, 488 | PREBUILT_PATH_ENV_VAR, 489 | HEADER_DIR_ENV_VAR, 490 | ]; 491 | for env_var in env_vars { 492 | println!("cargo:rerun-if-env-changed={env_var}"); 493 | if let Some(target) = normalized_target() { 494 | println!("cargo:rerun-if-env-changed={env_var}_{target}"); 495 | } 496 | } 497 | } 498 | 499 | let out_path = out_dir(); 500 | let os = target_os(); 501 | let arch = env::var("CARGO_CFG_TARGET_ARCH").expect("Unable to get TARGET_ARCH"); 502 | let arch = match arch.as_str() { 503 | "aarch64" => String::from("arm64"), 504 | "armv7" => { 505 | if os == "android" { 506 | String::from("arm") 507 | } else { 508 | arch 509 | } 510 | } 511 | _ => arch, 512 | }; 513 | if os != "ios" { 514 | println!("cargo:rustc-link-search=native={}", out_path.display()); 515 | println!("cargo:rustc-link-lib=dylib=tensorflowlite_c"); 516 | } else { 517 | println!("cargo:rustc-link-search=framework={}", out_path.display()); 518 | println!("cargo:rustc-link-lib=framework=TensorFlowLiteC"); 519 | } 520 | if env::var("DOCS_RS") == Ok(String::from("1")) { 521 | // docs.rs cannot access to network, use resource files 522 | prepare_for_docsrs(); 523 | } else { 524 | let tf_src_path = out_path.join(format!("tensorflow_{}", TAG)); 525 | let lib_output_path = lib_output_path(); 526 | 527 | if let Some(prebuilt_tflitec_path) = get_target_dependent_env_var(PREBUILT_PATH_ENV_VAR) { 528 | install_prebuilt(&prebuilt_tflitec_path, &tf_src_path, &lib_output_path); 529 | } else { 530 | // Build from source 531 | check_and_set_envs(); 532 | prepare_tensorflow_source(tf_src_path.as_path()); 533 | let config = if os == "android" || os == "ios" || (os == "macos" && arch == "arm64") { 534 | format!("{}_{}", os, arch) 535 | } else { 536 | os 537 | }; 538 | build_tensorflow_with_bazel( 539 | tf_src_path.to_str().unwrap(), 540 | &config, 541 | lib_output_path.as_path(), 542 | ); 543 | } 544 | 545 | // Generate bindings using headers 546 | generate_bindings(tf_src_path); 547 | } 548 | } 549 | -------------------------------------------------------------------------------- /src/error.rs: -------------------------------------------------------------------------------- 1 | //! Definitions of `Error` type and `ErrorKind`s of the crate. 2 | use core::fmt::{Display, Formatter}; 3 | 4 | /// A list specifying general categories of TensorFlow Lite errors. 5 | #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] 6 | pub enum ErrorKind { 7 | /// Indicates given tensor index (first value) is larger than maximum index (second value). 8 | InvalidTensorIndex(/* index: */ usize, /* max_index: */ usize), 9 | /// Indicates given data length (first value) is not equal to required length (second value). 10 | InvalidTensorDataCount(/* provided: */ usize, /* required: */ usize), 11 | /// Indicates failure to resize tensor with index (first value). 12 | FailedToResizeInputTensor(/* index: */ usize), 13 | AllocateTensorsRequired, 14 | InvalidTensorDataType, 15 | FailedToAllocateTensors, 16 | FailedToCopyDataToInputTensor, 17 | FailedToLoadModel, 18 | FailedToCreateInterpreter, 19 | ReadTensorError, 20 | InvokeInterpreterRequired, 21 | } 22 | 23 | impl ErrorKind { 24 | pub(crate) fn as_string(&self) -> String { 25 | match *self { 26 | ErrorKind::InvalidTensorIndex(index, max_index) => { 27 | format!("invalid tensor index {}, max index is {}", index, max_index) 28 | } 29 | ErrorKind::InvalidTensorDataCount(provided, required) => format!( 30 | "provided data count {} must match the required count {}", 31 | provided, required 32 | ), 33 | ErrorKind::InvalidTensorDataType => { 34 | "tensor data type is unsupported or could not be determined due to a model error" 35 | .to_string() 36 | } 37 | ErrorKind::FailedToResizeInputTensor(index) => { 38 | format!("failed to resize input tensor at index {}", index) 39 | } 40 | ErrorKind::AllocateTensorsRequired => "must call allocate_tensors()".to_string(), 41 | ErrorKind::FailedToAllocateTensors => { 42 | "failed to allocate memory for input tensors".to_string() 43 | } 44 | ErrorKind::FailedToCopyDataToInputTensor => { 45 | "failed to copy data to input tensor".to_string() 46 | } 47 | ErrorKind::FailedToLoadModel => "failed to load the given model".to_string(), 48 | ErrorKind::FailedToCreateInterpreter => "failed to create the interpreter".to_string(), 49 | ErrorKind::ReadTensorError => "failed to read tensor".to_string(), 50 | ErrorKind::InvokeInterpreterRequired => "must call invoke()".to_string(), 51 | } 52 | } 53 | } 54 | 55 | impl Display for ErrorKind { 56 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 57 | f.write_str(&self.as_string()) 58 | } 59 | } 60 | 61 | /// The error type for TensorFlow Lite operations. 62 | #[derive(Debug, Copy, Clone, Hash, Eq, PartialEq)] 63 | pub struct Error { 64 | kind: ErrorKind, 65 | } 66 | 67 | impl Display for Error { 68 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 69 | write!(f, "{}", self.kind) 70 | } 71 | } 72 | 73 | impl std::error::Error for Error {} 74 | 75 | impl Error { 76 | pub(crate) fn new(kind: ErrorKind) -> Error { 77 | Error { kind } 78 | } 79 | pub fn kind(&self) -> ErrorKind { 80 | self.kind 81 | } 82 | } 83 | 84 | /// A specialized [`Result`] type for API operations. 85 | pub type Result = std::result::Result; 86 | -------------------------------------------------------------------------------- /src/interpreter.rs: -------------------------------------------------------------------------------- 1 | //! API of TensorFlow Lite [`Interpreter`] that performs inference. 2 | use std::ffi::c_void; 3 | use std::os::raw::c_int; 4 | 5 | use crate::bindings::*; 6 | use crate::model::Model; 7 | use crate::tensor; 8 | use crate::tensor::Tensor; 9 | use crate::{Error, ErrorKind, Result}; 10 | use std::fmt::{Debug, Formatter}; 11 | 12 | /// Options for configuring the [`Interpreter`]. 13 | #[derive(Debug, Eq, PartialEq, Copy, Clone, Hash, Ord, PartialOrd)] 14 | pub struct Options { 15 | /// The maximum number of CPU threads that the interpreter should run on. 16 | /// 17 | /// The default is -1 indicating that the [`Interpreter`] will decide 18 | /// the number of threads to use. `thread_count` should be >= -1. 19 | /// Setting `thread_count` to 0 has the effect to disable multithreading, 20 | /// which is equivalent to setting `thread_count` to 1. 21 | /// If set to the value -1, the number of threads used will be 22 | /// implementation-defined and platform-dependent. 23 | pub thread_count: i32, 24 | 25 | /// Indicates whether an optimized set of floating point CPU kernels, provided by XNNPACK, is 26 | /// enabled. 27 | /// 28 | /// Enabling this flag will enable use of a new, highly optimized set of CPU kernels provided 29 | /// via the XNNPACK delegate. ~~Currently, this is restricted to a subset of floating point 30 | /// operations.~~ 31 | /// See [official README](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/delegates/xnnpack/README.md) for more details. 32 | /// 33 | /// ## Important: 34 | /// Things to keep in mind when enabling this flag: 35 | /// 36 | /// * Startup time and resize time may increase. 37 | /// * Baseline memory consumption may increase. 38 | /// * Quantized models will not see any benefit unless features `xnnpack_qu8` or `xnnpack_qs8` 39 | /// are enabled. 40 | #[cfg(feature = "xnnpack")] 41 | #[cfg_attr(docsrs, doc(cfg(feature = "xnnpack")))] 42 | pub is_xnnpack_enabled: bool, 43 | } 44 | 45 | impl Default for Options { 46 | fn default() -> Self { 47 | Self { 48 | thread_count: -1, 49 | #[cfg(feature = "xnnpack")] 50 | is_xnnpack_enabled: false, 51 | } 52 | } 53 | } 54 | 55 | /// A TensorFlow Lite interpreter that performs inference from a given model. 56 | /// 57 | /// - Note: Interpreter instances are *not* thread-safe. 58 | pub struct Interpreter<'a> { 59 | /// The configuration options for the [`Interpreter`]. 60 | options: Option, 61 | 62 | /// The underlying [`TfLiteInterpreter`] C pointer. 63 | interpreter_ptr: *mut TfLiteInterpreter, 64 | 65 | /// The underlying [`TfLiteDelegate`] C pointer for XNNPACK delegate. 66 | #[cfg(feature = "xnnpack")] 67 | xnnpack_delegate_ptr: Option<*mut TfLiteDelegate>, 68 | 69 | /// The underlying `Model` to limit lifetime of the interpreter. 70 | /// See this issue for details: 71 | /// 72 | #[allow(dead_code)] 73 | model: &'a Model<'a>, 74 | } 75 | 76 | impl Debug for Interpreter<'_> { 77 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 78 | f.debug_struct("Interpreter") 79 | .field("options", &self.options) 80 | .finish() 81 | } 82 | } 83 | unsafe impl Send for Interpreter<'_> {} 84 | 85 | impl<'a> Interpreter<'a> { 86 | /// Creates new [`Interpreter`] 87 | /// 88 | /// # Arguments 89 | /// 90 | /// * `model`: TensorFlow Lite [model][`Model`] 91 | /// * `options`: Interpreter [options][`Options`] 92 | /// 93 | /// # Examples 94 | /// 95 | /// ``` 96 | /// use tflitec::model::Model; 97 | /// use tflitec::interpreter::Interpreter; 98 | /// let model = Model::new("tests/add.bin")?; 99 | /// let interpreter = Interpreter::new(&model, None)?; 100 | /// # Ok::<(), tflitec::Error>(()) 101 | /// ``` 102 | /// 103 | /// # Errors 104 | /// 105 | /// Returns error if TensorFlow Lite C fails internally. 106 | pub fn new(model: &'a Model<'a>, options: Option) -> Result> { 107 | unsafe { 108 | let options_ptr = TfLiteInterpreterOptionsCreate(); 109 | if options_ptr.is_null() { 110 | return Err(Error::new(ErrorKind::FailedToCreateInterpreter)); 111 | } 112 | if let Some(thread_count) = options.as_ref().map(|s| s.thread_count) { 113 | TfLiteInterpreterOptionsSetNumThreads(options_ptr, thread_count); 114 | } 115 | 116 | #[cfg(feature = "xnnpack")] 117 | let mut xnnpack_delegate_ptr: Option<*mut TfLiteDelegate> = None; 118 | #[cfg(feature = "xnnpack")] 119 | { 120 | if let Some(options) = options.as_ref() { 121 | if options.is_xnnpack_enabled { 122 | xnnpack_delegate_ptr = 123 | Some(Interpreter::configure_xnnpack(options, options_ptr)); 124 | } 125 | } 126 | } 127 | 128 | // TODO(ebraraktas): TfLiteInterpreterOptionsSetErrorReporter 129 | let model_ptr = model.model_ptr as *const TfLiteModel; 130 | let interpreter_ptr = TfLiteInterpreterCreate(model_ptr, options_ptr); 131 | TfLiteInterpreterOptionsDelete(options_ptr); 132 | if interpreter_ptr.is_null() { 133 | Err(Error::new(ErrorKind::FailedToCreateInterpreter)) 134 | } else { 135 | Ok(Interpreter { 136 | options, 137 | interpreter_ptr, 138 | #[cfg(feature = "xnnpack")] 139 | xnnpack_delegate_ptr, 140 | model, 141 | }) 142 | } 143 | } 144 | } 145 | 146 | /// Returns the total number of input [`Tensor`]s associated with the model. 147 | pub fn input_tensor_count(&self) -> usize { 148 | unsafe { TfLiteInterpreterGetInputTensorCount(self.interpreter_ptr) as usize } 149 | } 150 | 151 | /// Returns the total number of output `Tensor`s associated with the model. 152 | pub fn output_tensor_count(&self) -> usize { 153 | unsafe { TfLiteInterpreterGetOutputTensorCount(self.interpreter_ptr) as usize } 154 | } 155 | 156 | /// Invokes the interpreter to perform inference from the loaded graph. 157 | /// 158 | /// # Errors 159 | /// 160 | /// Returns error if TensorFlow Lite C fails to invoke. 161 | pub fn invoke(&self) -> Result<()> { 162 | if TfLiteStatus_kTfLiteOk == unsafe { TfLiteInterpreterInvoke(self.interpreter_ptr) } { 163 | Ok(()) 164 | } else { 165 | Err(Error::new(ErrorKind::AllocateTensorsRequired)) 166 | } 167 | } 168 | 169 | /// Returns the input [`Tensor`] at the given `index`. 170 | /// 171 | /// # Arguments 172 | /// 173 | /// * `index`: The index for the input [`Tensor`]. 174 | /// 175 | /// # Errors 176 | /// 177 | /// Returns error if [`Interpreter::allocate_tensors()`] was not called before calling this 178 | /// or given index is not a valid input tensor index in 179 | /// [0, [`Interpreter::input_tensor_count()`]). 180 | pub fn input(&self, index: usize) -> Result { 181 | let max_index = self.input_tensor_count() - 1; 182 | if index > max_index { 183 | return Err(Error::new(ErrorKind::InvalidTensorIndex(index, max_index))); 184 | } 185 | unsafe { 186 | let tensor_ptr = TfLiteInterpreterGetInputTensor(self.interpreter_ptr, index as i32); 187 | Tensor::from_raw(tensor_ptr as *mut TfLiteTensor).map_err(|error| { 188 | if error.kind() == ErrorKind::ReadTensorError { 189 | Error::new(ErrorKind::AllocateTensorsRequired) 190 | } else { 191 | error 192 | } 193 | }) 194 | } 195 | } 196 | 197 | /// Returns the output [`Tensor`] at the given `index`. 198 | /// 199 | /// # Arguments 200 | /// 201 | /// * `index`: The index for the output [`Tensor`]. 202 | /// 203 | /// # Errors 204 | /// 205 | /// Returns error if given index is not a valid output tensor index in 206 | /// [0, [`Interpreter::output_tensor_count()`]). And, it may return error 207 | /// unless the output tensor has been both sized and allocated. In general, 208 | /// best practice is to call this *after* calling [`Interpreter::invoke()`]. 209 | pub fn output(&self, index: usize) -> Result { 210 | let max_index = self.output_tensor_count() - 1; 211 | if index > max_index { 212 | return Err(Error::new(ErrorKind::InvalidTensorIndex(index, max_index))); 213 | } 214 | unsafe { 215 | let tensor_ptr = TfLiteInterpreterGetOutputTensor(self.interpreter_ptr, index as i32); 216 | Tensor::from_raw(tensor_ptr as *mut TfLiteTensor).map_err(|error| { 217 | if error.kind() == ErrorKind::ReadTensorError { 218 | Error::new(ErrorKind::InvokeInterpreterRequired) 219 | } else { 220 | error 221 | } 222 | }) 223 | } 224 | } 225 | 226 | /// Resizes the input [`Tensor`] at the given index to the 227 | /// specified [`Shape`][tensor::Shape]. 228 | /// 229 | /// - Note: After resizing an input tensor, the client **must** explicitly call 230 | /// [`Interpreter::allocate_tensors()`] before attempting to access the resized tensor data 231 | /// or invoking the interpreter to perform inference. 232 | /// 233 | /// # Arguments 234 | /// 235 | /// * `index`: The index for the input [`Tensor`]. 236 | /// * `shape`: The shape to resize the input [`Tensor`] to. 237 | /// 238 | /// # Errors 239 | /// 240 | /// Returns error if given index is not a valid input tensor index in 241 | /// [0, [`Interpreter::input_tensor_count()`]) or TensorFlow Lite C fails internally. 242 | pub fn resize_input(&self, index: usize, shape: tensor::Shape) -> Result<()> { 243 | let max_index = self.input_tensor_count() - 1; 244 | if index > max_index { 245 | return Err(Error::new(ErrorKind::InvalidTensorIndex(index, max_index))); 246 | } 247 | let dims = shape 248 | .dimensions() 249 | .iter() 250 | .map(|v| *v as i32) 251 | .collect::>(); 252 | 253 | unsafe { 254 | if TfLiteStatus_kTfLiteOk 255 | == TfLiteInterpreterResizeInputTensor( 256 | self.interpreter_ptr, 257 | index as i32, 258 | dims.as_ptr() as *const c_int, 259 | dims.len() as i32, 260 | ) 261 | { 262 | Ok(()) 263 | } else { 264 | Err(Error::new(ErrorKind::FailedToResizeInputTensor(index))) 265 | } 266 | } 267 | } 268 | 269 | /// Allocates memory for all input [`Tensor`]s and dependent tensors based on 270 | /// their [`Shape`][tensor::Shape]s. 271 | /// 272 | /// - Note: This is a relatively expensive operation and should only be called 273 | /// after creating the interpreter and resizing any input tensors. 274 | /// 275 | /// # Error 276 | /// 277 | /// Returns error if TensorFlow Lite C fails to allocate memory 278 | /// for the input tensors. 279 | pub fn allocate_tensors(&self) -> Result<()> { 280 | if TfLiteStatus_kTfLiteOk 281 | == unsafe { TfLiteInterpreterAllocateTensors(self.interpreter_ptr) } 282 | { 283 | Ok(()) 284 | } else { 285 | Err(Error::new(ErrorKind::FailedToAllocateTensors)) 286 | } 287 | } 288 | 289 | /// Copies the given `data` to the input [`Tensor`] at the given `index`. 290 | /// 291 | /// # Arguments 292 | /// 293 | /// * `data`: The data to be copied to the input `Tensor`'s data buffer 294 | /// * `index`: The index for the input `Tensor` 295 | /// 296 | /// # Errors 297 | /// 298 | /// Return error if the data length does not match the buffer size of the input tensor or 299 | /// the given index is not a valid input tensor index in 300 | /// [0, [`Interpreter::input_tensor_count()`]) or TensorFlow Lite C fails internally. 301 | fn copy_bytes(&self, data: &[u8], index: usize) -> Result<()> { 302 | let max_index = self.input_tensor_count() - 1; 303 | if index > max_index { 304 | return Err(Error::new(ErrorKind::InvalidTensorIndex(index, max_index))); 305 | } 306 | unsafe { 307 | let tensor_ptr = TfLiteInterpreterGetInputTensor(self.interpreter_ptr, index as i32); 308 | let byte_count = TfLiteTensorByteSize(tensor_ptr); 309 | if data.len() != byte_count { 310 | return Err(Error::new(ErrorKind::InvalidTensorDataCount( 311 | data.len(), 312 | byte_count, 313 | ))); 314 | } 315 | let status = 316 | TfLiteTensorCopyFromBuffer(tensor_ptr, data.as_ptr() as *const c_void, data.len()); 317 | if status != TfLiteStatus_kTfLiteOk { 318 | Err(Error::new(ErrorKind::FailedToCopyDataToInputTensor)) 319 | } else { 320 | Ok(()) 321 | } 322 | } 323 | } 324 | 325 | /// Copies the given `data` to the input [`Tensor`] at the given `index`. 326 | /// 327 | /// # Arguments 328 | /// 329 | /// * `data`: The data to be copied to the input `Tensor`'s data buffer. 330 | /// * `index`: The index for the input [`Tensor`]. 331 | /// 332 | /// # Errors 333 | /// 334 | /// Returns error if byte count of the data does not match the buffer size of the 335 | /// input tensor or the given index is not a valid input tensor index in 336 | /// [0, [`Interpreter::input_tensor_count()`]) or TensorFlow Lite C fails internally. 337 | pub fn copy(&self, data: &[T], index: usize) -> Result<()> { 338 | let element_size = std::mem::size_of::(); 339 | let d = unsafe { 340 | std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * element_size) 341 | }; 342 | self.copy_bytes(d, index) 343 | } 344 | 345 | /// Returns optional reference of [`Options`]. 346 | pub fn options(&self) -> Option<&Options> { 347 | self.options.as_ref() 348 | } 349 | 350 | #[cfg(feature = "xnnpack")] 351 | unsafe fn configure_xnnpack( 352 | options: &Options, 353 | interpreter_options_ptr: *mut TfLiteInterpreterOptions, 354 | ) -> *mut TfLiteDelegate { 355 | let mut xnnpack_options = TfLiteXNNPackDelegateOptionsDefault(); 356 | if options.thread_count > 0 { 357 | xnnpack_options.num_threads = options.thread_count 358 | } 359 | 360 | let xnnpack_delegate_ptr = TfLiteXNNPackDelegateCreate(&xnnpack_options); 361 | TfLiteInterpreterOptionsAddDelegate(interpreter_options_ptr, xnnpack_delegate_ptr); 362 | xnnpack_delegate_ptr 363 | } 364 | } 365 | 366 | impl Drop for Interpreter<'_> { 367 | fn drop(&mut self) { 368 | unsafe { 369 | TfLiteInterpreterDelete(self.interpreter_ptr); 370 | 371 | #[cfg(feature = "xnnpack")] 372 | { 373 | if let Some(delegate_ptr) = self.xnnpack_delegate_ptr { 374 | TfLiteXNNPackDelegateDelete(delegate_ptr) 375 | } 376 | } 377 | } 378 | } 379 | } 380 | 381 | #[cfg(test)] 382 | mod tests { 383 | use crate::interpreter::Interpreter; 384 | use crate::model::Model; 385 | use crate::tensor; 386 | use crate::ErrorKind; 387 | 388 | #[cfg(target_os = "windows")] 389 | const MODEL_PATH: &str = "tests\\add.bin"; 390 | #[cfg(not(target_os = "windows"))] 391 | const MODEL_PATH: &str = "tests/add.bin"; 392 | 393 | #[test] 394 | fn test_interpreter_input_output_count() { 395 | let bytes = std::fs::read(MODEL_PATH).expect("Cannot read model data!"); 396 | let model = Model::from_bytes(&bytes).expect("Cannot load model from bytes"); 397 | let interpreter = Interpreter::new(&model, None).expect("Cannot create interpreter"); 398 | assert_eq!(interpreter.input_tensor_count(), 1); 399 | assert_eq!(interpreter.output_tensor_count(), 1); 400 | } 401 | 402 | #[test] 403 | fn test_interpreter_get_input_tensor() { 404 | let bytes = std::fs::read(MODEL_PATH).expect("Cannot read model data!"); 405 | let model = Model::from_bytes(&bytes).expect("Cannot load model from bytes!"); 406 | let interpreter = Interpreter::new(&model, None).expect("Cannot create interpreter!"); 407 | 408 | let invalid_tensor = interpreter.input(1); 409 | assert!(invalid_tensor.is_err()); 410 | let err = invalid_tensor.err().unwrap(); 411 | assert_eq!(ErrorKind::InvalidTensorIndex(1, 0), err.kind()); 412 | 413 | let invalid_tensor = interpreter.input(0); 414 | assert!(invalid_tensor.is_err()); 415 | let err = invalid_tensor.err().unwrap(); 416 | assert_eq!(ErrorKind::AllocateTensorsRequired, err.kind()); 417 | 418 | interpreter.allocate_tensors().unwrap(); 419 | let valid_tensor = interpreter.input(0); 420 | assert!(valid_tensor.is_ok()); 421 | let tensor = valid_tensor.ok().unwrap(); 422 | assert_eq!(tensor.shape().dimensions(), &vec![1, 8, 8, 3]) 423 | } 424 | 425 | #[test] 426 | fn test_interpreter_allocate_tensors() { 427 | let bytes = std::fs::read(MODEL_PATH).expect("Cannot read model data!"); 428 | let model = Model::from_bytes(&bytes).expect("Cannot load model from bytes!"); 429 | let interpreter = Interpreter::new(&model, None).expect("Cannot create interpreter!"); 430 | 431 | interpreter 432 | .resize_input(0, tensor::Shape::new(vec![10, 8, 8, 3])) 433 | .expect("Resize failed"); 434 | interpreter 435 | .allocate_tensors() 436 | .expect("Cannot allocate tensors"); 437 | let tensor = interpreter.input(0).unwrap(); 438 | assert_eq!(tensor.shape().dimensions(), &vec![10, 8, 8, 3]) 439 | } 440 | 441 | #[test] 442 | fn test_interpreter_copy_input() { 443 | let bytes = std::fs::read(MODEL_PATH).expect("Cannot read model data!"); 444 | let model = Model::from_bytes(&bytes).expect("Cannot load model from bytes!"); 445 | let interpreter = Interpreter::new(&model, None).expect("Cannot create interpreter!"); 446 | 447 | interpreter 448 | .resize_input(0, tensor::Shape::new(vec![10, 8, 8, 3])) 449 | .expect("Resize failed"); 450 | interpreter 451 | .allocate_tensors() 452 | .expect("Cannot allocate tensors"); 453 | let tensor = interpreter.input(0).unwrap(); 454 | let data = (0..1920).map(|x| x as f32).collect::>(); 455 | assert!(interpreter.copy(&data[..], 0).is_ok()); 456 | assert_eq!(data, tensor.data()); 457 | } 458 | 459 | #[test] 460 | fn test_interpreter_invoke() { 461 | let model = Model::new(MODEL_PATH).expect("Cannot load model from file!"); 462 | let interpreter = Interpreter::new(&model, None).expect("Cannot create interpreter!"); 463 | 464 | interpreter 465 | .resize_input(0, tensor::Shape::new(vec![10, 8, 8, 3])) 466 | .expect("Resize failed"); 467 | interpreter 468 | .allocate_tensors() 469 | .expect("Cannot allocate tensors"); 470 | 471 | let data = (0..1920).map(|x| x as f32).collect::>(); 472 | assert!(interpreter.copy(&data[..], 0).is_ok()); 473 | assert!(interpreter.invoke().is_ok()); 474 | let expected: Vec = data.iter().map(|e| e * 3.0).collect(); 475 | let output_tensor = interpreter.output(0).unwrap(); 476 | assert_eq!(output_tensor.shape().dimensions(), &vec![10, 8, 8, 3]); 477 | let output_vector = output_tensor.data::().to_vec(); 478 | assert_eq!(expected, output_vector); 479 | } 480 | 481 | #[cfg(feature = "xnnpack")] 482 | #[test] 483 | fn test_interpreter_invoke_xnnpack() { 484 | use crate::interpreter::Options; 485 | let options = Some(Options { 486 | thread_count: 2, 487 | is_xnnpack_enabled: true, 488 | }); 489 | let model = Model::new(MODEL_PATH).expect("Cannot load model from file!"); 490 | let interpreter = Interpreter::new(&model, options).expect("Cannot create interpreter!"); 491 | 492 | interpreter 493 | .resize_input(0, tensor::Shape::new(vec![10, 8, 8, 3])) 494 | .expect("Resize failed"); 495 | interpreter 496 | .allocate_tensors() 497 | .expect("Cannot allocate tensors"); 498 | 499 | let data = (0..1920).map(|x| x as f32).collect::>(); 500 | assert!(interpreter.copy(&data[..], 0).is_ok()); 501 | assert!(interpreter.invoke().is_ok()); 502 | let expected: Vec = data.iter().map(|e| e * 3.0).collect(); 503 | let output_tensor = interpreter.output(0).unwrap(); 504 | assert_eq!(output_tensor.shape().dimensions(), &vec![10, 8, 8, 3]); 505 | let output_vector = output_tensor.data::().to_vec(); 506 | assert_eq!(expected, output_vector); 507 | } 508 | } 509 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | // only enables the `doc_cfg` feature when 2 | // the `docsrs` configuration attribute is defined 3 | #![cfg_attr(docsrs, feature(doc_cfg))] 4 | #![doc = include_str!("../README.md")] 5 | 6 | mod error; 7 | pub mod interpreter; 8 | pub mod model; 9 | pub mod tensor; 10 | 11 | pub(crate) mod bindings { 12 | #![allow(clippy::all)] 13 | #![allow(non_upper_case_globals)] 14 | #![allow(non_camel_case_types)] 15 | #![allow(non_snake_case)] 16 | #![allow(unused)] 17 | #![allow(improper_ctypes)] 18 | include!(concat!(env!("OUT_DIR"), "/bindings.rs")); 19 | } 20 | 21 | pub use self::error::{Error, ErrorKind, Result}; 22 | -------------------------------------------------------------------------------- /src/model.rs: -------------------------------------------------------------------------------- 1 | //! TensorFlow Lite [`Model`] loader. 2 | //! 3 | //! # Examples 4 | //! 5 | //! ``` 6 | //! use tflitec::model::Model; 7 | //! let model = Model::new("tests/add.bin")?; 8 | //! # Ok::<(), tflitec::Error>(()) 9 | //! ``` 10 | use crate::bindings::{ 11 | TfLiteModel, TfLiteModelCreate, TfLiteModelCreateFromFile, TfLiteModelDelete, 12 | }; 13 | use crate::{Error, ErrorKind, Result}; 14 | use std::ffi::{c_void, CString}; 15 | use std::fmt::{Debug, Formatter}; 16 | 17 | /// A TensorFlow Lite model used by the [`Interpreter`][crate::interpreter::Interpreter] to perform inference. 18 | pub struct Model<'a> { 19 | /// The underlying [`TfLiteModel`] C pointer. 20 | pub(crate) model_ptr: *mut TfLiteModel, 21 | 22 | #[allow(dead_code)] 23 | /// The model data if initialized with bytes. 24 | /// 25 | /// This reference is taken to guarantee that bytes 26 | /// must be immutable and outlive the model 27 | pub(crate) bytes: Option<&'a [u8]>, 28 | } 29 | 30 | impl Debug for Model<'_> { 31 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 32 | f.debug_struct("Model").finish() 33 | } 34 | } 35 | 36 | unsafe impl Send for Model<'_> {} 37 | unsafe impl Sync for Model<'_> {} 38 | 39 | impl Model<'_> { 40 | /// Creates a new instance with the given `filepath`. 41 | /// 42 | /// # Arguments 43 | /// 44 | /// * `filepath`: The local file path to a TensorFlow Lite model. 45 | /// 46 | /// # Errors 47 | /// 48 | /// Returns error if TensorFlow Lite C fails to read model from file. 49 | pub fn new<'a>(filepath: &str) -> Result> { 50 | let model_ptr = unsafe { 51 | let path = CString::new(filepath).unwrap(); 52 | TfLiteModelCreateFromFile(path.as_ptr()) 53 | }; 54 | if model_ptr.is_null() { 55 | Err(Error::new(ErrorKind::FailedToLoadModel)) 56 | } else { 57 | Ok(Model { 58 | model_ptr, 59 | bytes: None, 60 | }) 61 | } 62 | } 63 | 64 | /// Creates a new instance from the given `bytes`. 65 | /// 66 | /// # Arguments 67 | /// 68 | /// * `bytes`: TensorFlow Lite model data. 69 | /// 70 | /// # Errors 71 | /// 72 | /// Returns error if TensorFlow Lite C fails to load model from the buffer. 73 | pub fn from_bytes(bytes: &[u8]) -> std::result::Result { 74 | let model_ptr = unsafe { TfLiteModelCreate(bytes.as_ptr() as *const c_void, bytes.len()) }; 75 | if model_ptr.is_null() { 76 | Err(Error::new(ErrorKind::FailedToLoadModel)) 77 | } else { 78 | Ok(Model { 79 | model_ptr, 80 | bytes: Some(bytes), 81 | }) 82 | } 83 | } 84 | } 85 | 86 | impl Drop for Model<'_> { 87 | fn drop(&mut self) { 88 | unsafe { TfLiteModelDelete(self.model_ptr) } 89 | } 90 | } 91 | 92 | #[cfg(test)] 93 | mod tests { 94 | use crate::model::Model; 95 | 96 | const MODEL_PATH: &str = "tests/add.bin"; 97 | 98 | #[test] 99 | fn test_model_from_bytes() { 100 | let mut bytes = std::fs::read("tests/add.bin").unwrap(); 101 | // If we assign model to a named variable, this test won't compile, because 102 | // model borrows bytes and will be dropped at the end of the test (after mutation). 103 | let _ = Model::from_bytes(&bytes).expect("Cannot load model from bytes"); 104 | bytes[0] = 1; 105 | } 106 | 107 | #[test] 108 | fn test_model_from_path() { 109 | let mut filepath = String::from(MODEL_PATH); 110 | let _model = Model::new(&filepath).expect("Cannot load model from file"); 111 | // We can mutate filepath here, because it is not borrowed. 112 | filepath.push('/'); 113 | } 114 | } 115 | -------------------------------------------------------------------------------- /src/tensor.rs: -------------------------------------------------------------------------------- 1 | //! TensorFlow Lite input or output [`Tensor`] associated with an interpreter. 2 | use std::ffi::{c_void, CStr}; 3 | 4 | use crate::bindings; 5 | use crate::bindings::*; 6 | use crate::{Error, ErrorKind, Result}; 7 | use std::fmt::{Debug, Formatter}; 8 | use std::marker::PhantomData; 9 | 10 | /// Parameters that determine the mapping of quantized values to real values. 11 | /// 12 | /// Quantized values can be mapped to float values using the following conversion: 13 | /// `realValue = scale * (quantizedValue - zeroPoint)`. 14 | #[derive(Copy, Clone, PartialEq, Debug, PartialOrd)] 15 | pub struct QuantizationParameters { 16 | /// The difference between real values corresponding to consecutive quantized 17 | /// values differing by 1. For example, the range of quantized values for `u8` 18 | /// data type is [0, 255]. 19 | pub scale: f32, 20 | 21 | /// The quantized value that corresponds to the real 0 value. 22 | pub zero_point: i32, 23 | } 24 | 25 | /// The supported [`Tensor`] data types. 26 | #[derive(Copy, Clone, Eq, PartialEq, Debug, Hash)] 27 | pub enum DataType { 28 | /// A boolean. 29 | Bool, 30 | /// An 8-bit unsigned integer. 31 | Uint8, 32 | /// An 8-bit signed integer. 33 | Int8, 34 | /// A 16-bit signed integer. 35 | Int16, 36 | /// A 32-bit signed integer. 37 | Int32, 38 | /// A 64-bit signed integer. 39 | Int64, 40 | /// A 16-bit half precision floating point. 41 | Float16, 42 | /// A 32-bit single precision floating point. 43 | Float32, 44 | /// A 64-bit double precision floating point. 45 | Float64, 46 | } 47 | 48 | impl DataType { 49 | /// Creates a new instance from the given [`Option`]. 50 | /// 51 | /// # Arguments 52 | /// 53 | /// * `tflite_type`: A data type for a tensor. 54 | /// 55 | /// returns: [`None`] if the data type is unsupported or could not 56 | /// be determined because there was an error, otherwise returns 57 | /// [`Some`] corresponding enum variant. 58 | pub(crate) fn new(tflite_type: TfLiteType) -> Option { 59 | match tflite_type { 60 | bindings::TfLiteType_kTfLiteBool => Some(DataType::Bool), 61 | bindings::TfLiteType_kTfLiteUInt8 => Some(DataType::Uint8), 62 | bindings::TfLiteType_kTfLiteInt8 => Some(DataType::Int8), 63 | bindings::TfLiteType_kTfLiteInt16 => Some(DataType::Int16), 64 | bindings::TfLiteType_kTfLiteInt32 => Some(DataType::Int32), 65 | bindings::TfLiteType_kTfLiteInt64 => Some(DataType::Int64), 66 | bindings::TfLiteType_kTfLiteFloat16 => Some(DataType::Float16), 67 | bindings::TfLiteType_kTfLiteFloat32 => Some(DataType::Float32), 68 | bindings::TfLiteType_kTfLiteFloat64 => Some(DataType::Float64), 69 | _ => None, 70 | } 71 | } 72 | } 73 | 74 | #[derive(Clone, Eq, PartialEq, Debug, Hash)] 75 | /// The shape of a [`Tensor`]. 76 | pub struct Shape { 77 | /// The number of dimensions of the [`Tensor`] 78 | rank: usize, 79 | 80 | /// An array of dimensions for the [`Tensor`] 81 | dimensions: Vec, 82 | } 83 | 84 | impl Shape { 85 | /// Creates a new instance with the given `dimensions`. 86 | /// 87 | /// # Arguments 88 | /// 89 | /// * `dimensions`: Dimensions for the [`Tensor`]. 90 | /// 91 | /// returns: Shape 92 | /// 93 | /// # Examples 94 | /// 95 | /// ``` 96 | /// use tflitec::tensor; 97 | /// let shape = tensor::Shape::new(vec![8, 16, 16]); 98 | /// assert_eq!(shape.rank(), 3); 99 | /// assert_eq!(shape.dimensions(), &vec![8, 16, 16]); 100 | /// ``` 101 | pub fn new(dimensions: Vec) -> Shape { 102 | Shape { 103 | rank: dimensions.len(), 104 | dimensions, 105 | } 106 | } 107 | 108 | /// Returns dimensions of the [`Tensor`]. 109 | pub fn dimensions(&self) -> &Vec { 110 | &self.dimensions 111 | } 112 | 113 | /// Returns rank(number of dimensions) of the [`Tensor`]. 114 | pub fn rank(&self) -> usize { 115 | self.rank 116 | } 117 | } 118 | 119 | pub(crate) struct TensorData { 120 | data_ptr: *mut u8, 121 | data_length: usize, 122 | } 123 | 124 | /// An input or output tensor in a TensorFlow Lite graph. 125 | pub struct Tensor<'a> { 126 | /// The name of the `Tensor`. 127 | name: String, 128 | 129 | /// The data type of the `Tensor`. 130 | data_type: DataType, 131 | 132 | /// The shape of the `Tensor`. 133 | shape: Shape, 134 | 135 | /// The data in the input or output `Tensor`. 136 | data: TensorData, 137 | 138 | /// The quantization parameters for the `Tensor` if using a quantized model. 139 | quantization_parameters: Option, 140 | 141 | /// The underlying [`TfLiteTensor`] C pointer. 142 | tensor_ptr: *mut TfLiteTensor, 143 | 144 | // To set lifetime of the Tensor 145 | phantom: PhantomData<&'a TfLiteTensor>, 146 | } 147 | 148 | impl Debug for Tensor<'_> { 149 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 150 | f.debug_struct("Tensor") 151 | .field("name", &self.name) 152 | .field("shape", &self.shape) 153 | .field("data_type", &self.data_type) 154 | .field("quantization_parameters", &self.quantization_parameters) 155 | .finish() 156 | } 157 | } 158 | 159 | impl<'a> Tensor<'a> { 160 | pub(crate) fn from_raw(tensor_ptr: *mut TfLiteTensor) -> Result> { 161 | unsafe { 162 | if tensor_ptr.is_null() { 163 | return Err(Error::new(ErrorKind::ReadTensorError)); 164 | } 165 | 166 | let name_ptr = TfLiteTensorName(tensor_ptr); 167 | if name_ptr.is_null() { 168 | return Err(Error::new(ErrorKind::ReadTensorError)); 169 | } 170 | let data_ptr = TfLiteTensorData(tensor_ptr) as *mut u8; 171 | if data_ptr.is_null() { 172 | return Err(Error::new(ErrorKind::ReadTensorError)); 173 | } 174 | let name = CStr::from_ptr(name_ptr).to_str().unwrap().to_owned(); 175 | 176 | let data_length = TfLiteTensorByteSize(tensor_ptr); 177 | let data_type = DataType::new(TfLiteTensorType(tensor_ptr)) 178 | .ok_or_else(|| Error::new(ErrorKind::InvalidTensorDataType))?; 179 | 180 | let rank = TfLiteTensorNumDims(tensor_ptr); 181 | let dimensions = (0..rank) 182 | .map(|i| TfLiteTensorDim(tensor_ptr, i) as usize) 183 | .collect(); 184 | let shape = Shape::new(dimensions); 185 | let data = TensorData { 186 | data_ptr, 187 | data_length, 188 | }; 189 | let quantization_parameters_ptr = TfLiteTensorQuantizationParams(tensor_ptr); 190 | let scale = quantization_parameters_ptr.scale; 191 | let quantization_parameters = 192 | if scale == 0.0 || (data_type != DataType::Uint8 && data_type != DataType::Int8) { 193 | None 194 | } else { 195 | Some(QuantizationParameters { 196 | scale: quantization_parameters_ptr.scale, 197 | zero_point: quantization_parameters_ptr.zero_point, 198 | }) 199 | }; 200 | Ok(Tensor { 201 | name, 202 | data_type, 203 | shape, 204 | data, 205 | quantization_parameters, 206 | tensor_ptr, 207 | phantom: PhantomData, 208 | }) 209 | } 210 | } 211 | 212 | /// Returns [`Shape`] of the tensor 213 | pub fn shape(&self) -> &Shape { 214 | &self.shape 215 | } 216 | 217 | /// Returns data of the tensor as a slice of given type `T`. 218 | /// 219 | /// # Panics 220 | /// 221 | /// * If number of bytes in buffer of the [`Tensor`] is not integer 222 | /// multiple of byte count of a single `T` (see [`std::mem::size_of`]) 223 | pub fn data(&self) -> &[T] { 224 | let element_size = std::mem::size_of::(); 225 | if self.data.data_length % element_size != 0 { 226 | panic!( 227 | "data length {} should be divisible by size of type {}", 228 | self.data.data_length, element_size 229 | ) 230 | } 231 | unsafe { 232 | std::slice::from_raw_parts( 233 | self.data.data_ptr as *const T, 234 | self.data.data_length / element_size, 235 | ) 236 | } 237 | } 238 | 239 | /// Sets data of the tensor by copying given data slice 240 | /// 241 | /// # Arguments 242 | /// 243 | /// * `data`: Data to be copied 244 | /// 245 | /// # Errors 246 | /// 247 | /// Returns error if byte count of the data does not match the buffer size of the 248 | /// input tensor or TensorFlow Lite C fails internally. 249 | pub fn set_data(&self, data: &[T]) -> Result<()> { 250 | let element_size = std::mem::size_of::(); 251 | let input_byte_count = element_size * data.len(); 252 | if self.data.data_length != input_byte_count { 253 | return Err(Error::new(ErrorKind::InvalidTensorDataCount( 254 | data.len(), 255 | input_byte_count, 256 | ))); 257 | } 258 | let status = unsafe { 259 | TfLiteTensorCopyFromBuffer( 260 | self.tensor_ptr, 261 | data.as_ptr() as *const c_void, 262 | input_byte_count, 263 | ) 264 | }; 265 | if status != TfLiteStatus_kTfLiteOk { 266 | Err(Error::new(ErrorKind::FailedToCopyDataToInputTensor)) 267 | } else { 268 | Ok(()) 269 | } 270 | } 271 | 272 | /// Returns [data type][`DataType`] of the [`Tensor`]. 273 | pub fn data_type(&self) -> DataType { 274 | self.data_type 275 | } 276 | 277 | /// Returns optional [`QuantizationParameters`] of the [`Tensor`]. 278 | pub fn quantization_parameters(&self) -> Option { 279 | self.quantization_parameters 280 | } 281 | 282 | /// Returns name of the [`Tensor`]. 283 | pub fn name(&self) -> &str { 284 | self.name.as_str() 285 | } 286 | } 287 | -------------------------------------------------------------------------------- /tests/add.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebraraktas/tflitec-rs/d3f4cd8623436b0083f715c5439c96511f5c0c77/tests/add.bin --------------------------------------------------------------------------------