├── .github └── workflows │ ├── config.rb │ └── mrbgem.yml ├── .gitignore ├── README.md ├── example ├── fizzbuzz │ ├── fizzbuzz.rb │ ├── fizzbuzz_model.tflite │ └── make.py └── xor │ ├── maketflite.py │ ├── makexor.py │ ├── xor.rb │ ├── xor_model.h5 │ └── xor_model.tflite ├── mrbgem.rake ├── src └── mrb_tflite.c ├── tensorflow.patch └── test ├── tflite_test.rb └── xor_model.tflite /.github/workflows/config.rb: -------------------------------------------------------------------------------- 1 | MRuby::Build.new do |conf| 2 | toolchain :gcc 3 | enable_test 4 | enable_debug 5 | 6 | conf.cc.flags << '-fsanitize=address,undefined' 7 | conf.cxx.flags << '-fsanitize=address,undefined' 8 | conf.linker.flags << '-fsanitize=address,undefined' 9 | 10 | conf.gem "#{MRUBY_ROOT}/.." 11 | end 12 | -------------------------------------------------------------------------------- /.github/workflows/mrbgem.yml: -------------------------------------------------------------------------------- 1 | name: mrbgem test 2 | 3 | on: 4 | push: {} 5 | pull_request: 6 | branches: [ master ] 7 | 8 | jobs: 9 | build: 10 | strategy: 11 | fail-fast: false 12 | matrix: 13 | os: [ubuntu-18.04] 14 | mruby_version: [master, 2.1.2, 1.4.1] 15 | runs-on: ${{ matrix.os }} 16 | name: ${{ matrix.os }} & mruby-${{ matrix.mruby_version }} 17 | env: 18 | MRUBY_VERSION: ${{ matrix.mruby_version }} 19 | steps: 20 | - uses: actions/checkout@v2 21 | - name: install package 22 | run: | 23 | sudo apt install curl gnupg 24 | curl -fsSL https://bazel.build/bazel-release.pub.gpg | gpg --dearmor > bazel.gpg 25 | sudo mv bazel.gpg /etc/apt/trusted.gpg.d/ 26 | echo "deb [arch=amd64] https://storage.googleapis.com/bazel-apt stable jdk1.8" | sudo tee /etc/apt/sources.list.d/bazel.list 27 | sudo apt update && sudo apt install bazel bazel-3.1.0 python-pip python-dev libegl1-mesa-dev libgles2-mesa-dev 28 | sudo pip install numpy future 29 | - name: download mruby 30 | run: git clone --depth 1 -b $MRUBY_VERSION "https://github.com/mruby/mruby.git" mruby 31 | - name: run test 32 | run: cd mruby && MRUBY_CONFIG="../.github/workflows/config.rb" ./minirake all test 33 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | gem_* 2 | gem-* 3 | mrb-*.a 4 | src/*.o 5 | *.d 6 | compile_flags.txt 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mruby-tflite 2 | 3 | interface to TensorFlow Lite for mruby 4 | 5 | ## Usage 6 | 7 | ```ruby 8 | model = TfLite::Model.from_file "xor_model.tflite" 9 | interpreter = TfLite::Interpreter.new(model) 10 | interpreter.allocate_tensors 11 | input = interpreter.input_tensor(0) 12 | output = interpreter.output_tensor(0) 13 | [[0,0], [1,0], [0,1], [1,1]].each do |x| 14 | input.data = x 15 | interpreter.invoke 16 | puts output.data[0].round 17 | end 18 | ``` 19 | 20 | ## Requirements 21 | 22 | * TensorFlow Lite 23 | 24 | ## License 25 | 26 | MIT 27 | 28 | ## Author 29 | 30 | Yasuhiro Matsumoto (a.k.a. mattn) 31 | -------------------------------------------------------------------------------- /example/fizzbuzz/fizzbuzz.rb: -------------------------------------------------------------------------------- 1 | #!mruby 2 | 3 | def bin(n, num_digits) 4 | f = [] 5 | 0.upto(num_digits-1) do |x| 6 | f[x] = (n >> x) & 1 7 | end 8 | return f 9 | end 10 | 11 | def dec(b, n) 12 | b.each_with_index do |x, i| 13 | if x > 0.4 14 | return case i+1 15 | when 1; n.to_s 16 | when 2; 'Fizz' 17 | when 3; 'Buzz' 18 | when 4; 'FizzBuzz' 19 | end 20 | end 21 | end 22 | raise "f*ck" 23 | end 24 | 25 | model = TfLite::Model.from_file 'fizzbuzz_model.tflite' 26 | interpreter = TfLite::Interpreter.new(model) 27 | interpreter.allocate_tensors 28 | input = interpreter.input_tensor(0) 29 | output = interpreter.output_tensor(0) 30 | 1.upto(100) do |x| 31 | input.data = bin(x, 7) 32 | interpreter.invoke 33 | puts dec(output.data, x) 34 | end 35 | -------------------------------------------------------------------------------- /example/fizzbuzz/fizzbuzz_model.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattn/mruby-tflite/385b658cca626ab932642aafe09846f64d58255f/example/fizzbuzz/fizzbuzz_model.tflite -------------------------------------------------------------------------------- /example/fizzbuzz/make.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tensorflow.contrib.keras.api.keras.models import Sequential, model_from_json 3 | from tensorflow.contrib.keras.api.keras.layers import Dense, Dropout, Activation 4 | from tensorflow.contrib.keras.api.keras.optimizers import SGD, Adam 5 | import tensorflow.contrib.lite as lite 6 | 7 | 8 | def fizzbuzz(i): 9 | if i % 15 == 0: return np.array([0, 0, 0, 1]) 10 | elif i % 5 == 0: return np.array([0, 0, 1, 0]) 11 | elif i % 3 == 0: return np.array([0, 1, 0, 0]) 12 | else: return np.array([1, 0, 0, 0]) 13 | 14 | def bin(i, num_digits): 15 | return np.array([i >> d & 1 for d in range(num_digits)]) 16 | 17 | NUM_DIGITS = 7 18 | trX = np.array([bin(i, NUM_DIGITS) for i in range(1, 101)]) 19 | trY = np.array([fizzbuzz(i) for i in range(1, 101)]) 20 | model = Sequential() 21 | model.add(Dense(64, input_dim = 7)) 22 | model.add(Activation('tanh')) 23 | model.add(Dense(4, input_dim = 64)) 24 | model.add(Activation('softmax')) 25 | model.compile(loss = 'categorical_crossentropy', optimizer = 'adam', metrics = ['accuracy']) 26 | model.fit(trX, trY, epochs = 3600, batch_size = 64) 27 | model.save('fizzbuzz_model.h5') 28 | 29 | converter = lite.TFLiteConverter.from_keras_model_file('fizzbuzz_model.h5') 30 | tflite_model = converter.convert() 31 | open('fizzbuzz_model.tflite', 'wb').write(tflite_model) 32 | -------------------------------------------------------------------------------- /example/xor/maketflite.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.lite as lite 3 | 4 | converter = lite.TFLiteConverter.from_keras_model_file("xor_model.h5") 5 | tflite_model = converter.convert() 6 | open("xor_model.tflite", "wb").write(tflite_model) 7 | -------------------------------------------------------------------------------- /example/xor/makexor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from keras.models import Sequential 3 | from keras.layers import Dense, Activation 4 | from keras.optimizers import SGD 5 | 6 | model = Sequential() 7 | model.add(Dense(8, input_dim=2)) 8 | model.add(Activation('tanh')) 9 | model.add(Dense(1)) 10 | model.add(Activation('sigmoid')) 11 | 12 | sgd = SGD(lr=0.1) 13 | model.compile(loss='binary_crossentropy', optimizer=sgd) 14 | X = np.array([[0,0],[0,1],[1,0],[1,1]]) 15 | y = np.array([[0],[1],[1],[0]]) 16 | model.fit(X, y, verbose=True, batch_size=1, epochs=1000) 17 | model.save('xor_model.h5') 18 | print(model.predict_proba(X)) 19 | -------------------------------------------------------------------------------- /example/xor/xor.rb: -------------------------------------------------------------------------------- 1 | #!mruby 2 | 3 | model = TfLite::Model.from_file "xor_model.tflite" 4 | interpreter = TfLite::Interpreter.new(model) 5 | interpreter.allocate_tensors 6 | input = interpreter.input_tensor(0) 7 | output = interpreter.output_tensor(0) 8 | [[0,0], [1,0], [0,1], [1,1]].each do |x| 9 | input.data = x 10 | interpreter.invoke 11 | puts "#{x[0]} ^ #{x[1]} = #{output.data[0].round}" 12 | end 13 | 14 | -------------------------------------------------------------------------------- /example/xor/xor_model.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattn/mruby-tflite/385b658cca626ab932642aafe09846f64d58255f/example/xor/xor_model.h5 -------------------------------------------------------------------------------- /example/xor/xor_model.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattn/mruby-tflite/385b658cca626ab932642aafe09846f64d58255f/example/xor/xor_model.tflite -------------------------------------------------------------------------------- /mrbgem.rake: -------------------------------------------------------------------------------- 1 | MRuby::Gem::Specification.new('mruby-tflite') do |spec| 2 | spec.license = 'MIT' 3 | spec.authors = 'mattn' 4 | spec.version = '2.3.0' 5 | 6 | add_test_dependency 'mruby-env' 7 | ENV['MRB_TFLITE_XORMODEL'] = "#{dir}/test/xor_model.tflite" 8 | 9 | if ENV['TENSORFLOW_ROOT'] 10 | spec.cc.include_paths << ENV['TENSORFLOW_ROOT'] 11 | spec.linker.library_paths << ENV['TENSORFLOW_ROOT'] + "tensorflow/lite/experimental/c/" 12 | spec.linker.libraries << 'tensorflowlite_c' 13 | else 14 | header = "#{build_dir}/tensorflow/tensorflow/lite/c/c_api.h" 15 | file header => __FILE__ do 16 | FileUtils.mkdir_p build_dir 17 | Dir.chdir build_dir do 18 | unless Dir.exists? 'tensorflow' 19 | sh "git clone https://github.com/tensorflow/tensorflow.git --depth 1 -b v#{version}" 20 | sh "cd tensorflow; patch -p1 -i #{dir}/tensorflow.patch" 21 | end 22 | Dir.chdir 'tensorflow' do 23 | sh 'bazel build --define tflite_with_xnnpack=true ' \ 24 | '//tensorflow/lite:libtensorflowlite.so ' \ 25 | '//tensorflow/lite/delegates/gpu:libtensorflowlite_gpu_delegate.so ' \ 26 | '//tensorflow/lite/c:libtensorflowlite_c.so' 27 | end 28 | end 29 | end 30 | file "#{dir}/src/mrb_tflite.c" => header 31 | cc.include_paths << "#{build_dir}/tensorflow" 32 | lib_paths = [ 33 | "#{build_dir}/tensorflow/bazel-bin/tensorflow/lite", 34 | "#{build_dir}/tensorflow/bazel-bin/tensorflow/lite/delegates/gpu", 35 | "#{build_dir}/tensorflow/bazel-bin/tensorflow/lite/c", 36 | ] 37 | linker.library_paths += lib_paths 38 | ENV['LD_LIBRARY_PATH'] = "#{ENV['LD_LIBRARY_PATH']}:#{lib_paths.join(':')}" 39 | linker.libraries << 'tensorflowlite' << 'tensorflowlite_gpu_delegate' << 'tensorflowlite_c' 40 | end 41 | end 42 | -------------------------------------------------------------------------------- /src/mrb_tflite.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #if 1 16 | #define ARENA_SAVE \ 17 | int ai = mrb_gc_arena_save(mrb); 18 | #define ARENA_RESTORE \ 19 | mrb_gc_arena_restore(mrb, ai); 20 | #else 21 | #define ARENA_SAVE 22 | #define ARENA_RESTORE 23 | #endif 24 | 25 | static const char* 26 | tensor_type_name(TfLiteType type) { 27 | switch (type) { 28 | case kTfLiteNoType: 29 | return "none"; 30 | case kTfLiteFloat32: 31 | return "float32"; 32 | case kTfLiteInt32: 33 | return "int32"; 34 | case kTfLiteUInt8: 35 | return "uint8"; 36 | case kTfLiteInt64: 37 | return "int64"; 38 | case kTfLiteString: 39 | return "string"; 40 | case kTfLiteBool: 41 | return "bool"; 42 | case kTfLiteInt16: 43 | return "int16"; 44 | case kTfLiteComplex64: 45 | return "complex64"; 46 | case kTfLiteInt8: 47 | return "int8"; 48 | default: 49 | return "unknown"; 50 | } 51 | } 52 | 53 | static void 54 | mrb_tflite_model_free(mrb_state *mrb, void *p) { 55 | TfLiteModelDelete((TfLiteModel*)p); 56 | } 57 | 58 | static void 59 | mrb_tflite_interpreter_options_free(mrb_state *mrb, void *p) { 60 | TfLiteInterpreterOptionsDelete((TfLiteInterpreterOptions*)p); 61 | } 62 | 63 | static void 64 | mrb_tflite_interpreter_free(mrb_state *mrb, void *p) { 65 | TfLiteInterpreterDelete((TfLiteInterpreter*)p); 66 | } 67 | 68 | static const struct mrb_data_type mrb_tflite_tensor_type_ = { 69 | "mrb_tflite_tensor", NULL, 70 | }; 71 | 72 | static const struct mrb_data_type mrb_tflite_model_type = { 73 | "mrb_tflite_model", mrb_tflite_model_free, 74 | }; 75 | 76 | static const struct mrb_data_type mrb_tflite_interpreter_options_type = { 77 | "mrb_tflite_interpreter_options", mrb_tflite_interpreter_options_free 78 | }; 79 | 80 | static const struct mrb_data_type mrb_tflite_interpreter_type = { 81 | "mrb_tflite_interpreter", mrb_tflite_interpreter_free 82 | }; 83 | 84 | static mrb_value 85 | mrb_tflite_model_init(mrb_state *mrb, mrb_value self) { 86 | TfLiteModel* model; 87 | mrb_value str; 88 | mrb_get_args(mrb, "S", &str); 89 | model = TfLiteModelCreate(RSTRING_PTR(str), RSTRING_LEN(str)); 90 | if (model == NULL) { 91 | mrb_raise(mrb, E_RUNTIME_ERROR, "cannot create model"); 92 | } 93 | DATA_TYPE(self) = &mrb_tflite_model_type; 94 | DATA_PTR(self) = model; 95 | return self; 96 | } 97 | 98 | static mrb_value 99 | mrb_tflite_model_from_file(mrb_state *mrb, mrb_value self) { 100 | TfLiteModel* model; 101 | mrb_value str; 102 | struct RClass* _class_tflite_model; 103 | 104 | mrb_get_args(mrb, "S", &str); 105 | model = TfLiteModelCreateFromFile(RSTRING_PTR(str)); 106 | if (model == NULL) { 107 | mrb_raise(mrb, E_RUNTIME_ERROR, "cannot create model"); 108 | } 109 | _class_tflite_model = mrb_class_get_under(mrb, mrb_module_get(mrb, "TfLite"), "Model"); 110 | return mrb_obj_value(Data_Wrap_Struct(mrb, (struct RClass*) _class_tflite_model, 111 | &mrb_tflite_model_type, (void*) model)); 112 | } 113 | 114 | static mrb_value 115 | mrb_tflite_interpreter_options_init(mrb_state *mrb, mrb_value self) { 116 | TfLiteInterpreterOptions* interpreter_options; 117 | 118 | interpreter_options = TfLiteInterpreterOptionsCreate(); 119 | if (interpreter_options == NULL) { 120 | mrb_raise(mrb, E_RUNTIME_ERROR, "cannot create interpreter options"); 121 | } 122 | DATA_TYPE(self) = &mrb_tflite_interpreter_options_type; 123 | DATA_PTR(self) = interpreter_options; 124 | return self; 125 | } 126 | 127 | static mrb_value 128 | mrb_tflite_interpreter_options_num_threads_set(mrb_state *mrb, mrb_value self) { 129 | TfLiteInterpreterOptions* interpreter_options; 130 | int num_threads = 0; 131 | mrb_get_args(mrb, "i", &num_threads); 132 | interpreter_options = DATA_PTR(self); 133 | TfLiteInterpreterOptionsSetNumThreads(interpreter_options, num_threads); 134 | return mrb_nil_value(); 135 | } 136 | 137 | static mrb_value 138 | mrb_tflite_interpreter_options_add_delegate(mrb_state *mrb, mrb_value self) { 139 | TfLiteInterpreterOptions* interpreter_options; 140 | mrb_value delegate; 141 | mrb_get_args(mrb, "o", &delegate); 142 | interpreter_options = DATA_PTR(self); 143 | TfLiteInterpreterOptionsAddDelegate(interpreter_options, DATA_PTR(delegate)); 144 | return mrb_nil_value(); 145 | } 146 | 147 | static mrb_value 148 | mrb_tflite_interpreter_init(mrb_state *mrb, mrb_value self) { 149 | TfLiteInterpreter* interpreter; 150 | TfLiteInterpreterOptions* interpreter_options = NULL; 151 | mrb_value arg_model; 152 | mrb_value arg_options = mrb_nil_value(); 153 | 154 | mrb_get_args(mrb, "o|o", &arg_model, &arg_options); 155 | if (mrb_nil_p(arg_model) || DATA_TYPE(arg_model) != &mrb_tflite_model_type) { 156 | mrb_raise(mrb, E_ARGUMENT_ERROR, "invalid argument"); 157 | } 158 | if (!mrb_nil_p(arg_options) && DATA_TYPE(arg_options) == &mrb_tflite_interpreter_options_type) { 159 | interpreter_options = DATA_PTR(arg_options); 160 | } 161 | interpreter = TfLiteInterpreterCreate((TfLiteModel*) DATA_PTR(arg_model), interpreter_options); 162 | if (interpreter == NULL) { 163 | mrb_raise(mrb, E_RUNTIME_ERROR, "cannot create interpreter"); 164 | } 165 | DATA_TYPE(self) = &mrb_tflite_interpreter_type; 166 | DATA_PTR(self) = interpreter; 167 | return self; 168 | } 169 | 170 | static mrb_value 171 | mrb_tflite_interpreter_allocate_tensors(mrb_state *mrb, mrb_value self) { 172 | TfLiteInterpreter* interpreter = DATA_PTR(self); 173 | if (TfLiteInterpreterAllocateTensors(interpreter) != kTfLiteOk) { 174 | mrb_raise(mrb, E_RUNTIME_ERROR, "cannot allocate tensors"); 175 | } 176 | return mrb_nil_value(); 177 | } 178 | 179 | static mrb_value 180 | mrb_tflite_interpreter_invoke(mrb_state *mrb, mrb_value self) { 181 | TfLiteInterpreter* interpreter = DATA_PTR(self); 182 | if (TfLiteInterpreterInvoke(interpreter) != kTfLiteOk) { 183 | mrb_raise(mrb, E_RUNTIME_ERROR, "cannot invoke"); 184 | } 185 | return mrb_nil_value(); 186 | } 187 | 188 | static mrb_value 189 | mrb_tflite_interpreter_input_tensor_count(mrb_state *mrb, mrb_value self) { 190 | TfLiteInterpreter* interpreter = DATA_PTR(self); 191 | return mrb_fixnum_value(TfLiteInterpreterGetInputTensorCount(interpreter)); 192 | } 193 | 194 | static mrb_value 195 | mrb_tflite_interpreter_output_tensor_count(mrb_state *mrb, mrb_value self) { 196 | TfLiteInterpreter* interpreter = DATA_PTR(self); 197 | return mrb_fixnum_value(TfLiteInterpreterGetOutputTensorCount(interpreter)); 198 | } 199 | 200 | static mrb_value 201 | mrb_tflite_interpreter_input_tensor(mrb_state *mrb, mrb_value self) { 202 | TfLiteTensor* tensor; 203 | mrb_int index; 204 | struct RClass* _class_tflite_tensor; 205 | mrb_value c; 206 | TfLiteInterpreter* interpreter = DATA_PTR(self); 207 | mrb_get_args(mrb, "i", &index); 208 | tensor = TfLiteInterpreterGetInputTensor(interpreter, index); 209 | if (tensor == NULL) { 210 | mrb_raise(mrb, E_ARGUMENT_ERROR, "invalid argument"); 211 | } 212 | _class_tflite_tensor = mrb_class_get_under(mrb, mrb_module_get(mrb, "TfLite"), "Tensor"); 213 | c = mrb_obj_new(mrb, _class_tflite_tensor, 0, NULL); 214 | DATA_TYPE(c) = &mrb_tflite_tensor_type_; 215 | DATA_PTR(c) = tensor; 216 | return c; 217 | } 218 | 219 | static mrb_value 220 | mrb_tflite_interpreter_output_tensor(mrb_state *mrb, mrb_value self) { 221 | TfLiteTensor* tensor; 222 | mrb_int index; 223 | struct RClass* _class_tflite_tensor; 224 | mrb_value c; 225 | TfLiteInterpreter* interpreter = DATA_PTR(self); 226 | mrb_get_args(mrb, "i", &index); 227 | tensor = (TfLiteTensor*) TfLiteInterpreterGetOutputTensor(interpreter, index); 228 | _class_tflite_tensor = mrb_class_get_under(mrb, mrb_module_get(mrb, "TfLite"), "Tensor"); 229 | c = mrb_obj_new(mrb, _class_tflite_tensor, 0, NULL); 230 | DATA_TYPE(c) = &mrb_tflite_tensor_type_; 231 | DATA_PTR(c) = tensor; 232 | return c; 233 | } 234 | 235 | static mrb_value 236 | mrb_tflite_tensor_type(mrb_state *mrb, mrb_value self) { 237 | TfLiteTensor* tensor = DATA_PTR(self); 238 | return mrb_fixnum_value(TfLiteTensorType(tensor)); 239 | } 240 | 241 | static mrb_value 242 | mrb_tflite_tensor_name(mrb_state *mrb, mrb_value self) { 243 | TfLiteTensor* tensor = DATA_PTR(self); 244 | return mrb_str_new_cstr(mrb, TfLiteTensorName(tensor)); 245 | } 246 | 247 | static mrb_value 248 | mrb_tflite_tensor_num_dims(mrb_state *mrb, mrb_value self) { 249 | TfLiteTensor* tensor = DATA_PTR(self); 250 | return mrb_fixnum_value(TfLiteTensorNumDims(tensor)); 251 | } 252 | 253 | static mrb_value 254 | mrb_tflite_tensor_dim(mrb_state *mrb, mrb_value self) { 255 | TfLiteTensor* tensor = DATA_PTR(self); 256 | mrb_int index; 257 | mrb_get_args(mrb, "i", &index); 258 | return mrb_fixnum_value(TfLiteTensorDim(tensor, index)); 259 | } 260 | 261 | static mrb_value 262 | mrb_tflite_tensor_byte_size(mrb_state *mrb, mrb_value self) { 263 | TfLiteTensor* tensor = DATA_PTR(self); 264 | return mrb_fixnum_value(TfLiteTensorByteSize(tensor)); 265 | } 266 | 267 | static mrb_value 268 | mrb_tflite_tensor_data_get(mrb_state *mrb, mrb_value self) { 269 | TfLiteTensor* tensor = DATA_PTR(self); 270 | int ai, i; 271 | mrb_value ret; 272 | int len; 273 | TfLiteType type; 274 | uint8_t *uint8s; 275 | float *float32s; 276 | 277 | type = TfLiteTensorType(tensor); 278 | switch (type) { 279 | case kTfLiteUInt8: 280 | case kTfLiteInt8: 281 | len = TfLiteTensorByteSize(tensor); 282 | uint8s = (uint8_t*) TfLiteTensorData(tensor); 283 | ret = mrb_ary_new_capa(mrb, len); 284 | ai = mrb_gc_arena_save(mrb); 285 | for (i = 0; i < len; i++) { 286 | mrb_ary_push(mrb, ret, mrb_fixnum_value(uint8s[i])); 287 | mrb_gc_arena_restore(mrb, ai); 288 | } 289 | break; 290 | case kTfLiteFloat32: 291 | len = TfLiteTensorByteSize(tensor) / 4; 292 | float32s = (float*) TfLiteTensorData(tensor); 293 | ret = mrb_ary_new_capa(mrb, len); 294 | ai = mrb_gc_arena_save(mrb); 295 | for (i = 0; i < len; i++) { 296 | mrb_ary_push(mrb, ret, mrb_float_value(mrb, float32s[i])); 297 | mrb_gc_arena_restore(mrb, ai); 298 | } 299 | break; 300 | default: 301 | mrb_raisef(mrb, E_RUNTIME_ERROR, "tensor type %S not supported", mrb_str_new_cstr(mrb, tensor_type_name(type))); 302 | } 303 | MRB_SET_FROZEN_FLAG(mrb_basic_ptr(ret)); 304 | return ret; 305 | } 306 | 307 | static mrb_value 308 | mrb_tflite_tensor_data_set(mrb_state *mrb, mrb_value self) { 309 | TfLiteTensor* tensor = DATA_PTR(self); 310 | int i; 311 | int len, ary_len; 312 | TfLiteType type; 313 | uint8_t *uint8s; 314 | float *float32s; 315 | mrb_value arg_data; 316 | 317 | mrb_get_args(mrb, "o", &arg_data); 318 | if (mrb_nil_p(arg_data) || mrb_type(arg_data) != MRB_TT_ARRAY) { 319 | mrb_raise(mrb, E_ARGUMENT_ERROR, "argument must be array"); 320 | } 321 | ary_len = RARRAY_LEN(arg_data); 322 | 323 | type = TfLiteTensorType(tensor); 324 | switch (type) { 325 | case kTfLiteUInt8: 326 | case kTfLiteInt8: 327 | len = TfLiteTensorByteSize(tensor); 328 | if (ary_len != len) { 329 | mrb_raise(mrb, E_ARGUMENT_ERROR, "argument size mismatched"); 330 | } 331 | uint8s = (uint8_t*) TfLiteTensorData(tensor); 332 | for (i = 0; i < len; i++) { 333 | uint8s[i] = (uint8_t) mrb_fixnum(mrb_ary_entry(arg_data, i)); 334 | } 335 | break; 336 | case kTfLiteFloat32: 337 | len = TfLiteTensorByteSize(tensor) / 4; 338 | if (ary_len != len) { 339 | mrb_raise(mrb, E_ARGUMENT_ERROR, "argument size mismatched"); 340 | } 341 | float32s = (float*) TfLiteTensorData(tensor); 342 | for (i = 0; i < len; i++) { 343 | float32s[i] = mrb_as_float(mrb, mrb_ary_entry(arg_data, i)); 344 | } 345 | break; 346 | default: 347 | mrb_raisef(mrb, E_RUNTIME_ERROR, "tensor type %S not supported", mrb_str_new_cstr(mrb, tensor_type_name(type))); 348 | } 349 | return mrb_nil_value(); 350 | } 351 | 352 | void 353 | mrb_mruby_tflite_gem_init(mrb_state* mrb) { 354 | struct RClass *_class_tflite; 355 | struct RClass *_class_tflite_model; 356 | struct RClass *_class_tflite_interpreter; 357 | struct RClass *_class_tflite_interpreter_options; 358 | struct RClass *_class_tflite_tensor; 359 | ARENA_SAVE; 360 | 361 | _class_tflite = mrb_define_module(mrb, "TfLite"); 362 | 363 | _class_tflite_model = mrb_define_class_under(mrb, _class_tflite, "Model", mrb->object_class); 364 | MRB_SET_INSTANCE_TT(_class_tflite_model, MRB_TT_DATA); 365 | mrb_define_method(mrb, _class_tflite_model, "initialize", mrb_tflite_model_init, MRB_ARGS_REQ(1)); 366 | mrb_define_module_function(mrb, _class_tflite_model, "from_file", mrb_tflite_model_from_file, MRB_ARGS_REQ(1)); 367 | ARENA_RESTORE; 368 | 369 | _class_tflite_interpreter_options = mrb_define_class_under(mrb, _class_tflite, "InterpreterOptions", mrb->object_class); 370 | MRB_SET_INSTANCE_TT(_class_tflite_interpreter_options, MRB_TT_DATA); 371 | mrb_define_method(mrb, _class_tflite_interpreter_options, "initialize", mrb_tflite_interpreter_options_init, MRB_ARGS_NONE()); 372 | mrb_define_method(mrb, _class_tflite_interpreter_options, "num_threads=", mrb_tflite_interpreter_options_num_threads_set, MRB_ARGS_REQ(1)); 373 | mrb_define_method(mrb, _class_tflite_interpreter_options, "add_delegate", mrb_tflite_interpreter_options_add_delegate, MRB_ARGS_REQ(1)); 374 | 375 | _class_tflite_interpreter = mrb_define_class_under(mrb, _class_tflite, "Interpreter", mrb->object_class); 376 | MRB_SET_INSTANCE_TT(_class_tflite_interpreter, MRB_TT_DATA); 377 | mrb_define_method(mrb, _class_tflite_interpreter, "initialize", mrb_tflite_interpreter_init, MRB_ARGS_ARG(1, 1)); 378 | mrb_define_method(mrb, _class_tflite_interpreter, "allocate_tensors", mrb_tflite_interpreter_allocate_tensors, MRB_ARGS_NONE()); 379 | mrb_define_method(mrb, _class_tflite_interpreter, "invoke", mrb_tflite_interpreter_invoke, MRB_ARGS_NONE()); 380 | mrb_define_method(mrb, _class_tflite_interpreter, "input_tensor_count", mrb_tflite_interpreter_input_tensor_count, MRB_ARGS_NONE()); 381 | mrb_define_method(mrb, _class_tflite_interpreter, "input_tensor", mrb_tflite_interpreter_input_tensor, MRB_ARGS_REQ(1)); 382 | mrb_define_method(mrb, _class_tflite_interpreter, "output_tensor_count", mrb_tflite_interpreter_output_tensor_count, MRB_ARGS_NONE()); 383 | mrb_define_method(mrb, _class_tflite_interpreter, "output_tensor", mrb_tflite_interpreter_output_tensor, MRB_ARGS_REQ(1)); 384 | ARENA_RESTORE; 385 | 386 | _class_tflite_tensor = mrb_define_class_under(mrb, _class_tflite, "Tensor", mrb->object_class); 387 | MRB_SET_INSTANCE_TT(_class_tflite_tensor, MRB_TT_DATA); 388 | mrb_define_method(mrb, _class_tflite_tensor, "type", mrb_tflite_tensor_type, MRB_ARGS_NONE()); 389 | mrb_define_method(mrb, _class_tflite_tensor, "name", mrb_tflite_tensor_name, MRB_ARGS_NONE()); 390 | mrb_define_method(mrb, _class_tflite_tensor, "num_dims", mrb_tflite_tensor_num_dims, MRB_ARGS_NONE()); 391 | mrb_define_method(mrb, _class_tflite_tensor, "dim", mrb_tflite_tensor_dim, MRB_ARGS_REQ(1)); 392 | mrb_define_method(mrb, _class_tflite_tensor, "byte_size", mrb_tflite_tensor_byte_size, MRB_ARGS_NONE()); 393 | mrb_define_method(mrb, _class_tflite_tensor, "data", mrb_tflite_tensor_data_get, MRB_ARGS_NONE()); 394 | mrb_define_method(mrb, _class_tflite_tensor, "data=", mrb_tflite_tensor_data_set, MRB_ARGS_REQ(1)); 395 | ARENA_RESTORE; 396 | } 397 | 398 | void 399 | mrb_mruby_tflite_gem_final(mrb_state* mrb) { 400 | } 401 | 402 | /* vim:set et ts=2 sts=2 sw=2 tw=0: */ 403 | -------------------------------------------------------------------------------- /tensorflow.patch: -------------------------------------------------------------------------------- 1 | diff --git a/tensorflow/lite/delegates/gpu/common/BUILD b/tensorflow/lite/delegates/gpu/common/BUILD 2 | index e9877b63fb..fe625c25e7 100644 3 | --- a/tensorflow/lite/delegates/gpu/common/BUILD 4 | +++ b/tensorflow/lite/delegates/gpu/common/BUILD 5 | @@ -251,6 +251,7 @@ cc_library( 6 | name = "status", 7 | hdrs = ["status.h"], 8 | deps = ["@com_google_absl//absl/status"], 9 | + defines = ["EGL_NO_X11=1"], 10 | ) 11 | 12 | cc_library( 13 | -------------------------------------------------------------------------------- /test/tflite_test.rb: -------------------------------------------------------------------------------- 1 | assert('xor') do 2 | model = TfLite::Model.from_file(ENV['MRB_TFLITE_XORMODEL']) 3 | interpreter = TfLite::Interpreter.new(model) 4 | interpreter.allocate_tensors 5 | input = interpreter.input_tensor(0) 6 | output = interpreter.output_tensor(0) 7 | [[0, 0], [1, 0], [0, 1], [1, 1]].each do |x| 8 | input.data = x 9 | interpreter.invoke 10 | assert_equal(x[0] ^ x[1], output.data[0].round) 11 | end 12 | end 13 | -------------------------------------------------------------------------------- /test/xor_model.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattn/mruby-tflite/385b658cca626ab932642aafe09846f64d58255f/test/xor_model.tflite --------------------------------------------------------------------------------