├── .eslintignore ├── .eslintrc ├── .gitignore ├── Makefile ├── README.md ├── binding.gyp ├── demo.js ├── hello.js ├── index.js ├── lib ├── ops.js └── pb.js ├── package.json ├── src ├── binding.cc ├── graph.cc ├── graph.h ├── session.cc ├── session.h ├── status.cc └── status.h └── test └── index.test.js /.eslintignore: -------------------------------------------------------------------------------- 1 | *.debug.js 2 | *.min.js 3 | node_modules/* 4 | -------------------------------------------------------------------------------- /.eslintrc: -------------------------------------------------------------------------------- 1 | { 2 | "rules": { 3 | "indent": [ 4 | 2, 5 | 2 6 | ], 7 | "quotes": [ 8 | 2, 9 | "single" 10 | ], 11 | "linebreak-style": [ 12 | 2, 13 | "unix" 14 | ], 15 | "semi": [2, "always"], 16 | "strict": [2, "global"], 17 | "curly": 2, 18 | "eqeqeq": 2, 19 | "no-eval": 2, 20 | "guard-for-in": 2, 21 | "no-caller": 2, 22 | "no-else-return": 2, 23 | "no-eq-null": 2, 24 | "no-extend-native": 2, 25 | "no-extra-bind": 2, 26 | "no-floating-decimal": 2, 27 | "no-implied-eval": 2, 28 | "no-labels": 2, 29 | "no-with": 2, 30 | "no-loop-func": 1, 31 | "no-native-reassign": 2, 32 | "no-redeclare": [2, {"builtinGlobals": true}], 33 | "no-delete-var": 2, 34 | "no-shadow-restricted-names": 2, 35 | "no-undef-init": 2, 36 | "no-use-before-define": 2, 37 | "no-unused-vars": [2, {"args": "none"}], 38 | "no-undef": 2, 39 | "callback-return": [2, ["callback", "cb", "next"]], 40 | "global-require": 0, 41 | "no-console": 0, 42 | "require-yield": 0 43 | }, 44 | "env": { 45 | "es6": true, 46 | "node": true, 47 | "browser": true 48 | }, 49 | "globals": { 50 | "describe": true, 51 | "it": true, 52 | "before": true, 53 | "after": true 54 | }, 55 | "parserOptions": { 56 | "ecmaVersion": 8, 57 | "sourceType": "script", 58 | "ecmaFeatures": { 59 | "jsx": true 60 | } 61 | }, 62 | "extends": "eslint:recommended" 63 | } 64 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | coverage 3 | node_modules 4 | libtensorflow-*.tar.gz 5 | .nyc_output 6 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | TESTS = test/*.js 2 | REPORTER = spec 3 | TIMEOUT = 20000 4 | MOCHA = ./node_modules/mocha/bin/_mocha 5 | PATH := ./node_modules/.bin:$(PATH) 6 | 7 | lint: 8 | @eslint --fix lib index.js test 9 | 10 | build: 11 | @node-gyp rebuild 12 | 13 | install: 14 | @cnpm install 15 | 16 | test: build 17 | @mocha -t $(TIMEOUT) -R spec $(TESTS) 18 | 19 | test-cov: build 20 | @nyc --reporter=html --reporter=text mocha -t $(TIMEOUT) -R spec $(TESTS) 21 | 22 | test-coveralls: build 23 | @nyc mocha -t $(TIMEOUT) -R spec $(TESTS) 24 | @echo TRAVIS_JOB_ID $(TRAVIS_JOB_ID) 25 | @nyc report --reporter=text-lcov | coveralls 26 | 27 | .PHONY: test build 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tensorflow-node 2 | 3 | tensorflow node binding. 4 | 5 | ## Installation 6 | 7 | Step 1: install tensorflow for C 8 | 9 | see . 10 | 11 | ## License 12 | 13 | The MIT license. 14 | -------------------------------------------------------------------------------- /binding.gyp: -------------------------------------------------------------------------------- 1 | { 2 | "targets": [ 3 | { 4 | "target_name": "binding", 5 | "sources": [ 6 | "src/binding.cc", 7 | "src/status.cc", 8 | "src/graph.cc", 9 | "src/session.cc" 10 | ], 11 | "include_dirs" : [ 12 | #"/usr/local/include" 13 | ], 14 | 'libraries': [ 15 | '-ltensorflow' 16 | ] 17 | } 18 | ] 19 | } 20 | -------------------------------------------------------------------------------- /demo.js: -------------------------------------------------------------------------------- 1 | 'use strict'; 2 | 3 | const tf = require('./'); 4 | 5 | const a = tf.constant([1.0, 2.0], 'a'); 6 | const b = tf.constant([2.0, 3.0], 'b'); 7 | 8 | const result = tf.add(a, b, 'add'); 9 | 10 | console.log(result); 11 | -------------------------------------------------------------------------------- /hello.js: -------------------------------------------------------------------------------- 1 | 'use strict'; 2 | 3 | const tf = require('./'); 4 | 5 | const hello = tf.constant('Hello, TensorFlow!'); 6 | const sess = new tf.Session(); 7 | 8 | sess.run(hello); 9 | -------------------------------------------------------------------------------- /index.js: -------------------------------------------------------------------------------- 1 | 'use strict'; 2 | 3 | const tf = require('bindings')('binding.node'); 4 | // const ops = require('./lib/ops'); 5 | 6 | tf.placeholder = function (graph, status, name = 'feed') { 7 | const desc = new tf.Operation(graph, 'Placeholder', name); 8 | desc.setAttrType('dtype', tf.DataType.INT32); 9 | return tf.finishOperation(desc, status); 10 | }; 11 | 12 | module.exports = tf; 13 | -------------------------------------------------------------------------------- /lib/ops.js: -------------------------------------------------------------------------------- 1 | // 'use strict'; 2 | 3 | // class DefaultStack { 4 | // // """A thread-local stack of objects for providing implicit defaults.""" 5 | 6 | // constructor() { 7 | // this._enforce_nesting = true; 8 | // this.stack = []; 9 | // } 10 | 11 | // getDefault() { 12 | // if (this.stack.length >= 1) { 13 | // return this.stack[this.stack.length - 1]; 14 | // } 15 | // return null; 16 | // } 17 | 18 | // reset() { 19 | // this.stack = []; 20 | // } 21 | 22 | // get enforce_nesting() { 23 | // return this._enforce_nesting; 24 | // } 25 | 26 | // set enforce_nesting(value) { 27 | // this._enforce_nesting = value; 28 | // } 29 | 30 | // * get_controller(_default) { 31 | // // """A context manager for manipulating a default stack.""" 32 | // try { 33 | // this.stack.append(_default); 34 | // yield _default; 35 | // } finally { 36 | // if (this._enforce_nesting) { 37 | // if (this.stack[this.stack.length - 1] !== _default) { 38 | // throw new TypeError( 39 | // `Nesting violated for default stack of ${typeof _default} objects`); 40 | // } 41 | 42 | // this.stack.pop(); 43 | // } else { 44 | // // TODO 45 | // this.stack.remove(_default); 46 | // } 47 | // } 48 | // } 49 | // } 50 | 51 | // const _default_session_stack = new DefaultStack(); 52 | 53 | // class DefaultGraphStack extends DefaultStack { 54 | 55 | // /// """A thread-local stack of objects for providing an implicit default graph.""" 56 | 57 | // constructor() { 58 | // super(); 59 | // this._global_default_graph = null; 60 | // } 61 | 62 | // getDefault() { 63 | // /// """Override that returns a global default if the stack is empty.""" 64 | // return super.getDefault() || this.getGlobalDefaultGraph(); 65 | // } 66 | 67 | // getGlobalDefaultGraph() { 68 | // if (!this._global_default_graph) { 69 | // this._global_default_graph = new Graph(); 70 | // } 71 | 72 | // return this._global_default_graph; 73 | // } 74 | 75 | // reset() { 76 | // super.reset(); 77 | // this._global_default_graph = null; 78 | // } 79 | // } 80 | 81 | // const _default_graph_stack = new DefaultGraphStack(); 82 | 83 | // exports.getDefaultGraph = function () { 84 | // return _default_graph_stack.getDefault(); 85 | // }; 86 | -------------------------------------------------------------------------------- /lib/pb.js: -------------------------------------------------------------------------------- 1 | // AttrValue; 2 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "tensorflow-node", 3 | "version": "1.0.0", 4 | "description": "tensorflow node binding", 5 | "main": "index.js", 6 | "scripts": { 7 | "test": "make test" 8 | }, 9 | "author": "", 10 | "license": "MIT", 11 | "gypfile": true, 12 | "devDependencies": { 13 | "expect.js": "^0.3.1", 14 | "mocha": "^3.4.2", 15 | "nyc": "^11.0.1" 16 | }, 17 | "dependencies": { 18 | "bindings": "^1.2.1", 19 | "protobufjs": "^6.7.3" 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /src/binding.cc: -------------------------------------------------------------------------------- 1 | // hello.cc 2 | #include 3 | #include 4 | 5 | #include "status.h" 6 | #include "graph.h" 7 | #include "session.h" 8 | 9 | #define CONST_INT(n, v) \ 10 | obj->Set(String::NewFromUtf8(isolate, n), Integer::New(isolate, v)); 11 | 12 | namespace tensorflow_node { 13 | 14 | using v8::FunctionCallbackInfo; 15 | using v8::Integer; 16 | using v8::Isolate; 17 | using v8::Local; 18 | using v8::Object; 19 | using v8::String; 20 | using v8::Value; 21 | 22 | void Version(const FunctionCallbackInfo& args) { 23 | Isolate* isolate = args.GetIsolate(); 24 | args.GetReturnValue().Set(String::NewFromUtf8(isolate, TF_Version())); 25 | } 26 | 27 | void FinishOperation(const FunctionCallbackInfo& args) { 28 | Isolate* isolate = args.GetIsolate(); 29 | 30 | args.GetReturnValue().Set(String::NewFromUtf8(isolate, TF_Version())); 31 | } 32 | 33 | void InitDataType(Local exports) { 34 | Isolate* isolate = exports->GetIsolate(); 35 | Local obj = Object::New(isolate); 36 | CONST_INT("FLOAT", TF_FLOAT); 37 | CONST_INT("DOUBLE", TF_DOUBLE); 38 | CONST_INT("INT32", TF_INT32); // Int32 tensors are always in 'host' memory. 39 | CONST_INT("UINT8", TF_UINT8); 40 | CONST_INT("INT16", TF_INT16); 41 | CONST_INT("INT8", TF_INT8); 42 | CONST_INT("STRING", TF_STRING); 43 | CONST_INT("COMPLEX64", TF_COMPLEX64); // Single-precision complex 44 | CONST_INT("COMPLEX", TF_COMPLEX); // Old identifier kept for API backwards compatibility 45 | CONST_INT("INT64", TF_INT64); 46 | CONST_INT("BOOL", TF_BOOL); 47 | CONST_INT("QINT8", TF_QINT8); // Quantized int8 48 | CONST_INT("QUINT8", TF_QUINT8); // Quantized uint8 49 | CONST_INT("QINT32", TF_QINT32); // Quantized int32 50 | CONST_INT("BFLOAT16", TF_BFLOAT16); // Float32 truncated to 16 bits. Only for cast ops. 51 | CONST_INT("QINT16", TF_QINT16); // Quantized int16 52 | CONST_INT("QUINT16", TF_QUINT16); // Quantized uint16 53 | CONST_INT("UINT16", TF_UINT16); 54 | CONST_INT("COMPLEX128", TF_COMPLEX128); // Double-precision complex 55 | CONST_INT("HALF", TF_HALF); 56 | CONST_INT("RESOURCE", TF_RESOURCE); 57 | 58 | exports->Set(String::NewFromUtf8(isolate, "DataType"), 59 | obj); 60 | } 61 | 62 | void InitStatusCode(Local exports) { 63 | Isolate* isolate = exports->GetIsolate(); 64 | Local obj = Object::New(isolate); 65 | 66 | CONST_INT("OK", TF_OK); 67 | CONST_INT("CANCELLED", TF_CANCELLED); 68 | CONST_INT("UNKNOWN", TF_UNKNOWN); 69 | CONST_INT("INVALID_ARGUMENT", TF_INVALID_ARGUMENT); 70 | CONST_INT("DEADLINE_EXCEEDED", TF_DEADLINE_EXCEEDED); 71 | CONST_INT("NOT_FOUND", TF_NOT_FOUND); 72 | CONST_INT("ALREADY_EXISTS", TF_ALREADY_EXISTS); 73 | CONST_INT("PERMISSION_DENIED", TF_PERMISSION_DENIED); 74 | CONST_INT("UNAUTHENTICATED", TF_UNAUTHENTICATED); 75 | CONST_INT("RESOURCE_EXHAUSTED", TF_RESOURCE_EXHAUSTED); 76 | CONST_INT("FAILED_PRECONDITION", TF_FAILED_PRECONDITION); 77 | CONST_INT("ABORTED", TF_ABORTED); 78 | CONST_INT("OUT_OF_RANGE", TF_OUT_OF_RANGE); 79 | CONST_INT("UNIMPLEMENTED", TF_UNIMPLEMENTED); 80 | CONST_INT("INTERNAL", TF_INTERNAL); 81 | CONST_INT("UNAVAILABLE", TF_UNAVAILABLE); 82 | CONST_INT("DATA_LOSS", TF_DATA_LOSS); 83 | exports->Set(String::NewFromUtf8(isolate, "Code"), 84 | obj); 85 | } 86 | 87 | void init(Local exports) { 88 | NODE_SET_METHOD(exports, "version", Version); 89 | NODE_SET_METHOD(exports, "finishOperation", FinishOperation); 90 | 91 | InitDataType(exports); 92 | 93 | InitStatusCode(exports); 94 | 95 | Status::Init(exports); 96 | Graph::Init(exports); 97 | } 98 | 99 | NODE_MODULE(addon, init) 100 | 101 | } // namespace tensorflow_node 102 | -------------------------------------------------------------------------------- /src/graph.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include "graph.h" 3 | 4 | namespace tensorflow_node { 5 | 6 | using v8::Context; 7 | using v8::Function; 8 | using v8::FunctionCallbackInfo; 9 | using v8::FunctionTemplate; 10 | using v8::Isolate; 11 | using v8::Local; 12 | using v8::Object; 13 | using v8::Persistent; 14 | using v8::String; 15 | using v8::Integer; 16 | using v8::Value; 17 | 18 | Persistent Graph::constructor; 19 | 20 | Graph::Graph(TF_Graph* graph) : graph_(graph) { 21 | } 22 | 23 | Graph::~Graph() { 24 | TF_DeleteGraph(graph_); 25 | } 26 | 27 | void Graph::Init(Local exports) { 28 | Isolate* isolate = exports->GetIsolate(); 29 | // Prepare constructor template 30 | Local tpl = FunctionTemplate::New(isolate, New); 31 | tpl->SetClassName(String::NewFromUtf8(isolate, "Graph")); 32 | tpl->InstanceTemplate()->SetInternalFieldCount(1); 33 | // Prototype 34 | // NODE_SET_PROTOTYPE_METHOD(tpl, "getCode", GetCode); 35 | // NODE_SET_PROTOTYPE_METHOD(tpl, "getMessage", GetMessage); 36 | // NODE_SET_PROTOTYPE_METHOD(tpl, "set", Set); 37 | 38 | constructor.Reset(isolate, tpl->GetFunction()); 39 | 40 | exports->Set(String::NewFromUtf8(isolate, "Graph"), 41 | tpl->GetFunction()); 42 | } 43 | 44 | void Graph::New(const FunctionCallbackInfo& args) { 45 | Isolate* isolate = args.GetIsolate(); 46 | 47 | if (args.IsConstructCall()) { 48 | // Invoked as constructor: `new Status(...)` 49 | TF_Graph* graph = TF_NewGraph(); 50 | Graph* obj = new Graph(graph); 51 | obj->Wrap(args.This()); 52 | args.GetReturnValue().Set(args.This()); 53 | } else { 54 | // Invoked as plain function `Status(...)`, turn into construct call. 55 | const int argc = 1; 56 | Local argv[argc] = { args[0] }; 57 | Local context = isolate->GetCurrentContext(); 58 | Local cons = Local::New(isolate, constructor); 59 | Local instance = 60 | cons->NewInstance(context, argc, argv).ToLocalChecked(); 61 | args.GetReturnValue().Set(instance); 62 | } 63 | } 64 | 65 | // void Graph::GetCode(const FunctionCallbackInfo& args) { 66 | // Isolate* isolate = args.GetIsolate(); 67 | // Graph* graph = ObjectWrap::Unwrap(args.Holder()); 68 | // TF_Code code = TF_GetCode(status->status()); 69 | // args.GetReturnValue().Set(Integer::New(isolate, code)); 70 | // } 71 | 72 | // void Graph::GetMessage(const FunctionCallbackInfo& args) { 73 | // Isolate* isolate = args.GetIsolate(); 74 | // Status* status = ObjectWrap::Unwrap(args.Holder()); 75 | // const char* message = TF_Message(status->status()); 76 | // args.GetReturnValue().Set(String::NewFromUtf8(isolate, message)); 77 | // } 78 | 79 | // void Graph::Set(const FunctionCallbackInfo& args) { 80 | // Isolate* isolate = args.GetIsolate(); 81 | // Status* status = ObjectWrap::Unwrap(args.Holder()); 82 | // int64_t _code = args[0]->IntegerValue(); 83 | // TF_Code code = static_cast(_code); 84 | // String::Utf8Value s(args[1]); 85 | // const char* _message = *s; 86 | // TF_SetStatus(status->status(), code, _message); 87 | // args.GetReturnValue().Set(Undefined(isolate)); 88 | // } 89 | 90 | } // namespace demo 91 | -------------------------------------------------------------------------------- /src/graph.h: -------------------------------------------------------------------------------- 1 | #ifndef GRAPH_H 2 | #define GRAPH_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | namespace tensorflow_node { 9 | 10 | class Graph : public node::ObjectWrap { 11 | public: 12 | static void Init(v8::Local exports); 13 | // static void GetCode(const v8::FunctionCallbackInfo& args); 14 | // static void GetMessage(const v8::FunctionCallbackInfo& args); 15 | inline TF_Graph* graph() const { return graph_; } 16 | 17 | private: 18 | explicit Graph(TF_Graph* graph); 19 | ~Graph(); 20 | 21 | static void New(const v8::FunctionCallbackInfo& args); 22 | static v8::Persistent constructor; 23 | TF_Graph* graph_; 24 | }; 25 | 26 | } // namespace demo 27 | 28 | #endif 29 | -------------------------------------------------------------------------------- /src/session.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include "session.h" 3 | 4 | namespace tensorflow_node { 5 | 6 | } 7 | -------------------------------------------------------------------------------- /src/session.h: -------------------------------------------------------------------------------- 1 | #ifndef SESSION_H 2 | #define SESSION_H 3 | 4 | #endif 5 | -------------------------------------------------------------------------------- /src/status.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include "status.h" 3 | 4 | namespace tensorflow_node { 5 | 6 | using v8::Context; 7 | using v8::Function; 8 | using v8::FunctionCallbackInfo; 9 | using v8::FunctionTemplate; 10 | using v8::Isolate; 11 | using v8::Local; 12 | using v8::Object; 13 | using v8::Persistent; 14 | using v8::String; 15 | using v8::Integer; 16 | using v8::Value; 17 | 18 | Persistent Status::constructor; 19 | 20 | Status::Status(TF_Status* status) : status_(status) { 21 | } 22 | 23 | Status::~Status() { 24 | TF_DeleteStatus(status_); 25 | } 26 | 27 | void Status::Init(Local exports) { 28 | Isolate* isolate = exports->GetIsolate(); 29 | // Prepare constructor template 30 | Local tpl = FunctionTemplate::New(isolate, New); 31 | tpl->SetClassName(String::NewFromUtf8(isolate, "Status")); 32 | tpl->InstanceTemplate()->SetInternalFieldCount(1); 33 | // Prototype 34 | NODE_SET_PROTOTYPE_METHOD(tpl, "getCode", GetCode); 35 | NODE_SET_PROTOTYPE_METHOD(tpl, "getMessage", GetMessage); 36 | NODE_SET_PROTOTYPE_METHOD(tpl, "set", Set); 37 | 38 | constructor.Reset(isolate, tpl->GetFunction()); 39 | 40 | exports->Set(String::NewFromUtf8(isolate, "Status"), 41 | tpl->GetFunction()); 42 | } 43 | 44 | void Status::New(const FunctionCallbackInfo& args) { 45 | Isolate* isolate = args.GetIsolate(); 46 | 47 | if (args.IsConstructCall()) { 48 | // Invoked as constructor: `new Status(...)` 49 | TF_Status* status = TF_NewStatus(); 50 | Status* obj = new Status(status); 51 | obj->Wrap(args.This()); 52 | args.GetReturnValue().Set(args.This()); 53 | } else { 54 | // Invoked as plain function `Status(...)`, turn into construct call. 55 | const int argc = 1; 56 | Local argv[argc] = { args[0] }; 57 | Local context = isolate->GetCurrentContext(); 58 | Local cons = Local::New(isolate, constructor); 59 | Local instance = 60 | cons->NewInstance(context, argc, argv).ToLocalChecked(); 61 | args.GetReturnValue().Set(instance); 62 | } 63 | } 64 | 65 | void Status::GetCode(const FunctionCallbackInfo& args) { 66 | Isolate* isolate = args.GetIsolate(); 67 | Status* status = ObjectWrap::Unwrap(args.Holder()); 68 | TF_Code code = TF_GetCode(status->status()); 69 | args.GetReturnValue().Set(Integer::New(isolate, code)); 70 | } 71 | 72 | void Status::GetMessage(const FunctionCallbackInfo& args) { 73 | Isolate* isolate = args.GetIsolate(); 74 | Status* status = ObjectWrap::Unwrap(args.Holder()); 75 | const char* message = TF_Message(status->status()); 76 | args.GetReturnValue().Set(String::NewFromUtf8(isolate, message)); 77 | } 78 | 79 | void Status::Set(const FunctionCallbackInfo& args) { 80 | Isolate* isolate = args.GetIsolate(); 81 | Status* status = ObjectWrap::Unwrap(args.Holder()); 82 | int64_t _code = args[0]->IntegerValue(); 83 | TF_Code code = static_cast(_code); 84 | String::Utf8Value s(args[1]); 85 | const char* _message = *s; 86 | TF_SetStatus(status->status(), code, _message); 87 | args.GetReturnValue().Set(Undefined(isolate)); 88 | } 89 | 90 | } // namespace demo 91 | -------------------------------------------------------------------------------- /src/status.h: -------------------------------------------------------------------------------- 1 | #ifndef STATUS_H 2 | #define STATUS_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | namespace tensorflow_node { 9 | 10 | class Status : public node::ObjectWrap { 11 | public: 12 | static void Init(v8::Local exports); 13 | static void GetCode(const v8::FunctionCallbackInfo& args); 14 | static void GetMessage(const v8::FunctionCallbackInfo& args); 15 | static void Set(const v8::FunctionCallbackInfo& args); 16 | inline TF_Status* status() const { return status_; } 17 | 18 | private: 19 | explicit Status(TF_Status* status); 20 | ~Status(); 21 | 22 | static void New(const v8::FunctionCallbackInfo& args); 23 | static v8::Persistent constructor; 24 | TF_Status* status_; 25 | }; 26 | 27 | } // namespace demo 28 | 29 | #endif 30 | -------------------------------------------------------------------------------- /test/index.test.js: -------------------------------------------------------------------------------- 1 | 'use strict'; 2 | 3 | const expect = require('expect.js'); 4 | 5 | const tf = require('../'); 6 | 7 | describe('tensorflow', function () { 8 | it('version()', function() { 9 | expect(tf.version()).to.be.ok(); 10 | }); 11 | 12 | it('DataType', function () { 13 | console.log(tf); 14 | expect(tf.DataType).to.be.ok(); 15 | }); 16 | 17 | it('Status', function () { 18 | var status = new tf.Status(); 19 | var code = status.getCode(); 20 | expect(code).to.be(tf.Code.OK); 21 | var message = status.getMessage(); 22 | expect(message).to.be(''); 23 | status.set(tf.Code.CANCELLED, 'cancel'); 24 | expect(status.getCode()).to.be(tf.Code.CANCELLED); 25 | expect(status.getMessage()).to.be('cancel'); 26 | // TF_DeleteStatus(s); 27 | }); 28 | 29 | it('Tensor', function() { 30 | 31 | }); 32 | // TEST(CAPI, Tensor) { 33 | // const int num_bytes = 6 * sizeof(float); 34 | // float* values = 35 | // reinterpret_cast(tensorflow::cpu_allocator()->AllocateRaw( 36 | // EIGEN_MAX_ALIGN_BYTES, num_bytes)); 37 | // int64_t dims[] = {2, 3}; 38 | // bool deallocator_called = false; 39 | // TF_Tensor* t = TF_NewTensor(TF_FLOAT, dims, 2, values, num_bytes, 40 | // &Deallocator, &deallocator_called); 41 | // EXPECT_FALSE(deallocator_called); 42 | // EXPECT_EQ(TF_FLOAT, TF_TensorType(t)); 43 | // EXPECT_EQ(2, TF_NumDims(t)); 44 | // EXPECT_EQ(dims[0], TF_Dim(t, 0)); 45 | // EXPECT_EQ(dims[1], TF_Dim(t, 1)); 46 | // EXPECT_EQ(num_bytes, TF_TensorByteSize(t)); 47 | // EXPECT_EQ(static_cast(values), TF_TensorData(t)); 48 | // TF_DeleteTensor(t); 49 | // EXPECT_TRUE(deallocator_called); 50 | // } 51 | 52 | it('Graph', function () { 53 | const s = new tf.Status(); 54 | const graph = new tf.Graph(); 55 | 56 | // Make a placeholder operation. 57 | const feed = tf.placeholder(graph, s); 58 | // TF_Operation* feed = Placeholder(graph, s); 59 | expect(tf.Code.OK).to.be(s.getCode()); 60 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 61 | 62 | // // Test TF_Operation*() query functions. 63 | expect(feed.operationName).to.be('feed'); 64 | // EXPECT_EQ(string("feed"), string(TF_OperationName(feed))); 65 | expect(feed.operationOpType).to.be('Placeholder'); 66 | // EXPECT_EQ(string("Placeholder"), string(TF_OperationOpType(feed))); 67 | expect(feed.operationDevice).to.be(''); 68 | // EXPECT_EQ(string(""), string(TF_OperationDevice(feed))); 69 | expect(feed.operationNumOutputs).to.be(1); 70 | // EXPECT_EQ(1, TF_OperationNumOutputs(feed)); 71 | // expect(feed.operationOutputType).to.be(tf.DataType.INT32); 72 | // EXPECT_EQ(TF_INT32, TF_OperationOutputType(TF_Output{feed, 0})); 73 | // EXPECT_EQ(1, TF_OperationOutputListLength(feed, "output", s)); 74 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 75 | // EXPECT_EQ(0, TF_OperationNumInputs(feed)); 76 | // EXPECT_EQ(0, TF_OperationOutputNumConsumers(TF_Output{feed, 0})); 77 | // EXPECT_EQ(0, TF_OperationNumControlInputs(feed)); 78 | // EXPECT_EQ(0, TF_OperationNumControlOutputs(feed)); 79 | 80 | // tensorflow::AttrValue attr_value; 81 | // ASSERT_TRUE(GetAttrValue(feed, "dtype", &attr_value, s)) << TF_Message(s); 82 | // EXPECT_EQ(attr_value.type(), tensorflow::DT_INT32); 83 | 84 | // // Test not found errors in TF_Operation*() query functions. 85 | // EXPECT_EQ(-1, TF_OperationOutputListLength(feed, "bogus", s)); 86 | // EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s)); 87 | 88 | // ASSERT_FALSE(GetAttrValue(feed, "missing", &attr_value, s)); 89 | // EXPECT_EQ(string("Operation has no attr named 'missing'."), 90 | // string(TF_Message(s))); 91 | 92 | // // Make a constant oper with the scalar "3". 93 | // TF_Operation* three = ScalarConst(3, graph, s); 94 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 95 | 96 | // // Add oper. 97 | // TF_Operation* add = Add(feed, three, graph, s); 98 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 99 | 100 | // // Test TF_Operation*() query functions. 101 | // EXPECT_EQ(string("add"), string(TF_OperationName(add))); 102 | // EXPECT_EQ(string("AddN"), string(TF_OperationOpType(add))); 103 | // EXPECT_EQ(string(""), string(TF_OperationDevice(add))); 104 | // EXPECT_EQ(1, TF_OperationNumOutputs(add)); 105 | // EXPECT_EQ(TF_INT32, TF_OperationOutputType(TF_Output{add, 0})); 106 | // EXPECT_EQ(1, TF_OperationOutputListLength(add, "sum", s)); 107 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 108 | // EXPECT_EQ(2, TF_OperationNumInputs(add)); 109 | // EXPECT_EQ(2, TF_OperationInputListLength(add, "inputs", s)); 110 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 111 | // EXPECT_EQ(TF_INT32, TF_OperationInputType(TF_Input{add, 0})); 112 | // EXPECT_EQ(TF_INT32, TF_OperationInputType(TF_Input{add, 1})); 113 | // TF_Output add_in_0 = TF_OperationInput(TF_Input{add, 0}); 114 | // EXPECT_EQ(feed, add_in_0.oper); 115 | // EXPECT_EQ(0, add_in_0.index); 116 | // TF_Output add_in_1 = TF_OperationInput(TF_Input{add, 1}); 117 | // EXPECT_EQ(three, add_in_1.oper); 118 | // EXPECT_EQ(0, add_in_1.index); 119 | // EXPECT_EQ(0, TF_OperationOutputNumConsumers(TF_Output{add, 0})); 120 | // EXPECT_EQ(0, TF_OperationNumControlInputs(add)); 121 | // EXPECT_EQ(0, TF_OperationNumControlOutputs(add)); 122 | 123 | // ASSERT_TRUE(GetAttrValue(add, "T", &attr_value, s)) << TF_Message(s); 124 | // EXPECT_EQ(attr_value.type(), tensorflow::DT_INT32); 125 | // ASSERT_TRUE(GetAttrValue(add, "N", &attr_value, s)) << TF_Message(s); 126 | // EXPECT_EQ(attr_value.i(), 2); 127 | 128 | // // Placeholder oper now has a consumer. 129 | // ASSERT_EQ(1, TF_OperationOutputNumConsumers(TF_Output{feed, 0})); 130 | // TF_Input feed_port; 131 | // EXPECT_EQ(1, TF_OperationOutputConsumers(TF_Output{feed, 0}, &feed_port, 1)); 132 | // EXPECT_EQ(add, feed_port.oper); 133 | // EXPECT_EQ(0, feed_port.index); 134 | 135 | // // The scalar const oper also has a consumer. 136 | // ASSERT_EQ(1, TF_OperationOutputNumConsumers(TF_Output{three, 0})); 137 | // TF_Input three_port; 138 | // EXPECT_EQ(1, 139 | // TF_OperationOutputConsumers(TF_Output{three, 0}, &three_port, 1)); 140 | // EXPECT_EQ(add, three_port.oper); 141 | // EXPECT_EQ(1, three_port.index); 142 | 143 | // // Serialize to GraphDef. 144 | // GraphDef graph_def; 145 | // ASSERT_TRUE(GetGraphDef(graph, &graph_def)); 146 | 147 | // // Validate GraphDef is what we expect. 148 | // bool found_placeholder = false; 149 | // bool found_scalar_const = false; 150 | // bool found_add = false; 151 | // for (const auto& n : graph_def.node()) { 152 | // if (IsPlaceholder(n)) { 153 | // EXPECT_FALSE(found_placeholder); 154 | // found_placeholder = true; 155 | // } else if (IsScalarConst(n, 3)) { 156 | // EXPECT_FALSE(found_scalar_const); 157 | // found_scalar_const = true; 158 | // } else if (IsAddN(n, 2)) { 159 | // EXPECT_FALSE(found_add); 160 | // found_add = true; 161 | // } else { 162 | // ADD_FAILURE() << "Unexpected NodeDef: " << ProtoDebugString(n); 163 | // } 164 | // } 165 | // EXPECT_TRUE(found_placeholder); 166 | // EXPECT_TRUE(found_scalar_const); 167 | // EXPECT_TRUE(found_add); 168 | 169 | // // Add another oper to the graph. 170 | // TF_Operation* neg = Neg(add, graph, s); 171 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 172 | 173 | // // Serialize to NodeDef. 174 | // NodeDef node_def; 175 | // ASSERT_TRUE(GetNodeDef(neg, &node_def)); 176 | 177 | // // Validate NodeDef is what we expect. 178 | // EXPECT_TRUE(IsNeg(node_def, "add")); 179 | 180 | // // Serialize to GraphDef. 181 | // GraphDef graph_def2; 182 | // ASSERT_TRUE(GetGraphDef(graph, &graph_def2)); 183 | 184 | // // Compare with first GraphDef + added NodeDef. 185 | // NodeDef* added_node = graph_def.add_node(); 186 | // *added_node = node_def; 187 | // EXPECT_EQ(ProtoDebugString(graph_def), ProtoDebugString(graph_def2)); 188 | 189 | // // Look up some nodes by name. 190 | // TF_Operation* neg2 = TF_GraphOperationByName(graph, "neg"); 191 | // EXPECT_TRUE(neg == neg2); 192 | // NodeDef node_def2; 193 | // ASSERT_TRUE(GetNodeDef(neg2, &node_def2)); 194 | // EXPECT_EQ(ProtoDebugString(node_def), ProtoDebugString(node_def2)); 195 | 196 | // TF_Operation* feed2 = TF_GraphOperationByName(graph, "feed"); 197 | // EXPECT_TRUE(feed == feed2); 198 | // ASSERT_TRUE(GetNodeDef(feed, &node_def)); 199 | // ASSERT_TRUE(GetNodeDef(feed2, &node_def2)); 200 | // EXPECT_EQ(ProtoDebugString(node_def), ProtoDebugString(node_def2)); 201 | 202 | // // Test iterating through the nodes of a graph. 203 | // found_placeholder = false; 204 | // found_scalar_const = false; 205 | // found_add = false; 206 | // bool found_neg = false; 207 | // size_t pos = 0; 208 | // TF_Operation* oper; 209 | // while ((oper = TF_GraphNextOperation(graph, &pos)) != nullptr) { 210 | // if (oper == feed) { 211 | // EXPECT_FALSE(found_placeholder); 212 | // found_placeholder = true; 213 | // } else if (oper == three) { 214 | // EXPECT_FALSE(found_scalar_const); 215 | // found_scalar_const = true; 216 | // } else if (oper == add) { 217 | // EXPECT_FALSE(found_add); 218 | // found_add = true; 219 | // } else if (oper == neg) { 220 | // EXPECT_FALSE(found_neg); 221 | // found_neg = true; 222 | // } else { 223 | // ASSERT_TRUE(GetNodeDef(oper, &node_def)); 224 | // ADD_FAILURE() << "Unexpected Node: " << ProtoDebugString(node_def); 225 | // } 226 | // } 227 | // EXPECT_TRUE(found_placeholder); 228 | // EXPECT_TRUE(found_scalar_const); 229 | // EXPECT_TRUE(found_add); 230 | // EXPECT_TRUE(found_neg); 231 | 232 | // // Clean up 233 | // TF_DeleteGraph(graph); 234 | // TF_DeleteStatus(s); 235 | }); 236 | }); 237 | 238 | // namespace { 239 | 240 | // typedef std::unique_ptr 241 | // unique_tensor_ptr; 242 | 243 | // TEST(CAPI, Status) { 244 | // TF_Status* s = TF_NewStatus(); 245 | // EXPECT_EQ(TF_OK, TF_GetCode(s)); 246 | // EXPECT_EQ(string(), TF_Message(s)); 247 | // TF_SetStatus(s, TF_CANCELLED, "cancel"); 248 | // EXPECT_EQ(TF_CANCELLED, TF_GetCode(s)); 249 | // EXPECT_EQ(string("cancel"), TF_Message(s)); 250 | // TF_DeleteStatus(s); 251 | // } 252 | 253 | // static void Deallocator(void* data, size_t, void* arg) { 254 | // tensorflow::cpu_allocator()->DeallocateRaw(data); 255 | // *reinterpret_cast(arg) = true; 256 | // } 257 | 258 | // TEST(CAPI, Tensor) { 259 | // const int num_bytes = 6 * sizeof(float); 260 | // float* values = 261 | // reinterpret_cast(tensorflow::cpu_allocator()->AllocateRaw( 262 | // EIGEN_MAX_ALIGN_BYTES, num_bytes)); 263 | // int64_t dims[] = {2, 3}; 264 | // bool deallocator_called = false; 265 | // TF_Tensor* t = TF_NewTensor(TF_FLOAT, dims, 2, values, num_bytes, 266 | // &Deallocator, &deallocator_called); 267 | // EXPECT_FALSE(deallocator_called); 268 | // EXPECT_EQ(TF_FLOAT, TF_TensorType(t)); 269 | // EXPECT_EQ(2, TF_NumDims(t)); 270 | // EXPECT_EQ(dims[0], TF_Dim(t, 0)); 271 | // EXPECT_EQ(dims[1], TF_Dim(t, 1)); 272 | // EXPECT_EQ(num_bytes, TF_TensorByteSize(t)); 273 | // EXPECT_EQ(static_cast(values), TF_TensorData(t)); 274 | // TF_DeleteTensor(t); 275 | // EXPECT_TRUE(deallocator_called); 276 | // } 277 | 278 | // TEST(CAPI, AllocateTensor) { 279 | // const int num_bytes = 6 * sizeof(float); 280 | // int64_t dims[] = {2, 3}; 281 | // TF_Tensor* t = TF_AllocateTensor(TF_FLOAT, dims, 2, num_bytes); 282 | // EXPECT_EQ(TF_FLOAT, TF_TensorType(t)); 283 | // EXPECT_EQ(2, TF_NumDims(t)); 284 | // EXPECT_EQ(dims[0], TF_Dim(t, 0)); 285 | // EXPECT_EQ(dims[1], TF_Dim(t, 1)); 286 | // EXPECT_EQ(num_bytes, TF_TensorByteSize(t)); 287 | // TF_DeleteTensor(t); 288 | // } 289 | 290 | // TEST(CAPI, LibraryLoadFunctions) { 291 | // // Load the library. 292 | // TF_Status* status = TF_NewStatus(); 293 | // TF_Library* lib = 294 | // TF_LoadLibrary("tensorflow/c/test_op.so", status); 295 | // TF_Code code = TF_GetCode(status); 296 | // string status_msg(TF_Message(status)); 297 | // TF_DeleteStatus(status); 298 | // ASSERT_EQ(TF_OK, code) << status_msg; 299 | 300 | // // Test op list. 301 | // TF_Buffer op_list_buf = TF_GetOpList(lib); 302 | // tensorflow::OpList op_list; 303 | // EXPECT_TRUE(op_list.ParseFromArray(op_list_buf.data, op_list_buf.length)); 304 | // ASSERT_EQ(op_list.op_size(), 1); 305 | // EXPECT_EQ("TestCApi", op_list.op(0).name()); 306 | 307 | // TF_DeleteLibraryHandle(lib); 308 | // } 309 | 310 | // static void TestEncodeDecode(int line, const std::vector& data) { 311 | // const tensorflow::int64 n = data.size(); 312 | // for (const std::vector& dims : 313 | // std::vector>{ 314 | // {n}, {1, n}, {n, 1}, {n / 2, 2}}) { 315 | // // Create C++ Tensor 316 | // Tensor src(tensorflow::DT_STRING, TensorShape(dims)); 317 | // for (tensorflow::int64 i = 0; i < src.NumElements(); ++i) { 318 | // src.flat()(i) = data[i]; 319 | // } 320 | // TF_Tensor* dst = TF_Tensor_EncodeStrings(src); 321 | 322 | // // Convert back to a C++ Tensor and ensure we get expected output. 323 | // TF_Status* status = TF_NewStatus(); 324 | // Tensor output; 325 | // ASSERT_TRUE(TF_Tensor_DecodeStrings(dst, &output, status)) << line; 326 | // ASSERT_EQ(TF_OK, TF_GetCode(status)) << line; 327 | // ASSERT_EQ(src.NumElements(), output.NumElements()) << line; 328 | // for (tensorflow::int64 i = 0; i < src.NumElements(); ++i) { 329 | // ASSERT_EQ(data[i], output.flat()(i)) << line; 330 | // } 331 | 332 | // TF_DeleteStatus(status); 333 | // TF_DeleteTensor(dst); 334 | // } 335 | // } 336 | 337 | // TEST(CAPI, TensorEncodeDecodeStrings) { 338 | // TestEncodeDecode(__LINE__, {}); 339 | // TestEncodeDecode(__LINE__, {"hello"}); 340 | // TestEncodeDecode(__LINE__, 341 | // {"the", "quick", "brown", "fox", "jumped", "over"}); 342 | 343 | // string big(1000, 'a'); 344 | // TestEncodeDecode(__LINE__, {"small", big, "small2"}); 345 | // } 346 | 347 | // TEST(CAPI, SessionOptions) { 348 | // TF_SessionOptions* opt = TF_NewSessionOptions(); 349 | // TF_DeleteSessionOptions(opt); 350 | // } 351 | 352 | // TEST(CAPI, DeprecatedSession) { 353 | // TF_Status* s = TF_NewStatus(); 354 | // TF_SessionOptions* opt = TF_NewSessionOptions(); 355 | // TF_DeprecatedSession* session = TF_NewDeprecatedSession(opt, s); 356 | // TF_DeleteSessionOptions(opt); 357 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 358 | 359 | // TF_Buffer* run_options = TF_NewBufferFromString("", 0); 360 | // TF_Buffer* run_metadata = TF_NewBuffer(); 361 | // TF_Run(session, run_options, nullptr, nullptr, 0, nullptr, nullptr, 0, 362 | // nullptr, 0, run_metadata, s); 363 | // EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s)) << TF_Message(s); 364 | // EXPECT_EQ(std::string("Session was not created with a graph before Run()!"), 365 | // std::string(TF_Message(s))); 366 | // TF_DeleteBuffer(run_metadata); 367 | // TF_DeleteBuffer(run_options); 368 | 369 | // TF_DeleteDeprecatedSession(session, s); 370 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 371 | 372 | // TF_DeleteStatus(s); 373 | // } 374 | 375 | // TEST(CAPI, DataTypeEnum) { 376 | // EXPECT_EQ(TF_FLOAT, static_cast(tensorflow::DT_FLOAT)); 377 | // EXPECT_EQ(TF_DOUBLE, static_cast(tensorflow::DT_DOUBLE)); 378 | // EXPECT_EQ(TF_INT32, static_cast(tensorflow::DT_INT32)); 379 | // EXPECT_EQ(TF_UINT8, static_cast(tensorflow::DT_UINT8)); 380 | // EXPECT_EQ(TF_INT16, static_cast(tensorflow::DT_INT16)); 381 | // EXPECT_EQ(TF_INT8, static_cast(tensorflow::DT_INT8)); 382 | // EXPECT_EQ(TF_STRING, static_cast(tensorflow::DT_STRING)); 383 | // EXPECT_EQ(TF_COMPLEX64, static_cast(tensorflow::DT_COMPLEX64)); 384 | // EXPECT_EQ(TF_COMPLEX, TF_COMPLEX64); 385 | // EXPECT_EQ(TF_INT64, static_cast(tensorflow::DT_INT64)); 386 | // EXPECT_EQ(TF_BOOL, static_cast(tensorflow::DT_BOOL)); 387 | // EXPECT_EQ(TF_QINT8, static_cast(tensorflow::DT_QINT8)); 388 | // EXPECT_EQ(TF_QUINT8, static_cast(tensorflow::DT_QUINT8)); 389 | // EXPECT_EQ(TF_QINT32, static_cast(tensorflow::DT_QINT32)); 390 | // EXPECT_EQ(TF_BFLOAT16, static_cast(tensorflow::DT_BFLOAT16)); 391 | // EXPECT_EQ(TF_QINT16, static_cast(tensorflow::DT_QINT16)); 392 | // EXPECT_EQ(TF_QUINT16, static_cast(tensorflow::DT_QUINT16)); 393 | // EXPECT_EQ(TF_UINT16, static_cast(tensorflow::DT_UINT16)); 394 | // EXPECT_EQ(TF_COMPLEX128, static_cast(tensorflow::DT_COMPLEX128)); 395 | // EXPECT_EQ(TF_HALF, static_cast(tensorflow::DT_HALF)); 396 | // EXPECT_EQ(TF_DataTypeSize(TF_DOUBLE), 397 | // tensorflow::DataTypeSize(tensorflow::DT_DOUBLE)); 398 | // EXPECT_EQ(TF_DataTypeSize(TF_STRING), 399 | // tensorflow::DataTypeSize(tensorflow::DT_STRING)); 400 | // // Test with invalid type; should always return 0 as documented 401 | // EXPECT_EQ(TF_DataTypeSize(static_cast(0)), 0); 402 | // } 403 | 404 | // TEST(CAPI, StatusEnum) { 405 | // EXPECT_EQ(TF_OK, static_cast(tensorflow::error::OK)); 406 | // EXPECT_EQ(TF_CANCELLED, static_cast(tensorflow::error::CANCELLED)); 407 | // EXPECT_EQ(TF_UNKNOWN, static_cast(tensorflow::error::UNKNOWN)); 408 | // EXPECT_EQ(TF_INVALID_ARGUMENT, 409 | // static_cast(tensorflow::error::INVALID_ARGUMENT)); 410 | // EXPECT_EQ(TF_DEADLINE_EXCEEDED, 411 | // static_cast(tensorflow::error::DEADLINE_EXCEEDED)); 412 | // EXPECT_EQ(TF_NOT_FOUND, static_cast(tensorflow::error::NOT_FOUND)); 413 | // EXPECT_EQ(TF_ALREADY_EXISTS, 414 | // static_cast(tensorflow::error::ALREADY_EXISTS)); 415 | // EXPECT_EQ(TF_PERMISSION_DENIED, 416 | // static_cast(tensorflow::error::PERMISSION_DENIED)); 417 | // EXPECT_EQ(TF_UNAUTHENTICATED, 418 | // static_cast(tensorflow::error::UNAUTHENTICATED)); 419 | // EXPECT_EQ(TF_RESOURCE_EXHAUSTED, 420 | // static_cast(tensorflow::error::RESOURCE_EXHAUSTED)); 421 | // EXPECT_EQ(TF_FAILED_PRECONDITION, 422 | // static_cast(tensorflow::error::FAILED_PRECONDITION)); 423 | // EXPECT_EQ(TF_ABORTED, static_cast(tensorflow::error::ABORTED)); 424 | // EXPECT_EQ(TF_OUT_OF_RANGE, 425 | // static_cast(tensorflow::error::OUT_OF_RANGE)); 426 | // EXPECT_EQ(TF_UNIMPLEMENTED, 427 | // static_cast(tensorflow::error::UNIMPLEMENTED)); 428 | // EXPECT_EQ(TF_INTERNAL, static_cast(tensorflow::error::INTERNAL)); 429 | // EXPECT_EQ(TF_UNAVAILABLE, 430 | // static_cast(tensorflow::error::UNAVAILABLE)); 431 | // EXPECT_EQ(TF_DATA_LOSS, static_cast(tensorflow::error::DATA_LOSS)); 432 | // } 433 | 434 | // TEST(CAPI, GetAllOpList) { 435 | // TF_Buffer* buf = TF_GetAllOpList(); 436 | // tensorflow::OpList op_list; 437 | // EXPECT_TRUE(op_list.ParseFromArray(buf->data, buf->length)); 438 | // EXPECT_GT(op_list.op_size(), 0); 439 | // TF_DeleteBuffer(buf); 440 | // } 441 | 442 | // static void Int32Deallocator(void* data, size_t, void* arg) { 443 | // delete[] static_cast(data); 444 | // } 445 | 446 | // static TF_Tensor* Int32Tensor(int32 v) { 447 | // const int num_bytes = sizeof(int32); 448 | // int32* values = new int32[1]; 449 | // values[0] = v; 450 | // return TF_NewTensor(TF_INT32, nullptr, 0, values, num_bytes, 451 | // &Int32Deallocator, nullptr); 452 | // } 453 | 454 | // TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, 455 | // const char* name = "feed") { 456 | // TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name); 457 | // TF_SetAttrType(desc, "dtype", TF_INT32); 458 | // return TF_FinishOperation(desc, s); 459 | // } 460 | 461 | // TF_Operation* ScalarConst(int32 v, TF_Graph* graph, TF_Status* s, 462 | // const char* name = "scalar") { 463 | // unique_tensor_ptr tensor(Int32Tensor(v), TF_DeleteTensor); 464 | // TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name); 465 | // TF_SetAttrTensor(desc, "value", tensor.get(), s); 466 | // if (TF_GetCode(s) != TF_OK) return nullptr; 467 | // TF_SetAttrType(desc, "dtype", TF_INT32); 468 | // return TF_FinishOperation(desc, s); 469 | // } 470 | 471 | // TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph, 472 | // TF_Status* s, const char* name = "add") { 473 | // TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name); 474 | // TF_Output add_inputs[2] = {{l, 0}, {r, 0}}; 475 | // TF_AddInputList(desc, add_inputs, 2); 476 | // return TF_FinishOperation(desc, s); 477 | // } 478 | 479 | // TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s, 480 | // const char* name = "add") { 481 | // TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name); 482 | // TF_Output inputs[2] = {l, r}; 483 | // TF_AddInputList(desc, inputs, 2); 484 | // return TF_FinishOperation(desc, s); 485 | // } 486 | 487 | // TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s) { 488 | // TF_OperationDescription* desc = TF_NewOperation(graph, "Neg", "neg"); 489 | // TF_Output neg_input = {n, 0}; 490 | // TF_AddInput(desc, neg_input); 491 | // return TF_FinishOperation(desc, s); 492 | // } 493 | 494 | // TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph, 495 | // TF_Status* s) { 496 | // TF_OperationDescription* desc = TF_NewOperation(graph, "Less", "less_than"); 497 | // TF_AddInput(desc, l); 498 | // TF_AddInput(desc, r); 499 | // return TF_FinishOperation(desc, s); 500 | // } 501 | 502 | // bool IsPlaceholder(const NodeDef& node_def) { 503 | // if (node_def.op() != "Placeholder" || node_def.name() != "feed") { 504 | // return false; 505 | // } 506 | // bool found_dtype = false; 507 | // bool found_shape = false; 508 | // for (const auto& attr : node_def.attr()) { 509 | // if (attr.first == "dtype") { 510 | // if (attr.second.type() == tensorflow::DT_INT32) { 511 | // found_dtype = true; 512 | // } else { 513 | // return false; 514 | // } 515 | // } else if (attr.first == "shape") { 516 | // found_shape = true; 517 | // } 518 | // } 519 | // return found_dtype && found_shape; 520 | // } 521 | 522 | // bool IsScalarConst(const NodeDef& node_def, int v) { 523 | // if (node_def.op() != "Const" || node_def.name() != "scalar") { 524 | // return false; 525 | // } 526 | // bool found_dtype = false; 527 | // bool found_value = false; 528 | // for (const auto& attr : node_def.attr()) { 529 | // if (attr.first == "dtype") { 530 | // if (attr.second.type() == tensorflow::DT_INT32) { 531 | // found_dtype = true; 532 | // } else { 533 | // return false; 534 | // } 535 | // } else if (attr.first == "value") { 536 | // if (attr.second.has_tensor() && 537 | // attr.second.tensor().int_val_size() == 1 && 538 | // attr.second.tensor().int_val(0) == v) { 539 | // found_value = true; 540 | // } else { 541 | // return false; 542 | // } 543 | // } 544 | // } 545 | // return found_dtype && found_value; 546 | // } 547 | 548 | // bool IsAddN(const NodeDef& node_def, int n) { 549 | // if (node_def.op() != "AddN" || node_def.name() != "add" || 550 | // node_def.input_size() != n) { 551 | // return false; 552 | // } 553 | // bool found_t = false; 554 | // bool found_n = false; 555 | // for (const auto& attr : node_def.attr()) { 556 | // if (attr.first == "T") { 557 | // if (attr.second.type() == tensorflow::DT_INT32) { 558 | // found_t = true; 559 | // } else { 560 | // return false; 561 | // } 562 | // } else if (attr.first == "N") { 563 | // if (attr.second.i() == n) { 564 | // found_n = true; 565 | // } else { 566 | // return false; 567 | // } 568 | // } 569 | // } 570 | // return found_t && found_n; 571 | // } 572 | 573 | // bool IsNeg(const NodeDef& node_def, const string& input) { 574 | // return node_def.op() == "Neg" && node_def.name() == "neg" && 575 | // node_def.input_size() == 1 && node_def.input(0) == input; 576 | // } 577 | 578 | // bool GetGraphDef(TF_Graph* graph, GraphDef* graph_def) { 579 | // TF_Status* s = TF_NewStatus(); 580 | // TF_Buffer* buffer = TF_NewBuffer(); 581 | // TF_GraphToGraphDef(graph, buffer, s); 582 | // bool ret = TF_GetCode(s) == TF_OK; 583 | // EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 584 | // if (ret) ret = graph_def->ParseFromArray(buffer->data, buffer->length); 585 | // TF_DeleteBuffer(buffer); 586 | // TF_DeleteStatus(s); 587 | // return ret; 588 | // } 589 | 590 | // bool GetNodeDef(TF_Operation* oper, NodeDef* node_def) { 591 | // TF_Status* s = TF_NewStatus(); 592 | // TF_Buffer* buffer = TF_NewBuffer(); 593 | // TF_OperationToNodeDef(oper, buffer, s); 594 | // bool ret = TF_GetCode(s) == TF_OK; 595 | // EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 596 | // if (ret) ret = node_def->ParseFromArray(buffer->data, buffer->length); 597 | // TF_DeleteBuffer(buffer); 598 | // TF_DeleteStatus(s); 599 | // return ret; 600 | // } 601 | 602 | // bool GetAttrValue(TF_Operation* oper, const char* attr_name, 603 | // tensorflow::AttrValue* attr_value, TF_Status* s) { 604 | // TF_Buffer* buffer = TF_NewBuffer(); 605 | // TF_OperationGetAttrValueProto(oper, attr_name, buffer, s); 606 | // bool ret = TF_GetCode(s) == TF_OK; 607 | // if (ret) ret = attr_value->ParseFromArray(buffer->data, buffer->length); 608 | // TF_DeleteBuffer(buffer); 609 | // return ret; 610 | // } 611 | 612 | // TEST(CAPI, SetShape) { 613 | // TF_Status* s = TF_NewStatus(); 614 | // TF_Graph* graph = TF_NewGraph(); 615 | 616 | // TF_Operation* feed = Placeholder(graph, s); 617 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 618 | // TF_Output feed_out_0 = TF_Output{feed, 0}; 619 | // int num_dims; 620 | 621 | // // Fetch the shape, it should be completely unknown. 622 | // num_dims = TF_GraphGetTensorNumDims(graph, feed_out_0, s); 623 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 624 | // EXPECT_EQ(-1, num_dims); 625 | 626 | // // Set the shape to be 2 x Unknown 627 | // int64_t dims[] = {2, -1}; 628 | // TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s); 629 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 630 | 631 | // // Fetch the shape and validate it is 2 by -1. 632 | // num_dims = TF_GraphGetTensorNumDims(graph, feed_out_0, s); 633 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 634 | // EXPECT_EQ(2, num_dims); 635 | 636 | // // Resize the dimension vector appropriately. 637 | // int64_t returned_dims[2]; 638 | // TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); 639 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 640 | // EXPECT_EQ(dims[0], returned_dims[0]); 641 | // EXPECT_EQ(dims[1], returned_dims[1]); 642 | 643 | // // Set to a new valid shape: [2, 3] 644 | // dims[1] = 3; 645 | // TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s); 646 | // EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 647 | 648 | // // Fetch and see that the new value is returned. 649 | // TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); 650 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 651 | // EXPECT_EQ(dims[0], returned_dims[0]); 652 | // EXPECT_EQ(dims[1], returned_dims[1]); 653 | 654 | // // Try to set 'unknown' on the shape and see that 655 | // // it doesn't change. 656 | // dims[0] = -1; 657 | // dims[1] = -1; 658 | // TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s); 659 | // EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 660 | // // Fetch and see that the new value is returned. 661 | // TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); 662 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 663 | // EXPECT_EQ(2, num_dims); 664 | // EXPECT_EQ(2, returned_dims[0]); 665 | // EXPECT_EQ(3, returned_dims[1]); 666 | 667 | // // Try to fetch a shape with the wrong num_dims 668 | // TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s); 669 | // EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s)) << TF_Message(s); 670 | 671 | // // Try to set an invalid shape (cannot change 2x3 to a 2x5). 672 | // dims[1] = 5; 673 | // TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s); 674 | // EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s)) << TF_Message(s); 675 | 676 | // // Test for a scalar. 677 | // TF_Operation* three = ScalarConst(3, graph, s); 678 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 679 | // TF_Output three_out_0 = TF_Output{three, 0}; 680 | 681 | // num_dims = TF_GraphGetTensorNumDims(graph, three_out_0, s); 682 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 683 | // EXPECT_EQ(0, num_dims); 684 | // TF_GraphGetTensorShape(graph, three_out_0, returned_dims, num_dims, s); 685 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 686 | 687 | // // Clean up 688 | // TF_DeleteGraph(graph); 689 | // TF_DeleteStatus(s); 690 | // } 691 | 692 | // TEST(CAPI, Graph) { 693 | // TF_Status* s = TF_NewStatus(); 694 | // TF_Graph* graph = TF_NewGraph(); 695 | 696 | // // Make a placeholder operation. 697 | // TF_Operation* feed = Placeholder(graph, s); 698 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 699 | 700 | // // Test TF_Operation*() query functions. 701 | // EXPECT_EQ(string("feed"), string(TF_OperationName(feed))); 702 | // EXPECT_EQ(string("Placeholder"), string(TF_OperationOpType(feed))); 703 | // EXPECT_EQ(string(""), string(TF_OperationDevice(feed))); 704 | // EXPECT_EQ(1, TF_OperationNumOutputs(feed)); 705 | // EXPECT_EQ(TF_INT32, TF_OperationOutputType(TF_Output{feed, 0})); 706 | // EXPECT_EQ(1, TF_OperationOutputListLength(feed, "output", s)); 707 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 708 | // EXPECT_EQ(0, TF_OperationNumInputs(feed)); 709 | // EXPECT_EQ(0, TF_OperationOutputNumConsumers(TF_Output{feed, 0})); 710 | // EXPECT_EQ(0, TF_OperationNumControlInputs(feed)); 711 | // EXPECT_EQ(0, TF_OperationNumControlOutputs(feed)); 712 | 713 | // tensorflow::AttrValue attr_value; 714 | // ASSERT_TRUE(GetAttrValue(feed, "dtype", &attr_value, s)) << TF_Message(s); 715 | // EXPECT_EQ(attr_value.type(), tensorflow::DT_INT32); 716 | 717 | // // Test not found errors in TF_Operation*() query functions. 718 | // EXPECT_EQ(-1, TF_OperationOutputListLength(feed, "bogus", s)); 719 | // EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s)); 720 | 721 | // ASSERT_FALSE(GetAttrValue(feed, "missing", &attr_value, s)); 722 | // EXPECT_EQ(string("Operation has no attr named 'missing'."), 723 | // string(TF_Message(s))); 724 | 725 | // // Make a constant oper with the scalar "3". 726 | // TF_Operation* three = ScalarConst(3, graph, s); 727 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 728 | 729 | // // Add oper. 730 | // TF_Operation* add = Add(feed, three, graph, s); 731 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 732 | 733 | // // Test TF_Operation*() query functions. 734 | // EXPECT_EQ(string("add"), string(TF_OperationName(add))); 735 | // EXPECT_EQ(string("AddN"), string(TF_OperationOpType(add))); 736 | // EXPECT_EQ(string(""), string(TF_OperationDevice(add))); 737 | // EXPECT_EQ(1, TF_OperationNumOutputs(add)); 738 | // EXPECT_EQ(TF_INT32, TF_OperationOutputType(TF_Output{add, 0})); 739 | // EXPECT_EQ(1, TF_OperationOutputListLength(add, "sum", s)); 740 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 741 | // EXPECT_EQ(2, TF_OperationNumInputs(add)); 742 | // EXPECT_EQ(2, TF_OperationInputListLength(add, "inputs", s)); 743 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 744 | // EXPECT_EQ(TF_INT32, TF_OperationInputType(TF_Input{add, 0})); 745 | // EXPECT_EQ(TF_INT32, TF_OperationInputType(TF_Input{add, 1})); 746 | // TF_Output add_in_0 = TF_OperationInput(TF_Input{add, 0}); 747 | // EXPECT_EQ(feed, add_in_0.oper); 748 | // EXPECT_EQ(0, add_in_0.index); 749 | // TF_Output add_in_1 = TF_OperationInput(TF_Input{add, 1}); 750 | // EXPECT_EQ(three, add_in_1.oper); 751 | // EXPECT_EQ(0, add_in_1.index); 752 | // EXPECT_EQ(0, TF_OperationOutputNumConsumers(TF_Output{add, 0})); 753 | // EXPECT_EQ(0, TF_OperationNumControlInputs(add)); 754 | // EXPECT_EQ(0, TF_OperationNumControlOutputs(add)); 755 | 756 | // ASSERT_TRUE(GetAttrValue(add, "T", &attr_value, s)) << TF_Message(s); 757 | // EXPECT_EQ(attr_value.type(), tensorflow::DT_INT32); 758 | // ASSERT_TRUE(GetAttrValue(add, "N", &attr_value, s)) << TF_Message(s); 759 | // EXPECT_EQ(attr_value.i(), 2); 760 | 761 | // // Placeholder oper now has a consumer. 762 | // ASSERT_EQ(1, TF_OperationOutputNumConsumers(TF_Output{feed, 0})); 763 | // TF_Input feed_port; 764 | // EXPECT_EQ(1, TF_OperationOutputConsumers(TF_Output{feed, 0}, &feed_port, 1)); 765 | // EXPECT_EQ(add, feed_port.oper); 766 | // EXPECT_EQ(0, feed_port.index); 767 | 768 | // // The scalar const oper also has a consumer. 769 | // ASSERT_EQ(1, TF_OperationOutputNumConsumers(TF_Output{three, 0})); 770 | // TF_Input three_port; 771 | // EXPECT_EQ(1, 772 | // TF_OperationOutputConsumers(TF_Output{three, 0}, &three_port, 1)); 773 | // EXPECT_EQ(add, three_port.oper); 774 | // EXPECT_EQ(1, three_port.index); 775 | 776 | // // Serialize to GraphDef. 777 | // GraphDef graph_def; 778 | // ASSERT_TRUE(GetGraphDef(graph, &graph_def)); 779 | 780 | // // Validate GraphDef is what we expect. 781 | // bool found_placeholder = false; 782 | // bool found_scalar_const = false; 783 | // bool found_add = false; 784 | // for (const auto& n : graph_def.node()) { 785 | // if (IsPlaceholder(n)) { 786 | // EXPECT_FALSE(found_placeholder); 787 | // found_placeholder = true; 788 | // } else if (IsScalarConst(n, 3)) { 789 | // EXPECT_FALSE(found_scalar_const); 790 | // found_scalar_const = true; 791 | // } else if (IsAddN(n, 2)) { 792 | // EXPECT_FALSE(found_add); 793 | // found_add = true; 794 | // } else { 795 | // ADD_FAILURE() << "Unexpected NodeDef: " << ProtoDebugString(n); 796 | // } 797 | // } 798 | // EXPECT_TRUE(found_placeholder); 799 | // EXPECT_TRUE(found_scalar_const); 800 | // EXPECT_TRUE(found_add); 801 | 802 | // // Add another oper to the graph. 803 | // TF_Operation* neg = Neg(add, graph, s); 804 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 805 | 806 | // // Serialize to NodeDef. 807 | // NodeDef node_def; 808 | // ASSERT_TRUE(GetNodeDef(neg, &node_def)); 809 | 810 | // // Validate NodeDef is what we expect. 811 | // EXPECT_TRUE(IsNeg(node_def, "add")); 812 | 813 | // // Serialize to GraphDef. 814 | // GraphDef graph_def2; 815 | // ASSERT_TRUE(GetGraphDef(graph, &graph_def2)); 816 | 817 | // // Compare with first GraphDef + added NodeDef. 818 | // NodeDef* added_node = graph_def.add_node(); 819 | // *added_node = node_def; 820 | // EXPECT_EQ(ProtoDebugString(graph_def), ProtoDebugString(graph_def2)); 821 | 822 | // // Look up some nodes by name. 823 | // TF_Operation* neg2 = TF_GraphOperationByName(graph, "neg"); 824 | // EXPECT_TRUE(neg == neg2); 825 | // NodeDef node_def2; 826 | // ASSERT_TRUE(GetNodeDef(neg2, &node_def2)); 827 | // EXPECT_EQ(ProtoDebugString(node_def), ProtoDebugString(node_def2)); 828 | 829 | // TF_Operation* feed2 = TF_GraphOperationByName(graph, "feed"); 830 | // EXPECT_TRUE(feed == feed2); 831 | // ASSERT_TRUE(GetNodeDef(feed, &node_def)); 832 | // ASSERT_TRUE(GetNodeDef(feed2, &node_def2)); 833 | // EXPECT_EQ(ProtoDebugString(node_def), ProtoDebugString(node_def2)); 834 | 835 | // // Test iterating through the nodes of a graph. 836 | // found_placeholder = false; 837 | // found_scalar_const = false; 838 | // found_add = false; 839 | // bool found_neg = false; 840 | // size_t pos = 0; 841 | // TF_Operation* oper; 842 | // while ((oper = TF_GraphNextOperation(graph, &pos)) != nullptr) { 843 | // if (oper == feed) { 844 | // EXPECT_FALSE(found_placeholder); 845 | // found_placeholder = true; 846 | // } else if (oper == three) { 847 | // EXPECT_FALSE(found_scalar_const); 848 | // found_scalar_const = true; 849 | // } else if (oper == add) { 850 | // EXPECT_FALSE(found_add); 851 | // found_add = true; 852 | // } else if (oper == neg) { 853 | // EXPECT_FALSE(found_neg); 854 | // found_neg = true; 855 | // } else { 856 | // ASSERT_TRUE(GetNodeDef(oper, &node_def)); 857 | // ADD_FAILURE() << "Unexpected Node: " << ProtoDebugString(node_def); 858 | // } 859 | // } 860 | // EXPECT_TRUE(found_placeholder); 861 | // EXPECT_TRUE(found_scalar_const); 862 | // EXPECT_TRUE(found_add); 863 | // EXPECT_TRUE(found_neg); 864 | 865 | // // Clean up 866 | // TF_DeleteGraph(graph); 867 | // TF_DeleteStatus(s); 868 | // } 869 | 870 | // /* 871 | // TODO(skyewm): this test currently DCHECKs, change to bad status 872 | 873 | // TEST(CAPI, InputFromDifferentGraphError) { 874 | // TF_Status* s = TF_NewStatus(); 875 | // TF_Graph* g1 = TF_NewGraph(); 876 | // TF_Graph* g2 = TF_NewGraph(); 877 | 878 | // TF_Operation* feed = Placeholder(g1, s); 879 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 880 | 881 | // // Attempt to create node in g2 with input from g1 882 | // Neg(feed, g2, s); 883 | // EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s)); 884 | // EXPECT_STREQ("foo", TF_Message(s)); 885 | 886 | // TF_DeleteGraph(g1); 887 | // TF_DeleteGraph(g2); 888 | // TF_DeleteStatus(s); 889 | // } 890 | // */ 891 | 892 | // TEST(CAPI, ImportGraphDef) { 893 | // TF_Status* s = TF_NewStatus(); 894 | // TF_Graph* graph = TF_NewGraph(); 895 | 896 | // // Create a graph with two nodes: x and 3 897 | // Placeholder(graph, s); 898 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 899 | // ASSERT_TRUE(TF_GraphOperationByName(graph, "feed") != nullptr); 900 | // TF_Operation* oper = ScalarConst(3, graph, s); 901 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 902 | // ASSERT_TRUE(TF_GraphOperationByName(graph, "scalar") != nullptr); 903 | // Neg(oper, graph, s); 904 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 905 | // ASSERT_TRUE(TF_GraphOperationByName(graph, "neg") != nullptr); 906 | 907 | // // Export to a GraphDef 908 | // TF_Buffer* graph_def = TF_NewBuffer(); 909 | // TF_GraphToGraphDef(graph, graph_def, s); 910 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 911 | 912 | // // Import it again, with a prefix, in a fresh graph. 913 | // TF_DeleteGraph(graph); 914 | // graph = TF_NewGraph(); 915 | // TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions(); 916 | // TF_ImportGraphDefOptionsSetPrefix(opts, "imported"); 917 | // TF_GraphImportGraphDef(graph, graph_def, opts, s); 918 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 919 | 920 | // TF_Operation* scalar = TF_GraphOperationByName(graph, "imported/scalar"); 921 | // TF_Operation* feed = TF_GraphOperationByName(graph, "imported/feed"); 922 | // TF_Operation* neg = TF_GraphOperationByName(graph, "imported/neg"); 923 | // ASSERT_TRUE(scalar != nullptr); 924 | // ASSERT_TRUE(feed != nullptr); 925 | // ASSERT_TRUE(neg != nullptr); 926 | 927 | // // Import it again, with an input mapping and return outputs, into the same 928 | // // graph. 929 | // TF_DeleteImportGraphDefOptions(opts); 930 | // opts = TF_NewImportGraphDefOptions(); 931 | // TF_ImportGraphDefOptionsSetPrefix(opts, "imported2"); 932 | // TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, {scalar, 0}); 933 | // TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0); 934 | // TF_ImportGraphDefOptionsAddReturnOutput(opts, "scalar", 0); 935 | // EXPECT_EQ(2, TF_ImportGraphDefOptionsNumReturnOutputs(opts)); 936 | // TF_Output return_outputs[2]; 937 | // TF_GraphImportGraphDefWithReturnOutputs(graph, graph_def, opts, 938 | // return_outputs, 2, s); 939 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 940 | 941 | // TF_Operation* scalar2 = TF_GraphOperationByName(graph, "imported2/scalar"); 942 | // TF_Operation* feed2 = TF_GraphOperationByName(graph, "imported2/feed"); 943 | // TF_Operation* neg2 = TF_GraphOperationByName(graph, "imported2/neg"); 944 | // ASSERT_TRUE(scalar2 != nullptr); 945 | // ASSERT_TRUE(feed2 != nullptr); 946 | // ASSERT_TRUE(neg2 != nullptr); 947 | 948 | // // Check input mapping 949 | // TF_Output neg_input = TF_OperationInput({neg, 0}); 950 | // EXPECT_EQ(scalar, neg_input.oper); 951 | // EXPECT_EQ(0, neg_input.index); 952 | 953 | // // Check return outputs 954 | // EXPECT_EQ(feed2, return_outputs[0].oper); 955 | // EXPECT_EQ(0, return_outputs[0].index); 956 | // EXPECT_EQ(scalar, return_outputs[1].oper); // remapped 957 | // EXPECT_EQ(0, return_outputs[1].index); 958 | 959 | // // Import again, with control dependencies, into the same graph. 960 | // TF_DeleteImportGraphDefOptions(opts); 961 | // opts = TF_NewImportGraphDefOptions(); 962 | // TF_ImportGraphDefOptionsSetPrefix(opts, "imported3"); 963 | // TF_ImportGraphDefOptionsAddControlDependency(opts, feed); 964 | // TF_ImportGraphDefOptionsAddControlDependency(opts, feed2); 965 | // TF_GraphImportGraphDef(graph, graph_def, opts, s); 966 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 967 | 968 | // TF_Operation* scalar3 = TF_GraphOperationByName(graph, "imported3/scalar"); 969 | // TF_Operation* feed3 = TF_GraphOperationByName(graph, "imported3/feed"); 970 | // TF_Operation* neg3 = TF_GraphOperationByName(graph, "imported3/neg"); 971 | // ASSERT_TRUE(scalar3 != nullptr); 972 | // ASSERT_TRUE(feed3 != nullptr); 973 | // ASSERT_TRUE(neg3 != nullptr); 974 | 975 | // // Check that newly-imported scalar and feed have control deps (neg3 will 976 | // // inherit them from input) 977 | // TF_Operation* control_inputs[100]; 978 | // int num_control_inputs = TF_OperationGetControlInputs( 979 | // scalar3, control_inputs, TF_OperationNumControlInputs(scalar3)); 980 | // ASSERT_EQ(2, num_control_inputs); 981 | // EXPECT_EQ(feed, control_inputs[0]); 982 | // EXPECT_EQ(feed2, control_inputs[1]); 983 | 984 | // num_control_inputs = TF_OperationGetControlInputs( 985 | // feed3, control_inputs, TF_OperationNumControlInputs(feed3)); 986 | // ASSERT_EQ(2, num_control_inputs); 987 | // EXPECT_EQ(feed, control_inputs[0]); 988 | // EXPECT_EQ(feed2, control_inputs[1]); 989 | 990 | // TF_DeleteImportGraphDefOptions(opts); 991 | // TF_DeleteBuffer(graph_def); 992 | 993 | // // Can add nodes to the imported graph without trouble. 994 | // Add(feed, scalar, graph, s); 995 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 996 | 997 | // TF_DeleteGraph(graph); 998 | // TF_DeleteStatus(s); 999 | // } 1000 | 1001 | // class CSession { 1002 | // public: 1003 | // CSession(TF_Graph* graph, TF_Status* s) { 1004 | // TF_SessionOptions* opts = TF_NewSessionOptions(); 1005 | // session_ = TF_NewSession(graph, opts, s); 1006 | // TF_DeleteSessionOptions(opts); 1007 | // } 1008 | 1009 | // CSession(TF_Session* session) { session_ = session; } 1010 | 1011 | // ~CSession() { 1012 | // TF_Status* s = TF_NewStatus(); 1013 | // CloseAndDelete(s); 1014 | // EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 1015 | // TF_DeleteStatus(s); 1016 | // } 1017 | 1018 | // void SetInputs(std::vector> inputs) { 1019 | // DeleteInputValues(); 1020 | // inputs_.clear(); 1021 | // for (const auto& p : inputs) { 1022 | // inputs_.emplace_back(TF_Output{p.first, 0}); 1023 | // input_values_.emplace_back(p.second); 1024 | // } 1025 | // } 1026 | 1027 | // void SetOutputs(std::initializer_list outputs) { 1028 | // ResetOutputValues(); 1029 | // outputs_.clear(); 1030 | // for (TF_Operation* o : outputs) { 1031 | // outputs_.emplace_back(TF_Output{o, 0}); 1032 | // } 1033 | // } 1034 | 1035 | // void SetOutputs(const std::vector& outputs) { 1036 | // ResetOutputValues(); 1037 | // outputs_ = outputs; 1038 | // } 1039 | 1040 | // void SetTargets(std::initializer_list targets) { 1041 | // targets_.clear(); 1042 | // for (TF_Operation* t : targets) { 1043 | // targets_.emplace_back(t); 1044 | // } 1045 | // } 1046 | 1047 | // void Run(TF_Status* s) { 1048 | // if (inputs_.size() != input_values_.size()) { 1049 | // ADD_FAILURE() << "Call SetInputs() before Run()"; 1050 | // return; 1051 | // } 1052 | // ResetOutputValues(); 1053 | // output_values_.resize(outputs_.size(), nullptr); 1054 | 1055 | // const TF_Output* inputs_ptr = inputs_.empty() ? nullptr : &inputs_[0]; 1056 | // TF_Tensor* const* input_values_ptr = 1057 | // input_values_.empty() ? nullptr : &input_values_[0]; 1058 | 1059 | // const TF_Output* outputs_ptr = outputs_.empty() ? nullptr : &outputs_[0]; 1060 | // TF_Tensor** output_values_ptr = 1061 | // output_values_.empty() ? nullptr : &output_values_[0]; 1062 | 1063 | // TF_Operation* const* targets_ptr = 1064 | // targets_.empty() ? nullptr : &targets_[0]; 1065 | 1066 | // TF_SessionRun(session_, nullptr, inputs_ptr, input_values_ptr, 1067 | // inputs_.size(), outputs_ptr, output_values_ptr, 1068 | // outputs_.size(), targets_ptr, targets_.size(), nullptr, s); 1069 | 1070 | // DeleteInputValues(); 1071 | // } 1072 | 1073 | // void CloseAndDelete(TF_Status* s) { 1074 | // DeleteInputValues(); 1075 | // ResetOutputValues(); 1076 | // if (session_ != nullptr) { 1077 | // TF_CloseSession(session_, s); 1078 | // EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 1079 | // TF_DeleteSession(session_, s); 1080 | // session_ = nullptr; 1081 | // } 1082 | // } 1083 | 1084 | // TF_Tensor* output_tensor(int i) { return output_values_[i]; } 1085 | 1086 | // private: 1087 | // void DeleteInputValues() { 1088 | // for (int i = 0; i < input_values_.size(); ++i) { 1089 | // TF_DeleteTensor(input_values_[i]); 1090 | // } 1091 | // input_values_.clear(); 1092 | // } 1093 | 1094 | // void ResetOutputValues() { 1095 | // for (int i = 0; i < output_values_.size(); ++i) { 1096 | // if (output_values_[i] != nullptr) TF_DeleteTensor(output_values_[i]); 1097 | // } 1098 | // output_values_.clear(); 1099 | // } 1100 | 1101 | // TF_Session* session_; 1102 | // std::vector inputs_; 1103 | // std::vector input_values_; 1104 | // std::vector outputs_; 1105 | // std::vector output_values_; 1106 | // std::vector targets_; 1107 | // }; 1108 | 1109 | // TEST(CAPI, Session) { 1110 | // TF_Status* s = TF_NewStatus(); 1111 | // TF_Graph* graph = TF_NewGraph(); 1112 | 1113 | // // Make a placeholder operation. 1114 | // TF_Operation* feed = Placeholder(graph, s); 1115 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 1116 | 1117 | // // Make a constant operation with the scalar "2". 1118 | // TF_Operation* two = ScalarConst(2, graph, s); 1119 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 1120 | 1121 | // // Add operation. 1122 | // TF_Operation* add = Add(feed, two, graph, s); 1123 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 1124 | 1125 | // // Create a session for this graph. 1126 | // CSession csession(graph, s); 1127 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 1128 | 1129 | // // Run the graph. 1130 | // csession.SetInputs({{feed, Int32Tensor(3)}}); 1131 | // csession.SetOutputs({add}); 1132 | // csession.Run(s); 1133 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 1134 | // TF_Tensor* out = csession.output_tensor(0); 1135 | // ASSERT_TRUE(out != nullptr); 1136 | // EXPECT_EQ(TF_INT32, TF_TensorType(out)); 1137 | // EXPECT_EQ(0, TF_NumDims(out)); // scalar 1138 | // ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out)); 1139 | // int32* output_contents = static_cast(TF_TensorData(out)); 1140 | // EXPECT_EQ(3 + 2, *output_contents); 1141 | 1142 | // // Add another operation to the graph. 1143 | // TF_Operation* neg = Neg(add, graph, s); 1144 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 1145 | 1146 | // // Run up to the new operation. 1147 | // csession.SetInputs({{feed, Int32Tensor(7)}}); 1148 | // csession.SetOutputs({neg}); 1149 | // csession.Run(s); 1150 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 1151 | // out = csession.output_tensor(0); 1152 | // ASSERT_TRUE(out != nullptr); 1153 | // EXPECT_EQ(TF_INT32, TF_TensorType(out)); 1154 | // EXPECT_EQ(0, TF_NumDims(out)); // scalar 1155 | // ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out)); 1156 | // output_contents = static_cast(TF_TensorData(out)); 1157 | // EXPECT_EQ(-(7 + 2), *output_contents); 1158 | 1159 | // // Clean up 1160 | // csession.CloseAndDelete(s); 1161 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 1162 | // TF_DeleteGraph(graph); 1163 | // TF_DeleteStatus(s); 1164 | // } 1165 | 1166 | // TEST(CAPI, SessionPRun) { 1167 | // TF_Status* s = TF_NewStatus(); 1168 | // TF_Graph* graph = TF_NewGraph(); 1169 | 1170 | // // Construct the graph: A + 2 + B 1171 | // TF_Operation* a = Placeholder(graph, s, "A"); 1172 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 1173 | 1174 | // TF_Operation* b = Placeholder(graph, s, "B"); 1175 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 1176 | 1177 | // TF_Operation* two = ScalarConst(2, graph, s); 1178 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 1179 | 1180 | // TF_Operation* plus2 = Add(a, two, graph, s, "plus2"); 1181 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 1182 | 1183 | // TF_Operation* plusB = Add(plus2, b, graph, s, "plusB"); 1184 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 1185 | 1186 | // // Setup a session and a partial run handle. The partial run will allow 1187 | // // computation of A + 2 + B in two phases (calls to TF_SessionPRun): 1188 | // // 1. Feed A and get (A+2) 1189 | // // 2. Feed B and get (A+2)+B 1190 | // TF_SessionOptions* opts = TF_NewSessionOptions(); 1191 | // TF_Session* sess = TF_NewSession(graph, opts, s); 1192 | // TF_DeleteSessionOptions(opts); 1193 | 1194 | // TF_Output feeds[] = {TF_Output{a, 0}, TF_Output{b, 0}}; 1195 | // TF_Output fetches[] = {TF_Output{plus2, 0}, TF_Output{plusB, 0}}; 1196 | 1197 | // const char* handle = nullptr; 1198 | // TF_SessionPRunSetup(sess, feeds, TF_ARRAYSIZE(feeds), fetches, 1199 | // TF_ARRAYSIZE(fetches), nullptr, 0, &handle, s); 1200 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 1201 | 1202 | // // Feed A and fetch A + 2. 1203 | // TF_Output feeds1[] = {TF_Output{a, 0}}; 1204 | // TF_Output fetches1[] = {TF_Output{plus2, 0}}; 1205 | // TF_Tensor* feedValues1[] = {Int32Tensor(1)}; 1206 | // TF_Tensor* fetchValues1[1]; 1207 | // TF_SessionPRun(sess, handle, feeds1, feedValues1, 1, fetches1, fetchValues1, 1208 | // 1, nullptr, 0, s); 1209 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 1210 | // EXPECT_EQ(3, *(static_cast(TF_TensorData(fetchValues1[0])))); 1211 | // TF_DeleteTensor(feedValues1[0]); 1212 | // TF_DeleteTensor(fetchValues1[0]); 1213 | 1214 | // // Feed B and fetch (A + 2) + B. 1215 | // TF_Output feeds2[] = {TF_Output{b, 0}}; 1216 | // TF_Output fetches2[] = {TF_Output{plusB, 0}}; 1217 | // TF_Tensor* feedValues2[] = {Int32Tensor(4)}; 1218 | // TF_Tensor* fetchValues2[1]; 1219 | // TF_SessionPRun(sess, handle, feeds2, feedValues2, 1, fetches2, fetchValues2, 1220 | // 1, nullptr, 0, s); 1221 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 1222 | // EXPECT_EQ(7, *(static_cast(TF_TensorData(fetchValues2[0])))); 1223 | // TF_DeleteTensor(feedValues2[0]); 1224 | // TF_DeleteTensor(fetchValues2[0]); 1225 | 1226 | // // Clean up. 1227 | // TF_DeletePRunHandle(handle); 1228 | // TF_DeleteSession(sess, s); 1229 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 1230 | // TF_DeleteGraph(graph); 1231 | // TF_DeleteStatus(s); 1232 | // } 1233 | 1234 | // TEST(CAPI, ColocateWith) { 1235 | // TF_Status* s = TF_NewStatus(); 1236 | // TF_Graph* graph = TF_NewGraph(); 1237 | 1238 | // TF_Operation* feed = Placeholder(graph, s); 1239 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 1240 | 1241 | // TF_Operation* constant = ScalarConst(10, graph, s); 1242 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 1243 | 1244 | // TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", "add"); 1245 | // TF_Output inputs[] = {{feed, 0}, {constant, 0}}; 1246 | // TF_AddInputList(desc, inputs, TF_ARRAYSIZE(inputs)); 1247 | // TF_ColocateWith(desc, feed); 1248 | // TF_Operation* add = TF_FinishOperation(desc, s); 1249 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 1250 | 1251 | // TF_AttrMetadata m = 1252 | // TF_OperationGetAttrMetadata(add, tensorflow::kColocationAttrName, s); 1253 | // EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 1254 | // EXPECT_EQ(1, m.is_list); 1255 | // EXPECT_EQ(1, m.list_size); 1256 | // EXPECT_EQ(TF_ATTR_STRING, m.type); 1257 | // void* values[1]; 1258 | // size_t lens[1]; 1259 | // std::unique_ptr storage(new char[m.total_size]); 1260 | // TF_OperationGetAttrStringList(add, tensorflow::kColocationAttrName, values, 1261 | // lens, 1, storage.get(), m.total_size, s); 1262 | // EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 1263 | // EXPECT_EQ("loc:@feed", string(static_cast(values[0]), lens[0])); 1264 | 1265 | // TF_DeleteGraph(graph); 1266 | // TF_DeleteStatus(s); 1267 | // } 1268 | 1269 | // TEST(CAPI, SavedModel) { 1270 | // // Load the saved model. 1271 | // const char kSavedModel[] = "cc/saved_model/testdata/half_plus_two/00000123"; 1272 | // const string saved_model_dir = tensorflow::io::JoinPath( 1273 | // tensorflow::testing::TensorFlowSrcRoot(), kSavedModel); 1274 | // TF_SessionOptions* opt = TF_NewSessionOptions(); 1275 | // TF_Buffer* run_options = TF_NewBufferFromString("", 0); 1276 | // TF_Buffer* metagraph = TF_NewBuffer(); 1277 | // TF_Status* s = TF_NewStatus(); 1278 | // const char* tags[] = {tensorflow::kSavedModelTagServe}; 1279 | // TF_Graph* graph = TF_NewGraph(); 1280 | // TF_Session* session = TF_LoadSessionFromSavedModel( 1281 | // opt, run_options, saved_model_dir.c_str(), tags, 1, graph, metagraph, s); 1282 | // TF_DeleteBuffer(run_options); 1283 | // TF_DeleteSessionOptions(opt); 1284 | // tensorflow::MetaGraphDef metagraph_def; 1285 | // metagraph_def.ParseFromArray(metagraph->data, metagraph->length); 1286 | // TF_DeleteBuffer(metagraph); 1287 | 1288 | // EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 1289 | // CSession csession(session); 1290 | 1291 | // // Retrieve the regression signature from meta graph def. 1292 | // const auto signature_def_map = metagraph_def.signature_def(); 1293 | // const auto signature_def = signature_def_map.at("regress_x_to_y"); 1294 | 1295 | // const string input_name = 1296 | // signature_def.inputs().at(tensorflow::kRegressInputs).name(); 1297 | // const string output_name = 1298 | // signature_def.outputs().at(tensorflow::kRegressOutputs).name(); 1299 | 1300 | // // Write {0, 1, 2, 3} as tensorflow::Example inputs. 1301 | // Tensor input(tensorflow::DT_STRING, TensorShape({4})); 1302 | // for (tensorflow::int64 i = 0; i < input.NumElements(); ++i) { 1303 | // tensorflow::Example example; 1304 | // auto* feature_map = example.mutable_features()->mutable_feature(); 1305 | // (*feature_map)["x"].mutable_float_list()->add_value(i); 1306 | // input.flat()(i) = example.SerializeAsString(); 1307 | // } 1308 | 1309 | // const tensorflow::string input_op_name = 1310 | // tensorflow::ParseTensorName(input_name).first.ToString(); 1311 | // TF_Operation* input_op = 1312 | // TF_GraphOperationByName(graph, input_op_name.c_str()); 1313 | // ASSERT_TRUE(input_op != nullptr); 1314 | // csession.SetInputs({{input_op, TF_Tensor_EncodeStrings(input)}}); 1315 | 1316 | // const tensorflow::string output_op_name = 1317 | // tensorflow::ParseTensorName(output_name).first.ToString(); 1318 | // TF_Operation* output_op = 1319 | // TF_GraphOperationByName(graph, output_op_name.c_str()); 1320 | // ASSERT_TRUE(output_op != nullptr); 1321 | // csession.SetOutputs({output_op}); 1322 | // csession.Run(s); 1323 | // ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 1324 | 1325 | // TF_Tensor* out = csession.output_tensor(0); 1326 | // ASSERT_TRUE(out != nullptr); 1327 | // EXPECT_EQ(TF_FLOAT, TF_TensorType(out)); 1328 | // EXPECT_EQ(2, TF_NumDims(out)); 1329 | // EXPECT_EQ(4, TF_Dim(out, 0)); 1330 | // EXPECT_EQ(1, TF_Dim(out, 1)); 1331 | // float* values = static_cast(TF_TensorData(out)); 1332 | // // These values are defined to be (input / 2) + 2. 1333 | // EXPECT_EQ(2, values[0]); 1334 | // EXPECT_EQ(2.5, values[1]); 1335 | // EXPECT_EQ(3, values[2]); 1336 | // EXPECT_EQ(3.5, values[3]); 1337 | 1338 | // csession.CloseAndDelete(s); 1339 | // EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 1340 | // TF_DeleteGraph(graph); 1341 | // TF_DeleteStatus(s); 1342 | // } 1343 | 1344 | // TEST(CAPI, SavedModelNullArgsAreValid) { 1345 | // const char kSavedModel[] = "cc/saved_model/testdata/half_plus_two/00000123"; 1346 | // const string saved_model_dir = tensorflow::io::JoinPath( 1347 | // tensorflow::testing::TensorFlowSrcRoot(), kSavedModel); 1348 | // TF_SessionOptions* opt = TF_NewSessionOptions(); 1349 | // TF_Status* s = TF_NewStatus(); 1350 | // const char* tags[] = {tensorflow::kSavedModelTagServe}; 1351 | // TF_Graph* graph = TF_NewGraph(); 1352 | // // NULL run_options and meta_graph_def should work. 1353 | // TF_Session* session = TF_LoadSessionFromSavedModel( 1354 | // opt, nullptr, saved_model_dir.c_str(), tags, 1, graph, nullptr, s); 1355 | // EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 1356 | // TF_DeleteSessionOptions(opt); 1357 | // TF_CloseSession(session, s); 1358 | // EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 1359 | // TF_DeleteSession(session, s); 1360 | // EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); 1361 | // TF_DeleteGraph(graph); 1362 | // TF_DeleteStatus(s); 1363 | // } 1364 | 1365 | // class CApiWhileLoopTest : public ::testing::Test { 1366 | // protected: 1367 | // CApiWhileLoopTest() : s_(TF_NewStatus()), graph_(TF_NewGraph()) {} 1368 | 1369 | // ~CApiWhileLoopTest() override { 1370 | // TF_DeleteGraph(graph_); 1371 | // TF_DeleteStatus(s_); 1372 | // } 1373 | 1374 | // void Init(int ninputs) { 1375 | // DCHECK(inputs_.empty()); 1376 | // DCHECK_GT(ninputs, 0); 1377 | 1378 | // for (int i = 0; i < ninputs; ++i) { 1379 | // TF_Operation* placeholder = Placeholder( 1380 | // graph_, s_, ::tensorflow::strings::StrCat("p", i).c_str()); 1381 | // DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1382 | // inputs_.push_back({placeholder, 0}); 1383 | // } 1384 | 1385 | // original_graph_description_ = GraphDebugString(); 1386 | 1387 | // params_.reset(new TF_WhileParams( 1388 | // TF_NewWhile(graph_, &inputs_[0], inputs_.size(), s_))); 1389 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1390 | // ASSERT_EQ(original_graph_description_, GraphDebugString()) 1391 | // << "TF_NewWhile() altered graph"; 1392 | 1393 | // params_->name = "test_loop"; 1394 | 1395 | // // Initialize outputs_ so we can easily detect errors/bugs 1396 | // outputs_.resize(ninputs, {nullptr, -1}); 1397 | // } 1398 | 1399 | // void ExpectOK() { 1400 | // TF_FinishWhile(params_.get(), s_, &outputs_[0]); 1401 | // EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1402 | // } 1403 | 1404 | // void ExpectError(TF_Code expected_code, const string& expected_msg) { 1405 | // TF_FinishWhile(params_.get(), s_, &outputs_[0]); 1406 | // EXPECT_EQ(expected_code, TF_GetCode(s_)); 1407 | // EXPECT_EQ(expected_msg, TF_Message(s_)); 1408 | // // TODO(skyewm): this assert is currently broken. Fix or remove guarantee. 1409 | // // ASSERT_EQ(original_graph_description_, GraphDebugString()) << 1410 | // // "TF_FinishWhile() altered graph on error"; 1411 | // } 1412 | 1413 | // void Run(std::initializer_list input_values) { 1414 | // DCHECK_EQ(inputs_.size(), input_values.size()); 1415 | // std::vector> inputs(inputs_.size()); 1416 | // int i = 0; 1417 | // for (int v : input_values) { 1418 | // inputs[i] = {inputs_[i].oper, Int32Tensor(v)}; 1419 | // ++i; 1420 | // } 1421 | // csession_.reset(new CSession(graph_, s_)); 1422 | // csession_->SetInputs(inputs); 1423 | // csession_->SetOutputs(outputs_); 1424 | // csession_->Run(s_); 1425 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1426 | // } 1427 | 1428 | // void ExpectOutputValue(int idx, int expected_value) { 1429 | // TF_Tensor* out = csession_->output_tensor(idx); 1430 | // ASSERT_TRUE(out != nullptr); 1431 | // EXPECT_EQ(TF_INT32, TF_TensorType(out)); 1432 | // EXPECT_EQ(0, TF_NumDims(out)); 1433 | // ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out)); 1434 | // int32* data = static_cast(TF_TensorData(out)); 1435 | // EXPECT_EQ(expected_value, *data); 1436 | // } 1437 | 1438 | // // Create a valid conditonal graph. Useful for testing unrelated errors. 1439 | // void CreateCondGraph() { 1440 | // TF_Operation* one = ScalarConst(1, params_->cond_graph, s_); 1441 | // TF_Operation* less_than = 1442 | // LessThan(params_->cond_inputs[0], {one, 0}, params_->cond_graph, s_); 1443 | // DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1444 | // params_->cond_output = {less_than, 0}; 1445 | // } 1446 | 1447 | // string GraphDebugString() const { 1448 | // TF_Buffer* buf = TF_NewBuffer(); 1449 | // TF_GraphToGraphDef(graph_, buf, s_); 1450 | // DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1451 | // GraphDef def; 1452 | // bool success = def.ParseFromArray(buf->data, buf->length); 1453 | // DCHECK(success); 1454 | // TF_DeleteBuffer(buf); 1455 | // return def.DebugString(); 1456 | // } 1457 | 1458 | // TF_Status* s_; 1459 | // TF_Graph* graph_; 1460 | // std::vector inputs_; // The inputs to the while loop 1461 | // std::vector outputs_; // The final outputs of the while loop 1462 | // std::unique_ptr params_; 1463 | // std::unique_ptr csession_; 1464 | 1465 | // private: 1466 | // // Used to verify that errors don't change graph_ 1467 | // string original_graph_description_; 1468 | // }; 1469 | 1470 | // TEST_F(CApiWhileLoopTest, BasicLoop) { 1471 | // Init(2); 1472 | 1473 | // // Validate TF_WhileParams returned by TF_NewWhile() 1474 | // EXPECT_TRUE(params_->body_graph != nullptr); 1475 | // EXPECT_TRUE(params_->cond_graph != nullptr); 1476 | 1477 | // EXPECT_EQ(params_->ninputs, 2); 1478 | 1479 | // ASSERT_TRUE(params_->cond_inputs != nullptr); 1480 | // ASSERT_TRUE(params_->cond_inputs[0].oper != nullptr); 1481 | // EXPECT_TRUE(params_->cond_inputs[1].oper != nullptr); 1482 | 1483 | // ASSERT_TRUE(params_->body_inputs != nullptr); 1484 | // EXPECT_TRUE(params_->body_inputs[0].oper != nullptr); 1485 | // EXPECT_TRUE(params_->body_inputs[1].oper != nullptr); 1486 | 1487 | // ASSERT_TRUE(params_->body_outputs != nullptr); 1488 | 1489 | // // Create loop: while (input1 < input2) input1 += input2 + 1 1490 | // TF_Operation* less_than = 1491 | // LessThan(params_->cond_inputs[0], params_->cond_inputs[1], 1492 | // params_->cond_graph, s_); 1493 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1494 | // params_->cond_output = {less_than, 0}; 1495 | 1496 | // TF_Operation* add1 = Add(params_->body_inputs[0], params_->body_inputs[1], 1497 | // params_->body_graph, s_, "add1"); 1498 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1499 | // TF_Operation* one = ScalarConst(1, params_->body_graph, s_); 1500 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1501 | // TF_Operation* add2 = Add(add1, one, params_->body_graph, s_, "add2"); 1502 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1503 | // params_->body_outputs[0] = {add2, 0}; 1504 | // params_->body_outputs[1] = params_->body_inputs[1]; 1505 | 1506 | // // Finalize while loop 1507 | // ExpectOK(); 1508 | 1509 | // // Validate while loop outputs returned by TF_FinishWhile() 1510 | // EXPECT_TRUE(outputs_[0].oper != nullptr); 1511 | // EXPECT_GE(outputs_[0].index, 0); 1512 | // EXPECT_TRUE(outputs_[1].oper != nullptr); 1513 | // EXPECT_GE(outputs_[1].index, 0); 1514 | 1515 | // // Run the graph 1516 | // Run({-9, 2}); 1517 | // ExpectOutputValue(0, 3); 1518 | // ExpectOutputValue(1, 2); 1519 | // } 1520 | 1521 | // TEST_F(CApiWhileLoopTest, NestedLoop) { 1522 | // Init(2); 1523 | // // Create nested loop: 1524 | // // while (input1 < 6) { 1525 | // // inner_input1 = input1 1526 | // // while (inner_input1 < 3) { 1527 | // // input2 += 1 1528 | // // inner_input1 += 2 1529 | // // } 1530 | // // input1 += input2 1531 | // // } 1532 | // // 1533 | // // Expected execution with initial values input1 = input2 = 0: 1534 | // // 1535 | // // outer inner inner_ 1536 | // // step# step# input1 input2 input1 1537 | // // ------------------------------------ 1538 | // // 0 0 0 0 0 1539 | // // 0 1 0 1 2 1540 | // // 0 2 0 2 4 1541 | // // 0 - 2 2 - 1542 | // // 1 0 2 2 2 1543 | // // 1 1 2 3 4 1544 | // // 1 - 5 3 - 1545 | // // 2 0 5 3 5 1546 | // // 2 - 8 3 - 1547 | 1548 | // // Create outer cond graph 1549 | // TF_Operation* six = ScalarConst(6, params_->cond_graph, s_); 1550 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1551 | // TF_Operation* less_than = 1552 | // LessThan(params_->cond_inputs[0], {six, 0}, params_->cond_graph, s_); 1553 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1554 | // params_->cond_output = {less_than, 0}; 1555 | 1556 | // // Create outer body graph 1557 | // // Init inner graph 1558 | // TF_Output inner_inputs[] = {params_->body_inputs[0], params_->body_inputs[1]}; 1559 | // TF_WhileParams inner_params = 1560 | // TF_NewWhile(params_->body_graph, inner_inputs, 2, s_); 1561 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1562 | // inner_params.name = "inner_loop"; 1563 | 1564 | // // Create inner cond graph 1565 | // TF_Operation* three = ScalarConst(3, inner_params.cond_graph, s_); 1566 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1567 | // TF_Operation* inner_less_than = LessThan( 1568 | // inner_params.cond_inputs[0], {three, 0}, inner_params.cond_graph, s_); 1569 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1570 | // inner_params.cond_output = {inner_less_than, 0}; 1571 | 1572 | // // Create inner body graph 1573 | // TF_Operation* one = ScalarConst(1, inner_params.body_graph, s_, "one"); 1574 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1575 | // TF_Operation* two = ScalarConst(2, inner_params.body_graph, s_, "two"); 1576 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1577 | 1578 | // TF_Operation* input2_add = 1579 | // Add(inner_params.body_inputs[1].oper, one, inner_params.body_graph, s_); 1580 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1581 | // inner_params.body_outputs[1] = {input2_add, 0}; 1582 | 1583 | // TF_Operation* inner_input1_add = Add(inner_params.body_inputs[0].oper, two, 1584 | // inner_params.body_graph, s_, "add2"); 1585 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1586 | // inner_params.body_outputs[0] = {inner_input1_add, 0}; 1587 | 1588 | // // Finalize inner graph 1589 | // TF_Output inner_outputs[2] = {{nullptr, -1}}; 1590 | // TF_FinishWhile(&inner_params, s_, inner_outputs); 1591 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1592 | 1593 | // TF_Operation* input1_add = 1594 | // Add(params_->body_inputs[0], inner_outputs[1], params_->body_graph, s_); 1595 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1596 | // params_->body_outputs[0] = {input1_add, 0}; 1597 | 1598 | // params_->body_outputs[1] = inner_outputs[1]; 1599 | 1600 | // // Finalize outer graph 1601 | // ExpectOK(); 1602 | 1603 | // // Check for a few expected nodes 1604 | // const char* node_name = "test_loop/cond/scalar"; 1605 | // EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr); 1606 | // node_name = "test_loop/body/add"; 1607 | // EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr); 1608 | // node_name = "test_loop/body/inner_loop/body/one"; 1609 | // EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr); 1610 | // node_name = "test_loop/body/inner_loop/cond/less_than"; 1611 | // EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr); 1612 | 1613 | // // Run the graph 1614 | // Run({0, 0}); 1615 | // ExpectOutputValue(0, 8); 1616 | // ExpectOutputValue(1, 3); 1617 | // } 1618 | 1619 | // TEST_F(CApiWhileLoopTest, BadCondOutput) { 1620 | // Init(1); 1621 | // params_->body_outputs[0] = params_->body_inputs[0]; 1622 | // ExpectError(TF_INVALID_ARGUMENT, 1623 | // "TF_WhileParams `cond_output` field isn't set"); 1624 | // } 1625 | 1626 | // TEST_F(CApiWhileLoopTest, BadBodyOutput) { 1627 | // Init(1); 1628 | // CreateCondGraph(); 1629 | // ExpectError(TF_INVALID_ARGUMENT, 1630 | // "TF_WhileParams `body_outputs[0]` field isn't set"); 1631 | // } 1632 | 1633 | // TEST_F(CApiWhileLoopTest, NullName) { 1634 | // Init(1); 1635 | // CreateCondGraph(); 1636 | // params_->body_outputs[0] = params_->body_inputs[0]; 1637 | // params_->name = nullptr; 1638 | // ExpectError(TF_INVALID_ARGUMENT, "TF_WhileParams `name` field is null"); 1639 | // } 1640 | 1641 | // TEST_F(CApiWhileLoopTest, WrongGraph) { 1642 | // Init(1); 1643 | // CreateCondGraph(); 1644 | // // Set body output to output from outer graph 1645 | // params_->body_outputs[0] = inputs_[0]; 1646 | // // TODO(skyewm): improve error message 1647 | // ExpectError(TF_INVALID_ARGUMENT, 1648 | // "Requested return node 'p0' not found in graph def"); 1649 | // } 1650 | 1651 | // TEST_F(CApiWhileLoopTest, BadTypes) { 1652 | // Init(1); 1653 | // CreateCondGraph(); 1654 | // // Op that has a float input + output 1655 | // TF_OperationDescription* desc = TF_NewOperation( 1656 | // params_->body_graph, "FakeQuantWithMinMaxArgs", "float_op"); 1657 | // TF_AddInput(desc, params_->body_inputs[0]); 1658 | // TF_FinishOperation(desc, s_); 1659 | // ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); 1660 | // string msg(TF_Message(s_)); 1661 | // EXPECT_NE(msg.find("Input 'inputs' passed int32 expected float while " 1662 | // "building NodeDef 'float_op'"), 1663 | // msg.npos); 1664 | // TF_AbortWhile(params_.get()); 1665 | // } 1666 | 1667 | // // Create a tensor with values of type TF_INT8 provided by `values`. 1668 | // TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) { 1669 | // int64_t num_values = 1; 1670 | // for (int i = 0; i < num_dims; ++i) { 1671 | // num_values *= dims[i]; 1672 | // } 1673 | // TF_Tensor* t = 1674 | // TF_AllocateTensor(TF_INT8, dims, num_dims, sizeof(char) * num_values); 1675 | // memcpy(TF_TensorData(t), values, sizeof(char) * num_values); 1676 | // return t; 1677 | // } 1678 | 1679 | // void StringVectorToArrays(const std::vector& v, 1680 | // std::unique_ptr* ptrs, 1681 | // std::unique_ptr* lens) { 1682 | // ptrs->reset(new const void*[v.size()]); 1683 | // lens->reset(new size_t[v.size()]); 1684 | // for (size_t i = 0; i < v.size(); ++i) { 1685 | // (*ptrs)[i] = v[i].data(); 1686 | // (*lens)[i] = v[i].size(); 1687 | // } 1688 | // } 1689 | 1690 | // // REGISTER_OP for CApiTestAttributesTest test cases. 1691 | // // Registers two ops, each with a single attribute called 'v'. 1692 | // // The attribute in one op will have a type 'type', the other 1693 | // // will have list(type). 1694 | // #define ATTR_TEST_REGISTER_OP(type) \ 1695 | // REGISTER_OP("CApiAttributesTestOp" #type).Attr("v: " #type); \ 1696 | // REGISTER_OP("CApiAttributesTestOpList" #type).Attr("v: list(" #type ")") 1697 | // ATTR_TEST_REGISTER_OP(string); 1698 | // ATTR_TEST_REGISTER_OP(int); 1699 | // ATTR_TEST_REGISTER_OP(float); 1700 | // ATTR_TEST_REGISTER_OP(bool); 1701 | // ATTR_TEST_REGISTER_OP(type); 1702 | // ATTR_TEST_REGISTER_OP(shape); 1703 | // ATTR_TEST_REGISTER_OP(tensor); 1704 | // #undef ATTR_TEST_REGISTER_OP 1705 | 1706 | // class CApiAttributesTest : public ::testing::Test { 1707 | // protected: 1708 | // CApiAttributesTest() 1709 | // : s_(TF_NewStatus()), graph_(TF_NewGraph()), counter_(0) {} 1710 | 1711 | // ~CApiAttributesTest() override { 1712 | // TF_DeleteGraph(graph_); 1713 | // TF_DeleteStatus(s_); 1714 | // } 1715 | 1716 | // TF_OperationDescription* init(string type) { 1717 | // // Construct op_name to match the name used by REGISTER_OP in the 1718 | // // ATTR_TEST_REGISTER calls above. 1719 | // string op_name = "CApiAttributesTestOp"; 1720 | // if (type.find("list(") == 0) { 1721 | // op_name += "List"; 1722 | // type = type.replace(0, 5, ""); 1723 | // type = type.replace(type.size() - 1, 1, ""); 1724 | // } 1725 | // op_name += type; 1726 | // return TF_NewOperation( 1727 | // graph_, op_name.c_str(), 1728 | // ::tensorflow::strings::StrCat("name", counter_++).c_str()); 1729 | // } 1730 | 1731 | // TF_Status* s_; 1732 | 1733 | // private: 1734 | // TF_Graph* graph_; 1735 | // int counter_; 1736 | // }; 1737 | 1738 | // // Helper macros for the TF_OperationGetAttr* tests. 1739 | // // TODO(ashankar): Use gmock matchers instead? 1740 | // // (https://github.com/google/googletest/blob/master/googlemock/docs/CookBook.md#writing-new-parameterized-matchers-quickly) 1741 | // // That will require setting up the tensorflow build with gmock. 1742 | // #define EXPECT_TF_META(attr_name, expected_list_size, expected_type, \ 1743 | // expected_total_size) \ 1744 | // do { \ 1745 | // auto m = TF_OperationGetAttrMetadata(oper, attr_name, s_); \ 1746 | // EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); \ 1747 | // const unsigned char e = expected_list_size >= 0 ? 1 : 0; \ 1748 | // EXPECT_EQ(e, m.is_list); \ 1749 | // EXPECT_EQ(expected_list_size, m.list_size); \ 1750 | // EXPECT_EQ(expected_type, m.type); \ 1751 | // EXPECT_EQ(expected_total_size, m.total_size); \ 1752 | // } while (0) 1753 | 1754 | // TEST_F(CApiAttributesTest, String) { 1755 | // auto desc = init("string"); 1756 | // TF_SetAttrString(desc, "v", "bunny", 5); 1757 | 1758 | // auto oper = TF_FinishOperation(desc, s_); 1759 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1760 | // EXPECT_TF_META("v", -1, TF_ATTR_STRING, 5); 1761 | // std::unique_ptr value(new char[5]); 1762 | 1763 | // TF_OperationGetAttrString(oper, "v", value.get(), 5, s_); 1764 | // EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1765 | // EXPECT_EQ("bunny", string(static_cast(value.get()), 5)); 1766 | // } 1767 | 1768 | // TEST_F(CApiAttributesTest, StringList) { 1769 | // std::vector list = {"bugs", "bunny", "duck"}; 1770 | // std::unique_ptr list_ptrs; 1771 | // std::unique_ptr list_lens; 1772 | // StringVectorToArrays(list, &list_ptrs, &list_lens); 1773 | // int list_total_size = 0; 1774 | // for (const auto& s : list) { 1775 | // list_total_size += s.size(); 1776 | // } 1777 | 1778 | // auto desc = init("list(string)"); 1779 | // TF_SetAttrStringList(desc, "v", list_ptrs.get(), list_lens.get(), 1780 | // list.size()); 1781 | 1782 | // auto oper = TF_FinishOperation(desc, s_); 1783 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1784 | 1785 | // EXPECT_TF_META("v", list.size(), TF_ATTR_STRING, list_total_size); 1786 | // std::unique_ptr values(new void*[list.size()]); 1787 | // std::unique_ptr lens(new size_t[list.size()]); 1788 | // std::unique_ptr storage(new char[list_total_size]); 1789 | // TF_OperationGetAttrStringList(oper, "v", values.get(), lens.get(), 1790 | // list.size(), storage.get(), list_total_size, 1791 | // s_); 1792 | // EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1793 | // for (size_t i = 0; i < list.size(); ++i) { 1794 | // EXPECT_EQ(list[i].size(), lens[i]) << i; 1795 | // EXPECT_EQ(list[i], string(static_cast(values[i]), lens[i])) 1796 | // << i; 1797 | // } 1798 | // } 1799 | 1800 | // TEST_F(CApiAttributesTest, Int) { 1801 | // auto desc = init("int"); 1802 | // TF_SetAttrInt(desc, "v", 31415); 1803 | 1804 | // auto oper = TF_FinishOperation(desc, s_); 1805 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1806 | // EXPECT_TF_META("v", -1, TF_ATTR_INT, -1); 1807 | 1808 | // int64_t value; 1809 | // TF_OperationGetAttrInt(oper, "v", &value, s_); 1810 | // EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1811 | // EXPECT_EQ(31415, value); 1812 | // } 1813 | 1814 | // TEST_F(CApiAttributesTest, IntList) { 1815 | // const int64_t list[] = {1, 2, 3, 4}; 1816 | // const size_t list_size = TF_ARRAYSIZE(list); 1817 | 1818 | // auto desc = init("list(int)"); 1819 | // TF_SetAttrIntList(desc, "v", list, list_size); 1820 | 1821 | // auto oper = TF_FinishOperation(desc, s_); 1822 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1823 | 1824 | // int64_t values[list_size]; 1825 | // EXPECT_TF_META("v", list_size, TF_ATTR_INT, -1); 1826 | // TF_OperationGetAttrIntList(oper, "v", values, list_size, s_); 1827 | // EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1828 | // EXPECT_TRUE(std::equal(std::begin(list), std::end(list), std::begin(values))); 1829 | // } 1830 | 1831 | // TEST_F(CApiAttributesTest, Float) { 1832 | // auto desc = init("float"); 1833 | // TF_SetAttrFloat(desc, "v", 2.718); 1834 | 1835 | // auto oper = TF_FinishOperation(desc, s_); 1836 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1837 | // EXPECT_TF_META("v", -1, TF_ATTR_FLOAT, -1); 1838 | 1839 | // float value; 1840 | // TF_OperationGetAttrFloat(oper, "v", &value, s_); 1841 | // EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1842 | // EXPECT_FLOAT_EQ(2.718, value); 1843 | // } 1844 | 1845 | // TEST_F(CApiAttributesTest, FloatList) { 1846 | // const float list[] = {1.414, 2.718, 3.1415}; 1847 | // const size_t list_size = TF_ARRAYSIZE(list); 1848 | 1849 | // auto desc = init("list(float)"); 1850 | // TF_SetAttrFloatList(desc, "v", list, list_size); 1851 | 1852 | // auto oper = TF_FinishOperation(desc, s_); 1853 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1854 | 1855 | // float values[list_size]; 1856 | // EXPECT_TF_META("v", list_size, TF_ATTR_FLOAT, -1); 1857 | // TF_OperationGetAttrFloatList(oper, "v", values, list_size, s_); 1858 | // EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1859 | // EXPECT_TRUE(std::equal(std::begin(list), std::end(list), std::begin(values))); 1860 | // } 1861 | 1862 | // TEST_F(CApiAttributesTest, Bool) { 1863 | // auto desc = init("bool"); 1864 | // TF_SetAttrBool(desc, "v", 1); 1865 | 1866 | // auto oper = TF_FinishOperation(desc, s_); 1867 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1868 | // EXPECT_TF_META("v", -1, TF_ATTR_BOOL, -1); 1869 | 1870 | // unsigned char value; 1871 | // TF_OperationGetAttrBool(oper, "v", &value, s_); 1872 | // EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1873 | // EXPECT_EQ(1, value); 1874 | // } 1875 | 1876 | // TEST_F(CApiAttributesTest, BoolList) { 1877 | // const unsigned char list[] = {0, 1, 1, 0, 0, 1, 1}; 1878 | // const size_t list_size = TF_ARRAYSIZE(list); 1879 | 1880 | // auto desc = init("list(bool)"); 1881 | // TF_SetAttrBoolList(desc, "v", list, list_size); 1882 | 1883 | // auto oper = TF_FinishOperation(desc, s_); 1884 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1885 | 1886 | // unsigned char values[list_size]; 1887 | // EXPECT_TF_META("v", list_size, TF_ATTR_BOOL, -1); 1888 | // TF_OperationGetAttrBoolList(oper, "v", values, list_size, s_); 1889 | // EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1890 | // EXPECT_TRUE(std::equal(std::begin(list), std::end(list), std::begin(values))); 1891 | // } 1892 | 1893 | // TEST_F(CApiAttributesTest, Type) { 1894 | // auto desc = init("type"); 1895 | // TF_SetAttrType(desc, "v", TF_COMPLEX128); 1896 | 1897 | // auto oper = TF_FinishOperation(desc, s_); 1898 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1899 | // EXPECT_TF_META("v", -1, TF_ATTR_TYPE, -1); 1900 | 1901 | // TF_DataType value; 1902 | // TF_OperationGetAttrType(oper, "v", &value, s_); 1903 | // EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1904 | // EXPECT_EQ(TF_COMPLEX128, value); 1905 | // } 1906 | 1907 | // TEST_F(CApiAttributesTest, TypeList) { 1908 | // const TF_DataType list[] = {TF_FLOAT, TF_DOUBLE, TF_HALF, TF_COMPLEX128}; 1909 | // const size_t list_size = TF_ARRAYSIZE(list); 1910 | 1911 | // auto desc = init("list(type)"); 1912 | // TF_SetAttrTypeList(desc, "v", list, list_size); 1913 | 1914 | // auto oper = TF_FinishOperation(desc, s_); 1915 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1916 | 1917 | // TF_DataType values[list_size]; 1918 | // EXPECT_TF_META("v", list_size, TF_ATTR_TYPE, -1); 1919 | // TF_OperationGetAttrTypeList(oper, "v", values, list_size, s_); 1920 | // EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1921 | // EXPECT_TRUE(std::equal(std::begin(list), std::end(list), std::begin(values))); 1922 | // } 1923 | 1924 | // TEST_F(CApiAttributesTest, Shape) { 1925 | // // Unknown shape 1926 | // auto desc = init("shape"); 1927 | // TF_SetAttrShape(desc, "v", nullptr, -1); 1928 | // auto oper = TF_FinishOperation(desc, s_); 1929 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1930 | // EXPECT_TF_META("v", -1, TF_ATTR_SHAPE, -1); 1931 | // TF_OperationGetAttrShape(oper, "v", nullptr, 10, s_); 1932 | // EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1933 | 1934 | // // Partially specified shape 1935 | // const int64_t partial_shape[] = {17, -1}; 1936 | // const size_t sz = TF_ARRAYSIZE(partial_shape); 1937 | // desc = init("shape"); 1938 | // TF_SetAttrShape(desc, "v", partial_shape, sz); 1939 | // oper = TF_FinishOperation(desc, s_); 1940 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1941 | // EXPECT_TF_META("v", -1, TF_ATTR_SHAPE, sz); 1942 | // int64_t values[sz]; 1943 | // TF_OperationGetAttrShape(oper, "v", values, sz, s_); 1944 | // EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1945 | // EXPECT_TRUE( 1946 | // std::equal(std::begin(partial_shape), std::end(partial_shape), values)); 1947 | // } 1948 | 1949 | // TEST_F(CApiAttributesTest, ShapeList) { 1950 | // const int64_t shape_1[] = {1, 3}; 1951 | // const int64_t shape_2[] = {2, 4, 6}; 1952 | // const int64_t* list[] = {&shape_1[0], &shape_2[0]}; 1953 | // const size_t list_size = TF_ARRAYSIZE(list); 1954 | // const int ndims[] = {TF_ARRAYSIZE(shape_1), TF_ARRAYSIZE(shape_2)}; 1955 | // const int total_ndims = 5; // ndims[0] + ndims[1] 1956 | 1957 | // auto desc = init("list(shape)"); 1958 | // TF_SetAttrShapeList(desc, "v", list, ndims, list_size); 1959 | // auto oper = TF_FinishOperation(desc, s_); 1960 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1961 | 1962 | // EXPECT_TF_META("v", list_size, TF_ATTR_SHAPE, total_ndims); 1963 | // int64_t* values[list_size]; 1964 | // int values_ndims[list_size]; 1965 | // int64_t storage[total_ndims]; 1966 | // TF_OperationGetAttrShapeList(oper, "v", values, values_ndims, list_size, 1967 | // storage, total_ndims, s_); 1968 | // EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1969 | // for (size_t i = 0; i < list_size; ++i) { 1970 | // EXPECT_EQ(ndims[i], values_ndims[i]) << i; 1971 | // for (int j = 0; j < values_ndims[i]; ++j) { 1972 | // EXPECT_EQ(list[i][j], values[i][j]) << "(" << i << ", " << j << ")"; 1973 | // } 1974 | // } 1975 | // } 1976 | 1977 | // TEST_F(CApiAttributesTest, TensorShapeProto) { 1978 | // const tensorflow::int64 pts[] = {2, 4, -1, 8}; 1979 | // tensorflow::TensorShapeProto proto; 1980 | // tensorflow::PartialTensorShape(pts).AsProto(&proto); 1981 | // string bytes; 1982 | // proto.SerializeToString(&bytes); 1983 | 1984 | // auto desc = init("shape"); 1985 | // TF_SetAttrTensorShapeProto(desc, "v", bytes.data(), bytes.length(), s_); 1986 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1987 | // auto oper = TF_FinishOperation(desc, s_); 1988 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1989 | 1990 | // EXPECT_TF_META("v", -1, TF_ATTR_SHAPE, 4); 1991 | // TF_Buffer* value = TF_NewBuffer(); 1992 | // TF_OperationGetAttrTensorShapeProto(oper, "v", value, s_); 1993 | // EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 1994 | // EXPECT_EQ(bytes.length(), value->length); 1995 | // EXPECT_EQ(0, memcmp(bytes.data(), value->data, value->length)); 1996 | // TF_DeleteBuffer(value); 1997 | // } 1998 | 1999 | // TEST_F(CApiAttributesTest, TensorShapeProtoList) { 2000 | // string bytes1, bytes2; 2001 | // tensorflow::TensorShapeProto proto; 2002 | 2003 | // const tensorflow::int64 pts1[] = {2, 4, -1, 8}; 2004 | // tensorflow::PartialTensorShape(pts1).AsProto(&proto); 2005 | // proto.SerializeToString(&bytes1); 2006 | 2007 | // const tensorflow::int64 pts2[] = {1, 3, 5, 7}; 2008 | // tensorflow::PartialTensorShape(pts2).AsProto(&proto); 2009 | // proto.SerializeToString(&bytes2); 2010 | 2011 | // std::unique_ptr list_ptrs; 2012 | // std::unique_ptr list_lens; 2013 | // const std::vector list = {bytes1, bytes2}; 2014 | // StringVectorToArrays(list, &list_ptrs, &list_lens); 2015 | 2016 | // auto desc = init("list(shape)"); 2017 | // TF_SetAttrTensorShapeProtoList(desc, "v", list_ptrs.get(), list_lens.get(), 2018 | // list.size(), s_); 2019 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 2020 | // auto oper = TF_FinishOperation(desc, s_); 2021 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 2022 | 2023 | // EXPECT_TF_META("v", 2, TF_ATTR_SHAPE, 8); 2024 | // TF_Buffer* values[2]; 2025 | // TF_OperationGetAttrTensorShapeProtoList(oper, "v", values, 2, s_); 2026 | // EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 2027 | // for (int i = 0; i < 2; ++i) { 2028 | // int le = list_lens[i]; 2029 | // int la = values[i]->length; 2030 | // const void* e = list_ptrs[i]; 2031 | // const void* a = values[i]->data; 2032 | // EXPECT_EQ(le, la) << i; 2033 | // EXPECT_EQ(0, memcmp(e, a, std::min(le, la))) << i; 2034 | // TF_DeleteBuffer(values[i]); 2035 | // } 2036 | // } 2037 | 2038 | // TEST_F(CApiAttributesTest, Tensor) { 2039 | // const char tensor[] = {5, 7}; 2040 | // const int64_t dims[] = {1, 2}; 2041 | // const size_t ndims = TF_ARRAYSIZE(dims); 2042 | 2043 | // auto desc = init("tensor"); 2044 | // unique_tensor_ptr v(Int8Tensor(dims, ndims, tensor), TF_DeleteTensor); 2045 | // TF_SetAttrTensor(desc, "v", v.get(), s_); 2046 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 2047 | 2048 | // auto oper = TF_FinishOperation(desc, s_); 2049 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 2050 | 2051 | // EXPECT_TF_META("v", -1, TF_ATTR_TENSOR, -1); 2052 | // TF_Tensor* value; 2053 | // TF_OperationGetAttrTensor(oper, "v", &value, s_); 2054 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 2055 | // ASSERT_NE(nullptr, value); 2056 | // EXPECT_EQ(TF_INT8, TF_TensorType(value)); 2057 | // EXPECT_EQ(ndims, TF_NumDims(value)); 2058 | // for (int i = 0; i < TF_NumDims(value); ++i) { 2059 | // EXPECT_EQ(dims[i], TF_Dim(value, i)) << i; 2060 | // } 2061 | // EXPECT_EQ(sizeof(char) * TF_ARRAYSIZE(tensor), TF_TensorByteSize(value)); 2062 | // EXPECT_EQ(0, memcmp(tensor, TF_TensorData(value), TF_TensorByteSize(value))); 2063 | // TF_DeleteTensor(value); 2064 | // } 2065 | 2066 | // TEST_F(CApiAttributesTest, TensorList) { 2067 | // const char tensor1[] = {5, 7}; 2068 | // const int64_t dims1[] = {1, 2}; 2069 | // const size_t ndims1 = TF_ARRAYSIZE(dims1); 2070 | 2071 | // const char tensor2[] = {2, 4, 6, 8}; 2072 | // const int64_t dims2[] = {2, 2}; 2073 | // const size_t ndims2 = TF_ARRAYSIZE(dims2); 2074 | 2075 | // auto desc = init("list(tensor)"); 2076 | // TF_Tensor* tmp[] = { 2077 | // Int8Tensor(dims1, ndims1, tensor1), Int8Tensor(dims2, ndims2, tensor2), 2078 | // }; 2079 | // TF_SetAttrTensorList(desc, "v", tmp, TF_ARRAYSIZE(tmp), s_); 2080 | // for (int i = 0; i < TF_ARRAYSIZE(tmp); ++i) { 2081 | // TF_DeleteTensor(tmp[i]); 2082 | // } 2083 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 2084 | // auto oper = TF_FinishOperation(desc, s_); 2085 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 2086 | 2087 | // EXPECT_TF_META("v", 2, TF_ATTR_TENSOR, -1); 2088 | // TF_Tensor* values[2]; 2089 | // TF_OperationGetAttrTensorList(oper, "v", &values[0], TF_ARRAYSIZE(values), 2090 | // s_); 2091 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 2092 | 2093 | // const char* tensor_data[] = {&tensor1[0], &tensor2[0]}; 2094 | // const size_t tensor_size[] = {TF_ARRAYSIZE(tensor1), TF_ARRAYSIZE(tensor2)}; 2095 | // const int64_t* tensor_dims[] = {&dims1[0], &dims2[0]}; 2096 | // const size_t tensor_ndims[] = {ndims1, ndims2}; 2097 | // for (int i = 0; i < 2; ++i) { 2098 | // TF_Tensor* v = values[i]; 2099 | // ASSERT_NE(nullptr, v) << i; 2100 | // EXPECT_EQ(TF_INT8, TF_TensorType(v)) << i; 2101 | // EXPECT_EQ(tensor_ndims[i], TF_NumDims(v)) << i; 2102 | // for (int j = 0; j < TF_NumDims(v); ++j) { 2103 | // EXPECT_EQ(tensor_dims[i][j], TF_Dim(v, j)) 2104 | // << "Tensor #" << i << ", dimension #" << j; 2105 | // } 2106 | // EXPECT_EQ(sizeof(char) * tensor_size[i], TF_TensorByteSize(v)) << i; 2107 | // EXPECT_EQ(0, 2108 | // memcmp(tensor_data[i], TF_TensorData(v), TF_TensorByteSize(v))); 2109 | // TF_DeleteTensor(v); 2110 | // } 2111 | // } 2112 | 2113 | // TEST_F(CApiAttributesTest, EmptyList) { 2114 | // auto desc = init("list(int)"); 2115 | // TF_SetAttrIntList(desc, "v", nullptr, 0); 2116 | // auto oper = TF_FinishOperation(desc, s_); 2117 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 2118 | // EXPECT_TF_META("v", 0, TF_ATTR_INT, -1); 2119 | // } 2120 | 2121 | // TEST_F(CApiAttributesTest, Errors) { 2122 | // auto desc = init("int"); 2123 | // TF_SetAttrInt(desc, "v", 3); 2124 | // auto oper = TF_FinishOperation(desc, s_); 2125 | // ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); 2126 | // TF_OperationGetAttrString(oper, "v", nullptr, 0, s_); 2127 | // EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); 2128 | // } 2129 | // #undef EXPECT_TF_META 2130 | --------------------------------------------------------------------------------