├── lib ├── Graph.js ├── SessionOptions.js ├── Session.js └── Operation.js ├── .gitignore ├── scripts ├── test.js ├── swig.sh └── install.js ├── package.json ├── LICENSE ├── index.js ├── test ├── session.js └── c-api.js ├── binding.gyp ├── README.md └── src └── tensorflow.i /lib/Graph.js: -------------------------------------------------------------------------------- 1 | 'use strict'; 2 | 3 | module.exports = function (bindings) { 4 | var Graph = bindings.Graph = function Graph() { 5 | this.instance = new bindings.TF_NewGraph(); 6 | } 7 | 8 | Graph.prototype = Object.create(bindings.OperationTarget); 9 | }; 10 | -------------------------------------------------------------------------------- /lib/SessionOptions.js: -------------------------------------------------------------------------------- 1 | 'use strict'; 2 | 3 | module.exports = function (bindings) { 4 | var SessionOptions = 5 | bindings.SessionOptions = 6 | function SessionOptions(options) { 7 | this.opts =options; 8 | this.instance = 9 | new bindings.TF_NewSessionOptions(); 10 | } 11 | }; 12 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | logs 2 | *.log 3 | npm-debug.log* 4 | lib-cov 5 | coverage 6 | .nyc_output 7 | .grunt 8 | .lock-wscript 9 | build/Release 10 | node_modules 11 | jspm_packages 12 | .npm 13 | .node_repl_history 14 | /build 15 | /bazel-* 16 | /vendor/* 17 | !/vendor/BUILD 18 | !/vendor/WORKSPACE 19 | !/vendor/workspace.bzl 20 | /.vscode 21 | .npmignore 22 | -------------------------------------------------------------------------------- /scripts/test.js: -------------------------------------------------------------------------------- 1 | 'use strict'; 2 | 3 | var fs = require('fs'), 4 | path = require('path'), 5 | mocha = require('mocha'), 6 | LIB_TEST_PATH = path.resolve(__dirname, '../test'), 7 | CONTRIB_PATH = path.resolve(__dirname, '../contrib/test'); 8 | 9 | [LIB_TEST_PATH, CONTRIB_PATH].forEach(function (dir) { 10 | fs.readdirSync(dir) 11 | .filter(function(file) { 12 | return path.extname(file) === '.js'; 13 | }) 14 | .forEach(function(file) { 15 | mocha.addFile(path.resolve(dir, file)); 16 | }); 17 | }); 18 | 19 | mocha.run(function (failures) { 20 | process.on('exit', function () { 21 | process.exit(failures); 22 | }); 23 | }); 24 | -------------------------------------------------------------------------------- /lib/Session.js: -------------------------------------------------------------------------------- 1 | 'use strict'; 2 | 3 | module.exports = function (bindings) { 4 | var Session = bindings.Session = function Session(graph, options) { 5 | if (graph instanceof bindings.Graph) { 6 | this.graph = graph; 7 | } else { 8 | if (graph !== undefined && options === undefined) { 9 | this.options = graph; 10 | } 11 | 12 | this.graph = new bindings.Graph(); 13 | } 14 | 15 | if (options instanceof bindings.SessionOptions) { 16 | this.options = options; 17 | } else if (!(this.options instanceof bindings.SessionOptions)) { 18 | this.options = new bindings.SessionOptions(this.options); 19 | } 20 | 21 | this.instance = 22 | new bindings.CSession(this.graph.instance); 23 | } 24 | 25 | Session.prototype.run = function () { 26 | this.instance.Run(); 27 | } 28 | }; 29 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "tensorflow", 3 | "version": "0.0.1", 4 | "description": "Tensorflow for node.js", 5 | "main": "index.js", 6 | "scripts": { 7 | "preinstall": "node scripts/install.js", 8 | "postinstall": "node-gyp build", 9 | "build": "node-gyp rebuild", 10 | "swig": "scripts/swig.sh", 11 | "test": "mocha --recursive", 12 | "debug": "lldb -n node -w -s scripts/lldbinit" 13 | }, 14 | "repository": { 15 | "type": "git", 16 | "url": "https://github.com/rchipka/node-tensorflow.git" 17 | }, 18 | "keywords": [ 19 | "tensorflow", "tf", 20 | "machinelearning", "ml", 21 | "deeplearning", "dl", 22 | "neuralnetworks", "nn", 23 | "math", "matrix", "tensor" 24 | ], 25 | "author": "Robbie Chipka", 26 | "license": "MIT", 27 | "bugs": { 28 | "url": "https://github.com/rchipka/node-tensorflow/issues" 29 | }, 30 | "dependencies": { 31 | "nan": "^2.5.1", 32 | "node-gyp": "^3.4.0" 33 | }, 34 | "devDependencies": { 35 | "mocha": "3.2.0" 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016 Robbie Chipka 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /index.js: -------------------------------------------------------------------------------- 1 | 'use strict'; 2 | 3 | var fs = require('fs'), 4 | path = require('path'), 5 | bindings = require('bindings')('tensorflow'), 6 | CONTRIB_PATH = path.resolve(__dirname, 'contrib'); 7 | 8 | var SegfaultHandler = require('segfault-handler'); 9 | 10 | SegfaultHandler.registerHandler("crash.log"); 11 | 12 | bindings = Object.create(bindings); 13 | require('./lib/types.js')(bindings); 14 | require('./lib/Operation.js')(bindings); 15 | require('./lib/OperationDescription.js')(bindings); 16 | require('./lib/Graph.js')(bindings); 17 | require('./lib/Session.js')(bindings); 18 | require('./lib/SessionOptions.js')(bindings); 19 | 20 | 21 | // console.log(bindings); 22 | 23 | 24 | // console.log('Press any key to continue'); 25 | // 26 | // process.stdin.setRawMode(true); 27 | // process.stdin.resume(); 28 | // process.stdin.on('data', function () { 29 | var g = new bindings.Graph(), 30 | s = new bindings.Session(g); 31 | 32 | var p = s.graph.placeholder().name('test').shape([1]); 33 | console.log(p.instance); 34 | // console.log("S", bindings.TF_GraphGetTensorNumDims(s.graph.instance, p.instance)); 35 | 36 | s.instance.SetTargets([p.finish(s.graph)]); 37 | 38 | s.run(); 39 | // }); 40 | 41 | 42 | fs.readdirSync(CONTRIB_PATH) 43 | .filter(function (file) { 44 | return path.extname(file) === '.js'; 45 | }) 46 | .forEach(function (file) { 47 | require(path.resolve(CONTRIB_PATH, file))(bindings); 48 | }); 49 | 50 | module.exports = bindings; 51 | -------------------------------------------------------------------------------- /test/session.js: -------------------------------------------------------------------------------- 1 | 'use strict'; 2 | 3 | var tf = require('../'), 4 | assert = require('assert'); 5 | 6 | describe('Session', function () { 7 | describe('#Constructor', function () { 8 | var graph = new tf.Graph(), 9 | options = new tf.SessionOptions(); 10 | 11 | it('should work without arguments', function () { 12 | var session = new tf.Session(); 13 | 14 | assert.ok(session instanceof tf.Session); 15 | assert.equal(session.instance.toString(), '[object SwigProxy]'); 16 | }); 17 | 18 | it('should work with just Graph', function () { 19 | var session = new tf.Session(graph); 20 | 21 | assert.ok(session instanceof tf.Session); 22 | assert.ok(session.options instanceof tf.SessionOptions); 23 | assert.strictEqual(session.graph, graph); 24 | }); 25 | 26 | it('should work with just SessionOptions', function () { 27 | var session = new tf.Session(options); 28 | 29 | assert.ok(session instanceof tf.Session); 30 | assert.ok(session.graph instanceof tf.Graph); 31 | assert.ok(options.instance.equals(session.options.instance)); 32 | }); 33 | 34 | it('should work with both Graph and SessionOptions', function () { 35 | var session = new tf.Session(graph, options); 36 | 37 | assert.ok(session instanceof tf.Session); 38 | assert.ok(session.graph.instance.equals(graph.instance)); 39 | assert.ok(session.options.instance.equals(options.instance)); 40 | }); 41 | }); 42 | }); 43 | -------------------------------------------------------------------------------- /binding.gyp: -------------------------------------------------------------------------------- 1 | { 2 | 'targets': [{ 3 | 'target_name': 'tensorflow', 4 | 'product_extension': 'node', 5 | 'sources': [ 6 | 'src/tensorflow.cc' 7 | ], 8 | 'libraries' : [ 9 | '-lpython', 10 | '../bazel-out/local-opt/bin/tensorflow/python/_pywrap_tensorflow.so', 11 | '../bazel-out/local-opt/bin/external/protobuf/pyext/_message.so', 12 | '/Users/administrator/Downloads/tensorflow-1.0.1/tensorflow/contrib/cmake/build/protobuf/src/protobuf/libprotobuf.a' 13 | ], 14 | 'include_dirs' : [ 15 | 'src/', 16 | 'src/include', 17 | 'bazel-out/local-opt/bin/tensorflow/include', 18 | " $TARGET 36 | 37 | 38 | # Fix void ptr bug 39 | mv $TARGET $TARGET.bak 40 | sed -E s/\(SWIG_as_voidptrptr[^,]+,\ *\)0/\\1SWIGTYPE_p_void/g $TARGET.bak > $TARGET 41 | 42 | 43 | # Use consistent string type 44 | mv $TARGET $TARGET.bak 45 | sed s/SWIGTYPE_p_string/SWIGTYPE_p_std__string/g $TARGET.bak > $TARGET 46 | 47 | 48 | # Useful error reporting 49 | mv $TARGET $TARGET.bak 50 | perl -0pe 's/(args.Length\(\) != ([0-9]+)\)\s*SWIG_exception_fail\(SWIG_ERROR, "[^\.]+)./$1 (should be $2)/smg' $TARGET.bak > $TARGET 51 | 52 | rm -f $TARGET.bak 53 | 54 | fi 55 | -------------------------------------------------------------------------------- /lib/Operation.js: -------------------------------------------------------------------------------- 1 | 'use strict'; 2 | 3 | var alias = require('./aliases.js'); 4 | 5 | module.exports = function (bindings) { 6 | var Operation = bindings.Operation = function Operation(op) { 7 | this.name = op.name(); 8 | this.summary = op.summary(); 9 | this.description = op.description(); 10 | 11 | var attrs = this.attrs = [], args = this.args = [], 12 | outputs = this.outputs = []; 13 | 14 | (function () { 15 | var i = 0, 16 | total = op.input_arg_size(), 17 | arg; 18 | 19 | for (; i < total; i++) { 20 | arg = op.input_arg(i); 21 | 22 | args[i] = { 23 | name: arg.name(), 24 | desc: arg.description(), 25 | type: bindings.types[arg.type()] 26 | }; 27 | } 28 | 29 | return args; 30 | })(); 31 | 32 | (function () { 33 | var i = 0, 34 | total = op.output_arg_size(), 35 | arg; 36 | 37 | for (; i < total; i++) { 38 | arg = op.output_arg(i); 39 | 40 | outputs[i] = { 41 | name: arg.name(), 42 | desc: arg.description(), 43 | type: bindings.types[arg.type()] 44 | }; 45 | } 46 | 47 | return args; 48 | })(); 49 | 50 | (function () { 51 | var i = 0, total = op.attr_size(), attr; 52 | 53 | for (; i < total; i++) { 54 | var attr = op.attr(i); 55 | 56 | attrs[i] = { 57 | name: attr.name(), 58 | desc: attr.description(), 59 | type: attr.type() 60 | }; 61 | } 62 | })(); 63 | 64 | return this; 65 | }; 66 | 67 | Operation.prototype.usage = function () { 68 | 69 | }; 70 | 71 | (function init() { 72 | var buffer = bindings.TF_GetAllOpList(), 73 | list = (new bindings.OpList()), 74 | total, i = 0, op, name, 75 | opTarget = bindings.OperationTarget = {}; 76 | 77 | list.ParseFromArray(buffer.data, buffer.length); 78 | total = list.op_size(); 79 | 80 | for (; i < total; i++) { 81 | op = new Operation(list.op(i)); 82 | 83 | name = op.name;//.replace(/V2$/, ''); 84 | 85 | if (opTarget[name] !== undefined) { 86 | continue; 87 | } 88 | 89 | // console.log(op); 90 | 91 | opTarget[name] = 92 | opTarget[alias(name)] = 93 | opTarget[name.charAt(0).toLowerCase() + name.substr(1)] = 94 | (function (op) { 95 | return function () { 96 | var args = Array.prototype.slice.call(arguments); 97 | 98 | args.unshift(this); 99 | 100 | return new bindings.OperationDescription(op, args); 101 | }; 102 | })(op); 103 | } 104 | 105 | // for (i in bindings.OperationTarget) { 106 | // bindings[i] = bindings.OperationTarget[i]; 107 | // } 108 | })(); 109 | }; 110 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow 2 | Native TensorFlow bindings for Node.JS 3 | 4 | # Install 5 | 6 | `npm install tensorflow` 7 | 8 | # Features 9 | 10 | * Fully exposes the TensorFlow [C API](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/c_api.h) 11 | * Fast native bindings (no FFI) 12 | * Composable, intuitive wrapper API 13 | * No TensorFlow install/build required 14 | 15 | # Globals 16 | 17 | ## tf.Graph() 18 | ## tf.Session() 19 | ## tf.Operation() 20 | ## tf.Tensor(tf.DataType, [tf.TensorShape]) 21 | 22 | ## Instance methods 23 | 24 | ### Tensor.batchToSpace(block_shape, crops) 25 | ### Tensor.cast(type) 26 | ### Tensor.broadcast(tensor) 27 | ### Tensor.isNumeric() 28 | Returns true if the tensor doesn't contain NaN or Infinity values. 29 | ### Tensor.concat(tensors[], axis) // ops::Concat, ops::ParallelConcat 30 | ### Tensor.copy([name]) 31 | ### Tensor.copyHost([name]) 32 | ### Tensor.debugID() 33 | ### Tensor.countNAN() 34 | ### Tensor.summary() 35 | ### Tensor.depthToSpace(block_size) 36 | ### Tensor.dequantize(minRange, maxRange, quantize) 37 | 38 | quantize = true || { 39 | signed: // If the quantization is signed or unsigned. 40 | bits: // The bitwidth of the quantization. 41 | minRange: 42 | maxRange: 43 | } 44 | ### Tensor.diagonal([diagonal]) // ops::Diag and ops::MatrixBandPart 45 | opts = { 46 | upper: true || number of superdiagonals to keep 47 | lower: true || number of subdiagonals to keep 48 | } 49 | if given a diagonal, this method sets the diagonal values 50 | ### Tensor.leven(tensor, { normalized: false }) 51 | ### Tensor.expand(axis) 52 | ### Tensor.quantize(min, max, [type], [gradient]) 53 | ### Tensor.fill(value, [dimensions]) 54 | ### Tensor.gather(indices) 55 | ### Tensor.clone() // ops::Identity, ops::ZerosLike 56 | ### Tensor.freeze() // ops::ImmutableConst 57 | ### Tensor.invert() // ops::InvertPermutation 58 | ### Tensor.pad(paddings, mode) // ops::Pad, ops::MirrorPad 59 | ### Tensor.quantize(minRange, maxRange, type) 60 | ### Tensor.quantizedInstanceNorm(min, max, opts) 61 | 62 | opts = { 63 | output_range_given: If True, given_y_min and given_y_min and given_y_max are used as the output range. Otherwise, the implementation computes the output range. 64 | given_y_min: Output in y_min if output_range_given is True. 65 | given_y_max: Output in y_max if output_range_given is True. 66 | variance_epsilon: A small float number to avoid dividing by 0. 67 | min_separation: Minimum value of y_max - y_min 68 | } 69 | 70 | ### Tensor.reverse(axis) 71 | ### Tensor.reverseSequence(sequenceLengths, sequenceDimension, reverseDimension) 72 | ### Tensor.scatter(indices, values, [shape]) 73 | ### Tensor.diff(tensor) 74 | ### Tensor.shape([shape]) 75 | Return the shape of the tensor. 76 | 77 | If shape is provided, sets the shape of the tensor. 78 | 79 | ### Tensor.size() 80 | ### Tensor.slice(offset, length, strides) 81 | ### Tensor.spaceToBatch(block_shape, paddings) 82 | ### Tensor.squeeze([dimensionIndices]) 83 | 84 | ## Static methods 85 | 86 | ### Tensor.diagonal(diagonal) 87 | ### Tensor.oneHot(indices, value, [opts]) 88 | ### Tensor.quantize(tensors[], dimensions, min[], max[]) 89 | ### Tensor.shapes(tensors[]) 90 | ### Tensor.split(axis, count, byValue (ops::SplitV)) 91 | ### Tensor.stack(tensors[], [axis]) 92 | ### Tensor.tile(multiples) 93 | ### Tensor.transpose(permutation) 94 | ### Tensor.unique(counts == false) 95 | ### Tensor.unstack([axis]) 96 | ### Tensor.where() 97 | 98 | opts = { 99 | axis: 1, 100 | depth: 1, 101 | zeros: 0.1 102 | } 103 | 104 | 105 | ## tf.Image() 106 | 107 | ## Static methods 108 | 109 | ### Image.extractPatches([images], windowSize, centerStride, rates, padding) 110 | 111 | # Global methods 112 | 113 | ### tf.IsGoogleCudaEnabled() 114 | ### tf.LogAllRegisteredKernels() 115 | -------------------------------------------------------------------------------- /scripts/install.js: -------------------------------------------------------------------------------- 1 | 'use strict'; 2 | 3 | var os = require('os'), 4 | fs = require('fs'), 5 | https = require('https'), 6 | path = require('path'), 7 | exec = require('child_process').exec, 8 | TF_VERSION = '0.12.0', 9 | TF_PLATFORM = os.platform() === 'darwin' ? 'mac': 'linux', 10 | TF_PY_VERSION = TF_PLATFORM === 'mac' ? 'py2-none' : 'cp27-none', 11 | TF_ARCH = TF_PLATFORM === 'mac' ? 'any' : 'linux_x86_64', 12 | TF_GPU_ENABLED = TF_PLATFORM !== 'mac', 13 | TF_BINARY_URL = 'https://storage.googleapis.com/tensorflow/' + 14 | TF_PLATFORM + '/' + (TF_GPU_ENABLED ? 'gpu' : 'cpu') + 15 | ['/tensorflow', TF_VERSION, TF_PY_VERSION, TF_ARCH].join('-') + '.whl', 16 | TF_ROOT_DIR = path.resolve(__dirname, '..'), 17 | TF_BINARY_FILE = path.resolve(TF_ROOT_DIR, 'binary.zip'), 18 | TF_TARGET_DIR = path.resolve(TF_ROOT_DIR, 'bazel-out'), 19 | UNZIP_COMMAND = 'unzip -q -o -d ' + TF_TARGET_DIR + ' ' + TF_BINARY_FILE, 20 | TF_SOURCE_FILES = [ 21 | 'c/c_api.cc', 'c/c_api.h', 'c/c_api_test.cc', 22 | 'cc/saved_model/loader.h' 23 | ], 24 | TF_SOURCE_DIR = path.resolve(__dirname, '../src/tensorflow'), 25 | TF_SOURCE_URL = 'https://raw.githubusercontent.com/tensorflow/' + 26 | 'tensorflow/v' + TF_VERSION +'/tensorflow/'; 27 | 28 | // URL structure obtained from: 29 | // https://www.tensorflow.org/versions/r0.10/get_started/os_setup 30 | 31 | fs.stat(TF_BINARY_FILE, function (err, stats) { 32 | if (!err) { 33 | unzip(); 34 | return; 35 | } 36 | 37 | https.get(TF_BINARY_URL, function (response) { 38 | var read = 0, total = parseInt(response.headers['content-length'], 10); 39 | 40 | response.on('data', function (buffer) { 41 | process.stderr.write( 42 | '\r' + (((read += buffer.length) / total) * 100).toFixed(0) + '%' + 43 | ' - Downloading ' + TF_BINARY_URL); 44 | }).pipe(fs.createWriteStream(TF_BINARY_FILE)); 45 | 46 | response.on('end', function () { 47 | console.log('\nSaved to ' + TF_BINARY_FILE); 48 | unzip(); 49 | }); 50 | }); 51 | }); 52 | 53 | function unzip() { 54 | exec(UNZIP_COMMAND, function (err, stdout, stderr) { 55 | if (err) { 56 | console.error(err.toString()); 57 | return; 58 | } 59 | 60 | fs.renameSync( 61 | path.resolve(TF_TARGET_DIR, 'tensorflow-' + TF_VERSION + '.data'), 62 | path.resolve(TF_TARGET_DIR, 'local-opt')); 63 | 64 | fs.renameSync( 65 | path.resolve(TF_TARGET_DIR, 'local-opt', 'purelib'), 66 | path.resolve(TF_TARGET_DIR, 'local-opt', 'bin')); 67 | 68 | fs.unlinkSync(TF_BINARY_FILE); 69 | 70 | getSourceFile(0); 71 | }); 72 | } 73 | 74 | function getSourceFile(index) { 75 | var file = TF_SOURCE_FILES[index], 76 | save, url; 77 | 78 | if (!file) { 79 | return; 80 | } 81 | 82 | save = path.resolve(TF_SOURCE_DIR, file); 83 | url = TF_SOURCE_URL + file; 84 | 85 | mkpathsync(save.substr(0, save.length - path.basename(file).length)); 86 | 87 | https.get(url, function (response) { 88 | var read = 0, total = parseInt(response.headers['content-length'], 10); 89 | 90 | response.on('data', function (buffer) { 91 | process.stderr.write( 92 | '\r' + (((read += buffer.length) / total) * 100).toFixed(0) + '%' + 93 | ' - Downloading ' + url); 94 | }).pipe(fs.createWriteStream(save)); 95 | 96 | response.on('end', function () { 97 | console.log('\nSaved to ' + save); 98 | getSourceFile(index + 1); 99 | }); 100 | }); 101 | } 102 | 103 | function mkpathsync(dirpath, mode) { 104 | dirpath = path.resolve(dirpath); 105 | 106 | if (typeof mode === 'undefined') { 107 | mode = parseInt('0777', 8) & (~process.umask()); 108 | } 109 | 110 | try { 111 | if (!fs.statSync(dirpath).isDirectory()) { 112 | throw new Error(dirpath + ' exists and is not a directory'); 113 | } 114 | } catch (err) { 115 | if (err.code === 'ENOENT') { 116 | mkpathsync(path.dirname(dirpath), mode); 117 | fs.mkdirSync(dirpath, mode); 118 | } else { 119 | throw err; 120 | } 121 | } 122 | }; 123 | -------------------------------------------------------------------------------- /src/tensorflow.i: -------------------------------------------------------------------------------- 1 | %ignore TF_Operation(); 2 | %ignore ~TF_Operation(); 3 | %ignore TF_Session::mu; 4 | %ignore TF_Graph::mu; 5 | 6 | %{ 7 | #include "tensorflow/core/framework/allocator.h" 8 | 9 | static void Deallocator(void* data, size_t, void* arg) { 10 | tensorflow::cpu_allocator()->DeallocateRaw(data); 11 | *reinterpret_cast(arg) = true; 12 | } 13 | %} 14 | 15 | %typemap(in, numinputs=0) void (*deallocator)(void* data, size_t len, void* arg) { 16 | $1 = &Deallocator; 17 | } 18 | 19 | %typemap(in, numinputs=0) void* deallocator_arg { 20 | bool deallocator_called = false; 21 | $1 = &deallocator_called; 22 | } 23 | 24 | %typemap(in, numinputs=1) (const int64_t* dims, int) { 25 | v8::Local array; 26 | v8::Local jsvalue; 27 | int i = 0, res = 0; 28 | unsigned long long temp; 29 | 30 | if ($input->IsArray()) { 31 | array = v8::Local::Cast($input); 32 | 33 | $2 = array->Length(); 34 | $1 = ($1_ltype) malloc(sizeof($1_ltype) * $2); 35 | 36 | // Get each element from array 37 | for (i = 0; i < $2; i++) { 38 | jsvalue = array->Get(i); 39 | 40 | // Get primitive value from JSObject 41 | res = SWIG_AsVal(unsigned long long)(jsvalue, &temp); 42 | if (!SWIG_IsOK(res)) { 43 | SWIG_exception_fail(SWIG_ERROR, "Failed to convert $input to double"); 44 | } 45 | 46 | $1[i] = (int64_t) temp; 47 | } 48 | } else { 49 | SWIG_exception_fail(SWIG_ERROR, "$input is not an array"); 50 | } 51 | } 52 | 53 | %typemap(in, numinputs=0) (void* data, size_t len) { 54 | int total = 0, length = arg3, i = 0; 55 | 56 | for (; i < length; i++) { 57 | total += arg2[i]; 58 | } 59 | 60 | $2 = total * sizeof(float); 61 | $1 = reinterpret_cast(tensorflow::cpu_allocator()->AllocateRaw( 62 | EIGEN_MAX_ALIGN_BYTES, $2)); 63 | } 64 | 65 | %typemap(freearg) (const int64_t* dims, int) { 66 | free($1); 67 | } 68 | 69 | %typemap(in, numinputs=0) (tensorflow::Node** created_node) { 70 | Node* node; 71 | $1 = &node; 72 | } 73 | 74 | %typemap(argout) (tensorflow::Node** created_node) { 75 | $result = SWIG_NewPointerObj(ToOperation(*$1), SWIGTYPE_p_TF_Operation, 0 | 0 ); 76 | } 77 | 78 | %typemap(out) TF_Status* { 79 | if (!SWIG_IsOK($1)) { 80 | SWIG_exception_fail(SWIG_ERROR, TF_Message($1)); 81 | } 82 | } 83 | 84 | %typemap(out) Status { 85 | if (!$1.ok()) { 86 | SWIG_exception_fail(SWIG_ERROR, $1.error_message().c_str()); 87 | } 88 | } 89 | 90 | %{ 91 | #include "tensorflow/c/c_api.h" 92 | #include "tensorflow/core/util/port.h" 93 | #include "tensorflow/core/public/version.h" 94 | #include "tensorflow/core/framework/types.pb.h" 95 | #include "tensorflow/core/framework/node_def.pb.h" 96 | #include "tensorflow/core/framework/variable.pb.h" 97 | #include "tensorflow/core/framework/attr_value.pb.h" 98 | #include "tensorflow/core/framework/tensor.pb.h" 99 | #include "tensorflow/core/framework/op_def.pb.h" 100 | #include "tensorflow/core/framework/graph.pb.h" 101 | #include "tensorflow/core/graph/graph.h" 102 | #include "tensorflow/core/lib/core/stringpiece.h" 103 | 104 | #include "tensorflow/core/platform/macros.h" 105 | #include "tensorflow/core/platform/env.h" 106 | #include "tensorflow/core/public/session.h" 107 | #include "tensorflow/core/public/session_options.h" 108 | #include "tensorflow/core/framework/tensor.h" 109 | #include "tensorflow/core/framework/tensor_shape.h" 110 | #include "tensorflow/core/framework/node_def.pb.h" 111 | #include "tensorflow/core/platform/default/thread_annotations.h" 112 | #include "tensorflow/core/platform/mutex.h" 113 | #include "tensorflow/core/graph/graph.h" 114 | #include "tensorflow/core/common_runtime/shape_refiner.h" 115 | #include "tensorflow/core/graph/node_builder.h" 116 | #include "tensorflow/core/platform/default/mutex.h" 117 | #include "tensorflow/core/lib/gtl/iterator_range.h" 118 | #include "tensorflow/core/graph/graph.h" 119 | #include "tensorflow/core/lib/core/status.h" 120 | #include "tensorflow/core/framework/op.h" 121 | #include "tensorflow/c/c_api.cc" 122 | #include "c_session.cc" 123 | 124 | using tensorflow::mutex; 125 | using tensorflow::mutex_lock; 126 | 127 | using namespace tensorflow; 128 | /*using namespace std;*/ 129 | #include "google/protobuf/stubs/port.h" 130 | #include "google/protobuf/message_lite.h" 131 | #include "google/protobuf/message.h" 132 | %} 133 | 134 | %import "stl.i" 135 | %import "std_string.i" 136 | %import "std_pair.i" 137 | %import "std_map.i" 138 | %import "std_vector.i" 139 | %import "status.i" 140 | 141 | %import "google/protobuf/stubs/port.h" 142 | %import "google/protobuf/descriptor.h" 143 | 144 | %include "google/protobuf/message_lite.h" 145 | %include "google/protobuf/message.h" 146 | 147 | %include "tensorflow/c/c_api.h" 148 | %include "tensorflow/core/util/port.h" 149 | %include "tensorflow/core/public/version.h" 150 | %include "tensorflow/core/framework/types.pb.h" 151 | %include "tensorflow/core/framework/node_def.pb.h" 152 | %include "tensorflow/core/framework/variable.pb.h" 153 | %include "tensorflow/core/framework/attr_value.pb.h" 154 | %include "tensorflow/core/framework/tensor.pb.h" 155 | %include "tensorflow/core/framework/op_def.pb.h" 156 | %include "tensorflow/core/lib/core/stringpiece.h" 157 | 158 | %include "tensorflow/core/platform/macros.h" 159 | %include "tensorflow/core/platform/env.h" 160 | %include "tensorflow/core/public/session.h" 161 | %include "tensorflow/core/public/session_options.h" 162 | %include "tensorflow/core/framework/tensor.h" 163 | %include "tensorflow/core/framework/tensor_shape.h" 164 | %include "tensorflow/core/framework/node_def.pb.h" 165 | %include "tensorflow/core/platform/default/thread_annotations.h" 166 | %include "tensorflow/core/platform/mutex.h" 167 | %include "tensorflow/core/lib/gtl/iterator_range.h" 168 | %include "tensorflow/core/graph/graph.h" 169 | %include "tensorflow/core/common_runtime/shape_refiner.h" 170 | %include "tensorflow/core/graph/node_builder.h" 171 | %include "tensorflow/core/lib/core/status.h" 172 | %include "tensorflow/core/framework/op.h" 173 | /*%include "tensorflow/core/platform/default/mutex.h"*/ 174 | %include "tensorflow/c/c_api.cc" 175 | %include "c_session.cc" 176 | -------------------------------------------------------------------------------- /test/c-api.js: -------------------------------------------------------------------------------- 1 | 'use strict'; 2 | 3 | // TODO: 4 | // OpDef_ArgDef_default_instance_: SwigProxy {}, 5 | // OpDef_AttrDef_default_instance_: SwigProxy {}, 6 | // OpDef_default_instance_: SwigProxy {}, 7 | // OpDeprecation_default_instance_: SwigProxy {}, 8 | // OpList_default_instance_: SwigProxy {}, 9 | // OpDef_ArgDef: 10 | // { [Function: OpDef_ArgDef] 11 | // descriptor: [Function], 12 | // default_instance: [Function], 13 | // internal_default_instance: [Function], 14 | // kNameFieldNumber: 1, 15 | // kDescriptionFieldNumber: 2, 16 | // kTypeFieldNumber: 3, 17 | // kTypeAttrFieldNumber: 4, 18 | // kNumberAttrFieldNumber: 5, 19 | // kTypeListAttrFieldNumber: 6, 20 | // kIsRefFieldNumber: 16 }, 21 | // OpDef_AttrDef: 22 | // { [Function: OpDef_AttrDef] 23 | // descriptor: [Function], 24 | // default_instance: [Function], 25 | // internal_default_instance: [Function], 26 | // kNameFieldNumber: 1, 27 | // kTypeFieldNumber: 2, 28 | // kDefaultValueFieldNumber: 3, 29 | // kDescriptionFieldNumber: 4, 30 | // kHasMinimumFieldNumber: 5, 31 | // kMinimumFieldNumber: 6, 32 | // kAllowedValuesFieldNumber: 7 }, 33 | // OpDef: 34 | // { [Function: OpDef] 35 | // descriptor: [Function], 36 | // default_instance: [Function], 37 | // internal_default_instance: [Function], 38 | // kNameFieldNumber: 1, 39 | // kInputArgFieldNumber: 2, 40 | // kOutputArgFieldNumber: 3, 41 | // kAttrFieldNumber: 4, 42 | // kDeprecationFieldNumber: 8, 43 | // kSummaryFieldNumber: 5, 44 | // kDescriptionFieldNumber: 6, 45 | // kIsCommutativeFieldNumber: 18, 46 | // kIsAggregateFieldNumber: 16, 47 | // kIsStatefulFieldNumber: 17, 48 | // kAllowsUninitializedInputFieldNumber: 19 }, 49 | // OpDeprecation: 50 | // { [Function: OpDeprecation] 51 | // descriptor: [Function], 52 | // default_instance: [Function], 53 | // internal_default_instance: [Function], 54 | // kVersionFieldNumber: 1, 55 | // kExplanationFieldNumber: 2 }, 56 | // OpList: 57 | // { [Function: OpList] 58 | // descriptor: [Function], 59 | // default_instance: [Function], 60 | // internal_default_instance: [Function], 61 | // kOpFieldNumber: 1 }, 62 | // StringPiece: { [Function: StringPiece] npos: 4294967295 } 63 | 64 | var tf = require('../'), 65 | assert = require('assert'), 66 | constants = { 67 | numeric: [ 68 | 'TF_FLOAT', 'TF_DOUBLE', 'TF_INT32', 'TF_UINT8', 'TF_INT16', 69 | 'TF_INT8', 'TF_STRING', 'TF_COMPLEX64', 'TF_COMPLEX', 'TF_INT64', 70 | 'TF_BOOL', 'TF_QINT8', 'TF_QUINT8', 'TF_QINT32', 'TF_BFLOAT16', 71 | 'TF_QINT16', 'TF_QUINT16', 'TF_UINT16', 'TF_COMPLEX128', 'TF_HALF', 72 | 'TF_RESOURCE', 'TF_OK', 'TF_CANCELLED', 'TF_UNKNOWN', 73 | 'TF_INVALID_ARGUMENT', 'TF_DEADLINE_EXCEEDED', 'TF_NOT_FOUND', 74 | 'TF_ALREADY_EXISTS', 'TF_PERMISSION_DENIED', 'TF_UNAUTHENTICATED', 75 | 'TF_RESOURCE_EXHAUSTED', 'TF_FAILED_PRECONDITION', 'TF_ABORTED', 76 | 'TF_OUT_OF_RANGE', 'TF_UNIMPLEMENTED', 'TF_INTERNAL', 'TF_UNAVAILABLE', 77 | 'TF_DATA_LOSS', 'TF_ATTR_STRING', 'TF_ATTR_INT', 'TF_ATTR_FLOAT', 78 | 'TF_ATTR_BOOL', 'TF_ATTR_TYPE', 'TF_ATTR_SHAPE', 'TF_ATTR_TENSOR', 79 | 'TF_ATTR_PLACEHOLDER', 'TF_ATTR_FUNC', 'TF_MAJOR_VERSION', 80 | 'TF_MINOR_VERSION', 'TF_PATCH_VERSION', 81 | 'TF_GRAPH_DEF_VERSION_MIN_PRODUCER', 82 | 'TF_GRAPH_DEF_VERSION_MIN_CONSUMER', 'TF_GRAPH_DEF_VERSION', 83 | 'TF_CHECKPOINT_VERSION_MIN_PRODUCER', 84 | 'TF_CHECKPOINT_VERSION_MIN_CONSUMER', 'TF_CHECKPOINT_VERSION' 85 | ], 86 | string: [ 87 | 'TF_VERSION_SUFFIX', 'TF_VERSION_STRING' 88 | ], 89 | function: [ 90 | 'TF_Version', 'TF_DataTypeSize', 'TF_NewStatus', 'TF_DeleteStatus', 91 | 'TF_SetStatus', 'TF_GetCode', 'TF_Message', 'TF_NewBufferFromString', 92 | 'TF_NewBuffer', 'TF_DeleteBuffer', 'TF_GetBuffer', 'TF_NewTensor', 93 | 'TF_AllocateTensor', 'TF_DeleteTensor', 'TF_TensorType', 'TF_NumDims', 94 | 'TF_Dim', 'TF_TensorByteSize', 'TF_TensorData', 'TF_StringEncode', 95 | 'TF_StringDecode', 'TF_StringEncodedSize', 'TF_NewSessionOptions', 96 | 'TF_SetTarget', 'TF_SetConfig', 'TF_DeleteSessionOptions', 97 | 'TF_NewGraph', 'TF_DeleteGraph', 'TF_GraphSetTensorShape', 98 | 'TF_GraphGetTensorNumDims', 'TF_GraphGetTensorShape', 99 | 'TF_NewOperation', 'TF_SetDevice', 'TF_AddInput', 'TF_AddInputList', 100 | 'TF_AddControlInput', 'TF_ColocateWith', 'TF_SetAttrString', 101 | 'TF_SetAttrStringList', 'TF_SetAttrInt', 'TF_SetAttrIntList', 102 | 'TF_SetAttrFloat', 'TF_SetAttrFloatList', 'TF_SetAttrBool', 103 | 'TF_SetAttrBoolList', 'TF_SetAttrType', 'TF_SetAttrTypeList', 104 | 'TF_SetAttrShape', 'TF_SetAttrShapeList', 105 | 'TF_SetAttrTensorShapeProto', 'TF_SetAttrTensorShapeProtoList', 106 | 'TF_SetAttrTensor', 'TF_SetAttrTensorList', 'TF_SetAttrValueProto', 107 | 'TF_FinishOperation', 'TF_OperationName', 'TF_OperationOpType', 108 | 'TF_OperationDevice', 'TF_OperationNumOutputs', 109 | 'TF_OperationOutputType', 'TF_OperationOutputListLength', 110 | 'TF_OperationNumInputs', 'TF_OperationInputType', 111 | 'TF_OperationInputListLength', 'TF_OperationInput', 112 | 'TF_OperationOutputNumConsumers', 'TF_OperationOutputConsumers', 113 | 'TF_OperationNumControlInputs', 'TF_OperationGetControlInputs', 114 | 'TF_OperationNumControlOutputs', 'TF_OperationGetControlOutputs', 115 | 'TF_OperationGetAttrMetadata', 'TF_OperationGetAttrString', 116 | 'TF_OperationGetAttrStringList', 'TF_OperationGetAttrInt', 117 | 'TF_OperationGetAttrIntList', 'TF_OperationGetAttrFloat', 118 | 'TF_OperationGetAttrFloatList', 'TF_OperationGetAttrBool', 119 | 'TF_OperationGetAttrBoolList', 'TF_OperationGetAttrType', 120 | 'TF_OperationGetAttrTypeList', 'TF_OperationGetAttrShape', 121 | 'TF_OperationGetAttrShapeList', 'TF_OperationGetAttrTensorShapeProto', 122 | 'TF_OperationGetAttrTensorShapeProtoList', 'TF_OperationGetAttrTensor', 123 | 'TF_OperationGetAttrTensorList', 'TF_OperationGetAttrValueProto', 124 | 'TF_GraphOperationByName', 'TF_GraphNextOperation', 125 | 'TF_GraphToGraphDef', 'TF_NewImportGraphDefOptions', 126 | 'TF_DeleteImportGraphDefOptions', 'TF_ImportGraphDefOptionsSetPrefix', 127 | 'TF_GraphImportGraphDef', 'TF_OperationToNodeDef', 'TF_NewSession', 128 | 'TF_LoadSessionFromSavedModel', 'TF_CloseSession', 'TF_DeleteSession', 129 | 'TF_SessionRun', 'TF_SessionPRunSetup', 'TF_SessionPRun', 130 | 'TF_NewDeprecatedSession', 'TF_CloseDeprecatedSession', 131 | 'TF_DeleteDeprecatedSession', 'TF_Reset', 'TF_ExtendGraph', 'TF_Run', 132 | 'TF_PRunSetup', 'TF_PRun', 'TF_LoadLibrary', 'TF_GetOpList', 133 | 'TF_DeleteLibraryHandle', 'TF_GetAllOpList', 'IsGoogleCudaEnabled', 134 | 'CudaSupportsHalfMatMulAndConv', 'tf_compiler_version', 'tf_git_version' 135 | ], 136 | constructor: [ 137 | 'MessageLite', 'Metadata', 'Message', 'Reflection', 'MessageFactory', 138 | 'TF_Buffer', 'TF_Input', 'TF_Output', 'TF_AttrMetadata' 139 | ] 140 | } 141 | describe('C API', function () { 142 | describe('#Constants', function () { 143 | it('numeric', function () { 144 | constants.numeric.forEach(function (c) { 145 | assert.equal(typeof tf[c], 'number'); 146 | }); 147 | }); 148 | 149 | it('string', function () { 150 | constants.string.forEach(function (c) { 151 | assert.equal(typeof tf[c], 'string'); 152 | }); 153 | }); 154 | 155 | it('function', function () { 156 | constants.function.forEach(function (c) { 157 | assert.equal(typeof tf[c], 'function'); 158 | }); 159 | }); 160 | 161 | it('constructor', function () { 162 | constants.constructor.forEach(function (c) { 163 | assert.equal(typeof tf[c], 'function'); 164 | assert.equal( 165 | tf[c].toString(), 'function ' + c + '() { [native code] }'); 166 | }); 167 | }); 168 | }); 169 | }); 170 | --------------------------------------------------------------------------------