├── LICENSE
├── README.md
├── package.json
├── plugin.xml
├── src
├── android
│ ├── TensorFlow.java
│ └── tf_libs
│ │ ├── Classifier.java
│ │ ├── TensorFlowImageClassifier.java
│ │ ├── armeabi-v7a
│ │ └── libtensorflow_inference.so
│ │ └── libandroid_tensorflow_inference_java.jar
└── ios
│ ├── TensorFlow.h
│ ├── TensorFlow.mm
│ └── tf_libs
│ ├── ios_image_load.h
│ ├── ios_image_load.mm
│ ├── tensorflow_utils.h
│ └── tensorflow_utils.mm
└── www
└── tensorflow.js
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2017 Houston Engineering, Inc., http://www.houstoneng.com
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy of
6 | this software and associated documentation files (the "Software"), to deal in
7 | the Software without restriction, including without limitation the rights to
8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
9 | the Software, and to permit persons to whom the Software is furnished to do so,
10 | subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
21 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # cordova-plugin-tensorflow
2 |
3 | Integrate the TensorFlow inference library into your PhoneGap/Cordova application!
4 |
5 | ```javascript
6 | var tf = new TensorFlow('inception-v1');
7 | var imgData = "/9j/4AAQSkZJRgABAQEAYABgAAD//gBGRm ...";
8 |
9 | tf.classify(imgData).then(function(results) {
10 | results.forEach(function(result) {
11 | console.log(result.title + " " + result.confidence);
12 | });
13 | });
14 |
15 | /* Output:
16 | military uniform 0.647296
17 | suit 0.0477196
18 | academic gown 0.0232411
19 | */
20 | ```
21 | ## Installation
22 |
23 | ### Cordova
24 | ```bash
25 | cordova plugin add https://github.com/heigeo/cordova-plugin-tensorflow
26 | ```
27 |
28 | ### PhoneGap Build
29 | ```xml
30 |
31 |
32 | ```
33 |
34 | ## Supported Platforms
35 |
36 | * Android
37 | * iOS
38 |
39 | ## API
40 |
41 | The plugin provides a `TensorFlow` class that can be used to initialize graphs and run the inference algorithm.
42 |
43 | ### Initialization
44 |
45 | ```javascript
46 | // Use the Inception model (will be downloaded on first use)
47 | var tf = new TensorFlow('inception-v1');
48 |
49 | // Use a custom retrained model
50 | var tf = new TensorFlow('custom-model', {
51 | 'label': 'My Custom Model',
52 | 'model_path': "https://example.com/graphs/custom-model-2017.zip#rounded_graph.pb",
53 | 'label_path': "https://example.com/graphs/custom-model-2017.zip#retrained_labels.txt",
54 | 'input_size': 299,
55 | 'image_mean': 128,
56 | 'image_std': 128,
57 | 'input_name': 'Mul',
58 | 'output_name': 'final_result'
59 | })
60 | ```
61 |
62 | To use a custom model, follow the steps to [retrain the model](https://www.tensorflow.org/tutorials/image_retraining) and [optimize it for mobile use](https://petewarden.com/2016/09/27/tensorflow-for-mobile-poets/).
63 | Put the .pb and .txt files in a HTTP-accessible zip file, which will be downloaded via the [FileTransfer plugin](https://cordova.apache.org/docs/en/latest/reference/cordova-plugin-file-transfer/). If you use the generic Inception model it will be downloaded from [the TensorFlow website](https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip) on first use.
64 |
65 | ### Methods
66 |
67 | Each method returns a `Promise` (if available) and also accepts a callback and errorCallback.
68 |
69 |
70 | ### classify(image[, callback, errorCallback])
71 | Classifies an image with TensorFlow's inference algorithm and the registered model. Will automatically download and initialize the model if necessary, but it is recommended to call `load()` explicitly for the best user experience.
72 |
73 | Note that the image must be provided as base64 encoded JPEG or PNG data. Support for file paths may be added in a future release.
74 |
75 | ```javascript
76 | var tf = new TensorFlow(...);
77 | var imgData = "/9j/4AAQSkZJRgABAQEAYABgAAD//gBGRm ...";
78 | tf.classify(imgData).then(function(results) {
79 | results.forEach(function(result) {
80 | console.log(result.title + " " + result.confidence);
81 | });
82 | });
83 | ```
84 |
85 | ### load()
86 |
87 | Downloads the referenced model files and loads the graph into TensorFlow.
88 |
89 | ```javascript
90 | var tf = new TensorFlow(...);
91 | tf.load().then(function() {
92 | console.log("Model loaded");
93 | });
94 | ```
95 |
96 | Downloading the model files can take some time. If you would like to provide a progress indicator, you can do that with an `onprogress` event:
97 | ```javascript
98 | var tf = new TensorFlow(...);
99 | tf.onprogress = function(evt) {
100 | if (evt['status'] == 'downloading')
101 | console.log("Downloading model files...");
102 | console.log(evt.label);
103 | if (evt.detail) {
104 | // evt.detail is from the FileTransfer API
105 | var $elem = $('progress');
106 | $elem.attr('max', evt.detail.total);
107 | $elem.attr('value', evt.detail.loaded);
108 | }
109 | } else if (evt['status'] == 'unzipping') {
110 | console.log("Extracting contents...");
111 | } else if (evt['status'] == 'initializing') {
112 | console.log("Initializing TensorFlow");
113 | }
114 | };
115 | tf.load().then(...);
116 | ```
117 |
118 | ### checkCached()
119 | Checks whether the requisite model files have already been downloaded. This is useful if you want to provide an interface for downloading and managing TensorFlow graphs that is separate from the classification interface.
120 |
121 | ```javascript
122 | var tf = new TensorFlow(...);
123 | tf.checkCached().then(function(isCached) {
124 | if (isCached) {
125 | $('button#download').hide();
126 | }
127 | });
128 | ```
129 |
130 | ## References
131 |
132 | This plugin is made possible by the following libraries and tutorials:
133 |
134 | Source | Files
135 | -------|--------
136 | [TensorFlow Android Inference Interface] | [libtensorflow_inference.so],
[libandroid_tensorflow_inference_java.jar]
137 | [TensorFlow Android Demo] |[Classifer.java],
[TensorFlowImageClassifier.java][TensorFlowImageClassifier.java] (modified)
138 | [TensorflowPod] | Referenced via [podspec]
139 | [TensorFlow iOS Examples] | [ios_image_load.mm][ios_image_load.mm] (modified),
[tensorflow_utils.mm][tensorflow_utils.mm] (+ RunModelViewController.mm)
140 |
141 | [TensorFlow Android Inference Interface]: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/android
142 | [libtensorflow_inference.so]: https://github.com/heigeo/cordova-plugin-tensorflow/blob/master/src/android/tf_libs/armeabi-v7a/libtensorflow_inference.so
143 | [libandroid_tensorflow_inference_java.jar]: https://github.com/heigeo/cordova-plugin-tensorflow/blob/master/src/android/tf_libs/libandroid_tensorflow_inference_java.jar
144 | [TensorFlow Android Demo]: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android
145 | [Classifer.java]: https://github.com/heigeo/cordova-plugin-tensorflow/blob/master/src/android/tf_libs/Classifier.java
146 | [TensorFlowImageClassifier.java]: https://github.com/heigeo/cordova-plugin-tensorflow/blob/master/src/android/tf_libs/TensorFlowImageClassifier.java
147 | [TensorflowPod]: https://github.com/rainbean/TensorflowPod
148 | [podspec]: https://github.com/heigeo/cordova-plugin-tensorflow/blob/master/plugin.xml#L38
149 | [TensorFlow iOS Examples]: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/ios_examples
150 | [ios_image_load.mm]: https://github.com/heigeo/cordova-plugin-tensorflow/blob/master/src/ios/tf_libs/ios_image_load.mm
151 | [tensorflow_utils.mm]: https://github.com/heigeo/cordova-plugin-tensorflow/blob/master/src/ios/tf_libs/tensorflow_utils.mm
152 |
--------------------------------------------------------------------------------
/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "cordova-plugin-tensorflow",
3 | "version": "0.0.1",
4 | "description": "TensorFlow for Cordova",
5 | "cordova": {
6 | "id": "cordova-plugin-tensorflow",
7 | "platforms": [
8 | "android"
9 | ]
10 | },
11 | "repository": {
12 | "type": "git",
13 | "url": "https://github.com/heigeo/cordova-plugin-tensorflow.git"
14 | },
15 | "keywords": [
16 | "ai",
17 | "inference",
18 | "classification",
19 | "imagerecognition",
20 | "neuralnetworks",
21 | "machinelearning",
22 | "tensorflow",
23 | "inception",
24 | "ecosystem:cordova",
25 | "cordova-android"
26 | ],
27 | "engines": [
28 | {
29 | "name": "cordova-android",
30 | "version": ">=5.1.0"
31 | }
32 | ],
33 | "author": "Houston Engineering, Inc.",
34 | "license": "MIT",
35 | "bugs": {
36 | "url": "https://github.com/heigeo/cordova-plugin-tensorflow/issues"
37 | },
38 | "homepage": "https://github.com/heigeo/cordova-plugin-tensorflow#readme"
39 | }
40 |
--------------------------------------------------------------------------------
/plugin.xml:
--------------------------------------------------------------------------------
1 |
2 |
4 | TensorFlow for Cordova
5 | Cordova/PhoneGap wrapper for TensorFlow's image recognition binary library.
6 | MIT
7 | cordova,device
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
--------------------------------------------------------------------------------
/src/android/TensorFlow.java:
--------------------------------------------------------------------------------
1 | package io.wq.tensorflow;
2 |
3 | import org.apache.cordova.CordovaPlugin;
4 | import org.apache.cordova.CallbackContext;
5 | import org.json.JSONArray;
6 | import org.json.JSONObject;
7 | import org.json.JSONException;
8 |
9 | import android.util.Base64;
10 | import android.graphics.Bitmap;
11 | import android.graphics.Bitmap.Config;
12 | import android.graphics.BitmapFactory;
13 |
14 | import org.tensorflow.demo.TensorFlowImageClassifier;
15 | import org.tensorflow.demo.Classifier.Recognition;
16 | import org.tensorflow.demo.Classifier;
17 | import java.util.List;
18 | import java.util.Map;
19 | import java.util.HashMap;
20 |
21 | import android.media.ThumbnailUtils;
22 |
23 | public class TensorFlow extends CordovaPlugin {
24 |
25 | @Override
26 | public boolean execute(String action, JSONArray args, CallbackContext callbackContext) throws JSONException {
27 | if (action.equals("loadModel")) {
28 | this.loadModel(
29 | args.getString(0),
30 | args.getString(1),
31 | args.getString(2),
32 | args.getInt(3),
33 | args.getInt(4),
34 | (float) args.getDouble(5),
35 | args.getString(6),
36 | args.getString(7),
37 | callbackContext
38 | );
39 | return true;
40 | } else if (action.equals("classify")) {
41 | this.classify(args.getString(0), args.getString(1), callbackContext);
42 | return true;
43 | } else {
44 | return false;
45 | }
46 | }
47 |
48 |
49 | private Map classifiers = new HashMap();
50 | private Map sizes = new HashMap();
51 |
52 | private void loadModel(String modelName, String modelFile, String labelFile,
53 | int inputSize, int imageMean, float imageStd,
54 | String inputName, String outputName,
55 | CallbackContext callbackContext) {
56 | classifiers.put(modelName, TensorFlowImageClassifier.create(
57 | cordova.getActivity().getAssets(),
58 | modelFile,
59 | labelFile,
60 | inputSize,
61 | imageMean,
62 | imageStd,
63 | inputName,
64 | outputName
65 | ));
66 | sizes.put(modelName, inputSize);
67 | callbackContext.success();
68 | }
69 |
70 | private void classify(String modelName, String image, CallbackContext callbackContext) {
71 | byte[] imageData = Base64.decode(image, Base64.DEFAULT);
72 | Classifier classifier = classifiers.get(modelName);
73 | int size = sizes.get(modelName);
74 | Bitmap bitmap = BitmapFactory.decodeByteArray(imageData, 0, imageData.length);
75 | Bitmap cropped = ThumbnailUtils.extractThumbnail(bitmap, size, size);
76 | List results = classifier.recognizeImage(cropped);
77 | JSONArray output = new JSONArray();
78 | try {
79 | for (Recognition result : results) {
80 | JSONObject record = new JSONObject();
81 | record.put("title", result.getTitle());
82 | record.put("confidence", result.getConfidence());
83 | output.put(record);
84 | }
85 | } catch (JSONException e) {
86 | }
87 | callbackContext.success(output);
88 | }
89 |
90 |
91 | }
92 |
--------------------------------------------------------------------------------
/src/android/tf_libs/Classifier.java:
--------------------------------------------------------------------------------
1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | ==============================================================================*/
15 |
16 | package org.tensorflow.demo;
17 |
18 | import android.graphics.Bitmap;
19 | import android.graphics.RectF;
20 | import java.util.List;
21 |
22 | /**
23 | * Generic interface for interacting with different recognition engines.
24 | */
25 | public interface Classifier {
26 | /**
27 | * An immutable result returned by a Classifier describing what was recognized.
28 | */
29 | public class Recognition {
30 | /**
31 | * A unique identifier for what has been recognized. Specific to the class, not the instance of
32 | * the object.
33 | */
34 | private final String id;
35 |
36 | /**
37 | * Display name for the recognition.
38 | */
39 | private final String title;
40 |
41 | /**
42 | * A sortable score for how good the recognition is relative to others. Higher should be better.
43 | */
44 | private final Float confidence;
45 |
46 | /** Optional location within the source image for the location of the recognized object. */
47 | private RectF location;
48 |
49 | public Recognition(
50 | final String id, final String title, final Float confidence, final RectF location) {
51 | this.id = id;
52 | this.title = title;
53 | this.confidence = confidence;
54 | this.location = location;
55 | }
56 |
57 | public String getId() {
58 | return id;
59 | }
60 |
61 | public String getTitle() {
62 | return title;
63 | }
64 |
65 | public Float getConfidence() {
66 | return confidence;
67 | }
68 |
69 | public RectF getLocation() {
70 | return new RectF(location);
71 | }
72 |
73 | public void setLocation(RectF location) {
74 | this.location = location;
75 | }
76 |
77 | @Override
78 | public String toString() {
79 | String resultString = "";
80 | if (id != null) {
81 | resultString += "[" + id + "] ";
82 | }
83 |
84 | if (title != null) {
85 | resultString += title + " ";
86 | }
87 |
88 | if (confidence != null) {
89 | resultString += String.format("(%.1f%%) ", confidence * 100.0f);
90 | }
91 |
92 | if (location != null) {
93 | resultString += location + " ";
94 | }
95 |
96 | return resultString.trim();
97 | }
98 | }
99 |
100 | List recognizeImage(Bitmap bitmap);
101 |
102 | void enableStatLogging(final boolean debug);
103 |
104 | String getStatString();
105 |
106 | void close();
107 | }
108 |
--------------------------------------------------------------------------------
/src/android/tf_libs/TensorFlowImageClassifier.java:
--------------------------------------------------------------------------------
1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | ==============================================================================*/
15 |
16 | package org.tensorflow.demo;
17 |
18 | import android.content.res.AssetManager;
19 | import android.graphics.Bitmap;
20 | import android.os.Trace;
21 | import android.util.Log;
22 | import java.io.FileInputStream;
23 | import java.io.BufferedReader;
24 | import java.io.IOException;
25 | import java.io.InputStreamReader;
26 | import java.util.ArrayList;
27 | import java.util.Comparator;
28 | import java.util.List;
29 | import java.util.PriorityQueue;
30 | import java.util.Vector;
31 | import org.tensorflow.Operation;
32 | import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
33 |
34 | /** A classifier specialized to label images using TensorFlow. */
35 | public class TensorFlowImageClassifier implements Classifier {
36 |
37 | private static final String TAG = "TensorFlowImageClassifier";
38 |
39 | // Only return this many results with at least this confidence.
40 | private static final int MAX_RESULTS = 3;
41 | private static final float THRESHOLD = 0.1f;
42 |
43 | // Config values.
44 | private String inputName;
45 | private String outputName;
46 | private int inputSize;
47 | private int imageMean;
48 | private float imageStd;
49 |
50 | // Pre-allocated buffers.
51 | private Vector labels = new Vector();
52 | private int[] intValues;
53 | private float[] floatValues;
54 | private float[] outputs;
55 | private String[] outputNames;
56 |
57 | private TensorFlowInferenceInterface inferenceInterface;
58 |
59 | private TensorFlowImageClassifier() {}
60 |
61 | /**
62 | * Initializes a native TensorFlow session for classifying images.
63 | *
64 | * @param assetManager The asset manager to be used to load assets.
65 | * @param modelFilename The filepath of the model GraphDef protocol buffer.
66 | * @param labelFilename The filepath of label file for classes.
67 | * @param inputSize The input size. A square image of inputSize x inputSize is assumed.
68 | * @param imageMean The assumed mean of the image values.
69 | * @param imageStd The assumed std of the image values.
70 | * @param inputName The label of the image input node.
71 | * @param outputName The label of the output node.
72 | * @throws IOException
73 | */
74 | public static Classifier create(
75 | AssetManager assetManager,
76 | String modelFilename,
77 | String labelFilename,
78 | int inputSize,
79 | int imageMean,
80 | float imageStd,
81 | String inputName,
82 | String outputName) {
83 | TensorFlowImageClassifier c = new TensorFlowImageClassifier();
84 | c.inputName = inputName;
85 | c.outputName = outputName;
86 |
87 | // Read the label names into memory.
88 | // TODO(andrewharp): make this handle non-assets.
89 | final boolean hasAssetPrefix = labelFilename.startsWith("file:///android_asset/");
90 | String actualFilename = hasAssetPrefix ? labelFilename.split("file:///android_asset/")[1] : labelFilename;
91 | Log.i(TAG, "Reading labels from: " + actualFilename);
92 | BufferedReader br = null;
93 | try {
94 | br = new BufferedReader(new InputStreamReader(assetManager.open(actualFilename)));
95 | String line;
96 | while ((line = br.readLine()) != null) {
97 | c.labels.add(line);
98 | }
99 | br.close();
100 | } catch (IOException e) {
101 | if (hasAssetPrefix) {
102 | throw new RuntimeException("Problem reading label file!" , e);
103 | }
104 | try {
105 | br = new BufferedReader(new InputStreamReader(new FileInputStream(actualFilename)));
106 | String line;
107 | while ((line = br.readLine()) != null) {
108 | c.labels.add(line);
109 | }
110 | br.close();
111 | } catch (IOException e2) {
112 | throw new RuntimeException("Problem reading label file!" , e);
113 | }
114 | }
115 |
116 | c.inferenceInterface = new TensorFlowInferenceInterface();
117 | if (c.inferenceInterface.initializeTensorFlow(assetManager, modelFilename) != 0) {
118 | throw new RuntimeException("TF initialization failed");
119 | }
120 | // The shape of the output is [N, NUM_CLASSES], where N is the batch size.
121 | final Operation operation = c.inferenceInterface.graph().operation(outputName);
122 | if (operation == null) {
123 | throw new RuntimeException("Node '" + outputName + "' does not exist in model '"
124 | + modelFilename + "'");
125 | }
126 | final int numClasses = (int) operation.output(0).shape().size(1);
127 | Log.i(TAG, "Read " + c.labels.size() + " labels, output layer size is " + numClasses);
128 |
129 | // Ideally, inputSize could have been retrieved from the shape of the input operation. Alas,
130 | // the placeholder node for input in the graphdef typically used does not specify a shape, so it
131 | // must be passed in as a parameter.
132 | c.inputSize = inputSize;
133 | c.imageMean = imageMean;
134 | c.imageStd = imageStd;
135 |
136 | // Pre-allocate buffers.
137 | c.outputNames = new String[] {outputName};
138 | c.intValues = new int[inputSize * inputSize];
139 | c.floatValues = new float[inputSize * inputSize * 3];
140 | c.outputs = new float[numClasses];
141 |
142 | return c;
143 | }
144 |
145 | @Override
146 | public List recognizeImage(final Bitmap bitmap) {
147 | // Log this method so that it can be analyzed with systrace.
148 | Trace.beginSection("recognizeImage");
149 |
150 | Trace.beginSection("preprocessBitmap");
151 | // Preprocess the image data from 0-255 int to normalized float based
152 | // on the provided parameters.
153 | bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
154 | for (int i = 0; i < intValues.length; ++i) {
155 | final int val = intValues[i];
156 | floatValues[i * 3 + 0] = (((val >> 16) & 0xFF) - imageMean) / imageStd;
157 | floatValues[i * 3 + 1] = (((val >> 8) & 0xFF) - imageMean) / imageStd;
158 | floatValues[i * 3 + 2] = ((val & 0xFF) - imageMean) / imageStd;
159 | }
160 | Trace.endSection();
161 |
162 | // Copy the input data into TensorFlow.
163 | Trace.beginSection("fillNodeFloat");
164 | inferenceInterface.fillNodeFloat(
165 | inputName, new int[] {1, inputSize, inputSize, 3}, floatValues);
166 | Trace.endSection();
167 |
168 | // Run the inference call.
169 | Trace.beginSection("runInference");
170 | inferenceInterface.runInference(outputNames);
171 | Trace.endSection();
172 |
173 | // Copy the output Tensor back into the output array.
174 | Trace.beginSection("readNodeFloat");
175 | inferenceInterface.readNodeFloat(outputName, outputs);
176 | Trace.endSection();
177 |
178 | // Find the best classifications.
179 | PriorityQueue pq =
180 | new PriorityQueue(
181 | 3,
182 | new Comparator() {
183 | @Override
184 | public int compare(Recognition lhs, Recognition rhs) {
185 | // Intentionally reversed to put high confidence at the head of the queue.
186 | return Float.compare(rhs.getConfidence(), lhs.getConfidence());
187 | }
188 | });
189 | for (int i = 0; i < outputs.length; ++i) {
190 | if (outputs[i] > THRESHOLD) {
191 | pq.add(
192 | new Recognition(
193 | "" + i, labels.size() > i ? labels.get(i) : "unknown", outputs[i], null));
194 | }
195 | }
196 | final ArrayList recognitions = new ArrayList();
197 | int recognitionsSize = Math.min(pq.size(), MAX_RESULTS);
198 | for (int i = 0; i < recognitionsSize; ++i) {
199 | recognitions.add(pq.poll());
200 | }
201 | Trace.endSection(); // "recognizeImage"
202 | return recognitions;
203 | }
204 |
205 | @Override
206 | public void enableStatLogging(boolean debug) {
207 | inferenceInterface.enableStatLogging(debug);
208 | }
209 |
210 | @Override
211 | public String getStatString() {
212 | return inferenceInterface.getStatString();
213 | }
214 |
215 | @Override
216 | public void close() {
217 | inferenceInterface.close();
218 | }
219 | }
220 |
--------------------------------------------------------------------------------
/src/android/tf_libs/armeabi-v7a/libtensorflow_inference.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/heigeo/cordova-plugin-tensorflow/9c8b74c81a642b1381be517de8f22e0caa649180/src/android/tf_libs/armeabi-v7a/libtensorflow_inference.so
--------------------------------------------------------------------------------
/src/android/tf_libs/libandroid_tensorflow_inference_java.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/heigeo/cordova-plugin-tensorflow/9c8b74c81a642b1381be517de8f22e0caa649180/src/android/tf_libs/libandroid_tensorflow_inference_java.jar
--------------------------------------------------------------------------------
/src/ios/TensorFlow.h:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include "tensorflow/core/public/session.h"
4 |
5 | #import "ios_image_load.h"
6 | #import "tensorflow_utils.h"
7 | #import
8 |
9 | @interface TensorFlow : CDVPlugin {
10 | NSMutableDictionary *classifiers;
11 | }
12 |
13 | - (void)loadModel:(CDVInvokedUrlCommand*)command;
14 | - (void)classify:(CDVInvokedUrlCommand*)command;
15 |
16 | @end
17 |
18 | @interface Classifier : NSObject {
19 | std::unique_ptr session;
20 | NSString* model_file;
21 | NSString* label_file;
22 | std::vector labels;
23 | int input_size;
24 | int image_mean;
25 | float image_std;
26 | NSString* input_name;
27 | NSString* output_name;
28 | }
29 |
30 | - (id)initWithModel:(NSString *)model_file_
31 | label_file:(NSString *)label_file_
32 | input_size:(int)input_size_
33 | image_mean:(int)image_mean_
34 | image_std:(float)image_std_
35 | input_name:(NSString *)input_name_
36 | output_name:(NSString *)output_name_;
37 | - (tensorflow::Status)load;
38 | - (tensorflow::Status)classify:(NSString *)image results:(NSMutableArray *)results;
39 |
40 | @end
41 |
--------------------------------------------------------------------------------
/src/ios/TensorFlow.mm:
--------------------------------------------------------------------------------
1 | #import "TensorFlow.h"
2 | #import
3 | #import "ios_image_load.h"
4 | #import "tensorflow_utils.h"
5 |
6 | @implementation TensorFlow
7 |
8 | - (void)loadModel:(CDVInvokedUrlCommand*)command
9 | {
10 | CDVPluginResult* pluginResult = nil;
11 | NSString* model_name = [command.arguments objectAtIndex:0];
12 | Classifier* classifier = [
13 | [Classifier alloc]
14 | initWithModel:[command.arguments objectAtIndex:1]
15 | label_file:[command.arguments objectAtIndex:2]
16 | input_size:[(NSNumber *)[command.arguments objectAtIndex:3] intValue]
17 | image_mean:[(NSNumber *)[command.arguments objectAtIndex:4] intValue]
18 | image_std:[(NSNumber *)[command.arguments objectAtIndex:5] floatValue]
19 | input_name:[command.arguments objectAtIndex:6]
20 | output_name:[command.arguments objectAtIndex:7]
21 | ];
22 | if (classifiers == nil) {
23 | classifiers = [NSMutableDictionary dictionaryWithDictionary:@{}];
24 | }
25 | classifiers[model_name] = classifier;
26 | tensorflow::Status result = [classifier load];
27 | if (result.ok()) {
28 | pluginResult = [CDVPluginResult resultWithStatus:CDVCommandStatus_OK];
29 | } else {
30 | pluginResult = [CDVPluginResult resultWithStatus:CDVCommandStatus_ERROR];
31 | }
32 |
33 | [self.commandDelegate sendPluginResult:pluginResult callbackId:command.callbackId];
34 | }
35 |
36 | - (void)classify:(CDVInvokedUrlCommand*)command
37 | {
38 | CDVPluginResult* pluginResult = nil;
39 | NSString* model_name = [command.arguments objectAtIndex:0];
40 | NSString* image = [command.arguments objectAtIndex:1];
41 | Classifier* classifier = classifiers[model_name];
42 | NSMutableArray* results = [[NSMutableArray alloc] init];
43 | tensorflow::Status result = [classifier classify:image results:results];
44 | if (result.ok()) {
45 | pluginResult = [CDVPluginResult resultWithStatus:CDVCommandStatus_OK messageAsArray:results];
46 | } else {
47 | pluginResult = [CDVPluginResult resultWithStatus:CDVCommandStatus_ERROR];
48 | }
49 |
50 | [self.commandDelegate sendPluginResult:pluginResult callbackId:command.callbackId];
51 | }
52 |
53 | @end
54 |
55 | @implementation Classifier
56 |
57 | - (id)initWithModel:(NSString *)model_file_
58 | label_file:(NSString *)label_file_
59 | input_size:(int)input_size_
60 | image_mean:(int)image_mean_
61 | image_std:(float)image_std_
62 | input_name:(NSString *)input_name_
63 | output_name:(NSString *)output_name_
64 | {
65 | self = [super init];
66 | if (self) {
67 | model_file = model_file_;
68 | label_file = label_file_;
69 | input_size = input_size_;
70 | image_mean = image_mean_;
71 | image_std = image_std_;
72 | input_name = input_name_;
73 | output_name = output_name_;
74 | }
75 | return self;
76 | }
77 |
78 | - (tensorflow::Status)load
79 | {
80 | tensorflow::Status result;
81 | result = LoadModel(model_file, &session);
82 | if (result.ok()) {
83 | result = LoadLabels(label_file, &labels);
84 | }
85 | return result;
86 | }
87 |
88 | - (tensorflow::Status)classify:(NSString *)image results:(NSMutableArray *)results
89 | {
90 | std::vector tfresults;
91 | tensorflow::Status result = RunInferenceOnImage(
92 | image,
93 | input_size,
94 | image_mean,
95 | image_std,
96 | [input_name UTF8String],
97 | [output_name UTF8String],
98 | &session,
99 | &labels,
100 | &tfresults
101 | );
102 | if (!result.ok()) {
103 | return result;
104 | }
105 | for (struct Result result : tfresults) {
106 | [results addObject: @{
107 | @"title": result.label,
108 | @"confidence": result.confidence
109 | }];
110 | }
111 | }
112 |
113 | @end
114 |
--------------------------------------------------------------------------------
/src/ios/tf_libs/ios_image_load.h:
--------------------------------------------------------------------------------
1 | // Copyright 2015 Google Inc. All rights reserved.
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_
16 | #define TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_
17 |
18 | #include
19 |
20 | #include "tensorflow/core/framework/types.h"
21 |
22 | std::vector LoadImageFromFile(const char* file_name,
23 | int* out_width,
24 | int* out_height,
25 | int* out_channels);
26 |
27 | std::vector LoadImageFromBase64(NSString* base64data,
28 | int* out_width,
29 | int* out_height,
30 | int* out_channels);
31 |
32 | std::vector LoadImageFromData(CFDataRef data_ref,
33 | const char* suffix,
34 | int* out_width,
35 | int* out_height,
36 | int* out_channels);
37 |
38 | #endif // TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_
39 |
--------------------------------------------------------------------------------
/src/ios/tf_libs/ios_image_load.mm:
--------------------------------------------------------------------------------
1 | // Copyright 2015 Google Inc. All rights reserved.
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #include "ios_image_load.h"
16 |
17 | #include
18 | #include
19 | #include
20 | #include
21 |
22 | #import
23 | #import
24 |
25 | using tensorflow::uint8;
26 |
27 | std::vector LoadImageFromFile(const char* file_name,
28 | int* out_width, int* out_height,
29 | int* out_channels) {
30 | FILE* file_handle = fopen(file_name, "rb");
31 | fseek(file_handle, 0, SEEK_END);
32 | const size_t bytes_in_file = ftell(file_handle);
33 | fseek(file_handle, 0, SEEK_SET);
34 | std::vector file_data(bytes_in_file);
35 | fread(file_data.data(), 1, bytes_in_file, file_handle);
36 | fclose(file_handle);
37 | CFDataRef file_data_ref = CFDataCreateWithBytesNoCopy(NULL, file_data.data(),
38 | bytes_in_file,
39 | kCFAllocatorNull);
40 | const char* suffix = strrchr(file_name, '.');
41 | if (!suffix || suffix == file_name) {
42 | suffix = "";
43 | }
44 | return LoadImageFromData(file_data_ref, suffix, out_width, out_height, out_channels);
45 | }
46 |
47 | std::vector LoadImageFromBase64(NSString* base64data,
48 | int* out_width, int* out_height,
49 | int* out_channels) {
50 | NSData *data = [[NSData alloc] initWithBase64EncodedString:base64data options:0];
51 | CFDataRef data_ref = CFDataCreateWithBytesNoCopy(NULL, (UInt8 *)[data bytes], [data length], kCFAllocatorNull);
52 | return LoadImageFromData(data_ref, [@".jpeg" UTF8String], out_width, out_height, out_channels);
53 | }
54 |
55 | std::vector LoadImageFromData(CFDataRef data_ref,
56 | const char* suffix,
57 | int* out_width, int* out_height,
58 | int* out_channels) {
59 | CGDataProviderRef image_provider =
60 | CGDataProviderCreateWithCFData(data_ref);
61 |
62 | CGImageRef image;
63 | if (strcasecmp(suffix, ".png") == 0) {
64 | image = CGImageCreateWithPNGDataProvider(image_provider, NULL, true,
65 | kCGRenderingIntentDefault);
66 | } else if ((strcasecmp(suffix, ".jpg") == 0) ||
67 | (strcasecmp(suffix, ".jpeg") == 0)) {
68 | image = CGImageCreateWithJPEGDataProvider(image_provider, NULL, true,
69 | kCGRenderingIntentDefault);
70 | } else {
71 | CFRelease(image_provider);
72 | CFRelease(data_ref);
73 | fprintf(stderr, "Unknown suffix '%s'\n", suffix);
74 | *out_width = 0;
75 | *out_height = 0;
76 | *out_channels = 0;
77 | return std::vector();
78 | }
79 |
80 | const int width = (int)CGImageGetWidth(image);
81 | const int height = (int)CGImageGetHeight(image);
82 | const int channels = 4;
83 | CGColorSpaceRef color_space = CGColorSpaceCreateDeviceRGB();
84 | const int bytes_per_row = (width * channels);
85 | const int bytes_in_image = (bytes_per_row * height);
86 | std::vector result(bytes_in_image);
87 | const int bits_per_component = 8;
88 | CGContextRef context = CGBitmapContextCreate(result.data(), width, height,
89 | bits_per_component, bytes_per_row, color_space,
90 | kCGImageAlphaPremultipliedLast | kCGBitmapByteOrder32Big);
91 | CGColorSpaceRelease(color_space);
92 | CGContextDrawImage(context, CGRectMake(0, 0, width, height), image);
93 | CGContextRelease(context);
94 | CFRelease(image);
95 | CFRelease(image_provider);
96 | CFRelease(data_ref);
97 |
98 | *out_width = width;
99 | *out_height = height;
100 | *out_channels = channels;
101 | return result;
102 | }
103 |
--------------------------------------------------------------------------------
/src/ios/tf_libs/tensorflow_utils.h:
--------------------------------------------------------------------------------
1 | // Copyright 2015 Google Inc. All rights reserved.
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef TENSORFLOW_CONTRIB_IOS_EXAMPLES_CAMERA_TENSORFLOW_UTILS_H_
16 | #define TENSORFLOW_CONTRIB_IOS_EXAMPLES_CAMERA_TENSORFLOW_UTILS_H_
17 |
18 | #include
19 | #include
20 |
21 | #include "tensorflow/core/public/session.h"
22 | #include "tensorflow/core/util/memmapped_file_system.h"
23 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24 |
25 | // Reads a serialized GraphDef protobuf file from the bundle, typically
26 | // created with the freeze_graph script. Populates the session argument with a
27 | // Session object that has the model loaded.
28 | tensorflow::Status LoadModel(NSString* file_name,
29 | std::unique_ptr* session);
30 |
31 | // Loads a model from a file that has been created using the
32 | // convert_graphdef_memmapped_format tool. This bundles together a GraphDef
33 | // proto together with a file that can be memory-mapped, containing the weight
34 | // parameters for the model. This is useful because it reduces the overall
35 | // memory pressure, since the read-only parameter regions can be easily paged
36 | // out and don't count toward memory limits on iOS.
37 | tensorflow::Status LoadMemoryMappedModel(
38 | NSString* file_name,
39 | std::unique_ptr* session,
40 | std::unique_ptr* memmapped_env);
41 |
42 | // Takes a text file with a single label on each line, and returns a list.
43 | tensorflow::Status LoadLabels(NSString* file_name,
44 | std::vector* label_strings);
45 |
46 | // Sorts the results from a model execution, and returns the highest scoring.
47 | void GetTopN(const Eigen::TensorMap,
48 | Eigen::Aligned>& prediction,
49 | const int num_results, const float threshold,
50 | std::vector >* top_results);
51 |
52 | struct Result {
53 | NSString* label;
54 | NSNumber* confidence;
55 | };
56 |
57 | tensorflow::Status RunInferenceOnImage(NSString* image, int input_size, float input_mean, float input_std, std::string input_layer, std::string output_layer, std::unique_ptr* session, std::vector* labels, std::vector* results);
58 |
59 | #endif // TENSORFLOW_CONTRIB_IOS_EXAMPLES_CAMERA_TENSORFLOW_UTILS_H_
60 |
--------------------------------------------------------------------------------
/src/ios/tf_libs/tensorflow_utils.mm:
--------------------------------------------------------------------------------
1 | // Copyright 2015 Google Inc. All rights reserved.
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #import
16 |
17 | #include "tensorflow_utils.h"
18 | #include "ios_image_load.h"
19 |
20 | #include
21 | #include
22 | #include
23 | #include
24 | #include
25 | #include
26 |
27 | #include "google/protobuf/io/coded_stream.h"
28 | #include "google/protobuf/io/zero_copy_stream_impl.h"
29 | #include "google/protobuf/io/zero_copy_stream_impl_lite.h"
30 | #include "google/protobuf/message_lite.h"
31 | #include "tensorflow/core/framework/tensor.h"
32 | #include "tensorflow/core/framework/types.pb.h"
33 | #include "tensorflow/core/platform/env.h"
34 | #include "tensorflow/core/platform/logging.h"
35 | #include "tensorflow/core/platform/mutex.h"
36 | #include "tensorflow/core/platform/types.h"
37 | #include "tensorflow/core/public/session.h"
38 |
39 | namespace {
40 |
41 | // Helper class used to load protobufs efficiently.
42 | class IfstreamInputStream : public ::google::protobuf::io::CopyingInputStream {
43 | public:
44 | explicit IfstreamInputStream(const std::string& file_name)
45 | : ifs_(file_name.c_str(), std::ios::in | std::ios::binary) {}
46 | ~IfstreamInputStream() { ifs_.close(); }
47 |
48 | int Read(void* buffer, int size) {
49 | if (!ifs_) {
50 | return -1;
51 | }
52 | ifs_.read(static_cast(buffer), size);
53 | return ifs_.gcount();
54 | }
55 |
56 | private:
57 | std::ifstream ifs_;
58 | };
59 | } // namespace
60 |
61 | // Returns the top N confidence values over threshold in the provided vector,
62 | // sorted by confidence in descending order.
63 | void GetTopN(const Eigen::TensorMap,
64 | Eigen::Aligned>& prediction,
65 | const int num_results, const float threshold,
66 | std::vector >* top_results) {
67 | // Will contain top N results in ascending order.
68 | std::priority_queue,
69 | std::vector >,
70 | std::greater > >
71 | top_result_pq;
72 |
73 | const int count = prediction.size();
74 | for (int i = 0; i < count; ++i) {
75 | const float value = prediction(i);
76 |
77 | // Only add it if it beats the threshold and has a chance at being in
78 | // the top N.
79 | if (value < threshold) {
80 | continue;
81 | }
82 |
83 | top_result_pq.push(std::pair(value, i));
84 |
85 | // If at capacity, kick the smallest value out.
86 | if (top_result_pq.size() > num_results) {
87 | top_result_pq.pop();
88 | }
89 | }
90 |
91 | // Copy to output vector and reverse into descending order.
92 | while (!top_result_pq.empty()) {
93 | top_results->push_back(top_result_pq.top());
94 | top_result_pq.pop();
95 | }
96 | std::reverse(top_results->begin(), top_results->end());
97 | }
98 |
99 | bool PortableReadFileToProto(const std::string& file_name,
100 | ::google::protobuf::MessageLite* proto) {
101 | ::google::protobuf::io::CopyingInputStreamAdaptor stream(
102 | new IfstreamInputStream(file_name));
103 | stream.SetOwnsCopyingStream(true);
104 | ::google::protobuf::io::CodedInputStream coded_stream(&stream);
105 | // Total bytes hard limit / warning limit are set to 1GB and 512MB
106 | // respectively.
107 | coded_stream.SetTotalBytesLimit(1024LL << 20, 512LL << 20);
108 | return proto->ParseFromCodedStream(&coded_stream);
109 | }
110 |
111 | tensorflow::Status LoadModel(NSString* model_path,
112 | std::unique_ptr* session) {
113 | tensorflow::SessionOptions options;
114 |
115 | tensorflow::Session* session_pointer = nullptr;
116 | tensorflow::Status session_status =
117 | tensorflow::NewSession(options, &session_pointer);
118 | if (!session_status.ok()) {
119 | LOG(ERROR) << "Could not create TensorFlow Session: " << session_status;
120 | return session_status;
121 | }
122 | session->reset(session_pointer);
123 |
124 | tensorflow::GraphDef tensorflow_graph;
125 |
126 | const bool read_proto_succeeded =
127 | PortableReadFileToProto([model_path UTF8String], &tensorflow_graph);
128 | if (!read_proto_succeeded) {
129 | LOG(ERROR) << "Failed to load model proto from" << [model_path UTF8String];
130 | return tensorflow::errors::NotFound([model_path UTF8String]);
131 | }
132 |
133 | tensorflow::Status create_status = (*session)->Create(tensorflow_graph);
134 | if (!create_status.ok()) {
135 | LOG(ERROR) << "Could not create TensorFlow Graph: " << create_status;
136 | return create_status;
137 | }
138 |
139 | return tensorflow::Status::OK();
140 | }
141 |
142 | tensorflow::Status LoadMemoryMappedModel(
143 | NSString* network_path,
144 | std::unique_ptr* session,
145 | std::unique_ptr* memmapped_env) {
146 | memmapped_env->reset(
147 | new tensorflow::MemmappedEnv(tensorflow::Env::Default()));
148 | tensorflow::Status mmap_status =
149 | (memmapped_env->get())->InitializeFromFile([network_path UTF8String]);
150 | if (!mmap_status.ok()) {
151 | LOG(ERROR) << "MMap failed with " << mmap_status.error_message();
152 | return mmap_status;
153 | }
154 |
155 | tensorflow::GraphDef tensorflow_graph;
156 | tensorflow::Status load_graph_status = ReadBinaryProto(
157 | memmapped_env->get(),
158 | tensorflow::MemmappedFileSystem::kMemmappedPackageDefaultGraphDef,
159 | &tensorflow_graph);
160 | if (!load_graph_status.ok()) {
161 | LOG(ERROR) << "MMap load graph failed with "
162 | << load_graph_status.error_message();
163 | return load_graph_status;
164 | }
165 |
166 | tensorflow::SessionOptions options;
167 | // Disable optimizations on this graph so that constant folding doesn't
168 | // increase the memory footprint by creating new constant copies of the weight
169 | // parameters.
170 | options.config.mutable_graph_options()
171 | ->mutable_optimizer_options()
172 | ->set_opt_level(::tensorflow::OptimizerOptions::L0);
173 | options.env = memmapped_env->get();
174 |
175 | tensorflow::Session* session_pointer = nullptr;
176 | tensorflow::Status session_status =
177 | tensorflow::NewSession(options, &session_pointer);
178 | if (!session_status.ok()) {
179 | LOG(ERROR) << "Could not create TensorFlow Session: " << session_status;
180 | return session_status;
181 | }
182 |
183 | tensorflow::Status create_status = session_pointer->Create(tensorflow_graph);
184 | if (!create_status.ok()) {
185 | LOG(ERROR) << "Could not create TensorFlow Graph: " << create_status;
186 | return create_status;
187 | }
188 |
189 | session->reset(session_pointer);
190 |
191 | return tensorflow::Status::OK();
192 | }
193 |
194 | tensorflow::Status LoadLabels(NSString* labels_path,
195 | std::vector* label_strings) {
196 | // Read the label list
197 | std::ifstream t;
198 | t.open([labels_path UTF8String]);
199 | std::string line;
200 | while (t) {
201 | std::getline(t, line);
202 | label_strings->push_back(line);
203 | }
204 | t.close();
205 | return tensorflow::Status::OK();
206 | }
207 |
208 | tensorflow::Status RunInferenceOnImage(NSString* image, int input_size, float input_mean, float input_std, std::string input_layer, std::string output_layer, std::unique_ptr* session, std::vector* labels, std::vector* results) {
209 |
210 | int image_width;
211 | int image_height;
212 | int image_channels;
213 | std::vector image_data = LoadImageFromBase64(
214 | image, &image_width, &image_height, &image_channels);
215 | const int wanted_width = input_size;
216 | const int wanted_height = input_size;
217 | const int wanted_channels = 3;
218 |
219 | assert(image_channels >= wanted_channels);
220 | tensorflow::Tensor image_tensor(
221 | tensorflow::DT_FLOAT,
222 | tensorflow::TensorShape({
223 | 1, wanted_height, wanted_width, wanted_channels}));
224 | auto image_tensor_mapped = image_tensor.tensor();
225 | tensorflow::uint8* in = image_data.data();
226 | tensorflow::uint8* in_end = (in + (image_height * image_width * image_channels));
227 | float* out = image_tensor_mapped.data();
228 | for (int y = 0; y < wanted_height; ++y) {
229 | const int in_y = (y * image_height) / wanted_height;
230 | tensorflow::uint8* in_row = in + (in_y * image_width * image_channels);
231 | float* out_row = out + (y * wanted_width * wanted_channels);
232 | for (int x = 0; x < wanted_width; ++x) {
233 | const int in_x = (x * image_width) / wanted_width;
234 | tensorflow::uint8* in_pixel = in_row + (in_x * image_channels);
235 | float* out_pixel = out_row + (x * wanted_channels);
236 | for (int c = 0; c < wanted_channels; ++c) {
237 | out_pixel[c] = (in_pixel[c] - input_mean) / input_std;
238 | }
239 | }
240 | }
241 |
242 | std::vector outputs;
243 | tensorflow::Status run_status = (*session)->Run({{input_layer, image_tensor}},
244 | {output_layer}, {}, &outputs);
245 | if (!run_status.ok()) {
246 | LOG(ERROR) << "Running model failed: " << run_status;
247 | return run_status;
248 | }
249 |
250 | tensorflow::Tensor* output = &outputs[0];
251 | const int kNumResults = 5;
252 | const float kThreshold = 0.1f;
253 | std::vector > top_results;
254 | GetTopN(output->flat(), kNumResults, kThreshold, &top_results);
255 | for (const auto& result : top_results) {
256 | struct Result res;
257 | std::string label = labels->at(result.second % labels->size());
258 | res.label = [NSString stringWithUTF8String:label.c_str()];
259 | res.confidence = [NSNumber numberWithFloat:result.first];
260 | results->push_back(res);
261 | }
262 |
263 | return run_status;
264 | }
265 |
--------------------------------------------------------------------------------
/www/tensorflow.js:
--------------------------------------------------------------------------------
1 | function TensorFlow(modelId, model) {
2 | this.modelId = modelId;
3 | if (model) {
4 | model = registerModel(modelId, model);
5 | } else {
6 | model = getModel(modelId);
7 | }
8 | this.model = model;
9 | this.onprogress = function() {};
10 | }
11 |
12 |
13 | TensorFlow.prototype.load = function(successCallback, errorCallback) {
14 | var promise;
15 | if (window.Promise && !successCallback) {
16 | promise = new Promise(function(resolve, reject) {
17 | successCallback = resolve;
18 | errorCallback = reject;
19 | });
20 | }
21 | loadModel(
22 | this.modelId,
23 | successCallback,
24 | errorCallback,
25 | this.onprogress
26 | );
27 | return promise;
28 | };
29 |
30 | TensorFlow.prototype.checkCached = function(successCallback, errorCallback) {
31 | var promise;
32 | if (window.Promise && !successCallback) {
33 | promise = new Promise(function(resolve, reject) {
34 | successCallback = resolve;
35 | errorCallback = reject;
36 | });
37 | }
38 | checkCached(this.modelId, successCallback, errorCallback);
39 | return promise;
40 | };
41 |
42 | TensorFlow.prototype.classify = function(image, successCallback, errorCallback) {
43 | var promise;
44 | if (window.Promise && !successCallback) {
45 | promise = new Promise(function(resolve, reject) {
46 | successCallback = resolve;
47 | errorCallback = reject;
48 | });
49 | }
50 |
51 | var self = this;
52 | if (!self.model.loaded) {
53 | self.load(function() {
54 | if (!self.model.loaded) {
55 | errorCallback("Error loading model!");
56 | return;
57 | }
58 | self.classify(image, successCallback, errorCallback);
59 | }, errorCallback);
60 | return promise;
61 | }
62 |
63 | cordova.exec(
64 | successCallback, errorCallback,
65 | "TensorFlow", "classify", [self.modelId, image]
66 | );
67 | return promise;
68 | };
69 |
70 | // Internal API for downloading and caching model files
71 | var models = {};
72 |
73 | var FIELDS = [
74 | 'label',
75 |
76 | 'model_path',
77 | 'label_path',
78 |
79 | 'input_size',
80 | 'image_mean',
81 | 'image_std',
82 | 'input_name',
83 | 'output_name'
84 | ];
85 |
86 | function registerModel(modelId, model) {
87 | FIELDS.forEach(function(field) {
88 | if (!model[field]) {
89 | throw 'Missing "' + field + '" on model description';
90 | }
91 | });
92 |
93 | if (model.model_path.match(/^http/) || model.label_path.match(/^http/)) {
94 | model.cached = false;
95 | } else {
96 | model.cached = true;
97 | }
98 | models[modelId] = model;
99 | model.id = modelId;
100 | return model;
101 | }
102 |
103 | function getModel(modelId) {
104 | var model = models[modelId];
105 | if (!model) {
106 | throw "Unknown model " + modelId;
107 | }
108 | return model;
109 | }
110 |
111 | var INCEPTION = 'https://storage.googleapis.com/download.tensorflow.org/models/';
112 | registerModel('inception-v1', {
113 | 'label': 'Inception v1',
114 | 'model_path': INCEPTION + 'inception5h.zip#tensorflow_inception_graph.pb',
115 | 'label_path': INCEPTION + 'inception5h.zip#imagenet_comp_graph_label_strings.txt',
116 | 'input_size': 224,
117 | 'image_mean': 117,
118 | 'image_std': 1,
119 | 'input_name': 'input',
120 | 'output_name': 'output'
121 | });
122 |
123 | registerModel('inception-v3', {
124 | 'label': 'Inception v3',
125 | 'model_path': INCEPTION + 'inception_dec_2015.zip#tensorflow_inception_graph.pb',
126 | 'label_path': INCEPTION + 'inception_dec_2015.zip#imagenet_comp_graph_label_strings.txt',
127 | 'input_size': 299,
128 | 'image_mean': 128,
129 | 'image_std': 128,
130 | 'input_name': 'Mul',
131 | 'output_name': 'final_result'
132 | });
133 |
134 | function loadModel(modelId, callback, errorCallback, progressCallback) {
135 | var model;
136 | try {
137 | model = getModel(modelId);
138 | } catch (e) {
139 | errorCallback(e);
140 | return;
141 | }
142 | if (!progressCallback) {
143 | progressCallback = function(stat) {
144 | console.log(stat.label);
145 | };
146 | }
147 | if (!model.cached) {
148 | checkCached(modelId, function(cached) {
149 | if (!cached) {
150 | fetchModel(
151 | model,
152 | initClassifier,
153 | errorCallback,
154 | progressCallback
155 | );
156 | } else {
157 | initClassifier();
158 | }
159 | }, errorCallback);
160 | } else {
161 | initClassifier();
162 | }
163 | function initClassifier() {
164 | var modelPath = (model.local_model_path || model.model_path),
165 | labelPath = (model.local_label_path || model.label_path);
166 | modelPath = modelPath.replace(/^file:\/\//, '');
167 | labelPath = labelPath.replace(/^file:\/\//, '');
168 | progressCallback({
169 | 'status': 'initializing',
170 | 'label': 'Initializing classifier'
171 | });
172 | cordova.exec(function() {
173 | model.loaded = true;
174 | callback(model);
175 | }, errorCallback, "TensorFlow", "loadModel", [
176 | model.id,
177 | modelPath,
178 | labelPath,
179 | model.input_size,
180 | model.image_mean,
181 | model.image_std,
182 | model.input_name,
183 | model.output_name
184 | ]);
185 | }
186 | }
187 |
188 | function getPath(filename) {
189 | return (
190 | cordova.file.externalDataDirectory || cordova.file.dataDirectory
191 | ) + filename;
192 | }
193 |
194 | function fetchModel(model, callback, errorCallback, progressCallback) {
195 | fetchZip(model, callback, errorCallback, progressCallback);
196 | }
197 |
198 | function checkCached(modelId, callback, errorCallback) {
199 | var model;
200 | try {
201 | model = getModel(modelId);
202 | } catch (e) {
203 | errorCallback(e);
204 | return;
205 | }
206 | var zipUrl = model.model_path.split('#')[0];
207 | if (model.label_path.indexOf(zipUrl) == -1) {
208 | errorCallback('Model and labels must be in same zip file!');
209 | return;
210 | }
211 | var modelZipName = model.model_path.replace(zipUrl + '#', '');
212 | var labelZipName = model.label_path.replace(zipUrl + '#', '');
213 | var zipPath = getPath(model.id + '.zip');
214 | var dir = getPath(model.id);
215 |
216 | model.local_model_path = dir + '/' + modelZipName;
217 | model.local_label_path = dir + '/' + labelZipName;
218 |
219 | resolveLocalFileSystemURL(
220 | model.local_model_path, cached(true), cached(false)
221 | );
222 |
223 | function cached(result) {
224 | return function() {
225 | model.cached = result;
226 | callback(model.cached);
227 | };
228 | }
229 | }
230 |
231 | function fetchZip(model, callback, errorCallback, progressCallback) {
232 | var zipUrl = model.model_path.split('#')[0];
233 | var zipPath = getPath(model.id + '.zip');
234 | var dir = getPath(model.id);
235 | var fileTransfer = new FileTransfer();
236 | progressCallback({
237 | 'status': 'downloading',
238 | 'label': 'Downloading model files',
239 | });
240 | fileTransfer.onprogress = function(evt) {
241 | var label = 'Downloading';
242 | if (evt.lengthComputable) {
243 | label += ' (' + evt.loaded + '/' + evt.total + ')';
244 | } else {
245 | label += '...';
246 | }
247 | progressCallback({
248 | 'status': 'downloading',
249 | 'label': label,
250 | 'detail': evt
251 | });
252 | };
253 | fileTransfer.download(zipUrl, zipPath, function(entry) {
254 | progressCallback({
255 | 'status': 'unzipping',
256 | 'label': 'Extracting contents'
257 | });
258 | zip.unzip(zipPath, dir, function(result) {
259 | if (result == -1) {
260 | errorCallback('Error unzipping file');
261 | return;
262 | }
263 | model.cached = true;
264 | callback();
265 | });
266 | }, errorCallback);
267 | }
268 |
269 | TensorFlow._models = models;
270 | module.exports = TensorFlow;
271 |
--------------------------------------------------------------------------------