├── .github └── workflows │ └── rust.yml ├── .gitignore ├── .gitmodules ├── .travis.yml ├── CONTRIBUTING.md ├── Cargo.toml ├── LICENSE-APACHE ├── LICENSE-MIT ├── README.md ├── build.rs ├── csrc ├── stl_wrapper.hpp └── tflite_wrapper.hpp ├── data ├── MNISTnet_uint8_quant.tflite ├── MNISTnet_v2_uint8_quant.tflite ├── builtin_options_impl.rs.template ├── memory_basic_impl.rs.template ├── mnist10.bin ├── vector_basic_impl.rs.template └── vector_primitive_impl.rs.template ├── examples └── minimal.rs ├── rustfmt.toml ├── src ├── bindings.rs ├── error.rs ├── interpreter │ ├── builder.rs │ ├── context.rs │ ├── fbmodel.rs │ ├── mod.rs │ ├── op_resolver.rs │ └── ops │ │ ├── builtin │ │ ├── mod.rs │ │ └── resolver.rs │ │ └── mod.rs ├── lib.rs └── model │ ├── builtin_options.rs │ ├── builtin_options_impl.rs │ ├── mod.rs │ └── stl │ ├── memory.rs │ ├── memory_impl.rs │ ├── mod.rs │ ├── string.rs │ ├── vector.rs │ └── vector_impl.rs ├── submodules ├── make-NativeTable-configurable-as-polymorphic.patch └── update-downloads.sh └── tests └── integration_test.rs /.github/workflows/rust.yml: -------------------------------------------------------------------------------- 1 | name: Rust 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | env: 9 | CARGO_TERM_COLOR: always 10 | 11 | jobs: 12 | test: 13 | name: test 14 | runs-on: ubuntu-20.04 15 | strategy: 16 | fail-fast: true 17 | matrix: 18 | rust: [stable, beta, nightly] 19 | steps: 20 | - uses: actions/checkout@v2 21 | with: 22 | submodules: recursive 23 | - uses: actions-rs/toolchain@v1 24 | with: 25 | profile: minimal 26 | toolchain: ${{matrix.rust}} 27 | - run: cargo fmt -- --check 28 | - run: cargo clippy --all-targets -- --deny warnings --allow unknown-lints 29 | - run: cargo test --features generate_model_apis -- --nocapture 30 | - run: cargo test --features debug_tflite,no_micro -- --nocapture 31 | - run: cargo fmt 32 | - run: cargo publish --dry-run 33 | # Make sure package size is under 10 MB 34 | - run: cargo package --verbose && [ $(stat -c %s target/package/tflite-*.crate) -le 10485760 ] 35 | 36 | macos-build-test: 37 | name: macos build test 38 | runs-on: macOS-latest 39 | strategy: 40 | fail-fast: true 41 | steps: 42 | - uses: actions/checkout@v2 43 | with: 44 | submodules: recursive 45 | - uses: actions-rs/cargo@v1 46 | with: 47 | command: build 48 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | 3 | /target 4 | **/*.rs.bk 5 | /Cargo.lock 6 | .idea/ 7 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "submodules/tensorflow"] 2 | path = submodules/tensorflow 3 | url = https://github.com/tensorflow/tensorflow.git 4 | [submodule "submodules/downloads"] 5 | path = submodules/downloads 6 | url = https://github.com/boncheolgu/tflite-rs-downloads 7 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | os: 2 | - osx 3 | - linux 4 | language: rust 5 | rust: 6 | - stable 7 | - beta 8 | - nightly 9 | script: 10 | - if [ "$TRAVIS_OS_NAME" = "linux" ]; then sudo ln -s -f $(which clang) /usr/bin/gcc ; fi 11 | - if [ "$TRAVIS_OS_NAME" = "linux" ]; then sudo ln -s -f $(which clang++) /usr/bin/g++ ; fi 12 | - gcc -v && g++ -v 13 | - cargo build --verbose 14 | - if [ "$TRAVIS_OS_NAME" = "linux" ]; then cargo test -- --nocapture ; fi 15 | - cargo build --verbose --features debug_tflite,no_micro 16 | - if [ "$TRAVIS_OS_NAME" = "linux" ]; then cargo test --features debug_tflite,no_micro -- --nocapture ; fi 17 | # Make sure package size is under 10 MB 18 | - if [ "$TRAVIS_OS_NAME" = "linux" ]; then cargo package --verbose && [ $(stat -c %s target/package/tflite-*.crate) -le 10485760 ] ; fi 19 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | tflite now includes the tensorflow git repository as a submodule 2 | along with the results of calling `download_dependencies.sh`. 3 | This reduces the number of build steps from ~260 to less than 100. 4 | If the version of tensorflow is ever updated, `submodules/update-downloads.sh` 5 | should also be updated if necessary and called. It removes most of the 6 | files that are obviously not necessary since they all get committed. 7 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "tflite" 3 | version = "0.9.8" 4 | authors = ["Boncheol Gu "] 5 | description = "Rust bindings for TensorFlow Lite" 6 | keywords = ["tensorflow", "tflite", "bindings"] 7 | license = "MIT/Apache-2.0" 8 | repository = "https://github.com/boncheolgu/tflite-rs" 9 | readme = "README.md" 10 | edition = "2021" 11 | documentation = "https://docs.rs/crate/tflite" 12 | 13 | # filter in only necessary files 14 | # tar tvaf target/package/tflite-*.crate | awk '{print $3 "\t" $6}' | sort -n 15 | include = [ 16 | "build.rs", 17 | "Cargo.toml", 18 | "*.md", 19 | "LICENSE-*", 20 | "data/", 21 | "csrc/", 22 | "src/", 23 | "submodules/downloads", 24 | "submodules/tensorflow/tensorflow/lite/c", 25 | "submodules/tensorflow/tensorflow/lite/core", 26 | "submodules/tensorflow/tensorflow/lite/delegates/nnapi", 27 | "submodules/tensorflow/tensorflow/lite/experimental/resource_variable", 28 | "submodules/tensorflow/tensorflow/lite/experimental/ruy", 29 | "submodules/tensorflow/tensorflow/lite/kernels", 30 | "submodules/tensorflow/tensorflow/lite/nnapi", 31 | "submodules/tensorflow/tensorflow/lite/profiling", 32 | "submodules/tensorflow/tensorflow/lite/schema/schema_generated.h", 33 | "submodules/tensorflow/tensorflow/lite/*.cc", 34 | "submodules/tensorflow/tensorflow/lite/*.h", 35 | "submodules/tensorflow/tensorflow/lite/tools/make", 36 | "submodules/tensorflow/third_party/eigen3", 37 | "submodules/tensorflow/third_party/fft2d", 38 | "submodules/tensorflow/tensorflow/core/kernels/eigen_convolution_helpers.h", 39 | "submodules/tensorflow/tensorflow/core/kernels/eigen_spatial_convolutions-inl.h", 40 | "submodules/tensorflow/tensorflow/core/public/version.h", 41 | ] 42 | 43 | [dependencies] 44 | cpp = "0.5.7" 45 | libc = "0.2.139" 46 | maybe-owned = "0.3.4" 47 | thiserror = "1.0.38" 48 | 49 | [build-dependencies] 50 | bart = { version = "0.1.6", optional = true } 51 | bart_derive = { version = "0.1.6", optional = true } 52 | bindgen = "0.69.4" 53 | cpp_build = "0.5.7" 54 | fs_extra = { version = "1.3.0", optional = true } 55 | 56 | [features] 57 | build = ["fs_extra"] 58 | default = ["build"] 59 | debug_tflite = ["build"] # use "libtensorflow-lite.a" built in debug mode 60 | generate_model_apis = ["bart", "bart_derive"] 61 | no_micro = ["build"] 62 | 63 | [package.metadata.docs.rs] 64 | all-features = false 65 | no-default-features = true 66 | default-target = "x86_64-unknown-linux-gnu" 67 | -------------------------------------------------------------------------------- /LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Boncheol Gu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Rust](https://github.com/boncheolgu/tflite-rs/actions/workflows/rust.yml/badge.svg)](https://github.com/boncheolgu/tflite-rs/actions/workflows/rust.yml) 2 | 3 | # Rust bindings for TensorFlow Lite 4 | 5 | This crates provides TensorFlow Lite APIs. 6 | Please read the [`API documentation on docs.rs`](https://docs.rs/crate/tflite) 7 | 8 | ### Using the interpreter from a model file 9 | 10 | The following example shows how to use the TensorFlow Lite interpreter when provided a TensorFlow Lite FlatBuffer file. 11 | The example also demonstrates how to run inference on input data. 12 | 13 | ```rust 14 | use std::fs::{self, File}; 15 | use std::io::Read; 16 | 17 | use tflite::ops::builtin::BuiltinOpResolver; 18 | use tflite::{FlatBufferModel, InterpreterBuilder, Result}; 19 | 20 | fn test_mnist(model: &FlatBufferModel) -> Result<()> { 21 | let resolver = BuiltinOpResolver::default(); 22 | 23 | let builder = InterpreterBuilder::new(model, &resolver)?; 24 | let mut interpreter = builder.build()?; 25 | 26 | interpreter.allocate_tensors()?; 27 | 28 | let inputs = interpreter.inputs().to_vec(); 29 | assert_eq!(inputs.len(), 1); 30 | 31 | let input_index = inputs[0]; 32 | 33 | let outputs = interpreter.outputs().to_vec(); 34 | assert_eq!(outputs.len(), 1); 35 | 36 | let output_index = outputs[0]; 37 | 38 | let input_tensor = interpreter.tensor_info(input_index).unwrap(); 39 | assert_eq!(input_tensor.dims, vec![1, 28, 28, 1]); 40 | 41 | let output_tensor = interpreter.tensor_info(output_index).unwrap(); 42 | assert_eq!(output_tensor.dims, vec![1, 10]); 43 | 44 | let mut input_file = File::open("data/mnist10.bin")?; 45 | for i in 0..10 { 46 | input_file.read_exact(interpreter.tensor_data_mut(input_index)?)?; 47 | 48 | interpreter.invoke()?; 49 | 50 | let output: &[u8] = interpreter.tensor_data(output_index)?; 51 | let guess = output.iter().enumerate().max_by(|x, y| x.1.cmp(y.1)).unwrap().0; 52 | 53 | println!("{}: {:?}", i, output); 54 | assert_eq!(i, guess); 55 | } 56 | Ok(()) 57 | } 58 | 59 | #[test] 60 | fn mobilenetv1_mnist() -> Result<()> { 61 | test_mnist(&FlatBufferModel::build_from_file("data/MNISTnet_uint8_quant.tflite")?)?; 62 | 63 | let buf = fs::read("data/MNISTnet_uint8_quant.tflite")?; 64 | test_mnist(&FlatBufferModel::build_from_buffer(buf)?) 65 | } 66 | 67 | #[test] 68 | fn mobilenetv2_mnist() -> Result<()> { 69 | test_mnist(&FlatBufferModel::build_from_file("data/MNISTnet_v2_uint8_quant.tflite")?)?; 70 | 71 | let buf = fs::read("data/MNISTnet_v2_uint8_quant.tflite")?; 72 | test_mnist(&FlatBufferModel::build_from_buffer(buf)?) 73 | } 74 | ``` 75 | 76 | ### Using the FlatBuffers model APIs 77 | 78 | This crate also provides a limited set of FlatBuffers model APIs. 79 | 80 | ```rust 81 | use tflite::model::stl::vector::{VectorInsert, VectorErase, VectorSlice}; 82 | use tflite::model::{BuiltinOperator, BuiltinOptions, Model, SoftmaxOptionsT}; 83 | 84 | #[test] 85 | fn flatbuffer_model_apis_inspect() { 86 | let model = Model::from_file("data/MNISTnet_uint8_quant.tflite").unwrap(); 87 | assert_eq!(model.version, 3); 88 | assert_eq!(model.operator_codes.size(), 5); 89 | assert_eq!(model.subgraphs.size(), 1); 90 | assert_eq!(model.buffers.size(), 24); 91 | assert_eq!( 92 | model.description.c_str().to_string_lossy(), 93 | "TOCO Converted." 94 | ); 95 | 96 | assert_eq!( 97 | model.operator_codes[0].builtin_code, 98 | BuiltinOperator::BuiltinOperator_AVERAGE_POOL_2D 99 | ); 100 | 101 | assert_eq!( 102 | model 103 | .operator_codes 104 | .iter() 105 | .map(|oc| oc.builtin_code) 106 | .collect::>(), 107 | vec![ 108 | BuiltinOperator::BuiltinOperator_AVERAGE_POOL_2D, 109 | BuiltinOperator::BuiltinOperator_CONV_2D, 110 | BuiltinOperator::BuiltinOperator_DEPTHWISE_CONV_2D, 111 | BuiltinOperator::BuiltinOperator_SOFTMAX, 112 | BuiltinOperator::BuiltinOperator_RESHAPE 113 | ] 114 | ); 115 | 116 | let subgraph = &model.subgraphs[0]; 117 | assert_eq!(subgraph.tensors.size(), 23); 118 | assert_eq!(subgraph.operators.size(), 9); 119 | assert_eq!(subgraph.inputs.as_slice(), &[22]); 120 | assert_eq!(subgraph.outputs.as_slice(), &[21]); 121 | 122 | let softmax = subgraph 123 | .operators 124 | .iter() 125 | .position(|op| { 126 | model.operator_codes[op.opcode_index as usize].builtin_code 127 | == BuiltinOperator::BuiltinOperator_SOFTMAX 128 | }) 129 | .unwrap(); 130 | 131 | assert_eq!(subgraph.operators[softmax].inputs.as_slice(), &[4]); 132 | assert_eq!(subgraph.operators[softmax].outputs.as_slice(), &[21]); 133 | assert_eq!( 134 | subgraph.operators[softmax].builtin_options.type_, 135 | BuiltinOptions::BuiltinOptions_SoftmaxOptions 136 | ); 137 | 138 | let softmax_options: &SoftmaxOptionsT = subgraph.operators[softmax].builtin_options.as_ref(); 139 | assert_eq!(softmax_options.beta, 1.); 140 | } 141 | 142 | #[test] 143 | fn flatbuffer_model_apis_mutate() { 144 | let mut model = Model::from_file("data/MNISTnet_uint8_quant.tflite").unwrap(); 145 | model.version = 2; 146 | model.operator_codes.erase(4); 147 | model.buffers.erase(22); 148 | model.buffers.erase(23); 149 | model 150 | .description 151 | .assign(CString::new("flatbuffer").unwrap()); 152 | 153 | { 154 | let subgraph = &mut model.subgraphs[0]; 155 | subgraph.inputs.erase(0); 156 | subgraph.outputs.assign(vec![1, 2, 3, 4]); 157 | } 158 | 159 | let model_buffer = model.to_buffer(); 160 | let model = Model::from_buffer(&model_buffer); 161 | assert_eq!(model.version, 2); 162 | assert_eq!(model.operator_codes.size(), 4); 163 | assert_eq!(model.subgraphs.size(), 1); 164 | assert_eq!(model.buffers.size(), 22); 165 | assert_eq!(model.description.c_str().to_string_lossy(), "flatbuffer"); 166 | 167 | let subgraph = &model.subgraphs[0]; 168 | assert_eq!(subgraph.tensors.size(), 23); 169 | assert_eq!(subgraph.operators.size(), 9); 170 | assert!(subgraph.inputs.as_slice().is_empty()); 171 | assert_eq!(subgraph.outputs.as_slice(), &[1, 2, 3, 4]); 172 | } 173 | ``` 174 | -------------------------------------------------------------------------------- /build.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "generate_model_apis")] 2 | #[macro_use] 3 | extern crate bart_derive; 4 | 5 | use std::env; 6 | use std::env::VarError; 7 | use std::path::{Path, PathBuf}; 8 | #[cfg(feature = "build")] 9 | use std::time::Instant; 10 | 11 | fn manifest_dir() -> PathBuf { 12 | PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap()) 13 | } 14 | 15 | fn submodules() -> PathBuf { 16 | manifest_dir().join("submodules") 17 | } 18 | 19 | #[cfg(feature = "build")] 20 | fn prepare_tensorflow_source() -> PathBuf { 21 | println!("Moving tflite source"); 22 | let start = Instant::now(); 23 | let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); 24 | let tf_src_dir = out_dir.join("tensorflow/tensorflow"); 25 | let submodules = submodules(); 26 | 27 | let mut copy_dir = fs_extra::dir::CopyOptions::new(); 28 | copy_dir.overwrite = true; 29 | copy_dir.buffer_size = 65536; 30 | 31 | if !tf_src_dir.exists() { 32 | fs_extra::dir::copy(submodules.join("tensorflow"), &out_dir, ©_dir) 33 | .expect("Unable to copy tensorflow"); 34 | } 35 | 36 | let download_dir = tf_src_dir.join("lite/tools/make/downloads"); 37 | if !download_dir.exists() { 38 | fs_extra::dir::copy( 39 | submodules.join("downloads"), 40 | download_dir.parent().unwrap(), 41 | ©_dir, 42 | ) 43 | .expect("Unable to copy download dir"); 44 | } 45 | 46 | println!("Moving source took {:?}", start.elapsed()); 47 | 48 | tf_src_dir 49 | } 50 | 51 | fn binary_changing_features() -> String { 52 | let mut features = String::new(); 53 | if cfg!(feature = "debug_tflite") { 54 | features.push_str("-debug"); 55 | } 56 | if cfg!(feature = "no_micro") { 57 | features.push_str("-no_micro"); 58 | } 59 | features 60 | } 61 | 62 | fn prepare_tensorflow_library() { 63 | let arch = env::var("CARGO_CFG_TARGET_ARCH").expect("Unable to get TARGET_ARCH"); 64 | 65 | #[cfg(feature = "build")] 66 | { 67 | let tflite = prepare_tensorflow_source(); 68 | let out_dir = env::var("OUT_DIR").unwrap(); 69 | // append tf_lib_name with features that can change how it is built 70 | // so a cached version that doesn't match expectations isn't used 71 | let binary_changing_features = binary_changing_features(); 72 | let tf_lib_name = 73 | Path::new(&out_dir).join(format!("libtensorflow-lite{binary_changing_features}.a")); 74 | let os = env::var("CARGO_CFG_TARGET_OS").expect("Unable to get TARGET_OS"); 75 | if !tf_lib_name.exists() { 76 | println!("Building tflite"); 77 | let start = Instant::now(); 78 | let mut make = std::process::Command::new("make"); 79 | if let Ok(prefix) = env::var("TARGET_TOOLCHAIN_PREFIX") { 80 | make.arg(format!("TARGET_TOOLCHAIN_PREFIX={prefix}")); 81 | } else { 82 | let target_triple = env::var("TARGET").unwrap(); 83 | let host_triple = env::var("HOST").unwrap(); 84 | let kind = if host_triple == target_triple { "HOST" } else { "TARGET" }; 85 | let target_u = target_triple.replace('-', "_"); 86 | for name in ["CC", "CXX", "AR", "CFLAGS", "CXXFLAGS", "ARFLAGS"] { 87 | if let Ok(value) = env::var(&format!("{name}_{target_triple}")) 88 | .or_else(|_| env::var(format!("{name}_{target_u}"))) 89 | .or_else(|_| env::var(format!("{kind}_{name}"))) 90 | .or_else(|_| env::var(name)) 91 | { 92 | make.arg(format!("{name}={value}")); 93 | println!("inherited: {name}={value}") 94 | } 95 | } 96 | } 97 | 98 | // Use cargo's cross-compilation information while building tensorflow 99 | // Now that tensorflow has an aarch64_makefile.inc use theirs 100 | let target = if &arch == "aarch64" { &arch } else { &os }; 101 | 102 | #[cfg(feature = "debug_tflite")] 103 | { 104 | println!("Feature debug_tflite enabled. Changing optimization to 0"); 105 | let makefile = tflite.join("lite/tools/make/Makefile"); 106 | let makefile_contents = 107 | std::fs::read_to_string(&makefile).expect("Unable to read Makefile"); 108 | let replaced = makefile_contents.replace("-O3", "-Og -g").replace("-DNDEBUG", ""); 109 | std::fs::write(&makefile, &replaced).expect("Unable to write Makefile"); 110 | if !replaced.contains("-Og") { 111 | panic!("Unable to change optimization settings"); 112 | } 113 | } 114 | 115 | let make_dir = tflite.parent().unwrap(); 116 | 117 | // allow parallelism to be overridden... 118 | let num_jobs = env::var("TFLITE_RS_MAKE_PARALLELISM").ok().or_else(|| { 119 | // but prefer jobserver if not explicitly given 120 | if !env::var("MAKEFLAGS").unwrap_or_default().contains("--jobserver") { 121 | env::var("NUM_JOBS").ok() 122 | } else { 123 | None 124 | } 125 | }); 126 | if let Some(num_jobs) = num_jobs { 127 | make.arg("-j").arg(num_jobs); 128 | } 129 | 130 | make.arg("BUILD_WITH_NNAPI=false").arg("-f").arg("tensorflow/lite/tools/make/Makefile"); 131 | 132 | for (make_var, default) in &[ 133 | ("TARGET", Some(target.as_str())), 134 | ("TARGET_ARCH", Some(arch.as_str())), 135 | ("TARGET_TOOLCHAIN_PREFIX", None), 136 | ("EXTRA_CFLAGS", None), 137 | ("EXTRA_CXXFLAGS", None), 138 | ] { 139 | let env_var = format!("TFLITE_RS_MAKE_{make_var}"); 140 | println!("cargo:rerun-if-env-changed={env_var}"); 141 | 142 | match env::var(&env_var) { 143 | Ok(result) => { 144 | make.arg(format!("{make_var}={result}")); 145 | } 146 | Err(VarError::NotPresent) => { 147 | // Try and set some reasonable default values 148 | if let Some(result) = default { 149 | make.arg(format!("{make_var}={result}")); 150 | } 151 | } 152 | Err(VarError::NotUnicode(_)) => { 153 | panic!("Provided a non-unicode value for {env_var}") 154 | } 155 | } 156 | } 157 | 158 | if cfg!(feature = "no_micro") { 159 | println!("Building lib but no micro"); 160 | make.arg("lib"); 161 | } else { 162 | make.arg("micro"); 163 | } 164 | make.current_dir(make_dir); 165 | eprintln!("make command = {make:?} in dir {make_dir:?}"); 166 | if !make.status().expect("failed to run make command").success() { 167 | panic!("Failed to build tensorflow"); 168 | } 169 | 170 | // find library 171 | let library = std::fs::read_dir(tflite.join("lite/tools/make/gen")) 172 | .expect("Make gen file should exist") 173 | .filter_map(|de| Some(de.ok()?.path().join("lib/libtensorflow-lite.a"))) 174 | .find(|p| p.exists()) 175 | .expect("Unable to find libtensorflow-lite.a"); 176 | std::fs::copy(library, &tf_lib_name).unwrap_or_else(|_| { 177 | panic!("Unable to copy libtensorflow-lite.a to {}", tf_lib_name.display()) 178 | }); 179 | 180 | println!("Building tflite from source took {:?}", start.elapsed()); 181 | } 182 | println!("cargo:rustc-link-search=native={out_dir}"); 183 | println!("cargo:rustc-link-lib=static=tensorflow-lite{binary_changing_features}"); 184 | } 185 | #[cfg(not(feature = "build"))] 186 | { 187 | let arch_var = format!("TFLITE_{}_LIB_DIR", arch.replace("-", "_").to_uppercase()); 188 | let all_var = "TFLITE_LIB_DIR".to_string(); 189 | let lib_dir = env::var(&arch_var).or(env::var(&all_var)).unwrap_or_else(|_| { 190 | panic!( 191 | "[feature = build] not set and environment variables {} and {} are not set", 192 | arch_var, all_var 193 | ) 194 | }); 195 | println!("cargo:rustc-link-search=native={}", lib_dir); 196 | let static_dynamic = if Path::new(&lib_dir).join("libtensorflow-lite.a").exists() { 197 | "static" 198 | } else { 199 | "dylib" 200 | }; 201 | println!("cargo:rustc-link-lib={}=tensorflow-lite", static_dynamic); 202 | println!("cargo:rerun-if-changed={}", lib_dir); 203 | } 204 | println!("cargo:rustc-link-lib=dylib=pthread"); 205 | println!("cargo:rustc-link-lib=dylib=dl"); 206 | } 207 | 208 | // This generates "tflite_types.rs" containing structs and enums which are inter-operable with Glow. 209 | fn import_tflite_types() { 210 | use bindgen::*; 211 | 212 | let submodules = submodules(); 213 | let submodules_str = submodules.to_string_lossy(); 214 | let bindings = Builder::default() 215 | .allowlist_recursively(true) 216 | .prepend_enum_name(false) 217 | .impl_debug(true) 218 | .with_codegen_config(CodegenConfig::TYPES) 219 | .layout_tests(false) 220 | .enable_cxx_namespaces() 221 | .derive_default(true) 222 | .size_t_is_usize(true) 223 | // for model APIs 224 | .allowlist_type("tflite::ModelT") 225 | .allowlist_type(".+OptionsT") 226 | .blocklist_type(".+_TableType") 227 | // for interpreter 228 | .allowlist_type("tflite::FlatBufferModel") 229 | .opaque_type("tflite::FlatBufferModel") 230 | .allowlist_type("tflite::InterpreterBuilder") 231 | .opaque_type("tflite::InterpreterBuilder") 232 | .allowlist_type("tflite::Interpreter") 233 | .opaque_type("tflite::Interpreter") 234 | .allowlist_type("tflite::ops::builtin::BuiltinOpResolver") 235 | .opaque_type("tflite::ops::builtin::BuiltinOpResolver") 236 | .allowlist_type("tflite::OpResolver") 237 | .opaque_type("tflite::OpResolver") 238 | .allowlist_type("TfLiteTensor") 239 | .opaque_type("std::string") 240 | .opaque_type("flatbuffers::NativeTable") 241 | .blocklist_type("std") 242 | .blocklist_type("tflite::Interpreter_TfLiteDelegatePtr") 243 | .blocklist_type("tflite::Interpreter_State") 244 | .default_enum_style(EnumVariation::Rust { non_exhaustive: false }) 245 | .derive_partialeq(true) 246 | .derive_eq(true) 247 | .header("csrc/tflite_wrapper.hpp") 248 | .clang_arg(format!("-I{submodules_str}/tensorflow")) 249 | .clang_arg(format!("-I{submodules_str}/downloads/flatbuffers/include")) 250 | .clang_arg("-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK") 251 | .clang_arg("-DFLATBUFFERS_POLYMORPHIC_NATIVETABLE") 252 | .clang_arg("-x") 253 | .clang_arg("c++") 254 | .clang_arg("-std=c++11") 255 | // required to get cross compilation for aarch64 to work because of an issue in flatbuffers 256 | .clang_arg("-fms-extensions") 257 | .no_copy("_Tp"); 258 | 259 | let bindings = bindings.generate().expect("Unable to generate bindings"); 260 | 261 | // Write the bindings to the $OUT_DIR/tflite_types.rs file. 262 | let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()).join("tflite_types.rs"); 263 | bindings.write_to_file(out_path).expect("Couldn't write bindings!"); 264 | } 265 | 266 | fn build_inline_cpp() { 267 | let submodules = submodules(); 268 | 269 | cpp_build::Config::new() 270 | .include(submodules.join("tensorflow")) 271 | .include(submodules.join("downloads/flatbuffers/include")) 272 | .flag("-fPIC") 273 | .flag("-std=c++14") 274 | .flag("-Wno-sign-compare") 275 | .define("GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK", None) 276 | .define("FLATBUFFERS_POLYMORPHIC_NATIVETABLE", None) 277 | .debug(true) 278 | .opt_level(if cfg!(debug_assertions) { 0 } else { 2 }) 279 | .build("src/lib.rs"); 280 | } 281 | 282 | fn import_stl_types() { 283 | use bindgen::*; 284 | 285 | let bindings = Builder::default() 286 | .enable_cxx_namespaces() 287 | .allowlist_type("std::string") 288 | .opaque_type("std::string") 289 | .allowlist_type("rust::.+") 290 | .opaque_type("rust::.+") 291 | .blocklist_type("std") 292 | .header("csrc/stl_wrapper.hpp") 293 | .layout_tests(false) 294 | .derive_partialeq(true) 295 | .derive_eq(true) 296 | .clang_arg("-x") 297 | .clang_arg("c++") 298 | .clang_arg("-std=c++14") 299 | .clang_arg("-fms-extensions") 300 | .formatter(Formatter::Rustfmt) 301 | .generate() 302 | .expect("Unable to generate STL bindings"); 303 | 304 | // Write the bindings to the $OUT_DIR/tflite_types.rs file. 305 | let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()).join("stl_types.rs"); 306 | bindings.write_to_file(out_path).expect("Couldn't write bindings!"); 307 | } 308 | 309 | #[cfg(feature = "generate_model_apis")] 310 | fn generate_memory_impl() -> Result<(), Box> { 311 | use std::io::Write; 312 | let mut file = std::fs::File::create("src/model/stl/memory_impl.rs")?; 313 | writeln!( 314 | &mut file, 315 | r#" 316 | #![allow(clippy::transmute_num_to_bytes)] 317 | use std::{{fmt, mem}}; 318 | use std::ops::{{Deref, DerefMut}}; 319 | 320 | use crate::model::stl::memory::UniquePtr; 321 | "# 322 | )?; 323 | 324 | #[derive(BartDisplay)] 325 | #[template = "data/memory_basic_impl.rs.template"] 326 | struct MemoryBasicImpl<'a> { 327 | cpp_type: &'a str, 328 | rust_type: &'a str, 329 | } 330 | 331 | let memory_types = vec![ 332 | ("OperatorCodeT", "crate::model::OperatorCodeT"), 333 | ("TensorT", "crate::model::TensorT"), 334 | ("OperatorT", "crate::model::OperatorT"), 335 | ("SubGraphT", "crate::model::SubGraphT"), 336 | ("BufferT", "crate::model::BufferT"), 337 | ("QuantizationParametersT", "crate::model::QuantizationParametersT"), 338 | ("ModelT", "crate::model::ModelT"), 339 | ("MetadataT", "crate::model::MetadataT"), 340 | ]; 341 | 342 | for (cpp_type, rust_type) in memory_types { 343 | writeln!(&mut file, "{}\n", &MemoryBasicImpl { cpp_type, rust_type },)?; 344 | } 345 | Ok(()) 346 | } 347 | 348 | #[cfg(feature = "generate_model_apis")] 349 | fn generate_vector_impl() -> Result<(), Box> { 350 | use std::io::Write; 351 | let mut file = std::fs::File::create("src/model/stl/vector_impl.rs")?; 352 | writeln!( 353 | &mut file, 354 | r#" 355 | #![allow(clippy::transmute_num_to_bytes)] 356 | use std::{{fmt, mem, slice}}; 357 | use std::ops::{{Deref, DerefMut, Index, IndexMut}}; 358 | 359 | use libc::size_t; 360 | 361 | use super::memory::UniquePtr; 362 | use super::vector::{{VectorOfUniquePtr, VectorErase, VectorExtract, VectorInsert, VectorSlice}}; 363 | use crate::model::stl::bindings::root::rust::dummy_vector; 364 | 365 | cpp! {{{{ 366 | #include 367 | }}}} 368 | "# 369 | )?; 370 | 371 | #[derive(BartDisplay)] 372 | #[template = "data/vector_primitive_impl.rs.template"] 373 | #[allow(non_snake_case)] 374 | struct VectorPrimitiveImpl<'a> { 375 | cpp_type: &'a str, 376 | rust_type: &'a str, 377 | RustType: &'a str, 378 | } 379 | 380 | let vector_types = vec![ 381 | ("uint8_t", "u8", "U8"), 382 | ("int32_t", "i32", "I32"), 383 | ("int64_t", "i64", "I64"), 384 | ("float", "f32", "F32"), 385 | ]; 386 | 387 | #[allow(non_snake_case)] 388 | for (cpp_type, rust_type, RustType) in vector_types { 389 | writeln!(&mut file, "{}\n", &VectorPrimitiveImpl { cpp_type, rust_type, RustType },)?; 390 | } 391 | 392 | #[derive(BartDisplay)] 393 | #[template = "data/vector_basic_impl.rs.template"] 394 | struct VectorBasicImpl<'a> { 395 | cpp_type: &'a str, 396 | rust_type: &'a str, 397 | } 398 | 399 | let vector_types = vec![ 400 | ("std::unique_ptr", "UniquePtr"), 401 | ("std::unique_ptr", "UniquePtr"), 402 | ("std::unique_ptr", "UniquePtr"), 403 | ("std::unique_ptr", "UniquePtr"), 404 | ("std::unique_ptr", "UniquePtr"), 405 | ("std::unique_ptr", "UniquePtr"), 406 | ]; 407 | 408 | for (cpp_type, rust_type) in vector_types { 409 | writeln!(&mut file, "{}\n", &VectorBasicImpl { cpp_type, rust_type },)?; 410 | } 411 | Ok(()) 412 | } 413 | 414 | #[cfg(feature = "generate_model_apis")] 415 | fn generate_builtin_options_impl() -> Result<(), Box> { 416 | use std::io::Write; 417 | let mut file = std::fs::File::create("src/model/builtin_options_impl.rs")?; 418 | writeln!( 419 | &mut file, 420 | r#" 421 | use super::{{BuiltinOptions, BuiltinOptionsUnion, NativeTable}}; 422 | "# 423 | )?; 424 | 425 | #[derive(BartDisplay)] 426 | #[template = "data/builtin_options_impl.rs.template"] 427 | struct BuiltinOptionsImpl<'a> { 428 | name: &'a str, 429 | } 430 | 431 | let option_names = vec![ 432 | "Conv2DOptions", 433 | "DepthwiseConv2DOptions", 434 | "ConcatEmbeddingsOptions", 435 | "LSHProjectionOptions", 436 | "Pool2DOptions", 437 | "SVDFOptions", 438 | "RNNOptions", 439 | "FullyConnectedOptions", 440 | "SoftmaxOptions", 441 | "ConcatenationOptions", 442 | "AddOptions", 443 | "L2NormOptions", 444 | "LocalResponseNormalizationOptions", 445 | "LSTMOptions", 446 | "ResizeBilinearOptions", 447 | "CallOptions", 448 | "ReshapeOptions", 449 | "SkipGramOptions", 450 | "SpaceToDepthOptions", 451 | "EmbeddingLookupSparseOptions", 452 | "MulOptions", 453 | "PadOptions", 454 | "GatherOptions", 455 | "BatchToSpaceNDOptions", 456 | "SpaceToBatchNDOptions", 457 | "TransposeOptions", 458 | "ReducerOptions", 459 | "SubOptions", 460 | "DivOptions", 461 | "SqueezeOptions", 462 | "SequenceRNNOptions", 463 | "StridedSliceOptions", 464 | "ExpOptions", 465 | "TopKV2Options", 466 | "SplitOptions", 467 | "LogSoftmaxOptions", 468 | "CastOptions", 469 | "DequantizeOptions", 470 | "MaximumMinimumOptions", 471 | "ArgMaxOptions", 472 | "LessOptions", 473 | "NegOptions", 474 | "PadV2Options", 475 | "GreaterOptions", 476 | "GreaterEqualOptions", 477 | "LessEqualOptions", 478 | "SelectOptions", 479 | "SliceOptions", 480 | "TransposeConvOptions", 481 | "SparseToDenseOptions", 482 | "TileOptions", 483 | "ExpandDimsOptions", 484 | "EqualOptions", 485 | "NotEqualOptions", 486 | "ShapeOptions", 487 | "PowOptions", 488 | "ArgMinOptions", 489 | "FakeQuantOptions", 490 | "PackOptions", 491 | "LogicalOrOptions", 492 | "OneHotOptions", 493 | "LogicalAndOptions", 494 | "LogicalNotOptions", 495 | "UnpackOptions", 496 | "FloorDivOptions", 497 | "SquareOptions", 498 | "ZerosLikeOptions", 499 | "FillOptions", 500 | "BidirectionalSequenceLSTMOptions", 501 | "BidirectionalSequenceRNNOptions", 502 | "UnidirectionalSequenceLSTMOptions", 503 | "FloorModOptions", 504 | "RangeOptions", 505 | "ResizeNearestNeighborOptions", 506 | "LeakyReluOptions", 507 | "SquaredDifferenceOptions", 508 | "MirrorPadOptions", 509 | "AbsOptions", 510 | "SplitVOptions", 511 | "UniqueOptions", 512 | "ReverseV2Options", 513 | "AddNOptions", 514 | "GatherNdOptions", 515 | "CosOptions", 516 | "WhereOptions", 517 | "RankOptions", 518 | "ReverseSequenceOptions", 519 | "MatrixDiagOptions", 520 | "QuantizeOptions", 521 | "MatrixSetDiagOptions", 522 | "HardSwishOptions", 523 | "IfOptions", 524 | "WhileOptions", 525 | "DepthToSpaceOptions", 526 | ]; 527 | 528 | for name in option_names { 529 | writeln!(&mut file, "{}\n", &BuiltinOptionsImpl { name },)?; 530 | } 531 | Ok(()) 532 | } 533 | 534 | fn main() { 535 | import_stl_types(); 536 | #[cfg(feature = "generate_model_apis")] 537 | { 538 | generate_memory_impl().unwrap(); 539 | generate_vector_impl().unwrap(); 540 | generate_builtin_options_impl().unwrap(); 541 | } 542 | import_tflite_types(); 543 | build_inline_cpp(); 544 | if env::var("DOCS_RS").is_err() { 545 | prepare_tensorflow_library(); 546 | } 547 | } 548 | -------------------------------------------------------------------------------- /csrc/stl_wrapper.hpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | namespace rust { 6 | 7 | struct alignas(alignof(std::vector>)) dummy_vector { 8 | uint8_t payload[sizeof(std::vector>)]; 9 | }; 10 | 11 | struct alignas(alignof(std::vector)) vector_of_bool { 12 | uint8_t payload[sizeof(std::vector)]; 13 | }; 14 | 15 | struct alignas(alignof(std::unique_ptr)) unique_ptr_of_void { 16 | uint8_t payload[sizeof(std::unique_ptr)]; 17 | }; 18 | 19 | } // namespace rust 20 | -------------------------------------------------------------------------------- /csrc/tflite_wrapper.hpp: -------------------------------------------------------------------------------- 1 | #include "tensorflow/lite/interpreter.h" 2 | #include "tensorflow/lite/kernels/register.h" 3 | #include "tensorflow/lite/model.h" 4 | #include "tensorflow/lite/optional_debug_tools.h" 5 | -------------------------------------------------------------------------------- /data/MNISTnet_uint8_quant.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boncheolgu/tflite-rs/c8490ec0420107860f40d429ba88c2807e8633e9/data/MNISTnet_uint8_quant.tflite -------------------------------------------------------------------------------- /data/MNISTnet_v2_uint8_quant.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boncheolgu/tflite-rs/c8490ec0420107860f40d429ba88c2807e8633e9/data/MNISTnet_v2_uint8_quant.tflite -------------------------------------------------------------------------------- /data/builtin_options_impl.rs.template: -------------------------------------------------------------------------------- 1 | impl BuiltinOptionsUnion { 2 | #[allow(non_snake_case, deprecated)] 3 | pub fn {{{name}}}() -> Self { 4 | let value = unsafe { 5 | cpp!([] -> *mut NativeTable as "flatbuffers::NativeTable*" { 6 | return new {{{name}}}T; 7 | }) 8 | }; 9 | 10 | Self { 11 | typ: BuiltinOptions::BuiltinOptions_{{{name}}}, 12 | value, 13 | } 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /data/memory_basic_impl.rs.template: -------------------------------------------------------------------------------- 1 | #[allow(deprecated)] 2 | impl Default for UniquePtr<{{{rust_type}}}> { 3 | fn default() -> Self { 4 | let mut this: Self = unsafe { mem::zeroed() }; 5 | let this_ref = &mut this; 6 | unsafe { 7 | cpp!([this_ref as "std::unique_ptr<{{{cpp_type}}}>*"] { 8 | new (this_ref) std::unique_ptr<{{{cpp_type}}}>(new {{{cpp_type}}}); 9 | }) 10 | } 11 | this 12 | } 13 | } 14 | 15 | #[allow(deprecated)] 16 | impl Deref for UniquePtr<{{{rust_type}}}> { 17 | type Target = {{{rust_type}}}; 18 | 19 | fn deref(&self) -> &Self::Target { 20 | unsafe { 21 | let ptr = cpp!([self as "const std::unique_ptr<{{{cpp_type}}}>*"] -> *const {{{rust_type}}} as "const {{{cpp_type}}}*" { 22 | return self->get(); 23 | }); 24 | 25 | ptr.as_ref().unwrap() 26 | } 27 | } 28 | } 29 | 30 | #[allow(deprecated)] 31 | impl DerefMut for UniquePtr<{{{rust_type}}}> { 32 | fn deref_mut(&mut self) -> &mut Self::Target { 33 | unsafe { 34 | let ptr = cpp!([self as "std::unique_ptr<{{{cpp_type}}}>*"] -> *mut {{{rust_type}}} as "{{{cpp_type}}}*" { 35 | return self->get(); 36 | }); 37 | 38 | ptr.as_mut().unwrap() 39 | } 40 | } 41 | } 42 | 43 | #[allow(deprecated)] 44 | impl fmt::Debug for UniquePtr<{{{rust_type}}}> 45 | { 46 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 47 | write!(f, "({:?})", self.deref()) 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /data/mnist10.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boncheolgu/tflite-rs/c8490ec0420107860f40d429ba88c2807e8633e9/data/mnist10.bin -------------------------------------------------------------------------------- /data/vector_basic_impl.rs.template: -------------------------------------------------------------------------------- 1 | #[allow(deprecated)] 2 | impl Default for VectorOf{{{rust_type}}} { 3 | fn default() -> Self { 4 | let mut this = unsafe{ mem::zeroed() }; 5 | let this_ref = &mut this; 6 | unsafe { 7 | cpp!([this_ref as "std::vector<{{{cpp_type}}}>*"] { 8 | new (this_ref) const std::vector<{{{cpp_type}}}>; 9 | }) 10 | } 11 | this 12 | } 13 | } 14 | 15 | #[allow(deprecated)] 16 | impl VectorSlice for VectorOf{{{rust_type}}} { 17 | type Item = {{{rust_type}}}; 18 | 19 | fn get_ptr(&self) -> *const Self::Item { 20 | unsafe { 21 | cpp!([self as "const std::vector<{{{cpp_type}}}>*"] 22 | -> *const {{{rust_type}}} as "const {{{cpp_type}}}*" { 23 | return self->data(); 24 | }) 25 | } 26 | } 27 | 28 | fn get_mut_ptr(&mut self) -> *mut Self::Item { 29 | unsafe { 30 | cpp!([self as "std::vector<{{{cpp_type}}}>*"] 31 | -> *mut {{{rust_type}}} as "{{{cpp_type}}}*" { 32 | return self->data(); 33 | }) 34 | } 35 | } 36 | 37 | fn size(&self) -> usize { 38 | unsafe { 39 | cpp!([self as "const std::vector<{{{cpp_type}}}>*"] -> size_t as "size_t" { 40 | return self->size(); 41 | }) 42 | } 43 | } 44 | } 45 | 46 | #[allow(deprecated)] 47 | impl VectorErase for VectorOf{{{rust_type}}} { 48 | fn erase_range(&mut self, offset: usize, size: usize) { 49 | let begin = offset as size_t; 50 | let end = offset + size as size_t; 51 | unsafe { 52 | cpp!([self as "std::vector<{{{cpp_type}}}>*", begin as "size_t", end as "size_t"] { 53 | self->erase(self->begin() + begin, self->begin() + end); 54 | }); 55 | } 56 | } 57 | } 58 | 59 | #[allow(deprecated)] 60 | impl VectorInsert<{{{rust_type}}}> for VectorOf{{{rust_type}}} { 61 | fn push_back(&mut self, mut v: Self::Item) { 62 | let vref = &mut v; 63 | unsafe { 64 | cpp!([self as "std::vector<{{{cpp_type}}}>*", vref as "{{{cpp_type}}}*"] { 65 | self->push_back(std::move(*vref)); 66 | }) 67 | } 68 | mem::forget(v); 69 | } 70 | } 71 | 72 | #[allow(deprecated)] 73 | impl VectorExtract<{{{rust_type}}}> for VectorOf{{{rust_type}}} { 74 | fn extract(&mut self, index: usize) -> {{{rust_type}}} { 75 | assert!(index < self.size()); 76 | let mut v: {{{rust_type}}} = unsafe { mem::zeroed() }; 77 | let vref = &mut v; 78 | unsafe { 79 | cpp!([self as "std::vector<{{{cpp_type}}}>*", index as "size_t", vref as "{{{cpp_type}}}*"] { 80 | *vref = std::move((*self)[index]); 81 | }) 82 | } 83 | v 84 | } 85 | } 86 | 87 | add_impl!(VectorOf{{{rust_type}}}); 88 | -------------------------------------------------------------------------------- /data/vector_primitive_impl.rs.template: -------------------------------------------------------------------------------- 1 | #[repr(C)] 2 | pub struct VectorOf{{{RustType}}}(dummy_vector); 3 | 4 | #[allow(deprecated)] 5 | impl Default for VectorOf{{{RustType}}} { 6 | fn default() -> Self { 7 | let mut this = unsafe{ mem::zeroed() }; 8 | let this_ref = &mut this; 9 | unsafe { 10 | cpp!([this_ref as "std::vector<{{{cpp_type}}}>*"] { 11 | new (this_ref) const std::vector<{{{cpp_type}}}>; 12 | }) 13 | } 14 | this 15 | } 16 | } 17 | 18 | #[allow(deprecated)] 19 | impl Drop for VectorOf{{{RustType}}} { 20 | fn drop(&mut self) { 21 | unsafe { 22 | cpp!([self as "const std::vector<{{{cpp_type}}}>*"] { 23 | self->~vector<{{{cpp_type}}}>(); 24 | }) 25 | } 26 | } 27 | } 28 | 29 | #[allow(deprecated)] 30 | impl Clone for VectorOf{{{RustType}}} { 31 | fn clone(&self) -> Self { 32 | let mut cloned = unsafe { mem::zeroed() }; 33 | let cloned_ref = &mut cloned; 34 | unsafe { 35 | cpp!([self as "const std::vector<{{{cpp_type}}}>*", cloned_ref as "std::vector<{{{cpp_type}}}>*"] { 36 | new (cloned_ref) std::vector<{{{cpp_type}}}>(*self); 37 | }); 38 | } 39 | cloned 40 | } 41 | } 42 | 43 | impl PartialEq for VectorOf{{{RustType}}} { 44 | fn eq(&self, other: &Self) -> bool { 45 | self.as_slice() == other.as_slice() 46 | } 47 | } 48 | 49 | impl Eq for VectorOf{{{RustType}}} {} 50 | 51 | #[allow(deprecated)] 52 | impl VectorSlice for VectorOf{{{RustType}}} { 53 | type Item = {{{rust_type}}}; 54 | 55 | fn get_ptr(&self) -> *const Self::Item { 56 | unsafe { 57 | cpp!([self as "const std::vector<{{{cpp_type}}}>*"] 58 | -> *const {{{rust_type}}} as "const {{{cpp_type}}}*" { 59 | return self->data(); 60 | }) 61 | } 62 | } 63 | 64 | fn get_mut_ptr(&mut self) -> *mut Self::Item { 65 | unsafe { 66 | cpp!([self as "std::vector<{{{cpp_type}}}>*"] 67 | -> *mut {{{rust_type}}} as "{{{cpp_type}}}*" { 68 | return self->data(); 69 | }) 70 | } 71 | } 72 | 73 | fn size(&self) -> usize { 74 | unsafe { 75 | cpp!([self as "const std::vector<{{{cpp_type}}}>*"] -> size_t as "size_t" { 76 | return self->size(); 77 | }) 78 | } 79 | } 80 | } 81 | 82 | #[allow(deprecated)] 83 | impl VectorErase for VectorOf{{{RustType}}} { 84 | fn erase_range(&mut self, offset: usize, size: usize) { 85 | let begin = offset as size_t; 86 | let end = offset + size as size_t; 87 | unsafe { 88 | cpp!([self as "std::vector<{{{cpp_type}}}>*", begin as "size_t", end as "size_t"] { 89 | self->erase(self->begin() + begin, self->begin() + end); 90 | }); 91 | } 92 | } 93 | } 94 | 95 | #[allow(deprecated)] 96 | impl VectorInsert<{{{rust_type}}}> for VectorOf{{{RustType}}} { 97 | fn push_back(&mut self, mut v: Self::Item) { 98 | let vref = &mut v; 99 | unsafe { 100 | cpp!([self as "std::vector<{{{cpp_type}}}>*", vref as "{{{cpp_type}}}*"] { 101 | self->push_back(std::move(*vref)); 102 | }) 103 | } 104 | } 105 | } 106 | 107 | #[allow(deprecated)] 108 | impl VectorExtract<{{{rust_type}}}> for VectorOf{{{RustType}}} { 109 | fn extract(&mut self, index: usize) -> {{{rust_type}}} { 110 | assert!(index < self.size()); 111 | let mut v: {{{rust_type}}} = unsafe { mem::zeroed() }; 112 | let vref = &mut v; 113 | unsafe { 114 | cpp!([self as "std::vector<{{{cpp_type}}}>*", index as "size_t", vref as "{{{cpp_type}}}*"] { 115 | *vref = std::move((*self)[index]); 116 | }) 117 | } 118 | v 119 | } 120 | } 121 | 122 | add_impl!(VectorOf{{{RustType}}}); 123 | -------------------------------------------------------------------------------- /examples/minimal.rs: -------------------------------------------------------------------------------- 1 | use std::env::args; 2 | 3 | use tflite::ops::builtin::BuiltinOpResolver; 4 | use tflite::{FlatBufferModel, InterpreterBuilder, Result}; 5 | 6 | pub fn main() -> Result<()> { 7 | assert_eq!(args().len(), 2, "minimal "); 8 | 9 | let filename = args().nth(1).unwrap(); 10 | 11 | let model = FlatBufferModel::build_from_file(filename)?; 12 | let resolver = BuiltinOpResolver::default(); 13 | 14 | let builder = InterpreterBuilder::new(&model, &resolver)?; 15 | let mut interpreter = builder.build()?; 16 | 17 | interpreter.allocate_tensors()?; 18 | 19 | println!("=== Pre-invoke Interpreter State ==="); 20 | interpreter.print_state(); 21 | 22 | interpreter.invoke()?; 23 | 24 | println!("\n\n=== Post-invoke Interpreter State ==="); 25 | interpreter.print_state(); 26 | Ok(()) 27 | } 28 | -------------------------------------------------------------------------------- /rustfmt.toml: -------------------------------------------------------------------------------- 1 | edition = "2018" 2 | 3 | # https://github.com/rust-lang/rustfmt/blob/master/Configurations.md#use_small_heuristics 4 | use_small_heuristics = "Max" 5 | -------------------------------------------------------------------------------- /src/bindings.rs: -------------------------------------------------------------------------------- 1 | #![allow(dead_code, clippy::all)] 2 | 3 | pub use self::root::*; 4 | 5 | include!(concat!(env!("OUT_DIR"), "/tflite_types.rs")); 6 | -------------------------------------------------------------------------------- /src/error.rs: -------------------------------------------------------------------------------- 1 | use std::io::Error as IoError; 2 | 3 | use thiserror::Error; 4 | 5 | #[derive(Error, Debug)] 6 | pub enum Error { 7 | #[error(transparent)] 8 | IoError(#[from] IoError), 9 | #[error("`{0}`")] 10 | InternalError(String), 11 | } 12 | 13 | impl Error { 14 | pub fn internal_error>(s: T) -> Self { 15 | Self::InternalError(s.into()) 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /src/interpreter/builder.rs: -------------------------------------------------------------------------------- 1 | use maybe_owned::MaybeOwned; 2 | 3 | use super::op_resolver::OpResolver; 4 | use super::FlatBufferModel; 5 | use super::Interpreter; 6 | use crate::bindings::tflite as bindings; 7 | use crate::{Error, Result}; 8 | 9 | cpp! {{ 10 | #include "tensorflow/lite/model.h" 11 | #include "tensorflow/lite/kernels/register.h" 12 | 13 | using namespace tflite; 14 | }} 15 | 16 | pub struct InterpreterBuilder<'a, Op> 17 | where 18 | Op: OpResolver, 19 | { 20 | handle: Box, 21 | _model: MaybeOwned<'a, FlatBufferModel>, 22 | _resolver: Op, 23 | } 24 | 25 | impl<'a, Op> Drop for InterpreterBuilder<'a, Op> 26 | where 27 | Op: OpResolver, 28 | { 29 | fn drop(&mut self) { 30 | let handle = Box::into_raw(std::mem::take(&mut self.handle)); 31 | #[allow(clippy::forgetting_copy_types, clippy::useless_transmute, deprecated)] 32 | unsafe { 33 | cpp!([handle as "InterpreterBuilder*"] { 34 | delete handle; 35 | }); 36 | } 37 | } 38 | } 39 | 40 | impl<'a, Op> InterpreterBuilder<'a, Op> 41 | where 42 | Op: OpResolver, 43 | { 44 | #[allow(clippy::new_ret_no_self)] 45 | pub fn new>>(model: M, resolver: Op) -> Result { 46 | use std::ops::Deref; 47 | let model = model.into(); 48 | let handle = { 49 | let model_handle = model.as_ref().handle.deref() as *const _; 50 | let resolver_handle = resolver.get_resolver_handle() as *const _; 51 | 52 | #[allow(clippy::forgetting_copy_types, deprecated)] 53 | unsafe { 54 | cpp!([model_handle as "const FlatBufferModel*", 55 | resolver_handle as "const OpResolver*" 56 | ] -> *mut bindings::InterpreterBuilder as "InterpreterBuilder*" { 57 | return new InterpreterBuilder(*model_handle, *resolver_handle); 58 | }) 59 | } 60 | }; 61 | if handle.is_null() { 62 | return Err(Error::InternalError("failed to create InterpreterBuilder".to_string())); 63 | } 64 | let handle = unsafe { Box::from_raw(handle) }; 65 | Ok(Self { handle, _model: model, _resolver: resolver }) 66 | } 67 | 68 | pub fn build(mut self) -> Result> { 69 | #[allow(clippy::forgetting_copy_types, deprecated)] 70 | let handle = { 71 | let builder = (&mut *self.handle) as *mut _; 72 | unsafe { 73 | cpp!([builder as "InterpreterBuilder*"] -> *mut bindings::Interpreter as "Interpreter*" { 74 | std::unique_ptr interpreter; 75 | (*builder)(&interpreter); 76 | return interpreter.release(); 77 | }) 78 | } 79 | }; 80 | if handle.is_null() { 81 | return Err(Error::InternalError("failed to build".to_string())); 82 | } 83 | Interpreter::new(handle, self) 84 | } 85 | 86 | pub fn build_with_threads( 87 | mut self, 88 | threads: std::os::raw::c_int, 89 | ) -> Result> { 90 | #[allow(clippy::forgetting_copy_types, deprecated)] 91 | let handle = { 92 | let builder = (&mut *self.handle) as *mut _; 93 | #[allow(clippy::transmute_num_to_bytes)] 94 | unsafe { 95 | cpp!([builder as "InterpreterBuilder*", threads as "int"] -> *mut bindings::Interpreter as "Interpreter*" { 96 | std::unique_ptr interpreter; 97 | (*builder)(&interpreter, threads); 98 | return interpreter.release(); 99 | }) 100 | } 101 | }; 102 | if handle.is_null() { 103 | return Err(Error::InternalError("failed to build with threads".to_string())); 104 | } 105 | Interpreter::new(handle, self) 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /src/interpreter/context.rs: -------------------------------------------------------------------------------- 1 | use std::ffi::CStr; 2 | use std::fmt; 3 | 4 | use crate::bindings; 5 | 6 | pub type ElementKind = bindings::TfLiteType; 7 | pub type QuantizationParams = bindings::TfLiteQuantizationParams; 8 | 9 | pub trait ElemKindOf { 10 | fn elem_kind_of() -> ElementKind; 11 | } 12 | 13 | impl ElemKindOf for f32 { 14 | fn elem_kind_of() -> ElementKind { 15 | bindings::TfLiteType::kTfLiteFloat32 16 | } 17 | } 18 | 19 | impl ElemKindOf for u8 { 20 | fn elem_kind_of() -> ElementKind { 21 | bindings::TfLiteType::kTfLiteUInt8 22 | } 23 | } 24 | 25 | impl ElemKindOf for i32 { 26 | fn elem_kind_of() -> ElementKind { 27 | bindings::TfLiteType::kTfLiteInt32 28 | } 29 | } 30 | 31 | pub struct TensorInfo { 32 | pub name: String, 33 | pub element_kind: ElementKind, 34 | pub dims: Vec, 35 | } 36 | 37 | impl fmt::Debug for TensorInfo { 38 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 39 | f.debug_struct("TensorInfo") 40 | .field("name", &self.name) 41 | .field("element_kind", &self.element_kind) 42 | .field("dims", &self.dims) 43 | .finish() 44 | } 45 | } 46 | 47 | impl<'a> From<&'a bindings::TfLiteTensor> for TensorInfo { 48 | fn from(t: &'a bindings::TfLiteTensor) -> Self { 49 | Self { 50 | name: unsafe { CStr::from_ptr(t.name) }.to_str().unwrap().to_string(), 51 | element_kind: t.type_, 52 | dims: { 53 | let slice = unsafe { 54 | let dims = &*t.dims; 55 | dims.data.as_slice(dims.size as usize) 56 | }; 57 | slice.iter().map(|n| *n as usize).collect() 58 | }, 59 | } 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /src/interpreter/fbmodel.rs: -------------------------------------------------------------------------------- 1 | use std::path::Path; 2 | use std::{fs, mem}; 3 | 4 | use crate::bindings::tflite as bindings; 5 | use crate::model::Model; 6 | use crate::{Error, Result}; 7 | 8 | cpp! {{ 9 | #include "tensorflow/lite/model.h" 10 | #include "tensorflow/lite/kernels/register.h" 11 | 12 | using namespace tflite; 13 | }} 14 | 15 | #[derive(Default)] 16 | pub struct FlatBufferModel { 17 | pub(crate) handle: Box, 18 | model_buffer: Vec, 19 | } 20 | 21 | impl Drop for FlatBufferModel { 22 | fn drop(&mut self) { 23 | let handle = Box::into_raw(mem::take(&mut self.handle)); 24 | 25 | #[allow(clippy::forgetting_copy_types, clippy::useless_transmute, deprecated)] 26 | unsafe { 27 | cpp!([handle as "FlatBufferModel*"] { 28 | delete handle; 29 | }); 30 | } 31 | } 32 | } 33 | 34 | impl FlatBufferModel { 35 | pub fn build_from_file>(path: P) -> Result { 36 | Self::build_from_buffer(fs::read(path)?) 37 | } 38 | 39 | pub fn build_from_buffer(model_buffer: Vec) -> Result { 40 | let ptr = model_buffer.as_ptr(); 41 | let size = model_buffer.len(); 42 | 43 | #[allow(clippy::forgetting_copy_types, deprecated, clippy::transmute_num_to_bytes)] 44 | let handle = unsafe { 45 | cpp!([ptr as "const char*", size as "size_t"] 46 | -> *mut bindings::FlatBufferModel as "FlatBufferModel*" { 47 | return FlatBufferModel::BuildFromBuffer(ptr, size).release(); 48 | }) 49 | }; 50 | if handle.is_null() { 51 | return Err(Error::internal_error("failed to build model")); 52 | } 53 | let handle = unsafe { Box::from_raw(handle) }; 54 | Ok(Self { handle, model_buffer }) 55 | } 56 | 57 | pub fn build_from_model(model: &Model) -> Result { 58 | FlatBufferModel::build_from_buffer(model.to_buffer()) 59 | } 60 | 61 | pub fn buffer(&self) -> &[u8] { 62 | &self.model_buffer 63 | } 64 | 65 | pub fn release_buffer(mut self) -> Vec { 66 | mem::take(&mut self.model_buffer) 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /src/interpreter/mod.rs: -------------------------------------------------------------------------------- 1 | mod builder; 2 | pub mod context; 3 | mod fbmodel; 4 | pub mod op_resolver; 5 | pub mod ops; 6 | 7 | use std::mem; 8 | use std::slice; 9 | 10 | use libc::{c_int, size_t}; 11 | 12 | use crate::{bindings, Error, Result}; 13 | pub use builder::InterpreterBuilder; 14 | use context::{ElemKindOf, ElementKind, QuantizationParams, TensorInfo}; 15 | pub use fbmodel::FlatBufferModel; 16 | use op_resolver::OpResolver; 17 | 18 | cpp! {{ 19 | #include "tensorflow/lite/interpreter.h" 20 | #include "tensorflow/lite/optional_debug_tools.h" 21 | 22 | using namespace tflite; 23 | }} 24 | 25 | pub type TensorIndex = c_int; 26 | 27 | pub struct Interpreter<'a, Op> 28 | where 29 | Op: OpResolver, 30 | { 31 | handle: Box, 32 | _builder: InterpreterBuilder<'a, Op>, 33 | } 34 | 35 | impl<'a, Op> Drop for Interpreter<'a, Op> 36 | where 37 | Op: OpResolver, 38 | { 39 | fn drop(&mut self) { 40 | let handle = Box::into_raw(mem::take(&mut self.handle)); 41 | #[allow(clippy::forgetting_copy_types, clippy::useless_transmute, deprecated)] 42 | unsafe { 43 | cpp!([handle as "Interpreter*"] { 44 | delete handle; 45 | }); 46 | } 47 | } 48 | } 49 | 50 | impl<'a, Op> Interpreter<'a, Op> 51 | where 52 | Op: OpResolver, 53 | { 54 | fn handle(&self) -> &bindings::tflite::Interpreter { 55 | use std::ops::Deref; 56 | self.handle.deref() 57 | } 58 | fn handle_mut(&mut self) -> &mut bindings::tflite::Interpreter { 59 | use std::ops::DerefMut; 60 | self.handle.deref_mut() 61 | } 62 | pub(crate) fn new( 63 | handle: *mut bindings::tflite::Interpreter, 64 | builder: InterpreterBuilder<'a, Op>, 65 | ) -> Result { 66 | if handle.is_null() { 67 | return Err(Error::internal_error("failed to create interpreter")); 68 | } 69 | let handle = unsafe { Box::from_raw(handle) }; 70 | let mut interpreter = Self { handle, _builder: builder }; 71 | // # Safety 72 | // Always allocate tensors so we don't get into a state 73 | // where we try to read from or write to unallocated memory 74 | // without doing this it is possible to have undefined behavior 75 | // outside of an unsafe block 76 | interpreter.allocate_tensors()?; 77 | Ok(interpreter) 78 | } 79 | /// Update allocations for all tensors. This will redim dependent tensors using 80 | /// the input tensor dimensionality as given. This is relatively expensive. 81 | /// If you know that your sizes are not changing, you need not call this. 82 | pub fn allocate_tensors(&mut self) -> Result<()> { 83 | let interpreter = self.handle_mut() as *mut _; 84 | 85 | #[allow(clippy::forgetting_copy_types, deprecated)] 86 | let r = unsafe { 87 | cpp!([interpreter as "Interpreter*"] -> bool as "bool" { 88 | return interpreter->AllocateTensors() == kTfLiteOk; 89 | }) 90 | }; 91 | if r { 92 | Ok(()) 93 | } else { 94 | Err(Error::internal_error("failed to allocate tensors")) 95 | } 96 | } 97 | 98 | /// Gets model input details 99 | pub fn get_input_details(&self) -> Result> { 100 | self.inputs() 101 | .iter() 102 | .map(|index| { 103 | self.tensor_info(*index).ok_or_else(|| Error::internal_error("tensor not found")) 104 | }) 105 | .collect() 106 | } 107 | 108 | /// Gets model output details 109 | pub fn get_output_details(&self) -> Result> { 110 | self.outputs() 111 | .iter() 112 | .map(|index| { 113 | self.tensor_info(*index).ok_or_else(|| Error::internal_error("tensor not found")) 114 | }) 115 | .collect() 116 | } 117 | 118 | /// Prints a dump of what tensors and what nodes are in the interpreter. 119 | pub fn print_state(&self) { 120 | let interpreter = self.handle() as *const _; 121 | 122 | #[allow(clippy::forgetting_copy_types, clippy::useless_transmute, deprecated)] 123 | unsafe { 124 | cpp!([interpreter as "Interpreter*"] { 125 | PrintInterpreterState(interpreter); 126 | }) 127 | }; 128 | } 129 | 130 | /// Invoke the interpreter (run the whole graph in dependency order). 131 | pub fn invoke(&mut self) -> Result<()> { 132 | let interpreter = self.handle_mut() as *mut _; 133 | 134 | #[allow(deprecated)] 135 | let r = unsafe { 136 | cpp!([interpreter as "Interpreter*"] -> bool as "bool" { 137 | return interpreter->Invoke() == kTfLiteOk; 138 | }) 139 | }; 140 | if r { 141 | Ok(()) 142 | } else { 143 | Err(Error::internal_error("failed to invoke interpreter")) 144 | } 145 | } 146 | 147 | /// Sets the number of threads available to the interpreter 148 | /// `threads` should be >= -1 149 | /// Passing in a value of -1 will let the interpreter set the number 150 | /// of threads available to itself. 151 | /// 152 | /// Note that increasing the number of threads does not always speed up inference 153 | pub fn set_num_threads(&mut self, threads: c_int) { 154 | let interpreter = self.handle_mut() as *mut _; 155 | 156 | #[allow(clippy::forgetting_copy_types, deprecated, clippy::transmute_num_to_bytes)] 157 | unsafe { 158 | cpp!([interpreter as "Interpreter*", threads as "int"] { 159 | interpreter->SetNumThreads(threads); 160 | }) 161 | }; 162 | println!("Set num threads to {threads}"); 163 | } 164 | 165 | /// Read only access to list of inputs. 166 | pub fn inputs(&self) -> &[TensorIndex] { 167 | let interpreter = self.handle() as *const _; 168 | let mut count: size_t = 0; 169 | 170 | #[allow(clippy::forgetting_copy_types, deprecated, clippy::transmute_num_to_bytes)] 171 | let ptr = unsafe { 172 | cpp!([ 173 | interpreter as "const Interpreter*", 174 | mut count as "size_t" 175 | ] -> *const TensorIndex as "const int*" { 176 | const auto& inputs = interpreter->inputs(); 177 | count = inputs.size(); 178 | return inputs.data(); 179 | }) 180 | }; 181 | unsafe { slice::from_raw_parts(ptr, count) } 182 | } 183 | 184 | /// Read only access to list of outputs. 185 | pub fn outputs(&self) -> &[TensorIndex] { 186 | let interpreter = self.handle() as *const _; 187 | let mut count: size_t = 0; 188 | 189 | #[allow(clippy::forgetting_copy_types, deprecated, clippy::transmute_num_to_bytes)] 190 | let ptr = unsafe { 191 | cpp!([ 192 | interpreter as "const Interpreter*", 193 | mut count as "size_t" 194 | ] -> *const TensorIndex as "const int*" { 195 | const auto& outputs = interpreter->outputs(); 196 | count = outputs.size(); 197 | return outputs.data(); 198 | }) 199 | }; 200 | unsafe { slice::from_raw_parts(ptr, count) } 201 | } 202 | 203 | /// Read only access to list of variable tensors. 204 | pub fn variables(&self) -> &[TensorIndex] { 205 | let interpreter = self.handle() as *const _; 206 | let mut count: size_t = 0; 207 | 208 | #[allow(clippy::forgetting_copy_types, deprecated, clippy::transmute_num_to_bytes)] 209 | let ptr = unsafe { 210 | cpp!([ 211 | interpreter as "const Interpreter*", 212 | mut count as "size_t" 213 | ] -> *const TensorIndex as "const int*" { 214 | const auto& variables = interpreter->variables(); 215 | count = variables.size(); 216 | return variables.data(); 217 | }) 218 | }; 219 | unsafe { slice::from_raw_parts(ptr, count) } 220 | } 221 | 222 | /// Return the number of tensors in the model. 223 | pub fn tensors_size(&self) -> size_t { 224 | let interpreter = self.handle() as *const _; 225 | 226 | #[allow(clippy::forgetting_copy_types, deprecated)] 227 | unsafe { 228 | cpp!([interpreter as "const Interpreter*"] -> size_t as "size_t" { 229 | return interpreter->tensors_size(); 230 | }) 231 | } 232 | } 233 | 234 | /// Return the number of ops in the model. 235 | pub fn nodes_size(&self) -> size_t { 236 | let interpreter = self.handle() as *const _; 237 | 238 | #[allow(clippy::forgetting_copy_types, deprecated)] 239 | unsafe { 240 | cpp!([interpreter as "const Interpreter*"] -> size_t as "size_t" { 241 | return interpreter->nodes_size(); 242 | }) 243 | } 244 | } 245 | 246 | /// Adds `count` tensors, preserving pre-existing Tensor entries. 247 | /// Return the index of the first new tensor. 248 | pub fn add_tensors(&mut self, count: size_t) -> Result { 249 | let interpreter = self.handle_mut() as *mut _; 250 | let mut index: TensorIndex = 0; 251 | 252 | #[allow(clippy::forgetting_copy_types, deprecated, clippy::transmute_num_to_bytes)] 253 | let result = unsafe { 254 | cpp!([ 255 | interpreter as "Interpreter*", 256 | count as "size_t", 257 | mut index as "int" 258 | ] -> bindings::TfLiteStatus as "TfLiteStatus" { 259 | return interpreter->AddTensors(count, &index); 260 | }) 261 | }; 262 | if result == bindings::TfLiteStatus::kTfLiteOk { 263 | Ok(index) 264 | } else { 265 | Err(Error::internal_error("failed to add tensors")) 266 | } 267 | } 268 | 269 | /// Provide a list of tensor indexes that are inputs to the model. 270 | /// Each index is bound check and this modifies the consistent_ flag of the 271 | /// interpreter. 272 | pub fn set_inputs(&mut self, inputs: &[TensorIndex]) -> Result<()> { 273 | let interpreter = self.handle_mut() as *mut _; 274 | let ptr = inputs.as_ptr(); 275 | let len = inputs.len() as size_t; 276 | 277 | #[allow(clippy::forgetting_copy_types, deprecated, clippy::transmute_num_to_bytes)] 278 | let result = unsafe { 279 | cpp!([ 280 | interpreter as "Interpreter*", 281 | ptr as "const int*", 282 | len as "size_t" 283 | ] -> bindings::TfLiteStatus as "TfLiteStatus" { 284 | std::vector inputs(ptr, ptr + len); 285 | return interpreter->SetInputs(inputs); 286 | }) 287 | }; 288 | if result == bindings::TfLiteStatus::kTfLiteOk { 289 | Ok(()) 290 | } else { 291 | Err(Error::internal_error("failed to set inputs")) 292 | } 293 | } 294 | 295 | /// Provide a list of tensor indexes that are outputs to the model 296 | /// Each index is bound check and this modifies the consistent_ flag of the 297 | /// interpreter. 298 | pub fn set_outputs(&mut self, outputs: &[TensorIndex]) -> Result<()> { 299 | let interpreter = self.handle_mut() as *mut _; 300 | let ptr = outputs.as_ptr(); 301 | let len = outputs.len() as size_t; 302 | 303 | #[allow(clippy::forgetting_copy_types, deprecated, clippy::transmute_num_to_bytes)] 304 | let result = unsafe { 305 | cpp!([ 306 | interpreter as "Interpreter*", 307 | ptr as "const int*", 308 | len as "size_t" 309 | ] -> bindings::TfLiteStatus as "TfLiteStatus" { 310 | std::vector outputs(ptr, ptr + len); 311 | return interpreter->SetOutputs(outputs); 312 | }) 313 | }; 314 | if result == bindings::TfLiteStatus::kTfLiteOk { 315 | Ok(()) 316 | } else { 317 | Err(Error::internal_error("failed to set outputs")) 318 | } 319 | } 320 | 321 | /// Provide a list of tensor indexes that are variable tensors. 322 | /// Each index is bound check and this modifies the consistent_ flag of the 323 | /// interpreter. 324 | pub fn set_variables(&mut self, variables: &[TensorIndex]) -> Result<()> { 325 | let interpreter = self.handle_mut() as *mut _; 326 | let ptr = variables.as_ptr(); 327 | let len = variables.len() as size_t; 328 | 329 | #[allow(clippy::forgetting_copy_types, deprecated, clippy::transmute_num_to_bytes)] 330 | let result = unsafe { 331 | cpp!([ 332 | interpreter as "Interpreter*", 333 | ptr as "const int*", 334 | len as "size_t" 335 | ] -> bindings::TfLiteStatus as "TfLiteStatus" { 336 | std::vector variables(ptr, ptr + len); 337 | return interpreter->SetVariables(variables); 338 | }) 339 | }; 340 | if result == bindings::TfLiteStatus::kTfLiteOk { 341 | Ok(()) 342 | } else { 343 | Err(Error::internal_error("failed to set variables")) 344 | } 345 | } 346 | 347 | #[allow(clippy::cognitive_complexity)] 348 | pub fn set_tensor_parameters_read_write( 349 | &mut self, 350 | tensor_index: TensorIndex, 351 | element_type: ElementKind, 352 | name: &str, 353 | dims: &[usize], 354 | quantization: QuantizationParams, 355 | is_variable: bool, 356 | ) -> Result<()> { 357 | let interpreter = self.handle_mut() as *mut _; 358 | 359 | let name_ptr = name.as_ptr(); 360 | let name_len = name.len() as size_t; 361 | 362 | let dims: Vec = dims.iter().map(|x| *x as i32).collect(); 363 | let dims_ptr = dims.as_ptr(); 364 | let dims_len = dims.len() as size_t; 365 | 366 | #[allow(clippy::forgetting_copy_types, deprecated, clippy::transmute_num_to_bytes)] 367 | let result = unsafe { 368 | cpp!([ 369 | interpreter as "Interpreter*", 370 | tensor_index as "int", 371 | element_type as "TfLiteType", 372 | name_ptr as "const char*", 373 | name_len as "size_t", 374 | dims_ptr as "const int*", 375 | dims_len as "size_t", 376 | quantization as "TfLiteQuantizationParams", 377 | is_variable as "bool" 378 | ] -> bindings::TfLiteStatus as "TfLiteStatus" { 379 | return interpreter->SetTensorParametersReadWrite( 380 | tensor_index, element_type, std::string(name_ptr, name_len).c_str(), 381 | dims_len, dims_ptr, quantization, is_variable); 382 | }) 383 | }; 384 | if result == bindings::TfLiteStatus::kTfLiteOk { 385 | Ok(()) 386 | } else { 387 | Err(Error::internal_error("failed to set tensor parameters")) 388 | } 389 | } 390 | 391 | fn tensor_inner(&self, tensor_index: TensorIndex) -> Option<&bindings::TfLiteTensor> { 392 | let interpreter = self.handle() as *const _; 393 | 394 | #[allow(clippy::forgetting_copy_types, deprecated, clippy::transmute_num_to_bytes)] 395 | let ptr = unsafe { 396 | cpp!([ 397 | interpreter as "const Interpreter*", 398 | tensor_index as "int" 399 | ] -> *const bindings::TfLiteTensor as "const TfLiteTensor*" { 400 | return interpreter->tensor(tensor_index); 401 | }) 402 | }; 403 | if ptr.is_null() { 404 | None 405 | } else { 406 | Some(unsafe { &*ptr }) 407 | } 408 | } 409 | 410 | pub fn tensor_info(&self, tensor_index: TensorIndex) -> Option { 411 | Some(self.tensor_inner(tensor_index)?.into()) 412 | } 413 | 414 | pub fn tensor_data(&self, tensor_index: TensorIndex) -> Result<&[T]> 415 | where 416 | T: ElemKindOf, 417 | { 418 | let inner = self 419 | .tensor_inner(tensor_index) 420 | .ok_or_else(|| Error::internal_error("invalid tensor index"))?; 421 | let tensor_info: TensorInfo = inner.into(); 422 | 423 | if tensor_info.element_kind != T::elem_kind_of() { 424 | return Err(Error::InternalError(format!( 425 | "Invalid type reference of `{:?}` to the original type `{:?}`", 426 | T::elem_kind_of(), 427 | tensor_info.element_kind 428 | ))); 429 | } 430 | 431 | Ok(unsafe { 432 | slice::from_raw_parts( 433 | inner.data.raw_const as *const T, 434 | inner.bytes / mem::size_of::(), 435 | ) 436 | }) 437 | } 438 | 439 | pub fn tensor_data_mut(&mut self, tensor_index: TensorIndex) -> Result<&mut [T]> 440 | where 441 | T: ElemKindOf, 442 | { 443 | let inner = self 444 | .tensor_inner(tensor_index) 445 | .ok_or_else(|| Error::internal_error("invalid tensor index"))?; 446 | let tensor_info: TensorInfo = inner.into(); 447 | 448 | if tensor_info.element_kind != T::elem_kind_of() { 449 | return Err(Error::InternalError(format!( 450 | "Invalid type reference of `{:?}` to the original type `{:?}`", 451 | T::elem_kind_of(), 452 | tensor_info.element_kind 453 | ))); 454 | } 455 | 456 | Ok(unsafe { 457 | slice::from_raw_parts_mut(inner.data.raw as *mut T, inner.bytes / mem::size_of::()) 458 | }) 459 | } 460 | 461 | pub fn tensor_buffer(&self, tensor_index: TensorIndex) -> Option<&[u8]> { 462 | let inner = self.tensor_inner(tensor_index)?; 463 | 464 | Some(unsafe { slice::from_raw_parts(inner.data.raw_const as *mut u8, inner.bytes) }) 465 | } 466 | 467 | pub fn tensor_buffer_mut(&mut self, tensor_index: TensorIndex) -> Option<&mut [u8]> { 468 | let inner = self.tensor_inner(tensor_index)?; 469 | 470 | Some(unsafe { slice::from_raw_parts_mut(inner.data.raw as *mut u8, inner.bytes) }) 471 | } 472 | } 473 | 474 | #[cfg(test)] 475 | mod tests { 476 | use super::*; 477 | use std::sync::Arc; 478 | 479 | use crate::ops::builtin::BuiltinOpResolver; 480 | 481 | #[test] 482 | fn threadsafe_types() { 483 | fn send_sync(_t: &T) {} 484 | let model = FlatBufferModel::build_from_file("data/MNISTnet_uint8_quant.tflite") 485 | .expect("Unable to build flatbuffer model"); 486 | send_sync(&model); 487 | let resolver = Arc::new(BuiltinOpResolver::default()); 488 | send_sync(&resolver); 489 | let builder = InterpreterBuilder::new(model, resolver).expect("Not able to build builder"); 490 | send_sync(&builder); 491 | let interpreter = builder.build().expect("Not able to build model"); 492 | send_sync(&interpreter); 493 | } 494 | } 495 | -------------------------------------------------------------------------------- /src/interpreter/op_resolver.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use crate::bindings::tflite::OpResolver as SysOpResolver; 4 | 5 | pub trait OpResolver: Send + Sync { 6 | fn get_resolver_handle(&self) -> &SysOpResolver; 7 | } 8 | 9 | impl OpResolver for Arc { 10 | fn get_resolver_handle(&self) -> &SysOpResolver { 11 | self.as_ref().get_resolver_handle() 12 | } 13 | } 14 | 15 | impl<'a, T: OpResolver> OpResolver for &'a T { 16 | fn get_resolver_handle(&self) -> &SysOpResolver { 17 | (*self).get_resolver_handle() 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /src/interpreter/ops/builtin/mod.rs: -------------------------------------------------------------------------------- 1 | mod resolver; 2 | 3 | pub use resolver::Resolver as BuiltinOpResolver; 4 | -------------------------------------------------------------------------------- /src/interpreter/ops/builtin/resolver.rs: -------------------------------------------------------------------------------- 1 | use std::mem; 2 | 3 | use crate::bindings::tflite as bindings; 4 | use crate::interpreter::op_resolver::OpResolver; 5 | 6 | cpp! {{ 7 | #include "tensorflow/lite/kernels/register.h" 8 | 9 | using namespace tflite::ops::builtin; 10 | }} 11 | 12 | pub struct Resolver { 13 | handle: Box, 14 | } 15 | 16 | impl Drop for Resolver { 17 | #[allow(clippy::useless_transmute, clippy::forgetting_copy_types, deprecated)] 18 | fn drop(&mut self) { 19 | let handle = Box::into_raw(mem::take(&mut self.handle)); 20 | unsafe { 21 | cpp!([handle as "BuiltinOpResolver*"] { 22 | delete handle; 23 | }); 24 | } 25 | } 26 | } 27 | 28 | impl OpResolver for Resolver { 29 | fn get_resolver_handle(&self) -> &bindings::OpResolver { 30 | self.handle.as_ref() 31 | } 32 | } 33 | 34 | impl Default for Resolver { 35 | #[allow(clippy::forgetting_copy_types, deprecated)] 36 | fn default() -> Self { 37 | let handle = unsafe { 38 | cpp!([] -> *mut bindings::OpResolver as "OpResolver*" { 39 | return new BuiltinOpResolver(); 40 | }) 41 | }; 42 | let handle = unsafe { Box::from_raw(handle) }; 43 | Self { handle } 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /src/interpreter/ops/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod builtin; 2 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | #![recursion_limit = "128"] 2 | 3 | #[macro_use] 4 | extern crate cpp; 5 | 6 | mod bindings; 7 | mod error; 8 | mod interpreter; 9 | pub mod model; 10 | 11 | pub use error::Error; 12 | pub use interpreter::*; 13 | 14 | pub type Result = ::std::result::Result; 15 | -------------------------------------------------------------------------------- /src/model/builtin_options.rs: -------------------------------------------------------------------------------- 1 | use std::ptr; 2 | 3 | use super::stl::vector::VectorOfI32; 4 | use crate::bindings::flatbuffers::NativeTable; 5 | use crate::bindings::tflite::*; 6 | 7 | #[repr(C)] 8 | #[derive(Debug)] 9 | pub struct BuiltinOptionsUnion { 10 | pub typ: BuiltinOptions, 11 | pub value: *mut NativeTable, 12 | } 13 | 14 | impl Default for BuiltinOptionsUnion { 15 | fn default() -> Self { 16 | BuiltinOptionsUnion { typ: BuiltinOptions::BuiltinOptions_NONE, value: ptr::null_mut() } 17 | } 18 | } 19 | 20 | impl Drop for BuiltinOptionsUnion { 21 | fn drop(&mut self) { 22 | let ptr = self.value; 23 | #[allow(deprecated)] 24 | unsafe { 25 | cpp!([ptr as "flatbuffers::NativeTable*"] { 26 | delete ptr; 27 | }); 28 | } 29 | } 30 | } 31 | 32 | #[repr(C)] 33 | #[derive(Debug, PartialEq, Eq)] 34 | pub struct ConcatEmbeddingsOptionsT { 35 | _vtable: NativeTable, 36 | pub num_channels: i32, 37 | pub num_columns_per_channel: VectorOfI32, 38 | pub embedding_dim_per_channel: VectorOfI32, 39 | } 40 | 41 | #[repr(C)] 42 | #[derive(Debug, PartialEq, Eq)] 43 | pub struct ReshapeOptionsT { 44 | _vtable: NativeTable, 45 | pub new_shape: VectorOfI32, 46 | } 47 | 48 | #[repr(C)] 49 | #[derive(Debug, PartialEq, Eq)] 50 | pub struct SqueezeOptionsT { 51 | _vtable: NativeTable, 52 | pub squeeze_dims: VectorOfI32, 53 | } 54 | 55 | impl PartialEq for BuiltinOptionsUnion { 56 | #[allow(clippy::cognitive_complexity)] 57 | fn eq(&self, other: &Self) -> bool { 58 | macro_rules! compare { 59 | ($e:expr, $t:ty) => { 60 | if self.typ == $e 61 | && other.typ == $e 62 | && AsRef::<$t>::as_ref(self) == AsRef::<$t>::as_ref(other) 63 | { 64 | return true; 65 | } 66 | }; 67 | } 68 | 69 | if self.typ == BuiltinOptions::BuiltinOptions_NONE 70 | && other.typ == BuiltinOptions::BuiltinOptions_NONE 71 | { 72 | return true; 73 | } 74 | compare!(BuiltinOptions::BuiltinOptions_Conv2DOptions, Conv2DOptionsT); 75 | compare!(BuiltinOptions::BuiltinOptions_DepthwiseConv2DOptions, DepthwiseConv2DOptionsT); 76 | compare!(BuiltinOptions::BuiltinOptions_ConcatEmbeddingsOptions, ConcatEmbeddingsOptionsT); 77 | compare!(BuiltinOptions::BuiltinOptions_LSHProjectionOptions, LSHProjectionOptionsT); 78 | compare!(BuiltinOptions::BuiltinOptions_Pool2DOptions, Pool2DOptionsT); 79 | compare!(BuiltinOptions::BuiltinOptions_SVDFOptions, SVDFOptionsT); 80 | compare!(BuiltinOptions::BuiltinOptions_RNNOptions, RNNOptionsT); 81 | compare!(BuiltinOptions::BuiltinOptions_FullyConnectedOptions, FullyConnectedOptionsT); 82 | compare!(BuiltinOptions::BuiltinOptions_SoftmaxOptions, SoftmaxOptionsT); 83 | compare!(BuiltinOptions::BuiltinOptions_ConcatenationOptions, ConcatenationOptionsT); 84 | compare!(BuiltinOptions::BuiltinOptions_AddOptions, AddOptionsT); 85 | compare!(BuiltinOptions::BuiltinOptions_L2NormOptions, L2NormOptionsT); 86 | compare!( 87 | BuiltinOptions::BuiltinOptions_LocalResponseNormalizationOptions, 88 | LocalResponseNormalizationOptionsT 89 | ); 90 | compare!(BuiltinOptions::BuiltinOptions_LSTMOptions, LSTMOptionsT); 91 | compare!(BuiltinOptions::BuiltinOptions_ResizeBilinearOptions, ResizeBilinearOptionsT); 92 | compare!(BuiltinOptions::BuiltinOptions_CallOptions, CallOptionsT); 93 | compare!(BuiltinOptions::BuiltinOptions_ReshapeOptions, ReshapeOptionsT); 94 | compare!(BuiltinOptions::BuiltinOptions_SkipGramOptions, SkipGramOptionsT); 95 | compare!(BuiltinOptions::BuiltinOptions_SpaceToDepthOptions, SpaceToDepthOptionsT); 96 | compare!( 97 | BuiltinOptions::BuiltinOptions_EmbeddingLookupSparseOptions, 98 | EmbeddingLookupSparseOptionsT 99 | ); 100 | compare!(BuiltinOptions::BuiltinOptions_MulOptions, MulOptionsT); 101 | compare!(BuiltinOptions::BuiltinOptions_PadOptions, PadOptionsT); 102 | compare!(BuiltinOptions::BuiltinOptions_GatherOptions, GatherOptionsT); 103 | compare!(BuiltinOptions::BuiltinOptions_BatchToSpaceNDOptions, BatchToSpaceNDOptionsT); 104 | compare!(BuiltinOptions::BuiltinOptions_SpaceToBatchNDOptions, SpaceToBatchNDOptionsT); 105 | compare!(BuiltinOptions::BuiltinOptions_TransposeOptions, TransposeOptionsT); 106 | compare!(BuiltinOptions::BuiltinOptions_ReducerOptions, ReducerOptionsT); 107 | compare!(BuiltinOptions::BuiltinOptions_SubOptions, SubOptionsT); 108 | compare!(BuiltinOptions::BuiltinOptions_DivOptions, DivOptionsT); 109 | compare!(BuiltinOptions::BuiltinOptions_SqueezeOptions, SqueezeOptionsT); 110 | compare!(BuiltinOptions::BuiltinOptions_SequenceRNNOptions, SequenceRNNOptionsT); 111 | compare!(BuiltinOptions::BuiltinOptions_StridedSliceOptions, StridedSliceOptionsT); 112 | compare!(BuiltinOptions::BuiltinOptions_ExpOptions, ExpOptionsT); 113 | compare!(BuiltinOptions::BuiltinOptions_TopKV2Options, TopKV2OptionsT); 114 | compare!(BuiltinOptions::BuiltinOptions_SplitOptions, SplitOptionsT); 115 | compare!(BuiltinOptions::BuiltinOptions_LogSoftmaxOptions, LogSoftmaxOptionsT); 116 | compare!(BuiltinOptions::BuiltinOptions_CastOptions, CastOptionsT); 117 | compare!(BuiltinOptions::BuiltinOptions_DequantizeOptions, DequantizeOptionsT); 118 | compare!(BuiltinOptions::BuiltinOptions_MaximumMinimumOptions, MaximumMinimumOptionsT); 119 | compare!(BuiltinOptions::BuiltinOptions_ArgMaxOptions, ArgMaxOptionsT); 120 | compare!(BuiltinOptions::BuiltinOptions_LessOptions, LessOptionsT); 121 | compare!(BuiltinOptions::BuiltinOptions_NegOptions, NegOptionsT); 122 | compare!(BuiltinOptions::BuiltinOptions_PadV2Options, PadV2OptionsT); 123 | compare!(BuiltinOptions::BuiltinOptions_GreaterOptions, GreaterOptionsT); 124 | compare!(BuiltinOptions::BuiltinOptions_GreaterEqualOptions, GreaterEqualOptionsT); 125 | compare!(BuiltinOptions::BuiltinOptions_LessEqualOptions, LessEqualOptionsT); 126 | compare!(BuiltinOptions::BuiltinOptions_SelectOptions, SelectOptionsT); 127 | compare!(BuiltinOptions::BuiltinOptions_SliceOptions, SliceOptionsT); 128 | compare!(BuiltinOptions::BuiltinOptions_TransposeConvOptions, TransposeConvOptionsT); 129 | compare!(BuiltinOptions::BuiltinOptions_SparseToDenseOptions, SparseToDenseOptionsT); 130 | compare!(BuiltinOptions::BuiltinOptions_TileOptions, TileOptionsT); 131 | compare!(BuiltinOptions::BuiltinOptions_ExpandDimsOptions, ExpandDimsOptionsT); 132 | compare!(BuiltinOptions::BuiltinOptions_EqualOptions, EqualOptionsT); 133 | compare!(BuiltinOptions::BuiltinOptions_NotEqualOptions, NotEqualOptionsT); 134 | compare!(BuiltinOptions::BuiltinOptions_ShapeOptions, ShapeOptionsT); 135 | compare!(BuiltinOptions::BuiltinOptions_PowOptions, PowOptionsT); 136 | compare!(BuiltinOptions::BuiltinOptions_ArgMinOptions, ArgMinOptionsT); 137 | compare!(BuiltinOptions::BuiltinOptions_FakeQuantOptions, FakeQuantOptionsT); 138 | compare!(BuiltinOptions::BuiltinOptions_PackOptions, PackOptionsT); 139 | compare!(BuiltinOptions::BuiltinOptions_LogicalOrOptions, LogicalOrOptionsT); 140 | compare!(BuiltinOptions::BuiltinOptions_OneHotOptions, OneHotOptionsT); 141 | compare!(BuiltinOptions::BuiltinOptions_LogicalAndOptions, LogicalAndOptionsT); 142 | compare!(BuiltinOptions::BuiltinOptions_LogicalNotOptions, LogicalNotOptionsT); 143 | compare!(BuiltinOptions::BuiltinOptions_UnpackOptions, UnpackOptionsT); 144 | compare!(BuiltinOptions::BuiltinOptions_FloorDivOptions, FloorDivOptionsT); 145 | compare!(BuiltinOptions::BuiltinOptions_SquareOptions, SquareOptionsT); 146 | compare!(BuiltinOptions::BuiltinOptions_ZerosLikeOptions, ZerosLikeOptionsT); 147 | compare!(BuiltinOptions::BuiltinOptions_FillOptions, FillOptionsT); 148 | compare!( 149 | BuiltinOptions::BuiltinOptions_BidirectionalSequenceLSTMOptions, 150 | BidirectionalSequenceLSTMOptionsT 151 | ); 152 | compare!( 153 | BuiltinOptions::BuiltinOptions_BidirectionalSequenceRNNOptions, 154 | BidirectionalSequenceRNNOptionsT 155 | ); 156 | compare!( 157 | BuiltinOptions::BuiltinOptions_UnidirectionalSequenceLSTMOptions, 158 | UnidirectionalSequenceLSTMOptionsT 159 | ); 160 | compare!(BuiltinOptions::BuiltinOptions_FloorModOptions, FloorModOptionsT); 161 | compare!(BuiltinOptions::BuiltinOptions_RangeOptions, RangeOptionsT); 162 | compare!( 163 | BuiltinOptions::BuiltinOptions_ResizeNearestNeighborOptions, 164 | ResizeNearestNeighborOptionsT 165 | ); 166 | compare!(BuiltinOptions::BuiltinOptions_LeakyReluOptions, LeakyReluOptionsT); 167 | compare!( 168 | BuiltinOptions::BuiltinOptions_SquaredDifferenceOptions, 169 | SquaredDifferenceOptionsT 170 | ); 171 | compare!(BuiltinOptions::BuiltinOptions_MirrorPadOptions, MirrorPadOptionsT); 172 | compare!(BuiltinOptions::BuiltinOptions_AbsOptions, AbsOptionsT); 173 | compare!(BuiltinOptions::BuiltinOptions_SplitVOptions, SplitVOptionsT); 174 | compare!(BuiltinOptions::BuiltinOptions_UniqueOptions, UniqueOptionsT); 175 | compare!(BuiltinOptions::BuiltinOptions_ReverseV2Options, ReverseV2OptionsT); 176 | compare!(BuiltinOptions::BuiltinOptions_AddNOptions, AddNOptionsT); 177 | compare!(BuiltinOptions::BuiltinOptions_GatherNdOptions, GatherNdOptionsT); 178 | compare!(BuiltinOptions::BuiltinOptions_CosOptions, CosOptionsT); 179 | compare!(BuiltinOptions::BuiltinOptions_WhereOptions, WhereOptionsT); 180 | compare!(BuiltinOptions::BuiltinOptions_RankOptions, RankOptionsT); 181 | compare!(BuiltinOptions::BuiltinOptions_ReverseSequenceOptions, ReverseSequenceOptionsT); 182 | compare!(BuiltinOptions::BuiltinOptions_MatrixDiagOptions, MatrixDiagOptionsT); 183 | compare!(BuiltinOptions::BuiltinOptions_QuantizeOptions, QuantizeOptionsT); 184 | compare!(BuiltinOptions::BuiltinOptions_MatrixSetDiagOptions, MatrixSetDiagOptionsT); 185 | compare!(BuiltinOptions::BuiltinOptions_HardSwishOptions, HardSwishOptionsT); 186 | compare!(BuiltinOptions::BuiltinOptions_IfOptions, IfOptionsT); 187 | compare!(BuiltinOptions::BuiltinOptions_WhileOptions, WhileOptionsT); 188 | compare!(BuiltinOptions::BuiltinOptions_DepthToSpaceOptions, DepthToSpaceOptionsT); 189 | false 190 | } 191 | } 192 | 193 | impl Eq for BuiltinOptionsUnion {} 194 | 195 | macro_rules! add_impl_options { 196 | ($($t:ty,)*) => ($( 197 | impl AsRef<$t> for BuiltinOptionsUnion { 198 | fn as_ref(&self) -> & $t { 199 | unsafe { (self.value as *const $t).as_ref().unwrap() } 200 | } 201 | } 202 | 203 | impl AsMut<$t> for BuiltinOptionsUnion { 204 | fn as_mut(&mut self) -> &mut $t { 205 | unsafe { (self.value as *mut $t).as_mut().unwrap() } 206 | } 207 | } 208 | )*) 209 | } 210 | 211 | add_impl_options! { 212 | Conv2DOptionsT, 213 | DepthwiseConv2DOptionsT, 214 | ConcatEmbeddingsOptionsT, 215 | LSHProjectionOptionsT, 216 | Pool2DOptionsT, 217 | SVDFOptionsT, 218 | RNNOptionsT, 219 | FullyConnectedOptionsT, 220 | SoftmaxOptionsT, 221 | ConcatenationOptionsT, 222 | AddOptionsT, 223 | L2NormOptionsT, 224 | LocalResponseNormalizationOptionsT, 225 | LSTMOptionsT, 226 | ResizeBilinearOptionsT, 227 | CallOptionsT, 228 | ReshapeOptionsT, 229 | SkipGramOptionsT, 230 | SpaceToDepthOptionsT, 231 | EmbeddingLookupSparseOptionsT, 232 | MulOptionsT, 233 | PadOptionsT, 234 | GatherOptionsT, 235 | BatchToSpaceNDOptionsT, 236 | SpaceToBatchNDOptionsT, 237 | TransposeOptionsT, 238 | ReducerOptionsT, 239 | SubOptionsT, 240 | DivOptionsT, 241 | SqueezeOptionsT, 242 | SequenceRNNOptionsT, 243 | StridedSliceOptionsT, 244 | ExpOptionsT, 245 | TopKV2OptionsT, 246 | SplitOptionsT, 247 | LogSoftmaxOptionsT, 248 | CastOptionsT, 249 | DequantizeOptionsT, 250 | MaximumMinimumOptionsT, 251 | ArgMaxOptionsT, 252 | LessOptionsT, 253 | NegOptionsT, 254 | PadV2OptionsT, 255 | GreaterOptionsT, 256 | GreaterEqualOptionsT, 257 | LessEqualOptionsT, 258 | SelectOptionsT, 259 | SliceOptionsT, 260 | TransposeConvOptionsT, 261 | SparseToDenseOptionsT, 262 | TileOptionsT, 263 | ExpandDimsOptionsT, 264 | EqualOptionsT, 265 | NotEqualOptionsT, 266 | ShapeOptionsT, 267 | PowOptionsT, 268 | ArgMinOptionsT, 269 | FakeQuantOptionsT, 270 | PackOptionsT, 271 | LogicalOrOptionsT, 272 | OneHotOptionsT, 273 | LogicalAndOptionsT, 274 | LogicalNotOptionsT, 275 | UnpackOptionsT, 276 | FloorDivOptionsT, 277 | SquareOptionsT, 278 | ZerosLikeOptionsT, 279 | FillOptionsT, 280 | BidirectionalSequenceLSTMOptionsT, 281 | BidirectionalSequenceRNNOptionsT, 282 | UnidirectionalSequenceLSTMOptionsT, 283 | FloorModOptionsT, 284 | RangeOptionsT, 285 | ResizeNearestNeighborOptionsT, 286 | LeakyReluOptionsT, 287 | SquaredDifferenceOptionsT, 288 | MirrorPadOptionsT, 289 | AbsOptionsT, 290 | SplitVOptionsT, 291 | UniqueOptionsT, 292 | ReverseV2OptionsT, 293 | AddNOptionsT, 294 | GatherNdOptionsT, 295 | CosOptionsT, 296 | WhereOptionsT, 297 | RankOptionsT, 298 | ReverseSequenceOptionsT, 299 | MatrixDiagOptionsT, 300 | QuantizeOptionsT, 301 | MatrixSetDiagOptionsT, 302 | HardSwishOptionsT, 303 | IfOptionsT, 304 | WhileOptionsT, 305 | DepthToSpaceOptionsT, 306 | } 307 | -------------------------------------------------------------------------------- /src/model/mod.rs: -------------------------------------------------------------------------------- 1 | #![allow(clippy::field_reassign_with_default)] 2 | #![allow(clippy::size_of_ref)] 3 | 4 | mod builtin_options; 5 | mod builtin_options_impl; 6 | pub mod stl; 7 | 8 | use std::ffi::c_void; 9 | use std::fs; 10 | use std::ops::{Deref, DerefMut}; 11 | use std::path::Path; 12 | use std::{fmt, mem, slice}; 13 | 14 | use libc::size_t; 15 | use stl::memory::UniquePtr; 16 | use stl::string::String as StlString; 17 | use stl::vector::{ 18 | VectorInsert, VectorOfBool, VectorOfF32, VectorOfI32, VectorOfI64, VectorOfU8, 19 | VectorOfUniquePtr, 20 | }; 21 | 22 | pub use crate::bindings::flatbuffers::NativeTable; 23 | pub use crate::bindings::tflite::*; 24 | use crate::{Error, Result}; 25 | pub use builtin_options::{ 26 | BuiltinOptionsUnion, ConcatEmbeddingsOptionsT, ReshapeOptionsT, SqueezeOptionsT, 27 | }; 28 | 29 | #[repr(C)] 30 | #[derive(Debug)] 31 | pub struct QuantizationDetailsUnion { 32 | pub typ: QuantizationDetails, 33 | pub value: *mut c_void, 34 | } 35 | 36 | impl PartialEq for QuantizationDetailsUnion { 37 | fn eq(&self, other: &Self) -> bool { 38 | self.typ == QuantizationDetails::QuantizationDetails_NONE 39 | && other.typ == QuantizationDetails::QuantizationDetails_NONE 40 | } 41 | } 42 | 43 | #[repr(C)] 44 | #[derive(Debug, PartialEq, Eq)] 45 | pub struct BufferT { 46 | _vtable: NativeTable, 47 | pub data: VectorOfU8, 48 | } 49 | 50 | #[repr(C)] 51 | #[derive(Debug, PartialEq)] 52 | pub struct QuantizationParametersT { 53 | _vtable: NativeTable, 54 | pub min: VectorOfF32, 55 | pub max: VectorOfF32, 56 | pub scale: VectorOfF32, 57 | pub zero_point: VectorOfI64, 58 | pub details: QuantizationDetailsUnion, 59 | pub quantized_dimension: i32, 60 | } 61 | 62 | #[repr(C)] 63 | #[derive(Debug, PartialEq)] 64 | pub struct TensorT { 65 | _vtable: NativeTable, 66 | pub shape: VectorOfI32, 67 | pub typ: TensorType, 68 | pub buffer: u32, 69 | pub name: StlString, 70 | pub quantization: UniquePtr, 71 | pub is_variable: bool, 72 | } 73 | 74 | #[repr(C)] 75 | #[derive(Debug, PartialEq, Eq)] 76 | pub struct OperatorT { 77 | _vtable: NativeTable, 78 | pub opcode_index: u32, 79 | pub inputs: VectorOfI32, 80 | pub outputs: VectorOfI32, 81 | pub builtin_options: BuiltinOptionsUnion, 82 | pub custom_options: VectorOfU8, 83 | pub custom_options_format: CustomOptionsFormat, 84 | pub mutating_variable_inputs: VectorOfBool, 85 | pub intermediates: VectorOfI32, 86 | } 87 | 88 | #[repr(C)] 89 | #[derive(Debug, PartialEq, Eq)] 90 | pub struct OperatorCodeT { 91 | _vtable: NativeTable, 92 | pub builtin_code: BuiltinOperator, 93 | pub custom_code: StlString, 94 | pub version: i32, 95 | } 96 | 97 | #[repr(C)] 98 | #[derive(Debug)] 99 | pub struct SubGraphT { 100 | _vtable: NativeTable, 101 | pub tensors: VectorOfUniquePtr, 102 | pub inputs: VectorOfI32, 103 | pub outputs: VectorOfI32, 104 | pub operators: VectorOfUniquePtr, 105 | pub name: StlString, 106 | } 107 | 108 | #[repr(C)] 109 | #[derive(Debug)] 110 | pub struct MetadataT { 111 | _vtable: NativeTable, 112 | pub name: StlString, 113 | pub buffer: u32, 114 | } 115 | 116 | #[repr(C)] 117 | #[derive(Debug)] 118 | pub struct ModelT { 119 | _vtable: NativeTable, 120 | pub version: u32, 121 | pub operator_codes: VectorOfUniquePtr, 122 | pub subgraphs: VectorOfUniquePtr, 123 | pub description: StlString, 124 | pub buffers: VectorOfUniquePtr, 125 | pub metadata_buffer: VectorOfI32, 126 | pub metadata: VectorOfUniquePtr, 127 | } 128 | 129 | impl Clone for BuiltinOptionsUnion { 130 | fn clone(&self) -> Self { 131 | let mut cloned = unsafe { mem::zeroed() }; 132 | let cloned_ref = &mut cloned; 133 | #[allow(deprecated)] 134 | unsafe { 135 | cpp!([self as "const BuiltinOptionsUnion*", cloned_ref as "BuiltinOptionsUnion*"] { 136 | new (cloned_ref) BuiltinOptionsUnion(*self); 137 | }); 138 | } 139 | cloned 140 | } 141 | } 142 | 143 | impl Clone for UniquePtr { 144 | fn clone(&self) -> Self { 145 | let mut cloned = unsafe { mem::zeroed() }; 146 | let cloned_ref = &mut cloned; 147 | #[allow(deprecated)] 148 | unsafe { 149 | cpp!([self as "const std::unique_ptr*", cloned_ref as "std::unique_ptr*"] { 150 | if(*self) { 151 | new (cloned_ref) std::unique_ptr(new BufferT(**self)); 152 | } 153 | else { 154 | new (cloned_ref) std::unique_ptr(); 155 | } 156 | }); 157 | } 158 | cloned 159 | } 160 | } 161 | 162 | impl Clone for UniquePtr { 163 | fn clone(&self) -> Self { 164 | let mut cloned: UniquePtr = Default::default(); 165 | cloned.builtin_code = self.builtin_code; 166 | cloned.custom_code.assign(&self.custom_code); 167 | cloned.version = self.version; 168 | cloned 169 | } 170 | } 171 | 172 | impl Clone for UniquePtr { 173 | fn clone(&self) -> Self { 174 | let mut cloned = unsafe { mem::zeroed() }; 175 | let cloned_ref = &mut cloned; 176 | #[allow(deprecated)] 177 | unsafe { 178 | cpp!([self as "const std::unique_ptr*", cloned_ref as "std::unique_ptr*"] { 179 | if(*self) { 180 | new (cloned_ref) std::unique_ptr(new QuantizationParametersT(**self)); 181 | } 182 | else { 183 | new (cloned_ref) std::unique_ptr(); 184 | } 185 | }); 186 | } 187 | cloned 188 | } 189 | } 190 | 191 | impl Clone for UniquePtr { 192 | fn clone(&self) -> Self { 193 | let mut cloned: UniquePtr = Default::default(); 194 | cloned.shape.assign(self.shape.iter().cloned()); 195 | cloned.typ = self.typ; 196 | cloned.buffer = self.buffer; 197 | cloned.name.assign(&self.name); 198 | cloned.quantization = self.quantization.clone(); 199 | cloned.is_variable = self.is_variable; 200 | cloned 201 | } 202 | } 203 | 204 | impl Clone for UniquePtr { 205 | fn clone(&self) -> Self { 206 | let mut cloned: UniquePtr = Default::default(); 207 | cloned.opcode_index = self.opcode_index; 208 | cloned.inputs.assign(self.inputs.iter().cloned()); 209 | cloned.outputs.assign(self.outputs.iter().cloned()); 210 | cloned.builtin_options = self.builtin_options.clone(); 211 | cloned.custom_options.assign(self.custom_options.iter().cloned()); 212 | cloned.custom_options_format = self.custom_options_format; 213 | cloned.mutating_variable_inputs = self.mutating_variable_inputs.clone(); 214 | cloned 215 | } 216 | } 217 | 218 | #[repr(transparent)] 219 | #[derive(Default)] 220 | pub struct Model(UniquePtr); 221 | 222 | impl Clone for Model { 223 | fn clone(&self) -> Self { 224 | Self::from_buffer(&self.to_buffer()).unwrap() 225 | } 226 | } 227 | 228 | impl Deref for Model { 229 | type Target = UniquePtr; 230 | 231 | fn deref(&self) -> &Self::Target { 232 | &self.0 233 | } 234 | } 235 | 236 | impl DerefMut for Model { 237 | fn deref_mut(&mut self) -> &mut Self::Target { 238 | &mut self.0 239 | } 240 | } 241 | 242 | impl fmt::Debug for Model { 243 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 244 | write!(f, "{:?}", self.0) 245 | } 246 | } 247 | 248 | impl Model { 249 | pub fn from_buffer(buffer: &[u8]) -> Option { 250 | let len = buffer.len(); 251 | let buffer = buffer.as_ptr(); 252 | let mut model: UniquePtr = unsafe { mem::zeroed() }; 253 | let model_ref = &mut model; 254 | #[allow(deprecated, clippy::transmute_num_to_bytes)] 255 | let r = unsafe { 256 | cpp!([buffer as "const void*", len as "size_t", model_ref as "std::unique_ptr*"] 257 | -> bool as "bool" { 258 | auto verifier = flatbuffers::Verifier((const uint8_t *)buffer, len); 259 | if (!VerifyModelBuffer(verifier)) { 260 | return false; 261 | } 262 | 263 | auto model = tflite::GetModel(buffer)->UnPack(); 264 | new (model_ref) std::unique_ptr(model); 265 | return true; 266 | }) 267 | }; 268 | if r && model.is_valid() { 269 | Some(Self(model)) 270 | } else { 271 | None 272 | } 273 | } 274 | 275 | pub fn from_file>(filepath: P) -> Result { 276 | Self::from_buffer(&fs::read(filepath)?) 277 | .ok_or_else(|| Error::internal_error("failed to unpack the flatbuffer model")) 278 | } 279 | 280 | pub fn to_buffer(&self) -> Vec { 281 | let mut buffer = Vec::new(); 282 | let buffer_ptr = &mut buffer; 283 | let model_ref = &self.0; 284 | #[allow(deprecated)] 285 | unsafe { 286 | cpp!([model_ref as "const std::unique_ptr*", buffer_ptr as "void*"] { 287 | flatbuffers::FlatBufferBuilder fbb; 288 | auto model = Model::Pack(fbb, model_ref->get()); 289 | FinishModelBuffer(fbb, model); 290 | uint8_t* ptr = fbb.GetBufferPointer(); 291 | size_t size = fbb.GetSize(); 292 | rust!(ModelT_to_file [ptr: *const u8 as "const uint8_t*", size: size_t as "size_t", buffer_ptr: &mut Vec as "void*"] { 293 | unsafe { buffer_ptr.extend_from_slice(slice::from_raw_parts(ptr, size)) }; 294 | }); 295 | }) 296 | } 297 | buffer 298 | } 299 | 300 | pub fn to_file>(&self, filepath: P) -> Result<()> { 301 | fs::write(filepath, self.to_buffer())?; 302 | Ok(()) 303 | } 304 | } 305 | 306 | #[cfg(test)] 307 | mod tests { 308 | use super::*; 309 | use std::ffi::CString; 310 | 311 | use crate::model::stl::vector::{VectorErase, VectorExtract, VectorInsert, VectorSlice}; 312 | use crate::ops::builtin::BuiltinOpResolver; 313 | use crate::{FlatBufferModel, InterpreterBuilder}; 314 | 315 | #[test] 316 | fn flatbuffer_model_apis_inspect() { 317 | assert!(Model::from_file("data.mnist10.bin").is_err()); 318 | 319 | let model = Model::from_file("data/MNISTnet_uint8_quant.tflite").unwrap(); 320 | assert_eq!(model.version, 3); 321 | assert_eq!(model.operator_codes.size(), 5); 322 | assert_eq!(model.subgraphs.size(), 1); 323 | assert_eq!(model.buffers.size(), 24); 324 | assert_eq!(model.description.c_str().to_string_lossy(), "TOCO Converted."); 325 | 326 | assert_eq!( 327 | model.operator_codes[0].builtin_code, 328 | BuiltinOperator::BuiltinOperator_AVERAGE_POOL_2D 329 | ); 330 | 331 | assert_eq!( 332 | model.operator_codes.iter().map(|oc| oc.builtin_code).collect::>(), 333 | vec![ 334 | BuiltinOperator::BuiltinOperator_AVERAGE_POOL_2D, 335 | BuiltinOperator::BuiltinOperator_CONV_2D, 336 | BuiltinOperator::BuiltinOperator_DEPTHWISE_CONV_2D, 337 | BuiltinOperator::BuiltinOperator_SOFTMAX, 338 | BuiltinOperator::BuiltinOperator_RESHAPE 339 | ] 340 | ); 341 | 342 | let subgraph = &model.subgraphs[0]; 343 | assert_eq!(subgraph.tensors.size(), 23); 344 | assert_eq!(subgraph.operators.size(), 9); 345 | assert_eq!(subgraph.inputs.as_slice(), &[22]); 346 | assert_eq!(subgraph.outputs.as_slice(), &[21]); 347 | 348 | let softmax = subgraph 349 | .operators 350 | .iter() 351 | .position(|op| { 352 | model.operator_codes[op.opcode_index as usize].builtin_code 353 | == BuiltinOperator::BuiltinOperator_SOFTMAX 354 | }) 355 | .unwrap(); 356 | 357 | assert_eq!(subgraph.operators[softmax].inputs.as_slice(), &[4]); 358 | assert_eq!(subgraph.operators[softmax].outputs.as_slice(), &[21]); 359 | assert_eq!( 360 | subgraph.operators[softmax].builtin_options.typ, 361 | BuiltinOptions::BuiltinOptions_SoftmaxOptions 362 | ); 363 | 364 | let softmax_options: &SoftmaxOptionsT = 365 | subgraph.operators[softmax].builtin_options.as_ref(); 366 | 367 | #[allow(clippy::float_cmp)] 368 | { 369 | assert_eq!(softmax_options.beta, 1.); 370 | } 371 | } 372 | 373 | #[test] 374 | fn flatbuffer_model_apis_mutate() { 375 | let mut model = Model::from_file("data/MNISTnet_uint8_quant.tflite").unwrap(); 376 | model.version = 2; 377 | model.operator_codes.erase(4); 378 | model.buffers.erase(22); 379 | model.buffers.erase(23); 380 | model.description.assign(&CString::new("flatbuffer").unwrap()); 381 | 382 | { 383 | let subgraph = &mut model.subgraphs[0]; 384 | subgraph.inputs.erase(0); 385 | subgraph.outputs.assign(vec![1, 2, 3, 4]); 386 | } 387 | 388 | let model_buffer = model.to_buffer(); 389 | let model = Model::from_buffer(&model_buffer).unwrap(); 390 | assert_eq!(model.version, 2); 391 | assert_eq!(model.operator_codes.size(), 4); 392 | assert_eq!(model.subgraphs.size(), 1); 393 | assert_eq!(model.buffers.size(), 22); 394 | assert_eq!(model.description.c_str().to_string_lossy(), "flatbuffer"); 395 | 396 | let subgraph = &model.subgraphs[0]; 397 | assert_eq!(subgraph.tensors.size(), 23); 398 | assert_eq!(subgraph.operators.size(), 9); 399 | assert!(subgraph.inputs.as_slice().is_empty()); 400 | assert_eq!(subgraph.outputs.as_slice(), &[1, 2, 3, 4]); 401 | } 402 | 403 | #[test] 404 | fn flatbuffer_model_apis_insert() { 405 | let mut model1 = Model::from_file("data/MNISTnet_uint8_quant.tflite").unwrap(); 406 | let mut model2 = Model::from_file("data/MNISTnet_uint8_quant.tflite").unwrap(); 407 | 408 | let num_buffers = model1.buffers.size(); 409 | 410 | let data = model1.buffers[0].data.to_vec(); 411 | let buffer = model1.buffers.extract(0); 412 | model2.buffers.push_back(buffer); 413 | assert_eq!(model2.buffers.size(), num_buffers + 1); 414 | 415 | assert_eq!(model2.buffers[num_buffers].data.to_vec(), data); 416 | } 417 | 418 | #[test] 419 | fn flatbuffer_model_apis_extract() { 420 | let source_model = Model::from_file("data/MNISTnet_uint8_quant.tflite").unwrap(); 421 | let source_subgraph = &source_model.subgraphs[0]; 422 | let source_operator = &source_subgraph.operators[0]; 423 | 424 | let tensors = source_operator 425 | .inputs 426 | .iter() 427 | .chain(source_operator.outputs.iter()) 428 | .map(|&tensor_index| source_subgraph.tensors[tensor_index as usize].clone()); 429 | 430 | let model_buffer = { 431 | let mut model = Model::default(); 432 | model.version = source_model.version; 433 | model.description.assign(&source_model.description); 434 | model.buffers.assign( 435 | tensors.clone().map(|tensor| source_model.buffers[tensor.buffer as usize].clone()), 436 | ); 437 | model.operator_codes.push_back( 438 | source_model.operator_codes[source_operator.opcode_index as usize].clone(), 439 | ); 440 | 441 | let mut subgraph: UniquePtr = Default::default(); 442 | subgraph.tensors.assign(tensors); 443 | for (i, tensor) in subgraph.tensors.iter_mut().enumerate() { 444 | tensor.buffer = i as u32; 445 | } 446 | let mut operator = source_operator.clone(); 447 | operator.opcode_index = 0; 448 | let num_inputs = operator.inputs.len() as i32; 449 | let num_outputs = operator.outputs.len() as i32; 450 | operator.inputs.assign(0..num_inputs); 451 | operator.outputs.assign(num_inputs..num_inputs + num_outputs); 452 | subgraph.operators.push_back(operator); 453 | subgraph 454 | .inputs 455 | .assign((0..num_inputs).filter(|&i| model.buffers[i as usize].data.is_empty())); 456 | subgraph.outputs.assign(num_inputs..num_inputs + num_outputs); 457 | model.subgraphs.push_back(subgraph); 458 | 459 | let subgraph = &model.subgraphs[0]; 460 | println!("{:?}", subgraph.inputs); 461 | println!("{:?}", subgraph.outputs); 462 | 463 | for operator in &subgraph.operators { 464 | println!("{operator:?}"); 465 | } 466 | 467 | for tensor in &subgraph.tensors { 468 | println!("{tensor:?}"); 469 | } 470 | 471 | for buffer in &model.buffers { 472 | println!("{buffer:?}"); 473 | } 474 | 475 | for operator_code in &model.operator_codes { 476 | println!("{operator_code:?}"); 477 | } 478 | model.to_buffer() 479 | }; 480 | 481 | let model = Model::from_buffer(&model_buffer).unwrap(); 482 | let subgraph = &model.subgraphs[0]; 483 | let operator = &subgraph.operators[0]; 484 | assert_eq!(model.version, 3); 485 | assert_eq!(model.description, source_model.description); 486 | assert_eq!(subgraph.inputs.as_slice(), &[0i32]); 487 | assert_eq!(subgraph.outputs.as_slice(), &[3i32]); 488 | assert_eq!(operator.inputs.as_slice(), &[0i32, 1, 2]); 489 | assert_eq!(operator.outputs.as_slice(), &[3i32]); 490 | assert_eq!( 491 | model.operator_codes[operator.opcode_index as usize], 492 | source_model.operator_codes[source_operator.opcode_index as usize] 493 | ); 494 | assert_eq!(operator.builtin_options, source_operator.builtin_options); 495 | assert_eq!(operator.custom_options, source_operator.custom_options); 496 | assert_eq!(operator.custom_options_format, source_operator.custom_options_format); 497 | assert_eq!(operator.mutating_variable_inputs, source_operator.mutating_variable_inputs); 498 | 499 | let tensors: Vec<_> = operator 500 | .inputs 501 | .iter() 502 | .chain(operator.outputs.iter()) 503 | .map(|&tensor_index| &subgraph.tensors[tensor_index as usize]) 504 | .collect(); 505 | 506 | let source_tensors: Vec<_> = source_operator 507 | .inputs 508 | .iter() 509 | .chain(source_operator.outputs.iter()) 510 | .map(|&tensor_index| &source_subgraph.tensors[tensor_index as usize]) 511 | .collect(); 512 | 513 | assert_eq!(tensors.len(), source_tensors.len()); 514 | for (tensor, source_tensor) in tensors.into_iter().zip(source_tensors.into_iter()) { 515 | assert_eq!(tensor.shape, source_tensor.shape); 516 | assert_eq!(tensor.typ, source_tensor.typ); 517 | assert_eq!(tensor.name, source_tensor.name); 518 | assert_eq!(tensor.quantization, source_tensor.quantization); 519 | assert_eq!(tensor.is_variable, source_tensor.is_variable); 520 | assert_eq!( 521 | model.buffers[tensor.buffer as usize], 522 | source_model.buffers[source_tensor.buffer as usize] 523 | ); 524 | } 525 | } 526 | 527 | #[test] 528 | fn unittest_buffer_clone() { 529 | let (buffer1, buffer2) = { 530 | let model = Model::from_file("data/MNISTnet_uint8_quant.tflite").unwrap(); 531 | let buffer = &model.buffers[0]; 532 | (buffer.clone(), buffer.clone()) 533 | }; 534 | assert_eq!(buffer1.data.as_slice(), buffer2.data.as_slice()); 535 | } 536 | 537 | #[test] 538 | fn unittest_tensor_clone() { 539 | let (tensor1, tensor2) = { 540 | let model = Model::from_file("data/MNISTnet_uint8_quant.tflite").unwrap(); 541 | let tensor = &model.subgraphs[0].tensors[0]; 542 | (tensor.clone(), tensor.clone()) 543 | }; 544 | 545 | assert_eq!(tensor1.shape.as_slice(), tensor2.shape.as_slice()); 546 | assert_eq!(tensor1.typ, tensor2.typ); 547 | assert_eq!(tensor1.buffer, tensor2.buffer); 548 | assert_eq!(tensor1.name.c_str(), tensor2.name.c_str()); 549 | assert_eq!(tensor1.is_variable, tensor2.is_variable); 550 | } 551 | 552 | #[test] 553 | fn unittest_operator_clone() { 554 | let (operator1, operator2) = { 555 | let model = Model::from_file("data/MNISTnet_uint8_quant.tflite").unwrap(); 556 | let operator = &model.subgraphs[0].operators[0]; 557 | (operator.clone(), operator.clone()) 558 | }; 559 | 560 | assert_eq!(operator1.opcode_index, operator2.opcode_index); 561 | assert_eq!(operator1.inputs.as_slice(), operator2.inputs.as_slice()); 562 | assert_eq!(operator1.outputs.as_slice(), operator2.outputs.as_slice()); 563 | assert_eq!(operator1.builtin_options.typ, operator2.builtin_options.typ); 564 | assert_eq!(operator1.custom_options.as_slice(), operator2.custom_options.as_slice()); 565 | assert_eq!(operator1.custom_options_format, operator2.custom_options_format); 566 | assert_eq!( 567 | operator1.mutating_variable_inputs.iter().collect::>(), 568 | operator2.mutating_variable_inputs.iter().collect::>() 569 | ); 570 | } 571 | 572 | #[test] 573 | fn unittest_build_model() { 574 | let mut model = Model::default(); 575 | model.version = 3; 576 | model.description.assign(&CString::new("model pad").unwrap()); 577 | 578 | { 579 | let mut pad: UniquePtr = Default::default(); 580 | pad.builtin_code = BuiltinOperator::BuiltinOperator_PAD; 581 | pad.version = 1; 582 | model.operator_codes.push_back(pad); 583 | } 584 | 585 | model.buffers.assign(vec![UniquePtr::::default(); 3]); 586 | 587 | let mut subgraph: UniquePtr = Default::default(); 588 | 589 | let mut quantization: UniquePtr = Default::default(); 590 | quantization.min.push_back(-1.110_645_3); 591 | quantization.max.push_back(1.274_200_2); 592 | quantization.scale.push_back(0.009_352_336); 593 | quantization.zero_point.push_back(119); 594 | 595 | let mut pad: UniquePtr = Default::default(); 596 | pad.builtin_options = BuiltinOptionsUnion::PadOptions(); 597 | pad.opcode_index = 0; 598 | pad.inputs.assign(vec![0, 1]); 599 | pad.outputs.assign(vec![2]); 600 | subgraph.operators.push_back(pad); 601 | subgraph.inputs.assign(vec![0]); 602 | subgraph.outputs.assign(vec![2]); 603 | 604 | let mut tensor: UniquePtr = Default::default(); 605 | tensor.shape.assign(vec![1, 1]); 606 | tensor.typ = TensorType::TensorType_UINT8; 607 | tensor.buffer = 0; 608 | tensor.name.assign(&CString::new("input_tensor").unwrap()); 609 | tensor.quantization = quantization.clone(); 610 | subgraph.tensors.push_back(tensor); 611 | 612 | let mut tensor: UniquePtr = Default::default(); 613 | tensor.shape.assign(vec![2, 2]); 614 | tensor.typ = TensorType::TensorType_INT32; 615 | tensor.buffer = 1; 616 | tensor.name.assign(&CString::new("shape_tensor").unwrap()); 617 | subgraph.tensors.push_back(tensor); 618 | model.buffers[1].data.assign(vec![0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0]); 619 | 620 | let mut tensor: UniquePtr = Default::default(); 621 | tensor.shape.assign(vec![2, 2]); 622 | tensor.typ = TensorType::TensorType_UINT8; 623 | tensor.buffer = 2; 624 | tensor.name.assign(&CString::new("output_tensor").unwrap()); 625 | tensor.quantization = quantization; 626 | subgraph.tensors.push_back(tensor); 627 | 628 | model.subgraphs.push_back(subgraph); 629 | 630 | let builder = InterpreterBuilder::new( 631 | FlatBufferModel::build_from_model(&model).unwrap(), 632 | BuiltinOpResolver::default(), 633 | ) 634 | .unwrap(); 635 | let mut interpreter = builder.build().unwrap(); 636 | 637 | interpreter.allocate_tensors().unwrap(); 638 | interpreter.tensor_data_mut(0).unwrap().copy_from_slice(&[0u8]); 639 | 640 | interpreter.invoke().unwrap(); 641 | assert_eq!(interpreter.tensor_data::(2).unwrap(), &[119u8, 0, 119, 119]); 642 | } 643 | } 644 | -------------------------------------------------------------------------------- /src/model/stl/memory.rs: -------------------------------------------------------------------------------- 1 | use std::marker::PhantomData; 2 | use std::ops::Deref; 3 | 4 | use super::bindings::root::rust::*; 5 | 6 | #[repr(C)] 7 | pub struct UniquePtr(unique_ptr_of_void, PhantomData); 8 | 9 | unsafe impl Sync for UniquePtr {} 10 | unsafe impl Send for UniquePtr {} 11 | 12 | impl PartialEq for UniquePtr 13 | where 14 | T: PartialEq, 15 | UniquePtr: Deref, 16 | { 17 | fn eq(&self, other: &Self) -> bool { 18 | self.deref() == other.deref() 19 | } 20 | } 21 | 22 | impl Eq for UniquePtr 23 | where 24 | T: Eq, 25 | UniquePtr: Deref, 26 | { 27 | } 28 | 29 | impl Drop for UniquePtr { 30 | fn drop(&mut self) { 31 | #[allow(deprecated)] 32 | unsafe { 33 | cpp!([self as "std::unique_ptr*"] { 34 | self->reset(); 35 | }); 36 | } 37 | } 38 | } 39 | 40 | impl UniquePtr { 41 | pub fn is_valid(&self) -> bool { 42 | #[allow(deprecated)] 43 | unsafe { 44 | cpp!([self as "const std::unique_ptr*"] -> bool as "bool" { 45 | return static_cast(*self); 46 | }) 47 | } 48 | } 49 | } 50 | 51 | #[cfg(test)] 52 | mod tests { 53 | use crate::model::stl::vector::VectorExtract; 54 | use crate::model::Model; 55 | 56 | #[test] 57 | fn unittest_unique_ptr_drop() { 58 | let mut model = Model::from_file("data/MNISTnet_uint8_quant.tflite").unwrap(); 59 | let _subgraph = model.subgraphs.extract(0); 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /src/model/stl/memory_impl.rs: -------------------------------------------------------------------------------- 1 | #![allow(clippy::transmute_num_to_bytes)] 2 | use std::ops::{Deref, DerefMut}; 3 | use std::{fmt, mem}; 4 | 5 | use crate::model::stl::memory::UniquePtr; 6 | 7 | #[allow(deprecated)] 8 | impl Default for UniquePtr { 9 | fn default() -> Self { 10 | let mut this: Self = unsafe { mem::zeroed() }; 11 | let this_ref = &mut this; 12 | unsafe { 13 | cpp!([this_ref as "std::unique_ptr*"] { 14 | new (this_ref) std::unique_ptr(new OperatorCodeT); 15 | }) 16 | } 17 | this 18 | } 19 | } 20 | 21 | #[allow(deprecated)] 22 | impl Deref for UniquePtr { 23 | type Target = crate::model::OperatorCodeT; 24 | 25 | fn deref(&self) -> &Self::Target { 26 | unsafe { 27 | let ptr = cpp!([self as "const std::unique_ptr*"] -> *const crate::model::OperatorCodeT as "const OperatorCodeT*" { 28 | return self->get(); 29 | }); 30 | 31 | ptr.as_ref().unwrap() 32 | } 33 | } 34 | } 35 | 36 | #[allow(deprecated)] 37 | impl DerefMut for UniquePtr { 38 | fn deref_mut(&mut self) -> &mut Self::Target { 39 | unsafe { 40 | let ptr = cpp!([self as "std::unique_ptr*"] -> *mut crate::model::OperatorCodeT as "OperatorCodeT*" { 41 | return self->get(); 42 | }); 43 | 44 | ptr.as_mut().unwrap() 45 | } 46 | } 47 | } 48 | 49 | #[allow(deprecated)] 50 | impl fmt::Debug for UniquePtr { 51 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 52 | write!(f, "({:?})", self.deref()) 53 | } 54 | } 55 | 56 | #[allow(deprecated)] 57 | impl Default for UniquePtr { 58 | fn default() -> Self { 59 | let mut this: Self = unsafe { mem::zeroed() }; 60 | let this_ref = &mut this; 61 | unsafe { 62 | cpp!([this_ref as "std::unique_ptr*"] { 63 | new (this_ref) std::unique_ptr(new TensorT); 64 | }) 65 | } 66 | this 67 | } 68 | } 69 | 70 | #[allow(deprecated)] 71 | impl Deref for UniquePtr { 72 | type Target = crate::model::TensorT; 73 | 74 | fn deref(&self) -> &Self::Target { 75 | unsafe { 76 | let ptr = cpp!([self as "const std::unique_ptr*"] -> *const crate::model::TensorT as "const TensorT*" { 77 | return self->get(); 78 | }); 79 | 80 | ptr.as_ref().unwrap() 81 | } 82 | } 83 | } 84 | 85 | #[allow(deprecated)] 86 | impl DerefMut for UniquePtr { 87 | fn deref_mut(&mut self) -> &mut Self::Target { 88 | unsafe { 89 | let ptr = cpp!([self as "std::unique_ptr*"] -> *mut crate::model::TensorT as "TensorT*" { 90 | return self->get(); 91 | }); 92 | 93 | ptr.as_mut().unwrap() 94 | } 95 | } 96 | } 97 | 98 | #[allow(deprecated)] 99 | impl fmt::Debug for UniquePtr { 100 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 101 | write!(f, "({:?})", self.deref()) 102 | } 103 | } 104 | 105 | #[allow(deprecated)] 106 | impl Default for UniquePtr { 107 | fn default() -> Self { 108 | let mut this: Self = unsafe { mem::zeroed() }; 109 | let this_ref = &mut this; 110 | unsafe { 111 | cpp!([this_ref as "std::unique_ptr*"] { 112 | new (this_ref) std::unique_ptr(new OperatorT); 113 | }) 114 | } 115 | this 116 | } 117 | } 118 | 119 | #[allow(deprecated)] 120 | impl Deref for UniquePtr { 121 | type Target = crate::model::OperatorT; 122 | 123 | fn deref(&self) -> &Self::Target { 124 | unsafe { 125 | let ptr = cpp!([self as "const std::unique_ptr*"] -> *const crate::model::OperatorT as "const OperatorT*" { 126 | return self->get(); 127 | }); 128 | 129 | ptr.as_ref().unwrap() 130 | } 131 | } 132 | } 133 | 134 | #[allow(deprecated)] 135 | impl DerefMut for UniquePtr { 136 | fn deref_mut(&mut self) -> &mut Self::Target { 137 | unsafe { 138 | let ptr = cpp!([self as "std::unique_ptr*"] -> *mut crate::model::OperatorT as "OperatorT*" { 139 | return self->get(); 140 | }); 141 | 142 | ptr.as_mut().unwrap() 143 | } 144 | } 145 | } 146 | 147 | #[allow(deprecated)] 148 | impl fmt::Debug for UniquePtr { 149 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 150 | write!(f, "({:?})", self.deref()) 151 | } 152 | } 153 | 154 | #[allow(deprecated)] 155 | impl Default for UniquePtr { 156 | fn default() -> Self { 157 | let mut this: Self = unsafe { mem::zeroed() }; 158 | let this_ref = &mut this; 159 | unsafe { 160 | cpp!([this_ref as "std::unique_ptr*"] { 161 | new (this_ref) std::unique_ptr(new SubGraphT); 162 | }) 163 | } 164 | this 165 | } 166 | } 167 | 168 | #[allow(deprecated)] 169 | impl Deref for UniquePtr { 170 | type Target = crate::model::SubGraphT; 171 | 172 | fn deref(&self) -> &Self::Target { 173 | unsafe { 174 | let ptr = cpp!([self as "const std::unique_ptr*"] -> *const crate::model::SubGraphT as "const SubGraphT*" { 175 | return self->get(); 176 | }); 177 | 178 | ptr.as_ref().unwrap() 179 | } 180 | } 181 | } 182 | 183 | #[allow(deprecated)] 184 | impl DerefMut for UniquePtr { 185 | fn deref_mut(&mut self) -> &mut Self::Target { 186 | unsafe { 187 | let ptr = cpp!([self as "std::unique_ptr*"] -> *mut crate::model::SubGraphT as "SubGraphT*" { 188 | return self->get(); 189 | }); 190 | 191 | ptr.as_mut().unwrap() 192 | } 193 | } 194 | } 195 | 196 | #[allow(deprecated)] 197 | impl fmt::Debug for UniquePtr { 198 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 199 | write!(f, "({:?})", self.deref()) 200 | } 201 | } 202 | 203 | #[allow(deprecated)] 204 | impl Default for UniquePtr { 205 | fn default() -> Self { 206 | let mut this: Self = unsafe { mem::zeroed() }; 207 | let this_ref = &mut this; 208 | unsafe { 209 | cpp!([this_ref as "std::unique_ptr*"] { 210 | new (this_ref) std::unique_ptr(new BufferT); 211 | }) 212 | } 213 | this 214 | } 215 | } 216 | 217 | #[allow(deprecated)] 218 | impl Deref for UniquePtr { 219 | type Target = crate::model::BufferT; 220 | 221 | fn deref(&self) -> &Self::Target { 222 | unsafe { 223 | let ptr = cpp!([self as "const std::unique_ptr*"] -> *const crate::model::BufferT as "const BufferT*" { 224 | return self->get(); 225 | }); 226 | 227 | ptr.as_ref().unwrap() 228 | } 229 | } 230 | } 231 | 232 | #[allow(deprecated)] 233 | impl DerefMut for UniquePtr { 234 | fn deref_mut(&mut self) -> &mut Self::Target { 235 | unsafe { 236 | let ptr = cpp!([self as "std::unique_ptr*"] -> *mut crate::model::BufferT as "BufferT*" { 237 | return self->get(); 238 | }); 239 | 240 | ptr.as_mut().unwrap() 241 | } 242 | } 243 | } 244 | 245 | #[allow(deprecated)] 246 | impl fmt::Debug for UniquePtr { 247 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 248 | write!(f, "({:?})", self.deref()) 249 | } 250 | } 251 | 252 | #[allow(deprecated)] 253 | impl Default for UniquePtr { 254 | fn default() -> Self { 255 | let mut this: Self = unsafe { mem::zeroed() }; 256 | let this_ref = &mut this; 257 | unsafe { 258 | cpp!([this_ref as "std::unique_ptr*"] { 259 | new (this_ref) std::unique_ptr(new QuantizationParametersT); 260 | }) 261 | } 262 | this 263 | } 264 | } 265 | 266 | #[allow(deprecated)] 267 | impl Deref for UniquePtr { 268 | type Target = crate::model::QuantizationParametersT; 269 | 270 | fn deref(&self) -> &Self::Target { 271 | unsafe { 272 | let ptr = cpp!([self as "const std::unique_ptr*"] -> *const crate::model::QuantizationParametersT as "const QuantizationParametersT*" { 273 | return self->get(); 274 | }); 275 | 276 | ptr.as_ref().unwrap() 277 | } 278 | } 279 | } 280 | 281 | #[allow(deprecated)] 282 | impl DerefMut for UniquePtr { 283 | fn deref_mut(&mut self) -> &mut Self::Target { 284 | unsafe { 285 | let ptr = cpp!([self as "std::unique_ptr*"] -> *mut crate::model::QuantizationParametersT as "QuantizationParametersT*" { 286 | return self->get(); 287 | }); 288 | 289 | ptr.as_mut().unwrap() 290 | } 291 | } 292 | } 293 | 294 | #[allow(deprecated)] 295 | impl fmt::Debug for UniquePtr { 296 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 297 | write!(f, "({:?})", self.deref()) 298 | } 299 | } 300 | 301 | #[allow(deprecated)] 302 | impl Default for UniquePtr { 303 | fn default() -> Self { 304 | let mut this: Self = unsafe { mem::zeroed() }; 305 | let this_ref = &mut this; 306 | unsafe { 307 | cpp!([this_ref as "std::unique_ptr*"] { 308 | new (this_ref) std::unique_ptr(new ModelT); 309 | }) 310 | } 311 | this 312 | } 313 | } 314 | 315 | #[allow(deprecated)] 316 | impl Deref for UniquePtr { 317 | type Target = crate::model::ModelT; 318 | 319 | fn deref(&self) -> &Self::Target { 320 | unsafe { 321 | let ptr = cpp!([self as "const std::unique_ptr*"] -> *const crate::model::ModelT as "const ModelT*" { 322 | return self->get(); 323 | }); 324 | 325 | ptr.as_ref().unwrap() 326 | } 327 | } 328 | } 329 | 330 | #[allow(deprecated)] 331 | impl DerefMut for UniquePtr { 332 | fn deref_mut(&mut self) -> &mut Self::Target { 333 | unsafe { 334 | let ptr = cpp!([self as "std::unique_ptr*"] -> *mut crate::model::ModelT as "ModelT*" { 335 | return self->get(); 336 | }); 337 | 338 | ptr.as_mut().unwrap() 339 | } 340 | } 341 | } 342 | 343 | #[allow(deprecated)] 344 | impl fmt::Debug for UniquePtr { 345 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 346 | write!(f, "({:?})", self.deref()) 347 | } 348 | } 349 | 350 | #[allow(deprecated)] 351 | impl Default for UniquePtr { 352 | fn default() -> Self { 353 | let mut this: Self = unsafe { mem::zeroed() }; 354 | let this_ref = &mut this; 355 | unsafe { 356 | cpp!([this_ref as "std::unique_ptr*"] { 357 | new (this_ref) std::unique_ptr(new MetadataT); 358 | }) 359 | } 360 | this 361 | } 362 | } 363 | 364 | #[allow(deprecated)] 365 | impl Deref for UniquePtr { 366 | type Target = crate::model::MetadataT; 367 | 368 | fn deref(&self) -> &Self::Target { 369 | unsafe { 370 | let ptr = cpp!([self as "const std::unique_ptr*"] -> *const crate::model::MetadataT as "const MetadataT*" { 371 | return self->get(); 372 | }); 373 | 374 | ptr.as_ref().unwrap() 375 | } 376 | } 377 | } 378 | 379 | #[allow(deprecated)] 380 | impl DerefMut for UniquePtr { 381 | fn deref_mut(&mut self) -> &mut Self::Target { 382 | unsafe { 383 | let ptr = cpp!([self as "std::unique_ptr*"] -> *mut crate::model::MetadataT as "MetadataT*" { 384 | return self->get(); 385 | }); 386 | 387 | ptr.as_mut().unwrap() 388 | } 389 | } 390 | } 391 | 392 | #[allow(deprecated)] 393 | impl fmt::Debug for UniquePtr { 394 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 395 | write!(f, "({:?})", self.deref()) 396 | } 397 | } 398 | -------------------------------------------------------------------------------- /src/model/stl/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod memory; 2 | pub mod memory_impl; 3 | pub mod string; 4 | #[macro_use] 5 | pub mod vector; 6 | pub mod vector_impl; 7 | 8 | pub(crate) mod bindings { 9 | include!(concat!(env!("OUT_DIR"), "/stl_types.rs")); 10 | } 11 | -------------------------------------------------------------------------------- /src/model/stl/string.rs: -------------------------------------------------------------------------------- 1 | use std::ffi::CStr; 2 | use std::fmt; 3 | use std::os::raw::c_char; 4 | 5 | use libc::size_t; 6 | 7 | use super::bindings::root::std::string; 8 | 9 | cpp! {{ 10 | #include 11 | 12 | struct struct_with_strings { 13 | int32_t index; 14 | std::string first_name; 15 | std::string last_name; 16 | }; 17 | }} 18 | 19 | #[repr(C)] 20 | /// This should be used as only (mutable) references. 21 | /// Small string optimization makes unsafe to move `String` instances. 22 | /// `String::drop` is also prohibited for this reason. 23 | pub struct String(string); 24 | 25 | impl PartialEq for String { 26 | fn eq(&self, other: &Self) -> bool { 27 | self.c_str() == other.c_str() 28 | } 29 | } 30 | 31 | impl Eq for String {} 32 | 33 | impl Drop for String { 34 | fn drop(&mut self) { 35 | panic!("Do not drop `String`!"); 36 | } 37 | } 38 | 39 | impl fmt::Display for String { 40 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 41 | write!(f, "{}", self.c_str().to_string_lossy()) 42 | } 43 | } 44 | 45 | impl fmt::Debug for String { 46 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 47 | write!(f, "{:?}", self.c_str().to_string_lossy()) 48 | } 49 | } 50 | 51 | impl AsRef for String { 52 | fn as_ref(&self) -> &CStr { 53 | self.c_str() 54 | } 55 | } 56 | 57 | impl String { 58 | pub fn is_empty(&self) -> bool { 59 | self.len() == 0 60 | } 61 | 62 | pub fn len(&self) -> size_t { 63 | #[allow(deprecated)] 64 | unsafe { 65 | cpp!([self as "const std::string*"] -> size_t as "size_t" { 66 | return self->size(); 67 | }) 68 | } 69 | } 70 | 71 | pub fn c_str(&self) -> &CStr { 72 | #[allow(deprecated)] 73 | unsafe { 74 | CStr::from_ptr(cpp!([self as "const std::string*"] 75 | -> *const c_char as "const char*" { 76 | return self->c_str(); 77 | })) 78 | } 79 | } 80 | 81 | pub fn assign>(&mut self, s: &S) { 82 | let s = s.as_ref(); 83 | let ptr = s.as_ptr(); 84 | #[allow(deprecated)] 85 | unsafe { 86 | cpp!([self as "std::string*", ptr as "const char*"] { 87 | self->assign(ptr); 88 | }) 89 | } 90 | } 91 | } 92 | 93 | #[cfg(test)] 94 | mod tests { 95 | use super::*; 96 | use std::ffi::CString; 97 | 98 | #[repr(C)] 99 | struct StructWithStrings { 100 | index: i32, 101 | first_name: String, 102 | last_name: String, 103 | } 104 | 105 | #[test] 106 | fn unittest_struct_with_strings() { 107 | #[allow(deprecated)] 108 | let x = unsafe { 109 | cpp!([] -> &mut StructWithStrings as "struct_with_strings*" { 110 | static struct_with_strings x{23, "boncheol", "gu"}; 111 | return &x; 112 | }) 113 | }; 114 | assert_eq!(x.index, 23); 115 | assert_eq!(x.first_name.c_str().to_string_lossy(), "boncheol"); 116 | assert_eq!(x.last_name.c_str().to_string_lossy(), "gu"); 117 | 118 | x.first_name.assign(&CString::new("junmo").unwrap()); 119 | assert_eq!(x.first_name.c_str().to_string_lossy(), "junmo"); 120 | 121 | x.last_name.assign(&x.first_name); 122 | } 123 | } 124 | -------------------------------------------------------------------------------- /src/model/stl/vector.rs: -------------------------------------------------------------------------------- 1 | use std::marker::PhantomData; 2 | use std::{mem, slice}; 3 | 4 | use super::bindings::root::rust::*; 5 | use super::memory::UniquePtr; 6 | pub use super::vector_impl::{VectorOfF32, VectorOfI32, VectorOfI64, VectorOfU8}; 7 | 8 | #[repr(C)] 9 | pub struct Vector(dummy_vector, PhantomData); 10 | 11 | pub trait VectorSlice { 12 | type Item; 13 | 14 | fn get_ptr(&self) -> *const Self::Item { 15 | self.as_slice().as_ptr() 16 | } 17 | 18 | fn get_mut_ptr(&mut self) -> *mut Self::Item { 19 | self.as_mut_slice().as_mut_ptr() 20 | } 21 | 22 | fn size(&self) -> usize { 23 | self.as_slice().len() 24 | } 25 | 26 | fn as_slice(&self) -> &[Self::Item] { 27 | let size = self.size(); 28 | 29 | if size == 0 { 30 | &[] 31 | } else { 32 | unsafe { slice::from_raw_parts(self.get_ptr(), size) } 33 | } 34 | } 35 | 36 | fn as_mut_slice(&mut self) -> &mut [Self::Item] { 37 | let size = self.size(); 38 | 39 | if size == 0 { 40 | &mut [] 41 | } else { 42 | unsafe { slice::from_raw_parts_mut(self.get_mut_ptr(), size) } 43 | } 44 | } 45 | } 46 | 47 | macro_rules! add_impl { 48 | ($($t:ty)*) => ($( 49 | impl fmt::Debug for $t 50 | where 51 | <$t as VectorSlice>::Item: fmt::Debug, 52 | { 53 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 54 | f.debug_list().entries(self.as_slice().iter()).finish() 55 | } 56 | } 57 | 58 | impl Deref for $t { 59 | type Target = [<$t as VectorSlice>::Item]; 60 | 61 | fn deref(&self) -> &Self::Target { 62 | self.as_slice() 63 | } 64 | } 65 | 66 | impl DerefMut for $t { 67 | fn deref_mut(&mut self) -> &mut Self::Target { 68 | self.as_mut_slice() 69 | } 70 | } 71 | 72 | impl Index for $t { 73 | type Output = <$t as VectorSlice>::Item; 74 | 75 | fn index(&self, index: usize) -> &Self::Output { 76 | &self.as_slice()[index] 77 | } 78 | } 79 | 80 | impl IndexMut for $t { 81 | fn index_mut(&mut self, index: usize) -> &mut Self::Output { 82 | &mut self.as_mut_slice()[index] 83 | } 84 | } 85 | 86 | impl<'a> IntoIterator for &'a $t { 87 | type Item = &'a <$t as VectorSlice>::Item; 88 | type IntoIter = slice::Iter<'a, <$t as VectorSlice>::Item>; 89 | 90 | fn into_iter(self) -> Self::IntoIter { 91 | self.iter() 92 | } 93 | } 94 | 95 | impl<'a> IntoIterator for &'a mut $t { 96 | type Item = &'a mut <$t as VectorSlice>::Item; 97 | type IntoIter = slice::IterMut<'a, <$t as VectorSlice>::Item>; 98 | 99 | fn into_iter(self) -> Self::IntoIter { 100 | self.iter_mut() 101 | } 102 | } 103 | )*) 104 | } 105 | 106 | pub trait VectorErase: VectorSlice { 107 | fn erase_range(&mut self, offset: usize, len: usize) { 108 | for i in (offset..offset + len).rev() { 109 | self.erase(i); 110 | } 111 | } 112 | 113 | fn pop_back(&mut self) { 114 | assert!(self.size() > 0); 115 | self.erase(self.size() - 1); 116 | } 117 | 118 | fn erase(&mut self, index: usize) { 119 | self.erase_range(index, 1); 120 | } 121 | 122 | fn clear(&mut self) { 123 | self.erase_range(0, self.size()); 124 | } 125 | 126 | fn retain(&mut self, pred: F) -> usize 127 | where 128 | F: Fn(usize, &Self::Item) -> bool, 129 | { 130 | let removed: Vec<_> = self 131 | .as_slice() 132 | .iter() 133 | .enumerate() 134 | .filter_map(|(i, op)| if pred(i, op) { None } else { Some(i) }) 135 | .collect(); 136 | 137 | for &i in removed.iter().rev() { 138 | self.erase(i); 139 | } 140 | removed.len() 141 | } 142 | 143 | fn truncate(&mut self, size: usize) { 144 | assert!(size <= self.size()); 145 | self.erase_range(size, self.size() - size); 146 | } 147 | } 148 | 149 | pub trait VectorInsert: VectorErase { 150 | fn push_back(&mut self, v: T); 151 | 152 | fn assign>(&mut self, vs: I) { 153 | self.clear(); 154 | for v in vs { 155 | self.push_back(v); 156 | } 157 | } 158 | 159 | fn append>(&mut self, items: I) { 160 | for item in items.into_iter() { 161 | self.push_back(item); 162 | } 163 | } 164 | } 165 | 166 | pub trait VectorExtract: VectorErase { 167 | fn extract(&mut self, index: usize) -> T; 168 | 169 | fn extract_remove(&mut self, index: usize) -> T { 170 | let item = self.extract(index); 171 | self.erase(index); 172 | item 173 | } 174 | } 175 | 176 | #[repr(C)] 177 | #[derive(Debug)] 178 | pub struct VectorOfBool(vector_of_bool); 179 | 180 | impl Default for VectorOfBool { 181 | fn default() -> Self { 182 | let mut this = unsafe { mem::zeroed() }; 183 | let this_ref = &mut this; 184 | #[allow(deprecated)] 185 | unsafe { 186 | cpp!([this_ref as "std::vector*"] { 187 | new (this_ref) const std::vector; 188 | }) 189 | } 190 | this 191 | } 192 | } 193 | 194 | impl Drop for VectorOfBool { 195 | fn drop(&mut self) { 196 | #[allow(deprecated)] 197 | unsafe { 198 | cpp!([self as "const std::vector*"] { 199 | self->~vector(); 200 | }) 201 | } 202 | } 203 | } 204 | 205 | impl Clone for VectorOfBool { 206 | fn clone(&self) -> Self { 207 | let mut cloned = unsafe { mem::zeroed() }; 208 | let cloned_ref = &mut cloned; 209 | #[allow(deprecated)] 210 | unsafe { 211 | cpp!([self as "const std::vector*", cloned_ref as "std::vector*"] { 212 | new (cloned_ref) std::vector(*self); 213 | }); 214 | } 215 | cloned 216 | } 217 | } 218 | 219 | impl PartialEq for VectorOfBool { 220 | fn eq(&self, other: &Self) -> bool { 221 | if self.size() != other.size() { 222 | return false; 223 | } 224 | self.iter().zip(other.iter()).all(|(x, y)| x == y) 225 | } 226 | } 227 | 228 | impl Eq for VectorOfBool {} 229 | 230 | impl VectorOfBool { 231 | pub fn get(&self, index: usize) -> bool { 232 | #[allow(deprecated, clippy::transmute_num_to_bytes)] 233 | unsafe { 234 | cpp!([self as "const std::vector*", index as "size_t"] -> bool as "bool" { 235 | return (*self)[index]; 236 | }) 237 | } 238 | } 239 | 240 | pub fn set(&mut self, index: usize, v: bool) { 241 | #[allow(deprecated, clippy::transmute_num_to_bytes)] 242 | unsafe { 243 | cpp!([self as "std::vector*", index as "size_t", v as "bool"] { 244 | (*self)[index] = v; 245 | }) 246 | } 247 | } 248 | 249 | pub fn size(&self) -> usize { 250 | #[allow(deprecated)] 251 | unsafe { 252 | cpp!([self as "const std::vector*"] -> usize as "size_t" { 253 | return self->size(); 254 | }) 255 | } 256 | } 257 | 258 | pub fn iter(&self) -> impl Iterator + '_ { 259 | (0..self.size()).map(move |i| self.get(i)) 260 | } 261 | } 262 | 263 | #[repr(C)] 264 | pub struct VectorOfUniquePtr(dummy_vector, PhantomData); 265 | 266 | impl PartialEq for VectorOfUniquePtr 267 | where 268 | Self: VectorSlice>, 269 | UniquePtr: PartialEq, 270 | { 271 | fn eq(&self, other: &Self) -> bool { 272 | self.as_slice() == other.as_slice() 273 | } 274 | } 275 | 276 | impl Eq for VectorOfUniquePtr 277 | where 278 | Self: VectorSlice>, 279 | UniquePtr: Eq, 280 | { 281 | } 282 | 283 | impl Drop for VectorOfUniquePtr { 284 | fn drop(&mut self) { 285 | #[allow(deprecated)] 286 | unsafe { 287 | cpp!([self as "const std::vector>*"] { 288 | self->~vector>(); 289 | }) 290 | } 291 | } 292 | } 293 | 294 | #[cfg(test)] 295 | mod tests { 296 | use super::*; 297 | 298 | use crate::model::stl::memory::UniquePtr; 299 | use crate::model::BufferT; 300 | 301 | #[test] 302 | fn unittest_vector_default() { 303 | let mut vs = VectorOfU8::default(); 304 | assert_eq!(vs.size(), 0); 305 | 306 | vs.push_back(9); 307 | vs.push_back(10); 308 | assert_eq!(vs.size(), 2); 309 | assert_eq!(vs.as_slice(), &[9u8, 10]); 310 | 311 | let mut vs: VectorOfUniquePtr = VectorOfUniquePtr::default(); 312 | vs.push_back(UniquePtr::default()); 313 | vs.push_back(UniquePtr::default()); 314 | vs.push_back(UniquePtr::default()); 315 | assert_eq!(vs.size(), 3); 316 | } 317 | 318 | #[test] 319 | fn unittest_vector_clone() { 320 | let mut vs = VectorOfU8::default(); 321 | vs.assign(0u8..6); 322 | assert_eq!(vs.size(), 6); 323 | 324 | let cloned = vs.clone(); 325 | assert_eq!(vs.as_slice(), cloned.as_slice()); 326 | } 327 | } 328 | -------------------------------------------------------------------------------- /src/model/stl/vector_impl.rs: -------------------------------------------------------------------------------- 1 | #![allow(clippy::transmute_num_to_bytes)] 2 | use std::ops::{Deref, DerefMut, Index, IndexMut}; 3 | use std::{fmt, mem, slice}; 4 | 5 | use libc::size_t; 6 | 7 | use super::memory::UniquePtr; 8 | use super::vector::{VectorErase, VectorExtract, VectorInsert, VectorOfUniquePtr, VectorSlice}; 9 | use crate::model::stl::bindings::root::rust::dummy_vector; 10 | 11 | cpp! {{ 12 | #include 13 | }} 14 | 15 | #[repr(C)] 16 | pub struct VectorOfU8(dummy_vector); 17 | 18 | #[allow(deprecated)] 19 | impl Default for VectorOfU8 { 20 | fn default() -> Self { 21 | let mut this = unsafe { mem::zeroed() }; 22 | let this_ref = &mut this; 23 | unsafe { 24 | cpp!([this_ref as "std::vector*"] { 25 | new (this_ref) const std::vector; 26 | }) 27 | } 28 | this 29 | } 30 | } 31 | 32 | #[allow(deprecated)] 33 | impl Drop for VectorOfU8 { 34 | fn drop(&mut self) { 35 | unsafe { 36 | cpp!([self as "const std::vector*"] { 37 | self->~vector(); 38 | }) 39 | } 40 | } 41 | } 42 | 43 | #[allow(deprecated)] 44 | impl Clone for VectorOfU8 { 45 | fn clone(&self) -> Self { 46 | let mut cloned = unsafe { mem::zeroed() }; 47 | let cloned_ref = &mut cloned; 48 | unsafe { 49 | cpp!([self as "const std::vector*", cloned_ref as "std::vector*"] { 50 | new (cloned_ref) std::vector(*self); 51 | }); 52 | } 53 | cloned 54 | } 55 | } 56 | 57 | impl PartialEq for VectorOfU8 { 58 | fn eq(&self, other: &Self) -> bool { 59 | self.as_slice() == other.as_slice() 60 | } 61 | } 62 | 63 | impl Eq for VectorOfU8 {} 64 | 65 | #[allow(deprecated)] 66 | impl VectorSlice for VectorOfU8 { 67 | type Item = u8; 68 | 69 | fn get_ptr(&self) -> *const Self::Item { 70 | unsafe { 71 | cpp!([self as "const std::vector*"] 72 | -> *const u8 as "const uint8_t*" { 73 | return self->data(); 74 | }) 75 | } 76 | } 77 | 78 | fn get_mut_ptr(&mut self) -> *mut Self::Item { 79 | unsafe { 80 | cpp!([self as "std::vector*"] 81 | -> *mut u8 as "uint8_t*" { 82 | return self->data(); 83 | }) 84 | } 85 | } 86 | 87 | fn size(&self) -> usize { 88 | unsafe { 89 | cpp!([self as "const std::vector*"] -> size_t as "size_t" { 90 | return self->size(); 91 | }) 92 | } 93 | } 94 | } 95 | 96 | #[allow(deprecated)] 97 | impl VectorErase for VectorOfU8 { 98 | fn erase_range(&mut self, offset: usize, size: usize) { 99 | let begin = offset as size_t; 100 | let end = offset + size as size_t; 101 | unsafe { 102 | cpp!([self as "std::vector*", begin as "size_t", end as "size_t"] { 103 | self->erase(self->begin() + begin, self->begin() + end); 104 | }); 105 | } 106 | } 107 | } 108 | 109 | #[allow(deprecated)] 110 | impl VectorInsert for VectorOfU8 { 111 | fn push_back(&mut self, mut v: Self::Item) { 112 | let vref = &mut v; 113 | unsafe { 114 | cpp!([self as "std::vector*", vref as "uint8_t*"] { 115 | self->push_back(std::move(*vref)); 116 | }) 117 | } 118 | } 119 | } 120 | 121 | #[allow(deprecated)] 122 | impl VectorExtract for VectorOfU8 { 123 | fn extract(&mut self, index: usize) -> u8 { 124 | assert!(index < self.size()); 125 | let mut v: u8 = unsafe { mem::zeroed() }; 126 | let vref = &mut v; 127 | unsafe { 128 | cpp!([self as "std::vector*", index as "size_t", vref as "uint8_t*"] { 129 | *vref = std::move((*self)[index]); 130 | }) 131 | } 132 | v 133 | } 134 | } 135 | 136 | add_impl!(VectorOfU8); 137 | 138 | #[repr(C)] 139 | pub struct VectorOfI32(dummy_vector); 140 | 141 | #[allow(deprecated)] 142 | impl Default for VectorOfI32 { 143 | fn default() -> Self { 144 | let mut this = unsafe { mem::zeroed() }; 145 | let this_ref = &mut this; 146 | unsafe { 147 | cpp!([this_ref as "std::vector*"] { 148 | new (this_ref) const std::vector; 149 | }) 150 | } 151 | this 152 | } 153 | } 154 | 155 | #[allow(deprecated)] 156 | impl Drop for VectorOfI32 { 157 | fn drop(&mut self) { 158 | unsafe { 159 | cpp!([self as "const std::vector*"] { 160 | self->~vector(); 161 | }) 162 | } 163 | } 164 | } 165 | 166 | #[allow(deprecated)] 167 | impl Clone for VectorOfI32 { 168 | fn clone(&self) -> Self { 169 | let mut cloned = unsafe { mem::zeroed() }; 170 | let cloned_ref = &mut cloned; 171 | unsafe { 172 | cpp!([self as "const std::vector*", cloned_ref as "std::vector*"] { 173 | new (cloned_ref) std::vector(*self); 174 | }); 175 | } 176 | cloned 177 | } 178 | } 179 | 180 | impl PartialEq for VectorOfI32 { 181 | fn eq(&self, other: &Self) -> bool { 182 | self.as_slice() == other.as_slice() 183 | } 184 | } 185 | 186 | impl Eq for VectorOfI32 {} 187 | 188 | #[allow(deprecated)] 189 | impl VectorSlice for VectorOfI32 { 190 | type Item = i32; 191 | 192 | fn get_ptr(&self) -> *const Self::Item { 193 | unsafe { 194 | cpp!([self as "const std::vector*"] 195 | -> *const i32 as "const int32_t*" { 196 | return self->data(); 197 | }) 198 | } 199 | } 200 | 201 | fn get_mut_ptr(&mut self) -> *mut Self::Item { 202 | unsafe { 203 | cpp!([self as "std::vector*"] 204 | -> *mut i32 as "int32_t*" { 205 | return self->data(); 206 | }) 207 | } 208 | } 209 | 210 | fn size(&self) -> usize { 211 | unsafe { 212 | cpp!([self as "const std::vector*"] -> size_t as "size_t" { 213 | return self->size(); 214 | }) 215 | } 216 | } 217 | } 218 | 219 | #[allow(deprecated)] 220 | impl VectorErase for VectorOfI32 { 221 | fn erase_range(&mut self, offset: usize, size: usize) { 222 | let begin = offset as size_t; 223 | let end = offset + size as size_t; 224 | unsafe { 225 | cpp!([self as "std::vector*", begin as "size_t", end as "size_t"] { 226 | self->erase(self->begin() + begin, self->begin() + end); 227 | }); 228 | } 229 | } 230 | } 231 | 232 | #[allow(deprecated)] 233 | impl VectorInsert for VectorOfI32 { 234 | fn push_back(&mut self, mut v: Self::Item) { 235 | let vref = &mut v; 236 | unsafe { 237 | cpp!([self as "std::vector*", vref as "int32_t*"] { 238 | self->push_back(std::move(*vref)); 239 | }) 240 | } 241 | } 242 | } 243 | 244 | #[allow(deprecated)] 245 | impl VectorExtract for VectorOfI32 { 246 | fn extract(&mut self, index: usize) -> i32 { 247 | assert!(index < self.size()); 248 | let mut v: i32 = unsafe { mem::zeroed() }; 249 | let vref = &mut v; 250 | unsafe { 251 | cpp!([self as "std::vector*", index as "size_t", vref as "int32_t*"] { 252 | *vref = std::move((*self)[index]); 253 | }) 254 | } 255 | v 256 | } 257 | } 258 | 259 | add_impl!(VectorOfI32); 260 | 261 | #[repr(C)] 262 | pub struct VectorOfI64(dummy_vector); 263 | 264 | #[allow(deprecated)] 265 | impl Default for VectorOfI64 { 266 | fn default() -> Self { 267 | let mut this = unsafe { mem::zeroed() }; 268 | let this_ref = &mut this; 269 | unsafe { 270 | cpp!([this_ref as "std::vector*"] { 271 | new (this_ref) const std::vector; 272 | }) 273 | } 274 | this 275 | } 276 | } 277 | 278 | #[allow(deprecated)] 279 | impl Drop for VectorOfI64 { 280 | fn drop(&mut self) { 281 | unsafe { 282 | cpp!([self as "const std::vector*"] { 283 | self->~vector(); 284 | }) 285 | } 286 | } 287 | } 288 | 289 | #[allow(deprecated)] 290 | impl Clone for VectorOfI64 { 291 | fn clone(&self) -> Self { 292 | let mut cloned = unsafe { mem::zeroed() }; 293 | let cloned_ref = &mut cloned; 294 | unsafe { 295 | cpp!([self as "const std::vector*", cloned_ref as "std::vector*"] { 296 | new (cloned_ref) std::vector(*self); 297 | }); 298 | } 299 | cloned 300 | } 301 | } 302 | 303 | impl PartialEq for VectorOfI64 { 304 | fn eq(&self, other: &Self) -> bool { 305 | self.as_slice() == other.as_slice() 306 | } 307 | } 308 | 309 | impl Eq for VectorOfI64 {} 310 | 311 | #[allow(deprecated)] 312 | impl VectorSlice for VectorOfI64 { 313 | type Item = i64; 314 | 315 | fn get_ptr(&self) -> *const Self::Item { 316 | unsafe { 317 | cpp!([self as "const std::vector*"] 318 | -> *const i64 as "const int64_t*" { 319 | return self->data(); 320 | }) 321 | } 322 | } 323 | 324 | fn get_mut_ptr(&mut self) -> *mut Self::Item { 325 | unsafe { 326 | cpp!([self as "std::vector*"] 327 | -> *mut i64 as "int64_t*" { 328 | return self->data(); 329 | }) 330 | } 331 | } 332 | 333 | fn size(&self) -> usize { 334 | unsafe { 335 | cpp!([self as "const std::vector*"] -> size_t as "size_t" { 336 | return self->size(); 337 | }) 338 | } 339 | } 340 | } 341 | 342 | #[allow(deprecated)] 343 | impl VectorErase for VectorOfI64 { 344 | fn erase_range(&mut self, offset: usize, size: usize) { 345 | let begin = offset as size_t; 346 | let end = offset + size as size_t; 347 | unsafe { 348 | cpp!([self as "std::vector*", begin as "size_t", end as "size_t"] { 349 | self->erase(self->begin() + begin, self->begin() + end); 350 | }); 351 | } 352 | } 353 | } 354 | 355 | #[allow(deprecated)] 356 | impl VectorInsert for VectorOfI64 { 357 | fn push_back(&mut self, mut v: Self::Item) { 358 | let vref = &mut v; 359 | unsafe { 360 | cpp!([self as "std::vector*", vref as "int64_t*"] { 361 | self->push_back(std::move(*vref)); 362 | }) 363 | } 364 | } 365 | } 366 | 367 | #[allow(deprecated)] 368 | impl VectorExtract for VectorOfI64 { 369 | fn extract(&mut self, index: usize) -> i64 { 370 | assert!(index < self.size()); 371 | let mut v: i64 = unsafe { mem::zeroed() }; 372 | let vref = &mut v; 373 | unsafe { 374 | cpp!([self as "std::vector*", index as "size_t", vref as "int64_t*"] { 375 | *vref = std::move((*self)[index]); 376 | }) 377 | } 378 | v 379 | } 380 | } 381 | 382 | add_impl!(VectorOfI64); 383 | 384 | #[repr(C)] 385 | pub struct VectorOfF32(dummy_vector); 386 | 387 | #[allow(deprecated)] 388 | impl Default for VectorOfF32 { 389 | fn default() -> Self { 390 | let mut this = unsafe { mem::zeroed() }; 391 | let this_ref = &mut this; 392 | unsafe { 393 | cpp!([this_ref as "std::vector*"] { 394 | new (this_ref) const std::vector; 395 | }) 396 | } 397 | this 398 | } 399 | } 400 | 401 | #[allow(deprecated)] 402 | impl Drop for VectorOfF32 { 403 | fn drop(&mut self) { 404 | unsafe { 405 | cpp!([self as "const std::vector*"] { 406 | self->~vector(); 407 | }) 408 | } 409 | } 410 | } 411 | 412 | #[allow(deprecated)] 413 | impl Clone for VectorOfF32 { 414 | fn clone(&self) -> Self { 415 | let mut cloned = unsafe { mem::zeroed() }; 416 | let cloned_ref = &mut cloned; 417 | unsafe { 418 | cpp!([self as "const std::vector*", cloned_ref as "std::vector*"] { 419 | new (cloned_ref) std::vector(*self); 420 | }); 421 | } 422 | cloned 423 | } 424 | } 425 | 426 | impl PartialEq for VectorOfF32 { 427 | fn eq(&self, other: &Self) -> bool { 428 | self.as_slice() == other.as_slice() 429 | } 430 | } 431 | 432 | impl Eq for VectorOfF32 {} 433 | 434 | #[allow(deprecated)] 435 | impl VectorSlice for VectorOfF32 { 436 | type Item = f32; 437 | 438 | fn get_ptr(&self) -> *const Self::Item { 439 | unsafe { 440 | cpp!([self as "const std::vector*"] 441 | -> *const f32 as "const float*" { 442 | return self->data(); 443 | }) 444 | } 445 | } 446 | 447 | fn get_mut_ptr(&mut self) -> *mut Self::Item { 448 | unsafe { 449 | cpp!([self as "std::vector*"] 450 | -> *mut f32 as "float*" { 451 | return self->data(); 452 | }) 453 | } 454 | } 455 | 456 | fn size(&self) -> usize { 457 | unsafe { 458 | cpp!([self as "const std::vector*"] -> size_t as "size_t" { 459 | return self->size(); 460 | }) 461 | } 462 | } 463 | } 464 | 465 | #[allow(deprecated)] 466 | impl VectorErase for VectorOfF32 { 467 | fn erase_range(&mut self, offset: usize, size: usize) { 468 | let begin = offset as size_t; 469 | let end = offset + size as size_t; 470 | unsafe { 471 | cpp!([self as "std::vector*", begin as "size_t", end as "size_t"] { 472 | self->erase(self->begin() + begin, self->begin() + end); 473 | }); 474 | } 475 | } 476 | } 477 | 478 | #[allow(deprecated)] 479 | impl VectorInsert for VectorOfF32 { 480 | fn push_back(&mut self, mut v: Self::Item) { 481 | let vref = &mut v; 482 | unsafe { 483 | cpp!([self as "std::vector*", vref as "float*"] { 484 | self->push_back(std::move(*vref)); 485 | }) 486 | } 487 | } 488 | } 489 | 490 | #[allow(deprecated)] 491 | impl VectorExtract for VectorOfF32 { 492 | fn extract(&mut self, index: usize) -> f32 { 493 | assert!(index < self.size()); 494 | let mut v: f32 = unsafe { mem::zeroed() }; 495 | let vref = &mut v; 496 | unsafe { 497 | cpp!([self as "std::vector*", index as "size_t", vref as "float*"] { 498 | *vref = std::move((*self)[index]); 499 | }) 500 | } 501 | v 502 | } 503 | } 504 | 505 | add_impl!(VectorOfF32); 506 | 507 | #[allow(deprecated)] 508 | impl Default for VectorOfUniquePtr { 509 | fn default() -> Self { 510 | let mut this = unsafe { mem::zeroed() }; 511 | let this_ref = &mut this; 512 | unsafe { 513 | cpp!([this_ref as "std::vector>*"] { 514 | new (this_ref) const std::vector>; 515 | }) 516 | } 517 | this 518 | } 519 | } 520 | 521 | #[allow(deprecated)] 522 | impl VectorSlice for VectorOfUniquePtr { 523 | type Item = UniquePtr; 524 | 525 | fn get_ptr(&self) -> *const Self::Item { 526 | unsafe { 527 | cpp!([self as "const std::vector>*"] 528 | -> *const UniquePtr as "const std::unique_ptr*" { 529 | return self->data(); 530 | }) 531 | } 532 | } 533 | 534 | fn get_mut_ptr(&mut self) -> *mut Self::Item { 535 | unsafe { 536 | cpp!([self as "std::vector>*"] 537 | -> *mut UniquePtr as "std::unique_ptr*" { 538 | return self->data(); 539 | }) 540 | } 541 | } 542 | 543 | fn size(&self) -> usize { 544 | unsafe { 545 | cpp!([self as "const std::vector>*"] -> size_t as "size_t" { 546 | return self->size(); 547 | }) 548 | } 549 | } 550 | } 551 | 552 | #[allow(deprecated)] 553 | impl VectorErase for VectorOfUniquePtr { 554 | fn erase_range(&mut self, offset: usize, size: usize) { 555 | let begin = offset as size_t; 556 | let end = offset + size as size_t; 557 | unsafe { 558 | cpp!([self as "std::vector>*", begin as "size_t", end as "size_t"] { 559 | self->erase(self->begin() + begin, self->begin() + end); 560 | }); 561 | } 562 | } 563 | } 564 | 565 | #[allow(deprecated)] 566 | impl VectorInsert> 567 | for VectorOfUniquePtr 568 | { 569 | fn push_back(&mut self, mut v: Self::Item) { 570 | let vref = &mut v; 571 | unsafe { 572 | cpp!([self as "std::vector>*", vref as "std::unique_ptr*"] { 573 | self->push_back(std::move(*vref)); 574 | }) 575 | } 576 | mem::forget(v); 577 | } 578 | } 579 | 580 | #[allow(deprecated)] 581 | impl VectorExtract> 582 | for VectorOfUniquePtr 583 | { 584 | fn extract(&mut self, index: usize) -> UniquePtr { 585 | assert!(index < self.size()); 586 | let mut v: UniquePtr = unsafe { mem::zeroed() }; 587 | let vref = &mut v; 588 | unsafe { 589 | cpp!([self as "std::vector>*", index as "size_t", vref as "std::unique_ptr*"] { 590 | *vref = std::move((*self)[index]); 591 | }) 592 | } 593 | v 594 | } 595 | } 596 | 597 | add_impl!(VectorOfUniquePtr); 598 | 599 | #[allow(deprecated)] 600 | impl Default for VectorOfUniquePtr { 601 | fn default() -> Self { 602 | let mut this = unsafe { mem::zeroed() }; 603 | let this_ref = &mut this; 604 | unsafe { 605 | cpp!([this_ref as "std::vector>*"] { 606 | new (this_ref) const std::vector>; 607 | }) 608 | } 609 | this 610 | } 611 | } 612 | 613 | #[allow(deprecated)] 614 | impl VectorSlice for VectorOfUniquePtr { 615 | type Item = UniquePtr; 616 | 617 | fn get_ptr(&self) -> *const Self::Item { 618 | unsafe { 619 | cpp!([self as "const std::vector>*"] 620 | -> *const UniquePtr as "const std::unique_ptr*" { 621 | return self->data(); 622 | }) 623 | } 624 | } 625 | 626 | fn get_mut_ptr(&mut self) -> *mut Self::Item { 627 | unsafe { 628 | cpp!([self as "std::vector>*"] 629 | -> *mut UniquePtr as "std::unique_ptr*" { 630 | return self->data(); 631 | }) 632 | } 633 | } 634 | 635 | fn size(&self) -> usize { 636 | unsafe { 637 | cpp!([self as "const std::vector>*"] -> size_t as "size_t" { 638 | return self->size(); 639 | }) 640 | } 641 | } 642 | } 643 | 644 | #[allow(deprecated)] 645 | impl VectorErase for VectorOfUniquePtr { 646 | fn erase_range(&mut self, offset: usize, size: usize) { 647 | let begin = offset as size_t; 648 | let end = offset + size as size_t; 649 | unsafe { 650 | cpp!([self as "std::vector>*", begin as "size_t", end as "size_t"] { 651 | self->erase(self->begin() + begin, self->begin() + end); 652 | }); 653 | } 654 | } 655 | } 656 | 657 | #[allow(deprecated)] 658 | impl VectorInsert> for VectorOfUniquePtr { 659 | fn push_back(&mut self, mut v: Self::Item) { 660 | let vref = &mut v; 661 | unsafe { 662 | cpp!([self as "std::vector>*", vref as "std::unique_ptr*"] { 663 | self->push_back(std::move(*vref)); 664 | }) 665 | } 666 | mem::forget(v); 667 | } 668 | } 669 | 670 | #[allow(deprecated)] 671 | impl VectorExtract> for VectorOfUniquePtr { 672 | fn extract(&mut self, index: usize) -> UniquePtr { 673 | assert!(index < self.size()); 674 | let mut v: UniquePtr = unsafe { mem::zeroed() }; 675 | let vref = &mut v; 676 | unsafe { 677 | cpp!([self as "std::vector>*", index as "size_t", vref as "std::unique_ptr*"] { 678 | *vref = std::move((*self)[index]); 679 | }) 680 | } 681 | v 682 | } 683 | } 684 | 685 | add_impl!(VectorOfUniquePtr); 686 | 687 | #[allow(deprecated)] 688 | impl Default for VectorOfUniquePtr { 689 | fn default() -> Self { 690 | let mut this = unsafe { mem::zeroed() }; 691 | let this_ref = &mut this; 692 | unsafe { 693 | cpp!([this_ref as "std::vector>*"] { 694 | new (this_ref) const std::vector>; 695 | }) 696 | } 697 | this 698 | } 699 | } 700 | 701 | #[allow(deprecated)] 702 | impl VectorSlice for VectorOfUniquePtr { 703 | type Item = UniquePtr; 704 | 705 | fn get_ptr(&self) -> *const Self::Item { 706 | unsafe { 707 | cpp!([self as "const std::vector>*"] 708 | -> *const UniquePtr as "const std::unique_ptr*" { 709 | return self->data(); 710 | }) 711 | } 712 | } 713 | 714 | fn get_mut_ptr(&mut self) -> *mut Self::Item { 715 | unsafe { 716 | cpp!([self as "std::vector>*"] 717 | -> *mut UniquePtr as "std::unique_ptr*" { 718 | return self->data(); 719 | }) 720 | } 721 | } 722 | 723 | fn size(&self) -> usize { 724 | unsafe { 725 | cpp!([self as "const std::vector>*"] -> size_t as "size_t" { 726 | return self->size(); 727 | }) 728 | } 729 | } 730 | } 731 | 732 | #[allow(deprecated)] 733 | impl VectorErase for VectorOfUniquePtr { 734 | fn erase_range(&mut self, offset: usize, size: usize) { 735 | let begin = offset as size_t; 736 | let end = offset + size as size_t; 737 | unsafe { 738 | cpp!([self as "std::vector>*", begin as "size_t", end as "size_t"] { 739 | self->erase(self->begin() + begin, self->begin() + end); 740 | }); 741 | } 742 | } 743 | } 744 | 745 | #[allow(deprecated)] 746 | impl VectorInsert> 747 | for VectorOfUniquePtr 748 | { 749 | fn push_back(&mut self, mut v: Self::Item) { 750 | let vref = &mut v; 751 | unsafe { 752 | cpp!([self as "std::vector>*", vref as "std::unique_ptr*"] { 753 | self->push_back(std::move(*vref)); 754 | }) 755 | } 756 | mem::forget(v); 757 | } 758 | } 759 | 760 | #[allow(deprecated)] 761 | impl VectorExtract> 762 | for VectorOfUniquePtr 763 | { 764 | fn extract(&mut self, index: usize) -> UniquePtr { 765 | assert!(index < self.size()); 766 | let mut v: UniquePtr = unsafe { mem::zeroed() }; 767 | let vref = &mut v; 768 | unsafe { 769 | cpp!([self as "std::vector>*", index as "size_t", vref as "std::unique_ptr*"] { 770 | *vref = std::move((*self)[index]); 771 | }) 772 | } 773 | v 774 | } 775 | } 776 | 777 | add_impl!(VectorOfUniquePtr); 778 | 779 | #[allow(deprecated)] 780 | impl Default for VectorOfUniquePtr { 781 | fn default() -> Self { 782 | let mut this = unsafe { mem::zeroed() }; 783 | let this_ref = &mut this; 784 | unsafe { 785 | cpp!([this_ref as "std::vector>*"] { 786 | new (this_ref) const std::vector>; 787 | }) 788 | } 789 | this 790 | } 791 | } 792 | 793 | #[allow(deprecated)] 794 | impl VectorSlice for VectorOfUniquePtr { 795 | type Item = UniquePtr; 796 | 797 | fn get_ptr(&self) -> *const Self::Item { 798 | unsafe { 799 | cpp!([self as "const std::vector>*"] 800 | -> *const UniquePtr as "const std::unique_ptr*" { 801 | return self->data(); 802 | }) 803 | } 804 | } 805 | 806 | fn get_mut_ptr(&mut self) -> *mut Self::Item { 807 | unsafe { 808 | cpp!([self as "std::vector>*"] 809 | -> *mut UniquePtr as "std::unique_ptr*" { 810 | return self->data(); 811 | }) 812 | } 813 | } 814 | 815 | fn size(&self) -> usize { 816 | unsafe { 817 | cpp!([self as "const std::vector>*"] -> size_t as "size_t" { 818 | return self->size(); 819 | }) 820 | } 821 | } 822 | } 823 | 824 | #[allow(deprecated)] 825 | impl VectorErase for VectorOfUniquePtr { 826 | fn erase_range(&mut self, offset: usize, size: usize) { 827 | let begin = offset as size_t; 828 | let end = offset + size as size_t; 829 | unsafe { 830 | cpp!([self as "std::vector>*", begin as "size_t", end as "size_t"] { 831 | self->erase(self->begin() + begin, self->begin() + end); 832 | }); 833 | } 834 | } 835 | } 836 | 837 | #[allow(deprecated)] 838 | impl VectorInsert> 839 | for VectorOfUniquePtr 840 | { 841 | fn push_back(&mut self, mut v: Self::Item) { 842 | let vref = &mut v; 843 | unsafe { 844 | cpp!([self as "std::vector>*", vref as "std::unique_ptr*"] { 845 | self->push_back(std::move(*vref)); 846 | }) 847 | } 848 | mem::forget(v); 849 | } 850 | } 851 | 852 | #[allow(deprecated)] 853 | impl VectorExtract> 854 | for VectorOfUniquePtr 855 | { 856 | fn extract(&mut self, index: usize) -> UniquePtr { 857 | assert!(index < self.size()); 858 | let mut v: UniquePtr = unsafe { mem::zeroed() }; 859 | let vref = &mut v; 860 | unsafe { 861 | cpp!([self as "std::vector>*", index as "size_t", vref as "std::unique_ptr*"] { 862 | *vref = std::move((*self)[index]); 863 | }) 864 | } 865 | v 866 | } 867 | } 868 | 869 | add_impl!(VectorOfUniquePtr); 870 | 871 | #[allow(deprecated)] 872 | impl Default for VectorOfUniquePtr { 873 | fn default() -> Self { 874 | let mut this = unsafe { mem::zeroed() }; 875 | let this_ref = &mut this; 876 | unsafe { 877 | cpp!([this_ref as "std::vector>*"] { 878 | new (this_ref) const std::vector>; 879 | }) 880 | } 881 | this 882 | } 883 | } 884 | 885 | #[allow(deprecated)] 886 | impl VectorSlice for VectorOfUniquePtr { 887 | type Item = UniquePtr; 888 | 889 | fn get_ptr(&self) -> *const Self::Item { 890 | unsafe { 891 | cpp!([self as "const std::vector>*"] 892 | -> *const UniquePtr as "const std::unique_ptr*" { 893 | return self->data(); 894 | }) 895 | } 896 | } 897 | 898 | fn get_mut_ptr(&mut self) -> *mut Self::Item { 899 | unsafe { 900 | cpp!([self as "std::vector>*"] 901 | -> *mut UniquePtr as "std::unique_ptr*" { 902 | return self->data(); 903 | }) 904 | } 905 | } 906 | 907 | fn size(&self) -> usize { 908 | unsafe { 909 | cpp!([self as "const std::vector>*"] -> size_t as "size_t" { 910 | return self->size(); 911 | }) 912 | } 913 | } 914 | } 915 | 916 | #[allow(deprecated)] 917 | impl VectorErase for VectorOfUniquePtr { 918 | fn erase_range(&mut self, offset: usize, size: usize) { 919 | let begin = offset as size_t; 920 | let end = offset + size as size_t; 921 | unsafe { 922 | cpp!([self as "std::vector>*", begin as "size_t", end as "size_t"] { 923 | self->erase(self->begin() + begin, self->begin() + end); 924 | }); 925 | } 926 | } 927 | } 928 | 929 | #[allow(deprecated)] 930 | impl VectorInsert> for VectorOfUniquePtr { 931 | fn push_back(&mut self, mut v: Self::Item) { 932 | let vref = &mut v; 933 | unsafe { 934 | cpp!([self as "std::vector>*", vref as "std::unique_ptr*"] { 935 | self->push_back(std::move(*vref)); 936 | }) 937 | } 938 | mem::forget(v); 939 | } 940 | } 941 | 942 | #[allow(deprecated)] 943 | impl VectorExtract> for VectorOfUniquePtr { 944 | fn extract(&mut self, index: usize) -> UniquePtr { 945 | assert!(index < self.size()); 946 | let mut v: UniquePtr = unsafe { mem::zeroed() }; 947 | let vref = &mut v; 948 | unsafe { 949 | cpp!([self as "std::vector>*", index as "size_t", vref as "std::unique_ptr*"] { 950 | *vref = std::move((*self)[index]); 951 | }) 952 | } 953 | v 954 | } 955 | } 956 | 957 | add_impl!(VectorOfUniquePtr); 958 | 959 | #[allow(deprecated)] 960 | impl Default for VectorOfUniquePtr { 961 | fn default() -> Self { 962 | let mut this = unsafe { mem::zeroed() }; 963 | let this_ref = &mut this; 964 | unsafe { 965 | cpp!([this_ref as "std::vector>*"] { 966 | new (this_ref) const std::vector>; 967 | }) 968 | } 969 | this 970 | } 971 | } 972 | 973 | #[allow(deprecated)] 974 | impl VectorSlice for VectorOfUniquePtr { 975 | type Item = UniquePtr; 976 | 977 | fn get_ptr(&self) -> *const Self::Item { 978 | unsafe { 979 | cpp!([self as "const std::vector>*"] 980 | -> *const UniquePtr as "const std::unique_ptr*" { 981 | return self->data(); 982 | }) 983 | } 984 | } 985 | 986 | fn get_mut_ptr(&mut self) -> *mut Self::Item { 987 | unsafe { 988 | cpp!([self as "std::vector>*"] 989 | -> *mut UniquePtr as "std::unique_ptr*" { 990 | return self->data(); 991 | }) 992 | } 993 | } 994 | 995 | fn size(&self) -> usize { 996 | unsafe { 997 | cpp!([self as "const std::vector>*"] -> size_t as "size_t" { 998 | return self->size(); 999 | }) 1000 | } 1001 | } 1002 | } 1003 | 1004 | #[allow(deprecated)] 1005 | impl VectorErase for VectorOfUniquePtr { 1006 | fn erase_range(&mut self, offset: usize, size: usize) { 1007 | let begin = offset as size_t; 1008 | let end = offset + size as size_t; 1009 | unsafe { 1010 | cpp!([self as "std::vector>*", begin as "size_t", end as "size_t"] { 1011 | self->erase(self->begin() + begin, self->begin() + end); 1012 | }); 1013 | } 1014 | } 1015 | } 1016 | 1017 | #[allow(deprecated)] 1018 | impl VectorInsert> 1019 | for VectorOfUniquePtr 1020 | { 1021 | fn push_back(&mut self, mut v: Self::Item) { 1022 | let vref = &mut v; 1023 | unsafe { 1024 | cpp!([self as "std::vector>*", vref as "std::unique_ptr*"] { 1025 | self->push_back(std::move(*vref)); 1026 | }) 1027 | } 1028 | mem::forget(v); 1029 | } 1030 | } 1031 | 1032 | #[allow(deprecated)] 1033 | impl VectorExtract> 1034 | for VectorOfUniquePtr 1035 | { 1036 | fn extract(&mut self, index: usize) -> UniquePtr { 1037 | assert!(index < self.size()); 1038 | let mut v: UniquePtr = unsafe { mem::zeroed() }; 1039 | let vref = &mut v; 1040 | unsafe { 1041 | cpp!([self as "std::vector>*", index as "size_t", vref as "std::unique_ptr*"] { 1042 | *vref = std::move((*self)[index]); 1043 | }) 1044 | } 1045 | v 1046 | } 1047 | } 1048 | 1049 | add_impl!(VectorOfUniquePtr); 1050 | -------------------------------------------------------------------------------- /submodules/make-NativeTable-configurable-as-polymorphic.patch: -------------------------------------------------------------------------------- 1 | From 645cce1b2ef7ea7d68d42fc6739cba8f1c767c68 Mon Sep 17 00:00:00 2001 2 | From: Martin Schwaighofer 3 | Date: Thu, 6 May 2021 23:02:40 +0200 4 | Subject: [PATCH 1/1] make NativeTable configurable as polymorphic 5 | 6 | This is required for import_tflite_types and build_inline_cpp but 7 | not for building libtensorflow-lite.a. It's done via a #define here. 8 | Previously this was done by rewriting the code from inside in build.rs 9 | of tflite-rs before the relevant build step. 10 | Doing it with a define seems much cleaner. 11 | 12 | See: 13 | https://github.com/boncheolgu/tflite-rs/pull/20#discussion_r343465525 14 | --- 15 | flatbuffers/include/flatbuffers/flatbuffers.h | 4 ++++ 16 | 1 file changed, 4 insertions(+) 17 | 18 | diff --git a/flatbuffers/include/flatbuffers/flatbuffers.h b/flatbuffers/include/flatbuffers/flatbuffers.h 19 | index a1a95f0..6480764 100644 20 | --- a/flatbuffers/include/flatbuffers/flatbuffers.h 21 | +++ b/flatbuffers/include/flatbuffers/flatbuffers.h 22 | @@ -2410,7 +2410,11 @@ inline uoffset_t GetPrefixedSize(const uint8_t* buf){ return ReadScalar Result<()> { 8 | let resolver = BuiltinOpResolver::default(); 9 | 10 | let builder = InterpreterBuilder::new(model, &resolver)?; 11 | let mut interpreter = builder.build()?; 12 | 13 | interpreter.allocate_tensors()?; 14 | 15 | let inputs = interpreter.inputs().to_vec(); 16 | assert_eq!(inputs.len(), 1); 17 | 18 | let input_index = inputs[0]; 19 | 20 | let outputs = interpreter.outputs().to_vec(); 21 | assert_eq!(outputs.len(), 1); 22 | 23 | let output_index = outputs[0]; 24 | 25 | let input_tensor = interpreter.tensor_info(input_index).unwrap(); 26 | assert_eq!(input_tensor.dims, vec![1, 28, 28, 1]); 27 | 28 | let output_tensor = interpreter.tensor_info(output_index).unwrap(); 29 | assert_eq!(output_tensor.dims, vec![1, 10]); 30 | 31 | let mut input_file = File::open("data/mnist10.bin")?; 32 | for i in 0..10 { 33 | input_file.read_exact(interpreter.tensor_data_mut(input_index)?)?; 34 | 35 | interpreter.invoke()?; 36 | 37 | let output: &[u8] = interpreter.tensor_data(output_index)?; 38 | let guess = output.iter().enumerate().max_by(|x, y| x.1.cmp(y.1)).unwrap().0; 39 | 40 | println!("{i}: {output:?}"); 41 | assert_eq!(i, guess); 42 | } 43 | Ok(()) 44 | } 45 | 46 | #[test] 47 | fn mobilenetv1_mnist() -> Result<()> { 48 | test_mnist(&FlatBufferModel::build_from_file("data/MNISTnet_uint8_quant.tflite")?)?; 49 | 50 | let buf = fs::read("data/MNISTnet_uint8_quant.tflite")?; 51 | test_mnist(&FlatBufferModel::build_from_buffer(buf)?) 52 | } 53 | 54 | #[test] 55 | fn mobilenetv2_mnist() -> Result<()> { 56 | test_mnist(&FlatBufferModel::build_from_file("data/MNISTnet_v2_uint8_quant.tflite")?)?; 57 | 58 | let buf = fs::read("data/MNISTnet_v2_uint8_quant.tflite")?; 59 | test_mnist(&FlatBufferModel::build_from_buffer(buf)?) 60 | } 61 | --------------------------------------------------------------------------------