├── example ├── cat15.jpg └── main.go ├── amalgamation ├── build.sh ├── gen.sh ├── mxnet0.cc ├── expand.py └── c_predict_api.h ├── input.go ├── argsort.go ├── README.md ├── predict.go └── mxnet.h /example/cat15.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jdeng/gomxnet/HEAD/example/cat15.jpg -------------------------------------------------------------------------------- /amalgamation/build.sh: -------------------------------------------------------------------------------- 1 | export CXX=g++ 2 | $CXX -O3 -std=c++11 -I/usr/local/Cellar/openblas/0.2.14_1/include -c -o mxnet.o mxnet.cc 3 | ar rcs mxnet.a mxnet.o 4 | 5 | 6 | -------------------------------------------------------------------------------- /amalgamation/gen.sh: -------------------------------------------------------------------------------- 1 | export MXNET_ROOT=~/Source/mxnet 2 | export OPENBLAS_ROOT=/usr/local/Cellar/openblas/0.2.14_1 3 | rm -f ./mxnet 4 | echo "Linking $MXNET_ROOT to ./mxnet" 5 | ln -s $MXNET_ROOT ./mxnet 6 | echo "Generating deps from $MXNET_ROOT to mxnet0.d with mxnet0.cc" 7 | g++ -MD -MF mxnet0.d -std=c++11 -Wall -I ./mxnet/ -I ./mxnet/mshadow/ -I ./mxnet/dmlc-core/include -I ./mxnet/include -I$OPENBLAS_ROOT/include -c mxnet0.cc 8 | 9 | echo "Generating amalgamation to mxnet.cc" 10 | python ./expand.py 11 | 12 | cp mxnet.cc ../ 13 | echo "Done" 14 | 15 | 16 | -------------------------------------------------------------------------------- /input.go: -------------------------------------------------------------------------------- 1 | package gomxnet 2 | 3 | import ( 4 | "fmt" 5 | "image" 6 | ) 7 | 8 | type ImageMean struct { 9 | R, G, B float32 10 | } 11 | 12 | func InputFrom(imgs []image.Image, mean ImageMean) ([]float32, error) { 13 | if len(imgs) == 0 { 14 | return nil, fmt.Errorf("No image") 15 | } 16 | height := imgs[0].Bounds().Max.Y - imgs[0].Bounds().Min.Y 17 | width := imgs[0].Bounds().Max.X - imgs[0].Bounds().Min.X 18 | 19 | out := make([]float32, height*width*3*len(imgs)) 20 | for i := 0; i < len(imgs); i++ { 21 | m := imgs[i] 22 | bounds := m.Bounds() 23 | h := bounds.Max.Y - bounds.Min.Y 24 | w := bounds.Max.X - bounds.Min.X 25 | if h != height || w != width { 26 | return nil, fmt.Errorf("Size not matched") 27 | } 28 | start := width * height * 3 * i 29 | for y := 0; y < height; y++ { 30 | for x := 0; x < width; x++ { 31 | r, g, b, _ := m.At(x+bounds.Min.X, y+bounds.Min.Y).RGBA() 32 | out[start+y*width+x] = float32(r>>8) - mean.R 33 | out[start+width*height+y*width+x] = float32(g>>8) - mean.G 34 | out[start+2*width*height+y*width+x] = float32(b>>8) - mean.B 35 | } 36 | } 37 | 38 | } 39 | return out, nil 40 | } 41 | -------------------------------------------------------------------------------- /argsort.go: -------------------------------------------------------------------------------- 1 | // https://github.com/gonum/floats/blob/master/floats.go 2 | package gomxnet 3 | 4 | import ( 5 | "sort" 6 | ) 7 | 8 | // argsort is a helper that implements sort.Interface, as used by 9 | // Argsort. 10 | type argsort struct { 11 | s []float32 12 | inds []int 13 | } 14 | 15 | func (a argsort) Len() int { 16 | return len(a.s) 17 | } 18 | 19 | func (a argsort) Less(i, j int) bool { 20 | return a.s[i] > a.s[j] 21 | } 22 | 23 | func (a argsort) Swap(i, j int) { 24 | a.s[i], a.s[j] = a.s[j], a.s[i] 25 | a.inds[i], a.inds[j] = a.inds[j], a.inds[i] 26 | } 27 | 28 | // Argsort sorts the elements of s while tracking their original order. 29 | // At the conclusion of Argsort, s will contain the original elements of s 30 | // but sorted in increasing order, and inds will contain the original position 31 | // of the elements in the slice such that dst[i] = origDst[inds[i]]. 32 | // It panics if the lengths of dst and inds do not match. 33 | func Argsort(dst []float32, inds []int) { 34 | if len(dst) != len(inds) { 35 | panic("floats: length of inds does not match length of slice") 36 | } 37 | for i := range dst { 38 | inds[i] = i 39 | } 40 | 41 | a := argsort{s: dst, inds: inds} 42 | sort.Sort(a) 43 | } 44 | -------------------------------------------------------------------------------- /example/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "C" 4 | 5 | import ( 6 | "github.com/jdeng/gomxnet" 7 | "bufio" 8 | "flag" 9 | "fmt" 10 | "image" 11 | // "image/jpeg" 12 | "io/ioutil" 13 | "os" 14 | 15 | "github.com/disintegration/imaging" 16 | ) 17 | 18 | 19 | func main() { 20 | flag.Parse() 21 | args := flag.Args() 22 | file := "cat15.jpg" 23 | if len(args) >= 1 { 24 | file = args[0] 25 | } 26 | 27 | reader, err := os.Open(file) 28 | if err != nil { 29 | panic(err) 30 | } 31 | 32 | var batch uint32 = 1 33 | img, _, _ := image.Decode(reader) 34 | img = imaging.Fill(img, 224, 224, imaging.Center, imaging.Lanczos) 35 | 36 | /* 37 | test, _ := os.OpenFile("test.jpg", os.O_CREATE|os.O_WRONLY, 0644) 38 | jpeg.Encode(test, img, nil) 39 | */ 40 | 41 | symbol, err := ioutil.ReadFile("./Inception-symbol.json") 42 | if err != nil { 43 | panic(err) 44 | } 45 | params, err := ioutil.ReadFile("./Inception-0009.params") 46 | if err != nil { 47 | panic(err) 48 | } 49 | synset, err := os.Open("./synset.txt") 50 | if err != nil { 51 | panic(err) 52 | } 53 | 54 | pred, err := gomxnet.NewPredictor(gomxnet.Model{symbol, params}, gomxnet.Device{gomxnet.CPU, 0}, []gomxnet.InputNode{{"data", []uint32{batch, 3, 224, 224}}}) 55 | if err != nil { 56 | panic(err) 57 | } 58 | 59 | input, _ := gomxnet.InputFrom([]image.Image{img}, gomxnet.ImageMean{117.0, 117.0, 117.0}) 60 | pred.Forward("data", input) 61 | output, _ := pred.GetOutput(0) 62 | pred.Free() 63 | 64 | dict := []string{} 65 | scanner := bufio.NewScanner(synset) 66 | for scanner.Scan() { 67 | dict = append(dict, scanner.Text()) 68 | } 69 | 70 | outputLen := uint32(len(output)) / batch 71 | var b uint32 = 0 72 | for ; b < batch; b++ { 73 | out := output[b*outputLen : (b+1)*outputLen] 74 | index := make([]int, len(out)) 75 | gomxnet.Argsort(out, index) 76 | 77 | fmt.Printf("image #%d\n", b) 78 | for i := 0; i < 20; i++ { 79 | fmt.Printf("%d: %f, %d, %s\n", i, out[i], index[i], dict[index[i]]) 80 | } 81 | fmt.Println("") 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /amalgamation/mxnet0.cc: -------------------------------------------------------------------------------- 1 | // mexnet.cc 2 | 3 | #define MSHADOW_FORCE_STREAM 4 | #define MSHADOW_USE_CUDA 0 5 | #define MSHADOW_USE_CBLAS 1 6 | #define MSHADOW_USE_MKL 0 7 | #define MSHADOW_RABIT_PS 0 8 | #define MSHADOW_DIST_PS 0 9 | 10 | #define MXNET_USE_OPENCV 0 11 | #define DISABLE_OPENMP 1 12 | 13 | #include "src/ndarray/unary_function.cc" 14 | #include "src/ndarray/ndarray_function.cc" 15 | #include "src/ndarray/ndarray.cc" 16 | #include "src/engine/engine.cc" 17 | #include "src/engine/naive_engine.cc" 18 | #include "src/engine/threaded_engine.cc" 19 | #include "src/engine/threaded_engine_perdevice.cc" 20 | #include "src/engine/threaded_engine_pooled.cc" 21 | #include "src/io/io.cc" 22 | #include "src/kvstore/kvstore.cc" 23 | #include "src/symbol/graph_executor.cc" 24 | #include "src/symbol/static_graph.cc" 25 | #include "src/symbol/symbol.cc" 26 | #include "src/operator/operator.cc" 27 | #include "src/operator/activation.cc" 28 | #include "src/operator/batch_norm.cc" 29 | #include "src/operator/block_grad.cc" 30 | #include "src/operator/concat.cc" 31 | #include "src/operator/convolution.cc" 32 | #include "src/operator/dropout.cc" 33 | #include "src/operator/elementwise_binary_op.cc" 34 | #include "src/operator/elementwise_sum.cc" 35 | #include "src/operator/fully_connected.cc" 36 | #include "src/operator/leaky_relu.cc" 37 | #include "src/operator/lrn.cc" 38 | #include "src/operator/pooling.cc" 39 | #include "src/operator/regression_output.cc" 40 | #include "src/operator/reshape.cc" 41 | #include "src/operator/slice_channel.cc" 42 | #include "src/operator/softmax_output.cc" 43 | #include "src/operator/deconvolution.cc" 44 | #include "src/operator/native_op.cc" 45 | #include "src/storage/storage.cc" 46 | #include "src/common/tblob_op_registry.cc" 47 | 48 | #include "src/resource.cc" 49 | 50 | #include "src/c_api/c_api.cc" 51 | #include "src/c_api/c_api_error.cc" 52 | #include "src/c_api/c_predict_api.cc" 53 | 54 | #include "dmlc-core/src/data.cc" 55 | #include "dmlc-core/src/io/input_split_base.cc" 56 | #include "dmlc-core/src/io/line_split.cc" 57 | #include "dmlc-core/src/io/local_filesys.cc" 58 | #include "dmlc-core/src/io/recordio_split.cc" 59 | #include "dmlc-core/src/io.cc" 60 | #include "dmlc-core/src/recordio.cc" 61 | 62 | -------------------------------------------------------------------------------- /amalgamation/expand.py: -------------------------------------------------------------------------------- 1 | import os.path, re, StringIO 2 | 3 | blacklist = [ 4 | 'Windows.h', 'cublas_v2.h', 'cuda/tensor_gpu-inl.cuh', 'cuda_runtime.h', 'cudnn.h', 'cudnn_lrn-inl.h', 'curand.h', 'glog/logging.h', 'io/azure_filesys.h', 'io/hdfs_filesys.h', 'io/s3_filesys.h', 'kvstore_dist.h', 'mach/clock.h', 'mach/mach.h', 'malloc.h', 'mkl.h', 'mkl_cblas.h', 'mkl_vsl.h', 'mkl_vsl_functions.h', 'nvml.h', 'opencv2/opencv.hpp', 'sys/stat.h', 'sys/types.h', 'emmintrin.h' 5 | ] 6 | 7 | sources = [] 8 | files = [] 9 | # g++ -MD -MF mxnet0.d -std=c++11 -Wall -I./mshadow/ -I./dmlc-core/include -Iinclude -I/usr/local//Cellar/openblas/0.2.14_1/include -c mxnet0.cc 10 | for line in open('mxnet0.d'): 11 | files = files + line.strip().split(' ') 12 | 13 | for f in files: 14 | f = f.strip() 15 | if not f or f == 'mxnet0.o:' or f == '\\': continue 16 | fn = os.path.relpath(f) 17 | if fn.find('/usr/') < 0 and fn not in sources: 18 | sources.append(fn) 19 | 20 | def find_source(name, start): 21 | candidates = [] 22 | for x in sources: 23 | if x == name or x.endswith('/' + name): candidates.append(x) 24 | if not candidates: return '' 25 | if len(candidates) == 1: return candidates[0] 26 | for x in candidates: 27 | # print 'multiple candidates: %s, looking for %s, candidates: %s' %(start, name, str(candidates)) 28 | if x.split('/')[1] == start.split('/')[1]: return x 29 | return '' 30 | 31 | 32 | re1 = re.compile('<([./a-zA-Z0-9_-]*)>') 33 | re2 = re.compile('"([./a-zA-Z0-9_-]*)"') 34 | 35 | sysheaders = [] 36 | history = {} 37 | 38 | out = StringIO.StringIO() 39 | def expand(x, pending): 40 | if x in history and x not in ['mshadow/mshadow/expr_scalar-inl.h']: # MULTIPLE includes 41 | return 42 | 43 | if x in pending: 44 | # print 'loop found: %s in ' % x, pending 45 | return 46 | 47 | print >>out, "//===== EXPANDIND: %s =====\n" %x 48 | for line in open(x): 49 | if line.find('#include') < 0: 50 | out.write(line) 51 | continue 52 | if line.strip().find('#include') > 0: 53 | print line 54 | continue 55 | m = re1.search(line) 56 | if not m: m = re2.search(line) 57 | if not m: 58 | print line + ' not found' 59 | continue 60 | h = m.groups()[0].strip('./') 61 | source = find_source(h, x) 62 | if not source: 63 | if h not in blacklist and h not in sysheaders: sysheaders.append(h) 64 | else: 65 | expand(source, pending + [x]) 66 | print >>out, "//===== EXPANDED: %s =====\n" %x 67 | history[x] = 1 68 | 69 | expand('mxnet0.cc', []) 70 | 71 | f = open('mxnet.cc', 'wb') 72 | print >>f, ''' 73 | #if defined(__MACH__) 74 | #include 75 | #include 76 | #endif 77 | 78 | #if !defined(__WIN32__) 79 | #include 80 | #include 81 | 82 | #if !defined(__ANDROID__) 83 | #include 84 | #endif 85 | 86 | #endif 87 | ''' 88 | 89 | for k in sorted(sysheaders): 90 | print >>f, "#include <%s>" % k 91 | 92 | print >>f, '' 93 | print >>f, out.getvalue() 94 | 95 | for x in sources: 96 | if x not in history: print 'Not processed:', x 97 | 98 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # gomxnet 2 | Amalgamation and go binding for mxnet https://github.com/dmlc/mxnet 3 | 4 | 5 | ## Go binding for predictor 6 | * Use ```go get github.com/jdeng/gomxnet``` to install. If your openblas version happens to be 0.2.14_1 you are all set. 7 | * Update ```predict.go```: update ```openblas``` library path accordingly (the first two lines). 8 | ``` 9 | ... 10 | //#cgo CXXFLAGS: -std=c++11 -I/usr/local/Cellar/openblas/0.2.14_1/include 11 | //#cgo LDFLAGS: -L /usr/local/Cellar/openblas/0.2.14_1/lib/ -lopenblas 12 | //#include 13 | //#include "mxnet.h" 14 | import "C" 15 | import "unsafe" 16 | ... 17 | ``` 18 | * ```go build``` could be slow when importing this library due to recompiling mxnet.cc every time. One workaround is manually building and linking mxnet.a. You will need to remove mxnet.cc to prevent go build compiling it. See below 19 | ``` 20 | cd $GOPATH/src/github.com/jdeng/gomxnet/amalgamation 21 | sh build.sh 22 | ls mxnet.a 23 | rm ../mxnet.cc 24 | ``` 25 | And add one line to ```predict.go```. It needs to be an absolute path. 26 | ``` 27 | //#cgo LDFLAGS: //src/github.com/jdeng/gomxnet/amalgamation/mxnet.a -lstdc++ 28 | ``` 29 | * Build the sample ```example/main.go``` with ```go build``` in the example directory. You need to install a dependent library with ```go get github.com/disintegration/imaging```. 30 | ** Download the model file package from [https://github.com/dmlc/mxnet-model-gallery] and update the path in ```main.go```. Build with ```go build```. Try with ```./example cat15.jpg```, the program should be able to recognize the cat. 31 | * Tested with golang 1.5.1 on Mac OS X Yosemite. 32 | * Sample usage (from ```src/main.go```) 33 | ``` 34 | // read model files into memory 35 | symbol, _ := ioutil.ReadFile("./Inception-symbol.json") 36 | params, _ := ioutil.ReadFile("./Inception-0009.params") 37 | 38 | // create predictor with model, device and input node config 39 | batch := 1 40 | pred, _ := gomxnet.NewPredictor(gomxnet.Model{symbol, params}, gomxnet.Device{gomxnet.CPU, 0}, []gomxnet.InputNode{{"data", []uint32{batch, 3, 224, 224}}}) 41 | 42 | // get input vector from 224 * 224 image(s) 43 | input, _ := gomxnet.InputFrom([]image.Image{img}, gomxnet.ImageMean{117.0, 117.0, 117.0}) 44 | 45 | // feed forward 46 | pred.Forward("data", input) 47 | 48 | // get the first output node. length for each image is len(output) / batch 49 | output, _ := pred.GetOutput(0) 50 | 51 | // free the predictor 52 | pred.Free() 53 | 54 | ``` 55 | ## mxnet amalgamation (this is optional. A pre-generated mxnet.cc is already in the directory.) 56 | * Check out mxnet, e.g., in ~/Sources/, update submodules and build 57 | * Generate ```mxnet.cc``` in ```amalgamation``` directory using ```amalgamation/gen.sh``` (content shown below). You may need to update the first two lines to point to your mxnet and openblas locations. 58 | ``` 59 | export MXNET_ROOT=~/Source/mxnet 60 | export OPENBLAS_ROOT=/usr/local/Cellar/openblas/0.2.14_1 61 | rm -f ./mxnet 62 | echo "Linking $MXNET_ROOT to ./mxnet" 63 | ln -s $MXNET_ROOT ./mxnet 64 | echo "Generating deps from $MXNET_ROOT to mxnet0.d with mxnet0.cc" 65 | g++ -MD -MF mxnet0.d -std=c++11 -Wall -I ./mxnet/ -I ./mxnet/mshadow/ -I ./mxnet/dmlc-core/include -I ./mxnet/include -I$OPENBLAS_ROOT/include -c mxnet0.cc 66 | 67 | echo "Generating amalgamation to mxnet.cc" 68 | python ./expand.py 69 | 70 | cp mxnet.cc ../ 71 | echo "Done" 72 | ``` 73 | 74 | # TODO 75 | * Merge openblas into the amalgamation? 76 | * Add Train API? 77 | 78 | -------------------------------------------------------------------------------- /predict.go: -------------------------------------------------------------------------------- 1 | package gomxnet 2 | 3 | //#cgo CXXFLAGS: -std=c++11 -I/usr/local/Cellar/openblas/0.2.14_1/include 4 | //#cgo LDFLAGS: -L /usr/local/Cellar/openblas/0.2.14_1/lib/ -lopenblas 5 | ////// #cgo LDFLAGS: /Users/jack/Source/go/src/github.com/jdeng/gomxnet/mxnet.a -lstdc++ 6 | //#include 7 | //#include "mxnet.h" 8 | import "C" 9 | import "unsafe" 10 | 11 | import "fmt" 12 | 13 | const ( 14 | CPU = iota + 1 15 | GPU 16 | CPU_PINNED 17 | ) 18 | 19 | type Predictor struct { 20 | handle C.PredictorHandle 21 | outputSize uint32 22 | } 23 | 24 | type Model struct { 25 | Symbol []byte // json 26 | Params []byte // network 27 | } 28 | 29 | type Device struct { 30 | Type int 31 | Id int 32 | } 33 | 34 | type InputNode struct { 35 | Key string 36 | Shape []uint32 37 | } 38 | 39 | func NewPredictor(model Model, dev Device, input []InputNode) (*Predictor, error) { 40 | shapeInd := []uint32{0} 41 | shapeData := []uint32{} 42 | 43 | var b *C.char 44 | keys := C.malloc(C.size_t(len(input)) * C.size_t(unsafe.Sizeof(b))) 45 | defer C.free(unsafe.Pointer(keys)) 46 | 47 | for i := 0; i < len(input); i++ { 48 | element := (**C.char)(unsafe.Pointer(uintptr(keys) + uintptr(i)*unsafe.Sizeof(b))) 49 | *element = C.CString(input[i].Key) 50 | shapeInd = append(shapeInd, uint32(len(input[i].Shape))) 51 | shapeData = append(shapeData, input[i].Shape...) 52 | } 53 | 54 | var handle C.PredictorHandle 55 | n, err := C.MXPredCreate((*C.char)(unsafe.Pointer(&model.Symbol[0])), (*C.char)(unsafe.Pointer(&model.Params[0])), C.size_t(len(model.Params)), C.int(dev.Type), C.int(dev.Id), C.mx_uint(len(input)), (**C.char)(keys), (*C.mx_uint)(&shapeInd[0]), (*C.mx_uint)(&shapeData[0]), &handle) 56 | 57 | for i := 0; i < len(input); i++ { 58 | element := (**C.char)(unsafe.Pointer(uintptr(keys) + uintptr(i)*unsafe.Sizeof(b))) 59 | C.free(unsafe.Pointer(*element)) 60 | } 61 | 62 | if err != nil { 63 | return nil, err 64 | } else if n < 0 { 65 | return nil, GetLastError() 66 | } 67 | 68 | return &Predictor{handle, 0}, nil 69 | } 70 | 71 | func (p *Predictor) Free() { 72 | if p.handle != nil { 73 | C.MXPredFree(p.handle) 74 | p.handle = nil 75 | } 76 | } 77 | 78 | func (p *Predictor) Forward(key string, data []float32) error { 79 | if data != nil { 80 | k := C.CString(key) 81 | defer C.free(unsafe.Pointer(k)) 82 | if n, err := C.MXPredSetInput(p.handle, k, (*C.mx_float)(&data[0]), C.mx_uint(len(data))); err != nil { 83 | return err 84 | } else if n < 0 { 85 | return GetLastError() 86 | } 87 | } 88 | 89 | if n, err := C.MXPredForward(p.handle); err != nil { 90 | return err 91 | } else if n < 0 { 92 | return GetLastError() 93 | } 94 | return nil 95 | } 96 | 97 | func (p *Predictor) GetOutput(index uint32) ([]float32, error) { 98 | if p.outputSize == 0 { 99 | var shapeData *C.mx_uint 100 | var shapeDim C.mx_uint 101 | if n, err := C.MXPredGetOutputShape(p.handle, C.mx_uint(index), (**C.mx_uint)(&shapeData), (*C.mx_uint)(&shapeDim)); err != nil { 102 | return nil, err 103 | } else if n < 0 { 104 | return nil, GetLastError() 105 | } 106 | 107 | var size uint32 = 1 108 | for i := 0; i < int(shapeDim); i++ { 109 | n := *(*C.mx_uint)(unsafe.Pointer(uintptr(unsafe.Pointer(shapeData)) + uintptr(i)*unsafe.Sizeof(size))) 110 | size *= uint32(n) 111 | } 112 | 113 | p.outputSize = size 114 | } 115 | 116 | size := p.outputSize 117 | data := make([]C.mx_float, size) 118 | if n, err := C.MXPredGetOutput(p.handle, C.mx_uint(index), (*C.mx_float)(&data[0]), C.mx_uint(size)); err != nil { 119 | return nil, err 120 | } else if n < 0 { 121 | return nil, GetLastError() 122 | } 123 | out := make([]float32, size) 124 | for i := 0; i < int(size); i++ { 125 | out[i] = float32(data[i]) 126 | } 127 | return out, nil 128 | } 129 | 130 | func GetLastError() error { 131 | if err := C.MXGetLastError(); err != nil { 132 | return fmt.Errorf(C.GoString(err)) 133 | } 134 | return nil 135 | } 136 | -------------------------------------------------------------------------------- /mxnet.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2015 by Contributors 3 | * \file c_predict_api.h 4 | * \brief C predict API of mxnet, contains a minimum API to run prediction. 5 | * This file is self-contained, and do not dependent on any other files. 6 | */ 7 | #ifndef MXNET_C_PREDICT_API_H_ 8 | #define MXNET_C_PREDICT_API_H_ 9 | 10 | #ifdef __cplusplus 11 | #define MXNET_EXTERN_C extern "C" 12 | #else 13 | #define MXNET_EXTERN_C 14 | #endif 15 | 16 | #ifdef _WIN32 17 | #ifdef MXNET_EXPORTS 18 | #define MXNET_DLL MXNET_EXTERN_C __declspec(dllexport) 19 | #else 20 | #define MXNET_DLL MXNET_EXTERN_C __declspec(dllimport) 21 | #endif 22 | #else 23 | #define MXNET_DLL MXNET_EXTERN_C 24 | #endif 25 | 26 | /*! \brief manually define unsigned int */ 27 | typedef unsigned int mx_uint; 28 | /*! \brief manually define float */ 29 | typedef float mx_float; 30 | /*! \brief handle to Predictor */ 31 | typedef void *PredictorHandle; 32 | /*! \brief handle to NDArray list */ 33 | typedef void *NDListHandle; 34 | 35 | /*! 36 | * \brief Get the last error happeneed. 37 | * \return The last error happened at the predictor. 38 | */ 39 | MXNET_DLL const char* MXGetLastError(); 40 | /*! 41 | * \brief create a predictor 42 | * \param symbol_json_str The JSON string of the symbol. 43 | * \param param_bytes The in-memory raw bytes of parameter ndarray file. 44 | * \param param_size The size of parameter ndarray file. 45 | * \param dev_type The device type, 1: cpu, 2:gpu 46 | * \param dev_id The device id of the predictor. 47 | * \param num_input_nodes Number of input nodes to the net, 48 | * For feedforward net, this is 1. 49 | * \param input_keys The name of input argument. 50 | * For feedforward net, this is {"data"} 51 | * \param input_shape_indptr Index pointer of shapes of each input node. 52 | * The length of this array = num_input_nodes + 1. 53 | * For feedforward net that takes 4 dimensional input, this is {0, 4}. 54 | * \param input_shape_data A flatted data of shapes of each input node. 55 | * For feedforward net that takes 4 dimensional input, this is the shape data. 56 | * \param out The created predictor handle. 57 | * \return 0 when success, -1 when failure. 58 | */ 59 | MXNET_DLL int MXPredCreate(const char* symbol_json_str, 60 | const char* param_bytes, 61 | size_t param_size, 62 | int dev_type, int dev_id, 63 | mx_uint num_input_nodes, 64 | const char** input_keys, 65 | const mx_uint* input_shape_indptr, 66 | const mx_uint* input_shape_data, 67 | PredictorHandle* out); 68 | /*! 69 | * \brief Get the shape of output node. 70 | * The returned shape_data and shape_ndim is only valid before next call to MXPred function. 71 | * \param handle The handle of the predictor. 72 | * \param index The index of output node, set to 0 if there is only one output. 73 | * \param shape_data Used to hold pointer to the shape data 74 | * \param shape_ndim Used to hold shape dimension. 75 | * \return 0 when success, -1 when failure. 76 | */ 77 | MXNET_DLL int MXPredGetOutputShape(PredictorHandle handle, 78 | mx_uint index, 79 | mx_uint** shape_data, 80 | mx_uint* shape_ndim); 81 | /*! 82 | * \brief Set the input data of predictor. 83 | * \param handle The predictor handle. 84 | * \param key The name of input node to set. 85 | * For feedforward net, this is "data". 86 | * \param data The pointer to the data to be set, with the shape specified in MXPredCreate. 87 | * \param size The size of data array, used for safety check. 88 | * \return 0 when success, -1 when failure. 89 | */ 90 | MXNET_DLL int MXPredSetInput(PredictorHandle handle, 91 | const char* key, 92 | const mx_float* data, 93 | mx_uint size); 94 | /*! 95 | * \brief Run a forward pass to get the output 96 | * \param handle The handle of the predictor. 97 | * \return 0 when success, -1 when failure. 98 | */ 99 | MXNET_DLL int MXPredForward(PredictorHandle handle); 100 | /*! 101 | * \brief Get the output value of prediction. 102 | * \param handle The handle of the predictor. 103 | * \param index The index of output node, set to 0 if there is only one output. 104 | * \param data User allocated data to hold the output. 105 | * \param size The size of data array, used for safe checking. 106 | * \return 0 when success, -1 when failure. 107 | */ 108 | MXNET_DLL int MXPredGetOutput(PredictorHandle handle, 109 | mx_uint index, 110 | mx_float* data, 111 | mx_uint size); 112 | /*! 113 | * \brief Free a predictor handle. 114 | * \param handle The handle of the predictor. 115 | * \return 0 when success, -1 when failure. 116 | */ 117 | MXNET_DLL int MXPredFree(PredictorHandle handle); 118 | /*! 119 | * \brief Create a NDArray List by loading from ndarray file. 120 | * This can be used to load mean image file. 121 | * \param nd_file_bytes The byte contents of nd file to be loaded. 122 | * \param nd_file_size The size of the nd file to be loaded. 123 | * \param out The out put NDListHandle 124 | * \param out_length Length of the list. 125 | * \return 0 when success, -1 when failure. 126 | */ 127 | MXNET_DLL int MXNDListCreate(const char* nd_file_bytes, 128 | size_t nd_file_size, 129 | NDListHandle *out, 130 | mx_uint* out_length); 131 | /*! 132 | * \brief Get an element from list 133 | * \param handle The handle to the NDArray 134 | * \param index The index in the list 135 | * \param out_key The output key of the item 136 | * \param out_data The data region of the item 137 | * \param out_shape The shape of the item. 138 | * \param out_ndim The number of dimension in the shape. 139 | * \return 0 when success, -1 when failure. 140 | */ 141 | MXNET_DLL int MXNDListGet(NDListHandle handle, 142 | mx_uint index, 143 | const char** out_key, 144 | const mx_float** out_data, 145 | const mx_uint** out_shape, 146 | mx_uint* out_ndim); 147 | /*! 148 | * \brief Free a predictor handle. 149 | * \param handle The handle of the predictor. 150 | * \return 0 when success, -1 when failure. 151 | */ 152 | MXNET_DLL int MXNDListFree(NDListHandle handle); 153 | 154 | #endif // MXNET_C_PREDICT_API_H_ 155 | -------------------------------------------------------------------------------- /amalgamation/c_predict_api.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2015 by Contributors 3 | * \file c_predict_api.h 4 | * \brief C predict API of mxnet, contains a minimum API to run prediction. 5 | * This file is self-contained, and do not dependent on any other files. 6 | */ 7 | #ifndef MXNET_C_PREDICT_API_H_ 8 | #define MXNET_C_PREDICT_API_H_ 9 | 10 | #ifdef __cplusplus 11 | #define MXNET_EXTERN_C extern "C" 12 | #else 13 | #define MXNET_EXTERN_C 14 | #endif 15 | 16 | #ifdef _WIN32 17 | #ifdef MXNET_EXPORTS 18 | #define MXNET_DLL MXNET_EXTERN_C __declspec(dllexport) 19 | #else 20 | #define MXNET_DLL MXNET_EXTERN_C __declspec(dllimport) 21 | #endif 22 | #else 23 | #define MXNET_DLL MXNET_EXTERN_C 24 | #endif 25 | 26 | /*! \brief manually define unsigned int */ 27 | typedef unsigned int mx_uint; 28 | /*! \brief manually define float */ 29 | typedef float mx_float; 30 | /*! \brief handle to Predictor */ 31 | typedef void *PredictorHandle; 32 | /*! \brief handle to NDArray list */ 33 | typedef void *NDListHandle; 34 | 35 | /*! 36 | * \brief Get the last error happeneed. 37 | * \return The last error happened at the predictor. 38 | */ 39 | MXNET_DLL const char* MXGetLastError(); 40 | /*! 41 | * \brief create a predictor 42 | * \param symbol_json_str The JSON string of the symbol. 43 | * \param param_bytes The in-memory raw bytes of parameter ndarray file. 44 | * \param param_size The size of parameter ndarray file. 45 | * \param dev_type The device type, 1: cpu, 2:gpu 46 | * \param dev_id The device id of the predictor. 47 | * \param num_input_nodes Number of input nodes to the net, 48 | * For feedforward net, this is 1. 49 | * \param input_keys The name of input argument. 50 | * For feedforward net, this is {"data"} 51 | * \param input_shape_indptr Index pointer of shapes of each input node. 52 | * The length of this array = num_input_nodes + 1. 53 | * For feedforward net that takes 4 dimensional input, this is {0, 4}. 54 | * \param input_shape_data A flatted data of shapes of each input node. 55 | * For feedforward net that takes 4 dimensional input, this is the shape data. 56 | * \param out The created predictor handle. 57 | * \return 0 when success, -1 when failure. 58 | */ 59 | MXNET_DLL int MXPredCreate(const char* symbol_json_str, 60 | const char* param_bytes, 61 | size_t param_size, 62 | int dev_type, int dev_id, 63 | mx_uint num_input_nodes, 64 | const char** input_keys, 65 | const mx_uint* input_shape_indptr, 66 | const mx_uint* input_shape_data, 67 | PredictorHandle* out); 68 | /*! 69 | * \brief Get the shape of output node. 70 | * The returned shape_data and shape_ndim is only valid before next call to MXPred function. 71 | * \param handle The handle of the predictor. 72 | * \param index The index of output node, set to 0 if there is only one output. 73 | * \param shape_data Used to hold pointer to the shape data 74 | * \param shape_ndim Used to hold shape dimension. 75 | * \return 0 when success, -1 when failure. 76 | */ 77 | MXNET_DLL int MXPredGetOutputShape(PredictorHandle handle, 78 | mx_uint index, 79 | mx_uint** shape_data, 80 | mx_uint* shape_ndim); 81 | /*! 82 | * \brief Set the input data of predictor. 83 | * \param handle The predictor handle. 84 | * \param key The name of input node to set. 85 | * For feedforward net, this is "data". 86 | * \param data The pointer to the data to be set, with the shape specified in MXPredCreate. 87 | * \param size The size of data array, used for safety check. 88 | * \return 0 when success, -1 when failure. 89 | */ 90 | MXNET_DLL int MXPredSetInput(PredictorHandle handle, 91 | const char* key, 92 | const mx_float* data, 93 | mx_uint size); 94 | /*! 95 | * \brief Run a forward pass to get the output 96 | * \param handle The handle of the predictor. 97 | * \return 0 when success, -1 when failure. 98 | */ 99 | MXNET_DLL int MXPredForward(PredictorHandle handle); 100 | /*! 101 | * \brief Get the output value of prediction. 102 | * \param handle The handle of the predictor. 103 | * \param index The index of output node, set to 0 if there is only one output. 104 | * \param data User allocated data to hold the output. 105 | * \param size The size of data array, used for safe checking. 106 | * \return 0 when success, -1 when failure. 107 | */ 108 | MXNET_DLL int MXPredGetOutput(PredictorHandle handle, 109 | mx_uint index, 110 | mx_float* data, 111 | mx_uint size); 112 | /*! 113 | * \brief Free a predictor handle. 114 | * \param handle The handle of the predictor. 115 | * \return 0 when success, -1 when failure. 116 | */ 117 | MXNET_DLL int MXPredFree(PredictorHandle handle); 118 | /*! 119 | * \brief Create a NDArray List by loading from ndarray file. 120 | * This can be used to load mean image file. 121 | * \param nd_file_bytes The byte contents of nd file to be loaded. 122 | * \param nd_file_size The size of the nd file to be loaded. 123 | * \param out The out put NDListHandle 124 | * \param out_length Length of the list. 125 | * \return 0 when success, -1 when failure. 126 | */ 127 | MXNET_DLL int MXNDListCreate(const char* nd_file_bytes, 128 | size_t nd_file_size, 129 | NDListHandle *out, 130 | mx_uint* out_length); 131 | /*! 132 | * \brief Get an element from list 133 | * \param handle The handle to the NDArray 134 | * \param index The index in the list 135 | * \param out_key The output key of the item 136 | * \param out_data The data region of the item 137 | * \param out_shape The shape of the item. 138 | * \param out_ndim The number of dimension in the shape. 139 | * \return 0 when success, -1 when failure. 140 | */ 141 | MXNET_DLL int MXNDListGet(NDListHandle handle, 142 | mx_uint index, 143 | const char** out_key, 144 | const mx_float** out_data, 145 | const mx_uint** out_shape, 146 | mx_uint* out_ndim); 147 | /*! 148 | * \brief Free a predictor handle. 149 | * \param handle The handle of the predictor. 150 | * \return 0 when success, -1 when failure. 151 | */ 152 | MXNET_DLL int MXNDListFree(NDListHandle handle); 153 | 154 | #endif // MXNET_C_PREDICT_API_H_ 155 | --------------------------------------------------------------------------------