├── LICENSE ├── README.md ├── package.json ├── plugin.xml ├── src ├── android │ ├── TensorFlow.java │ └── tf_libs │ │ ├── Classifier.java │ │ ├── TensorFlowImageClassifier.java │ │ ├── armeabi-v7a │ │ └── libtensorflow_inference.so │ │ └── libandroid_tensorflow_inference_java.jar └── ios │ ├── TensorFlow.h │ ├── TensorFlow.mm │ └── tf_libs │ ├── ios_image_load.h │ ├── ios_image_load.mm │ ├── tensorflow_utils.h │ └── tensorflow_utils.mm └── www └── tensorflow.js /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2017 Houston Engineering, Inc., http://www.houstoneng.com 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # cordova-plugin-tensorflow 2 | 3 | Integrate the TensorFlow inference library into your PhoneGap/Cordova application! 4 | 5 | ```javascript 6 | var tf = new TensorFlow('inception-v1'); 7 | var imgData = "/9j/4AAQSkZJRgABAQEAYABgAAD//gBGRm ..."; 8 | 9 | tf.classify(imgData).then(function(results) { 10 | results.forEach(function(result) { 11 | console.log(result.title + " " + result.confidence); 12 | }); 13 | }); 14 | 15 | /* Output: 16 | military uniform 0.647296 17 | suit 0.0477196 18 | academic gown 0.0232411 19 | */ 20 | ``` 21 | ## Installation 22 | 23 | ### Cordova 24 | ```bash 25 | cordova plugin add https://github.com/heigeo/cordova-plugin-tensorflow 26 | ``` 27 | 28 | ### PhoneGap Build 29 | ```xml 30 | 31 | 32 | ``` 33 | 34 | ## Supported Platforms 35 | 36 | * Android 37 | * iOS 38 | 39 | ## API 40 | 41 | The plugin provides a `TensorFlow` class that can be used to initialize graphs and run the inference algorithm. 42 | 43 | ### Initialization 44 | 45 | ```javascript 46 | // Use the Inception model (will be downloaded on first use) 47 | var tf = new TensorFlow('inception-v1'); 48 | 49 | // Use a custom retrained model 50 | var tf = new TensorFlow('custom-model', { 51 | 'label': 'My Custom Model', 52 | 'model_path': "https://example.com/graphs/custom-model-2017.zip#rounded_graph.pb", 53 | 'label_path': "https://example.com/graphs/custom-model-2017.zip#retrained_labels.txt", 54 | 'input_size': 299, 55 | 'image_mean': 128, 56 | 'image_std': 128, 57 | 'input_name': 'Mul', 58 | 'output_name': 'final_result' 59 | }) 60 | ``` 61 | 62 | To use a custom model, follow the steps to [retrain the model](https://www.tensorflow.org/tutorials/image_retraining) and [optimize it for mobile use](https://petewarden.com/2016/09/27/tensorflow-for-mobile-poets/). 63 | Put the .pb and .txt files in a HTTP-accessible zip file, which will be downloaded via the [FileTransfer plugin](https://cordova.apache.org/docs/en/latest/reference/cordova-plugin-file-transfer/). If you use the generic Inception model it will be downloaded from [the TensorFlow website](https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip) on first use. 64 | 65 | ### Methods 66 | 67 | Each method returns a `Promise` (if available) and also accepts a callback and errorCallback. 68 | 69 | 70 | ### classify(image[, callback, errorCallback]) 71 | Classifies an image with TensorFlow's inference algorithm and the registered model. Will automatically download and initialize the model if necessary, but it is recommended to call `load()` explicitly for the best user experience. 72 | 73 | Note that the image must be provided as base64 encoded JPEG or PNG data. Support for file paths may be added in a future release. 74 | 75 | ```javascript 76 | var tf = new TensorFlow(...); 77 | var imgData = "/9j/4AAQSkZJRgABAQEAYABgAAD//gBGRm ..."; 78 | tf.classify(imgData).then(function(results) { 79 | results.forEach(function(result) { 80 | console.log(result.title + " " + result.confidence); 81 | }); 82 | }); 83 | ``` 84 | 85 | ### load() 86 | 87 | Downloads the referenced model files and loads the graph into TensorFlow. 88 | 89 | ```javascript 90 | var tf = new TensorFlow(...); 91 | tf.load().then(function() { 92 | console.log("Model loaded"); 93 | }); 94 | ``` 95 | 96 | Downloading the model files can take some time. If you would like to provide a progress indicator, you can do that with an `onprogress` event: 97 | ```javascript 98 | var tf = new TensorFlow(...); 99 | tf.onprogress = function(evt) { 100 | if (evt['status'] == 'downloading') 101 | console.log("Downloading model files..."); 102 | console.log(evt.label); 103 | if (evt.detail) { 104 | // evt.detail is from the FileTransfer API 105 | var $elem = $('progress'); 106 | $elem.attr('max', evt.detail.total); 107 | $elem.attr('value', evt.detail.loaded); 108 | } 109 | } else if (evt['status'] == 'unzipping') { 110 | console.log("Extracting contents..."); 111 | } else if (evt['status'] == 'initializing') { 112 | console.log("Initializing TensorFlow"); 113 | } 114 | }; 115 | tf.load().then(...); 116 | ``` 117 | 118 | ### checkCached() 119 | Checks whether the requisite model files have already been downloaded. This is useful if you want to provide an interface for downloading and managing TensorFlow graphs that is separate from the classification interface. 120 | 121 | ```javascript 122 | var tf = new TensorFlow(...); 123 | tf.checkCached().then(function(isCached) { 124 | if (isCached) { 125 | $('button#download').hide(); 126 | } 127 | }); 128 | ``` 129 | 130 | ## References 131 | 132 | This plugin is made possible by the following libraries and tutorials: 133 | 134 | Source | Files 135 | -------|-------- 136 | [TensorFlow Android Inference Interface] | [libtensorflow_inference.so],
[libandroid_tensorflow_inference_java.jar] 137 | [TensorFlow Android Demo] |[Classifer.java],
[TensorFlowImageClassifier.java][TensorFlowImageClassifier.java] (modified) 138 | [TensorflowPod] | Referenced via [podspec] 139 | [TensorFlow iOS Examples] | [ios_image_load.mm][ios_image_load.mm] (modified),
[tensorflow_utils.mm][tensorflow_utils.mm] (+ RunModelViewController.mm) 140 | 141 | [TensorFlow Android Inference Interface]: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/android 142 | [libtensorflow_inference.so]: https://github.com/heigeo/cordova-plugin-tensorflow/blob/master/src/android/tf_libs/armeabi-v7a/libtensorflow_inference.so 143 | [libandroid_tensorflow_inference_java.jar]: https://github.com/heigeo/cordova-plugin-tensorflow/blob/master/src/android/tf_libs/libandroid_tensorflow_inference_java.jar 144 | [TensorFlow Android Demo]: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android 145 | [Classifer.java]: https://github.com/heigeo/cordova-plugin-tensorflow/blob/master/src/android/tf_libs/Classifier.java 146 | [TensorFlowImageClassifier.java]: https://github.com/heigeo/cordova-plugin-tensorflow/blob/master/src/android/tf_libs/TensorFlowImageClassifier.java 147 | [TensorflowPod]: https://github.com/rainbean/TensorflowPod 148 | [podspec]: https://github.com/heigeo/cordova-plugin-tensorflow/blob/master/plugin.xml#L38 149 | [TensorFlow iOS Examples]: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/ios_examples 150 | [ios_image_load.mm]: https://github.com/heigeo/cordova-plugin-tensorflow/blob/master/src/ios/tf_libs/ios_image_load.mm 151 | [tensorflow_utils.mm]: https://github.com/heigeo/cordova-plugin-tensorflow/blob/master/src/ios/tf_libs/tensorflow_utils.mm 152 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "cordova-plugin-tensorflow", 3 | "version": "0.0.1", 4 | "description": "TensorFlow for Cordova", 5 | "cordova": { 6 | "id": "cordova-plugin-tensorflow", 7 | "platforms": [ 8 | "android" 9 | ] 10 | }, 11 | "repository": { 12 | "type": "git", 13 | "url": "https://github.com/heigeo/cordova-plugin-tensorflow.git" 14 | }, 15 | "keywords": [ 16 | "ai", 17 | "inference", 18 | "classification", 19 | "imagerecognition", 20 | "neuralnetworks", 21 | "machinelearning", 22 | "tensorflow", 23 | "inception", 24 | "ecosystem:cordova", 25 | "cordova-android" 26 | ], 27 | "engines": [ 28 | { 29 | "name": "cordova-android", 30 | "version": ">=5.1.0" 31 | } 32 | ], 33 | "author": "Houston Engineering, Inc.", 34 | "license": "MIT", 35 | "bugs": { 36 | "url": "https://github.com/heigeo/cordova-plugin-tensorflow/issues" 37 | }, 38 | "homepage": "https://github.com/heigeo/cordova-plugin-tensorflow#readme" 39 | } 40 | -------------------------------------------------------------------------------- /plugin.xml: -------------------------------------------------------------------------------- 1 | 2 | 4 | TensorFlow for Cordova 5 | Cordova/PhoneGap wrapper for TensorFlow's image recognition binary library. 6 | MIT 7 | cordova,device 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /src/android/TensorFlow.java: -------------------------------------------------------------------------------- 1 | package io.wq.tensorflow; 2 | 3 | import org.apache.cordova.CordovaPlugin; 4 | import org.apache.cordova.CallbackContext; 5 | import org.json.JSONArray; 6 | import org.json.JSONObject; 7 | import org.json.JSONException; 8 | 9 | import android.util.Base64; 10 | import android.graphics.Bitmap; 11 | import android.graphics.Bitmap.Config; 12 | import android.graphics.BitmapFactory; 13 | 14 | import org.tensorflow.demo.TensorFlowImageClassifier; 15 | import org.tensorflow.demo.Classifier.Recognition; 16 | import org.tensorflow.demo.Classifier; 17 | import java.util.List; 18 | import java.util.Map; 19 | import java.util.HashMap; 20 | 21 | import android.media.ThumbnailUtils; 22 | 23 | public class TensorFlow extends CordovaPlugin { 24 | 25 | @Override 26 | public boolean execute(String action, JSONArray args, CallbackContext callbackContext) throws JSONException { 27 | if (action.equals("loadModel")) { 28 | this.loadModel( 29 | args.getString(0), 30 | args.getString(1), 31 | args.getString(2), 32 | args.getInt(3), 33 | args.getInt(4), 34 | (float) args.getDouble(5), 35 | args.getString(6), 36 | args.getString(7), 37 | callbackContext 38 | ); 39 | return true; 40 | } else if (action.equals("classify")) { 41 | this.classify(args.getString(0), args.getString(1), callbackContext); 42 | return true; 43 | } else { 44 | return false; 45 | } 46 | } 47 | 48 | 49 | private Map classifiers = new HashMap(); 50 | private Map sizes = new HashMap(); 51 | 52 | private void loadModel(String modelName, String modelFile, String labelFile, 53 | int inputSize, int imageMean, float imageStd, 54 | String inputName, String outputName, 55 | CallbackContext callbackContext) { 56 | classifiers.put(modelName, TensorFlowImageClassifier.create( 57 | cordova.getActivity().getAssets(), 58 | modelFile, 59 | labelFile, 60 | inputSize, 61 | imageMean, 62 | imageStd, 63 | inputName, 64 | outputName 65 | )); 66 | sizes.put(modelName, inputSize); 67 | callbackContext.success(); 68 | } 69 | 70 | private void classify(String modelName, String image, CallbackContext callbackContext) { 71 | byte[] imageData = Base64.decode(image, Base64.DEFAULT); 72 | Classifier classifier = classifiers.get(modelName); 73 | int size = sizes.get(modelName); 74 | Bitmap bitmap = BitmapFactory.decodeByteArray(imageData, 0, imageData.length); 75 | Bitmap cropped = ThumbnailUtils.extractThumbnail(bitmap, size, size); 76 | List results = classifier.recognizeImage(cropped); 77 | JSONArray output = new JSONArray(); 78 | try { 79 | for (Recognition result : results) { 80 | JSONObject record = new JSONObject(); 81 | record.put("title", result.getTitle()); 82 | record.put("confidence", result.getConfidence()); 83 | output.put(record); 84 | } 85 | } catch (JSONException e) { 86 | } 87 | callbackContext.success(output); 88 | } 89 | 90 | 91 | } 92 | -------------------------------------------------------------------------------- /src/android/tf_libs/Classifier.java: -------------------------------------------------------------------------------- 1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | package org.tensorflow.demo; 17 | 18 | import android.graphics.Bitmap; 19 | import android.graphics.RectF; 20 | import java.util.List; 21 | 22 | /** 23 | * Generic interface for interacting with different recognition engines. 24 | */ 25 | public interface Classifier { 26 | /** 27 | * An immutable result returned by a Classifier describing what was recognized. 28 | */ 29 | public class Recognition { 30 | /** 31 | * A unique identifier for what has been recognized. Specific to the class, not the instance of 32 | * the object. 33 | */ 34 | private final String id; 35 | 36 | /** 37 | * Display name for the recognition. 38 | */ 39 | private final String title; 40 | 41 | /** 42 | * A sortable score for how good the recognition is relative to others. Higher should be better. 43 | */ 44 | private final Float confidence; 45 | 46 | /** Optional location within the source image for the location of the recognized object. */ 47 | private RectF location; 48 | 49 | public Recognition( 50 | final String id, final String title, final Float confidence, final RectF location) { 51 | this.id = id; 52 | this.title = title; 53 | this.confidence = confidence; 54 | this.location = location; 55 | } 56 | 57 | public String getId() { 58 | return id; 59 | } 60 | 61 | public String getTitle() { 62 | return title; 63 | } 64 | 65 | public Float getConfidence() { 66 | return confidence; 67 | } 68 | 69 | public RectF getLocation() { 70 | return new RectF(location); 71 | } 72 | 73 | public void setLocation(RectF location) { 74 | this.location = location; 75 | } 76 | 77 | @Override 78 | public String toString() { 79 | String resultString = ""; 80 | if (id != null) { 81 | resultString += "[" + id + "] "; 82 | } 83 | 84 | if (title != null) { 85 | resultString += title + " "; 86 | } 87 | 88 | if (confidence != null) { 89 | resultString += String.format("(%.1f%%) ", confidence * 100.0f); 90 | } 91 | 92 | if (location != null) { 93 | resultString += location + " "; 94 | } 95 | 96 | return resultString.trim(); 97 | } 98 | } 99 | 100 | List recognizeImage(Bitmap bitmap); 101 | 102 | void enableStatLogging(final boolean debug); 103 | 104 | String getStatString(); 105 | 106 | void close(); 107 | } 108 | -------------------------------------------------------------------------------- /src/android/tf_libs/TensorFlowImageClassifier.java: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | package org.tensorflow.demo; 17 | 18 | import android.content.res.AssetManager; 19 | import android.graphics.Bitmap; 20 | import android.os.Trace; 21 | import android.util.Log; 22 | import java.io.FileInputStream; 23 | import java.io.BufferedReader; 24 | import java.io.IOException; 25 | import java.io.InputStreamReader; 26 | import java.util.ArrayList; 27 | import java.util.Comparator; 28 | import java.util.List; 29 | import java.util.PriorityQueue; 30 | import java.util.Vector; 31 | import org.tensorflow.Operation; 32 | import org.tensorflow.contrib.android.TensorFlowInferenceInterface; 33 | 34 | /** A classifier specialized to label images using TensorFlow. */ 35 | public class TensorFlowImageClassifier implements Classifier { 36 | 37 | private static final String TAG = "TensorFlowImageClassifier"; 38 | 39 | // Only return this many results with at least this confidence. 40 | private static final int MAX_RESULTS = 3; 41 | private static final float THRESHOLD = 0.1f; 42 | 43 | // Config values. 44 | private String inputName; 45 | private String outputName; 46 | private int inputSize; 47 | private int imageMean; 48 | private float imageStd; 49 | 50 | // Pre-allocated buffers. 51 | private Vector labels = new Vector(); 52 | private int[] intValues; 53 | private float[] floatValues; 54 | private float[] outputs; 55 | private String[] outputNames; 56 | 57 | private TensorFlowInferenceInterface inferenceInterface; 58 | 59 | private TensorFlowImageClassifier() {} 60 | 61 | /** 62 | * Initializes a native TensorFlow session for classifying images. 63 | * 64 | * @param assetManager The asset manager to be used to load assets. 65 | * @param modelFilename The filepath of the model GraphDef protocol buffer. 66 | * @param labelFilename The filepath of label file for classes. 67 | * @param inputSize The input size. A square image of inputSize x inputSize is assumed. 68 | * @param imageMean The assumed mean of the image values. 69 | * @param imageStd The assumed std of the image values. 70 | * @param inputName The label of the image input node. 71 | * @param outputName The label of the output node. 72 | * @throws IOException 73 | */ 74 | public static Classifier create( 75 | AssetManager assetManager, 76 | String modelFilename, 77 | String labelFilename, 78 | int inputSize, 79 | int imageMean, 80 | float imageStd, 81 | String inputName, 82 | String outputName) { 83 | TensorFlowImageClassifier c = new TensorFlowImageClassifier(); 84 | c.inputName = inputName; 85 | c.outputName = outputName; 86 | 87 | // Read the label names into memory. 88 | // TODO(andrewharp): make this handle non-assets. 89 | final boolean hasAssetPrefix = labelFilename.startsWith("file:///android_asset/"); 90 | String actualFilename = hasAssetPrefix ? labelFilename.split("file:///android_asset/")[1] : labelFilename; 91 | Log.i(TAG, "Reading labels from: " + actualFilename); 92 | BufferedReader br = null; 93 | try { 94 | br = new BufferedReader(new InputStreamReader(assetManager.open(actualFilename))); 95 | String line; 96 | while ((line = br.readLine()) != null) { 97 | c.labels.add(line); 98 | } 99 | br.close(); 100 | } catch (IOException e) { 101 | if (hasAssetPrefix) { 102 | throw new RuntimeException("Problem reading label file!" , e); 103 | } 104 | try { 105 | br = new BufferedReader(new InputStreamReader(new FileInputStream(actualFilename))); 106 | String line; 107 | while ((line = br.readLine()) != null) { 108 | c.labels.add(line); 109 | } 110 | br.close(); 111 | } catch (IOException e2) { 112 | throw new RuntimeException("Problem reading label file!" , e); 113 | } 114 | } 115 | 116 | c.inferenceInterface = new TensorFlowInferenceInterface(); 117 | if (c.inferenceInterface.initializeTensorFlow(assetManager, modelFilename) != 0) { 118 | throw new RuntimeException("TF initialization failed"); 119 | } 120 | // The shape of the output is [N, NUM_CLASSES], where N is the batch size. 121 | final Operation operation = c.inferenceInterface.graph().operation(outputName); 122 | if (operation == null) { 123 | throw new RuntimeException("Node '" + outputName + "' does not exist in model '" 124 | + modelFilename + "'"); 125 | } 126 | final int numClasses = (int) operation.output(0).shape().size(1); 127 | Log.i(TAG, "Read " + c.labels.size() + " labels, output layer size is " + numClasses); 128 | 129 | // Ideally, inputSize could have been retrieved from the shape of the input operation. Alas, 130 | // the placeholder node for input in the graphdef typically used does not specify a shape, so it 131 | // must be passed in as a parameter. 132 | c.inputSize = inputSize; 133 | c.imageMean = imageMean; 134 | c.imageStd = imageStd; 135 | 136 | // Pre-allocate buffers. 137 | c.outputNames = new String[] {outputName}; 138 | c.intValues = new int[inputSize * inputSize]; 139 | c.floatValues = new float[inputSize * inputSize * 3]; 140 | c.outputs = new float[numClasses]; 141 | 142 | return c; 143 | } 144 | 145 | @Override 146 | public List recognizeImage(final Bitmap bitmap) { 147 | // Log this method so that it can be analyzed with systrace. 148 | Trace.beginSection("recognizeImage"); 149 | 150 | Trace.beginSection("preprocessBitmap"); 151 | // Preprocess the image data from 0-255 int to normalized float based 152 | // on the provided parameters. 153 | bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight()); 154 | for (int i = 0; i < intValues.length; ++i) { 155 | final int val = intValues[i]; 156 | floatValues[i * 3 + 0] = (((val >> 16) & 0xFF) - imageMean) / imageStd; 157 | floatValues[i * 3 + 1] = (((val >> 8) & 0xFF) - imageMean) / imageStd; 158 | floatValues[i * 3 + 2] = ((val & 0xFF) - imageMean) / imageStd; 159 | } 160 | Trace.endSection(); 161 | 162 | // Copy the input data into TensorFlow. 163 | Trace.beginSection("fillNodeFloat"); 164 | inferenceInterface.fillNodeFloat( 165 | inputName, new int[] {1, inputSize, inputSize, 3}, floatValues); 166 | Trace.endSection(); 167 | 168 | // Run the inference call. 169 | Trace.beginSection("runInference"); 170 | inferenceInterface.runInference(outputNames); 171 | Trace.endSection(); 172 | 173 | // Copy the output Tensor back into the output array. 174 | Trace.beginSection("readNodeFloat"); 175 | inferenceInterface.readNodeFloat(outputName, outputs); 176 | Trace.endSection(); 177 | 178 | // Find the best classifications. 179 | PriorityQueue pq = 180 | new PriorityQueue( 181 | 3, 182 | new Comparator() { 183 | @Override 184 | public int compare(Recognition lhs, Recognition rhs) { 185 | // Intentionally reversed to put high confidence at the head of the queue. 186 | return Float.compare(rhs.getConfidence(), lhs.getConfidence()); 187 | } 188 | }); 189 | for (int i = 0; i < outputs.length; ++i) { 190 | if (outputs[i] > THRESHOLD) { 191 | pq.add( 192 | new Recognition( 193 | "" + i, labels.size() > i ? labels.get(i) : "unknown", outputs[i], null)); 194 | } 195 | } 196 | final ArrayList recognitions = new ArrayList(); 197 | int recognitionsSize = Math.min(pq.size(), MAX_RESULTS); 198 | for (int i = 0; i < recognitionsSize; ++i) { 199 | recognitions.add(pq.poll()); 200 | } 201 | Trace.endSection(); // "recognizeImage" 202 | return recognitions; 203 | } 204 | 205 | @Override 206 | public void enableStatLogging(boolean debug) { 207 | inferenceInterface.enableStatLogging(debug); 208 | } 209 | 210 | @Override 211 | public String getStatString() { 212 | return inferenceInterface.getStatString(); 213 | } 214 | 215 | @Override 216 | public void close() { 217 | inferenceInterface.close(); 218 | } 219 | } 220 | -------------------------------------------------------------------------------- /src/android/tf_libs/armeabi-v7a/libtensorflow_inference.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heigeo/cordova-plugin-tensorflow/9c8b74c81a642b1381be517de8f22e0caa649180/src/android/tf_libs/armeabi-v7a/libtensorflow_inference.so -------------------------------------------------------------------------------- /src/android/tf_libs/libandroid_tensorflow_inference_java.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heigeo/cordova-plugin-tensorflow/9c8b74c81a642b1381be517de8f22e0caa649180/src/android/tf_libs/libandroid_tensorflow_inference_java.jar -------------------------------------------------------------------------------- /src/ios/TensorFlow.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "tensorflow/core/public/session.h" 4 | 5 | #import "ios_image_load.h" 6 | #import "tensorflow_utils.h" 7 | #import 8 | 9 | @interface TensorFlow : CDVPlugin { 10 | NSMutableDictionary *classifiers; 11 | } 12 | 13 | - (void)loadModel:(CDVInvokedUrlCommand*)command; 14 | - (void)classify:(CDVInvokedUrlCommand*)command; 15 | 16 | @end 17 | 18 | @interface Classifier : NSObject { 19 | std::unique_ptr session; 20 | NSString* model_file; 21 | NSString* label_file; 22 | std::vector labels; 23 | int input_size; 24 | int image_mean; 25 | float image_std; 26 | NSString* input_name; 27 | NSString* output_name; 28 | } 29 | 30 | - (id)initWithModel:(NSString *)model_file_ 31 | label_file:(NSString *)label_file_ 32 | input_size:(int)input_size_ 33 | image_mean:(int)image_mean_ 34 | image_std:(float)image_std_ 35 | input_name:(NSString *)input_name_ 36 | output_name:(NSString *)output_name_; 37 | - (tensorflow::Status)load; 38 | - (tensorflow::Status)classify:(NSString *)image results:(NSMutableArray *)results; 39 | 40 | @end 41 | -------------------------------------------------------------------------------- /src/ios/TensorFlow.mm: -------------------------------------------------------------------------------- 1 | #import "TensorFlow.h" 2 | #import 3 | #import "ios_image_load.h" 4 | #import "tensorflow_utils.h" 5 | 6 | @implementation TensorFlow 7 | 8 | - (void)loadModel:(CDVInvokedUrlCommand*)command 9 | { 10 | CDVPluginResult* pluginResult = nil; 11 | NSString* model_name = [command.arguments objectAtIndex:0]; 12 | Classifier* classifier = [ 13 | [Classifier alloc] 14 | initWithModel:[command.arguments objectAtIndex:1] 15 | label_file:[command.arguments objectAtIndex:2] 16 | input_size:[(NSNumber *)[command.arguments objectAtIndex:3] intValue] 17 | image_mean:[(NSNumber *)[command.arguments objectAtIndex:4] intValue] 18 | image_std:[(NSNumber *)[command.arguments objectAtIndex:5] floatValue] 19 | input_name:[command.arguments objectAtIndex:6] 20 | output_name:[command.arguments objectAtIndex:7] 21 | ]; 22 | if (classifiers == nil) { 23 | classifiers = [NSMutableDictionary dictionaryWithDictionary:@{}]; 24 | } 25 | classifiers[model_name] = classifier; 26 | tensorflow::Status result = [classifier load]; 27 | if (result.ok()) { 28 | pluginResult = [CDVPluginResult resultWithStatus:CDVCommandStatus_OK]; 29 | } else { 30 | pluginResult = [CDVPluginResult resultWithStatus:CDVCommandStatus_ERROR]; 31 | } 32 | 33 | [self.commandDelegate sendPluginResult:pluginResult callbackId:command.callbackId]; 34 | } 35 | 36 | - (void)classify:(CDVInvokedUrlCommand*)command 37 | { 38 | CDVPluginResult* pluginResult = nil; 39 | NSString* model_name = [command.arguments objectAtIndex:0]; 40 | NSString* image = [command.arguments objectAtIndex:1]; 41 | Classifier* classifier = classifiers[model_name]; 42 | NSMutableArray* results = [[NSMutableArray alloc] init]; 43 | tensorflow::Status result = [classifier classify:image results:results]; 44 | if (result.ok()) { 45 | pluginResult = [CDVPluginResult resultWithStatus:CDVCommandStatus_OK messageAsArray:results]; 46 | } else { 47 | pluginResult = [CDVPluginResult resultWithStatus:CDVCommandStatus_ERROR]; 48 | } 49 | 50 | [self.commandDelegate sendPluginResult:pluginResult callbackId:command.callbackId]; 51 | } 52 | 53 | @end 54 | 55 | @implementation Classifier 56 | 57 | - (id)initWithModel:(NSString *)model_file_ 58 | label_file:(NSString *)label_file_ 59 | input_size:(int)input_size_ 60 | image_mean:(int)image_mean_ 61 | image_std:(float)image_std_ 62 | input_name:(NSString *)input_name_ 63 | output_name:(NSString *)output_name_ 64 | { 65 | self = [super init]; 66 | if (self) { 67 | model_file = model_file_; 68 | label_file = label_file_; 69 | input_size = input_size_; 70 | image_mean = image_mean_; 71 | image_std = image_std_; 72 | input_name = input_name_; 73 | output_name = output_name_; 74 | } 75 | return self; 76 | } 77 | 78 | - (tensorflow::Status)load 79 | { 80 | tensorflow::Status result; 81 | result = LoadModel(model_file, &session); 82 | if (result.ok()) { 83 | result = LoadLabels(label_file, &labels); 84 | } 85 | return result; 86 | } 87 | 88 | - (tensorflow::Status)classify:(NSString *)image results:(NSMutableArray *)results 89 | { 90 | std::vector tfresults; 91 | tensorflow::Status result = RunInferenceOnImage( 92 | image, 93 | input_size, 94 | image_mean, 95 | image_std, 96 | [input_name UTF8String], 97 | [output_name UTF8String], 98 | &session, 99 | &labels, 100 | &tfresults 101 | ); 102 | if (!result.ok()) { 103 | return result; 104 | } 105 | for (struct Result result : tfresults) { 106 | [results addObject: @{ 107 | @"title": result.label, 108 | @"confidence": result.confidence 109 | }]; 110 | } 111 | } 112 | 113 | @end 114 | -------------------------------------------------------------------------------- /src/ios/tf_libs/ios_image_load.h: -------------------------------------------------------------------------------- 1 | // Copyright 2015 Google Inc. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_ 16 | #define TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_ 17 | 18 | #include 19 | 20 | #include "tensorflow/core/framework/types.h" 21 | 22 | std::vector LoadImageFromFile(const char* file_name, 23 | int* out_width, 24 | int* out_height, 25 | int* out_channels); 26 | 27 | std::vector LoadImageFromBase64(NSString* base64data, 28 | int* out_width, 29 | int* out_height, 30 | int* out_channels); 31 | 32 | std::vector LoadImageFromData(CFDataRef data_ref, 33 | const char* suffix, 34 | int* out_width, 35 | int* out_height, 36 | int* out_channels); 37 | 38 | #endif // TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_ 39 | -------------------------------------------------------------------------------- /src/ios/tf_libs/ios_image_load.mm: -------------------------------------------------------------------------------- 1 | // Copyright 2015 Google Inc. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "ios_image_load.h" 16 | 17 | #include 18 | #include 19 | #include 20 | #include 21 | 22 | #import 23 | #import 24 | 25 | using tensorflow::uint8; 26 | 27 | std::vector LoadImageFromFile(const char* file_name, 28 | int* out_width, int* out_height, 29 | int* out_channels) { 30 | FILE* file_handle = fopen(file_name, "rb"); 31 | fseek(file_handle, 0, SEEK_END); 32 | const size_t bytes_in_file = ftell(file_handle); 33 | fseek(file_handle, 0, SEEK_SET); 34 | std::vector file_data(bytes_in_file); 35 | fread(file_data.data(), 1, bytes_in_file, file_handle); 36 | fclose(file_handle); 37 | CFDataRef file_data_ref = CFDataCreateWithBytesNoCopy(NULL, file_data.data(), 38 | bytes_in_file, 39 | kCFAllocatorNull); 40 | const char* suffix = strrchr(file_name, '.'); 41 | if (!suffix || suffix == file_name) { 42 | suffix = ""; 43 | } 44 | return LoadImageFromData(file_data_ref, suffix, out_width, out_height, out_channels); 45 | } 46 | 47 | std::vector LoadImageFromBase64(NSString* base64data, 48 | int* out_width, int* out_height, 49 | int* out_channels) { 50 | NSData *data = [[NSData alloc] initWithBase64EncodedString:base64data options:0]; 51 | CFDataRef data_ref = CFDataCreateWithBytesNoCopy(NULL, (UInt8 *)[data bytes], [data length], kCFAllocatorNull); 52 | return LoadImageFromData(data_ref, [@".jpeg" UTF8String], out_width, out_height, out_channels); 53 | } 54 | 55 | std::vector LoadImageFromData(CFDataRef data_ref, 56 | const char* suffix, 57 | int* out_width, int* out_height, 58 | int* out_channels) { 59 | CGDataProviderRef image_provider = 60 | CGDataProviderCreateWithCFData(data_ref); 61 | 62 | CGImageRef image; 63 | if (strcasecmp(suffix, ".png") == 0) { 64 | image = CGImageCreateWithPNGDataProvider(image_provider, NULL, true, 65 | kCGRenderingIntentDefault); 66 | } else if ((strcasecmp(suffix, ".jpg") == 0) || 67 | (strcasecmp(suffix, ".jpeg") == 0)) { 68 | image = CGImageCreateWithJPEGDataProvider(image_provider, NULL, true, 69 | kCGRenderingIntentDefault); 70 | } else { 71 | CFRelease(image_provider); 72 | CFRelease(data_ref); 73 | fprintf(stderr, "Unknown suffix '%s'\n", suffix); 74 | *out_width = 0; 75 | *out_height = 0; 76 | *out_channels = 0; 77 | return std::vector(); 78 | } 79 | 80 | const int width = (int)CGImageGetWidth(image); 81 | const int height = (int)CGImageGetHeight(image); 82 | const int channels = 4; 83 | CGColorSpaceRef color_space = CGColorSpaceCreateDeviceRGB(); 84 | const int bytes_per_row = (width * channels); 85 | const int bytes_in_image = (bytes_per_row * height); 86 | std::vector result(bytes_in_image); 87 | const int bits_per_component = 8; 88 | CGContextRef context = CGBitmapContextCreate(result.data(), width, height, 89 | bits_per_component, bytes_per_row, color_space, 90 | kCGImageAlphaPremultipliedLast | kCGBitmapByteOrder32Big); 91 | CGColorSpaceRelease(color_space); 92 | CGContextDrawImage(context, CGRectMake(0, 0, width, height), image); 93 | CGContextRelease(context); 94 | CFRelease(image); 95 | CFRelease(image_provider); 96 | CFRelease(data_ref); 97 | 98 | *out_width = width; 99 | *out_height = height; 100 | *out_channels = channels; 101 | return result; 102 | } 103 | -------------------------------------------------------------------------------- /src/ios/tf_libs/tensorflow_utils.h: -------------------------------------------------------------------------------- 1 | // Copyright 2015 Google Inc. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef TENSORFLOW_CONTRIB_IOS_EXAMPLES_CAMERA_TENSORFLOW_UTILS_H_ 16 | #define TENSORFLOW_CONTRIB_IOS_EXAMPLES_CAMERA_TENSORFLOW_UTILS_H_ 17 | 18 | #include 19 | #include 20 | 21 | #include "tensorflow/core/public/session.h" 22 | #include "tensorflow/core/util/memmapped_file_system.h" 23 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 24 | 25 | // Reads a serialized GraphDef protobuf file from the bundle, typically 26 | // created with the freeze_graph script. Populates the session argument with a 27 | // Session object that has the model loaded. 28 | tensorflow::Status LoadModel(NSString* file_name, 29 | std::unique_ptr* session); 30 | 31 | // Loads a model from a file that has been created using the 32 | // convert_graphdef_memmapped_format tool. This bundles together a GraphDef 33 | // proto together with a file that can be memory-mapped, containing the weight 34 | // parameters for the model. This is useful because it reduces the overall 35 | // memory pressure, since the read-only parameter regions can be easily paged 36 | // out and don't count toward memory limits on iOS. 37 | tensorflow::Status LoadMemoryMappedModel( 38 | NSString* file_name, 39 | std::unique_ptr* session, 40 | std::unique_ptr* memmapped_env); 41 | 42 | // Takes a text file with a single label on each line, and returns a list. 43 | tensorflow::Status LoadLabels(NSString* file_name, 44 | std::vector* label_strings); 45 | 46 | // Sorts the results from a model execution, and returns the highest scoring. 47 | void GetTopN(const Eigen::TensorMap, 48 | Eigen::Aligned>& prediction, 49 | const int num_results, const float threshold, 50 | std::vector >* top_results); 51 | 52 | struct Result { 53 | NSString* label; 54 | NSNumber* confidence; 55 | }; 56 | 57 | tensorflow::Status RunInferenceOnImage(NSString* image, int input_size, float input_mean, float input_std, std::string input_layer, std::string output_layer, std::unique_ptr* session, std::vector* labels, std::vector* results); 58 | 59 | #endif // TENSORFLOW_CONTRIB_IOS_EXAMPLES_CAMERA_TENSORFLOW_UTILS_H_ 60 | -------------------------------------------------------------------------------- /src/ios/tf_libs/tensorflow_utils.mm: -------------------------------------------------------------------------------- 1 | // Copyright 2015 Google Inc. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #import 16 | 17 | #include "tensorflow_utils.h" 18 | #include "ios_image_load.h" 19 | 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | 27 | #include "google/protobuf/io/coded_stream.h" 28 | #include "google/protobuf/io/zero_copy_stream_impl.h" 29 | #include "google/protobuf/io/zero_copy_stream_impl_lite.h" 30 | #include "google/protobuf/message_lite.h" 31 | #include "tensorflow/core/framework/tensor.h" 32 | #include "tensorflow/core/framework/types.pb.h" 33 | #include "tensorflow/core/platform/env.h" 34 | #include "tensorflow/core/platform/logging.h" 35 | #include "tensorflow/core/platform/mutex.h" 36 | #include "tensorflow/core/platform/types.h" 37 | #include "tensorflow/core/public/session.h" 38 | 39 | namespace { 40 | 41 | // Helper class used to load protobufs efficiently. 42 | class IfstreamInputStream : public ::google::protobuf::io::CopyingInputStream { 43 | public: 44 | explicit IfstreamInputStream(const std::string& file_name) 45 | : ifs_(file_name.c_str(), std::ios::in | std::ios::binary) {} 46 | ~IfstreamInputStream() { ifs_.close(); } 47 | 48 | int Read(void* buffer, int size) { 49 | if (!ifs_) { 50 | return -1; 51 | } 52 | ifs_.read(static_cast(buffer), size); 53 | return ifs_.gcount(); 54 | } 55 | 56 | private: 57 | std::ifstream ifs_; 58 | }; 59 | } // namespace 60 | 61 | // Returns the top N confidence values over threshold in the provided vector, 62 | // sorted by confidence in descending order. 63 | void GetTopN(const Eigen::TensorMap, 64 | Eigen::Aligned>& prediction, 65 | const int num_results, const float threshold, 66 | std::vector >* top_results) { 67 | // Will contain top N results in ascending order. 68 | std::priority_queue, 69 | std::vector >, 70 | std::greater > > 71 | top_result_pq; 72 | 73 | const int count = prediction.size(); 74 | for (int i = 0; i < count; ++i) { 75 | const float value = prediction(i); 76 | 77 | // Only add it if it beats the threshold and has a chance at being in 78 | // the top N. 79 | if (value < threshold) { 80 | continue; 81 | } 82 | 83 | top_result_pq.push(std::pair(value, i)); 84 | 85 | // If at capacity, kick the smallest value out. 86 | if (top_result_pq.size() > num_results) { 87 | top_result_pq.pop(); 88 | } 89 | } 90 | 91 | // Copy to output vector and reverse into descending order. 92 | while (!top_result_pq.empty()) { 93 | top_results->push_back(top_result_pq.top()); 94 | top_result_pq.pop(); 95 | } 96 | std::reverse(top_results->begin(), top_results->end()); 97 | } 98 | 99 | bool PortableReadFileToProto(const std::string& file_name, 100 | ::google::protobuf::MessageLite* proto) { 101 | ::google::protobuf::io::CopyingInputStreamAdaptor stream( 102 | new IfstreamInputStream(file_name)); 103 | stream.SetOwnsCopyingStream(true); 104 | ::google::protobuf::io::CodedInputStream coded_stream(&stream); 105 | // Total bytes hard limit / warning limit are set to 1GB and 512MB 106 | // respectively. 107 | coded_stream.SetTotalBytesLimit(1024LL << 20, 512LL << 20); 108 | return proto->ParseFromCodedStream(&coded_stream); 109 | } 110 | 111 | tensorflow::Status LoadModel(NSString* model_path, 112 | std::unique_ptr* session) { 113 | tensorflow::SessionOptions options; 114 | 115 | tensorflow::Session* session_pointer = nullptr; 116 | tensorflow::Status session_status = 117 | tensorflow::NewSession(options, &session_pointer); 118 | if (!session_status.ok()) { 119 | LOG(ERROR) << "Could not create TensorFlow Session: " << session_status; 120 | return session_status; 121 | } 122 | session->reset(session_pointer); 123 | 124 | tensorflow::GraphDef tensorflow_graph; 125 | 126 | const bool read_proto_succeeded = 127 | PortableReadFileToProto([model_path UTF8String], &tensorflow_graph); 128 | if (!read_proto_succeeded) { 129 | LOG(ERROR) << "Failed to load model proto from" << [model_path UTF8String]; 130 | return tensorflow::errors::NotFound([model_path UTF8String]); 131 | } 132 | 133 | tensorflow::Status create_status = (*session)->Create(tensorflow_graph); 134 | if (!create_status.ok()) { 135 | LOG(ERROR) << "Could not create TensorFlow Graph: " << create_status; 136 | return create_status; 137 | } 138 | 139 | return tensorflow::Status::OK(); 140 | } 141 | 142 | tensorflow::Status LoadMemoryMappedModel( 143 | NSString* network_path, 144 | std::unique_ptr* session, 145 | std::unique_ptr* memmapped_env) { 146 | memmapped_env->reset( 147 | new tensorflow::MemmappedEnv(tensorflow::Env::Default())); 148 | tensorflow::Status mmap_status = 149 | (memmapped_env->get())->InitializeFromFile([network_path UTF8String]); 150 | if (!mmap_status.ok()) { 151 | LOG(ERROR) << "MMap failed with " << mmap_status.error_message(); 152 | return mmap_status; 153 | } 154 | 155 | tensorflow::GraphDef tensorflow_graph; 156 | tensorflow::Status load_graph_status = ReadBinaryProto( 157 | memmapped_env->get(), 158 | tensorflow::MemmappedFileSystem::kMemmappedPackageDefaultGraphDef, 159 | &tensorflow_graph); 160 | if (!load_graph_status.ok()) { 161 | LOG(ERROR) << "MMap load graph failed with " 162 | << load_graph_status.error_message(); 163 | return load_graph_status; 164 | } 165 | 166 | tensorflow::SessionOptions options; 167 | // Disable optimizations on this graph so that constant folding doesn't 168 | // increase the memory footprint by creating new constant copies of the weight 169 | // parameters. 170 | options.config.mutable_graph_options() 171 | ->mutable_optimizer_options() 172 | ->set_opt_level(::tensorflow::OptimizerOptions::L0); 173 | options.env = memmapped_env->get(); 174 | 175 | tensorflow::Session* session_pointer = nullptr; 176 | tensorflow::Status session_status = 177 | tensorflow::NewSession(options, &session_pointer); 178 | if (!session_status.ok()) { 179 | LOG(ERROR) << "Could not create TensorFlow Session: " << session_status; 180 | return session_status; 181 | } 182 | 183 | tensorflow::Status create_status = session_pointer->Create(tensorflow_graph); 184 | if (!create_status.ok()) { 185 | LOG(ERROR) << "Could not create TensorFlow Graph: " << create_status; 186 | return create_status; 187 | } 188 | 189 | session->reset(session_pointer); 190 | 191 | return tensorflow::Status::OK(); 192 | } 193 | 194 | tensorflow::Status LoadLabels(NSString* labels_path, 195 | std::vector* label_strings) { 196 | // Read the label list 197 | std::ifstream t; 198 | t.open([labels_path UTF8String]); 199 | std::string line; 200 | while (t) { 201 | std::getline(t, line); 202 | label_strings->push_back(line); 203 | } 204 | t.close(); 205 | return tensorflow::Status::OK(); 206 | } 207 | 208 | tensorflow::Status RunInferenceOnImage(NSString* image, int input_size, float input_mean, float input_std, std::string input_layer, std::string output_layer, std::unique_ptr* session, std::vector* labels, std::vector* results) { 209 | 210 | int image_width; 211 | int image_height; 212 | int image_channels; 213 | std::vector image_data = LoadImageFromBase64( 214 | image, &image_width, &image_height, &image_channels); 215 | const int wanted_width = input_size; 216 | const int wanted_height = input_size; 217 | const int wanted_channels = 3; 218 | 219 | assert(image_channels >= wanted_channels); 220 | tensorflow::Tensor image_tensor( 221 | tensorflow::DT_FLOAT, 222 | tensorflow::TensorShape({ 223 | 1, wanted_height, wanted_width, wanted_channels})); 224 | auto image_tensor_mapped = image_tensor.tensor(); 225 | tensorflow::uint8* in = image_data.data(); 226 | tensorflow::uint8* in_end = (in + (image_height * image_width * image_channels)); 227 | float* out = image_tensor_mapped.data(); 228 | for (int y = 0; y < wanted_height; ++y) { 229 | const int in_y = (y * image_height) / wanted_height; 230 | tensorflow::uint8* in_row = in + (in_y * image_width * image_channels); 231 | float* out_row = out + (y * wanted_width * wanted_channels); 232 | for (int x = 0; x < wanted_width; ++x) { 233 | const int in_x = (x * image_width) / wanted_width; 234 | tensorflow::uint8* in_pixel = in_row + (in_x * image_channels); 235 | float* out_pixel = out_row + (x * wanted_channels); 236 | for (int c = 0; c < wanted_channels; ++c) { 237 | out_pixel[c] = (in_pixel[c] - input_mean) / input_std; 238 | } 239 | } 240 | } 241 | 242 | std::vector outputs; 243 | tensorflow::Status run_status = (*session)->Run({{input_layer, image_tensor}}, 244 | {output_layer}, {}, &outputs); 245 | if (!run_status.ok()) { 246 | LOG(ERROR) << "Running model failed: " << run_status; 247 | return run_status; 248 | } 249 | 250 | tensorflow::Tensor* output = &outputs[0]; 251 | const int kNumResults = 5; 252 | const float kThreshold = 0.1f; 253 | std::vector > top_results; 254 | GetTopN(output->flat(), kNumResults, kThreshold, &top_results); 255 | for (const auto& result : top_results) { 256 | struct Result res; 257 | std::string label = labels->at(result.second % labels->size()); 258 | res.label = [NSString stringWithUTF8String:label.c_str()]; 259 | res.confidence = [NSNumber numberWithFloat:result.first]; 260 | results->push_back(res); 261 | } 262 | 263 | return run_status; 264 | } 265 | -------------------------------------------------------------------------------- /www/tensorflow.js: -------------------------------------------------------------------------------- 1 | function TensorFlow(modelId, model) { 2 | this.modelId = modelId; 3 | if (model) { 4 | model = registerModel(modelId, model); 5 | } else { 6 | model = getModel(modelId); 7 | } 8 | this.model = model; 9 | this.onprogress = function() {}; 10 | } 11 | 12 | 13 | TensorFlow.prototype.load = function(successCallback, errorCallback) { 14 | var promise; 15 | if (window.Promise && !successCallback) { 16 | promise = new Promise(function(resolve, reject) { 17 | successCallback = resolve; 18 | errorCallback = reject; 19 | }); 20 | } 21 | loadModel( 22 | this.modelId, 23 | successCallback, 24 | errorCallback, 25 | this.onprogress 26 | ); 27 | return promise; 28 | }; 29 | 30 | TensorFlow.prototype.checkCached = function(successCallback, errorCallback) { 31 | var promise; 32 | if (window.Promise && !successCallback) { 33 | promise = new Promise(function(resolve, reject) { 34 | successCallback = resolve; 35 | errorCallback = reject; 36 | }); 37 | } 38 | checkCached(this.modelId, successCallback, errorCallback); 39 | return promise; 40 | }; 41 | 42 | TensorFlow.prototype.classify = function(image, successCallback, errorCallback) { 43 | var promise; 44 | if (window.Promise && !successCallback) { 45 | promise = new Promise(function(resolve, reject) { 46 | successCallback = resolve; 47 | errorCallback = reject; 48 | }); 49 | } 50 | 51 | var self = this; 52 | if (!self.model.loaded) { 53 | self.load(function() { 54 | if (!self.model.loaded) { 55 | errorCallback("Error loading model!"); 56 | return; 57 | } 58 | self.classify(image, successCallback, errorCallback); 59 | }, errorCallback); 60 | return promise; 61 | } 62 | 63 | cordova.exec( 64 | successCallback, errorCallback, 65 | "TensorFlow", "classify", [self.modelId, image] 66 | ); 67 | return promise; 68 | }; 69 | 70 | // Internal API for downloading and caching model files 71 | var models = {}; 72 | 73 | var FIELDS = [ 74 | 'label', 75 | 76 | 'model_path', 77 | 'label_path', 78 | 79 | 'input_size', 80 | 'image_mean', 81 | 'image_std', 82 | 'input_name', 83 | 'output_name' 84 | ]; 85 | 86 | function registerModel(modelId, model) { 87 | FIELDS.forEach(function(field) { 88 | if (!model[field]) { 89 | throw 'Missing "' + field + '" on model description'; 90 | } 91 | }); 92 | 93 | if (model.model_path.match(/^http/) || model.label_path.match(/^http/)) { 94 | model.cached = false; 95 | } else { 96 | model.cached = true; 97 | } 98 | models[modelId] = model; 99 | model.id = modelId; 100 | return model; 101 | } 102 | 103 | function getModel(modelId) { 104 | var model = models[modelId]; 105 | if (!model) { 106 | throw "Unknown model " + modelId; 107 | } 108 | return model; 109 | } 110 | 111 | var INCEPTION = 'https://storage.googleapis.com/download.tensorflow.org/models/'; 112 | registerModel('inception-v1', { 113 | 'label': 'Inception v1', 114 | 'model_path': INCEPTION + 'inception5h.zip#tensorflow_inception_graph.pb', 115 | 'label_path': INCEPTION + 'inception5h.zip#imagenet_comp_graph_label_strings.txt', 116 | 'input_size': 224, 117 | 'image_mean': 117, 118 | 'image_std': 1, 119 | 'input_name': 'input', 120 | 'output_name': 'output' 121 | }); 122 | 123 | registerModel('inception-v3', { 124 | 'label': 'Inception v3', 125 | 'model_path': INCEPTION + 'inception_dec_2015.zip#tensorflow_inception_graph.pb', 126 | 'label_path': INCEPTION + 'inception_dec_2015.zip#imagenet_comp_graph_label_strings.txt', 127 | 'input_size': 299, 128 | 'image_mean': 128, 129 | 'image_std': 128, 130 | 'input_name': 'Mul', 131 | 'output_name': 'final_result' 132 | }); 133 | 134 | function loadModel(modelId, callback, errorCallback, progressCallback) { 135 | var model; 136 | try { 137 | model = getModel(modelId); 138 | } catch (e) { 139 | errorCallback(e); 140 | return; 141 | } 142 | if (!progressCallback) { 143 | progressCallback = function(stat) { 144 | console.log(stat.label); 145 | }; 146 | } 147 | if (!model.cached) { 148 | checkCached(modelId, function(cached) { 149 | if (!cached) { 150 | fetchModel( 151 | model, 152 | initClassifier, 153 | errorCallback, 154 | progressCallback 155 | ); 156 | } else { 157 | initClassifier(); 158 | } 159 | }, errorCallback); 160 | } else { 161 | initClassifier(); 162 | } 163 | function initClassifier() { 164 | var modelPath = (model.local_model_path || model.model_path), 165 | labelPath = (model.local_label_path || model.label_path); 166 | modelPath = modelPath.replace(/^file:\/\//, ''); 167 | labelPath = labelPath.replace(/^file:\/\//, ''); 168 | progressCallback({ 169 | 'status': 'initializing', 170 | 'label': 'Initializing classifier' 171 | }); 172 | cordova.exec(function() { 173 | model.loaded = true; 174 | callback(model); 175 | }, errorCallback, "TensorFlow", "loadModel", [ 176 | model.id, 177 | modelPath, 178 | labelPath, 179 | model.input_size, 180 | model.image_mean, 181 | model.image_std, 182 | model.input_name, 183 | model.output_name 184 | ]); 185 | } 186 | } 187 | 188 | function getPath(filename) { 189 | return ( 190 | cordova.file.externalDataDirectory || cordova.file.dataDirectory 191 | ) + filename; 192 | } 193 | 194 | function fetchModel(model, callback, errorCallback, progressCallback) { 195 | fetchZip(model, callback, errorCallback, progressCallback); 196 | } 197 | 198 | function checkCached(modelId, callback, errorCallback) { 199 | var model; 200 | try { 201 | model = getModel(modelId); 202 | } catch (e) { 203 | errorCallback(e); 204 | return; 205 | } 206 | var zipUrl = model.model_path.split('#')[0]; 207 | if (model.label_path.indexOf(zipUrl) == -1) { 208 | errorCallback('Model and labels must be in same zip file!'); 209 | return; 210 | } 211 | var modelZipName = model.model_path.replace(zipUrl + '#', ''); 212 | var labelZipName = model.label_path.replace(zipUrl + '#', ''); 213 | var zipPath = getPath(model.id + '.zip'); 214 | var dir = getPath(model.id); 215 | 216 | model.local_model_path = dir + '/' + modelZipName; 217 | model.local_label_path = dir + '/' + labelZipName; 218 | 219 | resolveLocalFileSystemURL( 220 | model.local_model_path, cached(true), cached(false) 221 | ); 222 | 223 | function cached(result) { 224 | return function() { 225 | model.cached = result; 226 | callback(model.cached); 227 | }; 228 | } 229 | } 230 | 231 | function fetchZip(model, callback, errorCallback, progressCallback) { 232 | var zipUrl = model.model_path.split('#')[0]; 233 | var zipPath = getPath(model.id + '.zip'); 234 | var dir = getPath(model.id); 235 | var fileTransfer = new FileTransfer(); 236 | progressCallback({ 237 | 'status': 'downloading', 238 | 'label': 'Downloading model files', 239 | }); 240 | fileTransfer.onprogress = function(evt) { 241 | var label = 'Downloading'; 242 | if (evt.lengthComputable) { 243 | label += ' (' + evt.loaded + '/' + evt.total + ')'; 244 | } else { 245 | label += '...'; 246 | } 247 | progressCallback({ 248 | 'status': 'downloading', 249 | 'label': label, 250 | 'detail': evt 251 | }); 252 | }; 253 | fileTransfer.download(zipUrl, zipPath, function(entry) { 254 | progressCallback({ 255 | 'status': 'unzipping', 256 | 'label': 'Extracting contents' 257 | }); 258 | zip.unzip(zipPath, dir, function(result) { 259 | if (result == -1) { 260 | errorCallback('Error unzipping file'); 261 | return; 262 | } 263 | model.cached = true; 264 | callback(); 265 | }); 266 | }, errorCallback); 267 | } 268 | 269 | TensorFlow._models = models; 270 | module.exports = TensorFlow; 271 | --------------------------------------------------------------------------------