├── .gitignore ├── .travis.yml ├── model └── mnist.tflite ├── readme.md ├── makefile └── src └── tflite_model_parse.cpp /.gitignore: -------------------------------------------------------------------------------- 1 | flatbuffers 2 | build 3 | include 4 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | dist: bionic 2 | language: cpp 3 | 4 | script: 5 | - make run 6 | -------------------------------------------------------------------------------- /model/mnist.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tutorials-with-ci/tflite-model-parse/HEAD/model/mnist.tflite -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # TF-Lite Model Parse 2 | 3 | [![Build Status](https://travis-ci.com/tutorials-with-ci/tflite-model-parse.svg?branch=master)](https://travis-ci.com/tutorials-with-ci/tflite-model-parse) 4 | 5 | Parse TF-Lite Model File in C++ 6 | 7 | ## Steps 8 | 9 | Just `make run`! 10 | 11 | The detail steps are in the `makefile` 12 | 13 | 1. clone [FlatBuffers](https://github.com/google/flatbuffers) and compile 14 | 2. download [tf-lite model schema](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs) file 15 | 3. compile schema to c header file 16 | 4. compile & run 17 | -------------------------------------------------------------------------------- /makefile: -------------------------------------------------------------------------------- 1 | all: build/tflite_model_parse 2 | 3 | build/tflite_model_parse: src/tflite_model_parse.cpp include/flatbuffers/flatbuffers.h include/schema_generated.h 4 | g++ -I./include src/tflite_model_parse.cpp -std=c++11 -O2 -o build/tflite_model_parse 5 | 6 | run: build/tflite_model_parse 7 | ./build/tflite_model_parse 8 | diff ./model/mnist.tflite ./build/mnist_mutated.tflite || echo tf-lite model is mutated! 9 | 10 | include/: 11 | mkdir -p include 12 | 13 | include/flatbuffers/flatbuffers.h: include/ build/flatc 14 | mkdir -p include/flatbuffers 15 | cp flatbuffers/include/flatbuffers/base.h include/flatbuffers 16 | cp flatbuffers/include/flatbuffers/flatbuffers.h include/flatbuffers 17 | cp flatbuffers/include/flatbuffers/stl_emulation.h include/flatbuffers 18 | 19 | include/schema_generated.h: include/ build/flatc build/schema.fbs 20 | cd build && ./flatc -c ./schema.fbs --gen-mutable 21 | mv build/schema_generated.h include 22 | 23 | build/flatc: 24 | [ -d flatbuffers ] || git clone https://github.com/google/flatbuffers.git --depth 1 25 | mkdir -p build 26 | cd build && cmake ../flatbuffers && make -j8 flatc 27 | 28 | build/schema.fbs: 29 | mkdir -p build 30 | wget https://github.com/tensorflow/tensorflow/raw/master/tensorflow/lite/schema/schema.fbs -P build 31 | -------------------------------------------------------------------------------- /src/tflite_model_parse.cpp: -------------------------------------------------------------------------------- 1 | // Copyright [2018] 2 | // Author: SF-Zhou 3 | #include // C++ header file for file access 4 | #include // C++ header file for printing 5 | #include "flatbuffers/flatbuffers.h" 6 | #include "schema_generated.h" 7 | 8 | template 9 | void PrintVector(const flatbuffers::Vector &values) { 10 | printf("["); 11 | for (int i = 0; i < values.Length(); ++i) { 12 | if (i != 0) { 13 | printf(", "); 14 | } 15 | std::cout << values[i]; 16 | } 17 | printf("]\n"); 18 | } 19 | 20 | int main() { 21 | std::ifstream infile; 22 | infile.open("./model/mnist.tflite", std::ios::binary | std::ios::in); 23 | infile.seekg(0, std::ios::end); 24 | int length = infile.tellg(); 25 | infile.seekg(0, std::ios::beg); 26 | std::vector data(length); 27 | infile.read(data.data(), length); 28 | infile.close(); 29 | 30 | auto &model = *tflite::GetMutableModel(data.data()); 31 | auto &graph_list = *model.mutable_subgraphs(); 32 | auto &graph = *graph_list.GetMutableObject(0); 33 | 34 | auto &inputs = *graph.inputs(); 35 | printf("Inputs Count: %d\n", inputs.Length()); 36 | auto input_index = inputs[0]; 37 | printf("First Input Tensor Index: %d\n", input_index); 38 | 39 | auto &outputs = *graph.outputs(); 40 | printf("Outputs Count: %d\n", outputs.Length()); 41 | auto output_index = outputs[0]; 42 | printf("First Output Tensor Index: %d\n", output_index); 43 | 44 | auto &tensors = *graph.mutable_tensors(); 45 | printf("Tensors Count: %d\n", tensors.Length()); 46 | 47 | auto &input_tensor = *tensors.GetMutableObject(input_index); 48 | printf("Input Tensor Name: %s\n", input_tensor.name()->c_str()); 49 | auto &input_shape = *input_tensor.shape(); 50 | printf("Input Shape: "); 51 | PrintVector(input_shape); 52 | 53 | auto &output_tensor = *tensors.GetMutableObject(output_index); 54 | printf("Output Tensor Name: %s\n", output_tensor.name()->c_str()); 55 | auto &output_shape = *output_tensor.shape(); 56 | printf("Output Shape: "); 57 | PrintVector(output_shape); 58 | 59 | auto &input_quantization = *input_tensor.mutable_quantization(); 60 | auto &scales = *input_quantization.mutable_scale(); 61 | printf("Input Quanzation Scales: "); 62 | PrintVector(scales); 63 | scales.Mutate(0, 2.3333); 64 | printf("Input Quanzation Scales (Mutated): "); 65 | PrintVector(scales); 66 | printf("Input Quanzation Zero-Points: "); 67 | PrintVector(*input_quantization.zero_point()); 68 | 69 | auto &output_quantization = *output_tensor.mutable_quantization(); 70 | printf("Output Quanzation Scales: "); 71 | PrintVector(*output_quantization.scale()); 72 | printf("Output Quanzation Zero-Points: "); 73 | PrintVector(*output_quantization.zero_point()); 74 | 75 | std::ofstream outfile; 76 | outfile.open("./build/mnist_mutated.tflite", std::ios::binary | std::ios::out); 77 | outfile.write(data.data(), length); 78 | outfile.close(); 79 | } 80 | --------------------------------------------------------------------------------