├── python ├── .gitignore └── tinyflow │ ├── __init__.py │ ├── _ops.py │ ├── nn.py │ ├── _util.py │ ├── train.py │ ├── _base.py │ ├── _session.py │ └── datasets.py ├── .gitmodules ├── .gitignore ├── src ├── torch │ ├── op_special_torch.cc │ ├── op_nn_torch.cc │ ├── torch_util.h │ └── op_tensor_torch.cc ├── op_special.cc ├── rtc │ └── op_fusion.cc ├── c_api.cc ├── op_util.h ├── op_nn.cc ├── op_tensor.cc └── session.cc ├── example ├── mnist_softmax.py ├── mnist_softmax_minimum.py ├── mnist_mlp_auto_shape_inference.py ├── mnist_lenet.py └── cifar_resnet.py ├── include └── tinyflow │ ├── c_api.h │ └── base.h ├── tests └── python │ ├── test_states.py │ ├── test_gradients.py │ └── test_ops.py ├── Makefile ├── README.md └── LICENSE /python/.gitignore: -------------------------------------------------------------------------------- 1 | *.c 2 | *.cpp -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "nnvm"] 2 | path = nnvm 3 | url = https://github.com/dmlc/nnvm 4 | [submodule "dmlc-core"] 5 | path = dmlc-core 6 | url = https://github.com/dmlc/dmlc-core 7 | -------------------------------------------------------------------------------- /python/tinyflow/__init__.py: -------------------------------------------------------------------------------- 1 | """Tinyflow trial.""" 2 | from __future__ import absolute_import as _abs 3 | from . import _base 4 | from nnvm.symbol import * 5 | from . import nn 6 | from . import train 7 | 8 | from ._base import * 9 | from ._ops import * 10 | 11 | from ._session import Session 12 | 13 | from ._util import infer_variable_shapes 14 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files 2 | *.slo 3 | *.lo 4 | *.o 5 | *.obj 6 | 7 | # Precompiled Headers 8 | *.gch 9 | *.pch 10 | 11 | # Compiled Dynamic libraries 12 | *.so 13 | *.dylib 14 | *.dll 15 | 16 | # Fortran module files 17 | *.mod 18 | *.smod 19 | 20 | # Compiled Static libraries 21 | *.lai 22 | *.la 23 | *.a 24 | *.lib 25 | 26 | # Executables 27 | *.exe 28 | *.out 29 | *.app 30 | 31 | build 32 | lib 33 | *~ 34 | dmlc-core 35 | cli_test 36 | *.pyc 37 | test.* 38 | log 39 | -------------------------------------------------------------------------------- /python/tinyflow/_ops.py: -------------------------------------------------------------------------------- 1 | """Wrapping of certain ops for positional arguments. 2 | 3 | Mainly because NNVM accepts kwargs for some additional arguments, 4 | while TF sometimes support positional ops. 5 | """ 6 | from __future__ import absolute_import as _abs 7 | from nnvm import symbol 8 | from nnvm import _symbol_internal 9 | 10 | 11 | def argmax(x, axis): 12 | return _symbol_internal._argmax(x, reduction_indices=[axis]) 13 | 14 | 15 | def zeros(shape): 16 | return symbol.zeros(shape=shape) 17 | 18 | 19 | def normal(shape, stdev=1.0): 20 | return symbol.normal(shape=shape, stdev=stdev) 21 | -------------------------------------------------------------------------------- /python/tinyflow/nn.py: -------------------------------------------------------------------------------- 1 | from nnvm.symbol import * 2 | from nnvm import symbol as _sym 3 | 4 | def conv2d(data, weight=None, 5 | strides=[1, 1, 1, 1], 6 | padding='VALID', 7 | data_format='NCHW', 8 | **kwargs): 9 | kwargs = kwargs.copy() 10 | kwargs['data'] = data 11 | if weight: 12 | kwargs['weight'] = weight 13 | return _sym.conv2d(strides=strides, padding=padding, data_format=data_format, **kwargs) 14 | 15 | def max_pool(data, 16 | strides=[1, 1, 1, 1], 17 | padding='VALID', 18 | data_format='NCHW', **kwargs): 19 | return _sym.max_pool(data, strides=strides, padding=padding, 20 | data_format=data_format, **kwargs) 21 | -------------------------------------------------------------------------------- /src/torch/op_special_torch.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2016 by Contributors 2 | // implementation of common nn operators 3 | #include 4 | 5 | namespace tinyflow { 6 | 7 | const FLuaCompute kLuaNOP = "function(x, y, kwarg) return function() end end"; 8 | 9 | NNVM_REGISTER_OP(placeholder) 10 | .set_attr("FLuaCompute", kLuaNOP); 11 | 12 | NNVM_REGISTER_OP(_nop) 13 | .set_attr("FLuaCompute", kLuaNOP); 14 | 15 | NNVM_REGISTER_OP(assign) 16 | .set_attr( 17 | "FLuaCompute", R"( 18 | function(x, y, kwarg) 19 | return function() 20 | x[1]:copy(x[2]) 21 | -- normally inplace optimization prevent this 22 | if y[1]:storage() ~= x[2]:storage() then 23 | y[1]:copy(x[2]) 24 | end 25 | end 26 | end 27 | )"); 28 | 29 | } // namespace tinyflow 30 | -------------------------------------------------------------------------------- /example/mnist_softmax.py: -------------------------------------------------------------------------------- 1 | """Tinyflow example code. 2 | 3 | This code is adapted from Tensorflow's MNIST Tutorial with minimum code changes. 4 | """ 5 | import tinyflow as tf 6 | from tinyflow.datasets import get_mnist 7 | 8 | # Create the model 9 | x = tf.placeholder(tf.float32, [None, 784]) 10 | W = tf.Variable(tf.zeros([784, 10])) 11 | y = tf.nn.softmax(tf.matmul(x, W)) 12 | 13 | # Define loss and optimizer 14 | y_ = tf.placeholder(tf.float32, [None, 10]) 15 | 16 | cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1])) 17 | train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) 18 | 19 | sess = tf.Session() 20 | sess.run(tf.initialize_all_variables()) 21 | 22 | # get the mnist dataset 23 | mnist = get_mnist(flatten=True, onehot=True) 24 | 25 | for i in range(1000): 26 | batch_xs, batch_ys = mnist.train.next_batch(100) 27 | sess.run(train_step, feed_dict={x: batch_xs, y_:batch_ys}) 28 | 29 | correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) 30 | accuracy = tf.reduce_mean(correct_prediction) 31 | 32 | print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})) 33 | -------------------------------------------------------------------------------- /example/mnist_softmax_minimum.py: -------------------------------------------------------------------------------- 1 | """Tinyflow example code. 2 | 3 | Minimum softmax code that exposes the optimizer. 4 | """ 5 | import tinyflow as tf 6 | from tinyflow.datasets import get_mnist 7 | 8 | # Create the model 9 | x = tf.placeholder(tf.float32, [None, 784]) 10 | W = tf.Variable(tf.zeros([784, 10])) 11 | y = tf.nn.softmax(tf.matmul(x, W)) 12 | 13 | # Define loss and optimizer 14 | y_ = tf.placeholder(tf.float32, [None, 10]) 15 | 16 | cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1])) 17 | 18 | learning_rate = 0.5 19 | W_grad = tf.gradients(cross_entropy, [W])[0] 20 | train_step = tf.assign(W, W - learning_rate * W_grad) 21 | 22 | sess = tf.Session() 23 | sess.run(tf.initialize_all_variables()) 24 | 25 | # get the mnist dataset 26 | mnist = get_mnist(flatten=True, onehot=True) 27 | 28 | for i in range(1000): 29 | batch_xs, batch_ys = mnist.train.next_batch(100) 30 | sess.run(train_step, feed_dict={x: batch_xs, y_:batch_ys}) 31 | 32 | correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) 33 | accuracy = tf.reduce_mean(correct_prediction) 34 | 35 | print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})) 36 | -------------------------------------------------------------------------------- /include/tinyflow/c_api.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2016 by Contributors 3 | * \file c_api.h 4 | * \brief C API to tiny flow 5 | */ 6 | #ifndef TINYFLOW_C_API_H_ 7 | #define TINYFLOW_C_API_H_ 8 | 9 | #include 10 | 11 | typedef void* SessionHandle; 12 | 13 | NNVM_DLL int NNSessionCreate(SessionHandle* handle, const char* option); 14 | 15 | NNVM_DLL int NNSessionClose(SessionHandle handle); 16 | 17 | NNVM_DLL int NNSessionRun(SessionHandle handle, 18 | SymbolHandle graph, 19 | nn_uint num_feed, 20 | const SymbolHandle* feed_placeholders, 21 | const float** feed_dptr, 22 | const nn_uint* feed_dtype, 23 | const nn_uint* feed_shape_csr_ptr, 24 | const nn_uint* feed_shape_data, 25 | nn_uint* num_out, 26 | const float*** out_dptr, 27 | const nn_uint** out_dtype, 28 | const nn_uint **out_shape_ndim, 29 | const nn_uint ***out_shape_data); 30 | 31 | #endif // TINYFLOW_C_API_H_ 32 | -------------------------------------------------------------------------------- /tests/python/test_states.py: -------------------------------------------------------------------------------- 1 | import tinyflow as tf 2 | import numpy as np 3 | 4 | def test_assign(): 5 | x = tf.Variable(tf.zeros(shape=[2,3])) 6 | sess = tf.Session() 7 | sess.run(tf.assign(x, tf.zeros(shape=[2,3]))) 8 | ax = sess.run(x) 9 | np.testing.assert_almost_equal(ax, np.zeros((2,3))) 10 | 11 | def test_group(): 12 | x1 = tf.Variable(tf.zeros(shape=[2,3])) 13 | x2 = tf.Variable(tf.zeros(shape=[2,3])) 14 | a1 = tf.assign(x1, tf.zeros(shape=[2,3])) 15 | a2 = tf.assign(x2, tf.ones(shape=[2,3])) 16 | sess = tf.Session() 17 | sess.run(tf.group(a1, a2)) 18 | ax1 = sess.run(x1) 19 | ax2 = sess.run(x2) 20 | np.testing.assert_almost_equal(ax1, np.zeros((2,3))) 21 | np.testing.assert_almost_equal(ax2, np.ones((2,3))) 22 | 23 | def test_init(): 24 | x1 = tf.Variable(tf.ones(shape=[2,3])) 25 | x2 = tf.Variable(tf.zeros(shape=[2,3])) 26 | sess = tf.Session() 27 | sess.run(tf.initialize_all_variables()) 28 | ax1 = sess.run(x1) 29 | ax2 = sess.run(x2) 30 | np.testing.assert_almost_equal(ax1, np.ones((2,3))) 31 | np.testing.assert_almost_equal(ax2, np.zeros((2,3))) 32 | 33 | if __name__ == "__main__": 34 | 35 | pass 36 | -------------------------------------------------------------------------------- /python/tinyflow/_util.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import as _abs 2 | import json 3 | from nnvm import symbol, graph 4 | 5 | def infer_variable_shapes(net, feed_dict): 6 | """Inference shape of all variables in the net. 7 | 8 | Parameters 9 | ---------- 10 | net : tf.Symbol 11 | The symbolic network containing all the variables. 12 | 13 | feed_dict : dict 14 | dict of placeholder to known shape 15 | 16 | Returns 17 | ------- 18 | Generator of (var, vname, vshape) 19 | Enables enumeration of variables in the net with corresponding name and shape. 20 | """ 21 | g = graph.create(net) 22 | jgraph = json.loads(g.apply('SaveJSON').json_attr('json')) 23 | jnode_row_ptr = jgraph["node_row_ptr"] 24 | jnodes = jgraph["nodes"] 25 | shape = [[]] * jnode_row_ptr[-1] 26 | nindex = {n['name']: i for i, n in enumerate(jnodes)} 27 | 28 | for k, v in feed_dict.items(): 29 | node_name = k.attr("name") 30 | shape[jnode_row_ptr[nindex[node_name]]] = v 31 | g._set_json_attr("shape", shape, "list_shape") 32 | g = g.apply("InferShape") 33 | shape = g.json_attr("shape") 34 | ret = {} 35 | for v in net.list_input_variables(): 36 | vname = v.attr("name") 37 | vshape = shape[jnode_row_ptr[nindex[vname]]] 38 | if len(vshape) == 0: 39 | raise ValueError("not sufficient information in feed_dict") 40 | yield (v, vname, vshape) 41 | -------------------------------------------------------------------------------- /src/op_special.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2016 by Contributors 2 | // implementation of common nn operators 3 | #include 4 | #include 5 | #include 6 | #include "./op_util.h" 7 | 8 | namespace tinyflow { 9 | 10 | using namespace nnvm; 11 | 12 | const FLuaCompute kLuaNOP = "function(x, y, kwarg) return function() end end"; 13 | 14 | NNVM_REGISTER_OP(placeholder) 15 | .describe("placeholder op") 16 | .set_num_inputs(0); 17 | 18 | template 19 | inline bool EmptyAttr(const NodeAttrs& attrs, 20 | std::vector *ishape, 21 | std::vector *oshape) { 22 | oshape->at(0) = Attr{0}; return true; 23 | } 24 | 25 | NNVM_REGISTER_OP(_nop) 26 | .describe("no operation") 27 | .set_num_inputs(0) 28 | .set_num_outputs(1) 29 | .set_attr("FInferShape", EmptyAttr) 30 | .set_attr("FInferType", EmptyAttr); 31 | 32 | 33 | NNVM_REGISTER_OP(assign) 34 | .describe("assign second to the first") 35 | .set_num_inputs(2) 36 | .set_attr("FMutateInputs", [](const NodeAttrs& attrs) { 37 | return std::vector{0}; 38 | }) 39 | .set_attr("FInferShape", SameShape) 40 | .set_attr("FInplaceOption", InplaceIn1Out0); 41 | 42 | // special no gradient op to report error when take 43 | // gradient wrt non-differentiable inputs 44 | NNVM_REGISTER_OP(_no_gradient) 45 | .describe("Special op indicating no gradient") 46 | .set_num_inputs(0); 47 | 48 | } // namespace tinyflow 49 | -------------------------------------------------------------------------------- /example/mnist_mlp_auto_shape_inference.py: -------------------------------------------------------------------------------- 1 | """TinyFlow Example code. 2 | 3 | Automatic variable creation and shape inductions. 4 | The network structure is directly specified via forward node numbers 5 | The variables are automatically created, and their shape infered by tf.infer_variable_shapes 6 | """ 7 | import tinyflow as tf 8 | from tinyflow.datasets import get_mnist 9 | 10 | # Create the model 11 | x = tf.placeholder(tf.float32) 12 | fc1 = tf.nn.linear(x, num_hidden=100, name="fc1", no_bias=False) 13 | relu1 = tf.nn.relu(fc1) 14 | fc2 = tf.nn.linear(relu1, num_hidden=10, name="fc2") 15 | 16 | # define loss 17 | label = tf.placeholder(tf.float32) 18 | cross_entropy = tf.nn.mean_sparse_softmax_cross_entropy_with_logits(fc2, label) 19 | train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) 20 | 21 | sess = tf.Session(device='gpu') 22 | 23 | # Automatic variable shape inference API, infers the shape and initialize the weights. 24 | known_shape = {x: [100, 28 * 28], label: [100]} 25 | init_step = [] 26 | for v, name, shape in tf.infer_variable_shapes( 27 | cross_entropy, feed_dict=known_shape): 28 | init_step.append(tf.assign(v, tf.normal(shape))) 29 | print("shape[%s]=%s" % (name, str(shape))) 30 | sess.run(init_step) 31 | 32 | # get the mnist dataset 33 | mnist = get_mnist(flatten=True, onehot=False) 34 | 35 | print_period = 1000 36 | for epoch in range(10): 37 | sum_loss = 0.0 38 | num_batch = 600 39 | for i in range(num_batch): 40 | batch_xs, batch_ys = mnist.train.next_batch(100) 41 | loss, _ = sess.run([cross_entropy, train_step], feed_dict={x: batch_xs, label:batch_ys}) 42 | sum_loss += loss 43 | print("epoch[%d] cross_entropy=%g" % (epoch, sum_loss /num_batch)) 44 | 45 | correct_prediction = tf.equal(tf.argmax(fc2, 1), label) 46 | accuracy = tf.reduce_mean(correct_prediction) 47 | print(sess.run(accuracy, feed_dict={x: mnist.test.images, label: mnist.test.labels})) 48 | -------------------------------------------------------------------------------- /python/tinyflow/train.py: -------------------------------------------------------------------------------- 1 | from . import _base 2 | from nnvm import symbol as _sym 3 | 4 | class GradientDescentOptimizer(object): 5 | def __init__(self, learning_rate, name="GradientDescent"): 6 | self.learning_rate = learning_rate 7 | 8 | def minimize(self, obj): 9 | variables = obj.list_input_variables() 10 | grads = _base.gradients(obj, variables) 11 | updates = [] 12 | for v, g in zip(variables, grads): 13 | updates.append(_sym.assign(v, v + (-self.learning_rate) * g)) 14 | return _base.group(*updates) 15 | 16 | class AdamOptimizer(object): 17 | def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-04, name='Adam'): 18 | self.name = name 19 | self.t = _base.Variable(_sym.zeros(shape=[1]), name+'_t') 20 | self.learning_rate = learning_rate 21 | self.beta1 = beta1 22 | self.beta2 = beta2 23 | self.epsilon = epsilon 24 | self.m = [] 25 | self.v = [] 26 | 27 | def minimize(self, obj): 28 | variables = obj.list_input_variables() 29 | grads = _base.gradients(obj, variables) 30 | updates = [] 31 | for i, v in enumerate(variables): 32 | self.m.append(_base.Variable(_sym.zeros_like(v), self.name + '_m' + str(i))) 33 | self.v.append(_base.Variable(_sym.zeros_like(v), self.name + '_v' + str(i))) 34 | update_t = _sym.assign(self.t, self.t + 1) 35 | rate = _sym.sqrt(1 - self.beta2 ** update_t) / (1 - self.beta1 ** update_t) 36 | lr_t = self.learning_rate * rate 37 | for var, g, m, v in zip(variables, grads, self.m, self.v): 38 | update_m = _sym.assign(m, self.beta1 * m + (1 - self.beta1) * g) 39 | update_v = _sym.assign(v, self.beta2 * v + (1 - self.beta2) * g * g) 40 | update_var = _sym.assign(var, 41 | var - lr_t * update_m / (_sym.sqrt(update_v) + self.epsilon)) 42 | updates.append(update_var) 43 | return _base.group(*updates) 44 | -------------------------------------------------------------------------------- /tests/python/test_gradients.py: -------------------------------------------------------------------------------- 1 | import tinyflow as tf 2 | import numpy as np 3 | 4 | def test_add_grad(): 5 | x = tf.placeholder(tf.float32) 6 | y = tf.placeholder(tf.float32) 7 | ax = np.ones((2, 3)) 8 | ay = np.ones((2, 3)) * 4 9 | z = x + y 10 | gx, gy = tf.gradients(z, [x, y]) 11 | sess = tf.Session() 12 | agx = sess.run(gx, feed_dict={x:ax, y:ay}) 13 | np.testing.assert_almost_equal(agx, np.ones((2,3))) 14 | 15 | def test_mul_grad(): 16 | x = tf.placeholder(tf.float32) 17 | ax = np.ones((2, 3)) 18 | z = x * 14 19 | gx = tf.gradients(z, [x])[0] 20 | sess = tf.Session() 21 | agx = sess.run(gx, feed_dict={x:ax}) 22 | np.testing.assert_almost_equal(agx, np.ones((2,3)) * 14) 23 | 24 | def test_sum_grad(): 25 | x = tf.placeholder(tf.float32) 26 | ax = np.ones((2, 3)) 27 | z = -tf.reduce_sum(x) * 14 28 | gx = tf.gradients(z, [x])[0] 29 | sess = tf.Session() 30 | agx = sess.run(gx, feed_dict={x:ax}) 31 | np.testing.assert_almost_equal(agx, -np.ones((2,3)) * 14) 32 | 33 | def test_mean_grad(): 34 | x = tf.placeholder(tf.float32) 35 | ax = np.ones((2, 3)) 36 | z = -tf.reduce_mean(x) * 14 37 | gx = tf.gradients(z, [x])[0] 38 | sess = tf.Session() 39 | agx = sess.run(gx, feed_dict={x:ax}) 40 | np.testing.assert_almost_equal(agx, -np.ones((2,3)) * 14 / 6.0) 41 | 42 | def test_matmul_grad(): 43 | x = tf.placeholder(tf.float32) 44 | y = tf.placeholder(tf.float32) 45 | ax = np.ones((2, 3)) 46 | ay = np.ones((3, 4)) * 4 47 | z = tf.matmul(x, y) * 4 48 | gx, gy = tf.gradients(z, [x, y]) 49 | sess = tf.Session() 50 | agx = sess.run(gx, feed_dict={x:ax, y:ay}) 51 | agy = sess.run(gy, feed_dict={x:ax, y:ay}) 52 | np.testing.assert_almost_equal( 53 | agx, 54 | np.dot(np.ones((2,4)), ay.T) * 4) 55 | np.testing.assert_almost_equal( 56 | agy, 57 | np.dot(ax.T, np.ones((2,4))) * 4) 58 | 59 | 60 | if __name__ == "__main__": 61 | test_mean_grad() 62 | pass 63 | -------------------------------------------------------------------------------- /example/mnist_lenet.py: -------------------------------------------------------------------------------- 1 | """TinyFlow Example: LeNet for Digits classification. 2 | 3 | This code uses automatic variable shape inference for shorter code. 4 | """ 5 | import tinyflow as tf 6 | from tinyflow.datasets import get_mnist 7 | 8 | # Create the model 9 | x = tf.placeholder(tf.float32) 10 | conv1 = tf.nn.conv2d(x, num_filter=20, ksize=[1, 5, 5, 1], name="conv1", no_bias=False) 11 | tanh1 = tf.tanh(conv1) 12 | pool1 = tf.nn.max_pool(tanh1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1]) 13 | conv2 = tf.nn.conv2d(pool1, num_filter=50, ksize=[1, 5, 5, 1], name="conv2", no_bias=False) 14 | tanh2 = tf.tanh(conv2) 15 | pool2 = tf.nn.max_pool(tanh2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1]) 16 | flatten = tf.nn.flatten_layer(pool2) 17 | fc1 = tf.nn.linear(flatten, num_hidden=500, name="fc1") 18 | tanh3 = tf.tanh(fc1) 19 | fc2 = tf.nn.linear(tanh3, num_hidden=10, name="fc2") 20 | 21 | # define loss 22 | label = tf.placeholder(tf.float32) 23 | cross_entropy = tf.nn.mean_sparse_softmax_cross_entropy_with_logits(fc2, label) 24 | train_step = tf.train.AdamOptimizer(0.005).minimize(cross_entropy) 25 | 26 | sess = tf.Session(config='gpu') 27 | 28 | # Auromatic variable shape inference API, infers the shape and initialize the weights. 29 | known_shape = {x: [100, 1, 28, 28], label: [100]} 30 | stdev = 0.01 31 | init_step = [] 32 | for v, name, shape in tf.infer_variable_shapes( 33 | cross_entropy, feed_dict=known_shape): 34 | init_step.append(tf.assign(v, tf.normal(shape, stdev))) 35 | print("shape[%s]=%s" % (name, str(shape))) 36 | sess.run(init_step) 37 | sess.run(tf.initialize_all_variables()) 38 | 39 | # get the mnist dataset 40 | mnist = get_mnist(flatten=False, onehot=False) 41 | 42 | print_period = 1000 43 | for epoch in range(10): 44 | sum_loss = 0.0 45 | num_batch = 600 46 | for i in range(num_batch): 47 | batch_xs, batch_ys = mnist.train.next_batch(100) 48 | loss, _ = sess.run([cross_entropy, train_step], feed_dict={x: batch_xs, label:batch_ys}) 49 | sum_loss += loss 50 | print("epoch[%d] cross_entropy=%g" % (epoch, sum_loss /num_batch)) 51 | 52 | correct_prediction = tf.equal(tf.argmax(fc2, 1), label) 53 | accuracy = tf.reduce_mean(correct_prediction) 54 | print(sess.run(accuracy, feed_dict={x: mnist.test.images, label: mnist.test.labels})) 55 | -------------------------------------------------------------------------------- /python/tinyflow/_base.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import as _abs 2 | import os 3 | import sys 4 | 5 | if sys.version_info[0] == 3: 6 | import builtins as __builtin__ 7 | else: 8 | import __builtin__ 9 | 10 | curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) 11 | 12 | if hasattr(__builtin__, "NNVM_BASE_PATH"): 13 | assert __builtin__.NNVM_BASE_PATH == curr_path 14 | else: 15 | __builtin__.NNVM_BASE_PATH = curr_path 16 | 17 | if hasattr(__builtin__, "NNVM_LIBRARY_NAME"): 18 | assert __builtin__.NNVM_LIBRARY_NAME == curr_path 19 | else: 20 | __builtin__.NNVM_LIBRARY_NAME = "libtinyflow" 21 | 22 | 23 | import ctypes as _ctypes 24 | from nnvm.name import NameManager 25 | from nnvm._base import c_str, check_call, _LIB 26 | from nnvm import symbol, graph 27 | from nnvm import _symbol_internal 28 | 29 | __all__ = ["float32", "placeholder", "Variable", "group", 30 | "initialize_all_variables", "gradients"] 31 | 32 | # data type table 33 | float32 = 0 34 | 35 | # global list of all variable initializers 36 | _all_variable_inits = [] 37 | 38 | 39 | def Variable(init=None, name=None): 40 | name = NameManager.current.get(name, 'variable') 41 | v = symbol.Variable(name) 42 | if init is not None: 43 | if not isinstance(init, symbol.Symbol): 44 | raise TypeError("Expect initialization expression to be Symbol") 45 | _all_variable_inits.append(symbol.assign(v, init)) 46 | return v 47 | 48 | 49 | def initialize_all_variables(): 50 | global _all_variable_inits 51 | init_op = group(*_all_variable_inits) 52 | _all_variable_inits = [] 53 | return init_op 54 | 55 | 56 | def placeholder(dtype, shape=None, name=None): 57 | v = symbol.placeholder(name=name, dtype=dtype) 58 | return v 59 | 60 | 61 | def group(*inputs): 62 | x = _symbol_internal._nop() 63 | x._add_control_deps(symbol.Group(inputs)) 64 | return x 65 | 66 | 67 | def gradients(ys, xs, grad_ys=None): 68 | if isinstance(ys, list): 69 | ys = symbol.Group(ys) 70 | g = graph.create(ys) 71 | g._set_symbol_list_attr('grad_ys', ys) 72 | g._set_symbol_list_attr('grad_xs', xs) 73 | ny = len(ys.list_output_names()) 74 | if grad_ys is None: 75 | grad_ys = [symbol.ones_like(ys[i]) for i in range(ny)] 76 | g._set_symbol_list_attr('grad_ys_out_grad', grad_ys) 77 | sym = g.apply('Gradient').symbol 78 | nx = len(xs) if isinstance(xs, list) else len(xs.list_output_names()) 79 | ret = [sym[i] for i in range(nx)] 80 | return ret 81 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | TORCH_PATH=${TORCH_HOME} 2 | 3 | ROOTDIR = $(CURDIR) 4 | 5 | ifndef CUDA_PATH 6 | CUDA_PATH = /usr/local/cuda 7 | endif 8 | 9 | ifndef NNVM_PATH 10 | NNVM_PATH = $(ROOTDIR)/nnvm 11 | endif 12 | 13 | export LDFLAGS = -pthread -lm 14 | export CFLAGS = -std=c++11 -Wall -O2 -msse2 -Wno-unknown-pragmas -funroll-loops\ 15 | -fPIC -Iinclude -Idmlc-core/include -I$(NNVM_PATH)/include 16 | 17 | # whether use fusion 18 | USE_FUSION = 0 19 | ifeq ($(USE_FUSION), 1) 20 | ifndef NNVM_FUSION_PATH 21 | NNVM_FUSION_PATH = $(NNVM_PATH)/plugin/nnvm-fusion/ 22 | endif 23 | CFLAGS += -DTINYFLOW_USE_FUSION=1 -I$(NNVM_FUSION_PATH)/include -I$(CUDA_PATH)/include 24 | LDFLAGS += -L$(CUDA_PATH)/lib64 -lcuda -lnvrtc -lcudart 25 | endif 26 | 27 | .PHONY: clean all test lint doc 28 | 29 | UNAME_S := $(shell uname -s) 30 | 31 | ifeq ($(UNAME_S), Darwin) 32 | WHOLE_ARCH= -all_load 33 | NO_WHOLE_ARCH= -noall_load 34 | CFLAGS += -I$(TORCH_PATH)/install/include -I$(TORCH_PATH)/install/include/TH 35 | LDFLAGS += -L$(TORCH_PATH)/install/lib -llua -lluaT -lTH 36 | else 37 | WHOLE_ARCH= --whole-archive 38 | NO_WHOLE_ARCH= --no-whole-archive 39 | CFLAGS += -I$(TORCH_PATH)/install/include -I$(TORCH_PATH)/install/include/TH \ 40 | -I$(TORCH_PATH)/install/include/THC/ 41 | LDFLAGS += -L$(TORCH_PATH)/install/lib -lluajit -lluaT -lTH -lTHC 42 | endif 43 | 44 | SRC = $(wildcard src/*.cc src/*/*.cc src/*/*/*.cc) 45 | OBJ = $(patsubst %.cc, build/%.o, $(SRC)) 46 | CUSRC = $(wildcard src/*.cu src/*/*.cu src/*/*/*.cu) 47 | CUOBJ = $(patsubst %.cu, build/%_gpu.o, $(CUSRC)) 48 | 49 | LIB_DEP = $(NNVM_PATH)/lib/libnnvm.a 50 | ALL_DEP = $(OBJ) $(LIB_DEP) 51 | 52 | all: lib/libtinyflow.so 53 | 54 | build/src/%.o: src/%.cc 55 | @mkdir -p $(@D) 56 | $(CXX) -std=c++11 $(CFLAGS) -MM -MT build/src/$*.o $< >build/src/$*.d 57 | $(CXX) -std=c++11 -c $(CFLAGS) -c $< -o $@ 58 | 59 | build/src/%_gpu.o: src/%.cu 60 | @mkdir -p $(@D) 61 | $(NVCC) $(NVCCFLAGS) -Xcompiler "$(CFLAGS)" -M -MT build/src/$*_gpu.o $< >build/src/$*_gpu.d 62 | $(NVCC) -c -o $@ $(NVCCFLAGS) -Xcompiler "$(CFLAGS)" $< 63 | 64 | lib/libtinyflow.so: $(ALL_DEP) 65 | @mkdir -p $(@D) 66 | $(CXX) $(CFLAGS) -shared -o $@ $(filter %.o, $^) \ 67 | -Wl,${WHOLE_ARCH} $(filter %.a, $^) -Wl,${NO_WHOLE_ARCH} $(LDFLAGS) 68 | 69 | $(NNVM_PATH)/lib/libnnvm.a: 70 | + cd $(NNVM_PATH); make lib/libnnvm.a; cd $(ROOTDIR) 71 | 72 | lint: 73 | python2 dmlc-core/scripts/lint.py tinyflow cpp include src 74 | 75 | clean: 76 | $(RM) -rf build lib bin *~ */*~ */*/*~ */*/*/*~ */*.o */*/*.o */*/*/*.o 77 | 78 | -include build/*.d 79 | -include build/*/*.d 80 | -include build/*/*/*.d 81 | -------------------------------------------------------------------------------- /python/tinyflow/_session.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import as _abs 2 | import ctypes as _ctypes 3 | import numpy as np 4 | from nnvm import symbol 5 | from nnvm._base import c_str, check_call, _LIB, c_array, nn_uint 6 | 7 | SessionHandle = _ctypes.c_void_p 8 | nn_float = _ctypes.c_float 9 | 10 | def _get_numpy(cptr, dtype, shape): 11 | if dtype != 0: 12 | raise ValueError("only float32 is supported so far") 13 | size = 1 14 | for s in shape: 15 | size *= s 16 | if size != 0 and shape: 17 | dbuffer = (nn_float * size).from_address(_ctypes.addressof(cptr.contents)) 18 | return np.frombuffer(dbuffer, dtype=np.float32).reshape(shape).copy() 19 | else: 20 | return None 21 | 22 | class Session(object): 23 | def __init__(self, config='cpu'): 24 | handle = SessionHandle() 25 | check_call(_LIB.NNSessionCreate(_ctypes.byref(handle), c_str(config))) 26 | self.handle = handle 27 | 28 | def __del__(self): 29 | check_call(_LIB.NNSessionClose(self.handle)) 30 | 31 | def run(self, fetch, feed_dict=None): 32 | if isinstance(fetch, list): 33 | fetch = symbol.Group(fetch) 34 | feed_dict = feed_dict if feed_dict else {} 35 | feed_placeholders = [] 36 | feed_dptr = [] 37 | feed_dtype = [] 38 | feed_shape_csr_ptr = [0] 39 | feed_shape_data = [] 40 | src_list = [] 41 | 42 | for k, v in feed_dict.items(): 43 | assert isinstance(k, symbol.Symbol) 44 | assert isinstance(v, np.ndarray) 45 | feed_placeholders.append(k.handle) 46 | # only convert to float32 for now 47 | source_array = np.ascontiguousarray(v, dtype=np.float32) 48 | # leep src_list alive for the period 49 | src_list.append(source_array) 50 | feed_dptr.append(source_array.ctypes.data_as(_ctypes.c_void_p)) 51 | feed_dtype.append(0) 52 | feed_shape_data.extend(source_array.shape) 53 | feed_shape_csr_ptr.append(len(feed_shape_data)) 54 | out_size = nn_uint() 55 | out_dptr = _ctypes.POINTER(_ctypes.POINTER(nn_float))() 56 | out_dtype = _ctypes.POINTER(nn_uint)() 57 | out_shape_ndim = _ctypes.POINTER(nn_uint)() 58 | out_shape_data = _ctypes.POINTER(_ctypes.POINTER(nn_uint))() 59 | 60 | check_call(_LIB.NNSessionRun( 61 | self.handle, fetch.handle, nn_uint(len(src_list)), 62 | c_array(_ctypes.c_void_p, feed_placeholders), 63 | c_array(_ctypes.c_void_p, feed_dptr), 64 | c_array(nn_uint, feed_dtype), 65 | c_array(nn_uint, feed_shape_csr_ptr), 66 | c_array(nn_uint, feed_shape_data), 67 | _ctypes.byref(out_size), 68 | _ctypes.byref(out_dptr), 69 | _ctypes.byref(out_dtype), 70 | _ctypes.byref(out_shape_ndim), 71 | _ctypes.byref(out_shape_data))) 72 | ret = [] 73 | for i in range(out_size.value): 74 | shape = tuple(out_shape_data[i][:out_shape_ndim[i]]) 75 | ret.append(_get_numpy(out_dptr[i], out_dtype[i], shape)) 76 | 77 | return ret[0] if len(ret) == 1 else ret 78 | -------------------------------------------------------------------------------- /example/cifar_resnet.py: -------------------------------------------------------------------------------- 1 | import tinyflow as tf 2 | from tinyflow.datasets import get_cifar10 3 | import numpy as np 4 | 5 | num_epoch = 10 6 | num_batch = 600 7 | batch_size = 100 8 | 9 | 10 | def conv_factory(x, filter_size, in_filters, out_filters): 11 | x = tf.nn.conv2d(x, num_filter=out_filters, 12 | ksize=[1, filter_size, filter_size, 1], padding='SAME') 13 | x = tf.nn.batch_normalization(x) 14 | x = tf.nn.relu(x) 15 | return x 16 | 17 | def residual_factory(x, in_filters, out_filters): 18 | if in_filters == out_filters: 19 | orig_x = x 20 | conv1 = conv_factory(x, 3, in_filters, out_filters) 21 | conv2 = conv_factory(conv1, 3, out_filters, out_filters) 22 | new = orig_x + conv2 23 | return tf.nn.relu(new) 24 | else: 25 | conv1 = conv_factory(x, 3, in_filters, out_filters) 26 | conv2 = conv_factory(conv1, 3, out_filters, out_filters) 27 | project_x = conv_factory(x, 1, in_filters, out_filters) 28 | new = project_x + conv2 29 | return tf.nn.relu(new) 30 | 31 | def resnet(x, n, in_filters, out_filters): 32 | for i in range(n): 33 | if i == 0: 34 | x = residual_factory(x, in_filters, 16) 35 | else: 36 | x = residual_factory(x, 16, 16) 37 | for i in range(n): 38 | if i == 0: 39 | x = residual_factory(x, 16, 32) 40 | else: 41 | x = residual_factory(x, 32, 32) 42 | for i in range(n): 43 | if i == 0: 44 | x = residual_factory(x, 32, 64) 45 | else: 46 | x = residual_factory(x, 64, 64) 47 | return x 48 | 49 | 50 | x = tf.placeholder(tf.float32) 51 | conv1 = tf.nn.conv2d(x, num_filter=16, ksize=[1, 5, 5, 1], padding='SAME') 52 | tanh1 = tf.tanh(conv1) 53 | res = resnet(tanh1, 1, 16, 64) 54 | pool1 = tf.nn.avg_pool(res, ksize=[1, 4, 4, 1], strides=[1, 2, 2, 1], padding='SAME', data_format='NCHW') 55 | conv2 = tf.nn.conv2d(pool1, num_filter=16, ksize=[1, 5, 5, 1]) 56 | flatten = tf.nn.flatten_layer(conv2) 57 | fc1 = tf.nn.linear(flatten, num_hidden=10, name="fc1") 58 | 59 | # define loss 60 | label = tf.placeholder(tf.float32) 61 | cross_entropy = tf.nn.mean_sparse_softmax_cross_entropy_with_logits(fc1, label) 62 | train_step = tf.train.AdamOptimizer(0.0005).minimize(cross_entropy) 63 | 64 | sess = tf.Session(config='gpu') 65 | 66 | # Auromatic variable shape inference API, infers the shape and initialize the weights. 67 | known_shape = {x: [batch_size, 3, 32, 32], label: [batch_size]} 68 | stdev = 0.01 69 | init_step = [] 70 | for v, name, shape in tf.infer_variable_shapes( 71 | cross_entropy, feed_dict=known_shape): 72 | init_step.append(tf.assign(v, tf.normal(shape, stdev))) 73 | print("shape[%s]=%s" % (name, str(shape))) 74 | sess.run(init_step) 75 | sess.run(tf.initialize_all_variables()) 76 | 77 | # get the cifar dataset 78 | cifar = get_cifar10() 79 | 80 | for epoch in range(num_epoch): 81 | sum_loss = 0.0 82 | for i in range(num_batch): 83 | batch_xs, batch_ys = cifar.train.next_batch(batch_size) 84 | loss, _ = sess.run([cross_entropy, train_step], feed_dict={x: batch_xs, label:batch_ys}) 85 | sum_loss += loss 86 | print("epoch[%d] cross_entropy=%g" % (epoch, sum_loss /num_batch)) 87 | 88 | correct_prediction = tf.equal(tf.argmax(fc1, 1), label) 89 | accuracy = tf.reduce_mean(correct_prediction) 90 | print(sess.run(accuracy, feed_dict={x: cifar.test.images, label: cifar.test.labels})) 91 | -------------------------------------------------------------------------------- /src/rtc/op_fusion.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2016 by Contributors 2 | // implementation of operators FCodeGen attribute 3 | #if TINYFLOW_USE_FUSION == 1 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | using nnvm::NodePtr; 10 | using nnvm::fusion::FCodeGen; 11 | using nnvm::fusion::ASTPtr; 12 | using nnvm::fusion::FloatAST; 13 | using nnvm::fusion::CallAST; 14 | 15 | namespace tinyflow { 16 | 17 | NNVM_REGISTER_OP(__add_symbol__) 18 | .set_attr( 19 | "FCodeGen", [](const NodePtr& n, 20 | const std::vector& inputs) { 21 | return std::vector{ 22 | inputs[0] + inputs[1], 23 | }; 24 | } 25 | ); 26 | 27 | 28 | NNVM_REGISTER_OP(__add_scalar__) 29 | .set_attr( 30 | "FCodeGen", [](const NodePtr& n, 31 | const std::vector& inputs) { 32 | double val = std::stod(n->attrs.dict["scalar"]); 33 | ASTPtr num = ASTPtr(new FloatAST(val)); 34 | return std::vector{ 35 | inputs[0] + num, 36 | }; 37 | } 38 | ); 39 | 40 | 41 | NNVM_REGISTER_OP(__sub_symbol__) 42 | .set_attr( 43 | "FCodeGen", [](const NodePtr& n, 44 | const std::vector& inputs) { 45 | return std::vector{ 46 | inputs[0] - inputs[1], 47 | }; 48 | } 49 | ); 50 | 51 | 52 | NNVM_REGISTER_OP(__rsub_scalar__) 53 | .set_attr( 54 | "FCodeGen", [](const NodePtr& n, 55 | const std::vector& inputs) { 56 | double val = std::stod(n->attrs.dict["scalar"]); 57 | ASTPtr num = ASTPtr(new FloatAST(val)); 58 | return std::vector{ 59 | num - inputs[0], 60 | }; 61 | } 62 | ); 63 | 64 | 65 | NNVM_REGISTER_OP(__mul_symbol__) 66 | .set_attr( 67 | "FCodeGen", [](const NodePtr& n, 68 | const std::vector& inputs) { 69 | return std::vector{ 70 | inputs[0] * inputs[1], 71 | }; 72 | } 73 | ); 74 | 75 | 76 | NNVM_REGISTER_OP(__mul_scalar__) 77 | .set_attr( 78 | "FCodeGen", [](const NodePtr& n, 79 | const std::vector& inputs) { 80 | double val = std::stod(n->attrs.dict["scalar"]); 81 | ASTPtr num = ASTPtr(new FloatAST(val)); 82 | return std::vector{ 83 | num * inputs[0], 84 | }; 85 | } 86 | ); 87 | 88 | 89 | NNVM_REGISTER_OP(__div_symbol__) 90 | .set_attr( 91 | "FCodeGen", [](const NodePtr& n, 92 | const std::vector& inputs) { 93 | return std::vector{ 94 | inputs[0] / inputs[1], 95 | }; 96 | } 97 | ); 98 | 99 | 100 | NNVM_REGISTER_OP(exp) 101 | .set_attr( 102 | "FCodeGen", [](const NodePtr& n, 103 | const std::vector& inputs) { 104 | return std::vector{ 105 | ASTPtr(new CallAST("exp", inputs)), 106 | }; 107 | } 108 | ); 109 | 110 | 111 | NNVM_REGISTER_OP(sqrt) 112 | .set_attr( 113 | "FCodeGen", [](const NodePtr& n, 114 | const std::vector& inputs) { 115 | return std::vector{ 116 | ASTPtr(new CallAST("sqrt", inputs)), 117 | }; 118 | } 119 | ); 120 | 121 | 122 | NNVM_REGISTER_OP(__rpow_scalar__) 123 | .set_attr( 124 | "FCodeGen", [](const NodePtr& n, 125 | const std::vector& inputs) { 126 | double val = std::stod(n->attrs.dict["scalar"]); 127 | ASTPtr num = ASTPtr(new FloatAST(val)); 128 | return std::vector{ 129 | ASTPtr(new CallAST("pow", {num, inputs[0]})), 130 | }; 131 | } 132 | ); 133 | 134 | } // namespace tinyflow 135 | #endif 136 | -------------------------------------------------------------------------------- /python/tinyflow/datasets.py: -------------------------------------------------------------------------------- 1 | """auxiliary utility to get the dataset for demo""" 2 | import numpy as np 3 | from collections import namedtuple 4 | from sklearn.datasets import fetch_mldata 5 | import cPickle 6 | import sys 7 | import os 8 | from subprocess import call 9 | 10 | 11 | class ArrayPacker(object): 12 | """Dataset packer for iterator""" 13 | def __init__(self, X, Y): 14 | self.images = X 15 | self.labels = Y 16 | self.ptr = 0 17 | 18 | def next_batch(self, batch_size): 19 | if self.ptr + batch_size >= self.labels.shape[0]: 20 | self.ptr = 0 21 | X = self.images[self.ptr:self.ptr+batch_size] 22 | Y = self.labels[self.ptr:self.ptr+batch_size] 23 | self.ptr += batch_size 24 | return X, Y 25 | 26 | MNISTData = namedtuple("MNISTData", ["train", "test"]) 27 | 28 | def get_mnist(flatten=False, onehot=False): 29 | mnist = fetch_mldata('MNIST original') 30 | np.random.seed(1234) # set seed for deterministic ordering 31 | p = np.random.permutation(mnist.data.shape[0]) 32 | X = mnist.data[p] 33 | Y = mnist.target[p] 34 | X = X.astype(np.float32) / 255.0 35 | if flatten: 36 | X = X.reshape((70000, 28 * 28)) 37 | else: 38 | X = X.reshape((70000, 1, 28, 28)) 39 | if onehot: 40 | onehot = np.zeros((Y.shape[0], 10)) 41 | onehot[np.arange(Y.shape[0]), Y.astype(np.int32)] = 1 42 | Y = onehot 43 | X_train = X[:60000] 44 | Y_train = Y[:60000] 45 | X_test = X[60000:] 46 | Y_test = Y[60000:] 47 | return MNISTData(train=ArrayPacker(X_train, Y_train), 48 | test=ArrayPacker(X_test, Y_test)) 49 | 50 | 51 | CIFAR10Data = namedtuple("CIFAR10Data", ["train", "test"]) 52 | 53 | def load_batch(fpath, label_key='labels'): 54 | f = open(fpath, 'rb') 55 | if sys.version_info < (3,): 56 | d = cPickle.load(f) 57 | else: 58 | d = cPickle.load(f, encoding="bytes") 59 | # decode utf8 60 | for k, v in d.items(): 61 | del(d[k]) 62 | d[k.decode("utf8")] = v 63 | f.close() 64 | data = d["data"] 65 | labels = d[label_key] 66 | 67 | data = data.reshape(data.shape[0], 3, 32, 32).astype(np.float32) 68 | labels = np.array(labels, dtype="float32") 69 | return data, labels 70 | 71 | 72 | def get_cifar10(swap_axes=False): 73 | path = "cifar-10-batches-py" 74 | if not os.path.exists(path): 75 | tar_file = "cifar-10-python.tar.gz" 76 | origin = "http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 77 | if os.path.exists(tar_file): 78 | need_download = False 79 | else: 80 | need_download = True 81 | if need_download: 82 | call(["wget", origin]) 83 | call(["tar", "-xvf", "cifar-10-python.tar.gz"]) 84 | else: 85 | call(["tar", "-xvf", "cifar-10-python.tar.gz"]) 86 | 87 | nb_train_samples = 50000 88 | 89 | X_train = np.zeros((nb_train_samples, 3, 32, 32), dtype="float32") 90 | y_train = np.zeros((nb_train_samples,), dtype="float32") 91 | 92 | for i in range(1, 6): 93 | fpath = os.path.join(path, 'data_batch_' + str(i)) 94 | data, labels = load_batch(fpath) 95 | X_train[(i - 1) * 10000: i * 10000, :, :, :] = data 96 | y_train[(i - 1) * 10000: i * 10000] = labels 97 | 98 | fpath = os.path.join(path, 'test_batch') 99 | X_test, y_test = load_batch(fpath) 100 | 101 | if swap_axes: 102 | X_train = np.swapaxes(X_train, 1, 3) 103 | X_test = np.swapaxes(X_test, 1, 3) 104 | 105 | return CIFAR10Data(train=ArrayPacker(X_train, y_train), 106 | test=ArrayPacker(X_test, y_test)) 107 | -------------------------------------------------------------------------------- /include/tinyflow/base.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2016 by Contributors 3 | * \file base.h 4 | * \brief Basic data structures. 5 | */ 6 | #ifndef TINYFLOW_BASE_H_ 7 | #define TINYFLOW_BASE_H_ 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | namespace tinyflow { 18 | 19 | using nnvm::Op; 20 | using nnvm::Node; 21 | using nnvm::Symbol; 22 | using nnvm::TShape; 23 | 24 | /*! \brief device mask for each device */ 25 | enum DeviceMask { 26 | kCPU = 1, 27 | kGPU = 2 28 | }; 29 | 30 | /*! \brief data type enumeration */ 31 | enum DataType { 32 | kFloat32 = 0 33 | }; 34 | 35 | /*! \brief contiguous tensor block data structure */ 36 | struct TBlob { 37 | /*! \brief pointer to the data */ 38 | void* data{nullptr}; 39 | /*! \brief shape of the tensor */ 40 | TShape shape; 41 | /*! \brief device mask of the corresponding device type */ 42 | int dev_mask{kCPU}; 43 | /*! \brief type of the tensor */ 44 | int dtype{kFloat32}; 45 | }; 46 | 47 | /*! 48 | * \brief a lua function to return closure to carry out computation of an op. 49 | * 50 | * Signature: 51 | * function(inputs, outputs, kwarg) 52 | * - inputs: array of input torch.FloatTensor 53 | * - outputs: array of input torch.FloatTensor 54 | * - kwargs: additional table arguments passed from kwarg. 55 | * - return: a lua closure, with signature function() that carrys out the computation. 56 | * 57 | * After this function, outputs content are set to be correct value. 58 | * This function cannot change the storage of inputs and outputs. 59 | * \note Register as FLuaCompute 60 | * only one of FLuaCreateNNModule/FLuaCompute is needed. 61 | */ 62 | using FLuaCompute = std::string; 63 | 64 | /*! 65 | * \brief a lua code to create a NN module/Criterion for an op. 66 | * Only allows register normal module that takes one tensor and returns one tensor. 67 | * 68 | * Signature: 69 | * function(ishape, kwarg) 70 | * - ishape: array of array, shape of input arguments, including weights 71 | * - kwarg: table: str->str, additional arguments to the op. 72 | * - return: a nn.Module for corresponding op. 73 | * 74 | * \note Register as FLuaCreateNNModule, 75 | * only one of FLuaCreateNNModule/FLuaCompute is needed. 76 | */ 77 | using FLuaCreateNNModule = std::string; 78 | 79 | /*! 80 | * \brief If registered and TBackwardNumNoGrad=k 81 | * The last k inputs do not have gradient. 82 | * \note Register as TBackwardNumNoGradInputs 83 | */ 84 | using TBackwardNumNoGradInputs = int; 85 | 86 | /*! 87 | * \brief Whether backward need weight. 88 | * \note Register as TBackwardNeedInputs 89 | */ 90 | using TBackwardNeedInputs = bool; 91 | 92 | /*! 93 | * \brief Whether backward op need outputs. 94 | * \note Register as TBackwardNeedOutputs 95 | */ 96 | using TBackwardNeedOutputs = bool; 97 | 98 | /*! \brief Executor of a graph */ 99 | class Session { 100 | public: 101 | /*! 102 | * \brief Run the given graph 103 | * \param g the graph to run. 104 | * \param inputs The input feed_dict mapping 105 | * \note The session hold the ownership of the outputs. 106 | * The results are only valid before calling any functions of this session again. 107 | * \return The output tensors. 108 | */ 109 | virtual const std::vector& Run( 110 | Symbol* g, 111 | const std::unordered_map& inputs) = 0; 112 | /*! \brief virtual destructor */ 113 | virtual ~Session() {} 114 | /*! 115 | * \brief create a new session of given type. 116 | * \param type The type of the session. 117 | * \return a new created session. 118 | */ 119 | static Session* Create(const std::string& type); 120 | }; 121 | 122 | } // namespace tinyflow 123 | 124 | #endif // TINYFLOW_BASE_H_ 125 | -------------------------------------------------------------------------------- /src/c_api.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2016 by Contributors 2 | #include 3 | #include 4 | 5 | /*! 6 | * \brief handle exception throwed out 7 | * \param e the exception 8 | * \return the return value of API after exception is handled 9 | */ 10 | inline int NNAPIHandleException(const dmlc::Error &e) { 11 | NNAPISetLastError(e.what()); 12 | return -1; 13 | } 14 | 15 | /*! \brief macro to guard beginning and end section of all functions */ 16 | #define API_BEGIN() try { 17 | /*! \brief every function starts with API_BEGIN(); 18 | and finishes with API_END() or API_END_HANDLE_ERROR */ 19 | #define API_END() } catch(dmlc::Error &_except_) { return NNAPIHandleException(_except_); } return 0; // NOLINT(*) 20 | /*! 21 | * \brief every function starts with API_BEGIN(); 22 | * and finishes with API_END() or API_END_HANDLE_ERROR 23 | * The finally clause contains procedure to cleanup states when an error happens. 24 | */ 25 | #define API_END_HANDLE_ERROR(Finalize) } catch(dmlc::Error &_except_) { Finalize; return NNAPIHandleException(_except_); } return 0; // NOLINT(*) 26 | 27 | /*! \brief entry to to easily hold returning information */ 28 | struct TinyAPIThreadLocalEntry { 29 | /*! \brief result holder for returning handles */ 30 | std::vector floatp; 31 | /*! \brief result holder for returning handles */ 32 | std::vector dtype; 33 | /*! \brief result holder for returning handles */ 34 | std::vector shape_ndim; 35 | /*! \brief result holder for returning handles */ 36 | std::vector shape_data; 37 | }; 38 | 39 | using namespace tinyflow; 40 | 41 | int NNSessionCreate(SessionHandle* handle, const char* option) { 42 | API_BEGIN(); 43 | *handle = Session::Create(option); 44 | API_END(); 45 | } 46 | 47 | int NNSessionClose(SessionHandle handle) { 48 | API_BEGIN(); 49 | delete static_cast(handle); 50 | API_END(); 51 | } 52 | 53 | int NNSessionRun(SessionHandle handle, 54 | SymbolHandle graph, 55 | nn_uint num_feed, 56 | const SymbolHandle* feed_placeholders, 57 | const float** feed_dptr, 58 | const nn_uint* feed_dtype, 59 | const nn_uint* feed_shape_csr_ptr, 60 | const nn_uint* feed_shape_data, 61 | nn_uint* num_out, 62 | const float*** out_dptr, 63 | const nn_uint** out_dtype, 64 | const nn_uint** out_shape_ndim, 65 | const nn_uint*** out_shape_data) { 66 | API_BEGIN(); 67 | std::unordered_map feed; 68 | for (nn_uint i = 0; i < num_feed; ++i) { 69 | const std::string& key = 70 | static_cast(feed_placeholders[i])->outputs[0].node->attrs.name; 71 | TBlob tmp; 72 | tmp.data = (void*)feed_dptr[i]; // NOLINT(*) 73 | tmp.shape = TShape(feed_shape_data + feed_shape_csr_ptr[i], 74 | feed_shape_data + feed_shape_csr_ptr[i + 1]); 75 | feed[key] = tmp; 76 | } 77 | 78 | const std::vector& out = static_cast(handle)->Run( 79 | static_cast(graph), feed); 80 | *num_out = static_cast(out.size()); 81 | auto* ret = dmlc::ThreadLocalStore::Get(); 82 | ret->floatp.resize(out.size()); 83 | ret->dtype.resize(out.size()); 84 | ret->shape_ndim.resize(out.size()); 85 | ret->shape_data.resize(out.size()); 86 | 87 | for (size_t i = 0; i < out.size(); ++i) { 88 | ret->floatp[i] = static_cast(out[i].data); 89 | ret->dtype[i] = out[i].dtype; 90 | ret->shape_ndim[i] = out[i].shape.ndim(); 91 | ret->shape_data[i] = out[i].shape.data(); 92 | } 93 | *out_dptr = dmlc::BeginPtr(ret->floatp); 94 | *out_dtype = dmlc::BeginPtr(ret->dtype); 95 | *out_shape_ndim = dmlc::BeginPtr(ret->shape_ndim); 96 | *out_shape_data = dmlc::BeginPtr(ret->shape_data); 97 | API_END(); 98 | return 0; 99 | } 100 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TinyFlow: Build Your Own DL System in 2K Lines 2 | 3 | TinyFlow is "example code" for [NNVM](https://github.com/dmlc/nnvm/). 4 | 5 | It demonstrates how can we build a clean, minimum and powerful computational 6 | graph based deep learning system with same API as TensorFlow. 7 | The operator code are implemented with [Torch7](https://github.com/torch/torch7) to reduce the effort to write operators while still demonstrating the concepts of the system (and Embedding Lua in C++ is kinda of fun:). 8 | 9 | TinyFlow is a real deep learning system that can run on GPU and CPUs. 10 | To support the examples, it takes. 11 | - 927 lines code for operators 12 | - 734 lines of code for execution runtime 13 | - 71 lines of code for API glue 14 | - 233 lines of code for front-end 15 | 16 | Note that more code in operators can easily be added to make it as 17 | feature complete as most existing deep learning systems. 18 | 19 | ## What is it for 20 | As explained in the goal of [NNVM](https://github.com/dmlc/nnvm/), 21 | it is important to make modular and reusable components for to enable us to build 22 | customized learning system easily. 23 | 24 | - Course Material for teaching DL system. TinyFlow can be used to teach student the concepts in deep learning systems. 25 | - e.g. design homeworks on implementing symbolic differentiation, memory allocation, operator fusion. 26 | - Experiment bed for learning system researchers. TinyFlow allows easy addition with new system features with 27 | the modular design being portable to other system that reuses NNVM. 28 | - Showcase of intermediate representation usecase. It demonstrates how intermediate representation like NNVM to be able to 29 | target multiple front-ends(TF, MXNet) and backends(Torch7, MXNet) with common set of optimizations. 30 | - Test bed on common reusable modules for DL system. TinyFlow, together with other systems(e.g. MXNet) can be used as testbed on the 31 | common reusable modules in deep learning to encourage front-end, optimization module and backends 32 | that are shared across frameworks. 33 | - Just for fun :) 34 | 35 | 36 | We believe the Unix Philosophy can building learning system more fun and everyone can be able to build 37 | and understand learning system better. 38 | 39 | ## The Design 40 | - The graph construction API is automatically reused from NNVM 41 | - We choose Torch7 as the default operator execution backend. 42 | - So TinyFlow can also be called "TorchFlow" since it is literally TensorFlow on top of Torch:) 43 | - This allows us to quickly implement the operators and focus code on the system part. 44 | - We intentionally choose to avoid using [MXNet](https://github.com/dmlc/mxnet) as front or backend, 45 | since MXNet already uses NNVM as intermediate layer, and it would be more fun to try something different. 46 | 47 | Although it is minimum. TinyFlow still comes with many advanced design concepts in Deep Learning system. 48 | - Automatic differentiation. 49 | - Shape/type inference. 50 | - Static memory allocation for graph for memory efficient training/inference. 51 | 52 | The operator implementation is easy Thanks to Torch7. More fun demonstrations will be added to the project. 53 | 54 | ## Dependencies 55 | Most of TinyFlow's code is self-contained. 56 | - TinyFlow depend on Torch7 for operator supports with minimum code. 57 | - We use a lightweight lua bridge code from dmlc-core/dmlc/lua.h 58 | - NNVM is used for graph representation and optimizations 59 | 60 | ## Build 61 | - Install Torch7 62 | - For OSX User, please install Torch with Lua 5.1 instead of LuaJIT, 63 | i.e. ```TORCH_LUA_VERSION=LUA51 ./install.sh``` 64 | - Set up environment variable ```TORCH_HOME``` to root of torch 65 | - Type ```make``` 66 | - Setup python path to include tinyflow and nnvm 67 | ```bash 68 | export PYTHONPATH=${PYTHONPATH}:/path/to/tinyflow/python:/path/to/tinyflow/nnvm/python 69 | ``` 70 | - Try example program ```python example/mnist_softmax.py``` 71 | 72 | ## Enable Fusion in TinyFlow 73 | - Build NNVM with Fusion: uncomment fusion plugin part in config.mk, then `make` 74 | - Build TinyFlow: enable `USE_FUSION` in Makefile, then `make` 75 | - Try Example program `example/mnist_lenet.py`, change the config of session from `tf.Session(config='gpu')` to `tf.Session(config='gpu fusion')` 76 | -------------------------------------------------------------------------------- /tests/python/test_ops.py: -------------------------------------------------------------------------------- 1 | import tinyflow as tf 2 | import numpy as np 3 | 4 | def check_ewise(ufunc): 5 | x = tf.placeholder(tf.float32) 6 | y = tf.placeholder(tf.float32) 7 | z = ufunc(x, y) 8 | ax = np.ones((2, 3)) 9 | ay = np.ones((2, 3)) * 4 10 | sess = tf.Session() 11 | az = sess.run(z, feed_dict={x:ax, y:ay}) 12 | np.testing.assert_almost_equal(az, ufunc(ax, ay)) 13 | 14 | def check_ewise_scalar(ufunc): 15 | x = tf.placeholder(tf.float32) 16 | y = 10; 17 | z = ufunc(x, y) 18 | ax = np.ones((2, 3)) 19 | sess = tf.Session() 20 | az = sess.run(z, feed_dict={x:ax}) 21 | np.testing.assert_almost_equal(az, ufunc(ax, y)) 22 | 23 | def check_ewise_rscalar(ufunc): 24 | x = 10; 25 | y = tf.placeholder(tf.float32) 26 | z = ufunc(x, y) 27 | ay = np.ones((2, 3)) 28 | sess = tf.Session() 29 | az = sess.run(z, feed_dict={y:ay}) 30 | np.testing.assert_almost_equal(az, ufunc(x, ay)) 31 | 32 | def test_ewise(): 33 | check_ewise(lambda x, y: x+y) 34 | check_ewise(lambda x, y: x-y) 35 | check_ewise(lambda x, y: x*y) 36 | check_ewise(lambda x, y: x/y) 37 | check_ewise(lambda x, y: x**y) 38 | check_ewise_scalar(lambda x, y: x+y) 39 | check_ewise_scalar(lambda x, y: x-y) 40 | check_ewise_scalar(lambda x, y: x*y) 41 | check_ewise_scalar(lambda x, y: x/y) 42 | check_ewise_rscalar(lambda x, y: x-y) 43 | check_ewise_rscalar(lambda x, y: x**y) 44 | 45 | def test_exp(): 46 | x = tf.placeholder(tf.float32) 47 | y = tf.exp(x) 48 | ax = np.ones((2, 3)) * 2 49 | sess = tf.Session() 50 | ay = sess.run(y, feed_dict={x:ax}) 51 | np.testing.assert_almost_equal(ay, np.exp(ax)) 52 | 53 | def test_log(): 54 | x = tf.placeholder(tf.float32) 55 | y = tf.log(x) 56 | ax = np.ones((2, 3)) * 2 57 | sess = tf.Session() 58 | ay = sess.run(y, feed_dict={x:ax}) 59 | np.testing.assert_almost_equal(ay, np.log(ax)) 60 | 61 | def test_sqrt(): 62 | x = tf.placeholder(tf.float32) 63 | y = tf.sqrt(x) 64 | ax = np.ones((2, 3)) * 2 65 | sess = tf.Session() 66 | ay = sess.run(y, feed_dict={x:ax}) 67 | np.testing.assert_almost_equal(ay, np.sqrt(ax)) 68 | 69 | def test_softmax(): 70 | x = tf.placeholder(tf.float32) 71 | y = tf.nn.softmax(x) 72 | ax = np.ones((2, 4)) 73 | sess = tf.Session() 74 | ay = sess.run(y, feed_dict={x:ax}) 75 | np.testing.assert_almost_equal( 76 | ay, ax / np.sum(ax, axis=1, keepdims=True)) 77 | 78 | def test_matmul(): 79 | x = tf.placeholder(tf.float32) 80 | y = tf.placeholder(tf.float32) 81 | ax = np.ones((2, 3)) 82 | ay = np.ones((3, 4)) * 4 83 | z = tf.matmul(x, y) * 4 84 | sess = tf.Session() 85 | az = sess.run(z, feed_dict={x:ax, y:ay}) 86 | np.testing.assert_almost_equal( 87 | az, np.dot(ax, ay) * 4) 88 | 89 | def test_sum(): 90 | axis = [1, 3] 91 | x = tf.placeholder(tf.float32) 92 | y = tf.reduce_sum(x, reduction_indices=axis) 93 | ax = np.random.uniform(size=(2, 4, 8, 7)) 94 | sess = tf.Session() 95 | ay = sess.run(y, feed_dict={x:ax}) 96 | npy = ax.sum(axis=tuple(axis)) 97 | assert(np.mean(np.abs(ay - npy))) < 1e-6 98 | 99 | def test_mean(): 100 | axis = [1, 3] 101 | x = tf.placeholder(tf.float32) 102 | y = tf.reduce_mean(x, reduction_indices=axis) 103 | ax = np.random.uniform(size=(2, 4, 8, 7)) 104 | sess = tf.Session() 105 | ay = sess.run(y, feed_dict={x:ax}) 106 | npy = ax.mean(axis=tuple(axis)) 107 | assert(np.mean(np.abs(ay - npy))) < 1e-6 108 | 109 | def test_argmax(): 110 | x = tf.placeholder(tf.float32) 111 | y = tf.argmax(x, 1) 112 | ax = np.random.uniform(size=(700, 10)) 113 | sess = tf.Session() 114 | ay = sess.run(y, feed_dict={x:ax}) 115 | npy = np.argmax(ax, 1) 116 | assert(np.mean(np.abs(ay - npy))) < 1e-6 117 | 118 | def test_pad(): 119 | out_filter = 10 120 | in_filter = 4 121 | pad_width = (out_filter-in_filter)//2 122 | x = tf.placeholder(tf.float32) 123 | y = tf.pad(x, dim=1, pad=-pad_width) 124 | z = tf.pad(y, dim=1, pad=pad_width) 125 | nx = np.random.randn(100, 4, 28, 28) 126 | npy = np.pad(nx, ((0, 0), (pad_width, pad_width), (0, 0), (0, 0)), 127 | mode='constant', constant_values=0) 128 | sess = tf.Session() 129 | sess.run(tf.initialize_all_variables()) 130 | ay = sess.run(z, feed_dict={x : nx}) 131 | assert(np.mean(np.abs(ay - npy))) < 1e-6 132 | 133 | 134 | if __name__ == "__main__": 135 | test_ewise() 136 | test_exp() 137 | test_log() 138 | test_sqrt() 139 | test_sum() 140 | test_mean() 141 | test_matmul() 142 | test_softmax() 143 | test_argmax() 144 | test_pad() 145 | pass 146 | -------------------------------------------------------------------------------- /src/torch/op_nn_torch.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2016 by Contributors 2 | // implementation of common nn operators 3 | #include 4 | #include 5 | #include "../op_util.h" 6 | 7 | namespace tinyflow { 8 | 9 | NNVM_REGISTER_OP(softmax) 10 | .set_attr( 11 | "FLuaCreateNNModule", R"( 12 | function(ishape, kwarg) 13 | return nn.SoftMax() 14 | end 15 | )"); 16 | 17 | 18 | NNVM_REGISTER_OP(tanh) 19 | .set_attr( 20 | "FLuaCreateNNModule", R"( 21 | function(ishape, kwarg) 22 | return nn.Tanh() 23 | end 24 | )"); 25 | 26 | 27 | NNVM_REGISTER_OP(relu) 28 | .set_attr( 29 | "FLuaCreateNNModule", R"( 30 | function(ishape, kwarg) 31 | return nn.ReLU() 32 | end 33 | )"); 34 | 35 | 36 | NNVM_REGISTER_OP(linear) 37 | .set_attr( 38 | "FLuaCreateNNModule", R"( 39 | function(ishape, kwarg) 40 | local wshape = ishape[2] 41 | local m = nn.Linear(wshape[2], wshape[1]) 42 | if #ishape == 2 then 43 | m = m:noBias() 44 | end 45 | return m 46 | end 47 | )"); 48 | 49 | 50 | NNVM_REGISTER_OP(pad) 51 | .set_attr( 52 | "FLuaCreateNNModule", R"( 53 | function(ishape, kwarg) 54 | local dim = tonumber(kwarg.dim) + 1 55 | local pad = tonumber(kwarg.pad) 56 | local m = nn.Padding(dim, pad) 57 | return m 58 | end 59 | )"); 60 | 61 | 62 | NNVM_REGISTER_OP(conv2d) 63 | .set_attr( 64 | "FLuaCreateNNModule", R"( 65 | function(ishape, kwarg) 66 | local dshape = ishape[2] 67 | local fshape = ishape[2] 68 | local outPlane = fshape[1] 69 | local inPlane = fshape[2] 70 | local kH = fshape[3] 71 | local kW = fshape[4] 72 | local inH = dshape[3] 73 | local inW = dshape[4] 74 | local stride = nn_parse_tuple(kwarg.strides, {1,1,1,1}) 75 | local dH = stride[2] 76 | local dW = stride[3] 77 | local padH = 0 78 | local padW = 0 79 | 80 | assert(kwarg.data_format == 'NCHW') 81 | if kwarg.padding == 'SAME' then 82 | padW = math.floor((kW - 1) / 2) 83 | padH = math.floor((kH - 1) / 2) 84 | end 85 | local m = nn.SpatialConvolution( 86 | inPlane, outPlane, 87 | kW, kH, dW, dH, padW, padH) 88 | if #ishape == 2 then 89 | m = m:noBias() 90 | end 91 | return m 92 | end 93 | )"); 94 | 95 | 96 | NNVM_REGISTER_OP(max_pool) 97 | .set_attr( 98 | "FLuaCreateNNModule", R"( 99 | function(ishape, kwarg) 100 | local ksize = nn_parse_tuple(kwarg.ksize) 101 | local stride = nn_parse_tuple(kwarg.strides, {1,1,1,1}) 102 | local kH = ksize[2] 103 | local kW = ksize[3] 104 | local dH = stride[2] 105 | local dW = stride[3] 106 | local padH = 0 107 | local padW = 0 108 | assert(kwarg.data_format == 'NCHW') 109 | if kwarg.padding == 'SAME' then 110 | padW = math.floor((kW - 1) / 2) 111 | padH = math.floor((kH - 1) / 2) 112 | end 113 | return nn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH) 114 | end 115 | )"); 116 | 117 | 118 | NNVM_REGISTER_OP(avg_pool) 119 | .set_attr( 120 | "FLuaCreateNNModule", R"( 121 | function(ishape, kwarg) 122 | local ksize = nn_parse_tuple(kwarg.ksize) 123 | local stride = nn_parse_tuple(kwarg.strides, {1,1,1,1}) 124 | local kH = ksize[2] 125 | local kW = ksize[3] 126 | local dH = stride[2] 127 | local dW = stride[3] 128 | local padH = 0 129 | local padW = 0 130 | assert(kwarg.data_format == 'NCHW') 131 | if kwarg.padding == 'SAME' then 132 | padW = math.floor((kW - 1) / 2) 133 | padH = math.floor((kH - 1) / 2) 134 | end 135 | local m = nn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH) 136 | return m 137 | end 138 | )"); 139 | 140 | 141 | NNVM_REGISTER_OP(batch_normalization) 142 | .set_attr( 143 | "FLuaCreateNNModule", R"( 144 | function(ishape, kwarg) 145 | local n = ishape[1][2] 146 | return nn.SpatialBatchNormalization(n) 147 | end 148 | )"); 149 | 150 | 151 | NNVM_REGISTER_OP(mean_sparse_softmax_cross_entropy_with_logits) 152 | .set_attr( 153 | "FLuaCreateNNModule", R"( 154 | function(ishape, kwarg) 155 | return nn_zero_index_target_criterion( 156 | nn.CrossEntropyCriterion()) 157 | end 158 | )"); 159 | 160 | 161 | const char* LuaReshape = R"( 162 | function(x, y, kwarg) 163 | if x[1]:storage() == y[1]:storage() then 164 | return function() end 165 | else 166 | return function() y[1]:copy(x[1]:resizeAs(y[1])) end 167 | end 168 | end 169 | )"; 170 | 171 | 172 | NNVM_REGISTER_OP(flatten_layer) 173 | .set_attr("FLuaCompute", LuaReshape); 174 | 175 | 176 | NNVM_REGISTER_OP(_flatten_backward) 177 | .set_attr("FLuaCompute", LuaReshape); 178 | 179 | } // namespace tinyflow 180 | -------------------------------------------------------------------------------- /src/op_util.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2016 by Contributors 3 | * \file op_util.h 4 | * \brief common util to define operators 5 | */ 6 | #ifndef TINYFLOW_OP_UTIL_H_ 7 | #define TINYFLOW_OP_UTIL_H_ 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | namespace tinyflow { 17 | 18 | using namespace nnvm; 19 | 20 | // assign rhs to lhs, check if shape is consistent 21 | #define SHAPE_ASSIGN(lhs, rhs) \ 22 | if ((lhs).ndim() == 0) (lhs) = (rhs); \ 23 | else \ 24 | CHECK_EQ(lhs, rhs) << "shape inference inconsistent"; \ 25 | 26 | // assign rhs to lhs, check if type is consistent 27 | #define DTYPE_ASSIGN(lhs, rhs) \ 28 | if ((lhs) == -1) (lhs) = (rhs); \ 29 | else \ 30 | CHECK_EQ(lhs, rhs) << "type inference inconsistent"; \ 31 | 32 | 33 | // simply return the shape as same 34 | inline bool SameShape(const NodeAttrs& attrs, 35 | std::vector *ishape, 36 | std::vector *oshape) { 37 | TShape def_v; 38 | for (TShape& pshape : *oshape) { 39 | if (pshape.ndim() != 0) { 40 | def_v = pshape; break; 41 | } 42 | } 43 | if (def_v.ndim() == 0) { 44 | for (TShape& pshape : *ishape) { 45 | if (pshape.ndim() != 0) { 46 | def_v = pshape; 47 | if (pshape.ndim() != 1 || pshape[0] != 1) { 48 | break; 49 | } 50 | } 51 | } 52 | } 53 | if (def_v.ndim() == 0) return false; 54 | 55 | for (TShape& pshape : *oshape) { 56 | SHAPE_ASSIGN(pshape, def_v); 57 | } 58 | for (TShape& pshape : *ishape) { 59 | if (pshape.ndim() == 1 && pshape[0] == 1) { 60 | continue; 61 | } 62 | SHAPE_ASSIGN(pshape, def_v); 63 | } 64 | return true; 65 | } 66 | 67 | // The output is a scalar. 68 | inline bool ScalarShape(const NodeAttrs& attrs, 69 | std::vector *ishape, 70 | std::vector *oshape) { 71 | for (TShape& pshape : *ishape) { 72 | if (pshape.ndim() == 0) return false; 73 | } 74 | SHAPE_ASSIGN(oshape->at(0), TShape{1}); 75 | return true; 76 | } 77 | 78 | inline std::vector > InplaceIn0Out0(const NodeAttrs& attrs) { 79 | return {{0, 0}}; 80 | } 81 | 82 | inline std::vector > InplaceIn1Out0(const NodeAttrs& attrs) { 83 | return {{1, 0}}; 84 | } 85 | 86 | /*! \brief Parse keyword arguments as PType arguments and save to parsed */ 87 | template 88 | inline void ParamParser(nnvm::NodeAttrs* attrs) { 89 | PType param; 90 | param.Init(attrs->dict); 91 | attrs->parsed = std::move(param); 92 | } 93 | 94 | // quick helper to make node 95 | inline NodeEntry MakeNode(const char* op_name, 96 | std::string node_name, 97 | std::vector inputs, 98 | std::unordered_map kwarg = {}) { 99 | NodePtr p = Node::Create(); 100 | p->attrs.op = nnvm::Op::Get(op_name); 101 | p->attrs.name = std::move(node_name); 102 | p->inputs = std::move(inputs); 103 | p->attrs.dict = std::move(kwarg); 104 | if (p->op()->attr_parser != nullptr) { 105 | p->op()->attr_parser(&(p->attrs)); 106 | } 107 | return NodeEntry{p, 0, 0}; 108 | } 109 | 110 | // make backward convention node of an op 111 | inline std::vector MakeBackwardGrads( 112 | const char* op_name, 113 | const NodePtr& n, 114 | std::vector inputs, 115 | std::unordered_map kwarg = {}) { 116 | NodePtr p = Node::Create(); 117 | p->attrs.op = nnvm::Op::Get(op_name); 118 | p->attrs.name = std::move(n->attrs.name + "_backward"); 119 | p->inputs = std::move(inputs); 120 | p->attrs.dict = std::move(kwarg); 121 | p->control_deps.push_back(n); 122 | std::vector ret; 123 | for (uint32_t i = 0; i < p->num_outputs(); ++i) { 124 | ret.emplace_back(NodeEntry{p, i, 0}); 125 | } 126 | return ret; 127 | } 128 | 129 | // special parameter stored in backward node. 130 | struct NNBackwardParam { 131 | // total number of inputs in forward op 132 | uint32_t forward_readonly_inputs; 133 | // number of internal states in the op 134 | uint32_t num_states{0}; 135 | // number of inputs who do not have gradients. 136 | uint32_t num_no_grad_inputs{0}; 137 | // whether backward need all te inputs. 138 | bool need_inputs{true}; 139 | // whether backward need all the outputs. 140 | bool need_outputs{true}; 141 | }; 142 | 143 | } // namespace tinyflow 144 | 145 | #endif // TINYFLOW_OP_UTIL_H_ 146 | -------------------------------------------------------------------------------- /src/torch/torch_util.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2016 by Contributors 3 | * \file torch_util.h 4 | * \brief common util to reuse things from torch. 5 | */ 6 | #ifndef TINYFLOW_TORCH_UTIL_H_ 7 | #define TINYFLOW_TORCH_UTIL_H_ 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | namespace dmlc { 15 | namespace lua_stack { 16 | // enable pass in TShape as arguments 17 | template<> 18 | struct Handler { 19 | static inline nnvm::TShape Get(lua_State* L, int index, LuaState* s) { 20 | std::vector v = Handler >::Get(L, index, s); 21 | return nnvm::TShape(v.begin(), v.end()); 22 | } 23 | static inline void Push(lua_State* L, const nnvm::TShape& shape) { 24 | std::vector v(shape.begin(), shape.end()); 25 | Handler >::Push(L, v); 26 | } 27 | }; 28 | 29 | } // namespace lua_stack 30 | } // namespace dmlc 31 | 32 | namespace tinyflow { 33 | 34 | using dmlc::LuaRef; 35 | using dmlc::LuaState; 36 | 37 | // hhelper to create new functions 38 | class TorchState { 39 | public: 40 | TorchState() { 41 | auto* lua = LuaState::ThreadLocalState(); 42 | lua->Eval("require 'torch'"); 43 | lua->Eval("require 'nn'"); 44 | lua->Eval("torch.setdefaulttensortype('torch.FloatTensor')"); 45 | LuaRef parse_tuple = lua->Eval(R"( 46 | return function(s, def) 47 | if s == nil then 48 | return def 49 | end 50 | t = {} 51 | for k in string.gmatch(s, '%d') do 52 | table.insert(t, tonumber(k)) 53 | end 54 | return t 55 | end 56 | )"); 57 | LuaRef zero_index_target_criterion = lua->Eval(R"( 58 | return function(c) 59 | local updateOutput = c.updateOutput 60 | local updateGradInput = c.updateGradInput 61 | c.updateOutput = function(self, input, target) 62 | return updateOutput(self, input, target + 1) 63 | end 64 | c.updateGradInput = function(self, input, target) 65 | return updateGradInput(self, input, target + 1) 66 | end 67 | return c 68 | end 69 | )"); 70 | lua->SetGlobalField("nn_parse_tuple", parse_tuple); 71 | lua->SetGlobalField("nn_zero_index_target_criterion", zero_index_target_criterion); 72 | } 73 | // prepare for GPU ops 74 | inline void InitGPU() { 75 | if (gpu_init_) return; 76 | LOG(INFO) << "start to initialize ..."; 77 | auto* lua = LuaState::ThreadLocalState(); 78 | lua->Eval("require 'cutorch'"); 79 | lua->Eval("require 'cunn'"); 80 | lua->Eval("require 'cudnn'"); 81 | LOG(INFO) << "finished gpu initialization..."; 82 | gpu_init_ = true; 83 | } 84 | // create a new storage with given size 85 | LuaRef NewStorage(size_t size, int dev_mask = kCPU, int dtype = 0) { 86 | CHECK_EQ(dtype, 0) << "only float is supported so far"; 87 | if (fstorage_new_.is_nil()) { 88 | auto* lua = LuaState::ThreadLocalState(); 89 | fstorage_new_ = lua->Eval(R"( 90 | return 91 | function(size, dev_mask) 92 | if dev_mask == 1 then 93 | return torch.FloatStorage(size) 94 | else 95 | return torch.CudaTensor(size) 96 | end 97 | end 98 | )"); 99 | } 100 | return fstorage_new_(size, dev_mask); 101 | } 102 | // create a new empty tensor container 103 | LuaRef NewTensorEmpty(int dev_mask = kCPU, int dtype = 0) { 104 | CHECK_EQ(dtype, 0) << "only float is supported so far"; 105 | if (ftensor_new_.is_nil()) { 106 | auto* lua = LuaState::ThreadLocalState(); 107 | ftensor_new_ = lua->Eval(R"( 108 | return 109 | function(dev_mask) 110 | if dev_mask == 1 then 111 | return torch.FloatTensor() 112 | else 113 | return torch.CudaTensor() 114 | end 115 | end 116 | )"); 117 | } 118 | return ftensor_new_(dev_mask); 119 | } 120 | // create a new tensor that shares space with src 121 | // The memory is managed by src. 122 | LuaRef NewTensorShared(TBlob src) { 123 | CHECK_EQ(src.dtype, 0) << "only float is supported so far"; 124 | if (ftensor_new_shared_.is_nil()) { 125 | auto* lua = LuaState::ThreadLocalState(); 126 | ftensor_new_shared_ = lua->Eval(R"( 127 | return 128 | function(ptr, shape, size, dev_mask) 129 | local sz = torch.LongStorage(shape) 130 | local storage 131 | if dev_mask == 1 then 132 | storage = torch.FloatStorage(size, ptr) 133 | return torch.FloatTensor(storage, 1, sz) 134 | else 135 | storage = torch.CudaStorage(size, ptr) 136 | return torch.CudaTensor(storage, 1, sz) 137 | end 138 | end 139 | )"); 140 | } 141 | return ftensor_new_shared_( 142 | reinterpret_cast(src.data), 143 | src.shape, src.shape.Size(), src.dev_mask); 144 | } 145 | // copy from one tensor to another one 146 | void CopyFromTo(LuaRef from, LuaRef to) { 147 | if (fcopy_from_to_.is_nil()) { 148 | auto* lua = LuaState::ThreadLocalState(); 149 | fcopy_from_to_ = lua->Eval(R"( 150 | return 151 | function(from, to) 152 | to:copy(from) 153 | end 154 | )"); 155 | } 156 | fcopy_from_to_(from, to); 157 | } 158 | // reset the storage of tensor to storage. 159 | void ResetStorage(LuaRef tensor, 160 | LuaRef storage, 161 | TShape shape) { 162 | if (ftensor_set_.is_nil()) { 163 | auto* lua = LuaState::ThreadLocalState(); 164 | ftensor_set_ = lua->Eval(R"( 165 | return 166 | function(tensor, storage, shape) 167 | sz = torch.LongStorage(shape) 168 | -- cutorch does not support pass size in set 169 | tensor:set(storage, 1) 170 | tensor:resize(sz) 171 | end 172 | )"); 173 | } 174 | ftensor_set_(tensor, storage, shape); 175 | } 176 | // Get the internal TBlob representation of 177 | // The tensor object must stay alive to keep the space valid. 178 | TBlob GetTBlob(LuaRef tensor) { 179 | if (fget_internal_.is_nil()) { 180 | auto* lua = LuaState::ThreadLocalState(); 181 | fget_internal_ = lua->Eval(R"( 182 | return 183 | function(tensor) 184 | local dev_mask 185 | t = tensor:type() 186 | if t == 'torch.FloatTensor' then 187 | dev_mask = 1 188 | elseif t == 'torch.CudaTensor' then 189 | dev_mask = 2 190 | else 191 | error('only float tensor is supported') 192 | end 193 | local data = tonumber(torch.data(tensor, true)) 194 | local shape = tensor:size():totable() 195 | return {data, shape, dev_mask} 196 | end 197 | )"); 198 | } 199 | LuaRef temp = fget_internal_(tensor); 200 | TBlob ret; 201 | ret.data = reinterpret_cast(temp[1].Get()); 202 | ret.shape = temp[2].Get(); 203 | ret.dev_mask = temp[3].Get(); 204 | return ret; 205 | } 206 | // return threadlocal state for torch. 207 | static TorchState* ThreadLocalState() { 208 | return dmlc::ThreadLocalStore::Get(); 209 | } 210 | 211 | private: 212 | bool gpu_init_{false}; 213 | LuaRef fstorage_new_; 214 | LuaRef ftensor_new_; 215 | LuaRef ftensor_new_shared_; 216 | LuaRef ftensor_set_; 217 | LuaRef fcopy_from_to_; 218 | LuaRef fget_internal_; 219 | }; 220 | 221 | } // namespace tinyflow 222 | 223 | #endif // TINYFLOW_TORCH_UTIL_H_ 224 | -------------------------------------------------------------------------------- /src/torch/op_tensor_torch.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2016 by Contributors 2 | // implementation of common tensor operators 3 | #include 4 | 5 | namespace tinyflow { 6 | 7 | NNVM_REGISTER_OP(zeros) 8 | .set_attr( 9 | "FLuaCompute", R"( 10 | function(x, y, kwarg) 11 | return function() 12 | y[1]:fill(0) 13 | end 14 | end 15 | )"); 16 | 17 | 18 | NNVM_REGISTER_OP(zeros_like) 19 | .set_attr( 20 | "FLuaCompute", R"( 21 | function(x, y, kwarg) 22 | return function() 23 | y[1]:fill(0) 24 | end 25 | end 26 | )"); 27 | 28 | 29 | NNVM_REGISTER_OP(ones) 30 | .set_attr( 31 | "FLuaCompute", R"( 32 | function(x, y, kwarg) 33 | return function() 34 | y[1]:fill(1) 35 | end 36 | end 37 | )"); 38 | 39 | 40 | NNVM_REGISTER_OP(ones_like) 41 | .set_attr( 42 | "FLuaCompute", R"( 43 | function(x, y, kwarg) 44 | return function() 45 | y[1]:fill(1) 46 | end 47 | end 48 | )"); 49 | 50 | 51 | NNVM_REGISTER_OP(normal) 52 | .set_attr( 53 | "FLuaCompute", R"( 54 | function(x, y, kwarg) 55 | return function() 56 | local scale = 1 57 | if kwarg.stdev ~= nil then 58 | scale = tonumber(kwarg.stdev) 59 | end 60 | y[1]:copy(torch.randn(y[1]:size()) * scale) 61 | end 62 | end 63 | )"); 64 | 65 | 66 | NNVM_REGISTER_OP(equal) 67 | .set_attr( 68 | "FLuaCompute", R"( 69 | function(x, y, kwarg) 70 | return function() 71 | y[1]:copy(torch.eq(x[1], x[2])) 72 | end 73 | end 74 | )"); 75 | 76 | 77 | NNVM_REGISTER_OP(__ewise_sum__) 78 | .set_attr( 79 | "FLuaCompute", R"( 80 | function(x, y, kwarg) 81 | return function() 82 | y[1]:copy(x[1]) 83 | for i = 2, #x do 84 | torch.add(y[1], y[1], x[i]) 85 | end 86 | end 87 | end 88 | )"); 89 | 90 | 91 | NNVM_REGISTER_OP(__add_symbol__) 92 | .set_attr( 93 | "FLuaCompute", R"( 94 | function(x, y, kwarg) 95 | return function() 96 | torch.add(y[1], x[1], x[2]) 97 | end 98 | end 99 | )"); 100 | 101 | 102 | NNVM_REGISTER_OP(__add_scalar__) 103 | .set_attr( 104 | "FLuaCompute", R"( 105 | function(x, y, kwarg) 106 | local scalar = tonumber(kwarg.scalar) 107 | return function() 108 | torch.add(y[1], x[1], scalar) 109 | end 110 | end 111 | )"); 112 | 113 | 114 | NNVM_REGISTER_OP(__sub_symbol__) 115 | .set_attr( 116 | "FLuaCompute", R"( 117 | function(x, y, kwarg) 118 | return function() 119 | torch.add(y[1], x[1], -x[2]) 120 | end 121 | end 122 | )"); 123 | 124 | 125 | NNVM_REGISTER_OP(__sub_scalar__) 126 | .set_attr( 127 | "FLuaCompute", R"( 128 | function(x, y, kwarg) 129 | local scalar = tonumber(kwarg.scalar) 130 | return function() 131 | torch.add(y[1], x[1], -scalar) 132 | end 133 | end 134 | )"); 135 | 136 | 137 | NNVM_REGISTER_OP(__rsub_scalar__) 138 | .set_attr( 139 | "FLuaCompute", R"( 140 | function(x, y, kwarg) 141 | local scalar = tonumber(kwarg.scalar) 142 | return function() 143 | torch.add(y[1], -x[1], scalar) 144 | end 145 | end 146 | )"); 147 | 148 | 149 | NNVM_REGISTER_OP(__mul_symbol__) 150 | .set_attr( 151 | "FLuaCompute", R"( 152 | function(x, y, kwarg) 153 | return function() 154 | if x[1]:dim() == 1 and x[1]:size(1) == 1 then 155 | local scalar = x[1][1] 156 | torch.mul(y[1], x[2], scalar) 157 | return 158 | end 159 | if x[2]:dim() == 1 and x[2]:size(1) == 1 then 160 | local scalar = x[2][1] 161 | torch.mul(y[1], x[1], scalar) 162 | return 163 | end 164 | torch.cmul(y[1], x[1], x[2]) 165 | end 166 | end 167 | )"); 168 | 169 | 170 | NNVM_REGISTER_OP(__mul_scalar__) 171 | .set_attr( 172 | "FLuaCompute", R"( 173 | function(x, y, kwarg) 174 | local scalar = tonumber(kwarg.scalar) 175 | return function() 176 | torch.mul(y[1], x[1], scalar) 177 | end 178 | end 179 | )"); 180 | 181 | 182 | NNVM_REGISTER_OP(__div_symbol__) 183 | .set_attr( 184 | "FLuaCompute", R"( 185 | function(x, y, kwarg) 186 | return function() 187 | torch.cdiv(y[1], x[1], x[2]) 188 | end 189 | end 190 | )"); 191 | 192 | 193 | NNVM_REGISTER_OP(__div_scalar__) 194 | .set_attr( 195 | "FLuaCompute", R"( 196 | function(x, y, kwarg) 197 | local scalar = tonumber(kwarg.scalar) 198 | return function() 199 | torch.div(y[1], x[1], scalar) 200 | end 201 | end 202 | )"); 203 | 204 | 205 | NNVM_REGISTER_OP(exp) 206 | .set_attr( 207 | "FLuaCompute", R"( 208 | function(x, y, kwarg) 209 | return function() 210 | torch.exp(y[1], x[1]) 211 | end 212 | end 213 | )"); 214 | 215 | 216 | NNVM_REGISTER_OP(log) 217 | .set_attr( 218 | "FLuaCompute", R"( 219 | function(x, y, kwarg) 220 | return function() 221 | torch.log(y[1], x[1]) 222 | end 223 | end 224 | )"); 225 | 226 | 227 | NNVM_REGISTER_OP(sqrt) 228 | .set_attr( 229 | "FLuaCompute", R"( 230 | function(x, y, kwarg) 231 | return function() 232 | torch.sqrt(y[1], x[1]) 233 | end 234 | end 235 | )"); 236 | 237 | 238 | NNVM_REGISTER_OP(__pow_symbol__) 239 | .set_attr( 240 | "FLuaCompute", R"( 241 | function(x, y, kwarg) 242 | return function() 243 | torch.cpow(y[1], x[1], x[2]) 244 | end 245 | end 246 | )"); 247 | 248 | 249 | NNVM_REGISTER_OP(__rpow_scalar__) 250 | .set_attr( 251 | "FLuaCompute", R"( 252 | function(x, y, kwarg) 253 | local scalar = tonumber(kwarg.scalar) 254 | return function() 255 | torch.pow(y[1], scalar, x[1]) 256 | end 257 | end 258 | )"); 259 | 260 | 261 | NNVM_REGISTER_OP(matmul) 262 | .set_attr( 263 | "FLuaCompute", R"( 264 | function(x, y, kwarg) 265 | return function() 266 | torch.mm(y[1], x[1], x[2]) 267 | end 268 | end 269 | )"); 270 | 271 | 272 | // simply register a bulk op for backward 273 | NNVM_REGISTER_OP(_matmul_backward) 274 | .set_attr( 275 | "FLuaCompute", R"( 276 | function(x, y, kwarg) 277 | local gradOutput = x[1] 278 | local lhs = x[2] 279 | local rhs = x[3] 280 | local gradLhs = y[1] 281 | local gradRhs = y[2] 282 | return function() 283 | torch.mm(gradRhs, lhs:t(), gradOutput) 284 | torch.mm(gradLhs, gradOutput, rhs:t()) 285 | end 286 | end 287 | )"); 288 | 289 | 290 | NNVM_REGISTER_OP(reduce_sum) 291 | .set_attr( 292 | "FLuaCompute", R"( 293 | function(x, y, kwarg) 294 | local rhs = x[1] 295 | local lhs = y[1] 296 | if kwarg.reduction_indices == nil then 297 | rhs = rhs:view(rhs:nElement()) 298 | return function() 299 | torch.sum(lhs, rhs, 1) 300 | end 301 | else 302 | local axis = nn_parse_tuple(kwarg.reduction_indices) 303 | table.sort(axis) 304 | local k = #axis 305 | return function() 306 | for i = 1, (k - 1) do 307 | rhs = torch.sum(rhs, axis[k - i + 1] + 1) 308 | end 309 | torch.sum(lhs, rhs, axis[1] + 1) 310 | end 311 | end 312 | end 313 | )"); 314 | 315 | 316 | NNVM_REGISTER_OP(reduce_mean) 317 | .set_attr( 318 | "FLuaCompute", R"( 319 | function(x, y, kwarg) 320 | local rhs = x[1] 321 | local lhs = y[1] 322 | if kwarg.reduction_indices == nil then 323 | rhs = rhs:view(rhs:nElement()) 324 | return function() 325 | torch.mean(lhs, rhs, 1) 326 | end 327 | else 328 | local axis = nn_parse_tuple(kwarg.reduction_indices) 329 | table.sort(axis) 330 | local k = #axis 331 | return function() 332 | for i = 1, (k - 1) do 333 | rhs = torch.mean(rhs, axis[k - i + 1] + 1) 334 | end 335 | torch.mean(lhs, rhs, axis[1] + 1) 336 | end 337 | end 338 | end 339 | )"); 340 | 341 | 342 | NNVM_REGISTER_OP(_reduce_sum_backward) 343 | .set_attr( 344 | "FLuaCompute", R"( 345 | function(x, y, kwarg) 346 | local rhs = x[1] 347 | local lhs = y[1] 348 | if kwarg.reduction_indices == nil then 349 | lhs = lhs:view(lhs:nElement()) 350 | rhs = rhs:expandAs(lhs) 351 | else 352 | local axis = nn_parse_tuple(kwarg.reduction_indices) 353 | local vshape = lhs:size() 354 | for i = 1, #axis do 355 | vshape[axis[i] + 1] = 1 356 | end 357 | rhs = rhs:view(vshape):expandAs(lhs) 358 | end 359 | return function() 360 | lhs:copy(rhs) 361 | end 362 | end 363 | )"); 364 | 365 | 366 | NNVM_REGISTER_OP(_reduce_mean_backward) 367 | .set_attr( 368 | "FLuaCompute", R"( 369 | function(x, y, kwarg) 370 | local rhs = x[1] 371 | local lhs = y[1] 372 | local scale = 1 373 | if kwarg.reduction_indices == nil then 374 | lhs = lhs:view(lhs:nElement()) 375 | rhs = rhs:expandAs(lhs) 376 | scale = lhs:nElement() 377 | else 378 | local axis = nn_parse_tuple(kwarg.reduction_indices) 379 | local vshape = lhs:size() 380 | for i = 1, #axis do 381 | scale = scale * vshape[axis[i] + 1] 382 | vshape[axis[i] + 1] = 1 383 | end 384 | rhs = rhs:view(vshape):expandAs(lhs) 385 | end 386 | return function() 387 | torch.div(lhs, rhs, scale) 388 | end 389 | end 390 | )"); 391 | 392 | 393 | NNVM_REGISTER_OP(_argmax) 394 | .set_attr( 395 | "FLuaCompute", R"( 396 | function(x, y, kwarg) 397 | local rhs = x[1] 398 | local lhs = y[1] 399 | local axis = nn_parse_tuple(kwarg.reduction_indices) 400 | return function() 401 | local mx, ind = torch.max(rhs, axis[1] + 1) 402 | torch.add(lhs, ind:typeAs(lhs), -1) 403 | end 404 | end 405 | )"); 406 | 407 | } // namespace tinyflow 408 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2016 by contributors. 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/op_nn.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2016 by Contributors 2 | // implementation of common nn operators 3 | #include 4 | #include 5 | #include "./op_util.h" 6 | 7 | namespace tinyflow { 8 | 9 | using namespace nnvm; 10 | 11 | // create a backward node 12 | inline std::vector MakeNNBackwardNode( 13 | const NodePtr& n, 14 | const std::vector& ograds) { 15 | static auto& backward_need_inputs = Op::GetAttr("TBackwardNeedInputs"); 16 | static auto& backward_need_outputs = Op::GetAttr("TBackwardNeedOutputs"); 17 | static auto& backward_num_nograd = Op::GetAttr("TBackwardNumNoGradInputs"); 18 | nnvm::NodePtr p = nnvm::Node::Create(); 19 | p->attrs.op = nnvm::Op::Get("_backward"); 20 | p->attrs.name = n->attrs.name + "_backward"; 21 | 22 | NNBackwardParam param; 23 | param.forward_readonly_inputs = static_cast(n->inputs.size()); 24 | param.need_inputs = backward_need_inputs[n->op()]; 25 | param.need_outputs = backward_need_outputs[n->op()]; 26 | param.num_no_grad_inputs = backward_num_nograd.get(n->op(), 0); 27 | CHECK_EQ(ograds.size(), 1); 28 | CHECK_EQ(param.forward_readonly_inputs + param.num_states, 29 | static_cast(n->inputs.size())); 30 | 31 | p->attrs.parsed = param; 32 | p->control_deps.emplace_back(n); 33 | // layout [output_grad, inputs, states, output] 34 | p->inputs.push_back(ograds[0]); 35 | if (param.need_inputs) { 36 | for (index_t i = 0; i < param.forward_readonly_inputs; ++i) { 37 | p->inputs.push_back(n->inputs[i]); 38 | } 39 | } 40 | for (index_t i = 0; i < param.num_states; ++i) { 41 | p->inputs.push_back(n->inputs[i + param.forward_readonly_inputs]); 42 | } 43 | if (param.need_outputs) { 44 | for (uint32_t i = 0; i < n->num_outputs(); ++i) { 45 | p->inputs.emplace_back(nnvm::NodeEntry{n, i, 0}); 46 | } 47 | } 48 | 49 | std::vector ret; 50 | for (index_t i = 0; i < param.forward_readonly_inputs; ++i) { 51 | ret.emplace_back(nnvm::NodeEntry{p, i, 0}); 52 | } 53 | if (param.num_states != 0 || param.num_no_grad_inputs != 0) { 54 | nnvm::NodePtr np = nnvm::Node::Create(); 55 | np->attrs.op = nnvm::Op::Get("_no_gradient"); 56 | for (uint32_t i = 0; i < param.num_no_grad_inputs; ++i) { 57 | ret.at(ret.size() - i - 1) = nnvm::NodeEntry{np, 0, 0}; 58 | } 59 | for (index_t i = 0; i < param.num_states; ++i) { 60 | ret.emplace_back(nnvm::NodeEntry{np, 0, 0}); 61 | } 62 | } 63 | return ret; 64 | } 65 | 66 | 67 | NNVM_REGISTER_OP(_backward) 68 | .describe("backward operator of NN module") 69 | .set_num_outputs([] (const NodeAttrs& attrs) { 70 | const NNBackwardParam& param = dmlc::get(attrs.parsed); 71 | return param.forward_readonly_inputs - param.num_no_grad_inputs; 72 | }) 73 | .set_num_inputs([] (const NodeAttrs& attrs) { 74 | const NNBackwardParam& param = dmlc::get(attrs.parsed); 75 | uint32_t n = param.num_states + 1; 76 | if (param.need_inputs) n += param.forward_readonly_inputs; 77 | if (param.need_outputs) n += 1; 78 | return n; 79 | }) 80 | .set_attr("TIsBackward", true); 81 | 82 | 83 | // common attributes for nn module. 84 | NNVM_REGISTER_OP_GROUP(nn_module) 85 | .set_attr("FGradient", MakeNNBackwardNode) 86 | .set_attr("TBackwardNeedInputs", true) 87 | .set_attr("TBackwardNeedOutputs", true); 88 | 89 | 90 | NNVM_REGISTER_OP_GROUP(nn_criterion) 91 | .set_attr("FGradient", MakeNNBackwardNode) 92 | .set_attr("TBackwardNumNoGradInputs", 1) 93 | .set_attr("TBackwardNeedInputs", true) 94 | .set_attr("TBackwardNeedOutputs", false) 95 | .set_attr("FInferShape", ScalarShape); 96 | 97 | 98 | NNVM_REGISTER_OP(softmax) 99 | .describe("Softmax operation") 100 | .set_num_inputs(1) 101 | .include("nn_module") 102 | .set_attr("FInferShape", SameShape); 103 | 104 | 105 | NNVM_REGISTER_OP(relu) 106 | .describe("Relu operation") 107 | .set_num_inputs(1) 108 | .include("nn_module") 109 | .set_attr("FInferShape", SameShape) 110 | .set_attr("TBackwardNeedOutputs", true); 111 | 112 | 113 | NNVM_REGISTER_OP(tanh) 114 | .describe("Tanh operation") 115 | .set_num_inputs(1) 116 | .include("nn_module") 117 | .set_attr("FInferShape", SameShape); 118 | 119 | 120 | // same as matrix multiplication, but automatically infers shape 121 | struct LinearParam : public dmlc::Parameter { 122 | uint32_t num_hidden; 123 | bool no_bias; 124 | 125 | DMLC_DECLARE_PARAMETER(LinearParam) { 126 | DMLC_DECLARE_FIELD(num_hidden).set_default(0); 127 | DMLC_DECLARE_FIELD(no_bias).set_default(true); 128 | } 129 | }; 130 | DMLC_REGISTER_PARAMETER(LinearParam); 131 | 132 | inline bool LinearShape(const NodeAttrs& attrs, 133 | std::vector *ishape, 134 | std::vector *oshape) { 135 | const auto& param = dmlc::get(attrs.parsed); 136 | if (ishape->at(0).ndim() == 0) return false; 137 | const TShape& in = ishape->at(0); 138 | TShape wshape; 139 | if (param.num_hidden != 0) { 140 | wshape = TShape{param.num_hidden, in[1]}; 141 | SHAPE_ASSIGN(ishape->at(1), wshape); 142 | } else { 143 | if (ishape->at(1).ndim() == 0) return false; 144 | } 145 | if (ishape->size() > 2) { 146 | TShape bshape{ishape->at(1)[0]}; 147 | SHAPE_ASSIGN(ishape->at(2), bshape); 148 | } 149 | TShape out{in[0], wshape[0]}; 150 | SHAPE_ASSIGN(oshape->at(0), out); 151 | return true; 152 | } 153 | 154 | NNVM_REGISTER_OP(linear) 155 | .describe("A linear transformation layer") 156 | .set_attr_parser(ParamParser) 157 | .set_num_inputs([](const NodeAttrs& attrs) { 158 | return (dmlc::get(attrs.parsed).no_bias? 2 : 3); 159 | }) 160 | .set_attr("FListInputNames", [](const NodeAttrs& attrs) { 161 | if (dmlc::get(attrs.parsed).no_bias) { 162 | return std::vector{"data", "weight"}; 163 | } else { 164 | return std::vector{"data", "weight", "bias"}; 165 | } 166 | }) 167 | .include("nn_module") 168 | .set_attr("FInferShape", LinearShape); 169 | 170 | 171 | struct PadParam : public dmlc::Parameter { 172 | uint32_t dim; 173 | int pad; 174 | 175 | DMLC_DECLARE_PARAMETER(PadParam) { 176 | DMLC_DECLARE_FIELD(dim).set_default(0); 177 | DMLC_DECLARE_FIELD(pad).set_default(0); 178 | } 179 | }; 180 | DMLC_REGISTER_PARAMETER(PadParam); 181 | 182 | inline bool PadShape(const NodeAttrs& attrs, 183 | std::vector *ishape, 184 | std::vector *oshape) { 185 | const auto& param = dmlc::get(attrs.parsed); 186 | if (ishape->at(0).ndim() == 0) { 187 | return false; 188 | } 189 | TShape out = ishape->at(0); 190 | out[param.dim] += abs(param.pad); 191 | oshape->at(0) = out; 192 | return true; 193 | } 194 | 195 | NNVM_REGISTER_OP(pad) 196 | .describe("pads a tensor") 197 | .set_num_inputs(1) 198 | .include("nn_module") 199 | .set_attr_parser(ParamParser) 200 | .set_attr("FInferShape", PadShape); 201 | 202 | 203 | struct ConvPoolParam : public dmlc::Parameter { 204 | TShape ksize; 205 | TShape strides; 206 | std::string padding; 207 | std::string data_format; 208 | bool no_bias; 209 | uint32_t num_filter; 210 | 211 | DMLC_DECLARE_PARAMETER(ConvPoolParam) { 212 | DMLC_DECLARE_FIELD(ksize).set_default(TShape{1, 1, 1, 1}); 213 | DMLC_DECLARE_FIELD(strides).set_default(TShape{1, 1, 1, 1}); 214 | DMLC_DECLARE_FIELD(padding).set_default("SAME"); 215 | DMLC_DECLARE_FIELD(data_format).set_default("NCHW"); 216 | DMLC_DECLARE_FIELD(no_bias).set_default(true); 217 | DMLC_DECLARE_FIELD(num_filter).set_default(0); 218 | } 219 | }; 220 | DMLC_REGISTER_PARAMETER(ConvPoolParam); 221 | 222 | inline bool ConvPoolShape(const NodeAttrs& attrs, 223 | std::vector *ishape, 224 | std::vector *oshape) { 225 | const auto& param = dmlc::get(attrs.parsed); 226 | if (ishape->at(0).ndim() == 0) return false; 227 | const TShape& in = ishape->at(0); 228 | TShape filter; 229 | if (ishape->size() == 1) { 230 | // pooling 231 | CHECK_EQ(param.ksize.ndim(), 4); 232 | CHECK(param.ksize[0] == param.ksize[3] && param.ksize[0] == 1); 233 | filter = TShape{in[1], in[1], param.ksize[1], param.ksize[2]}; 234 | } else if (param.ksize.ndim() == 4 && param.num_filter != 0) { 235 | CHECK(param.ksize[0] == param.ksize[3] && param.ksize[0] == 1); 236 | filter = TShape{param.num_filter, in[1], param.ksize[1], param.ksize[2]}; 237 | SHAPE_ASSIGN(ishape->at(1), filter); 238 | } else { 239 | if (ishape->at(1).ndim() == 0) return false; 240 | filter = ishape->at(1); 241 | } 242 | if (ishape->size() > 2) { 243 | SHAPE_ASSIGN(ishape->at(2), TShape{filter[0]}); 244 | } 245 | CHECK(param.strides[0] == param.strides[3] && param.strides[0] == 1); 246 | uint32_t dH = param.strides[1]; 247 | uint32_t dW = param.strides[2]; 248 | uint32_t padW = 0, padH = 0; 249 | if (param.padding == "SAME") { 250 | padH = (filter[2] - 1) / 2; 251 | padW = (filter[3] - 1) / 2; 252 | } 253 | CHECK_EQ(in[1], filter[1]) 254 | << "in=" << in << ", filter=" << filter; 255 | // batch, out, height, width 256 | oshape->at(0) = TShape{in[0], filter[0], 257 | (in[2] + 2 * padH - filter[2]) / dH + 1, 258 | (in[3] + 2 * padW - filter[3]) / dW + 1}; 259 | return true; 260 | } 261 | 262 | NNVM_REGISTER_OP(conv2d) 263 | .describe("Convolution operation") 264 | .set_num_inputs([](const NodeAttrs& attrs){ 265 | return (dmlc::get(attrs.parsed).no_bias? 2 : 3); 266 | }) 267 | .set_attr_parser(ParamParser) 268 | .include("nn_module") 269 | .set_attr("FListInputNames", [](const NodeAttrs& attrs) { 270 | if (dmlc::get(attrs.parsed).no_bias) { 271 | return std::vector{"data", "weight"}; 272 | } else { 273 | return std::vector{"data", "weight", "bias"}; 274 | } 275 | }) 276 | .set_attr("FInferShape", ConvPoolShape) 277 | .set_attr("TBackwardNeedOutputs", false); 278 | 279 | 280 | NNVM_REGISTER_OP(max_pool) 281 | .describe("Max pooling") 282 | .set_num_inputs(1) 283 | .set_attr_parser(ParamParser) 284 | .include("nn_module") 285 | .set_attr("FInferShape", ConvPoolShape); 286 | 287 | 288 | NNVM_REGISTER_OP(avg_pool) 289 | .describe("Avg pooling") 290 | .set_num_inputs(1) 291 | .set_attr_parser(ParamParser) 292 | .include("nn_module") 293 | .set_attr("FInferShape", ConvPoolShape); 294 | 295 | 296 | struct BatchNormalizationParam : public dmlc::Parameter { 297 | std::string name; 298 | DMLC_DECLARE_PARAMETER(BatchNormalizationParam) { 299 | DMLC_DECLARE_FIELD(name).set_default("batch_normalization"); 300 | } 301 | }; 302 | DMLC_REGISTER_PARAMETER(BatchNormalizationParam); 303 | 304 | inline bool BatchNormalizationShape(const NodeAttrs& attrs, 305 | std::vector *ishape, 306 | std::vector *oshape) { 307 | if (ishape->at(0).ndim() == 0) return false; 308 | const TShape& in = ishape->at(0); 309 | CHECK_EQ(in.ndim(), 4); 310 | TShape mean = TShape{in[1]}; 311 | SHAPE_ASSIGN(ishape->at(1), mean); 312 | SHAPE_ASSIGN(ishape->at(2), mean); 313 | oshape->at(0) = in; 314 | return true; 315 | } 316 | 317 | NNVM_REGISTER_OP(batch_normalization) 318 | .describe("batch normalization") 319 | .set_num_inputs(3) 320 | .set_attr("FListInputNames", [](const NodeAttrs& attrs) { 321 | return std::vector{"data", "gamma", "beta"}; 322 | }) 323 | .set_attr_parser(ParamParser) 324 | .include("nn_module") 325 | .set_attr("FInferShape", BatchNormalizationShape); 326 | 327 | 328 | NNVM_REGISTER_OP(mean_sparse_softmax_cross_entropy_with_logits) 329 | .describe("Softmax cross entropy given logit and label") 330 | .set_num_inputs(2) 331 | .include("nn_criterion"); 332 | 333 | 334 | NNVM_REGISTER_OP(flatten_layer) 335 | .describe("Flatten to 2D") 336 | .set_num_inputs(1) 337 | .set_attr("FInplaceOption", InplaceIn0Out0) 338 | .set_attr( 339 | "FInferShape", [](const NodeAttrs& attrs, 340 | std::vector *ishape, 341 | std::vector *oshape) { 342 | const TShape& in = ishape->at(0); 343 | if (in.ndim() == 0) return false; 344 | TShape out{in[0], in.ProdShape(1, in.ndim())}; 345 | SHAPE_ASSIGN(oshape->at(0), out); 346 | return true; 347 | }) 348 | .set_attr( 349 | "FGradient", [](const NodePtr& n, 350 | const std::vector& ograds) { 351 | return MakeBackwardGrads("_flatten_backward", n, 352 | {ograds[0], n->inputs[0]}); 353 | }); 354 | 355 | NNVM_REGISTER_OP(_flatten_backward) 356 | .set_num_inputs(1) 357 | .set_attr("FInplaceOption", InplaceIn0Out0) 358 | .set_attr("TIsBackward", true); 359 | 360 | } // namespace tinyflow 361 | -------------------------------------------------------------------------------- /src/op_tensor.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2016 by Contributors 2 | // implementation of common tensor operators 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include "./op_util.h" 9 | 10 | namespace tinyflow { 11 | 12 | // shape given the ZeroParam 13 | using namespace nnvm; 14 | 15 | // shape parameter for zeros, ones 16 | struct ZeroParam : public dmlc::Parameter { 17 | TShape shape; 18 | int dtype; 19 | DMLC_DECLARE_PARAMETER(ZeroParam) { 20 | DMLC_DECLARE_FIELD(shape).set_default(TShape()); 21 | DMLC_DECLARE_FIELD(dtype).set_default(kFloat32); 22 | } 23 | }; 24 | DMLC_REGISTER_PARAMETER(ZeroParam); 25 | 26 | inline bool ZeroShape(const NodeAttrs& attrs, 27 | std::vector *ishape, 28 | std::vector *oshape) { 29 | const TShape& ts = dmlc::get(attrs.parsed).shape; 30 | if (ts.ndim() != 0) { 31 | SHAPE_ASSIGN(oshape->at(0), ts); 32 | return true; 33 | } else { 34 | return false; 35 | } 36 | } 37 | 38 | inline bool ZeroType(const NodeAttrs& attrs, 39 | std::vector *iattr, 40 | std::vector *oattr) { 41 | int dtype = dmlc::get(attrs.parsed).dtype; 42 | DTYPE_ASSIGN(oattr->at(0), dtype); 43 | return true; 44 | } 45 | 46 | 47 | NNVM_REGISTER_OP_GROUP(ElementwiseOpAttr) 48 | .set_attr("IsElementWise", true) 49 | .set_attr("FInferShape", SameShape); 50 | 51 | 52 | NNVM_REGISTER_OP(zeros) 53 | .describe("zeros") 54 | .set_num_inputs(0) 55 | .set_attr_parser(ParamParser) 56 | .set_attr("FInferShape", ZeroShape) 57 | .set_attr("FInferType", ZeroType); 58 | 59 | NNVM_REGISTER_OP(zeros_like) 60 | .describe("zeros_like") 61 | .set_num_inputs(1) 62 | .set_attr("FInferShape", SameShape); 63 | 64 | NNVM_REGISTER_OP(ones) 65 | .describe("ones") 66 | .set_num_inputs(0) 67 | .set_attr_parser(ParamParser) 68 | .set_attr("FInferShape", ZeroShape) 69 | .set_attr("FInferType", ZeroType); 70 | 71 | 72 | NNVM_REGISTER_OP(ones_like) 73 | .describe("ones_like") 74 | .set_num_inputs(1) 75 | .set_attr("FInferShape", SameShape); 76 | 77 | 78 | NNVM_REGISTER_OP(normal) 79 | .describe("normal distribution") 80 | .set_num_inputs(0) 81 | .set_attr_parser(ParamParser) 82 | .set_attr("FInferShape", ZeroShape) 83 | .set_attr("FInferType", ZeroType); 84 | 85 | 86 | NNVM_REGISTER_OP(equal) 87 | .describe("Equal comparitor") 88 | .set_num_inputs(2) 89 | .set_attr("FInferShape", SameShape); 90 | 91 | 92 | NNVM_REGISTER_OP(__ewise_sum__) 93 | .describe("ewise sum") 94 | .set_num_inputs(nnvm::kVarg) 95 | .set_attr("FInplaceOption", InplaceIn0Out0) 96 | .set_attr("FInferShape", SameShape) 97 | .set_attr( 98 | "FGradient", [](const NodePtr& n, 99 | const std::vector& ograds) { 100 | return std::vector(n->num_inputs(), ograds[0]); 101 | }); 102 | 103 | 104 | NNVM_REGISTER_OP(__add_symbol__) 105 | .describe("add two data together") 106 | .set_num_inputs(2) 107 | .include("ElementwiseOpAttr") 108 | .set_attr("FInplaceOption", InplaceIn0Out0) 109 | .set_attr( 110 | "FGradient", [](const NodePtr& n, 111 | const std::vector& ograds){ 112 | return std::vector{ograds[0], ograds[0]}; 113 | }); 114 | 115 | 116 | NNVM_REGISTER_OP(__add_scalar__) 117 | .describe("add symbol with scalar") 118 | .set_num_inputs(1) 119 | .include("ElementwiseOpAttr") 120 | .set_attr("FInplaceOption", InplaceIn0Out0) 121 | .set_attr( 122 | "FGradient", [](const NodePtr& n, 123 | const std::vector& ograds){ 124 | return std::vector{ograds[0]}; 125 | }); 126 | 127 | 128 | NNVM_REGISTER_OP(__sub_symbol__) 129 | .describe("do subtract") 130 | .set_num_inputs(2) 131 | .include("ElementwiseOpAttr") 132 | .set_attr("FInplaceOption", InplaceIn0Out0) 133 | .set_attr( 134 | "FGradient", [](const NodePtr& n, 135 | const std::vector& ograds){ 136 | return std::vector{ 137 | MakeNode("__mul_scalar__", n->attrs.name + "_grad_0", 138 | {ograds[0]}, {{"scalar", "1"}}), 139 | MakeNode("__mul_scalar__", n->attrs.name + "_grad_1", 140 | {ograds[0]}, {{"scalar", "-1"}}), 141 | }; 142 | }); 143 | 144 | 145 | NNVM_REGISTER_OP(__sub_scalar__) 146 | .describe("subtract symbol with scalar") 147 | .set_num_inputs(1) 148 | .include("ElementwiseOpAttr") 149 | .set_attr("FInplaceOption", InplaceIn0Out0) 150 | .set_attr( 151 | "FGradient", [](const NodePtr& n, 152 | const std::vector& ograds){ 153 | return std::vector{ograds[0]}; 154 | }); 155 | 156 | 157 | NNVM_REGISTER_OP(__rsub_scalar__) 158 | .describe("subtract scalar with symbol") 159 | .set_num_inputs(1) 160 | .include("ElementwiseOpAttr") 161 | .set_attr( 162 | "FGradient", [](const NodePtr& n, 163 | const std::vector& ograds){ 164 | return std::vector{ 165 | MakeNode("__mul_scalar__", n->attrs.name + "_grad_1", 166 | {ograds[0]}, {{"scalar", "-1"}}), 167 | }; 168 | }); 169 | 170 | 171 | NNVM_REGISTER_OP(mul) 172 | .add_alias("__mul_symbol__") 173 | .describe("add two data together") 174 | .set_num_inputs(2) 175 | .include("ElementwiseOpAttr") 176 | .set_attr("FInplaceOption", InplaceIn0Out0) 177 | .set_attr( 178 | "FGradient", [](const NodePtr& n, 179 | const std::vector& ograds){ 180 | return std::vector{ 181 | MakeNode("mul", n->attrs.name + "_grad_0", 182 | {ograds[0], n->inputs[1]}), 183 | MakeNode("mul", n->attrs.name + "_grad_1", 184 | {ograds[0], n->inputs[0]}) 185 | }; 186 | }); 187 | 188 | 189 | NNVM_REGISTER_OP(__mul_scalar__) 190 | .describe("Multiply symbol with scalar") 191 | .set_num_inputs(1) 192 | .include("ElementwiseOpAttr") 193 | .set_attr("FInplaceOption", InplaceIn0Out0) 194 | .set_attr( 195 | "FGradient", [](const NodePtr& n, 196 | const std::vector& ograds){ 197 | return std::vector{ 198 | MakeNode("__mul_scalar__", n->attrs.name + "_grad_0", 199 | {ograds[0]}, {{"scalar", n->attrs.dict["scalar"]}}), 200 | }; 201 | }); 202 | 203 | 204 | NNVM_REGISTER_OP(__div_symbol__) 205 | .add_alias("div") 206 | .describe("do division") 207 | .set_num_inputs(2) 208 | .include("ElementwiseOpAttr") 209 | .set_attr("FInplaceOption", InplaceIn0Out0) 210 | .set_attr( 211 | "FGradient", [](const NodePtr& n, 212 | const std::vector& ograds){ 213 | NodeEntry n1 = MakeNode("mul", n->attrs.name + "_grad_sub_0", 214 | {ograds[0], n->inputs[0]}); 215 | NodeEntry n2 = MakeNode("__mul_scalar__", n->attrs.name + "_grad_sub_1", 216 | {n1}, {{"scalar", "-1"}}); 217 | NodeEntry n3 = MakeNode("mul", n->attrs.name + "_grad_sub_2", 218 | {n->inputs[1], n->inputs[1]}); 219 | return std::vector{ 220 | MakeNode("__div_symbol__", n->attrs.name + "_grad_0", 221 | {ograds[0], n->inputs[1]}), 222 | MakeNode("__div_symbol__", n->attrs.name + "_grad_1", 223 | {n1, n2}) 224 | }; 225 | }); 226 | 227 | 228 | NNVM_REGISTER_OP(__div_scalar__) 229 | .describe("division symbol with scalar") 230 | .set_num_inputs(1) 231 | .include("ElementwiseOpAttr") 232 | .set_attr("FInplaceOption", InplaceIn0Out0) 233 | .set_attr( 234 | "FGradient", [](const NodePtr& n, 235 | const std::vector& ograds){ 236 | return std::vector{ 237 | MakeNode("__div_scalar__", n->attrs.name + "_grad_0", 238 | {ograds[0]}, {{"scalar", n->attrs.dict["scalar"]}}), 239 | }; 240 | }); 241 | 242 | 243 | NNVM_REGISTER_OP(exp) 244 | .describe("take elemtnwise exponation") 245 | .set_num_inputs(1) 246 | .include("ElementwiseOpAttr") 247 | .set_attr("FInplaceOption", InplaceIn0Out0) 248 | .set_attr( 249 | "FGradient", [](const NodePtr& n, 250 | const std::vector& ograds) { 251 | return std::vector{ 252 | MakeNode("__mul_symbol__", n->attrs.name + "_grad_0", 253 | {ograds[0], NodeEntry{n, 0, 0}}) 254 | }; 255 | }); 256 | 257 | 258 | NNVM_REGISTER_OP(log) 259 | .describe("take elemtnwise logarithm") 260 | .set_num_inputs(1) 261 | .include("ElementwiseOpAttr") 262 | .set_attr("FInplaceOption", InplaceIn0Out0) 263 | .set_attr( 264 | "FGradient", [](const NodePtr& n, 265 | const std::vector& ograds) { 266 | return std::vector{ 267 | MakeNode("__div_symbol__", n->attrs.name + "_grad_0", 268 | {ograds[0], n->inputs[0]}) 269 | }; 270 | }); 271 | 272 | 273 | NNVM_REGISTER_OP(sqrt) 274 | .describe("return square root of input") 275 | .set_num_inputs(1) 276 | .include("ElementwiseOpAttr") 277 | .set_attr("FInplaceOption", InplaceIn0Out0) 278 | .set_attr( 279 | // 1 / (2 * sqrt(x)) == 1 / (2 * y) 280 | "FGradient", [](const NodePtr& n, 281 | const std::vector& ograds) { 282 | NodeEntry n1 = MakeNode("__mul_scalar__", n->attrs.name + "_grad_sub_1", 283 | {NodeEntry{n, 0, 0}}, {{"scalar", "2"}}); 284 | return std::vector{ 285 | MakeNode("__div_symbol__", n->attrs.name + "_grad_0", 286 | {ograds[0], n1}) 287 | }; 288 | }); 289 | 290 | 291 | NNVM_REGISTER_OP(__pow_symbol__) 292 | .add_alias("pow") 293 | .describe("take elmtnwise power between two tensor") 294 | .set_num_inputs(2) 295 | .include("ElementwiseOpAttr") 296 | .set_attr("FInplaceOption", InplaceIn0Out0) 297 | .set_attr( 298 | "FGradient", [](const NodePtr& n, 299 | const std::vector& ograds) { 300 | // lhs: b*pow(a, b-1), rhs: pow(a, b)*ln(a) 301 | NodeEntry n0 = MakeNode("__add_scalar__", n->attrs.name + "_grad_sub_0", 302 | {n->inputs[1]}, {{"scalar", "-1"}}); 303 | NodeEntry n1 = MakeNode("pow", n->attrs.name + "_grad_sub_1", 304 | {n->inputs[0], n0}); 305 | NodeEntry d_lhs = MakeNode("mul", n->attrs.name + "_grad_sub_2", 306 | {n1, n->inputs[1]}); 307 | NodeEntry n2 = MakeNode("log", n->attrs.name + "_grad_sub_3", 308 | {n->inputs[0]}); 309 | NodeEntry d_rhs = MakeNode("mul", n->attrs.name + "_grad_sub_4", 310 | {NodeEntry{n, 0, 0}, n2}); 311 | return std::vector{ 312 | MakeNode("__mul_symbol__", n->attrs.name + "_grad_0", 313 | {ograds[0], d_lhs}), 314 | MakeNode("__mul_symbol__", n->attrs.name + "_grad_1", 315 | {ograds[0], d_rhs}) 316 | }; 317 | 318 | }); 319 | 320 | 321 | NNVM_REGISTER_OP(__rpow_scalar__) 322 | .describe("take elmtnwise power between a number and a tensor") 323 | .set_num_inputs(1) 324 | .include("ElementwiseOpAttr") 325 | .set_attr("FInplaceOption", InplaceIn0Out0) 326 | .set_attr( 327 | "FGradient", [](const NodePtr& n, 328 | const std::vector& ograds) { 329 | // pow(m, x) * ln(m) 330 | double num = std::stod(n->attrs.dict["scalar"]); 331 | NodeEntry n0 = MakeNode("__mul_scalar__", n->attrs.name + "_grad_sub_4", 332 | {NodeEntry{n, 0, 0}}, {{"scalar", std::to_string(std::log(num))}}); 333 | return std::vector{ 334 | MakeNode("__mul_symbol__", n->attrs.name + "_grad_0", 335 | {ograds[0], n0}) 336 | }; 337 | }); 338 | 339 | 340 | NNVM_REGISTER_OP(matmul) 341 | .describe("Matrix multiplication") 342 | .set_num_inputs(2) 343 | .set_attr( 344 | "FInferShape", [](const NodeAttrs& attrs, 345 | std::vector *ishape, 346 | std::vector *oshape) { 347 | if (ishape->at(0).ndim() == 0) return false; 348 | if (ishape->at(1).ndim() == 0) return false; 349 | CHECK_EQ(ishape->at(0).ndim(), 2); 350 | CHECK_EQ(ishape->at(1).ndim(), 2); 351 | CHECK_EQ(ishape->at(0)[1], ishape->at(1)[0]); 352 | TShape target{ishape->at(0)[0], ishape->at(1)[1]}; 353 | SHAPE_ASSIGN(oshape->at(0), target); 354 | return true; 355 | }) 356 | .set_attr( 357 | "FGradient", [](const NodePtr& n, 358 | const std::vector& ograds) { 359 | return MakeBackwardGrads("_matmul_backward", n, 360 | {ograds[0], n->inputs[0], n->inputs[1]}); 361 | }); 362 | 363 | 364 | // simply register a bulk op for backward 365 | NNVM_REGISTER_OP(_matmul_backward) 366 | .set_num_inputs(3) 367 | .set_num_outputs(2) 368 | .set_attr("TIsBackward", true); 369 | 370 | struct ReduceParam : public dmlc::Parameter { 371 | Tuple reduction_indices; 372 | DMLC_DECLARE_PARAMETER(ReduceParam) { 373 | DMLC_DECLARE_FIELD(reduction_indices).set_default(Tuple()); 374 | } 375 | }; 376 | DMLC_REGISTER_PARAMETER(ReduceParam); 377 | 378 | 379 | inline bool ReduceShape(const NodeAttrs& attrs, 380 | std::vector *ishape, 381 | std::vector *oshape) { 382 | const auto& axis 383 | = dmlc::get(attrs.parsed).reduction_indices; 384 | if (ishape->at(0).ndim() == 0) return false; 385 | if (axis.ndim() == 0) { 386 | SHAPE_ASSIGN(oshape->at(0), TShape{1}); 387 | } else { 388 | TShape tmp = ishape->at(0); 389 | for (uint32_t idx : axis) { 390 | tmp[idx] = 0; 391 | } 392 | std::vector ret; 393 | for (uint32_t x : tmp) { 394 | if (x != 0) ret.push_back(x); 395 | } 396 | if (ret.size() == 0) ret.push_back(1); 397 | SHAPE_ASSIGN(oshape->at(0), TShape(ret.begin(), ret.end())); 398 | } 399 | return true; 400 | } 401 | 402 | 403 | NNVM_REGISTER_OP(reduce_sum) 404 | .describe("reduce sum") 405 | .set_attr_parser(ParamParser) 406 | .set_num_inputs(1) 407 | .set_attr("FInferShape", ReduceShape) 408 | .set_attr( 409 | "FGradient", [](const NodePtr& n, 410 | const std::vector& ograds) { 411 | return MakeBackwardGrads("_reduce_sum_backward", n, 412 | {ograds[0]}, n->attrs.dict); 413 | }); 414 | 415 | 416 | NNVM_REGISTER_OP(reduce_mean) 417 | .describe("reduce mean") 418 | .set_attr_parser(ParamParser) 419 | .set_num_inputs(1) 420 | .set_attr("FInferShape", ReduceShape) 421 | .set_attr( 422 | "FGradient", [](const NodePtr& n, 423 | const std::vector& ograds) { 424 | return MakeBackwardGrads("_reduce_mean_backward", n, 425 | {ograds[0]}, n->attrs.dict); 426 | }); 427 | 428 | 429 | NNVM_REGISTER_OP_GROUP(ReduceBackwardIndeAttr) 430 | .set_attr("TIsBackward", true); 431 | 432 | 433 | NNVM_REGISTER_OP(_reduce_sum_backward) 434 | .set_num_inputs(1) 435 | .set_num_outputs(1) 436 | .include("ReduceBackwardIndeAttr"); 437 | 438 | 439 | NNVM_REGISTER_OP(_reduce_mean_backward) 440 | .set_num_inputs(1) 441 | .set_num_outputs(1) 442 | .include("ReduceBackwardIndeAttr"); 443 | 444 | 445 | NNVM_REGISTER_OP(_argmax) 446 | .set_attr_parser(ParamParser) 447 | .set_num_inputs(1) 448 | .set_attr("FInferShape", ReduceShape); 449 | 450 | } // namespace tinyflow 451 | -------------------------------------------------------------------------------- /src/session.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2016 by Contributors 2 | #include 3 | #include 4 | #if TINYFLOW_USE_FUSION == 1 5 | #include 6 | #include 7 | #endif 8 | #include 9 | #include 10 | #include "./op_util.h" 11 | #include "./torch/torch_util.h" 12 | 13 | namespace tinyflow { 14 | 15 | using dmlc::any; 16 | using nnvm::Graph; 17 | using nnvm::IndexedGraph; 18 | using nnvm::ShapeVector; 19 | using nnvm::DTypeVector; 20 | using nnvm::StorageVector; 21 | #if TINYFLOW_USE_FUSION == 1 22 | using nnvm::fusion::RTC; 23 | using nnvm::fusion::RTCMap; 24 | #endif 25 | 26 | class TorchExecutor; 27 | 28 | /*! \brief shared variable */ 29 | struct VarState { 30 | /*! \brief The internal internal tensor */ 31 | LuaRef tensor; 32 | /*! \brief The corresponding tblob */ 33 | TBlob blob; 34 | 35 | /*! \return Whether the tensor is initialized already */ 36 | inline bool initialized() const { 37 | return !tensor.is_nil(); 38 | } 39 | // reset the space. 40 | inline void ResetSpace(TShape shape, int dev_mask = kCPU, int dtype = 0) { 41 | if (tensor.is_nil() || 42 | shape != blob.shape || 43 | dev_mask != blob.dev_mask || 44 | dtype != blob.dtype) { 45 | TorchState* th = TorchState::ThreadLocalState(); 46 | if (tensor.is_nil()) { 47 | tensor = th->NewTensorEmpty(dev_mask, dtype); 48 | } 49 | th->ResetStorage( 50 | tensor, th->NewStorage(shape.Size(), dev_mask, dtype), shape); 51 | this->blob = th->GetTBlob(tensor); 52 | } 53 | } 54 | }; 55 | 56 | 57 | // shared variable map structure 58 | using VarStateMap = std::unordered_map >; 59 | // operator executor closures 60 | using FOpExec = std::function; 61 | 62 | // torch session. 63 | class TorchSession : public Session { 64 | public: 65 | // simple session that binds to one device. 66 | explicit TorchSession(const std::string& config) { 67 | if (config.find("gpu") != std::string::npos) { 68 | default_dev_mask_ = kGPU; 69 | if (config.find("fusion") != std::string::npos) { 70 | enable_fusion_ = true; 71 | } 72 | } 73 | } 74 | const std::vector& 75 | Run(nnvm::Symbol* sym, 76 | const std::unordered_map& inputs) override; 77 | 78 | private: 79 | // entry to store cached executor 80 | struct ExecEntry { 81 | nnvm::Symbol cached_symbol; 82 | std::shared_ptr exec; 83 | size_t use_count{0}; 84 | }; 85 | int default_dev_mask_{kCPU}; 86 | bool enable_fusion_{false}; 87 | // local cached variable states. 88 | VarStateMap states_; 89 | // cached executor 90 | std::unordered_map cached_execs_; 91 | }; 92 | 93 | 94 | class TorchExecutor { 95 | public: 96 | // initialize the executor 97 | // possibly update the states. 98 | void Init(nnvm::Symbol symbol, VarStateMap* states, int default_dev_mask, bool enable_fusion); 99 | /// run the executor, return the outputs. 100 | const std::vector& Run(const std::unordered_map& inputs); 101 | // return corresponding internal symbol 102 | inline const nnvm::Symbol& symbol() const { 103 | return symbol_; 104 | } 105 | 106 | private: 107 | // setup the executor space. 108 | void SetupAuxiliaryMembers(); 109 | void ClearAuxiliaryMembers(); 110 | void Setup(const std::unordered_map& inputs); 111 | void SetupShapeDType(const std::unordered_map& inputs, bool* need_redo_infer); 112 | void SetupStorage(); 113 | void SetupOpExecs(); 114 | #if TINYFLOW_USE_FUSION == 1 115 | FOpExec GenerateRTCClosure(RTC& rtc, 116 | const std::vector& input_luaref, std::vector& output_luaref); 117 | #endif 118 | // internal symbol and graph 119 | nnvm::Symbol symbol_; 120 | nnvm::Graph graph_; 121 | // variable states map. 122 | VarStateMap* var_states_; 123 | // shape vector in graph attribute 124 | const ShapeVector* node_shape_{nullptr}; 125 | // type vector in graph attribute 126 | const DTypeVector* node_dtype_{nullptr}; 127 | #if TINYFLOW_USE_FUSION == 1 128 | // map nid->rtc 129 | RTCMap* node_rtc_{nullptr}; 130 | #endif 131 | // ---------------------------- 132 | // node auxiliary data structures 133 | // The device of this executor 134 | int dev_mask_{kGPU}; 135 | // whether to enable fusion 136 | bool enable_fusion_; 137 | // node id of place holder ops 138 | std::vector placeholder_nids_; 139 | // size of number of node, placeholder_tblobs_[nid].data != nullptr 140 | // if nid is a placeholder and the content is the corresponding TBlob to be copied in. 141 | std::vector placeholder_tblobs_; 142 | // node id of variable that is assigned in this executor 143 | std::vector assign_var_nids_; 144 | // node id of variable that is readed by this executor 145 | // can overlap with assign_var_nids_ 146 | std::vector read_var_nids_; 147 | // vector maps nid->state, nullptr for non variables. 148 | std::vector node_states_; 149 | // ---------------------------- 150 | // execution information 151 | // data of each outputs 152 | std::vector data_entry_; 153 | // whether data entry is variable. 154 | std::vector data_entry_is_var_; 155 | // internal storage space. 156 | std::vector storage_pool_; 157 | // operator executor closures 158 | std::vector op_execs_; 159 | // lua module states of each operator. 160 | std::vector op_exec_modules_; 161 | // The storage space to hold outputs. 162 | std::vector outputs_; 163 | std::vector output_blobs_; 164 | }; 165 | 166 | Session* Session::Create(const std::string& option) { 167 | return new TorchSession(option); 168 | } 169 | 170 | const std::vector& TorchSession::Run( 171 | nnvm::Symbol* new_sym, 172 | const std::unordered_map& inputs) { 173 | // compute the hash value 174 | uint64_t hash_value = new_sym->outputs.size(); 175 | for (NodeEntry& e : new_sym->outputs) { 176 | uint64_t value = reinterpret_cast(e.node.get()); 177 | hash_value ^= value + 0x9e3779b9 + (hash_value << 6) + (hash_value >> 2); 178 | } 179 | if (cached_execs_.count(hash_value) != 0) { 180 | auto& entry = cached_execs_.at(hash_value); 181 | const nnvm::Symbol& old_sym = entry.cached_symbol; 182 | bool stale_exec = (old_sym.outputs.size() != new_sym->outputs.size()); 183 | if (!stale_exec) { 184 | for (size_t i = 0; i < old_sym.outputs.size(); ++i) { 185 | if (old_sym.outputs[i].node.get() != new_sym->outputs[i].node.get() || 186 | old_sym.outputs[i].index != new_sym->outputs[i].index || 187 | old_sym.outputs[i].version != new_sym->outputs[i].version) { 188 | stale_exec = true; break; 189 | } 190 | } 191 | } 192 | if (!stale_exec) { 193 | ++entry.use_count; 194 | return entry.exec->Run(inputs); 195 | } else { 196 | cached_execs_.erase(hash_value); 197 | } 198 | } 199 | // dump technique, remove all previous executors 200 | // better strategy, LRU? 201 | cached_execs_.clear(); 202 | ExecEntry e; 203 | e.cached_symbol = *new_sym; 204 | e.exec = std::make_shared(); 205 | e.exec->Init(*new_sym, &states_, default_dev_mask_, enable_fusion_); 206 | cached_execs_[hash_value] = e; 207 | return e.exec->Run(inputs); 208 | } 209 | 210 | void TorchExecutor::Init(nnvm::Symbol symbol, 211 | VarStateMap* states, 212 | int default_dev_mask, 213 | bool enable_fusion) { 214 | dev_mask_ = default_dev_mask; 215 | if (dev_mask_ == kGPU) TorchState::ThreadLocalState()->InitGPU(); 216 | enable_fusion_ = enable_fusion; 217 | graph_.outputs = symbol.outputs; 218 | symbol_.outputs = graph_.outputs; 219 | var_states_ = states; 220 | SetupAuxiliaryMembers(); 221 | } 222 | 223 | void TorchExecutor::SetupAuxiliaryMembers() { 224 | // initialize all node auxiliary data structures. 225 | const Op* assign_op = Op::Get("assign"); 226 | const Op* placeholder_op = Op::Get("placeholder"); 227 | const auto& idx = graph_.indexed_graph(); 228 | node_states_.resize(idx.num_nodes(), nullptr); 229 | 230 | std::vector read_count(idx.num_nodes(), 0); 231 | std::vector assign_count(idx.num_nodes(), 0); 232 | placeholder_tblobs_.resize(idx.num_nodes()); 233 | 234 | for (uint32_t i = idx.num_nodes(); i != 0; --i) { 235 | uint32_t nid = i - 1; 236 | auto& inode = idx[nid]; 237 | if (inode.source->is_variable()) { 238 | const std::string& key = inode.source->attrs.name; 239 | if (var_states_->count(key) == 0) { 240 | (*var_states_)[key] = std::make_shared(); 241 | } 242 | node_states_[nid] = var_states_->at(key).get(); 243 | if (read_count[nid] != 0 || assign_count[nid] == 0) { 244 | read_var_nids_.push_back(nid); 245 | } 246 | if (assign_count[nid] != 0) { 247 | assign_var_nids_.push_back(nid); 248 | } 249 | } else { 250 | if (inode.source->op() == placeholder_op) { 251 | placeholder_nids_.push_back(nid); 252 | } else if (inode.source->op() == assign_op) { 253 | CHECK_EQ(inode.inputs.size(), 2); 254 | ++read_count[inode.inputs[1].node_id]; 255 | ++assign_count[inode.inputs[0].node_id]; 256 | } else { 257 | for (auto e : inode.inputs) { 258 | ++read_count[e.node_id]; 259 | } 260 | } 261 | } 262 | } 263 | } 264 | 265 | void TorchExecutor::ClearAuxiliaryMembers() { 266 | placeholder_nids_.clear(); 267 | placeholder_tblobs_.clear(); 268 | assign_var_nids_.clear(); 269 | read_var_nids_.clear(); 270 | node_states_.clear(); 271 | } 272 | 273 | const std::vector& 274 | TorchExecutor::Run(const std::unordered_map& inputs) { 275 | Setup(inputs); 276 | { 277 | // execution 278 | const auto& idx = graph_.indexed_graph(); 279 | auto* th = TorchState::ThreadLocalState(); 280 | for (size_t i = 0; i < op_execs_.size(); ++i) { 281 | // copy in place holder as demanded. 282 | if (placeholder_tblobs_[i].data != nullptr) { 283 | th->CopyFromTo(th->NewTensorShared(placeholder_tblobs_[i]), 284 | data_entry_[idx.entry_id(i, 0)]); 285 | } 286 | try { 287 | // TODO op_execs_[i].nil()? 288 | if (op_execs_[i]) { 289 | op_execs_[i](); 290 | } 291 | } catch (dmlc::Error e) { 292 | LOG(INFO) << "error catched in op " << idx[i].source->op()->name; 293 | throw e; 294 | } 295 | } 296 | } 297 | { 298 | // copy outputs 299 | output_blobs_.clear(); 300 | auto* th = TorchState::ThreadLocalState(); 301 | const auto& idx = graph_.indexed_graph(); 302 | for (size_t i = 0; i < outputs_.size(); ++i) { 303 | uint32_t eid = idx.entry_id(idx.outputs()[i]); 304 | th->CopyFromTo(data_entry_[eid], outputs_[i]); 305 | output_blobs_.push_back(th->GetTBlob(outputs_[i])); 306 | } 307 | } 308 | return output_blobs_; 309 | } 310 | 311 | void TorchExecutor::Setup(const std::unordered_map& inputs) { 312 | bool need_redo_infer; 313 | SetupShapeDType(inputs, &need_redo_infer); 314 | #if TINYFLOW_USE_FUSION == 1 315 | if (enable_fusion_ && need_redo_infer) { 316 | graph_ = ApplyPasses(std::move(graph_), {"Fusion", "CodeGen", "RTCGen"}); 317 | node_rtc_ = const_cast(&(graph_.GetAttr("rtc"))); 318 | ClearAuxiliaryMembers(); 319 | SetupAuxiliaryMembers(); 320 | 321 | node_shape_ = nullptr; 322 | node_dtype_ = nullptr; 323 | SetupShapeDType(inputs, &need_redo_infer); 324 | } 325 | #endif 326 | if (need_redo_infer) SetupStorage(); 327 | if (need_redo_infer) { 328 | op_execs_.clear(); 329 | op_exec_modules_.clear(); 330 | SetupOpExecs(); 331 | } 332 | { 333 | // copy inputs 334 | const auto& idx = graph_.indexed_graph(); 335 | for (uint32_t nid : placeholder_nids_) { 336 | const std::string& key = idx[nid].source->attrs.name; 337 | const TBlob& value = inputs.at(key); 338 | placeholder_tblobs_[nid] = value; 339 | } 340 | } 341 | } 342 | 343 | void TorchExecutor::SetupShapeDType( 344 | const std::unordered_map& inputs, 345 | bool* p_need_redo_infer) { 346 | const auto& idx = graph_.indexed_graph(); 347 | bool& need_redo_infer = *p_need_redo_infer; 348 | need_redo_infer = (node_shape_ == nullptr); 349 | 350 | // check the variable states 351 | if (!need_redo_infer) { 352 | CHECK(node_dtype_ != nullptr); 353 | for (uint32_t nid : read_var_nids_) { 354 | VarState* state = node_states_[nid]; 355 | CHECK(state != nullptr); 356 | CHECK(state->initialized()) 357 | << "Attempt to execute a graph un-initialized Variable"; 358 | if (node_shape_->at(idx.entry_id(nid, 0)) != state->blob.shape) { 359 | need_redo_infer = true; break; 360 | } 361 | if (node_dtype_->at(idx.entry_id(nid, 0)) != state->blob.dtype) { 362 | need_redo_infer = true; break; 363 | } 364 | } 365 | } 366 | // check placeholder shapes. 367 | if (!need_redo_infer) { 368 | for (uint32_t nid : placeholder_nids_) { 369 | const std::string& key = idx[nid].source->attrs.name; 370 | CHECK(inputs.count(key)) 371 | << "Not enought placeholder argument to feed_dict"; 372 | const TBlob& value = inputs.at(key); 373 | if (node_shape_->at(idx.entry_id(nid, 0)) != value.shape) { 374 | need_redo_infer = true; break; 375 | } 376 | if (node_dtype_->at(idx.entry_id(nid, 0)) != value.dtype) { 377 | need_redo_infer = true; break; 378 | } 379 | } 380 | } 381 | 382 | if (!need_redo_infer) return; 383 | // run shape inference. 384 | ShapeVector new_shape(idx.num_node_entries(), TShape()); 385 | DTypeVector new_dtype(idx.num_node_entries(), -1); 386 | 387 | for (uint32_t nid : read_var_nids_) { 388 | VarState* state = node_states_[nid]; 389 | // TODO more strict rule 390 | if (state->initialized()) { 391 | new_shape[idx.entry_id(nid, 0)] = state->blob.shape; 392 | new_dtype[idx.entry_id(nid, 0)] = state->blob.dtype; 393 | } else if (std::find(assign_var_nids_.cbegin(), 394 | assign_var_nids_.cend(), nid) == assign_var_nids_.cend()) { 395 | CHECK(state->initialized()) 396 | << "Attempt to execute a graph un-initialized Variable"; 397 | } 398 | } 399 | for (uint32_t nid : placeholder_nids_) { 400 | const std::string& key = idx[nid].source->attrs.name; 401 | const TBlob& value = inputs.at(key); 402 | new_shape[idx.entry_id(nid, 0)] = value.shape; 403 | new_dtype[idx.entry_id(nid, 0)] = value.dtype; 404 | } 405 | graph_.attrs["shape"] = std::make_shared(std::move(new_shape)); 406 | graph_.attrs["dtype"] = std::make_shared(std::move(new_dtype)); 407 | graph_ = ApplyPasses(std::move(graph_), {"InferShape", "InferType"}); 408 | CHECK_EQ(graph_.GetAttr("shape_num_unknown_nodes"), 0) 409 | << "Shape information in the graph is in-complete"; 410 | CHECK_EQ(graph_.GetAttr("dtype_num_unknown_nodes"), 0) 411 | << "Type information in the graph is in-complete"; 412 | node_shape_ = &(graph_.GetAttr("shape")); 413 | node_dtype_ = &(graph_.GetAttr("dtype")); 414 | // setup out Variable space. 415 | for (uint32_t nid : assign_var_nids_) { 416 | node_states_[nid]->ResetSpace( 417 | node_shape_->at(idx.entry_id(nid, 0)), 418 | dev_mask_, 419 | node_dtype_->at(idx.entry_id(nid, 0))); 420 | } 421 | } 422 | 423 | void TorchExecutor::SetupStorage() { 424 | const auto& idx = graph_.indexed_graph(); 425 | if (storage_pool_.size() == 0) { 426 | graph_ = nnvm::ApplyPass(std::move(graph_), "PlanMemory"); 427 | } 428 | const auto& vstorage = graph_.GetAttr("storage_id"); 429 | const auto& vshape = graph_.GetAttr("shape"); 430 | auto* th = TorchState::ThreadLocalState(); 431 | 432 | if (data_entry_.size() == 0) { 433 | data_entry_.resize(idx.num_node_entries()); 434 | data_entry_is_var_.resize(idx.num_node_entries(), false); 435 | for (size_t i = 0; i < data_entry_.size(); ++i) { 436 | data_entry_[i] = th->NewTensorEmpty(dev_mask_); 437 | } 438 | for (uint32_t nid : idx.input_nodes()) { 439 | CHECK(node_states_[nid] != nullptr); 440 | data_entry_[idx.entry_id(nid, 0)] = node_states_[nid]->tensor; 441 | data_entry_is_var_[idx.entry_id(nid, 0)] = true; 442 | } 443 | } 444 | 445 | 446 | // size of each storage pool entry 447 | std::vector pool_entry_size; 448 | for (size_t i = 0; i < vshape.size(); ++i) { 449 | if (data_entry_is_var_[i]) continue; 450 | int storage_id = vstorage[i]; 451 | size_t size = vshape[i].Size(); 452 | CHECK_GE(storage_id, 0) << "Do not support runtime shape op yet"; 453 | size_t sid = static_cast(storage_id); 454 | if (sid >= pool_entry_size.size()) { 455 | pool_entry_size.resize(sid + 1, 0); 456 | } 457 | pool_entry_size[sid] = std::max(pool_entry_size[sid], size); 458 | } 459 | storage_pool_.clear(); 460 | for (size_t i = 0; i < pool_entry_size.size(); ++i) { 461 | storage_pool_.push_back( 462 | th->NewStorage(pool_entry_size[i], dev_mask_)); 463 | } 464 | // assign pooled data to entry 465 | for (size_t i = 0; i < data_entry_.size(); ++i) { 466 | if (data_entry_is_var_[i]) continue; 467 | int storage_id = vstorage[i]; 468 | th->ResetStorage(data_entry_[i], storage_pool_.at(storage_id), vshape[i]); 469 | } 470 | 471 | outputs_.resize(idx.outputs().size()); 472 | for (size_t i = 0; i < outputs_.size(); ++i) { 473 | uint32_t eid = idx.entry_id(idx.outputs()[i]); 474 | LuaRef t = th->NewTensorEmpty(kCPU); 475 | th->ResetStorage(t, th->NewStorage(vshape[eid].Size(), kCPU), vshape[eid]); 476 | outputs_[i] = t; 477 | } 478 | } 479 | 480 | void TorchExecutor::SetupOpExecs() { 481 | // a slightly big function to setup execution functors 482 | // We can separate some logics into a new pass later. 483 | auto* lua = LuaState::ThreadLocalState(); 484 | const auto& idx = graph_.indexed_graph(); 485 | const auto& lua_create_module = 486 | nnvm::Op::GetAttr("FLuaCreateNNModule"); 487 | const auto& lua_compute_code = 488 | nnvm::Op::GetAttr("FLuaCompute"); 489 | LuaRef lempty_tensor = lua->Eval(R"( 490 | return 491 | function(dev_mask) 492 | local empty = torch.FloatTensor() 493 | if dev_mask == 2 then 494 | empty = empty:cuda() 495 | end 496 | return empty 497 | end 498 | )")(dev_mask_); 499 | LuaRef fremove_module_storage = lua->Eval(R"( 500 | return 501 | function(m, dev_mask, empty) 502 | if dev_mask == 2 then 503 | if torch.isTypeOf(m, nn.Criterion) then 504 | return m:cuda() 505 | end 506 | local net = nn.Sequential():add(m):cuda() 507 | net = cudnn.convert(net, cudnn) 508 | return net.modules[1] 509 | end 510 | if torch.isTypeOf(m, nn.Module) then 511 | local W, gW = m:parameters() 512 | if W ~= nil then 513 | for i, t in ipairs(W) do 514 | t:set(empty) 515 | end 516 | for i, t in ipairs(gW) do 517 | t:set(empty) 518 | end 519 | end 520 | end 521 | return m 522 | end 523 | )"); 524 | LuaRef fcreate_nnforward_closure = lua->Eval(R"( 525 | return 526 | function(m, input, output, weight) 527 | if torch.isTypeOf(m, nn.Module) then 528 | if m:parameters() ~= nil then 529 | return function() 530 | local W, gW = m:parameters() 531 | for i, t in ipairs(W) do 532 | t:set(weight[i]) 533 | end 534 | m.output:set(output) 535 | m:updateOutput(input) 536 | if not m.output:isSetTo(output) then 537 | output:copy(m.output) 538 | m.output:set(output) 539 | end 540 | end 541 | else 542 | return function() 543 | m.output:set(output) 544 | m:updateOutput(input) 545 | if not m.output:isSetTo(output) then 546 | output:copy(m.output) 547 | m.output:set(output) 548 | end 549 | end 550 | end 551 | else 552 | target = weight[1] 553 | assert(torch.isTypeOf(m, nn.Criterion)) 554 | return function() 555 | local x = m:updateOutput(input, target) 556 | output:fill(x) 557 | end 558 | end 559 | end 560 | )"); 561 | LuaRef fcreate_nnbackward_closure = lua->Eval(R"( 562 | return 563 | function(m, input, output, weight, gradInput, gradOutput, gradWeight) 564 | if torch.isTypeOf(m, nn.Module) then 565 | if m:parameters() ~= nil then 566 | return function() 567 | local W, gW = m:parameters() 568 | for i, t in ipairs(W) do 569 | t:set(weight[i]) 570 | end 571 | for i, t in ipairs(gW) do 572 | t:set(gradWeight[i]) 573 | end 574 | m.output:set(output) 575 | m.gradInput:set(gradInput) 576 | m:zeroGradParameters() 577 | m:accGradParameters(input, gradOutput, 1) 578 | m:updateGradInput(input, gradOutput) 579 | if not m.gradInput:isSetTo(gradInput) then 580 | gradInput:copy(m.gradInput) 581 | m.gradInput:set(gradInput) 582 | end 583 | for i, t in ipairs(gW) do 584 | if not t:isSetTo(gradWeight[i]) then 585 | gradWeight[i]:copy(t) 586 | t:set(gradWeight[i]) 587 | end 588 | end 589 | end 590 | else 591 | return function() 592 | m.output:set(output) 593 | m.gradInput:set(gradInput) 594 | m:updateGradInput(input, gradOutput) 595 | if not m.gradInput:isSetTo(gradInput) then 596 | gradInput:copy(m.gradInput) 597 | m.gradInput:set(gradInput) 598 | end 599 | end 600 | end 601 | else 602 | assert(torch.isTypeOf(m, nn.Criterion)) 603 | target = weight[1] 604 | return function() 605 | m.gradInput:set(gradInput) 606 | m:updateGradInput(input, target) 607 | if not m.gradInput:isSetTo(gradInput) then 608 | gradInput:copy(m.gradInput) 609 | m.gradInput:set(gradInput) 610 | end 611 | end 612 | end 613 | end 614 | )"); 615 | 616 | op_exec_modules_.resize(idx.num_nodes()); 617 | // setup torch.nn modules when available. 618 | // setup the array and requirements. 619 | for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { 620 | const auto& inode = idx[nid]; 621 | if (inode.source->is_variable()) continue; 622 | std::string lua_code; 623 | if (lua_create_module.count(inode.source->op())) { 624 | lua_code = "return " + lua_create_module[inode.source->op()]; 625 | LuaRef fcreate = lua->Eval(lua_code); 626 | std::vector ishape; 627 | for (auto& e : inode.inputs) { 628 | ishape.push_back(node_shape_->at(idx.entry_id(e))); 629 | } 630 | op_exec_modules_[nid] = fremove_module_storage( 631 | fcreate(ishape, inode.source->attrs.dict), dev_mask_, lempty_tensor); 632 | } 633 | } 634 | 635 | // setup executor closure 636 | const Op* backward_op = Op::Get("_backward"); 637 | op_execs_.resize(idx.num_nodes()); 638 | // setup the array and requirements. 639 | for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { 640 | const auto& inode = idx[nid]; 641 | if (inode.source->is_variable()) continue; 642 | std::vector in_array, out_array; 643 | for (const auto& e : inode.inputs) { 644 | in_array.push_back(data_entry_[idx.entry_id(e)]); 645 | } 646 | for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) { 647 | uint32_t eid = idx.entry_id(nid, index); 648 | out_array.push_back(data_entry_[eid]); 649 | } 650 | 651 | #if TINYFLOW_USE_FUSION == 1 652 | if (node_rtc_ && node_rtc_->count(nid)) { 653 | // rtc compute 654 | op_execs_[nid] = GenerateRTCClosure(node_rtc_->at(nid), in_array, out_array); 655 | } else if (lua_compute_code.count(inode.source->op())) { 656 | #else 657 | if (lua_compute_code.count(inode.source->op())) { 658 | #endif 659 | // compute function 660 | std::string lua_str = "return " + lua_compute_code[inode.source->op()]; 661 | LuaRef fcompute = lua->Eval(lua_str); 662 | op_execs_[nid] = fcompute( 663 | in_array, out_array, inode.source->attrs.dict); 664 | } else if (!op_exec_modules_[nid].is_nil()) { 665 | // nn module forward 666 | std::vector weights; 667 | for (size_t i = 1; i < in_array.size(); ++i) { 668 | weights.push_back(in_array[i]); 669 | } 670 | op_execs_[nid] = fcreate_nnforward_closure( 671 | op_exec_modules_[nid], in_array[0], out_array[0], weights); 672 | CHECK_EQ(out_array.size(), 1) << "only support tensor nn module"; 673 | } else if (inode.source->op() == backward_op) { 674 | // nn module backward 675 | CHECK_GE(inode.control_deps.size(), 1); 676 | const NNBackwardParam& param = 677 | dmlc::get(inode.source->attrs.parsed); 678 | std::vector weight, gradWeight; 679 | LuaRef gradInput, gradOutput, input = lempty_tensor, output = lempty_tensor; 680 | gradInput = out_array[0]; 681 | for (size_t i = 1; i < out_array.size(); ++i) { 682 | gradWeight.push_back(out_array[i]); 683 | } 684 | gradOutput = in_array[0]; 685 | // set the non-needed to be empty tensor. 686 | size_t in_ptr = 1; 687 | if (param.need_inputs) { 688 | input = in_array[in_ptr]; 689 | for (size_t i = 1; i < param.forward_readonly_inputs; ++i) { 690 | weight.push_back(in_array[i + in_ptr]); 691 | } 692 | in_ptr += param.forward_readonly_inputs; 693 | } else { 694 | weight.resize(param.forward_readonly_inputs, lempty_tensor); 695 | } 696 | CHECK_EQ(param.num_states, 0); 697 | if (param.need_outputs) { 698 | output = in_array[in_ptr]; 699 | } 700 | op_execs_[nid] = fcreate_nnbackward_closure( 701 | op_exec_modules_[inode.control_deps[0]], 702 | input, output, weight, gradInput, gradOutput, gradWeight); 703 | } else { 704 | LOG(FATAL) << "Function FLuaCompute is not registered on " 705 | << inode.source->op()->name; 706 | } 707 | } 708 | } 709 | 710 | #if TINYFLOW_USE_FUSION == 1 711 | FOpExec TorchExecutor::GenerateRTCClosure(RTC& rtc, 712 | const std::vector& input_luaref, std::vector& output_luaref) { 713 | auto ret = [&rtc, input_luaref, output_luaref]() { 714 | auto* th = TorchState::ThreadLocalState(); 715 | CUdeviceptr input_dptr[input_luaref.size()], output_dptr[output_luaref.size()]; 716 | std::vector input, output; 717 | 718 | for (size_t i = 0; i < input_luaref.size(); ++i) { 719 | input_dptr[i] = reinterpret_cast(th->GetTBlob(input_luaref[i]).data); 720 | input.push_back(&input_dptr[i]); 721 | } 722 | 723 | for (size_t i = 0; i < output_luaref.size(); ++i) { 724 | output_dptr[i] = reinterpret_cast(th->GetTBlob(output_luaref[i]).data); 725 | output.push_back(&output_dptr[i]); 726 | } 727 | 728 | TShape ewise_shape = th->GetTBlob(output_luaref[0]).shape; 729 | int num_elements = 1; 730 | for (auto it = ewise_shape.begin(); it != ewise_shape.end(); ++it) { 731 | num_elements *= (*it); 732 | } 733 | rtc.Run(input, output, num_elements); 734 | }; 735 | return ret; 736 | } 737 | #endif 738 | 739 | } // namespace tinyflow 740 | --------------------------------------------------------------------------------