├── .dockerignore ├── .gitignore ├── CMakeLists.txt ├── Dockerfile ├── Dockerfile.debug ├── Dockerfile.test ├── Jenkinsfile ├── README.md ├── Tutorial.md ├── Tutorial.zh-CN.md ├── example_client ├── cpp │ ├── CMakeLists.txt │ ├── benchmark.cc │ ├── grpc_client.h │ └── main.cc └── python │ ├── CMakeLists.txt │ ├── README.md │ └── gpt2ml_example.py ├── grpc_server ├── grpc_server.cc ├── grpc_server.h ├── model_loader.cc ├── model_loader.h ├── model_manager.h ├── thread_pool.cc └── thread_pool.h ├── java ├── pom.xml └── src │ └── main │ ├── java │ └── org │ │ └── onnx │ │ └── inference │ │ └── ONNXGRPCClient.java │ └── proto │ └── inference.proto ├── proposal.md ├── tests ├── CMakeLists.txt ├── grpc-test.cc └── models │ └── mnist │ ├── config │ ├── grpc_config.txt │ ├── img0.data │ ├── img1.data │ ├── model.onnx │ ├── model.so │ └── val_map.txt └── utils ├── CMakeLists.txt ├── inference.proto ├── onnx.proto └── onnx_reader.cc /.dockerignore: -------------------------------------------------------------------------------- 1 | Dockerfile 2 | Dockerfile.base 3 | Dockerfile.* 4 | build/ 5 | .git -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | cmake/* 2 | build/ 3 | Jenkinsfile 4 | LICENSE 5 | results/ 6 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | 2 | cmake_minimum_required(VERSION 3.14) 3 | set(CMAKE_DEBUG_POSTFIX d) 4 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -W -Wall -pthread") 5 | set(CMAKE_CXX_STANDARD 17) 6 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 7 | 8 | 9 | project(onnx-mlir-grpc-serving) 10 | 11 | 12 | 13 | # cmake -DCMAKE_BUILD_TYPE=Release -DONNX_COMPILER_DIR:STRING=/aivol/mlperf/AIU_bk/onnx-mlir-build -DCMAKE_PREFIX_PATH=/aivol/grpc_install .. 14 | if(NOT CMAKE_BUILD_TYPE) 15 | set(CMAKE_BUILD_TYPE "Release") 16 | endif() 17 | 18 | 19 | 20 | message("-- CMAKE_HOST_SYSTEM_PROCESSOR: ${CMAKE_HOST_SYSTEM_PROCESSOR}") 21 | if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "s390x" ) 22 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DBIGENDIAN=1") 23 | ELSE() 24 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DBIGENDIAN=0") 25 | ENDIF() 26 | 27 | 28 | add_subdirectory(utils ${CMAKE_CURRENT_BINARY_DIR}/utils) 29 | include_directories(${CMAKE_CURRENT_BINARY_DIR}) 30 | # For onnx-mlir 31 | find_library(CRuntime 32 | NAMES cruntime 33 | PATHS ${ONNX_COMPILER_DIR}/lib) 34 | message(STATUS "CRuntime: ${CRuntime}") 35 | 36 | include_directories("${CMAKE_CURRENT_BINARY_DIR}") 37 | 38 | add_executable(grpc_server "grpc_server/grpc_server.cc" "grpc_server/thread_pool.cc" "grpc_server/model_loader.cc") 39 | target_link_libraries(grpc_server 40 | ${CRuntime} 41 | hw_grpc_proto 42 | ${_REFLECTION} 43 | ${_GRPC_GRPCPP} 44 | ) 45 | target_include_directories(grpc_server PRIVATE 46 | ${ONNX_COMPILER_DIR}/include) 47 | 48 | add_subdirectory(tests ${CMAKE_CURRENT_BINARY_DIR}/tests EXCLUDE_FROM_ALL) 49 | add_subdirectory(example_client/cpp ${CMAKE_CURRENT_BINARY_DIR}/cpp) 50 | add_subdirectory(example_client/python ${CMAKE_CURRENT_BINARY_DIR}/python) 51 | 52 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:22.04 2 | 3 | COPY --from=onnxmlirczar/onnx-mlir:latest /usr/local/bin/ /usr/local/bin/ 4 | COPY --from=onnxmlirczar/onnx-mlir:latest /usr/local/lib/ /usr/local/lib/ 5 | COPY --from=onnxmlirczar/onnx-mlir:latest /usr/local/lib64/ /usr/local/lib64/ 6 | COPY --from=onnxmlirczar/onnx-mlir:latest /usr/local/include/ /usr/local/include/ 7 | 8 | RUN apt-get update \ 9 | && apt-get install -y build-essential autoconf libtool pkg-config cmake git maven libssl-dev clang 10 | 11 | ARG WORK_DIR=/workdir 12 | WORKDIR ${WORK_DIR} 13 | 14 | RUN git clone -b v1.57.0 https://github.com/grpc/grpc \ 15 | && cd grpc; git submodule update --init \ 16 | && cd grpc;mkdir -p cmake/build;cd cmake/build;cmake -DCMAKE_BUILD_TYPE=Release -DgRPC_SSL_PROVIDER=package ../.. \ 17 | && cd grpc/cmake/build; make -j8;make install \ 18 | && cd /workdir; rm -rf grpc 19 | 20 | 21 | COPY . onnx-mlir-serving 22 | RUN cd onnx-mlir-serving \ 23 | && mkdir -p cmake/build; cd cmake/build \ 24 | && cmake -DCMAKE_BUILD_TYPE=Release ../.. \ 25 | && make -j8 \ 26 | && rm -rf /root/.cache 27 | 28 | 29 | ENTRYPOINT ["/bin/bash"] -------------------------------------------------------------------------------- /Dockerfile.debug: -------------------------------------------------------------------------------- 1 | FROM onnx/aigrpc-server:ff 2 | RUN apt-get install -y gdb 3 | RUN cd aigrpc-server \ 4 | && mkdir -p cmake/debug; cd cmake/debug \ 5 | && cmake -DCMAKE_BUILD_TYPE=Debug ../.. \ 6 | && cmake --build . \ 7 | && cd ../../tests \ 8 | && for dir in models/*/; do echo $dir; ls $dir/model.onnx|xargs onnx-mlir; done; cp -r models ../cmake/debug/ 9 | 10 | ENTRYPOINT ["/bin/bash"] 11 | -------------------------------------------------------------------------------- /Dockerfile.test: -------------------------------------------------------------------------------- 1 | FROM onnx/aigrpc-server:ff 2 | 3 | RUN apt-get install wget \ 4 | && wget -O googletest.tar.gz https://github.com/google/googletest/archive/release-1.11.0.tar.gz \ 5 | && tar xf googletest.tar.gz;mv googletest-release-1.11.0 googletest \ 6 | && cd googletest;cmake -DBUILD_SHARED_LIBS=ON .;make;make install \ 7 | && rm -rf googletest \ 8 | && ldconfig 9 | 10 | 11 | RUN cd aigrpc-server/cmake/build/tests;make \ 12 | && cd ../../../tests; for dir in models/*/; do echo $dir; ls ${dir}model.onnx|xargs onnx-mlir; done \ 13 | && cp -r models ../cmake/build/ \ 14 | && cp ../cmake/build/tests/grpc-test ../cmake/build \ 15 | && cd ../java; mvn verify 16 | 17 | 18 | ENTRYPOINT ["/bin/bash"] -------------------------------------------------------------------------------- /Jenkinsfile: -------------------------------------------------------------------------------- 1 | pipeline { 2 | agent any 3 | stages { 4 | stage('Build Server Image') { 5 | steps { 6 | sh "docker build -f Dockerfile -t onnx/aigrpc-server ." 7 | } 8 | } 9 | stage('Build Debug Image') { 10 | steps { 11 | sh "docker build -f Dockerfile.debug -t onnx/aigrpc-debug ." 12 | } 13 | } 14 | stage('utest') { 15 | steps { 16 | sh "docker run -v ${PWD}/results:/results onnx/aigrpc-server -c 'cd /workdir/aigrpc-server/cmake/build;./grpc-test --gtest_output=xml:/results/'" 17 | } 18 | } 19 | } 20 | post { 21 | always { 22 | junit './results/**' 23 | } 24 | success { 25 | echo 'This will run only if successful' 26 | } 27 | failure { 28 | echo 'This will run only if failed' 29 | } 30 | unstable { 31 | echo 'This will run only if the run was marked as unstable' 32 | } 33 | } 34 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ONNX-MLIR Serving 2 | 3 | This project implements a GRPC server written with C++ to serve [onnx-mlir](https://onnx.ai/onnx-mlir/) compiled models. Benefiting from C++ implementation, ONNX Serving has very low latency overhead and high throughput. 4 | 5 | ONNX Servring provides dynamic batch aggregation and workers pool feature to fully utilize AI accelerators on the machine. 6 | 7 | ## [ONNX-MLIR](https://onnx.ai/onnx-mlir/) 8 | ONNX-MLIR is compiler technology to transform a valid Open Neural Network Exchange (ONNX) graph into code that implements the graph with minimum runtime support. It implements the ONNX standard and is based on the underlying LLVM/MLIR compiler technology. 9 | 10 | ## Build 11 | 12 | There are two ways to build this project. 13 | + [Build ONNX-MLIR Serving on local environment](#build-onnx-mlir-serving-on-local-environment) 14 | + [Build ONNX-MLIR Serving on Docker environment](#build-onnx-mlir-serving-on-docker-environment) (Recommended) 15 | 16 | ### Build ONNX-MLIR Serving on local environment 17 | 18 | 19 | #### **Prerequisite** 20 | 21 | 22 | ##### 1. GPRC Installed 23 | 24 | [Build GRPC from Source](https://github.com/grpc/grpc/blob/master/BUILDING.md#build-from-source) 25 | 26 | **GPRC Installation DIR example**: grpc/cmake/install 27 | 28 | 29 | ##### 2. ONNX MLIR Build is built 30 | 31 | Copy include files from onnx-mlir source to onnx-mlir build dir. 32 | 33 | ``` 34 | ls onnx-mlir-serving/onnx-mlir-build/* 35 | onnx-mlir-sering/onnx-mlir-build/include: 36 | benchmark CMakeLists.txt google onnx onnx-mlir OnnxMlirCompiler.h OnnxMlirRuntime.h rapidcheck rapidcheck.h 37 | 38 | onnx-mlir-serving/onnx-mlir-build/lib: 39 | libcruntime.a 40 | ``` 41 | 42 | #### **Build ONNX-MLIR Serving** 43 | 44 | ``` 45 | cmake -DCMAKE_BUILD_TYPE=Release -DGRPC_DIR:STRING={GPRC_SRC_DIR} -DONNX_COMPILER_DIR:STRING={ONNX_MLIR_BUILD_DIR} -DCMAKE_PREFIX_PATH={GPRC_INSTALL_DIR} ../.. 46 | make -j 47 | ``` 48 | 49 | ### Build ONNX-MLIR Serving on Docker environment 50 | 51 | 52 | Build AI GPRC Server and Client 53 | ``` 54 | docker build -t onnx/aigrpc-server . 55 | ``` 56 | 57 | 58 | ## **Run ONNX-MLIR Server and Client** 59 | 60 | ### Server: 61 | ``` 62 | ./grpc_server -h 63 | usage: grpc_server [options] 64 | -w arg wait time for batch size, default is 0 65 | -b arg server side batch size, default is 1 66 | -n arg thread numberm default is 1 67 | 68 | ./grpc_server 69 | ``` 70 | ### Add more models 71 | 72 | Build Models Directory 73 | ``` 74 | /cmake/build 75 | mkdir models 76 | ``` 77 | example models directory 78 | ``` 79 | models 80 | └── mnist 81 | ├── config 82 | ├── model.so 83 | └── model.onnx 84 | ``` 85 | 86 | #### config 87 | discripte model configs, can be generated usng utils/OnnxReader 88 | examle of mnist config 89 | ``` 90 | input { 91 | name: "Input3" 92 | type { 93 | tensor_type { 94 | elem_type: 1 95 | shape { 96 | dim { 97 | dim_value: 1 98 | } 99 | dim { 100 | dim_value: 1 101 | } 102 | dim { 103 | dim_value: 28 104 | } 105 | dim { 106 | dim_value: 28 107 | } 108 | } 109 | } 110 | } 111 | } 112 | output { 113 | name: "Plus214_Output_0" 114 | type { 115 | tensor_type { 116 | elem_type: 1 117 | shape { 118 | dim { 119 | dim_value: 1 120 | } 121 | dim { 122 | dim_value: 10 123 | } 124 | } 125 | } 126 | } 127 | } 128 | max_batch_size: 1 129 | ``` 130 | 131 | ### Inference request 132 | 133 | see utils/inference.proto and utils/onnx.proto 134 | 135 | 136 | #### Use Batching 137 | There are two place to input batch size 138 | 1. In model config file 'max_batch_size' 139 | 2. When start grpc_server -b [batch size] 140 | 141 | situation_1: grpc_server without -b, defaule batch size is 1, means no batching 142 | situation_2: grpc_server -b , batch_size > 1, and model A config max_batch_size > 1, when query model A, will use the mininum batch size. 143 | situation_3: grpc_server -b , batch_size > 1, and model B config max_batch_size = 1 (generated by default), when query model B, will not using batching. 144 | 145 | 146 | ### example client: 147 | ``` 148 | example/cpp or example/python 149 | ``` 150 | 151 | 152 | ## Example 153 | 154 | See [grpc-test.cc](./tests/grpc-test.cc) 155 | 156 | - TEST_F is a simpliest example to serve minst model. 157 | 158 | -------------------------------------------------------------------------------- /Tutorial.md: -------------------------------------------------------------------------------- 1 | # Inference MNIST Model with ONNX-MLIR-Serving 2 | 3 | ## Tutorial 4 | 5 | This tutorial demonstrates how to perform inference on the MNIST model using the ONNX-MLIR-Serving framework. Follow the steps below to set up the environment and perform the inference. 6 | 7 | ### Step 1: Build ONNX-MLIR-Serving 8 | 9 | 1. Start by building the ONNX-MLIR-Serving Docker image. Navigate to the `onnx-mlir-serving` directory and run the following command: 10 | 11 | ```shell 12 | cd onnx-mlir-serving 13 | sudo docker build -t onnx/onnx-mlir-server . 14 | ``` 15 | 16 | ### Step 2: Obtain the MNIST ONNX Model 17 | 18 | 1. Download the MNIST model and its associated input data sets from the ONNX Model Zoo. Use the following commands to download and extract the model: 19 | 20 | ```shell 21 | wget https://github.com/onnx/models/raw/main/vision/classification/mnist/model/mnist-12.tar.gz 22 | tar xvzf mnist-12.tar.gz 23 | ``` 24 | This will create the `mnist-12` directory, which contains the model and its associated data sets. 25 | 26 | ### Step 3: Compile the ONNX Model with ONNX-MLIR 27 | 28 | 1. Start a Docker container using the `onnx/onnx-mlir-server` image. Use the following command to run the container and open a shell inside it: 29 | (The prompt "container#" denotes the shell inside the container) 30 | 31 | ```shell 32 | sudo docker run -it --name serving onnx/onnx-mlir-server 33 | Emulate Docker CLI using podman. Create /etc/containers/nodocker to quiet msg. 34 | container# 35 | ``` 36 | 37 | 2. In another shell session, copy the `mnist-12` directory from the host to the running container using the following command: 38 | 39 | ```shell 40 | sudo docker cp mnist-12/ serving:/workdir 41 | ``` 42 | 43 | 3. Return to the shell environment of the container and navigate to the `mnist-12` directory. Use the following command to compile the `mnist-12.onnx` model into executable code: 44 | 45 | ```shell 46 | container# cd mnist-12/ 47 | container# onnx-mlir -O3 --maccel=NNPA mnist-12.onnx 48 | ``` 49 | 50 | This will generate a shared library file named `mnist-12.so` in the `mnist-12` directory. 51 | 52 | 4. Verify that the files were generated successfully: 53 | 54 | ```shell 55 | container# ls 56 | mnist-12.onnx mnist-12.so test_data_set_0 57 | ``` 58 | 59 | ### Step 4: Configure the Model for GRPC Serving 60 | 61 | 1. Prepare a directory structure for the model inside the `models` directory of the `grpc_server`: 62 | 63 | ``` 64 | grpc_server 65 | ├── models 66 | │ └── mnist 67 | │ ├── config 68 | │ ├── model.so 69 | │ └── model.onnx 70 | ``` 71 | 72 | 2. Copy the model's ONNX and compiled `.so` files to the `models/mnist` directory inside the container: 73 | 74 | ```shell 75 | container# cd /workdir/onnx-mlir-serving/cmake/build 76 | container# mkdir -p models/mnist 77 | container# cd models/mnist/ 78 | container# cp /workdir/mnist-12/mnist-12.onnx model.onnx 79 | container# cp /workdir/mnist-12/mnist-12.so model.so 80 | ``` 81 | 82 | 3. Generate a configuration file for the model based on its ONNX file: 83 | 84 | ```shell 85 | container# ../../utils/OnnxReader ./model.onnx 86 | container# ls 87 | config model.onnx model.so 88 | ``` 89 | 90 | The `config` file should be in JSON format and contains information about the model's input, output, and maximum batch size. 91 | 92 | ### Step 5: Start the GRPC Server 93 | 94 | 1. Return to the `grpc_server` directory and start the server: 95 | 96 | ```shell 97 | container# cd /workdir/onnx-mlir-serving/cmake/build 98 | container# ./grpc_server 99 | wait time 0ns 100 | batch max size 1 101 | thread number 1 102 | Server listening on 0.0.0.0:50051 103 | ``` 104 | 105 | ### Step 6: Send Inference Request with a Client 106 | 107 | 1. Start an client to send inference request for inference on the MNIST model. Use the provided C++ client example `Client` with the following command: 108 | 109 | ```shell 110 | container# cd /workdir/onnx-mlir-serving/cmake/build/cpp 111 | container# ./Client /workdir/mnist-12 mnist 112 | ``` 113 | 114 | The output will be an array representing the possibilities of each digit (0-9) based on the input. The inference result will be the digit with the highest possibility. 115 | 116 | Example output: 117 | 118 | ``` 119 | result size: 10 120 | -48.8125 121 | -4.6875 122 | -21.9062 123 | -13.3906 124 | 75.25 125 | -8.10938 126 | -52.3125 127 | 22.625 128 | -16.6875 129 | 48.0625 130 | ``` 131 | 132 | In this example, the highest possibility is 75.25 for the 5th number (digit 4), indicating that digit 4 is the predicted result. 133 | 134 | Note: Please make sure to adjust the paths and commands as necessary based on your specific setup and environment. 135 | 136 | This program sends first record of mnist dataset (input*.pb) to get the inference result. 137 | 138 | If you like, you can update example client code to read other datasets or your own record. 139 | 140 | #### Client Example Code 141 | 142 | ```C 143 | #include "grpc_client.h" 144 | 145 | // ./main 146 | int main(int argc, char** argv) { 147 | 148 | Dataset ds(argv[1], argv[2]); 149 | const char* address = "localhost:50051"; 150 | if(argc > 3) 151 | address = argv[3]; 152 | InferenceClient client(grpc::CreateChannel(address, grpc::InsecureChannelCredentials())); 153 | // ds.getInput(0) just get first record only. 154 | std::vector out = client.Inference(ds.getInput(0)); 155 | std::cout << "result size: " << out.size() << std::endl; 156 | for(size_t i = 0; i< out.size(); i++) 157 | std::cout << out[i] << std::endl; 158 | 159 | return 0; 160 | } 161 | ``` 162 | -------------------------------------------------------------------------------- /Tutorial.zh-CN.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # 使用ONNX-MLIR-Serving推理MNIST模型 4 | 5 | ## 教程 6 | 7 | 这个教程演示了如何使用ONNX-MLIR-Serving框架对MNIST模型进行推理。 8 | 按照以下步骤设置环境并执行推断。 9 | 10 | ### 步骤1:构建ONNX-MLIR-Serving 11 | 12 | 1. 首先构建ONNX-MLIR-Serving Docker镜像。进入`onnx-mlir-serving`目录并运行以下命令: 13 | 14 | ```shell 15 | cd onnx-mlir-serving 16 | sudo docker build -t onnx/onnx-mlir-server . 17 | ``` 18 | 19 | ### 步骤2:获取MNIST ONNX模型 20 | 21 | 1. 从ONNX Model Zoo下载MNIST模型及其相关输入数据集。使用以下命令下载并提取模型: 22 | 23 | ```shell 24 | wget https://github.com/onnx/models/raw/main/vision/classification/mnist/model/mnist-12.tar.gz 25 | 26 | tar xvzf mnist-12.tar.gz 27 | 28 | ``` 29 | 这将创建`mnist-12`目录,其中包含模型及其关联的数据集。 30 | 31 | ### 第三步:使用ONNX-MLIR编译ONNX模型 32 | 33 | 1. 使用 `onnx/onnx-mlir-server` 镜像启动一个 Docker 容器。使用以下命令运行容器并打开其内部的 shell 终端: 34 | ( "container#" 的意思是容器里面的shell) 35 | 36 | ```shell 37 | sudo docker run -it --name serving onnx/onnx-mlir-server 38 | ``` 39 | 40 | 2. 在另一个终端会话中,使用以下命令将主机上的`mnist-12`目录复制到正在运行的容器中: 41 | 42 | ``` 43 | sudo docker cp mnist-12/ serving:/workdir 44 | ``` 45 | 46 | 3. 返回容器的 shell 环境并导航到 `mnist-12` 目录。使用以下命令将 `mnist-12.onnx` 模型编译为可执行代码: 47 | 48 | ```shell 49 | container# cd mnist-12/ 50 | container# onnx-mlir -O3 --maccel=NNPA mnist-12.onnx 51 | ``` 52 | 这将在`mnist-12`目录中生成一个名为`mnist-12.so`的共享库文件。 53 | 54 | 4. 验证文件是否成功生成: 55 | 56 | ```shell 57 | container# ls 58 | mnist-12.onnx mnist-12.so test_data_set_0 59 | mnist-12.onnx mnist-12.so test_data_set_0 60 | ``` 61 | 62 | ### 步骤 4:配置用于 GRPC 服务的模型 63 | 在`grpc_server`的`models`目录中创建一个模型的目录结构。 64 | 65 | ``` 66 | grpc_server 67 | ├── models 68 | │ └── mnist 69 | │ ├── config 70 | │ ├── model.so 71 | │ └── model.onnx 72 | ``` 73 | 74 | 1. 将模型的ONNX和编译的`.so`文件复制到容器内的`models/mnist`目录中: 75 | 76 | ```shell 77 | container#cd / workdir / onnx-mlir-serving / cmake / build 78 | container# mkdir -p models/mnist 79 | container# cd models/mnist/ 80 | container#cp /workdir/mnist-12/mnist-12.onnx model.onnx 81 | container# cp /workdir/mnist-12/mnist-12.so model.so 82 | ``` 83 | 2.根据ONNX文件为模型生成一个配置文件: 84 | 85 | ```shell 86 | container# ../../utils/OnnxReader ./model.onnx 87 | container# ls 88 | config model.onnx model.so 89 | ``` 90 | 91 | `config`文件应采用JSON格式,并包含有关模型的输入、输出和最大批处理大小的信息。 92 | 93 | ### 步骤5:启动GRPC服务器 94 | 95 | 1. 返回到`grpc_server`目录并启动服务器: 96 | ```shell 97 | container# cd /workdir/onnx-mlir-serving/cmake/build 98 | container# ./grpc_server 99 | wait time 0ns 100 | batch max size 1 101 | thread number 1 102 | Server listening on 0.0.0.0:50051 103 | ``` 104 | 105 | ### 步骤 6:使用客户端发送推理请求 106 | 107 | 开始一个客户端发送推论请求,对MNIST模型进行推论。使用提供的C++客户端示例`Client`,执行以下命令: 108 | ```shell 109 | container# cd /workdir/onnx-mlir-serving/cmake/build/cpp 110 | container# ./Client /workdir/mnist-12 mnist 111 | ``` 112 | 113 | 输出将是一个表示基于输入的每个数字(0-9)可能性的数组。推理结果将是具有最高可能性的数字。 114 | 115 | ``` 116 | 结果大小:10 117 | -48.8125 118 | -4.6875 119 | -21.9062 120 | -13.3906 121 | 75.25 122 | -8.10938 123 | -52.3125 124 | -52.3125 125 | 22.625 126 | -16.6875 127 | 48.0625 128 | ``` 129 | 在这个例子中,第5个数字(第4位数)的最高可能性为75.25,表明数字4是预测结果。 130 | 131 | 注意:请根据您的特定设置和环境调整路径和命令。 132 | 133 | 该程序发送mnist数据集的第一条记录(input*.pb)以获取推理结果。 134 | 如果你愿意,你可以更新示例客户端代码以读取其他数据集或你自己的记录。 135 | 136 | ### 客户端示例代码 137 | 138 | ```C 139 | #include "grpc_client.h" 140 | 141 | // ./main 142 | int main(int argc, char** argv) { 143 | 144 | Dataset ds(argv[1], argv[2]); 145 | const char* address = "localhost:50051"; 146 | if(argc > 3) 147 | address = argv[3]; 148 | InferenceClient client(grpc::CreateChannel(address, grpc::InsecureChannelCredentials())); 149 | // ds.getInput(0) just get first record only. 150 | std::vector out = client.Inference(ds.getInput(0)); 151 | std::cout << "result size: " << out.size() << std::endl; 152 | for(size_t i = 0; i< out.size(); i++) 153 | std::cout << out[i] << std::endl; 154 | 155 | return 0; 156 | } 157 | ``` 158 | -------------------------------------------------------------------------------- /example_client/cpp/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | project(grpc_cpp_client) 2 | 3 | include_directories(${CMAKE_CURRENT_SOURCE_DIR}) 4 | include_directories(${CMAKE_CURRENT_BINARY_DIR}) 5 | 6 | add_executable(Benchmark "${CMAKE_CURRENT_SOURCE_DIR}/benchmark.cc") 7 | target_link_libraries(Benchmark 8 | hw_grpc_proto 9 | ${_REFLECTION} 10 | ${_GRPC_GRPCPP} 11 | ) 12 | 13 | add_executable(Client "${CMAKE_CURRENT_SOURCE_DIR}/main.cc") 14 | target_link_libraries(Client 15 | hw_grpc_proto 16 | ${_REFLECTION} 17 | ${_GRPC_GRPCPP} 18 | ) 19 | 20 | -------------------------------------------------------------------------------- /example_client/cpp/benchmark.cc: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "grpc_client.h" 4 | 5 | 6 | class MultiClientSimlate{ 7 | 8 | double _recordStart; 9 | double _recordEnd; 10 | double _totalTime; 11 | double _actuTime; 12 | std::vector _client_threads; 13 | std::mutex _log_mutex; 14 | Dataset &ds; 15 | char _host[200]; 16 | std::ofstream ofs; 17 | std::vector latencys; 18 | std::unordered_map> _resultMap; 19 | 20 | public: 21 | MultiClientSimlate(char* host, Dataset &ds, int threadNum, double recordStart, double recordEnd, double totalTime, char* logPrefix):ds(ds){ 22 | 23 | _recordStart = recordStart*1000; 24 | _recordEnd = recordEnd*1000; 25 | _totalTime = totalTime*1000; 26 | 27 | // char* _host = host; 28 | memcpy(_host, host, strlen(host)); 29 | 30 | char logname[200]; 31 | sprintf(logname, "build/log_%s.txt", logPrefix); 32 | 33 | ofs.open(logname, std::ios::out); 34 | 35 | // std::unordered_map> resultMap; 36 | 37 | high_resolution_clock::time_point startTime = high_resolution_clock::now(); 38 | createThread(threadNum); 39 | for (std::thread& thread : _client_threads) { 40 | //thread.detach(); 41 | if(thread.joinable()) 42 | thread.join(); 43 | } 44 | high_resolution_clock::time_point endTime = high_resolution_clock::now(); 45 | _actuTime = std::chrono::duration(endTime - startTime).count(); 46 | 47 | 48 | ofs.close(); 49 | 50 | } 51 | 52 | void createThread(int threadNum){ 53 | 54 | high_resolution_clock::time_point startTime = high_resolution_clock::now(); 55 | for(size_t i = 0; i timerRecord; 70 | std::map> resultMap; 71 | while(true){ 72 | int index = count % imageSize; 73 | 74 | if(!isStartRecord){ 75 | double d = std::chrono::duration(high_resolution_clock::now() - startTime).count(); 76 | if(d > _recordStart){ 77 | isStartRecord = true; 78 | } 79 | 80 | }else{ 81 | double d = std::chrono::duration(high_resolution_clock::now() - startTime).count(); 82 | if(d > _recordEnd){ 83 | isEndRecord = true; 84 | } 85 | 86 | } 87 | 88 | if(!isRun){ 89 | double d = std::chrono::duration(high_resolution_clock::now() - startTime).count(); 90 | if(d > 5000){ 91 | isRun = true; 92 | }else{ 93 | std::this_thread::sleep_for(std::chrono::milliseconds(10)); 94 | } 95 | } 96 | 97 | if(isRun){ 98 | Timer t; 99 | t.start = high_resolution_clock::now(); 100 | std::vector out = client.Inference(ds.getInput(index)); 101 | t.end = high_resolution_clock::now(); 102 | resultMap.emplace(index,out); 103 | if(isStartRecord && !isEndRecord){ 104 | timerRecord.push_back(t); 105 | } 106 | count ++; 107 | totalcount ++; 108 | } 109 | 110 | 111 | if(isEndRecord){ 112 | double d = std::chrono::duration(high_resolution_clock::now() - startTime).count(); 113 | if( d > _totalTime ){ 114 | for(Timer t:timerRecord){ 115 | std::lock_guard log_lock(_log_mutex); 116 | ofs << threadIndex << " "; 117 | double latency = std::chrono::duration(t.end -t.start).count(); 118 | latencys.emplace_back(latency); 119 | ofs << std::to_string(latency) << std::endl; 120 | } 121 | 122 | for (auto iter = resultMap.begin(); iter != resultMap.end(); ++iter){ 123 | std::lock_guard log_lock(_log_mutex); 124 | _resultMap.emplace(iter->first, iter->second); 125 | } 126 | 127 | break; 128 | } 129 | } 130 | 131 | 132 | } 133 | },startTime, i); 134 | } 135 | } 136 | 137 | char Bin2Hex(uint8_t four_bits) { 138 | char number = '0' + four_bits; 139 | char letter = ('A' - 10) + four_bits; 140 | return four_bits < 10 ? number : letter; 141 | } 142 | 143 | const std::string ArgValueTransform(const std::vector* data) { 144 | if (data == nullptr) { 145 | return "\"\""; 146 | } 147 | std::string hex; 148 | hex.reserve(data->size() + 2); 149 | hex.push_back('"'); 150 | for (auto b : *data) { 151 | hex.push_back(Bin2Hex(b >> 4)); 152 | hex.push_back(Bin2Hex(b & 0x0F)); 153 | } 154 | hex.push_back('"'); 155 | return hex; 156 | } 157 | 158 | double meanLatency(){ 159 | double sum = std::accumulate(latencys.begin(), latencys.end(), 0.0); 160 | // std::cout << sum << " " << latencys.size() << std::endl; 161 | double mean = sum / latencys.size(); 162 | 163 | std::cout << "total time: " << _actuTime << std::endl; 164 | std::cout << "qps: " << latencys.size() / _actuTime * 1000 << std::endl; 165 | 166 | 167 | char logname[200]; 168 | sprintf(logname, "build/log_acc%s.json", "2"); 169 | std::ofstream ofs_acc; 170 | size_t float_size = sizeof(float); 171 | ofs_acc.open(logname, std::ios::out); 172 | ofs_acc << "[" << std::endl; 173 | for (auto iter = _resultMap.begin(); iter != _resultMap.end(); ++iter){ 174 | 175 | uint8_t* src_begin = reinterpret_cast(iter->second.data()); 176 | uint8_t* src_end = src_begin + iter->second.size()*float_size; 177 | std::vector* data = new std::vector(src_begin, src_end); 178 | 179 | ofs_acc << "{\"qsl_idx\":\"" << iter->first << "\",\"data\":"<< ArgValueTransform(data) << "}," << std::endl; 180 | } 181 | ofs_acc << "]"; 182 | return mean; 183 | } 184 | }; 185 | 186 | // ./Benchmark ccf1 /aivol/inputs/ccf1_inputs 1 187 | // ./Benchmark model_name inputs threadNum 188 | int main(int argc, char** argv) { 189 | 190 | Dataset ds(argv[2], argv[1]); 191 | int threadNum = 64; 192 | char* logPrefix = argv[1]; 193 | 194 | if(argc > 3){ 195 | threadNum = std::stoi(argv[3]); 196 | } 197 | 198 | char* host = "localhost:50051"; 199 | if((host = getenv("grpc_server"))) 200 | std::cout << "Using " << host << std::endl; 201 | else 202 | std::cout << "Using " << host << std::endl; 203 | 204 | 205 | std::cout << "number of threads: " << threadNum << std::endl; 206 | MultiClientSimlate s(host, ds, threadNum, 0, 60, 60, logPrefix); 207 | 208 | InferenceClient* client = new InferenceClient(grpc::CreateChannel(host, grpc::InsecureChannelCredentials())); 209 | client->printStatistics(); 210 | 211 | double latency = s.meanLatency(); 212 | std::cout << "mean latency: " << latency << std::endl; 213 | return 0; 214 | } 215 | -------------------------------------------------------------------------------- /example_client/cpp/grpc_client.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | 16 | #include 17 | #include 18 | 19 | #include "utils/inference.grpc.pb.h" 20 | 21 | using grpc::Channel; 22 | using grpc::ClientAsyncResponseReader; 23 | using grpc::ClientContext; 24 | using grpc::CompletionQueue; 25 | using grpc::Status; 26 | using inference::InferenceService; 27 | using inference::InferenceResponse; 28 | using inference::InferenceRequest; 29 | using inference::PrintStatisticsRequest; 30 | using inference::PrintStatisticsResponse; 31 | using std::chrono::high_resolution_clock; 32 | 33 | struct Timer { 34 | high_resolution_clock::time_point start; 35 | high_resolution_clock::time_point end; 36 | }; 37 | 38 | class Dataset { 39 | public: 40 | std::vector requestsInput; 41 | std::string modelName; 42 | bool invalid = false; 43 | Dataset(std::string imagePath_, std::string modelName): 44 | modelName(modelName) 45 | { 46 | std::string imagePath = imagePath_; 47 | 48 | try{ 49 | std::filesystem::path p = imagePath; 50 | for (const auto &entry : std::filesystem::directory_iterator(p)) 51 | { 52 | std::string path = entry.path().string(); 53 | 54 | std::regex pattern(".*test_data_set_[0-9].*"); 55 | 56 | if (std::regex_match(path, pattern)){ 57 | // std::cout << path << std::endl; 58 | std::vector paths; 59 | 60 | for (const auto &inner_entry : std::filesystem::directory_iterator(entry.path())){ 61 | std::string inner_path = inner_entry.path().string(); 62 | if (inner_path.find("_input") == std::string::npos && inner_path.find("input") != std::string::npos) 63 | { 64 | // std::cout << inner_path << std::endl; 65 | paths.push_back(inner_path); 66 | } 67 | } 68 | std::sort(paths.begin(), paths.end()); 69 | InferenceRequest* request = new InferenceRequest(); 70 | request->set_model_name(modelName); 71 | for (std::string p_: paths){ 72 | // std::cout << p << std::endl; 73 | addTensor(p_.c_str(),request); 74 | } 75 | requestsInput.push_back(request); 76 | } 77 | } 78 | }catch (const std::filesystem::filesystem_error& ex) { 79 | std::cout << "Input directory cannot open directory: No such file or directory" < buffer(size); 95 | input.read(buffer.data(), size); // read raw data 96 | 97 | 98 | onnx::TensorProto* tensor = request->add_tensor(); 99 | tensor->ParseFromArray(buffer.data(), size); 100 | input.close(); 101 | } 102 | 103 | InferenceRequest* getInput(int index){ 104 | return requestsInput[index]; 105 | } 106 | 107 | int getImageCount(){ 108 | return requestsInput.size(); 109 | } 110 | 111 | }; 112 | 113 | class InferenceClient { 114 | public: 115 | explicit InferenceClient(std::shared_ptr channel) 116 | : stub_(InferenceService::NewStub(channel)) {} 117 | 118 | void printStatistics(){ 119 | PrintStatisticsRequest request; 120 | PrintStatisticsResponse response; 121 | ClientContext context; 122 | CompletionQueue cq; 123 | Status status; 124 | std::unique_ptr > rpc( 125 | stub_->PrepareAsyncPrintStatistics(&context, request, &cq)); 126 | rpc->StartCall(); 127 | rpc->Finish(&response, &status, (void*)1); 128 | void* got_tag; 129 | bool ok = false; 130 | GPR_ASSERT(cq.Next(&got_tag, &ok)); 131 | GPR_ASSERT(got_tag == (void*)1); 132 | } 133 | 134 | std::vector Inference(InferenceRequest* request_){ 135 | 136 | InferenceRequest request = *request_; 137 | 138 | const onnx::TensorProto& tensor = request.tensor(0); 139 | 140 | InferenceResponse reply; 141 | ClientContext context; 142 | CompletionQueue cq; 143 | Status status; 144 | std::unique_ptr> response_reader(stub_->PrepareAsyncInference(&context, request, &cq)); 145 | response_reader->StartCall(); 146 | response_reader->Finish(&reply, &status, (void*)1); 147 | void* got_tag; 148 | bool ok = false; 149 | 150 | GPR_ASSERT(cq.Next(&got_tag, &ok)); 151 | GPR_ASSERT(got_tag == (void*)1); 152 | GPR_ASSERT(ok); 153 | std::vector out(reply.tensor(0).float_data().begin(),reply.tensor(0).float_data().end()); 154 | return out; 155 | } 156 | 157 | private: 158 | std::unique_ptr stub_; 159 | }; -------------------------------------------------------------------------------- /example_client/cpp/main.cc: -------------------------------------------------------------------------------- 1 | #include "grpc_client.h" 2 | 3 | // ./main 4 | int main(int argc, char** argv) { 5 | 6 | Dataset ds(argv[1], argv[2]); 7 | const char* address = "localhost:50051"; 8 | if(argc > 3) 9 | address = argv[3]; 10 | InferenceClient client(grpc::CreateChannel(address, grpc::InsecureChannelCredentials())); 11 | std::vector out = client.Inference(ds.getInput(0)); 12 | std::cout << "result size: " << out.size() << std::endl; 13 | for(size_t i = 0; i< out.size(); i++) 14 | std::cout << out[i] << std::endl; 15 | 16 | return 0; 17 | } 18 | -------------------------------------------------------------------------------- /example_client/python/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | project(grpc_python_client) 2 | 3 | 4 | find_package(PythonInterp) 5 | 6 | set(onnx_proto_grpc_py "${CMAKE_CURRENT_BINARY_DIR}/onnx.pb2_grpc.py") 7 | set(onnx_proto_py "${CMAKE_CURRENT_BINARY_DIR}/onnx.pb2.py") 8 | set(hw_proto_py "${CMAKE_CURRENT_BINARY_DIR}/inference.pb2.py") 9 | set(hw_proto_grpc_py "${CMAKE_CURRENT_BINARY_DIR}/inference.pb2_grpc.py") 10 | 11 | 12 | execute_process(COMMAND ${PYTHON_EXECUTABLE} -m grpc_tools.protoc --grpc_out "${CMAKE_CURRENT_BINARY_DIR}" --python_out= "${CMAKE_CURRENT_BINARY_DIR}" -I "${hw_proto_path}" "${hw_proto}" "${onnx_proto}" ) 13 | 14 | 15 | -------------------------------------------------------------------------------- /example_client/python/README.md: -------------------------------------------------------------------------------- 1 | ## Prepare for server side 2 | 3 | ``` 4 | mkdir models/gpt2lm 5 | copy model.onnx to models/gpt2lm 6 | ./utils/OnnxReader models/gpt2lm/model.onnx 7 | ./grpc_server 8 | ``` 9 | 10 | ## Prepare for python client side 11 | 12 | ``` 13 | cd example_client/python 14 | python3 -m grpc_tools.protoc -I ../../utils --python_out=. --grpc_python_out=. ../../utils/inference.proto ../../utils/onnx.proto 15 | python3 gpt2ml_example.py 16 | ``` -------------------------------------------------------------------------------- /example_client/python/gpt2ml_example.py: -------------------------------------------------------------------------------- 1 | import grpc 2 | 3 | import inference_pb2_grpc 4 | import inference_pb2 5 | import onnx_pb2 6 | 7 | from transformers import GPT2Model, GPT2Tokenizer 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | import numpy as np 12 | 13 | def post(out): 14 | logits = out[0,0, -1, :] 15 | log_probs = F.softmax(torch.tensor(logits), dim=-1) 16 | _, prev = torch.topk(log_probs, k=1, dim=-1) 17 | return prev.tolist() 18 | 19 | 20 | def run_gpt2(host,maxlen): 21 | text="Tell me about IBM" 22 | print(text, end="") 23 | tokenizer = GPT2Tokenizer.from_pretrained('gpt2') 24 | tokens_ = tokenizer.encode(text) 25 | tokens = np.array(tokenizer.encode(text)) 26 | for i in range(maxlen): 27 | 28 | options = [('grpc.max_send_message_length', 512 * 1024 * 1024), ('grpc.max_receive_message_length', 512 * 1024 * 1024)] 29 | 30 | with grpc.insecure_channel(host,options = options) as channel: 31 | stub = inference_pb2_grpc.InferenceServiceStub(channel) 32 | tensor = onnx_pb2.TensorProto(data_type=7, dims=[1,1,len(tokens_)], int64_data=tokens_) 33 | response = stub.Inference(inference_pb2.InferenceRequest(tensor=[tensor], model_name='gpt2lm')) 34 | 35 | rdata = response.tensor[0].float_data 36 | 37 | prev = post(np.reshape(rdata, response.tensor[0].dims)) 38 | print(tokenizer.decode(prev), end="") 39 | tokens_.append(prev[0]) 40 | 41 | 42 | if __name__ == '__main__': 43 | run_gpt2('10.1.20.99:50051', 15) -------------------------------------------------------------------------------- /grpc_server/grpc_server.cc: -------------------------------------------------------------------------------- 1 | // #pragma GCC diagnostic ignored "-Wdelete-non-virtual-dtor" 2 | #include "grpc_server.h" 3 | #include 4 | 5 | std::chrono::high_resolution_clock::time_point originTime = std::chrono::high_resolution_clock::now(); 6 | OnnxMlirModelLoader modelLoder; 7 | 8 | 9 | void CallData::Proceed(void *modelManager){ 10 | if (status_ == CREATE) { 11 | // Make this instance progress to the PROCESS state. 12 | status_ = PROCESS; 13 | switch (s_type_){ 14 | case CallData::inference: 15 | service_->RequestInference(&ctx_, &request_, &responder_, cq_, cq_, this); 16 | break; 17 | case CallData::printStatistics: 18 | service_->RequestPrintStatistics(&ctx_, &printStatisticsRequest_, &printStatisticsResponder_, cq_, cq_, this); 19 | break; 20 | default: 21 | break; 22 | } 23 | // service_->RequestInference(&ctx_, &request_, &responder_, cq_, cq_,this); 24 | } else if (status_ == PROCESS) { 25 | 26 | switch (s_type_){ 27 | case CallData::inference: 28 | new CallData(service_, cq_,CallData::inference); 29 | static_cast(modelManager)->AddModel(this); 30 | now = high_resolution_clock::now(); 31 | break; 32 | case CallData::printStatistics: 33 | new CallData(service_, cq_,CallData::printStatistics); 34 | static_cast(modelManager)->PrintLogs(); 35 | status_ = FINISH; 36 | printStatisticsResponder_.Finish(printStatisticsReply_, Status::OK, this); 37 | break; 38 | default: 39 | break; 40 | } 41 | } else { 42 | GPR_ASSERT(status_ == FINISH); 43 | // Once in the FINISH state, deallocate ourselves (CallData). 44 | delete this; 45 | } 46 | } 47 | 48 | void ServerImpl::Run(){ 49 | std::string server_address("0.0.0.0:50051"); 50 | ServerBuilder builder; 51 | builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); 52 | builder.RegisterService(&service_); 53 | cq_ = builder.AddCompletionQueue(); 54 | server_ = builder.BuildAndStart(); 55 | std::cout << "Server listening on " << server_address << std::endl; 56 | 57 | new CallData(&service_, cq_.get(),CallData::printStatistics); 58 | new CallData(&service_, cq_.get(),CallData::inference); 59 | HandleRpcs(0); 60 | // for(int i = 0; i < 2; i++){ 61 | // async_threads.emplace_back([this](int i){ 62 | // HandleRpcs(i); 63 | // },i); 64 | // } 65 | } 66 | 67 | void ServerImpl::HandleRpcs(int i){ 68 | // Spawn a new CallData instance to serve new clients. 69 | 70 | void* tag; // uniquely identifies a request. 71 | bool ok; 72 | while (true) { 73 | 74 | lock_guard lock(mtx_); 75 | GPR_ASSERT(cq_->Next(&tag, &ok)); 76 | 77 | // GPR_ASSERT(ok); 78 | if(ok){ 79 | static_cast(tag)->Proceed(&modelManager_); 80 | } 81 | 82 | } 83 | } 84 | 85 | void printHelp(){ 86 | std::cout << "usage: server [options]" << std::endl; 87 | std::cout << "-w arg " << "wait time for batch size" << std::endl; 88 | std::cout << "-b arg " << "server side batch size" << std::endl; 89 | std::cout << "-n arg " << "thread number" << std::endl; 90 | } 91 | 92 | int main(int argc, char** argv) { 93 | // std::AIInfrenceThreadPool tpool(5); 94 | int wait = 0; 95 | int batch_size = 1; 96 | int threadNum = 1; 97 | 98 | int argIndex = 1; 99 | 100 | while(argIndex < argc){ 101 | char *curArg = argv[argIndex]; 102 | if(strcmp(curArg, "-h") == 0){ 103 | printHelp(); 104 | return 0; 105 | } 106 | if(strcmp(curArg, "-w") == 0){ 107 | wait = std::stoi(argv[argIndex + 1]); 108 | argIndex = argIndex + 2; 109 | continue; 110 | } 111 | if(strcmp(curArg, "-b") == 0){ 112 | batch_size = std::stoi(argv[argIndex + 1]); 113 | argIndex = argIndex + 2; 114 | continue; 115 | } 116 | if(strcmp(curArg, "-n") == 0){ 117 | threadNum = std::stoi(argv[argIndex + 1]); 118 | argIndex = argIndex + 2; 119 | continue; 120 | } 121 | printHelp(); 122 | return 0; 123 | } 124 | 125 | std::cout << "wait time " << wait << " ns" << std::endl; 126 | std::cout << "batch max size " << batch_size << std::endl; 127 | std::cout << "thread number " << threadNum << std::endl; 128 | ServerImpl server(batch_size, threadNum, wait); 129 | server.Run(); 130 | 131 | return 0; 132 | } 133 | -------------------------------------------------------------------------------- /grpc_server/grpc_server.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #ifndef GRPC_SERVER_H 3 | #define GRPC_SERVER_H 4 | 5 | #include 6 | #include 7 | // #include 8 | // #include 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | // #include 16 | #include 17 | #include 18 | 19 | // #include 20 | // #include 21 | #include 22 | #include 23 | 24 | 25 | #include 26 | #include 27 | 28 | #include "utils/inference.grpc.pb.h" 29 | // #include "onnx.pb.h" 30 | 31 | 32 | #include "model_manager.h" 33 | 34 | using grpc::Server; 35 | using grpc::ServerAsyncResponseWriter; 36 | using grpc::ServerBuilder; 37 | using grpc::ServerCompletionQueue; 38 | using grpc::ServerContext; 39 | using grpc::Status; 40 | using inference::InferenceRequest; 41 | using inference::InferenceResponse; 42 | using inference::InferenceService; 43 | using inference::PrintStatisticsRequest; 44 | using inference::PrintStatisticsResponse; 45 | 46 | // #include "OnnxMlirRuntime.h" 47 | // extern "C"{ 48 | // OMTensorList *run_main_graph(OMTensorList *); 49 | // } 50 | 51 | using std::atomic; 52 | using std::condition_variable; 53 | using std::lock_guard; 54 | using std::mutex; 55 | using std::queue; 56 | using std::string; 57 | using std::thread; 58 | using std::unique_lock; 59 | using std::vector; 60 | using std::chrono::high_resolution_clock; 61 | 62 | 63 | class CallData : public AbstractCallData 64 | { 65 | 66 | public: 67 | enum ServiceType 68 | { 69 | inference = 0, 70 | printStatistics = 1 71 | }; 72 | 73 | public: 74 | CallData(InferenceService::AsyncService *service, ServerCompletionQueue *cq, ServiceType s_type) 75 | : service_(service), cq_(cq), responder_(&ctx_), printStatisticsResponder_(&ctx_), s_type_(s_type), status_(CREATE) 76 | { 77 | Proceed(NULL); 78 | } 79 | 80 | void Proceed(void *threadpool); 81 | 82 | InferenceRequest& getRequestData() 83 | { 84 | return request_; 85 | } 86 | 87 | 88 | void sendBack() 89 | { 90 | status_ = FINISH; 91 | responder_.Finish(reply_, Status::OK, this); 92 | } 93 | 94 | 95 | onnx::TensorProto* AddOutputTensor(){ 96 | onnx::TensorProto* tensor = reply_.add_tensor(); 97 | return tensor; 98 | } 99 | 100 | private: 101 | InferenceService::AsyncService *service_; 102 | ServerCompletionQueue *cq_; 103 | ServerContext ctx_; 104 | InferenceRequest request_; 105 | PrintStatisticsRequest printStatisticsRequest_; 106 | InferenceResponse reply_; 107 | PrintStatisticsResponse printStatisticsReply_; 108 | ServerAsyncResponseWriter responder_; 109 | ServerAsyncResponseWriter printStatisticsResponder_; 110 | ServiceType s_type_; 111 | enum CallStatus 112 | { 113 | CREATE, 114 | PROCESS, 115 | FINISH 116 | }; 117 | CallStatus status_; // The current serving state. 118 | }; 119 | 120 | class ServerImpl final 121 | { 122 | public: 123 | ~ServerImpl() 124 | { 125 | server_->Shutdown(); 126 | // Always shutdown the completion queue after the server. 127 | cq_->Shutdown(); 128 | } 129 | 130 | ServerImpl(int batch_size, int threadNum_, int wait) : modelManager_(batch_size, threadNum_, wait) 131 | { 132 | } 133 | 134 | void Run(); 135 | 136 | private: 137 | void HandleRpcs(int i); 138 | 139 | std::shared_ptr cq_; 140 | InferenceService::AsyncService service_; 141 | OnnxMlirModelManager modelManager_; 142 | std::shared_ptr server_; 143 | vector async_threads; 144 | mutex mtx_; 145 | }; 146 | 147 | #endif 148 | -------------------------------------------------------------------------------- /grpc_server/model_loader.cc: -------------------------------------------------------------------------------- 1 | #include "model_loader.h" 2 | 3 | int check_endianness() 4 | { 5 | union 6 | { 7 | char c; 8 | int i; 9 | } u; 10 | 11 | // Assign an int value with a known byte pattern to the union 12 | u.i = 0x01020304; 13 | 14 | // Check the value of the char member of the union 15 | if (u.c == 0x01) 16 | { 17 | return 1; // Big-endian 18 | } 19 | 20 | if (u.c == 0x04) 21 | { 22 | return -1; // Little-endian 23 | } 24 | 25 | return 0; // Unknown endianness 26 | } 27 | 28 | 29 | uint64_t swap_uint64(uint64_t n) 30 | { 31 | // Swap lower and upper 32 bits 32 | n = ((n & 0x00000000FFFFFFFF) << 32) | ((n & 0xFFFFFFFF00000000) >> 32); 33 | // Swap adjacent 16 bits 34 | n = ((n & 0x0000FFFF0000FFFF) << 16) | ((n & 0xFFFF0000FFFF0000) >> 16); 35 | // Swap adjacent 8 bits 36 | n = ((n & 0x00FF00FF00FF00FF) << 8) | ((n & 0xFF00FF00FF00FF00) >> 8); 37 | return n; 38 | } 39 | 40 | uint32_t swap_uint32(uint32_t n) 41 | { 42 | // Swap adjacent bytes using bit shifts and masks 43 | n = ((n & 0x00FF00FF) << 8) | ((n & 0xFF00FF00) >> 8); 44 | 45 | // Swap non-adjacent bytes using bit shifts and masks 46 | n = (n << 16) | (n >> 16); 47 | // Return the swapped integer 48 | return n; 49 | } 50 | 51 | uint16_t swap_uint16(uint16_t n) 52 | { 53 | // Swap adjacent bytes using bit shifts and masks 54 | return (n << 8) | (n >> 8); 55 | } 56 | 57 | 58 | float swap_float32(float n) 59 | { 60 | // Define a union type that can hold both float32 and int32 61 | union 62 | { 63 | float f; 64 | uint32_t i; 65 | } u; 66 | 67 | // Assign n to the float member of the union 68 | u.f = n; 69 | 70 | // Swap the bytes of the int member of the union 71 | u.i = swap_uint32(u.i); // Use the function from previous example 72 | 73 | // Return the float member of the union 74 | return u.f; 75 | } 76 | 77 | inline void buildTensorProto(void *prediction,int bufferSize, OM_DATA_TYPE type, onnx::TensorProto* tensor_proto){ 78 | 79 | int64_t typeSize = getDataTypeSize(type); 80 | onnx::TensorProto_DataType tensor_pType = OM_TO_ONNX_DATA_TYPE.at(type); 81 | tensor_proto->set_data_type(tensor_pType); 82 | switch(tensor_pType){ 83 | case(onnx::TensorProto_DataType_FLOAT): 84 | case(onnx::TensorProto_DataType_COMPLEX64): 85 | tensor_proto->mutable_float_data()->Add((float*)prediction, (float*)prediction + bufferSize/typeSize); 86 | break; 87 | case(onnx::TensorProto_DataType_UINT8): 88 | case(onnx::TensorProto_DataType_INT8): 89 | case(onnx::TensorProto_DataType_UINT16): 90 | case(onnx::TensorProto_DataType_INT16): 91 | case(onnx::TensorProto_DataType_FLOAT16): 92 | case(onnx::TensorProto_DataType_INT32): 93 | case(onnx::TensorProto_DataType_BOOL): 94 | tensor_proto->mutable_int32_data()->Add((int32_t*)prediction, (int32_t*)prediction + bufferSize/typeSize); 95 | break; 96 | case(onnx::TensorProto_DataType_INT64): 97 | tensor_proto->mutable_int64_data()->Add((int64_t*)prediction, (int64_t*)prediction + bufferSize/typeSize); 98 | break; 99 | case(onnx::TensorProto_DataType_STRING): 100 | tensor_proto->mutable_string_data()->Add((char*)prediction, (char*)prediction + bufferSize/typeSize); 101 | break; 102 | case(onnx::TensorProto_DataType_DOUBLE): 103 | case(onnx::TensorProto_DataType_COMPLEX128): 104 | tensor_proto->mutable_double_data()->Add((double*)prediction, (double*)prediction + bufferSize/typeSize); 105 | break; 106 | case(onnx::TensorProto_DataType_UINT32): 107 | case(onnx::TensorProto_DataType_UINT64): 108 | tensor_proto->mutable_uint64_data()->Add((uint64_t*)prediction, (uint64_t*)prediction + bufferSize/typeSize); 109 | break; 110 | case(onnx::TensorProto_DataType_BFLOAT16): 111 | tensor_proto->set_raw_data(prediction, bufferSize); 112 | break; 113 | } 114 | 115 | return; 116 | } 117 | 118 | 119 | inline int64_t getTensorProtoData(const onnx::TensorProto& tensor){ 120 | int64_t data_size; 121 | 122 | if(tensor.raw_data().size() > 0){ 123 | int64_t typeSize = getDataTypeSize(ONNX_DATA_TYPE_TO_OM.at(onnx::TensorProto_DataType(tensor.data_type()))); 124 | return tensor.raw_data().size() / typeSize; 125 | } 126 | 127 | switch(tensor.data_type()){ 128 | case(onnx::TensorProto_DataType_FLOAT): 129 | case(onnx::TensorProto_DataType_COMPLEX64): 130 | data_size = tensor.float_data_size(); 131 | break; 132 | case(onnx::TensorProto_DataType_UINT8): 133 | case(onnx::TensorProto_DataType_INT8): 134 | case(onnx::TensorProto_DataType_UINT16): 135 | case(onnx::TensorProto_DataType_INT16): 136 | case(onnx::TensorProto_DataType_FLOAT16): 137 | case(onnx::TensorProto_DataType_INT32): 138 | case(onnx::TensorProto_DataType_BOOL): 139 | data_size = tensor.int32_data_size(); 140 | break; 141 | case(onnx::TensorProto_DataType_INT64): 142 | data_size = tensor.int64_data_size(); 143 | break; 144 | case(onnx::TensorProto_DataType_STRING): 145 | data_size = tensor.string_data_size(); 146 | break; 147 | case(onnx::TensorProto_DataType_DOUBLE): 148 | case(onnx::TensorProto_DataType_COMPLEX128): 149 | data_size = tensor.double_data_size(); 150 | break; 151 | case(onnx::TensorProto_DataType_UINT32): 152 | case(onnx::TensorProto_DataType_UINT64): 153 | data_size = tensor.uint64_data_size(); 154 | break; 155 | case(onnx::TensorProto_DataType_BFLOAT16): 156 | data_size = tensor.raw_data().size()/sizeof(2); 157 | break; 158 | } 159 | 160 | return data_size; 161 | } 162 | 163 | 164 | inline void copyTensorData(const onnx::TensorProto &tensor, void* dst, int64_t index, int64_t length){ 165 | 166 | void* src; 167 | bool fromRaw = false; 168 | switch(tensor.data_type()){ 169 | case(onnx::TensorProto_DataType_FLOAT): 170 | case(onnx::TensorProto_DataType_COMPLEX64): 171 | if (tensor.float_data_size() == 0) 172 | { 173 | fromRaw = true; 174 | break; 175 | } 176 | src = (void*)(&tensor.float_data().data()[index]); 177 | break; 178 | case(onnx::TensorProto_DataType_UINT8): 179 | case(onnx::TensorProto_DataType_INT8): 180 | case(onnx::TensorProto_DataType_UINT16): 181 | case(onnx::TensorProto_DataType_INT16): 182 | case(onnx::TensorProto_DataType_FLOAT16): 183 | case(onnx::TensorProto_DataType_INT32): 184 | case(onnx::TensorProto_DataType_BOOL): 185 | if (tensor.int32_data_size() == 0) 186 | { 187 | fromRaw = true; 188 | break; 189 | } 190 | src = (void*)(&tensor.int32_data().data()[index]); 191 | break; 192 | case(onnx::TensorProto_DataType_INT64): 193 | if (tensor.int64_data_size() == 0) 194 | { 195 | fromRaw = true; 196 | break; 197 | } 198 | src = (void*)(&tensor.int64_data().data()[index]); 199 | break; 200 | case(onnx::TensorProto_DataType_STRING): 201 | if (tensor.string_data_size() == 0) 202 | { 203 | fromRaw = true; 204 | break; 205 | } 206 | src = (void*)(&tensor.string_data().data()[index]); 207 | break; 208 | case(onnx::TensorProto_DataType_DOUBLE): 209 | case(onnx::TensorProto_DataType_COMPLEX128): 210 | if (tensor.double_data_size() == 0) 211 | { 212 | fromRaw = true; 213 | break; 214 | } 215 | src = (void*)(&tensor.double_data().data()[index]); 216 | break; 217 | case(onnx::TensorProto_DataType_UINT32): 218 | case(onnx::TensorProto_DataType_UINT64): 219 | if (tensor.uint64_data_size() == 0) 220 | { 221 | fromRaw = true; 222 | break; 223 | } 224 | src = (void*)(&tensor.uint64_data().data()[index]); 225 | break; 226 | case(onnx::TensorProto_DataType_BFLOAT16): 227 | 228 | src = (void*)(&tensor.raw_data().c_str()[index]); 229 | break; 230 | } 231 | 232 | if (fromRaw) 233 | { 234 | src = (void *)(&tensor.raw_data().c_str()[0]); 235 | if(check_endianness() == -1){ 236 | memcpy(dst, src, length); 237 | return; 238 | } 239 | int64_t *dst_64, *src_64; 240 | int32_t *dst_32, *src_32; 241 | int16_t *dst_16, *src_16; 242 | switch (tensor.data_type()) 243 | { 244 | case (onnx::TensorProto_DataType_FLOAT): 245 | case(onnx::TensorProto_DataType_COMPLEX64): 246 | case (onnx::TensorProto_DataType_INT32): 247 | dst_32 = (int32_t *)dst; 248 | src_32 = (int32_t *)src; 249 | for (int i = 0; i < length / 4; i++) 250 | { 251 | dst_32[i] = swap_uint32(src_32[i]); 252 | } 253 | break; 254 | case (onnx::TensorProto_DataType_DOUBLE): 255 | case (onnx::TensorProto_DataType_COMPLEX128): 256 | case (onnx::TensorProto_DataType_INT64): 257 | dst_64 = (int64_t *)dst; 258 | src_64 = (int64_t *)src; 259 | for (int i = 0; i < length / 8; i++) 260 | { 261 | dst_64[i] = swap_uint64(src_64[i]); 262 | } 263 | break; 264 | case(onnx::TensorProto_DataType_UINT16): 265 | case(onnx::TensorProto_DataType_INT16): 266 | case(onnx::TensorProto_DataType_FLOAT16): 267 | dst_16 = (int16_t *)dst; 268 | src_16 = (int16_t *)src; 269 | for (int i = 0; i < length / 8; i++) 270 | { 271 | dst_16[i] = swap_uint16(src_16[i]); 272 | } 273 | break; 274 | default: 275 | memcpy(dst, src, length); 276 | } 277 | } 278 | else 279 | { 280 | memcpy(dst, src, length); 281 | } 282 | } 283 | 284 | 285 | bool OnnxMlirModelLoader::LoadModel(char *model_path) 286 | { 287 | void *handle = dlopen(model_path, RTLD_LAZY); 288 | if (!handle) 289 | { 290 | std::cout << "Did not find model " << model_path << std::endl; 291 | return false; 292 | } 293 | success = true; 294 | dll_run_main_graph = (OMTensorList * (*)(OMTensorList *)) 295 | dlsym(handle, "run_main_graph"); 296 | assert(!dlerror() && "failed to load entry point"); 297 | dll_omTensorCreate = 298 | (OMTensor * (*)(void *, int64_t *, int64_t, OM_DATA_TYPE)) 299 | dlsym(handle, "omTensorCreate"); 300 | assert(!dlerror() && "failed to load omTensorCreate"); 301 | dll_omTensorListCreate = (OMTensorList * (*)(OMTensor **, int)) 302 | dlsym(handle, "omTensorListCreate"); 303 | assert(!dlerror() && "failed to load omTensorListCreate"); 304 | dll_omTensorListGetOmtByIndex = 305 | (OMTensor * (*)(OMTensorList *, int64_t)) dlsym(handle, "omTensorListGetOmtByIndex"); 306 | dll_omTensorGetDataPtr = (void *(*)(OMTensor *))dlsym(handle, "omTensorGetDataPtr"); 307 | 308 | dll_omTensorListDestroy = 309 | (void (*)(OMTensorList *))dlsym(handle, "omTensorListDestroy"); 310 | assert(!dlerror() && "failed to load omTensorListDestroy"); 311 | dll_omTensorDestroy = 312 | (void (*)(OMTensor *))dlsym(handle, "omTensorDestroy"); 313 | return true; 314 | } 315 | 316 | OMTensor *OnnxMlirModelLoader::RunModel(void *x1Data, int64_t *shape, int64_t rank, OM_DATA_TYPE type) 317 | { 318 | OMTensor *x1 = dll_omTensorCreate(x1Data, shape, rank, type); 319 | OMTensor *list[1] = {x1}; 320 | OMTensorList *input = dll_omTensorListCreate(list, 1); 321 | OMTensorList *outputList = dll_run_main_graph(input); 322 | 323 | OMTensor *y = dll_omTensorListGetOmtByIndex(outputList, 0); 324 | omTensorDestroy(x1); 325 | return y; 326 | } 327 | 328 | OMTensorList *OnnxMlirModelLoader:: 329 | RunModel(OMTensor **list, int count) 330 | { 331 | OMTensorList *input = dll_omTensorListCreate(list, count); 332 | OMTensorList *outputList = dll_run_main_graph(input); 333 | return outputList; 334 | } 335 | 336 | OnnxMlirModel::OnnxMlirModel(const char *_model_name) 337 | { 338 | max_batchsize = -1; 339 | 340 | strcpy(model_name, _model_name); 341 | char model_path[70]; 342 | sprintf(model_path, "./models/%s/model.so", model_name); 343 | 344 | if (!loader.LoadModel(model_path)) 345 | { 346 | std::cout << "create failed" << std::endl; 347 | model_name[0] = 0; 348 | } 349 | 350 | char model_onnx[70]; 351 | sprintf(model_onnx, "./models/%s/config", model_name); 352 | ReadModelConfigFile(model_onnx); 353 | } 354 | 355 | void to_vector(const ::google::protobuf::RepeatedPtrField< ::onnx::ValueInfoProto > &info, std::vector *infos){ 356 | // std::vector infos; 357 | for (auto input_data: info) 358 | { 359 | TensorInfo input_info; 360 | input_info.data_type = input_data.type().tensor_type().elem_type(); 361 | auto shape = input_data.type().tensor_type().shape(); 362 | input_info.batch_dim = -1; 363 | 364 | if (shape.dim_size() != 0) 365 | { 366 | int size = shape.dim_size(); 367 | for (int i = 0; i < size; ++i) 368 | { 369 | auto dim = shape.dim(i); 370 | switch (dim.value_case()) 371 | { 372 | case onnx::TensorShapeProto_Dimension::ValueCase::kDimParam: 373 | input_info.shape.emplace_back(-1); 374 | input_info.batch_dim = i; 375 | break; 376 | case onnx::TensorShapeProto_Dimension::ValueCase::kDimValue: 377 | input_info.shape.emplace_back(dim.dim_value()); 378 | break; 379 | default: 380 | assert(false && "should never happen"); 381 | } 382 | } 383 | } 384 | infos->emplace_back(input_info); 385 | } 386 | } 387 | 388 | void OnnxMlirModel::ReadModelConfigFile(char *file_path){ 389 | inference::ModelConfig modelConfig; 390 | 391 | int fd = open(file_path, O_RDONLY); 392 | FileInputStream* input_stream = new FileInputStream(fd); 393 | 394 | google::protobuf::TextFormat::Parse(input_stream, &modelConfig); 395 | input_stream->Close(); 396 | close(fd); 397 | 398 | to_vector(modelConfig.input(), &inputs); 399 | to_vector(modelConfig.output(), &outputs); 400 | max_batchsize = modelConfig.max_batch_size(); 401 | } 402 | 403 | void OnnxMlirModel::ReadConfigFile(char *fileName) 404 | { 405 | std::ifstream fp2(fileName); 406 | if (!fp2.is_open()) 407 | { 408 | printf("read model config error\n"); 409 | return; 410 | } 411 | char modelName[20]; 412 | int rank; 413 | int typeSize; 414 | fp2 >> modelName; 415 | fp2 >> typeName >> typeSize; 416 | fp2 >> rank; 417 | int64_t *shape = (int64_t *)malloc(rank * sizeof(int64_t)); 418 | int index = 0; 419 | while (index < rank) 420 | { 421 | fp2 >> shape[index]; 422 | index++; 423 | } 424 | fp2 >> max_batchsize; 425 | fp2 >> batch_dim; 426 | fp2.close(); 427 | } 428 | 429 | bool OnnxMlirModel::CheckInputData(AbstractCallData *data){ 430 | bool match = false; 431 | do{ 432 | int input_size = data->getRequestData().tensor_size(); 433 | if (inputs.size() != input_size) 434 | break; 435 | for(size_t count= 0; count < input_size; count ++){ 436 | 437 | const onnx::TensorProto& tensor = data->getRequestData().tensor(count); 438 | size_t dim_size = tensor.dims_size(); 439 | if(dim_size != inputs[count].shape.size()) 440 | break; 441 | 442 | if(tensor.data_type() != inputs[count].data_type){ 443 | break; 444 | } 445 | 446 | int64_t data_length = 1; 447 | for(size_t i=0; isendBack(); 470 | return ; 471 | } 472 | std::unique_lock lock{lock_}; 473 | inference_data.push(data); 474 | } 475 | 476 | bool OnnxMlirModel::Ready(int wait, int max_batchsize_) 477 | { 478 | bool check = false; 479 | int size = 0; 480 | double d = 0; 481 | 482 | 483 | { 484 | std::lock_guard lock(lock_); 485 | 486 | size = inference_data.size(); 487 | if((max_batchsize_ > 1 && size >= max_batchsize_) || (size >=max_batchsize)) 488 | { 489 | check = true; 490 | } 491 | else if (size > 0 && wait != 0) 492 | { 493 | high_resolution_clock::time_point pnow = inference_data.front()->now; 494 | high_resolution_clock::time_point now = high_resolution_clock::now(); 495 | d = std::chrono::duration(now - pnow).count(); 496 | double w = (double)wait; 497 | check = d >= w; 498 | 499 | } 500 | else 501 | { 502 | check = true; 503 | } 504 | } 505 | return check; 506 | } 507 | 508 | 509 | void OnnxMlirModel::Add_log(LogInfo info, std::stringstream& log_stream) 510 | { 511 | 512 | log_stream << std::this_thread::get_id() << "," << info.key << "," << info.inference_size << ","; 513 | log_stream << Calulate_duration(info.end, info.start) << ","; 514 | log_stream << Calulate_duration(info.start, originTime) << ","; 515 | log_stream << Calulate_duration(info.end, originTime) << std::endl; 516 | 517 | } 518 | 519 | Task OnnxMlirModel::Perpare_and_run(AbstractCallData *callData) 520 | { 521 | 522 | if (!CheckInputData(callData)){ 523 | return [this, callData](std::function log) 524 | { 525 | callData->sendBack(); 526 | }; 527 | } 528 | 529 | return [this, callData](std::function log) 530 | { 531 | 532 | std::stringstream log_stream; 533 | high_resolution_clock::time_point pnow = callData->now; 534 | high_resolution_clock::time_point now = high_resolution_clock::now(); 535 | 536 | Add_log({pnow, now, "wake up", 1}, log_stream); 537 | 538 | size_t input_size = callData->getRequestData().tensor_size(); 539 | OMTensor* tensorlist[input_size]; 540 | for(size_t index=0; index < input_size; index++){ 541 | 542 | const onnx::TensorProto& tensor = callData->getRequestData().tensor(index); 543 | int64_t rank = tensor.dims_size(); 544 | const int64_t *shape = tensor.dims().data(); 545 | OM_DATA_TYPE type = ONNX_DATA_TYPE_TO_OM.at(onnx::TensorProto_DataType(tensor.data_type())); 546 | OMTensor *omTensor = omTensorCreateEmpty(const_cast(shape), rank, type); 547 | void *data = omTensorGetDataPtr(omTensor); 548 | int64_t buffsize = omTensorGetBufferSize(omTensor); 549 | copyTensorData(tensor, data, 0, buffsize); 550 | tensorlist[index] = omTensor; 551 | } 552 | 553 | 554 | OMTensorList *yList = loader.RunModel(tensorlist,input_size); 555 | int result_size = omTensorListGetSize(yList); 556 | 557 | for(size_t index = 0; index< result_size; index++){ 558 | 559 | OMTensor* y = omTensorListGetOmtByIndex(yList, index); 560 | int buffsize = omTensorGetBufferSize(y); 561 | int rank = omTensorGetRank(y); 562 | const int64_t *shape = omTensorGetShape(y); 563 | void *prediction = (void*)omTensorGetDataPtr(y); 564 | OM_DATA_TYPE type = omTensorGetDataType(y); 565 | 566 | onnx::TensorProto* tensor = callData->AddOutputTensor(); 567 | buildTensorProto(prediction, buffsize, type, tensor); 568 | tensor->mutable_dims()->Add(shape, shape+rank); 569 | omTensorDestroy(y); 570 | } 571 | 572 | callData->sendBack(); 573 | 574 | for(auto x: tensorlist){ 575 | omTensorDestroy(x); 576 | } 577 | 578 | high_resolution_clock::time_point now1 = high_resolution_clock::now(); 579 | Add_log({now, now1, "inference", 1}, log_stream); 580 | }; 581 | } 582 | 583 | 584 | Task OnnxMlirModel::Perpare_and_run(int64_t maxBatchsize) 585 | { 586 | 587 | return [this, maxBatchsize](std::function log) 588 | { 589 | int count = 0; 590 | 591 | std::vector my_queue; 592 | { 593 | std::unique_lock lock{lock_}; 594 | { 595 | int totalsize = inference_data.size(); 596 | while (count < maxBatchsize && count < totalsize) 597 | { 598 | my_queue.push_back(inference_data.front()); 599 | inference_data.pop(); 600 | count++; 601 | } 602 | } 603 | } 604 | 605 | 606 | int64_t batchsize = count; 607 | if (batchsize < 1) 608 | { 609 | return; 610 | } 611 | 612 | std::stringstream log_stream; 613 | high_resolution_clock::time_point pnow = my_queue[0]->now; 614 | high_resolution_clock::time_point now = high_resolution_clock::now(); 615 | 616 | Add_log({pnow, now, "wake up", batchsize}, log_stream); 617 | 618 | pnow = high_resolution_clock::now(); 619 | 620 | // merage 621 | int input_size = my_queue[0]->getRequestData().tensor_size(); 622 | 623 | OMTensor* tensorlist[input_size]; 624 | for(size_t index = 0; index< input_size; index++){ 625 | TensorInfo& info = inputs[index]; 626 | int64_t batch_dim = 0; 627 | if(info.batch_dim >= 0){ 628 | batch_dim = info.batch_dim; 629 | } 630 | 631 | const onnx::TensorProto& tensor = my_queue[0]->getRequestData().tensor(index); 632 | int64_t rank = tensor.dims_size(); //input.rank; 633 | const int64_t *single_shape = tensor.dims().data(); //input.shape; 634 | OM_DATA_TYPE type = ONNX_DATA_TYPE_TO_OM.at(onnx::TensorProto_DataType(tensor.data_type())); 635 | 636 | int64_t typeSize = getDataTypeSize(type); 637 | 638 | int64_t shape[rank]; 639 | 640 | // int totalsize = 1; 641 | for (int64_t i = 0; i < rank; i++) 642 | { 643 | shape[i] = my_queue[0]->getRequestData().tensor(index).dims(i); //single_shape[i]; 644 | } 645 | 646 | shape[batch_dim] = batchsize; 647 | 648 | OMTensor *omTensor = omTensorCreateEmpty(shape, rank, type); 649 | void *xData = omTensorGetDataPtr(omTensor); 650 | 651 | 652 | int before = 1; 653 | for (int64_t i = 0; i < batch_dim; i++) 654 | { 655 | before *= shape[i]; 656 | } 657 | int after = 1; 658 | for (int64_t i = batch_dim + 1; i < rank; i++) 659 | { 660 | after *= shape[i]; 661 | } 662 | 663 | size_t b = 0; 664 | for(AbstractCallData* callData: my_queue){ 665 | 666 | const onnx::TensorProto& tensor = callData->getRequestData().tensor(index); 667 | 668 | 669 | for (int j = 0; j < before; j++) 670 | { 671 | void * dst = xData + (j * batchsize + b) * after*typeSize; 672 | copyTensorData(tensor, dst, j * after, after*typeSize); 673 | } 674 | b++; 675 | 676 | } 677 | tensorlist[index] = omTensor; 678 | } 679 | 680 | 681 | now = high_resolution_clock::now(); 682 | Add_log({pnow, now, "merge", batchsize}, log_stream); 683 | OMTensorList *yList = loader.RunModel(tensorlist,input_size); 684 | int result_size = omTensorListGetSize(yList); 685 | for(size_t index = 0; index< result_size; index++){ 686 | 687 | // get batch dim 688 | TensorInfo info = outputs[index]; 689 | int64_t batch_dim = 0; 690 | if(info.batch_dim >= 0){ 691 | batch_dim = info.batch_dim; 692 | } 693 | 694 | OMTensor* y = omTensorListGetOmtByIndex(yList, index); 695 | int buffsize = omTensorGetBufferSize(y); 696 | int rank = omTensorGetRank(y); 697 | const int64_t *resultShape = omTensorGetShape(y); 698 | int64_t shape[omTensorGetRank(y)]; 699 | uint8_t *prediction = (uint8_t*)omTensorGetDataPtr(y); 700 | OM_DATA_TYPE type = omTensorGetDataType(y); 701 | int64_t typeSize = getDataTypeSize(type); 702 | int singleBufferSize = buffsize / batchsize; 703 | 704 | 705 | int before = 1; 706 | for (int64_t i = 0; i < batch_dim; i++) 707 | { 708 | shape[i]=resultShape[i]; 709 | before *= shape[i]; 710 | } 711 | int after = 1; 712 | for (int64_t i = batch_dim + 1; i < rank; i++) 713 | { 714 | shape[i]=resultShape[i]; 715 | after *= shape[i]; 716 | } 717 | shape[batch_dim] =1; 718 | 719 | size_t b = 0; 720 | for(AbstractCallData* callData: my_queue){ 721 | onnx::TensorProto* tensor = callData->AddOutputTensor(); 722 | uint8_t single_result[singleBufferSize]; 723 | for (int j = 0; j < before; j++) 724 | { 725 | uint8_t * src = &prediction[(j * batchsize + b) * after*typeSize]; 726 | uint8_t * dst = &single_result[j * after*typeSize]; 727 | memcpy(dst, src, after*typeSize); 728 | } 729 | b++; 730 | 731 | buildTensorProto(single_result, singleBufferSize, type, tensor); 732 | tensor->mutable_dims()->Add(shape, shape+rank); 733 | } 734 | omTensorDestroy(y); 735 | } 736 | 737 | for(AbstractCallData* callData: my_queue){ 738 | callData->sendBack(); 739 | } 740 | 741 | now = high_resolution_clock::now(); 742 | Add_log({pnow, now, "inference", batchsize}, log_stream); 743 | for(auto x: tensorlist){ 744 | omTensorDestroy(x); 745 | } 746 | 747 | log(log_stream.str()); 748 | }; 749 | } 750 | -------------------------------------------------------------------------------- /grpc_server/model_loader.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #pragma GCC diagnostic ignored "-Wunused-parameter" 3 | #ifndef ONNXMLIR_MODEL_LOADER_H 4 | #define ONNXMLIR_MODEL_LOADER_H 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include "OnnxMlirRuntime.h" 23 | #include "utils/inference.grpc.pb.h" 24 | 25 | using google::protobuf::io::FileOutputStream; 26 | using google::protobuf::io::FileInputStream; 27 | 28 | using inference::InferenceRequest; 29 | using std::chrono::high_resolution_clock; 30 | using Task = std::function)>; 31 | 32 | extern std::chrono::high_resolution_clock::time_point originTime; 33 | 34 | typedef struct tensorInfo_ 35 | { 36 | int32_t data_type; 37 | std::vector shape; 38 | int64_t batch_dim; 39 | } TensorInfo; 40 | 41 | 42 | //TODO bfloat not support for onnx-mlir. 43 | const std::map ONNX_DATA_TYPE_TO_OM = { 44 | {onnx::TensorProto_DataType_BOOL, ONNX_TYPE_BOOL}, // bool -> BOOL // char -> INT8 (platform dependent, can be UINT8) 45 | {onnx::TensorProto_DataType_INT8, ONNX_TYPE_INT8}, // int8_t -> INT8 46 | {onnx::TensorProto_DataType_UINT8, ONNX_TYPE_UINT8}, // uint8_t -> UINT8, unsigned char -> UNIT 8 47 | {onnx::TensorProto_DataType_INT16, ONNX_TYPE_INT16}, // int16_t -> INT16, short -> INT16 48 | {onnx::TensorProto_DataType_UINT16, ONNX_TYPE_UINT16}, // uint16_t -> UINT16, unsigned short -> UINT16 49 | {onnx::TensorProto_DataType_INT32, ONNX_TYPE_INT32}, // int32_t -> INT32, int -> INT32 50 | {onnx::TensorProto_DataType_UINT32, ONNX_TYPE_UINT32}, // uint32_t -> UINT32, unsigned int -> UINT32 51 | {onnx::TensorProto_DataType_INT64, ONNX_TYPE_INT64}, // int64_t -> INT64, long -> INT64 52 | {onnx::TensorProto_DataType_UINT64, ONNX_TYPE_UINT64}, // uint64_t -> UINT64, unsigned long -> UINT64 53 | {onnx::TensorProto_DataType_FLOAT, ONNX_TYPE_FLOAT}, // float -> FLOAT 54 | {onnx::TensorProto_DataType_DOUBLE, ONNX_TYPE_DOUBLE}, // double -> DOUBLE 55 | {onnx::TensorProto_DataType_STRING, ONNX_TYPE_STRING}, // const char * -> STRING 56 | {onnx::TensorProto_DataType_COMPLEX64, ONNX_TYPE_COMPLEX64}, // _Complex float -> COMPLEX64 57 | {onnx::TensorProto_DataType_COMPLEX128, ONNX_TYPE_COMPLEX128}, // _Complex double -> COMPLEX128 58 | }; 59 | 60 | const std::map OM_TO_ONNX_DATA_TYPE = { 61 | {ONNX_TYPE_BOOL, onnx::TensorProto_DataType_BOOL,}, // bool -> BOOL // char -> INT8 (platform dependent, can be UINT8) 62 | {ONNX_TYPE_INT8, onnx::TensorProto_DataType_INT8}, // int8_t -> INT8 63 | {ONNX_TYPE_UINT8, onnx::TensorProto_DataType_UINT8}, // uint8_t -> UINT8, unsigned char -> UNIT 8 64 | {ONNX_TYPE_INT16, onnx::TensorProto_DataType_INT16}, // int16_t -> INT16, short -> INT16 65 | {ONNX_TYPE_UINT16, onnx::TensorProto_DataType_UINT16}, // uint16_t -> UINT16, unsigned short -> UINT16 66 | {ONNX_TYPE_INT32, onnx::TensorProto_DataType_INT32}, // int32_t -> INT32, int -> INT32 67 | {ONNX_TYPE_UINT32, onnx::TensorProto_DataType_UINT32}, // uint32_t -> UINT32, unsigned int -> UINT32 68 | {ONNX_TYPE_INT64, onnx::TensorProto_DataType_INT64}, // int64_t -> INT64, long -> INT64 69 | {ONNX_TYPE_UINT64, onnx::TensorProto_DataType_UINT64}, // uint64_t -> UINT64, unsigned long -> UINT64 70 | {ONNX_TYPE_FLOAT, onnx::TensorProto_DataType_FLOAT}, // float -> FLOAT 71 | {ONNX_TYPE_DOUBLE, onnx::TensorProto_DataType_DOUBLE, }, // double -> DOUBLE 72 | {ONNX_TYPE_STRING, onnx::TensorProto_DataType_STRING}, // const char * -> STRING 73 | {ONNX_TYPE_COMPLEX64, onnx::TensorProto_DataType_COMPLEX64}, // _Complex float -> COMPLEX64 74 | {ONNX_TYPE_COMPLEX128, onnx::TensorProto_DataType_COMPLEX128}, // _Complex double -> COMPLEX128 75 | }; 76 | 77 | 78 | class AbstractCallData 79 | { 80 | public: 81 | virtual InferenceRequest &getRequestData() = 0; 82 | // virtual std::vector getInputsData() = 0; 83 | virtual void sendBack() = 0; 84 | // virtual void AddInputs(TensorData data) = 0; 85 | virtual onnx::TensorProto* AddOutputTensor() = 0; 86 | high_resolution_clock::time_point now; 87 | }; 88 | 89 | class OnnxMlirModelLoader 90 | { 91 | public: 92 | bool LoadModel(char *model_path); 93 | 94 | 95 | OMTensor *(*dll_omTensorCreate)(void *, int64_t *, int64_t, OM_DATA_TYPE); 96 | OMTensor *RunModel(void *x1Data, int64_t *shape, int64_t rank, OM_DATA_TYPE type); 97 | OMTensorList *RunModel(OMTensor **list, int); 98 | 99 | bool success{false}; 100 | 101 | private: 102 | OMTensorList *(*dll_run_main_graph)(OMTensorList *); 103 | const char *(*dll_omInputSignature)(); 104 | const char *(*dll_omOutputSignature)(); 105 | OMTensorList *(*dll_omTensorListCreate)(OMTensor **, int); 106 | OMTensor *(*dll_omTensorListGetOmtByIndex)(OMTensorList *, int64_t); 107 | void *(*dll_omTensorGetDataPtr)(OMTensor *); 108 | void (*dll_omTensorDestroy)(OMTensor *tensor); 109 | void (*dll_omTensorListDestroy)(OMTensorList *); 110 | }; 111 | 112 | typedef struct logInfo_ 113 | { 114 | high_resolution_clock::time_point start; 115 | high_resolution_clock::time_point end; 116 | std::string key; 117 | int64_t inference_size; 118 | } LogInfo; 119 | 120 | 121 | 122 | class OnnxMlirModel 123 | { 124 | public: 125 | OnnxMlirModel(const char *_model_name); 126 | 127 | void ReadConfigFile(char *fileName); 128 | 129 | void ReadModelConfigFile(char *file_path); 130 | 131 | bool CheckInputData(AbstractCallData *data); 132 | 133 | void AddInferenceData(AbstractCallData *data); 134 | 135 | bool Ready(int wait, int max_batchsize_); 136 | 137 | std::string Calulate_duration(high_resolution_clock::time_point time1, high_resolution_clock::time_point time2) 138 | { 139 | return std::to_string(std::chrono::duration(time1 - time2).count()); 140 | } 141 | 142 | void Add_log(LogInfo info, std::stringstream& log_stream); 143 | 144 | Task Perpare_and_run(AbstractCallData *data); 145 | 146 | Task Perpare_and_run(int64_t batchsize_); 147 | 148 | char model_name[50]; 149 | OnnxMlirModelLoader loader; 150 | std::queue inference_data; 151 | int max_batchsize = -1; 152 | int batch_dim = -1; 153 | char typeName[5]; 154 | std::vector inputs; 155 | std::vector outputs; 156 | 157 | private: 158 | std::mutex lock_; 159 | }; 160 | 161 | #endif -------------------------------------------------------------------------------- /grpc_server/model_manager.h: -------------------------------------------------------------------------------- 1 | #include "model_loader.h" 2 | #include "thread_pool.h" 3 | 4 | class OnnxMlirModelManager 5 | { 6 | public: 7 | OnnxMlirModelManager(int batch_size, int thread_num, int wait_time) : 8 | tpool_(thread_num), 9 | checkBatchingThread_([this]{ checkBatching(); }) 10 | { 11 | batch_size_ = batch_size; 12 | wait_time_ = wait_time; 13 | } 14 | 15 | ~OnnxMlirModelManager() 16 | { 17 | run_ = 0; 18 | if (checkBatchingThread_.joinable()) 19 | { 20 | checkBatchingThread_.join(); 21 | } 22 | } 23 | 24 | int AddModel(AbstractCallData *data) 25 | { 26 | const char *model_name = data->getRequestData().model_name().c_str(); 27 | OnnxMlirModel *model = NULL; 28 | { 29 | lock_guard lock(lock_); 30 | model = Get_model(model_name); 31 | } 32 | 33 | if (model == NULL) 34 | { 35 | data->sendBack(); 36 | return 0; 37 | } 38 | 39 | // no batching, add task to thread pool right now 40 | if (model->max_batchsize <= 1 || batch_size_ == 1) 41 | { 42 | tpool_.AddTask(model->Perpare_and_run(data)); 43 | } 44 | // else add data to inference queue, wait batching 45 | else 46 | { 47 | model->AddInferenceData(data); 48 | } 49 | 50 | return 1; 51 | } 52 | 53 | void PrintLogs() 54 | { 55 | tpool_.PrintLogs(); 56 | } 57 | 58 | private: 59 | void checkBatching() 60 | { 61 | while (run_) 62 | { 63 | { 64 | lock_guard lock(lock_); 65 | for (size_t i = 0; i < models_.size(); i++) 66 | { 67 | OnnxMlirModel *model = models_.at(i); 68 | if (model->max_batchsize > 0 && model->Ready(wait_time_, batch_size_)) 69 | { 70 | tpool_.AddTask(model->Perpare_and_run(batch_size_)); 71 | } 72 | } 73 | } 74 | std::this_thread::sleep_for(std::chrono::nanoseconds((int)(10000))); 75 | } 76 | } 77 | 78 | OnnxMlirModel *Get_model(const char *model_name) 79 | { 80 | OnnxMlirModel *model = NULL; 81 | // get model from exist model queue 82 | for (size_t i = 0; i < models_.size(); i++) 83 | { 84 | if (strcmp(model_name, models_[i]->model_name) == 0) 85 | { 86 | model = models_[i]; 87 | return model; 88 | } 89 | } 90 | 91 | // create new model when not find 92 | model = new OnnxMlirModel(model_name); 93 | if (model->model_name[0] == 0) 94 | { 95 | return NULL; 96 | } 97 | models_.emplace_back(model); 98 | 99 | return model; 100 | } 101 | 102 | std::vector models_; 103 | AIInfrenceThreadPool tpool_; 104 | std::thread checkBatchingThread_; 105 | std::mutex lock_; 106 | int run_ = 1; 107 | int batch_size_; 108 | int wait_time_; 109 | }; -------------------------------------------------------------------------------- /grpc_server/thread_pool.cc: -------------------------------------------------------------------------------- 1 | #include "thread_pool.h" 2 | 3 | void AIInfrenceThreadPool::AddTask(Task task) 4 | { 5 | lock_guard lock(lock_); 6 | tasks_.push(task); 7 | task_cv_.notify_one(); 8 | } 9 | 10 | void AIInfrenceThreadPool::AddThread(int _size) 11 | { 12 | 13 | int size = std::min(_size, THREADPOOL_MAX_NUM); 14 | while(size > 0) 15 | { 16 | pool_.emplace_back([this](int cpuindex){ 17 | int cpu = (cpuindex-1) * 12; 18 | int total_cpu_num = sysconf(_SC_NPROCESSORS_CONF); 19 | cpu = cpu % total_cpu_num + cpu / total_cpu_num; 20 | cpu_set_t mask; 21 | // cpu_set_t get; 22 | CPU_ZERO(&mask); 23 | CPU_SET(*&cpu,&mask); 24 | char tname[20]; 25 | sprintf(tname, "Inference thread %d", cpu); 26 | 27 | prctl(PR_SET_NAME, tname); 28 | if(sched_setaffinity(0,sizeof(cpu_set_t),&mask)==-1) 29 | { 30 | printf("warning: could not set CPU affinity, continuing...\n"); 31 | } 32 | while (run_) 33 | { 34 | 35 | Task task; 36 | { 37 | unique_lock lock{ lock_ }; 38 | task_cv_.wait(lock, [this]{ return !run_ || !tasks_.empty(); }); 39 | 40 | if (!run_ && tasks_.empty()) 41 | return; 42 | 43 | task = move(tasks_.front()); 44 | tasks_.pop(); 45 | } 46 | 47 | idl_thread_num_--; 48 | task(to_log); 49 | idl_thread_num_++; 50 | } }, size); 51 | size--; 52 | idl_thread_num_++; 53 | } 54 | } 55 | 56 | void AIInfrenceThreadPool::PrintLogs() 57 | { 58 | std::cout << "print log " << std::endl; 59 | std::cout << log_stream_.str() << std::endl; 60 | log_stream_.clear(); 61 | } -------------------------------------------------------------------------------- /grpc_server/thread_pool.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #ifndef AI_INFERENCE_THREAD_POOL_H 3 | #define AI_INFERENCE_THREAD_POOL_H 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | 23 | #include "model_loader.h" 24 | 25 | using std::atomic; 26 | using std::condition_variable; 27 | using std::lock_guard; 28 | using std::mutex; 29 | using std::queue; 30 | using std::string; 31 | using std::thread; 32 | using std::unique_lock; 33 | using std::vector; 34 | // using std::chrono::high_resolution_clock; 35 | 36 | #define THREADPOOL_MAX_NUM 20 37 | 38 | extern std::chrono::high_resolution_clock::time_point originTime; 39 | extern OnnxMlirModelLoader modelLoder; 40 | 41 | class AIInfrenceThreadPool 42 | { 43 | using Task = std::function)>; 44 | vector pool_; 45 | queue tasks_; 46 | mutex lock_; 47 | mutex log_mutex_; 48 | condition_variable task_cv_; 49 | atomic run_{true}; 50 | atomic idl_thread_num_{0}; 51 | int wait_ = 0; 52 | std::stringstream log_stream_; 53 | int batch_size_ = 1; 54 | 55 | public: 56 | AIInfrenceThreadPool(int size) 57 | { 58 | AddThread(size); 59 | } 60 | ~AIInfrenceThreadPool() 61 | { 62 | run_ = false; 63 | task_cv_.notify_all(); 64 | for (thread &thread : pool_) 65 | { 66 | if (thread.joinable()) 67 | thread.join(); 68 | } 69 | } 70 | 71 | void AddThread(int size); 72 | void AddCallData(AbstractCallData *data); 73 | void AddTask(Task task); 74 | void PrintLogs(); 75 | int IdlCount() { return idl_thread_num_; } 76 | int ThreadCount() { return pool_.size(); } 77 | std::function to_log = [this](std::string c) 78 | { 79 | lock_guard lock(log_mutex_); 80 | log_stream_ << c; 81 | }; 82 | }; 83 | 84 | #endif -------------------------------------------------------------------------------- /java/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 4.0.0 7 | org.onnx 8 | onnx-grpc-client 9 | jar 10 | 12 | 1.0.0-SNAPSHOT 13 | onnx-grpc-client 14 | https://github.com/grpc/grpc-java 15 | 16 | 17 | UTF-8 18 | 1.47.0 19 | 3.19.2 20 | 3.19.2 21 | 22 | 1.7 23 | 1.7 24 | 25 | 26 | 27 | 28 | 29 | io.grpc 30 | grpc-bom 31 | ${grpc.version} 32 | pom 33 | import 34 | 35 | 36 | 37 | 38 | 39 | 40 | io.grpc 41 | grpc-netty-shaded 42 | runtime 43 | 44 | 45 | io.grpc 46 | grpc-protobuf 47 | 48 | 49 | io.grpc 50 | grpc-stub 51 | 52 | 53 | com.google.protobuf 54 | protobuf-java-util 55 | ${protobuf.version} 56 | 57 | 58 | com.google.code.gson 59 | gson 60 | 2.9.0 61 | 62 | 63 | org.apache.tomcat 64 | annotations-api 65 | 6.0.53 66 | provided 67 | 68 | 69 | io.grpc 70 | grpc-testing 71 | test 72 | 73 | 74 | junit 75 | junit 76 | 4.12 77 | test 78 | 79 | 80 | org.mockito 81 | mockito-core 82 | 3.4.0 83 | test 84 | 85 | 86 | 87 | 88 | 89 | 90 | kr.motd.maven 91 | os-maven-plugin 92 | 1.6.2 93 | 94 | 95 | 96 | 97 | org.xolstice.maven.plugins 98 | protobuf-maven-plugin 99 | 0.6.1 100 | 101 | com.google.protobuf:protoc:${protoc.version}:exe:${os.detected.classifier} 102 | grpc-java 103 | io.grpc:protoc-gen-grpc-java:${grpc.version}:exe:${os.detected.classifier} 104 | 105 | 106 | 107 | 108 | compile 109 | compile-custom 110 | 111 | 112 | 113 | 114 | 115 | org.apache.maven.plugins 116 | maven-enforcer-plugin 117 | 1.4.1 118 | 119 | 120 | enforce 121 | 122 | enforce 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | -------------------------------------------------------------------------------- /java/src/main/java/org/onnx/inference/ONNXGRPCClient.java: -------------------------------------------------------------------------------- 1 | 2 | 3 | package org.onnx.inference; 4 | 5 | import io.grpc.Channel; 6 | import io.grpc.ManagedChannel; 7 | import io.grpc.ManagedChannelBuilder; 8 | import io.grpc.StatusRuntimeException; 9 | import java.io.IOException; 10 | import java.nio.ByteBuffer; 11 | import java.nio.FloatBuffer; 12 | import java.nio.charset.Charset; 13 | import java.nio.charset.StandardCharsets; 14 | import java.nio.file.Files; 15 | import java.nio.file.Path; 16 | import java.nio.file.Paths; 17 | import java.util.ArrayList; 18 | import java.util.LinkedHashMap; 19 | import java.util.List; 20 | import java.util.Map; 21 | import java.util.logging.Level; 22 | import java.util.logging.Logger; 23 | import inference.*; 24 | import java.util.concurrent.TimeUnit; 25 | /** 26 | * ONNX GRPC Client 27 | */ 28 | public class ONNXGRPCClient { 29 | 30 | private static final Logger logger = Logger.getLogger(ONNXGRPCClient.class.getName()); 31 | 32 | private final InferenceServiceGrpc.InferenceServiceBlockingStub blockingStub; 33 | 34 | private Inference.InferenceRequest.Builder builder; 35 | 36 | public ONNXGRPCClient(Channel channel) { 37 | 38 | blockingStub = InferenceServiceGrpc.newBlockingStub(channel); 39 | } 40 | 41 | public void setupRequestBuilder(String model, List shapes) { 42 | builder = Inference.InferenceRequest.newBuilder().setModelName(model); 43 | builder.addAllShape(shapes); 44 | } 45 | 46 | 47 | /** Say hello to server. */ 48 | public void inference(List data) { 49 | Inference.InferenceRequest request = builder.addAllData(data).build(); 50 | Inference.InferenceResponse response; 51 | try { 52 | response = blockingStub.inference(request); 53 | } catch (StatusRuntimeException e) { 54 | logger.log(Level.WARNING, "Inference request failed: {0}", e.getStatus()); 55 | return; 56 | } 57 | logger.info("Inference result: " + response.getDataList()); 58 | } 59 | 60 | 61 | 62 | static class ImageDataset { 63 | private Map imageNames = new LinkedHashMap<>(); 64 | private String model; 65 | public String getModel() { 66 | return model; 67 | } 68 | 69 | private int rank; 70 | public int getRank() { 71 | return rank; 72 | } 73 | 74 | private List shapes; 75 | public List getShapes() { 76 | return shapes; 77 | } 78 | 79 | private String datasetPath; 80 | private List[] images; 81 | 82 | public List[] getImages() { 83 | return images; 84 | } 85 | 86 | public ImageDataset(String datasetPath) throws IOException { 87 | this.datasetPath = datasetPath; 88 | readImageList(datasetPath); 89 | images = new List[imageNames.size()]; 90 | int index = 0; 91 | for (String imageName:imageNames.keySet()) { 92 | images[index] = loadImage(imageName, imageNames.get(imageName)); 93 | index++; 94 | } 95 | } 96 | 97 | private List loadImage(String imageName, Integer integer) throws IOException { 98 | logger.info("Load image:"+imageName); 99 | Path imagePath = Paths.get(datasetPath, imageName); 100 | try { 101 | byte[] imageBytes = Files.readAllBytes(imagePath); 102 | List floatList = new ArrayList(); 103 | // ByteBuffer buffer = ByteBuffer.wrap(imageBytes); 104 | // List floatList = new ArrayList(); 105 | // while (buffer.hasRemaining()) { 106 | // floatList.add(buffer.getDouble()); 107 | // } 108 | for (int i=0; i lines = Files.readAllLines(mapFilePath, charset); 133 | for(String line: lines) { 134 | String[] tokens = line.split("\\s+"); 135 | if (tokens.length >1){ 136 | imageNames.put(tokens[0], Integer.parseInt(tokens[1])); 137 | } 138 | } 139 | 140 | } catch (IOException ex) { 141 | logger.warning("Failed to load input dataset map"); 142 | throw ex; 143 | } 144 | try { 145 | List lines = Files.readAllLines(configFilePath, charset); 146 | model = lines.get(0); 147 | rank = Integer.parseInt(lines.get(1)); 148 | shapes = new ArrayList(); 149 | String[] tokens = lines.get(2).split("\\s+"); 150 | for (int i=0; i=1) { 175 | String datasetPath = args[0]; 176 | ImageDataset dataset = new ImageDataset(datasetPath); 177 | ManagedChannel channel = ManagedChannelBuilder.forTarget(server) 178 | .usePlaintext() 179 | .build(); 180 | try { 181 | ONNXGRPCClient client = new ONNXGRPCClient(channel); 182 | client.setupRequestBuilder(dataset.getModel(), dataset.getShapes()); 183 | for ( List image : dataset.getImages()) { 184 | client.inference(image); 185 | } 186 | 187 | } finally { 188 | // ManagedChannels use resources like threads and TCP connections. To prevent leaking these 189 | // resources the channel should be shut down when it will no longer be used. If it may be used 190 | // again leave it running. 191 | channel.shutdownNow().awaitTermination(5, TimeUnit.SECONDS); 192 | } 193 | } else { 194 | logger.warning("Pls provide dataset path as parameter"); 195 | } 196 | 197 | // Create a communication channel to the server, known as a Channel. Channels are thread-safe 198 | // and reusable. It is common to create channels at the beginning of your application and reuse 199 | // them until the application shuts down. 200 | 201 | } 202 | } 203 | 204 | -------------------------------------------------------------------------------- /java/src/main/proto/inference.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package inference; 4 | 5 | service InferenceService { 6 | rpc Inference (InferenceRequest) returns (InferenceResponse) {}; 7 | rpc PrintStatistics (PrintStatisticsRequest) returns (PrintStatisticsResponse){}; 8 | } 9 | message InferenceRequest { 10 | repeated int64 shape = 1; 11 | repeated float data = 2; 12 | string model_name = 3; 13 | } 14 | message InferenceResponse{ 15 | repeated float data = 1; 16 | } 17 | message PrintStatisticsRequest { 18 | 19 | } 20 | message PrintStatisticsResponse { 21 | 22 | } -------------------------------------------------------------------------------- /proposal.md: -------------------------------------------------------------------------------- 1 | # ONNX Serving 2 | 3 | ## Serving Tool Proposal 4 | 5 | ONNX Serving is a project written with C++ to serve onnx-mlir compiled models with GRPC and other protocols. Benefiting from C++ implementation, ONNX Serving has very low latency overhead and high throughput. ONNX Servring provides dynamic batch aggregation and workers pool to fully utilize AI accelerators on the machine. 6 | 7 | Currently there is no existing high performance open source sering solution for onnx-mlir compiled model, IBM wants to contribute an open-source project to ONNX community which can help user to deploy their onnx-mlir in production environment. 8 | 9 | ## Proposal 10 | 11 | Contriubte ONNX Serving to https://github.com/onnx/onnx-serving 12 | 13 | Welcome community contributions to enhance onnx-serving with broader hardware and platform support. 14 | 15 | Questions: 16 | 17 | 18 | ## Rules for all repos and Requirements for new, contributed repos 19 | 20 | | Rules for all repos 21 | 22 | 1. Must be owned and managed by one of the ONNX SIGs (Architecture & Infra) 23 | 24 | 2. Must be actively maintained (Qin Yue Chen, Fei Fei Li) 25 | 26 | 3. Must adopt the ONNX Code of Conduct (check) 27 | 28 | 4. Must adopt the standard ONNX license(s) (already Apache-2.0 License) 29 | 30 | 5. Must adopt the ONNX CLA bot (check) 31 | 32 | 6. Must adopt all ONNX automation (like LGTM) (check) 33 | 34 | 7. Must have CI or other automation in place for repos containing code to ensure quality (already implemented CI and utest, need to implement more test cases and add coverage scan tool) 35 | 36 | 8. | All OWNERS must be members of standing as defined by ability to vote in Steering Committee elections. (check) 37 | 38 | Requirements for new, contributed repos 39 | 40 | We are happy to accept contributions as repos under the ONNX organization of new projects that meet the following requirements: 41 | 42 | 1. Project is closely related to ONNX (onnx-mlir) 43 | 44 | 2. Adds value to the ONNX ecosystem (serving onnx-mlir compiled model) 45 | 46 | 3. Determined to need a new repo rather than a folder in an existing repo (no) 47 | 48 | 4. All contributors must have signed the ONNX CLA (check) 49 | 50 | 5. Licenses of dependencies must be acceptable (check) 51 | 52 | 6. Committment to maintain the repo (Qin Yue Chen, Fei Fei Li) 53 | 54 | 7. Approval of the SIG that will own the repo 55 | 56 | 8. Approval of the Steering Committee -------------------------------------------------------------------------------- /tests/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | project(grpc_server_test) 2 | 3 | 4 | get_filename_component(utils "../utils" ABSOLUTE) 5 | get_filename_component(utils_path "${utils}" PATH) 6 | 7 | message(STATUS "hellohello_${utils_path}") 8 | 9 | add_executable(grpc-test grpc-test.cc) 10 | target_link_libraries(grpc-test 11 | gtest_main 12 | gtest 13 | hw_grpc_proto 14 | ${_REFLECTION} 15 | ${_GRPC_GRPCPP} 16 | ${_PROTOBUF_LIBPROTOBUF}) 17 | target_include_directories(grpc-test PRIVATE ${utils_path}) 18 | add_test(NAME grpc-test COMMAND grpc-test) 19 | -------------------------------------------------------------------------------- /tests/grpc-test.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "gtest/gtest.h" 4 | #include "utils/grpc_client.h" 5 | #include 6 | #include 7 | using namespace std::chrono_literals; 8 | namespace { 9 | 10 | class ServerTest : public testing::Test { 11 | protected: 12 | // Remember that SetUp() is run immediately before a test starts. 13 | // This is a good place to record the start time. 14 | void SetUp() override { 15 | std::cout << "Start server with simplest way" << std::endl; 16 | system("bash -c './grpc_server' &"); 17 | std::this_thread::sleep_for(2s); 18 | 19 | } 20 | // TearDown() is invoked immediately after a test finishes. Here we 21 | // check if the test was too slow. 22 | void TearDown() override { 23 | std::cout << "Stop server" << std::endl; 24 | auto a = system("pkill -e grpc_server"); 25 | } 26 | 27 | 28 | // Gets the time when the test finishes 29 | 30 | }; 31 | 32 | 33 | 34 | 35 | TEST_F(ServerTest, mnist0) { 36 | system("wait 5"); 37 | Dataset ds("./models/mnist"); 38 | InferenceClient client(grpc::CreateChannel("localhost:50051", grpc::InsecureChannelCredentials())); 39 | std::vector out_vector = client.Inference(ds.getImageData(0), ds.shape, ds.rank, ds.modelName); 40 | for (auto value:out_vector) 41 | std::cout << value << std::endl; 42 | auto maxPosition = max_element(out_vector.begin(),out_vector.end()) - out_vector.begin(); 43 | EXPECT_EQ(4, maxPosition); 44 | out_vector = client.Inference(ds.getImageData(0), ds.shape, ds.rank, ds.modelName); 45 | for (auto value:out_vector) 46 | std::cout << value << std::endl; 47 | maxPosition = max_element(out_vector.begin(),out_vector.end()) - out_vector.begin(); 48 | EXPECT_EQ(4, maxPosition); 49 | // auto out_vector2 = client.Inference(&ds, 1); 50 | // auto maxPosition2 = max_element(out_vector2.begin(),out_vector2.end()) - out_vector2.begin(); 51 | // EXPECT_EQ(4, maxPosition2); 52 | } 53 | } // namespace -------------------------------------------------------------------------------- /tests/models/mnist/config: -------------------------------------------------------------------------------- 1 | input { 2 | name: "image" 3 | type { 4 | tensor_type { 5 | elem_type: 1 6 | shape { 7 | dim { 8 | dim_value: 1 9 | } 10 | dim { 11 | dim_value: 1 12 | } 13 | dim { 14 | dim_value: 28 15 | } 16 | dim { 17 | dim_value: 28 18 | } 19 | } 20 | } 21 | } 22 | } 23 | output { 24 | name: "prediction" 25 | type { 26 | tensor_type { 27 | elem_type: 1 28 | shape { 29 | dim { 30 | dim_value: 1 31 | } 32 | dim { 33 | dim_value: 10 34 | } 35 | } 36 | } 37 | } 38 | } 39 | max_batch_size: 5 40 | -------------------------------------------------------------------------------- /tests/models/mnist/grpc_config.txt: -------------------------------------------------------------------------------- 1 | mnist 2 | f 4 3 | 4 4 | 1 1 28 28 -------------------------------------------------------------------------------- /tests/models/mnist/img0.data: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/onnx-mlir-serving/21d1640ffe0af914933f82ebe768cb0ac0c38caf/tests/models/mnist/img0.data -------------------------------------------------------------------------------- /tests/models/mnist/img1.data: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/onnx-mlir-serving/21d1640ffe0af914933f82ebe768cb0ac0c38caf/tests/models/mnist/img1.data -------------------------------------------------------------------------------- /tests/models/mnist/model.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/onnx-mlir-serving/21d1640ffe0af914933f82ebe768cb0ac0c38caf/tests/models/mnist/model.onnx -------------------------------------------------------------------------------- /tests/models/mnist/model.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/onnx-mlir-serving/21d1640ffe0af914933f82ebe768cb0ac0c38caf/tests/models/mnist/model.so -------------------------------------------------------------------------------- /tests/models/mnist/val_map.txt: -------------------------------------------------------------------------------- 1 | img0.data 4 2 | -------------------------------------------------------------------------------- /utils/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | project(onnx-parser) 2 | 3 | if(DEFINED GRPC_DIR) 4 | message(STATUS "GPRC_DIR: ${GRPC_DIR}") 5 | add_subdirectory(${GRPC_DIR} ${CMAKE_CURRENT_BINARY_DIR}/grpc EXCLUDE_FROM_ALL) 6 | set(_PROTOBUF_LIBPROTOBUF libprotobuf) 7 | set(_REFLECTION grpc++_reflection) 8 | set(_GRPC_GRPCPP grpc++) 9 | else() 10 | find_package(Protobuf CONFIG REQUIRED) 11 | include_directories(${Protobuf_INCLUDE_DIR}) 12 | find_package(gRPC CONFIG REQUIRED) 13 | set(_REFLECTION gRPC::grpc++_reflection) 14 | set(_GRPC_GRPCPP gRPC::grpc++) 15 | endif() 16 | if(DEFINED ONNX_COMPILER_DIR) 17 | message(STATUS "ONNX_COMPILER_DIR: ${ONNX_COMPILER_DIR}") 18 | endif() 19 | 20 | find_program(_PROTOBUF_PROTOC protoc) 21 | find_program(_GRPC_CPP_PLUGIN_EXECUTABLE grpc_cpp_plugin) 22 | 23 | message(STATUS "_PROTOBUF_PROTOC: ${_PROTOBUF_PROTOC}") 24 | message(STATUS "Protobuf_INCLUDE_DIR: ${Protobuf_INCLUDE_DIR}") 25 | message(STATUS "_GRPC_CPP_PLUGIN_EXECUTABLE: ${_GRPC_CPP_PLUGIN_EXECUTABLE}") 26 | 27 | get_filename_component(onnx_proto "${CMAKE_CURRENT_SOURCE_DIR}/onnx.proto" ABSOLUTE) 28 | get_filename_component(hw_proto "${CMAKE_CURRENT_SOURCE_DIR}/inference.proto" ABSOLUTE) 29 | get_filename_component(hw_proto_path "${hw_proto}" PATH) 30 | 31 | # Generated sources 32 | set(onnx_proto_srcs "${CMAKE_CURRENT_BINARY_DIR}/onnx.pb.cc") 33 | set(onnx_proto_hdrs "${CMAKE_CURRENT_BINARY_DIR}/onnx.pb.h") 34 | set(hw_proto_srcs "${CMAKE_CURRENT_BINARY_DIR}/inference.pb.cc") 35 | set(hw_proto_hdrs "${CMAKE_CURRENT_BINARY_DIR}/inference.pb.h") 36 | set(hw_grpc_srcs "${CMAKE_CURRENT_BINARY_DIR}/inference.grpc.pb.cc") 37 | set(hw_grpc_hdrs "${CMAKE_CURRENT_BINARY_DIR}/inference.grpc.pb.h") 38 | add_custom_command( 39 | OUTPUT "${hw_proto_srcs}" "${hw_proto_hdrs}" "${hw_grpc_srcs}" "${hw_grpc_hdrs}" "${onnx_proto_srcs}" "${onnx_proto_hdrs}" 40 | COMMAND ${_PROTOBUF_PROTOC} 41 | ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}" 42 | --cpp_out "${CMAKE_CURRENT_BINARY_DIR}" 43 | -I "${hw_proto_path}" 44 | --plugin=protoc-gen-grpc="${_GRPC_CPP_PLUGIN_EXECUTABLE}" 45 | "${hw_proto}" "${onnx_proto}" 46 | DEPENDS "${hw_proto}") 47 | 48 | # Include generated *.pb.h files 49 | include_directories("${CMAKE_CURRENT_BINARY_DIR}") 50 | 51 | # hw_grpc_proto 52 | add_library(hw_grpc_proto ${onnx_proto_srcs} ${onnx_proto_hdrs} ${hw_grpc_srcs} ${hw_grpc_hdrs} ${hw_proto_srcs} ${hw_proto_hdrs}) 53 | target_link_libraries(hw_grpc_proto 54 | ${_REFLECTION} 55 | ${_GRPC_GRPCPP} 56 | ) 57 | 58 | find_package(Protobuf REQUIRED) 59 | message(STATUS "Protobuf: ${Protobuf_INCLUDE_DIR}") 60 | include_directories(${Protobuf_INCLUDE_DIR}) 61 | 62 | add_executable(OnnxReader "onnx_reader.cc") 63 | target_link_libraries(OnnxReader 64 | hw_grpc_proto 65 | ${protobuf} 66 | ) -------------------------------------------------------------------------------- /utils/inference.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package inference; 4 | 5 | import "onnx.proto"; 6 | 7 | service InferenceService { 8 | rpc Inference (InferenceRequest) returns (InferenceResponse) {}; 9 | rpc PrintStatistics (PrintStatisticsRequest) returns (PrintStatisticsResponse){}; 10 | } 11 | message InferenceRequest2 { 12 | repeated int64 shape = 1; 13 | bytes data = 2; 14 | string model_name = 3; 15 | } 16 | message InferenceResponse2{ 17 | repeated float data = 1; 18 | } 19 | message inputInfo { 20 | repeated int64 shape = 1; 21 | bytes data = 2; 22 | } 23 | 24 | message PrintStatisticsRequest { 25 | 26 | } 27 | message PrintStatisticsResponse { 28 | 29 | } 30 | 31 | message InferenceRequest { 32 | repeated onnx.TensorProto tensor = 1; 33 | string model_name = 3; 34 | } 35 | 36 | message InferenceResponse{ 37 | repeated onnx.TensorProto tensor = 1; 38 | } 39 | 40 | message ModelConfig { 41 | repeated onnx.ValueInfoProto input = 1; 42 | repeated onnx.ValueInfoProto output = 2; 43 | int64 batch_dim = 3; 44 | int64 max_batch_size = 4; 45 | int64 max_batch_delay_microseconds = 5; 46 | } 47 | -------------------------------------------------------------------------------- /utils/onnx.proto: -------------------------------------------------------------------------------- 1 | // 2 | // WARNING: This file is automatically generated! Please edit onnx.in.proto. 3 | // 4 | 5 | 6 | // SPDX-License-Identifier: Apache-2.0 7 | 8 | 9 | syntax = "proto3"; 10 | 11 | package onnx; 12 | 13 | // Overview 14 | // 15 | // ONNX is an open specification that is comprised of the following components: 16 | // 17 | // 1) A definition of an extensible computation graph model. 18 | // 2) Definitions of standard data types. 19 | // 3) Definitions of built-in operators. 20 | // 21 | // This document describes the syntax of models and their computation graphs, 22 | // as well as the standard data types. Together, they are referred to as the ONNX 23 | // Intermediate Representation, or 'IR' for short. 24 | // 25 | // The normative semantic specification of the ONNX IR is found in docs/IR.md. 26 | // Definitions of the built-in neural network operators may be found in docs/Operators.md. 27 | 28 | // Notes 29 | // 30 | // Protobuf compatibility 31 | // 32 | // To simplify framework compatibility, ONNX is defined using the subset of protobuf 33 | // that is compatible with both protobuf v2 and v3. This means that we do not use any 34 | // protobuf features that are only available in one of the two versions. 35 | // 36 | // Here are the most notable contortions we have to carry out to work around 37 | // these limitations: 38 | // 39 | // - No 'map' (added protobuf 3.0). We instead represent mappings as lists 40 | // of key-value pairs, where order does not matter and duplicates 41 | // are not allowed. 42 | 43 | 44 | // Versioning 45 | // 46 | // ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md 47 | // 48 | // To be compatible with both proto2 and proto3, we will use a version number 49 | // that is not defined by the default value but an explicit enum number. 50 | enum Version { 51 | // proto3 requires the first enum value to be zero. 52 | // We add this just to appease the compiler. 53 | _START_VERSION = 0; 54 | // The version field is always serialized and we will use it to store the 55 | // version that the graph is generated from. This helps us set up version 56 | // control. 57 | // For the IR, we are using simple numbers starting with 0x00000001, 58 | // which was the version we published on Oct 10, 2017. 59 | IR_VERSION_2017_10_10 = 0x0000000000000001; 60 | 61 | // IR_VERSION 2 published on Oct 30, 2017 62 | // - Added type discriminator to AttributeProto to support proto3 users 63 | IR_VERSION_2017_10_30 = 0x0000000000000002; 64 | 65 | // IR VERSION 3 published on Nov 3, 2017 66 | // - For operator versioning: 67 | // - Added new message OperatorSetIdProto 68 | // - Added opset_import in ModelProto 69 | // - For vendor extensions, added domain in NodeProto 70 | IR_VERSION_2017_11_3 = 0x0000000000000003; 71 | 72 | // IR VERSION 4 published on Jan 22, 2019 73 | // - Relax constraint that initializers should be a subset of graph inputs 74 | // - Add type BFLOAT16 75 | IR_VERSION_2019_1_22 = 0x0000000000000004; 76 | 77 | // IR VERSION 5 published on March 18, 2019 78 | // - Add message TensorAnnotation. 79 | // - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters. 80 | IR_VERSION_2019_3_18 = 0x0000000000000005; 81 | 82 | // IR VERSION 6 published on Sep 19, 2019 83 | // - Add support for sparse tensor constants stored in model. 84 | // - Add message SparseTensorProto 85 | // - Add sparse initializers 86 | IR_VERSION_2019_9_19 = 0x0000000000000006; 87 | 88 | // IR VERSION 7 published on May 8, 2020 89 | // - Add support to allow function body graph to rely on multiple external opreator sets. 90 | // - Add a list to promote inference graph's initializers to global and 91 | // mutable variables. Global variables are visible in all graphs of the 92 | // stored models. 93 | // - Add message TrainingInfoProto to store initialization 94 | // method and training algorithm. The execution of TrainingInfoProto 95 | // can modify the values of mutable variables. 96 | // - Implicitly add inference graph into each TrainingInfoProto's algorithm. 97 | IR_VERSION_2020_5_8 = 0x0000000000000007; 98 | 99 | // IR VERSION 8 published on 100 | // Introduce TypeProto.SparseTensor 101 | // Introduce TypeProto.Optional 102 | // Added a list of FunctionProtos local to the model 103 | // Deprecated since_version and operator status from FunctionProto 104 | IR_VERSION = 0x0000000000000008; 105 | 106 | } 107 | 108 | // Attributes 109 | // 110 | // A named attribute containing either singular float, integer, string, graph, 111 | // and tensor values, or repeated float, integer, string, graph, and tensor values. 112 | // An AttributeProto MUST contain the name field, and *only one* of the 113 | // following content fields, effectively enforcing a C/C++ union equivalent. 114 | message AttributeProto { 115 | 116 | // Note: this enum is structurally identical to the OpSchema::AttrType 117 | // enum defined in schema.h. If you rev one, you likely need to rev the other. 118 | enum AttributeType { 119 | UNDEFINED = 0; 120 | FLOAT = 1; 121 | INT = 2; 122 | STRING = 3; 123 | TENSOR = 4; 124 | GRAPH = 5; 125 | SPARSE_TENSOR = 11; 126 | TYPE_PROTO = 13; 127 | 128 | FLOATS = 6; 129 | INTS = 7; 130 | STRINGS = 8; 131 | TENSORS = 9; 132 | GRAPHS = 10; 133 | SPARSE_TENSORS = 12; 134 | TYPE_PROTOS = 14; 135 | } 136 | 137 | // The name field MUST be present for this version of the IR. 138 | string name = 1; // namespace Attribute 139 | 140 | // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function. 141 | // In this case, this AttributeProto does not contain data, and it's a reference of attribute 142 | // in parent scope. 143 | // NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph. 144 | string ref_attr_name = 21; 145 | 146 | // A human-readable documentation for this attribute. Markdown is allowed. 147 | string doc_string = 13; 148 | 149 | // The type field MUST be present for this version of the IR. 150 | // For 0.0.1 versions of the IR, this field was not defined, and 151 | // implementations needed to use has_field heuristics to determine 152 | // which value field was in use. For IR_VERSION 0.0.2 or later, this 153 | // field MUST be set and match the f|i|s|t|... field in use. This 154 | // change was made to accommodate proto3 implementations. 155 | AttributeType type = 20; // discriminator that indicates which field below is in use 156 | 157 | // Exactly ONE of the following fields must be present for this version of the IR 158 | float f = 2; // float 159 | int64 i = 3; // int 160 | bytes s = 4; // UTF-8 string 161 | TensorProto t = 5; // tensor value 162 | GraphProto g = 6; // graph 163 | SparseTensorProto sparse_tensor = 22; // sparse tensor value 164 | // Do not use field below, it's deprecated. 165 | // optional ValueProto v = 12; // value - subsumes everything but graph 166 | TypeProto tp = 14; // type proto 167 | 168 | repeated float floats = 7; // list of floats 169 | repeated int64 ints = 8; // list of ints 170 | repeated bytes strings = 9; // list of UTF-8 strings 171 | repeated TensorProto tensors = 10; // list of tensors 172 | repeated GraphProto graphs = 11; // list of graph 173 | repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors 174 | repeated TypeProto type_protos = 15;// list of type protos 175 | } 176 | 177 | // Defines information on value, including the name, the type, and 178 | // the shape of the value. 179 | message ValueInfoProto { 180 | // This field MUST be present in this version of the IR. 181 | string name = 1; // namespace Value 182 | // This field MUST be present in this version of the IR for 183 | // inputs and outputs of the top-level graph. 184 | TypeProto type = 2; 185 | // A human-readable documentation for this value. Markdown is allowed. 186 | string doc_string = 3; 187 | } 188 | 189 | // Nodes 190 | // 191 | // Computation graphs are made up of a DAG of nodes, which represent what is 192 | // commonly called a "layer" or "pipeline stage" in machine learning frameworks. 193 | // 194 | // For example, it can be a node of type "Conv" that takes in an image, a filter 195 | // tensor and a bias tensor, and produces the convolved output. 196 | message NodeProto { 197 | repeated string input = 1; // namespace Value 198 | repeated string output = 2; // namespace Value 199 | 200 | // An optional identifier for this node in a graph. 201 | // This field MAY be absent in ths version of the IR. 202 | string name = 3; // namespace Node 203 | 204 | // The symbolic identifier of the Operator to execute. 205 | string op_type = 4; // namespace Operator 206 | // The domain of the OperatorSet that specifies the operator named by op_type. 207 | string domain = 7; // namespace Domain 208 | 209 | // Additional named attributes. 210 | repeated AttributeProto attribute = 5; 211 | 212 | // A human-readable documentation for this node. Markdown is allowed. 213 | string doc_string = 6; 214 | } 215 | 216 | // Training information 217 | // TrainingInfoProto stores information for training a model. 218 | // In particular, this defines two functionalities: an initialization-step 219 | // and a training-algorithm-step. Initialization resets the model 220 | // back to its original state as if no training has been performed. 221 | // Training algorithm improves the model based on input data. 222 | // 223 | // The semantics of the initialization-step is that the initializers 224 | // in ModelProto.graph and in TrainingInfoProto.algorithm are first 225 | // initialized as specified by the initializers in the graph, and then 226 | // updated by the "initialization_binding" in every instance in 227 | // ModelProto.training_info. 228 | // 229 | // The field "algorithm" defines a computation graph which represents a 230 | // training algorithm's step. After the execution of a 231 | // TrainingInfoProto.algorithm, the initializers specified by "update_binding" 232 | // may be immediately updated. If the targeted training algorithm contains 233 | // consecutive update steps (such as block coordinate descent methods), 234 | // the user needs to create a TrainingInfoProto for each step. 235 | message TrainingInfoProto { 236 | // This field describes a graph to compute the initial tensors 237 | // upon starting the training process. Initialization graph has no input 238 | // and can have multiple outputs. Usually, trainable tensors in neural 239 | // networks are randomly initialized. To achieve that, for each tensor, 240 | // the user can put a random number operator such as RandomNormal or 241 | // RandomUniform in TrainingInfoProto.initialization.node and assign its 242 | // random output to the specific tensor using "initialization_binding". 243 | // This graph can also set the initializers in "algorithm" in the same 244 | // TrainingInfoProto; a use case is resetting the number of training 245 | // iteration to zero. 246 | // 247 | // By default, this field is an empty graph and its evaluation does not 248 | // produce any output. Thus, no initializer would be changed by default. 249 | GraphProto initialization = 1; 250 | 251 | // This field represents a training algorithm step. Given required inputs, 252 | // it computes outputs to update initializers in its own or inference graph's 253 | // initializer lists. In general, this field contains loss node, gradient node, 254 | // optimizer node, increment of iteration count. 255 | // 256 | // An execution of the training algorithm step is performed by executing the 257 | // graph obtained by combining the inference graph (namely "ModelProto.graph") 258 | // and the "algorithm" graph. That is, the actual the actual 259 | // input/initializer/output/node/value_info/sparse_initializer list of 260 | // the training graph is the concatenation of 261 | // "ModelProto.graph.input/initializer/output/node/value_info/sparse_initializer" 262 | // and "algorithm.input/initializer/output/node/value_info/sparse_initializer" 263 | // in that order. This combined graph must satisfy the normal ONNX conditions. 264 | // Now, let's provide a visualization of graph combination for clarity. 265 | // Let the inference graph (i.e., "ModelProto.graph") be 266 | // tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d 267 | // and the "algorithm" graph be 268 | // tensor_d -> Add -> tensor_e 269 | // The combination process results 270 | // tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d -> Add -> tensor_e 271 | // 272 | // Notice that an input of a node in the "algorithm" graph may reference the 273 | // output of a node in the inference graph (but not the other way round). Also, inference 274 | // node cannot reference inputs of "algorithm". With these restrictions, inference graph 275 | // can always be run independently without training information. 276 | // 277 | // By default, this field is an empty graph and its evaluation does not 278 | // produce any output. Evaluating the default training step never 279 | // update any initializers. 280 | GraphProto algorithm = 2; 281 | 282 | // This field specifies the bindings from the outputs of "initialization" to 283 | // some initializers in "ModelProto.graph.initializer" and 284 | // the "algorithm.initializer" in the same TrainingInfoProto. 285 | // See "update_binding" below for details. 286 | // 287 | // By default, this field is empty and no initializer would be changed 288 | // by the execution of "initialization". 289 | repeated StringStringEntryProto initialization_binding = 3; 290 | 291 | // Gradient-based training is usually an iterative procedure. In one gradient 292 | // descent iteration, we apply 293 | // 294 | // x = x - r * g 295 | // 296 | // where "x" is the optimized tensor, "r" stands for learning rate, and "g" is 297 | // gradient of "x" with respect to a chosen loss. To avoid adding assignments 298 | // into the training graph, we split the update equation into 299 | // 300 | // y = x - r * g 301 | // x = y 302 | // 303 | // The user needs to save "y = x - r * g" into TrainingInfoProto.algorithm. To 304 | // tell that "y" should be assigned to "x", the field "update_binding" may 305 | // contain a key-value pair of strings, "x" (key of StringStringEntryProto) 306 | // and "y" (value of StringStringEntryProto). 307 | // For a neural network with multiple trainable (mutable) tensors, there can 308 | // be multiple key-value pairs in "update_binding". 309 | // 310 | // The initializers appears as keys in "update_binding" are considered 311 | // mutable variables. This implies some behaviors 312 | // as described below. 313 | // 314 | // 1. We have only unique keys in all "update_binding"s so that two 315 | // variables may not have the same name. This ensures that one 316 | // variable is assigned up to once. 317 | // 2. The keys must appear in names of "ModelProto.graph.initializer" or 318 | // "TrainingInfoProto.algorithm.initializer". 319 | // 3. The values must be output names of "algorithm" or "ModelProto.graph.output". 320 | // 4. Mutable variables are initialized to the value specified by the 321 | // corresponding initializer, and then potentially updated by 322 | // "initializer_binding"s and "update_binding"s in "TrainingInfoProto"s. 323 | // 324 | // This field usually contains names of trainable tensors 325 | // (in ModelProto.graph), optimizer states such as momentums in advanced 326 | // stochastic gradient methods (in TrainingInfoProto.graph), 327 | // and number of training iterations (in TrainingInfoProto.graph). 328 | // 329 | // By default, this field is empty and no initializer would be changed 330 | // by the execution of "algorithm". 331 | repeated StringStringEntryProto update_binding = 4; 332 | } 333 | 334 | // Models 335 | // 336 | // ModelProto is a top-level file/container format for bundling a ML model and 337 | // associating its computation graph with metadata. 338 | // 339 | // The semantics of the model are described by the associated GraphProto's. 340 | message ModelProto { 341 | // The version of the IR this model targets. See Version enum above. 342 | // This field MUST be present. 343 | int64 ir_version = 1; 344 | 345 | // The OperatorSets this model relies on. 346 | // All ModelProtos MUST have at least one entry that 347 | // specifies which version of the ONNX OperatorSet is 348 | // being imported. 349 | // 350 | // All nodes in the ModelProto's graph will bind against the operator 351 | // with the same-domain/same-op_type operator with the HIGHEST version 352 | // in the referenced operator sets. 353 | repeated OperatorSetIdProto opset_import = 8; 354 | 355 | // The name of the framework or tool used to generate this model. 356 | // This field SHOULD be present to indicate which implementation/tool/framework 357 | // emitted the model. 358 | string producer_name = 2; 359 | 360 | // The version of the framework or tool used to generate this model. 361 | // This field SHOULD be present to indicate which implementation/tool/framework 362 | // emitted the model. 363 | string producer_version = 3; 364 | 365 | // Domain name of the model. 366 | // We use reverse domain names as name space indicators. For example: 367 | // `com.facebook.fair` or `com.microsoft.cognitiveservices` 368 | // 369 | // Together with `model_version` and GraphProto.name, this forms the unique identity of 370 | // the graph. 371 | string domain = 4; 372 | 373 | // The version of the graph encoded. See Version enum below. 374 | int64 model_version = 5; 375 | 376 | // A human-readable documentation for this model. Markdown is allowed. 377 | string doc_string = 6; 378 | 379 | // The parameterized graph that is evaluated to execute the model. 380 | GraphProto graph = 7; 381 | 382 | // Named metadata values; keys should be distinct. 383 | repeated StringStringEntryProto metadata_props = 14; 384 | 385 | // Training-specific information. Sequentially executing all stored 386 | // `TrainingInfoProto.algorithm`s and assigning their outputs following 387 | // the corresponding `TrainingInfoProto.update_binding`s is one training 388 | // iteration. Similarly, to initialize the model 389 | // (as if training hasn't happened), the user should sequentially execute 390 | // all stored `TrainingInfoProto.initialization`s and assigns their outputs 391 | // using `TrainingInfoProto.initialization_binding`s. 392 | // 393 | // If this field is empty, the training behavior of the model is undefined. 394 | repeated TrainingInfoProto training_info = 20; 395 | 396 | // A list of function protos local to the model. 397 | // 398 | // Name of the function "FunctionProto.name" should be unique within the domain "FunctionProto.domain". 399 | // In case of any conflicts the behavior (whether the model local functions are given higher priority, 400 | // or standard opserator sets are given higher priotity or this is treated as error) is defined by 401 | // the runtimes. 402 | // 403 | // The operator sets imported by FunctionProto should be compatible with the ones 404 | // imported by ModelProto and other model local FunctionProtos. 405 | // Example, if same operator set say 'A' is imported by a FunctionProto and ModelProto 406 | // or by 2 FunctionProtos then versions for the operator set may be different but, 407 | // the operator schema returned for op_type, domain, version combination 408 | // for both the versions should be same for every node in the function body. 409 | // 410 | // One FunctionProto can reference other FunctionProto in the model, however, recursive reference 411 | // is not allowed. 412 | repeated FunctionProto functions = 25; 413 | }; 414 | 415 | // StringStringEntryProto follows the pattern for cross-proto-version maps. 416 | // See https://developers.google.com/protocol-buffers/docs/proto3#maps 417 | message StringStringEntryProto { 418 | string key = 1; 419 | string value = 2; 420 | }; 421 | 422 | message TensorAnnotation { 423 | string tensor_name = 1; 424 | // pairs to annotate tensor specified by above. 425 | // The keys used in the mapping below must be pre-defined in ONNX spec. 426 | // For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as 427 | // quantization parameter keys. 428 | repeated StringStringEntryProto quant_parameter_tensor_names = 2; 429 | } 430 | 431 | 432 | 433 | // Graphs 434 | // 435 | // A graph defines the computational logic of a model and is comprised of a parameterized 436 | // list of nodes that form a directed acyclic graph based on their inputs and outputs. 437 | // This is the equivalent of the "network" or "graph" in many deep learning 438 | // frameworks. 439 | message GraphProto { 440 | // The nodes in the graph, sorted topologically. 441 | repeated NodeProto node = 1; 442 | 443 | // The name of the graph. 444 | string name = 2; // namespace Graph 445 | 446 | // A list of named tensor values, used to specify constant inputs of the graph. 447 | // Each initializer (both TensorProto as well SparseTensorProto) MUST have a name. 448 | // The name MUST be unique across both initializer and sparse_initializer, 449 | // but the name MAY also appear in the input list. 450 | repeated TensorProto initializer = 5; 451 | 452 | // Initializers (see above) stored in sparse format. 453 | repeated SparseTensorProto sparse_initializer = 15; 454 | 455 | // A human-readable documentation for this graph. Markdown is allowed. 456 | string doc_string = 10; 457 | 458 | // The inputs and outputs of the graph. 459 | repeated ValueInfoProto input = 11; 460 | repeated ValueInfoProto output = 12; 461 | 462 | // Information for the values in the graph. The ValueInfoProto.name's 463 | // must be distinct. It is optional for a value to appear in value_info list. 464 | repeated ValueInfoProto value_info = 13; 465 | 466 | // This field carries information to indicate the mapping among a tensor and its 467 | // quantization parameter tensors. For example: 468 | // For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated, 469 | // which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model. 470 | repeated TensorAnnotation quantization_annotation = 14; 471 | 472 | reserved 3, 4, 6 to 9; 473 | reserved "ir_version", "producer_version", "producer_tag", "domain"; 474 | } 475 | 476 | // Tensors 477 | // 478 | // A serialized tensor value. 479 | message TensorProto { 480 | enum DataType { 481 | UNDEFINED = 0; 482 | // Basic types. 483 | FLOAT = 1; // float 484 | UINT8 = 2; // uint8_t 485 | INT8 = 3; // int8_t 486 | UINT16 = 4; // uint16_t 487 | INT16 = 5; // int16_t 488 | INT32 = 6; // int32_t 489 | INT64 = 7; // int64_t 490 | STRING = 8; // string 491 | BOOL = 9; // bool 492 | 493 | // IEEE754 half-precision floating-point format (16 bits wide). 494 | // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits. 495 | FLOAT16 = 10; 496 | 497 | DOUBLE = 11; 498 | UINT32 = 12; 499 | UINT64 = 13; 500 | COMPLEX64 = 14; // complex with float32 real and imaginary components 501 | COMPLEX128 = 15; // complex with float64 real and imaginary components 502 | 503 | // Non-IEEE floating-point format based on IEEE754 single-precision 504 | // floating-point number truncated to 16 bits. 505 | // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits. 506 | BFLOAT16 = 16; 507 | 508 | // Future extensions go here. 509 | } 510 | 511 | // The shape of the tensor. 512 | repeated int64 dims = 1; 513 | 514 | // The data type of the tensor. 515 | // This field MUST have a valid TensorProto.DataType value 516 | int32 data_type = 2; 517 | 518 | // For very large tensors, we may want to store them in chunks, in which 519 | // case the following fields will specify the segment that is stored in 520 | // the current TensorProto. 521 | message Segment { 522 | int64 begin = 1; 523 | int64 end = 2; 524 | } 525 | Segment segment = 3; 526 | 527 | // Tensor content must be organized in row-major order. 528 | // 529 | // Depending on the data_type field, exactly one of the fields below with 530 | // name ending in _data is used to store the elements of the tensor. 531 | 532 | // For float and complex64 values 533 | // Complex64 tensors are encoded as a single array of floats, 534 | // with the real components appearing in odd numbered positions, 535 | // and the corresponding imaginary component appearing in the 536 | // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] 537 | // is encoded as [1.0, 2.0 ,3.0 ,4.0] 538 | // When this field is present, the data_type field MUST be FLOAT or COMPLEX64. 539 | repeated float float_data = 4 [packed = true]; 540 | 541 | // For int32, uint8, int8, uint16, int16, bool, and float16 values 542 | // float16 values must be bit-wise converted to an uint16_t prior 543 | // to writing to the buffer. 544 | // When this field is present, the data_type field MUST be 545 | // INT32, INT16, INT8, UINT16, UINT8, BOOL, or FLOAT16 546 | repeated int32 int32_data = 5 [packed = true]; 547 | 548 | // For strings. 549 | // Each element of string_data is a UTF-8 encoded Unicode 550 | // string. No trailing null, no leading BOM. The protobuf "string" 551 | // scalar type is not used to match ML community conventions. 552 | // When this field is present, the data_type field MUST be STRING 553 | repeated bytes string_data = 6; 554 | 555 | // For int64. 556 | // When this field is present, the data_type field MUST be INT64 557 | repeated int64 int64_data = 7 [packed = true]; 558 | 559 | // Optionally, a name for the tensor. 560 | string name = 8; // namespace Value 561 | 562 | // A human-readable documentation for this tensor. Markdown is allowed. 563 | string doc_string = 12; 564 | 565 | // Serializations can either use one of the fields above, or use this 566 | // raw bytes field. The only exception is the string case, where one is 567 | // required to store the content in the repeated bytes string_data field. 568 | // 569 | // When this raw_data field is used to store tensor value, elements MUST 570 | // be stored in as fixed-width, little-endian order. 571 | // Floating-point data types MUST be stored in IEEE 754 format. 572 | // Complex64 elements must be written as two consecutive FLOAT values, real component first. 573 | // Complex128 elements must be written as two consecutive DOUBLE values, real component first. 574 | // Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false). 575 | // 576 | // Note: the advantage of specific field rather than the raw_data field is 577 | // that in some cases (e.g. int data), protobuf does a better packing via 578 | // variable length storage, and may lead to smaller binary footprint. 579 | // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED 580 | bytes raw_data = 9; 581 | 582 | // Data can be stored inside the protobuf file using type-specific fields or raw_data. 583 | // Alternatively, raw bytes data can be stored in an external file, using the external_data field. 584 | // external_data stores key-value pairs describing data location. Recognized keys are: 585 | // - "location" (required) - POSIX filesystem path relative to the directory where the ONNX 586 | // protobuf model was stored 587 | // - "offset" (optional) - position of byte at which stored data begins. Integer stored as string. 588 | // Offset values SHOULD be multiples 4096 (page size) to enable mmap support. 589 | // - "length" (optional) - number of bytes containing data. Integer stored as string. 590 | // - "checksum" (optional) - SHA1 digest of file specified in under 'location' key. 591 | repeated StringStringEntryProto external_data = 13; 592 | 593 | // Location of the data for this tensor. MUST be one of: 594 | // - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field. 595 | // - EXTERNAL - data stored in an external location as described by external_data field. 596 | enum DataLocation { 597 | DEFAULT = 0; 598 | EXTERNAL = 1; 599 | } 600 | 601 | // If value not set, data is stored in raw_data (if set) otherwise in type-specified field. 602 | DataLocation data_location = 14; 603 | 604 | // For double 605 | // Complex128 tensors are encoded as a single array of doubles, 606 | // with the real components appearing in odd numbered positions, 607 | // and the corresponding imaginary component appearing in the 608 | // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] 609 | // is encoded as [1.0, 2.0 ,3.0 ,4.0] 610 | // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 611 | repeated double double_data = 10 [packed = true]; 612 | 613 | // For uint64 and uint32 values 614 | // When this field is present, the data_type field MUST be 615 | // UINT32 or UINT64 616 | repeated uint64 uint64_data = 11 [packed = true]; 617 | } 618 | 619 | // A serialized sparse-tensor value 620 | message SparseTensorProto { 621 | // The sequence of non-default values are encoded as a tensor of shape [NNZ]. 622 | // The default-value is zero for numeric tensors, and empty-string for string tensors. 623 | // values must have a non-empty name present which serves as a name for SparseTensorProto 624 | // when used in sparse_initializer list. 625 | TensorProto values = 1; 626 | 627 | // The indices of the non-default values, which may be stored in one of two formats. 628 | // (a) Indices can be a tensor of shape [NNZ, rank] with the [i,j]-th value 629 | // corresponding to the j-th index of the i-th value (in the values tensor). 630 | // (b) Indices can be a tensor of shape [NNZ], in which case the i-th value 631 | // must be the linearized-index of the i-th value (in the values tensor). 632 | // The linearized-index can be converted into an index tuple (k_1,...,k_rank) 633 | // using the shape provided below. 634 | // The indices must appear in ascending order without duplication. 635 | // In the first format, the ordering is lexicographic-ordering: 636 | // e.g., index-value [1,4] must appear before [2,1] 637 | TensorProto indices = 2; 638 | 639 | // The shape of the underlying dense-tensor: [dim_1, dim_2, ... dim_rank] 640 | repeated int64 dims = 3; 641 | } 642 | 643 | // Defines a tensor shape. A dimension can be either an integer value 644 | // or a symbolic variable. A symbolic variable represents an unknown 645 | // dimension. 646 | message TensorShapeProto { 647 | message Dimension { 648 | oneof value { 649 | int64 dim_value = 1; 650 | string dim_param = 2; // namespace Shape 651 | }; 652 | // Standard denotation can optionally be used to denote tensor 653 | // dimensions with standard semantic descriptions to ensure 654 | // that operations are applied to the correct axis of a tensor. 655 | // Refer to https://github.com/onnx/onnx/blob/main/docs/DimensionDenotation.md#denotation-definition 656 | // for pre-defined dimension denotations. 657 | string denotation = 3; 658 | }; 659 | repeated Dimension dim = 1; 660 | } 661 | 662 | // Types 663 | // 664 | // The standard ONNX data types. 665 | message TypeProto { 666 | 667 | message Tensor { 668 | // This field MUST NOT have the value of UNDEFINED 669 | // This field MUST have a valid TensorProto.DataType value 670 | // This field MUST be present for this version of the IR. 671 | int32 elem_type = 1; 672 | TensorShapeProto shape = 2; 673 | } 674 | 675 | // repeated T 676 | message Sequence { 677 | // The type and optional shape of each element of the sequence. 678 | // This field MUST be present for this version of the IR. 679 | TypeProto elem_type = 1; 680 | }; 681 | 682 | // map 683 | message Map { 684 | // This field MUST have a valid TensorProto.DataType value 685 | // This field MUST be present for this version of the IR. 686 | // This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING 687 | int32 key_type = 1; 688 | // This field MUST be present for this version of the IR. 689 | TypeProto value_type = 2; 690 | }; 691 | 692 | // wrapper for Tensor, Sequence, or Map 693 | message Optional { 694 | // The type and optional shape of the element wrapped. 695 | // This field MUST be present for this version of the IR. 696 | // Possible values correspond to OptionalProto.DataType enum 697 | TypeProto elem_type = 1; 698 | }; 699 | 700 | 701 | message SparseTensor { 702 | // This field MUST NOT have the value of UNDEFINED 703 | // This field MUST have a valid TensorProto.DataType value 704 | // This field MUST be present for this version of the IR. 705 | int32 elem_type = 1; 706 | TensorShapeProto shape = 2; 707 | } 708 | 709 | 710 | oneof value { 711 | // The type of a tensor. 712 | Tensor tensor_type = 1; 713 | 714 | // NOTE: DNN-only implementations of ONNX MAY elect to not support non-tensor values 715 | // as input and output to graphs and nodes. These types are needed to naturally 716 | // support classical ML operators. DNN operators SHOULD restrict their input 717 | // and output types to tensors. 718 | 719 | // The type of a sequence. 720 | Sequence sequence_type = 4; 721 | 722 | // The type of a map. 723 | Map map_type = 5; 724 | 725 | // The type of an optional. 726 | Optional optional_type = 9; 727 | 728 | 729 | // Type of the sparse tensor 730 | SparseTensor sparse_tensor_type = 8; 731 | 732 | } 733 | 734 | // An optional denotation can be used to denote the whole 735 | // type with a standard semantic description as to what is 736 | // stored inside. Refer to https://github.com/onnx/onnx/blob/main/docs/TypeDenotation.md#type-denotation-definition 737 | // for pre-defined type denotations. 738 | string denotation = 6; 739 | } 740 | 741 | // Operator Sets 742 | // 743 | // OperatorSets are uniquely identified by a (domain, opset_version) pair. 744 | message OperatorSetIdProto { 745 | // The domain of the operator set being identified. 746 | // The empty string ("") or absence of this field implies the operator 747 | // set that is defined as part of the ONNX specification. 748 | // This field MUST be present in this version of the IR when referring to any other operator set. 749 | string domain = 1; 750 | 751 | // The version of the operator set being identified. 752 | // This field MUST be present in this version of the IR. 753 | int64 version = 2; 754 | } 755 | 756 | // Operator/function status. 757 | enum OperatorStatus { 758 | EXPERIMENTAL = 0; 759 | STABLE = 1; 760 | } 761 | 762 | message FunctionProto { 763 | // The name of the function, similar usage of op_type in OperatorProto. 764 | // Combined with FunctionProto.domain, this forms the unique identity of 765 | // the FunctionProto. 766 | string name = 1; 767 | 768 | // Deprecated since IR Version 8 769 | // optional int64 since_version = 2; 770 | reserved 2; 771 | reserved "since_version"; 772 | 773 | // Deprecated since IR Version 8 774 | // optional OperatorStatus status = 3; 775 | reserved 3; 776 | reserved "status"; 777 | 778 | // The inputs and outputs of the function. 779 | repeated string input = 4; 780 | repeated string output = 5; 781 | 782 | // The attributes of the function. 783 | repeated string attribute = 6; 784 | 785 | // The nodes in the function. 786 | repeated NodeProto node = 7; 787 | // A human-readable documentation for this function. Markdown is allowed. 788 | string doc_string = 8; 789 | 790 | // The OperatorSets this function body (graph) relies on. 791 | // 792 | // All nodes in the function body (graph) will bind against the operator 793 | // with the same-domain/same-op_type operator with the HIGHEST version 794 | // in the referenced operator sets. This means at most one version can be relied 795 | // for one domain. 796 | // 797 | // The operator sets imported by FunctionProto should be compatible with the ones 798 | // imported by ModelProto. Example, if same operator set say 'A' is imported by FunctionProto 799 | // and ModelProto then versions for the operator set may be different but, 800 | // the operator schema returned for op_type, domain, version combination 801 | // for both the versions should be same. 802 | 803 | repeated OperatorSetIdProto opset_import = 9; 804 | 805 | // The domain which this function belongs to. Combined with FunctionProto.name, this forms the unique identity of 806 | // the FunctionProto. 807 | string domain = 10; 808 | } 809 | 810 | 811 | // For using protobuf-lite 812 | // option optimize_for = LITE_RUNTIME; -------------------------------------------------------------------------------- /utils/onnx_reader.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include "inference.pb.h" 6 | #include "onnx.pb.h" 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | using google::protobuf::io::FileOutputStream; 16 | using google::protobuf::io::FileInputStream; 17 | 18 | 19 | void createModelConfig(const char* file_path, const char* out_put_file){ 20 | std::ifstream input(file_path,std::ios::ate | std::ios::binary); // open file and move current position in file to the end 21 | std::streamsize size = input.tellg(); // get current position in file 22 | input.seekg(0,std::ios::beg); // move to start of file 23 | std::vector buffer(size); 24 | input.read(buffer.data(),size); // read raw data 25 | onnx::ModelProto model; 26 | model.ParseFromArray(buffer.data(),size); // parse protobuf 27 | auto graph = model.graph(); 28 | 29 | inference::ModelConfig modelConfig; 30 | 31 | std::vector initializer_name; 32 | for(auto initializer: graph.initializer()){ 33 | initializer_name.emplace_back(initializer.name()); 34 | } 35 | 36 | for(auto input_data: graph.input()){ 37 | bool match = false; 38 | for(std::string n: initializer_name){ 39 | if(input_data.name().compare(n)==0){ 40 | match = true; 41 | break; 42 | } 43 | } 44 | if(!match){ 45 | auto input = modelConfig.add_input(); 46 | input->CopyFrom(input_data); 47 | } 48 | 49 | } 50 | 51 | for(auto output_data: graph.output()){ 52 | bool match = false; 53 | for(std::string n: initializer_name){ 54 | if(output_data.name().compare(n)==0){ 55 | match = true; 56 | break; 57 | } 58 | } 59 | if(!match){ 60 | auto output = modelConfig.add_output(); 61 | output->CopyFrom(output_data); 62 | } 63 | } 64 | 65 | modelConfig.set_max_batch_size(1); 66 | 67 | 68 | int fd = open(out_put_file, O_WRONLY | O_CREAT | O_TRUNC, 0644); 69 | FileOutputStream* outputfile = new FileOutputStream(fd); 70 | google::protobuf::TextFormat::Print(modelConfig, outputfile); 71 | outputfile->Flush(); 72 | close(fd); 73 | 74 | } 75 | 76 | 77 | int main(int argc, char** argv) { 78 | std::string input_path = argv[1]; 79 | int found = input_path.find_last_of("/\\"); 80 | std::string output_path = "config"; 81 | if(found > 0){ 82 | output_path = input_path.substr(0,found) + "/config"; 83 | } 84 | std::cout << "input path: " << input_path << '\n'; 85 | std::cout << "output path: " << output_path << '\n'; 86 | createModelConfig(input_path.c_str(), output_path.c_str()); 87 | 88 | } --------------------------------------------------------------------------------