├── testdata ├── bad_model.tflite ├── permute_uint8.tflite └── mobilenet_quant.tflite ├── .gitignore ├── lib ├── src │ ├── blobs │ │ ├── libtensorflowlite_c-mac64.so │ │ ├── libtensorflowlite_c-win64.dll │ │ └── libtensorflowlite_c-linux64.so │ ├── check.dart │ ├── bindings │ │ ├── bindings.dart │ │ ├── types.dart │ │ ├── model.dart │ │ ├── dlib.dart │ │ ├── interpreter_options.dart │ │ ├── tensor.dart │ │ └── interpreter.dart │ ├── ffi │ │ └── helper.dart │ ├── model.dart │ ├── interpreter_options.dart │ ├── tensor.dart │ └── interpreter.dart └── tflite.dart ├── AUTHORS ├── pubspec.yaml ├── .github └── workflows │ └── build.yaml ├── CHANGELOG.md ├── README.md ├── LICENSE ├── analysis_options.yaml └── test └── tflite_test.dart /testdata/bad_model.tflite: -------------------------------------------------------------------------------- 1 | I don't know what I'm doing! 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .dart_tool 2 | .idea/ 3 | .packages 4 | pubspec.lock 5 | -------------------------------------------------------------------------------- /testdata/permute_uint8.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dart-archive/tflite_native/HEAD/testdata/permute_uint8.tflite -------------------------------------------------------------------------------- /testdata/mobilenet_quant.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dart-archive/tflite_native/HEAD/testdata/mobilenet_quant.tflite -------------------------------------------------------------------------------- /lib/src/blobs/libtensorflowlite_c-mac64.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dart-archive/tflite_native/HEAD/lib/src/blobs/libtensorflowlite_c-mac64.so -------------------------------------------------------------------------------- /lib/src/blobs/libtensorflowlite_c-win64.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dart-archive/tflite_native/HEAD/lib/src/blobs/libtensorflowlite_c-win64.dll -------------------------------------------------------------------------------- /lib/src/blobs/libtensorflowlite_c-linux64.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dart-archive/tflite_native/HEAD/lib/src/blobs/libtensorflowlite_c-linux64.so -------------------------------------------------------------------------------- /AUTHORS: -------------------------------------------------------------------------------- 1 | # Below is a list of people and organizations that have contributed 2 | # to the Dart Services project. Names should be added to the list like so: 3 | # 4 | # Name/Organization 5 | 6 | Google Inc. 7 | -------------------------------------------------------------------------------- /pubspec.yaml: -------------------------------------------------------------------------------- 1 | name: tflite_native 2 | version: 0.4.1 3 | author: Dart Team 4 | description: >- 5 | A Dart interface to TensorFlow Lite through Dart FFI. This library wraps the 6 | experimental tflite C API. 7 | homepage: https://github.com/dart-lang/tflite_native 8 | 9 | environment: 10 | sdk: '>=2.6.0 <3.0.0' 11 | 12 | dependencies: 13 | path: ^1.6.2 14 | ffi: ^0.1.3 15 | 16 | dev_dependencies: 17 | pedantic: ^1.0.0 18 | test: ^1.6.4 19 | -------------------------------------------------------------------------------- /lib/src/check.dart: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2020, the Dart project authors. Please see the AUTHORS file 2 | // for details. All rights reserved. Use of this source code is governed by a 3 | // BSD-style license that can be found in the LICENSE file. 4 | 5 | /// Throws an [ArgumentError] if [test] is false. 6 | void checkArgument(bool test, {String message}) { 7 | if (!test) throw ArgumentError(message); 8 | } 9 | 10 | /// Throws a [StateError] if [test] is false. 11 | void checkState(bool test, {String message}) { 12 | if (!test) throw StateError(message ?? 'failed precondition'); 13 | } 14 | -------------------------------------------------------------------------------- /lib/tflite.dart: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, the Dart project authors. Please see the AUTHORS file 2 | // for details. All rights reserved. Use of this source code is governed by a 3 | // BSD-style license that can be found in the LICENSE file. 4 | 5 | /// TensorFlow Lite for Dart 6 | library tflite; 7 | 8 | import 'package:ffi/ffi.dart'; 9 | 10 | import 'src/bindings/bindings.dart'; 11 | 12 | export 'src/interpreter.dart'; 13 | export 'src/interpreter_options.dart'; 14 | export 'src/model.dart'; 15 | export 'src/tensor.dart'; 16 | 17 | /// tflite version information. 18 | String get version => Utf8.fromUtf8(TfLiteVersion()); 19 | -------------------------------------------------------------------------------- /lib/src/bindings/bindings.dart: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, the Dart project authors. Please see the AUTHORS file 2 | // for details. All rights reserved. Use of this source code is governed by a 3 | // BSD-style license that can be found in the LICENSE file. 4 | 5 | import 'dart:ffi'; 6 | 7 | import 'package:ffi/ffi.dart'; 8 | 9 | import 'dlib.dart'; 10 | 11 | /// Version information for the TensorFlowLite library. 12 | Pointer Function() TfLiteVersion = tflitelib 13 | .lookup>('TfLiteVersion') 14 | .asFunction(); 15 | 16 | typedef _TfLiteVersion_native_t = Pointer Function(); 17 | -------------------------------------------------------------------------------- /.github/workflows/build.yaml: -------------------------------------------------------------------------------- 1 | name: Dart 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - master 8 | schedule: 9 | # “At 00:00 (UTC) on Sunday.” 10 | - cron: '0 0 * * 0' 11 | 12 | jobs: 13 | build: 14 | runs-on: ubuntu-latest 15 | 16 | container: 17 | image: google/dart:beta 18 | 19 | steps: 20 | - uses: actions/checkout@v2 21 | 22 | - name: pub get 23 | run: pub get 24 | 25 | - name: dart format 26 | run: dart format --output=none --set-exit-if-changed . 27 | 28 | - name: dart analyze 29 | run: dart analyze --fatal-infos 30 | 31 | - name: dart test 32 | run: dart test 33 | -------------------------------------------------------------------------------- /lib/src/ffi/helper.dart: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, the Dart project authors. Please see the AUTHORS file 2 | // for details. All rights reserved. Use of this source code is governed by a 3 | // BSD-style license that can be found in the LICENSE file. 4 | 5 | import 'dart:ffi'; 6 | 7 | /// Forces casting a [Pointer] to a different type. 8 | /// 9 | /// Unlike [Pointer.cast], the output type does not have to be a subclass of 10 | /// the input type. This is especially useful for with [Pointer]. 11 | Pointer cast(Pointer ptr) => 12 | Pointer.fromAddress(ptr.address); 13 | 14 | /// Checks if a [Pointer] is not null. 15 | bool isNotNull(Pointer ptr) => ptr.address != nullptr.address; 16 | -------------------------------------------------------------------------------- /lib/src/bindings/types.dart: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, the Dart project authors. Please see the AUTHORS file 2 | // for details. All rights reserved. Use of this source code is governed by a 3 | // BSD-style license that can be found in the LICENSE file. 4 | 5 | import 'dart:ffi'; 6 | 7 | /// Wraps a model interpreter. 8 | class TfLiteInterpreter extends Struct {} 9 | 10 | /// Wraps customized interpreter configuration options. 11 | class TfLiteInterpreterOptions extends Struct {} 12 | 13 | /// Wraps a loaded TensorFlowLite model. 14 | class TfLiteModel extends Struct {} 15 | 16 | /// Wraps data associated with a graph tensor. 17 | class TfLiteTensor extends Struct {} 18 | 19 | /// Status of a TensorFlowLite function call. 20 | class TfLiteStatus { 21 | static const ok = 0; 22 | static const error = 1; 23 | } 24 | 25 | /// Types supported by tensor. 26 | enum TfLiteType { 27 | none, 28 | float32, 29 | int32, 30 | uint8, 31 | int64, 32 | string, 33 | bool, 34 | int16, 35 | complex64, 36 | int8, 37 | float16 38 | } 39 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # 0.4.1 2 | 3 | * Updated the documentation of this package to clarify that is is discontinued. 4 | 5 | # 0.4.0+1 6 | 7 | * Removed the dependency on quiver. 8 | 9 | # 0.4.0 10 | 11 | * Replaced uses of deprecated `asExternalTypedData` with `asTypedList`. 12 | * Updated to use Dart 2.6 and latest dart:ffi APIs 13 | 14 | # 0.3.0 15 | 16 | * Update to new tflite C API 17 | 18 | # 0.3.0 19 | 20 | * Update to new tflite C API 21 | 22 | # 0.2.3 23 | 24 | * Specify `>=2.5.0` in `pubspec.yaml` 25 | 26 | # 0.2.2 27 | 28 | * Use `Pointer.asExternalTypedData` for reading output tensor 29 | 30 | # 0.2.1 31 | 32 | * Rebuild tflite shared libraries to include 33 | [tensorflow/tensorflow@4735b9b#diff-4e94df4d2961961ba5f69bbd666e0552](https://github.com/tensorflow/tensorflow/commit/4735b9bc02d35864eaba6ab48d73b73b333390e2#diff-4e94df4d2961961ba5f69bbd666e0552) 34 | 35 | # 0.2.0 36 | 37 | * Load shared object as sibling when running from snapshot 38 | 39 | # 0.1.0 40 | 41 | * Apply fixes from `dartfmt` 42 | * Update package description 43 | 44 | # 0.0.3 45 | 46 | * Bundle binaries for 64-bit linux, mac, and windows systems 47 | * Bindings for 24 / 29 functions exposed by libtensorflow_c.so 48 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://github.com/dart-lang/tflite_native/workflows/Dart/badge.svg)](https://github.com/dart-lang/tflite_native/actions) 2 | 3 | ## Status 4 | 5 | Note, this package has been **discontinued** - no support or maintenance is 6 | planned. 7 | 8 | ## What is it? 9 | 10 | A Dart interface to [TensorFlow Lite (tflite)](https://www.tensorflow.org/lite) 11 | through Dart's 12 | [foreign function interface (FFI)](https://dart.dev/server/c-interop). 13 | This library wraps the experimental tflite 14 | [C API](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/c/c_api.h). 15 | 16 | ## What Dart platforms does this package support? 17 | 18 | This package supports desktop use cases (Linux, OSX, Windows, etc). 19 | 20 | ## What if I want TensorFlow Lite support for Flutter apps? 21 | 22 | Flutter developers should instead 23 | consider using the Flutter plugin [flutter_tflite](https://github.com/shaqian/flutter_tflite) 24 | (among the issues using this package with Flutter, we locate the tflite dynamic library through 25 | `Isolate.resolvePackageUri`, which doesn't translate perfectly in the Flutter context (see 26 | https://github.com/flutter/flutter/issues/14815). 27 | -------------------------------------------------------------------------------- /lib/src/model.dart: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, the Dart project authors. Please see the AUTHORS file 2 | // for details. All rights reserved. Use of this source code is governed by a 3 | // BSD-style license that can be found in the LICENSE file. 4 | 5 | import 'dart:ffi'; 6 | 7 | import 'package:ffi/ffi.dart'; 8 | 9 | import 'bindings/model.dart'; 10 | import 'bindings/types.dart'; 11 | import 'check.dart'; 12 | import 'ffi/helper.dart'; 13 | 14 | /// TensorFlowLite model. 15 | class Model { 16 | final Pointer _model; 17 | bool _deleted = false; 18 | 19 | Pointer get base => _model; 20 | 21 | Model._(this._model); 22 | 23 | /// Loads model from a file or throws if unsuccessful. 24 | factory Model.fromFile(String path) { 25 | final cpath = Utf8.toUtf8(path); 26 | final model = TfLiteModelCreateFromFile(cpath); 27 | free(cpath); 28 | checkArgument(isNotNull(model), message: 'Unable to create model.'); 29 | return Model._(model); 30 | } 31 | 32 | /// Destroys the model instance. 33 | void delete() { 34 | checkState(!_deleted, message: 'Model already deleted.'); 35 | TfLiteModelDelete(_model); 36 | _deleted = true; 37 | } 38 | 39 | // Unimplemented: 40 | // Model.fromBuffer => TfLiteNewModel 41 | } 42 | -------------------------------------------------------------------------------- /lib/src/interpreter_options.dart: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, the Dart project authors. Please see the AUTHORS file 2 | // for details. All rights reserved. Use of this source code is governed by a 3 | // BSD-style license that can be found in the LICENSE file. 4 | 5 | import 'dart:ffi'; 6 | 7 | import 'bindings/interpreter_options.dart'; 8 | import 'bindings/types.dart'; 9 | import 'check.dart'; 10 | 11 | /// TensorFlowLite interpreter options. 12 | class InterpreterOptions { 13 | final Pointer _options; 14 | bool _deleted = false; 15 | 16 | Pointer get base => _options; 17 | 18 | InterpreterOptions._(this._options); 19 | 20 | /// Creates a new options instance. 21 | factory InterpreterOptions() => 22 | InterpreterOptions._(TfLiteInterpreterOptionsCreate()); 23 | 24 | /// Destroys the options instance. 25 | void delete() { 26 | checkState(!_deleted, message: 'InterpreterOptions already deleted.'); 27 | TfLiteInterpreterOptionsDelete(_options); 28 | _deleted = true; 29 | } 30 | 31 | /// Sets the number of CPU threads to use. 32 | set threads(int threads) => 33 | TfLiteInterpreterOptionsSetNumThreads(_options, threads); 34 | 35 | // Unimplemented: 36 | // TfLiteInterpreterOptionsSetErrorReporter 37 | } 38 | -------------------------------------------------------------------------------- /lib/src/bindings/model.dart: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, the Dart project authors. Please see the AUTHORS file 2 | // for details. All rights reserved. Use of this source code is governed by a 3 | // BSD-style license that can be found in the LICENSE file. 4 | 5 | import 'dart:ffi'; 6 | 7 | import 'package:ffi/ffi.dart'; 8 | 9 | import 'dlib.dart'; 10 | import 'types.dart'; 11 | 12 | /// Returns a model from the provided buffer, or null on failure. 13 | Pointer Function(Pointer data, int size) TfLiteNewModel = 14 | tflitelib 15 | .lookup>('TfLiteNewModel') 16 | .asFunction(); 17 | 18 | typedef _TfLiteNewModel_native_t = Pointer Function( 19 | Pointer data, Int32 size); 20 | 21 | /// Returns a model from the provided file, or null on failure. 22 | Pointer Function(Pointer path) TfLiteModelCreateFromFile = 23 | tflitelib 24 | .lookup>( 25 | 'TfLiteModelCreateFromFile') 26 | .asFunction(); 27 | 28 | typedef _TfLiteModelCreateFromFile_native_t = Pointer Function( 29 | Pointer path); 30 | 31 | /// Destroys the model instance. 32 | void Function(Pointer) TfLiteModelDelete = tflitelib 33 | .lookup>('TfLiteModelDelete') 34 | .asFunction(); 35 | 36 | typedef _TfLiteModelDelete_native_t = Void Function(Pointer); 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2019, the Dart project authors. All rights reserved. 2 | Redistribution and use in source and binary forms, with or without 3 | modification, are permitted provided that the following conditions are 4 | met: 5 | 6 | * Redistributions of source code must retain the above copyright 7 | notice, this list of conditions and the following disclaimer. 8 | * Redistributions in binary form must reproduce the above 9 | copyright notice, this list of conditions and the following 10 | disclaimer in the documentation and/or other materials provided 11 | with the distribution. 12 | * Neither the name of Google Inc. nor the names of its 13 | contributors may be used to endorse or promote products derived 14 | from this software without specific prior written permission. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 20 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 21 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 22 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 23 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 24 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 25 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 26 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | -------------------------------------------------------------------------------- /lib/src/bindings/dlib.dart: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, the Dart project authors. Please see the AUTHORS file 2 | // for details. All rights reserved. Use of this source code is governed by a 3 | // BSD-style license that can be found in the LICENSE file. 4 | 5 | import 'dart:cli' as cli; 6 | import 'dart:ffi'; 7 | import 'dart:io'; 8 | import 'dart:isolate' show Isolate; 9 | 10 | const Set _supported = {'linux64', 'mac64', 'win64'}; 11 | 12 | /// Computes the shared object filename for this os and architecture. 13 | /// 14 | /// Throws an exception if invoked on an unsupported platform. 15 | String _getObjectFilename() { 16 | final architecture = sizeOf() == 4 ? '32' : '64'; 17 | String os, extension; 18 | if (Platform.isLinux) { 19 | os = 'linux'; 20 | extension = 'so'; 21 | } else if (Platform.isMacOS) { 22 | os = 'mac'; 23 | extension = 'so'; 24 | } else if (Platform.isWindows) { 25 | os = 'win'; 26 | extension = 'dll'; 27 | } else { 28 | throw Exception('Unsupported platform!'); 29 | } 30 | 31 | final result = os + architecture; 32 | if (!_supported.contains(result)) { 33 | throw Exception('Unsupported platform: $result!'); 34 | } 35 | 36 | return 'libtensorflowlite_c-$result.$extension'; 37 | } 38 | 39 | /// TensorFlowLite C library. 40 | DynamicLibrary tflitelib = () { 41 | String objectFile; 42 | if (Platform.script.path.endsWith('.snapshot')) { 43 | // If we're running from snapshot, assume that the shared object 44 | // file is a sibling. 45 | objectFile = 46 | File.fromUri(Platform.script).parent.path + '/' + _getObjectFilename(); 47 | } else { 48 | final rootLibrary = 'package:tflite_native/tflite.dart'; 49 | final blobs = cli 50 | .waitFor(Isolate.resolvePackageUri(Uri.parse(rootLibrary))) 51 | .resolve('src/blobs/'); 52 | objectFile = blobs.resolve(_getObjectFilename()).toFilePath(); 53 | } 54 | 55 | return DynamicLibrary.open(objectFile); 56 | }(); 57 | -------------------------------------------------------------------------------- /lib/src/bindings/interpreter_options.dart: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, the Dart project authors. Please see the AUTHORS file 2 | // for details. All rights reserved. Use of this source code is governed by a 3 | // BSD-style license that can be found in the LICENSE file. 4 | 5 | import 'dart:ffi'; 6 | 7 | import 'package:ffi/ffi.dart'; 8 | 9 | import 'dlib.dart'; 10 | import 'types.dart'; 11 | 12 | /// Returns a new interpreter options instances. 13 | Pointer Function() TfLiteInterpreterOptionsCreate = 14 | tflitelib 15 | .lookup>( 16 | 'TfLiteInterpreterOptionsCreate') 17 | .asFunction(); 18 | 19 | typedef _TfLiteInterpreterOptionsCreate_native_t 20 | = Pointer Function(); 21 | 22 | /// Destroys the interpreter options instance. 23 | void Function(Pointer) 24 | TfLiteInterpreterOptionsDelete = tflitelib 25 | .lookup>( 26 | 'TfLiteInterpreterOptionsDelete') 27 | .asFunction(); 28 | 29 | typedef _TfLiteInterpreterOptionsDelete_native_t = Void Function( 30 | Pointer); 31 | 32 | /// Sets the number of CPU threads to use for the interpreter. 33 | void Function( 34 | Pointer options, 35 | int 36 | threads) TfLiteInterpreterOptionsSetNumThreads = tflitelib 37 | .lookup>( 38 | 'TfLiteInterpreterOptionsSetNumThreads') 39 | .asFunction(); 40 | 41 | typedef _TfLiteInterpreterOptionsSetNumThreads_native_t = Void Function( 42 | Pointer options, Int32 threads); 43 | 44 | /// Sets a custom error reporter for interpreter execution. 45 | // 46 | /// * `reporter` takes the provided `user_data` object, as well as a C-style 47 | /// format string and arg list (see also vprintf). 48 | /// * `user_data` is optional. If provided, it is owned by the client and must 49 | /// remain valid for the duration of the interpreter lifetime. 50 | void Function( 51 | Pointer options, 52 | Pointer> reporter, 53 | Pointer user_data, 54 | ) TfLiteInterpreterOptionsSetErrorReporter = tflitelib 55 | .lookup>( 56 | 'TfLiteInterpreterOptionsSetErrorReporter') 57 | .asFunction(); 58 | 59 | typedef _TfLiteInterpreterOptionsSetErrorReporter_native_t = Void Function( 60 | Pointer options, 61 | Pointer> reporter, 62 | Pointer user_data, 63 | ); 64 | 65 | /// Custom error reporter function for interpreter execution. 66 | typedef Reporter = Void Function( 67 | Pointer user_data, 68 | Pointer format, 69 | /*va_list*/ Pointer args); 70 | -------------------------------------------------------------------------------- /lib/src/tensor.dart: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, the Dart project authors. Please see the AUTHORS file 2 | // for details. All rights reserved. Use of this source code is governed by a 3 | // BSD-style license that can be found in the LICENSE file. 4 | 5 | import 'dart:ffi'; 6 | import 'dart:typed_data'; 7 | 8 | import 'package:ffi/ffi.dart'; 9 | 10 | import 'bindings/tensor.dart'; 11 | import 'bindings/types.dart'; 12 | import 'check.dart'; 13 | import 'ffi/helper.dart'; 14 | 15 | export 'bindings/types.dart' show TfLiteType; 16 | 17 | /// TensorFlowLite tensor. 18 | class Tensor { 19 | final Pointer _tensor; 20 | 21 | Tensor(this._tensor) { 22 | ArgumentError.checkNotNull(_tensor); 23 | } 24 | 25 | /// Name of the tensor element. 26 | String get name => Utf8.fromUtf8(TfLiteTensorName(_tensor)); 27 | 28 | /// Data type of the tensor element. 29 | TfLiteType get type => TfLiteTensorType(_tensor); 30 | 31 | /// Dimensions of the tensor. 32 | List get shape => List.generate( 33 | TfLiteTensorNumDims(_tensor), (i) => TfLiteTensorDim(_tensor, i)); 34 | 35 | /// Underlying data buffer as bytes. 36 | Uint8List get data { 37 | final data = cast(TfLiteTensorData(_tensor)); 38 | checkState(isNotNull(data), message: 'Tensor data is null.'); 39 | return UnmodifiableUint8ListView( 40 | data.asTypedList(TfLiteTensorByteSize(_tensor))); 41 | } 42 | 43 | /// Updates the underlying data buffer with new bytes. 44 | /// 45 | /// The size must match the size of the tensor. 46 | set data(Uint8List bytes) { 47 | final tensorByteSize = TfLiteTensorByteSize(_tensor); 48 | checkArgument(tensorByteSize == bytes.length); 49 | final data = cast(TfLiteTensorData(_tensor)); 50 | checkState(isNotNull(data), message: 'Tensor data is null.'); 51 | final externalTypedData = data.asTypedList(tensorByteSize); 52 | externalTypedData.setRange(0, tensorByteSize, bytes); 53 | } 54 | 55 | /// Copies the input bytes to the underlying data buffer. 56 | // TODO(shanehop): Prevent access if unallocated. 57 | void copyFrom(Uint8List bytes) { 58 | var size = bytes.length; 59 | final ptr = allocate(count: size); 60 | final externalTypedData = ptr.asTypedList(size); 61 | externalTypedData.setRange(0, bytes.length, bytes); 62 | checkState(TfLiteTensorCopyFromBuffer(_tensor, ptr.cast(), bytes.length) == 63 | TfLiteStatus.ok); 64 | free(ptr); 65 | } 66 | 67 | /// Returns a copy of the underlying data buffer. 68 | // TODO(shanehop): Prevent access if unallocated. 69 | Uint8List copyTo() { 70 | var size = TfLiteTensorByteSize(_tensor); 71 | final ptr = allocate(count: size); 72 | final externalTypedData = ptr.asTypedList(size); 73 | checkState( 74 | TfLiteTensorCopyToBuffer(_tensor, ptr.cast(), 4) == TfLiteStatus.ok); 75 | // Clone the data, because once `free(ptr)`, `externalTypedData` will be 76 | // volatile 77 | final bytes = externalTypedData.sublist(0); 78 | free(ptr); 79 | return bytes; 80 | } 81 | 82 | // Unimplemented: 83 | // TfLiteTensorQuantizationParams 84 | } 85 | -------------------------------------------------------------------------------- /analysis_options.yaml: -------------------------------------------------------------------------------- 1 | include: package:pedantic/analysis_options.yaml 2 | 3 | analyzer: 4 | strong-mode: 5 | implicit-casts: false 6 | 7 | linter: 8 | rules: 9 | - always_declare_return_types 10 | - annotate_overrides 11 | - avoid_bool_literals_in_conditional_expressions 12 | - avoid_classes_with_only_static_members 13 | - avoid_empty_else 14 | - avoid_function_literals_in_foreach_calls 15 | - avoid_init_to_null 16 | - avoid_null_checks_in_equality_operators 17 | - avoid_relative_lib_imports 18 | - avoid_renaming_method_parameters 19 | - avoid_return_types_on_setters 20 | - avoid_returning_null 21 | - avoid_returning_null_for_future 22 | - avoid_returning_null_for_void 23 | - avoid_returning_this 24 | - avoid_shadowing_type_parameters 25 | - avoid_single_cascade_in_expression_statements 26 | - avoid_types_as_parameter_names 27 | - avoid_unused_constructor_parameters 28 | - await_only_futures 29 | - camel_case_types 30 | - cancel_subscriptions 31 | #- cascade_invocations 32 | - comment_references 33 | - constant_identifier_names 34 | - control_flow_in_finally 35 | - directives_ordering 36 | - empty_catches 37 | - empty_constructor_bodies 38 | - empty_statements 39 | - file_names 40 | - hash_and_equals 41 | - implementation_imports 42 | - invariant_booleans 43 | - iterable_contains_unrelated_type 44 | - join_return_with_assignment 45 | - library_names 46 | - library_prefixes 47 | - list_remove_unrelated_type 48 | - literal_only_boolean_expressions 49 | - no_adjacent_strings_in_list 50 | - no_duplicate_case_values 51 | #- non_constant_identifier_names 52 | - null_closures 53 | - omit_local_variable_types 54 | - only_throw_errors 55 | - overridden_fields 56 | - package_api_docs 57 | - package_names 58 | - package_prefixed_library_names 59 | - prefer_adjacent_string_concatenation 60 | - prefer_collection_literals 61 | - prefer_conditional_assignment 62 | - prefer_const_constructors 63 | - prefer_contains 64 | - prefer_equal_for_default_values 65 | - prefer_final_fields 66 | #- prefer_final_locals 67 | - prefer_generic_function_type_aliases 68 | - prefer_initializing_formals 69 | #- prefer_interpolation_to_compose_strings 70 | - prefer_is_empty 71 | - prefer_is_not_empty 72 | - prefer_null_aware_operators 73 | - prefer_single_quotes 74 | - prefer_typing_uninitialized_variables 75 | - recursive_getters 76 | - slash_for_doc_comments 77 | - test_types_in_equals 78 | - throw_in_finally 79 | - type_init_formals 80 | - unawaited_futures 81 | - unnecessary_await_in_return 82 | - unnecessary_brace_in_string_interps 83 | - unnecessary_const 84 | - unnecessary_getters_setters 85 | - unnecessary_lambdas 86 | - unnecessary_new 87 | - unnecessary_null_aware_assignments 88 | - unnecessary_parenthesis 89 | - unnecessary_statements 90 | - unnecessary_this 91 | - unrelated_type_equality_checks 92 | - use_function_type_syntax_for_parameters 93 | - use_rethrow_when_possible 94 | - valid_regexps 95 | - void_checks 96 | -------------------------------------------------------------------------------- /lib/src/interpreter.dart: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, the Dart project authors. Please see the AUTHORS file 2 | // for details. All rights reserved. Use of this source code is governed by a 3 | // BSD-style license that can be found in the LICENSE file. 4 | 5 | import 'dart:ffi'; 6 | 7 | import 'package:ffi/ffi.dart'; 8 | 9 | import 'bindings/interpreter.dart'; 10 | import 'bindings/types.dart'; 11 | import 'check.dart'; 12 | import 'ffi/helper.dart'; 13 | import 'interpreter_options.dart'; 14 | import 'model.dart'; 15 | import 'tensor.dart'; 16 | 17 | /// TensorFlowLite interpreter for running inference on a model. 18 | class Interpreter { 19 | final Pointer _interpreter; 20 | bool _deleted = false; 21 | bool _allocated = false; 22 | 23 | Interpreter._(this._interpreter); 24 | 25 | /// Creates interpreter from model or throws if unsuccessful. 26 | factory Interpreter(Model model, {InterpreterOptions options}) { 27 | final interpreter = TfLiteInterpreterCreate( 28 | model.base, options?.base ?? cast(nullptr)); 29 | checkArgument(isNotNull(interpreter), 30 | message: 'Unable to create interpreter.'); 31 | return Interpreter._(interpreter); 32 | } 33 | 34 | /// Creates interpreter from a model file or throws if unsuccessful. 35 | factory Interpreter.fromFile(String file, {InterpreterOptions options}) { 36 | final model = Model.fromFile(file); 37 | final interpreter = Interpreter(model, options: options); 38 | model.delete(); 39 | return interpreter; 40 | } 41 | 42 | /// Destroys the model instance. 43 | void delete() { 44 | checkState(!_deleted, message: 'Interpreter already deleted.'); 45 | TfLiteInterpreterDelete(_interpreter); 46 | _deleted = true; 47 | } 48 | 49 | /// Updates allocations for all tensors. 50 | void allocateTensors() { 51 | checkState(!_allocated, message: 'Interpreter already allocated.'); 52 | checkState( 53 | TfLiteInterpreterAllocateTensors(_interpreter) == TfLiteStatus.ok); 54 | _allocated = true; 55 | } 56 | 57 | /// Runs inference for the loaded graph. 58 | void invoke() { 59 | checkState(_allocated, message: 'Interpreter not allocated.'); 60 | checkState(TfLiteInterpreterInvoke(_interpreter) == TfLiteStatus.ok); 61 | } 62 | 63 | /// Gets all input tensors associated with the model. 64 | List getInputTensors() => List.generate( 65 | TfLiteInterpreterGetInputTensorCount(_interpreter), 66 | (i) => Tensor(TfLiteInterpreterGetInputTensor(_interpreter, i)), 67 | growable: false); 68 | 69 | /// Gets all output tensors associated with the model. 70 | List getOutputTensors() => List.generate( 71 | TfLiteInterpreterGetOutputTensorCount(_interpreter), 72 | (i) => Tensor(TfLiteInterpreterGetOutputTensor(_interpreter, i)), 73 | growable: false); 74 | 75 | /// Resize input tensor for the given tensor index. `allocateTensors` must be called again afterward. 76 | void resizeInputTensor(int tensorIndex, List shape) { 77 | final dimensionSize = shape.length; 78 | final dimensions = allocate(count: dimensionSize); 79 | final externalTypedData = dimensions.asTypedList(dimensionSize); 80 | externalTypedData.setRange(0, dimensionSize, shape); 81 | final status = TfLiteInterpreterResizeInputTensor( 82 | _interpreter, tensorIndex, dimensions, dimensionSize); 83 | free(dimensions); 84 | checkState(status == TfLiteStatus.ok); 85 | _allocated = false; 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /lib/src/bindings/tensor.dart: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, the Dart project authors. Please see the AUTHORS file 2 | // for details. All rights reserved. Use of this source code is governed by a 3 | // BSD-style license that can be found in the LICENSE file. 4 | 5 | import 'dart:ffi'; 6 | 7 | import 'package:ffi/ffi.dart'; 8 | 9 | import 'dlib.dart'; 10 | import 'types.dart'; 11 | 12 | /// Returns the type of a tensor element. 13 | TfLiteType TfLiteTensorType(Pointer t) => 14 | TfLiteType.values[_TfLiteTensorType(t)]; 15 | int Function(Pointer) _TfLiteTensorType = tflitelib 16 | .lookup>('TfLiteTensorType') 17 | .asFunction(); 18 | 19 | typedef _TfLiteTensorType_native_t = /*TfLiteType*/ Int32 Function( 20 | Pointer); 21 | 22 | /// Returns the number of dimensions that the tensor has. 23 | int Function(Pointer) TfLiteTensorNumDims = tflitelib 24 | .lookup>( 25 | 'TfLiteTensorNumDims') 26 | .asFunction(); 27 | 28 | typedef _TfLiteTensorNumDims_native_t = Int32 Function(Pointer); 29 | 30 | /// Returns the length of the tensor in the 'dim_index' dimension. 31 | /// 32 | /// REQUIRES: 0 <= dim_index < TFLiteTensorNumDims(tensor) 33 | int Function(Pointer tensor, int dim_index) TfLiteTensorDim = 34 | tflitelib 35 | .lookup>('TfLiteTensorDim') 36 | .asFunction(); 37 | 38 | typedef _TfLiteTensorDim_native_t = Int32 Function( 39 | Pointer tensor, Int32 dim_index); 40 | 41 | /// Returns the size of the underlying data in bytes. 42 | int Function(Pointer) TfLiteTensorByteSize = tflitelib 43 | .lookup>( 44 | 'TfLiteTensorByteSize') 45 | .asFunction(); 46 | 47 | typedef _TfLiteTensorByteSize_native_t = Int32 Function(Pointer); 48 | 49 | /// Returns a pointer to the underlying data buffer. 50 | /// 51 | /// NOTE: The result may be null if tensors have not yet been allocated, e.g., 52 | /// if the Tensor has just been created or resized and `TfLiteAllocateTensors()` 53 | /// has yet to be called, or if the output tensor is dynamically sized and the 54 | /// interpreter hasn't been invoked. 55 | Pointer Function(Pointer) TfLiteTensorData = tflitelib 56 | .lookup>('TfLiteTensorData') 57 | .asFunction(); 58 | 59 | typedef _TfLiteTensorData_native_t = Pointer Function( 60 | Pointer); 61 | 62 | /// Returns the (null-terminated) name of the tensor. 63 | Pointer Function(Pointer) TfLiteTensorName = tflitelib 64 | .lookup>('TfLiteTensorName') 65 | .asFunction(); 66 | 67 | typedef _TfLiteTensorName_native_t = Pointer Function( 68 | Pointer); 69 | 70 | /// Copies from the provided input buffer into the tensor's buffer. 71 | /// 72 | /// REQUIRES: input_data_size == TfLiteTensorByteSize(tensor) 73 | /*TfLiteStatus*/ 74 | int Function( 75 | Pointer tensor, 76 | Pointer input_data, 77 | int input_data_size, 78 | ) TfLiteTensorCopyFromBuffer = tflitelib 79 | .lookup>( 80 | 'TfLiteTensorCopyFromBuffer') 81 | .asFunction(); 82 | 83 | typedef _TfLiteTensorCopyFromBuffer_native_t = /*TfLiteStatus*/ Int32 Function( 84 | Pointer tensor, 85 | Pointer input_data, 86 | Int32 input_data_size, 87 | ); 88 | 89 | /// Copies to the provided output buffer from the tensor's buffer. 90 | /// 91 | /// REQUIRES: output_data_size == TfLiteTensorByteSize(tensor) 92 | /*TfLiteStatus*/ 93 | int Function( 94 | Pointer tensor, 95 | Pointer output_data, 96 | int output_data_size, 97 | ) TfLiteTensorCopyToBuffer = tflitelib 98 | .lookup>( 99 | 'TfLiteTensorCopyToBuffer') 100 | .asFunction(); 101 | 102 | typedef _TfLiteTensorCopyToBuffer_native_t = /*TfLiteStatus*/ Int32 Function( 103 | Pointer tensor, 104 | Pointer output_data, 105 | Int32 output_data_size, 106 | ); 107 | 108 | // Unimplemented functions: 109 | // TfLiteTensorQuantizationParams 110 | -------------------------------------------------------------------------------- /lib/src/bindings/interpreter.dart: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, the Dart project authors. Please see the AUTHORS file 2 | // for details. All rights reserved. Use of this source code is governed by a 3 | // BSD-style license that can be found in the LICENSE file. 4 | 5 | import 'dart:ffi'; 6 | 7 | import 'dlib.dart'; 8 | import 'types.dart'; 9 | 10 | /// Returns a new interpreter using the provided model and options, or null on 11 | /// failure. 12 | /// 13 | /// * `model` must be a valid model instance. The caller retains ownership of the 14 | /// object, and can destroy it immediately after creating the interpreter; the 15 | /// interpreter will maintain its own reference to the underlying model data. 16 | /// * `optional_options` may be null. The caller retains ownership of the object, 17 | /// and can safely destroy it immediately after creating the interpreter. 18 | // 19 | /// NOTE: The client *must* explicitly allocate tensors before attempting to 20 | /// access input tensor data or invoke the interpreter. 21 | Pointer Function(Pointer model, 22 | Pointer optional_options) 23 | TfLiteInterpreterCreate = tflitelib 24 | .lookup>( 25 | 'TfLiteInterpreterCreate') 26 | .asFunction(); 27 | 28 | typedef _TfLiteInterpreterCreate_native_t = Pointer Function( 29 | Pointer model, 30 | Pointer optional_options); 31 | 32 | /// Destroys the interpreter. 33 | void Function(Pointer) TfLiteInterpreterDelete = tflitelib 34 | .lookup>( 35 | 'TfLiteInterpreterDelete') 36 | .asFunction(); 37 | 38 | typedef _TfLiteInterpreterDelete_native_t = Void Function( 39 | Pointer); 40 | 41 | /// Returns the number of input tensors associated with the model. 42 | int Function(Pointer) TfLiteInterpreterGetInputTensorCount = 43 | tflitelib 44 | .lookup>( 45 | 'TfLiteInterpreterGetInputTensorCount') 46 | .asFunction(); 47 | 48 | typedef _TfLiteInterpreterGetInputTensorCount_native_t = Int32 Function( 49 | Pointer); 50 | 51 | /// Returns the tensor associated with the input index. 52 | /// 53 | /// REQUIRES: 0 <= input_index < TfLiteInterpreterGetInputTensorCount(tensor) 54 | Pointer Function( 55 | Pointer interpreter, int input_index) 56 | TfLiteInterpreterGetInputTensor = tflitelib 57 | .lookup>( 58 | 'TfLiteInterpreterGetInputTensor') 59 | .asFunction(); 60 | 61 | typedef _TfLiteInterpreterGetInputTensor_native_t = Pointer 62 | Function(Pointer interpreter, Int32 input_index); 63 | 64 | /// Resizes the specified input tensor. 65 | /// 66 | /// NOTE: After a resize, the client *must* explicitly allocate tensors before 67 | /// attempting to access the resized tensor data or invoke the interpreter. 68 | /// REQUIRES: 0 <= input_index < TfLiteInterpreterGetInputTensorCount(tensor) 69 | /*TfLiteStatus*/ 70 | int Function(Pointer interpreter, int input_index, 71 | Pointer input_dims, int input_dims_size) 72 | TfLiteInterpreterResizeInputTensor = tflitelib 73 | .lookup>( 74 | 'TfLiteInterpreterResizeInputTensor') 75 | .asFunction(); 76 | 77 | typedef _TfLiteInterpreterResizeInputTensor_native_t 78 | = /*TfLiteStatus*/ Int32 Function(Pointer interpreter, 79 | Int32 input_index, Pointer input_dims, Int32 input_dims_size); 80 | 81 | /// Updates allocations for all tensors, resizing dependent tensors using the 82 | /// specified input tensor dimensionality. 83 | /// 84 | /// This is a relatively expensive operation, and need only be called after 85 | /// creating the graph and/or resizing any inputs. 86 | /*TfLiteStatus*/ 87 | int Function(Pointer) TfLiteInterpreterAllocateTensors = 88 | tflitelib 89 | .lookup>( 90 | 'TfLiteInterpreterAllocateTensors') 91 | .asFunction(); 92 | 93 | typedef _TfLiteInterpreterAllocateTensors_native_t = /*TfLiteStatus*/ Int32 94 | Function(Pointer); 95 | 96 | /// Runs inference for the loaded graph. 97 | /// 98 | /// NOTE: It is possible that the interpreter is not in a ready state to 99 | /// evaluate (e.g., if a ResizeInputTensor() has been performed without a call to 100 | /// AllocateTensors()). 101 | /*TfLiteStatus*/ 102 | int Function(Pointer) TfLiteInterpreterInvoke = tflitelib 103 | .lookup>( 104 | 'TfLiteInterpreterInvoke') 105 | .asFunction(); 106 | 107 | typedef _TfLiteInterpreterInvoke_native_t = /*TfLiteStatus*/ Int32 Function( 108 | Pointer); 109 | 110 | /// Returns the number of output tensors associated with the model. 111 | int Function( 112 | Pointer< 113 | TfLiteInterpreter>) TfLiteInterpreterGetOutputTensorCount = tflitelib 114 | .lookup>( 115 | 'TfLiteInterpreterGetOutputTensorCount') 116 | .asFunction(); 117 | 118 | typedef _TfLiteInterpreterGetOutputTensorCount_native_t = Int32 Function( 119 | Pointer); 120 | 121 | /// Returns the tensor associated with the output index. 122 | /// 123 | /// REQUIRES: 0 <= input_index < TfLiteInterpreterGetOutputTensorCount(tensor) 124 | /// 125 | /// NOTE: The shape and underlying data buffer for output tensors may be not 126 | /// be available until after the output tensor has been both sized and allocated. 127 | /// In general, best practice is to interact with the output tensor *after* 128 | /// calling TfLiteInterpreterInvoke(). 129 | Pointer Function( 130 | Pointer interpreter, int output_index) 131 | TfLiteInterpreterGetOutputTensor = tflitelib 132 | .lookup>( 133 | 'TfLiteInterpreterGetOutputTensor') 134 | .asFunction(); 135 | 136 | typedef _TfLiteInterpreterGetOutputTensor_native_t = Pointer 137 | Function(Pointer interpreter, Int32 output_index); 138 | -------------------------------------------------------------------------------- /test/tflite_test.dart: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, the Dart project authors. Please see the AUTHORS file 2 | // for details. All rights reserved. Use of this source code is governed by a 3 | // BSD-style license that can be found in the LICENSE file. 4 | 5 | import 'dart:io'; 6 | import 'dart:typed_data'; 7 | 8 | import 'package:path/path.dart' as path; 9 | import 'package:test/test.dart'; 10 | import 'package:tflite_native/tflite.dart' as tfl; 11 | 12 | final dataDir = path.join(Directory.current.path, 'testdata'); 13 | final dataFile = '$dataDir/permute_uint8.tflite'; 14 | final missingFile = '$dataDir/missing.tflite'; 15 | final badFile = '$dataDir/bad_model.tflite'; 16 | 17 | void main() { 18 | test('version', () { 19 | expect(tfl.version, isNotEmpty); 20 | }); 21 | 22 | group('model', () { 23 | test('from file', () { 24 | var model = tfl.Model.fromFile(dataFile); 25 | model.delete(); 26 | }); 27 | 28 | test('deleting a deleted model throws', () { 29 | var model = tfl.Model.fromFile(dataFile); 30 | model.delete(); 31 | expect(() => model.delete(), throwsA(isStateError)); 32 | }); 33 | 34 | test('missing file throws', () { 35 | if (File(missingFile).existsSync()) { 36 | fail('missingFile is not missing.'); 37 | } 38 | expect(() => tfl.Model.fromFile(missingFile), throwsA(isArgumentError)); 39 | }); 40 | 41 | test('bad file throws', () { 42 | if (!File(badFile).existsSync()) { 43 | fail('badFile is missing.'); 44 | } 45 | expect(() => tfl.Model.fromFile(badFile), throwsA(isArgumentError)); 46 | }); 47 | }); 48 | 49 | test('interpreter from model', () { 50 | var model = tfl.Model.fromFile(dataFile); 51 | var interpreter = tfl.Interpreter(model); 52 | model.delete(); 53 | interpreter.delete(); 54 | }); 55 | 56 | test('interpreter from file', () { 57 | var interpreter = tfl.Interpreter.fromFile(dataFile); 58 | interpreter.delete(); 59 | }); 60 | 61 | group('interpreter options', () { 62 | test('default', () { 63 | var options = tfl.InterpreterOptions(); 64 | var interpreter = tfl.Interpreter.fromFile(dataFile, options: options); 65 | options.delete(); 66 | interpreter.allocateTensors(); 67 | interpreter.invoke(); 68 | interpreter.delete(); 69 | }); 70 | 71 | test('threads', () { 72 | var options = tfl.InterpreterOptions()..threads = 1; 73 | var interpreter = tfl.Interpreter.fromFile(dataFile, options: options); 74 | options.delete(); 75 | interpreter.allocateTensors(); 76 | interpreter.invoke(); 77 | interpreter.delete(); 78 | // TODO(shanehop): Is there something meaningful to verify? 79 | }); 80 | }); 81 | 82 | group('interpreter', () { 83 | tfl.Interpreter interpreter; 84 | 85 | setUp(() => interpreter = tfl.Interpreter.fromFile(dataFile)); 86 | tearDown(() => interpreter.delete()); 87 | 88 | test('allocate', () { 89 | interpreter.allocateTensors(); 90 | }); 91 | 92 | test('allocate throws if already allocated', () { 93 | interpreter.allocateTensors(); 94 | expect(() => interpreter.allocateTensors(), throwsA(isStateError)); 95 | }); 96 | 97 | test('invoke throws if not allocated', () { 98 | expect(() => interpreter.invoke(), throwsA(isStateError)); 99 | }); 100 | 101 | test('invoke throws if not allocated after resized', () { 102 | interpreter.allocateTensors(); 103 | interpreter.resizeInputTensor(0, [1, 2, 4]); 104 | expect(() => interpreter.invoke(), throwsA(isStateError)); 105 | }); 106 | 107 | test('invoke succeeds if allocated', () { 108 | interpreter.allocateTensors(); 109 | interpreter.invoke(); 110 | }); 111 | 112 | test('get input tensors', () { 113 | expect(interpreter.getInputTensors(), hasLength(1)); 114 | }); 115 | 116 | test('get output tensors', () { 117 | expect(interpreter.getOutputTensors(), hasLength(1)); 118 | }); 119 | 120 | test('get output tensors', () { 121 | expect(interpreter.getOutputTensors(), hasLength(1)); 122 | }); 123 | 124 | test('resize input tensor', () { 125 | interpreter.resizeInputTensor(0, [2, 3, 5]); 126 | expect(interpreter.getInputTensors().single.shape, [2, 3, 5]); 127 | }); 128 | 129 | group('tensors', () { 130 | List tensors; 131 | setUp(() => tensors = interpreter.getInputTensors()); 132 | 133 | test('name', () { 134 | expect(tensors[0].name, 'input'); 135 | }); 136 | 137 | test('type', () { 138 | expect(tensors[0].type, tfl.TfLiteType.uint8); 139 | }); 140 | 141 | test('shape', () { 142 | expect(tensors[0].shape, [1, 4]); 143 | }); 144 | 145 | group('data', () { 146 | test('get throws if not allocated', () { 147 | expect(() => tensors[0].data, throwsA(isStateError)); 148 | }); 149 | 150 | test('get', () { 151 | interpreter.allocateTensors(); 152 | expect(tensors[0].data, hasLength(4)); 153 | }); 154 | 155 | test('set throws if not allocated', () { 156 | expect(() => tensors[0].data = Uint8List.fromList(const [0, 0, 0, 0]), 157 | throwsA(isStateError)); 158 | }); 159 | 160 | test('set', () { 161 | interpreter.allocateTensors(); 162 | tensors[0].data = Uint8List.fromList(const [0, 0, 0, 0]); 163 | expect(tensors[0].data, [0, 0, 0, 0]); 164 | tensors[0].data = Uint8List.fromList(const [0, 1, 10, 100]); 165 | expect(tensors[0].data, [0, 1, 10, 100]); 166 | }); 167 | 168 | test('copyTo', () { 169 | interpreter.allocateTensors(); 170 | expect(tensors[0].copyTo(), hasLength(4)); 171 | }); 172 | 173 | test('copyFrom throws if not allocated', () { 174 | expect( 175 | () => tensors[0].copyFrom(Uint8List.fromList(const [0, 0, 0, 0])), 176 | throwsA(isStateError)); 177 | }, skip: 'segmentation fault!'); 178 | // TODO(shanehop): Prevent data access for unallocated tensors. 179 | 180 | test('copyFrom', () { 181 | interpreter.allocateTensors(); 182 | tensors[0].copyFrom(Uint8List.fromList(const [0, 0, 0, 0])); 183 | expect(tensors[0].data, [0, 0, 0, 0]); 184 | tensors[0].copyFrom(Uint8List.fromList(const [0, 1, 10, 100])); 185 | expect(tensors[0].data, [0, 1, 10, 100]); 186 | }); 187 | }); 188 | }); 189 | }); 190 | } 191 | --------------------------------------------------------------------------------