├── .gitignore ├── README.md ├── examples └── SineExample │ └── SineExample.ino ├── library.json ├── library.properties └── src ├── .DS_Store ├── eloquent_tensorflow32.h └── exception.h /.gitignore: -------------------------------------------------------------------------------- 1 | publish 2 | .idea 3 | .DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EloquentTensorFlow32 2 | 3 | An Arduino library to run TensorFlow models on ESP32 chips without pain. 4 | 5 | ## How to use 6 | 7 | Once you have your TensorFlow model exported in a C header format (for example using `xxd`), 8 | running it is as easy as: 9 | 10 | ```cpp 11 | #include 12 | #include "your_tf_model.h" 13 | #define NUM_OPS 1 14 | #define ARENA_SIZE 2000 15 | 16 | using Eloquent::Esp32::TensorFlow; 17 | TensorFlow tf; 18 | 19 | void setup() { 20 | Serial.begin(115200); 21 | 22 | tf.setNumInputs(1); 23 | tf.setNumOutputs(1); 24 | // add required ops 25 | tf.resolver.AddFullyConnected(); 26 | 27 | // init model 28 | while (!tf.begin(your_tf_model).isOk()) 29 | Serial.println(tf.exception.toString()); 30 | } 31 | 32 | void loop() { 33 | // fill your input vector 34 | float input[1] = {0}; 35 | 36 | while (!tf.predict(input).isOk()) 37 | Serial.println(tf.exception.toString()); 38 | 39 | // one output 40 | Serial.print("One output: "); 41 | Serial.println(tf.result()); 42 | 43 | // many outputs 44 | Serial.print("Many outputs: "); 45 | 46 | for (int i = 0; i < tf.numOutputs; i++) { 47 | Serial.print(tf.result(i)); 48 | Serial.print(", "); 49 | } 50 | 51 | Serial.println(); 52 | } 53 | ``` -------------------------------------------------------------------------------- /examples/SineExample/SineExample.ino: -------------------------------------------------------------------------------- 1 | /** 2 | * Run a TensorFlow NN to predict sin(x) 3 | * For a complete guide, visit 4 | * https://eloquentarduino.com/tensorflow-lite-esp32 5 | */ 6 | #include 7 | // replace with your own model 8 | #include "sine_model.h" 9 | // replace with the correct number of ops 10 | #define NUM_OPS 1 11 | // this is trial-and-error 12 | // when developing a new model, start with a high value 13 | // (e.g. 10000), then decrease until the model stops 14 | // working as expected 15 | #define ARENA_SIZE 2000 16 | 17 | using Eloquent::Esp32::TensorFlow; 18 | 19 | TensorFlow tf; 20 | 21 | /** 22 | * 23 | */ 24 | void setup() { 25 | Serial.begin(115200); 26 | delay(3000); 27 | Serial.println("__TENSORFLOW ESP32 SINE__"); 28 | 29 | // replace with the correct values 30 | tf.setNumInputs(1); 31 | tf.setNumOutputs(1); 32 | // add required ops 33 | tf.resolver.AddFullyConnected(); 34 | 35 | while (!tf.begin(sine_model).isOk()) 36 | Serial.println(tf.exception.toString()); 37 | } 38 | 39 | 40 | void loop() { 41 | float x = random(1000) / 1000 * 3.14; 42 | float input[1] = {x}; 43 | 44 | while (!tf.predict(input).isOk()) 45 | Serial.println(tf.exception.toString()); 46 | 47 | Serial.printf( 48 | "sin(%.2f) = %.2f. Predicted %.2f\n", 49 | x, 50 | sin(x), 51 | tf.result(0) 52 | ); 53 | } -------------------------------------------------------------------------------- /library.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "EloquentTensorFlow32", 3 | "keywords": "tinyml", 4 | "description": "Eloquent interface to Tensorflow Lite for Microcontrollers for ESP32 chipset", 5 | "repository": { 6 | "type": "git", 7 | "url": "https://github.com/eloquentarduino/EloquentTensorFlow32" 8 | }, 9 | "version": "1.0.1", 10 | "authors": { 11 | "name": "Simone Salerno", 12 | "url": "https://github.com/eloquentarduino" 13 | }, 14 | "frameworks": "arduino", 15 | "platforms": "*" 16 | } 17 | -------------------------------------------------------------------------------- /library.properties: -------------------------------------------------------------------------------- 1 | name=EloquentTensorFlow32 2 | version=1.0.1 3 | author=Simone Salerno,eloquentarduino@gmail.com 4 | maintainer=Simone Salerno,eloquentarduino@gmail.com 5 | sentence=An eloquent interface to Tensorflow Lite for Microcontrollers for ESP32 chipset 6 | paragraph= 7 | category=Other 8 | url=https://github.com/eloquentarduino/EloquentTensorFlow32 9 | architectures=* 10 | depends=TensorFlowLite_ESP32 -------------------------------------------------------------------------------- /src/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eloquentarduino/EloquentTensorFlow32/7f0ed8a3a1dcef1cb864f880d5fac83048ecb3f1/src/.DS_Store -------------------------------------------------------------------------------- /src/eloquent_tensorflow32.h: -------------------------------------------------------------------------------- 1 | #ifndef ELOQUENT_TENSORFLOW_32 2 | #define ELOQUENT_TENSORFLOW_32 3 | 4 | #include 5 | #include "tensorflow/lite/micro/micro_mutable_op_resolver.h" 6 | #include "tensorflow/lite/micro/micro_error_reporter.h" 7 | #include "tensorflow/lite/micro/micro_interpreter.h" 8 | #include "tensorflow/lite/micro/system_setup.h" 9 | #include "tensorflow/lite/schema/schema_generated.h" 10 | #include "./exception.h" 11 | 12 | using Eloquent::Extra::Exception; 13 | using tflite::Model; 14 | using tflite::ErrorReporter; 15 | using tflite::MicroErrorReporter; 16 | using tflite::MicroMutableOpResolver; 17 | using tflite::MicroInterpreter; 18 | 19 | namespace Eloquent { 20 | namespace Esp32 { 21 | /** 22 | * Run TensorFlow models the Eloquent-style 23 | */ 24 | template 25 | class TensorFlow { 26 | public: 27 | const Model *model; 28 | ErrorReporter *reporter; 29 | MicroMutableOpResolver resolver; 30 | MicroInterpreter *interpreter; 31 | TfLiteTensor *input; 32 | TfLiteTensor *output; 33 | Exception exception; 34 | uint8_t arena[tensorArenaSize]; 35 | uint16_t numInputs; 36 | uint16_t numOutputs; 37 | float *outputs; 38 | 39 | /** 40 | * Constructor 41 | */ 42 | TensorFlow() : 43 | exception("TF"), 44 | reporter(nullptr), 45 | model(nullptr), 46 | interpreter(nullptr), 47 | input(nullptr), 48 | output(nullptr), 49 | numInputs(0), 50 | numOutputs(0), 51 | outputs(NULL) 52 | { 53 | 54 | } 55 | 56 | /** 57 | * 58 | */ 59 | void setNumInputs(uint16_t n) { 60 | numInputs = n; 61 | } 62 | 63 | /** 64 | * 65 | */ 66 | void setNumOutputs(uint16_t n) { 67 | numOutputs = n; 68 | } 69 | 70 | /** 71 | * Get i-th output 72 | */ 73 | float result(uint16_t i = 0) { 74 | if (outputs == NULL || i >= numOutputs) 75 | return sqrt(-1); 76 | 77 | return outputs[i]; 78 | } 79 | 80 | /** 81 | * Init model 82 | */ 83 | Exception& begin(const unsigned char *data) { 84 | if (!numInputs) 85 | return exception.set("You must set the number of inputs"); 86 | 87 | if (!numOutputs) 88 | return exception.set("You must set the number of outputs"); 89 | 90 | model = tflite::GetModel(data); 91 | 92 | if (model->version() != TFLITE_SCHEMA_VERSION) 93 | return exception.set(String("Model version mismatch. Expected ") + TFLITE_SCHEMA_VERSION + ", got " + model->version()); 94 | 95 | reporter = new MicroErrorReporter(); 96 | interpreter = new MicroInterpreter(model, resolver, arena, tensorArenaSize, reporter); 97 | 98 | TfLiteStatus status = interpreter->AllocateTensors(); 99 | 100 | if (status != kTfLiteOk) 101 | return exception.set("AllocateTensors() failed"); 102 | 103 | input = interpreter->input(0); 104 | output = interpreter->output(0); 105 | 106 | return exception.clear(); 107 | } 108 | 109 | /** 110 | * 111 | */ 112 | template 113 | Exception& predict(T *x) { 114 | // quantize 115 | float inputScale = input->params.scale; 116 | float inputOffset = input->params.zero_point; 117 | 118 | for (uint16_t i = 0; i < numInputs; i++) 119 | input->data.int8[i] = (x[i] / inputScale) + inputOffset; 120 | 121 | // execute 122 | TfLiteStatus status = interpreter->Invoke(); 123 | 124 | if (status != kTfLiteOk) 125 | return exception.set("Invoke() failed"); 126 | 127 | // allocate outputs 128 | if (outputs == NULL) 129 | outputs = (float*) calloc(numOutputs, sizeof(float)); 130 | 131 | // dequantize 132 | float outputScale = output->params.scale; 133 | float outputOffset = output->params.zero_point; 134 | 135 | for (uint16_t i = 0; i < numOutputs; i++) 136 | outputs[i] = (output->data.int8[0] - outputOffset) * outputScale; 137 | 138 | return exception.clear(); 139 | } 140 | 141 | protected: 142 | 143 | 144 | 145 | }; 146 | } 147 | } 148 | 149 | 150 | #endif -------------------------------------------------------------------------------- /src/exception.h: -------------------------------------------------------------------------------- 1 | #ifndef ELOQUENT_EXTRA_ERROR_EXCEPTION_H 2 | #define ELOQUENT_EXTRA_ERROR_EXCEPTION_H 3 | 4 | namespace Eloquent { 5 | namespace Extra { 6 | /** 7 | * Application expcetion 8 | */ 9 | class Exception { 10 | public: 11 | /** 12 | * 13 | */ 14 | Exception(const char* tag) : 15 | _tag(tag), 16 | _message(""), 17 | _isSevere(true) { 18 | } 19 | 20 | /** 21 | * Test if there's an exception 22 | */ 23 | operator bool() const { 24 | return !isOk(); 25 | } 26 | 27 | /** 28 | * Test if there's an exception 29 | */ 30 | bool isOk() const { 31 | return _message == ""; 32 | } 33 | 34 | /** 35 | * Test if exception is severe 36 | */ 37 | bool isSevere() const { 38 | return _isSevere && !isOk(); 39 | } 40 | 41 | /** 42 | * Mark error as not severe 43 | */ 44 | Exception& soft() { 45 | _isSevere = false; 46 | 47 | return *this; 48 | } 49 | 50 | /** 51 | * Set exception message 52 | */ 53 | Exception& set(String error) { 54 | _message = error; 55 | _isSevere = true; 56 | 57 | if (error.length() > 0) { 58 | const char *c_str = error.c_str(); 59 | ESP_LOGE(_tag, "%s", c_str); 60 | } 61 | 62 | return *this; 63 | } 64 | 65 | /** 66 | * Clear exception 67 | */ 68 | Exception& clear() { 69 | return set(""); 70 | } 71 | 72 | /** 73 | * 74 | */ 75 | template 76 | Exception& propagate(Other& other) { 77 | set(other.exception.toString()); 78 | 79 | return *this; 80 | } 81 | 82 | /** 83 | * Convert exception to string 84 | */ 85 | inline String toString() { 86 | return _message; 87 | } 88 | 89 | /** 90 | * Convert exception to char* 91 | */ 92 | inline const char* toCString() { 93 | return toString().c_str(); 94 | } 95 | 96 | /** 97 | * 98 | */ 99 | static Exception none() { 100 | return Exception(""); 101 | } 102 | 103 | protected: 104 | const char* _tag; 105 | bool _isSevere; 106 | String _message; 107 | }; 108 | } 109 | } 110 | 111 | #endif --------------------------------------------------------------------------------