├── .gitignore ├── README.md ├── SConscript ├── examples ├── mnist.c ├── mnist.h ├── mnist_model.c ├── mnist_sm.c └── model │ ├── mnist-keras.ipynb │ ├── mnist-lg.onnx │ └── mnist-sm.onnx └── src ├── add.c ├── conv2d.c ├── dense.c ├── info.c ├── matmul.c ├── maxpool.c ├── model.c ├── onnx.h ├── relu.c ├── softmax.c └── transpose.c /.gitignore: -------------------------------------------------------------------------------- 1 | # Prerequisites 2 | *.d 3 | 4 | # Object files 5 | *.o 6 | *.ko 7 | *.obj 8 | *.elf 9 | 10 | # Linker output 11 | *.ilk 12 | *.map 13 | *.exp 14 | 15 | # Precompiled Headers 16 | *.gch 17 | *.pch 18 | 19 | # Libraries 20 | *.lib 21 | *.a 22 | *.la 23 | *.lo 24 | 25 | # Shared objects (inc. Windows DLLs) 26 | *.dll 27 | *.so 28 | *.so.* 29 | *.dylib 30 | 31 | # Executables 32 | *.exe 33 | *.out 34 | *.app 35 | *.i*86 36 | *.x86_64 37 | *.hex 38 | 39 | # Debug files 40 | *.dSYM/ 41 | *.su 42 | *.idb 43 | *.pdb 44 | 45 | # Kernel Module Compile Results 46 | *.mod* 47 | *.cmd 48 | .tmp_versions/ 49 | modules.order 50 | Module.symvers 51 | Mkfile.old 52 | dkms.conf 53 | 54 | # Project 55 | .vscode/ 56 | .sconsign.dblite 57 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![](https://raw.githubusercontent.com/onnx/onnx/master/docs/ONNX_logo_main.png) 2 | 3 | # onnx-backend 4 | 5 | **通用神经网络模型 onnx 在 RT-Thread 上的后端** 6 | 7 | [ONNX](https://onnx.ai/) (Open Neural Network Exchange) 是机器学习模型的通用格式,可以帮助大家方便地融合不同机器学习框架的模型。 8 | 9 | 如果能在 RT-Thread 上解析并运行 onnx 的模型,那么就可以在 [RT-Thread](https://www.rt-thread.org) 上运行几乎所有主流机器学习框架了,例如 Tensorflow, Keras, Pytorch, Caffe2, mxnet, 因为它们生成的模型都可以转换为 onnx。 10 | 11 | ## 支持的算子 12 | 13 | - Conv2D 14 | - Relu 15 | - Maxpool 16 | - Softmax 17 | - Matmul 18 | - Add 19 | - Flatten 20 | - Transpose 21 | 22 | ## 手写体例程 23 | 24 | 当前只有一个手写体识别的例程:利用 Keras 训练一个卷积神经网络模型,保存为 onnx 模型,再在 RT-Thread 上解析模型进行 inference,当前在 STM32F407 上测试通过。 25 | 26 | 不过这个例程分成了 3 个小的 demo,放在 examples 目录下,用来更直观地展示 onnx-backend 的工作流程,最小的 demo 只需要 16KB 内存就可以了,因此在 STM32F103C8T6 上也可以运行: 27 | 28 | | 例程文件 | 说明 | 29 | | ------------- | ---------------------------------------- | 30 | | mnist.c | 纯手动构建模型,模型参数保存在 mnist.h | 31 | | mnist_sm.c | 纯手动构建模型,模型参数从 onnx 文件加载 | 32 | | mnist_model.c | 自动从 onnx 文件加载模型和参数 | 33 | 34 | #### Keras 模型结构 35 | 36 | ``` 37 | _________________________________________________________________ 38 | Layer (type) Output Shape Param # 39 | ================================================================= 40 | conv2d_5 (Conv2D) (None, 28, 28, 2) 20 41 | _________________________________________________________________ 42 | max_pooling2d_5 (MaxPooling2 (None, 14, 14, 2) 0 43 | _________________________________________________________________ 44 | dropout_5 (Dropout) (None, 14, 14, 2) 0 45 | _________________________________________________________________ 46 | conv2d_6 (Conv2D) (None, 14, 14, 2) 38 47 | _________________________________________________________________ 48 | max_pooling2d_6 (MaxPooling2 (None, 7, 7, 2) 0 49 | _________________________________________________________________ 50 | dropout_6 (Dropout) (None, 7, 7, 2) 0 51 | _________________________________________________________________ 52 | flatten_3 (Flatten) (None, 98) 0 53 | _________________________________________________________________ 54 | dense_5 (Dense) (None, 4) 396 55 | _________________________________________________________________ 56 | dense_6 (Dense) (None, 10) 50 57 | ================================================================= 58 | Total params: 504 59 | Trainable params: 504 60 | Non-trainable params: 0 61 | _________________________________________________________________ 62 | 63 | ``` 64 | 65 | ``` 66 | msh />onnx_mnist 1 67 | @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ 68 | @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ 69 | @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ 70 | @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ 71 | @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ 72 | @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ @@@@@@@@@@@@@@@@@@ 73 | @@@@@@@@@@@@@@@@@@@@@@@@@@ @@@@@@@@@@@@@@@@ 74 | @@@@@@@@@@@@@@@@@@@@@@ @@@@@@@@@@@@@@ 75 | @@@@@@@@@@@@@@@@@@@@ @@@@@@@@ @@@@@@@@@@@@@@ 76 | @@@@@@@@@@@@@@@@@@@@ @@@@@@@@@@ @@@@@@@@@@@@@@ 77 | @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ @@@@@@ @@@@@@@@@@@@@@ 78 | @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ @@@@ @@@@@@@@@@@@@@ 79 | @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ @@@@@@@@@@@@@@@@ 80 | @@@@@@@@@@@@@@@@@@@@@@@@ @@@@@@@@@@@@@@@@@@ 81 | @@@@@@@@@@@@@@@@@@@@@@@@ @@@@@@@@@@@@@@@@@@@@ 82 | @@@@@@@@@@@@@@@@@@@@@@@@ @@@@@@@@@@@@@@@@@@@@@@ 83 | @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ @@@@@@@@@@@@@@@@@@@@ 84 | @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ @@@@@@@@@@@@@@@@@@@@ 85 | @@@@@@@@@@@@ @@@@@@@@@@@@@@@@@@ @@@@@@@@@@@@@@@@@@@@ 86 | @@@@@@@@@@ @@@@@@@@@@@@@@@@@@ @@@@@@@@@@@@@@@@@@@@ 87 | @@@@@@@@@@ @@@@@@@@@@@@@@@@@@ @@@@@@@@@@@@@@@@@@@@ 88 | @@@@@@@@@@ @@@@@@@@@@@@ @@@@@@@@@@@@@@@@@@@@@@ 89 | @@@@@@@@@@ @@@@@@@@@@@@@@@@@@@@@@@@ 90 | @@@@@@@@@@@@ @@@@@@@@@@@@@@@@@@@@@@@@@@ 91 | @@@@@@@@@@@@@@@@@@ @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ 92 | @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ 93 | @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ 94 | @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ 95 | 96 | Predictions: 97 | 0.007383 0.000000 0.057510 0.570970 0.000000 0.105505 0.000000 0.000039 0.257576 0.001016 98 | 99 | The number is 3 100 | 101 | ``` 102 | 103 | ``` 104 | msh />onnx_mnist 0 105 | @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ 106 | @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ 107 | @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ 108 | @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ 109 | @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ 110 | @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ 111 | @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ 112 | @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ 113 | @@@@@@@@@@ @@@@@@@@@@@@@@@@@@@@ 114 | @@@@@@ @@@@@@@@@@@@@@@@@@ 115 | @@@@ @@@@@@@@@@@@@@ @@@@@@@@@@@@@@@@ 116 | @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ @@@@@@@@@@@@@@@@ 117 | @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ @@@@@@@@@@@@@@ 118 | @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ @@@@@@@@@@@@@@ 119 | @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ @@@@@@@@@@@@@@ 120 | @@@@@@@@@@@@@@@@@@@@@@@@@@ @@@@@@@@@@@@ 121 | @@@@@@@@@@@@@@@@@@@@@@ @@@@@@@@@@@@@@ 122 | @@@@@@@@@@@@@@@@ @@@@ @@@@@@@@@@@@@@ 123 | @@@@@@@@@@@@@@ @@@@@@@@@@@@@@ @@@@@@@@@@@@@@ 124 | @@@@@@@@@@@@ @@@@@@@@@@@@@@@@@@ @@@@@@@@@@@@@@ 125 | @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ @@@@@@@@@@@@@@ 126 | @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ @@@@@@@@@@@@@@@@ 127 | @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ @@@@@@@@@@@@@@@@ 128 | @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ @@@@@@@@@@@@@@@@@@ 129 | @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ @@@@@@@@@@@@@@@@@@@@ 130 | @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ @@@@@@@@@@@@@@@@@@@@ 131 | @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ 132 | @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ 133 | 134 | Predictions: 135 | 0.000498 0.000027 0.017220 0.028220 0.000643 0.002182 0.000000 0.753116 0.026616 136 | 137 | The number is 7 138 | 139 | ``` 140 | 141 | ## 注意事项 142 | 143 | 由于 onnx 的模型是 Google Protobuf v3 的格式,所以这个后端依赖于2个软件包,默认也会选中这两个软件包: 144 | 145 | - protobuf-c 146 | - onnx-parser 147 | 148 | 149 | ## Todo List 150 | 151 | - 模型量化 152 | - 解析更加复杂的模型,生成计算图, 153 | - 针对不同算子进行硬件加速。 154 | 155 | 156 | ## 联系方式 157 | 158 | - 维护:Wu Han 159 | - 主页:http://wuhanstudio.cc 160 | - 邮箱:wuhanstudio@qq.com 161 | 162 | -------------------------------------------------------------------------------- /SConscript: -------------------------------------------------------------------------------- 1 | from building import * 2 | import rtconfig 3 | 4 | # get current directory 5 | cwd = GetCurrentDir() 6 | # The set of source files associated with this SConscript file. 7 | src = Glob('src/*.c') 8 | 9 | if GetDepend('ONNX_BACKEND_USING_MNIST_EXAMPLE'): 10 | src += Glob('examples/mnist.c') 11 | 12 | if GetDepend('ONNX_BACKEND_USING_MNIST_SMALL_EXAMPLE'): 13 | src += Glob('examples/mnist_sm.c') 14 | 15 | if GetDepend('ONNX_BACKEND_USING_MNIST_MODEL_EXAMPLE'): 16 | src += Glob('examples/mnist_model.c') 17 | 18 | path = [cwd + '/src'] 19 | path += [cwd + '/examples'] 20 | 21 | LOCAL_CCFLAGS = '' 22 | 23 | group = DefineGroup('onnx-backend', src, depend = ['PKG_USING_ONNX_BACKEND'], CPPPATH = path, LOCAL_CCFLAGS = LOCAL_CCFLAGS) 24 | 25 | Return('group') 26 | -------------------------------------------------------------------------------- /examples/mnist.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "mnist.h" 9 | #include "onnx.h" 10 | 11 | int mnist(int argc, char const *argv[]) 12 | { 13 | int img_index = 0; 14 | if(argc == 2) 15 | { 16 | img_index = atoi(argv[1]); 17 | } 18 | print_img(img[img_index]); 19 | 20 | // 1. Conv2D 21 | int64_t shapeW3[] = {2, 1, 3, 3}; 22 | int64_t dimW3 = 4; 23 | int64_t permW3_t[] = { 0, 2, 3, 1}; 24 | float* W3_t = transpose(W3, shapeW3, dimW3, permW3_t); 25 | 26 | float* conv1 = (float*) malloc(sizeof(float)*28*28*2); 27 | memset(conv1, 0, sizeof(sizeof(float)*28*28*2)); 28 | conv2D(img[img_index], 28, 28, 1, W3, 2, 3, 3, 1, 1, 1, 1, B3, conv1, 28, 28); 29 | 30 | free(W3_t); 31 | 32 | // 2. Relu 33 | float* relu1 = (float*) malloc(sizeof(float)*28*28*2); 34 | relu(conv1, 28*28*2, relu1); 35 | 36 | free(conv1); 37 | 38 | // 3. Maxpool 39 | float* maxpool1 = (float*) malloc(sizeof(float)*14*14*2); 40 | memset(maxpool1, 0, sizeof(sizeof(float)*14*14*2)); 41 | maxpool(relu1, 28, 28, 2, 2, 2, 0, 0, 2, 2, 14, 14, maxpool1); 42 | 43 | free(relu1); 44 | 45 | // 4. Conv2D 46 | int64_t shapeW2[] = {2, 2, 3, 3}; 47 | int64_t dimW2 = 4; 48 | int64_t perm_t[] = { 0, 2, 3, 1}; 49 | float* W2_t = transpose(W2, shapeW2, dimW2, perm_t); 50 | 51 | float* conv2 = (float*) malloc(sizeof(float)*14*14*2); 52 | memset(conv2, 0, sizeof(sizeof(float)*14*14*2)); 53 | conv2D(maxpool1, 14, 14, 2, W2_t, 2, 3, 3, 1, 1, 1, 1, B2, conv2, 14, 14); 54 | 55 | free(W2_t); 56 | free(maxpool1); 57 | 58 | // 5. Relu 59 | float* relu2 = (float*) malloc(sizeof(float)*14*14*2); 60 | relu(conv2, 14*14*2, relu2); 61 | 62 | free(conv2); 63 | 64 | // 6. Maxpool 65 | float* maxpool2 = (float*) malloc(sizeof(float)*7*7*2); 66 | memset(maxpool2, 0, sizeof(sizeof(float)*7*7*2)); 67 | maxpool(relu2, 14, 14, 2, 2, 2, 0, 0, 2, 2, 7, 7, maxpool2); 68 | 69 | free(relu2); 70 | 71 | // Flatten NOT REQUIRED 72 | 73 | // 7. Dense 74 | int64_t shapeW1[] = {98, 4}; 75 | int64_t dimW1 = 2; 76 | int64_t permW1_t[] = { 1, 0}; 77 | float* W1_t = transpose(W1, shapeW1, dimW1, permW1_t); 78 | 79 | float* dense1 = (float*) malloc(sizeof(float)*4); 80 | memset(dense1, 0, sizeof(sizeof(float)*4)); 81 | dense(maxpool2, W1_t, 98, 4, B1, dense1); 82 | 83 | free(W1_t); 84 | free(maxpool2); 85 | 86 | // 8. Dense 87 | int64_t shapeW[] = {4, 10}; 88 | int64_t dimW = 2; 89 | int64_t permW_t[] = { 1, 0}; 90 | float* W_t = transpose(W, shapeW, dimW, permW_t); 91 | 92 | float* dense2 = (float*) malloc(sizeof(float)*10); 93 | memset(dense2, 0, sizeof(sizeof(float)*10)); 94 | dense(dense1, W_t, 4, 10, B, dense2); 95 | 96 | free(W_t); 97 | free(dense1); 98 | 99 | // 9. Softmax 100 | float* output = (float*) malloc(sizeof(float)*10); 101 | memset(output, 0, sizeof(sizeof(float)*10)); 102 | softmax(dense2, 10, output); 103 | 104 | // 10. Result 105 | float max = 0; 106 | int max_index = 0; 107 | printf("\nPredictions: \n"); 108 | for(int i = 0; i < 10; i++) 109 | { 110 | printf("%f ", output[i]); 111 | if(output[i] > max) 112 | { 113 | max = output[i]; 114 | max_index = i; 115 | } 116 | } 117 | printf("\n"); 118 | printf("\nThe number is %d\n", max_index); 119 | 120 | free(dense2); 121 | free(output); 122 | 123 | return 0; 124 | } 125 | MSH_CMD_EXPORT(mnist, mnist simple example) 126 | -------------------------------------------------------------------------------- /examples/mnist.h: -------------------------------------------------------------------------------- 1 | #ifndef __MNIST_H__ 2 | #define __MNIST_H__ 3 | 4 | #include 5 | #include 6 | 7 | #define IMG0 {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3803922, 0.37647063, 0.3019608, 0.46274513, 0.2392157, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3529412, 0.5411765, 0.9215687, 0.9215687, 0.9215687, 0.9215687, 0.9215687, 0.9215687, 0.9843138, 0.9843138, 0.9725491, 0.9960785, 0.9607844, 0.9215687, 0.74509805, 0.08235294, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.54901963, 0.9843138, 0.9960785, 0.9960785, 0.9960785, 0.9960785, 0.9960785, 0.9960785, 0.9960785, 0.9960785, 0.9960785, 0.9960785, 0.9960785, 0.9960785, 0.9960785, 0.9960785, 0.7411765, 0.09019608, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8862746, 0.9960785, 0.81568635, 0.7803922, 0.7803922, 0.7803922, 0.7803922, 0.54509807, 0.2392157, 0.2392157, 0.2392157, 0.2392157, 0.2392157, 0.5019608, 0.8705883, 0.9960785, 0.9960785, 0.7411765, 0.08235294, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.14901961, 0.32156864, 0.050980397, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.13333334, 0.8352942, 0.9960785, 0.9960785, 0.45098042, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.32941177, 0.9960785, 0.9960785, 0.9176471, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.32941177, 0.9960785, 0.9960785, 0.9176471, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4156863, 0.6156863, 0.9960785, 0.9960785, 0.95294124, 0.20000002, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.098039225, 0.45882356, 0.8941177, 0.8941177, 0.8941177, 0.9921569, 0.9960785, 0.9960785, 0.9960785, 0.9960785, 0.94117653, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.26666668, 0.4666667, 0.86274517, 0.9960785, 0.9960785, 0.9960785, 0.9960785, 0.9960785, 0.9960785, 0.9960785, 0.9960785, 0.9960785, 0.5568628, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.14509805, 0.73333335, 0.9921569, 0.9960785, 0.9960785, 0.9960785, 0.8745099, 0.8078432, 0.8078432, 0.29411766, 0.26666668, 0.8431373, 0.9960785, 0.9960785, 0.45882356, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4431373, 0.8588236, 0.9960785, 0.9490197, 0.89019614, 0.45098042, 0.34901962, 0.121568635, 0.0, 0.0, 0.0, 0.0, 0.7843138, 0.9960785, 0.9450981, 0.16078432, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.6627451, 0.9960785, 0.6901961, 0.24313727, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.18823531, 0.9058824, 0.9960785, 0.9176471, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.07058824, 0.48627454, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.32941177, 0.9960785, 0.9960785, 0.6509804, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.54509807, 0.9960785, 0.9333334, 0.22352943, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8235295, 0.9803922, 0.9960785, 0.65882355, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9490197, 0.9960785, 0.93725497, 0.22352943, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.34901962, 0.9843138, 0.9450981, 0.3372549, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.019607844, 0.8078432, 0.96470594, 0.6156863, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.015686275, 0.45882356, 0.27058825, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0} 8 | #define IMG0_LABEL 7 9 | 10 | #define IMG1 {} 11 | #define IMG1_LABEL 3 12 | 13 | #define TOTAL_IMAGE 2 14 | 15 | static const float img[][784] = {IMG0, IMG1}; 16 | static const signed char label[] = {IMG0_LABEL, IMG1_LABEL}; 17 | 18 | static const float W3[] = {-0.3233681, -0.4261553, -0.6519891, 0.79061985, -0.2210753, 0.037107922, 0.3984157, 0.22128074, 0.7975414, 0.2549885, 0.3076058, 0.62500215, -0.58958095, 0.20375429, -0.06477713, -1.566038, -0.37670124, -0.6443057}; 19 | static const float B3[] = {-0.829373, -0.14096421}; 20 | 21 | static const float W2[] = {0.0070440695, 0.23192555, 0.036849476, -0.14687373, -0.15593372, 0.0044246824, 0.27322513, -0.027562773, 0.23404223, -0.6354651, -0.55645454, -0.77057034, 0.15603222, 0.71015775, 0.23954256, 1.8201442, -0.018377468, 1.5745461, 1.7230825, -0.59662616, 1.3997843, 0.33511618, 0.56846994, 0.3797911, 0.035079807, -0.18287429, -0.032232445, 0.006910181, -0.0026898328, -0.0057844054, 0.29354542, 0.13796881, 0.3558416, 0.0022847173, 0.0025906325, -0.022641085}; 22 | static const float B2[] = {-0.11655525, -0.0036503011}; 23 | 24 | static const float W1[] = {0.15791991, -0.22649878, 0.021204736, 0.025593571, 0.008755621, -0.775102, -0.41594088, -0.12580238, -0.3963741, 0.33545518, -0.631953, -0.028754484, -0.50668705, -0.3574023, -3.7807872, -0.8261617, 0.102246165, 0.571127, -0.6256297, 0.06698781, 0.55969477, 0.25374785, -3.075965, -0.6959133, 0.2531965, 0.31739804, -0.8664238, 0.12750633, 0.83136076, 0.2666574, -2.5865922, -0.572031, 0.29743987, 0.16238026, -0.99154145, 0.077973805, 0.8913329, 0.16854058, -2.5247803, -0.5639109, 0.41671264, -0.10801031, -1.0229865, 0.2062031, 0.39889312, -0.16026731, -1.9185526, -0.48375717, 0.057339806, -1.2573057, -0.23117211, 1.051854, -0.7981992, -1.6263007, -0.26003376, -0.07649365, -0.4646075, 0.755821, 0.13187818, 0.24743222, -1.5276812, 0.1636555, -0.075465426, -0.058517877, -0.33852127, 1.3052516, 0.14443535, 0.44080895, -0.31031442, 0.15416017, 0.0053661224, -0.03175326, -0.15991405, 0.66121936, 0.0832211, 0.2651985, -0.038445678, 0.18054117, -0.0073251156, 0.054193687, -0.014296916, 0.30657783, 0.006181963, 0.22319937, 0.030315898, 0.12695274, -0.028179673, 0.11189027, 0.035358384, 0.046855893, -0.026528472, 0.26450494, 0.069981076, 0.107152134, -0.030371506, 0.09524366, 0.24802336, -0.36496836, -0.102762334, 0.49609017, 0.04002767, 0.020934932, -0.054773595, 0.05412083, -0.071876526, -1.5381132, -0.2356421, 1.5890793, -0.023087852, -0.24933836, 0.018771818, 0.08040064, 0.051946845, 0.6141782, 0.15780787, 0.12887044, -0.8691056, 1.3761537, 0.43058, 0.13476849, -0.14973496, 0.4542634, 0.13077497, 0.23117822, 0.003657386, 0.42742714, 0.23396699, 0.09209521, -0.060258932, 0.4642852, 0.10395402, 0.25047097, -0.05326261, 0.21466804, 0.11694269, 0.22402634, 0.12639907, 0.23495848, 0.12770525, 0.3324459, 0.0140223345, 0.106348366, 0.10877733, 0.30522102, 0.31412345, -0.07164018, 0.13483422, 0.45414954, 0.054698735, 0.07451815, 0.097312905, 0.27480683, 0.4866108, -0.43636885, -0.13586079, 0.5724732, 0.13595985, -0.0074526076, 0.11859829, 0.24481037, -0.37537888, -0.46877658, -0.5648533, 0.86578417, 0.3407381, -0.17214134, 0.040683553, 0.3630519, 0.089548275, -0.4989473, 0.47688767, 0.021731026, 0.2856471, 0.6174715, 0.7059148, -0.30635756, -0.5705427, -0.20692639, 0.041900065, 0.23040071, -0.1790487, -0.023751246, 0.14114629, 0.02345284, -0.64177734, -0.069909826, -0.08587972, 0.16460821, -0.53466517, -0.10163383, -0.13119817, 0.14908728, -0.63503706, -0.098961875, -0.23248474, 0.15406314, -0.48586813, -0.1904713, -0.20466608, 0.10629631, -0.5291871, -0.17358926, -0.36273107, 0.12225631, -0.38659447, -0.24787207, -0.25225234, 0.102635615, -0.14507034, -0.10110793, 0.043757595, -0.17158166, -0.031343404, -0.30139172, -0.09401665, 0.06986169, -0.54915506, 0.66843456, 0.14574362, -0.737502, 0.7700305, -0.4125441, 0.10115133, 0.05281194, 0.25467375, 0.22757779, -0.030224197, -0.0832025, -0.66385627, 0.51225215, -0.121023245, -0.3340579, -0.07505331, -0.09820366, -0.016041134, -0.03187605, -0.43589246, 0.094394326, -0.04983066, -0.0777906, -0.12822862, -0.089667186, -0.07014707, -0.010794195, -0.29095307, -0.01319235, -0.039757702, -0.023403417, -0.15530063, -0.052093383, -0.1477549, -0.07557954, -0.2686017, -0.035220042, -0.095615104, -0.015471024, -0.03906604, 0.024237331, -0.19604297, -0.19998372, -0.20302829, -0.04267139, -0.18774728, -0.045169186, -0.010131819, 0.14829905, -0.117015064, -0.4180649, -0.20680964, -0.024034742, -0.15787442, -0.055698488, -0.09037726, 0.40253848, -0.35745984, -0.786149, -0.0799551, 0.16205557, -0.14461482, -0.2749642, 0.2683253, 0.6881363, -0.064145364, 0.11361358, 0.59981894, 1.2947721, -1.2500908, 0.6082035, 0.12344158, 0.15808935, -0.17505693, 0.03425684, 0.39107767, 0.23190938, -0.7568858, 0.20042256, 0.079169095, 0.014275463, -0.12135842, 0.008516737, 0.26897284, 0.05706199, -0.52615446, 0.12489152, 0.08065737, -0.038548164, -0.08894516, 7.250979E-4, 0.28635752, -0.010820533, -0.39301336, 0.11144395, 0.06563818, -0.033744805, -0.07450528, -0.027328406, 0.3002447, 0.0029921278, -0.47954947, -0.04527057, -0.010289918, 0.039380465, -0.09236952, -0.1924659, 0.15401903, 0.21237805, -0.38984418, -0.37384143, -0.20648403, 0.29201767, -0.1299253, -0.36048025, -0.5544466, 0.45723814, -0.35266167, -0.94797707, -1.2481197, 0.88701195, 0.33620682, 0.0035414647, -0.22769359, 1.4563162, 0.54950374, 0.38396382, -0.41196275, 0.3758704, 0.17687413, 0.038129736, 0.16358295, 0.70515764, 0.055063568, 0.6445265, -0.2072113, 0.14618243, 0.10311305, 0.1971523, 0.174206, 0.36578146, -0.09782787, 0.5229244, -0.18459272, -0.0013945608, 0.08863555, 0.24184574, 0.15541393, 0.1722381, -0.10531331, 0.38215113, -0.30659106, -0.16298945, 0.11549875, 0.30750987, 0.1586183, -0.017728966, -0.050216004, 0.26232007, -1.2994286, -0.22700997, 0.108534105, 0.7447398, -0.39803517, 0.016863048, 0.10067235, -0.16355589, -0.64953077, -0.5674107, 0.017935256, 0.98968256, -1.395801, 0.44127485, 0.16644385, -0.19195901}; 25 | static const float B1[] = {1.2019119, -1.1770505, 2.1698284, -1.9615222}; 26 | 27 | static const float W[] = {0.55808353, 0.78707385, -0.040990848, -0.122510895, -0.41261443, -0.036044, 0.1691557, -0.14711425, -0.016407091, -0.28058195, 0.018765535, 0.062936015, 0.49562064, 0.33931744, -0.47547337, -0.1405672, -0.88271654, 0.18359914, 0.020887045, -0.13782434, -0.052250575, 0.67922074, -0.28022966, -0.31278887, 0.44416663, -0.26106882, -0.32219923, 1.0321393, -0.1444394, 0.5221766, 0.057590708, -0.96547794, -0.3051688, 0.16859075, -0.5320585, 0.42684716, -0.5434046, 0.014693736, 0.26795483, 0.15921915}; 28 | static const float B[] = {0.041442648, 1.461427, 0.07154641, -1.2774754, 0.80927604, -1.6933714, -0.29740578, -0.11774022, 0.3292682, 0.6596958}; 29 | 30 | // ASCII lib from (https://www.jianshu.com/p/1f58a0ebf5d9) 31 | static const char codeLib[] = "@B%8&WM#*oahkbdpqwmZO0QLCJUYXzcvunxrjft/\\|()1{}[]?-_+~<>i!lI;:,\"^`'. "; 32 | static void print_img(const float * buf) 33 | { 34 | for(int y = 0; y < 28; y++) 35 | { 36 | for (int x = 0; x < 28; x++) 37 | { 38 | int index = 0; 39 | if(buf[y*28+x] > 0.6f) index =69; 40 | if(index < 0) index = 0; 41 | printf("%c",codeLib[index]); 42 | printf("%c",codeLib[index]); 43 | } 44 | printf("\n"); 45 | } 46 | } 47 | 48 | #endif //__MNIST_H__ 49 | -------------------------------------------------------------------------------- /examples/mnist_model.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "mnist.h" 7 | #include "onnx.h" 8 | 9 | #define MNIST_TEST_IMAGE 1 10 | #define ONNX_MODEL_NAME "mnist-sm.onnx" 11 | 12 | #define THREAD_PRIORITY 8 13 | #define THREAD_STACK_SIZE 5120 14 | #define THREAD_TIMESLICE 5 15 | 16 | static rt_thread_t tid1 = RT_NULL; 17 | 18 | static void mnist_model_entry(void* parameter) 19 | { 20 | // 0. Load Model 21 | Onnx__ModelProto* model = onnx_load_model(ONNX_MODEL_NAME); 22 | if(model == NULL) 23 | { 24 | printf("Failed to load model %s\n", ONNX_MODEL_NAME); 25 | return; 26 | } 27 | 28 | // 1. Initialize input 29 | int64_t* shapeInput = (int64_t*) malloc(sizeof(int64_t)*3); 30 | shapeInput[0] = 28; shapeInput[1] = 28; shapeInput[2] = 1; 31 | 32 | float* input = (float*) malloc(sizeof(int64_t)*28*28); 33 | memcpy(input, img[MNIST_TEST_IMAGE], sizeof(float)*28*28); 34 | 35 | print_img(input); 36 | printf("\n"); 37 | 38 | // 2. Run Model 39 | float* output = onnx_model_run(model, input, shapeInput); 40 | 41 | // 3. Print Result 42 | float max = 0; 43 | int max_index = 0; 44 | printf("\nPredictions: \n"); 45 | for(int i = 0; i < 10; i++) 46 | { 47 | printf("%f ", output[i]); 48 | if(output[i] > max) 49 | { 50 | max = output[i]; 51 | max_index = i; 52 | } 53 | } 54 | printf("\n"); 55 | printf("\nThe number is %d\n", max_index); 56 | 57 | // 4. Free model 58 | free(shapeInput); 59 | free(output); 60 | onnx__model_proto__free_unpacked(model, NULL); 61 | } 62 | 63 | static void mnist_model(int argc, char *argv[]) 64 | { 65 | tid1 = rt_thread_create("tonnx_model", 66 | mnist_model_entry, RT_NULL, 67 | THREAD_STACK_SIZE, 68 | THREAD_PRIORITY, THREAD_TIMESLICE); 69 | 70 | if (tid1 != RT_NULL) 71 | { 72 | 73 | rt_thread_startup(tid1); 74 | } 75 | else 76 | { 77 | rt_kprintf("Failed to start onnx thread\n"); 78 | } 79 | } 80 | MSH_CMD_EXPORT(mnist_model, load mnist onnx model from file); 81 | -------------------------------------------------------------------------------- /examples/mnist_sm.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "mnist.h" 7 | #include "onnx.h" 8 | 9 | #define MNIST_TEST_IMAGE 1 10 | #define ONNX_MODEL_NAME "/mnist-sm.onnx" 11 | 12 | #define THREAD_PRIORITY 8 13 | #define THREAD_STACK_SIZE 5120 14 | #define THREAD_TIMESLICE 5 15 | 16 | static rt_thread_t tid1 = RT_NULL; 17 | 18 | static void mnist_sm_entry(void* parameter) 19 | { 20 | // Load Model 21 | Onnx__ModelProto* model = onnx_load_model(ONNX_MODEL_NAME); 22 | if(model == NULL) 23 | { 24 | printf("Failed to load model %s\n", ONNX_MODEL_NAME); 25 | return; 26 | } 27 | 28 | // Set input image: NWHC 29 | print_img(img[MNIST_TEST_IMAGE]); 30 | 31 | // 0. Initialize input shape 32 | int64_t* shapeInput = (int64_t*) malloc(sizeof(int64_t)*3); 33 | int64_t* shapeOutput = (int64_t*) malloc(sizeof(int64_t)*3); 34 | shapeInput[0] = 28; 35 | shapeInput[1] = 28; 36 | shapeInput[2] = 1; 37 | 38 | // 1. Transpose 39 | // float* input = transpose_layer(model->graph, img[img_index], shapeInput, shapeOutput, "Transpose6"); 40 | // memcpy(shapeInput, shapeOutput, sizeof(int64_t)*3); 41 | 42 | // 2. Conv2D 43 | float* conv1 = conv2D_layer(model->graph, img[MNIST_TEST_IMAGE], shapeInput, shapeOutput, "conv2d_5"); 44 | memcpy(shapeInput, shapeOutput, sizeof(int64_t)*3); 45 | // free(input); 46 | 47 | // 3. Relu 48 | float* relu1 = relu_layer(model->graph, conv1, shapeInput, shapeOutput, "Relu1"); 49 | memcpy(shapeInput, shapeOutput, sizeof(int64_t)*3); 50 | free(conv1); 51 | 52 | // 4. Maxpool 53 | float* maxpool1 = maxpool_layer(model->graph, relu1, shapeInput, shapeOutput, "max_pooling2d_5"); 54 | memcpy(shapeInput, shapeOutput, sizeof(int64_t)*3); 55 | free(relu1); 56 | 57 | // 5. Conv2D 58 | float* conv2 = conv2D_layer(model->graph, maxpool1, shapeInput, shapeOutput, "conv2d_6"); 59 | memcpy(shapeInput, shapeOutput, sizeof(int64_t)*3); 60 | free(maxpool1); 61 | 62 | // 6. Relu 63 | float* relu2 = relu_layer(model->graph, conv2, shapeInput, shapeOutput, "Relu"); 64 | memcpy(shapeInput, shapeOutput, sizeof(int64_t)*3); 65 | free(conv2); 66 | 67 | // 7. Maxpool 68 | float* maxpool2 = maxpool_layer(model->graph, relu2, shapeInput, shapeOutput, "max_pooling2d_6"); 69 | memcpy(shapeInput, shapeOutput, sizeof(int64_t)*3); 70 | free(relu2); 71 | 72 | // 8. Transpose 73 | // float* maxpool2_t = transpose_layer(model->graph, maxpool2, shapeInput, shapeOutput, "Transpose1"); 74 | // memcpy(shapeInput, shapeOutput, sizeof(int64_t)*3); 75 | // free(maxpool2); 76 | 77 | // 9. Flatten 78 | shapeInput[1] = shapeInput[0] * shapeInput[1] * shapeInput[2]; 79 | shapeInput[2] = 1; 80 | shapeInput[0] = 1; 81 | 82 | // 10. Dense 83 | float* matmul1 = matmul_layer(model->graph, maxpool2, shapeInput, shapeOutput, "dense_5"); 84 | memcpy(shapeInput, shapeOutput, sizeof(int64_t)*3); 85 | free(maxpool2); 86 | 87 | // 11. Add 88 | float* dense1 = add_layer(model->graph, matmul1, shapeInput, shapeOutput, "Add1"); 89 | memcpy(shapeInput, shapeOutput, sizeof(int64_t)*3); 90 | free(matmul1); 91 | 92 | // 12. Dense 93 | float* matmul2 = matmul_layer(model->graph, dense1, shapeInput, shapeOutput, "dense_6"); 94 | memcpy(shapeInput, shapeOutput, sizeof(int64_t)*3); 95 | free(dense1); 96 | 97 | // 13. Add 98 | float* dense2 = add_layer(model->graph, matmul2, shapeInput, shapeOutput, "Add"); 99 | memcpy(shapeInput, shapeOutput, sizeof(int64_t)*3); 100 | free(matmul2); 101 | 102 | // 14. Softmax 103 | float* output = softmax_layer(model->graph, dense2, shapeInput, shapeOutput, "Softmax"); 104 | free(dense2); 105 | 106 | // 15. Identity 107 | // Do Nothing Here 108 | 109 | // Result 110 | float max = 0; 111 | int max_index = 0; 112 | printf("\nPredictions: \n"); 113 | for(int i = 0; i < 10; i++) 114 | { 115 | printf("%f ", output[i]); 116 | if(output[i] > max) 117 | { 118 | max = output[i]; 119 | max_index = i; 120 | } 121 | } 122 | printf("\n"); 123 | printf("\nThe number is %d\n", max_index); 124 | 125 | // Free model 126 | free(shapeInput); 127 | free(shapeOutput); 128 | free(output); 129 | onnx__model_proto__free_unpacked(model, NULL); 130 | 131 | return; 132 | } 133 | 134 | static void mnist_sm(int argc, char const *argv[]) 135 | { 136 | 137 | tid1 = rt_thread_create("tmnist_sm", 138 | mnist_sm_entry, RT_NULL, 139 | THREAD_STACK_SIZE, 140 | THREAD_PRIORITY, THREAD_TIMESLICE); 141 | 142 | if (tid1 != RT_NULL) 143 | { 144 | 145 | rt_thread_startup(tid1); 146 | } 147 | else 148 | { 149 | rt_kprintf("Failed to start mnist-sm thread\n"); 150 | } 151 | 152 | } 153 | MSH_CMD_EXPORT(mnist_sm, mnist small model) 154 | -------------------------------------------------------------------------------- /examples/model/mnist-keras.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 导入库" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "#coding:utf-8\n", 17 | "from tensorflow.examples.tutorials.mnist import input_data\n", 18 | "\n", 19 | "import numpy as np\n", 20 | "np.set_printoptions(suppress=True)\n", 21 | "\n", 22 | "import matplotlib.pyplot as plt\n", 23 | "%matplotlib inline" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "# 导入数据集" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 3, 36 | "metadata": {}, 37 | "outputs": [ 38 | { 39 | "name": "stderr", 40 | "output_type": "stream", 41 | "text": [ 42 | "W0829 10:37:44.431263 12720 deprecation.py:323] From d:\\anaconda3\\envs\\tensorflow\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n", 43 | "Instructions for updating:\n", 44 | "Please use tf.data to implement this functionality.\n" 45 | ] 46 | }, 47 | { 48 | "name": "stdout", 49 | "output_type": "stream", 50 | "text": [ 51 | "Extracting MNIST_data/train-images-idx3-ubyte.gz\n" 52 | ] 53 | }, 54 | { 55 | "name": "stderr", 56 | "output_type": "stream", 57 | "text": [ 58 | "W0829 10:37:44.709223 12720 deprecation.py:323] From d:\\anaconda3\\envs\\tensorflow\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n", 59 | "Instructions for updating:\n", 60 | "Please use tf.data to implement this functionality.\n", 61 | "W0829 10:37:44.722234 12720 deprecation.py:323] From d:\\anaconda3\\envs\\tensorflow\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n", 62 | "Instructions for updating:\n", 63 | "Please use tf.one_hot on tensors.\n", 64 | "W0829 10:37:44.785238 12720 deprecation.py:323] From d:\\anaconda3\\envs\\tensorflow\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n", 65 | "Instructions for updating:\n", 66 | "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n" 67 | ] 68 | }, 69 | { 70 | "name": "stdout", 71 | "output_type": "stream", 72 | "text": [ 73 | "Extracting MNIST_data/train-labels-idx1-ubyte.gz\n", 74 | "Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n", 75 | "Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n" 76 | ] 77 | } 78 | ], 79 | "source": [ 80 | "mnist = input_data.read_data_sets(\"MNIST_data/\", one_hot=True) #MNIST数据输入" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 4, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "X_train = mnist.train.images\n", 90 | "y_train = mnist.train.labels\n", 91 | "X_test = mnist.test.images\n", 92 | "y_test = mnist.test.labels\n", 93 | "\n", 94 | "# 输入图像大小是 28x28 大小\n", 95 | "X_train = X_train.reshape([-1, 28, 28, 1])\n", 96 | "X_test = X_test.reshape([-1, 28, 28, 1])" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 8, 102 | "metadata": {}, 103 | "outputs": [ 104 | { 105 | "data": { 106 | "text/plain": [ 107 | "" 108 | ] 109 | }, 110 | "execution_count": 8, 111 | "metadata": {}, 112 | "output_type": "execute_result" 113 | }, 114 | { 115 | "data": { 116 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAANgElEQVR4nO3dXaxV9ZnH8d9vEKKxjS+jMowwUvC1zgVVJBonE8dK43iDTaz2JFaqzZxqcAKmJmMck3rhRTMZiiYmNTSS0kmlqWlVNM0MLyEhhFgFwxyw2Oo0WCgERBQO0dgRn7k4y8kRz1r7sNfaL+c8309ysvdez15rPdnhx1p7//def0eEAEx+f9HrBgB0B2EHkiDsQBKEHUiCsANJnNbNndnmo3+gwyLCYy2vdWS3fbPt39l+y/ZDdbYFoLPc7ji77SmSfi9poaR9kl6VNBARv61YhyM70GGdOLIvkPRWRPwhIv4s6eeSFtXYHoAOqhP2CyXtHfV4X7HsM2wP2t5me1uNfQGoqc4HdGOdKnzuND0iVkpaKXEaD/RSnSP7PkmzRj2eKWl/vXYAdEqdsL8q6RLbX7I9TdI3Ja1tpi0ATWv7ND4iPrZ9v6T/kjRF0qqIeL2xzgA0qu2ht7Z2xnt2oOM68qUaABMHYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiCJtudnlyTbeyQNSzoh6eOImN9EUwCaVyvshX+IiMMNbAdAB3EaDyRRN+whaZ3t7bYHx3qC7UHb22xvq7kvADU4Itpf2f7riNhv+wJJ6yX9c0Rsrnh++zsDMC4R4bGW1zqyR8T+4vaQpOckLaizPQCd03bYbZ9p+4uf3pf0NUm7mmoMQLPqfBo/XdJztj/dzjMR8Z+NdAWgcbXes5/yznjPDnRcR96zA5g4CDuQBGEHkiDsQBKEHUiiiR/CoMfuvvvu0lqr0ZZ33323sn7FFVdU1rdu3VpZ37JlS2Ud3cORHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSmDTj7AMDA5X1q666qrJeNVbd784+++y21z1x4kRlfdq0aZX1Dz/8sLL+wQcflNZ27txZue7tt99eWX/nnXcq6/gsjuxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kMSEurrs8uXLS2tLly6tXHfKlCl1do0e2LRpU2W91XcrDh482GQ7EwZXlwWSI+xAEoQdSIKwA0kQdiAJwg4kQdiBJCbUOPvevXtLazNnzqxcd2hoqLLe6nfZndTq2urPP/98lzo5dQsXLqys33XXXaW12bNn19p3q3H4O+64o7Q2mX8L3/Y4u+1Vtg/Z3jVq2bm219t+s7g9p8lmATRvPKfxP5F080nLHpK0MSIukbSxeAygj7UMe0RslnTkpMWLJK0u7q+WdGvDfQFoWLvXoJseEQckKSIO2L6g7Im2ByUNtrkfAA3p+AUnI2KlpJVS/Q/oALSv3aG3g7ZnSFJxe6i5lgB0QrthXytpcXF/saQXmmkHQKe0HGe3vUbSDZLOk3RQ0vclPS/pF5L+RtIfJX0jIk7+EG+sbdU6jb/00ktLa1deeWXluhs2bKisDw8Pt9UTqs2ZM6e09tJLL1Wu22pu+FYefPDB0lrVtREmurJx9pbv2SOi7AoBX63VEYCu4uuyQBKEHUiCsANJEHYgCcIOJDGhfuKKyeW2226rrD/77LO1tn/48OHS2vnnn19r2/2MS0kDyRF2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEh2fEQa53XfffaW1a665pqP7Pv3000trV199deW627dvb7qdnuPIDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJcN34SWDGjBmltTvvvLNy3WXLljXdzmdU9WaPeXnzrjh27Fhl/ayzzupSJ81r+7rxtlfZPmR716hlj9r+k+0dxd8tTTYLoHnjOY3/iaSbx1i+IiLmFX+/brYtAE1rGfaI2CzpSBd6AdBBdT6gu9/2UHGaf07Zk2wP2t5me1uNfQGoqd2w/0jSXEnzJB2QtLzsiRGxMiLmR8T8NvcFoAFthT0iDkbEiYj4RNKPJS1oti0ATWsr7LZHj6d8XdKusucC6A8tf89ue42kGySdZ3ufpO9LusH2PEkhaY+k73awx0nvpptuqqy3+u314OBgaW3OnDlt9TTZrVq1qtctdF3LsEfEwBiLn+5ALwA6iK/LAkkQdiAJwg4kQdiBJAg7kASXkm7AxRdfXFl/6qmnKus33nhjZb2TPwV9++23K+vvvfdere0/8sgjpbWPPvqoct0nn3yysn7ZZZe11ZMk7d+/v+11JyqO7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBOPs4/TAAw+U1pYsWVK57ty5cyvrx48fr6y///77lfXHH3+8tNZqPHnr1q2V9Vbj8J109OjRWusPDw+X1l588cVa256IOLIDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKMs4/TddddV1prNY6+du3ayvry5aUT6kiSNm/eXFmfqObNm1dZv+iii2ptv+r38m+88UatbU9EHNmBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnG2cfp3nvvLa0NDQ1VrvvYY4813c6k0Op6+9OnT6+1/Q0bNtRaf7JpeWS3Pcv2Jtu7bb9ue2mx/Fzb622/Wdye0/l2AbRrPKfxH0v6XkRcIelaSUtsf1nSQ5I2RsQlkjYWjwH0qZZhj4gDEfFacX9Y0m5JF0paJGl18bTVkm7tVJMA6jul9+y2Z0v6iqTfSJoeEQekkf8QbF9Qss6gpMF6bQKoa9xht/0FSb+UtCwijo13ssGIWClpZbGNaKdJAPWNa+jN9lSNBP1nEfGrYvFB2zOK+gxJhzrTIoAmtDyye+QQ/rSk3RHxw1GltZIWS/pBcftCRzrsE0eOHCmtMbTWnmuvvbbW+q0usf3EE0/U2v5kM57T+OslfUvSTts7imUPayTkv7D9HUl/lPSNzrQIoAktwx4RWySVvUH/arPtAOgUvi4LJEHYgSQIO5AEYQeSIOxAEvzEFR21c+fO0trll19ea9vr1q2rrL/88su1tj/ZcGQHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQYZ0dHzZ49u7R22mnV//yOHj1aWV+xYkU7LaXFkR1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkmCcHbUMDAxU1s8444zS2vDwcOW6g4PVs4bxe/VTw5EdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5JwRFQ/wZ4l6aeS/krSJ5JWRsQTth+V9E+S3ime+nBE/LrFtqp3hr4zderUyvorr7xSWa+6NvyaNWsq173nnnsq6xhbRIw56/J4vlTzsaTvRcRrtr8oabvt9UVtRUT8e1NNAuic8czPfkDSgeL+sO3dki7sdGMAmnVK79ltz5b0FUm/KRbdb3vI9irb55SsM2h7m+1ttToFUMu4w277C5J+KWlZRByT9CNJcyXN08iRf/lY60XEyoiYHxHzG+gXQJvGFXbbUzUS9J9FxK8kKSIORsSJiPhE0o8lLehcmwDqahl225b0tKTdEfHDUctnjHra1yXtar49AE0Zz6fx10v6lqSdtncUyx6WNGB7nqSQtEfSdzvSIXqq1dDsM888U1nfsWNHaW39+vWlNTRvPJ/Gb5E01rhd5Zg6gP7CN+iAJAg7kARhB5Ig7EAShB1IgrADSbT8iWujO+MnrkDHlf3ElSM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTR7SmbD0t6e9Tj84pl/ahfe+vXviR6a1eTvV1UVujql2o+t3N7W79em65fe+vXviR6a1e3euM0HkiCsANJ9DrsK3u8/yr92lu/9iXRW7u60ltP37MD6J5eH9kBdAlhB5LoSdht32z7d7bfsv1QL3ooY3uP7Z22d/R6frpiDr1DtneNWnau7fW23yxux5xjr0e9PWr7T8Vrt8P2LT3qbZbtTbZ3237d9tJieU9fu4q+uvK6df09u+0pkn4vaaGkfZJelTQQEb/taiMlbO+RND8iev4FDNt/L+m4pJ9GxN8Wy/5N0pGI+EHxH+U5EfEvfdLbo5KO93oa72K2ohmjpxmXdKukb6uHr11FX7erC69bL47sCyS9FRF/iIg/S/q5pEU96KPvRcRmSUdOWrxI0uri/mqN/GPpupLe+kJEHIiI14r7w5I+nWa8p69dRV9d0YuwXyhp76jH+9Rf872HpHW2t9se7HUzY5geEQekkX88ki7ocT8nazmNdzedNM1437x27Ux/Xlcvwj7W9bH6afzv+oi4StI/SlpSnK5ifMY1jXe3jDHNeF9od/rzunoR9n2SZo16PFPS/h70MaaI2F/cHpL0nPpvKuqDn86gW9we6nE//6+fpvEea5px9cFr18vpz3sR9lclXWL7S7anSfqmpLU96ONzbJ9ZfHAi22dK+pr6byrqtZIWF/cXS3qhh718Rr9M4102zbh6/Nr1fPrziOj6n6RbNPKJ/P9I+tde9FDS1xxJ/138vd7r3iSt0chp3f9q5IzoO5L+UtJGSW8Wt+f2UW//IWmnpCGNBGtGj3r7O428NRyStKP4u6XXr11FX1153fi6LJAE36ADkiDsQBKEHUiCsANJEHYgCcIOJEHYgST+D0dqK8VlJwIwAAAAAElFTkSuQmCC\n", 117 | "text/plain": [ 118 | "
" 119 | ] 120 | }, 121 | "metadata": { 122 | "needs_background": "light" 123 | }, 124 | "output_type": "display_data" 125 | } 126 | ], 127 | "source": [ 128 | "plt.imshow(X_train[0].reshape((28, 28)), cmap='gray')" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 9, 134 | "metadata": {}, 135 | "outputs": [ 136 | { 137 | "data": { 138 | "text/plain": [ 139 | "" 140 | ] 141 | }, 142 | "execution_count": 9, 143 | "metadata": {}, 144 | "output_type": "execute_result" 145 | }, 146 | { 147 | "data": { 148 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAOHUlEQVR4nO3dS4xc5ZnG8ecBko2ThXEbsIixkwiZGUYaggwaCWMFRTGXje1FRjEXMRqgIxSkALMYcxFBMrbQaMiAN4aOQHFGgSjyRVgRKEFWZDMb5BsDhjaBQYztYPkCixCxyIDfWfRx1DF9vtOu26n2+/9Jrao6b52q1+V++pyqr875HBECcPY7p+0GAAwGYQeSIOxAEoQdSIKwA0mcN8gns81H/0CfRYSnWt7Vlt32Dbbfsf2e7dXdPBaA/nKn4+y2z5X0e0nflXRY0i5JqyLi7cI6bNmBPuvHlv1qSe9FxPsR8WdJv5S0vIvHA9BH3YT9YkmHJt0+XC37K7ZHbe+2vbuL5wLQpW4+oJtqV+ELu+kRMSZpTGI3HmhTN1v2w5LmT7r9NUkfdtcOgH7pJuy7JF1q++u2vyzp+5K29aYtAL3W8W58RHxm+x5Jv5F0rqTnIuKtnnUGoKc6Hnrr6Ml4zw70XV++VANg5iDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgiY6nbMbZYcGCBcX6nXfeWaw/9NBDxXpplmB7yslG/2J8fLxYf/jhh4v1rVu3FuvZdBV22x9I+kTS55I+i4jFvWgKQO/1Yst+XUSc6MHjAOgj3rMDSXQb9pD0W9t7bI9OdQfbo7Z3297d5XMB6EK3u/HXRMSHti+Q9IrtAxGxc/IdImJM0pgk2a7/tAZAX3W1ZY+ID6vLY5K2Srq6F00B6L2Ow257lu2vnrouaZmk/b1qDEBvuTQOWlzR/oYmtubSxNuB5yNibcM67Mb3wdy5c2trDzzwQHHdW265pVifM2dOsd40Vt7NOHvT7+ahQ4eK9auuuqq2duLE2TuAFBFTvrAdv2ePiPcl/X3HHQEYKIbegCQIO5AEYQeSIOxAEoQdSKLjobeOnoyht440HUa6Zs2a2lrT/2+/h7+OHz9erJeMjIwU6wsXLizW33777dra5Zdf3klLM0Ld0BtbdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnH2GWDXrl3F+pVXXllb63acvTRWLUnXXXddsd7NoaRLliwp1nfs2FGsl/7t55139p5FnXF2IDnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCcfYhcNlllxXrTePsH330UW2t6XjypnHw++67r1i/9957i/V169bV1g4ePFhct0nT7+7Jkydra3fffXdx3bGxsY56GgaMswPJEXYgCcIOJEHYgSQIO5AEYQeSIOxAEoyzzwBN4/ClsfJupyYeHR0t1jds2FCsl6ZN3rt3b3HdlStXFuubNm0q1ku/2xdddFFx3Zk8pXPH4+y2n7N9zPb+ScvOt/2K7Xery9m9bBZA701nN/5nkm44bdlqSdsj4lJJ26vbAIZYY9gjYqekj09bvFzSxur6RkkretwXgB7r9ERcF0bEEUmKiCO2L6i7o+1RSeU3fgD6ru9n3YuIMUljEh/QAW3qdOjtqO15klRdHutdSwD6odOwb5N0e3X9dkkv9qYdAP3SuBtv+wVJ35Y0YvuwpB9LelzSr2zfIemgpO/1s8nsDhw40NpzNx0P/8477xTrpWPtm46VX726PMjTdM77fn7/YCZqDHtErKopfafHvQDoI74uCyRB2IEkCDuQBGEHkiDsQBJn77y1iSxdurS21nR4bNPQ2vj4eLG+aNGiYv21116rrc2dO7e4btPh102933jjjcV6NmzZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJxtnPAjfffHNt7a677iqu23SYaNNYd9P6pbH0bg5RlaT169cX602nqs6GLTuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJME4+1mu2ym5+7n+q6++Wlz3/vvvL9YZRz8zbNmBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnG2c8Czz//fG1twYIFxXVHRkaK9abzzs+aNatYL3nkkUeKdcbRe6txy277OdvHbO+ftOxR23+w/Xr1c1N/2wTQrensxv9M0g1TLP+PiLii+nmpt20B6LXGsEfETkkfD6AXAH3UzQd099h+o9rNn113J9ujtnfb3t3FcwHoUqdh3yDpm5KukHRE0hN1d4yIsYhYHBGLO3wuAD3QUdgj4mhEfB4RJyX9VNLVvW0LQK91FHbb8ybdXClpf919AQwHT+O84C9I+rakEUlHJf24un2FpJD0gaQfRMSRxiezuzs4GgPXNM7+2GOPFesrVqyore3bt6+4btP86k3nlc8qIqY8IX/jl2oiYtUUi5/tuiMAA8XXZYEkCDuQBGEHkiDsQBKEHUiiceitp082g4feSlMPHz9+fICdzCwvv/xybe36668vrtt0Kuknn3yyo57OdnVDb2zZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJTiVdWbp0abH+xBO1J+PRgQMHiuvedtttHfV0Nli7dm1tbdmyZcV1Fy1a1Ot2UmPLDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJpBlnLx2PLklPP/10sX7s2LHaWuZx9KYpm5955pnamj3lYdfoE7bsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5BEmnH2lStXFutNx07v2LGjl+3MGE1TNm/evLlYL72uTXMWNJ0nAGemcctue77t39ket/2W7R9Vy8+3/Yrtd6vL2f1vF0CnprMb/5mkf4mIv5H0D5J+aPtvJa2WtD0iLpW0vboNYEg1hj0ijkTE3ur6J5LGJV0sabmkjdXdNkpa0a8mAXTvjN6z214o6VuSXpN0YUQckSb+INi+oGadUUmj3bUJoFvTDrvtr0jaLOneiPjjdA9iiIgxSWPVY8zYiR2BmW5aQ2+2v6SJoP8iIrZUi4/anlfV50mqPywMQOsat+ye2IQ/K2k8In4yqbRN0u2SHq8uX+xLhz2yc+fOYv2cc8p/90qnmr711luL646Pjxfre/bsKdabLFiwoLZ27bXXFtdtGpJcsaL8UUzTHl5peO2pp54qrttUx5mZzm78NZJuk/Sm7derZQ9qIuS/sn2HpIOSvtefFgH0QmPYI+K/JNX9+f5Ob9sB0C98XRZIgrADSRB2IAnCDiRB2IEk3HSYYU+fbIi/Qbdp06ZivTTe3M1YsyTt27evWG9yySWX1NbmzJlTXLfb3pvWL03ZvH79+uK6J06cKNYxtYiY8j+FLTuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJME4e6VpSueXXnqptrZ48eLiuidPnizW+znW3bTup59+Wqw3nc553bp1xfrWrVuLdfQe4+xAcoQdSIKwA0kQdiAJwg4kQdiBJAg7kATj7NM0MjJSW1uzZk1Xjz06Wp4da8uWLcV6N8d9N52bnWmTZx7G2YHkCDuQBGEHkiDsQBKEHUiCsANJEHYgicZxdtvzJf1c0kWSTkoai4inbD8q6S5Jx6u7PhgR9Qd9a2aPswMzRd04+3TCPk/SvIjYa/urkvZIWiHpHyX9KSL+fbpNEHag/+rCPp352Y9IOlJd/8T2uKSLe9segH47o/fsthdK+pak16pF99h+w/ZztmfXrDNqe7ft3V11CqAr0/5uvO2vSNohaW1EbLF9oaQTkkLSGk3s6v9zw2OwGw/0Wcfv2SXJ9pck/VrSbyLiJ1PUF0r6dUT8XcPjEHagzzo+EMYTpy59VtL45KBXH9ydslLS/m6bBNA/0/k0fomkVyW9qYmhN0l6UNIqSVdoYjf+A0k/qD7MKz0WW3agz7raje8Vwg70H8ezA8kRdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkmg84WSPnZD0v5Nuj1TLhtGw9jasfUn01qle9ragrjDQ49m/8OT27ohY3FoDBcPa27D2JdFbpwbVG7vxQBKEHUii7bCPtfz8JcPa27D2JdFbpwbSW6vv2QEMTttbdgADQtiBJFoJu+0bbL9j+z3bq9vooY7tD2y/afv1tuenq+bQO2Z7/6Rl59t+xfa71eWUc+y11Nujtv9QvXav276ppd7m2/6d7XHbb9n+UbW81deu0NdAXreBv2e3fa6k30v6rqTDknZJWhURbw+0kRq2P5C0OCJa/wKG7aWS/iTp56em1rL9b5I+jojHqz+UsyPiX4ekt0d1htN496m3umnG/0ktvna9nP68E21s2a+W9F5EvB8Rf5b0S0nLW+hj6EXETkkfn7Z4uaSN1fWNmvhlGbia3oZCRByJiL3V9U8knZpmvNXXrtDXQLQR9oslHZp0+7CGa773kPRb23tsj7bdzBQuPDXNVnV5Qcv9nK5xGu9BOm2a8aF57TqZ/rxbbYR9qqlphmn875qIuFLSjZJ+WO2uYno2SPqmJuYAPCLpiTabqaYZ3yzp3oj4Y5u9TDZFXwN53doI+2FJ8yfd/pqkD1voY0oR8WF1eUzSVk287RgmR0/NoFtdHmu5n7+IiKMR8XlEnJT0U7X42lXTjG+W9IuI2FItbv21m6qvQb1ubYR9l6RLbX/d9pclfV/Sthb6+ALbs6oPTmR7lqRlGr6pqLdJur26frukF1vs5a8MyzTeddOMq+XXrvXpzyNi4D+SbtLEJ/L/I+mhNnqo6esbkv67+nmr7d4kvaCJ3br/08Qe0R2S5kjaLund6vL8IertPzUxtfcbmgjWvJZ6W6KJt4ZvSHq9+rmp7deu0NdAXje+LgskwTfogCQIO5AEYQeSIOxAEoQdSIKwA0kQdiCJ/wftgrMNjgT54AAAAABJRU5ErkJggg==\n", 149 | "text/plain": [ 150 | "
" 151 | ] 152 | }, 153 | "metadata": { 154 | "needs_background": "light" 155 | }, 156 | "output_type": "display_data" 157 | } 158 | ], 159 | "source": [ 160 | "plt.imshow(X_train[1].reshape((28, 28)), cmap='gray')" 161 | ] 162 | }, 163 | { 164 | "cell_type": "markdown", 165 | "metadata": {}, 166 | "source": [ 167 | "# 构建模型" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 14, 173 | "metadata": {}, 174 | "outputs": [ 175 | { 176 | "name": "stderr", 177 | "output_type": "stream", 178 | "text": [ 179 | "Using TensorFlow backend.\n" 180 | ] 181 | } 182 | ], 183 | "source": [ 184 | "# Importing the Keras libraries and packages\n", 185 | "# Importing the Keras libraries and packages\n", 186 | "from keras.models import Sequential\n", 187 | "from keras.layers import Dense\n", 188 | "from keras.layers import Conv2D\n", 189 | "from keras.layers import MaxPooling2D\n", 190 | "from keras.layers import Dropout\n", 191 | "from keras.layers import Flatten" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": 25, 197 | "metadata": {}, 198 | "outputs": [], 199 | "source": [ 200 | "def build_classifier():\n", 201 | " # Initialising the CNN\n", 202 | " classifier = Sequential()\n", 203 | "\n", 204 | " # Adding the first CNN layer and some Dropout regularisation\n", 205 | " classifier.add(Conv2D(filters = 2, kernel_size = 3, strides = 1, padding = \"SAME\", activation = \"relu\", input_shape = (28, 28, 1)))\n", 206 | " classifier.add(MaxPooling2D(pool_size=(2, 2), padding='SAME'))\n", 207 | " classifier.add(Dropout(0.3))\n", 208 | "\n", 209 | " classifier.add(Conv2D(filters = 2, kernel_size = 3, strides = 1, padding = \"SAME\", activation = \"relu\"))\n", 210 | " classifier.add(MaxPooling2D(pool_size=(2, 2), padding='SAME'))\n", 211 | " classifier.add(Dropout(0.3))\n", 212 | "\n", 213 | " classifier.add(Flatten())\n", 214 | " classifier.add(Dense(kernel_initializer=\"uniform\", units = 4))\n", 215 | "\n", 216 | " # Adding the output layer\n", 217 | " classifier.add(Dense(kernel_initializer=\"uniform\", units = 10, activation=\"softmax\"))\n", 218 | " classifier.compile(optimizer = 'adam', loss = 'categorical_crossentropy', metrics=['accuracy'])\n", 219 | "\n", 220 | " return classifier\n" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": 26, 226 | "metadata": {}, 227 | "outputs": [ 228 | { 229 | "name": "stdout", 230 | "output_type": "stream", 231 | "text": [ 232 | "_________________________________________________________________\n", 233 | "Layer (type) Output Shape Param # \n", 234 | "=================================================================\n", 235 | "conv2d_3 (Conv2D) (None, 28, 28, 2) 20 \n", 236 | "_________________________________________________________________\n", 237 | "max_pooling2d_3 (MaxPooling2 (None, 14, 14, 2) 0 \n", 238 | "_________________________________________________________________\n", 239 | "dropout_3 (Dropout) (None, 14, 14, 2) 0 \n", 240 | "_________________________________________________________________\n", 241 | "conv2d_4 (Conv2D) (None, 14, 14, 2) 38 \n", 242 | "_________________________________________________________________\n", 243 | "max_pooling2d_4 (MaxPooling2 (None, 7, 7, 2) 0 \n", 244 | "_________________________________________________________________\n", 245 | "dropout_4 (Dropout) (None, 7, 7, 2) 0 \n", 246 | "_________________________________________________________________\n", 247 | "flatten_2 (Flatten) (None, 98) 0 \n", 248 | "_________________________________________________________________\n", 249 | "dense_3 (Dense) (None, 4) 396 \n", 250 | "_________________________________________________________________\n", 251 | "dense_4 (Dense) (None, 10) 50 \n", 252 | "=================================================================\n", 253 | "Total params: 504\n", 254 | "Trainable params: 504\n", 255 | "Non-trainable params: 0\n", 256 | "_________________________________________________________________\n" 257 | ] 258 | } 259 | ], 260 | "source": [ 261 | "classifier = build_classifier()\n", 262 | "classifier.summary()" 263 | ] 264 | }, 265 | { 266 | "cell_type": "markdown", 267 | "metadata": {}, 268 | "source": [ 269 | "# 训练模型" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": 27, 275 | "metadata": {}, 276 | "outputs": [], 277 | "source": [ 278 | "from keras.callbacks import ModelCheckpoint\n", 279 | "checkpointer = ModelCheckpoint(filepath='minions.hdf5', verbose=1, save_best_only=True, monitor='val_loss',mode='min')" 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": 28, 285 | "metadata": {}, 286 | "outputs": [ 287 | { 288 | "name": "stdout", 289 | "output_type": "stream", 290 | "text": [ 291 | "Train on 55000 samples, validate on 10000 samples\n", 292 | "Epoch 1/50\n", 293 | "55000/55000 [==============================] - 10s 178us/step - loss: 1.3555 - acc: 0.5203 - val_loss: 0.8780 - val_acc: 0.7317\n", 294 | "\n", 295 | "Epoch 00001: val_loss improved from inf to 0.87800, saving model to minions.hdf5\n", 296 | "Epoch 2/50\n", 297 | "55000/55000 [==============================] - 9s 171us/step - loss: 1.0732 - acc: 0.6216 - val_loss: 0.8069 - val_acc: 0.7405\n", 298 | "\n", 299 | "Epoch 00002: val_loss improved from 0.87800 to 0.80690, saving model to minions.hdf5\n", 300 | "Epoch 3/50\n", 301 | "55000/55000 [==============================] - 9s 172us/step - loss: 1.0243 - acc: 0.6383 - val_loss: 0.7807 - val_acc: 0.7610\n", 302 | "\n", 303 | "Epoch 00003: val_loss improved from 0.80690 to 0.78067, saving model to minions.hdf5\n", 304 | "Epoch 4/50\n", 305 | "55000/55000 [==============================] - 9s 172us/step - loss: 0.9878 - acc: 0.6533 - val_loss: 0.7382 - val_acc: 0.7770\n", 306 | "\n", 307 | "Epoch 00004: val_loss improved from 0.78067 to 0.73821, saving model to minions.hdf5\n", 308 | "Epoch 5/50\n", 309 | "55000/55000 [==============================] - 9s 169us/step - loss: 0.9613 - acc: 0.6623 - val_loss: 0.7169 - val_acc: 0.7922\n", 310 | "\n", 311 | "Epoch 00005: val_loss improved from 0.73821 to 0.71689, saving model to minions.hdf5\n", 312 | "Epoch 6/50\n", 313 | "55000/55000 [==============================] - 10s 173us/step - loss: 0.9434 - acc: 0.6650 - val_loss: 0.6947 - val_acc: 0.7987\n", 314 | "\n", 315 | "Epoch 00006: val_loss improved from 0.71689 to 0.69465, saving model to minions.hdf5\n", 316 | "Epoch 7/50\n", 317 | "55000/55000 [==============================] - 9s 171us/step - loss: 0.9078 - acc: 0.6775 - val_loss: 0.6258 - val_acc: 0.8173\n", 318 | "\n", 319 | "Epoch 00007: val_loss improved from 0.69465 to 0.62578, saving model to minions.hdf5\n", 320 | "Epoch 8/50\n", 321 | "55000/55000 [==============================] - 9s 169us/step - loss: 0.8136 - acc: 0.7133 - val_loss: 0.5246 - val_acc: 0.8476\n", 322 | "\n", 323 | "Epoch 00008: val_loss improved from 0.62578 to 0.52461, saving model to minions.hdf5\n", 324 | "Epoch 9/50\n", 325 | "55000/55000 [==============================] - 10s 174us/step - loss: 0.7273 - acc: 0.7461 - val_loss: 0.4556 - val_acc: 0.8652\n", 326 | "\n", 327 | "Epoch 00009: val_loss improved from 0.52461 to 0.45560, saving model to minions.hdf5\n", 328 | "Epoch 10/50\n", 329 | "55000/55000 [==============================] - 10s 174us/step - loss: 0.6914 - acc: 0.7601 - val_loss: 0.4300 - val_acc: 0.8750\n", 330 | "\n", 331 | "Epoch 00010: val_loss improved from 0.45560 to 0.43003, saving model to minions.hdf5\n", 332 | "Epoch 11/50\n", 333 | "55000/55000 [==============================] - 10s 177us/step - loss: 0.6735 - acc: 0.7685 - val_loss: 0.4067 - val_acc: 0.8877\n", 334 | "\n", 335 | "Epoch 00011: val_loss improved from 0.43003 to 0.40667, saving model to minions.hdf5\n", 336 | "Epoch 12/50\n", 337 | "55000/55000 [==============================] - 9s 168us/step - loss: 0.6610 - acc: 0.7749 - val_loss: 0.4092 - val_acc: 0.8806\n", 338 | "\n", 339 | "Epoch 00012: val_loss did not improve from 0.40667\n", 340 | "Epoch 13/50\n", 341 | "55000/55000 [==============================] - 9s 166us/step - loss: 0.6581 - acc: 0.7785 - val_loss: 0.3992 - val_acc: 0.8874\n", 342 | "\n", 343 | "Epoch 00013: val_loss improved from 0.40667 to 0.39921, saving model to minions.hdf5\n", 344 | "Epoch 14/50\n", 345 | "55000/55000 [==============================] - 9s 168us/step - loss: 0.6510 - acc: 0.7780 - val_loss: 0.3958 - val_acc: 0.8838\n", 346 | "\n", 347 | "Epoch 00014: val_loss improved from 0.39921 to 0.39576, saving model to minions.hdf5\n", 348 | "Epoch 15/50\n", 349 | "55000/55000 [==============================] - 9s 170us/step - loss: 0.6450 - acc: 0.7811 - val_loss: 0.4030 - val_acc: 0.8782\n", 350 | "\n", 351 | "Epoch 00015: val_loss did not improve from 0.39576\n", 352 | "Epoch 16/50\n", 353 | "55000/55000 [==============================] - 9s 169us/step - loss: 0.6391 - acc: 0.7837 - val_loss: 0.3956 - val_acc: 0.8850\n", 354 | "\n", 355 | "Epoch 00016: val_loss improved from 0.39576 to 0.39564, saving model to minions.hdf5\n", 356 | "Epoch 17/50\n", 357 | "55000/55000 [==============================] - 9s 167us/step - loss: 0.6345 - acc: 0.7848 - val_loss: 0.3838 - val_acc: 0.8887\n", 358 | "\n", 359 | "Epoch 00017: val_loss improved from 0.39564 to 0.38377, saving model to minions.hdf5\n", 360 | "Epoch 18/50\n", 361 | "55000/55000 [==============================] - 9s 169us/step - loss: 0.6337 - acc: 0.7848 - val_loss: 0.3913 - val_acc: 0.8818\n", 362 | "\n", 363 | "Epoch 00018: val_loss did not improve from 0.38377\n", 364 | "Epoch 19/50\n", 365 | "55000/55000 [==============================] - 9s 168us/step - loss: 0.6338 - acc: 0.7870 - val_loss: 0.3845 - val_acc: 0.8875\n", 366 | "\n", 367 | "Epoch 00019: val_loss did not improve from 0.38377\n", 368 | "Epoch 20/50\n", 369 | "55000/55000 [==============================] - 9s 168us/step - loss: 0.6345 - acc: 0.7847 - val_loss: 0.3817 - val_acc: 0.8917\n", 370 | "\n", 371 | "Epoch 00020: val_loss improved from 0.38377 to 0.38166, saving model to minions.hdf5\n", 372 | "Epoch 21/50\n", 373 | "55000/55000 [==============================] - 9s 171us/step - loss: 0.6208 - acc: 0.7914 - val_loss: 0.3702 - val_acc: 0.8914\n", 374 | "\n", 375 | "Epoch 00021: val_loss improved from 0.38166 to 0.37016, saving model to minions.hdf5\n", 376 | "Epoch 22/50\n", 377 | "55000/55000 [==============================] - 10s 175us/step - loss: 0.6217 - acc: 0.7905 - val_loss: 0.3771 - val_acc: 0.8924\n", 378 | "\n", 379 | "Epoch 00022: val_loss did not improve from 0.37016\n", 380 | "Epoch 23/50\n", 381 | "55000/55000 [==============================] - 10s 175us/step - loss: 0.6203 - acc: 0.7882 - val_loss: 0.3732 - val_acc: 0.8911\n", 382 | "\n", 383 | "Epoch 00023: val_loss did not improve from 0.37016\n", 384 | "Epoch 24/50\n", 385 | "55000/55000 [==============================] - 9s 171us/step - loss: 0.6259 - acc: 0.7895 - val_loss: 0.3805 - val_acc: 0.8883\n", 386 | "\n", 387 | "Epoch 00024: val_loss did not improve from 0.37016\n", 388 | "Epoch 25/50\n", 389 | "55000/55000 [==============================] - 9s 168us/step - loss: 0.6204 - acc: 0.7907 - val_loss: 0.3785 - val_acc: 0.8908\n", 390 | "\n", 391 | "Epoch 00025: val_loss did not improve from 0.37016\n", 392 | "Epoch 26/50\n", 393 | "55000/55000 [==============================] - 9s 169us/step - loss: 0.6194 - acc: 0.7916 - val_loss: 0.3764 - val_acc: 0.8924\n", 394 | "\n", 395 | "Epoch 00026: val_loss did not improve from 0.37016\n", 396 | "Epoch 27/50\n", 397 | "55000/55000 [==============================] - 9s 165us/step - loss: 0.6291 - acc: 0.7876 - val_loss: 0.3873 - val_acc: 0.8820\n", 398 | "\n", 399 | "Epoch 00027: val_loss did not improve from 0.37016\n", 400 | "Epoch 28/50\n", 401 | "55000/55000 [==============================] - 10s 175us/step - loss: 0.6165 - acc: 0.7912 - val_loss: 0.3765 - val_acc: 0.8919\n", 402 | "\n", 403 | "Epoch 00028: val_loss did not improve from 0.37016\n", 404 | "Epoch 29/50\n", 405 | "55000/55000 [==============================] - 10s 178us/step - loss: 0.6232 - acc: 0.7895 - val_loss: 0.3781 - val_acc: 0.8919\n", 406 | "\n", 407 | "Epoch 00029: val_loss did not improve from 0.37016\n", 408 | "Epoch 30/50\n", 409 | "55000/55000 [==============================] - 10s 173us/step - loss: 0.6226 - acc: 0.7904 - val_loss: 0.3909 - val_acc: 0.8823\n", 410 | "\n", 411 | "Epoch 00030: val_loss did not improve from 0.37016\n", 412 | "Epoch 31/50\n", 413 | "55000/55000 [==============================] - 10s 177us/step - loss: 0.6230 - acc: 0.7898 - val_loss: 0.3971 - val_acc: 0.8787\n", 414 | "\n", 415 | "Epoch 00031: val_loss did not improve from 0.37016\n", 416 | "Epoch 32/50\n", 417 | "55000/55000 [==============================] - 9s 168us/step - loss: 0.6197 - acc: 0.7899 - val_loss: 0.3747 - val_acc: 0.8887\n", 418 | "\n", 419 | "Epoch 00032: val_loss did not improve from 0.37016\n", 420 | "Epoch 33/50\n", 421 | "55000/55000 [==============================] - 9s 171us/step - loss: 0.6168 - acc: 0.7918 - val_loss: 0.3770 - val_acc: 0.8871\n", 422 | "\n", 423 | "Epoch 00033: val_loss did not improve from 0.37016\n", 424 | "Epoch 34/50\n", 425 | "55000/55000 [==============================] - 9s 164us/step - loss: 0.6146 - acc: 0.7914 - val_loss: 0.3842 - val_acc: 0.8843\n", 426 | "\n", 427 | "Epoch 00034: val_loss did not improve from 0.37016\n", 428 | "Epoch 35/50\n", 429 | "55000/55000 [==============================] - 9s 169us/step - loss: 0.6208 - acc: 0.7908 - val_loss: 0.3780 - val_acc: 0.8909\n", 430 | "\n", 431 | "Epoch 00035: val_loss did not improve from 0.37016\n", 432 | "Epoch 36/50\n", 433 | "55000/55000 [==============================] - 9s 164us/step - loss: 0.6172 - acc: 0.7935 - val_loss: 0.3705 - val_acc: 0.8917\n", 434 | "\n", 435 | "Epoch 00036: val_loss did not improve from 0.37016\n", 436 | "Epoch 37/50\n", 437 | "55000/55000 [==============================] - 9s 172us/step - loss: 0.6171 - acc: 0.7930 - val_loss: 0.3812 - val_acc: 0.8872\n", 438 | "\n", 439 | "Epoch 00037: val_loss did not improve from 0.37016\n", 440 | "Epoch 38/50\n", 441 | "55000/55000 [==============================] - 9s 169us/step - loss: 0.6223 - acc: 0.7928 - val_loss: 0.3910 - val_acc: 0.8887\n", 442 | "\n", 443 | "Epoch 00038: val_loss did not improve from 0.37016\n", 444 | "Epoch 39/50\n", 445 | "55000/55000 [==============================] - 10s 175us/step - loss: 0.6133 - acc: 0.7920 - val_loss: 0.3759 - val_acc: 0.8900\n", 446 | "\n", 447 | "Epoch 00039: val_loss did not improve from 0.37016\n", 448 | "Epoch 40/50\n", 449 | "55000/55000 [==============================] - 10s 177us/step - loss: 0.6144 - acc: 0.7934 - val_loss: 0.3715 - val_acc: 0.8932\n" 450 | ] 451 | }, 452 | { 453 | "name": "stdout", 454 | "output_type": "stream", 455 | "text": [ 456 | "\n", 457 | "Epoch 00040: val_loss did not improve from 0.37016\n", 458 | "Epoch 41/50\n", 459 | "55000/55000 [==============================] - 9s 166us/step - loss: 0.6202 - acc: 0.7912 - val_loss: 0.3847 - val_acc: 0.8881\n", 460 | "\n", 461 | "Epoch 00041: val_loss did not improve from 0.37016\n", 462 | "Epoch 42/50\n", 463 | "55000/55000 [==============================] - 9s 170us/step - loss: 0.6162 - acc: 0.7936 - val_loss: 0.3791 - val_acc: 0.8903\n", 464 | "\n", 465 | "Epoch 00042: val_loss did not improve from 0.37016\n", 466 | "Epoch 43/50\n", 467 | "55000/55000 [==============================] - 10s 174us/step - loss: 0.6174 - acc: 0.7920 - val_loss: 0.3777 - val_acc: 0.8899\n", 468 | "\n", 469 | "Epoch 00043: val_loss did not improve from 0.37016\n", 470 | "Epoch 44/50\n", 471 | "55000/55000 [==============================] - 9s 173us/step - loss: 0.6199 - acc: 0.7920 - val_loss: 0.3805 - val_acc: 0.8934\n", 472 | "\n", 473 | "Epoch 00044: val_loss did not improve from 0.37016\n", 474 | "Epoch 45/50\n", 475 | "55000/55000 [==============================] - 9s 166us/step - loss: 0.6163 - acc: 0.7916 - val_loss: 0.3764 - val_acc: 0.8931\n", 476 | "\n", 477 | "Epoch 00045: val_loss did not improve from 0.37016\n", 478 | "Epoch 46/50\n", 479 | "55000/55000 [==============================] - 10s 174us/step - loss: 0.6176 - acc: 0.7918 - val_loss: 0.3700 - val_acc: 0.8965\n", 480 | "\n", 481 | "Epoch 00046: val_loss improved from 0.37016 to 0.37002, saving model to minions.hdf5\n", 482 | "Epoch 47/50\n", 483 | "55000/55000 [==============================] - 10s 182us/step - loss: 0.6180 - acc: 0.7921 - val_loss: 0.3797 - val_acc: 0.8936\n", 484 | "\n", 485 | "Epoch 00047: val_loss did not improve from 0.37002\n", 486 | "Epoch 48/50\n", 487 | "55000/55000 [==============================] - 10s 183us/step - loss: 0.6133 - acc: 0.7931 - val_loss: 0.3770 - val_acc: 0.8926\n", 488 | "\n", 489 | "Epoch 00048: val_loss did not improve from 0.37002\n", 490 | "Epoch 49/50\n", 491 | "55000/55000 [==============================] - 11s 196us/step - loss: 0.6175 - acc: 0.7912 - val_loss: 0.3755 - val_acc: 0.8888\n", 492 | "\n", 493 | "Epoch 00049: val_loss did not improve from 0.37002\n", 494 | "Epoch 50/50\n", 495 | "55000/55000 [==============================] - 10s 174us/step - loss: 0.6087 - acc: 0.7945 - val_loss: 0.3695 - val_acc: 0.8977\n", 496 | "\n", 497 | "Epoch 00050: val_loss improved from 0.37002 to 0.36951, saving model to minions.hdf5\n" 498 | ] 499 | } 500 | ], 501 | "source": [ 502 | "history = classifier.fit(X_train, y_train, epochs = 50, batch_size = 50, validation_data=(X_test, y_test), callbacks=[checkpointer])" 503 | ] 504 | }, 505 | { 506 | "cell_type": "markdown", 507 | "metadata": {}, 508 | "source": [ 509 | "# 查看训练过程" 510 | ] 511 | }, 512 | { 513 | "cell_type": "code", 514 | "execution_count": 19, 515 | "metadata": {}, 516 | "outputs": [], 517 | "source": [ 518 | "def plot_history(history) :\n", 519 | " SMALL_SIZE = 20\n", 520 | " MEDIUM_SIZE = 22\n", 521 | " BIGGER_SIZE = 24\n", 522 | "\n", 523 | " plt.rc('font', size=SMALL_SIZE) # controls default text sizes\n", 524 | " plt.rc('axes', titlesize=SMALL_SIZE) # fontsize of the axes title\n", 525 | " plt.rc('axes', labelsize=MEDIUM_SIZE) # fontsize of the x and y labels\n", 526 | " plt.rc('xtick', labelsize=SMALL_SIZE) # fontsize of the tick labels\n", 527 | " plt.rc('ytick', labelsize=SMALL_SIZE) # fontsize of the tick labels\n", 528 | " plt.rc('legend', fontsize=SMALL_SIZE) # legend fontsize\n", 529 | " plt.rc('figure', titlesize=BIGGER_SIZE) # fontsize of the figure title\n", 530 | "\n", 531 | " fig = plt.figure()\n", 532 | " fig.set_size_inches(15,10)\n", 533 | " plt.plot(history['loss'])\n", 534 | " plt.plot(history['val_loss'])\n", 535 | " plt.title('Model Loss')\n", 536 | " plt.xlabel('epoch')\n", 537 | " plt.ylabel('loss')\n", 538 | " plt.legend(['train', 'test'],loc='upper left')\n", 539 | " plt.show()" 540 | ] 541 | }, 542 | { 543 | "cell_type": "code", 544 | "execution_count": 29, 545 | "metadata": {}, 546 | "outputs": [ 547 | { 548 | "data": { 549 | "image/png": "\n", 550 | "text/plain": [ 551 | "
" 552 | ] 553 | }, 554 | "metadata": { 555 | "needs_background": "light" 556 | }, 557 | "output_type": "display_data" 558 | } 559 | ], 560 | "source": [ 561 | "plot_history(history.history)" 562 | ] 563 | }, 564 | { 565 | "cell_type": "markdown", 566 | "metadata": {}, 567 | "source": [ 568 | "# 保存模型" 569 | ] 570 | }, 571 | { 572 | "cell_type": "markdown", 573 | "metadata": {}, 574 | "source": [ 575 | "保存为 Keras 模型" 576 | ] 577 | }, 578 | { 579 | "cell_type": "code", 580 | "execution_count": 30, 581 | "metadata": {}, 582 | "outputs": [], 583 | "source": [ 584 | "classifier.save(\"mnist.h5\")" 585 | ] 586 | }, 587 | { 588 | "cell_type": "markdown", 589 | "metadata": {}, 590 | "source": [ 591 | "保存为 onnx 模型" 592 | ] 593 | }, 594 | { 595 | "cell_type": "code", 596 | "execution_count": 31, 597 | "metadata": {}, 598 | "outputs": [], 599 | "source": [ 600 | "import onnx\n", 601 | "import keras2onnx" 602 | ] 603 | }, 604 | { 605 | "cell_type": "code", 606 | "execution_count": 32, 607 | "metadata": {}, 608 | "outputs": [], 609 | "source": [ 610 | "onnx_model = keras2onnx.convert_keras(classifier, 'mnist')\n", 611 | "onnx.save_model(onnx_model, 'mnist.onnx')" 612 | ] 613 | }, 614 | { 615 | "cell_type": "markdown", 616 | "metadata": {}, 617 | "source": [ 618 | "# 加载训练好的模型" 619 | ] 620 | }, 621 | { 622 | "cell_type": "code", 623 | "execution_count": 33, 624 | "metadata": {}, 625 | "outputs": [], 626 | "source": [ 627 | "import onnxruntime as rt\n", 628 | "sess = rt.InferenceSession(\"mnist.onnx\")" 629 | ] 630 | }, 631 | { 632 | "cell_type": "code", 633 | "execution_count": 34, 634 | "metadata": {}, 635 | "outputs": [], 636 | "source": [ 637 | "input_name = sess.get_inputs()[0].name\n", 638 | "output_name = sess.get_outputs()[0].name" 639 | ] 640 | }, 641 | { 642 | "cell_type": "code", 643 | "execution_count": 36, 644 | "metadata": {}, 645 | "outputs": [], 646 | "source": [ 647 | "res = sess.run([output_name], {input_name: X_test})\n", 648 | "res = np.array(res)" 649 | ] 650 | }, 651 | { 652 | "cell_type": "code", 653 | "execution_count": 44, 654 | "metadata": {}, 655 | "outputs": [ 656 | { 657 | "name": "stdout", 658 | "output_type": "stream", 659 | "text": [ 660 | "[0.00000001 0.00000004 0.0000312 0.00062566 0. 0.0000002\n", 661 | " 0. 0.99688894 0.00009035 0.00236361]\n" 662 | ] 663 | }, 664 | { 665 | "data": { 666 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQgAAAEFCAYAAAAfaHkhAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAPIElEQVR4nO3de6xdZZnH8e9DayigLZcBMZFrA0hiEBUVW4VSkhkGoxBbJpjxkozSjGkwDEJMRo1o/MP4hwhMUtJMHMZApiRtxIzTQJO2UKSQCSQKKra2iEiQKJZwL1B95o+9zng87GedfS77cnq+n2RnudezLu95y/n57rXevU5kJpLUzSHDboCk0WVASCoZEJJKBoSkkgEhqWRASCoZEJJKfQ2IiHh7RHwvIp6KiFcj4vGI+G5EHNXP80qaHdGviVIRsRTYCRwH/BD4JfB+4AJgF7A8M//Yl5NLmh2Z2ZcXcBeQwJUT1n+nWX9zj8dJX7589fdV/f71ZQQREacCe4HHgaWZ+edxtbcAvwMCOC4zX5rkWLPfQEl/JTOj2/p+XYNY2Sy3jA+HpiEvAPcBhwPn9un8kmZBvwLijGa5u6j/qlme3qfzS5oFC/t03CXN8rmiPrb+yG7FiFgDrJntRkmamn4FxGTGPu90vb6QmeuB9eA1CGmY+vURY2yEsKSoL56wnaQR1K+A2NUsq2sMpzXL6hqFpBHQr9ucS4E9tN/mPAQ41tuc0vAN9DZnZu4FtgAnA2snlL8OHAF8f7JwkDRcg5xq/SjwATpTrXcDy3qZau0IQuq/agTRt4AAiIgTgG8AFwHH0PlocQfw9czc1+MxDAipz4YSELPBgJD6b9BTrSUdBAwISSUDQlLJgJBUMiAklQwISSUDQlLJgJBUMiAklQwISSUDQlLJgJBUMiAklQwISSUDQlLJgJBUMiAklQwISSUDQlLJgJBUMiAklQwISSUDQlLJgJBUMiAklQwISSUDQlLJgJBUMiAklQwISSUDQlLJgJBUMiAklQwISSUDQlLJgJBUMiAklQwISSUDQlLJgJBUMiAklRYOuwEHs9WrV5e1K664onXfp556qrW+f//+1vptt93WWn/66afL2p49e1r31fzhCEJSyYCQVDIgJJUMCEklA0JSyYCQVDIgJJUiM4fdhlYRMdoNbPHYY4+VtZNPPnlwDenihRdeKGs///nPB9iS0fLkk0+WtW9/+9ut+z744IOz3ZyByczott4RhKSSASGpZEBIKhkQkkoGhKSSASGpZEBIKvk8iD5qe+bDWWed1brvo48+2lo/88wzW+vvec97WusrVqwoa+eee27rvr/97W9b6yeccEJrfSYOHDjQWv/DH/7QWn/b29427XM/8cQTrfW5PA+i4ghCUsmAkFQyICSVDAhJJQNCUqmngIiI1RFxU0TcGxHPR0RGxK2T7LMsIjZHxL6IeDkiHo6IqyJiwew0XVK/9Xqb8yvAu4AXgSeBd7RtHBGXAJuA/cDtwD7go8D1wHLgsmm2V9IA9fQ8iIi4gE4w7AHOB7YDt2XmJ7tsu7jZbgmwPDMfbNYvArYBHwQ+kZkbemrgHH4exCg76qijytrZZ5/duu9DDz3UWn/f+943rTb1YrK/B7J79+7W+mTzS44++uiytnbt2tZ9161b11ofZTN6HkRmbs/MX2VvT5dZDRwLbBgLh+YY++mMRAA+38t5JQ1XPy5SrmyWd3ap7QBeBpZFxKF9OLekWdSPgDijWb5hrJeZB4Bf07n2cWofzi1pFvXjuxhLmuVzRX1s/ZHVASJiDbBmNhslaeqG8WWtsYsh5fWMzFwPrAcvUkrD1I+PGGMjhCVFffGE7SSNqH4ExK5mefrEQkQsBE4BDgD1M+EljYR+fMTYBvwjcBHwXxNq5wGHAzsy89U+nFs9evbZZ8va9u3bZ3TsrVu3zmj/mVi1alVrvW3+B8AjjzxS1m6//fZptWku68cIYiPwDHB5RJwztrKZKPXN5u3cnVEizSM9jSAi4lLg0ubt8c3ygxFxS/O/n8nMawAy8/mIuIJOUNwdERvoTLX+GJ1boBvpTL+WNOJ6/YhxNvCZCetO5S9zGX4DXDNWyMw7IuJ84MvAKmARnenXVwM39jgjU9KQ9RQQmXkdcN1UDpyZ9wEXT71JkkaFz4OQVDIgJJV6+rr3MDmTUuMdd9xxrfW225S97L969eqytmnTptZ957IZfd1b0vxkQEgqGRCSSgaEpJIBIalkQEgqGRCSSsN4opQ0bZM9ev7YY49trbd9zR1g165drfX5xhGEpJIBIalkQEgqGRCSSgaEpJIBIalkQEgq+TwIjZzly5eXtW3btrXu+6Y3vam1vmLFitb6jh07WusHK58HIWnKDAhJJQNCUsmAkFQyICSVDAhJJQNCUsnnQWjkXHxx/RcbJ5vnsHXr1tb6/fffP602zVeOICSVDAhJJQNCUsmAkFQyICSVDAhJJQNCUsl5EBq4ww47rLV+0UUXlbXXXnutdd+vfe1rrfXXX3+9ta6/5ghCUsmAkFQyICSVDAhJJQNCUsmAkFTyNqcG7tprr22tv/vd7y5rd955Z+u+O3funFab1J0jCEklA0JSyYCQVDIgJJUMCEklA0JSyYCQVIrMHHYbWkXEaDdQb/CRj3yktX7HHXe01l966aWy1vZVcIAHHnigta7uMjO6rXcEIalkQEgqGRCSSgaEpJIBIalkQEgqGRCSSj4PQlN2zDHHtNZvvPHG1vqCBQta65s3by5rznMYLEcQkkoGhKSSASGpZEBIKhkQkko9BUREHBMRn4uIH0TEnoh4JSKei4gfR8RnI6LrcSJiWURsjoh9EfFyRDwcEVdFRPtlbEkjodfbnJcB64DfAduBJ4C3Ah8H/h34+4i4LMd9dzwiLgE2AfuB24F9wEeB64HlzTEljbCengcRESuBI4D/ycw/j1t/PPC/wAnA6szc1KxfDOwBlgDLM/PBZv0iYBvwQeATmbmhh3P7PIgBm2yewmRzEd773ve21vfu3dtab3vmw2T7anpm9DyIzNyWmf89Phya9U8DNzdvV4wrrQaOBTaMhUOz/X7gK83bz/fWdEnDMhsXKV9vlgfGrVvZLLv9GaQdwMvAsog4dBbOL6lPZhQQEbEQ+HTzdnwYnNEsd0/cJzMPAL+mc/3j1JmcX1J/zfS7GN8C3glszsy7xq1f0iyfK/YbW39kt2JErAHWzLBtkmZo2gEREV8Avgj8EvjUVHdvll0vQGbmemB9cx4vUkpDMq2PGBGxFrgB+AVwQWbum7DJ2AhhCd0tnrCdpBE05RFERFxFZy7Dz4ALM/P3XTbbBZwDnA48NGH/hcApdC5qPjbV86v/li5d2lqf7DbmZK6++urWurcyR8eURhAR8SU64fATOiOHbuEAnbkOAN1uaJ8HHA7szMxXp3J+SYPVc0BExFfpXJR8iM7I4ZmWzTcCzwCXR8Q5446xCPhm83bd1JsraZB6+ogREZ8BvgH8CbgX+ELEGyZePZ6ZtwBk5vMRcQWdoLg7IjbQmWr9MTq3QDfSmX4taYT1eg3ilGa5ALiq2OYe4JaxN5l5R0ScD3wZWAUsojP9+mrgxhz1v/knqbeAyMzrgOumevDMvA+4eKr7SRoNPg9CUsmAkFTysffz1EknnVTWtmzZMqNjX3vtta31H/3oRzM6vgbHEYSkkgEhqWRASCoZEJJKBoSkkgEhqWRASCo5D2KeWrOmfqLfiSeeOKNj33PPPa11v4YzdziCkFQyICSVDAhJJQNCUsmAkFQyICSVDAhJJedBHKQ+9KEPtdavvPLKAbVEc5kjCEklA0JSyYCQVDIgJJUMCEklA0JSyYCQVHIexEHqwx/+cGv9zW9+87SPvXfv3tb6iy++OO1ja7Q4gpBUMiAklQwISSUDQlLJgJBUMiAklbzNqTf46U9/2lq/8MILW+v79u2bzeZoiBxBSCoZEJJKBoSkkgEhqWRASCoZEJJKBoSkUoz6n2KPiNFuoHQQyMzott4RhKSSASGpZEBIKhkQkkoGhKSSASGpZEBIKs2F50E8A/xm3Pu/adapd/bZ9MyXfjupKoz8RKmJIuLBzDxn2O2YS+yz6bHf/IghqYUBIak0FwNi/bAbMAfZZ9Mz7/ttzl2DkDQ4c3EEIWlADAhJpTkREBHx9oj4XkQ8FRGvRsTjEfHdiDhq2G0bpohYHRE3RcS9EfF8RGRE3DrJPssiYnNE7IuIlyPi4Yi4KiIWDKrdwxQRx0TE5yLiBxGxJyJeiYjnIuLHEfHZiOj6OzFf+23kr0FExFJgJ3Ac8EPgl8D7gQuAXcDyzPzj8Fo4PBHxE+BdwIvAk8A7gNsy85PF9pcAm4D9wO3APuCjwBnAxsy8bBDtHqaI+GdgHfA7YDvwBPBW4OPAEjr9c1mO+8WY1/2WmSP9Au4CErhywvrvNOtvHnYbh9g3FwCnAQGsaPrj1mLbxcDvgVeBc8atX0QngBO4fNg/0wD6bCWdX+5DJqw/nk5YJLDKfmt+zmE3YJJ/zFObf4Bfd/kHfQud/+d8CThi2G0d9quHgPinpv6fXWorm9o9w/45htyH/9r0w032W+c16tcgVjbLLZn55/GFzHwBuA84HDh30A2bg8b68s4utR3Ay8CyiDh0cE0aOa83ywPj1s3rfhv1gDijWe4u6r9qlqcPoC1zXdmXmXmAzihtIZ1R27wTEQuBTzdvx4fBvO63UQ+IJc3yuaI+tv7IAbRlrrMv230LeCewOTPvGrd+XvfbqAfEZMYe1T3at2LmhnnblxHxBeCLdO6QfWqquzfLg7LfRj0gxtJ5SVFfPGE71ezLLiJiLXAD8AvggszcN2GTed1vox4Qu5pldY3htGZZXaPQX5R92Xz+PoXOxbnHBtmoYYqIq4B/A35GJxye7rLZvO63UQ+I7c3ybyfOcIuItwDLgVeABwbdsDloW7O8qEvtPDp3g3Zm5quDa9LwRMSXgOuBn9AJh98Xm87rfhvpgMjMvcAW4GRg7YTy14EjgO9n5ksDbtpctJHO49Muj4j/f0pSRCwCvtm8XTeMhg1aRHyVzkXJh4ALM7PtsXLzut/m4lTrR4EP0JlFuBtYlvN3qvWlwKXN2+OBv6Mz1L23WfdMZl4zYfuNdKYMb6AzZfhjNFOGgX/IUf8PYoYi4jPALcCfgJvofu3g8cy8Zdw+87ffhj1Tq8cZbicA/0Fn/vxrdB5iewNw9LDbNuR+uY7O1fPq9XiXfZYDm4Fn6Xw8ewT4F2DBsH+eEemzBO623zqvkR9BSBqekb4GIWm4DAhJJQNCUsmAkFQyICSVDAhJJQNCUsmAkFQyICSVDAhJpf8Duxsu2QgN1noAAAAASUVORK5CYII=\n", 667 | "text/plain": [ 668 | "
" 669 | ] 670 | }, 671 | "metadata": { 672 | "needs_background": "light" 673 | }, 674 | "output_type": "display_data" 675 | } 676 | ], 677 | "source": [ 678 | "plt.imshow(X_test[0].reshape((28, 28)), cmap='gray')\n", 679 | "print(res[0][0])" 680 | ] 681 | }, 682 | { 683 | "cell_type": "code", 684 | "execution_count": null, 685 | "metadata": {}, 686 | "outputs": [], 687 | "source": [] 688 | } 689 | ], 690 | "metadata": { 691 | "kernelspec": { 692 | "display_name": "Python 3", 693 | "language": "python", 694 | "name": "python3" 695 | }, 696 | "language_info": { 697 | "codemirror_mode": { 698 | "name": "ipython", 699 | "version": 3 700 | }, 701 | "file_extension": ".py", 702 | "mimetype": "text/x-python", 703 | "name": "python", 704 | "nbconvert_exporter": "python", 705 | "pygments_lexer": "ipython3", 706 | "version": "3.7.3" 707 | } 708 | }, 709 | "nbformat": 4, 710 | "nbformat_minor": 2 711 | } 712 | -------------------------------------------------------------------------------- /examples/model/mnist-lg.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuhanstudio/onnx-backend/56abccbf30dbd6436c406008aadf4345b4faad58/examples/model/mnist-lg.onnx -------------------------------------------------------------------------------- /examples/model/mnist-sm.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuhanstudio/onnx-backend/56abccbf30dbd6436c406008aadf4345b4faad58/examples/model/mnist-sm.onnx -------------------------------------------------------------------------------- /src/add.c: -------------------------------------------------------------------------------- 1 | #include "onnx.h" 2 | 3 | void add(const float *input, // pointer to vector 4 | const float *bias, // pointer to matrix 5 | const uint16_t dim_vec, // length of the vector 6 | float *output) 7 | { 8 | for (int i = 0; i < dim_vec; i++) 9 | { 10 | output[i] = input[i] + bias[i]; 11 | } 12 | } 13 | 14 | float* add_layer(Onnx__GraphProto* graph, const float *input, int64_t* shapeInput, int64_t* shapeOutput, const char* layer_name) 15 | { 16 | assert(graph != NULL && input != NULL && layer_name != "" ); 17 | 18 | Onnx__NodeProto* node = onnx_graph_get_node_by_name(graph, layer_name); 19 | const char* bias = node->input[1]; 20 | 21 | float* B = onnx_graph_get_weights_by_name(graph, bias); 22 | int64_t* shapeB = onnx_graph_get_dims_by_name(graph, bias); 23 | if(shapeB == NULL) 24 | { 25 | return NULL; 26 | } 27 | 28 | float* output = (float*) malloc(sizeof(float)*shapeB[0]); 29 | memset(output, 0, sizeof(sizeof(float)*shapeB[0])); 30 | add(input, B, shapeB[0], output); 31 | 32 | memcpy(shapeInput, shapeOutput, sizeof(int64_t)*3); 33 | 34 | return output; 35 | } 36 | -------------------------------------------------------------------------------- /src/conv2d.c: -------------------------------------------------------------------------------- 1 | #include "onnx.h" 2 | 3 | void conv2D(const float *input, // input image 4 | const uint16_t dim_im_in_x, // input image dimention x 5 | const uint16_t dim_im_in_y, // input image dimention y 6 | const uint16_t ch_im_in, // number of input image channels 7 | const float *weight, // kernel weights 8 | const uint16_t ch_im_out, // number of filters, i.e., output image channels 9 | const uint16_t dim_kernel_x, // filter kernel size x 10 | const uint16_t dim_kernel_y, // filter kernel size y 11 | const uint16_t padding_x, // padding sizes x 12 | const uint16_t padding_y, // padding sizes y 13 | const uint16_t stride_x, // stride x 14 | const uint16_t stride_y, // stride y 15 | const float *bias, // bias 16 | float *output, // output image 17 | const uint16_t dim_im_out_x, // output image dimension x 18 | const uint16_t dim_im_out_y // output image dimension y 19 | ) 20 | { 21 | int i, j, k, l, m, n; 22 | float conv_out = 0.0f; 23 | int in_row, in_col; 24 | 25 | // For each filter 26 | for (i = 0; i < ch_im_out; i++) 27 | { 28 | // For each image dimension 29 | for (j = 0; j < dim_im_out_y; j++) 30 | { 31 | for (k = 0; k < dim_im_out_x; k++) 32 | { 33 | conv_out = bias[i]; 34 | // For each kernel dimension 35 | for (m = 0; m < dim_kernel_y; m++) 36 | { 37 | for (n = 0; n < dim_kernel_x; n++) 38 | { 39 | // if-for implementation 40 | in_row = stride_y * j + m - padding_y; 41 | in_col = stride_x * k + n - padding_x; 42 | if (in_row >= 0 && in_col >= 0 && in_row < dim_im_in_y && in_col < dim_im_in_x) 43 | { 44 | // For each input channel 45 | for (l = 0; l < ch_im_in; l++) 46 | { 47 | conv_out += input[(in_row * dim_im_in_x + in_col) * ch_im_in + l] * 48 | weight[i * ch_im_in * dim_kernel_y * dim_kernel_x + (m * dim_kernel_x + n) * ch_im_in + 49 | l]; 50 | } 51 | } 52 | } 53 | } 54 | output[i + (j * dim_im_out_x + k) * ch_im_out] = conv_out; 55 | } 56 | } 57 | } 58 | } 59 | 60 | float* conv2D_layer(Onnx__GraphProto* graph, const float *input, int64_t* shapeInput, int64_t* shapeOutput, const char* layer_name) 61 | { 62 | assert(graph != NULL && input != NULL && layer_name != "" ); 63 | 64 | Onnx__NodeProto* node = onnx_graph_get_node_by_name(graph, layer_name); 65 | if(node == NULL) 66 | { 67 | // layer not found 68 | return NULL; 69 | } 70 | const char* weight = node->input[1]; 71 | const char* bias = node->input[2]; 72 | 73 | // Get weight shape 74 | int64_t* shapeW = onnx_graph_get_dims_by_name(graph, weight); 75 | if(shapeW == NULL) 76 | { 77 | return NULL; 78 | } 79 | int64_t dimW = onnx_graph_get_dim_by_name(graph, weight); 80 | if(dimW < 0) 81 | { 82 | return NULL; 83 | } 84 | 85 | // Get weights 86 | // NCWH --> NWHC 87 | int64_t permW_t[] = { 0, 2, 3, 1}; 88 | float* W = onnx_graph_get_weights_by_name(graph, weight); 89 | if(W == NULL) 90 | { 91 | return NULL; 92 | } 93 | float* W_t = transpose(W, shapeW, dimW, permW_t); 94 | 95 | // Get bias 96 | float* B = onnx_graph_get_weights_by_name(graph, bias); 97 | if(B == NULL) 98 | { 99 | return NULL; 100 | } 101 | 102 | float* output = (float*) malloc(sizeof(float)*shapeW[0]*shapeInput[W_INDEX]*shapeInput[H_INDEX]); 103 | memset(output, 0, sizeof(sizeof(float)*shapeW[0]*shapeInput[W_INDEX]*shapeInput[H_INDEX])); 104 | conv2D(input, shapeInput[W_INDEX], shapeInput[H_INDEX], shapeW[1], W_t, shapeW[0], shapeW[2], shapeW[3], 1, 1, 1, 1, B, output, shapeInput[W_INDEX], shapeInput[H_INDEX]); 105 | 106 | shapeOutput[W_INDEX] = shapeInput[W_INDEX]; 107 | shapeOutput[H_INDEX] = shapeInput[H_INDEX]; 108 | shapeOutput[C_INDEX] = shapeW[0]; 109 | 110 | free(W_t); 111 | 112 | return output; 113 | } 114 | -------------------------------------------------------------------------------- /src/dense.c: -------------------------------------------------------------------------------- 1 | #include "onnx.h" 2 | 3 | void dense(const float *input, // pointer to vector 4 | const float *weight, // pointer to matrix 5 | const uint16_t dim_vec, // length of the vector 6 | const uint16_t num_of_rows, // numCol of A 7 | const float *bias, 8 | float *output) // output operand 9 | { 10 | for (int i = 0; i < num_of_rows; i++) 11 | { 12 | float ip_out = bias[i]; 13 | for (int j = 0; j < dim_vec; j++) 14 | { 15 | ip_out += input[j] * weight[i * dim_vec + j]; 16 | } 17 | output[i] = ip_out; 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /src/info.c: -------------------------------------------------------------------------------- 1 | #include "onnx.h" 2 | 3 | void onnx_tensor_info(const float* A, int64_t* shape, int64_t dim) 4 | { 5 | int elem = 1; 6 | for(int i = 0; i < dim; i++) 7 | { 8 | elem = elem * shape[i]; 9 | } 10 | 11 | printf("Array size: %d\n", elem); 12 | for(int i = 0; i < elem; i++) 13 | { 14 | printf( "%f ", A[i] ); 15 | int split = 1; 16 | for(int j = dim-1; j > 0; j--) 17 | { 18 | split = split * shape[j]; 19 | if( (i+1) % split == 0) 20 | { 21 | printf("\n"); 22 | } 23 | } 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /src/matmul.c: -------------------------------------------------------------------------------- 1 | #include "onnx.h" 2 | 3 | void matmul(const float *input, // pointer to vector 4 | const float *weight, // pointer to matrix 5 | const uint16_t dim_vec, // length of the vector 6 | const uint16_t num_of_rows, // numCol of A 7 | float *output) 8 | { 9 | for (int i = 0; i < num_of_rows; i++) 10 | { 11 | float ip_out = 0; 12 | for (int j = 0; j < dim_vec; j++) 13 | { 14 | ip_out += input[j] * weight[i * dim_vec + j]; 15 | } 16 | output[i] = ip_out; 17 | } 18 | } 19 | 20 | float* matmul_layer(Onnx__GraphProto* graph, const float *input, int64_t* shapeInput, int64_t* shapeOutput, const char* layer_name) 21 | { 22 | assert(graph != NULL && input != NULL && layer_name != "" ); 23 | 24 | Onnx__NodeProto* node = onnx_graph_get_node_by_name(graph, layer_name); 25 | const char* weight = node->input[1]; 26 | 27 | int64_t* shapeW = onnx_graph_get_dims_by_name(graph, weight); 28 | if(shapeW == NULL) 29 | { 30 | return NULL; 31 | } 32 | int64_t dimW = onnx_graph_get_dim_by_name(graph, weight); 33 | if(dimW < 0) 34 | { 35 | return NULL; 36 | } 37 | 38 | assert(shapeW[0] == shapeInput[1]); 39 | 40 | int64_t permW_t[] = {1, 0}; 41 | float* W = onnx_graph_get_weights_by_name(graph, weight); 42 | if(W == NULL) 43 | { 44 | return NULL; 45 | } 46 | float* W_t = transpose(W, shapeW, dimW, permW_t); 47 | 48 | float* output = (float*) malloc(sizeof(float)*shapeW[1]); 49 | if(output == NULL) 50 | { 51 | // No memory 52 | return NULL; 53 | } 54 | memset(output, 0, sizeof(sizeof(float)*shapeW[1])); 55 | matmul(input, W_t, shapeW[0], shapeW[1], output); 56 | 57 | shapeOutput[0] = shapeInput[0]; 58 | shapeOutput[1] = shapeW[1]; 59 | 60 | free(W_t); 61 | 62 | return output; 63 | } 64 | -------------------------------------------------------------------------------- /src/maxpool.c: -------------------------------------------------------------------------------- 1 | #include "onnx.h" 2 | 3 | void maxpool(const float *input, 4 | const uint16_t dim_im_in_x, // input image dimension x or W 5 | const uint16_t dim_im_in_y, // input image dimension y or H 6 | const uint16_t ch_im_in, // number of input image channels 7 | const uint16_t dim_kernel_x, // window kernel size 8 | const uint16_t dim_kernel_y, // window kernel size 9 | const uint16_t padding_x, // padding sizes 10 | const uint16_t padding_y, // padding sizes 11 | const uint16_t stride_x, // stride 12 | const uint16_t stride_y, // stride 13 | const uint16_t dim_im_out_x, // output image dimension x or W 14 | const uint16_t dim_im_out_y, // output image dimension y or H 15 | float *output) 16 | { 17 | int16_t i_ch_in, i_x, i_y; 18 | int16_t k_x, k_y; 19 | 20 | for (i_ch_in = 0; i_ch_in < ch_im_in; i_ch_in++) 21 | { 22 | for (i_y = 0; i_y < dim_im_out_y; i_y++) 23 | { 24 | for (i_x = 0; i_x < dim_im_out_x; i_x++) 25 | { 26 | float max = FLT_MIN; 27 | for (k_y = i_y * stride_y - padding_y; k_y < i_y * stride_y - padding_y + dim_kernel_y; k_y++) 28 | { 29 | for (k_x = i_x * stride_x - padding_x; k_x < i_x * stride_x - padding_x + dim_kernel_x; k_x++) 30 | { 31 | if (k_y >= 0 && k_x >= 0 && k_y < dim_im_in_y && k_x < dim_im_in_x) 32 | { 33 | if (input[i_ch_in + ch_im_in * (k_x + k_y * dim_im_in_x)] > max) 34 | { 35 | max = input[i_ch_in + ch_im_in * (k_x + k_y * dim_im_in_x)]; 36 | } 37 | } 38 | } 39 | } 40 | output[i_ch_in + ch_im_in * (i_x + i_y * dim_im_out_x)] = max; 41 | } 42 | } 43 | } 44 | } 45 | 46 | float* maxpool_layer(Onnx__GraphProto* graph, float* input, int64_t* shapeInput, int64_t* shapeOutput, const char* layer_name) 47 | { 48 | assert(graph != NULL && input != NULL && layer_name != "" ); 49 | 50 | Onnx__NodeProto* node = onnx_graph_get_node_by_name(graph, layer_name); 51 | if(node == NULL) 52 | { 53 | // layer not found 54 | return NULL; 55 | } 56 | 57 | uint16_t kernel_x = 1; 58 | uint16_t kernel_y = 1; 59 | uint16_t padding_x = 0; 60 | uint16_t padding_y = 0; 61 | uint16_t stride_x = 1; 62 | uint16_t stride_y = 1; 63 | 64 | for(int i = 0; i < node->n_attribute; i++) 65 | { 66 | if( strcmp(node->attribute[i]->name, "kernel_shape") == 0 ) 67 | { 68 | kernel_x = node->attribute[i]->ints[0]; 69 | kernel_y = node->attribute[i]->ints[1]; 70 | } 71 | if( strcmp(node->attribute[i]->name, "strides") == 0 ) 72 | { 73 | stride_x = node->attribute[i]->ints[0]; 74 | stride_y = node->attribute[i]->ints[1]; 75 | } 76 | } 77 | 78 | uint16_t out_x = (shapeInput[W_INDEX] - kernel_x + 2 * padding_x) / stride_x + 1; 79 | uint16_t out_y = (shapeInput[H_INDEX] - kernel_y + 2 * padding_y) / stride_y + 1; 80 | 81 | float* output = (float*) malloc(sizeof(float)*out_x*out_y*shapeInput[C_INDEX]); 82 | if(output == NULL) 83 | { 84 | // No memory 85 | return NULL; 86 | } 87 | memset(output, 0, sizeof(sizeof(float)*out_x*out_y*shapeInput[C_INDEX])); 88 | maxpool(input, shapeInput[W_INDEX], shapeInput[H_INDEX], shapeInput[C_INDEX], kernel_x, kernel_y, padding_x, padding_y, stride_x, stride_y, out_x, out_y, output); 89 | 90 | shapeOutput[W_INDEX] = out_x; 91 | shapeOutput[H_INDEX] = out_y; 92 | shapeOutput[C_INDEX] = shapeInput[C_INDEX]; 93 | 94 | return output; 95 | } 96 | -------------------------------------------------------------------------------- /src/model.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include "onnx.h" 3 | 4 | float* onnx_model_run(Onnx__ModelProto* model, float* input, int64_t* shapeInput) 5 | { 6 | int64_t* shapeOutput = (int64_t*) malloc(sizeof(int64_t)*3); 7 | shapeOutput[0] = -1; shapeOutput[1] = -1; shapeOutput[2] = -1; 8 | 9 | Onnx__NodeProto* node = onnx_graph_get_node_by_input(model->graph, model->graph->input[0]->name); 10 | 11 | int i = 0; 12 | float* output; 13 | while(node != NULL) 14 | { 15 | printf("[%2d] %-10s %-20s ", i++, node->op_type, node->name); 16 | if(strcmp(node->op_type, "Conv") == 0) 17 | { 18 | output = conv2D_layer(model->graph, input, shapeInput, shapeOutput, node->name); 19 | } 20 | else if(strcmp(node->op_type, "Relu") == 0) 21 | { 22 | output = relu_layer(model->graph, input, shapeInput, shapeOutput, node->name); 23 | } 24 | else if(strcmp(node->op_type, "MaxPool") == 0) 25 | { 26 | output = maxpool_layer(model->graph, input, shapeInput, shapeOutput, node->name); 27 | } 28 | else if(strcmp(node->op_type, "Softmax") == 0) 29 | { 30 | output = softmax_layer(model->graph, input, shapeInput, shapeOutput, node->name); 31 | } 32 | else if(strcmp(node->op_type, "MatMul") == 0) 33 | { 34 | output = matmul_layer(model->graph, input, shapeInput, shapeOutput, node->name); 35 | } 36 | else if(strcmp(node->op_type, "Add") == 0) 37 | { 38 | output = add_layer(model->graph, input, shapeInput, shapeOutput, node->name); 39 | } 40 | else if(strcmp(node->op_type, "Identity") == 0) 41 | { 42 | node = onnx_graph_get_node_by_input(model->graph, node->output[0]); 43 | printf("\n"); 44 | 45 | continue; 46 | } 47 | else if(strcmp(node->op_type, "Transpose") == 0) 48 | { 49 | node = onnx_graph_get_node_by_input(model->graph, node->output[0]); 50 | printf("\n"); 51 | 52 | continue; 53 | } 54 | else if(strcmp(node->op_type, "Reshape") == 0) 55 | { 56 | shapeOutput[1] = shapeOutput[0] * shapeOutput[1] * shapeOutput[2]; 57 | shapeOutput[2] = 1; 58 | shapeOutput[0] = 1; 59 | printf("[%2" PRId64 ", %2" PRId64 ", %2" PRId64 "] --> [%2" PRId64 ", %2" PRId64 ", %2" PRId64 "]\n", shapeInput[0], shapeInput[1], shapeInput[2], shapeOutput[0], shapeOutput[1], shapeOutput[2]); 60 | 61 | // free(input); 62 | // input = output; 63 | memcpy(shapeInput, shapeOutput, sizeof(int64_t)*3); 64 | 65 | node = onnx_graph_get_node_by_input(model->graph, node->output[0]); 66 | continue; 67 | } 68 | else 69 | { 70 | printf("Unsupported operand: %s\n", node->op_type); 71 | } 72 | printf("[%2" PRId64 ", %2" PRId64 ", %2" PRId64 "] --> [%2" PRId64 ", %2" PRId64 ", %2" PRId64 "]\n", shapeInput[0], shapeInput[1], shapeInput[2], shapeOutput[0], shapeOutput[1], shapeOutput[2]); 73 | 74 | free(input); 75 | input = output; 76 | memcpy(shapeInput, shapeOutput, sizeof(int64_t)*3); 77 | 78 | node = onnx_graph_get_node_by_input(model->graph, node->output[0]); 79 | } 80 | output = input; 81 | free(shapeOutput); 82 | 83 | return output; 84 | } 85 | -------------------------------------------------------------------------------- /src/onnx.h: -------------------------------------------------------------------------------- 1 | #ifndef __ONNX_H__ 2 | #define __ONNX_H__ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | 12 | #define ONNX_USE_NWHC 13 | 14 | #ifdef ONNX_USE_NWHC 15 | // NWHC 16 | #define W_INDEX 0 17 | #define H_INDEX 1 18 | #define C_INDEX 2 19 | #else 20 | // NCWH 21 | #define C_INDEX 0 22 | #define W_INDEX 1 23 | #define H_INDEX 2 24 | #endif 25 | 26 | // Model 27 | void onnx_tensor_info(const float* A, int64_t* shape, int64_t dim); 28 | float* onnx_model_run(Onnx__ModelProto* model, float* input, int64_t* shapeInput); 29 | 30 | // Layers 31 | float* conv2D_layer(Onnx__GraphProto* graph, const float *input, int64_t* shapeInput, int64_t* shapeOutput, const char* layer_name); 32 | float* relu_layer(Onnx__GraphProto* graph, const float *input, int64_t* shapeInput, int64_t* shapeOutput, const char* layer_name); 33 | float* maxpool_layer(Onnx__GraphProto* graph, float* input, int64_t* shapeInput, int64_t* shapeOutput, const char* layer_name); 34 | float* matmul_layer(Onnx__GraphProto* graph, const float *input, int64_t* shapeInput, int64_t* shapeOutput, const char* layer_name); 35 | float* add_layer(Onnx__GraphProto* graph, const float *input, int64_t* shapeInput, int64_t* shapeOutput, const char* layer_name); 36 | float* softmax_layer(Onnx__GraphProto* graph, const float *input, int64_t* shapeInput, int64_t* shapeOutput, const char* layer_name); 37 | 38 | // Operators 39 | float* transpose(const float* A, int64_t* shape, int64_t dim, int64_t* perm); 40 | 41 | void conv2D(const float *input, // input image 42 | const uint16_t dim_im_in_x, // input image dimention x 43 | const uint16_t dim_im_in_y, // input image dimention y 44 | const uint16_t ch_im_in, // number of input image channels 45 | const float *weight, // kernel weights 46 | const uint16_t ch_im_out, // number of filters, i.e., output image channels 47 | const uint16_t dim_kernel_x, // filter kernel size x 48 | const uint16_t dim_kernel_y, // filter kernel size y 49 | const uint16_t padding_x, // padding sizes x 50 | const uint16_t padding_y, // padding sizes y 51 | const uint16_t stride_x, // stride x 52 | const uint16_t stride_y, // stride y 53 | const float *bias, // bias 54 | float *output, // output image 55 | const uint16_t dim_im_out_x, // output image dimension x 56 | const uint16_t dim_im_out_y // output image dimension y 57 | ); 58 | 59 | void relu(const float *input, uint32_t size, float* output); 60 | 61 | void maxpool(const float *input, 62 | const uint16_t dim_im_in_x, // input image dimension x or W 63 | const uint16_t dim_im_in_y, // input image dimension y or H 64 | const uint16_t ch_im_in, // number of input image channels 65 | const uint16_t dim_kernel_x, // window kernel size 66 | const uint16_t dim_kernel_y, // window kernel size 67 | const uint16_t padding_x, // padding sizes 68 | const uint16_t padding_y, // padding sizes 69 | const uint16_t stride_x, // stride 70 | const uint16_t stride_y, // stride 71 | const uint16_t dim_im_out_x, // output image dimension x or W 72 | const uint16_t dim_im_out_y, // output image dimension y or H 73 | float *output); 74 | 75 | void matmul(const float *input, // pointer to vector 76 | const float *weight, // pointer to matrix 77 | const uint16_t dim_vec, // length of the vector 78 | const uint16_t num_of_rows, // numCol of A 79 | float *output); 80 | 81 | void add(const float *input, // pointer to vector 82 | const float *bias, // pointer to matrix 83 | const uint16_t dim_vec, // length of the vector 84 | float *output); 85 | 86 | void dense(const float *input, // pointer to vector 87 | const float *weight, // pointer to matrix 88 | const uint16_t dim_vec, // length of the vector 89 | const uint16_t num_of_rows, // numCol of A 90 | const float *bias, 91 | float *output); 92 | 93 | void softmax(const float *input, const uint32_t dim_vec, float *output); 94 | 95 | #endif // __ONNX_H__ 96 | -------------------------------------------------------------------------------- /src/relu.c: -------------------------------------------------------------------------------- 1 | #include "onnx.h" 2 | 3 | void relu(const float *input, uint32_t size, float* output) 4 | { 5 | uint32_t i; 6 | memcpy(output, input, sizeof(float) * size); 7 | for (i = 0; i < size; i++) 8 | { 9 | if (output[i] < 0) 10 | output[i] = 0; 11 | } 12 | } 13 | 14 | float* relu_layer(Onnx__GraphProto* graph, const float *input, int64_t* shapeInput, int64_t* shapeOutput, const char* layer_name) 15 | { 16 | assert(graph != NULL && input != NULL && layer_name != "" ); 17 | 18 | int64_t len = shapeInput[0] * shapeInput[1] * shapeInput[2]; 19 | float* output = (float*) malloc(sizeof(float)*len); 20 | memset(output, 0, sizeof(sizeof(float)*len)); 21 | 22 | relu(input, len, output); 23 | 24 | memcpy(shapeInput, shapeOutput, sizeof(int64_t)*3); 25 | 26 | return output; 27 | } 28 | -------------------------------------------------------------------------------- /src/softmax.c: -------------------------------------------------------------------------------- 1 | #include "onnx.h" 2 | 3 | void softmax(const float *input, const uint32_t dim_vec, float *output) 4 | { 5 | float sum = 0.0f; 6 | 7 | for(int i = 0; i < dim_vec; i++) 8 | { 9 | output[i] = expf(input[i]); 10 | sum = sum + output[i]; 11 | } 12 | 13 | for(int i = 0; i < dim_vec; i++) 14 | { 15 | output[i] = output[i] / sum; 16 | } 17 | } 18 | 19 | float* softmax_layer(Onnx__GraphProto* graph, const float *input, int64_t* shapeInput, int64_t* shapeOutput, const char* layer_name) 20 | { 21 | assert(graph != NULL && input != NULL && layer_name != "" && shapeInput[1] > 0); 22 | 23 | float* output = (float*) malloc(sizeof(float)*shapeInput[1]); 24 | memset(output, 0, sizeof(sizeof(float)*shapeInput[1])); 25 | softmax(input, shapeInput[1], output); 26 | 27 | memcpy(shapeInput, shapeOutput, sizeof(int64_t)*3); 28 | 29 | return output; 30 | } 31 | -------------------------------------------------------------------------------- /src/transpose.c: -------------------------------------------------------------------------------- 1 | #include "onnx.h" 2 | 3 | float* transpose(const float* A, int64_t* shape, int64_t dim, int64_t* perm) 4 | { 5 | // Get array size 6 | int elem = 1; 7 | for(int i = 0; i < dim; i++) 8 | { 9 | elem = elem * shape[i]; 10 | } 11 | 12 | // Malloc memory for B 13 | float* B = malloc(sizeof(float) * elem); 14 | if(B == NULL) 15 | { 16 | return NULL; 17 | } 18 | 19 | // Malloc memory for shapeB 20 | int* shapeB = malloc(sizeof(int) * dim); 21 | if( shapeB == NULL) 22 | { 23 | return NULL; 24 | } 25 | for(int i = 0; i < dim; i++) 26 | { 27 | shapeB[i] = shape[perm[i]]; 28 | } 29 | 30 | // Transpose 31 | for(int src = 0; src < elem; src++) 32 | { 33 | // Get transposed B array 34 | // A[1][0][3] -> B[3][1][0] 35 | int temp = src; 36 | int* indexA = malloc(sizeof(int) * dim); 37 | if(indexA == NULL) 38 | { 39 | return NULL; 40 | } 41 | int* indexB = malloc(sizeof(int) * dim); 42 | if(indexB == NULL) 43 | { 44 | return NULL; 45 | } 46 | for(int i = dim-1; i >= 0; i--) 47 | { 48 | indexA[i] = temp % shape[i]; 49 | temp = temp / shape[i]; 50 | } 51 | for(int i = 0; i < dim; i++) 52 | { 53 | indexB[i] = indexA[perm[i]]; 54 | } 55 | 56 | // Get transposed B index 57 | // #15 A[1][0][3] -> B[3][1][0] #21 58 | int dst = 0; 59 | temp = 1; 60 | for(int i = dim - 1; i >= 0; i--) 61 | { 62 | dst = dst + indexB[i] * temp; 63 | temp = temp * shapeB[i]; 64 | } 65 | 66 | B[dst] = A[src]; 67 | 68 | free(indexA); 69 | free(indexB); 70 | } 71 | 72 | free(shapeB); 73 | 74 | return B; 75 | } 76 | 77 | float* transpose_layer(Onnx__GraphProto* graph, const float *input, int64_t* shapeInput, int64_t* shapeOutput, const char* layer_name) 78 | { 79 | assert(graph != NULL && input != NULL && layer_name != "" ); 80 | 81 | Onnx__NodeProto* node = onnx_graph_get_node_by_name(graph, layer_name); 82 | if(node == NULL) 83 | { 84 | return NULL; 85 | } 86 | 87 | int64_t perm_t[3]; 88 | int64_t* perm = node->attribute[0]->ints; 89 | perm_t[0] = perm[1] - 1; 90 | perm_t[1] = perm[2] - 1; 91 | perm_t[2] = perm[3] - 1; 92 | 93 | float* output = transpose(input, shapeInput, 3, perm_t); 94 | 95 | shapeOutput[0] = shapeInput[perm_t[0]]; 96 | shapeOutput[1] = shapeInput[perm_t[1]]; 97 | shapeOutput[2] = shapeInput[perm_t[2]]; 98 | 99 | return output; 100 | } 101 | --------------------------------------------------------------------------------