├── .gitignore ├── LICENSE ├── README.md ├── env ├── Dockerfile ├── build.sh ├── run.sh └── shell.sh ├── package.json ├── samples ├── README.md └── graphs │ ├── basic │ ├── graph.proto │ ├── graph.proto.json │ ├── graph.proto.txt │ ├── graph.py │ ├── main.js │ └── package.json │ ├── json │ ├── main.js │ └── package.json │ ├── matrix │ ├── graph.proto │ ├── graph.proto.json │ ├── graph.proto.txt │ ├── graph.py │ ├── main.js │ └── package.json │ └── strings │ ├── graph.proto │ ├── graph.proto.json │ ├── graph.proto.txt │ ├── graph.py │ ├── main.js │ └── package.json ├── setup └── setup.js └── src ├── graph.js ├── index.js ├── interop ├── api.js ├── messages.js ├── messages.proto └── serializers.js ├── session.js └── tensor.js /.gitignore: -------------------------------------------------------------------------------- 1 | include/ 2 | lib/ 3 | node_modules/ 4 | *.pyc 5 | package-lock.json 6 | tensorflow*.tgz 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | 204 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow + Node.js 2 | 3 | [TensorFlow](https://tensorflow.org) is Google's machine learning runtime. It 4 | is implemented as C++ runtime, along with Python framework to support building 5 | a variety of models, especially neural networks for deep learning. 6 | 7 | It is interesting to be able to use TensorFlow in a node.js application 8 | using just JavaScript (or TypeScript if that's your preference). However, 9 | the Python functionality is vast (several ops, estimator implementations etc.) 10 | and continually expanding. Instead, it would be more practical to consider 11 | building Graphs and training models in Python, and then consuming those 12 | for runtime use-cases (like prediction or inference) in a pure node.js and 13 | Python-free deployment. This is what this node module enables. 14 | 15 | This module takes care of the building blocks and mechanics for working 16 | with the TensorFlow C API, and instead provides an API around Tensors, Graphs, 17 | Sessions and Models. 18 | 19 | This is still in the works, and recently revamped to support TensorFlow 1.4+. 20 | 21 | ## High Level Interface - Models 22 | 23 | This is in plan. The idea here is to point to a saved model and be able to 24 | use it for predictions. Instances-in, inferences-out. 25 | 26 | Stay tuned for a future update. 27 | 28 | ## Low Level Interface - Tensors, Graphs and Sessions 29 | 30 | ### Trivial Example - Loading and Running Graphs 31 | Lets assume we have a simple TensorFlow graph. For illustration purposes, a 32 | trivial graph produced from this Python code, and saved as a GraphDef 33 | protocol buffer file. 34 | 35 | ```python 36 | import tensorflow as tf 37 | 38 | with tf.Graph().as_default() as graph: 39 | c1 = tf.constant(1, name='c1') 40 | c2 = tf.constant(41, name='c2') 41 | result = tf.add(c1, c2, name='result') 42 | 43 | tf.train.write_graph(graph, '.', 'trivial.graph.proto', as_text=False) 44 | ``` 45 | 46 | Now, in node.js, you can load this serialized graph definition, load a 47 | TensorFlow session, and then run specific operations to retrive tensors. 48 | 49 | ```javascript 50 | const tf = require('tensorflow'); 51 | 52 | // Load the Graph and create a Session to be able to run the operations 53 | // defined in the graph. 54 | let graph = tf.graph('trivial.graph.proto'); 55 | let session = graph.createSession(); 56 | 57 | // Run to evaluate and retrieve the value of the 'result' op. 58 | let result = session.run(/* inputs */ null, 59 | /* outputs */ 'result', 60 | /* targets */ null); 61 | 62 | // The result is a Tensor, which contains value, type and shape fields. 63 | // This Should print out '42' 64 | console.log(result.value); 65 | 66 | // Cleanup 67 | graph.delete(); 68 | ``` 69 | 70 | ### Feeding and Fetching Tensors with a Session 71 | 72 | This example goes a bit further - in particular, the Graph contains 73 | variables, and placeholders, requiring initialization as well as feeding values, 74 | when executing the graph. Additionally the Tensors are integer matrices. 75 | 76 | ```python 77 | import tensorflow as tf 78 | 79 | with tf.Graph().as_default() as graph: 80 | var1 = tf.placeholder(dtype=tf.int32, shape=[2,2], name='var1') 81 | var2 = tf.placeholder(dtype=tf.int32, shape=[2,1], name='var2') 82 | var3 = tf.Variable(initial_value=[[1],[1]], dtype=tf.int32) 83 | 84 | tf.variables_initializer(tf.global_variables(), name='init') 85 | 86 | with tf.name_scope('computation'): 87 | tf.add(tf.matmul(var1, var2), var3, name='result') 88 | 89 | tf.train.write_graph(graph, '.', 'graph.proto', as_text=False) 90 | ``` 91 | 92 | Here is the corresponding node.js snippet to work with the Graph defined above: 93 | 94 | ```javascript 95 | const tf = require('tensorflow'); 96 | 97 | let graph = tf.graph('graph.proto'); 98 | let session = graph.createSession(); 99 | 100 | // Run the 'init' op to initialize variables defined in the graph. 101 | session.run(null, null, 'init'); 102 | 103 | // Generally you can use arrays directly. This samples demonstrates creating 104 | // Tensors to explicitly specify types to match the int32 types that the graph 105 | // expects. 106 | let a = tf.tensor([[2,2],[4,4]], tf.Types.int32); 107 | let b = tf.tensor([[3],[5]], tf.Types.int32); 108 | 109 | // You can fetch multiple outputs as well. 110 | let outputs = session.run({ var1: a, var2: b }, ['var3', 'computation/result']); 111 | console.log(outputs.var3.value) 112 | console.log(outputs['computation/result'].value); 113 | 114 | graph.delete(); 115 | ``` 116 | 117 | ## Installation 118 | 119 | Installation is pretty straight-forward. Installing this module automatically brings installs 120 | the TensorFlow binary dependencies (by default, TensorFlow CPU v1.4.1). 121 | 122 | npm install tensorflow 123 | 124 | Optionally, you can specify the build of TensorFlow binaries to install using environment 125 | variables. 126 | 127 | export TENSORFLOW_LIB_TYPE=gpu 128 | export TENSORFLOW_LIB_VERSION=1.5.0 129 | npm install tensorflow 130 | 131 | The TensorFlow binaries automatically installed within the directory containing the node 132 | module. If you have a custom build of TensorFlow you would like to use instead, you can 133 | suppress downloadinging the binaries at installation time. 134 | 135 | export TENSORFLOW_LIB_PATH=path-to-custom-binaries 136 | npm install tensorflow 137 | 138 | Note that the path you specify must be a directory that contains both `libtensorflow.so` and 139 | `libtensorflow_framework.so`. 140 | 141 | 142 | ## TensorFlow Setup and Docs 143 | Note that to use the Python interface to build TensorFlow graphs and train models, you will 144 | also need to install TensorFlow directly within your Python environment. 145 | 146 | pip install tensorflow==1.4.1 147 | 148 | For more information, check out the TensorFlow [install](https://www.tensorflow.org/install) 149 | and [API](https://www.tensorflow.org/api_docs/) documentation. 150 | 151 | 152 | ## In the works, and more to come ... 153 | Some things on the plan to be tackled. 154 | 155 | * Support for high-level API (and saved models representing results of training) 156 | * Support for Windows 157 | 158 | Please file issues for feature suggestions, bugs or questions. 159 | -------------------------------------------------------------------------------- /env/Dockerfile: -------------------------------------------------------------------------------- 1 | # TensorFlow and node.js development container. 2 | # 3 | 4 | FROM ubuntu:16.04 5 | MAINTAINER Nikhil Kothari 6 | 7 | # Setup OS and core packages 8 | RUN apt-get update -y && \ 9 | apt-get install --no-install-recommends -y -q \ 10 | curl wget unzip bzip2 git vim build-essential ca-certificates pkg-config \ 11 | python2.7 python-dev python-pip python-setuptools 12 | 13 | # Setup Node.js 14 | RUN mkdir -p /tools/node && \ 15 | wget -nv https://nodejs.org/dist/v8.9.3/node-v8.9.3-linux-x64.tar.gz -O node.tar.gz && \ 16 | tar xf node.tar.gz -C /tools/node --strip-components=1 && \ 17 | rm node.tar.gz 18 | 19 | # Setup TensorFlow 20 | RUN pip install --upgrade pip && \ 21 | pip install setuptools && \ 22 | pip install tensorflow==1.4.1 23 | 24 | # Configuration 25 | ENV PATH $PATH:/tools/node/bin 26 | ENTRYPOINT [ "/bin/bash" ] 27 | -------------------------------------------------------------------------------- /env/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | docker build -t tf-env . 4 | 5 | -------------------------------------------------------------------------------- /env/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | SCRIPT=$0 4 | if [ "$SCRIPT" == "-bash" ]; then 5 | SCRIPT=${BASH_SOURCE[0]} 6 | fi 7 | REPO_DIR=$(git rev-parse --show-toplevel) 8 | 9 | docker run -it --rm --name tf-env -v $REPO_DIR:/repo -p 8080:8080 tf-env 10 | -------------------------------------------------------------------------------- /env/shell.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | docker exec -it tf-env bash 4 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "tensorflow", 3 | "version": "0.7.0", 4 | "description": "Node.js module for using TensorFlow graphs and models", 5 | "keywords": [ 6 | "tensorflow", 7 | "tf", 8 | "tensor", 9 | "machine learning", 10 | "ml", 11 | "ai", 12 | "neural networks", 13 | "neuralnetworks", 14 | "deeplearning", 15 | "model", 16 | "numerical computation", 17 | "google" 18 | ], 19 | "homepage": "https://github.com/nikhilk/node-tensorflow#readme", 20 | "author": "Nikhil Kothari ", 21 | "contributors": [ 22 | { 23 | "name": "Nikhil Kothari", 24 | "email": "nikhilko@google.com", 25 | "url": "http://www.nikhilk.net" 26 | } 27 | ], 28 | "license": "Apache 2.0", 29 | "repository": "github:nikhilk/node-tensorflow", 30 | "files": [ 31 | "src", 32 | "setup", 33 | "README.md" 34 | ], 35 | "main": "src/index.js", 36 | "dependencies": { 37 | "ffi": "^2.2.0", 38 | "pbf": "^3.1.0", 39 | "ref": "^1.3.5", 40 | "ref-array": "^1.2.0", 41 | "ref-struct": "^1.1.0" 42 | }, 43 | "engines": { 44 | "node": ">=8.9.3" 45 | }, 46 | "os": [ 47 | "linux", 48 | "darwin" 49 | ], 50 | "scripts": { 51 | "postinstall": "npm run -s installtf", 52 | "installtf": "node setup/setup.js" 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /samples/README.md: -------------------------------------------------------------------------------- 1 | # Samples List 2 | 3 | This directory contains a few samples of using TensorFlow with node.js. 4 | 5 | ## Graphs 6 | The graphs set of samples demonstrate using TensorFlow graphs. 7 | 8 | ### basic 9 | 'Hello World' style sample, demonstrating graph loading and execution. 10 | 11 | ### matrix 12 | Builds on the above to use matrices instead of scalar tensors, as well as load graphs containing 13 | namescopes. This sample will be updated to demonstrate feeding in tensors when executing graphs. 14 | 15 | ### strings 16 | This demonstrates a graph that accepts a batch of strings input tensor, and produces a batch of 17 | strings as output. 18 | 19 | # Running Samples 20 | 21 | Within each sample directory: 22 | 23 | npm install 24 | npm run -s sample 25 | 26 | This will install the tensorflow node.js package along with associated TensorFlow binaries. 27 | Once installed, running will first run the Python code to produce the TensorFlow artifact 28 | (eg. graph) and then run the node.js sample. 29 | 30 | ## If you're making changes to the tensorflow package ... 31 | ... and you want to run the samples without first publishing the package to npm, you can 32 | create a package from your environment and install using that to run the samples. 33 | 34 | # Build local tensorflow-\.tgz package 35 | # From the root of the repo directory... 36 | npm pack 37 | 38 | cd samples/graphs/basic 39 | npm install ../../../tensorflow-.tgz 40 | 41 | This will install from your local package (even if the versions have not changed). 42 | 43 | Note that it will record this in the sample's package.json file. Be sure to revert the package.json 44 | file to avoid committing this change. 45 | -------------------------------------------------------------------------------- /samples/graphs/basic/graph.proto: -------------------------------------------------------------------------------- 1 | 2 | , 3 | c1Const* 4 | value B:* 5 | dtype0 6 | , 7 | c2Const* 8 | value B:)* 9 | dtype0 10 |  11 | resultAddc1c2* 12 | T0" -------------------------------------------------------------------------------- /samples/graphs/basic/graph.proto.json: -------------------------------------------------------------------------------- 1 | { 2 | "node": [ 3 | { 4 | "attr": { 5 | "dtype": { 6 | "type": "DT_INT32" 7 | }, 8 | "value": { 9 | "tensor": { 10 | "dtype": "DT_INT32", 11 | "tensorShape": {}, 12 | "intVal": [ 13 | 1 14 | ] 15 | } 16 | } 17 | }, 18 | "name": "c1", 19 | "op": "Const" 20 | }, 21 | { 22 | "attr": { 23 | "dtype": { 24 | "type": "DT_INT32" 25 | }, 26 | "value": { 27 | "tensor": { 28 | "dtype": "DT_INT32", 29 | "tensorShape": {}, 30 | "intVal": [ 31 | 41 32 | ] 33 | } 34 | } 35 | }, 36 | "name": "c2", 37 | "op": "Const" 38 | }, 39 | { 40 | "input": [ 41 | "c1", 42 | "c2" 43 | ], 44 | "attr": { 45 | "T": { 46 | "type": "DT_INT32" 47 | } 48 | }, 49 | "name": "result", 50 | "op": "Add" 51 | } 52 | ], 53 | "versions": { 54 | "producer": 24 55 | } 56 | } -------------------------------------------------------------------------------- /samples/graphs/basic/graph.proto.txt: -------------------------------------------------------------------------------- 1 | node { 2 | name: "c1" 3 | op: "Const" 4 | attr { 5 | key: "dtype" 6 | value { 7 | type: DT_INT32 8 | } 9 | } 10 | attr { 11 | key: "value" 12 | value { 13 | tensor { 14 | dtype: DT_INT32 15 | tensor_shape { 16 | } 17 | int_val: 1 18 | } 19 | } 20 | } 21 | } 22 | node { 23 | name: "c2" 24 | op: "Const" 25 | attr { 26 | key: "dtype" 27 | value { 28 | type: DT_INT32 29 | } 30 | } 31 | attr { 32 | key: "value" 33 | value { 34 | tensor { 35 | dtype: DT_INT32 36 | tensor_shape { 37 | } 38 | int_val: 41 39 | } 40 | } 41 | } 42 | } 43 | node { 44 | name: "result" 45 | op: "Add" 46 | input: "c1" 47 | input: "c2" 48 | attr { 49 | key: "T" 50 | value { 51 | type: DT_INT32 52 | } 53 | } 54 | } 55 | versions { 56 | producer: 24 57 | } 58 | -------------------------------------------------------------------------------- /samples/graphs/basic/graph.py: -------------------------------------------------------------------------------- 1 | # graph.py 2 | # Builds a trivial graph for most basic example of loading/running TensorFlow. 3 | # 4 | # Run with the following command: 5 | # python graph.py 6 | # 7 | # This should produce graph.proto (which is used from node.js) along with graph.proto.txt and 8 | # graph.proto.json for readable versions. 9 | 10 | import google.protobuf.json_format as json 11 | import tensorflow as tf 12 | 13 | def save_graph(graph, name='graph'): 14 | tf.train.write_graph(graph, '.', name + '.proto', as_text=False) 15 | tf.train.write_graph(graph, '.', name + '.proto.txt', as_text=True) 16 | 17 | data = json.MessageToJson(graph.as_graph_def()) 18 | with open(name + '.proto.json', 'w') as f: 19 | f.write(data) 20 | 21 | 22 | def build_graph(): 23 | with tf.Graph().as_default() as graph: 24 | c1 = tf.constant(1, name='c1') 25 | c2 = tf.constant(41, name='c2') 26 | result = tf.add(c1, c2, name='result') 27 | 28 | return graph 29 | 30 | save_graph(build_graph()) 31 | -------------------------------------------------------------------------------- /samples/graphs/basic/main.js: -------------------------------------------------------------------------------- 1 | const tf = require('tensorflow'); 2 | 3 | console.log('Tensor Test'); 4 | 5 | var tensor = tf.tensor([0.5, 42.0]); 6 | console.log(tensor.shape); 7 | console.log(tensor.value); 8 | 9 | let items = 10 | [ 11 | [ 12 | [1,2], [3,4] 13 | ], 14 | 15 | [ 16 | [5,6], [7,8] 17 | ] 18 | ] 19 | var tensor2 = tf.tensor(items); 20 | console.log(tensor2.shape); 21 | console.log(tensor2.value); 22 | 23 | 24 | console.log('Graph Test'); 25 | 26 | let graph = tf.graph('./graph.proto'); 27 | let session = graph.createSession(); 28 | let result = session.run(null, 'result'); 29 | console.log(result); 30 | 31 | graph.delete(); 32 | -------------------------------------------------------------------------------- /samples/graphs/basic/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "dependencies": { 3 | "tensorflow": "^0.7.0" 4 | }, 5 | "scripts": { 6 | "presample": "python graph.py", 7 | "sample": "node main.js" 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /samples/graphs/json/main.js: -------------------------------------------------------------------------------- 1 | const tf = require('tensorflow'); 2 | 3 | var const1 = { 4 | name: 'c1', 5 | op: 'Const', 6 | attr: { 7 | value: { 8 | value: 'tensor', 9 | tensor: { 10 | dtype: 3, 11 | tensor_shape: { dim: [] }, 12 | int_val: [1] 13 | } 14 | }, 15 | dtype: { 16 | value: 'type', 17 | type: 3 18 | } 19 | } 20 | }; 21 | 22 | var const2 = { 23 | name: 'c2', 24 | op: 'Const', 25 | attr: { 26 | value: { 27 | value: 'tensor', 28 | tensor: { 29 | dtype: 3, 30 | tensor_shape: { dim: [] }, 31 | int_val: [41] 32 | } 33 | }, 34 | dtype: { 35 | value: 'type', 36 | type: 3 37 | } 38 | } 39 | }; 40 | 41 | var add = { 42 | name: 'sum', 43 | op: 'Add', 44 | input: [ 45 | 'c1', 46 | 'c2' 47 | ], 48 | attr: { 49 | T: { 50 | value: 'type', 51 | type: 3 52 | } 53 | } 54 | }; 55 | 56 | var graphDef = { 57 | node: [ const1, const2, add ] 58 | } 59 | 60 | let graph = tf.graph(graphDef); 61 | let session = graph.createSession(); 62 | 63 | let results = session.run(null, ['sum'], null); 64 | 65 | console.log(results.sum.value); 66 | 67 | graph.delete(); 68 | -------------------------------------------------------------------------------- /samples/graphs/json/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "dependencies": { 3 | "tensorflow": "^0.7.0" 4 | }, 5 | "scripts": { 6 | "sample": "node main.js" 7 | } 8 | } 9 | -------------------------------------------------------------------------------- /samples/graphs/matrix/graph.proto: -------------------------------------------------------------------------------- 1 | 2 | 5 3 | var1 Placeholder* 4 | dtype0* 5 | shape 6 | : 7 | 5 8 | var2 Placeholder* 9 | shape 10 | :* 11 | dtype0 12 | K 13 | var3/initial_valueConst* 14 | dtype0*! 15 | valueB" 16 | X 17 | var3 18 | VariableV2* 19 | shared_name* 20 | dtype0* 21 | container* 22 | shape 23 | : 24 | z 25 | var3/AssignAssignvar3var3/initial_value* 26 | T0* 27 | _class 28 |  loc:@var3* 29 | validate_shape(* 30 | use_locking( 31 | = 32 | var3/readIdentityvar3* 33 | T0* 34 | _class 35 |  loc:@var3 36 |  37 | initNoOp ^var3/Assign 38 | W 39 | computation/MatMulMatMulvar1var2* 40 | T0* 41 | transpose_a(* 42 | transpose_b( 43 | A 44 | computation/resultAddcomputation/MatMul var3/read* 45 | T0" -------------------------------------------------------------------------------- /samples/graphs/matrix/graph.proto.json: -------------------------------------------------------------------------------- 1 | { 2 | "node": [ 3 | { 4 | "attr": { 5 | "dtype": { 6 | "type": "DT_INT32" 7 | }, 8 | "shape": { 9 | "shape": { 10 | "dim": [ 11 | { 12 | "size": "2" 13 | }, 14 | { 15 | "size": "2" 16 | } 17 | ] 18 | } 19 | } 20 | }, 21 | "name": "var1", 22 | "op": "Placeholder" 23 | }, 24 | { 25 | "attr": { 26 | "dtype": { 27 | "type": "DT_INT32" 28 | }, 29 | "shape": { 30 | "shape": { 31 | "dim": [ 32 | { 33 | "size": "2" 34 | }, 35 | { 36 | "size": "1" 37 | } 38 | ] 39 | } 40 | } 41 | }, 42 | "name": "var2", 43 | "op": "Placeholder" 44 | }, 45 | { 46 | "attr": { 47 | "dtype": { 48 | "type": "DT_INT32" 49 | }, 50 | "value": { 51 | "tensor": { 52 | "dtype": "DT_INT32", 53 | "tensorShape": { 54 | "dim": [ 55 | { 56 | "size": "2" 57 | }, 58 | { 59 | "size": "1" 60 | } 61 | ] 62 | }, 63 | "tensorContent": "AQAAAAEAAAA=" 64 | } 65 | } 66 | }, 67 | "name": "var3/initial_value", 68 | "op": "Const" 69 | }, 70 | { 71 | "attr": { 72 | "dtype": { 73 | "type": "DT_INT32" 74 | }, 75 | "shape": { 76 | "shape": { 77 | "dim": [ 78 | { 79 | "size": "2" 80 | }, 81 | { 82 | "size": "1" 83 | } 84 | ] 85 | } 86 | }, 87 | "container": { 88 | "s": "" 89 | }, 90 | "shared_name": { 91 | "s": "" 92 | } 93 | }, 94 | "name": "var3", 95 | "op": "VariableV2" 96 | }, 97 | { 98 | "input": [ 99 | "var3", 100 | "var3/initial_value" 101 | ], 102 | "attr": { 103 | "validate_shape": { 104 | "b": true 105 | }, 106 | "_class": { 107 | "list": { 108 | "s": [ 109 | "bG9jOkB2YXIz" 110 | ] 111 | } 112 | }, 113 | "use_locking": { 114 | "b": true 115 | }, 116 | "T": { 117 | "type": "DT_INT32" 118 | } 119 | }, 120 | "name": "var3/Assign", 121 | "op": "Assign" 122 | }, 123 | { 124 | "input": [ 125 | "var3" 126 | ], 127 | "attr": { 128 | "_class": { 129 | "list": { 130 | "s": [ 131 | "bG9jOkB2YXIz" 132 | ] 133 | } 134 | }, 135 | "T": { 136 | "type": "DT_INT32" 137 | } 138 | }, 139 | "name": "var3/read", 140 | "op": "Identity" 141 | }, 142 | { 143 | "input": [ 144 | "^var3/Assign" 145 | ], 146 | "name": "init", 147 | "op": "NoOp" 148 | }, 149 | { 150 | "input": [ 151 | "var1", 152 | "var2" 153 | ], 154 | "attr": { 155 | "transpose_b": { 156 | "b": false 157 | }, 158 | "transpose_a": { 159 | "b": false 160 | }, 161 | "T": { 162 | "type": "DT_INT32" 163 | } 164 | }, 165 | "name": "computation/MatMul", 166 | "op": "MatMul" 167 | }, 168 | { 169 | "input": [ 170 | "computation/MatMul", 171 | "var3/read" 172 | ], 173 | "attr": { 174 | "T": { 175 | "type": "DT_INT32" 176 | } 177 | }, 178 | "name": "computation/result", 179 | "op": "Add" 180 | } 181 | ], 182 | "versions": { 183 | "producer": 24 184 | } 185 | } -------------------------------------------------------------------------------- /samples/graphs/matrix/graph.proto.txt: -------------------------------------------------------------------------------- 1 | node { 2 | name: "var1" 3 | op: "Placeholder" 4 | attr { 5 | key: "dtype" 6 | value { 7 | type: DT_INT32 8 | } 9 | } 10 | attr { 11 | key: "shape" 12 | value { 13 | shape { 14 | dim { 15 | size: 2 16 | } 17 | dim { 18 | size: 2 19 | } 20 | } 21 | } 22 | } 23 | } 24 | node { 25 | name: "var2" 26 | op: "Placeholder" 27 | attr { 28 | key: "dtype" 29 | value { 30 | type: DT_INT32 31 | } 32 | } 33 | attr { 34 | key: "shape" 35 | value { 36 | shape { 37 | dim { 38 | size: 2 39 | } 40 | dim { 41 | size: 1 42 | } 43 | } 44 | } 45 | } 46 | } 47 | node { 48 | name: "var3/initial_value" 49 | op: "Const" 50 | attr { 51 | key: "dtype" 52 | value { 53 | type: DT_INT32 54 | } 55 | } 56 | attr { 57 | key: "value" 58 | value { 59 | tensor { 60 | dtype: DT_INT32 61 | tensor_shape { 62 | dim { 63 | size: 2 64 | } 65 | dim { 66 | size: 1 67 | } 68 | } 69 | tensor_content: "\001\000\000\000\001\000\000\000" 70 | } 71 | } 72 | } 73 | } 74 | node { 75 | name: "var3" 76 | op: "VariableV2" 77 | attr { 78 | key: "container" 79 | value { 80 | s: "" 81 | } 82 | } 83 | attr { 84 | key: "dtype" 85 | value { 86 | type: DT_INT32 87 | } 88 | } 89 | attr { 90 | key: "shape" 91 | value { 92 | shape { 93 | dim { 94 | size: 2 95 | } 96 | dim { 97 | size: 1 98 | } 99 | } 100 | } 101 | } 102 | attr { 103 | key: "shared_name" 104 | value { 105 | s: "" 106 | } 107 | } 108 | } 109 | node { 110 | name: "var3/Assign" 111 | op: "Assign" 112 | input: "var3" 113 | input: "var3/initial_value" 114 | attr { 115 | key: "T" 116 | value { 117 | type: DT_INT32 118 | } 119 | } 120 | attr { 121 | key: "_class" 122 | value { 123 | list { 124 | s: "loc:@var3" 125 | } 126 | } 127 | } 128 | attr { 129 | key: "use_locking" 130 | value { 131 | b: true 132 | } 133 | } 134 | attr { 135 | key: "validate_shape" 136 | value { 137 | b: true 138 | } 139 | } 140 | } 141 | node { 142 | name: "var3/read" 143 | op: "Identity" 144 | input: "var3" 145 | attr { 146 | key: "T" 147 | value { 148 | type: DT_INT32 149 | } 150 | } 151 | attr { 152 | key: "_class" 153 | value { 154 | list { 155 | s: "loc:@var3" 156 | } 157 | } 158 | } 159 | } 160 | node { 161 | name: "init" 162 | op: "NoOp" 163 | input: "^var3/Assign" 164 | } 165 | node { 166 | name: "computation/MatMul" 167 | op: "MatMul" 168 | input: "var1" 169 | input: "var2" 170 | attr { 171 | key: "T" 172 | value { 173 | type: DT_INT32 174 | } 175 | } 176 | attr { 177 | key: "transpose_a" 178 | value { 179 | b: false 180 | } 181 | } 182 | attr { 183 | key: "transpose_b" 184 | value { 185 | b: false 186 | } 187 | } 188 | } 189 | node { 190 | name: "computation/result" 191 | op: "Add" 192 | input: "computation/MatMul" 193 | input: "var3/read" 194 | attr { 195 | key: "T" 196 | value { 197 | type: DT_INT32 198 | } 199 | } 200 | } 201 | versions { 202 | producer: 24 203 | } 204 | -------------------------------------------------------------------------------- /samples/graphs/matrix/graph.py: -------------------------------------------------------------------------------- 1 | # graph.py 2 | # Builds a trivial graph for most basic example of loading/running TensorFlow. 3 | # 4 | # Run with the following command: 5 | # python graph.py 6 | # 7 | # This should produce graph.proto (which is used from node.js) along with graph.proto.txt and 8 | # graph.proto.json for readable versions. 9 | 10 | import google.protobuf.json_format as json 11 | import tensorflow as tf 12 | 13 | def save_graph(graph, name='graph'): 14 | tf.train.write_graph(graph, '.', name + '.proto', as_text=False) 15 | tf.train.write_graph(graph, '.', name + '.proto.txt', as_text=True) 16 | 17 | data = json.MessageToJson(graph.as_graph_def()) 18 | with open(name + '.proto.json', 'w') as f: 19 | f.write(data) 20 | 21 | 22 | def build_graph(): 23 | with tf.Graph().as_default() as graph: 24 | var1 = tf.placeholder(dtype=tf.int32, shape=[2,2], name='var1') 25 | var2 = tf.placeholder(dtype=tf.int32, shape=[2,1], name='var2') 26 | var3 = tf.Variable(initial_value=[[1],[1]], dtype=tf.int32, name='var3') 27 | 28 | tf.variables_initializer(tf.global_variables(), name='init') 29 | 30 | with tf.name_scope('computation'): 31 | tf.add(tf.matmul(var1, var2), var3, name='result') 32 | 33 | return graph 34 | 35 | save_graph(build_graph()) 36 | -------------------------------------------------------------------------------- /samples/graphs/matrix/main.js: -------------------------------------------------------------------------------- 1 | const tf = require('tensorflow'); 2 | 3 | let graph = tf.graph('./graph.proto'); 4 | let session = graph.createSession(); 5 | 6 | session.run(null, null, 'init'); 7 | 8 | let a = tf.tensor([[2,2],[4,4]], tf.Types.int32); 9 | let b = tf.tensor([[3],[5]], tf.Types.int32); 10 | 11 | let outputs = session.run({ var1: a, var2: b }, ['var3', 'computation/result']); 12 | console.log(outputs.var3.value) 13 | console.log(outputs['computation/result'].value); 14 | 15 | graph.delete(); 16 | -------------------------------------------------------------------------------- /samples/graphs/matrix/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "dependencies": { 3 | "tensorflow": "^0.7.0" 4 | }, 5 | "scripts": { 6 | "presample": "python graph.py", 7 | "sample": "node main.js" 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /samples/graphs/strings/graph.proto: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilk/node-tensorflow/db0b0db423f0b8354d29819129c13ed3c8475971/samples/graphs/strings/graph.proto -------------------------------------------------------------------------------- /samples/graphs/strings/graph.proto.json: -------------------------------------------------------------------------------- 1 | { 2 | "node": [ 3 | { 4 | "attr": { 5 | "dtype": { 6 | "type": "DT_STRING" 7 | }, 8 | "shape": { 9 | "shape": { 10 | "dim": [ 11 | { 12 | "size": "-1" 13 | }, 14 | { 15 | "size": "1" 16 | } 17 | ] 18 | } 19 | } 20 | }, 21 | "name": "input", 22 | "op": "Placeholder" 23 | }, 24 | { 25 | "input": [ 26 | "input" 27 | ], 28 | "attr": { 29 | "T": { 30 | "type": "DT_STRING" 31 | } 32 | }, 33 | "name": "output", 34 | "op": "Identity" 35 | } 36 | ], 37 | "versions": { 38 | "producer": 24 39 | } 40 | } -------------------------------------------------------------------------------- /samples/graphs/strings/graph.proto.txt: -------------------------------------------------------------------------------- 1 | node { 2 | name: "input" 3 | op: "Placeholder" 4 | attr { 5 | key: "dtype" 6 | value { 7 | type: DT_STRING 8 | } 9 | } 10 | attr { 11 | key: "shape" 12 | value { 13 | shape { 14 | dim { 15 | size: -1 16 | } 17 | dim { 18 | size: 1 19 | } 20 | } 21 | } 22 | } 23 | } 24 | node { 25 | name: "output" 26 | op: "Identity" 27 | input: "input" 28 | attr { 29 | key: "T" 30 | value { 31 | type: DT_STRING 32 | } 33 | } 34 | } 35 | versions { 36 | producer: 24 37 | } 38 | -------------------------------------------------------------------------------- /samples/graphs/strings/graph.py: -------------------------------------------------------------------------------- 1 | # graph.py 2 | # Builds a graph that operates over string tensors. 3 | # 4 | # Run with the following command: 5 | # python graph.py 6 | # 7 | # This should produce graph.proto (which is used from node.js) along with graph.proto.txt and 8 | # graph.proto.json for readable versions. 9 | 10 | import google.protobuf.json_format as json 11 | import tensorflow as tf 12 | 13 | def save_graph(graph, name='graph'): 14 | tf.train.write_graph(graph, '.', name + '.proto', as_text=False) 15 | tf.train.write_graph(graph, '.', name + '.proto.txt', as_text=True) 16 | 17 | data = json.MessageToJson(graph.as_graph_def()) 18 | with open(name + '.proto.json', 'w') as f: 19 | f.write(data) 20 | 21 | 22 | def build_graph(): 23 | with tf.Graph().as_default() as graph: 24 | strings = tf.placeholder(dtype=tf.string, shape=[None,1], name='input') 25 | tf.identity(strings, name='output') 26 | 27 | return graph 28 | 29 | save_graph(build_graph()) 30 | -------------------------------------------------------------------------------- /samples/graphs/strings/main.js: -------------------------------------------------------------------------------- 1 | const tf = require('tensorflow'); 2 | 3 | let graph = tf.graph('./graph.proto'); 4 | let session = graph.createSession(); 5 | 6 | let result = session.run({ input: ['example', 'data'] }, 'output'); 7 | console.log(result); 8 | 9 | graph.delete(); 10 | -------------------------------------------------------------------------------- /samples/graphs/strings/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "dependencies": { 3 | "tensorflow": "^0.8.0" 4 | }, 5 | "scripts": { 6 | "presample": "python graph.py", 7 | "sample": "node main.js" 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /setup/setup.js: -------------------------------------------------------------------------------- 1 | // setup.js 2 | // Installs pre-built TensorFlow libraries released by TensorFlow. 3 | // For more details, see https://www.tensorflow.org/install/install_c 4 | // 5 | 6 | 'use strict'; 7 | 8 | const fs = require('fs'), 9 | https = require('https'), 10 | os = require('os'), 11 | path = require('path'), 12 | processes = require('child_process'), 13 | zlib = require('zlib'); 14 | 15 | const libPlatform = os.platform(); 16 | 17 | // TODO: Add support for GPU-enabled TensorFlow builds 18 | // TODO: Add support for GPU - is there a way to detect NVIDIA GPU availability and/or 19 | // relevant driver/software and automatically install the right one? 20 | // For now, we'll use an environment variable, and otherwise default to CPU-only. 21 | const libType = process.env['TENSORFLOW_LIB_TYPE'] || 'cpu'; 22 | 23 | // TODO: Add support for specifying the version. 24 | // One way is to have this node package version match, but that seems like it may not pan 25 | // out always. 26 | // For now, we'll use an environment variable, and otherwise default to 1.4.1. 27 | const libVersion = process.env['TENSORFLOW_LIB_VERSION'] || '1.4.1'; 28 | 29 | function isInstallationRequired() { 30 | let libPath = process.env['TENSORFLOW_LIB_PATH'] || null; 31 | if (!libPath) { 32 | return true; 33 | } 34 | 35 | if (!fs.existsSync(path.join(libPath, 'libtensorflow.so'))) { 36 | console.log(`libtensorflow.so was not found at "${libPath}"`); 37 | process.exit(1); 38 | } 39 | if (!fs.existsSync(path.join(libPath, 'libtensorflow_framework.so'))) { 40 | console.log(`libtensorflow_framework.so was not found at "${libPath}"`); 41 | process.exit(1); 42 | } 43 | 44 | console.log(`TensorFlow libraries are already available at ${libPath}.`); 45 | return false; 46 | } 47 | 48 | function getSourceUrl() { 49 | } 50 | 51 | function downloadPackage(url, downloadPath, cb) { 52 | console.log(`Downloading ...\n${url}\n --> ${downloadPath} ...`); 53 | 54 | var file = fs.createWriteStream(downloadPath); 55 | var request = https.get(url, function(response) { 56 | response.on('data', function(chunk) { 57 | file.write(chunk); 58 | }) 59 | .on('end', function() { 60 | file.end(cb); 61 | }) 62 | }); 63 | 64 | request.on('error', function(e) { 65 | fs.unlink(downloadPath); 66 | cb(e); 67 | }); 68 | } 69 | 70 | function expandPackage(tarPath, expandPath) { 71 | console.log('Expanding and installing ...') 72 | try { 73 | processes.execSync(`tar -C ${expandPath} -xzf ${tarPath}`); 74 | } 75 | catch (e) { 76 | console.log('Unable to setup TensorFlow libraries.'); 77 | process.exit(1); 78 | } 79 | } 80 | 81 | function install() { 82 | if ((libPlatform != 'linux') && (libPlatform != 'darwin')) { 83 | console.log('Only Linux and Mac OS platforms are supported.\n' + 84 | 'See https://www.tensorflow.org/install/install_c for more information'); 85 | process.exit(1); 86 | } 87 | 88 | let url = 'https://storage.googleapis.com/tensorflow/libtensorflow/' + 89 | `libtensorflow-${libType}-${libPlatform}-x86_64-${libVersion}.tar.gz`; 90 | let tarPath = path.join(os.tmpdir(), 'tensorflow.tar.gz'); 91 | let installPath = path.join(__dirname, '..'); 92 | 93 | downloadPackage(url, tarPath, function(e) { 94 | if (e) { 95 | console.log(e.message); 96 | process.exit(1); 97 | } 98 | 99 | expandPackage(tarPath, installPath); 100 | }); 101 | } 102 | 103 | 104 | if (isInstallationRequired()) { 105 | install(); 106 | } 107 | -------------------------------------------------------------------------------- /src/graph.js: -------------------------------------------------------------------------------- 1 | // graph.js 2 | // Implements the Graph class to represent a Graph built from a GraphDef. 3 | // 4 | 5 | 'use strict'; 6 | 7 | const api = require('./interop/api'), 8 | fs = require('fs'), 9 | session = require('./session'); 10 | 11 | 12 | class Graph { 13 | 14 | constructor(graphHandle) { 15 | this._graphHandle = graphHandle; 16 | this._opCache = {}; 17 | 18 | this._sessions = []; 19 | } 20 | 21 | delete() { 22 | if (this._sessions) { 23 | this._sessions.forEach((session) => session.delete()); 24 | this._sessions = null; 25 | } 26 | 27 | if (this._graphHandle) { 28 | api.TF_DeleteGraph(this._graphHandle); 29 | this._graphHandle = null; 30 | } 31 | } 32 | 33 | createSession() { 34 | this._ensureValid(); 35 | 36 | if (this._sessions === null) { 37 | this._sessions = []; 38 | } 39 | 40 | let s = session.create(this._graphHandle, this._opCache); 41 | this._sessions.push(s); 42 | 43 | return s; 44 | } 45 | 46 | _ensureValid() { 47 | if (!this._graphHandle) { 48 | throw new Error('The Graph instance has been deleted.'); 49 | } 50 | } 51 | } 52 | 53 | 54 | function createGraph(graphDef) { 55 | let protobuf = loadGraphDef(graphDef); 56 | 57 | let graphDefBuffer = api.TF_NewBufferFromString(protobuf, protobuf.length); 58 | let graphDefOptions = api.TF_NewImportGraphDefOptions(); 59 | 60 | let graphHandle = api.TF_NewGraph(); 61 | api.TF_GraphImportGraphDef(graphHandle, graphDefBuffer, graphDefOptions, api.Status); 62 | 63 | api.TF_DeleteImportGraphDefOptions(graphDefOptions); 64 | api.TF_DeleteBuffer(graphDefBuffer); 65 | 66 | if (api.TF_GetCode(api.Status) !== api.StatusCodes.ok) { 67 | api.TF_DeleteGraph(graphHandle); 68 | 69 | let error = api.TF_Message(api.Status); 70 | throw new Error(error); 71 | } 72 | 73 | return new Graph(graphHandle); 74 | } 75 | 76 | function loadGraphDef(graphDef) { 77 | if (graphDef.constructor == String) { 78 | return fs.readFileSync(graphDef); 79 | } 80 | else if (Buffer.isBuffer(graphDef)) { 81 | return graphdef; 82 | } 83 | else { 84 | let ProtobufWriter = require('pbf'); 85 | 86 | let writer = new ProtobufWriter(); 87 | api.Protos.GraphDef.write(graphDef, writer); 88 | 89 | return writer.finish(); 90 | } 91 | } 92 | 93 | 94 | module.exports = { 95 | create: createGraph 96 | }; 97 | -------------------------------------------------------------------------------- /src/index.js: -------------------------------------------------------------------------------- 1 | // index.js 2 | // Defines the TensorFlow module. 3 | // 4 | 5 | const api = require('./interop/api'), 6 | tensor = require('./tensor'), 7 | graph = require('./graph'); 8 | 9 | module.exports = { 10 | Types: api.Types, 11 | graph: graph.create, 12 | tensor: tensor.create, 13 | }; 14 | -------------------------------------------------------------------------------- /src/interop/api.js: -------------------------------------------------------------------------------- 1 | // tf.js 2 | // Interface for TensorFlow C API 3 | // 4 | // This defines the TensorFlow library matching a subset of the C API methods as defined in 5 | // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/c_api.h 6 | // 7 | 8 | 'use strict'; 9 | 10 | const ffi = require('ffi'), 11 | fs = require('fs'), 12 | path = require('path'), 13 | protobuf = require('protocol-buffers'), 14 | ref = require('ref'), 15 | refArray = require('ref-array'), 16 | refStruct = require('ref-struct'); 17 | 18 | // Interop types to work with the C API. 19 | const types = { 20 | Void: 'void', 21 | Int: 'int32', 22 | IntArray: refArray('int32'), 23 | LongLong: 'longlong', 24 | LongLongArray: refArray('longlong'), 25 | Float: 'float', 26 | FloatArray: refArray('float'), 27 | String: 'string', 28 | StringArray: refArray('string'), 29 | Size: 'size_t', 30 | SizePtr: ref.refType('size_t'), 31 | Byte: 'byte', 32 | BytePtr: ref.refType('byte'), 33 | Any: ref.refType('void') 34 | }; 35 | types.Tensor = types.Any; 36 | types.TensorArray = refArray(types.Tensor); 37 | types.Status = types.Any; 38 | types.Buffer = types.Any; 39 | types.Graph = types.Any; 40 | types.ImportGraphDefOptions = types.Any; 41 | types.Operation = types.Any; 42 | types.OperationArray = refArray(types.Operation); 43 | types.OperationValue = refStruct({ op: types.Operation, index: 'int32' }); 44 | types.OperationValueArray = refArray(types.OperationValue); 45 | types.Session = types.Any; 46 | types.SessionOptions = types.Any; 47 | 48 | // Tensor data types supported by TensorFlow. 49 | const tensorTypes = { 50 | float: 1, 51 | double: 2, 52 | int32: 3, 53 | uint8: 4, 54 | int16: 5, 55 | int8: 6, 56 | string: 7, 57 | complex64: 8, 58 | int64: 9, 59 | bool: 10, 60 | qint8: 11, 61 | quint8: 12, 62 | qint32: 13, 63 | bfloat16: 14, 64 | qint16: 15, 65 | quint16: 16, 66 | complex128: 18, 67 | half: 19, 68 | resource: 20, 69 | variant: 21, 70 | uint32: 22, 71 | uint64: 23 72 | }; 73 | 74 | // Status codes used by the TensorFlow API. 75 | const statusCodes = { 76 | ok: 0, 77 | cancelled: 1, 78 | unknown: 2, 79 | invalidArgument: 3, 80 | deadlineExceeded: 4, 81 | notFound: 5, 82 | alreadyExists: 6, 83 | permissionDenied: 7, 84 | resourceExhausted: 8, 85 | failedPrecondition: 9, 86 | aborted: 10, 87 | outOfRange: 11, 88 | unimplemented: 12, 89 | internal: 13, 90 | unavailable: 14, 91 | dataLoss: 15, 92 | unauthenticated: 16, 93 | }; 94 | 95 | let libPath = process.env['TENSORFLOW_LIB_PATH']; 96 | if (!libPath) { 97 | libPath = path.join(__dirname, '..', '..', 'lib'); 98 | } 99 | if (!fs.existsSync(path.join(libPath, 'libtensorflow.so'))) { 100 | throw new Error(`libtensorflow.so was not found at "${libPath}"`); 101 | } 102 | if (!fs.existsSync(path.join(libPath, 'libtensorflow_framework.so'))) { 103 | throw new Error(`libtensorflow_framework.so was not found at "${libPath}"`); 104 | } 105 | 106 | // Change the TensorFlow logging level to WARNING (default is INFO, which gets pretty noisy). 107 | // 0 -> all logs 108 | // 1 -> filter out INFO 109 | // 2 -> filter out WARN 110 | // 3 -> filter out ERROR 111 | process.env['TF_CPP_MIN_LOG_LEVEL'] = process.env['TENSORFLOW_LIB_LOG_LEVEL'] || '1'; 112 | 113 | // Defines the subset of relevant TensorFlow APIs. 114 | // Each entry corresponds to an exported API signature in form of name -> [return type, arg types]. 115 | const libApi = { 116 | // Status TF_NewStatus() 117 | TF_NewStatus: [types.Status, []], 118 | 119 | // void TF_DeleteStatus(Status) 120 | TF_DeleteStatus: [types.Void, [types.Status]], 121 | 122 | // void TF_SetStatus(Status, int code, string message) 123 | TF_SetStatus: [types.Void, [types.Status, types.Int, types.String]], 124 | 125 | // int TF_GetCode(Status) 126 | TF_GetCode: [types.Int, [types.Status]], 127 | 128 | // string TF_Message(Status) 129 | TF_Message: [types.String, [types.Status]], 130 | 131 | // Tensor TF_NewTensor(int dataType, longlong* dimLengths, int dims, void* data, size_t length, 132 | // void* dealloc, void* deallocarg) 133 | TF_NewTensor: [types.Tensor, [types.Int, types.LongLongArray, types.Int, types.Any, types.Size, 134 | types.Any, types.Any]], 135 | 136 | // void TF_DeleteTensor(Tensor) 137 | TF_DeleteTensor: [types.Void, [types.Tensor]], 138 | 139 | // int TF_TensorType(tensor) 140 | TF_TensorType: [types.Int, [types.Tensor]], 141 | 142 | // int TF_NumDims(tensor) 143 | TF_NumDims: [types.Int, [types.Tensor]], 144 | 145 | // longlong TF_Dim(tensor, int dimensionIndex) 146 | TF_Dim: [types.LongLong, [types.Tensor, types.Int]], 147 | 148 | // size_t TF_TensorByteSize(tensor) 149 | TF_TensorByteSize: [types.Size, [types.Tensor]], 150 | 151 | // void* TF_TensorData(tensor) 152 | TF_TensorData: [types.Any, [types.Tensor]], 153 | 154 | // size_t TF_StringEncodedSize(size_t len) 155 | TF_StringEncodedSize: [types.Size, [types.Size]], 156 | 157 | // size_t TF_StringEncode(char* src, size_t src_len, char* dst, size_t dst_len, status) 158 | TF_StringEncode: [types.Size, [types.Any, types.Size, types.Any, types.Size, types.Status]], 159 | 160 | // size_t TF_StringDecode(char* src, size_t src_len, char** dst, size_t* dst_len, status) 161 | TF_StringDecode: [types.Size, [types.Any, types.Size, types.Any, types.SizePtr, types.Status]], 162 | 163 | // Buffer TF_NewBufferFromString(void* data, size_t len) 164 | TF_NewBufferFromString: [types.Buffer, [types.Any, types.Size]], 165 | 166 | // void TF_DeleteBuffer(Buffer) 167 | TF_DeleteBuffer: [types.Void, [types.Buffer]], 168 | 169 | // ImportGraphDefOptions TF_NewImportGraphDefOptions() 170 | TF_NewImportGraphDefOptions: [types.ImportGraphDefOptions, []], 171 | 172 | // void TF_DeleteImportGraphDefOptions(Graph) 173 | TF_DeleteImportGraphDefOptions: [types.Void, [types.ImportGraphDefOptions]], 174 | 175 | // Graph TF_NewGraph() 176 | TF_NewGraph: [types.Graph, []], 177 | 178 | // void TF_DeleteGraph(Graph) 179 | TF_DeleteGraph: [types.Void, [types.Graph]], 180 | 181 | // void TF_GraphImportGraphDef(Graph, Buffer graph_def, ImportGraphDefOptions options, Status) 182 | TF_GraphImportGraphDef: [types.Void, 183 | [types.Graph, types.Buffer, types.ImportGraphDefOptions, types.Status]], 184 | 185 | // Operation TF_GraphOperationByName(Graph graph, char* oper_name); 186 | TF_GraphOperationByName: [types.Operation, [types.Graph, types.String]], 187 | 188 | // SessionOptions TF_NewSessionOptions() 189 | TF_NewSessionOptions: [types.SessionOptions, []], 190 | 191 | // void TF_DeleteSessionOptions(Graph) 192 | TF_DeleteSessionOptions: [types.Void, [types.SessionOptions]], 193 | 194 | // Session TF_NewSession(Graph graph, SessionOptions options, Status status); 195 | TF_NewSession: [types.Session, [types.Graph, types.SessionOptions, types.Status]], 196 | 197 | // void TF_DeleteSessionOptions(Graph) 198 | TF_DeleteSession: [types.Void, [types.Session, types.Status]], 199 | 200 | // void TF_SessionRun(Session, Buffer options, 201 | // Input* input_ops, Tensor* input_values, int inputs, 202 | // Output* output_ops, Tensor* output_values, int outputs, 203 | // Operation* target_ops, int targets, 204 | // Buffer metadata, Status) 205 | TF_SessionRun: [types.Void, [types.Session, types.Buffer, 206 | types.OperationValueArray, types.TensorArray, types.Int, 207 | types.OperationValueArray, types.TensorArray, types.Int, 208 | types.OperationArray, types.Int, 209 | types.Buffer, types.Status]] 210 | }; 211 | 212 | const library = ffi.Library(path.join(libPath, 'libtensorflow'), libApi); 213 | library.Protos = require('./messages'); 214 | library.ApiTypes = types; 215 | library.Status = library.TF_NewStatus(); 216 | library.StatusCodes = statusCodes; 217 | library.Types = tensorTypes; 218 | 219 | // A no-op deallocator, that can be passed in when creating tensors. 220 | // The buffer allocated to hold tensors is automatically freed up within node.js. 221 | library.TensorDeallocator = ffi.Callback(types.Void, [types.Any, types.Size, types.Any], 222 | function() {}); 223 | 224 | module.exports = library; 225 | -------------------------------------------------------------------------------- /src/interop/messages.js: -------------------------------------------------------------------------------- 1 | 'use strict'; // code generated by pbf v3.1.0 2 | 3 | var DataType = exports.DataType = { 4 | "DT_INVALID": 0, 5 | "DT_FLOAT": 1, 6 | "DT_DOUBLE": 2, 7 | "DT_INT32": 3, 8 | "DT_UINT8": 4, 9 | "DT_INT16": 5, 10 | "DT_INT8": 6, 11 | "DT_STRING": 7, 12 | "DT_COMPLEX64": 8, 13 | "DT_INT64": 9, 14 | "DT_BOOL": 10, 15 | "DT_QINT8": 11, 16 | "DT_QUINT8": 12, 17 | "DT_QINT32": 13, 18 | "DT_BFLOAT16": 14, 19 | "DT_FLOAT_REF": 101, 20 | "DT_DOUBLE_REF": 102, 21 | "DT_INT32_REF": 103, 22 | "DT_UINT8_REF": 104, 23 | "DT_INT16_REF": 105, 24 | "DT_INT8_REF": 106, 25 | "DT_STRING_REF": 107, 26 | "DT_COMPLEX64_REF": 108, 27 | "DT_INT64_REF": 109, 28 | "DT_BOOL_REF": 110, 29 | "DT_QINT8_REF": 111, 30 | "DT_QUINT8_REF": 112, 31 | "DT_QINT32_REF": 113, 32 | "DT_BFLOAT16_REF": 114 33 | }; 34 | 35 | // Any ======================================== 36 | 37 | var Any = exports.Any = {}; 38 | 39 | Any.read = function (pbf, end) { 40 | return pbf.readFields(Any._readField, {type_url: "", value: null}, end); 41 | }; 42 | Any._readField = function (tag, obj, pbf) { 43 | if (tag === 1) obj.type_url = pbf.readString(); 44 | else if (tag === 2) obj.value = pbf.readBytes(); 45 | }; 46 | Any.write = function (obj, pbf) { 47 | if (obj.type_url) pbf.writeStringField(1, obj.type_url); 48 | if (obj.value) pbf.writeBytesField(2, obj.value); 49 | }; 50 | 51 | // TensorShape ======================================== 52 | 53 | var TensorShape = exports.TensorShape = {}; 54 | 55 | TensorShape.read = function (pbf, end) { 56 | return pbf.readFields(TensorShape._readField, {dim: [], unknown_rank: false}, end); 57 | }; 58 | TensorShape._readField = function (tag, obj, pbf) { 59 | if (tag === 2) obj.dim.push(TensorShape.Dim.read(pbf, pbf.readVarint() + pbf.pos)); 60 | else if (tag === 3) obj.unknown_rank = pbf.readBoolean(); 61 | }; 62 | TensorShape.write = function (obj, pbf) { 63 | if (obj.dim) for (var i = 0; i < obj.dim.length; i++) pbf.writeMessage(2, TensorShape.Dim.write, obj.dim[i]); 64 | if (obj.unknown_rank) pbf.writeBooleanField(3, obj.unknown_rank); 65 | }; 66 | 67 | // TensorShape.Dim ======================================== 68 | 69 | TensorShape.Dim = {}; 70 | 71 | TensorShape.Dim.read = function (pbf, end) { 72 | return pbf.readFields(TensorShape.Dim._readField, {size: 0, name: ""}, end); 73 | }; 74 | TensorShape.Dim._readField = function (tag, obj, pbf) { 75 | if (tag === 1) obj.size = pbf.readVarint(true); 76 | else if (tag === 2) obj.name = pbf.readString(); 77 | }; 78 | TensorShape.Dim.write = function (obj, pbf) { 79 | if (obj.size) pbf.writeVarintField(1, obj.size); 80 | if (obj.name) pbf.writeStringField(2, obj.name); 81 | }; 82 | 83 | // Tensor ======================================== 84 | 85 | var Tensor = exports.Tensor = {}; 86 | 87 | Tensor.read = function (pbf, end) { 88 | return pbf.readFields(Tensor._readField, {dtype: 0, tensor_shape: null, version_number: 0, tensor_content: null, float_val: [], double_val: [], int_val: [], string_val: [], scomplex_val: [], int64_val: [], bool_val: [], uint32_val: [], uint64_val: []}, end); 89 | }; 90 | Tensor._readField = function (tag, obj, pbf) { 91 | if (tag === 1) obj.dtype = pbf.readVarint(); 92 | else if (tag === 2) obj.tensor_shape = TensorShape.read(pbf, pbf.readVarint() + pbf.pos); 93 | else if (tag === 3) obj.version_number = pbf.readVarint(true); 94 | else if (tag === 4) obj.tensor_content = pbf.readBytes(); 95 | else if (tag === 5) pbf.readPackedFloat(obj.float_val); 96 | else if (tag === 6) pbf.readPackedDouble(obj.double_val); 97 | else if (tag === 7) pbf.readPackedVarint(obj.int_val, true); 98 | else if (tag === 8) obj.string_val.push(pbf.readBytes()); 99 | else if (tag === 9) pbf.readPackedFloat(obj.scomplex_val); 100 | else if (tag === 10) pbf.readPackedVarint(obj.int64_val, true); 101 | else if (tag === 11) pbf.readPackedBoolean(obj.bool_val); 102 | else if (tag === 16) pbf.readPackedVarint(obj.uint32_val); 103 | else if (tag === 17) pbf.readPackedVarint(obj.uint64_val); 104 | }; 105 | Tensor.write = function (obj, pbf) { 106 | if (obj.dtype) pbf.writeVarintField(1, obj.dtype); 107 | if (obj.tensor_shape) pbf.writeMessage(2, TensorShape.write, obj.tensor_shape); 108 | if (obj.version_number) pbf.writeVarintField(3, obj.version_number); 109 | if (obj.tensor_content) pbf.writeBytesField(4, obj.tensor_content); 110 | if (obj.float_val) pbf.writePackedFloat(5, obj.float_val); 111 | if (obj.double_val) pbf.writePackedDouble(6, obj.double_val); 112 | if (obj.int_val) pbf.writePackedVarint(7, obj.int_val); 113 | if (obj.string_val) for (var i = 0; i < obj.string_val.length; i++) pbf.writeBytesField(8, obj.string_val[i]); 114 | if (obj.scomplex_val) pbf.writePackedFloat(9, obj.scomplex_val); 115 | if (obj.int64_val) pbf.writePackedVarint(10, obj.int64_val); 116 | if (obj.bool_val) pbf.writePackedBoolean(11, obj.bool_val); 117 | if (obj.uint32_val) pbf.writePackedVarint(16, obj.uint32_val); 118 | if (obj.uint64_val) pbf.writePackedVarint(17, obj.uint64_val); 119 | }; 120 | 121 | // AttrValue ======================================== 122 | 123 | var AttrValue = exports.AttrValue = {}; 124 | 125 | AttrValue.read = function (pbf, end) { 126 | return pbf.readFields(AttrValue._readField, {list: null, value: null, s: null, i: 0, f: 0, b: false, type: 0, shape: null, tensor: null, placeholder: "", func: null}, end); 127 | }; 128 | AttrValue._readField = function (tag, obj, pbf) { 129 | if (tag === 1) obj.list = AttrValue.ListValue.read(pbf, pbf.readVarint() + pbf.pos), obj.value = "list"; 130 | else if (tag === 2) obj.s = pbf.readBytes(), obj.value = "s"; 131 | else if (tag === 3) obj.i = pbf.readVarint(true), obj.value = "i"; 132 | else if (tag === 4) obj.f = pbf.readFloat(), obj.value = "f"; 133 | else if (tag === 5) obj.b = pbf.readBoolean(), obj.value = "b"; 134 | else if (tag === 6) obj.type = pbf.readVarint(), obj.value = "type"; 135 | else if (tag === 7) obj.shape = TensorShape.read(pbf, pbf.readVarint() + pbf.pos), obj.value = "shape"; 136 | else if (tag === 8) obj.tensor = Tensor.read(pbf, pbf.readVarint() + pbf.pos), obj.value = "tensor"; 137 | else if (tag === 9) obj.placeholder = pbf.readString(), obj.value = "placeholder"; 138 | else if (tag === 10) obj.func = NameAttrList.read(pbf, pbf.readVarint() + pbf.pos), obj.value = "func"; 139 | }; 140 | AttrValue.write = function (obj, pbf) { 141 | if (obj.list) pbf.writeMessage(1, AttrValue.ListValue.write, obj.list); 142 | if (obj.s) pbf.writeBytesField(2, obj.s); 143 | if (obj.i) pbf.writeVarintField(3, obj.i); 144 | if (obj.f) pbf.writeFloatField(4, obj.f); 145 | if (obj.b) pbf.writeBooleanField(5, obj.b); 146 | if (obj.type) pbf.writeVarintField(6, obj.type); 147 | if (obj.shape) pbf.writeMessage(7, TensorShape.write, obj.shape); 148 | if (obj.tensor) pbf.writeMessage(8, Tensor.write, obj.tensor); 149 | if (obj.placeholder) pbf.writeStringField(9, obj.placeholder); 150 | if (obj.func) pbf.writeMessage(10, NameAttrList.write, obj.func); 151 | }; 152 | 153 | // AttrValue.ListValue ======================================== 154 | 155 | AttrValue.ListValue = {}; 156 | 157 | AttrValue.ListValue.read = function (pbf, end) { 158 | return pbf.readFields(AttrValue.ListValue._readField, {s: [], i: [], f: [], b: [], type: [], shape: [], tensor: [], func: []}, end); 159 | }; 160 | AttrValue.ListValue._readField = function (tag, obj, pbf) { 161 | if (tag === 2) obj.s.push(pbf.readBytes()); 162 | else if (tag === 3) pbf.readPackedVarint(obj.i, true); 163 | else if (tag === 4) pbf.readPackedFloat(obj.f); 164 | else if (tag === 5) pbf.readPackedBoolean(obj.b); 165 | else if (tag === 6) pbf.readPackedVarint(obj.type); 166 | else if (tag === 7) obj.shape.push(TensorShape.read(pbf, pbf.readVarint() + pbf.pos)); 167 | else if (tag === 8) obj.tensor.push(Tensor.read(pbf, pbf.readVarint() + pbf.pos)); 168 | else if (tag === 9) obj.func.push(NameAttrList.read(pbf, pbf.readVarint() + pbf.pos)); 169 | }; 170 | AttrValue.ListValue.write = function (obj, pbf) { 171 | if (obj.s) for (var i = 0; i < obj.s.length; i++) pbf.writeBytesField(2, obj.s[i]); 172 | if (obj.i) pbf.writePackedVarint(3, obj.i); 173 | if (obj.f) pbf.writePackedFloat(4, obj.f); 174 | if (obj.b) pbf.writePackedBoolean(5, obj.b); 175 | if (obj.type) pbf.writePackedVarint(6, obj.type); 176 | if (obj.shape) for (i = 0; i < obj.shape.length; i++) pbf.writeMessage(7, TensorShape.write, obj.shape[i]); 177 | if (obj.tensor) for (i = 0; i < obj.tensor.length; i++) pbf.writeMessage(8, Tensor.write, obj.tensor[i]); 178 | if (obj.func) for (i = 0; i < obj.func.length; i++) pbf.writeMessage(9, NameAttrList.write, obj.func[i]); 179 | }; 180 | 181 | // NameAttrList ======================================== 182 | 183 | var NameAttrList = exports.NameAttrList = {}; 184 | 185 | NameAttrList.read = function (pbf, end) { 186 | return pbf.readFields(NameAttrList._readField, {name: "", attr: {}}, end); 187 | }; 188 | NameAttrList._readField = function (tag, obj, pbf) { 189 | if (tag === 1) obj.name = pbf.readString(); 190 | else if (tag === 2) { var entry = NameAttrList._FieldEntry2.read(pbf, pbf.readVarint() + pbf.pos); obj.attr[entry.key] = entry.value; } 191 | }; 192 | NameAttrList.write = function (obj, pbf) { 193 | if (obj.name) pbf.writeStringField(1, obj.name); 194 | if (obj.attr) for (var i in obj.attr) if (Object.prototype.hasOwnProperty.call(obj.attr, i)) pbf.writeMessage(2, NameAttrList._FieldEntry2.write, { key: i, value: obj.attr[i] }); 195 | }; 196 | 197 | // NameAttrList._FieldEntry2 ======================================== 198 | 199 | NameAttrList._FieldEntry2 = {}; 200 | 201 | NameAttrList._FieldEntry2.read = function (pbf, end) { 202 | return pbf.readFields(NameAttrList._FieldEntry2._readField, {key: "", value: null}, end); 203 | }; 204 | NameAttrList._FieldEntry2._readField = function (tag, obj, pbf) { 205 | if (tag === 1) obj.key = pbf.readString(); 206 | else if (tag === 2) obj.value = AttrValue.read(pbf, pbf.readVarint() + pbf.pos); 207 | }; 208 | NameAttrList._FieldEntry2.write = function (obj, pbf) { 209 | if (obj.key) pbf.writeStringField(1, obj.key); 210 | if (obj.value) pbf.writeMessage(2, AttrValue.write, obj.value); 211 | }; 212 | 213 | // NodeDef ======================================== 214 | 215 | var NodeDef = exports.NodeDef = {}; 216 | 217 | NodeDef.read = function (pbf, end) { 218 | return pbf.readFields(NodeDef._readField, {name: "", op: "", input: [], device: "", attr: {}}, end); 219 | }; 220 | NodeDef._readField = function (tag, obj, pbf) { 221 | if (tag === 1) obj.name = pbf.readString(); 222 | else if (tag === 2) obj.op = pbf.readString(); 223 | else if (tag === 3) obj.input.push(pbf.readString()); 224 | else if (tag === 4) obj.device = pbf.readString(); 225 | else if (tag === 5) { var entry = NodeDef._FieldEntry5.read(pbf, pbf.readVarint() + pbf.pos); obj.attr[entry.key] = entry.value; } 226 | }; 227 | NodeDef.write = function (obj, pbf) { 228 | if (obj.name) pbf.writeStringField(1, obj.name); 229 | if (obj.op) pbf.writeStringField(2, obj.op); 230 | if (obj.input) for (var i = 0; i < obj.input.length; i++) pbf.writeStringField(3, obj.input[i]); 231 | if (obj.device) pbf.writeStringField(4, obj.device); 232 | if (obj.attr) for (i in obj.attr) if (Object.prototype.hasOwnProperty.call(obj.attr, i)) pbf.writeMessage(5, NodeDef._FieldEntry5.write, { key: i, value: obj.attr[i] }); 233 | }; 234 | 235 | // NodeDef._FieldEntry5 ======================================== 236 | 237 | NodeDef._FieldEntry5 = {}; 238 | 239 | NodeDef._FieldEntry5.read = function (pbf, end) { 240 | return pbf.readFields(NodeDef._FieldEntry5._readField, {key: "", value: null}, end); 241 | }; 242 | NodeDef._FieldEntry5._readField = function (tag, obj, pbf) { 243 | if (tag === 1) obj.key = pbf.readString(); 244 | else if (tag === 2) obj.value = AttrValue.read(pbf, pbf.readVarint() + pbf.pos); 245 | }; 246 | NodeDef._FieldEntry5.write = function (obj, pbf) { 247 | if (obj.key) pbf.writeStringField(1, obj.key); 248 | if (obj.value) pbf.writeMessage(2, AttrValue.write, obj.value); 249 | }; 250 | 251 | // VersionDef ======================================== 252 | 253 | var VersionDef = exports.VersionDef = {}; 254 | 255 | VersionDef.read = function (pbf, end) { 256 | return pbf.readFields(VersionDef._readField, {producer: 0, min_consumer: 0, bad_consumers: []}, end); 257 | }; 258 | VersionDef._readField = function (tag, obj, pbf) { 259 | if (tag === 1) obj.producer = pbf.readVarint(true); 260 | else if (tag === 2) obj.min_consumer = pbf.readVarint(true); 261 | else if (tag === 3) pbf.readPackedVarint(obj.bad_consumers, true); 262 | }; 263 | VersionDef.write = function (obj, pbf) { 264 | if (obj.producer) pbf.writeVarintField(1, obj.producer); 265 | if (obj.min_consumer) pbf.writeVarintField(2, obj.min_consumer); 266 | if (obj.bad_consumers) pbf.writePackedVarint(3, obj.bad_consumers); 267 | }; 268 | 269 | // GraphDef ======================================== 270 | 271 | var GraphDef = exports.GraphDef = {}; 272 | 273 | GraphDef.read = function (pbf, end) { 274 | return pbf.readFields(GraphDef._readField, {node: [], versions: null}, end); 275 | }; 276 | GraphDef._readField = function (tag, obj, pbf) { 277 | if (tag === 1) obj.node.push(NodeDef.read(pbf, pbf.readVarint() + pbf.pos)); 278 | else if (tag === 4) obj.versions = VersionDef.read(pbf, pbf.readVarint() + pbf.pos); 279 | }; 280 | GraphDef.write = function (obj, pbf) { 281 | if (obj.node) for (var i = 0; i < obj.node.length; i++) pbf.writeMessage(1, NodeDef.write, obj.node[i]); 282 | if (obj.versions) pbf.writeMessage(4, VersionDef.write, obj.versions); 283 | }; 284 | 285 | // CollectionDef ======================================== 286 | 287 | var CollectionDef = exports.CollectionDef = {}; 288 | 289 | CollectionDef.read = function (pbf, end) { 290 | return pbf.readFields(CollectionDef._readField, {node_list: null, kind: null, bytes_list: null, int64_list: null, float_list: null, any_list: null}, end); 291 | }; 292 | CollectionDef._readField = function (tag, obj, pbf) { 293 | if (tag === 1) obj.node_list = CollectionDef.NodeList.read(pbf, pbf.readVarint() + pbf.pos), obj.kind = "node_list"; 294 | else if (tag === 2) obj.bytes_list = CollectionDef.BytesList.read(pbf, pbf.readVarint() + pbf.pos), obj.kind = "bytes_list"; 295 | else if (tag === 3) obj.int64_list = CollectionDef.Int64List.read(pbf, pbf.readVarint() + pbf.pos), obj.kind = "int64_list"; 296 | else if (tag === 4) obj.float_list = CollectionDef.FloatList.read(pbf, pbf.readVarint() + pbf.pos), obj.kind = "float_list"; 297 | else if (tag === 5) obj.any_list = CollectionDef.AnyList.read(pbf, pbf.readVarint() + pbf.pos), obj.kind = "any_list"; 298 | }; 299 | CollectionDef.write = function (obj, pbf) { 300 | if (obj.node_list) pbf.writeMessage(1, CollectionDef.NodeList.write, obj.node_list); 301 | if (obj.bytes_list) pbf.writeMessage(2, CollectionDef.BytesList.write, obj.bytes_list); 302 | if (obj.int64_list) pbf.writeMessage(3, CollectionDef.Int64List.write, obj.int64_list); 303 | if (obj.float_list) pbf.writeMessage(4, CollectionDef.FloatList.write, obj.float_list); 304 | if (obj.any_list) pbf.writeMessage(5, CollectionDef.AnyList.write, obj.any_list); 305 | }; 306 | 307 | // CollectionDef.NodeList ======================================== 308 | 309 | CollectionDef.NodeList = {}; 310 | 311 | CollectionDef.NodeList.read = function (pbf, end) { 312 | return pbf.readFields(CollectionDef.NodeList._readField, {value: []}, end); 313 | }; 314 | CollectionDef.NodeList._readField = function (tag, obj, pbf) { 315 | if (tag === 1) obj.value.push(pbf.readString()); 316 | }; 317 | CollectionDef.NodeList.write = function (obj, pbf) { 318 | if (obj.value) for (var i = 0; i < obj.value.length; i++) pbf.writeStringField(1, obj.value[i]); 319 | }; 320 | 321 | // CollectionDef.BytesList ======================================== 322 | 323 | CollectionDef.BytesList = {}; 324 | 325 | CollectionDef.BytesList.read = function (pbf, end) { 326 | return pbf.readFields(CollectionDef.BytesList._readField, {value: []}, end); 327 | }; 328 | CollectionDef.BytesList._readField = function (tag, obj, pbf) { 329 | if (tag === 1) obj.value.push(pbf.readBytes()); 330 | }; 331 | CollectionDef.BytesList.write = function (obj, pbf) { 332 | if (obj.value) for (var i = 0; i < obj.value.length; i++) pbf.writeBytesField(1, obj.value[i]); 333 | }; 334 | 335 | // CollectionDef.Int64List ======================================== 336 | 337 | CollectionDef.Int64List = {}; 338 | 339 | CollectionDef.Int64List.read = function (pbf, end) { 340 | return pbf.readFields(CollectionDef.Int64List._readField, {value: []}, end); 341 | }; 342 | CollectionDef.Int64List._readField = function (tag, obj, pbf) { 343 | if (tag === 1) pbf.readPackedVarint(obj.value, true); 344 | }; 345 | CollectionDef.Int64List.write = function (obj, pbf) { 346 | if (obj.value) pbf.writePackedVarint(1, obj.value); 347 | }; 348 | 349 | // CollectionDef.FloatList ======================================== 350 | 351 | CollectionDef.FloatList = {}; 352 | 353 | CollectionDef.FloatList.read = function (pbf, end) { 354 | return pbf.readFields(CollectionDef.FloatList._readField, {value: []}, end); 355 | }; 356 | CollectionDef.FloatList._readField = function (tag, obj, pbf) { 357 | if (tag === 1) pbf.readPackedFloat(obj.value); 358 | }; 359 | CollectionDef.FloatList.write = function (obj, pbf) { 360 | if (obj.value) pbf.writePackedFloat(1, obj.value); 361 | }; 362 | 363 | // CollectionDef.AnyList ======================================== 364 | 365 | CollectionDef.AnyList = {}; 366 | 367 | CollectionDef.AnyList.read = function (pbf, end) { 368 | return pbf.readFields(CollectionDef.AnyList._readField, {value: []}, end); 369 | }; 370 | CollectionDef.AnyList._readField = function (tag, obj, pbf) { 371 | if (tag === 1) obj.value.push(Any.read(pbf, pbf.readVarint() + pbf.pos)); 372 | }; 373 | CollectionDef.AnyList.write = function (obj, pbf) { 374 | if (obj.value) for (var i = 0; i < obj.value.length; i++) pbf.writeMessage(1, Any.write, obj.value[i]); 375 | }; 376 | 377 | // SaverDef ======================================== 378 | 379 | var SaverDef = exports.SaverDef = {}; 380 | 381 | SaverDef.read = function (pbf, end) { 382 | return pbf.readFields(SaverDef._readField, {filename_tensor_name: "", save_tensor_name: "", restore_op_name: "", max_to_keep: 0, sharded: false, keep_checkpoint_every_n_hours: 0, version: 0}, end); 383 | }; 384 | SaverDef._readField = function (tag, obj, pbf) { 385 | if (tag === 1) obj.filename_tensor_name = pbf.readString(); 386 | else if (tag === 2) obj.save_tensor_name = pbf.readString(); 387 | else if (tag === 3) obj.restore_op_name = pbf.readString(); 388 | else if (tag === 4) obj.max_to_keep = pbf.readVarint(true); 389 | else if (tag === 5) obj.sharded = pbf.readBoolean(); 390 | else if (tag === 6) obj.keep_checkpoint_every_n_hours = pbf.readFloat(); 391 | else if (tag === 7) obj.version = pbf.readVarint(); 392 | }; 393 | SaverDef.write = function (obj, pbf) { 394 | if (obj.filename_tensor_name) pbf.writeStringField(1, obj.filename_tensor_name); 395 | if (obj.save_tensor_name) pbf.writeStringField(2, obj.save_tensor_name); 396 | if (obj.restore_op_name) pbf.writeStringField(3, obj.restore_op_name); 397 | if (obj.max_to_keep) pbf.writeVarintField(4, obj.max_to_keep); 398 | if (obj.sharded) pbf.writeBooleanField(5, obj.sharded); 399 | if (obj.keep_checkpoint_every_n_hours) pbf.writeFloatField(6, obj.keep_checkpoint_every_n_hours); 400 | if (obj.version) pbf.writeVarintField(7, obj.version); 401 | }; 402 | 403 | SaverDef.CheckpointFormatVersion = { 404 | "LEGACY": 0, 405 | "V1": 1, 406 | "V2": 2 407 | }; 408 | 409 | // TensorInfo ======================================== 410 | 411 | var TensorInfo = exports.TensorInfo = {}; 412 | 413 | TensorInfo.read = function (pbf, end) { 414 | return pbf.readFields(TensorInfo._readField, {name: "", encoding: null, coo_sparse: null, dtype: 0, tensor_shape: null}, end); 415 | }; 416 | TensorInfo._readField = function (tag, obj, pbf) { 417 | if (tag === 1) obj.name = pbf.readString(), obj.encoding = "name"; 418 | else if (tag === 4) obj.coo_sparse = TensorInfo.CooSparse.read(pbf, pbf.readVarint() + pbf.pos), obj.encoding = "coo_sparse"; 419 | else if (tag === 2) obj.dtype = pbf.readVarint(); 420 | else if (tag === 3) obj.tensor_shape = TensorShape.read(pbf, pbf.readVarint() + pbf.pos); 421 | }; 422 | TensorInfo.write = function (obj, pbf) { 423 | if (obj.name) pbf.writeStringField(1, obj.name); 424 | if (obj.coo_sparse) pbf.writeMessage(4, TensorInfo.CooSparse.write, obj.coo_sparse); 425 | if (obj.dtype) pbf.writeVarintField(2, obj.dtype); 426 | if (obj.tensor_shape) pbf.writeMessage(3, TensorShape.write, obj.tensor_shape); 427 | }; 428 | 429 | // TensorInfo.CooSparse ======================================== 430 | 431 | TensorInfo.CooSparse = {}; 432 | 433 | TensorInfo.CooSparse.read = function (pbf, end) { 434 | return pbf.readFields(TensorInfo.CooSparse._readField, {values_tensor_name: "", indices_tensor_name: "", dense_shape_tensor_name: ""}, end); 435 | }; 436 | TensorInfo.CooSparse._readField = function (tag, obj, pbf) { 437 | if (tag === 1) obj.values_tensor_name = pbf.readString(); 438 | else if (tag === 2) obj.indices_tensor_name = pbf.readString(); 439 | else if (tag === 3) obj.dense_shape_tensor_name = pbf.readString(); 440 | }; 441 | TensorInfo.CooSparse.write = function (obj, pbf) { 442 | if (obj.values_tensor_name) pbf.writeStringField(1, obj.values_tensor_name); 443 | if (obj.indices_tensor_name) pbf.writeStringField(2, obj.indices_tensor_name); 444 | if (obj.dense_shape_tensor_name) pbf.writeStringField(3, obj.dense_shape_tensor_name); 445 | }; 446 | 447 | // SignatureDef ======================================== 448 | 449 | var SignatureDef = exports.SignatureDef = {}; 450 | 451 | SignatureDef.read = function (pbf, end) { 452 | return pbf.readFields(SignatureDef._readField, {inputs: {}, outputs: {}, method_name: ""}, end); 453 | }; 454 | SignatureDef._readField = function (tag, obj, pbf) { 455 | if (tag === 1) { var entry = SignatureDef._FieldEntry1.read(pbf, pbf.readVarint() + pbf.pos); obj.inputs[entry.key] = entry.value; } 456 | else if (tag === 2) { entry = SignatureDef._FieldEntry2.read(pbf, pbf.readVarint() + pbf.pos); obj.outputs[entry.key] = entry.value; } 457 | else if (tag === 3) obj.method_name = pbf.readString(); 458 | }; 459 | SignatureDef.write = function (obj, pbf) { 460 | if (obj.inputs) for (var i in obj.inputs) if (Object.prototype.hasOwnProperty.call(obj.inputs, i)) pbf.writeMessage(1, SignatureDef._FieldEntry1.write, { key: i, value: obj.inputs[i] }); 461 | if (obj.outputs) for (i in obj.outputs) if (Object.prototype.hasOwnProperty.call(obj.outputs, i)) pbf.writeMessage(2, SignatureDef._FieldEntry2.write, { key: i, value: obj.outputs[i] }); 462 | if (obj.method_name) pbf.writeStringField(3, obj.method_name); 463 | }; 464 | 465 | // SignatureDef._FieldEntry1 ======================================== 466 | 467 | SignatureDef._FieldEntry1 = {}; 468 | 469 | SignatureDef._FieldEntry1.read = function (pbf, end) { 470 | return pbf.readFields(SignatureDef._FieldEntry1._readField, {key: "", value: null}, end); 471 | }; 472 | SignatureDef._FieldEntry1._readField = function (tag, obj, pbf) { 473 | if (tag === 1) obj.key = pbf.readString(); 474 | else if (tag === 2) obj.value = TensorInfo.read(pbf, pbf.readVarint() + pbf.pos); 475 | }; 476 | SignatureDef._FieldEntry1.write = function (obj, pbf) { 477 | if (obj.key) pbf.writeStringField(1, obj.key); 478 | if (obj.value) pbf.writeMessage(2, TensorInfo.write, obj.value); 479 | }; 480 | 481 | // SignatureDef._FieldEntry2 ======================================== 482 | 483 | SignatureDef._FieldEntry2 = {}; 484 | 485 | SignatureDef._FieldEntry2.read = function (pbf, end) { 486 | return pbf.readFields(SignatureDef._FieldEntry2._readField, {key: "", value: null}, end); 487 | }; 488 | SignatureDef._FieldEntry2._readField = function (tag, obj, pbf) { 489 | if (tag === 1) obj.key = pbf.readString(); 490 | else if (tag === 2) obj.value = TensorInfo.read(pbf, pbf.readVarint() + pbf.pos); 491 | }; 492 | SignatureDef._FieldEntry2.write = function (obj, pbf) { 493 | if (obj.key) pbf.writeStringField(1, obj.key); 494 | if (obj.value) pbf.writeMessage(2, TensorInfo.write, obj.value); 495 | }; 496 | 497 | // AssetFileDef ======================================== 498 | 499 | var AssetFileDef = exports.AssetFileDef = {}; 500 | 501 | AssetFileDef.read = function (pbf, end) { 502 | return pbf.readFields(AssetFileDef._readField, {tensor_info: null, filename: ""}, end); 503 | }; 504 | AssetFileDef._readField = function (tag, obj, pbf) { 505 | if (tag === 1) obj.tensor_info = TensorInfo.read(pbf, pbf.readVarint() + pbf.pos); 506 | else if (tag === 2) obj.filename = pbf.readString(); 507 | }; 508 | AssetFileDef.write = function (obj, pbf) { 509 | if (obj.tensor_info) pbf.writeMessage(1, TensorInfo.write, obj.tensor_info); 510 | if (obj.filename) pbf.writeStringField(2, obj.filename); 511 | }; 512 | 513 | // OpDef ======================================== 514 | 515 | var OpDef = exports.OpDef = {}; 516 | 517 | OpDef.read = function (pbf, end) { 518 | return pbf.readFields(OpDef._readField, {name: "", input_arg: [], output_arg: [], attr: [], deprecation: null, summary: "", description: "", is_commutative: false, is_aggregate: false, is_stateful: false, allows_uninitialized_input: false}, end); 519 | }; 520 | OpDef._readField = function (tag, obj, pbf) { 521 | if (tag === 1) obj.name = pbf.readString(); 522 | else if (tag === 2) obj.input_arg.push(OpDef.ArgDef.read(pbf, pbf.readVarint() + pbf.pos)); 523 | else if (tag === 3) obj.output_arg.push(OpDef.ArgDef.read(pbf, pbf.readVarint() + pbf.pos)); 524 | else if (tag === 4) obj.attr.push(OpDef.AttrDef.read(pbf, pbf.readVarint() + pbf.pos)); 525 | else if (tag === 8) obj.deprecation = OpDef.OpDeprecation.read(pbf, pbf.readVarint() + pbf.pos); 526 | else if (tag === 5) obj.summary = pbf.readString(); 527 | else if (tag === 6) obj.description = pbf.readString(); 528 | else if (tag === 18) obj.is_commutative = pbf.readBoolean(); 529 | else if (tag === 16) obj.is_aggregate = pbf.readBoolean(); 530 | else if (tag === 17) obj.is_stateful = pbf.readBoolean(); 531 | else if (tag === 19) obj.allows_uninitialized_input = pbf.readBoolean(); 532 | }; 533 | OpDef.write = function (obj, pbf) { 534 | if (obj.name) pbf.writeStringField(1, obj.name); 535 | if (obj.input_arg) for (var i = 0; i < obj.input_arg.length; i++) pbf.writeMessage(2, OpDef.ArgDef.write, obj.input_arg[i]); 536 | if (obj.output_arg) for (i = 0; i < obj.output_arg.length; i++) pbf.writeMessage(3, OpDef.ArgDef.write, obj.output_arg[i]); 537 | if (obj.attr) for (i = 0; i < obj.attr.length; i++) pbf.writeMessage(4, OpDef.AttrDef.write, obj.attr[i]); 538 | if (obj.deprecation) pbf.writeMessage(8, OpDef.OpDeprecation.write, obj.deprecation); 539 | if (obj.summary) pbf.writeStringField(5, obj.summary); 540 | if (obj.description) pbf.writeStringField(6, obj.description); 541 | if (obj.is_commutative) pbf.writeBooleanField(18, obj.is_commutative); 542 | if (obj.is_aggregate) pbf.writeBooleanField(16, obj.is_aggregate); 543 | if (obj.is_stateful) pbf.writeBooleanField(17, obj.is_stateful); 544 | if (obj.allows_uninitialized_input) pbf.writeBooleanField(19, obj.allows_uninitialized_input); 545 | }; 546 | 547 | // OpDef.ArgDef ======================================== 548 | 549 | OpDef.ArgDef = {}; 550 | 551 | OpDef.ArgDef.read = function (pbf, end) { 552 | return pbf.readFields(OpDef.ArgDef._readField, {name: "", description: "", type: 0, type_attr: "", number_attr: "", type_list_attr: "", is_ref: false}, end); 553 | }; 554 | OpDef.ArgDef._readField = function (tag, obj, pbf) { 555 | if (tag === 1) obj.name = pbf.readString(); 556 | else if (tag === 2) obj.description = pbf.readString(); 557 | else if (tag === 3) obj.type = pbf.readVarint(); 558 | else if (tag === 4) obj.type_attr = pbf.readString(); 559 | else if (tag === 5) obj.number_attr = pbf.readString(); 560 | else if (tag === 6) obj.type_list_attr = pbf.readString(); 561 | else if (tag === 16) obj.is_ref = pbf.readBoolean(); 562 | }; 563 | OpDef.ArgDef.write = function (obj, pbf) { 564 | if (obj.name) pbf.writeStringField(1, obj.name); 565 | if (obj.description) pbf.writeStringField(2, obj.description); 566 | if (obj.type) pbf.writeVarintField(3, obj.type); 567 | if (obj.type_attr) pbf.writeStringField(4, obj.type_attr); 568 | if (obj.number_attr) pbf.writeStringField(5, obj.number_attr); 569 | if (obj.type_list_attr) pbf.writeStringField(6, obj.type_list_attr); 570 | if (obj.is_ref) pbf.writeBooleanField(16, obj.is_ref); 571 | }; 572 | 573 | // OpDef.AttrDef ======================================== 574 | 575 | OpDef.AttrDef = {}; 576 | 577 | OpDef.AttrDef.read = function (pbf, end) { 578 | return pbf.readFields(OpDef.AttrDef._readField, {name: "", type: "", default_value: null, description: "", has_minimum: false, minimum: 0, allowed_values: null}, end); 579 | }; 580 | OpDef.AttrDef._readField = function (tag, obj, pbf) { 581 | if (tag === 1) obj.name = pbf.readString(); 582 | else if (tag === 2) obj.type = pbf.readString(); 583 | else if (tag === 3) obj.default_value = AttrValue.read(pbf, pbf.readVarint() + pbf.pos); 584 | else if (tag === 4) obj.description = pbf.readString(); 585 | else if (tag === 5) obj.has_minimum = pbf.readBoolean(); 586 | else if (tag === 6) obj.minimum = pbf.readVarint(true); 587 | else if (tag === 7) obj.allowed_values = AttrValue.read(pbf, pbf.readVarint() + pbf.pos); 588 | }; 589 | OpDef.AttrDef.write = function (obj, pbf) { 590 | if (obj.name) pbf.writeStringField(1, obj.name); 591 | if (obj.type) pbf.writeStringField(2, obj.type); 592 | if (obj.default_value) pbf.writeMessage(3, AttrValue.write, obj.default_value); 593 | if (obj.description) pbf.writeStringField(4, obj.description); 594 | if (obj.has_minimum) pbf.writeBooleanField(5, obj.has_minimum); 595 | if (obj.minimum) pbf.writeVarintField(6, obj.minimum); 596 | if (obj.allowed_values) pbf.writeMessage(7, AttrValue.write, obj.allowed_values); 597 | }; 598 | 599 | // OpDef.OpDeprecation ======================================== 600 | 601 | OpDef.OpDeprecation = {}; 602 | 603 | OpDef.OpDeprecation.read = function (pbf, end) { 604 | return pbf.readFields(OpDef.OpDeprecation._readField, {version: 0, explanation: ""}, end); 605 | }; 606 | OpDef.OpDeprecation._readField = function (tag, obj, pbf) { 607 | if (tag === 1) obj.version = pbf.readVarint(true); 608 | else if (tag === 2) obj.explanation = pbf.readString(); 609 | }; 610 | OpDef.OpDeprecation.write = function (obj, pbf) { 611 | if (obj.version) pbf.writeVarintField(1, obj.version); 612 | if (obj.explanation) pbf.writeStringField(2, obj.explanation); 613 | }; 614 | 615 | // OpList ======================================== 616 | 617 | var OpList = exports.OpList = {}; 618 | 619 | OpList.read = function (pbf, end) { 620 | return pbf.readFields(OpList._readField, {op: []}, end); 621 | }; 622 | OpList._readField = function (tag, obj, pbf) { 623 | if (tag === 1) obj.op.push(OpDef.read(pbf, pbf.readVarint() + pbf.pos)); 624 | }; 625 | OpList.write = function (obj, pbf) { 626 | if (obj.op) for (var i = 0; i < obj.op.length; i++) pbf.writeMessage(1, OpDef.write, obj.op[i]); 627 | }; 628 | 629 | // MetaGraphDef ======================================== 630 | 631 | var MetaGraphDef = exports.MetaGraphDef = {}; 632 | 633 | MetaGraphDef.read = function (pbf, end) { 634 | return pbf.readFields(MetaGraphDef._readField, {meta_info_def: null, graph_def: null, saver_def: null, collection_def: {}, signature_def: {}, asset_file_def: []}, end); 635 | }; 636 | MetaGraphDef._readField = function (tag, obj, pbf) { 637 | if (tag === 1) obj.meta_info_def = MetaGraphDef.MetaInfoDef.read(pbf, pbf.readVarint() + pbf.pos); 638 | else if (tag === 2) obj.graph_def = GraphDef.read(pbf, pbf.readVarint() + pbf.pos); 639 | else if (tag === 3) obj.saver_def = SaverDef.read(pbf, pbf.readVarint() + pbf.pos); 640 | else if (tag === 4) { var entry = MetaGraphDef._FieldEntry4.read(pbf, pbf.readVarint() + pbf.pos); obj.collection_def[entry.key] = entry.value; } 641 | else if (tag === 5) { entry = MetaGraphDef._FieldEntry5.read(pbf, pbf.readVarint() + pbf.pos); obj.signature_def[entry.key] = entry.value; } 642 | else if (tag === 6) obj.asset_file_def.push(AssetFileDef.read(pbf, pbf.readVarint() + pbf.pos)); 643 | }; 644 | MetaGraphDef.write = function (obj, pbf) { 645 | if (obj.meta_info_def) pbf.writeMessage(1, MetaGraphDef.MetaInfoDef.write, obj.meta_info_def); 646 | if (obj.graph_def) pbf.writeMessage(2, GraphDef.write, obj.graph_def); 647 | if (obj.saver_def) pbf.writeMessage(3, SaverDef.write, obj.saver_def); 648 | if (obj.collection_def) for (var i in obj.collection_def) if (Object.prototype.hasOwnProperty.call(obj.collection_def, i)) pbf.writeMessage(4, MetaGraphDef._FieldEntry4.write, { key: i, value: obj.collection_def[i] }); 649 | if (obj.signature_def) for (i in obj.signature_def) if (Object.prototype.hasOwnProperty.call(obj.signature_def, i)) pbf.writeMessage(5, MetaGraphDef._FieldEntry5.write, { key: i, value: obj.signature_def[i] }); 650 | if (obj.asset_file_def) for (i = 0; i < obj.asset_file_def.length; i++) pbf.writeMessage(6, AssetFileDef.write, obj.asset_file_def[i]); 651 | }; 652 | 653 | // MetaGraphDef.MetaInfoDef ======================================== 654 | 655 | MetaGraphDef.MetaInfoDef = {}; 656 | 657 | MetaGraphDef.MetaInfoDef.read = function (pbf, end) { 658 | return pbf.readFields(MetaGraphDef.MetaInfoDef._readField, {meta_graph_version: "", stripped_op_list: null, any_info: null, tags: [], tensorflow_version: "", tensorflow_git_version: ""}, end); 659 | }; 660 | MetaGraphDef.MetaInfoDef._readField = function (tag, obj, pbf) { 661 | if (tag === 1) obj.meta_graph_version = pbf.readString(); 662 | else if (tag === 2) obj.stripped_op_list = OpList.read(pbf, pbf.readVarint() + pbf.pos); 663 | else if (tag === 3) obj.any_info = Any.read(pbf, pbf.readVarint() + pbf.pos); 664 | else if (tag === 4) obj.tags.push(pbf.readString()); 665 | else if (tag === 5) obj.tensorflow_version = pbf.readString(); 666 | else if (tag === 6) obj.tensorflow_git_version = pbf.readString(); 667 | }; 668 | MetaGraphDef.MetaInfoDef.write = function (obj, pbf) { 669 | if (obj.meta_graph_version) pbf.writeStringField(1, obj.meta_graph_version); 670 | if (obj.stripped_op_list) pbf.writeMessage(2, OpList.write, obj.stripped_op_list); 671 | if (obj.any_info) pbf.writeMessage(3, Any.write, obj.any_info); 672 | if (obj.tags) for (var i = 0; i < obj.tags.length; i++) pbf.writeStringField(4, obj.tags[i]); 673 | if (obj.tensorflow_version) pbf.writeStringField(5, obj.tensorflow_version); 674 | if (obj.tensorflow_git_version) pbf.writeStringField(6, obj.tensorflow_git_version); 675 | }; 676 | 677 | // MetaGraphDef._FieldEntry4 ======================================== 678 | 679 | MetaGraphDef._FieldEntry4 = {}; 680 | 681 | MetaGraphDef._FieldEntry4.read = function (pbf, end) { 682 | return pbf.readFields(MetaGraphDef._FieldEntry4._readField, {key: "", value: null}, end); 683 | }; 684 | MetaGraphDef._FieldEntry4._readField = function (tag, obj, pbf) { 685 | if (tag === 1) obj.key = pbf.readString(); 686 | else if (tag === 2) obj.value = CollectionDef.read(pbf, pbf.readVarint() + pbf.pos); 687 | }; 688 | MetaGraphDef._FieldEntry4.write = function (obj, pbf) { 689 | if (obj.key) pbf.writeStringField(1, obj.key); 690 | if (obj.value) pbf.writeMessage(2, CollectionDef.write, obj.value); 691 | }; 692 | 693 | // MetaGraphDef._FieldEntry5 ======================================== 694 | 695 | MetaGraphDef._FieldEntry5 = {}; 696 | 697 | MetaGraphDef._FieldEntry5.read = function (pbf, end) { 698 | return pbf.readFields(MetaGraphDef._FieldEntry5._readField, {key: "", value: null}, end); 699 | }; 700 | MetaGraphDef._FieldEntry5._readField = function (tag, obj, pbf) { 701 | if (tag === 1) obj.key = pbf.readString(); 702 | else if (tag === 2) obj.value = SignatureDef.read(pbf, pbf.readVarint() + pbf.pos); 703 | }; 704 | MetaGraphDef._FieldEntry5.write = function (obj, pbf) { 705 | if (obj.key) pbf.writeStringField(1, obj.key); 706 | if (obj.value) pbf.writeMessage(2, SignatureDef.write, obj.value); 707 | }; 708 | 709 | // SavedModel ======================================== 710 | 711 | var SavedModel = exports.SavedModel = {}; 712 | 713 | SavedModel.read = function (pbf, end) { 714 | return pbf.readFields(SavedModel._readField, {saved_model_schema_version: 0, meta_graphs: []}, end); 715 | }; 716 | SavedModel._readField = function (tag, obj, pbf) { 717 | if (tag === 1) obj.saved_model_schema_version = pbf.readVarint(true); 718 | else if (tag === 2) obj.meta_graphs.push(MetaGraphDef.read(pbf, pbf.readVarint() + pbf.pos)); 719 | }; 720 | SavedModel.write = function (obj, pbf) { 721 | if (obj.saved_model_schema_version) pbf.writeVarintField(1, obj.saved_model_schema_version); 722 | if (obj.meta_graphs) for (var i = 0; i < obj.meta_graphs.length; i++) pbf.writeMessage(2, MetaGraphDef.write, obj.meta_graphs[i]); 723 | }; 724 | -------------------------------------------------------------------------------- /src/interop/messages.proto: -------------------------------------------------------------------------------- 1 | // messages.js.proto 2 | // Definition of various TensorFlow protobuf messages for use with the TensorFlow API. 3 | // 4 | // Assembled from these relevant proto sources: 5 | // https://github.com/google/protobuf/blob/master/src/google/protobuf/any.proto 6 | // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/types.proto 7 | // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/tensor.proto 8 | // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/tensor_shape.proto 9 | // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/attr_value.proto 10 | // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/versions.proto 11 | // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/graph.proto 12 | // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/node_def.proto 13 | // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/op_def.proto 14 | // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/protobuf/saver.proto 15 | // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/protobuf/meta_graph.proto 16 | // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/protobuf/saved_model.proto 17 | // 18 | // These definitions are used to produce messages.js using the following command: 19 | // ../../node_modules/pbf/bin/pbf messages.proto > messages.js 20 | // 21 | 22 | syntax = "proto3"; 23 | package tensorflow; 24 | 25 | message Any { 26 | string type_url = 1; 27 | bytes value = 2; 28 | } 29 | 30 | enum DataType { 31 | // Not a legal value for DataType. Used to indicate a DataType field 32 | // has not been set. 33 | DT_INVALID = 0; 34 | 35 | // Data types that all computation devices are expected to be 36 | // capable to support. 37 | DT_FLOAT = 1; 38 | DT_DOUBLE = 2; 39 | DT_INT32 = 3; 40 | DT_UINT8 = 4; 41 | DT_INT16 = 5; 42 | DT_INT8 = 6; 43 | DT_STRING = 7; 44 | DT_COMPLEX64 = 8; // Single-precision complex 45 | DT_INT64 = 9; 46 | DT_BOOL = 10; 47 | DT_QINT8 = 11; // Quantized int8 48 | DT_QUINT8 = 12; // Quantized uint8 49 | DT_QINT32 = 13; // Quantized int32 50 | DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops. 51 | 52 | // Do not use! These are only for parameters. Every enum above 53 | // should have a corresponding value below (verified by types_test). 54 | DT_FLOAT_REF = 101; 55 | DT_DOUBLE_REF = 102; 56 | DT_INT32_REF = 103; 57 | DT_UINT8_REF = 104; 58 | DT_INT16_REF = 105; 59 | DT_INT8_REF = 106; 60 | DT_STRING_REF = 107; 61 | DT_COMPLEX64_REF = 108; 62 | DT_INT64_REF = 109; 63 | DT_BOOL_REF = 110; 64 | DT_QINT8_REF = 111; 65 | DT_QUINT8_REF = 112; 66 | DT_QINT32_REF = 113; 67 | DT_BFLOAT16_REF = 114; 68 | } 69 | 70 | message TensorShape { 71 | // One dimension of the tensor. 72 | message Dim { 73 | // Size of the tensor in that dimension. 74 | int64 size = 1; 75 | 76 | // Optional name of the tensor dimension. 77 | string name = 2; 78 | } 79 | 80 | // Dimensions of the tensor, such as {"input", 30}, {"output", 40} for a 30 x 81 | // 40 2D tensor. The names are optional. 82 | // 83 | // The order of entries in "dim" matters: It indicates the layout of the 84 | // values in the tensor in-memory representation. 85 | // 86 | // The first entry in "dim" is the outermost dimension used to layout the 87 | // values, the last entry is the innermost dimension. This matches the 88 | // in-memory layout of RowMajor Eigen tensors. 89 | repeated Dim dim = 2; 90 | 91 | bool unknown_rank = 3; 92 | } 93 | 94 | message Tensor { 95 | DataType dtype = 1; 96 | 97 | // Shape of the tensor. TODO(touts): sort out the 0-rank issues. 98 | TensorShape tensor_shape = 2; 99 | 100 | // Only one of the representations below is set, one of "tensor_contents" and 101 | // the "xxx_val" attributes. We are not using oneof because as oneofs cannot 102 | // contain repeated fields it would require another extra set of messages. 103 | 104 | // Version number. 105 | // 106 | // In version 0, if the "repeated xxx" representations contain only one 107 | // element, that element is repeated to fill the shape. This makes it easy 108 | // to represent a constant Tensor with a single value. 109 | int32 version_number = 3; 110 | 111 | // Serialized content from TensorBase::Serialize() This representation can be 112 | // used for all tensor types. 113 | bytes tensor_content = 4; 114 | 115 | // Type specific representations that make it easy to create tensor protos in 116 | // all languages. Only the representation corresponding to "dtype" can 117 | // be set. The values hold the flattened representation of the tensor in 118 | // row major order. 119 | 120 | // DT_FLOAT. 121 | repeated float float_val = 5 [packed = true]; 122 | 123 | // DT_DOUBLE. 124 | repeated double double_val = 6 [packed = true]; 125 | 126 | // DT_INT32, DT_INT16, DT_INT8, DT_UINT8. 127 | repeated int32 int_val = 7 [packed = true]; 128 | 129 | // DT_STRING 130 | repeated bytes string_val = 8; 131 | 132 | // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real 133 | // and imaginary parts of i-th single precision complex. 134 | repeated float scomplex_val = 9 [packed = true]; 135 | 136 | // DT_INT64 137 | repeated int64 int64_val = 10 [packed = true]; 138 | 139 | // DT_BOOL 140 | repeated bool bool_val = 11 [packed = true]; 141 | 142 | // DT_UINT32 143 | repeated uint32 uint32_val = 16 [packed = true]; 144 | 145 | // DT_UINT64 146 | repeated uint64 uint64_val = 17 [packed = true]; 147 | } 148 | 149 | message AttrValue { 150 | message ListValue { 151 | repeated bytes s = 2; 152 | repeated int64 i = 3 [packed = true]; 153 | repeated float f = 4 [packed = true]; 154 | repeated bool b = 5 [packed = true]; 155 | repeated DataType type = 6 [packed = true]; 156 | repeated TensorShape shape = 7; 157 | repeated Tensor tensor = 8; 158 | repeated NameAttrList func = 9; 159 | } 160 | 161 | oneof value { 162 | ListValue list = 1; 163 | bytes s = 2; 164 | int64 i = 3; 165 | float f = 4; 166 | bool b = 5; 167 | DataType type = 6; 168 | TensorShape shape = 7; 169 | Tensor tensor = 8; 170 | string placeholder = 9; 171 | NameAttrList func = 10; 172 | } 173 | } 174 | 175 | message NameAttrList { 176 | string name = 1; 177 | map attr = 2; 178 | } 179 | 180 | message NodeDef { 181 | string name = 1; 182 | string op = 2; 183 | repeated string input = 3; 184 | string device = 4; 185 | map attr = 5; 186 | } 187 | 188 | message VersionDef { 189 | int32 producer = 1; 190 | int32 min_consumer = 2; 191 | repeated int32 bad_consumers = 3; 192 | } 193 | 194 | message GraphDef { 195 | repeated NodeDef node = 1; 196 | VersionDef versions = 4; 197 | } 198 | 199 | message CollectionDef { 200 | message NodeList { 201 | repeated string value = 1; 202 | } 203 | message BytesList { 204 | repeated bytes value = 1; 205 | } 206 | message Int64List { 207 | repeated int64 value = 1 [packed = true]; 208 | } 209 | message FloatList { 210 | repeated float value = 1 [packed = true]; 211 | } 212 | message AnyList { 213 | repeated Any value = 1; 214 | } 215 | 216 | oneof kind { 217 | NodeList node_list = 1; 218 | BytesList bytes_list = 2; 219 | Int64List int64_list = 3; 220 | FloatList float_list = 4; 221 | AnyList any_list = 5; 222 | } 223 | } 224 | 225 | message SaverDef { 226 | string filename_tensor_name = 1; 227 | string save_tensor_name = 2; 228 | string restore_op_name = 3; 229 | int32 max_to_keep = 4; 230 | bool sharded = 5; 231 | float keep_checkpoint_every_n_hours = 6; 232 | 233 | enum CheckpointFormatVersion { 234 | LEGACY = 0; 235 | V1 = 1; 236 | V2 = 2; 237 | } 238 | CheckpointFormatVersion version = 7; 239 | } 240 | 241 | message TensorInfo { 242 | message CooSparse { 243 | string values_tensor_name = 1; 244 | string indices_tensor_name = 2; 245 | string dense_shape_tensor_name = 3; 246 | } 247 | 248 | oneof encoding { 249 | string name = 1; 250 | CooSparse coo_sparse = 4; 251 | } 252 | DataType dtype = 2; 253 | TensorShape tensor_shape = 3; 254 | } 255 | 256 | message SignatureDef { 257 | map inputs = 1; 258 | map outputs = 2; 259 | string method_name = 3; 260 | } 261 | 262 | message AssetFileDef { 263 | TensorInfo tensor_info = 1; 264 | string filename = 2; 265 | } 266 | 267 | message OpDef { 268 | string name = 1; 269 | 270 | message ArgDef { 271 | string name = 1; 272 | string description = 2; 273 | DataType type = 3; 274 | string type_attr = 4; // if specified, attr must have type "type" 275 | string number_attr = 5; // if specified, attr must have type "int" 276 | string type_list_attr = 6; 277 | bool is_ref = 16; 278 | } 279 | repeated ArgDef input_arg = 2; 280 | repeated ArgDef output_arg = 3; 281 | 282 | message AttrDef { 283 | string name = 1; 284 | string type = 2; 285 | AttrValue default_value = 3; 286 | string description = 4; 287 | bool has_minimum = 5; 288 | int64 minimum = 6; 289 | AttrValue allowed_values = 7; 290 | } 291 | repeated AttrDef attr = 4; 292 | 293 | message OpDeprecation { 294 | int32 version = 1; 295 | string explanation = 2; 296 | } 297 | OpDeprecation deprecation = 8; 298 | 299 | string summary = 5; 300 | string description = 6; 301 | bool is_commutative = 18; 302 | bool is_aggregate = 16; // for things like add 303 | bool is_stateful = 17; // for things like variables, queue 304 | bool allows_uninitialized_input = 19; // for Assign, etc. 305 | } 306 | 307 | message OpList { 308 | repeated OpDef op = 1; 309 | } 310 | 311 | message MetaGraphDef { 312 | message MetaInfoDef { 313 | string meta_graph_version = 1; 314 | OpList stripped_op_list = 2; 315 | Any any_info = 3; 316 | repeated string tags = 4; 317 | string tensorflow_version = 5; 318 | string tensorflow_git_version = 6; 319 | } 320 | MetaInfoDef meta_info_def = 1; 321 | GraphDef graph_def = 2; 322 | SaverDef saver_def = 3; 323 | map collection_def = 4; 324 | map signature_def = 5; 325 | repeated AssetFileDef asset_file_def = 6; 326 | } 327 | 328 | message SavedModel { 329 | int64 saved_model_schema_version = 1; 330 | repeated MetaGraphDef meta_graphs = 2; 331 | } 332 | -------------------------------------------------------------------------------- /src/interop/serializers.js: -------------------------------------------------------------------------------- 1 | // serializers.js 2 | // Implements serialization logic to convert between raw C Tensor representation and script types. 3 | // 4 | 5 | 'use strict'; 6 | 7 | const api = require('./api'), 8 | os = require('os'), 9 | ref = require('ref'); 10 | 11 | class GenericSerializer { 12 | 13 | fromBuffer(buffer, shape) { 14 | return buffer; 15 | } 16 | 17 | toBuffer(data) { 18 | if (Buffer.isBuffer(data)) { 19 | return data; 20 | } 21 | 22 | throw new Error('Unsupported Tensor data.'); 23 | } 24 | } 25 | 26 | class NumberSerializer { 27 | 28 | constructor(readFn, size) { 29 | this._readFn = Buffer.prototype[readFn]; 30 | this._size = size; 31 | } 32 | 33 | fromBuffer(buffer, shape) { 34 | if (shape.length === 0) { 35 | // Scalar tensor value 36 | return this._readFn.call(buffer, 0); 37 | } 38 | else { 39 | let array = createItemArray(shape); 40 | 41 | for (let i = 0; i < array.length; i++) { 42 | array[i] = this._readFn.call(buffer, i * this._size); 43 | } 44 | 45 | return array; 46 | } 47 | } 48 | } 49 | 50 | class Int32Serializer extends NumberSerializer { 51 | 52 | constructor() { 53 | super('readInt32' + os.endianness(), /* size */ 4) 54 | } 55 | 56 | toBuffer(data) { 57 | return api.ApiTypes.IntArray(data).buffer; 58 | } 59 | } 60 | 61 | class FloatSerializer extends NumberSerializer { 62 | 63 | constructor() { 64 | super('readFloat' + os.endianness(), /* size */ 4) 65 | } 66 | 67 | toBuffer(data) { 68 | return api.ApiTypes.FloatArray(data).buffer; 69 | } 70 | } 71 | 72 | class Int64Serializer { 73 | 74 | fromBuffer(buffer, shape) { 75 | // TODO: Handle representing int64 values in script (likely using the int64-buffer module) 76 | return buffer; 77 | } 78 | 79 | toBuffer(data) { 80 | return api.ApiTypes.LongLongArray(data).buffer; 81 | } 82 | } 83 | 84 | class StringSerializer { 85 | 86 | // TODO: Handle the case where the input is not a string, but represented as a Buffer. 87 | 88 | // Strings are encoded as an array of uint64 values containing offsets into the buffer for 89 | // each string. Each string is a 7-bit encoded length prefix followed by the bytes. 90 | 91 | fromBuffer(buffer, shape) { 92 | let array = createItemArray(shape); 93 | let count = array.length; 94 | 95 | let header = 8 * count; 96 | for (let i = 0; i < count; i++) { 97 | let offset = ref.readUInt64LE(buffer, i * 8); 98 | let nextOffset = i === count - 1 ? buffer.size : ref.readUInt64LE(buffer, i * 8 + 8) 99 | 100 | let srcLength = nextOffset - offset; 101 | let srcBuffer = buffer.reinterpret(srcLength, offset + header); 102 | 103 | let decodedBuffer = ref.alloc(api.ApiTypes.BytePtr); 104 | let decodedLength = ref.alloc(api.ApiTypes.Size); 105 | api.TF_StringDecode(srcBuffer, srcLength, decodedBuffer, decodedLength, api.Status); 106 | 107 | decodedLength = decodedLength.deref(); 108 | decodedBuffer = decodedBuffer.deref(); 109 | array[i] = decodedBuffer.reinterpret(decodedLength, 0).toString('binary'); 110 | } 111 | 112 | return array; 113 | } 114 | 115 | toBuffer(data) { 116 | let maxLength = 0; 117 | let size = 0; 118 | data = data.map((s) => { 119 | let length = Buffer.byteLength(s, 'binary'); 120 | let encodedLength = api.TF_StringEncodedSize(length); 121 | 122 | let item = { s: s, offset: size, length: length, encodedLength: encodedLength }; 123 | size += encodedLength; 124 | 125 | maxLength = Math.max(maxLength, length); 126 | return item; 127 | }); 128 | 129 | // Add for the list of offsets (uint64 value per string) 130 | let header = 8 * data.length; 131 | size += header; 132 | 133 | // Allocate a buffer that is large enough to hold the longest string; Add 1 for trailing null. 134 | let srcBuffer = new Buffer(maxLength + 1); 135 | 136 | let buffer = new Buffer(size); 137 | data.forEach((item, i) => { 138 | ref.writeUInt64LE(buffer, i * 8, item.offset); 139 | 140 | ref.writeCString(srcBuffer, 0, item.s, 'binary'); 141 | 142 | let destBuffer = ref.reinterpret(buffer, item.encodedLength, item.offset + header); 143 | api.TF_StringEncode(srcBuffer, item.length, destBuffer, item.encodedLength, api.Status); 144 | }); 145 | 146 | return buffer; 147 | } 148 | } 149 | 150 | 151 | function createItemArray(shape) { 152 | // The number of items is the product of the dimensions specified by the shape. 153 | let totalItems = shape.reduce(function(dim, items) { return dim * items}, 1); 154 | return new Array(totalItems); 155 | } 156 | 157 | 158 | const _genericSerializer = new GenericSerializer(); 159 | const _serializers = {}; 160 | _serializers[api.Types.int32.toString()] = new Int32Serializer(); 161 | _serializers[api.Types.int64.toString()] = new Int64Serializer(); 162 | _serializers[api.Types.float.toString()] = new FloatSerializer(); 163 | _serializers[api.Types.string.toString()] = new StringSerializer(); 164 | 165 | function createSerializer(type) { 166 | return _serializers[type.toString()] || _genericSerializer; 167 | } 168 | 169 | 170 | module.exports = { 171 | create: createSerializer 172 | }; 173 | -------------------------------------------------------------------------------- /src/session.js: -------------------------------------------------------------------------------- 1 | // session.js 2 | // Implements the Session class to wrap a TensorFlow session, and session.run functionality. 3 | // 4 | 5 | 'use strict'; 6 | 7 | const api = require('./interop/api'), 8 | tensor = require('./tensor'); 9 | 10 | 11 | class Session { 12 | 13 | constructor(sessionHandle, graphHandle, graphOps) { 14 | this._sessionHandle = sessionHandle; 15 | this._graphHandle = graphHandle; 16 | this._graphOps = graphOps; 17 | } 18 | 19 | delete() { 20 | if (this._sessionHandle) { 21 | api.TF_DeleteSession(this._sessionHandle, api.Status); 22 | this._sessionHandle = null; 23 | } 24 | } 25 | 26 | _ensureValid() { 27 | if (!this._sessionHandle) { 28 | throw new Error('The Session instance has been closed and deleted.'); 29 | } 30 | } 31 | 32 | run(inputs, outputs, targets) { 33 | this._ensureValid(); 34 | 35 | let singleOutput = false; 36 | if (outputs && !Array.isArray(outputs)) { 37 | outputs = [outputs]; 38 | singleOutput = true; 39 | } 40 | if (targets && !Array.isArray(targets)) { 41 | targets = [targets]; 42 | } 43 | 44 | let params = createRunParameters(this._graphHandle, this._graphOps, 45 | inputs, outputs, targets); 46 | 47 | api.TF_SessionRun(this._sessionHandle, 48 | /* options */ null, 49 | params.inputOps, params.inputTensors, params.inputs, 50 | params.outputOps, params.outputTensors, params.outputs, 51 | params.targetOps, params.targets, 52 | /* metadata */ null, 53 | api.Status); 54 | 55 | if (params.inputs) { 56 | for (let i = 0; i < params.inputs; i++) { 57 | api.TF_DeleteTensor(params.inputTensors[i]); 58 | } 59 | } 60 | 61 | let code = api.TF_GetCode(api.Status); 62 | if (code !== api.StatusCodes.ok) { 63 | let message = api.TF_Message(api.Status); 64 | throw new Error(message); 65 | } 66 | 67 | let results = undefined; 68 | if (params.outputs) { 69 | results = createRunResults(outputs, params.outputTensors, singleOutput); 70 | 71 | for (let i = 0; i < params.outputs; i++) { 72 | api.TF_DeleteTensor(params.outputTensors[i]); 73 | } 74 | } 75 | 76 | return results; 77 | } 78 | } 79 | 80 | 81 | function createSession(graphHandle, graphOps) { 82 | let sessionOptions = api.TF_NewSessionOptions(); 83 | let sessionHandle = api.TF_NewSession(graphHandle, sessionOptions, api.Status); 84 | let code = api.TF_GetCode(api.Status); 85 | 86 | api.TF_DeleteSessionOptions(sessionOptions); 87 | 88 | if (api.TF_GetCode(api.Status) !== api.StatusCodes.ok) { 89 | let error = api.TF_Message(api.Status); 90 | throw new Error(error); 91 | } 92 | 93 | return new Session(sessionHandle, graphHandle, graphOps); 94 | } 95 | 96 | function createRunParameters(graphHandle, ops, inputs, outputs, targets) { 97 | let params = { 98 | inputOps: null, 99 | inputTensors: null, 100 | inputs: 0, 101 | outputOps: null, 102 | outputTensors: null, 103 | outputs: 0, 104 | targetOps: null, 105 | targets: 0 106 | }; 107 | 108 | if (inputs) { 109 | params.inputOps = []; 110 | params.inputTensors = []; 111 | 112 | for (let op in inputs) { 113 | let parts = op.split(':'); 114 | let name = parts[0]; 115 | let opReference = resolveOp(graphHandle, ops, name); 116 | 117 | params.inputOps.push(api.ApiTypes.OperationValue({op: opReference, index: parts[1] || 0})); 118 | params.inputTensors.push(tensor.toHandle(inputs[op])); 119 | 120 | params.inputs++; 121 | } 122 | 123 | params.inputOps = api.ApiTypes.OperationValueArray(params.inputOps); 124 | params.inputTensors = api.ApiTypes.TensorArray(params.inputTensors); 125 | } 126 | 127 | if (outputs) { 128 | params.outputs = outputs.length; 129 | params.outputOps = outputs.map((o) => { 130 | let parts = o.split(':'); 131 | let name = parts[0]; 132 | let opReference = resolveOp(graphHandle, ops, name) 133 | 134 | return api.ApiTypes.OperationValue({op: opReference, index: parts[1] || 0}); 135 | }); 136 | params.outputOps = api.ApiTypes.OperationValueArray(params.outputOps); 137 | params.outputTensors = api.ApiTypes.TensorArray(params.outputs); 138 | } 139 | 140 | if (targets) { 141 | params.targets = targets.length; 142 | params.targetOps = targets.map((name) => { 143 | return resolveOp(graphHandle, ops, name); 144 | }); 145 | params.targetOps = api.ApiTypes.OperationArray(params.targetOps); 146 | } 147 | 148 | return params; 149 | } 150 | 151 | function createRunResults(outputs, outputTensors, singleOutput) { 152 | if (singleOutput) { 153 | return tensor.fromHandle(outputTensors[0]); 154 | } 155 | 156 | let results = {}; 157 | outputs.forEach((name, i) => { 158 | results[name] = tensor.fromHandle(outputTensors[i]); 159 | }); 160 | 161 | return results; 162 | } 163 | 164 | function resolveOp(graphHandle, opCache, name) { 165 | let op = opCache[name]; 166 | if (op !== undefined) { 167 | return op; 168 | } 169 | 170 | op = api.TF_GraphOperationByName(graphHandle, name); 171 | if (op && !op.isNull()) { 172 | opCache[name] = op; 173 | return op; 174 | } 175 | 176 | throw new Error(`An operation with the name "${name}" was not found in the graph.`); 177 | } 178 | 179 | 180 | module.exports = { 181 | create: createSession 182 | }; 183 | -------------------------------------------------------------------------------- /src/tensor.js: -------------------------------------------------------------------------------- 1 | // tensor.js 2 | // Implements the Tensor class to represent tensor data and encapsulates marshalling logic. 3 | // 4 | 5 | 'use strict'; 6 | 7 | const api = require('./interop/api'), 8 | serializers = require('./interop/serializers'); 9 | 10 | 11 | class Tensor { 12 | 13 | constructor(value, type, shape) { 14 | this._value = value; 15 | this._type = type; 16 | this._shape = shape; 17 | } 18 | 19 | get shape() { 20 | return this._shape; 21 | } 22 | 23 | get type() { 24 | return this._type; 25 | } 26 | 27 | get value() { 28 | return this._value; 29 | } 30 | } 31 | 32 | 33 | function createTensor(value, type, shape) { 34 | if ((value === null) || (value === undefined)) { 35 | throw new Error('A value representing the Tensor data must be specified.'); 36 | } 37 | 38 | if (value.constructor === Tensor) { 39 | return value; 40 | } 41 | 42 | if (Buffer.isBuffer(value)) { 43 | if ((type === null) || (type === undefined) || (shape === null) || (shape === undefined)) { 44 | throw new Error('The type and shape of a raw tensor data buffer must be specified.'); 45 | } 46 | 47 | return new Tensor(value, type, shape); 48 | } 49 | 50 | if ((shape === null) || (shape === undefined)) { 51 | shape = calculateShape(value); 52 | 53 | // Convert to value to a flat array 54 | if (shape.length === 0) { 55 | // Ensure the value is represented as an array, even for scalars 56 | value = [value]; 57 | } 58 | else if (shape.length > 1) { 59 | // Flatten the value, so it can be converted into a buffer containing all the values. 60 | value = flattenList(value); 61 | } 62 | 63 | if (type === undefined) { 64 | if (value[0].constructor == Number) { 65 | type = api.Types.float; 66 | } 67 | else if (value[0].constructor == String) { 68 | type = api.Types.string; 69 | } 70 | else { 71 | throw new Error('Unsupported data type for creating a Tensor'); 72 | } 73 | } 74 | } 75 | 76 | return new Tensor(value, type, shape); 77 | } 78 | 79 | function createHandleFromTensor(value) { 80 | let tensor = createTensor(value); 81 | 82 | // Convert to a buffer containing raw byte representation of the Tensor 83 | let serializer = serializers.create(tensor.type); 84 | let data = serializer.toBuffer(tensor.value); 85 | 86 | return api.TF_NewTensor(tensor.type, 87 | api.ApiTypes.LongLongArray(tensor.shape), tensor.shape.length, 88 | data, data.length, 89 | api.TensorDeallocator, null); 90 | } 91 | 92 | function createTensorFromHandle(tensorHandle) { 93 | let shape = []; 94 | 95 | let dimensions = api.TF_NumDims(tensorHandle); 96 | for (let i = 0; i < dimensions; i++) { 97 | shape.push(api.TF_Dim(tensorHandle, i)); 98 | } 99 | 100 | // Read data into a buffer and reset the current position in the buffer to be at the start. 101 | let dataLength = api.TF_TensorByteSize(tensorHandle); 102 | let data = api.TF_TensorData(tensorHandle); 103 | data = data.reinterpret(dataLength, 0); 104 | 105 | let type = api.TF_TensorType(tensorHandle); 106 | let serializer = serializers.create(type); 107 | 108 | let value = serializer.fromBuffer(data, shape); 109 | if ((shape.length > 1) && !Buffer.isBuffer(value)) { 110 | value = reshapeList(value, shape); 111 | } 112 | 113 | return new Tensor(value, type, shape); 114 | } 115 | 116 | function calculateShape(value) { 117 | if ((value.shape !== undefined) && (value.shape !== null)) { 118 | return value.shape; 119 | } 120 | 121 | // Detect the shape by walking the arrays (to handle nested arrays). This assumes the arrays 122 | // are not jagged. 123 | let shape = []; 124 | 125 | let element = value; 126 | while (Array.isArray(element)) { 127 | shape.push(element.length); 128 | element = element[0]; 129 | } 130 | 131 | return shape; 132 | } 133 | 134 | // Flatten the list. This is only relevant for multi-dimensional tensors. 135 | function flattenList(list) { 136 | // Make a copy, so the original tensor is not modified. 137 | list = [].concat(list); 138 | 139 | // Note that i must be checked against the length of the list each time through the loop, as the 140 | // list is modified within the iterations. 141 | for (let i = 0; i < list.length; i++) { 142 | if (Array.isArray(list[i])) { 143 | // Replace the item with the flattened version of the item (using the ... operator). 144 | // Replace with the items and backtrack 1 position 145 | list.splice(i, 1, ...list[i]); 146 | 147 | // Decrement i to look at the element again; we'll keep looking at this i index, until 148 | // the most deeply nested item has been flattened. 149 | i--; 150 | } 151 | } 152 | 153 | return list; 154 | } 155 | 156 | function reshapeList(list, shape) { 157 | // This modifies the list in place, given this is run on a temporary list; hence avoiding 158 | // copying cost. 159 | 160 | // Work from the inner-most dimension to outer-most, building up arrays of items matching 161 | // dimension length (this is essentially the inverse of the tensorToList implementation). 162 | for (let i = shape.length - 1; i > 0; i--) { 163 | let dimension = shape[i]; 164 | 165 | for (let j = 0; j < list.length; j++) { 166 | let items = list.splice(j, dimension); 167 | list.splice(j, 0, items); 168 | } 169 | } 170 | 171 | return list; 172 | } 173 | 174 | 175 | module.exports = { 176 | create: createTensor, 177 | fromHandle: createTensorFromHandle, 178 | toHandle: createHandleFromTensor 179 | }; 180 | --------------------------------------------------------------------------------