├── .gitignore ├── CMakeLists.txt ├── LICENSE ├── README.md ├── cmake ├── BuildFlatBuffers.cmake ├── flatbuffers.cmake └── utils.cmake ├── schemas └── schema.fbs └── src ├── cpp-gen.cc ├── cpp-gen.h ├── dump.cc ├── dump.h ├── exception.h ├── main.cc ├── model.cc ├── model.h └── templates ├── jni.tpl ├── top_nn_cc.tpl └── top_nn_h.tpl /.gitignore: -------------------------------------------------------------------------------- 1 | CMakeCache.txt 2 | CMakeFiles 3 | CMakeScripts 4 | Testing 5 | Makefile 6 | cmake_install.cmake 7 | install_manifest.txt 8 | compile_commands.json 9 | CTestTestfile.cmake 10 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.8) 2 | project(nnt) 3 | 4 | set(CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/cmake/" ${CMAKE_MODULE_PATH}) 5 | 6 | include(CheckCXXCompilerFlag) 7 | include(cmake/utils.cmake) 8 | 9 | # 10 | # compile flags 11 | # 12 | CHECK_CXX_COMPILER_FLAG("-std=c++17" COMPILER_SUPPORTS_CXX17) 13 | if(COMPILER_SUPPORTS_CXX17) 14 | set(CMAKE_CXX_STANDARD 17) 15 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 16 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17") 17 | else() 18 | message(STATUS "The compiler ${CMAKE_CXX_COMPILER} has no C++17 support. Please use a different C++ compiler.") 19 | endif() 20 | 21 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra") 22 | set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g") 23 | set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3") 24 | 25 | # 26 | # find packages 27 | # 28 | find_package(Threads) 29 | find_package(Boost COMPONENTS system filesystem program_options REQUIRED) 30 | 31 | # 32 | # Find flatbuffers and build schemas 33 | # 34 | message(STATUS "-- FlatBuffers will be downloaded and built from source") 35 | include(cmake/flatbuffers.cmake) 36 | 37 | file(GLOB_RECURSE FBS_FILES "schemas/*.fbs") 38 | 39 | build_flatbuffers( 40 | "${FBS_FILES}" 41 | "${PROJECT_SOURCE_DIR}/schemas" 42 | NntSchemas 43 | "" # No additional_dependencies 44 | "${PROJECT_SOURCE_DIR}/schemas" 45 | "" # No binary_schemas_dir 46 | "" # No copy_text_schemas_dir 47 | ) 48 | 49 | set_property(TARGET NntSchemas APPEND PROPERTY SOURCES ${FBS_FILES}) 50 | add_dependencies(NntSchemas flatbuffers) 51 | 52 | # 53 | # include directories 54 | # 55 | include_directories(${FLATBUFFERS_INCLUDE_DIRS}) 56 | include_directories(${CMAKE_SOURCE_DIR}) 57 | 58 | # 59 | # list cpp files 60 | # 61 | file(GLOB SRCS "${CMAKE_SOURCE_DIR}/src/*.cc") 62 | 63 | # 64 | # create tensorflow proto library 65 | # 66 | add_executable(nnt ${SRCS}) 67 | add_dependencies(nnt NntSchemas) 68 | 69 | target_link_libraries(nnt 70 | ${Boost_FILESYSTEM_LIBRARY} 71 | ${Boost_SYSTEM_LIBRARY} 72 | ${Boost_LIBRARIES} 73 | ${FLATBUFFERS_LIBRARIES}) 74 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Network Transpiler 2 | Convert a model from tflite to a C++ source code using Android Neural Network API. 3 | 4 | ## Build 5 | In a project directory create a build directory 6 | ``` 7 | $ mkdir build 8 | $ cd build 9 | ``` 10 | 11 | In build directory call cmake to generate the make file, and call the make program 12 | ``` 13 | $ cmake .. 14 | $ make 15 | ``` 16 | 17 | To verify if the compilation was successfully completed 18 | ``` 19 | $ .\nnt -h 20 | ``` 21 | 22 | ## How to use 23 | ``` 24 | -h [ --help ] Help screen 25 | -i [ --info ] Info about model 26 | -d [ --dot ] arg Generate dot file 27 | -m [ --model ] arg flatbuffer neural network model 28 | -p [ --path ] arg store generated files on this path 29 | -j [ --javapackage ] arg java package for JNI 30 | ``` 31 | 32 | In all examples, consider I have a mobilenet_quant_v1_224.tflite model file in build directory, the same directory from where I am executing the nnt executaeble. 33 | 34 | ### Model info 35 | ``` 36 | ./nnt -m mobilenet_quant_v1_224.tflite -i 37 | ``` 38 | It generate the output: 39 | ``` 40 | ::Inputs:: 41 | Placeholder [1, 224, 224, 3] (quantized) 42 | └─ Quant: {min:[0], max:[1], scale: [0.00392157], zero_point:[0]} 43 | 44 | ::Outputs:: 45 | MobilenetV1/Predictions/Softmax [1, 1001] (quantized) 46 | ``` 47 | 48 | ### Generating dot file 49 | ``` 50 | ./nnt -m mobilenet_quant_v1_224.tflite -d mobnet.dot 51 | ``` 52 | The file mobnet.dot was generated on the same directory 53 | 54 | ### Generating NNAPI files to use on Android 55 | ``` 56 | ./nnt -m mobilenet_quant_v1_224.tflite -j com.nnt.nnexample -p mobnet_path 57 | ``` 58 | It creates a directory with name "mobnet_path" with files: [jni.cc, nn.h, nn.cc, weights_biases.bin] 59 | where the java package is com.nnt.nnexample 60 | -------------------------------------------------------------------------------- /cmake/BuildFlatBuffers.cmake: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # General function to create FlatBuffer build rules for the given list of 16 | # schemas. 17 | # 18 | # flatbuffers_schemas: A list of flatbuffer schema files to process. 19 | # 20 | # schema_include_dirs: A list of schema file include directories, which will be 21 | # passed to flatc via the -I parameter. 22 | # 23 | # custom_target_name: The generated files will be added as dependencies for a 24 | # new custom target with this name. You should add that target as a dependency 25 | # for your main target to ensure these files are built. You can also retrieve 26 | # various properties from this target, such as GENERATED_INCLUDES_DIR, 27 | # BINARY_SCHEMAS_DIR, and COPY_TEXT_SCHEMAS_DIR. 28 | # 29 | # additional_dependencies: A list of additional dependencies that you'd like 30 | # all generated files to depend on. Pass in a blank string if you have none. 31 | # 32 | # generated_includes_dir: Where to generate the C++ header files for these 33 | # schemas. The generated includes directory will automatically be added to 34 | # CMake's include directories, and will be where generated header files are 35 | # placed. This parameter is optional; pass in empty string if you don't want to 36 | # generate include files for these schemas. 37 | # 38 | # binary_schemas_dir: If you specify an optional binary schema directory, binary 39 | # schemas will be generated for these schemas as well, and placed into the given 40 | # directory. 41 | # 42 | # copy_text_schemas_dir: If you want all text schemas (including schemas from 43 | # all schema include directories) copied into a directory (for example, if you 44 | # need them within your project to build JSON files), you can specify that 45 | # folder here. All text schemas will be copied to that folder. 46 | # 47 | # IMPORTANT: Make sure you quote all list arguments you pass to this function! 48 | # Otherwise CMake will only pass in the first element. 49 | # Example: build_flatbuffers("${fb_files}" "${include_dirs}" target_name ...) 50 | function(build_flatbuffers flatbuffers_schemas 51 | schema_include_dirs 52 | custom_target_name 53 | additional_dependencies 54 | generated_includes_dir 55 | binary_schemas_dir 56 | copy_text_schemas_dir) 57 | 58 | # Test if including from FindFlatBuffers 59 | if(FLATBUFFERS_FLATC_EXECUTABLE) 60 | set(FLATC_TARGET "") 61 | set(FLATC ${FLATBUFFERS_FLATC_EXECUTABLE}) 62 | else() 63 | set(FLATC_TARGET flatc) 64 | set(FLATC flatc) 65 | endif() 66 | set(FLATC_SCHEMA_ARGS --gen-mutable) 67 | if(FLATBUFFERS_FLATC_SCHEMA_EXTRA_ARGS) 68 | set(FLATC_SCHEMA_ARGS 69 | ${FLATBUFFERS_FLATC_SCHEMA_EXTRA_ARGS} 70 | ${FLATC_SCHEMA_ARGS} 71 | ) 72 | endif() 73 | 74 | set(schema_glob "*.fbs") 75 | # Generate the include files parameters. 76 | set(include_params "") 77 | set(all_generated_files "") 78 | foreach (include_dir ${schema_include_dirs}) 79 | set(include_params -I ${include_dir} ${include_params}) 80 | if (NOT ${copy_text_schemas_dir} STREQUAL "") 81 | # Copy text schemas from dependent folders. 82 | file(GLOB_RECURSE dependent_schemas ${include_dir}/${schema_glob}) 83 | foreach (dependent_schema ${dependent_schemas}) 84 | file(COPY ${dependent_schema} DESTINATION ${copy_text_schemas_dir}) 85 | endforeach() 86 | endif() 87 | endforeach() 88 | 89 | foreach(schema ${flatbuffers_schemas}) 90 | get_filename_component(filename ${schema} NAME_WE) 91 | # For each schema, do the things we requested. 92 | if (NOT ${generated_includes_dir} STREQUAL "") 93 | set(generated_include ${generated_includes_dir}/${filename}_generated.h) 94 | add_custom_command( 95 | OUTPUT ${generated_include} 96 | COMMAND ${FLATC} ${FLATC_SCHEMA_ARGS} 97 | -o ${generated_includes_dir} 98 | ${include_params} 99 | -c ${schema} 100 | DEPENDS ${FLATC_TARGET} ${schema} ${additional_dependencies}) 101 | list(APPEND all_generated_files ${generated_include}) 102 | endif() 103 | 104 | if (NOT ${binary_schemas_dir} STREQUAL "") 105 | set(binary_schema ${binary_schemas_dir}/${filename}.bfbs) 106 | add_custom_command( 107 | OUTPUT ${binary_schema} 108 | COMMAND ${FLATC} -b --schema 109 | -o ${binary_schemas_dir} 110 | ${include_params} 111 | ${schema} 112 | DEPENDS ${FLATC_TARGET} ${schema} ${additional_dependencies}) 113 | list(APPEND all_generated_files ${binary_schema}) 114 | endif() 115 | 116 | if (NOT ${copy_text_schemas_dir} STREQUAL "") 117 | file(COPY ${schema} DESTINATION ${copy_text_schemas_dir}) 118 | endif() 119 | endforeach() 120 | 121 | # Create a custom target that depends on all the generated files. 122 | # This is the target that you can depend on to trigger all these 123 | # to be built. 124 | add_custom_target(${custom_target_name} 125 | DEPENDS ${all_generated_files} ${additional_dependencies}) 126 | 127 | # Register the include directory we are using. 128 | if (NOT ${generated_includes_dir} STREQUAL "") 129 | include_directories(${generated_includes_dir}) 130 | set_property(TARGET ${custom_target_name} 131 | PROPERTY GENERATED_INCLUDES_DIR 132 | ${generated_includes_dir}) 133 | endif() 134 | 135 | # Register the binary schemas dir we are using. 136 | if (NOT ${binary_schemas_dir} STREQUAL "") 137 | set_property(TARGET ${custom_target_name} 138 | PROPERTY BINARY_SCHEMAS_DIR 139 | ${binary_schemas_dir}) 140 | endif() 141 | 142 | # Register the text schema copy dir we are using. 143 | if (NOT ${copy_text_schemas_dir} STREQUAL "") 144 | set_property(TARGET ${custom_target_name} 145 | PROPERTY COPY_TEXT_SCHEMAS_DIR 146 | ${copy_text_schemas_dir}) 147 | endif() 148 | endfunction() 149 | -------------------------------------------------------------------------------- /cmake/flatbuffers.cmake: -------------------------------------------------------------------------------- 1 | include(ExternalProject) 2 | 3 | set(flatbuffers_PREFIX ${CMAKE_CURRENT_BINARY_DIR}/external/flatbuffers) 4 | 5 | ExternalProject_Add( 6 | flatbuffers 7 | PREFIX ${flatbuffers_PREFIX} 8 | URL "https://github.com/google/flatbuffers/archive/v1.7.1.tar.gz" 9 | URL_MD5 "81934736f31fbd2cfdb513e71b53b358" 10 | CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${flatbuffers_PREFIX} -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} 11 | LOG_UPDATE ON 12 | LOG_CONFIGURE ON 13 | LOG_BUILD ON 14 | ) 15 | 16 | set(FLATBUFFERS_INCLUDE_DIRS "${flatbuffers_PREFIX}/include") 17 | set(FLATBUFFERS_FLATC_EXECUTABLE "${flatbuffers_PREFIX}/bin/flatc") 18 | set(FLATBUFFERS_LIBRARIES "${flatbuffers_PREFIX}/lib/libflatbuffers.a") 19 | 20 | include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/BuildFlatBuffers.cmake) 21 | -------------------------------------------------------------------------------- /cmake/utils.cmake: -------------------------------------------------------------------------------- 1 | function(status text) 2 | set(status_cond) 3 | set(status_then) 4 | set(status_else) 5 | 6 | set(status_current_name "cond") 7 | foreach(arg ${ARGN}) 8 | if(arg STREQUAL "THEN") 9 | set(status_current_name "then") 10 | elseif(arg STREQUAL "ELSE") 11 | set(status_current_name "else") 12 | else() 13 | list(APPEND status_${status_current_name} ${arg}) 14 | endif() 15 | endforeach() 16 | 17 | if(DEFINED status_cond) 18 | set(status_placeholder_length 32) 19 | string(RANDOM LENGTH ${status_placeholder_length} ALPHABET " " status_placeholder) 20 | string(LENGTH "${text}" status_text_length) 21 | if(status_text_length LESS status_placeholder_length) 22 | string(SUBSTRING "${text}${status_placeholder}" 0 ${status_placeholder_length} status_text) 23 | elseif(DEFINED status_then OR DEFINED status_else) 24 | message(STATUS "${text}") 25 | set(status_text "${status_placeholder}") 26 | else() 27 | set(status_text "${text}") 28 | endif() 29 | 30 | if(DEFINED status_then OR DEFINED status_else) 31 | if(${status_cond}) 32 | string(REPLACE ";" " " status_then "${status_then}") 33 | string(REGEX REPLACE "^[ \t]+" "" status_then "${status_then}") 34 | message(STATUS "${status_text} ${status_then}") 35 | else() 36 | string(REPLACE ";" " " status_else "${status_else}") 37 | string(REGEX REPLACE "^[ \t]+" "" status_else "${status_else}") 38 | message(STATUS "${status_text} ${status_else}") 39 | endif() 40 | else() 41 | string(REPLACE ";" " " status_cond "${status_cond}") 42 | string(REGEX REPLACE "^[ \t]+" "" status_cond "${status_cond}") 43 | message(STATUS "${status_text} ${status_cond}") 44 | endif() 45 | else() 46 | message(STATUS "${text}") 47 | endif() 48 | endfunction() 49 | -------------------------------------------------------------------------------- /schemas/schema.fbs: -------------------------------------------------------------------------------- 1 | // Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | // Revision History 16 | // Version 0: Initial version. 17 | // Version 1: Add subgraphs to schema. 18 | // Version 2: Rename operators to conform to NN API. 19 | // Version 3: Move buffer data from Model.Subgraph.Tensors to Model.Buffers. 20 | 21 | namespace tflite; 22 | 23 | // This corresponds to the version. 24 | file_identifier "TFL3"; 25 | // File extension of any written files. 26 | file_extension "tflite"; 27 | 28 | // The type of data stored in a tensor. 29 | enum TensorType : byte { 30 | FLOAT32 = 0, 31 | FLOAT16 = 1, 32 | INT32 = 2, 33 | UINT8 = 3, 34 | INT64 = 4, 35 | STRING = 5, 36 | } 37 | 38 | // Parameters for converting a quantized tensor back to float. Given a 39 | // quantized value q, the corresponding float value f should be: 40 | // f = scale * (q - zero_point) 41 | table QuantizationParameters { 42 | min:[float]; // For importing back into tensorflow. 43 | max:[float]; // For importing back into tensorflow. 44 | scale:[float]; 45 | zero_point:[long]; 46 | } 47 | 48 | table Tensor { 49 | // The tensor shape. The meaning of each entry is operator-specific but 50 | // builtin ops use: [batch size, height, width, number of channels] (That's 51 | // Tensorflow's NHWC). 52 | shape:[int]; 53 | type:TensorType; 54 | // An index that refers to the buffers table at the root of the model. Or, 55 | // if there is no data buffer associated (i.e. intermediate results), then 56 | // this is 0 (which refers to an always existent empty buffer). 57 | // 58 | // The data_buffer itself is an opaque container, with the assumption that the 59 | // target device is little-endian. In addition, all builtin operators assume 60 | // the memory is ordered such that if `shape` is [4, 3, 2], then index 61 | // [i, j, k] maps to data_buffer[i*3*2 + j*2 + k]. 62 | buffer:uint; 63 | name:string; // For debugging and importing back into tensorflow. 64 | quantization:QuantizationParameters; // Optional. 65 | } 66 | 67 | // A list of builtin operators. Builtin operators a slighlty faster than custom 68 | // ones, but not by much. Moreover, while custom operators accept an opaque 69 | // object containing configuration parameters, builtins have a predetermined 70 | // set of acceptable options. 71 | enum BuiltinOperator : byte { 72 | ADD = 0, 73 | AVERAGE_POOL_2D = 1, 74 | CONCATENATION = 2, 75 | CONV_2D = 3, 76 | DEPTHWISE_CONV_2D = 4, 77 | // DEPTH_TO_SPACE = 5, 78 | DEQUANTIZE = 6, 79 | EMBEDDING_LOOKUP = 7, 80 | // FLOOR = 8, 81 | FULLY_CONNECTED = 9, 82 | HASHTABLE_LOOKUP = 10, 83 | L2_NORMALIZATION = 11, 84 | L2_POOL_2D = 12, 85 | LOCAL_RESPONSE_NORMALIZATION = 13, 86 | LOGISTIC = 14, 87 | LSH_PROJECTION = 15, 88 | LSTM = 16, 89 | MAX_POOL_2D = 17, 90 | MUL = 18, 91 | RELU = 19, 92 | // NOTE(aselle): RELU_N1_TO_1 used to be called RELU1, but it was renamed 93 | // since different model developers use RELU1 in different ways. Never 94 | // create another op called RELU1. 95 | RELU_N1_TO_1 = 20, 96 | RELU6 = 21, 97 | RESHAPE = 22, 98 | RESIZE_BILINEAR = 23, 99 | RNN = 24, 100 | SOFTMAX = 25, 101 | SPACE_TO_DEPTH = 26, 102 | SVDF = 27, 103 | TANH = 28, 104 | // TODO(aselle): Consider rename to CONCATENATE_EMBEDDINGS 105 | CONCAT_EMBEDDINGS = 29, 106 | SKIP_GRAM = 30, 107 | CALL = 31, 108 | CUSTOM = 32, 109 | EMBEDDING_LOOKUP_SPARSE = 33, 110 | PAD = 34, 111 | UNIDIRECTIONAL_SEQUENCE_RNN = 35, 112 | GATHER = 36, 113 | BATCH_TO_SPACE_ND = 37, 114 | SPACE_TO_BATCH_ND = 38, 115 | TRANSPOSE = 39, 116 | MEAN = 40, 117 | SUB = 41, 118 | DIV = 42, 119 | SQUEEZE = 43, 120 | UNIDIRECTIONAL_SEQUENCE_LSTM = 44, 121 | STRIDED_SLICE = 45, 122 | BIDIRECTIONAL_SEQUENCE_RNN = 46, 123 | EXP = 47, 124 | TOPK_V2 = 48, 125 | SPLIT = 49, 126 | LOG_SOFTMAX = 50, 127 | // DELEGATE is a special op type for the operations which are delegated to 128 | // other backends. 129 | // WARNING: Experimental interface, subject to change 130 | DELEGATE = 51, 131 | BIDIRECTIONAL_SEQUENCE_LSTM = 52, 132 | CAST = 53, 133 | PRELU = 54, 134 | MAXIMUM = 55, 135 | } 136 | 137 | // Options for the builtin operators. 138 | union BuiltinOptions { 139 | Conv2DOptions, 140 | DepthwiseConv2DOptions, 141 | ConcatEmbeddingsOptions, 142 | LSHProjectionOptions, 143 | Pool2DOptions, 144 | SVDFOptions, 145 | RNNOptions, 146 | FullyConnectedOptions, 147 | SoftmaxOptions, 148 | ConcatenationOptions, 149 | AddOptions, 150 | L2NormOptions, 151 | LocalResponseNormalizationOptions, 152 | LSTMOptions, 153 | ResizeBilinearOptions, 154 | CallOptions, 155 | ReshapeOptions, 156 | SkipGramOptions, 157 | SpaceToDepthOptions, 158 | EmbeddingLookupSparseOptions, 159 | MulOptions, 160 | PadOptions, 161 | GatherOptions, 162 | BatchToSpaceNDOptions, 163 | SpaceToBatchNDOptions, 164 | TransposeOptions, 165 | MeanOptions, 166 | SubOptions, 167 | DivOptions, 168 | SqueezeOptions, 169 | SequenceRNNOptions, 170 | StridedSliceOptions, 171 | ExpOptions, 172 | TopKV2Options, 173 | SplitOptions, 174 | LogSoftmaxOptions, 175 | CastOptions, 176 | DequantizeOptions, 177 | MaximumOptions, 178 | } 179 | 180 | enum Padding : byte { SAME, VALID } 181 | 182 | enum ActivationFunctionType : byte { 183 | NONE = 0, 184 | RELU = 1, 185 | RELU_N1_TO_1 = 2, 186 | RELU6 = 3, 187 | TANH = 4, 188 | SIGN_BIT = 5, 189 | } 190 | 191 | table Conv2DOptions { 192 | padding:Padding; 193 | stride_w:int; 194 | stride_h:int; 195 | fused_activation_function:ActivationFunctionType; 196 | } 197 | 198 | table Pool2DOptions { 199 | padding:Padding; 200 | stride_w:int; 201 | stride_h:int; 202 | filter_width:int; 203 | filter_height:int; 204 | fused_activation_function:ActivationFunctionType; 205 | } 206 | 207 | table DepthwiseConv2DOptions { 208 | padding:Padding; 209 | stride_w:int; 210 | stride_h:int; 211 | depth_multiplier:int; 212 | fused_activation_function:ActivationFunctionType; 213 | } 214 | 215 | table ConcatEmbeddingsOptions { 216 | num_channels:int; 217 | num_columns_per_channel:[int]; 218 | embedding_dim_per_channel:[int]; // This could be inferred from parameters. 219 | } 220 | 221 | enum LSHProjectionType: byte { 222 | UNKNOWN = 0, 223 | SPARSE = 1, 224 | DENSE = 2, 225 | } 226 | 227 | table LSHProjectionOptions { 228 | type: LSHProjectionType; 229 | } 230 | 231 | table SVDFOptions { 232 | rank:int; 233 | fused_activation_function:ActivationFunctionType; 234 | } 235 | 236 | // An implementation of TensorFlow RNNCell. 237 | table RNNOptions { 238 | fused_activation_function:ActivationFunctionType; 239 | } 240 | 241 | // An implementation of TensorFlow dynamic_rnn with RNNCell. 242 | table SequenceRNNOptions { 243 | time_major:bool; 244 | fused_activation_function:ActivationFunctionType; 245 | } 246 | 247 | // An implementation of TensorFlow bidrectional_dynamic_rnn with RNNCell. 248 | table BidirectionalSequenceRNNOptions { 249 | time_major:bool; 250 | fused_activation_function:ActivationFunctionType; 251 | } 252 | 253 | // An implementation of TensorFlow fully_connected (a.k.a Dense) layer. 254 | table FullyConnectedOptions { 255 | fused_activation_function:ActivationFunctionType; 256 | } 257 | 258 | table SoftmaxOptions { 259 | beta: float; 260 | } 261 | 262 | // An implementation of TensorFlow concat. 263 | table ConcatenationOptions { 264 | axis:int; 265 | fused_activation_function:ActivationFunctionType; 266 | } 267 | 268 | table AddOptions { 269 | fused_activation_function:ActivationFunctionType; 270 | } 271 | 272 | table MulOptions { 273 | fused_activation_function:ActivationFunctionType; 274 | } 275 | 276 | table L2NormOptions { 277 | fused_activation_function:ActivationFunctionType; 278 | } 279 | 280 | table LocalResponseNormalizationOptions { 281 | radius:int; 282 | bias:float; 283 | alpha:float; 284 | beta:float; 285 | } 286 | 287 | // An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell 288 | table LSTMOptions { 289 | fused_activation_function:ActivationFunctionType; 290 | cell_clip: float; // Optional, 0.0 means no clipping 291 | proj_clip: float; // Optional, 0.0 means no clipping 292 | } 293 | 294 | table ResizeBilinearOptions { 295 | new_height: int (deprecated); 296 | new_width: int (deprecated); 297 | align_corners: bool; 298 | } 299 | 300 | // A call operation options 301 | table CallOptions { 302 | // The subgraph index that needs to be called. 303 | subgraph:uint; 304 | } 305 | 306 | table PadOptions { 307 | } 308 | 309 | table ReshapeOptions { 310 | new_shape:[int]; 311 | } 312 | 313 | table SpaceToBatchNDOptions { 314 | } 315 | 316 | table BatchToSpaceNDOptions { 317 | } 318 | 319 | table SkipGramOptions { 320 | ngram_size: int; 321 | max_skip_size: int; 322 | include_all_ngrams: bool; 323 | } 324 | 325 | table SpaceToDepthOptions { 326 | block_size: int; 327 | } 328 | 329 | table SubOptions { 330 | fused_activation_function:ActivationFunctionType; 331 | } 332 | 333 | table DivOptions { 334 | fused_activation_function:ActivationFunctionType; 335 | } 336 | 337 | table TopKV2Options { 338 | } 339 | 340 | enum CombinerType : byte { 341 | SUM = 0, 342 | MEAN = 1, 343 | SQRTN = 2, 344 | } 345 | 346 | table EmbeddingLookupSparseOptions { 347 | combiner:CombinerType; 348 | } 349 | 350 | table GatherOptions { 351 | axis: int; 352 | } 353 | 354 | table TransposeOptions { 355 | } 356 | 357 | table ExpOptions { 358 | } 359 | 360 | table MeanOptions { 361 | keep_dims: bool; 362 | } 363 | 364 | table SqueezeOptions { 365 | squeeze_dims:[int]; 366 | } 367 | 368 | table SplitOptions { 369 | num_splits: int; 370 | } 371 | 372 | table StridedSliceOptions { 373 | begin_mask: int; 374 | end_mask: int; 375 | ellipsis_mask: int; 376 | new_axis_mask: int; 377 | shrink_axis_mask: int; 378 | } 379 | 380 | table LogSoftmaxOptions { 381 | } 382 | 383 | table CastOptions { 384 | in_data_type: TensorType; 385 | out_data_type: TensorType; 386 | } 387 | 388 | table DequantizeOptions { 389 | } 390 | 391 | table MaximumOptions { 392 | } 393 | 394 | // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a 395 | // builtin, or a string if the operator is custom. 396 | table OperatorCode { 397 | builtin_code:BuiltinOperator; 398 | custom_code:string; 399 | } 400 | 401 | enum CustomOptionsFormat : byte { 402 | FLEXBUFFERS = 0, 403 | } 404 | 405 | // An operator takes tensors as inputs and outputs. The type of operation being 406 | // performed is determined by an index into the list of valid OperatorCodes, 407 | // while the specifics of each operations is configured using builtin_options 408 | // or custom_options. 409 | table Operator { 410 | // Index into the operator_codes array. Using an integer here avoids 411 | // complicate map lookups. 412 | opcode_index:uint; 413 | 414 | // Optional input and output tensors are indicated by -1. 415 | inputs:[int]; 416 | outputs:[int]; 417 | 418 | builtin_options:BuiltinOptions; 419 | custom_options:[ubyte]; 420 | custom_options_format:CustomOptionsFormat; 421 | } 422 | 423 | // The root type, defining a model. 424 | table SubGraph { 425 | // A list of all tensors used in this model. 426 | tensors:[Tensor]; 427 | 428 | // Indices of the input tensors. 429 | inputs:[int]; 430 | 431 | // Indices of the output tensors. 432 | outputs:[int]; 433 | 434 | // All operators, in execution order. 435 | operators:[Operator]; 436 | 437 | // Name of subgraph (used for debugging). 438 | name:string; 439 | } 440 | 441 | // Table of raw data buffers (used for constant tensors). Referenced by tensors 442 | // by index. 443 | table Buffer { 444 | data:[ubyte]; 445 | } 446 | 447 | table Model { 448 | // Version of the schema. 449 | version:uint; 450 | 451 | // A list of all operator codes used in this model. This is 452 | // kept in order because operators carry an index into this 453 | // vector. 454 | operator_codes:[OperatorCode]; 455 | 456 | // All the subgraphs of the model. The 0th is assumed to be the main 457 | // model. 458 | subgraphs:[SubGraph]; 459 | 460 | // A description of the model. 461 | description:string; 462 | 463 | // Buffers of the model 464 | buffers:[Buffer]; 465 | 466 | } 467 | 468 | root_type Model; 469 | -------------------------------------------------------------------------------- /src/cpp-gen.cc: -------------------------------------------------------------------------------- 1 | #include "cpp-gen.h" 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include "exception.h" 10 | 11 | namespace nnt { 12 | 13 | std::string TensorsHeader::Generate() { 14 | const std::vector& buffers = model_.Buffers(); 15 | std::string str_buf; 16 | 17 | for (const auto& buf : buffers) { 18 | if (buf.Data().size() > 0) { 19 | std::string str_vec = ""; 20 | for (const auto& c : buf.Data()) { 21 | str_buf += c; 22 | } 23 | } 24 | } 25 | 26 | return str_buf; 27 | } 28 | 29 | std::string TensorsHeader::Assembler() { 30 | return Generate(); 31 | } 32 | 33 | std::string ModelGen::Generate() { 34 | std::string str_init = "Init"; 35 | return str_init; 36 | } 37 | 38 | std::string ModelGen::TensorTypeStr(TensorType type) { 39 | switch (type) { 40 | case TensorType::FLOAT32: 41 | return "ANEURALNETWORKS_TENSOR_FLOAT32"; 42 | break; 43 | 44 | case TensorType::INT32: 45 | return "ANEURALNETWORKS_TENSOR_INT32"; 46 | break; 47 | 48 | case TensorType::UINT8: 49 | return "ANEURALNETWORKS_TENSOR_QUANT8_ASYMM"; 50 | break; 51 | 52 | default: 53 | FATAL("Tensor type not valid for Android NNAPI") 54 | } 55 | } 56 | 57 | std::string ModelGen::TensorCppTypeStr(TensorType type) { 58 | switch (type) { 59 | case TensorType::FLOAT32: 60 | return "FLOAT16"; 61 | break; 62 | 63 | case TensorType::INT32: 64 | return "int32_t"; 65 | break; 66 | 67 | case TensorType::UINT8: 68 | return "char"; 69 | break; 70 | 71 | default: 72 | FATAL("Tensor type not valid for Android NNAPI") 73 | } 74 | } 75 | 76 | std::string ModelGen::TensorDim(const std::vector& dim) { 77 | std::string str_out = "{"; 78 | 79 | for (const auto& e : dim) { 80 | str_out += std::to_string(e) + ","; 81 | } 82 | 83 | str_out = str_out.substr(0, str_out.length() - 2); 84 | str_out += "}"; 85 | 86 | return str_out; 87 | } 88 | 89 | float ModelGen::TensorQuantizationScale(const QuantizationParameters& q) { 90 | if (q.scale.size() > 0) { 91 | return q.scale[0]; 92 | } else { 93 | return 0.0f; 94 | } 95 | } 96 | 97 | int ModelGen::TensorQuantizationZeroPoint( 98 | const QuantizationParameters& q) { 99 | if (q.zero_point.size() > 0) { 100 | return q.zero_point[0]; 101 | } else { 102 | return 0; 103 | } 104 | } 105 | 106 | std::string ModelGen::GenerateTensorType(const Tensor& tensor, int count) { 107 | std::stringstream ss; 108 | std::string dimensions = TensorDim(tensor.shape()); 109 | 110 | ss << "uint32_t dimensions_" << count << "[] = " << dimensions << ";\n"; 111 | ss << "ANeuralNetworksOperandType operand_type_" << count << " {\n"; 112 | 113 | std::string str_tensor_type = TensorTypeStr(tensor.tensor_type()); 114 | int dimension_count = tensor.shape().size(); 115 | 116 | float scale; 117 | int zero_point; 118 | 119 | if (tensor.HasQuantization()) { 120 | scale = TensorQuantizationScale(tensor.quantization()); 121 | zero_point = TensorQuantizationZeroPoint(tensor.quantization()); 122 | } else { 123 | scale = 0.0f; 124 | zero_point = 0; 125 | } 126 | 127 | ss << " .type = " << str_tensor_type << ",\n"; 128 | ss << " .dimensionCount = " << dimension_count << ",\n"; 129 | ss << " .dimensions = " << "dimensions_" << count << ",\n"; 130 | 131 | if (scale == 0) { 132 | ss << " .scale = ." << scale << "f,\n"; 133 | } else{ 134 | ss << " .scale = " << scale << "f,\n"; 135 | } 136 | 137 | ss << " .zeroPoint = " << zero_point << "\n"; 138 | ss << "};\n\n"; 139 | 140 | return ss.str(); 141 | } 142 | 143 | std::string ModelGen::CheckStatus(const boost::format& msg) { 144 | std::stringstream ss; 145 | 146 | // nnapi always check the result of operation 147 | ss << "if (status != ANEURALNETWORKS_NO_ERROR) {\n"; 148 | ss << " __android_log_print(ANDROID_LOG_ERROR, LOG_TAG,\n"; 149 | ss << " \"" << boost::str(msg) << "\");\n"; 150 | ss << " return false;\n"; 151 | ss << "}\n\n"; 152 | 153 | return ss.str(); 154 | } 155 | 156 | std::string ModelGen::GenerateTensorsCode() { 157 | Graph& graph = model_.graph(); 158 | std::stringstream ss; 159 | 160 | int count = 0; 161 | for (const auto& tensor: graph.Tensors()) { 162 | // insert operand type 163 | ss << GenerateTensorType(tensor, count); 164 | 165 | // insert nnapi operand 166 | ss << "status = ANeuralNetworksModel_addOperand(model, "; 167 | ss << "&operand_type_" << count << ");\n"; 168 | ss << CheckStatus(boost::format("ANeuralNetworksModel_addOperand failed" 169 | "for operand %1%")%count); 170 | 171 | size_t buf_size = tensor.buffer().Data().size(); 172 | 173 | if (buf_size > 0) { 174 | // get tensor size 175 | ss << "tensor_size = " << buf_size << ";\n"; 176 | 177 | // insert operand value 178 | ss << "status = ANeuralNetworksModel_setOperandValueFromMemory(model, "; 179 | ss << count << ", mem, offset, tensor_size);\n\n"; 180 | ss << CheckStatus(boost::format( 181 | "ANeuralNetworksModel_setOperandValueFromMemory " 182 | "failed for operand %1%")%count); 183 | 184 | // calculates the offset 185 | ss << "offset += tensor_size;\n"; 186 | } 187 | 188 | ++count; 189 | } 190 | 191 | count_operands_ = count; 192 | tensor_pos_ = graph.Tensors().size(); 193 | 194 | return ss.str(); 195 | } 196 | 197 | std::string ModelGen::GenerateOpInputs(const std::vector& inputs, 198 | size_t num_params) { 199 | // inputs loop 200 | std::string str_in = ""; 201 | 202 | // insert data params like conv filters params 203 | for (const auto& in_value : inputs) { 204 | str_in += " " + std::to_string(in_value) + ","; 205 | } 206 | 207 | // insert hiperparams like conv stride 208 | size_t tensor_start_pos = tensor_pos_; 209 | for (; tensor_pos_ < (tensor_start_pos + num_params); tensor_pos_++) { 210 | str_in += " " + std::to_string(tensor_pos_) + ","; 211 | } 212 | 213 | str_in = str_in.substr(0, str_in.length() - 1); 214 | return str_in; 215 | } 216 | 217 | std::string ModelGen::GenerateOpOutputs(const std::vector& outputs) { 218 | // outputs loop 219 | std::string str_out = ""; 220 | 221 | for (const auto& out_value : outputs) { 222 | str_out += " " + std::to_string(out_value) + ","; 223 | } 224 | 225 | str_out = str_out.substr(0, str_out.length() - 1); 226 | return str_out; 227 | } 228 | 229 | std::string ModelGen::OpTypeStr(BuiltinOperator op_type) { 230 | switch (op_type) { 231 | case BuiltinOperator::ADD: 232 | return "ANEURALNETWORKS_ADD"; 233 | break; 234 | 235 | case BuiltinOperator::AVERAGE_POOL_2D: 236 | return "ANEURALNETWORKS_AVERAGE_POOL_2D"; 237 | break; 238 | 239 | case BuiltinOperator::MAX_POOL_2D: 240 | return "ANEURALNETWORKS_MAX_POOL_2D"; 241 | break; 242 | 243 | case BuiltinOperator::L2_POOL_2D: 244 | return "ANEURALNETWORKS_L2_POOL_2D"; 245 | break; 246 | 247 | case BuiltinOperator::CONV_2D: 248 | return "ANEURALNETWORKS_CONV_2D"; 249 | break; 250 | 251 | case BuiltinOperator::RELU: 252 | return "ANEURALNETWORKS_RELU"; 253 | break; 254 | 255 | case BuiltinOperator::RELU6: 256 | return "ANEURALNETWORKS_RELU6"; 257 | break; 258 | 259 | case BuiltinOperator::TANH: 260 | return "ANEURALNETWORKS_TANH"; 261 | break; 262 | 263 | case BuiltinOperator::LOGISTIC: 264 | return "ANEURALNETWORKS_LOGISTIC"; 265 | break; 266 | 267 | case BuiltinOperator::DEPTHWISE_CONV_2D: 268 | return "ANEURALNETWORKS_DEPTHWISE_CONV_2D"; 269 | break; 270 | 271 | case BuiltinOperator::CONCATENATION: 272 | return "ANEURALNETWORKS_CONCATENATION"; 273 | break; 274 | 275 | case BuiltinOperator::SOFTMAX: 276 | return "ANEURALNETWORKS_SOFTMAX"; 277 | break; 278 | 279 | case BuiltinOperator::FULLY_CONNECTED: 280 | return "ANEURALNETWORKS_FULLY_CONNECTED"; 281 | break; 282 | 283 | case BuiltinOperator::RESHAPE: 284 | return "ANEURALNETWORKS_RESHAPE"; 285 | break; 286 | 287 | case BuiltinOperator::SPACE_TO_DEPTH: 288 | return "ANEURALNETWORKS_SPACE_TO_DEPTH"; 289 | break; 290 | 291 | case BuiltinOperator::LSTM: 292 | return "ANEURALNETWORKS_LSTM"; 293 | break; 294 | 295 | default: 296 | FATAL(boost::format("Not supported type on NNAPI")) 297 | } 298 | } 299 | 300 | std::string ModelGen::AddScalarInt32(int value) { 301 | std::stringstream ss; 302 | 303 | ss << "CHECK_ADD_SCALAR(AddScalarInt32(" << count_operands_ << ", " 304 | << value << "))\n"; 305 | 306 | ++count_operands_; 307 | return ss.str(); 308 | } 309 | 310 | std::string ModelGen::AddScalarFloat32(float value) { 311 | std::stringstream ss; 312 | 313 | ss << "CHECK_ADD_SCALAR(AddScalarFloat32(" << count_operands_ << ", " 314 | << value << "))\n"; 315 | 316 | ++count_operands_; 317 | return ss.str(); 318 | } 319 | 320 | std::tuple ModelGen::OpParams(const Operator& op) { 321 | std::stringstream ss; 322 | size_t num_params = 0; 323 | 324 | auto check = [&op](BuiltinOptionsType type) { 325 | if (op.builtin_op().type != type) { 326 | FATAL(boost::format("Operator node type wrong")); 327 | } 328 | }; 329 | 330 | switch (op.op_code().builtin_code) { 331 | case BuiltinOperator::ADD: 332 | ss << AddScalarInt32(0); 333 | num_params = 1; 334 | break; 335 | 336 | case BuiltinOperator::L2_POOL_2D: 337 | case BuiltinOperator::MAX_POOL_2D: 338 | case BuiltinOperator::AVERAGE_POOL_2D: { 339 | check(BuiltinOptionsType::Pool2DOptions); 340 | const Pool2DOptions& pool_options = static_cast( 341 | op.builtin_op()); 342 | 343 | ss << AddScalarInt32(static_cast(pool_options.padding)); 344 | ss << AddScalarInt32(pool_options.stride_w); 345 | ss << AddScalarInt32(pool_options.stride_h); 346 | ss << AddScalarInt32(pool_options.filter_width); 347 | ss << AddScalarInt32(pool_options.filter_height); 348 | ss << AddScalarInt32(static_cast( 349 | pool_options.fused_activation_function)); 350 | num_params = 6; 351 | break; 352 | } 353 | 354 | case BuiltinOperator::CONV_2D: { 355 | check(BuiltinOptionsType::Conv2DOptions); 356 | const Conv2DOptions& conv_options = static_cast( 357 | op.builtin_op()); 358 | 359 | ss << AddScalarInt32(static_cast(conv_options.padding)); 360 | ss << AddScalarInt32(conv_options.stride_w); 361 | ss << AddScalarInt32(conv_options.stride_h); 362 | ss << AddScalarInt32(static_cast( 363 | conv_options.fused_activation_function)); 364 | num_params = 4; 365 | break; 366 | } 367 | 368 | case BuiltinOperator::DEPTHWISE_CONV_2D: { 369 | check(BuiltinOptionsType::DepthwiseConv2DOptions); 370 | const DepthwiseConv2DOptions& dept_conv_options = 371 | static_cast(op.builtin_op()); 372 | 373 | ss << AddScalarInt32(static_cast(dept_conv_options.padding)); 374 | ss << AddScalarInt32(dept_conv_options.stride_w); 375 | ss << AddScalarInt32(dept_conv_options.stride_h); 376 | ss << AddScalarInt32(dept_conv_options.depth_multiplier); 377 | ss << AddScalarInt32(static_cast( 378 | dept_conv_options.fused_activation_function)); 379 | num_params = 5; 380 | break; 381 | } 382 | 383 | case BuiltinOperator::FULLY_CONNECTED: { 384 | check(BuiltinOptionsType::FullyConnectedOptions); 385 | const FullyConnectedOptions& fully_con_options = 386 | static_cast(op.builtin_op()); 387 | 388 | ss << AddScalarInt32(static_cast( 389 | fully_con_options.fused_activation_function)); 390 | num_params = 1; 391 | break; 392 | } 393 | 394 | case BuiltinOperator::CONCATENATION: { 395 | check(BuiltinOptionsType::ConcatenationOptions); 396 | const ConcatenationOptions& concat_options = 397 | static_cast(op.builtin_op()); 398 | 399 | ss << AddScalarInt32(concat_options.axis); 400 | ss << AddScalarInt32(static_cast( 401 | concat_options.fused_activation_function)); 402 | num_params = 2; 403 | break; 404 | } 405 | 406 | case BuiltinOperator::SOFTMAX: { 407 | check(BuiltinOptionsType::SoftmaxOptions); 408 | const SoftmaxOptions& softmax_options = 409 | static_cast(op.builtin_op()); 410 | 411 | ss << AddScalarFloat32(softmax_options.beta); 412 | num_params = 1; 413 | break; 414 | } 415 | 416 | case BuiltinOperator::SPACE_TO_DEPTH: { 417 | check(BuiltinOptionsType::SpaceToDepthOptions); 418 | const SpaceToDepthOptions& space2depth_options = 419 | static_cast(op.builtin_op()); 420 | 421 | ss << AddScalarInt32(space2depth_options.block_size); 422 | num_params = 1; 423 | break; 424 | } 425 | 426 | case BuiltinOperator::LSTM: { 427 | check(BuiltinOptionsType::LSTMOptions); 428 | // TODO: Check better on TfLite how lstm parametes is filled 429 | const LSTMOptions& lstm_options = static_cast( 430 | op.builtin_op()); 431 | 432 | ss << AddScalarInt32(static_cast( 433 | lstm_options.fused_activation_function)); 434 | ss << AddScalarInt32(lstm_options.cell_clip); 435 | ss << AddScalarInt32(lstm_options.proj_clip); 436 | num_params = 3; 437 | break; 438 | } 439 | 440 | default: 441 | num_params = 0; 442 | } 443 | 444 | return std::tuple(num_params, ss.str()); 445 | } 446 | 447 | std::string ModelGen::GenerateOpCode() { 448 | Graph& graph = model_.graph(); 449 | std::stringstream ss; 450 | 451 | int count = 0; 452 | for (const auto& op: graph.Operators()) { 453 | size_t num_params; 454 | std::string str_params; 455 | std::tie(num_params, str_params) = OpParams(op); 456 | ss << str_params << "\n"; 457 | ss << "uint32_t input_operands_" << count << "[] = { "; 458 | ss << GenerateOpInputs(op.inputs(), num_params) << " };\n"; 459 | 460 | ss << "uint32_t output_operands_" << count << "[] = {"; 461 | ss << GenerateOpOutputs(op.outputs()) << " };\n\n"; 462 | 463 | ss << "status = ANeuralNetworksModel_addOperation(model, "; 464 | ss << OpTypeStr(op.op_code().builtin_code) << ", sizeof(input_operands_" ; 465 | ss << count <<"), input_operands_" << count << ", "; 466 | ss << "sizeof(output_operands_" << count << "), "; 467 | ss << "output_operands_" << count << ");\n"; 468 | 469 | ss << CheckStatus(boost::format( 470 | "ANeuralNetworksModel_addOperation failed for operation %1%")%count); 471 | 472 | ++count; 473 | } 474 | 475 | return ss.str(); 476 | } 477 | 478 | std::string ModelGen::GenerateInputsAndOutputs() { 479 | Graph& graph = model_.graph(); 480 | std::stringstream ss; 481 | 482 | size_t num_inputs = graph.Inputs().size();; 483 | ss << "uint32_t input_indexes[" << num_inputs << "] = {"; 484 | 485 | std::string str_input; 486 | for (int i : graph.Inputs()) { 487 | str_input += " " + std::to_string(i) + ","; 488 | } 489 | 490 | str_input = str_input.substr(0, str_input.length() - 1); 491 | ss << str_input << " };\n"; 492 | 493 | size_t num_outputs = graph.Outputs().size(); 494 | ss << "uint32_t output_indexes[" << num_outputs << "] = {"; 495 | 496 | std::string str_output; 497 | for (int i : graph.Outputs()) { 498 | str_output += " " + std::to_string(i) + ","; 499 | } 500 | 501 | str_output = str_output.substr(0, str_output.length() - 1); 502 | ss << str_output << " };\n"; 503 | 504 | ss << "ANeuralNetworksModel_identifyInputsAndOutputs(model, " 505 | << num_inputs << ", input_indexes, " << num_outputs 506 | << ", output_indexes);\n"; 507 | 508 | return ss.str(); 509 | } 510 | 511 | std::string TensorType(const Tensor& tensor) { 512 | switch (tensor.tensor_type()) { 513 | case TensorType::FLOAT32: 514 | return "float32_t"; 515 | break; 516 | 517 | case TensorType::INT32: 518 | return "int32_t"; 519 | break; 520 | 521 | case TensorType::UINT8: 522 | return "int8_t"; 523 | break; 524 | 525 | default: 526 | FATAL("Tensor type not valid for Android NNAPI") 527 | } 528 | } 529 | 530 | int ModelGen::TensorSize(const Tensor& tensor) { 531 | int size = 1; 532 | for (int shape_i : tensor.shape()) { 533 | size *= shape_i; 534 | 535 | if (tensor.tensor_type() == TensorType::FLOAT32 || 536 | tensor.tensor_type() == TensorType::INT32) { 537 | size *= 4; 538 | } 539 | } 540 | 541 | return size; 542 | } 543 | 544 | std::string ModelGen::GenerateInputFunctions() { 545 | Graph& graph = model_.graph(); 546 | std::string str_input; 547 | 548 | str_input += "bool SetInput(const int8_t *buffer) {\n"; 549 | 550 | int start = 0; 551 | for (int i : graph.Inputs()) { 552 | const Tensor& tensor = graph.Tensors()[i]; 553 | int size = TensorSize(tensor); 554 | 555 | str_input += " int status = ANeuralNetworksExecution_setInput(run, " + 556 | std::to_string(i) + ", NULL, &buffer[" + std::to_string(start) + 557 | "], " + std::to_string(size) + ");\n"; 558 | 559 | str_input += CheckStatus(boost::format( 560 | "ANeuralNetworksExecution_setInput failed")); 561 | 562 | start += size; 563 | } 564 | 565 | str_input += " return true;\n}\n\n"; 566 | 567 | return str_input; 568 | } 569 | 570 | std::string ModelGen::GenerateOutputFunctions() { 571 | Graph& graph = model_.graph(); 572 | std::string str_output; 573 | 574 | str_output += "bool SetOutput(int8_t *buffer) {\n"; 575 | 576 | int start = 0; 577 | for (int i : graph.Outputs()) { 578 | const Tensor& tensor = graph.Tensors()[i]; 579 | int size = TensorSize(tensor); 580 | 581 | str_output += " int status = ANeuralNetworksExecution_setOutput(run, " + 582 | std::to_string(i) + ", NULL, &buffer[" + std::to_string(start) + 583 | "], " + std::to_string(size) + ");\n"; 584 | 585 | str_output += CheckStatus(boost::format( 586 | "ANeuralNetworksExecution_setOutput failed")); 587 | 588 | start += size; 589 | } 590 | 591 | str_output += " return true;\n}\n\n"; 592 | 593 | return str_output; 594 | } 595 | 596 | std::string ModelGen::GenerateHeader() { 597 | std::string str = 598 | #include "templates/top_nn_cc.tpl" 599 | ; 600 | return str; 601 | } 602 | 603 | std::string ModelGen::Assembler() { 604 | std::string code; 605 | code = GenerateHeader(); 606 | code += GenerateTensorsCode(); 607 | code += GenerateOpCode(); 608 | code += GenerateInputsAndOutputs(); 609 | 610 | // close model function 611 | code += "return true;\n}\n\n"; 612 | 613 | code += GenerateInputFunctions(); 614 | code += GenerateOutputFunctions(); 615 | 616 | // close namespace 617 | code += "\n}\n\n"; 618 | 619 | return code; 620 | } 621 | 622 | std::string ModelGenHeader::GenerateHeader() { 623 | std::string str = 624 | #include "templates/top_nn_h.tpl" 625 | ; 626 | return str; 627 | } 628 | 629 | std::string ModelGenHeader::Assembler() { 630 | std::string str = GenerateHeader(); 631 | str += "}"; 632 | 633 | return str; 634 | } 635 | 636 | template 637 | int ModelGenJni::TotalSize(Fn&& fn) { 638 | Graph& graph = model_.graph(); 639 | int total_size = 0; 640 | 641 | for (int i : fn()) { 642 | int size = 1; 643 | const Tensor& tensor = graph.Tensors()[i]; 644 | 645 | for (int shape_i : tensor.shape()) { 646 | size *= shape_i; 647 | 648 | if (tensor.tensor_type() == TensorType::FLOAT32 || 649 | tensor.tensor_type() == TensorType::INT32) { 650 | size *= 4; 651 | } 652 | } 653 | total_size += size; 654 | } 655 | 656 | return total_size; 657 | } 658 | 659 | std::string ModelGenJni::GenerateJni() { 660 | std::string str = 661 | #include "templates/jni.tpl" 662 | ; 663 | 664 | Graph& graph = model_.graph(); 665 | 666 | auto fn_in = std::bind(&Graph::Inputs, &graph); 667 | int total_input_size = TotalSize(fn_in); 668 | 669 | auto fn_out = std::bind(&Graph::Outputs, &graph); 670 | int total_output_size = TotalSize(fn_out); 671 | 672 | boost::replace_all(str, "@TOTAL_INPUT_SIZE", 673 | std::to_string(total_input_size)); 674 | boost::replace_all(str, "@TOTAL_OUTPUT_SIZE", 675 | std::to_string(total_output_size)); 676 | boost::replace_all(str, "@JAVA_PACKAGE", java_package_); 677 | 678 | return str; 679 | } 680 | 681 | std::string ModelGenJni::Assembler() { 682 | boost::replace_all(java_package_, ".", "_"); 683 | std::string str = GenerateJni(); 684 | return str; 685 | } 686 | 687 | void CppGen::GenFiles(const boost::filesystem::path& path, 688 | const std::string& java_path) { 689 | GenTensorsDataFile(path); 690 | GenCppFile(path); 691 | GenHFile(path); 692 | GenJniFile(path, java_path); 693 | } 694 | 695 | void CppGen::GenTensorsDataFile(const boost::filesystem::path& path) { 696 | const boost::filesystem::path& fname("weights_biases.bin"); 697 | std::string str_path = (path / fname).string(); 698 | std::ofstream tensors_file(str_path, 699 | std::ofstream::out | std::ofstream::binary); 700 | 701 | if (!tensors_file.is_open()) { 702 | FATAL(boost::format("Fail on create weights_biases.bin file on: %1%") 703 | %str_path) 704 | } 705 | 706 | TensorsHeader tensor_header(model_); 707 | std::string buf = tensor_header.Assembler(); 708 | tensors_file.write(buf.c_str(), buf.length()); 709 | tensors_file.close(); 710 | 711 | std::cout << "File: " << str_path << " generated\n"; 712 | } 713 | 714 | void CppGen::GenCppFile(const boost::filesystem::path& path) { 715 | const boost::filesystem::path& fname("nn.cc"); 716 | std::string str_path = (path / fname).string(); 717 | std::ofstream cc_file(str_path, std::ofstream::out | std::ofstream::binary); 718 | 719 | if (!cc_file.is_open()) { 720 | FATAL("Fail on create nn.cc file") 721 | } 722 | 723 | ModelGen model(model_); 724 | std::string code = model.Assembler(); 725 | cc_file.write(code.c_str(), code.length()); 726 | cc_file.close(); 727 | 728 | std::cout << "File: " << str_path << " generated\n"; 729 | } 730 | 731 | void CppGen::GenHFile(const boost::filesystem::path& path) { 732 | const boost::filesystem::path& fname("nn.h"); 733 | std::string str_path = (path / fname).string(); 734 | 735 | std::ofstream cc_file(str_path, std::ofstream::out | std::ofstream::binary); 736 | 737 | if (!cc_file.is_open()) { 738 | FATAL("Fail on create nn.h file") 739 | } 740 | 741 | ModelGenHeader model(model_); 742 | std::string code = model.Assembler(); 743 | cc_file.write(code.c_str(), code.length()); 744 | cc_file.close(); 745 | 746 | std::cout << "File: " << str_path << " generated\n"; 747 | } 748 | 749 | void CppGen::GenJniFile(const boost::filesystem::path& path, 750 | const std::string& java_package) { 751 | const boost::filesystem::path& fname("jni.cc"); 752 | std::string str_path = (path / fname).string(); 753 | std::ofstream jni_file(str_path, std::ofstream::out | std::ofstream::binary); 754 | 755 | if (!jni_file.is_open()) { 756 | FATAL("Fail on create nn.h file") 757 | } 758 | 759 | ModelGenJni model(model_, java_package); 760 | std::string code = model.Assembler(); 761 | jni_file.write(code.c_str(), code.length()); 762 | jni_file.close(); 763 | 764 | std::cout << "File: " << str_path << " generated\n"; 765 | } 766 | 767 | } 768 | -------------------------------------------------------------------------------- /src/cpp-gen.h: -------------------------------------------------------------------------------- 1 | #ifndef nnt_CCP_GEN_H 2 | #define nnt_CCP_GEN_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "model.h" 11 | 12 | namespace nnt { 13 | 14 | class TensorsHeader { 15 | public: 16 | TensorsHeader(Model& model): model_(model) {} 17 | 18 | std::string Assembler(); 19 | 20 | private: 21 | std::string Generate(); 22 | 23 | Model& model_; 24 | }; 25 | 26 | class ModelGen { 27 | public: 28 | ModelGen(Model& model): model_(model), tensor_pos_(0) {} 29 | 30 | std::string Assembler(); 31 | 32 | private: 33 | std::string Generate(); 34 | std::string GenerateTensorType(const Tensor& tensor, int count); 35 | std::string GenerateTensorsCode(); 36 | std::string TensorTypeStr(TensorType type); 37 | std::string TensorCppTypeStr(TensorType type); 38 | std::string TensorDim(const std::vector& dim); 39 | float TensorQuantizationScale(const QuantizationParameters& q); 40 | int TensorQuantizationZeroPoint(const QuantizationParameters& q); 41 | std::string CheckStatus(const boost::format& msg); 42 | 43 | std::string GenerateOpCode(); 44 | std::string GenerateOpInputs(const std::vector& inputs, 45 | size_t num_params); 46 | std::string GenerateOpOutputs(const std::vector& outputs); 47 | std::string OpTypeStr(BuiltinOperator op_type); 48 | std::tuple OpParams(const Operator& op); 49 | std::string GenerateInputsAndOutputs(); 50 | std::string GenerateInputFunctions(); 51 | std::string GenerateOutputFunctions(); 52 | std::string GenerateHeader(); 53 | std::string AddScalarInt32(int value); 54 | std::string AddScalarFloat32(float value); 55 | int TensorSize(const Tensor& tensor); 56 | 57 | Model& model_; 58 | size_t tensor_pos_; 59 | int count_operands_; 60 | }; 61 | 62 | class ModelGenHeader { 63 | public: 64 | ModelGenHeader(Model& model): model_(model) {} 65 | 66 | std::string Assembler(); 67 | private: 68 | std::string GenerateHeader(); 69 | Model& model_; 70 | }; 71 | 72 | class ModelGenJni { 73 | public: 74 | ModelGenJni(Model& model, const std::string& java_package) 75 | : model_(model) 76 | , java_package_(java_package) {} 77 | 78 | std::string Assembler(); 79 | private: 80 | std::string GenerateJni(); 81 | 82 | template 83 | int TotalSize(Fn&& fn); 84 | 85 | Model& model_; 86 | std::string java_package_; 87 | }; 88 | 89 | class CppGen { 90 | public: 91 | CppGen(Model& model): model_(model) {} 92 | 93 | void GenFiles(const boost::filesystem::path& path, 94 | const std::string& java_path); 95 | 96 | private: 97 | void GenTensorsDataFile(const boost::filesystem::path& path); 98 | void GenCppFile(const boost::filesystem::path& path); 99 | void GenHFile(const boost::filesystem::path& path); 100 | void GenJniFile(const boost::filesystem::path& path, 101 | const std::string& java_package); 102 | 103 | Model& model_; 104 | }; 105 | 106 | } 107 | 108 | #endif // nnt_CCP_GEN_H 109 | -------------------------------------------------------------------------------- /src/dump.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "dump.h" 5 | 6 | namespace nnt { 7 | 8 | void DumpGraph::Print() { 9 | Graph& graph = model_.graph(); 10 | 11 | std::cout << "Inputs:"; 12 | for (const auto& i : graph.Inputs()) { 13 | std::cout << " " << i; 14 | } 15 | std::cout << "\n"; 16 | 17 | std::cout << "Outputs:"; 18 | for (const auto& i : graph.Outputs()) { 19 | std::cout << " " << i; 20 | } 21 | std::cout << "\n"; 22 | 23 | std::cout << "\nTensors:\n"; 24 | int count = 0; 25 | for (const auto& tensor: graph.Tensors()) { 26 | std::cout << "[" << count++ << "] "; 27 | std::cout << "name: " << tensor.name() << " [ "; 28 | for (const auto&i : tensor.shape()) { 29 | std::cout << i << " "; 30 | } 31 | std::cout << "] "; 32 | std::cout << "buffer: " << tensor.buffer_index() << "\n"; 33 | } 34 | 35 | std::cout << "\nOperators:\n"; 36 | for (const auto& op: graph.Operators()) { 37 | std::cout << "index: " << op.index() << ", "; 38 | std::cout << "builtin_op: " << op.builtin_op_str() << ", "; 39 | 40 | std::cout << "inputs:"; 41 | for (const auto& i : op.inputs()) { 42 | std::cout << " " << i; 43 | } 44 | std::cout << ", "; 45 | 46 | std::cout << "outputs:"; 47 | for (const auto& i : op.outputs()) { 48 | std::cout << " " << i; 49 | } 50 | std::cout << "\n"; 51 | } 52 | } 53 | 54 | std::string DumpGraph::TensorType(const Tensor& tensor) { 55 | switch (tensor.tensor_type()) { 56 | case TensorType::FLOAT32: 57 | return std::string("FLOAT32"); 58 | break; 59 | 60 | case TensorType::FLOAT16: 61 | return std::string("FLOAT16"); 62 | break; 63 | 64 | case TensorType::INT32: 65 | return std::string("INT32"); 66 | break; 67 | 68 | case TensorType::UINT8: 69 | return std::string("UINT8"); 70 | break; 71 | 72 | case TensorType::INT64: 73 | return std::string("INT64"); 74 | break; 75 | 76 | case TensorType::STRING: 77 | return std::string("STRING"); 78 | break; 79 | #ifdef NEWER_TENSORFLOW 80 | case TensorType::BOOL: 81 | return std::string("BOOL"); 82 | break; 83 | #endif 84 | } 85 | 86 | return std::string(); 87 | } 88 | 89 | template 90 | std::string VectorToStr(const std::vector& vec) { 91 | std::stringstream ss; 92 | 93 | ss << "["; 94 | for (const auto&i : vec) { 95 | ss << i << ", "; 96 | } 97 | 98 | std::string str = ss.str(); 99 | str = str.substr(0, str.length() - 2); 100 | str += "]"; 101 | return str; 102 | } 103 | 104 | std::string DumpGraph::Info() { 105 | std::stringstream ss; 106 | 107 | Graph& graph = model_.graph(); 108 | const auto& tensors = graph.Tensors(); 109 | 110 | ss << "::Inputs::\n"; 111 | for (const auto& i : graph.Inputs()) { 112 | ss << " " << tensors[i].name() << "<" << TensorType(tensors[i]) << ">" 113 | << " [" << TensorShape(tensors[i]) << "]"; 114 | 115 | if (tensors[i].HasQuantization()) { 116 | ss << " (quantized)\n"; 117 | const QuantizationParameters& quant = tensors[i].quantization(); 118 | ss << " └─ Quant: {min:" << VectorToStr(quant.min) << ", max:" 119 | << VectorToStr(quant.max) << ", scale: " << VectorToStr(quant.scale) 120 | << ", zero_point:" << VectorToStr(quant.zero_point) << "}\n"; 121 | } else { 122 | ss << "\n"; 123 | } 124 | } 125 | 126 | ss << "\n"; 127 | ss << "::Outputs::\n"; 128 | for (const auto& i : graph.Outputs()) { 129 | ss << " " << tensors[i].name() << "<" << TensorType(tensors[i]) << ">" 130 | << " [" << TensorShape(tensors[i]) << "]"; 131 | 132 | if (tensors[i].HasQuantization()) { 133 | ss << " (quantized)\n"; 134 | } else { 135 | ss << "\n"; 136 | } 137 | } 138 | 139 | ss << "\n"; 140 | return ss.str(); 141 | } 142 | 143 | std::string DumpGraph::FormatTensorName(const std::string& name) { 144 | size_t pos = name.find_last_of('/'); 145 | 146 | if (pos != std::string::npos) { 147 | return name.substr(pos); 148 | } 149 | 150 | return name; 151 | } 152 | 153 | std::string DumpGraph::TensorShape(const Tensor& tensor) { 154 | std::stringstream ss; 155 | 156 | for (const auto&i : tensor.shape()) { 157 | ss << i << ", "; 158 | } 159 | 160 | std::string str = ss.str(); 161 | str = str.substr(0, str.length() - 2); 162 | return str; 163 | } 164 | 165 | std::string DumpGraph::Dot() { 166 | std::stringstream ss; 167 | 168 | Graph& graph = model_.graph(); 169 | 170 | ss << "digraph {\n"; 171 | 172 | int count = 0; 173 | for (const auto& tensor: graph.Tensors()) { 174 | ss << " T" << count++ << " ["; 175 | ss << "shape=box label=\"" << FormatTensorName(tensor.name()); 176 | ss << " [" << TensorShape(tensor) << "]\"]\n"; 177 | } 178 | 179 | count = 0; 180 | for (const auto& op: graph.Operators()) { 181 | ss << " O" << count++ << " ["; 182 | ss << "label=\"" << op.builtin_op_str() << "\"]\n"; 183 | } 184 | 185 | count = 0; 186 | for (const auto& op: graph.Operators()) { 187 | for (const auto& i : op.inputs()) { 188 | ss << " T" << i << " -> " << "O" << count << "\n"; 189 | } 190 | 191 | for (const auto& i : op.outputs()) { 192 | ss << " O" << count << " -> " << "T" << i << "\n"; 193 | } 194 | 195 | ++count; 196 | } 197 | 198 | ss << "}\n"; 199 | 200 | return ss.str(); 201 | } 202 | 203 | } 204 | -------------------------------------------------------------------------------- /src/dump.h: -------------------------------------------------------------------------------- 1 | #ifndef NNT_DUMP_H 2 | #define NNT_DUMP_H 3 | 4 | #include 5 | #include 6 | 7 | #include "model.h" 8 | 9 | namespace nnt { 10 | 11 | class DumpGraph { 12 | public: 13 | DumpGraph(Model& model): model_(model) {} 14 | 15 | void Print(); 16 | 17 | std::string TensorShape(const Tensor& tensor); 18 | 19 | std::string Dot(); 20 | 21 | std::string Info(); 22 | 23 | std::string TensorType(const Tensor& tensor); 24 | 25 | private: 26 | std::string FormatTensorName(const std::string& name); 27 | 28 | Model& model_; 29 | }; 30 | 31 | } 32 | 33 | #endif 34 | -------------------------------------------------------------------------------- /src/exception.h: -------------------------------------------------------------------------------- 1 | #ifndef NNC_EXCEPTION_H 2 | #define NNC_EXCEPTION_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | namespace nnt { 9 | 10 | class Exception : public std::exception { 11 | public: 12 | Exception(const boost::format& msg) 13 | : msg_(boost::str(msg)) {} 14 | 15 | Exception(const std::string& msg) 16 | : msg_(msg) {} 17 | 18 | virtual ~Exception() noexcept = default; 19 | 20 | Exception(const Exception& rt_err) 21 | : msg_(rt_err.msg_) {} 22 | 23 | Exception& operator=(const Exception& rt_err) { 24 | msg_ = rt_err.msg_; 25 | 26 | return *this; 27 | } 28 | 29 | /** 30 | * @return the error description and the context as a text string. 31 | */ 32 | virtual const char* what() const noexcept { 33 | return msg_.c_str(); 34 | } 35 | 36 | const std::string& msg() const noexcept { 37 | return msg_; 38 | } 39 | 40 | std::string msg_; 41 | }; 42 | 43 | #define FATAL(msg_arg) \ 44 | throw Exception(msg_arg); 45 | 46 | } // nnt 47 | 48 | #endif // NNC_EXCEPTION_H 49 | -------------------------------------------------------------------------------- /src/main.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "model.h" 5 | #include "cpp-gen.h" 6 | #include "dump.h" 7 | #include "exception.h" 8 | 9 | void GenerateJniFiles(const std::string& str_model, const std::string& str_path, 10 | const std::string& java_package) { 11 | nnt::Model model(str_model); 12 | nnt::CppGen cpp(model); 13 | boost::filesystem::path path(str_path); 14 | cpp.GenFiles(path, java_package); 15 | std::cout << "Finish!\n"; 16 | } 17 | 18 | void GenerateDotFile(const std::string& filename, 19 | const std::string& str_model) { 20 | nnt::Model model(str_model); 21 | nnt::DumpGraph dump(model); 22 | 23 | std::ofstream dot_file(filename, std::ofstream::out); 24 | 25 | if (!dot_file.is_open()) { 26 | std::cerr << "Fail on create dot file: '" << filename << "'\n"; 27 | return; 28 | } 29 | 30 | std::string dot_src = dump.Dot(); 31 | dot_file.write(dot_src.c_str(), dot_src.length()); 32 | dot_file.close(); 33 | 34 | std::cout << "Dot file: '" << filename << "' generated.\n"; 35 | } 36 | 37 | void Info(const std::string& str_model) { 38 | nnt::Model model(str_model); 39 | nnt::DumpGraph dump(model); 40 | std::cout << dump.Info(); 41 | } 42 | 43 | int main(int argc, char **argv) { 44 | namespace po = boost::program_options; 45 | std::string str_path; 46 | std::string java_package; 47 | std::string str_model; 48 | std::string str_dot; 49 | bool flag_info; 50 | 51 | try { 52 | po::options_description desc{"Options"}; 53 | desc.add_options() 54 | ("help,h", "Help screen") 55 | ("info,i", po::bool_switch(&flag_info), "Info about model") 56 | ("dot,d", po::value(), "Generate dot file") 57 | ("model,m", po::value(), "flatbuffer neural network model") 58 | ("path,p", po::value(), "store generated files on this path") 59 | ("javapackage,j", po::value(), "java package for JNI"); 60 | 61 | po::variables_map vm; 62 | po::store(parse_command_line(argc, argv, desc), vm); 63 | po::notify(vm); 64 | 65 | if (vm.count("help")) { 66 | std::cout << desc << '\n'; 67 | return 0; 68 | } 69 | 70 | if (!vm.count("model")) { 71 | std::cerr << "--model must not be empty" << '\n'; 72 | std::cerr << desc << '\n'; 73 | return 0; 74 | } 75 | 76 | str_model = vm["model"].as(); 77 | 78 | if (flag_info) { 79 | Info(str_model); 80 | return 0; 81 | } 82 | 83 | if (vm.count("dot")) { 84 | str_dot = vm["dot"].as(); 85 | GenerateDotFile(str_dot, str_model); 86 | return 0; 87 | } 88 | 89 | if (vm.count("path")) { 90 | str_path = vm["path"].as(); 91 | } else { 92 | str_path = "./"; 93 | } 94 | 95 | if (!vm.count("javapackage")) { 96 | std::cerr << "--javapackage must not be empty" << '\n'; 97 | std::cerr << desc << '\n'; 98 | return 0; 99 | } 100 | 101 | java_package = vm["javapackage"].as(); 102 | 103 | GenerateJniFiles(str_model, str_path, java_package); 104 | } catch (const boost::program_options::error &e) { 105 | std::cerr << "Error: " << e.what() << '\n'; 106 | } catch (const nnt::Exception& e) { 107 | std::cerr << "Error: " << e.what() << '\n'; 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /src/model.cc: -------------------------------------------------------------------------------- 1 | #include "model.h" 2 | 3 | #include 4 | #include 5 | 6 | #include "exception.h" 7 | 8 | namespace nnt { 9 | 10 | FlatBufferModel::FlatBufferModel(const std::string& fname) { 11 | FILE* file = fopen(fname.c_str(), "rb"); 12 | fseek(file, 0L, SEEK_END); 13 | int length = ftell(file); 14 | fseek(file, 0L, SEEK_SET); 15 | data_ = new char[length]; 16 | fread(data_, sizeof(char), length, file); 17 | fclose(file); 18 | len_ = length; 19 | } 20 | 21 | FlatBufferModel::~FlatBufferModel() { 22 | delete data_; 23 | } 24 | 25 | char* FlatBufferModel::data() { 26 | return data_; 27 | } 28 | 29 | int FlatBufferModel::Length() const { 30 | return len_; 31 | } 32 | 33 | Model::Model(const std::string& fname) 34 | : flat_buffers_(fname) 35 | , fb_model_(tflite::GetModel(flat_buffers_.data())) { 36 | PopulateBuffers(); 37 | PopulateOperatorsCode(); 38 | PopulateGraph(); 39 | } 40 | 41 | void Model::PopulateGraphInputs(const tflite::SubGraph* graph) { 42 | std::vector inputs = AssignVector(graph->inputs()); 43 | graph_.SetInputs(std::move(inputs)); 44 | } 45 | 46 | void Model::PopulateGraphOutputs(const tflite::SubGraph* graph) { 47 | std::vector outputs = AssignVector(graph->outputs()); 48 | graph_.SetOutputs(std::move(outputs)); 49 | } 50 | 51 | TensorType Model::ConvertTensorType(tflite::TensorType type) { 52 | switch (type) { 53 | case tflite::TensorType_FLOAT32: 54 | return TensorType::FLOAT32; 55 | break; 56 | 57 | case tflite::TensorType_FLOAT16: 58 | return TensorType::FLOAT16; 59 | break; 60 | 61 | case tflite::TensorType_INT32: 62 | return TensorType::INT32; 63 | break; 64 | 65 | case tflite::TensorType_UINT8: 66 | return TensorType::UINT8; 67 | break; 68 | 69 | case tflite::TensorType_INT64: 70 | return TensorType::INT64; 71 | break; 72 | 73 | case tflite::TensorType_STRING: 74 | return TensorType::STRING; 75 | break; 76 | #ifdef NEWER_TENSORFLOW 77 | case tflite::TensorType_BOOL: 78 | return TensorType::BOOL; 79 | break; 80 | #endif 81 | default: 82 | FATAL("Tensor type not valid") 83 | } 84 | } 85 | 86 | void Model::PopulateGraphTensors(const tflite::SubGraph* graph) { 87 | auto tensors = graph->tensors(); 88 | 89 | // get tensors 90 | for (auto it = tensors->begin(); it != tensors->end(); ++it) { 91 | std::vector vec_shape = AssignVector(it->shape()); 92 | std::string name = it->name()->c_str(); 93 | uint buf_index = it->buffer(); 94 | const Buffer& buffer = buffers_[buf_index]; 95 | 96 | // get quantization 97 | auto quantization = it->quantization(); 98 | std::unique_ptr quantization_ptr( 99 | new QuantizationParameters); 100 | 101 | if (quantization) { 102 | quantization_ptr->min = AssignVector(quantization->min()); 103 | quantization_ptr->max = AssignVector(quantization->max()); 104 | quantization_ptr->scale = AssignVector(quantization->scale()); 105 | quantization_ptr->zero_point = 106 | AssignVector(quantization->zero_point()); 107 | } 108 | 109 | TensorType type = ConvertTensorType(it->type()); 110 | graph_.AddTensor(std::move(Tensor(std::move(vec_shape), type, name, buffer, 111 | buf_index, std::move(quantization_ptr)))); 112 | } 113 | } 114 | 115 | ActivationFunctionType Model::ConvertActivationFunction( 116 | tflite::ActivationFunctionType fn_activation_type) { 117 | switch (fn_activation_type) { 118 | case tflite::ActivationFunctionType_NONE: 119 | return ActivationFunctionType::NONE; 120 | break; 121 | 122 | case tflite::ActivationFunctionType_RELU: 123 | return ActivationFunctionType::RELU; 124 | break; 125 | 126 | case tflite::ActivationFunctionType_RELU_N1_TO_1: 127 | return ActivationFunctionType::NONE; 128 | break; 129 | 130 | case tflite::ActivationFunctionType_RELU6: 131 | return ActivationFunctionType::RELU6; 132 | break; 133 | 134 | case tflite::ActivationFunctionType_TANH: 135 | return ActivationFunctionType::TANH; 136 | break; 137 | 138 | case tflite::ActivationFunctionType_SIGN_BIT: 139 | return ActivationFunctionType::SIGN_BIT; 140 | break; 141 | 142 | default: 143 | return ActivationFunctionType::NONE; 144 | } 145 | } 146 | 147 | Padding Model::ConvertPadding(tflite::Padding padding) { 148 | if (padding == tflite::Padding_SAME) { 149 | return Padding::SAME; 150 | } else if (padding == tflite::Padding_VALID) { 151 | return Padding::VALID; 152 | } else { 153 | return Padding::UNKNOWN; 154 | } 155 | } 156 | 157 | std::unique_ptr Model::MakeNoneOptions( 158 | const tflite::Operator* /*op*/) { 159 | std::unique_ptr option = std::make_unique(); 160 | return option; 161 | } 162 | 163 | std::unique_ptr Model::MakeConv2DOptions( 164 | const tflite::Operator* op) { 165 | auto p = reinterpret_cast( 166 | op->builtin_options()); 167 | 168 | std::unique_ptr option = std::make_unique(); 169 | 170 | option->stride_w = p->stride_w(); 171 | option->stride_h = p->stride_h(); 172 | option->fused_activation_function = ConvertActivationFunction( 173 | p->fused_activation_function()); 174 | option->padding = ConvertPadding(p->padding()); 175 | #ifdef NEWER_TENSORFLOW 176 | option->dilation_w_factor = p->dilation_w_factor(); 177 | option->dilation_h_factor = p->dilation_h_factor(); 178 | #endif 179 | return option; 180 | } 181 | 182 | std::unique_ptr Model::MakePool2DOptions( 183 | const tflite::Operator* op) { 184 | auto p = reinterpret_cast( 185 | op->builtin_options()); 186 | 187 | std::unique_ptr option = std::make_unique(); 188 | 189 | option->stride_w = p->stride_w(); 190 | option->stride_h = p->stride_h(); 191 | option->filter_width = p->filter_width(); 192 | option->filter_height = p->filter_height(); 193 | option->fused_activation_function = ConvertActivationFunction( 194 | p->fused_activation_function()); 195 | option->padding = ConvertPadding(p->padding()); 196 | 197 | return option; 198 | } 199 | 200 | std::unique_ptr Model::MakeDepthwiseConv2DOptions( 201 | const tflite::Operator* op) { 202 | auto p = reinterpret_cast( 203 | op->builtin_options()); 204 | 205 | std::unique_ptr option = 206 | std::make_unique(); 207 | 208 | option->stride_w = p->stride_w(); 209 | option->stride_h = p->stride_h(); 210 | option->depth_multiplier = p->depth_multiplier(); 211 | option->fused_activation_function = ConvertActivationFunction( 212 | p->fused_activation_function()); 213 | option->padding = ConvertPadding(p->padding()); 214 | 215 | return option; 216 | } 217 | 218 | std::unique_ptr Model::MakeConcatEmbeddingsOptions( 219 | const tflite::Operator* op) { 220 | auto p = reinterpret_cast( 221 | op->builtin_options()); 222 | 223 | std::unique_ptr option = 224 | std::make_unique(); 225 | 226 | option->num_channels = p->num_channels(); 227 | 228 | auto num_columns = p->num_columns_per_channel(); 229 | option->num_columns_per_channel = AssignVector(num_columns); 230 | 231 | auto embedding_dim = p->embedding_dim_per_channel(); 232 | option->embedding_dim_per_channel = AssignVector(embedding_dim); 233 | 234 | return option; 235 | } 236 | 237 | std::unique_ptr Model::MakeLSHProjectionOptions( 238 | const tflite::Operator* op) { 239 | auto p = reinterpret_cast( 240 | op->builtin_options()); 241 | 242 | std::unique_ptr option = 243 | std::make_unique(); 244 | 245 | switch (p->type()) { 246 | case tflite::LSHProjectionType_UNKNOWN: 247 | option->type = LSHProjectionType::UNKNOWN; 248 | break; 249 | 250 | case tflite::LSHProjectionType_SPARSE: 251 | option->type = LSHProjectionType::SPARSE; 252 | break; 253 | 254 | case tflite::LSHProjectionType_DENSE: 255 | option->type = LSHProjectionType::DENSE; 256 | break; 257 | } 258 | 259 | return option; 260 | } 261 | 262 | std::unique_ptr Model::MakeSVDFOptions( 263 | const tflite::Operator* op) { 264 | auto p = reinterpret_cast(op->builtin_options()); 265 | 266 | std::unique_ptr option = std::make_unique(); 267 | 268 | option->rank = p->rank(); 269 | option->fused_activation_function = ConvertActivationFunction( 270 | p->fused_activation_function()); 271 | 272 | return option; 273 | } 274 | 275 | std::unique_ptr Model::MakeRNNOptions(const tflite::Operator* op) { 276 | auto p = reinterpret_cast(op->builtin_options()); 277 | 278 | std::unique_ptr option = std::make_unique(); 279 | 280 | option->fused_activation_function = ConvertActivationFunction( 281 | p->fused_activation_function()); 282 | 283 | return option; 284 | } 285 | 286 | std::unique_ptr Model::MakeSequenceRNNOptions( 287 | const tflite::Operator* op) { 288 | auto p = reinterpret_cast( 289 | op->builtin_options()); 290 | 291 | std::unique_ptr option = 292 | std::make_unique(); 293 | 294 | option->time_major = p->time_major(); 295 | option->fused_activation_function = ConvertActivationFunction( 296 | p->fused_activation_function()); 297 | 298 | return option; 299 | } 300 | 301 | std::unique_ptr Model::MakeFullyConnectedOptions( 302 | const tflite::Operator* op) { 303 | auto p = reinterpret_cast( 304 | op->builtin_options()); 305 | 306 | std::unique_ptr option = 307 | std::make_unique(); 308 | 309 | option->fused_activation_function = ConvertActivationFunction( 310 | p->fused_activation_function()); 311 | 312 | return option; 313 | } 314 | 315 | std::unique_ptr Model::MakeSoftmaxOptions( 316 | const tflite::Operator* op) { 317 | auto p = reinterpret_cast( 318 | op->builtin_options()); 319 | 320 | std::unique_ptr option = std::make_unique(); 321 | 322 | option->beta = p->beta(); 323 | 324 | return option; 325 | } 326 | 327 | std::unique_ptr Model::MakeConcatenationOptions( 328 | const tflite::Operator* op) { 329 | auto p = reinterpret_cast( 330 | op->builtin_options()); 331 | 332 | std::unique_ptr option = 333 | std::make_unique(); 334 | 335 | option->axis = p->axis(); 336 | option->fused_activation_function = ConvertActivationFunction( 337 | p->fused_activation_function()); 338 | 339 | return option; 340 | } 341 | 342 | std::unique_ptr Model::MakeAddOptions(const tflite::Operator* op) { 343 | auto p = reinterpret_cast(op->builtin_options()); 344 | 345 | std::unique_ptr option = std::make_unique(); 346 | 347 | option->fused_activation_function = ConvertActivationFunction( 348 | p->fused_activation_function()); 349 | 350 | return option; 351 | } 352 | 353 | std::unique_ptr Model::MakeMulOptions(const tflite::Operator* op) { 354 | auto p = reinterpret_cast(op->builtin_options()); 355 | 356 | std::unique_ptr option = std::make_unique(); 357 | 358 | option->fused_activation_function = ConvertActivationFunction( 359 | p->fused_activation_function()); 360 | 361 | return option; 362 | } 363 | 364 | std::unique_ptr Model::MakeL2NormOptions( 365 | const tflite::Operator* op) { 366 | auto p = reinterpret_cast( 367 | op->builtin_options()); 368 | 369 | std::unique_ptr option = std::make_unique(); 370 | 371 | option->fused_activation_function = ConvertActivationFunction( 372 | p->fused_activation_function()); 373 | 374 | return option; 375 | } 376 | 377 | std::unique_ptr 378 | Model::MakeLocalResponseNormalizationOptions(const tflite::Operator* op) { 379 | auto p = reinterpret_cast( 380 | op->builtin_options()); 381 | 382 | std::unique_ptr option = 383 | std::make_unique(); 384 | 385 | option->radius = p->radius(); 386 | option->bias = p->bias(); 387 | option->alpha = p->alpha(); 388 | option->beta = p->beta(); 389 | 390 | return option; 391 | } 392 | 393 | std::unique_ptr Model::MakeLSTMOptions( 394 | const tflite::Operator* op) { 395 | auto p = reinterpret_cast(op->builtin_options()); 396 | 397 | std::unique_ptr option = std::make_unique(); 398 | 399 | option->cell_clip = p->cell_clip(); 400 | option->proj_clip = p->proj_clip(); 401 | option->fused_activation_function = ConvertActivationFunction( 402 | p->fused_activation_function()); 403 | 404 | return option; 405 | } 406 | 407 | std::unique_ptr Model::MakeResizeBilinearOptions( 408 | const tflite::Operator* op) { 409 | auto p = reinterpret_cast( 410 | op->builtin_options()); 411 | 412 | std::unique_ptr option = 413 | std::make_unique(); 414 | 415 | option->align_corners = p->align_corners(); 416 | 417 | return option; 418 | } 419 | 420 | std::unique_ptr Model::MakeCallOptions( 421 | const tflite::Operator* op) { 422 | auto p = reinterpret_cast(op->builtin_options()); 423 | 424 | std::unique_ptr option = std::make_unique(); 425 | 426 | option->subgraph = p->subgraph(); 427 | 428 | return option; 429 | } 430 | 431 | std::unique_ptr Model::MakePadOptions(const tflite::Operator*) { 432 | std::unique_ptr option = std::make_unique(); 433 | return option; 434 | } 435 | 436 | std::unique_ptr Model::MakeReshapeOptions( 437 | const tflite::Operator* op) { 438 | auto p = reinterpret_cast( 439 | op->builtin_options()); 440 | 441 | std::unique_ptr option = std::make_unique(); 442 | 443 | option->new_shape = AssignVector(p->new_shape()); 444 | 445 | return option; 446 | } 447 | 448 | std::unique_ptr Model::MakeSpaceToBatchNDOptions( 449 | const tflite::Operator*) { 450 | std::unique_ptr option = 451 | std::make_unique(); 452 | 453 | return option; 454 | } 455 | 456 | std::unique_ptr Model::MakeBatchToSpaceNDOptions( 457 | const tflite::Operator*) { 458 | std::unique_ptr option = 459 | std::make_unique(); 460 | 461 | return option; 462 | } 463 | 464 | std::unique_ptr Model::MakeSkipGramOptions( 465 | const tflite::Operator* op) { 466 | auto p = reinterpret_cast( 467 | op->builtin_options()); 468 | 469 | std::unique_ptr option = std::make_unique(); 470 | 471 | option->ngram_size = p->ngram_size(); 472 | option->max_skip_size = p->max_skip_size(); 473 | option->include_all_ngrams = p->include_all_ngrams(); 474 | 475 | return option; 476 | } 477 | 478 | std::unique_ptr Model::MakeSpaceToDepthOptions( 479 | const tflite::Operator* op) { 480 | auto p = reinterpret_cast( 481 | op->builtin_options()); 482 | 483 | std::unique_ptr option = 484 | std::make_unique(); 485 | 486 | option->block_size = p->block_size(); 487 | 488 | return option; 489 | } 490 | 491 | std::unique_ptr Model::MakeSubOptions(const tflite::Operator* op) { 492 | auto p = reinterpret_cast(op->builtin_options()); 493 | 494 | std::unique_ptr option = std::make_unique(); 495 | 496 | option->fused_activation_function = ConvertActivationFunction( 497 | p->fused_activation_function()); 498 | 499 | return option; 500 | } 501 | 502 | std::unique_ptr Model::MakeDivOptions(const tflite::Operator* op) { 503 | auto p = reinterpret_cast(op->builtin_options()); 504 | 505 | std::unique_ptr option = std::make_unique(); 506 | 507 | option->fused_activation_function = ConvertActivationFunction( 508 | p->fused_activation_function()); 509 | 510 | return option; 511 | } 512 | 513 | std::unique_ptr 514 | Model::MakeEmbeddingLookupSparseOptions(const tflite::Operator* op) { 515 | auto p = reinterpret_cast( 516 | op->builtin_options()); 517 | 518 | std::unique_ptr option = 519 | std::make_unique(); 520 | 521 | switch (p->combiner()) { 522 | case tflite::CombinerType_SUM: 523 | option->combiner = CombinerType::SUM; 524 | break; 525 | 526 | case tflite::CombinerType_MEAN: 527 | option->combiner = CombinerType::MEAN; 528 | break; 529 | 530 | case tflite::CombinerType_SQRTN: 531 | option->combiner = CombinerType::SQRTN; 532 | break; 533 | } 534 | 535 | return option; 536 | } 537 | 538 | std::unique_ptr Model::MakeGatherOptions( 539 | const tflite::Operator* op) { 540 | auto p = reinterpret_cast( 541 | op->builtin_options()); 542 | 543 | std::unique_ptr option = std::make_unique(); 544 | 545 | option->axis = p->axis(); 546 | 547 | return option; 548 | } 549 | 550 | std::unique_ptr Model::MakeTransposeOptions( 551 | const tflite::Operator*) { 552 | std::unique_ptr option = 553 | std::make_unique(); 554 | 555 | return option; 556 | } 557 | 558 | std::unique_ptr Model::MakeMeanOptions( 559 | const tflite::Operator* op) { 560 | auto p = reinterpret_cast( 561 | op->builtin_options()); 562 | 563 | std::unique_ptr option = std::make_unique(); 564 | 565 | option->keep_dims = p->keep_dims(); 566 | 567 | return option; 568 | } 569 | 570 | std::unique_ptr Model::MakeSqueezeOptions( 571 | const tflite::Operator* op) { 572 | auto p = reinterpret_cast( 573 | op->builtin_options()); 574 | 575 | std::unique_ptr option = std::make_unique(); 576 | 577 | option->squeeze_dims = AssignVector(p->squeeze_dims()); 578 | 579 | return option; 580 | } 581 | 582 | std::unique_ptr Model::MakeExpOptions( 583 | const tflite::Operator*) { 584 | std::unique_ptr option = std::make_unique(); 585 | 586 | return option; 587 | } 588 | 589 | std::unique_ptr Model::MakeTopKV2Options( 590 | const tflite::Operator*) { 591 | std::unique_ptr option = std::make_unique(); 592 | return option; 593 | } 594 | 595 | std::unique_ptr Model::MakeSplitOptions( 596 | const tflite::Operator* op) { 597 | auto p = reinterpret_cast( 598 | op->builtin_options()); 599 | 600 | std::unique_ptr option = std::make_unique(); 601 | 602 | option->num_splits = p->num_splits(); 603 | 604 | return option; 605 | } 606 | 607 | std::unique_ptr Model::MakeLogSoftmaxOptions( 608 | const tflite::Operator*) { 609 | std::unique_ptr option = 610 | std::make_unique(); 611 | 612 | return option; 613 | } 614 | 615 | std::unique_ptr Model::MakeCastOptions( 616 | const tflite::Operator* op) { 617 | auto p = reinterpret_cast( 618 | op->builtin_options()); 619 | 620 | std::unique_ptr option = std::make_unique(); 621 | 622 | option->in_data_type = ConvertTensorType(p->in_data_type()); 623 | option->out_data_type = ConvertTensorType(p->out_data_type()); 624 | 625 | return option; 626 | } 627 | 628 | std::unique_ptr Model::MakeDequantizeOptions( 629 | const tflite::Operator*) { 630 | std::unique_ptr option = 631 | std::make_unique(); 632 | 633 | return option; 634 | } 635 | 636 | #ifdef NEWER_TENSORFLOW 637 | std::unique_ptr Model::MakeMaximumMinimumOptions( 638 | const tflite::Operator*) { 639 | std::unique_ptr option = 640 | std::make_unique(); 641 | 642 | return option; 643 | } 644 | 645 | std::unique_ptr Model::MakeArgMaxOptions( 646 | const tflite::Operator* op) { 647 | auto p = reinterpret_cast( 648 | op->builtin_options()); 649 | 650 | std::unique_ptr option = std::make_unique(); 651 | 652 | option->output_type = ConvertTensorType(p->output_type()); 653 | 654 | return option; 655 | } 656 | 657 | std::unique_ptr Model::MakeLessOptions( 658 | const tflite::Operator*) { 659 | std::unique_ptr option = std::make_unique(); 660 | return option; 661 | } 662 | 663 | std::unique_ptr Model::MakeNegOptions( 664 | const tflite::Operator*) { 665 | std::unique_ptr option = std::make_unique(); 666 | return option; 667 | } 668 | #else 669 | std::unique_ptr Model::MakeMaximumOptions( 670 | const tflite::Operator*) { 671 | std::unique_ptr option = std::make_unique(); 672 | 673 | return option; 674 | } 675 | #endif 676 | 677 | std::unique_ptr Model::HandleBuiltinOptions( 678 | const tflite::Operator* op) { 679 | auto op_type = op->builtin_options_type(); 680 | 681 | switch (op_type) { 682 | case tflite::BuiltinOptions_Conv2DOptions: 683 | return MakeConv2DOptions(op); 684 | break; 685 | 686 | case tflite::BuiltinOptions_DepthwiseConv2DOptions: 687 | return MakeDepthwiseConv2DOptions(op); 688 | break; 689 | 690 | case tflite::BuiltinOptions_ConcatEmbeddingsOptions: 691 | return MakeConcatEmbeddingsOptions(op); 692 | break; 693 | 694 | case tflite::BuiltinOptions_LSHProjectionOptions: 695 | return MakeLSHProjectionOptions(op); 696 | break; 697 | 698 | case tflite::BuiltinOptions_Pool2DOptions: 699 | return MakePool2DOptions(op); 700 | break; 701 | 702 | case tflite::BuiltinOptions_SVDFOptions: 703 | return MakeSVDFOptions(op); 704 | break; 705 | 706 | case tflite::BuiltinOptions_RNNOptions: 707 | return MakeRNNOptions(op); 708 | break; 709 | 710 | case tflite::BuiltinOptions_FullyConnectedOptions: 711 | return MakeFullyConnectedOptions(op); 712 | break; 713 | 714 | case tflite::BuiltinOptions_SoftmaxOptions: 715 | return MakeSoftmaxOptions(op); 716 | break; 717 | 718 | case tflite::BuiltinOptions_ConcatenationOptions: 719 | return MakeConcatenationOptions(op); 720 | break; 721 | 722 | case tflite::BuiltinOptions_AddOptions: 723 | return MakeAddOptions(op); 724 | break; 725 | 726 | case tflite::BuiltinOptions_L2NormOptions: 727 | return MakeL2NormOptions(op); 728 | break; 729 | 730 | case tflite::BuiltinOptions_LocalResponseNormalizationOptions: 731 | return MakeLocalResponseNormalizationOptions(op); 732 | break; 733 | 734 | case tflite::BuiltinOptions_LSTMOptions: 735 | return MakeLSTMOptions(op); 736 | break; 737 | 738 | case tflite::BuiltinOptions_ResizeBilinearOptions: 739 | return MakeResizeBilinearOptions(op); 740 | break; 741 | 742 | case tflite::BuiltinOptions_CallOptions: 743 | return MakeCallOptions(op); 744 | break; 745 | 746 | case tflite::BuiltinOptions_ReshapeOptions: 747 | return MakeReshapeOptions(op); 748 | break; 749 | 750 | case tflite::BuiltinOptions_SkipGramOptions: 751 | return MakeSkipGramOptions(op); 752 | break; 753 | 754 | case tflite::BuiltinOptions_SpaceToDepthOptions: 755 | return MakeSpaceToDepthOptions(op); 756 | break; 757 | 758 | case tflite::BuiltinOptions_EmbeddingLookupSparseOptions: 759 | return MakeEmbeddingLookupSparseOptions(op); 760 | break; 761 | 762 | case tflite::BuiltinOptions_MulOptions: 763 | return MakeMulOptions(op); 764 | break; 765 | 766 | case tflite::BuiltinOptions_PadOptions: 767 | return MakePadOptions(op); 768 | break; 769 | 770 | case tflite::BuiltinOptions_GatherOptions: 771 | return MakeGatherOptions(op); 772 | break; 773 | 774 | case tflite::BuiltinOptions_BatchToSpaceNDOptions: 775 | return MakeBatchToSpaceNDOptions(op); 776 | break; 777 | 778 | case tflite::BuiltinOptions_SpaceToBatchNDOptions: 779 | return MakeSpaceToBatchNDOptions(op); 780 | break; 781 | 782 | case tflite::BuiltinOptions_TransposeOptions: 783 | return MakeTransposeOptions(op); 784 | break; 785 | 786 | case tflite::BuiltinOptions_MeanOptions: 787 | return MakeMeanOptions(op); 788 | break; 789 | 790 | case tflite::BuiltinOptions_SubOptions: 791 | return MakeSubOptions(op); 792 | break; 793 | 794 | case tflite::BuiltinOptions_DivOptions: 795 | return MakeDivOptions(op); 796 | break; 797 | 798 | case tflite::BuiltinOptions_SqueezeOptions: 799 | return MakeSqueezeOptions(op); 800 | break; 801 | 802 | case tflite::BuiltinOptions_SequenceRNNOptions: 803 | return MakeSequenceRNNOptions(op); 804 | break; 805 | 806 | default: 807 | return MakeNoneOptions(op); 808 | } 809 | } 810 | 811 | void Model::PopulateGraphOperators(const tflite::SubGraph* graph) { 812 | auto operators = graph->operators(); 813 | std::vector vec_operators; 814 | 815 | // get operators 816 | for (auto it = operators->begin(); it != operators->end(); ++it) { 817 | std::vector vec_ins = AssignVector(it->inputs()); 818 | std::vector vec_outs = AssignVector(it->outputs()); 819 | 820 | std::string opt_str = tflite::EnumNamesBuiltinOptions()[static_cast( 821 | it->builtin_options_type())]; 822 | 823 | // get builtin options 824 | std::unique_ptr builtin_op(HandleBuiltinOptions(*it)); 825 | 826 | // get the operator code reference given the index o operator table 827 | size_t opcode_index = static_cast(it->opcode_index()); 828 | const OperatorCode& op_code = operators_code_[opcode_index]; 829 | 830 | graph_.AddOperator(Operator(opcode_index, op_code, std::move(builtin_op), 831 | opt_str, std::move(vec_ins), std::move(vec_outs))); 832 | } 833 | } 834 | 835 | void Model::PopulateGraph() { 836 | if (flat_buffers_.Length() == 0) { 837 | FATAL("Model file is empty") 838 | return; 839 | } 840 | 841 | auto subgraphs = fb_model_->subgraphs(); 842 | if (!subgraphs) { 843 | FATAL("No subgraph found") 844 | return; 845 | } 846 | 847 | auto graph = subgraphs->Get(0); 848 | 849 | PopulateGraphInputs(graph); 850 | PopulateGraphOutputs(graph); 851 | PopulateGraphTensors(graph); 852 | PopulateGraphOperators(graph); 853 | } 854 | 855 | void Model::PopulateBuffers() { 856 | auto buffer_vec = fb_model_->buffers(); 857 | 858 | // test if buffer_vec is null to avoid crash on flatbuffers 859 | if (!buffer_vec) { 860 | return; 861 | } 862 | 863 | for (auto it = buffer_vec->begin(); it != buffer_vec->end(); ++it) { 864 | std::vector buf = AssignVector(it->data()); 865 | buffers_.push_back(std::move(buf)); 866 | } 867 | } 868 | 869 | BuiltinOperator Model::ConvertOperatorCode(tflite::BuiltinOperator type) { 870 | switch (type) { 871 | case tflite::BuiltinOperator_ADD: 872 | return BuiltinOperator::ADD; 873 | break; 874 | 875 | case tflite::BuiltinOperator_AVERAGE_POOL_2D: 876 | return BuiltinOperator::AVERAGE_POOL_2D; 877 | break; 878 | 879 | case tflite::BuiltinOperator_CONCATENATION: 880 | return BuiltinOperator::CONCATENATION; 881 | break; 882 | 883 | case tflite::BuiltinOperator_CONV_2D: 884 | return BuiltinOperator::CONV_2D; 885 | break; 886 | 887 | case tflite::BuiltinOperator_DEPTHWISE_CONV_2D: 888 | return BuiltinOperator::DEPTHWISE_CONV_2D; 889 | break; 890 | 891 | case tflite::BuiltinOperator_DEQUANTIZE: 892 | return BuiltinOperator::DEQUANTIZE; 893 | break; 894 | 895 | case tflite::BuiltinOperator_EMBEDDING_LOOKUP: 896 | return BuiltinOperator::EMBEDDING_LOOKUP; 897 | break; 898 | #ifdef NEWER_TENSORFLOW 899 | case tflite::BuiltinOperator_FLOOR: 900 | return BuiltinOperator::FLOOR; 901 | break; 902 | #endif 903 | case tflite::BuiltinOperator_FULLY_CONNECTED: 904 | return BuiltinOperator::FULLY_CONNECTED; 905 | break; 906 | 907 | case tflite::BuiltinOperator_HASHTABLE_LOOKUP: 908 | return BuiltinOperator::HASHTABLE_LOOKUP; 909 | break; 910 | 911 | case tflite::BuiltinOperator_L2_NORMALIZATION: 912 | return BuiltinOperator::L2_NORMALIZATION; 913 | break; 914 | 915 | case tflite::BuiltinOperator_L2_POOL_2D: 916 | return BuiltinOperator::L2_POOL_2D; 917 | break; 918 | 919 | case tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: 920 | return BuiltinOperator::LOCAL_RESPONSE_NORMALIZATION; 921 | break; 922 | 923 | case tflite::BuiltinOperator_LOGISTIC: 924 | return BuiltinOperator::LOGISTIC; 925 | break; 926 | 927 | case tflite::BuiltinOperator_LSH_PROJECTION: 928 | return BuiltinOperator::LSH_PROJECTION; 929 | break; 930 | 931 | case tflite::BuiltinOperator_LSTM: 932 | return BuiltinOperator::LSTM; 933 | break; 934 | 935 | case tflite::BuiltinOperator_MAX_POOL_2D: 936 | return BuiltinOperator::MAX_POOL_2D; 937 | break; 938 | 939 | case tflite::BuiltinOperator_MUL: 940 | return BuiltinOperator::MUL; 941 | break; 942 | 943 | case tflite::BuiltinOperator_RELU: 944 | return BuiltinOperator::RELU; 945 | break; 946 | 947 | case tflite::BuiltinOperator_RELU_N1_TO_1: 948 | return BuiltinOperator::RELU1; 949 | break; 950 | 951 | case tflite::BuiltinOperator_RELU6: 952 | return BuiltinOperator::RELU6; 953 | break; 954 | 955 | case tflite::BuiltinOperator_RESHAPE: 956 | return BuiltinOperator::RESHAPE; 957 | break; 958 | 959 | case tflite::BuiltinOperator_RNN: 960 | return BuiltinOperator::RNN; 961 | break; 962 | 963 | case tflite::BuiltinOperator_SOFTMAX: 964 | return BuiltinOperator::SOFTMAX; 965 | break; 966 | 967 | case tflite::BuiltinOperator_SPACE_TO_DEPTH: 968 | return BuiltinOperator::SPACE_TO_DEPTH; 969 | break; 970 | 971 | case tflite::BuiltinOperator_SVDF: 972 | return BuiltinOperator::SVDF; 973 | break; 974 | 975 | case tflite::BuiltinOperator_TANH: 976 | return BuiltinOperator::TANH; 977 | break; 978 | 979 | case tflite::BuiltinOperator_CONCAT_EMBEDDINGS: 980 | return BuiltinOperator::CONCAT_EMBEDDINGS; 981 | break; 982 | 983 | case tflite::BuiltinOperator_SKIP_GRAM: 984 | return BuiltinOperator::SKIP_GRAM; 985 | break; 986 | 987 | case tflite::BuiltinOperator_CALL: 988 | return BuiltinOperator::CALL; 989 | break; 990 | 991 | case tflite::BuiltinOperator_CUSTOM: 992 | return BuiltinOperator::CUSTOM; 993 | break; 994 | 995 | case tflite::BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: 996 | return BuiltinOperator::EMBEDDING_LOOKUP_SPARSE; 997 | break; 998 | 999 | case tflite::BuiltinOperator_PAD: 1000 | return BuiltinOperator::PAD; 1001 | break; 1002 | 1003 | case tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: 1004 | return BuiltinOperator::UNIDIRECTIONAL_SEQUENCE_RNN; 1005 | break; 1006 | 1007 | case tflite::BuiltinOperator_GATHER: 1008 | return BuiltinOperator::GATHER; 1009 | break; 1010 | 1011 | case tflite::BuiltinOperator_BATCH_TO_SPACE_ND: 1012 | return BuiltinOperator::BATCH_TO_SPACE_ND; 1013 | break; 1014 | 1015 | case tflite::BuiltinOperator_SPACE_TO_BATCH_ND: 1016 | return BuiltinOperator::SPACE_TO_BATCH_ND; 1017 | break; 1018 | 1019 | case tflite::BuiltinOperator_TRANSPOSE: 1020 | return BuiltinOperator::TRANSPOSE; 1021 | break; 1022 | 1023 | case tflite::BuiltinOperator_MEAN: 1024 | return BuiltinOperator::MEAN; 1025 | break; 1026 | 1027 | case tflite::BuiltinOperator_SUB: 1028 | return BuiltinOperator::SUB; 1029 | break; 1030 | 1031 | case tflite::BuiltinOperator_DIV: 1032 | return BuiltinOperator::DIV; 1033 | break; 1034 | 1035 | case tflite::BuiltinOperator_SQUEEZE: 1036 | return BuiltinOperator::SQUEEZE; 1037 | break; 1038 | 1039 | case tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: 1040 | return BuiltinOperator::UNIDIRECTIONAL_SEQUENCE_LSTM; 1041 | break; 1042 | 1043 | case tflite::BuiltinOperator_STRIDED_SLICE: 1044 | return BuiltinOperator::STRIDED_SLICE; 1045 | break; 1046 | 1047 | case tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN: 1048 | return BuiltinOperator::BIDIRECTIONAL_SEQUENCE_RNN; 1049 | break; 1050 | 1051 | case tflite::BuiltinOperator_EXP: 1052 | return BuiltinOperator::EXP; 1053 | break; 1054 | 1055 | case tflite::BuiltinOperator_TOPK_V2: 1056 | return BuiltinOperator::TOPK_V2; 1057 | break; 1058 | 1059 | case tflite::BuiltinOperator_SPLIT: 1060 | return BuiltinOperator::SPLIT; 1061 | break; 1062 | 1063 | case tflite::BuiltinOperator_LOG_SOFTMAX: 1064 | return BuiltinOperator::LOG_SOFTMAX; 1065 | break; 1066 | 1067 | case tflite::BuiltinOperator_DELEGATE: 1068 | return BuiltinOperator::DELEGATE; 1069 | break; 1070 | 1071 | case tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: 1072 | return BuiltinOperator::BIDIRECTIONAL_SEQUENCE_LSTM; 1073 | break; 1074 | 1075 | case tflite::BuiltinOperator_CAST: 1076 | return BuiltinOperator::CAST; 1077 | break; 1078 | 1079 | case tflite::BuiltinOperator_PRELU: 1080 | return BuiltinOperator::PRELU; 1081 | break; 1082 | 1083 | case tflite::BuiltinOperator_MAXIMUM: 1084 | return BuiltinOperator::MAXIMUM; 1085 | break; 1086 | #ifdef NEWER_TENSORFLOW 1087 | case tflite::BuiltinOperator_ARG_MAX: 1088 | return BuiltinOperator::ARG_MAX; 1089 | break; 1090 | 1091 | case tflite::BuiltinOperator_MINIMUM: 1092 | return BuiltinOperator::MINIMUM; 1093 | break; 1094 | 1095 | case tflite::BuiltinOperator_LESS: 1096 | return BuiltinOperator::LESS; 1097 | break; 1098 | 1099 | case tflite::BuiltinOperator_NEG: 1100 | return BuiltinOperator::NEG; 1101 | break; 1102 | #endif 1103 | default: 1104 | return BuiltinOperator::NONE; 1105 | } 1106 | } 1107 | 1108 | void Model::PopulateOperatorsCode() { 1109 | auto op_codes_vec = fb_model_->operator_codes(); 1110 | 1111 | if (!op_codes_vec) { 1112 | return; 1113 | } 1114 | 1115 | for (auto it = op_codes_vec->begin(); it != op_codes_vec->end(); ++it) { 1116 | auto custom_code = it->custom_code(); 1117 | 1118 | OperatorCode op_code { 1119 | .builtin_code = ConvertOperatorCode(it->builtin_code()), 1120 | .custom_code = custom_code? "\"" + custom_code->str() +"\"" : "\"\"" 1121 | }; 1122 | 1123 | operators_code_.push_back(std::move(op_code)); 1124 | } 1125 | } 1126 | 1127 | const char* Model::description() { 1128 | return fb_model_->description()->c_str(); 1129 | } 1130 | 1131 | } 1132 | -------------------------------------------------------------------------------- /src/model.h: -------------------------------------------------------------------------------- 1 | #ifndef NNT_MODEL_H 2 | #define NNT_MODEL_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "schemas/schema_generated.h" 11 | 12 | namespace nnt { 13 | 14 | class FlatBufferModel { 15 | public: 16 | FlatBufferModel(const std::string& file); 17 | ~FlatBufferModel(); 18 | char* data(); 19 | int Length() const; 20 | 21 | private: 22 | char *data_; 23 | int len_; 24 | }; 25 | 26 | class Buffer { 27 | public: 28 | Buffer(std::vector&& buf): buf_(std::move(buf)) {} 29 | 30 | const std::vector& Data() const { 31 | return buf_; 32 | } 33 | 34 | const u_char* RawData() const { 35 | return buf_.data(); 36 | } 37 | 38 | private: 39 | std::vector buf_; 40 | }; 41 | 42 | enum class ActivationFunctionType: int8_t { 43 | NONE, 44 | RELU, 45 | RELU1, 46 | RELU6, 47 | TANH, 48 | SIGN_BIT 49 | }; 50 | 51 | enum class TensorType: int8_t { 52 | FLOAT32, 53 | FLOAT16, 54 | INT32, 55 | UINT8, 56 | INT64, 57 | STRING, 58 | #ifdef NEWER_TENSORFLOW 59 | BOOL 60 | #endif 61 | }; 62 | 63 | enum class Padding: int8_t { UNKNOWN = 0, SAME, VALID }; 64 | 65 | enum class BuiltinOperator { 66 | NONE, 67 | ADD, 68 | AVERAGE_POOL_2D, 69 | CONCATENATION, 70 | CONV_2D, 71 | DEPTHWISE_CONV_2D, 72 | DEQUANTIZE, 73 | EMBEDDING_LOOKUP, 74 | #ifdef NEWER_TENSORFLOW 75 | FLOOR, 76 | #endif 77 | FULLY_CONNECTED, 78 | HASHTABLE_LOOKUP, 79 | L2_NORMALIZATION, 80 | L2_POOL_2D, 81 | LOCAL_RESPONSE_NORMALIZATION, 82 | LOGISTIC, 83 | LSH_PROJECTION, 84 | LSTM, 85 | MAX_POOL_2D, 86 | MUL, 87 | RELU, 88 | RELU1, 89 | RELU6, 90 | RESHAPE, 91 | RESIZE_BILINEAR, 92 | RNN, 93 | SOFTMAX, 94 | SPACE_TO_DEPTH, 95 | SVDF, 96 | TANH, 97 | CONCAT_EMBEDDINGS, 98 | SKIP_GRAM, 99 | CALL, 100 | CUSTOM, 101 | EMBEDDING_LOOKUP_SPARSE, 102 | PAD, 103 | UNIDIRECTIONAL_SEQUENCE_RNN, 104 | GATHER, 105 | BATCH_TO_SPACE_ND, 106 | SPACE_TO_BATCH_ND, 107 | TRANSPOSE, 108 | MEAN, 109 | SUB, 110 | DIV, 111 | SQUEEZE, 112 | UNIDIRECTIONAL_SEQUENCE_LSTM, 113 | STRIDED_SLICE, 114 | BIDIRECTIONAL_SEQUENCE_RNN, 115 | EXP, 116 | TOPK_V2, 117 | SPLIT, 118 | LOG_SOFTMAX, 119 | DELEGATE, 120 | BIDIRECTIONAL_SEQUENCE_LSTM, 121 | CAST, 122 | PRELU, 123 | MAXIMUM, 124 | ARG_MAX, 125 | MINIMUM, 126 | LESS, 127 | NEG 128 | }; 129 | 130 | enum class BuiltinOptionsType { 131 | None, 132 | Conv2DOptions, 133 | DepthwiseConv2DOptions, 134 | ConcatEmbeddingsOptions, 135 | LSHProjectionOptions, 136 | Pool2DOptions, 137 | SVDFOptions, 138 | RNNOptions, 139 | FullyConnectedOptions, 140 | SoftmaxOptions, 141 | ConcatenationOptions, 142 | AddOptions, 143 | L2NormOptions, 144 | LocalResponseNormalizationOptions, 145 | LSTMOptions, 146 | ResizeBilinearOptions, 147 | CallOptions, 148 | ReshapeOptions, 149 | SkipGramOptions, 150 | SpaceToDepthOptions, 151 | EmbeddingLookupSparseOptions, 152 | MulOptions, 153 | PadOptions, 154 | GatherOptions, 155 | BatchToSpaceNDOptions, 156 | SpaceToBatchNDOptions, 157 | TransposeOptions, 158 | MeanOptions, 159 | SubOptions, 160 | DivOptions, 161 | SqueezeOptions, 162 | SequenceRNNOptions, 163 | StridedSliceOptions, 164 | ExpOptions, 165 | TopKV2Options, 166 | SplitOptions, 167 | LogSoftmaxOptions, 168 | CastOptions, 169 | DequantizeOptions, 170 | #ifdef NEWER_TENSORFLOW 171 | MaximumMinimumOptions, 172 | ArgMaxOptions, 173 | LessOptions, 174 | NegOptions, 175 | #else 176 | MaximumOptions 177 | #endif 178 | }; 179 | 180 | struct OperatorCode { 181 | BuiltinOperator builtin_code; 182 | std::string custom_code; 183 | }; 184 | 185 | struct BuiltinOptions { 186 | BuiltinOptions(BuiltinOptionsType type_op): type(type_op) {} 187 | 188 | BuiltinOptionsType type; 189 | }; 190 | 191 | struct NoneOptions: public BuiltinOptions { 192 | NoneOptions(): BuiltinOptions(BuiltinOptionsType::None) {} 193 | }; 194 | 195 | struct Conv2DOptions: public BuiltinOptions { 196 | Conv2DOptions(): BuiltinOptions(BuiltinOptionsType::Conv2DOptions) {} 197 | 198 | Padding padding; 199 | int stride_w; 200 | int stride_h; 201 | #ifdef NEWER_TENSORFLOW 202 | int dilation_w_factor; 203 | int dilation_h_factor; 204 | #endif 205 | ActivationFunctionType fused_activation_function; 206 | }; 207 | 208 | struct Pool2DOptions: public BuiltinOptions { 209 | Pool2DOptions(): BuiltinOptions(BuiltinOptionsType::Pool2DOptions) {} 210 | 211 | Padding padding; 212 | int stride_w; 213 | int stride_h; 214 | int filter_width; 215 | int filter_height; 216 | ActivationFunctionType fused_activation_function; 217 | }; 218 | 219 | struct DepthwiseConv2DOptions: public BuiltinOptions { 220 | DepthwiseConv2DOptions() 221 | : BuiltinOptions(BuiltinOptionsType::DepthwiseConv2DOptions) {} 222 | 223 | Padding padding; 224 | int stride_w; 225 | int stride_h; 226 | int depth_multiplier; 227 | ActivationFunctionType fused_activation_function; 228 | }; 229 | 230 | struct ConcatEmbeddingsOptions: public BuiltinOptions { 231 | ConcatEmbeddingsOptions() 232 | : BuiltinOptions(BuiltinOptionsType::ConcatEmbeddingsOptions) {} 233 | 234 | int num_channels; 235 | std::vector num_columns_per_channel; 236 | std::vector embedding_dim_per_channel; 237 | }; 238 | 239 | enum class LSHProjectionType: int8_t { 240 | UNKNOWN = 0, 241 | SPARSE = 1, 242 | DENSE = 2, 243 | }; 244 | 245 | struct LSHProjectionOptions: public BuiltinOptions { 246 | LSHProjectionOptions() 247 | : BuiltinOptions(BuiltinOptionsType::LSHProjectionOptions) {} 248 | 249 | LSHProjectionType type; 250 | }; 251 | 252 | struct SVDFOptions: public BuiltinOptions { 253 | SVDFOptions(): BuiltinOptions(BuiltinOptionsType::SVDFOptions) {} 254 | 255 | int rank; 256 | ActivationFunctionType fused_activation_function; 257 | }; 258 | 259 | struct RNNOptions: public BuiltinOptions { 260 | RNNOptions(): BuiltinOptions(BuiltinOptionsType::RNNOptions) {} 261 | 262 | ActivationFunctionType fused_activation_function; 263 | }; 264 | 265 | struct SequenceRNNOptions: public BuiltinOptions { 266 | SequenceRNNOptions() 267 | : BuiltinOptions(BuiltinOptionsType::SequenceRNNOptions) {} 268 | 269 | bool time_major; 270 | ActivationFunctionType fused_activation_function; 271 | }; 272 | 273 | struct BidirectionalSequenceRNNOptions: public BuiltinOptions { 274 | bool time_major; 275 | ActivationFunctionType fused_activation_function; 276 | }; 277 | 278 | struct FullyConnectedOptions: public BuiltinOptions { 279 | FullyConnectedOptions() 280 | : BuiltinOptions(BuiltinOptionsType::FullyConnectedOptions) {} 281 | 282 | ActivationFunctionType fused_activation_function; 283 | }; 284 | 285 | struct SoftmaxOptions: public BuiltinOptions { 286 | SoftmaxOptions() 287 | : BuiltinOptions(BuiltinOptionsType::SoftmaxOptions) {} 288 | 289 | float beta; 290 | }; 291 | 292 | struct ConcatenationOptions: public BuiltinOptions { 293 | ConcatenationOptions() 294 | : BuiltinOptions(BuiltinOptionsType::ConcatenationOptions) {} 295 | 296 | int axis; 297 | ActivationFunctionType fused_activation_function; 298 | }; 299 | 300 | struct AddOptions: public BuiltinOptions { 301 | AddOptions(): BuiltinOptions(BuiltinOptionsType::AddOptions) {} 302 | 303 | ActivationFunctionType fused_activation_function; 304 | }; 305 | 306 | struct MulOptions: public BuiltinOptions { 307 | MulOptions(): BuiltinOptions(BuiltinOptionsType::MulOptions) {} 308 | 309 | ActivationFunctionType fused_activation_function; 310 | }; 311 | 312 | struct L2NormOptions: public BuiltinOptions { 313 | L2NormOptions(): BuiltinOptions(BuiltinOptionsType::L2NormOptions) {} 314 | 315 | ActivationFunctionType fused_activation_function; 316 | }; 317 | 318 | struct LocalResponseNormalizationOptions: public BuiltinOptions { 319 | LocalResponseNormalizationOptions() 320 | : BuiltinOptions(BuiltinOptionsType::LocalResponseNormalizationOptions) {} 321 | 322 | int radius; 323 | float bias; 324 | float alpha; 325 | float beta; 326 | }; 327 | 328 | struct LSTMOptions: public BuiltinOptions { 329 | LSTMOptions(): BuiltinOptions(BuiltinOptionsType::LSTMOptions) {} 330 | 331 | float cell_clip; 332 | float proj_clip; 333 | ActivationFunctionType fused_activation_function; 334 | }; 335 | 336 | struct ResizeBilinearOptions: public BuiltinOptions { 337 | ResizeBilinearOptions() 338 | : BuiltinOptions(BuiltinOptionsType::ResizeBilinearOptions) {} 339 | 340 | bool align_corners; 341 | }; 342 | 343 | struct CallOptions: public BuiltinOptions { 344 | CallOptions(): BuiltinOptions(BuiltinOptionsType::CallOptions) {} 345 | 346 | uint subgraph; 347 | }; 348 | 349 | struct PadOptions: public BuiltinOptions { 350 | PadOptions(): BuiltinOptions(BuiltinOptionsType::PadOptions) {} 351 | }; 352 | 353 | struct ReshapeOptions: public BuiltinOptions { 354 | ReshapeOptions(): BuiltinOptions(BuiltinOptionsType::ReshapeOptions) {} 355 | 356 | std::vector new_shape; 357 | }; 358 | 359 | struct SpaceToBatchNDOptions: public BuiltinOptions { 360 | SpaceToBatchNDOptions() 361 | : BuiltinOptions(BuiltinOptionsType::SpaceToBatchNDOptions) {} 362 | }; 363 | 364 | struct BatchToSpaceNDOptions: public BuiltinOptions { 365 | BatchToSpaceNDOptions() 366 | : BuiltinOptions(BuiltinOptionsType::BatchToSpaceNDOptions) {} 367 | }; 368 | 369 | struct SkipGramOptions: public BuiltinOptions { 370 | SkipGramOptions(): BuiltinOptions(BuiltinOptionsType::SkipGramOptions) {} 371 | 372 | int ngram_size; 373 | int max_skip_size; 374 | bool include_all_ngrams; 375 | }; 376 | 377 | struct SpaceToDepthOptions: public BuiltinOptions { 378 | SpaceToDepthOptions() 379 | : BuiltinOptions(BuiltinOptionsType::SpaceToDepthOptions) {} 380 | 381 | int block_size; 382 | }; 383 | 384 | struct SubOptions: public BuiltinOptions { 385 | SubOptions(): BuiltinOptions(BuiltinOptionsType::SubOptions) {} 386 | 387 | ActivationFunctionType fused_activation_function; 388 | }; 389 | 390 | struct DivOptions: public BuiltinOptions { 391 | DivOptions(): BuiltinOptions(BuiltinOptionsType::DivOptions) {} 392 | 393 | ActivationFunctionType fused_activation_function; 394 | }; 395 | 396 | struct TopKV2Options: public BuiltinOptions { 397 | TopKV2Options(): BuiltinOptions(BuiltinOptionsType::TopKV2Options) {} 398 | }; 399 | 400 | enum class CombinerType : int8_t { 401 | SUM = 0, 402 | MEAN = 1, 403 | SQRTN = 2, 404 | }; 405 | 406 | struct EmbeddingLookupSparseOptions: public BuiltinOptions { 407 | EmbeddingLookupSparseOptions() 408 | : BuiltinOptions(BuiltinOptionsType::EmbeddingLookupSparseOptions) {} 409 | 410 | CombinerType combiner; 411 | }; 412 | 413 | struct GatherOptions: public BuiltinOptions { 414 | GatherOptions(): BuiltinOptions(BuiltinOptionsType::GatherOptions) {} 415 | 416 | int axis; 417 | }; 418 | 419 | struct TransposeOptions: public BuiltinOptions { 420 | TransposeOptions(): BuiltinOptions(BuiltinOptionsType::TransposeOptions) {} 421 | }; 422 | 423 | struct ExpOptions: public BuiltinOptions { 424 | ExpOptions(): BuiltinOptions(BuiltinOptionsType::ExpOptions) {} 425 | }; 426 | 427 | struct MeanOptions: public BuiltinOptions { 428 | MeanOptions(): BuiltinOptions(BuiltinOptionsType::MeanOptions) {} 429 | 430 | bool keep_dims; 431 | }; 432 | 433 | struct SqueezeOptions: public BuiltinOptions { 434 | SqueezeOptions(): BuiltinOptions(BuiltinOptionsType::SqueezeOptions) {} 435 | 436 | std::vector squeeze_dims; 437 | }; 438 | 439 | struct SplitOptions: public BuiltinOptions { 440 | SplitOptions(): BuiltinOptions(BuiltinOptionsType::SplitOptions) {} 441 | 442 | int num_splits; 443 | }; 444 | 445 | struct StridedSliceOptions: public BuiltinOptions { 446 | StridedSliceOptions() 447 | : BuiltinOptions(BuiltinOptionsType::StridedSliceOptions) {} 448 | 449 | int begin_mask; 450 | int end_mask; 451 | int ellipsis_mask; 452 | int new_axis_mask; 453 | int shrink_axis_mask; 454 | }; 455 | 456 | struct LogSoftmaxOptions: public BuiltinOptions { 457 | LogSoftmaxOptions(): BuiltinOptions(BuiltinOptionsType::LogSoftmaxOptions) {} 458 | }; 459 | 460 | struct CastOptions: public BuiltinOptions { 461 | CastOptions(): BuiltinOptions(BuiltinOptionsType::CastOptions) {} 462 | 463 | TensorType in_data_type; 464 | TensorType out_data_type; 465 | }; 466 | 467 | struct DequantizeOptions: public BuiltinOptions { 468 | DequantizeOptions(): BuiltinOptions(BuiltinOptionsType::DequantizeOptions) {} 469 | }; 470 | 471 | #ifdef NEWER_TENSORFLOW 472 | struct MaximumMinimumOptions: public BuiltinOptions { 473 | MaximumMinimumOptions() 474 | : BuiltinOptions(BuiltinOptionsType::MaximumMinimumOptions) {} 475 | }; 476 | 477 | struct ArgMaxOptions: public BuiltinOptions { 478 | ArgMaxOptions(): BuiltinOptions(BuiltinOptionsType::ArgMaxOptions) {} 479 | 480 | TensorType output_type; 481 | }; 482 | 483 | struct LessOptions: public BuiltinOptions { 484 | LessOptions(): BuiltinOptions(BuiltinOptionsType::LessOptions) {} 485 | }; 486 | 487 | struct NegOptions: public BuiltinOptions { 488 | NegOptions(): BuiltinOptions(BuiltinOptionsType::NegOptions) {} 489 | }; 490 | #else 491 | struct MaximumOptions: public BuiltinOptions { 492 | MaximumOptions(): BuiltinOptions(BuiltinOptionsType::MaximumOptions) {} 493 | }; 494 | #endif 495 | 496 | class Operator { 497 | public: 498 | Operator(int index, const OperatorCode& op_code, 499 | std::unique_ptr builtin_op, 500 | const std::string& builtin_op_str, std::vector&& inputs, 501 | std::vector&& outputs) 502 | : index_(index) 503 | , op_code_(op_code) 504 | , builtin_op_str_(builtin_op_str) 505 | , inputs_(std::move(inputs)) 506 | , outputs_(std::move(outputs)) 507 | , builtin_op_(std::move(builtin_op)) {} 508 | 509 | Operator(Operator&& op) 510 | : index_(op.index_) 511 | , op_code_(op.op_code_) 512 | , builtin_op_str_(std::move(op.builtin_op_str_)) 513 | , inputs_(std::move(op.inputs_)) 514 | , outputs_(std::move(op.outputs_)) 515 | , builtin_op_(std::move(op.builtin_op_)) {} 516 | 517 | Operator(const Operator&) = delete; 518 | 519 | int index() const { 520 | return index_; 521 | } 522 | 523 | const std::string& builtin_op_str() const { 524 | return builtin_op_str_; 525 | } 526 | 527 | const std::vector& inputs() const { 528 | return inputs_; 529 | } 530 | 531 | const std::vector& outputs() const { 532 | return outputs_; 533 | } 534 | 535 | const BuiltinOptions& builtin_op() const { 536 | return *builtin_op_; 537 | } 538 | 539 | const OperatorCode& op_code() const { 540 | return op_code_; 541 | } 542 | 543 | private: 544 | int index_; 545 | const OperatorCode& op_code_; 546 | std::string builtin_op_str_; 547 | std::vector inputs_; 548 | std::vector outputs_; 549 | std::unique_ptr builtin_op_; 550 | }; 551 | 552 | struct QuantizationParameters { 553 | std::vector min; 554 | std::vector max; 555 | std::vector scale; 556 | std::vector zero_point; 557 | }; 558 | 559 | class Tensor { 560 | public: 561 | Tensor(std::vector&& shape, TensorType tensor_type, 562 | const std::string& name, const Buffer& buffer, uint buffer_index, 563 | std::unique_ptr quantization) 564 | : shape_(std::move(shape)) 565 | , tensor_type_(tensor_type) 566 | , name_(name) 567 | , buffer_(buffer) 568 | , buffer_index_(buffer_index) 569 | , quantization_(std::move(quantization)) {} 570 | 571 | Tensor(const Tensor& tensor) = delete; 572 | 573 | Tensor(Tensor&& tensor) 574 | : shape_(std::move(tensor.shape_)) 575 | , tensor_type_(tensor.tensor_type_) 576 | , name_(std::move(tensor.name_)) 577 | , buffer_(tensor.buffer_) 578 | , buffer_index_(tensor.buffer_index_) 579 | , quantization_(std::move(tensor.quantization_)) {} 580 | 581 | const std::string& name() const { 582 | return name_; 583 | } 584 | 585 | const std::vector& shape() const { 586 | return shape_; 587 | } 588 | 589 | TensorType tensor_type() const { 590 | return tensor_type_; 591 | } 592 | 593 | const Buffer& buffer() const { 594 | return buffer_; 595 | } 596 | 597 | uint buffer_index() const { 598 | return buffer_index_; 599 | } 600 | 601 | bool HasQuantization() const { 602 | return bool(quantization_); 603 | } 604 | 605 | const QuantizationParameters& quantization() const { 606 | return *quantization_; 607 | } 608 | 609 | private: 610 | std::vector shape_; 611 | TensorType tensor_type_; 612 | std::string name_; 613 | const Buffer& buffer_; 614 | uint buffer_index_; 615 | std::unique_ptr quantization_; 616 | }; 617 | 618 | class Graph { 619 | public: 620 | Graph() = default; 621 | 622 | void SetInputs(std::vector&& inputs) { 623 | inputs_ = std::move(inputs); 624 | } 625 | 626 | void SetOutputs(std::vector&& outputs) { 627 | outputs_ = std::move(outputs); 628 | } 629 | 630 | void AddTensor(Tensor&& tensor) { 631 | tensors_.push_back(std::move(tensor)); 632 | } 633 | 634 | void AddOperator(Operator&& op) { 635 | operators_.push_back(std::move(op)); 636 | } 637 | 638 | const std::vector& Tensors() const { 639 | return tensors_; 640 | } 641 | 642 | const std::vector& Operators() const { 643 | return operators_; 644 | } 645 | 646 | const std::vector& Inputs() const { 647 | return inputs_; 648 | } 649 | 650 | const std::vector& Outputs() const { 651 | return outputs_; 652 | } 653 | 654 | private: 655 | std::vector tensors_; 656 | std::vector operators_; 657 | std::vector inputs_; 658 | std::vector outputs_; 659 | }; 660 | 661 | class Model { 662 | public: 663 | Model(const std::string& file); 664 | 665 | const char* description(); 666 | 667 | Graph& graph() { 668 | return graph_; 669 | } 670 | 671 | const std::vector& Buffers() const { 672 | return buffers_; 673 | } 674 | 675 | private: 676 | void PopulateGraph(); 677 | 678 | void PopulateGraphInputs(const tflite::SubGraph* graph); 679 | 680 | void PopulateGraphOutputs(const tflite::SubGraph* graph); 681 | 682 | TensorType ConvertTensorType(tflite::TensorType type); 683 | 684 | void PopulateGraphTensors(const tflite::SubGraph* graph); 685 | 686 | void PopulateGraphOperators(const tflite::SubGraph* graph); 687 | 688 | void PopulateBuffers(); 689 | 690 | BuiltinOperator ConvertOperatorCode(tflite::BuiltinOperator type); 691 | 692 | void PopulateOperatorsCode(); 693 | 694 | std::unique_ptr MakeNoneOptions(const tflite::Operator* op); 695 | 696 | std::unique_ptr MakeConv2DOptions( 697 | const tflite::Operator* op); 698 | 699 | std::unique_ptr MakePool2DOptions( 700 | const tflite::Operator* op); 701 | 702 | std::unique_ptr MakeDepthwiseConv2DOptions( 703 | const tflite::Operator* op); 704 | 705 | std::unique_ptr MakeConcatEmbeddingsOptions( 706 | const tflite::Operator* op); 707 | 708 | std::unique_ptr MakeLSHProjectionOptions( 709 | const tflite::Operator* op); 710 | 711 | std::unique_ptr MakeSVDFOptions(const tflite::Operator* op); 712 | 713 | std::unique_ptr MakeRNNOptions(const tflite::Operator* op); 714 | 715 | std::unique_ptr MakeSequenceRNNOptions( 716 | const tflite::Operator* op); 717 | 718 | std::unique_ptr MakeFullyConnectedOptions( 719 | const tflite::Operator* op); 720 | 721 | std::unique_ptr MakeSoftmaxOptions( 722 | const tflite::Operator* op); 723 | 724 | std::unique_ptr MakeConcatenationOptions( 725 | const tflite::Operator* op); 726 | 727 | std::unique_ptr HandleBuiltinOptions( 728 | const tflite::Operator* op); 729 | 730 | std::unique_ptr MakeAddOptions(const tflite::Operator* op); 731 | 732 | std::unique_ptr MakeMulOptions(const tflite::Operator* op); 733 | 734 | std::unique_ptr MakeL2NormOptions( 735 | const tflite::Operator* op); 736 | 737 | std::unique_ptr 738 | MakeLocalResponseNormalizationOptions(const tflite::Operator* op); 739 | 740 | std::unique_ptr MakeLSTMOptions(const tflite::Operator* op); 741 | 742 | std::unique_ptr MakeResizeBilinearOptions( 743 | const tflite::Operator* op); 744 | 745 | std::unique_ptr MakeCallOptions(const tflite::Operator* op); 746 | 747 | std::unique_ptr MakePadOptions(const tflite::Operator* op); 748 | 749 | std::unique_ptr MakeReshapeOptions( 750 | const tflite::Operator* op); 751 | 752 | std::unique_ptr MakeSpaceToBatchNDOptions( 753 | const tflite::Operator* op); 754 | 755 | std::unique_ptr MakeSkipGramOptions( 756 | const tflite::Operator* op); 757 | 758 | std::unique_ptr MakeSpaceToDepthOptions( 759 | const tflite::Operator* op); 760 | 761 | std::unique_ptr MakeBatchToSpaceNDOptions( 762 | const tflite::Operator* op); 763 | 764 | std::unique_ptr MakeSubOptions(const tflite::Operator* op); 765 | 766 | std::unique_ptr MakeDivOptions(const tflite::Operator* op); 767 | 768 | std::unique_ptr 769 | MakeEmbeddingLookupSparseOptions(const tflite::Operator* op); 770 | 771 | std::unique_ptr MakeGatherOptions( 772 | const tflite::Operator* op); 773 | 774 | std::unique_ptr MakeTransposeOptions( 775 | const tflite::Operator* op); 776 | 777 | std::unique_ptr MakeMeanOptions(const tflite::Operator* op); 778 | 779 | std::unique_ptr MakeSqueezeOptions( 780 | const tflite::Operator* op); 781 | 782 | std::unique_ptr MakeExpOptions(const tflite::Operator* op); 783 | 784 | std::unique_ptr MakeTopKV2Options( 785 | const tflite::Operator* op); 786 | 787 | std::unique_ptr MakeSplitOptions(const tflite::Operator* op); 788 | 789 | std::unique_ptr MakeLogSoftmaxOptions( 790 | const tflite::Operator* op); 791 | 792 | std::unique_ptr MakeCastOptions(const tflite::Operator* op); 793 | 794 | std::unique_ptr MakeDequantizeOptions(const tflite::Operator* op); 795 | 796 | #ifdef NEWER_TENSORFLOW 797 | std::unique_ptr MakeMaximumMinimumOptions( 798 | const tflite::Operator* op); 799 | 800 | std::unique_ptr MakeArgMaxOptions(const tflite::Operator* op); 801 | 802 | std::unique_ptr MakeLessOptions(const tflite::Operator* op); 803 | 804 | std::unique_ptr MakeNegOptions(const tflite::Operator* op); 805 | #else 806 | std::unique_ptr MakeMaximumOptions( 807 | const tflite::Operator* op); 808 | #endif 809 | Padding ConvertPadding(tflite::Padding padding); 810 | 811 | ActivationFunctionType ConvertActivationFunction( 812 | tflite::ActivationFunctionType fn_activation_type); 813 | 814 | FlatBufferModel flat_buffers_; 815 | const tflite::Model *fb_model_; 816 | std::vector buffers_; 817 | std::vector operators_code_; 818 | Graph graph_; 819 | }; 820 | 821 | template 822 | std::vector AssignVector(Ptr ptr) { 823 | std::vector vec; 824 | 825 | if (!ptr) { 826 | return vec; 827 | } 828 | 829 | for (auto it = ptr->begin(); it != ptr->end(); ++it) { 830 | vec.push_back(*it); 831 | } 832 | 833 | return vec; 834 | } 835 | 836 | } 837 | 838 | #endif 839 | -------------------------------------------------------------------------------- /src/templates/jni.tpl: -------------------------------------------------------------------------------- 1 | "#include \n\ 2 | #include \n\ 3 | #include \"nn.h\"\n\ 4 | \n\ 5 | jint throwException(JNIEnv *env, std::string message) {\n\ 6 | jclass exClass;\n\ 7 | std::string className = \"java/lang/RuntimeException\" ;\n\ 8 | \n\ 9 | exClass = env->FindClass(className.c_str());\n\ 10 | \n\ 11 | return env->ThrowNew(exClass, message.c_str());\n\ 12 | }\n\ 13 | \n\ 14 | extern \"C\"\n\ 15 | JNIEXPORT void\n\ 16 | JNICALL\n\ 17 | @JAVA_PACKAGE_readFile(\n\ 18 | JNIEnv *env,\n\ 19 | jobject /* this */,\n\ 20 | jstring params_file,\n\ 21 | jint preference) {\n\ 22 | std::string filename = std::string(env->GetStringUTFChars(params_file,\n\ 23 | nullptr));\n\ 24 | \n\ 25 | if (!nnc::OpenTrainingData(filename.c_str())) {\n\ 26 | throwException(env, \"Error on open file: \" + filename);\n\ 27 | return;\n\ 28 | }\n\ 29 | \n\ 30 | if (!nnc::CreateModel()) {\n\ 31 | throwException(env, \"Error on create nnapi model\");\n\ 32 | return;\n\ 33 | }\n\ 34 | \n\ 35 | if (!nnc::Compile(preference)) {\n\ 36 | throwException(env, \"Error on compile nnapi model\");\n\ 37 | return;\n\ 38 | }\n\ 39 | \n\ 40 | if (!nnc::BuildModel()) {\n\ 41 | throwException(env, \"Error on build model\");\n\ 42 | return;\n\ 43 | }\n\ 44 | }\n\ 45 | \n\ 46 | extern \"C\"\n\ 47 | JNIEXPORT void\n\ 48 | JNICALL\n\ 49 | @JAVA_PACKAGE_cleanup(\n\ 50 | JNIEnv *env,\n\ 51 | jobject /* this */) {\n\ 52 | nnc::Cleanup();\n\ 53 | }\n\ 54 | \n\ 55 | extern \"C\"\n\ 56 | JNIEXPORT void\n\ 57 | JNICALL\n\ 58 | @JAVA_PACKAGE_execute(\n\ 59 | JNIEnv *env,\n\ 60 | jobject /* this */) {\n\ 61 | if (!nnc::Execute()) {\n\ 62 | throwException(env, \"Error on execute model\");\n\ 63 | return;\n\ 64 | }\n\ 65 | }\n\ 66 | \n\ 67 | extern \"C\"\n\ 68 | JNIEXPORT void\n\ 69 | JNICALL\n\ 70 | @JAVA_PACKAGE_setInput(\n\ 71 | JNIEnv *env,\n\ 72 | jobject /* this */,\n\ 73 | jbyteArray input_data) {\n\ 74 | jsize input_len = env->GetArrayLength(input_data);\n\ 75 | \n\ 76 | if (input_len != @TOTAL_INPUT_SIZE) {\n\ 77 | throwException(env, \"Input data has wrong length\");\n\ 78 | return;\n\ 79 | }\n\ 80 | \n\ 81 | jbyte *bytes = env->GetByteArrayElements(input_data, 0);\n\ 82 | \n\ 83 | if (bytes == NULL) {\n\ 84 | throwException(env, \"Error on elements from java array input data\");\n\ 85 | return;\n\ 86 | }\n\ 87 | \n\ 88 | if (!nnc::SetInput(bytes)) {\n\ 89 | env->ReleaseByteArrayElements(input_data, bytes, JNI_ABORT);\n\ 90 | throwException(env, \"Error on execute model\");\n\ 91 | return;\n\ 92 | }\n\ 93 | \n\ 94 | env->ReleaseByteArrayElements(input_data, bytes, 0);\n\ 95 | }\n\ 96 | \n\ 97 | extern \"C\"\n\ 98 | JNIEXPORT jbyteArray\n\ 99 | JNICALL\n\ 100 | @JAVA_PACKAGE_getOutput(\n\ 101 | JNIEnv *env,\n\ 102 | jobject /* this */) {\n\ 103 | jbyteArray result;\n\ 104 | result = env->NewByteArray(@TOTAL_OUTPUT_SIZE);\n\ 105 | if (result == NULL) {\n\ 106 | throwException(env, \"out of memory\");\n\ 107 | return NULL; /* out of memory error thrown */\n\ 108 | }\n\ 109 | \n\ 110 | jbyte data[@TOTAL_OUTPUT_SIZE];\n\ 111 | if (!nnc::SetOutput(data)) {\n\ 112 | throwException(env, \"Error on execute model\");\n\ 113 | return NULL;\n\ 114 | }\n\ 115 | \n\ 116 | env->SetByteArrayRegion(result, 0, @TOTAL_OUTPUT_SIZE, data);\n\ 117 | return result;\n\ 118 | }\n\ 119 | " 120 | -------------------------------------------------------------------------------- /src/templates/top_nn_cc.tpl: -------------------------------------------------------------------------------- 1 | "#include \n\ 2 | #include \n\ 3 | #include \n\ 4 | #include \n\ 5 | #include \n\ 6 | #include \n\ 7 | #include \n\ 8 | #include \n\ 9 | \n\ 10 | #include \"nn.h\"\n\ 11 | \n\ 12 | #define LOG_TAG \"NNC\"\n\ 13 | \n\ 14 | namespace nnc {\n\ 15 | \n\ 16 | static ANeuralNetworksMemory* mem = NULL;\n\ 17 | static int fd;\n\ 18 | static ANeuralNetworksModel* model = NULL;\n\ 19 | static ANeuralNetworksCompilation* compilation = NULL;\n\ 20 | static ANeuralNetworksExecution* run = NULL;\n\ 21 | \n\ 22 | bool OpenTrainingData(const char* file_name) {\n\ 23 | int fd = open(file_name, O_RDONLY);\n\ 24 | \n\ 25 | if (fd < 0) {\n\ 26 | __android_log_print(ANDROID_LOG_ERROR, LOG_TAG,\n\ 27 | \"open failed\");\n\ 28 | return false;\n\ 29 | }\n\ 30 | \n\ 31 | struct stat sb;\n\ 32 | fstat(fd, &sb);\n\ 33 | size_t buffer_size_bytes = sb.st_size;\n\ 34 | int status = ANeuralNetworksMemory_createFromFd(buffer_size_bytes, PROT_READ, fd, 0, &mem);\n\ 35 | if (status != ANEURALNETWORKS_NO_ERROR) {\n\ 36 | __android_log_print(ANDROID_LOG_ERROR, LOG_TAG,\n\ 37 | \"ANeuralNetworksMemory_createFromFd failed\");\n\ 38 | return false;\n\ 39 | }\n\ 40 | \n\ 41 | return true;\n\ 42 | }\n\ 43 | \n\ 44 | bool CreateModel() {\n\ 45 | int status = ANeuralNetworksModel_create(&model);\n\ 46 | \n\ 47 | if (status != ANEURALNETWORKS_NO_ERROR) {\n\ 48 | __android_log_print(ANDROID_LOG_ERROR, LOG_TAG,\n\ 49 | \"ANeuralNetworksMemory_createFromFd failed\");\n\ 50 | return false;\n\ 51 | }\n\ 52 | \n\ 53 | return true;\n\ 54 | }\n\ 55 | \n\ 56 | bool Compile(int32_t preference) {\n\ 57 | int status = ANeuralNetworksCompilation_create(model, &compilation);\n\ 58 | if (status != ANEURALNETWORKS_NO_ERROR) {\n\ 59 | __android_log_print(ANDROID_LOG_ERROR, LOG_TAG,\n\ 60 | \"ANeuralNetworksMemory_createFromFd failed\");\n\ 61 | return false;\n\ 62 | }\n\ 63 | \n\ 64 | status = ANeuralNetworksCompilation_setPreference(compilation, preference);\n\ 65 | if (status != ANEURALNETWORKS_NO_ERROR) {\n\ 66 | __android_log_print(ANDROID_LOG_ERROR, LOG_TAG,\n\ 67 | \"ANeuralNetworksMemory_createFromFd failed\");\n\ 68 | return false;\n\ 69 | }\n\ 70 | \n\ 71 | return true;\n\ 72 | }\n\ 73 | \n\ 74 | \n\ 75 | bool Execute() {\n\ 76 | int status = ANeuralNetworksExecution_create(compilation, &run);\n\ 77 | \n\ 78 | if (status != ANEURALNETWORKS_NO_ERROR) {\n\ 79 | __android_log_print(ANDROID_LOG_ERROR, LOG_TAG,\n\ 80 | \"ANeuralNetworksMemory_createFromFd failed\");\n\ 81 | return false;\n\ 82 | }\n\ 83 | \n\ 84 | ANeuralNetworksEvent* run_end = NULL;\n\ 85 | ANeuralNetworksExecution_startCompute(run, &run_end);\n\ 86 | ANeuralNetworksEvent_wait(run_end);\n\ 87 | ANeuralNetworksEvent_free(run_end);\n\ 88 | ANeuralNetworksExecution_free(run);\n\ 89 | return true;\n\ 90 | }\n\ 91 | \n\ 92 | void Cleanup() {\n\ 93 | ANeuralNetworksCompilation_free(compilation);\n\ 94 | ANeuralNetworksModel_free(model);\n\ 95 | ANeuralNetworksMemory_free(mem);\n\ 96 | }\n\ 97 | \n\ 98 | #define CHECK_ADD_SCALAR(x) \\\n\ 99 | if (!x) { \\\n\ 100 | __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, \\\n\ 101 | \"AddScalar Failed\"); \\\n\ 102 | return false; \\\n\ 103 | }\n\ 104 | \n\ 105 | bool AddScalarInt32(int32_t id, int value) {\n\ 106 | ANeuralNetworksOperandType operand_type{.type = ANEURALNETWORKS_INT32};\n\ 107 | \n\ 108 | int status = ANeuralNetworksModel_addOperand(model, &operand_type);\n\ 109 | if (status != ANEURALNETWORKS_NO_ERROR) {\n\ 110 | __android_log_print(ANDROID_LOG_ERROR, LOG_TAG,\n\ 111 | \"ANeuralNetworksModel_addOperand failed\");\n\ 112 | return false;\n\ 113 | }\n\ 114 | \n\ 115 | status = ANeuralNetworksModel_setOperandValue(model, id, &value, sizeof(int32_t));\n\ 116 | if (status != ANEURALNETWORKS_NO_ERROR) {\n\ 117 | __android_log_print(ANDROID_LOG_ERROR, LOG_TAG,\n\ 118 | \"ANeuralNetworksModel_setOperandValue failed\");\n\ 119 | return false;\n\ 120 | }\n\ 121 | \n\ 122 | return true;\n\ 123 | }\n\ 124 | \n\ 125 | bool AddScalarFloat32(int32_t id, float value) {\n\ 126 | ANeuralNetworksOperandType operand_type{.type = ANEURALNETWORKS_FLOAT32};\n\ 127 | \n\ 128 | int status = ANeuralNetworksModel_addOperand(model, &operand_type);\n\ 129 | if (status != ANEURALNETWORKS_NO_ERROR) {\n\ 130 | __android_log_print(ANDROID_LOG_ERROR, LOG_TAG,\n\ 131 | \"ANeuralNetworksModel_addOperand failed\");\n\ 132 | return false;\n\ 133 | }\n\ 134 | \n\ 135 | status = ANeuralNetworksModel_setOperandValue(model, id, &value, sizeof(float));\n\ 136 | if (status != ANEURALNETWORKS_NO_ERROR) {\n\ 137 | __android_log_print(ANDROID_LOG_ERROR, LOG_TAG,\n\ 138 | \"ANeuralNetworksModel_setOperandValue failed\");\n\ 139 | return false;\n\ 140 | }\n\ 141 | \n\ 142 | return true;\n\ 143 | }\n\ 144 | \n\ 145 | bool BuildModel() {\n\ 146 | int tensor_size = 0;\n\ 147 | int offset = 0;\n\ 148 | int status;\n\ 149 | "; 150 | -------------------------------------------------------------------------------- /src/templates/top_nn_h.tpl: -------------------------------------------------------------------------------- 1 | "namespace nnc {\n\ 2 | \n\ 3 | bool OpenTrainingData(const char* file_name);\n\ 4 | bool CreateModel();\n\ 5 | bool Compile(int32_t preference);\n\ 6 | bool Execute();\n\ 7 | void Cleanup();\n\ 8 | bool BuildModel();\n\ 9 | bool SetInput(const int8_t *buffer);\n\ 10 | bool SetOutput(int8_t *buffer);\n\ 11 | " 12 | --------------------------------------------------------------------------------