├── .gitignore ├── CMakeLists.txt ├── README.md ├── _inner_product_grad.py ├── inner_product.cc ├── inner_product_grad.cc └── inner_product_tests.py /.gitignore: -------------------------------------------------------------------------------- 1 | /build/ -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 2.8) 2 | 3 | # get tensorflow include dirs, see https://www.tensorflow.org/how_tos/adding_an_op/ 4 | execute_process(COMMAND python3 -c "import tensorflow; print(tensorflow.sysconfig.get_include())" OUTPUT_VARIABLE Tensorflow_INCLUDE_DIRS) 5 | execute_process(COMMAND python3 -c "import tensorflow as tf; print(' '.join(tf.sysconfig.get_link_flags()), end='')" OUTPUT_VARIABLE Tensorflow_LINK_FLAGS) 6 | execute_process(COMMAND python3 -c "import tensorflow as tf; print(' '.join(tf.sysconfig.get_compile_flags()), end='')" OUTPUT_VARIABLE Tensorflow_COMPILE_FLAGS) 7 | 8 | 9 | # C++11 required for tensorflow 10 | set(CMAKE_CXX_FLAGS "-std=c++11 ${Tensorflow_COMPILE_FLAGS} ${CMAKE_CXX_FLAGS}") 11 | set(CMAKE_SHARED_LINKER_FLAGS "-Wl,--no-as-needed ${Tensorflow_LINK_FLAGS} ${CMAKE_SHARED_LINKER_FLAGS}") 12 | 13 | # if GCC > 5 14 | # if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 5.0 OR CMAKE_CXX_COMPILER_VERSION VERSION_EQUAL 5.0) 15 | # set(CMAKE_CXX_FLAGS "-D_GLIBCXX_USE_CXX11_ABI=0 ${CMAKE_CXX_FLAGS}") 16 | # endif() 17 | 18 | # build the actual operation which can be used directory 19 | include_directories(${Tensorflow_INCLUDE_DIRS}) 20 | add_library(inner_product SHARED inner_product.cc) 21 | 22 | # build the gradient operation which is used in inner_product_grad.py 23 | # to register it 24 | include_directories(${Tensorflow_INCLUDE_DIRS}) 25 | add_library(inner_product_grad SHARED inner_product_grad.cc) 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Example of Tensorflow Operation in C++ 2 | 3 | This repository contains an example of a simple Tensorflow operation and its gradient both implemented in C++, as described in [this article](http://davidstutz.de/implementing-tensorflow-operations-in-c-including-gradients/). 4 | 5 | ## Building 6 | 7 | The operation is built using [CMake](https://cmake.org/) and requires an appropriate version of Tensorflow to be installed. In order to get the necessary include directories containing the Tensorflow header files, the following trick is used (also see the [Tensorflow documentation](https://www.tensorflow.org/how_tos/adding_an_op/)): 8 | 9 | import tensorflow 10 | print(tensorflow.sysconfig.get_include()) 11 | 12 | In the `CMakeLists.txt` this is used as follows: 13 | 14 | execute_process(COMMAND python3 -c "import tensorflow; print(tensorflow.sysconfig.get_include())" OUTPUT_VARIABLE Tensorflow_INCLUDE_DIRS) 15 | 16 | The remaining contents are pretty standard. Building is now done using: 17 | 18 | $ mkdir build 19 | $ cd build 20 | $ cmake .. 21 | $ make 22 | Scanning dependencies of target inner_product 23 | [ 50%] Building CXX object CMakeFiles/inner_product.dir/inner_product.cc.o 24 | Linking CXX shared library libinner_product.so 25 | [ 50%] Built target inner_product 26 | Scanning dependencies of target inner_product_grad 27 | [100%] Building CXX object CMakeFiles/inner_product_grad.dir/inner_product_grad.cc.o 28 | Linking CXX shared library libinner_product_grad.so 29 | [100%] Built target inner_product_grad 30 | 31 | `libinner_product.so` and `libinner_product_grad.so` can be found in `build` and need to be included in order to load the module in Python: 32 | 33 | import tensorflow as tf 34 | inner_product_module = tf.load_op_library('build/libinner_product.so') 35 | 36 | See `inner_product_tests.py` for usage examples. 37 | 38 | ## License 39 | 40 | Copyright (c) 2016 David Stutz 41 | 42 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 43 | 44 | Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. 45 | -------------------------------------------------------------------------------- /_inner_product_grad.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Gradients for inner product. 4 | 5 | .. moduleauthor:: David Stutz 6 | """ 7 | 8 | import tensorflow as tf 9 | from tensorflow.python.framework import ops 10 | from tensorflow.python.ops import array_ops 11 | from tensorflow.python.ops import sparse_ops 12 | inner_product_grad_module = tf.load_op_library('libinner_product_grad.so') 13 | 14 | @ops.RegisterGradient("InnerProduct") 15 | def _inner_product_grad_cc(op, grad): 16 | """ 17 | The gradient for `inner_product` using the operation implemented in C++. 18 | 19 | :param op: `inner_product` `Operation` that we are differentiating, which we can use 20 | to find the inputs and outputs of the original op. 21 | :param grad: gradient with respect to the output of the `inner_product` op. 22 | :return: gradients with respect to the input of `inner_product`. 23 | """ 24 | 25 | return inner_product_grad_module.inner_product_grad(grad, op.inputs[0], op.inputs[1]) 26 | 27 | # uncomment this and comment the corresponding line above to use the Python 28 | # implementation of the inner product gradient 29 | #@ops.RegisterGradient("InnerProduct") 30 | def _inner_product_grad(op, grad): 31 | """ 32 | The gradients for `inner_product`. 33 | 34 | :param op: `inner_product` `Operation` that we are differentiating, which we can use 35 | to find the inputs and outputs of the original op. 36 | :param grad: gradient with respect to the output of the `inner_product` op. 37 | :return: gradients with respect to the input of `inner_product`. 38 | """ 39 | 40 | input_tensor = op.inputs[0] 41 | weight_tensor = op.inputs[1] 42 | input_rows = array_ops.shape(input_tensor)[0] 43 | output_rows = array_ops.shape(weight_tensor)[0] 44 | 45 | grad_input = tf.matmul(tf.transpose(grad), weight_tensor) 46 | grad_weights = tf.multiply(tf.transpose(grad), tf.reshape(tf.tile(tf.reshape(input_tensor, [input_rows]), [output_rows]), [output_rows, -1])) 47 | 48 | return [tf.transpose(grad_input), grad_weights] 49 | 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /inner_product.cc: -------------------------------------------------------------------------------- 1 | /// \file inner_product.cc 2 | /// \author David Stutz 3 | /// \brief Implementation of a inner product (i.e. fully connected layer) 4 | /// operation in Tensorflow. 5 | 6 | #include "tensorflow/core/framework/op_kernel.h" 7 | #include "tensorflow/core/framework/tensor_shape.h" 8 | #include "tensorflow/core/platform/default/logging.h" 9 | #include "tensorflow/core/framework/shape_inference.h" 10 | 11 | using namespace tensorflow; 12 | 13 | REGISTER_OP("InnerProduct") 14 | .Input("input: float") 15 | .Input("weights: float") 16 | .Output("inner_product: float") 17 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 18 | shape_inference::ShapeHandle input_shape; 19 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &input_shape)); 20 | 21 | shape_inference::ShapeHandle weight_shape; 22 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &weight_shape)); 23 | 24 | shape_inference::DimensionHandle output_rows = c->Dim(weight_shape, 0); 25 | 26 | shape_inference::DimensionHandle input_rows = c->Dim(input_shape, 0); 27 | shape_inference::DimensionHandle weight_cols = c->Dim(weight_shape, 1); 28 | shape_inference::DimensionHandle merged; 29 | TF_RETURN_IF_ERROR(c->Merge(input_rows, weight_cols, &merged)); 30 | 31 | c->set_output(0, c->Matrix(output_rows, 1)); 32 | return Status::OK(); 33 | }); 34 | 35 | /// \brief Implementation of an inner product operation. 36 | /// \param context 37 | /// \author David Stutz 38 | class InnerProductOp : public OpKernel { 39 | public: 40 | /// \brief Constructor. 41 | /// \param context 42 | explicit InnerProductOp(OpKernelConstruction* context) : OpKernel(context) { 43 | 44 | } 45 | 46 | /// \brief Compute the inner product. 47 | /// \param context 48 | void Compute(OpKernelContext* context) override { 49 | 50 | // some checks to be sure ... 51 | DCHECK_EQ(2, context->num_inputs()); 52 | 53 | // get the input tensor 54 | const Tensor& input = context->input(0); 55 | 56 | // get the weight tensor 57 | const Tensor& weights = context->input(1); 58 | 59 | // check shapes of input and weights 60 | const TensorShape& input_shape = input.shape(); 61 | const TensorShape& weights_shape = weights.shape(); 62 | 63 | // check input is a standing vector 64 | DCHECK_EQ(input_shape.dims(), 2); 65 | DCHECK_EQ(input_shape.dim_size(1), 1); 66 | 67 | // check weights is matrix of correct size 68 | DCHECK_EQ(weights_shape.dims(), 2); 69 | DCHECK_EQ(input_shape.dim_size(0), weights_shape.dim_size(1)); 70 | 71 | // create output shape 72 | TensorShape output_shape; 73 | output_shape.AddDim(weights_shape.dim_size(0)); 74 | output_shape.AddDim(1); 75 | 76 | // create output tensor 77 | Tensor* output = NULL; 78 | OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); 79 | 80 | // get the corresponding Eigen tensors for data access 81 | auto input_tensor = input.matrix(); 82 | auto weights_tensor = weights.matrix(); 83 | auto output_tensor = output->matrix(); 84 | 85 | for (int i = 0; i < output->shape().dim_size(0); i++) { 86 | output_tensor(i, 0) = 0; 87 | for (int j = 0; j < weights.shape().dim_size(1); j++) { 88 | output_tensor(i, 0) += weights_tensor(i, j)*input_tensor(j, 0); 89 | } 90 | } 91 | } 92 | }; 93 | 94 | REGISTER_KERNEL_BUILDER(Name("InnerProduct").Device(DEVICE_CPU), InnerProductOp); 95 | -------------------------------------------------------------------------------- /inner_product_grad.cc: -------------------------------------------------------------------------------- 1 | /// \file inner_product_grad.cc 2 | /// \author David Stutz 3 | /// \brief Implementation of the gradient of a inner product operation, see 4 | /// inner_product.cc. 5 | 6 | #include "tensorflow/core/framework/op_kernel.h" 7 | #include "tensorflow/core/framework/shape_inference.h" 8 | 9 | using namespace tensorflow; 10 | 11 | // the gradients are simply passed as additional arguments as 12 | // they are available in the Python function for registering the gradient operation. 13 | REGISTER_OP("InnerProductGrad") 14 | .Input("grad: float32") 15 | .Input("input: float32") 16 | .Input("weights: float32") 17 | .Output("grad_input: float32") 18 | .Output("grad_weights: float32"); 19 | 20 | /// \brief Implementation of an inner product gradient operation. 21 | /// Note that this operation is used in Python to register the gradient as 22 | /// this is not possible in C*+ right now. 23 | /// \param context 24 | /// \author David Stutz 25 | class InnerProductGradOp : public OpKernel { 26 | public: 27 | /// \brief Constructor. 28 | /// \param context 29 | explicit InnerProductGradOp(OpKernelConstruction* context) : OpKernel(context) { 30 | 31 | } 32 | 33 | /// \brief Compute the inner product gradients. 34 | /// \param context 35 | void Compute(OpKernelContext* context) override { 36 | 37 | // output and grad is provided as input 38 | DCHECK_EQ(3, context->num_inputs()); 39 | 40 | // get the gradient tensor 41 | const Tensor& grad = context->input(0); 42 | 43 | // get the original input tensor 44 | const Tensor& input = context->input(1); 45 | 46 | // get the weight tensor 47 | const Tensor& weights = context->input(2); 48 | 49 | // create input shape (inferred from the additional attribute `n`) 50 | TensorShape input_shape = input.shape(); 51 | TensorShape weights_shape = weights.shape(); 52 | 53 | DCHECK_EQ(input_shape.dim_size(0), weights_shape.dim_size(1)); 54 | DCHECK_EQ(weights_shape.dim_size(0), grad.shape().dim_size(0)); 55 | 56 | // create output tensors 57 | Tensor* grad_input = NULL; 58 | Tensor* grad_weights = NULL; 59 | OP_REQUIRES_OK(context, context->allocate_output(0, input_shape, &grad_input)); 60 | OP_REQUIRES_OK(context, context->allocate_output(1, weights_shape, &grad_weights)); 61 | 62 | // get the Eigen tensors for data access 63 | auto grad_tensor = grad.matrix(); 64 | auto weights_tensor = weights.matrix(); 65 | auto input_tensor = input.matrix(); 66 | auto grad_input_tensor = grad_input->matrix(); 67 | auto grad_weights_tensor = grad_weights->matrix(); 68 | 69 | // doign it manually for ismplicity 70 | for (int i = 0; i < weights_shape.dim_size(0); i++) { 71 | grad_input_tensor(i, 0) = 0; 72 | for (int j = 0; j < grad.shape().dim_size(0); j++) { 73 | grad_input_tensor(i, 0) += grad_tensor(j, 0)*weights_tensor(j, i); 74 | } 75 | } 76 | 77 | for (int i = 0; i < weights_shape.dim_size(0); i++) { 78 | for (int j = 0; j < weights_shape.dim_size(1); j++) { 79 | grad_weights_tensor(i, j) = grad_tensor(i, 0)*input_tensor(j, 0);; 80 | } 81 | } 82 | } 83 | }; 84 | 85 | REGISTER_KERNEL_BUILDER(Name("InnerProductGrad").Device(DEVICE_CPU), InnerProductGradOp); -------------------------------------------------------------------------------- /inner_product_tests.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Tests for the inner product Tensorflow operation. 4 | 5 | .. moduleauthor:: David Stutz 6 | """ 7 | 8 | import unittest 9 | import numpy as np 10 | import tensorflow as tf 11 | import _inner_product_grad 12 | inner_product_module = tf.load_op_library('libinner_product.so') 13 | 14 | class InnerProductOpTest(unittest.TestCase): 15 | def test_raisesExceptionWithIncompatibleDimensions(self): 16 | with tf.Session(''): 17 | with self.assertRaises(ValueError): 18 | inner_product_module.inner_product([1, 2], [[1, 2], [3, 4]]).eval() 19 | with self.assertRaises(ValueError): 20 | self.assertRaises(inner_product_module.inner_product([1, 2], [1, 2, 3, 4]).eval(), ValueError) 21 | with self.assertRaises(ValueError): 22 | self.assertRaises(inner_product_module.inner_product([1, 2, 3], [[1, 2], [3, 4]]).eval(), ValueError) 23 | 24 | def test_innerProductHardCoded(self): 25 | with tf.Session(''): 26 | result = inner_product_module.inner_product([[1], [2]], [[1, 2], [3, 4]]).eval() 27 | self.assertEqual(result.shape[0], 2) 28 | self.assertEqual(result[0], 5) 29 | self.assertEqual(result[1], 11) 30 | 31 | def test_innerProductGradientXHardCoded(self): 32 | with tf.Session('') as sess: 33 | x = tf.placeholder(tf.float32, shape = (2)) 34 | W = tf.constant(np.asarray([[1, 2], [3, 4]]).astype(np.float32)) 35 | 36 | Wx_tf = tf.matmul(W, tf.reshape(x, [-1, 1])) 37 | Wx_inner_product = inner_product_module.inner_product(tf.reshape(x, [-1, 1]), W) 38 | 39 | grad_x_tf = tf.gradients(Wx_tf, x) 40 | grad_x_inner_product = tf.gradients(Wx_inner_product, x) 41 | 42 | gradient_tf = sess.run(grad_x_tf, feed_dict = {x: np.asarray([1, 2]).astype(np.float32)}) 43 | gradient_inner_product = sess.run(grad_x_inner_product, feed_dict = {x: np.asarray([1, 2]).astype(np.float32)}) 44 | 45 | self.assertEqual(gradient_tf[0][0], gradient_inner_product[0][0]) 46 | self.assertEqual(gradient_tf[0][1], gradient_inner_product[0][1]) 47 | 48 | def test_innerProductGradientWHardCoded(self): 49 | with tf.Session('') as sess: 50 | x = tf.constant(np.asarray([1, 2]).astype(np.float32)) 51 | W = tf.placeholder(tf.float32, shape = (2, 2)) 52 | 53 | Wx_tf = tf.matmul(W, tf.reshape(x, [-1, 1])) 54 | Wx_inner_product = inner_product_module.inner_product(tf.reshape(x, [-1, 1]), W) 55 | 56 | grad_W_tf = tf.gradients(Wx_tf, W) 57 | grad_W_inner_product = tf.gradients(Wx_inner_product, W) 58 | 59 | gradient_tf = sess.run(grad_W_tf, feed_dict = {W: np.asarray([[1, 2], [3, 4]]).astype(np.float32)}) 60 | gradient_inner_product = sess.run(grad_W_inner_product, feed_dict = {W: np.asarray([[1, 2], [3, 4]]).astype(np.float32)}) 61 | 62 | self.assertEqual(gradient_tf[0][0][0], gradient_inner_product[0][0][0]) 63 | self.assertEqual(gradient_tf[0][0][1], gradient_inner_product[0][0][1]) 64 | self.assertEqual(gradient_tf[0][1][0], gradient_inner_product[0][1][0]) 65 | self.assertEqual(gradient_tf[0][1][1], gradient_inner_product[0][1][1]) 66 | 67 | def test_innerProductRandom(self): 68 | with tf.Session(''): 69 | n = 4 70 | m = 5 71 | 72 | for i in range(100): 73 | x_rand = np.random.randint(10, size = (n, 1)) 74 | W_rand = np.random.randint(10, size = (m, n)) 75 | result_rand = np.dot(W_rand, x_rand) 76 | 77 | result = inner_product_module.inner_product(x_rand, W_rand).eval() 78 | np.testing.assert_array_equal(result, result_rand) 79 | 80 | def test_innerProductGradientXRandom(self): 81 | with tf.Session('') as sess: 82 | n = 4 83 | m = 5 84 | 85 | x = tf.placeholder(tf.float32, shape = (n)) 86 | W = tf.placeholder(tf.float32, shape = (m, n)) 87 | 88 | Wx_tf = tf.matmul(W, tf.reshape(x, [-1, 1])) 89 | Wx_inner_product = inner_product_module.inner_product(tf.reshape(x, [-1, 1]), W) 90 | 91 | grad_x_tf = tf.gradients(Wx_tf, x) 92 | grad_x_inner_product = tf.gradients(Wx_inner_product, x) 93 | 94 | for i in range(100): 95 | x_rand = np.random.randint(10, size = (n)) 96 | W_rand = np.random.randint(10, size = (m, n)) 97 | 98 | gradient_tf = sess.run(grad_x_tf, feed_dict = {x: x_rand, W: W_rand}) 99 | gradient_inner_product = sess.run(grad_x_inner_product, feed_dict = {x: x_rand, W: W_rand}) 100 | 101 | np.testing.assert_array_equal(gradient_tf, gradient_inner_product) 102 | 103 | def test_innerProductGradientWRandom(self): 104 | with tf.Session('') as sess: 105 | n = 4 106 | m = 5 107 | 108 | x = tf.placeholder(tf.float32, shape = (n)) 109 | W = tf.placeholder(tf.float32, shape = (m, n)) 110 | 111 | Wx_tf = tf.matmul(W, tf.reshape(x, [-1, 1])) 112 | Wx_inner_product = inner_product_module.inner_product(tf.reshape(x, [-1, 1]), W) 113 | 114 | grad_W_tf = tf.gradients(Wx_tf, W) 115 | grad_W_inner_product = tf.gradients(Wx_inner_product, W) 116 | 117 | for i in range(100): 118 | x_rand = np.random.randint(10, size = (n)) 119 | W_rand = np.random.randint(10, size = (m, n)) 120 | 121 | gradient_tf = sess.run(grad_W_tf, feed_dict = {x: x_rand, W: W_rand}) 122 | gradient_inner_product = sess.run(grad_W_inner_product, feed_dict = {x: x_rand, W: W_rand}) 123 | 124 | np.testing.assert_array_equal(gradient_tf, gradient_inner_product) 125 | 126 | 127 | if __name__ == '__main__': 128 | unittest.main() --------------------------------------------------------------------------------