├── .github └── workflows │ ├── build.yml │ ├── formatting.yml │ ├── main.yml │ ├── mlbridge_build.sh │ ├── upload_pypi.yml │ └── wheel.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CMakeLists.txt ├── CompilerInterface ├── compilerinterface │ ├── BaseCompilerInterface.py │ ├── GrpcCompilerInterface.py │ ├── PipeCompilerInterface.py │ ├── SerDes.py │ ├── __init__.py │ └── log_reader.py ├── fetch_version.py ├── pyproject.toml └── setup.cfg ├── LICENSE ├── MLModelRunner ├── C │ ├── CMakeLists.txt │ ├── ONNXModelRunnerCWrapper.cpp │ └── PipeModelRunnerCWrapper.cpp ├── CMakeLists.txt ├── ONNXModelRunner │ ├── CMakeLists.txt │ ├── ONNXModelRunner.cpp │ ├── README.md │ ├── agent.cpp │ └── onnx.cpp ├── PipeModelRunner.cpp ├── Utils │ ├── CMakeLists.txt │ └── MLConfig.cpp └── gRPCModelRunner │ ├── CMakeLists.txt │ └── README.md ├── README.md ├── SerDes ├── CMakeLists.txt ├── TensorSpec.cpp ├── bitstreamSerDes.cpp ├── jsonSerDes.cpp ├── protobufSerDes.cpp └── tensorflowSerDes.cpp ├── cmake └── modules │ ├── FindOnnxruntime.cmake │ └── TensorFlowCompile.cmake ├── docs └── Doxyfile ├── images └── component-ml-compiler-bridge.png ├── include ├── MLModelRunner │ ├── C │ │ ├── ONNXModelRunner.h │ │ └── PipeModelRunner.h │ ├── MLModelRunner.h │ ├── ONNXModelRunner │ │ ├── ONNXModelRunner.h │ │ ├── agent.h │ │ ├── environment.h │ │ ├── onnx.h │ │ └── utils.h │ ├── PipeModelRunner.h │ ├── TFModelRunner.h │ ├── Utils │ │ ├── DataTypes.h │ │ ├── Debug.h │ │ └── MLConfig.h │ └── gRPCModelRunner.h └── SerDes │ ├── TensorSpec.h │ ├── baseSerDes.h │ ├── bitstreamSerDes.h │ ├── jsonSerDes.h │ ├── protobufSerDes.h │ └── tensorflowSerDes.h ├── mlbridge.yml ├── test ├── CMakeLists.txt ├── MLBridgeTest.cpp ├── hello-mlbridge.py ├── include │ ├── HelloMLBridge_Env.h │ └── ProtosInclude.h ├── mlbridge-test.py ├── mlbridge-test.sh ├── onnx │ └── dummy_model.onnx └── protos │ ├── MLBridgeTest_bool.proto │ ├── MLBridgeTest_char.proto │ ├── MLBridgeTest_double.proto │ ├── MLBridgeTest_float.proto │ ├── MLBridgeTest_int.proto │ ├── MLBridgeTest_long.proto │ ├── MLBridgeTest_vec_double.proto │ ├── MLBridgeTest_vec_float.proto │ ├── MLBridgeTest_vec_int.proto │ ├── MLBridgeTest_vec_long.proto │ └── helloMLBridge.proto └── tools.cpp /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Build and Tests 2 | 3 | on: 4 | workflow_dispatch: 5 | branches: [ main ] 6 | push: 7 | branches: [ main ] 8 | paths: 9 | - '**.cpp' 10 | - '**.h' 11 | - '**.py' 12 | pull_request: 13 | branches: [ main ] 14 | 15 | jobs: 16 | build: 17 | 18 | runs-on: ubuntu-20.04 19 | defaults: 20 | run: 21 | shell: bash -l {0} 22 | steps: 23 | - name: Install LLVM-10 24 | run: | 25 | wget https://apt.llvm.org/llvm.sh 26 | sudo bash llvm.sh 10 27 | - uses: actions/checkout@v3 28 | - name: Setup Conda dependencies 29 | uses: conda-incubator/setup-miniconda@v2 30 | with: 31 | channels: conda-forge,bioconda,defaults 32 | auto-activate-base: true 33 | activate-environment: mlbridge 34 | environment-file: mlbridge.yml 35 | - name: Install other dependencies and build 36 | run: bash .github/workflows/mlbridge_build.sh Release 37 | - name: Setup Conda dependencies 38 | uses: conda-incubator/setup-miniconda@v2 39 | with: 40 | channels: conda-forge,bioconda,defaults 41 | auto-activate-base: true 42 | activate-environment: mlbridge 43 | environment-file: mlbridge.yml 44 | - name: Run Tests 45 | run: | 46 | conda init 47 | conda activate mlbridge 48 | pip install compilerinterface 49 | cd $GITHUB_WORKSPACE/test 50 | bash mlbridge-test.sh 51 | - uses: actions/upload-artifact@v2 52 | with: 53 | name: MLCompilerBridge 54 | path: | 55 | install/lib 56 | install/include 57 | -------------------------------------------------------------------------------- /.github/workflows/formatting.yml: -------------------------------------------------------------------------------- 1 | name: pre-commit checks 2 | on: 3 | workflow_dispatch: 4 | branches: [ main ] 5 | pull_request: 6 | branches: [ main ] 7 | jobs: 8 | pre-commit: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | - uses: actions/setup-python@v2 13 | - uses: pre-commit/action@v2.0.3 14 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | # This is a basic workflow to help you get started with Actions 2 | 3 | name: Doxygen Action 4 | 5 | # Controls when the action will run. Triggers the workflow on push or pull request 6 | # events but only for the master branch 7 | on: 8 | push: 9 | branches: [ main ] 10 | 11 | 12 | 13 | # A workflow run is made up of one or more jobs that can run sequentially or in parallel 14 | jobs: 15 | # This workflow contains a single job called "build" 16 | build: 17 | # The type of runner that the job will run on 18 | runs-on: ubuntu-latest 19 | 20 | # Steps represent a sequence of tasks that will be executed as part of the job 21 | steps: 22 | # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it 23 | - uses: actions/checkout@v2 24 | 25 | - name: Doxygen Action 26 | uses: mattnotmitt/doxygen-action@v1.1.0 27 | with: 28 | # Path to Doxyfile 29 | doxyfile-path: "./docs/Doxyfile" # default is ./Doxyfile 30 | # Working directory 31 | working-directory: "." # default is . 32 | 33 | - name: Deploy 34 | uses: peaceiris/actions-gh-pages@v3 35 | with: 36 | github_token: ${{ secrets.GITHUB_TOKEN }} 37 | # Default Doxyfile build documentation to html directory. 38 | # Change the directory if changes in Doxyfile 39 | publish_dir: ./html 40 | -------------------------------------------------------------------------------- /.github/workflows/mlbridge_build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Build script for GitHub Actions 4 | 5 | # Setup gRPC 6 | export MY_INSTALL_DIR=$HOME/.local 7 | mkdir -p $MY_INSTALL_DIR 8 | export PATH="$MY_INSTALL_DIR/bin:$PATH" 9 | git clone --recurse-submodules -b v1.58.0 --depth 1 --shallow-submodules https://github.com/grpc/grpc 10 | 11 | cd grpc 12 | mkdir -p cmake/build 13 | pushd cmake/build 14 | cmake -DgRPC_INSTALL=ON \ 15 | -DgRPC_BUILD_TESTS=OFF \ 16 | -DCMAKE_INSTALL_PREFIX=$MY_INSTALL_DIR \ 17 | -DABSL_PROPAGATE_CXX_STD=OFF \ 18 | -DCMAKE_CXX_STANDARD=17 \ 19 | -G Ninja \ 20 | ../.. 21 | ninja 22 | ninja install 23 | popd 24 | 25 | # Setup ONNXRuntime 26 | cd $HOME 27 | wget https://github.com/microsoft/onnxruntime/releases/download/v1.16.3/onnxruntime-linux-x64-1.16.3.tgz 28 | tar -xzf onnxruntime-linux-x64-1.16.3.tgz 29 | 30 | # Setup MLCompilerBridge 31 | REPO_DIR=$GITHUB_WORKSPACE 32 | BUILD=$1 33 | 34 | if [[ -z "$BUILD" ]]; then 35 | echo "Pls provide build type" 36 | echo "exiting..." 37 | exit 1 38 | fi 39 | 40 | cmake \ 41 | -G Ninja \ 42 | -S $REPO_DIR \ 43 | -B $REPO_DIR/build_${BUILD,,} \ 44 | -DONNXRUNTIME_ROOTDIR=$HOME/onnxruntime-linux-x64-1.16.3 \ 45 | -DTENSORFLOW_AOT_PATH=$CONDA/envs/mlbridge/lib/python3.10/site-packages/tensorflow \ 46 | -DCMAKE_BUILD_TYPE="$BUILD" \ 47 | -DCMAKE_INSTALL_PREFIX=$REPO_DIR/install \ 48 | -DPROTOS_DIRECTORY=$REPO_DIR/test/protos \ 49 | -DPYTHON_UTILITIES_DIRECTORY=$REPO_DIR/test \ 50 | -DMLBRIDGE_ENABLE_TEST=ON 51 | 52 | cd $REPO_DIR/build_${BUILD,,} 53 | ninja install 54 | -------------------------------------------------------------------------------- /.github/workflows/upload_pypi.yml: -------------------------------------------------------------------------------- 1 | name: Upload to PyPI 2 | 3 | on: 4 | release: 5 | types: 6 | - published 7 | workflow_dispatch: 8 | inputs: 9 | pypi_repo: 10 | description: 'Repo to upload to pypi' 11 | default: 'pypi' 12 | required: true 13 | type: choice 14 | options: 15 | - testpypi 16 | - pypi 17 | 18 | jobs: 19 | build_wheels: 20 | uses: ./.github/workflows/wheel.yml 21 | 22 | upload_pypi: 23 | permissions: 24 | id-token: write 25 | needs: [build_wheels] 26 | runs-on: ubuntu-latest 27 | steps: 28 | - uses: actions/download-artifact@v3 29 | with: 30 | name: artifact 31 | path: ./CompilerInterface/dist 32 | 33 | # - name: Publish package to TestPyPI 34 | # uses: pypa/gh-action-pypi-publish@v1.8.5 35 | # with: 36 | # repository-url: https://test.pypi.org/legacy/ 37 | # packages-dir: ./CompilerInterface/dist 38 | # if: ${{ github.event.inputs.pypi_repo != 'pypi' }} 39 | 40 | - name: Publish package to PyPI 41 | uses: pypa/gh-action-pypi-publish@v1.8.5 42 | with: 43 | repository-url: https://upload.pypi.org/legacy/ 44 | packages-dir: ./CompilerInterface/dist 45 | # if: ${{ github.event.inputs.pypi_repo == 'pypi' }} 46 | -------------------------------------------------------------------------------- /.github/workflows/wheel.yml: -------------------------------------------------------------------------------- 1 | name: Build wheels 2 | 3 | on: [push, workflow_dispatch, workflow_call] 4 | 5 | jobs: 6 | build_wheels: 7 | name: Build wheels on ${{ matrix.os }} 8 | runs-on: ${{ matrix.os }} 9 | strategy: 10 | matrix: 11 | os: [ubuntu-20.04] 12 | 13 | steps: 14 | - uses: actions/checkout@v3 15 | 16 | - name: Build wheels 17 | run: | 18 | cd $GITHUB_WORKSPACE/CompilerInterface 19 | python fetch_version.py 20 | cp ../README.md ./ 21 | pip wheel . -w ./dist 22 | pip install dist/compilerinterface*.whl 23 | 24 | - uses: actions/upload-artifact@v3 25 | with: 26 | name: artifact 27 | path: ./CompilerInterface/dist 28 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | build* 3 | install* 4 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.0.1 4 | hooks: 5 | - id: check-yaml 6 | - id: end-of-file-fixer 7 | - id: trailing-whitespace 8 | exclude: test-suite/oracle 9 | - repo: https://github.com/psf/black 10 | rev: 22.3.0 11 | hooks: 12 | - id: black 13 | - repo: https://github.com/pocc/pre-commit-hooks 14 | rev: v1.1.1 15 | hooks: 16 | - id: clang-format 17 | args: [-i, --version=14] 18 | # - repo: https://github.com/dfm/black_nbconvert 19 | # rev: v0.4.0 20 | # hooks: 21 | # - id: black_nbconvert 22 | - repo: https://github.com/maxwinterstein/shfmt-py 23 | rev: v3.4.3.1 24 | hooks: 25 | - id: shfmt 26 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.10) 2 | 3 | project(MLCompilerBridge VERSION 0.0.2) 4 | add_compile_options("$<$:-UNDEBUG>") 5 | set(protobuf_MODULE_COMPATIBLE TRUE) 6 | find_package(Protobuf CONFIG REQUIRED) 7 | set(protobuf_MODULE_COMPATIBLE TRUE) 8 | find_package(Protobuf CONFIG REQUIRED) 9 | 10 | set(CMAKE_MODULE_PATH 11 | "${CMAKE_CURRENT_SOURCE_DIR}/cmake" 12 | "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules" 13 | ${CMAKE_MODULE_PATH} 14 | ) 15 | message("CMAKE Module path: ${CMAKE_MODULE_PATH}") 16 | set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib CACHE PATH "Output directory for static libraries") 17 | 18 | include_directories(${Protobuf_INCLUDE_DIRS} ${CMAKE_CURRENT_SOURCE_DIR}/include) 19 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-rtti -fPIC") 20 | set (CMAKE_CXX_STANDARD 17) 21 | 22 | option(LLVM_MLBRIDGE "MLCompilerBridge install for LLVM" OFF) 23 | 24 | if (NOT CMAKE_BUILD_TYPE) 25 | set(CMAKE_BUILD_TYPE Release) 26 | endif() 27 | 28 | string( TOLOWER "${CMAKE_BUILD_TYPE}" MLBRIDGE_CMAKE_BUILD_TYPE) 29 | if(MLBRIDGE_CMAKE_BUILD_TYPE MATCHES debug) 30 | option(MLBRIDGE_DEBUG_MODE "Enable debug mode" ON) 31 | else() 32 | option(MLBRIDGE_DEBUG_MODE "Enable debug mode" OFF) 33 | endif() 34 | 35 | if(NOT LLVM_MLBRIDGE) 36 | find_package(LLVM 10.0.0 REQUIRED CONFIG) 37 | include_directories(${LLVM_INCLUDE_DIRS}) 38 | link_directories(${LLVM_LIBRARY_DIR}) 39 | endif() 40 | 41 | if(MLBRIDGE_DEBUG_MODE) 42 | add_compile_definitions(DEBUG_MODE) 43 | endif() 44 | 45 | if(LLVM_MLBRIDGE) 46 | include(AddLLVM) 47 | include(HandleLLVMOptions) 48 | include(LLVMDistributionSupport) 49 | endif() 50 | 51 | add_subdirectory(MLModelRunner) 52 | add_subdirectory(SerDes) 53 | if(MLBRIDGE_ENABLE_TEST) 54 | add_subdirectory(test) 55 | endif() 56 | 57 | 58 | if(LLVM_MLBRIDGE) 59 | add_llvm_library(LLVMMLBridge 60 | tools.cpp 61 | 62 | ADDITIONAL_HEADER_DIRS 63 | ${CMAKE_CURRENT_SOURCE_DIR}/include 64 | 65 | LINK_LIBS 66 | ModelRunnerLib 67 | SerDesLib 68 | ) 69 | 70 | target_include_directories(LLVMMLBridge SYSTEM PUBLIC ${Protobuf_INCLUDE_DIRS} ${TENSORFLOW_AOT_PATH}/include) 71 | target_include_directories(LLVMMLBridge PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include) 72 | install(TARGETS LLVMMLBridge DESTINATION lib) 73 | add_custom_command(TARGET LLVMMLBridge 74 | POST_BUILD 75 | COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/CompilerInterface ${CMAKE_CURRENT_BINARY_DIR}/CompilerInterface 76 | ) 77 | 78 | else() 79 | llvm_map_components_to_libnames(llvm_libs support core irreader analysis TransformUtils) 80 | 81 | add_library(MLCompilerBridge STATIC tools.cpp) 82 | target_link_libraries(MLCompilerBridge PUBLIC SerDesLib ModelRunnerLib ONNXModelRunnerLib ${llvm_libs}) 83 | set_target_properties(MLCompilerBridge PROPERTIES ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) 84 | set_property(TARGET MLCompilerBridge PROPERTY POSITION_INDEPENDENT_CODE 1) 85 | install(TARGETS MLCompilerBridge DESTINATION lib) 86 | 87 | 88 | add_library(MLCompilerBridgeC STATIC $) 89 | target_link_libraries(MLCompilerBridgeC PUBLIC SerDesCLib ModelRunnerCLib ONNXModelRunnerLib ${llvm_libs}) 90 | target_include_directories(MLCompilerBridgeC PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include ${LLVM_INCLUDE_DIRS}) 91 | target_compile_features(MLCompilerBridgeC PRIVATE cxx_std_17) 92 | target_compile_definitions(MLCompilerBridgeC PRIVATE C_LIBRARY) 93 | set_property(TARGET MLCompilerBridgeC PROPERTY POSITION_INDEPENDENT_CODE 1) 94 | set_target_properties(MLCompilerBridgeC PROPERTIES ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) 95 | install(TARGETS MLCompilerBridgeC DESTINATION lib) 96 | if(MLBRIDGE_ENABLE_TEST) 97 | add_executable(MLCompilerBridgeTest $) 98 | target_link_libraries(MLCompilerBridgeTest PUBLIC MLCompilerBridge) 99 | set_target_properties(MLCompilerBridgeTest PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) 100 | install(TARGETS MLCompilerBridgeTest DESTINATION bin) 101 | endif() 102 | endif(LLVM_MLBRIDGE) 103 | 104 | install(DIRECTORY include/ DESTINATION include) 105 | install(DIRECTORY CompilerInterface DESTINATION MLModelRunner/CompilerInterface) 106 | file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/CompilerInterface DESTINATION ${CMAKE_BINARY_DIR}/MLModelRunner/) 107 | -------------------------------------------------------------------------------- /CompilerInterface/compilerinterface/BaseCompilerInterface.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # 3 | # Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | # Exceptions. See the LICENSE file for license information. 5 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | # 7 | # ------------------------------------------------------------------------------ 8 | ## 9 | ## @file 10 | ## This file contains the abstract class for compiler interface. It specifies the 11 | ## methods for communication with compiler. It also initializes the correct 12 | ## SerDes object for serialization and deserialization of data. 13 | ## 14 | # ------------------------------------------------------------------------------ 15 | 16 | 17 | from abc import ABC, abstractmethod 18 | from .SerDes import SerDes 19 | 20 | ## This base class specifies methods for communication with compiler. 21 | class BaseCompilerInterface(ABC): 22 | ## Initializes correct SerDes object 23 | def __init__(self, data_format=None): 24 | self.serdes_obj = SerDes(data_format) 25 | 26 | ## Places data for next request into a buffer after serialization. 27 | # @param Unserialized data for next query to compiler 28 | def populate_buffer(self, data): 29 | self.serdes_obj.serializeData(data) 30 | 31 | @abstractmethod 32 | ## Sends query to compiler and returns deserialized result. 33 | def evaluate(self): 34 | pass 35 | -------------------------------------------------------------------------------- /CompilerInterface/compilerinterface/GrpcCompilerInterface.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # 3 | # Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | # Exceptions. See the LICENSE file for license information. 5 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | # 7 | # ------------------------------------------------------------------------------ 8 | ## 9 | ## @file 10 | ## Compiler interface for gRPC. This class implements methods for communication 11 | ## with compiler using gRPC. 12 | ## 13 | # ------------------------------------------------------------------------------ 14 | 15 | 16 | from .BaseCompilerInterface import BaseCompilerInterface 17 | import time 18 | 19 | import grpc 20 | from concurrent import futures 21 | 22 | # requires grpc stub evaluate, service_server obj and adder method 23 | 24 | ## This class implements methods for communication with compiler using gRPC. 25 | class GrpcCompilerInterface(BaseCompilerInterface): 26 | ## Initializes GrpcCompilerInterface object. 27 | # @param mode Can be 'client' or 'server'. 28 | # @param stub_class gRPC stub class used in 'client' mode 29 | # @param hostip 30 | # @param hostport 31 | # @param add_server_method used in 'server' mode 32 | # @param grpc_service_obj used in 'server' mode 33 | def __init__( 34 | self, 35 | mode, 36 | stub_class=None, 37 | hostip="127.0.0.1", 38 | hostport=50051, 39 | add_server_method=None, 40 | grpc_service_obj=None, 41 | ): 42 | super().__init__("protobuf") 43 | self.mode = mode 44 | self.host = hostip 45 | self.server_port = hostport 46 | 47 | if self.mode == "client": 48 | self.channel = grpc.insecure_channel( 49 | "{}:{}".format(self.host, self.server_port) 50 | ) 51 | print("Setting stub", stub_class) 52 | self.stub = stub_class(self.channel) 53 | 54 | elif self.mode == "server": 55 | self.grpc_service_obj = grpc_service_obj 56 | self.add_server_method = add_server_method 57 | self.start_server() 58 | 59 | def __del__(self): 60 | pass 61 | 62 | ## Sends query to compiler and returns deserialized result. 63 | def evaluate(self, mode=None): 64 | out = self.serdes_obj.getOutputBuffer() 65 | return self.stub.queryCompiler(out) 66 | 67 | ## Starts gRPC server 68 | def start_server(self): 69 | if self.mode == "server": 70 | server = grpc.server( 71 | futures.ThreadPoolExecutor(max_workers=20), 72 | options=[ 73 | ("grpc.max_send_message_length", 200 * 1024 * 1024), # 50MB 74 | ("grpc.max_receive_message_length", 200 * 1024 * 1024), # 50MB 75 | ], 76 | ) 77 | 78 | self.add_server_method(self.grpc_service_obj, server) 79 | 80 | retries = 0 81 | max_retries = 30 82 | wait_seconds = 0.2 83 | retry_wait_backoff_exponent = 1.2 84 | 85 | while retries < max_retries: 86 | added_port = server.add_insecure_port( 87 | "{}:{}".format(self.host, self.server_port) 88 | ) 89 | 90 | if str(added_port) == str(self.server_port): 91 | server.start() 92 | print("Server Running") 93 | server.wait_for_termination() 94 | break 95 | else: 96 | retries += 1 97 | print( 98 | "The port", 99 | self.server_port, 100 | "is already in use retrying! attempt: ", 101 | retries, 102 | ) 103 | 104 | time.sleep(wait_seconds) 105 | wait_seconds *= retry_wait_backoff_exponent 106 | -------------------------------------------------------------------------------- /CompilerInterface/compilerinterface/PipeCompilerInterface.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # 3 | # Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | # Exceptions. See the LICENSE file for license information. 5 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | # 7 | # ------------------------------------------------------------------------------ 8 | ## 9 | ## @file 10 | ## Compiler interface for pipes. This class implements methods for communication 11 | ## with compiler using pipes. 12 | ## 13 | # ------------------------------------------------------------------------------ 14 | 15 | from .BaseCompilerInterface import BaseCompilerInterface 16 | import os 17 | import io 18 | 19 | ## This class implements methods for communication with compiler using pipes. 20 | class PipeCompilerInterface(BaseCompilerInterface): 21 | ## Initializes PipeCompilerInterface object. 22 | # @param data_format Data format for serialization 23 | # @param pipe_name Name for pipe file 24 | def __init__(self, data_format=None, pipe_name=None): 25 | super().__init__(data_format) 26 | self.pipe_name = pipe_name 27 | self.to_compiler = None 28 | self.from_compiler = None 29 | self.tc_buffer = None 30 | self.fc_buffer = None 31 | self.buffer = None 32 | self.init_pipes() 33 | 34 | def __del__(self): 35 | self.close_pipes() 36 | self.remove_pipes() 37 | 38 | ## Sends query to compiler and returns deserialized result. 39 | def evaluate(self, mode=None): 40 | out = self.serdes_obj.getOutputBuffer() 41 | if out is not None: 42 | self.tc_buffer.write(out) 43 | self.tc_buffer.flush() 44 | 45 | if mode == "exit": 46 | return None 47 | 48 | result = self.serdes_obj.deserializeData(self.fc_buffer) 49 | 50 | return result 51 | 52 | ## Creates pipe files for communication. 53 | def init_pipes(self): 54 | self.to_compiler = self.pipe_name + ".in" 55 | self.from_compiler = self.pipe_name + ".out" 56 | if os.path.exists(self.to_compiler): 57 | os.remove(self.to_compiler) 58 | if os.path.exists(self.from_compiler): 59 | os.remove(self.from_compiler) 60 | 61 | os.mkfifo(self.to_compiler, 0o666) 62 | os.mkfifo(self.from_compiler, 0o666) 63 | 64 | ## Resets the buffered reader/writers. 65 | def reset_pipes(self): 66 | self.tc_buffer = io.BufferedWriter(io.FileIO(self.to_compiler, "wb")) 67 | self.fc_buffer = io.BufferedReader(io.FileIO(self.from_compiler, "rb")) 68 | 69 | ## Closes the buffered reader/writers. 70 | def close_pipes(self): 71 | if self.fc_buffer is not None: 72 | self.tc_buffer.close() 73 | self.fc_buffer.close() 74 | self.tc_buffer = None 75 | self.fc_buffer = None 76 | 77 | ## Deletes the pipe files. 78 | def remove_pipes(self): 79 | os.remove(self.to_compiler) 80 | os.remove(self.from_compiler) 81 | -------------------------------------------------------------------------------- /CompilerInterface/compilerinterface/SerDes.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # 3 | # Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | # Exceptions. See the LICENSE file for license information. 5 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | # 7 | # ------------------------------------------------------------------------------ 8 | ## 9 | ## @file 10 | ## SerDes for JSON and bitstream data. 11 | ## 12 | # ------------------------------------------------------------------------------ 13 | 14 | import json 15 | from . import log_reader 16 | import ctypes 17 | import struct 18 | 19 | 20 | class NpEncoder(json.JSONEncoder): 21 | def default(self, obj): 22 | if isinstance(obj, ctypes.c_long): 23 | return obj.value 24 | if isinstance(obj, ctypes.c_double): 25 | return obj.value 26 | return super(NpEncoder, self).default(obj) 27 | 28 | 29 | ## Class for serialization and deserialization in various formats for communication. 30 | class SerDes: 31 | ## Contructor for SerDes object 32 | # @param data_format Format for serialization and deserialization 33 | def __init__(self, data_format): 34 | self.buffer = None 35 | self.data_format = data_format 36 | self.read_stream_iter = None 37 | 38 | self.serMap = { 39 | "json": self.serializeJson, 40 | "bytes": self.serializeBytes, 41 | "protobuf": self.serializeProtobuf, 42 | } 43 | self.desMap = { 44 | "json": self.deserializeJson, 45 | "bytes": self.deserializeBytes, 46 | "protobuf": self.deserializeProtobuf, 47 | } 48 | 49 | ## Deserializes data for specified data format 50 | # @param rawdata Datastream receiving serialized data 51 | # @return Deserialized data 52 | def deserializeData(self, rawdata): 53 | return self.desMap[self.data_format](rawdata) 54 | 55 | ## Deserializes and returns JSON data 56 | def deserializeJson(self, datastream): 57 | hdr = datastream.read(8) 58 | size = int.from_bytes(hdr, "little") 59 | data = datastream.read(size) 60 | return json.loads(data) 61 | 62 | ## Deserializes and returns bitstream data 63 | def deserializeBytes(self, datastream): 64 | if self.read_stream_iter is None: 65 | self.read_stream_iter = log_reader.read_stream2( 66 | datastream 67 | ) # try to make it indep 68 | hdr = datastream.read(8) 69 | context, observation_id, features, score = next(self.read_stream_iter) 70 | return features 71 | 72 | # Not implemented 73 | def deserializeProtobuf(self, datastream): 74 | raise NotImplementedError 75 | 76 | ## Serializes data and places it in a buffer 77 | def serializeData(self, data): 78 | self.serMap[self.data_format](data) 79 | 80 | ## Serializes data to JSON 81 | def serializeJson(self, data): 82 | msg = json.dumps({"out": data}, cls=NpEncoder).encode("utf-8") 83 | hdr = len(msg).to_bytes(8, "little") 84 | self.buffer = hdr + msg 85 | 86 | ## Serializes data to bitstream 87 | def serializeBytes(self, data): 88 | def _pack(data): 89 | if isinstance(data, int): 90 | return struct.pack("i", data) 91 | elif isinstance(data, float): 92 | return struct.pack("f", data) 93 | elif isinstance(data, str) and len(data) == 1: 94 | return struct.pack("c", data) 95 | elif isinstance(data, ctypes.c_double): 96 | return struct.pack("d", data.value) 97 | elif isinstance(data, ctypes.c_long): 98 | return struct.pack("l", data.value) 99 | elif isinstance(data, list): 100 | return b"".join([_pack(x) for x in data]) 101 | 102 | msg = _pack(data) 103 | hdr = len(msg).to_bytes(8, "little") 104 | self.buffer = hdr + msg 105 | 106 | # Implemented outside for now 107 | def serializeProtobuf(self, data): 108 | self.buffer = data 109 | 110 | ## Returns value in buffer and empties it 111 | # @return Data from output buffer 112 | def getOutputBuffer(self): 113 | out = self.buffer 114 | self.buffer = None 115 | return out 116 | -------------------------------------------------------------------------------- /CompilerInterface/compilerinterface/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # 3 | # Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | # Exceptions. See the LICENSE file for license information. 5 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | # 7 | # ------------------------------------------------------------------------------ 8 | 9 | from .BaseCompilerInterface import BaseCompilerInterface 10 | from .PipeCompilerInterface import PipeCompilerInterface 11 | from .GrpcCompilerInterface import GrpcCompilerInterface 12 | from .SerDes import SerDes 13 | -------------------------------------------------------------------------------- /CompilerInterface/compilerinterface/log_reader.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # 3 | # Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | # Exceptions. See the LICENSE file for license information. 5 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | # 7 | # ------------------------------------------------------------------------------ 8 | ## 9 | ## @file 10 | ## Reader for training log. Used by Pipe compiler interface with bitstream data. 11 | ## 12 | # ------------------------------------------------------------------------------ 13 | 14 | import ctypes 15 | import dataclasses 16 | import io 17 | import json 18 | import sys 19 | from typing import List, Optional 20 | from functools import reduce 21 | import operator 22 | 23 | _element_types = { 24 | "float": ctypes.c_float, 25 | "double": ctypes.c_double, 26 | "int8_t": ctypes.c_int8, 27 | "uint8_t": ctypes.c_uint8, 28 | "int16_t": ctypes.c_int16, 29 | "uint16_t": ctypes.c_uint16, 30 | "int32_t": ctypes.c_int32, 31 | "uint32_t": ctypes.c_uint32, 32 | "int64_t": ctypes.c_int64, 33 | "uint64_t": ctypes.c_uint64, 34 | } 35 | 36 | 37 | @dataclasses.dataclass(frozen=True) 38 | class TensorSpec: 39 | name: str 40 | port: int 41 | shape: List[int] 42 | element_type: type 43 | 44 | @staticmethod 45 | def from_dict(d: dict): 46 | name = d["name"] 47 | port = d["port"] 48 | shape = [int(e) for e in d["shape"]] 49 | element_type_str = d["type"] 50 | if element_type_str not in _element_types: 51 | raise ValueError(f"uknown type: {element_type_str}") 52 | return TensorSpec( 53 | name=name, 54 | port=port, 55 | shape=shape, 56 | element_type=_element_types[element_type_str], 57 | ) 58 | 59 | 60 | class TensorValue: 61 | def __init__(self, spec: TensorSpec, buffer: bytes): 62 | self._spec = spec 63 | self._buffer = buffer 64 | self._view = ctypes.cast(self._buffer, ctypes.POINTER(self._spec.element_type)) 65 | # self._len = math.prod(self._spec.shape) 66 | self._len = reduce(operator.mul, self._spec.shape, 1) 67 | # self._view = numpy.frombuffer(self._buffer, float) 68 | # print("Value of", self._spec.name, "is:", self._view) 69 | 70 | def spec(self) -> TensorSpec: 71 | return self._spec 72 | 73 | def __len__(self) -> int: 74 | return self._len 75 | 76 | def __getitem__(self, index): 77 | if index < 0 or index >= self._len: 78 | raise IndexError(f"Index {index} out of range [0..{self._len})") 79 | return self._view[index] 80 | 81 | 82 | def read_tensor(fs: io.BufferedReader, ts: TensorSpec) -> TensorValue: 83 | size = reduce(operator.mul, ts.shape, 1) * ctypes.sizeof(ts.element_type) 84 | # size = math.prod(ts.shape) * ctypes.sizeof(ts.element_type) 85 | data = fs.read(size) 86 | return TensorValue(ts, data) 87 | 88 | 89 | def pretty_print_tensor_value(tv: TensorValue): 90 | print(f'{tv.spec().name}: {",".join([str(v) for v in tv])}') 91 | 92 | 93 | def read_header(f: io.BufferedReader): 94 | line = f.readline() 95 | header = json.loads(line) 96 | tensor_specs = [TensorSpec.from_dict(ts) for ts in header["features"]] 97 | score_spec = TensorSpec.from_dict(header["score"]) if "score" in header else None 98 | advice_spec = TensorSpec.from_dict(header["advice"]) if "advice" in header else None 99 | return tensor_specs, score_spec, advice_spec 100 | 101 | 102 | def read_one_observation( 103 | context: Optional[str], 104 | event_str: str, 105 | f: io.BufferedReader, 106 | tensor_specs: List[TensorSpec], 107 | score_spec: Optional[TensorSpec], 108 | ): 109 | features = [] 110 | for ts in tensor_specs: 111 | features.append(read_tensor(f, ts)) 112 | f.readline() 113 | return context, None, features, None 114 | 115 | 116 | def read_stream(fname: str): 117 | with io.BufferedReader(io.FileIO(fname, "rb")) as f: 118 | tensor_specs, score_spec, _ = read_header(f) 119 | context = None 120 | while True: 121 | event_str = f.readline() 122 | if not event_str: 123 | break 124 | context, observation_id, features, score = read_one_observation( 125 | context, event_str, f, tensor_specs, score_spec 126 | ) 127 | yield context, observation_id, features, score 128 | 129 | 130 | def read_stream2(f: io.BufferedReader): 131 | context = None 132 | while True: 133 | tensor_specs, score_spec, _ = read_header(f) 134 | # event_str = f.readline() 135 | # print("Event: ", event_str) 136 | # if not event_str: 137 | # break 138 | context, observation_id, features, score = read_one_observation( 139 | context, "", f, tensor_specs, score_spec 140 | ) 141 | yield context, observation_id, features, score 142 | 143 | 144 | def main(args): 145 | last_context = None 146 | for ctx, obs_id, features, score in read_stream(args[1]): 147 | if last_context != ctx: 148 | print(f"context: {ctx}") 149 | last_context = ctx 150 | print(f"observation: {obs_id}") 151 | for fv in features: 152 | pretty_print_tensor_value(fv) 153 | if score: 154 | pretty_print_tensor_value(score) 155 | 156 | 157 | if __name__ == "__main__": 158 | main(sys.argv) 159 | -------------------------------------------------------------------------------- /CompilerInterface/fetch_version.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # 3 | # Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | # Exceptions. See the LICENSE file for license information. 5 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | # 7 | # ------------------------------------------------------------------------------ 8 | 9 | import re 10 | 11 | version_regex = re.compile(r"^project\(MLCompilerBridge VERSION (?P[^)]+)\)$") 12 | toml_field_regex = r'version[ ]*=[ ]*"(.*)"' 13 | 14 | VERSION = "" 15 | with open("../CMakeLists.txt", "r") as f: 16 | for line in f: 17 | vmatch = version_regex.match(line) # Not using walrus because Python3.6 18 | if vmatch: 19 | VERSION = vmatch.group("version") 20 | break 21 | 22 | print("Version detected =", VERSION) 23 | lines = [] 24 | with open("./pyproject.toml", "r") as f: 25 | lines = f.readlines() 26 | 27 | with open("./pyproject.toml", "w") as f: 28 | for line in lines: 29 | if re.search(toml_field_regex, line): 30 | new_text = f'version = "{VERSION}"\n' 31 | f.write(new_text) 32 | else: 33 | f.write(line) 34 | -------------------------------------------------------------------------------- /CompilerInterface/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "compilerinterface" 3 | version = "0.0.1" 4 | authors = [ {name = "The authors of \"The Next 700 ML-Enabled Compiler Optimizations\"" }] 5 | description = "Communication framework for ML-Enabled Compiler Optimizations." 6 | license = { text= "Apache License v2.0 with LLVM Exceptions"} 7 | readme = "README.md" 8 | requires-python = ">=3.7" 9 | classifiers = [ 10 | "Programming Language :: Python :: 3", 11 | "Programming Language :: Python :: 3.7", 12 | "Programming Language :: Python :: 3.8", 13 | "Programming Language :: Python :: 3.9", 14 | "Programming Language :: Python :: 3.10", 15 | "Programming Language :: Python :: 3.11", 16 | "Intended Audience :: Developers", 17 | "Intended Audience :: Science/Research", 18 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 19 | "Topic :: Software Development :: Compilers", 20 | "Development Status :: 4 - Beta", 21 | "Operating System :: POSIX :: Linux", 22 | ] 23 | 24 | [build-system] 25 | requires = [ 26 | "setuptools>=42", 27 | "wheel" 28 | ] 29 | build-backend = "setuptools.build_meta" 30 | -------------------------------------------------------------------------------- /CompilerInterface/setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | url = https://github.com/IITH-Compilers/ML-Compiler-Bridge/ 3 | -------------------------------------------------------------------------------- /MLModelRunner/C/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_library(ModelRunnerCWrapper OBJECT PipeModelRunnerCWrapper.cpp ONNXModelRunnerCWrapper.cpp) 2 | target_include_directories(ModelRunnerCWrapper PUBLIC ${TENSORFLOW_AOT_PATH}/include) 3 | -------------------------------------------------------------------------------- /MLModelRunner/C/ONNXModelRunnerCWrapper.cpp: -------------------------------------------------------------------------------- 1 | //=== MLModelRunner/C/ONNXModelRunner.cpp - C API for ONNXModelRunner -----===// 2 | // 3 | // Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | // Exceptions. See the LICENSE file for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // 7 | //===----------------------------------------------------------------------===// 8 | /// 9 | /// \file 10 | /// This file defines the C APIs for ONNXModelRunner. 11 | /// 12 | //===----------------------------------------------------------------------===// 13 | 14 | #include "MLModelRunner/C/ONNXModelRunner.h" 15 | #include "MLModelRunner/ONNXModelRunner/agent.h" 16 | #include "MLModelRunner/ONNXModelRunner/utils.h" 17 | #include "MLModelRunner/Utils/Debug.h" 18 | #include "llvm/ADT/SmallVector.h" 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | 25 | using namespace MLBridge; 26 | struct ONNXModelRunner { 27 | Environment *env; 28 | Agent *agent; 29 | std::map agents; 30 | ONNXModelRunner(Environment *env, std::map &&agents) 31 | : env(env), agents(agents) {} 32 | 33 | ONNXModelRunner(Agent *agent) : agent(agent) {} 34 | }; 35 | 36 | struct Environment { 37 | private: 38 | // Function pointer to the step and reset functions 39 | float *(*stepFunc)(Action action); 40 | float *(*resetFunc)(); 41 | int numFeatures; 42 | std::string nextAgent; 43 | bool done; 44 | 45 | public: 46 | Environment() : stepFunc(nullptr), resetFunc(nullptr) {} 47 | // EnvironmentImpl(float *(*stepFunc)(Action action), float *(*resetFunc)()) 48 | // : stepFunc(stepFunc), resetFunc(resetFunc) {} 49 | 50 | void setNumFeatures(int numFeatures) { this->numFeatures = numFeatures; } 51 | 52 | void setStepFunc(float *(*stepFunc)(Action action)) { 53 | this->stepFunc = stepFunc; 54 | } 55 | 56 | void setResetFunc(float *(*resetFunc)()) { this->resetFunc = resetFunc; } 57 | 58 | void setNextAgent(char *agentName) { nextAgent = agentName; } 59 | 60 | std::string getNextAgent() { return nextAgent; } 61 | 62 | Observation step(Action action) { 63 | assert(stepFunc != nullptr && 64 | "Step function is null! Define step function on env"); 65 | float *stepRes = stepFunc(action); 66 | return llvm::SmallVector(stepRes, stepRes + numFeatures); 67 | } 68 | 69 | Observation reset() { 70 | assert(resetFunc != nullptr && 71 | "Reset function is null! Define reset function on env"); 72 | float *resetRes = resetFunc(); 73 | return llvm::SmallVector(resetRes, resetRes + numFeatures); 74 | } 75 | 76 | bool checkDone() { return done; } 77 | void setDone() { done = true; } 78 | void resetDone() { done = false; } 79 | }; 80 | 81 | Environment *createEnvironment() { return new Environment(); } 82 | 83 | void env_setDone(Environment *env) { env->setDone(); } 84 | 85 | void env_resetDone(Environment *env) { env->resetDone(); } 86 | 87 | bool env_checkDone(Environment *env) { return env->checkDone(); } 88 | 89 | void env_setNumFeatures(Environment *env, int numFeatures) { 90 | env->setNumFeatures(numFeatures); 91 | } 92 | 93 | void env_setStepFunc(Environment *env, float *(*stepFunc)(Action action)) { 94 | env->setStepFunc(stepFunc); 95 | } 96 | 97 | void env_setResetFunc(Environment *env, float *(*resetFunc)()) { 98 | env->setResetFunc(resetFunc); 99 | } 100 | 101 | void env_setNextAgent(Environment *env, char *agentName) { 102 | env->setNextAgent(agentName); 103 | } 104 | 105 | ONNXModelRunner *createONNXModelRunner(Environment *env, int numAgents, ...) { 106 | assert(env != nullptr && "Environment is null!"); 107 | 108 | va_list args; 109 | va_start(args, numAgents); 110 | std::map agents; 111 | 112 | for (int i = 0; i < numAgents; i += 2) { 113 | char *agentName = va_arg(args, char *); 114 | char *agentPath = va_arg(args, char *); 115 | agents[agentName] = new Agent(agentPath); 116 | } 117 | 118 | va_end(args); 119 | 120 | ONNXModelRunner *obj = new ONNXModelRunner(env, std::move(agents)); 121 | return obj; 122 | } 123 | 124 | ONNXModelRunner *createSingleAgentOMR(char *agent_path) { 125 | Agent *agent = new Agent(agent_path); 126 | ONNXModelRunner *obj = new ONNXModelRunner(agent); 127 | return obj; 128 | } 129 | 130 | void evaluate(ONNXModelRunner *omr) { 131 | auto x = omr->env->reset(); 132 | 133 | while (true) { 134 | Action action; 135 | // current agent 136 | // auto current_agent = omr->agents[omr->env->getNextAgent()]; 137 | Agent *current_agent = omr->agent; 138 | action = current_agent->computeAction(x); 139 | MLBRIDGE_DEBUG(std::cout << "Action: " << action << "\n"); 140 | x = omr->env->step(action); 141 | if (omr->env->checkDone()) { 142 | MLBRIDGE_DEBUG(std::cout << "Done🎉\n"); 143 | break; 144 | } 145 | } 146 | } 147 | 148 | int singleAgentEvaluate(ONNXModelRunner *obj, float *inp, int inp_size) { 149 | Observation obs(inp, inp + inp_size); 150 | Action action = obj->agent->computeAction(obs); 151 | MLBRIDGE_DEBUG(std::cout << "action :: " << action << "\n"); 152 | return action; 153 | } 154 | 155 | void destroyEnvironment(Environment *env) { delete env; } 156 | 157 | void destroyONNXModelRunner(ONNXModelRunner *omr) { delete omr; } 158 | -------------------------------------------------------------------------------- /MLModelRunner/C/PipeModelRunnerCWrapper.cpp: -------------------------------------------------------------------------------- 1 | //=== MLModelRunner/C/PipeModelRunner.cpp - C API for PipeModelRunner -----===// 2 | // 3 | // Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | // Exceptions. See the LICENSE file for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // 7 | //===----------------------------------------------------------------------===// 8 | /// 9 | /// \file 10 | /// This file defines the C APIs for PipeModelRunner. 11 | /// 12 | //===----------------------------------------------------------------------===// 13 | 14 | #include "MLModelRunner/C/PipeModelRunner.h" 15 | 16 | #include "MLModelRunner/MLModelRunner.h" 17 | #include "MLModelRunner/PipeModelRunner.h" 18 | 19 | #include 20 | 21 | using namespace MLBridge; 22 | struct PipeModelRunnerWrapper { 23 | MLModelRunner *model; 24 | }; 25 | 26 | PipeModelRunnerWrapper *createPipeModelRunner(const char *OutboundName, 27 | const char *InboundName, 28 | int SerDesType) { 29 | PipeModelRunnerWrapper *obj = new PipeModelRunnerWrapper(); 30 | obj->model = 31 | new PipeModelRunner(OutboundName, InboundName, (SerDesKind)SerDesType); 32 | return obj; 33 | } 34 | 35 | void populateFloatFeatures(PipeModelRunnerWrapper *obj, const char *name, 36 | const float *data, const int size) { 37 | auto dataVec = std::vector(data, data + size); 38 | std::pair> p1(name, dataVec); 39 | obj->model->populateFeatures(p1); 40 | } 41 | 42 | void populateIntFeatures(PipeModelRunnerWrapper *obj, const char *name, 43 | const int *data, const int size) { 44 | auto dataVec = std::vector(data, data + size); 45 | std::pair> p1(name, dataVec); 46 | obj->model->populateFeatures(p1); 47 | } 48 | 49 | int evaluateIntFeatures(PipeModelRunnerWrapper *obj) { 50 | return obj->model->evaluate(); 51 | } 52 | 53 | float evaluateFloatFeatures(PipeModelRunnerWrapper *obj) { 54 | return obj->model->evaluate(); 55 | } 56 | 57 | void destroyPipeModelRunner(PipeModelRunnerWrapper *obj) { 58 | delete obj->model; 59 | delete obj; 60 | } 61 | -------------------------------------------------------------------------------- /MLModelRunner/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(Utils) 2 | set(PROTOS_DIRECTORY "" CACHE PATH "Path to the directory containing the proto files") 3 | if(NOT PROTOS_DIRECTORY STREQUAL "") 4 | add_subdirectory(gRPCModelRunner) 5 | endif() 6 | add_subdirectory(ONNXModelRunner) 7 | add_subdirectory(C) 8 | 9 | # # For up-to-date instructions for installing the TFLite dependency, refer to 10 | # # the bot setup script: https://github.com/google/ml-compiler-opt/blob/main/buildbot/buildbot_init.sh 11 | # set(LLVM_HAVE_TFLITE "" CACHE BOOL "Use tflite") 12 | # if (LLVM_HAVE_TFLITE) 13 | # find_package(tensorflow-lite REQUIRED) 14 | # endif() 15 | 16 | # # For up-to-date instructions for installing the Tensorflow dependency, refer to 17 | # # the bot setup script: https://github.com/google/ml-compiler-opt/blob/main/buildbot/buildbot_init.sh 18 | # # Specifically, assuming python3 is installed: 19 | # # python3 -m pip install --upgrade pip && python3 -m pip install --user tf_nightly==2.3.0.dev20200528 20 | # # Then set TENSORFLOW_AOT_PATH to the package install - usually it's ~/.local/lib/python3.7/site-packages/tensorflow 21 | # # 22 | set(TENSORFLOW_AOT_PATH "" CACHE PATH "Path to TensorFlow pip install dir") 23 | 24 | if (NOT TENSORFLOW_AOT_PATH STREQUAL "") 25 | include_directories(${TENSORFLOW_AOT_PATH}/include) 26 | add_subdirectory(${TENSORFLOW_AOT_PATH}/xla_aot_runtime_src 27 | ${CMAKE_ARCHIVE_OUTPUT_DIRECTORY}/tf_runtime) 28 | if(LLVM_MLBRIDGE) 29 | install(TARGETS tf_xla_runtime EXPORT LLVMExports 30 | ARCHIVE DESTINATION lib${LLVM_LIBDIR_SUFFIX} COMPONENT tf_xla_runtime) 31 | set_property(GLOBAL APPEND PROPERTY LLVM_EXPORTS tf_xla_runtime) 32 | else() 33 | install(TARGETS tf_xla_runtime ARCHIVE DESTINATION ${CMAKE_ARCHIVE_OUTPUT_DIRECTORY}) 34 | endif() 35 | endif() 36 | 37 | add_library(ModelRunnerCLib OBJECT PipeModelRunner.cpp) 38 | target_link_libraries(ModelRunnerCLib PUBLIC ModelRunnerCUtils ONNXModelRunnerLib) 39 | target_compile_definitions(ModelRunnerCLib PRIVATE C_LIBRARY) 40 | 41 | if(LLVM_MLBRIDGE) 42 | add_llvm_library(ModelRunnerLib PipeModelRunner.cpp) 43 | else() 44 | add_library(ModelRunnerLib OBJECT PipeModelRunner.cpp) 45 | endif(LLVM_MLBRIDGE) 46 | 47 | target_link_libraries(ModelRunnerLib PUBLIC ModelRunnerUtils ONNXModelRunnerLib) 48 | 49 | if(NOT PROTOS_DIRECTORY STREQUAL "") 50 | target_link_libraries(ModelRunnerLib PUBLIC gRPCModelRunnerLib) 51 | endif() 52 | set_property(TARGET ModelRunnerLib PROPERTY POSITION_INDEPENDENT_CODE 1) 53 | -------------------------------------------------------------------------------- /MLModelRunner/ONNXModelRunner/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -DORT_NO_EXCEPTIONS") 2 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wall -Wextra") 3 | 4 | find_package(Onnxruntime REQUIRED) 5 | 6 | #onnxruntime providers 7 | option(onnxruntime_USE_CUDA "Build with CUDA support" OFF) 8 | option(onnxruntime_USE_OPENVINO "Build with OpenVINO support" OFF) 9 | option(onnxruntime_USE_NNAPI_BUILTIN "Build with builtin NNAPI lib for Android NNAPI support" OFF) 10 | option(onnxruntime_USE_DNNL "Build with DNNL support" OFF) 11 | option(onnxruntime_USE_NUPHAR "Build with Nuphar" OFF) 12 | option(onnxruntime_USE_TENSORRT "Build with TensorRT support" OFF) 13 | 14 | set(ONNXRUNTIME_ROOTDIR "" CACHE PATH "Directory that contains ONNXRuntime" ) 15 | if(NOT ONNXRUNTIME_ROOTDIR) 16 | message( FATAL_ERROR "Set path to Onnx runtime in -DONNXRUNTIME_ROOTDIR variable" ) 17 | endif() 18 | 19 | if(onnxruntime_USE_CUDA) 20 | add_definitions(-DUSE_CUDA) 21 | endif() 22 | if(onnxruntime_USE_OPENVINO) 23 | add_definitions(-DUSE_OPENVINO) 24 | endif() 25 | if(onnxruntime_USE_NNAPI_BUILTIN) 26 | add_definitions(-DUSE_NNAPI) 27 | endif() 28 | if(onnxruntime_USE_DNNL) 29 | add_definitions(-DUSE_DNNL) 30 | endif() 31 | if(onnxruntime_USE_NUPHAR) 32 | add_definitions(-DUSE_NUPHAR) 33 | endif() 34 | if(onnxruntime_USE_TENSORRT) 35 | add_definitions(-DUSE_TENSORRT) 36 | endif() 37 | if(onnxruntime_USE_DML) 38 | message("Enabling DML") 39 | add_definitions(-DUSE_DML) 40 | endif() 41 | 42 | if(LLVM_MLBRIDGE) 43 | add_llvm_component_library(ONNXModelRunnerLib 44 | onnx.cpp 45 | agent.cpp 46 | ONNXModelRunner.cpp 47 | ) 48 | else() 49 | add_library(ONNXModelRunnerLib OBJECT onnx.cpp 50 | agent.cpp 51 | ONNXModelRunner.cpp 52 | ) 53 | endif(LLVM_MLBRIDGE) 54 | 55 | target_link_libraries(ONNXModelRunnerLib PRIVATE Onnxruntime::Onnxruntime) 56 | target_include_directories(ONNXModelRunnerLib PUBLIC ${TENSORFLOW_AOT_PATH}/include) 57 | -------------------------------------------------------------------------------- /MLModelRunner/ONNXModelRunner/ONNXModelRunner.cpp: -------------------------------------------------------------------------------- 1 | //===- ONNXModelRunner.cpp - ONNX Runner ------------------------*- C++ -*-===// 2 | // 3 | // Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | // Exceptions. See the LICENSE file for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // 7 | //===----------------------------------------------------------------------===// 8 | /// 9 | /// \file 10 | /// This file defines the ONNXModelRunner class to support ML model inference 11 | /// via ONNX. 12 | /// 13 | //===----------------------------------------------------------------------===// 14 | 15 | #include "MLModelRunner/ONNXModelRunner/ONNXModelRunner.h" 16 | #include "SerDes/baseSerDes.h" 17 | 18 | using namespace llvm; 19 | namespace MLBridge { 20 | 21 | ONNXModelRunner::ONNXModelRunner(Environment *env, 22 | std::map agents, 23 | LLVMContext *Ctx) 24 | : MLModelRunner(Kind::ONNX, Ctx), env(env), agents(agents) {} 25 | 26 | void ONNXModelRunner::addAgent(Agent *agent, std::string name) { 27 | if (agents.find(name) == agents.end()) { 28 | agents[name] = agent; 29 | } else { 30 | // throw error 31 | std::cerr << "ERROR: Agent with the name " << name 32 | << " already exists. Please give a different name!\n"; 33 | exit(1); 34 | } 35 | } 36 | 37 | void ONNXModelRunner::computeAction(Observation &obs) { 38 | while (true) { 39 | Action action; 40 | // current agent 41 | auto current_agent = this->agents[this->env->getNextAgent()]; 42 | action = current_agent->computeAction(obs); 43 | this->env->step(action); 44 | if (this->env->checkDone()) { 45 | std::cout << "Done🎉\n"; 46 | break; 47 | } 48 | } 49 | } 50 | 51 | void *ONNXModelRunner::evaluateUntyped() { 52 | Observation &obs = env->reset(); 53 | computeAction(obs); 54 | return new int(0); 55 | } 56 | 57 | } // namespace MLBridge 58 | -------------------------------------------------------------------------------- /MLModelRunner/ONNXModelRunner/README.md: -------------------------------------------------------------------------------- 1 | # ONNX Model Runner (In-Process Model Runner) 2 | 3 | A ONNX based Model Runner to support integrtion of trained models with the compiler during inference. 4 | 5 | ## Trained model integration 6 | 7 | Integration of a trained model happnes in two steps. 8 | 9 | * Step 1: The model trained on some native environment need to be exported in ONNX format 10 | 11 | * Step 2: The model can be queried anywhere in the compiler environment by creating a instance of ONNXModel class provided by setting the ONNX model's path. 12 | 13 | ## Example Usage: 14 | 15 | #### RL Agent Class 16 | 17 | The below example show how a RL agent can create a instance of ONNXModel class and query it at the time of computing action. 18 | 19 | ```C++ 20 | #include "agent.h" 21 | #include "onnxUtil.h" 22 | 23 | Agent::Agent(std::string modelPath, int inputSize) { 24 | this->model = new ONNXModel(modelPath.c_str()); 25 | this->inputSize = inputSize; 26 | } 27 | 28 | unsigned Agent::computeAction(Observation &input) { 29 | ... 30 | this->model->run(input, model_output); 31 | ... 32 | return argmaxVal; 33 | } 34 | ``` 35 | 36 | #### RL InferenceEngine Class 37 | 38 | The below example showing a demo class to query the RL model, which can be called from pass. DriverService class is defined by inherithing InferenceEngine base class. It will also define the supporting fuction e.g. getPredictions which will query the ONNXModel and return the responce. 39 | 40 | ```C++ 41 | #include "inference-engine.h" 42 | 43 | class DriverService : public InferenceEngine { 44 | 45 | InferenceEngine(MultiAgentEnv* env) { 46 | setEnvironment(env); 47 | addAgent(new Agent(agentModelPath, agentObsSize),"Agent name"); 48 | } 49 | 50 | InferenceEngine(std::string modelPath, int inputSize) { 51 | // Creat instance of ONNXModel 52 | this->model = new ONNXModel(modelPath.c_str()); 53 | this->inputSize = inputSize; 54 | } 55 | 56 | void getPredictions(PassData &passData, OptInfo &predictions) { 57 | // Logic to query model and return prediction goes here 58 | ... 59 | } 60 | }; 61 | ``` 62 | 63 | #### Integration with LLVM Pass 64 | 65 | The below example show how to call the InferenceEngine from a LLVM Pass. 66 | 67 | ```C++ 68 | #include "environment.h" 69 | #include "driver-service.h" 70 | 71 | struct Hello : public FunctionPass, Environment { 72 | 73 | bool runOnFunction(Function &F) override { 74 | ... 75 | // Creates instance of DriverService with InferenceEngine as parent class 76 | InferenceEngine* inference_driver = new DriverService(Environment* env); 77 | inference_driver->getPredictions(PassData passData, OptInfo &predictions); 78 | ... 79 | } 80 | 81 | ``` 82 | -------------------------------------------------------------------------------- /MLModelRunner/ONNXModelRunner/agent.cpp: -------------------------------------------------------------------------------- 1 | //===- agent.cpp - RL Agent/Model for ONNX Runner --------------*- C++ -*-===// 2 | // 3 | // Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | // Exceptions. See the LICENSE file for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // 7 | //===----------------------------------------------------------------------===// 8 | /// 9 | /// \file 10 | /// This file defines the Agent class, which is a wrapper around the ONNXModel 11 | /// class. 12 | /// 13 | //===----------------------------------------------------------------------===// 14 | 15 | #include "MLModelRunner/ONNXModelRunner/agent.h" 16 | #include "MLModelRunner/Utils/Debug.h" 17 | #include 18 | #include 19 | #include 20 | 21 | namespace MLBridge { 22 | 23 | Agent::Agent(std::string modelPath) { 24 | // Create model object here 25 | this->model = new ONNXModel(modelPath.c_str()); 26 | } 27 | 28 | unsigned Agent::computeAction(Observation &input) { 29 | // Call model on input 30 | assert(input.size() > 0); 31 | llvm::SmallVector model_input(input.begin(), input.end()); 32 | llvm::SmallVector model_output; 33 | 34 | this->model->run(model_input, model_output); 35 | 36 | // select action from model output 37 | auto max = std::max_element(model_output.begin(), 38 | model_output.end()); // [2, 4) 39 | int argmaxVal = std::distance(model_output.begin(), max); 40 | 41 | MLBRIDGE_DEBUG( 42 | std::cout << "---------------MODEL OUTPUT VECTOR:----------------\n"; 43 | for (auto e 44 | : model_output) { std::cout << e << " "; } std::cout 45 | << "\nmax value and index are " << *max << " " << argmaxVal << "\n";); 46 | return argmaxVal; 47 | } 48 | 49 | } // namespace MLBridge 50 | -------------------------------------------------------------------------------- /MLModelRunner/ONNXModelRunner/onnx.cpp: -------------------------------------------------------------------------------- 1 | //===- onnx.cpp - ONNX Interface with CPP Runtime --------------*- C++ -*-===// 2 | // 3 | // Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | // Exceptions. See the LICENSE file for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // 7 | //===----------------------------------------------------------------------===// 8 | /// 9 | /// \file 10 | /// This file defines the ONNXModel class, which is a wrapper around the ONNX 11 | /// C++ interface. 12 | /// 13 | //===----------------------------------------------------------------------===// 14 | 15 | #include "MLModelRunner/ONNXModelRunner/onnx.h" 16 | #include "onnxruntime_cxx_api.h" 17 | 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | ONNXModel::ONNXModel(const char *model_path) : model_path(model_path) { 25 | env = new Ort::Env(ORT_LOGGING_LEVEL_WARNING, "test"); 26 | session = new Ort::Session(*env, model_path, Ort::SessionOptions{nullptr}); 27 | } 28 | 29 | Ort::Value ONNXModel::getInputValue(llvm::SmallVector &input, 30 | int inputIdx) { 31 | auto typeInfo = session->GetInputTypeInfo(inputIdx); 32 | auto tensorInfo = typeInfo.GetTensorTypeAndShapeInfo(); 33 | auto inputDims = tensorInfo.GetShape(); 34 | std::replace_if( 35 | inputDims.begin(), inputDims.end(), [](int64_t &i) { return i < 0; }, 1); 36 | 37 | size_t inputTensorSize = std::accumulate(inputDims.begin(), inputDims.end(), 38 | 1, std::multiplies()); 39 | assert(inputTensorSize == input.size()); 40 | 41 | auto memory_info = 42 | Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); 43 | auto inputTmp = Ort::Value::CreateTensor( 44 | memory_info, input.data(), inputTensorSize, inputDims.data(), 45 | inputDims.size()); 46 | auto inputTensor = &inputTmp; 47 | assert(inputTensor->IsTensor()); 48 | return inputTmp; 49 | } 50 | 51 | void ONNXModel::run(llvm::SmallVector &input, 52 | llvm::SmallVector &output) { 53 | Ort::AllocatorWithDefaultOptions allocator; 54 | 55 | int input_count = session->GetInputCount(); 56 | llvm::SmallVector, 10> inputList; 57 | inputList.push_back(input); 58 | llvm::SmallVector dummy_input; 59 | dummy_input.insert(dummy_input.end(), 1); 60 | for (int i = 1; i < input_count; i++) { 61 | inputList.push_back(dummy_input); 62 | } 63 | 64 | llvm::SmallVector inputNameList; 65 | for (int i = 0; i < input_count; i++) { 66 | auto inputName = session->GetInputNameAllocated(i, allocator); 67 | auto inputNameStr = inputName.get(); 68 | inputNameList.push_back(inputNameStr); 69 | } 70 | 71 | std::vector input_final; 72 | llvm::SmallVector inputNameStr_final; 73 | 74 | for (int i = 0; i < input_count; i++) { 75 | input_final.insert(input_final.end(), getInputValue(inputList[i], i)); 76 | inputNameStr_final.insert(inputNameStr_final.end(), 77 | inputNameList[i].c_str()); 78 | } 79 | 80 | auto outputName = session->GetOutputNameAllocated(0, allocator); 81 | auto outputNameStr = outputName.get(); 82 | 83 | auto outputTensors = 84 | session->Run(Ort::RunOptions{nullptr}, inputNameStr_final.data(), 85 | input_final.data(), input_count, &outputNameStr, 1); 86 | 87 | assert(outputTensors.size() == 1 && outputTensors.front().IsTensor()); 88 | 89 | auto outputDims = 90 | outputTensors.front().GetTensorTypeAndShapeInfo().GetShape()[1]; 91 | 92 | auto outVal = outputTensors.front().GetTensorMutableData(); 93 | 94 | output = llvm::SmallVector(outVal, outVal + outputDims); 95 | std::replace_if( 96 | output.begin(), output.end(), [](double x) { return std::isnan(x); }, 97 | -1.17549e+038); 98 | } 99 | -------------------------------------------------------------------------------- /MLModelRunner/PipeModelRunner.cpp: -------------------------------------------------------------------------------- 1 | //===- PipeModelRunner.cpp - Pipe based Model Runner ------------*- C++ -*-===// 2 | // 3 | // Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | // Exceptions. See the LICENSE file for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // 7 | // (Preliminary version adopted from InteractiveModelRunner.cpp of LLVM 17.X) 8 | // 9 | //===----------------------------------------------------------------------===// 10 | /// 11 | /// \file 12 | /// This file defines the PipeModelRunner class to interface with an external ML 13 | /// model during training and inference. The model is assumed to be running as 14 | /// an external process and the communication is done via 2 file descriptors 15 | /// using pipes. 16 | /// 17 | //===----------------------------------------------------------------------===// 18 | 19 | #include "MLModelRunner/PipeModelRunner.h" 20 | #include "MLModelRunner/MLModelRunner.h" 21 | #include "MLModelRunner/Utils/Debug.h" 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | 28 | #define DEBUG_TYPE "pipe-model-runner" 29 | 30 | using namespace llvm; 31 | 32 | namespace MLBridge { 33 | PipeModelRunner::PipeModelRunner(StringRef OutboundName, StringRef InboundName, 34 | SerDesKind SerDesType, LLVMContext *Ctx) 35 | : MLModelRunner(Kind::Pipe, SerDesType, Ctx), 36 | InEC(sys::fs::openFileForRead(InboundName, Inbound)) { 37 | this->InboundName = InboundName.str(); 38 | if (InEC) { 39 | int max_retries = 30, attempts = 0; 40 | double wait_seconds = 0.2, backoff_exp = 1.2; 41 | 42 | while (attempts < max_retries) { 43 | InEC = sys::fs::openFileForRead(InboundName, Inbound); 44 | if (InEC) { 45 | attempts++; 46 | std::cout << "Cannot open inbound file retrying! attempt: " << attempts 47 | << std::endl; 48 | std::this_thread::sleep_for( 49 | std::chrono::duration(wait_seconds)); 50 | wait_seconds *= backoff_exp; 51 | } else { 52 | break; 53 | } 54 | } 55 | if (InEC) { 56 | auto message = "Cannot open inbound file: " + InEC.message(); 57 | if (this->Ctx) 58 | this->Ctx->emitError(message); 59 | else 60 | std::cerr << message << std::endl; 61 | return; 62 | } 63 | } 64 | { 65 | OutStream = std::make_unique(OutboundName, OutEC); 66 | if (OutEC) { 67 | auto message = "Cannot open outbound file: " + OutEC.message(); 68 | if (this->Ctx) 69 | this->Ctx->emitError(message); 70 | else 71 | std::cerr << message << std::endl; 72 | return; 73 | } 74 | } 75 | } 76 | 77 | PipeModelRunner::~PipeModelRunner() { 78 | // close the file descriptors 79 | sys::fs::file_t FDAsOSHandle = sys::fs::convertFDToNativeFile(Inbound); 80 | sys::fs::closeFile(FDAsOSHandle); 81 | 82 | OutStream->close(); 83 | } 84 | 85 | std::string PipeModelRunner::readNBytes(size_t N) { 86 | std::string OutputBuffer(N, '\0'); 87 | char *Buff = OutputBuffer.data(); 88 | size_t InsPoint = 0; 89 | const size_t Limit = N; 90 | while (InsPoint < Limit) { 91 | auto ReadOrErr = ::sys::fs::readNativeFile( 92 | sys::fs::convertFDToNativeFile(Inbound), 93 | {Buff + InsPoint, OutputBuffer.size() - InsPoint}); 94 | if (ReadOrErr.takeError()) { 95 | if (this->Ctx) 96 | this->Ctx->emitError("Failed reading from inbound file"); 97 | else 98 | std::cerr << "Failed reading from inbound file" << std::endl; 99 | break; 100 | } 101 | InsPoint += *ReadOrErr; 102 | } 103 | return OutputBuffer; 104 | } 105 | 106 | void PipeModelRunner::send(void *data) { 107 | // TODO: send data size first (a hack to send protbuf data properly) 108 | auto dataString = reinterpret_cast(data); 109 | size_t message_length = dataString->size(); 110 | const char *message_length_ptr = 111 | reinterpret_cast(&message_length); 112 | MLBRIDGE_DEBUG(std::cout << "Message length: " << message_length << "\n"); 113 | MLBRIDGE_DEBUG(std::cout << "DataString.size(): " << dataString->size() 114 | << "\n"); 115 | OutStream->write(message_length_ptr, sizeof(size_t)); 116 | OutStream->write(dataString->data(), dataString->size()); 117 | OutStream->flush(); 118 | } 119 | 120 | void *PipeModelRunner::receive() { 121 | MLBRIDGE_DEBUG(std::cout << "In PipeModelRunner receive...\n"); 122 | auto hdr = readNBytes(8); 123 | MLBRIDGE_DEBUG(std::cout << "Read header...\n"); 124 | size_t MessageLength = 0; 125 | memcpy(&MessageLength, hdr.data(), sizeof(MessageLength)); 126 | // Read message 127 | auto OutputBuffer = new std::string(readNBytes(MessageLength)); 128 | MLBRIDGE_DEBUG(std::cout << "OutputBuffer size: " << OutputBuffer->size() 129 | << "\n"; 130 | std::cout << "OutputBuffer: " << *OutputBuffer << "\n"); 131 | return OutputBuffer; 132 | } 133 | 134 | void *PipeModelRunner::evaluateUntyped() { 135 | MLBRIDGE_DEBUG(std::cout << "In PipeModelRunner evaluateUntyped...\n"); 136 | auto *data = SerDes->getSerializedData(); 137 | send(data); 138 | auto *reply = receive(); 139 | MLBRIDGE_DEBUG( 140 | std::cout << "In PipeModelRunner::evaluateUntyped() received data...\n"); 141 | return SerDes->deserializeUntyped(reply); 142 | } 143 | 144 | } // namespace MLBridge 145 | -------------------------------------------------------------------------------- /MLModelRunner/Utils/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | if(LLVM_MLBRIDGE) 2 | add_llvm_component_library(ModelRunnerUtils 3 | MLConfig.cpp 4 | ) 5 | target_include_directories(ModelRunnerUtils PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../../include) 6 | 7 | else() 8 | add_library(ModelRunnerCUtils OBJECT MLConfig.cpp) 9 | add_library(ModelRunnerUtils OBJECT MLConfig.cpp) 10 | endif(LLVM_MLBRIDGE) 11 | -------------------------------------------------------------------------------- /MLModelRunner/Utils/MLConfig.cpp: -------------------------------------------------------------------------------- 1 | //===- MLConfig.cpp - Set ML Config Paths -----------------------*- C++ -*-===// 2 | // 3 | // Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | // Exceptions. See the LICENSE file for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // 7 | //===----------------------------------------------------------------------===// 8 | 9 | #include "MLModelRunner/Utils/MLConfig.h" 10 | -------------------------------------------------------------------------------- /MLModelRunner/gRPCModelRunner/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(protobuf_MODULE_COMPATIBLE TRUE) 2 | set(CMAKE_EXPORT_COMPILE_COMMANDS ON) 3 | find_package(Protobuf CONFIG REQUIRED) 4 | set(_PROTOBUF_LIBPROTOBUF protobuf::libprotobuf) 5 | set(_REFLECTION gRPC::grpc++_reflection) 6 | if(CMAKE_CROSSCOMPILING) 7 | find_program(_PROTOBUF_PROTOC protoc) 8 | else() 9 | set(_PROTOBUF_PROTOC $) 10 | endif() 11 | 12 | # Find gRPC installation 13 | # Looks for gRPCConfig.cmake file installed by gRPC's cmake installation. 14 | 15 | find_package(gRPC 1.58.0 EXACT CONFIG REQUIRED) 16 | message(STATUS "Using gRPC ${gRPC_VERSION}") 17 | 18 | set(_GRPC_GRPCPP gRPC::grpc++) 19 | if(CMAKE_CROSSCOMPILING) 20 | find_program(_GRPC_CPP_PLUGIN_EXECUTABLE grpc_cpp_plugin) 21 | find_program(_GRPC_PYTHON_PLUGIN_EXECUTABLE grpc_python_plugin) 22 | else() 23 | set(_GRPC_CPP_PLUGIN_EXECUTABLE $) 24 | set(_GRPC_PYTHON_PLUGIN_EXECUTABLE $) 25 | endif() 26 | 27 | file(GLOB proto_list ${PROTOS_DIRECTORY}/*.proto) 28 | set(proto_dir ${PROTOS_DIRECTORY}) 29 | file(MAKE_DIRECTORY ${CMAKE_BINARY_DIR}/include/grpc) 30 | 31 | 32 | foreach(proto ${proto_list}) 33 | get_filename_component(proto_name ${proto} NAME_WLE) 34 | file(GLOB cc_file ${CMAKE_CURRENT_SOURCE_DIR}/Service/${proto_name}/*.cc) 35 | set(cc_files ${cc_files} ${cc_file}) 36 | file(MAKE_DIRECTORY ${CMAKE_BINARY_DIR}/include/grpc/${proto_name}) 37 | set(header_files ${header_files} "${CMAKE_BINARY_DIR}/include/grpc/${proto_name}") 38 | set(proto_srcs_list ${proto_srcs_list} "${CMAKE_BINARY_DIR}/include/grpc/${proto_name}/${proto_name}.pb.cc") 39 | set(grpc_srcs_list ${grpc_srcs_list} "${CMAKE_BINARY_DIR}/include/grpc/${proto_name}/${proto_name}.grpc.pb.cc") 40 | add_custom_command( 41 | OUTPUT "${CMAKE_BINARY_DIR}/include/grpc/${proto_name}/${proto_name}.pb.cc" 42 | "${CMAKE_BINARY_DIR}/include/grpc/${proto_name}/${proto_name}.pb.h" 43 | "${CMAKE_BINARY_DIR}/include/grpc/${proto_name}/${proto_name}.grpc.pb.cc" 44 | "${CMAKE_BINARY_DIR}/include/grpc/${proto_name}/${proto_name}.grpc.pb.h" 45 | COMMAND ${_PROTOBUF_PROTOC} 46 | ARGS --grpc_out "${CMAKE_BINARY_DIR}/include/grpc/${proto_name}" 47 | --cpp_out "${CMAKE_BINARY_DIR}/include/grpc/${proto_name}" 48 | -I "${proto_dir}" 49 | --plugin=protoc-gen-grpc="${_GRPC_CPP_PLUGIN_EXECUTABLE}" 50 | "${proto}" 51 | DEPENDS "${proto}" 52 | ) 53 | endforeach() 54 | 55 | set( PYTHON_UTILITIES_DIRECTORY "" CACHE PATH "Path to the directory containing the python utilities") 56 | 57 | if(PYTHON_UTILITIES_DIRECTORY STREQUAL "") 58 | set(PYTHON_UTILITIES_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/Python-Utilities) 59 | endif() 60 | 61 | file(MAKE_DIRECTORY ${PYTHON_UTILITIES_DIRECTORY}) 62 | foreach(proto ${proto_list}) 63 | get_filename_component(proto_name ${proto} NAME_WLE) 64 | set(proto_python_srcs_list ${proto_python_srcs_list} "${PYTHON_UTILITIES_DIRECTORY}/${proto_name}_pb2.py") 65 | 66 | add_custom_command( 67 | OUTPUT "${PYTHON_UTILITIES_DIRECTORY}/${proto_name}_pb2.py" 68 | COMMAND ${_PROTOBUF_PROTOC} 69 | ARGS --grpc_out "${PYTHON_UTILITIES_DIRECTORY}/" 70 | --python_out "${PYTHON_UTILITIES_DIRECTORY}/" 71 | -I "${proto_dir}" 72 | --plugin=protoc-gen-grpc="${_GRPC_PYTHON_PLUGIN_EXECUTABLE}" 73 | "${proto}" 74 | DEPENDS "${proto}" 75 | ) 76 | endforeach() 77 | 78 | # Building the library 79 | if(LLVM_MLBRIDGE) 80 | add_llvm_component_library(gRPCModelRunnerLib 81 | ${cc_files} 82 | ${proto_srcs_list} 83 | ${grpc_srcs_list} 84 | ${proto_python_srcs_list} 85 | ) 86 | else() 87 | add_library(gRPCModelRunnerLib 88 | ${cc_files} 89 | ${proto_srcs_list} 90 | ${grpc_srcs_list} 91 | ${proto_python_srcs_list} 92 | ) 93 | endif(LLVM_MLBRIDGE) 94 | 95 | target_link_libraries(gRPCModelRunnerLib 96 | PRIVATE ${_REFLECTION} 97 | ${_GRPC_GRPCPP} 98 | ${_PROTOBUF_LIBPROTOBUF} 99 | ) 100 | 101 | target_include_directories(gRPCModelRunnerLib 102 | PRIVATE ${CMAKE_BINARY_DIR}/include/Service 103 | PUBLIC ${Protobuf_INCLUDE_DIRS} 104 | ) 105 | -------------------------------------------------------------------------------- /MLModelRunner/gRPCModelRunner/README.md: -------------------------------------------------------------------------------- 1 | # LLVM-GRPC:Register-Allocation 2 | A GRPC framework to communicate between LLVM and a python ML workload for optimized Register Allocation. 3 | 4 | ## Pre-requites 5 | * Building GRPC from Source: Please follow [`Build GRPC with cmake`](https://grpc.io/docs/languages/cpp/quickstart/) v1.34 (protobuff v3.13) to build GRPC from source. 6 | * In the above tutorial setting `DCMAKE_INSTALL_PREFIX` may not be necessary and the default install prefix can be used. 7 | * A proper build from source of LLVM along with Clang. 8 | * The following dependencies will be required for Python: 9 | `pip install grpcio-tools` 10 | 11 | ## Building LLVM-GRPC 12 | * `mkdir build && cd build` 13 | * `cmake -DLLVM_BUILD_DIR= ../` 14 | * `Ex: cmake -DLLVM_BUILD_DIR=/home/cs20mtech01002/llvm-project/build ../` 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ML-Compiler-Bridge 2 | 3 | `ML-Compiler-Bridge` is a compiler agnostic library to aid in ML-Enabled Compiler Optimizations. ML-Compiler-Bridge supports both training and inference scenarios. Library exposes Python and C/C++ APIs to interface 4 | with the Python-based ML models and a C/C++ compiler. This design allows ML model development within a traditional Python framework while making end-to-end integration with an optimizing compiler possible and efficient. 5 | 6 | This repo contains the source code and relevant information described in our paper, ["The Next 700 ML-Enabled Compiler Optimizations"](https://doi.org/10.1145/3640537.3641580) ([arxiv](https://arxiv.org/abs/2311.10800)). 7 | Please see [here](https://iith-compilers.github.io/ML-Compiler-Bridge/) for documentation and other details. 8 | 9 | > The Next 700 ML-Enabled Compiler Optimizations, S. VenkataKeerthy, Siddharth Jain, Umesh Kalvakuntla, Pranav Sai Gorantla, Rajiv Shailesh Chitale, Eugene Brevdo, Albert Cohen, Mircea Trofin and Ramakrishna Upadrasta. CC 2024. 10 | 11 | [![Build and Tests](https://github.com/IITH-Compilers/MLCompilerBridge/actions/workflows/build.yml/badge.svg)](https://github.com/IITH-Compilers/MLCompilerBridge/actions/workflows/build.yml) 12 | [![Doxygen Action](https://github.com/IITH-Compilers/MLCompilerBridge/actions/workflows/main.yml/badge.svg)](https://github.com/IITH-Compilers/MLCompilerBridge/actions/workflows/main.yml) 13 | [![pre-commit checks](https://github.com/IITH-Compilers/MLCompilerBridge/actions/workflows/formatting.yml/badge.svg)](https://github.com/IITH-Compilers/MLCompilerBridge/actions/workflows/formatting.yml) 14 | [![Upload to Pypi](https://github.com/IITH-Compilers/MLCompilerBridge/actions/workflows/upload_pypi.yml/badge.svg)](https://github.com/IITH-Compilers/MLCompilerBridge/actions/workflows/upload_pypi.yml) 15 | 16 | 17 | ![Image](https://github.com/IITH-Compilers/ML-Compiler-Bridge/raw/main/images/component-ml-compiler-bridge.png) 18 | 19 | ## Features 20 | * **Unified Framework:** Comes with a suite of two inter-process and two in-process model runners and three serialization-deserialization mechanisms to support interleaved and non-interleaved communication between models and compiler. 21 | * **Multi-language Support:** Exposes C++ and C APIs to interface model runners and serializers with the compilers and Python APIs to interface inter-process model runners with ML models. 22 | * **Compiler and ML-Framework Independence:** Provides compiler and ML-Framework independent APIs, and supports easier integration with compilers like LLVM, MLIR, and Pluto and ML Frameworks like TensorFlow, PyTorch, JAX, etc. 23 | * **Deeper Integration:** Enables deeper integration of ML models within the compiler in a framework-independent manner to support easier inference in case of ML driven compiler optimizations. 24 | 25 | ## Requirements 26 | * cmake (>= 3.10) 27 | * GNU Make (4.2.1) 28 | * LLVM (10.X) - [src](https://github.com/llvm/llvm-project/tree/release/10.x), [release](https://releases.llvm.org/download.html#10.0.1) 29 | * Python (3.10), C++17 30 | * gRPC v1.58 and protobuf v23.4 - for gRPC Model Runner 31 | * Building GRPC from Source: Please follow [`Build GRPC with cmake`](https://grpc.io/docs/languages/cpp/quickstart/) v1.58 (protobuf v23.4) to build GRPC from source. 32 | * In the above tutorial setting `DCMAKE_INSTALL_PREFIX` may not be necessary and the default install prefix can be used. 33 | * The following dependencies will be required for Python: `pip install grpcio-tools`. 34 | * [ONNXRuntime](https://github.com/microsoft/onnxruntime/releases) v1.13.1 35 | * TensorFlow - for TF Model Runner (AOT flow) 36 | * Tested with TensorFlow 2.13.0 37 | * Other python requirements are available in [mlbridge.yml](./mlbridge.yml) 38 | * Conda/Anaconda based virtual environment is assumed 39 | 40 | (Experiments are done on an Ubuntu 20.04 machine) 41 | 42 | ## Setup 43 | `ML-Compiler-Bridge` can be built as a stand-alone library to generate `.a` files that can in turn be linked with any compiler. 44 | 1. `mkdir build && cd build` 45 | 2. `cmake [-DCMAKE_BUILD_TYPE=Release|Debug] [-DCMAKE_INSTALL_PREFIX=] [-DMLBRIDGE_ENABLE_TEST=ON|OFF] -DONNXRUNTIME_ROOTDIR= -DPROTOS_DIRECTORY= -DTENSORFLOW_AOT_PATH= ../` 46 | 3. `make -j [&& make install]` 47 | 4. `pip install compilerinterface` 48 | 49 | This process would generate `libMLCompilerBridge.a` and `libMLCompilerBridgeC.a` libraries under `build/lib` directory, required headers under `build/include` directory. `libMLCompilerBridgeC.a` exposes C APIs for using with C-based compilers like Pluto, where as `libMLCompilerBridge.a` exposes C++ APIs that can be used with any compiler written in C++. 50 | 51 | Python end points are available under [`CompilerInterface`](./CompilerInterface/). They can be downloaded as a [`package`](https://pypi.org/project/compilerinterface/) from pypi. 52 | 53 | To ensure the correctness, run `make verify-all`. This would need enabling tests in cmake (`-DMLBRIDGE_ENABLE_TEST=ON`) and `PROTOS_DIRECTORY` should point to `test/protos`. 54 | 55 | ### Using ML-Compiler-Bridge with LLVM and MLIR 56 | 57 | `ML-Compiler-Bridge` can be integrated and built along with the LLVM project. This can be done by adding this repository as a new project and setting `LLVM_MLBRIDGE` option to `ON`. 58 | 59 | You can check the [CMakeLists.txt](https://github.com/IITH-Compilers/ml-llvm-project/blob/mlbridge-lib/llvm/CMakeLists.txt) of the [`ml-llvm-project`](https://github.com/IITH-Compilers/ml-llvm-project/tree/mlbridge-lib) repository which demonstrates such an integration. 60 | 61 | The passes that need to make use of this library can then just link with `LLVMMLBridge`. 62 | 63 | Example `CMakeLists.txt` of an LLVM pass that would use the library is shown below. 64 | 65 | ```CMakeLists.txt 66 | add_llvm_component_library(LLVMMLPass 67 | pass.cpp 68 | ml.cpp 69 | 70 | ADDITIONAL_HEADER_DIRS 71 | ${CMAKE_CURRENT_SOURCE_DIR}/includes 72 | 73 | DEPENDS 74 | LLVMMLBridge 75 | intrinsics_gen 76 | ) 77 | target_link_libraries(LLVMMLPass PRIVATE LLVMMLBridge) 78 | 79 | ``` 80 | To use TensorFlow AOT Model Runner, you need to make use of `tf_find_and_compile` method exposed in [`cmake/modules/TensorFlowCompile.cmake`](cmake/modules/TensorFlowCompile.cmake) in the CMakeLists.txt of your pass with appropriate arguments. An example of integrating TF AOT Model with inlining pass is shown [here](https://github.com/IITH-Compilers/ml-llvm-project/blob/tfmodel/llvm/lib/Analysis/CMakeLists.txt). 81 | 82 | ## Artifacts 83 | Libraries are autogenerated for every relevant check-in with GitHub actions. Such generated artifacts are tagged along with the successful runs of [`Publish`]() action. 84 | 85 | ## Citation 86 | ``` 87 | @inproceedings{venkatakeerthy-2024-MLCompilerBridge, 88 | author = {VenkataKeerthy, S. and Jain, Siddharth and Kalvakuntla, Umesh and Gorantla, Pranav Sai and Chitale, Rajiv Shailesh and Brevdo, Eugene and Cohen, Albert and Trofin, Mircea and Upadrasta, Ramakrishna}, 89 | title = {The Next 700 ML-Enabled Compiler Optimizations}, 90 | year = {2024}, 91 | isbn = {9798400705076}, 92 | publisher = {Association for Computing Machinery}, 93 | address = {New York, NY, USA}, 94 | url = {https://doi.org/10.1145/3640537.3641580}, 95 | doi = {10.1145/3640537.3641580}, 96 | booktitle = {Proceedings of the 33rd ACM SIGPLAN International Conference on Compiler Construction}, 97 | pages = {238–249}, 98 | numpages = {12}, 99 | keywords = {Machine Learning for Compiler Optimizations, ONNX, Pipes, TensorFlow AOT, gRPC}, 100 | location = {, Edinburgh, United Kingdom, }, 101 | series = {CC 2024} 102 | } 103 | ``` 104 | 105 | ## Contributions 106 | Please feel free to raise issues to file a bug, pose a question, or initiate any related discussions. Pull requests are welcome :) 107 | 108 | ## License 109 | ML-Compiler-Bridge is released under Apache 2.0 license with LLVM Exceptions. See LICENSE file for more details. 110 | -------------------------------------------------------------------------------- /SerDes/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(protobuf_MODULE_COMPATIBLE TRUE) 2 | find_package(Protobuf CONFIG REQUIRED) 3 | 4 | if(LLVM_MLBRIDGE) 5 | add_llvm_library(SerDesLib 6 | TensorSpec.cpp 7 | jsonSerDes.cpp 8 | bitstreamSerDes.cpp 9 | protobufSerDes.cpp 10 | tensorflowSerDes.cpp 11 | ) 12 | target_compile_definitions(SerDesLib PRIVATE LLVM_MLBRIDGE) 13 | else() 14 | add_library(SerDesLib OBJECT TensorSpec.cpp jsonSerDes.cpp bitstreamSerDes.cpp protobufSerDes.cpp tensorflowSerDes.cpp) 15 | 16 | add_library(SerDesCLib OBJECT TensorSpec.cpp jsonSerDes.cpp bitstreamSerDes.cpp) 17 | endif() 18 | target_include_directories(SerDesLib PUBLIC ${TENSORFLOW_AOT_PATH}/include) 19 | target_link_libraries(SerDesLib PRIVATE tf_xla_runtime) 20 | -------------------------------------------------------------------------------- /SerDes/TensorSpec.cpp: -------------------------------------------------------------------------------- 1 | //===- TensorSpec.cpp - tensor type abstraction ---------------------------===// 2 | // 3 | // Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | // Exceptions. See the LICENSE file for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // 7 | // (Preliminary version adopted from TensorSpec.cpp of LLVM 12.X) 8 | // 9 | //===----------------------------------------------------------------------===// 10 | /// 11 | /// \file 12 | /// Implementation file for the abstraction of a tensor type, and JSON loading 13 | /// utils. 14 | /// 15 | //===----------------------------------------------------------------------===// 16 | 17 | #include "SerDes/TensorSpec.h" 18 | #include "llvm/ADT/None.h" 19 | #include "llvm/ADT/StringExtras.h" 20 | #include "llvm/Support/Debug.h" 21 | #include "llvm/Support/JSON.h" 22 | 23 | #include 24 | #include 25 | #include 26 | 27 | using namespace llvm; 28 | 29 | namespace MLBridge { 30 | 31 | #define TFUTILS_GETDATATYPE_IMPL(T, E) \ 32 | template <> TensorType TensorSpec::getDataType() { return TensorType::E; } 33 | 34 | SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_IMPL) 35 | 36 | #undef TFUTILS_GETDATATYPE_IMPL 37 | 38 | static std::array(TensorType::Total)> 39 | TensorTypeNames{"INVALID", 40 | #define TFUTILS_GETNAME_IMPL(T, _) #T, 41 | SUPPORTED_TENSOR_TYPES(TFUTILS_GETNAME_IMPL) 42 | #undef TFUTILS_GETNAME_IMPL 43 | }; 44 | 45 | StringRef toString(TensorType TT) { 46 | return TensorTypeNames[static_cast(TT)]; 47 | } 48 | 49 | void TensorSpec::toJSON(json::OStream &OS) const { 50 | OS.object([&]() { 51 | OS.attribute("name", name()); 52 | OS.attribute("type", toString(type())); 53 | OS.attribute("port", port()); 54 | OS.attributeArray("shape", [&]() { 55 | for (size_t D : shape()) 56 | OS.value(static_cast(D)); 57 | }); 58 | }); 59 | } 60 | 61 | TensorSpec::TensorSpec(const std::string &Name, int Port, TensorType Type, 62 | size_t ElementSize, const std::vector &Shape) 63 | : Name(Name), Port(Port), Type(Type), Shape(Shape), 64 | ElementCount(std::accumulate(Shape.begin(), Shape.end(), 1, 65 | std::multiplies())), 66 | ElementSize(ElementSize) {} 67 | 68 | llvm::Optional getTensorSpecFromJSON(LLVMContext &Ctx, 69 | const json::Value &Value) { 70 | auto EmitError = 71 | [&](const llvm::Twine &Message) -> llvm::Optional { 72 | std::string S; 73 | llvm::raw_string_ostream OS(S); 74 | OS << Value; 75 | Ctx.emitError("Unable to parse JSON Value as spec (" + Message + "): " + S); 76 | return None; 77 | }; 78 | // FIXME: accept a Path as a parameter, and use it for error reporting. 79 | #ifdef LLVM_MLBRIDGE 80 | json::Path::Root Root("tensor_spec"); 81 | json::ObjectMapper Mapper(Value, Root); 82 | #else 83 | json::ObjectMapper Mapper(Value); 84 | #endif 85 | if (!Mapper) 86 | return EmitError("Value is not a dict"); 87 | 88 | std::string TensorName; 89 | int TensorPort = -1; 90 | std::string TensorType; 91 | std::vector TensorShape; 92 | 93 | if (!Mapper.map("name", TensorName)) 94 | return EmitError("'name' property not present or not a string"); 95 | if (!Mapper.map("type", TensorType)) 96 | return EmitError("'type' property not present or not a string"); 97 | if (!Mapper.map("port", TensorPort)) 98 | return EmitError("'port' property not present or not an int"); 99 | if (!Mapper.map>("shape", TensorShape)) 100 | return EmitError("'shape' property not present or not an int array"); 101 | 102 | #define PARSE_TYPE(T, E) \ 103 | if (TensorType == #T) \ 104 | return TensorSpec::createSpec(TensorName, TensorShape, TensorPort); 105 | SUPPORTED_TENSOR_TYPES(PARSE_TYPE) 106 | #undef PARSE_TYPE 107 | return None; 108 | } 109 | 110 | std::string tensorValueToString(const char *Buffer, const TensorSpec &Spec) { 111 | switch (Spec.type()) { 112 | #define _IMR_DBG_PRINTER(T, N) \ 113 | case TensorType::N: { \ 114 | const T *TypedBuff = reinterpret_cast(Buffer); \ 115 | auto R = llvm::make_range(TypedBuff, TypedBuff + Spec.getElementCount()); \ 116 | return llvm::join( \ 117 | llvm::map_range(R, [](T V) { return std::to_string(V); }), ","); \ 118 | } 119 | SUPPORTED_TENSOR_TYPES(_IMR_DBG_PRINTER) 120 | #undef _IMR_DBG_PRINTER 121 | case TensorType::Total: 122 | case TensorType::Invalid: 123 | llvm_unreachable("invalid tensor type"); 124 | } 125 | // To appease warnings about not all control paths returning a value. 126 | return ""; 127 | } 128 | 129 | } // namespace MLBridge 130 | -------------------------------------------------------------------------------- /SerDes/bitstreamSerDes.cpp: -------------------------------------------------------------------------------- 1 | //===- bitstreamSerDes.cpp - Serializer for Bitstream -----------*- C++ -*-===// 2 | // 3 | // Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | // Exceptions. See the LICENSE file for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // 7 | //===----------------------------------------------------------------------===// 8 | /// 9 | /// \file 10 | /// This file defines the BitstreamSerDes class, which is a serializer for 11 | /// Bitstream format. 12 | /// 13 | //===----------------------------------------------------------------------===// 14 | 15 | #include "SerDes/bitstreamSerDes.h" 16 | #include "MLModelRunner/Utils/Debug.h" 17 | #include "llvm/Support/JSON.h" 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | 26 | #define DEBUG_TYPE "bitstream-serdes" 27 | 28 | namespace MLBridge { 29 | void BitstreamSerDes::setFeature(const std::string &name, const int value) { 30 | auto *valuePtr = new int(value); 31 | featuresint[name] = valuePtr; 32 | tensorSpecs.push_back(TensorSpec::createSpec(name, {1})); 33 | rawData.push_back(valuePtr); 34 | } 35 | 36 | void BitstreamSerDes::setFeature(const std::string &name, const long value) { 37 | auto *valuePtr = new long(value); 38 | featureslong[name] = valuePtr; 39 | tensorSpecs.push_back(TensorSpec::createSpec(name, {1})); 40 | rawData.push_back(valuePtr); 41 | } 42 | 43 | void BitstreamSerDes::setFeature(const std::string &name, const float value) { 44 | auto *valuePtr = new float(value); 45 | featuresfloat[name] = valuePtr; 46 | tensorSpecs.push_back(TensorSpec::createSpec(name, {1})); 47 | rawData.push_back(valuePtr); 48 | } 49 | 50 | void BitstreamSerDes::setFeature(const std::string &name, const double value) { 51 | auto *valuePtr = new double(value); 52 | featuresdouble[name] = valuePtr; 53 | tensorSpecs.push_back(TensorSpec::createSpec(name, {1})); 54 | rawData.push_back(valuePtr); 55 | } 56 | 57 | void BitstreamSerDes::setFeature(const std::string &name, 58 | const std::string value) { 59 | auto *valuePtr = new std::string(value); 60 | featuresstring[name] = valuePtr; 61 | long size = value.length(); 62 | tensorSpecs.push_back(TensorSpec::createSpec(name, {size})); 63 | rawData.push_back((void *)valuePtr->c_str()); 64 | } 65 | 66 | void BitstreamSerDes::setFeature(const std::string &name, const bool value) { 67 | auto *valuePtr = new bool(value); 68 | featuresbool[name] = valuePtr; 69 | tensorSpecs.push_back(TensorSpec::createSpec(name, {1})); 70 | rawData.push_back(valuePtr); 71 | } 72 | 73 | void BitstreamSerDes::setFeature(const std::string &name, 74 | const std::vector &value) { 75 | auto *valuePtr = new std::vector(value); 76 | featuresVectorint[name] = valuePtr; 77 | tensorSpecs.push_back( 78 | TensorSpec::createSpec(name, {static_cast(valuePtr->size())})); 79 | rawData.push_back(valuePtr->data()); 80 | } 81 | 82 | void BitstreamSerDes::setFeature(const std::string &name, 83 | const std::vector &value) { 84 | auto *valuePtr = new std::vector(value); 85 | featuresVectorlong[name] = valuePtr; 86 | tensorSpecs.push_back(TensorSpec::createSpec( 87 | name, {static_cast(valuePtr->size())})); 88 | rawData.push_back(valuePtr->data()); 89 | } 90 | 91 | void BitstreamSerDes::setFeature(const std::string &name, 92 | const std::vector &value) { 93 | auto *valuePtr = new std::vector(value); 94 | featuresVectorfloat[name] = valuePtr; 95 | tensorSpecs.push_back(TensorSpec::createSpec( 96 | name, {static_cast(valuePtr->size())})); 97 | rawData.push_back(valuePtr->data()); 98 | } 99 | 100 | void BitstreamSerDes::setFeature(const std::string &name, 101 | const std::vector &value) { 102 | auto *valuePtr = new std::vector(value); 103 | featuresVectordouble[name] = valuePtr; 104 | tensorSpecs.push_back(TensorSpec::createSpec( 105 | name, {static_cast(valuePtr->size())})); 106 | rawData.push_back(valuePtr->data()); 107 | } 108 | 109 | void BitstreamSerDes::setFeature(const std::string &name, 110 | const std::vector &value) { 111 | llvm_unreachable("Currently std::vector not supported"); 112 | } 113 | 114 | void BitstreamSerDes::setFeature(const std::string &name, 115 | const std::vector &value) { 116 | llvm_unreachable("Currently std::vector not supported"); 117 | } 118 | 119 | void *BitstreamSerDes::getSerializedData() { 120 | std::unique_ptr OS = 121 | std::make_unique(Buffer); 122 | llvm::json::OStream J(*OS); 123 | J.object([&]() { 124 | J.attributeArray("features", [&]() { 125 | for (size_t I = 0; I < tensorSpecs.size(); ++I) { 126 | tensorSpecs[I].toJSON(J); 127 | } 128 | }); 129 | }); 130 | J.flush(); 131 | OS->write("\n", 1); 132 | MLBRIDGE_DEBUG(std::cout << "rawData.size(): " << rawData.size() << "\n"); 133 | for (size_t I = 0; I < rawData.size(); ++I) { 134 | OS->write(reinterpret_cast(rawData[I]), 135 | tensorSpecs[I].getTotalTensorBufferSize()); 136 | } 137 | OS->write("\n", 1); 138 | OS->flush(); 139 | auto *out = new std::string(Buffer); 140 | cleanDataStructures(); 141 | return out; 142 | } 143 | 144 | void *BitstreamSerDes::deserializeUntyped(void *data) { 145 | // set the message length 146 | auto *res = reinterpret_cast(data); 147 | this->MessageLength = res->size(); 148 | return res->data(); 149 | } 150 | } // namespace MLBridge 151 | -------------------------------------------------------------------------------- /SerDes/jsonSerDes.cpp: -------------------------------------------------------------------------------- 1 | //===- jsonstreamSerDes.cpp - Serializer for JSON ---------------*- C++ -*-===// 2 | // 3 | // Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | // Exceptions. See the LICENSE file for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // 7 | //===----------------------------------------------------------------------===// 8 | /// 9 | /// \file 10 | /// This file defines the JsonSerDes class, which is a wrapper around the JSON 11 | /// C++ interface to serialize and deserialize data to and from JSON. 12 | /// 13 | //===----------------------------------------------------------------------===// 14 | 15 | #include "SerDes/jsonSerDes.h" 16 | #include "MLModelRunner/Utils/DataTypes.h" 17 | #include "MLModelRunner/Utils/Debug.h" 18 | #include "SerDes/baseSerDes.h" 19 | #include "llvm/Support/Debug.h" 20 | #include "llvm/Support/JSON.h" 21 | #include 22 | #include 23 | 24 | #define DEBUG_TYPE "json-serdes" 25 | 26 | using namespace llvm; 27 | 28 | namespace MLBridge { 29 | void *JsonSerDes::getSerializedData() { 30 | auto tempJO = J; 31 | auto data = json::Value(std::move(tempJO)); 32 | auto *ret = new std::string(); 33 | llvm::raw_string_ostream OS(*ret); 34 | json::OStream(OS).value(data); 35 | cleanDataStructures(); 36 | return ret; 37 | } 38 | 39 | void *JsonSerDes::deserializeUntyped(void *data) { 40 | MLBRIDGE_DEBUG(std::cout << "In JsonSerDes deserializeUntyped...\n"); 41 | auto dataStr = static_cast(data); 42 | MLBRIDGE_DEBUG(std::cout << "dataStr: " << *dataStr << "\n"); 43 | Expected valueOrErr = json::parse(*dataStr); 44 | if (!valueOrErr) { 45 | auto *ret = new std::string(); 46 | llvm::raw_string_ostream SOS(*ret); 47 | SOS << "Error parsing JSON: " << valueOrErr.takeError() << "\n"; 48 | std::cerr << &ret << "\n"; 49 | exit(1); 50 | } 51 | json::Object *ret = valueOrErr->getAsObject(); 52 | auto val = ret->get("out"); 53 | MLBRIDGE_DEBUG(std::cout << "Got the final array...\n"; 54 | std::cout << "End JsonSerDes deserializeUntyped...\n"); 55 | return desJson(val); 56 | } 57 | 58 | void *JsonSerDes::desJson(json::Value *V) { 59 | switch (V->kind()) { 60 | case json::Value::Kind::Null: 61 | return nullptr; 62 | case json::Value::Kind::Number: { 63 | if (auto x = V->getAsInteger()) { 64 | IntegerType *ret = new IntegerType(); 65 | *ret = x.getValue(); 66 | this->MessageLength = sizeof(IntegerType); 67 | return ret; 68 | } else if (auto x = V->getAsNumber()) { 69 | RealType *ret = new RealType(); 70 | *ret = x.getValue(); 71 | this->MessageLength = sizeof(RealType); 72 | return ret; 73 | } else { 74 | std::cerr << "Error in desJson: Number is not int, or double\n"; 75 | exit(1); 76 | } 77 | } 78 | case json::Value::Kind::String: { 79 | std::string *ret = new std::string(); 80 | *ret = V->getAsString()->str(); 81 | this->MessageLength = ret->size() * sizeof(char); 82 | return ret->data(); 83 | } 84 | case json::Value::Kind::Boolean: { 85 | bool *ret = new bool(); 86 | *ret = V->getAsBoolean().getValue(); 87 | this->MessageLength = sizeof(bool); 88 | return ret; 89 | } 90 | case json::Value::Kind::Array: { 91 | // iterate over array and find its type 92 | // assume all elements are of same type and return vector of that type 93 | // if not, return nullptr 94 | auto arr = V->getAsArray(); 95 | 96 | auto it = arr->begin(); 97 | auto first = it; 98 | switch (first->kind()) { 99 | case json::Value::Kind::Number: { 100 | if (auto x = first->getAsInteger()) { 101 | std::vector *ret = new std::vector(); 102 | for (auto it : *arr) { 103 | ret->push_back(it.getAsInteger().getValue()); 104 | } 105 | this->MessageLength = ret->size() * sizeof(IntegerType); 106 | return ret->data(); 107 | } else if (auto x = first->getAsNumber()) { 108 | std::vector *ret = new std::vector(); 109 | for (auto it : *arr) { 110 | ret->push_back(it.getAsNumber().getValue()); 111 | } 112 | this->MessageLength = ret->size() * sizeof(RealType); 113 | return ret->data(); 114 | } else { 115 | std::cerr << "Error in desJson: Number is not int, or double\n"; 116 | exit(1); 117 | } 118 | } 119 | case json::Value::Kind::String: { 120 | std::vector *ret = new std::vector(); 121 | for (auto it : *arr) { 122 | ret->push_back(it.getAsString()->str()); 123 | } 124 | this->MessageLength = ret->size() * sizeof(std::string); 125 | return ret->data(); 126 | } 127 | case json::Value::Kind::Boolean: { 128 | std::vector *ret = new std::vector(); 129 | for (auto it : *arr) { 130 | ret->push_back(it.getAsBoolean().getValue()); 131 | } 132 | this->MessageLength = ret->size() * sizeof(uint8_t); 133 | return ret->data(); 134 | } 135 | default: { 136 | std::cerr << "Error in desJson: Array is not of supported type\n"; 137 | exit(1); 138 | } 139 | } 140 | } 141 | } 142 | return nullptr; 143 | } 144 | } // namespace MLBridge 145 | -------------------------------------------------------------------------------- /SerDes/protobufSerDes.cpp: -------------------------------------------------------------------------------- 1 | //===- protobufSerDes.cpp - Protobuf Serializer for gRPC -------*- C++ -*-===// 2 | // 3 | // Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | // Exceptions. See the LICENSE file for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // 7 | //===----------------------------------------------------------------------===// 8 | /// 9 | /// \file 10 | /// This file defines the ProtobufSerDes class, which is a wrapper around the 11 | /// protobuf C++ interface to support gRPC communication between the client and 12 | /// server. The protobuf C++ interface is used to serialize and deserialize 13 | /// messages. 14 | /// 15 | //===----------------------------------------------------------------------===// 16 | 17 | #include "SerDes/protobufSerDes.h" 18 | #include "google/protobuf/descriptor.h" 19 | #include "google/protobuf/message.h" 20 | 21 | #include 22 | #include 23 | #include 24 | #include 25 | 26 | namespace MLBridge { 27 | inline void ProtobufSerDes::setFeature(const std::string &name, 28 | const int value) { 29 | Request->GetReflection()->SetInt32( 30 | Request, Request->GetDescriptor()->FindFieldByName(name), value); 31 | } 32 | 33 | inline void ProtobufSerDes::setFeature(const std::string &name, 34 | const long value) { 35 | Request->GetReflection()->SetInt64( 36 | Request, Request->GetDescriptor()->FindFieldByName(name), value); 37 | } 38 | 39 | inline void ProtobufSerDes::setFeature(const std::string &name, 40 | const float value) { 41 | Request->GetReflection()->SetFloat( 42 | Request, Request->GetDescriptor()->FindFieldByName(name), value); 43 | } 44 | 45 | inline void ProtobufSerDes::setFeature(const std::string &name, 46 | const double value) { 47 | Request->GetReflection()->SetDouble( 48 | Request, Request->GetDescriptor()->FindFieldByName(name), value); 49 | } 50 | 51 | inline void ProtobufSerDes::setFeature(const std::string &name, 52 | const std::string value) { 53 | Request->GetReflection()->SetString( 54 | Request, Request->GetDescriptor()->FindFieldByName(name), value); 55 | } 56 | 57 | inline void ProtobufSerDes::setFeature(const std::string &name, 58 | const bool value) { 59 | Request->GetReflection()->SetBool( 60 | Request, Request->GetDescriptor()->FindFieldByName(name), value); 61 | } 62 | 63 | inline void ProtobufSerDes::setFeature(const std::string &name, 64 | const std::vector &value) { 65 | auto ref = Request->GetReflection()->MutableRepeatedField( 66 | Request, Request->GetDescriptor()->FindFieldByName(name)); 67 | ref->Add(value.begin(), value.end()); 68 | } 69 | 70 | inline void ProtobufSerDes::setFeature(const std::string &name, 71 | const std::vector &value) { 72 | auto ref = Request->GetReflection()->MutableRepeatedField( 73 | Request, Request->GetDescriptor()->FindFieldByName(name)); 74 | ref->Add(value.begin(), value.end()); 75 | } 76 | 77 | inline void ProtobufSerDes::setFeature(const std::string &name, 78 | const std::vector &value) { 79 | auto ref = Request->GetReflection()->MutableRepeatedField( 80 | Request, Request->GetDescriptor()->FindFieldByName(name)); 81 | ref->Add(value.begin(), value.end()); 82 | } 83 | 84 | inline void ProtobufSerDes::setFeature(const std::string &name, 85 | const std::vector &value) { 86 | auto ref = Request->GetReflection()->MutableRepeatedField( 87 | Request, Request->GetDescriptor()->FindFieldByName(name)); 88 | ref->Add(value.begin(), value.end()); 89 | } 90 | 91 | void ProtobufSerDes::setFeature(const std::string &name, 92 | const std::vector &value) { 93 | auto reflection = Request->GetReflection(); 94 | auto descriptor = Request->GetDescriptor(); 95 | auto field = descriptor->FindFieldByName(name); 96 | for (auto &v : value) { 97 | reflection->AddString(Request, field, v); 98 | } 99 | } 100 | 101 | inline void ProtobufSerDes::setFeature(const std::string &name, 102 | const std::vector &value) { 103 | auto ref = Request->GetReflection()->MutableRepeatedField( 104 | Request, Request->GetDescriptor()->FindFieldByName(name)); 105 | ref->Add(value.begin(), value.end()); 106 | } 107 | 108 | void *ProtobufSerDes::getSerializedData() { 109 | std::string *data = new std::string(); 110 | Request->SerializeToString(data); 111 | cleanDataStructures(); 112 | return data; 113 | } 114 | 115 | void ProtobufSerDes::setFeature(const std::string &name, 116 | const google::protobuf::Message *value) { 117 | auto reflection = Request->GetReflection(); 118 | auto descriptor = Request->GetDescriptor(); 119 | auto field = descriptor->FindFieldByName(name); 120 | reflection->MutableMessage(Request, field)->CopyFrom(*value); 121 | } 122 | 123 | void ProtobufSerDes::setFeature( 124 | const std::string &name, 125 | const std::vector &value) { 126 | // set repeated field of messages in this->Request 127 | auto reflection = Request->GetReflection(); 128 | auto descriptor = Request->GetDescriptor(); 129 | auto field = descriptor->FindFieldByName(name); 130 | for (auto &v : value) { 131 | reflection->AddMessage(Request, field)->CopyFrom(*v); 132 | } 133 | } 134 | 135 | inline void ProtobufSerDes::setRequest(void *Request) { 136 | this->Request = reinterpret_cast(Request); 137 | } 138 | 139 | inline void ProtobufSerDes::setResponse(void *Response) { 140 | this->Response = reinterpret_cast(Response); 141 | } 142 | 143 | void *ProtobufSerDes::deserializeUntyped(void *data) { 144 | Request->Clear(); // todo: find correct place to clear request for protobuf 145 | // serdes 146 | Response = reinterpret_cast(data); 147 | 148 | const Descriptor *descriptor = Response->GetDescriptor(); 149 | const Reflection *reflection = Response->GetReflection(); 150 | const FieldDescriptor *field = descriptor->field(0); 151 | 152 | if (field->label() == FieldDescriptor::LABEL_REPEATED) { 153 | if (field->type() == FieldDescriptor::Type::TYPE_INT32) { 154 | auto &ref = reflection->GetRepeatedField(*Response, field); 155 | std::vector *ret = new std::vector(ref.begin(), ref.end()); 156 | this->MessageLength = ref.size() * sizeof(int32_t); 157 | return ret->data(); 158 | } 159 | if (field->type() == FieldDescriptor::Type::TYPE_INT64) { 160 | auto &ref = reflection->GetRepeatedField(*Response, field); 161 | std::vector *ret = 162 | new std::vector(ref.begin(), ref.end()); 163 | this->MessageLength = ref.size() * sizeof(int64_t); 164 | return ret->data(); 165 | } 166 | if (field->type() == FieldDescriptor::Type::TYPE_FLOAT) { 167 | auto ref = reflection->GetRepeatedField(*Response, field); 168 | std::vector *ret = new std::vector(ref.begin(), ref.end()); 169 | this->MessageLength = ref.size() * sizeof(float); 170 | return ret->data(); 171 | } 172 | if (field->type() == FieldDescriptor::Type::TYPE_DOUBLE) { 173 | auto ref = reflection->GetRepeatedField(*Response, field); 174 | std::vector *ret = 175 | new std::vector(ref.begin(), ref.end()); 176 | this->MessageLength = ref.size() * sizeof(double); 177 | return ret->data(); 178 | } 179 | if (field->type() == FieldDescriptor::Type::TYPE_STRING) { 180 | // yet to be tested 181 | std::cerr << "vector deserialization yet to be done\n"; 182 | exit(1); 183 | std::vector *ptr = new std::vector(); 184 | 185 | /* 186 | ISSUE: error: static assertion failed: We only support non-string scalars 187 | in RepeatedField. FIX: ?? 188 | */ 189 | // auto ref = reflection->GetRepeatedField(*Response, field); 190 | // for (auto &v : ref) { 191 | // ptr->push_back(v); 192 | // } 193 | return ptr; 194 | } 195 | if (field->type() == FieldDescriptor::Type::TYPE_BOOL) { 196 | auto ref = reflection->GetRepeatedField(*Response, field); 197 | std::vector *ptr = new std::vector( 198 | ref.mutable_data(), ref.mutable_data() + ref.size()); 199 | return ptr; 200 | } 201 | } 202 | 203 | if (field->type() == FieldDescriptor::Type::TYPE_INT32) { 204 | int32_t value = reflection->GetInt32(*Response, field); 205 | int32_t *ptr = new int32_t(value); 206 | this->MessageLength = sizeof(int32_t); 207 | return ptr; 208 | } 209 | if (field->type() == FieldDescriptor::Type::TYPE_INT64) { 210 | int64_t value = reflection->GetInt64(*Response, field); 211 | int64_t *ptr = new int64_t(value); 212 | this->MessageLength = sizeof(int64_t); 213 | return ptr; 214 | } 215 | if (field->type() == FieldDescriptor::Type::TYPE_FLOAT) { 216 | float value = reflection->GetFloat(*Response, field); 217 | float *ptr = new float(value); 218 | this->MessageLength = sizeof(float); 219 | return ptr; 220 | } 221 | if (field->type() == FieldDescriptor::Type::TYPE_DOUBLE) { 222 | double value = reflection->GetDouble(*Response, field); 223 | double *ptr = new double(value); 224 | this->MessageLength = sizeof(double); 225 | return ptr; 226 | } 227 | if (field->type() == FieldDescriptor::Type::TYPE_STRING) { 228 | std::string value = reflection->GetString(*Response, field); 229 | std::string *ptr = new std::string(value); 230 | this->MessageLength = value.length(); 231 | return ptr; 232 | } 233 | if (field->type() == FieldDescriptor::Type::TYPE_BOOL) { 234 | bool value = reflection->GetBool(*Response, field); 235 | bool *ptr = new bool(value); 236 | this->MessageLength = sizeof(bool); 237 | return ptr; 238 | } 239 | 240 | std::cerr << "Unknown type in protobuf serializer\n"; 241 | exit(1); 242 | } 243 | 244 | void ProtobufSerDes::cleanDataStructures() { 245 | Request->Clear(); 246 | Response->Clear(); 247 | } 248 | } // namespace MLBridge 249 | -------------------------------------------------------------------------------- /SerDes/tensorflowSerDes.cpp: -------------------------------------------------------------------------------- 1 | //===- tensorflowSerDes.cpp - Serializer support for TF ---------*- C++ -*-===// 2 | // 3 | // Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | // Exceptions. See the LICENSE file for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // 7 | //===----------------------------------------------------------------------===// 8 | /// 9 | /// \file 10 | /// This file defines the TensorflowSerDes class, to support interfacing with 11 | /// Tensorflow AOT models via TFModelRunner. 12 | /// 13 | //===----------------------------------------------------------------------===// 14 | 15 | #include "SerDes/tensorflowSerDes.h" 16 | #include "SerDes/baseSerDes.h" 17 | 18 | // #define EXCEPT_LONG(M) M(int) M(float) M(double) M(std::string) M(bool) 19 | namespace MLBridge { 20 | #define SET_FEATURE(TYPE, _) \ 21 | void TensorflowSerDes::setFeature(const std::string &Name, \ 22 | const TYPE Value) { \ 23 | std::string prefix = "feed_"; \ 24 | const int Index = CompiledModel->LookupArgIndex(prefix + Name); \ 25 | if (Index >= 0) \ 26 | *reinterpret_cast(CompiledModel->arg_data(Index)) = Value; \ 27 | } 28 | SUPPORTED_TYPES(SET_FEATURE) 29 | #undef SET_FEATURE 30 | 31 | void TensorflowSerDes::setFeature(const std::string &Name, 32 | const std::vector &Value) { 33 | std::string prefix = "feed_"; 34 | const int Index = CompiledModel->LookupArgIndex(prefix + Name); 35 | std::copy(Value.begin(), Value.end(), 36 | static_cast(CompiledModel->arg_data(Index))); 37 | } 38 | 39 | void TensorflowSerDes::setFeature(const std::string &Name, 40 | const std::vector &Value) { 41 | std::string prefix = "feed_"; 42 | const int Index = CompiledModel->LookupArgIndex(prefix + Name); 43 | std::copy(Value.begin(), Value.end(), 44 | static_cast(CompiledModel->arg_data(Index))); 45 | } 46 | 47 | void TensorflowSerDes::setFeature(const std::string &Name, 48 | const std::vector &Value) { 49 | std::string prefix = "feed_"; 50 | const int Index = CompiledModel->LookupArgIndex(prefix + Name); 51 | std::copy(Value.begin(), Value.end(), 52 | static_cast(CompiledModel->arg_data(Index))); 53 | } 54 | 55 | void TensorflowSerDes::setFeature(const std::string &Name, 56 | const std::vector &Value) { 57 | std::string prefix = "feed_"; 58 | const int Index = CompiledModel->LookupArgIndex(prefix + Name); 59 | std::copy(Value.begin(), Value.end(), 60 | static_cast(CompiledModel->arg_data(Index))); 61 | } 62 | 63 | void TensorflowSerDes::setFeature(const std::string &Name, 64 | const std::vector &Value) { 65 | std::string prefix = "feed_"; 66 | const int Index = CompiledModel->LookupArgIndex(prefix + Name); 67 | std::copy(Value.begin(), Value.end(), 68 | static_cast(CompiledModel->arg_data(Index))); 69 | } 70 | 71 | void TensorflowSerDes::setFeature(const std::string &Name, 72 | const std::vector &Value) { 73 | std::string prefix = "feed_"; 74 | const int Index = CompiledModel->LookupArgIndex(prefix + Name); 75 | std::copy(Value.begin(), Value.end(), 76 | static_cast(CompiledModel->arg_data(Index))); 77 | } 78 | } // namespace MLBridge 79 | -------------------------------------------------------------------------------- /cmake/modules/FindOnnxruntime.cmake: -------------------------------------------------------------------------------- 1 | find_library(ONNXRUNTIME_LIBRARY 2 | NAMES onnxruntime 3 | PATHS 4 | ${ONNXRUNTIME_ROOTDIR}/lib 5 | ) 6 | 7 | find_path(ONNXRUNTIME_INCLUDE_DIR 8 | NAMES onnxruntime_cxx_api.h 9 | PATHS ${ONNXRUNTIME_ROOTDIR}/include 10 | ) 11 | if(ONNXRUNTIME_LIBRARY AND ONNXRUNTIME_INCLUDE_DIR) 12 | set(ONNXRUNTIME_FOUND TRUE) 13 | else() 14 | set(ONNXRUNTIME_FOUND FALSE) 15 | endif() 16 | 17 | include(FindPackageHandleStandardArgs) 18 | find_package_handle_standard_args(Onnxruntime 19 | REQUIRED_VARS ONNXRUNTIME_LIBRARY 20 | ) 21 | 22 | if(ONNXRUNTIME_FOUND) 23 | add_library(Onnxruntime::Onnxruntime SHARED IMPORTED) 24 | set_target_properties(Onnxruntime::Onnxruntime PROPERTIES 25 | IMPORTED_LOCATION "${ONNXRUNTIME_LIBRARY}" 26 | INTERFACE_INCLUDE_DIRECTORIES "${ONNXRUNTIME_INCLUDE_DIR}" 27 | ) 28 | endif() 29 | set(ONNXRUNTIME_FOUND TRUE CACHE BOOL "Set to TRUE if onnxruntime is found") 30 | -------------------------------------------------------------------------------- /cmake/modules/TensorFlowCompile.cmake: -------------------------------------------------------------------------------- 1 | function(tf_get_absolute_path path base final_path) 2 | if (IS_ABSOLUTE ${path}) 3 | set(${final_path} ${path} PARENT_SCOPE) 4 | else() 5 | set(${final_path} ${base}/${path} PARENT_SCOPE) 6 | endif() 7 | endfunction() 8 | 9 | function(tf_get_model model final_path) 10 | string(FIND ${model} "http:" pos_http) 11 | string(FIND ${model} "https:" pos_https) 12 | if (${pos_http} EQUAL 0 OR ${pos_https} EQUAL 0) 13 | message("Downloading model " ${model}) 14 | string(FIND ${model} "/" fname_start REVERSE) 15 | math(EXPR fname_start "${fname_start}+1") 16 | string(SUBSTRING ${model} ${fname_start}+1 -1 fname) 17 | message("Model archive: " ${fname}) 18 | file(DOWNLOAD ${model} ${CMAKE_CURRENT_BINARY_DIR}/${fname}) 19 | file(ARCHIVE_EXTRACT INPUT 20 | ${CMAKE_CURRENT_BINARY_DIR}/${fname} 21 | DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/${fname}_model) 22 | set(${final_path} ${CMAKE_CURRENT_BINARY_DIR}/${fname}_model/model PARENT_SCOPE) 23 | else() 24 | tf_get_absolute_path(${model} ${CMAKE_CURRENT_BINARY_DIR} model_path) 25 | set(${final_path} ${model_path} PARENT_SCOPE) 26 | endif() 27 | endfunction() 28 | 29 | # Generate a mock model for tests. 30 | function(generate_mock_model generator output) 31 | tf_get_absolute_path(${generator} ${CMAKE_CURRENT_SOURCE_DIR} generator_absolute_path) 32 | tf_get_absolute_path(${output} ${CMAKE_CURRENT_BINARY_DIR} output_absolute_path) 33 | message(WARNING "Autogenerated mock models should not be used in production builds.") 34 | execute_process(COMMAND ${Python3_EXECUTABLE} 35 | ${generator_absolute_path} 36 | ${output_absolute_path} 37 | WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} 38 | ) 39 | endfunction() 40 | 41 | # Run the tensorflow compiler (saved_model_cli) on the saved model in the 42 | # ${model} directory, looking for the ${tag_set} tag set, and the SignatureDef 43 | # ${signature_def_key}. 44 | # Produce a pair of files called ${fname}.h and ${fname}.o in the 45 | # ${CMAKE_CURRENT_BINARY_DIR}. The generated header will define a C++ class 46 | # called ${cpp_class} - which may be a namespace-qualified class name. 47 | function(tf_compile model tag_set signature_def_key fname cpp_class hdr_file obj_file) 48 | tf_get_absolute_path(${model} ${CMAKE_CURRENT_BINARY_DIR} LLVM_ML_MODELS_ABSOLUTE) 49 | add_custom_command(OUTPUT ${obj_file} ${hdr_file} 50 | COMMAND ${TENSORFLOW_AOT_COMPILER} aot_compile_cpu 51 | --multithreading false 52 | --dir ${LLVM_ML_MODELS_ABSOLUTE} 53 | --tag_set ${tag_set} 54 | --signature_def_key ${signature_def_key} 55 | --output_prefix ${prefix} 56 | --cpp_class ${cpp_class} 57 | --target_triple ${LLVM_HOST_TRIPLE} 58 | ) 59 | 60 | # Aggregate the objects so that results of different tf_compile calls may be 61 | # grouped into one target. 62 | set(GENERATED_OBJS ${GENERATED_OBJS} ${obj_file} PARENT_SCOPE) 63 | set_source_files_properties(${obj_file} PROPERTIES 64 | GENERATED 1 EXTERNAL_OBJECT 1) 65 | 66 | set(GENERATED_HEADERS ${GENERATED_HEADERS} ${hdr_file} PARENT_SCOPE) 67 | set_source_files_properties(${hdr_file} PROPERTIES 68 | GENERATED 1) 69 | 70 | endfunction() 71 | 72 | function(tf_find_and_compile compiler_path model default_url default_path test_model_generator tag_set signature_def_key fname cpp_class opath) 73 | set(TENSORFLOW_AOT_COMPILER ${compiler_path}) 74 | set(prefix ${opath}/${fname}) 75 | set(obj_file ${prefix}.o) 76 | set(hdr_file ${prefix}.h) 77 | string(TOUPPER ${fname} fname_allcaps) 78 | set(override_header ${LLVM_OVERRIDE_MODEL_HEADER_${fname_allcaps}}) 79 | set(override_object ${LLVM_OVERRIDE_MODEL_OBJECT_${fname_allcaps}}) 80 | # If the user specified overrides, that indicates intent to use AOT and we 81 | # don't care what the model path is 82 | if (EXISTS "${override_header}" AND EXISTS "${override_object}") 83 | configure_file(${override_header} ${hdr_file} COPYONLY) 84 | configure_file(${override_object} ${obj_file} COPYONLY) 85 | message(STATUS "Using provided header " ${hdr_file} " and object " ${obj_file} " 86 | files for model " ${fname}) 87 | set(GENERATED_OBJS ${GENERATED_OBJS} ${obj_file}) 88 | set(GENERATED_HEADERS ${GENERATED_HEADERS} ${hdr_file}) 89 | elseif("${model}" STREQUAL "none") 90 | message(STATUS "Will skip enabling mlgo for ${fname}") 91 | return() 92 | else() 93 | if ("${model}" STREQUAL "download") 94 | # Crash if the user wants to download a model but a URL is set to "TO_BE_UPDATED" 95 | if ("${default_url}" STREQUAL "") 96 | message(FATAL_ERROR "Model path was set to 'download' but there is no" 97 | " model url currently specified in cmake. You can generate a model" 98 | " using, for example, the tools at http://github.com/google/ml-compiler-opt." 99 | " Some reference models are also periodically released there.") 100 | endif() 101 | 102 | set(model ${default_url}) 103 | endif() 104 | 105 | if ("${model}" STREQUAL "autogenerate") 106 | set(model ${default_path}-autogenerated) 107 | generate_mock_model(${test_model_generator} ${model}) 108 | endif() 109 | 110 | tf_get_model(${model} LLVM_ML_MODELS_ABSOLUTE) 111 | tf_compile(${LLVM_ML_MODELS_ABSOLUTE} ${tag_set} ${signature_def_key} ${fname} ${cpp_class} ${hdr_file} ${obj_file}) 112 | endif() 113 | 114 | set(GeneratedMLSources ${GeneratedMLSources} ${GENERATED_OBJS} ${GENERATED_HEADERS} PARENT_SCOPE) 115 | set(MLDeps ${MLDeps} tf_xla_runtime PARENT_SCOPE) 116 | set(MLLinkDeps ${MLLinkDeps} tf_xla_runtime PARENT_SCOPE) 117 | add_compile_definitions(LLVM_HAVE_TF_AOT_${fname_allcaps}) 118 | endfunction() 119 | -------------------------------------------------------------------------------- /images/component-ml-compiler-bridge.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IITH-Compilers/ML-Compiler-Bridge/53bf989f6e34125a2be6dd6b832cb852fb6922ce/images/component-ml-compiler-bridge.png -------------------------------------------------------------------------------- /include/MLModelRunner/C/ONNXModelRunner.h: -------------------------------------------------------------------------------- 1 | //=== MLModelRunner/C/ONNXModelRunner.h - C API for ONNXModelRunner - C++ -===// 2 | // 3 | // Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | // Exceptions. See the LICENSE file for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // 7 | //===---------------------------------------------------------------------===// 8 | /// 9 | /// \file 10 | /// This file defines the C APIs for ONNXModelRunner. 11 | /// This is a wrapper around the ONNXModelRunner class that provides an 12 | /// interface for the MLCompilerBridge to interact with the ONNX models during 13 | /// inference. 14 | /// 15 | /// Usage for single agent: 16 | /// 1. Create an ONNXModelRunner object using createSingleAgentOMR 17 | /// 2. Evaluate the features using singleAgentEvaluate 18 | /// 3. Destroy the instance of ONNXModelRunner using destroyONNXModelRunner 19 | /// 20 | /// Usage for multiple agents: 21 | /// 1. Create an Environment object using createEnvironment 22 | /// 2. Set the number of features using env_setNumFeatures 23 | /// 3. Set the step function using env_setStepFunc 24 | /// 4. Set the reset function using env_setResetFunc 25 | /// 5. Set the next agent using env_setNextAgent 26 | /// 6. Create an ONNXModelRunner object using createONNXModelRunner 27 | /// 7. Evaluate the features using evaluate 28 | /// 8. Destroy the instance of ONNXModelRunner using destroyONNXModelRunner 29 | /// 9. Destroy the instance of Environment using destroyEnvironment 30 | /// 31 | /// Using Environment: 32 | /// 1. Create an Environment object using createEnvironment 33 | /// 2. Set the number of features using env_setNumFeatures 34 | /// 3. Set the step function using env_setStepFunc 35 | /// 4. Set the reset function using env_setResetFunc 36 | /// 5. Set the next agent using env_setNextAgent 37 | /// 6. Destroy the instance of Environment using destroyEnvironment after 38 | /// calling 39 | /// destroyONNXModelRunner. 40 | /// 41 | /// Internally, the ONNXModelRunner will call the step function to get the next 42 | /// action and the reset function to reset the environment. The step function 43 | /// should return a pointer to an array of floats. The reset function should 44 | /// return a pointer to an array of floats. 45 | /// 46 | //===---------------------------------------------------------------------===// 47 | 48 | #ifndef ONNX_MODEL_RUNNER_WRAPPER_H 49 | #define ONNX_MODEL_RUNNER_WRAPPER_H 50 | 51 | typedef struct ONNXModelRunner ONNXModelRunner; 52 | typedef struct Environment Environment; 53 | typedef signed Action; 54 | 55 | #ifdef __cplusplus 56 | extern "C" { 57 | #endif 58 | 59 | Environment *createEnvironment(); 60 | void env_setDone(Environment *env); 61 | void env_resetDone(Environment *env); 62 | bool env_checkDone(Environment *env); 63 | void env_setNumFeatures(Environment *env, int numFeatures); 64 | void env_setStepFunc(Environment *env, float *(*stepFunc)(Action action)); 65 | void env_setResetFunc(Environment *env, float *(*resetFunc)()); 66 | void env_setNextAgent(Environment *env, char *agentName); 67 | 68 | ONNXModelRunner *createONNXModelRunner(Environment *env, int numAgents, ...); 69 | ONNXModelRunner *createSingleAgentOMR(char *agent_path); 70 | void evaluate(ONNXModelRunner *obj); 71 | int singleAgentEvaluate(ONNXModelRunner *obj, float *inp, int inp_size); 72 | void destroyEnvironment(Environment *env); 73 | void destroyONNXModelRunner(ONNXModelRunner *obj); 74 | 75 | #ifdef __cplusplus 76 | } 77 | #endif 78 | 79 | #endif // ONNX_MODEL_RUNNER_WRAPPER_H 80 | -------------------------------------------------------------------------------- /include/MLModelRunner/C/PipeModelRunner.h: -------------------------------------------------------------------------------- 1 | //===--- MLModelRunner/C/PipeModelRunner.h - C API for PipeModelRunner ---===// 2 | // 3 | // Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | // Exceptions. See the LICENSE file for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // 7 | //===---------------------------------------------------------------------===// 8 | /// 9 | /// \file 10 | /// This file defines the C API for PipeModelRunner. 11 | /// PipeModelRunner is a wrapper around the MLModelRunner class that provides 12 | /// an interface for the MLCompilerBridge to interact with the PipeModelRunner 13 | /// class. 14 | /// 15 | /// Usage: 16 | /// 1. Create an instance of PipeModelRunnerWrapper using createPipeModelRunner 17 | /// 2. Populate the features using populateXXXFeatures functions 18 | /// 3. Evaluate the features using evaluateXXXFeatures functions 19 | /// 4. Destroy the instance of PipeModelRunnerWrapper using 20 | /// destroyPipeModelRunner 21 | /// 22 | //===---------------------------------------------------------------------===// 23 | 24 | #ifndef PIPE_MODEL_RUNNER_WRAPPER_H 25 | #define PIPE_MODEL_RUNNER_WRAPPER_H 26 | 27 | #ifdef __cplusplus 28 | extern "C" { 29 | #endif 30 | 31 | // Define an opaque pointer type for PipeModelRunnerWrapper 32 | typedef struct PipeModelRunnerWrapper PipeModelRunnerWrapper; 33 | 34 | /// Creates an instance of PipeModelRunnerWrapper 35 | PipeModelRunnerWrapper *createPipeModelRunner(const char *outBoundName, 36 | const char *inBoundName, 37 | int serDesType); 38 | 39 | /// Populates the features of PipeModelRunnerWrapper 40 | void populateFloatFeatures(PipeModelRunnerWrapper *obj, const char *name, 41 | const float *data, const int size); 42 | void populateIntFeatures(PipeModelRunnerWrapper *obj, const char *name, 43 | const int *data, const int size); 44 | 45 | /// Evaluates the features of PipeModelRunnerWrapper 46 | int evaluateIntFeatures(PipeModelRunnerWrapper *obj); 47 | float evaluateFloatFeatures(PipeModelRunnerWrapper *obj); 48 | 49 | /// Destroys an instance of PipeModelRunnerWrapper 50 | void destroyPipeModelRunner(PipeModelRunnerWrapper *obj); 51 | 52 | #ifdef __cplusplus 53 | } 54 | #endif 55 | 56 | #endif // PIPE_MODEL_RUNNER_WRAPPER_H 57 | -------------------------------------------------------------------------------- /include/MLModelRunner/MLModelRunner.h: -------------------------------------------------------------------------------- 1 | //===- MLModelRunner.h ---- ML model runner interface -----------*- C++ -*-===// 2 | // 3 | // Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | // Exceptions. See the LICENSE file for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // 7 | // (Preliminary version adopted from MLModelRunner.h of LLVM 17.X) 8 | // 9 | //===----------------------------------------------------------------------===// 10 | /// 11 | /// \file 12 | /// The MLModelRunner class is the main interface for interacting with the 13 | /// ML models. The MLCompilerBridge uses the MLModelRunner class to set the 14 | /// features to be sent to the model and get the result back from the model. 15 | /// 16 | /// This class internally uses the SerDes class to serialize and deserialize the 17 | /// features and result. 18 | /// 19 | /// The MLModelRunner class is an abstract class and cannot be instantiated. 20 | /// 21 | /// This class internally uses the SerDes class to serialize and deserialize the 22 | /// features and result. 23 | /// 24 | /// Supporting new Model Runners: 25 | /// 1. Create a new class inheriting the MLModelRunner class. 26 | /// 2. Override evaluateUntyped() method to call the model and get the result. 27 | /// 28 | /// Using any of the existing Model Runners: 29 | /// 1. Instantiate the model runner object with the appropriate arguments. 30 | /// 2. Call populateFeatures() to set the features to be sent to the model. 31 | /// 3. Call evaluate() to get the send and receive the result from the model. 32 | /// Similar flows apply for both training and inference. 33 | /// 34 | //===----------------------------------------------------------------------===// 35 | 36 | #ifndef ML_MODEL_RUNNER_H 37 | #define ML_MODEL_RUNNER_H 38 | 39 | #include "SerDes/baseSerDes.h" 40 | #include "SerDes/bitstreamSerDes.h" 41 | #include "SerDes/jsonSerDes.h" 42 | 43 | #include 44 | #include 45 | #include 46 | #include 47 | #include 48 | 49 | #ifndef C_LIBRARY 50 | #include "SerDes/protobufSerDes.h" 51 | #include "SerDes/tensorflowSerDes.h" 52 | #endif 53 | namespace MLBridge { 54 | 55 | /// MLModelRunner - The main interface for interacting with the ML models. 56 | class MLModelRunner { 57 | public: 58 | // Disallows copy and assign. 59 | MLModelRunner(const MLModelRunner &) = delete; 60 | MLModelRunner &operator=(const MLModelRunner &) = delete; 61 | virtual ~MLModelRunner() = default; 62 | 63 | /// Main user-facing method for interacting with the ML models 64 | template 65 | typename std::enable_if::value, T>::type evaluate() { 66 | return *reinterpret_cast(evaluateUntyped()); 67 | } 68 | 69 | /// Main user-facing method for interacting with the ML models 70 | template 71 | typename std::enable_if< 72 | std::is_fundamental::type>::value, 73 | void>::type 74 | evaluate(T &data, size_t &dataSize) { 75 | using BaseType = typename std::remove_pointer::type; 76 | void *res = evaluateUntyped(); 77 | T ret = static_cast(malloc(SerDes->getMessageLength())); 78 | memcpy(ret, res, SerDes->getMessageLength()); 79 | dataSize = SerDes->getMessageLength() / sizeof(BaseType); 80 | data = ret; 81 | } 82 | 83 | /// Type of the MLModelRunner 84 | enum class Kind : int { Unknown, Pipe, gRPC, ONNX, TFAOT }; 85 | 86 | Kind getKind() const { return Type; } 87 | SerDesKind getSerDesKind() const { return SerDesType; } 88 | 89 | virtual void requestExit() = 0; 90 | 91 | /// User-facing interface for setting the features to be sent to the model. 92 | /// The features are passed as a list of key-value pairs. 93 | /// The key is the name of the feature and the value is the value of the 94 | /// feature. The value can be a scalar or a vector. 95 | 96 | template 97 | void populateFeatures(const std::pair &var1, 98 | const std::pair &...var2) { 99 | SerDes->setFeature(var1.first, var1.second); 100 | populateFeatures(var2...); 101 | } 102 | 103 | template 104 | void populateFeatures(const std::pair &&var1, 105 | const std::pair &&...var2) { 106 | SerDes->setFeature(var1.first, var1.second); 107 | populateFeatures(var2...); 108 | } 109 | 110 | void populateFeatures() {} 111 | 112 | /// Mainly used in the case of gRPC where the request object is 113 | /// not known explicitly. 114 | void setRequest(void *request) { SerDes->setRequest(request); } 115 | 116 | /// Mainly used in the case of gRPC where the response object is 117 | /// not known explicitly. 118 | void setResponse(void *response) { SerDes->setResponse(response); } 119 | 120 | protected: 121 | MLModelRunner(Kind Type, SerDesKind SerDesType, 122 | llvm::LLVMContext *Ctx = nullptr) 123 | : Ctx(Ctx), Type(Type), SerDesType(SerDesType) { 124 | assert(Type != Kind::Unknown); 125 | initSerDes(); 126 | } 127 | 128 | MLModelRunner(Kind Type, llvm::LLVMContext *Ctx = nullptr) 129 | : Ctx(Ctx), Type(Type), SerDesType(SerDesKind::Unknown) { 130 | SerDes = nullptr; 131 | }; 132 | 133 | /// Should be implemented by the derived class to call the model and get the 134 | /// result. 135 | virtual void *evaluateUntyped() = 0; 136 | 137 | llvm::LLVMContext *Ctx; 138 | const Kind Type; 139 | const SerDesKind SerDesType; 140 | 141 | protected: 142 | std::unique_ptr SerDes; 143 | 144 | private: 145 | void initSerDes() { 146 | switch (SerDesType) { 147 | case SerDesKind::Json: 148 | SerDes = std::make_unique(); 149 | break; 150 | case SerDesKind::Bitstream: 151 | SerDes = std::make_unique(); 152 | break; 153 | #ifndef C_LIBRARY 154 | case SerDesKind::Protobuf: 155 | SerDes = std::make_unique(); 156 | break; 157 | case SerDesKind::Tensorflow: 158 | SerDes = std::make_unique(); 159 | break; 160 | #endif 161 | case SerDesKind::Unknown: 162 | SerDes = nullptr; 163 | break; 164 | } 165 | } 166 | }; 167 | } // namespace MLBridge 168 | 169 | #endif // LLVM_MLMODELRUNNER_H 170 | -------------------------------------------------------------------------------- /include/MLModelRunner/ONNXModelRunner/ONNXModelRunner.h: -------------------------------------------------------------------------------- 1 | //=== MLModelRunner/ONNXModelRunner/ONNXModelRunner.h - C++ --------------===// 2 | // 3 | // Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | // Exceptions. See the LICENSE file for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // 7 | //===----------------------------------------------------------------------===// 8 | /// 9 | /// \file 10 | /// ONNXModelRunner class supporting communication via ONNX C++ Runtime. 11 | /// Only inference is supported. 12 | /// 13 | /// This class interfaces with Environment and Agent classes to support 14 | /// ML/RL model inference via ONNXModel. 15 | /// 16 | /// Usage: 17 | /// 1. Construct an ONNXModelRunner object with the environment and the agents. 18 | /// Environment and agents are created by the user by inheriting from the 19 | /// Environment class and using the Agent class. 20 | /// 2. Multiple agents can be added to the ONNXModelRunner object using the 21 | /// addAgent() method. The agents are identified by a unique name. 22 | /// 3. Call evaluate() to get the result from the model. 23 | /// 24 | /// Internally the ONNXModelRunner object will call the step() method of the 25 | /// environment to get the next observation and the computeAction() method of 26 | /// the agent to get the action corresponding to the observation. 27 | /// 28 | //===----------------------------------------------------------------------===// 29 | 30 | #ifndef ONNX_MODELRUNNER_H 31 | #define ONNX_MODELRUNNER_H 32 | 33 | #include "MLModelRunner/MLModelRunner.h" 34 | #include "MLModelRunner/ONNXModelRunner/agent.h" 35 | #include "MLModelRunner/ONNXModelRunner/environment.h" 36 | 37 | namespace MLBridge { 38 | 39 | /// ONNXModelRunner is the main user facing class that interfaces with the 40 | /// Environment and Agent classes to support ML/RL model inference via 41 | /// ONNXModel. 42 | class ONNXModelRunner : public MLModelRunner { 43 | public: 44 | ONNXModelRunner(MLBridge::Environment *env, 45 | std::map agents, 46 | llvm::LLVMContext *Ctx = nullptr); 47 | 48 | void setEnvironment(MLBridge::Environment *_env) { env = _env; } 49 | MLBridge::Environment *getEnvironment() { return env; } 50 | 51 | void addAgent(Agent *agent, std::string name); 52 | 53 | void requestExit() override {} 54 | 55 | private: 56 | MLBridge::Environment *env; 57 | std::map agents; 58 | void computeAction(Observation &obs); 59 | void *evaluateUntyped() override; 60 | }; 61 | } // namespace MLBridge 62 | #endif // ONNX_MODELRUNNER_H 63 | -------------------------------------------------------------------------------- /include/MLModelRunner/ONNXModelRunner/agent.h: -------------------------------------------------------------------------------- 1 | //=== MLModelRunner/ONNXModelRunner/agent.h - Agent Model Helper - C++ -===// 2 | // 3 | // Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | // Exceptions. See the LICENSE file for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // 7 | //===---------------------------------------------------------------------===// 8 | /// 9 | /// \file 10 | /// Agent class to support ML/RL model inference via ONNX 11 | /// 12 | /// Usage: 13 | /// 1. Construct an agent object with the path to the ONNX model 14 | /// 2. Call computeAction() to get the action from the model 15 | /// 16 | //===----------------------------------------------------------------------===// 17 | 18 | #ifndef ONNX_MODELRUNNER_AGENT_H 19 | #define ONNX_MODELRUNNER_AGENT_H 20 | 21 | #include "MLModelRunner/ONNXModelRunner/onnx.h" 22 | #include "MLModelRunner/ONNXModelRunner/utils.h" 23 | 24 | #include 25 | #include 26 | 27 | namespace MLBridge { 28 | /// Agent is a wrapper around the ONNXModel class, interfaces with the 29 | /// Environment class to support ML/RL model inference via ONNXModel. 30 | class Agent { 31 | ONNXModel *model; 32 | 33 | public: 34 | Agent(std::string model_path); 35 | /// Runs the ONNX model on the input Observation and returns the output 36 | unsigned computeAction(Observation &obs); 37 | }; 38 | } // namespace MLBridge 39 | 40 | #endif // ONNX_MODELRUNNER_AGENT_H 41 | -------------------------------------------------------------------------------- /include/MLModelRunner/ONNXModelRunner/environment.h: -------------------------------------------------------------------------------- 1 | //=== MLModelRunner/ONNXModelRunner/environment.h - ONNX Environment C++ -===// 2 | // 3 | // Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | // Exceptions. See the LICENSE file for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // 7 | //===---------------------------------------------------------------------===// 8 | /// 9 | /// \file 10 | /// Base Environment class to support ONNX based inference of RL models. This 11 | /// class is used to define the environment for the agents to interact with. 12 | /// 13 | /// The Environment should be defined by the compiler pass that is using the 14 | /// MLCompilerBridge. The environment should be defined by inheriting from this 15 | /// class and implementing the step() and reset() methods. 16 | /// 17 | /// step() and reset() are typical methods used in RL environments. 18 | /// 19 | /// The step() method takes an action as input and returns the observation 20 | /// corresponding to the next state. The reset() method returns the initial 21 | /// observation. 22 | /// 23 | /// Usage: 24 | /// 1. Create an environment class inheriting from MLBridge::Environment 25 | /// 2. Implement step() and reset() methods 26 | /// 27 | /// Example: 28 | /// \code 29 | /// class MyEnvironment : public MLBridge::Environment { 30 | /// public: 31 | /// Observation &step(Action action) override { 32 | /// // Implement the step function here 33 | /// } 34 | /// Observation &reset() override { 35 | /// // Implement the reset function here 36 | /// } 37 | /// }; 38 | /// \endcode 39 | /// 40 | /// This environment can then be used by the ONNXModelRunner to interact with 41 | /// the agents. getNextAgent() and setNextAgent() methods can be used to set the 42 | /// next agent to interact with. These methods are used in step() and reset() to 43 | /// get the next agent to interact with in case of multi-agent environment. 44 | /// 45 | //===----------------------------------------------------------------------===// 46 | 47 | #ifndef ONNX_MODELRUNNER_ENVIRONMENT_H 48 | #define ONNX_MODELRUNNER_ENVIRONMENT_H 49 | 50 | #include "MLModelRunner/ONNXModelRunner/agent.h" 51 | #include 52 | 53 | typedef signed Action; 54 | 55 | namespace MLBridge { 56 | class Environment { 57 | bool done = false; 58 | std::string nextAgent = ""; 59 | 60 | protected: 61 | std::map obsMap; 62 | 63 | public: 64 | /// CheckDone returns true if the termination condition is met at the end of 65 | /// the episode. 66 | bool checkDone() { return done == true; }; 67 | 68 | /// SetDone sets the termination condition to true. 69 | void setDone() { done = true; } 70 | void resetDone() { done = false; } 71 | 72 | /// GetNextAgent returns the name/ID of the next agent to interact with. 73 | std::string getNextAgent() { return nextAgent; }; 74 | 75 | /// SetNextAgent sets the name of the next agent to interact with. 76 | void setNextAgent(std::string name) { nextAgent = name; } 77 | 78 | /// Step function takes an action as input and returns the observation 79 | /// corresponding to the next state. This method should be implemented by the 80 | /// user. Typically this method should call setDone() if the termination 81 | /// condition is met. setNextAgent() can be called in this method to set the 82 | /// next agent to interact with. 83 | virtual Observation &step(Action action) = 0; 84 | 85 | /// Reset function returns the initial observation. This method should be 86 | /// implemented by the user. This method can internally call setNextAgent() to 87 | /// set the next agent to interact with. 88 | virtual Observation &reset() = 0; 89 | }; 90 | } // namespace MLBridge 91 | #endif // ONNX_MODELRUNNER_ENVIRONMENT_H 92 | -------------------------------------------------------------------------------- /include/MLModelRunner/ONNXModelRunner/onnx.h: -------------------------------------------------------------------------------- 1 | //=== MLModelRunner/ONNXModelRunner/onnx.h --- ONNX C++ Interface -----===// 2 | // 3 | // Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | // Exceptions. See the LICENSE file for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // 7 | //===---------------------------------------------------------------------===// 8 | /// 9 | /// \file 10 | /// This file defines the ONNXModel class, which is a wrapper around the ONNX 11 | /// C++ interface. 12 | /// 13 | //===---------------------------------------------------------------------===// 14 | 15 | #ifndef ONNX_MODELRUNNER_ONNX_H 16 | #define ONNX_MODELRUNNER_ONNX_H 17 | 18 | #include "llvm/ADT/SmallVector.h" 19 | 20 | #include 21 | #include 22 | 23 | namespace Ort { 24 | class Env; 25 | class Session; 26 | class Value; 27 | } // namespace Ort 28 | 29 | class ONNXModel { 30 | Ort::Env *env; 31 | const char *model_path; 32 | Ort::Session *session; 33 | Ort::Value getInputValue(llvm::SmallVector &input, int inputIdx); 34 | 35 | public: 36 | ONNXModel(const char *model_path); 37 | 38 | /// Runs the ONNX model on the input and returns the output 39 | void run(llvm::SmallVector &input, 40 | llvm::SmallVector &output); 41 | }; 42 | 43 | #endif // ONNX_MODELRUNNER_ONNX_H 44 | -------------------------------------------------------------------------------- /include/MLModelRunner/ONNXModelRunner/utils.h: -------------------------------------------------------------------------------- 1 | //=== MLModelRunner/ONNXModelRunner/utils.h - C++ ------------------------===// 2 | // 3 | // Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | // Exceptions. See the LICENSE file for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // 7 | //===----------------------------------------------------------------------===// 8 | 9 | #ifndef ONNX_MODELRUNNER_UTILS_H 10 | #define ONNX_MODELRUNNER_UTILS_H 11 | 12 | #include "llvm/ADT/SmallVector.h" 13 | 14 | namespace MLBridge { 15 | typedef llvm::SmallVector Observation; 16 | 17 | } // namespace MLBridge 18 | 19 | #endif // ONNX_MODELRUNNER_UTILS_H 20 | -------------------------------------------------------------------------------- /include/MLModelRunner/PipeModelRunner.h: -------------------------------------------------------------------------------- 1 | //===- PipeModelRunner.h ---- PipeModelRunner ------*- C++ -*-------------===// 2 | // 3 | // Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | // Exceptions. See the LICENSE file for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // 7 | // (Preliminary version adopted from InteractiveModelRunner.h of LLVM 17.X) 8 | // 9 | //===----------------------------------------------------------------------===// 10 | /// 11 | /// \file 12 | /// PipeModelRunner class supporting communication via OS pipes between the 13 | /// compiler and an external ML agent. 14 | /// 15 | /// Usage: 16 | /// 1. Create a PipeModelRunner object with the names of the pipes, and the 17 | /// serialization technique. 18 | /// 2. Populate the features to be sent to the model. 19 | /// 3. Call evaluate() to get the result back from the model. 20 | /// 21 | /// This supports both training and inference. Supports interleaved 22 | /// communication. 23 | /// 24 | //===----------------------------------------------------------------------===// 25 | 26 | #ifndef PipeModelRunner_H 27 | #define PipeModelRunner_H 28 | 29 | #include "MLModelRunner/MLModelRunner.h" 30 | #include "SerDes/TensorSpec.h" 31 | #include "SerDes/baseSerDes.h" 32 | #include "llvm/Support/FileSystem.h" 33 | #include 34 | #include 35 | #include 36 | 37 | namespace MLBridge { 38 | 39 | /// A MLModelRunner that asks for advice from an external agent, or host. It 40 | /// uses 2 files - ideally named pipes - one to send data to that agent, and 41 | /// one to receive advice. 42 | /// The compiler will send observations; the host is expected to reply with a 43 | /// tensor value after each observation as a binary buffer that's conforming to 44 | /// the shape of the advice. Interleaved, the data closely resembles the 45 | /// training log for a log where we don't capture the reward signal. 46 | /// 47 | /// Note that the correctness of the received data is the responsibility of the 48 | /// host. In particular, if insufficient data were sent, the compiler will block 49 | /// when waiting for an advice. 50 | /// 51 | /// Note that the host can either open the pipes RW, or open first the pipe to 52 | /// the compiler - i.e. the "Inbound" - and then the "Outbound", to avoid 53 | /// deadlock. This is because the compiler first tries to open the inbound 54 | /// (which will hang until there's a writer on the other end). 55 | class PipeModelRunner : public MLModelRunner { 56 | public: 57 | PipeModelRunner(llvm::StringRef OutboundName, llvm::StringRef InboundName, 58 | SerDesKind Kind, llvm::LLVMContext *Ctx = nullptr); 59 | 60 | static bool classof(const MLModelRunner *R) { 61 | return R->getKind() == MLModelRunner::Kind::Pipe; 62 | } 63 | 64 | void requestExit() override {} 65 | virtual ~PipeModelRunner(); 66 | 67 | private: 68 | void send(void *data); 69 | void *receive(); 70 | void *evaluateUntyped() override; 71 | std::string readNBytes(size_t N); 72 | // This must be declared before InEC if we want to initialize it in the 73 | // ctor initializer list. 74 | std::string InboundName; 75 | int Inbound = -1; 76 | std::error_code OutEC; 77 | std::error_code InEC; 78 | std::unique_ptr OutStream; 79 | }; 80 | } // namespace MLBridge 81 | #endif // PipeModelRunner_H 82 | -------------------------------------------------------------------------------- /include/MLModelRunner/TFModelRunner.h: -------------------------------------------------------------------------------- 1 | //===- TFModelRunner.h ---- TF precompiled model runner ------*- C++-*----===// 2 | // 3 | // Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | // Exceptions. See the LICENSE file for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // 7 | // (Preliminary version adopted from ReleaseModeModelRunner.h of LLVM 17.X) 8 | // 9 | //===----------------------------------------------------------------------===// 10 | /// 11 | /// \file 12 | /// This file implements a model runner wrapping an AOT compiled ML model. 13 | /// Only inference is supported. 14 | /// 15 | //===----------------------------------------------------------------------===// 16 | 17 | #ifndef TFMODELRUNNER_H 18 | #define TFMODELRUNNER_H 19 | 20 | #include "MLModelRunner/MLModelRunner.h" 21 | #include "SerDes/TensorSpec.h" 22 | 23 | #include 24 | #include 25 | 26 | namespace MLBridge { 27 | 28 | /// TFModelRunner - TF Compiled model implementation of the 29 | /// MLModelRunner. It uses an AOT-compiled SavedModel for efficient execution. 30 | template class TFModelRunner final : public MLModelRunner { 31 | public: 32 | /// FeatureNames' type should be an indexed collection of std::string, like 33 | /// std::array or std::vector, that has a size() method. 34 | TFModelRunner(llvm::StringRef DecisionName, llvm::LLVMContext &Ctx, 35 | llvm::StringRef FeedPrefix = "feed_", 36 | llvm::StringRef FetchPrefix = "fetch_") 37 | : MLModelRunner(MLModelRunner::Kind::TFAOT, SerDesKind::Tensorflow, &Ctx), 38 | CompiledModel(std::make_unique()) { 39 | 40 | SerDes->setRequest(CompiledModel.get()); 41 | 42 | assert(CompiledModel && "The CompiledModel should be valid"); 43 | 44 | ResultIndex = CompiledModel->LookupResultIndex(FetchPrefix.str() + 45 | DecisionName.str()); 46 | assert(ResultIndex >= 0 && "Cannot find DecisionName in inlining model"); 47 | } 48 | TFModelRunner(llvm::StringRef DecisionName, 49 | llvm::StringRef FeedPrefix = "feed_", 50 | llvm::StringRef FetchPrefix = "fetch_") 51 | : MLModelRunner(MLModelRunner::Kind::TFAOT, SerDesKind::Tensorflow), 52 | CompiledModel(std::make_unique()) { 53 | 54 | SerDes->setRequest(CompiledModel.get()); 55 | 56 | assert(CompiledModel && "The CompiledModel should be valid"); 57 | 58 | ResultIndex = CompiledModel->LookupResultIndex(FetchPrefix.str() + 59 | DecisionName.str()); 60 | assert(ResultIndex >= 0 && "Cannot find DecisionName in inlining model"); 61 | } 62 | virtual ~TFModelRunner() = default; 63 | 64 | virtual void requestExit() override { 65 | llvm_unreachable("requestExit() is not supported in TFModelRunner"); 66 | } 67 | 68 | static bool classof(const MLModelRunner *R) { 69 | return R->getKind() == MLModelRunner::Kind::TFAOT; 70 | } 71 | 72 | private: 73 | void *evaluateUntyped() override { 74 | CompiledModel->Run(); 75 | return CompiledModel->result_data(ResultIndex); 76 | } 77 | 78 | int32_t ResultIndex = -1; 79 | std::unique_ptr CompiledModel; 80 | }; 81 | 82 | /// A mock class satisfying the interface expected by ReleaseModeModelRunner for 83 | /// its `TGen` parameter. Useful to avoid conditional compilation complexity, as 84 | /// a compile-time replacement for a real AOT-ed model. 85 | class NoopSavedModelImpl final { 86 | #define NOOP_MODEL_ERRMSG \ 87 | "The mock AOT-ed saved model is a compile-time stub and should not be " \ 88 | "called." 89 | 90 | public: 91 | NoopSavedModelImpl() = default; 92 | int LookupArgIndex(const std::string &) { 93 | llvm_unreachable(NOOP_MODEL_ERRMSG); 94 | } 95 | int LookupResultIndex(const std::string &) { 96 | llvm_unreachable(NOOP_MODEL_ERRMSG); 97 | } 98 | void Run() { llvm_unreachable(NOOP_MODEL_ERRMSG); } 99 | void *result_data(int) { llvm_unreachable(NOOP_MODEL_ERRMSG); } 100 | void *arg_data(int) { llvm_unreachable(NOOP_MODEL_ERRMSG); } 101 | #undef NOOP_MODEL_ERRMSG 102 | }; 103 | 104 | template bool isEmbeddedModelEvaluatorValid() { return true; } 105 | 106 | template <> inline bool isEmbeddedModelEvaluatorValid() { 107 | return false; 108 | } 109 | } // namespace MLBridge 110 | 111 | #endif // TFMODELRUNNER_H 112 | -------------------------------------------------------------------------------- /include/MLModelRunner/Utils/DataTypes.h: -------------------------------------------------------------------------------- 1 | //=== MLModelRunner/Utils/DataTypes.h - Supported Data Types - C++ -------===// 2 | // 3 | // Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | // Exceptions. See the LICENSE file for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // 7 | //===----------------------------------------------------------------------===// 8 | /// 9 | /// \file 10 | /// This file defines the bit widths of integral and floating point types 11 | /// supported by the MLCompilerBridge. 12 | /// 13 | /// The bit widths of floating and integral types supported by the 14 | /// MLCompilerBridge can be configured by defining the MLBRIDGE_EXTENDED_TYPES 15 | /// macro in the CMakeLists.txt file. 16 | /// 17 | //===----------------------------------------------------------------------===// 18 | 19 | #ifndef MLBRIDGE_DATATYPES_H 20 | #define MLBRIDGE_DATATYPES_H 21 | 22 | namespace MLBridge { 23 | 24 | #ifdef MLBRIDGE_EXTENDED_TYPES 25 | using RealType = double; 26 | using IntegerType = long; 27 | #else 28 | using RealType = float; 29 | using IntegerType = int; 30 | #endif 31 | 32 | } // namespace MLBridge 33 | 34 | #endif 35 | -------------------------------------------------------------------------------- /include/MLModelRunner/Utils/Debug.h: -------------------------------------------------------------------------------- 1 | //=== MLModelRunner/Utils/Debug.h - Debug definitions with support - C++ --===// 2 | // 3 | // Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | // Exceptions. See the LICENSE file for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // 7 | //===----------------------------------------------------------------------===// 8 | /// 9 | /// \file 10 | /// This file defines the debug macros for the MLCompilerBridge. 11 | /// 12 | //===----------------------------------------------------------------------===// 13 | 14 | #ifndef MLBRIDGE_DEBUG_H 15 | #define MLBRIDGE_DEBUG_H 16 | 17 | namespace MLBridge { 18 | 19 | #ifdef DEBUG_MODE 20 | #define MLBRIDGE_DEBUG(X) \ 21 | do { \ 22 | X; \ 23 | } while (false) 24 | #else 25 | #define MLBRIDGE_DEBUG(X) 26 | #endif 27 | 28 | } // namespace MLBridge 29 | 30 | #endif 31 | -------------------------------------------------------------------------------- /include/MLModelRunner/Utils/MLConfig.h: -------------------------------------------------------------------------------- 1 | //=== MLModelRunner/Utils/MLConfig.h -MLConfig class definition -*- C++ -*-===// 2 | // 3 | // Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | // Exceptions. See the LICENSE file for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // 7 | //===----------------------------------------------------------------------===// 8 | /// 9 | /// \file 10 | /// This file defines the MLConfig class, which is a wrapper around the MLConfig 11 | /// command line option for passing information like path of the models and 12 | /// other configuration to the compiler passes. 13 | /// 14 | //===----------------------------------------------------------------------===// 15 | 16 | #ifndef MLBRIDGE_CONFIG_H 17 | #define MLBRIDGE_CONFIG_H 18 | 19 | #include "llvm/Support/CommandLine.h" 20 | 21 | namespace MLBridge { 22 | namespace MLConfig { 23 | extern llvm::cl::opt mlconfig; 24 | } // namespace MLConfig 25 | } // namespace MLBridge 26 | 27 | #endif 28 | -------------------------------------------------------------------------------- /include/MLModelRunner/gRPCModelRunner.h: -------------------------------------------------------------------------------- 1 | //=== MLModelRunner/gRPCModelRunner.h -MLConfig class definition - C++ -*--===// 2 | // 3 | // Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | // Exceptions. See the LICENSE file for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // 7 | //===---------------------------------------------------------------------===// 8 | /// 9 | /// \file 10 | /// gRPCModelRunner class supporting communication via gRPC. This class is used 11 | /// to communicate with the gRPC server and send/receive data to/from the model. 12 | /// Supports interleaved communication with the model. 13 | /// 14 | /// There are two ways to use this class: 15 | /// 1. Training mode - gRPC Server: In this mode, the gRPCModelRunner object is 16 | /// created with the server address and the service object. The service object 17 | /// is used to create the server and the server waits for the client to connect. 18 | /// Once the client connects, the server waits for the client to send the 19 | /// request and then responds with the result. In Training mode, 20 | /// GrpcCompilerInterface class from GrpcCompilerInterface.py acts as the client 21 | /// and sends the request to the server. 22 | /// 23 | /// 2. Inference mode - gRPC Client: In this mode, the gRPCModelRunner object is 24 | /// created with the server address, request and response objects. The request 25 | /// object is used to send the features to the model and the response object is 26 | /// used to receive the result from the model. 27 | /// 28 | /// In Inference mode, the compiler pass using this class acts as the client and 29 | /// sends the request to the server implemented by using 30 | /// GrpcCompilerInterface.py. 31 | /// 32 | /// Interfacing with the model using protobuf (.proto) files: 33 | /// Users should define the service, stub, request and response by writing a 34 | /// .proto file. The service and stub objects are generated using the protoc 35 | /// compiler. The request and response objects are generated using the protoc 36 | /// compiler or created by the user. The build process will automatically 37 | /// compile .proto files and generates the relevant stubs and request/response 38 | /// objects in both CPP and Python which will be used by gRPCModelRunner and 39 | /// GrpcCompilerInterface respectively.. 40 | /// 41 | /// In Inference mode, users should override `getAdvice()` RPC method in the 42 | /// Python model code to process the request and send the response back to the 43 | /// client. This method is called by the gRPC server in the evaluate_untyped() 44 | /// method of gRPCModelRunner class to get the result from the model after 45 | /// populating the features in the request object. 46 | /// 47 | /// In Training mode, users should override the RPC function/service that they 48 | /// declare in the .proto file in the Compiler pass which is using the 49 | /// gRPCModelRunner. This function is called by the gRPC client (Python model) 50 | /// that is using the GrpcCompilerInterface class to send the request to the 51 | /// server. 52 | /// 53 | /// Usage: 54 | /// 1. Create a .proto file with the service and message definitions 55 | /// 2. Generate the stubs using protoc 56 | /// 3. Create a gRPCModelRunner object with the server address, stub, request 57 | /// and response 58 | /// 4. Populate the features to be sent to the model 59 | /// 5. Call evaluate() to get the result back from the model 60 | /// 61 | // ===----------------------------------------------------------------------===// 62 | 63 | #ifndef GRPC_MODELRUNNER_H 64 | #define GRPC_MODELRUNNER_H 65 | 66 | #include "MLModelRunner/MLModelRunner.h" 67 | 68 | #include 69 | #include 70 | #include 71 | #include 72 | #include 73 | #include 74 | 75 | namespace MLBridge { 76 | /// This class is used to create the grpc model runner object. grpc model runner 77 | /// requires service, stub, request and response objects to communicate with the 78 | /// model. 79 | template 80 | class gRPCModelRunner : public MLModelRunner { 81 | public: 82 | /// For server mode 83 | gRPCModelRunner(std::string server_address, grpc::Service *s, 84 | llvm::LLVMContext *Ctx = nullptr) 85 | : MLModelRunner(MLModelRunner::Kind::gRPC, SerDesKind::Protobuf, Ctx), 86 | server_address(server_address), request(nullptr), response(nullptr), 87 | server_mode(true) { 88 | RunService(s); 89 | } 90 | 91 | /// For client mode 92 | gRPCModelRunner(std::string server_address, Request *request, 93 | Response *response, llvm::LLVMContext *Ctx = nullptr) 94 | : MLModelRunner(MLModelRunner::Kind::gRPC, SerDesKind::Protobuf, Ctx), 95 | server_address(server_address), request(request), response(response), 96 | server_mode(false) { 97 | SetStub(); 98 | } 99 | 100 | void requestExit() override { 101 | std::string input; 102 | std::cin >> input; 103 | if (input == "Terminate") { 104 | this->exit_requested->set_value(); 105 | } else { 106 | std::cout << "Problem while closing server\n"; 107 | } 108 | } 109 | 110 | private: 111 | /// checks whether a port number is available or not 112 | bool isPortAvailable(std::string addr) { 113 | int max_retries = 30, attempts = 0; 114 | double wait_seconds = 0.2, backoff_exp = 1.2; 115 | 116 | int idx = addr.find(":"); 117 | int port = stoi(addr.substr(idx + 1, addr.size() - idx - 1)); 118 | 119 | while (attempts < max_retries) { 120 | std::string command = "lsof -i :" + std::to_string(port); 121 | FILE *pipe = popen(command.c_str(), "r"); 122 | if (!pipe) { 123 | std::cerr << "Error executing command: " << std::strerror(errno) 124 | << std::endl; 125 | return false; 126 | } 127 | 128 | char buffer[256]; 129 | std::string result = ""; 130 | while (!feof(pipe)) { 131 | if (fgets(buffer, 256, pipe) != nullptr) 132 | result += buffer; 133 | } 134 | pclose(pipe); 135 | 136 | if (result.empty()) { 137 | return true; 138 | } 139 | attempts++; 140 | std::cout << "Port is unavailable retrying! attempt: " << attempts 141 | << std::endl; 142 | std::this_thread::sleep_for(std::chrono::duration(wait_seconds)); 143 | wait_seconds *= backoff_exp; 144 | } 145 | 146 | std::cout << "Port is unavailable now!" << std::endl; 147 | return false; 148 | } 149 | 150 | std::promise *exit_requested; 151 | 152 | /// This method is used to send the request to the model and get the result. 153 | /// Used in only client mode during inference. 154 | void *evaluateUntyped() override { 155 | assert(!server_mode && 156 | "evaluateUntyped not implemented for gRPCModelRunner; " 157 | "Override gRPC method instead"); 158 | assert(request != nullptr && "Request cannot be null"); 159 | 160 | int max_retries = 30, attempts = 0; 161 | double retries_wait_secs = 0.2; 162 | int deadline_time = 10000; 163 | int deadline_max_retries = 30, deadline_attpts = 0; 164 | double retry_wait_backoff_exponent = 1.5; 165 | 166 | // setting a deadline 167 | auto deadline = std::chrono::system_clock::now() + 168 | std::chrono::milliseconds(deadline_time); 169 | 170 | while (attempts < max_retries && deadline_attpts < deadline_max_retries) { 171 | grpc::ClientContext grpcCtx; 172 | request = getRequest(); 173 | grpc::Status status; 174 | grpcCtx.set_deadline(deadline); 175 | 176 | status = stub_->getAdvice(&grpcCtx, *request, response); 177 | 178 | if (status.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) { 179 | deadline_attpts++; 180 | int ext_deadline = 2 * deadline_time; 181 | deadline_time = ext_deadline; 182 | std::cout << "Deadline Exceeded for Request! sending the message again " 183 | "with extended deadline : " 184 | << deadline_time << "\n"; 185 | deadline = std::chrono::system_clock::now() + 186 | std::chrono::milliseconds(deadline_time); 187 | } else if (status.error_code() == grpc::StatusCode::UNAVAILABLE) { 188 | attempts++; 189 | std::cout << "Server is unavailable retrying! attempt: " << attempts 190 | << "\n"; 191 | std::this_thread::sleep_for( 192 | std::chrono::duration(retries_wait_secs)); 193 | retries_wait_secs *= retry_wait_backoff_exponent; 194 | } else { 195 | request->Clear(); 196 | if (!status.ok()) { 197 | if (Ctx) 198 | Ctx->emitError("gRPC failed: " + status.error_message()); 199 | else 200 | std::cerr << "gRPC failed: " << status.error_message() << std::endl; 201 | } 202 | // auto *action = new int(); // Hard wired for PosetRL case, should be 203 | // fixed *action = response->action(); return action; 204 | return SerDes->deserializeUntyped(response); 205 | } 206 | } 207 | 208 | std::cout << "Server is unavailable now!!!\n"; 209 | return new int(-1); 210 | } 211 | 212 | Stub *stub_; 213 | std::string server_address; 214 | Request *request; 215 | Response *response; 216 | bool server_mode; 217 | 218 | /// This method is used to create the server and start listening. Used in 219 | /// server mode. 220 | int RunService(grpc::Service *s) { 221 | exit_requested = new std::promise(); 222 | grpc::ServerBuilder builder; 223 | // if (!this->isPortAvailable(server_address)) return -1; 224 | builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); 225 | builder.RegisterService(s); 226 | std::unique_ptr server(builder.BuildAndStart()); 227 | std::cout << "Server Listening on " << server_address << std::endl; 228 | auto serveFn = [&]() { server->Wait(); }; 229 | std::thread serving_thread(serveFn); 230 | auto f = exit_requested->get_future(); 231 | this->requestExit(); 232 | f.wait(); 233 | server->Shutdown(); 234 | serving_thread.join(); 235 | std::cout << "Server Shutdowns Successfully" << std::endl; 236 | return 0; 237 | } 238 | 239 | /// This method is used to create the stub. Used in client mode. 240 | int SetStub() { 241 | std::shared_ptr channel = 242 | grpc::CreateChannel(server_address, grpc::InsecureChannelCredentials()); 243 | auto Stub_temp = Client::NewStub(channel); 244 | stub_ = Stub_temp.release(); 245 | return 0; 246 | } 247 | 248 | Request *getRequest() { return (Request *)SerDes->getRequest(); } 249 | 250 | Response *getResponse() { return (Response *)SerDes->getResponse(); } 251 | 252 | void printMessage(const google::protobuf::Message *message) { 253 | std::string s; 254 | if (google::protobuf::TextFormat::PrintToString(*message, &s)) { 255 | std::cout << "Your message: " << s << std::endl; 256 | } else { 257 | std::cerr << "Message not valid (partial content: " 258 | << request->ShortDebugString() << ")\n"; 259 | } 260 | } 261 | }; 262 | } // namespace MLBridge 263 | 264 | #endif // GRPC_MODELRUNNER_H 265 | -------------------------------------------------------------------------------- /include/SerDes/TensorSpec.h: -------------------------------------------------------------------------------- 1 | //===- TensorSpec.h - type descriptor for a tensor --------------*- C++ -*-===// 2 | // 3 | // Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | // Exceptions. See the LICENSE file for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // 7 | // (Preliminary version adopted from TensorSpec.h of LLVM 17.X) 8 | // 9 | //===----------------------------------------------------------------------===// 10 | 11 | #ifndef MLBRIDGE_TENSORSPEC_H 12 | #define MLBRIDGE_TENSORSPEC_H 13 | 14 | #include "llvm/ADT/Optional.h" 15 | #include "llvm/IR/LLVMContext.h" 16 | #include "llvm/Support/JSON.h" 17 | #include 18 | #include 19 | #include 20 | 21 | namespace MLBridge { 22 | /// TensorSpec encapsulates the specification of a tensor: its dimensions, or 23 | /// "shape" (row-major), its type (see TensorSpec::getDataType specializations 24 | /// for supported types), its name and port (see "TensorFlow: Large-Scale 25 | /// Machine Learning on Heterogeneous Distributed Systems", section 4.2, para 2: 26 | /// https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/45166.pdf) 27 | /// 28 | /// Known tensor types. The left part is the C type, the right is a name we 29 | /// can use to identify the type (to implement TensorSpec equality checks), and 30 | /// to use, if needed, when mapping to an underlying evaluator's type system. 31 | /// The main requirement is that the C type we use has the same size and 32 | /// encoding (e.g. endian-ness) as the one used by the evaluator. 33 | #define SUPPORTED_TENSOR_TYPES(M) \ 34 | M(float, Float) \ 35 | M(double, Double) \ 36 | M(int8_t, Int8) \ 37 | M(uint8_t, UInt8) \ 38 | M(int16_t, Int16) \ 39 | M(uint16_t, UInt16) \ 40 | M(int32_t, Int32) \ 41 | M(uint32_t, UInt32) \ 42 | M(int64_t, Int64) \ 43 | M(uint64_t, UInt64) 44 | 45 | enum class TensorType { 46 | Invalid, 47 | #define _TENSOR_TYPE_ENUM_MEMBERS(_, Name) Name, 48 | SUPPORTED_TENSOR_TYPES(_TENSOR_TYPE_ENUM_MEMBERS) 49 | #undef _TENSOR_TYPE_ENUM_MEMBERS 50 | Total 51 | }; 52 | 53 | class TensorSpec final { 54 | public: 55 | template 56 | static TensorSpec createSpec(const std::string &Name, 57 | const std::vector &Shape, 58 | int Port = 6020) { 59 | return TensorSpec(Name, Port, getDataType(), sizeof(T), Shape); 60 | } 61 | 62 | const std::string &name() const { return Name; } 63 | int port() const { return Port; } 64 | TensorType type() const { return Type; } 65 | const std::vector &shape() const { return Shape; } 66 | void setShape(const std::vector &NewShape) { 67 | Shape = NewShape; 68 | ElementCount = std::accumulate(Shape.begin(), Shape.end(), 1, 69 | std::multiplies()); 70 | } 71 | bool operator==(const TensorSpec &Other) const { 72 | return Name == Other.Name && Port == Other.Port && Type == Other.Type && 73 | Shape == Other.Shape; 74 | } 75 | 76 | bool operator!=(const TensorSpec &Other) const { return !(*this == Other); } 77 | 78 | /// Get the number of elements in a tensor with this shape. 79 | size_t getElementCount() const { return ElementCount; } 80 | /// Get the size, in bytes, of one element. 81 | size_t getElementByteSize() const { return ElementSize; } 82 | /// Get the total size of a memory buffer needed to store the whole tensor. 83 | size_t getTotalTensorBufferSize() const { return ElementCount * ElementSize; } 84 | 85 | template bool isElementType() const { 86 | return getDataType() == Type; 87 | } 88 | 89 | TensorSpec(const std::string &NewName, const TensorSpec &Other) 90 | : TensorSpec(NewName, Other.Port, Other.Type, Other.ElementSize, 91 | Other.Shape) {} 92 | 93 | void toJSON(llvm::json::OStream &OS) const; 94 | 95 | private: 96 | TensorSpec(const std::string &Name, int Port, TensorType Type, 97 | size_t ElementSize, const std::vector &Shape); 98 | 99 | template static TensorType getDataType(); 100 | 101 | std::string Name; 102 | int Port = 0; 103 | TensorType Type = TensorType::Invalid; 104 | std::vector Shape; 105 | size_t ElementCount = 0; 106 | size_t ElementSize = 0; 107 | }; 108 | 109 | /// For debugging. 110 | std::string tensorValueToString(const char *Buffer, const TensorSpec &Spec); 111 | 112 | /// Construct a TensorSpec from a JSON dictionary of the form: 113 | /// { "name": , 114 | /// "port": , 115 | /// "type": , 116 | /// "shape": } 117 | /// For the "type" field, see the C++ primitive types used in 118 | /// TFUTILS_SUPPORTED_TYPES. 119 | llvm::Optional 120 | getTensorSpecFromJSON(llvm::LLVMContext &Ctx, const llvm::json::Value &Value); 121 | 122 | #define TFUTILS_GETDATATYPE_DEF(T, Name) \ 123 | template <> TensorType TensorSpec::getDataType(); 124 | SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_DEF) 125 | 126 | #undef TFUTILS_GETDATATYPE_DEF 127 | } // namespace MLBridge 128 | 129 | #endif 130 | -------------------------------------------------------------------------------- /include/SerDes/baseSerDes.h: -------------------------------------------------------------------------------- 1 | //=== SerDes/baseSerDes.h - Base for serialization and deserialization C++ ===// 2 | // 3 | // Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | // Exceptions. See the LICENSE file for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // 7 | //===----------------------------------------------------------------------===// 8 | /// 9 | /// \file 10 | /// Supporting new SerDes: 11 | /// 1. Create a new class which inherits from BaseSerDes. 12 | /// 2. Implement the setFeature(), getSerializedData(), cleanDataStructures() 13 | /// and deserializeUntyped() methods. 14 | /// 3. Add the new SerDes to the enum class SerDesKind in this class. 15 | /// 16 | //===----------------------------------------------------------------------===// 17 | 18 | #ifndef BASE_SERDES_H 19 | #define BASE_SERDES_H 20 | 21 | #include "MLModelRunner/Utils/Debug.h" 22 | #include "google/protobuf/extension_set.h" 23 | #include "google/protobuf/message.h" 24 | #include "llvm/Support/raw_ostream.h" 25 | 26 | #include 27 | #include 28 | #include 29 | #include 30 | 31 | // TYPE, NAME 32 | #define SUPPORTED_TYPES(M) \ 33 | M(int, int) \ 34 | M(long, long) \ 35 | M(float, float) \ 36 | M(double, double) \ 37 | M(std::string, string) \ 38 | M(bool, bool) 39 | 40 | namespace MLBridge { 41 | /// This is the base class for SerDes. It defines the interface for the 42 | /// serialization and deserialization of the data structures used for the 43 | /// communication by the MLModelRunner. 44 | /// Currently, (int, float) or (long, double), char and bool are supported. 45 | /// Vectors of these types are supported as well. 46 | enum class SerDesKind : int { Unknown, Json, Bitstream, Protobuf, Tensorflow }; 47 | class BaseSerDes { 48 | public: 49 | SerDesKind getKind() const { return Type; } 50 | 51 | /// setFeature() is used to set the features of the data structure used for 52 | /// communication. The features are set as key-value pairs. The key is a 53 | /// string and the value can be any of the supported types. 54 | #define SET_FEATURE(TYPE, _) \ 55 | virtual void setFeature(const std::string &, const TYPE) = 0; \ 56 | virtual void setFeature(const std::string &, const std::vector &){}; 57 | SUPPORTED_TYPES(SET_FEATURE) 58 | #undef SET_FEATURE 59 | 60 | virtual void setFeature(const std::string &name, 61 | const google::protobuf::Message *value){}; 62 | virtual void 63 | setFeature(const std::string &name, 64 | const std::vector &value){}; 65 | 66 | // a hack to set the request and response structures in protobuf serializer 67 | virtual void setRequest(void *Request) { 68 | MLBRIDGE_DEBUG(std::cout << "In BaseSerializer setRequest...\n"); 69 | }; 70 | virtual void setResponse(void *Response){}; 71 | virtual void *getSerializedData() = 0; 72 | virtual void *deserializeUntyped(void *data) = 0; 73 | size_t getMessageLength() { return MessageLength; } 74 | virtual void *getRequest() { return nullptr; }; 75 | virtual void *getResponse() { return nullptr; }; 76 | 77 | protected: 78 | BaseSerDes(SerDesKind Type) : Type(Type) { 79 | assert(Type != SerDesKind::Unknown); 80 | } 81 | virtual void cleanDataStructures() = 0; 82 | const SerDesKind Type; 83 | void *RequestVoid; 84 | void *ResponseVoid; 85 | size_t MessageLength; 86 | }; 87 | } // namespace MLBridge 88 | 89 | #endif 90 | -------------------------------------------------------------------------------- /include/SerDes/bitstreamSerDes.h: -------------------------------------------------------------------------------- 1 | //=== SerDes/bitstreamSerDes.h -Bitstream Serialization/Deserialization-C++===// 2 | // 3 | // Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | // Exceptions. See the LICENSE file for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // 7 | //===----------------------------------------------------------------------===// 8 | /// 9 | /// \file 10 | /// Bitstream Serialization/Deserialization which sends header information 11 | /// followed by the raw data. 12 | /// 13 | //===----------------------------------------------------------------------===// 14 | 15 | #ifndef BITSTREAM_SERIALIZER_H 16 | #define BITSTREAM_SERIALIZER_H 17 | 18 | #include "SerDes/TensorSpec.h" 19 | #include "SerDes/baseSerDes.h" 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | 26 | namespace MLBridge { 27 | /// BitstreamSerDes - Bitstream Serialization/Deserialization which sends header 28 | /// information followed by the raw data. 29 | class BitstreamSerDes : public BaseSerDes { 30 | public: 31 | BitstreamSerDes() : BaseSerDes(SerDesKind::Bitstream) { 32 | Buffer = ""; 33 | tensorSpecs = std::vector(); 34 | rawData = std::vector(); 35 | 36 | #define TEMPORARY_STORAGE_INIT(TYPE, NAME) \ 37 | features##NAME = {}; \ 38 | featuresVector##NAME = {}; 39 | SUPPORTED_TYPES(TEMPORARY_STORAGE_INIT) 40 | #undef TEMPORARY_STORAGE_INIT 41 | }; 42 | #define SET_FEATURE(TYPE, _) \ 43 | void setFeature(const std::string &, const TYPE) override; \ 44 | void setFeature(const std::string &, const std::vector &) override; 45 | SUPPORTED_TYPES(SET_FEATURE) 46 | #undef SET_FEATURE 47 | 48 | void *getSerializedData() override; 49 | 50 | void cleanDataStructures() override { 51 | Buffer = ""; 52 | tensorSpecs = std::vector(); 53 | rawData = std::vector(); 54 | 55 | #define TEMPORARY_STORAGE_CLEAN(TYPE, NAME) \ 56 | for (auto &it : features##NAME) { \ 57 | delete it.second; \ 58 | } \ 59 | features##NAME.clear(); \ 60 | features##NAME = {}; \ 61 | for (auto &it : featuresVector##NAME) { \ 62 | delete it.second; \ 63 | } \ 64 | featuresVector##NAME.clear(); \ 65 | featuresVector##NAME = {}; 66 | SUPPORTED_TYPES(TEMPORARY_STORAGE_CLEAN) 67 | #undef TEMPORARY_STORAGE_CLEAN 68 | } 69 | 70 | private: 71 | void *deserializeUntyped(void *) override; 72 | std::vector tensorSpecs; 73 | std::vector rawData; 74 | std::string Buffer; 75 | 76 | #define TEMPORARY_STORAGE_DEF(TYPE, NAME) \ 77 | std::map features##NAME; \ 78 | std::map *> featuresVector##NAME; 79 | SUPPORTED_TYPES(TEMPORARY_STORAGE_DEF) 80 | #undef TEMPORARY_STORAGE_DEF 81 | }; 82 | } // namespace MLBridge 83 | 84 | #endif 85 | -------------------------------------------------------------------------------- /include/SerDes/jsonSerDes.h: -------------------------------------------------------------------------------- 1 | //=== SerDes/jsonSerDes.h -Json Serialization/Deserialization ---*- C++ ---===// 2 | // 3 | // Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | // Exceptions. See the LICENSE file for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // 7 | //===----------------------------------------------------------------------===// 8 | /// 9 | /// \file 10 | /// Json Serialization/Deserialization using LLVM's json library. 11 | /// 12 | //===----------------------------------------------------------------------===// 13 | 14 | #ifndef JSON_SERIALIZER_H 15 | #define JSON_SERIALIZER_H 16 | 17 | #include "SerDes/baseSerDes.h" 18 | #include "llvm/Support/JSON.h" 19 | #include 20 | #include 21 | 22 | namespace MLBridge { 23 | /// JsonSerDes - Json Serialization/Deserialization using LLVM's json library. 24 | class JsonSerDes : public BaseSerDes { 25 | public: 26 | JsonSerDes() : BaseSerDes(SerDesKind::Json){}; 27 | 28 | static bool classof(const BaseSerDes *S) { 29 | return S->getKind() == SerDesKind::Json; 30 | } 31 | 32 | #define SET_FEATURE(TYPE, _) \ 33 | void setFeature(const std::string &name, const TYPE value) override { \ 34 | J[name] = value; \ 35 | } \ 36 | void setFeature(const std::string &name, const std::vector &value) \ 37 | override { \ 38 | J[name] = llvm::json::Array(value); \ 39 | } 40 | SUPPORTED_TYPES(SET_FEATURE) 41 | #undef SET_FEATURE 42 | 43 | void *getSerializedData() override; 44 | 45 | void cleanDataStructures() override { J = llvm::json::Object(); } 46 | 47 | private: 48 | void *deserializeUntyped(void *data) override; 49 | void *desJson(llvm::json::Value *V); 50 | 51 | private: 52 | llvm::json::Object J; 53 | }; 54 | } // namespace MLBridge 55 | 56 | #endif 57 | -------------------------------------------------------------------------------- /include/SerDes/protobufSerDes.h: -------------------------------------------------------------------------------- 1 | //=== SerDes/protobufSerDes.h - Protobuf Serialization/Deserialization C++-===// 2 | // 3 | // Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | // Exceptions. See the LICENSE file for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // 7 | //===----------------------------------------------------------------------===// 8 | /// 9 | /// \file 10 | /// Protobuf Serialization/Deserialization to support gRPC communication. 11 | /// 12 | //===----------------------------------------------------------------------===// 13 | 14 | #ifndef PROTOBUF_SERIALIZER_H 15 | #define PROTOBUF_SERIALIZER_H 16 | 17 | #include "SerDes/baseSerDes.h" 18 | #include "google/protobuf/extension_set.h" 19 | #include "google/protobuf/message.h" 20 | 21 | using namespace google::protobuf; 22 | 23 | namespace MLBridge { 24 | /// ProtobufSerDes - Protobuf Serialization/Deserialization to support gRPC 25 | /// communication. 26 | class ProtobufSerDes : public BaseSerDes { 27 | public: 28 | ProtobufSerDes() : BaseSerDes(SerDesKind::Protobuf){}; 29 | 30 | static bool classof(const BaseSerDes *S) { 31 | return S->getKind() == SerDesKind::Protobuf; 32 | } 33 | 34 | void setRequest(void *Request) override; 35 | void setResponse(void *Response) override; 36 | 37 | void *getRequest() override { return Request; } 38 | 39 | void *getResponse() override { return Response; } 40 | 41 | #define SET_FEATURE(TYPE, _) \ 42 | virtual void setFeature(const std::string &, const TYPE) override; \ 43 | virtual void setFeature(const std::string &, const std::vector &) \ 44 | override; 45 | SUPPORTED_TYPES(SET_FEATURE) 46 | #undef SET_FEATURE 47 | 48 | void setFeature(const std::string &name, 49 | const google::protobuf::Message *value) override; 50 | void 51 | setFeature(const std::string &name, 52 | const std::vector &value) override; 53 | 54 | void *getSerializedData() override; 55 | void cleanDataStructures() override; 56 | 57 | Message *getMessage() { return Response; }; 58 | 59 | private: 60 | void *deserializeUntyped(void *data) override; 61 | Message *Response; 62 | Message *Request; 63 | }; 64 | } // namespace MLBridge 65 | 66 | #endif 67 | -------------------------------------------------------------------------------- /include/SerDes/tensorflowSerDes.h: -------------------------------------------------------------------------------- 1 | //=== SerDes/tensorflowSerDes.h - SerDes for TF support ---*- C++ ---------===// 2 | // 3 | // Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | // Exceptions. See the LICENSE file for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // 7 | //===----------------------------------------------------------------------===// 8 | /// 9 | /// \file 10 | /// Serialization/Deserialization to support TF AOT models. 11 | /// 12 | //===----------------------------------------------------------------------===// 13 | 14 | #ifndef TENSORFLOW_SERIALIZER_H 15 | #define TENSORFLOW_SERIALIZER_H 16 | 17 | #include "SerDes/baseSerDes.h" 18 | #include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" 19 | 20 | namespace MLBridge { 21 | /// TensorflowSerDes - Serialization/Deserialization to support TF AOT models. 22 | class TensorflowSerDes : public BaseSerDes { 23 | public: 24 | TensorflowSerDes() : BaseSerDes(SerDesKind::Tensorflow) {} 25 | 26 | static bool classof(const BaseSerDes *S) { 27 | return S->getKind() == SerDesKind::Tensorflow; 28 | } 29 | 30 | #define SET_FEATURE(TYPE, _) \ 31 | void setFeature(const std::string &, const TYPE) override; \ 32 | void setFeature(const std::string &, const std::vector &) override; 33 | SUPPORTED_TYPES(SET_FEATURE) 34 | #undef SET_FEATURE 35 | 36 | void setRequest(void *request) override { 37 | CompiledModel = 38 | reinterpret_cast(request); 39 | } 40 | 41 | void *getSerializedData() override { return nullptr; }; 42 | void cleanDataStructures() override{}; 43 | 44 | private: 45 | void *deserializeUntyped(void *data) override { return nullptr; }; 46 | tensorflow::XlaCompiledCpuFunction *CompiledModel; 47 | }; 48 | } // namespace MLBridge 49 | 50 | #endif 51 | -------------------------------------------------------------------------------- /test/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-rtti") 2 | 3 | add_library(MLBridgeCPPTest OBJECT MLBridgeTest.cpp) 4 | file(GLOB MODEL_OBJECTS ${CMAKE_CURRENT_SOURCE_DIR}/tf_models/*.o) 5 | 6 | foreach(MODEL_OBJECT ${MODEL_OBJECTS}) 7 | target_link_libraries(MLBridgeCPPTest PRIVATE ${MODEL_OBJECT}) 8 | endforeach() 9 | target_link_libraries(MLBridgeCPPTest PRIVATE MLCompilerBridge ) 10 | target_include_directories(MLBridgeCPPTest PRIVATE ${CMAKE_BINARY_DIR}/include ${TENSORFLOW_AOT_PATH}/include ${CMAKE_CURRENT_SOURCE_DIR}/include) 11 | target_link_libraries(MLBridgeCPPTest PRIVATE tf_xla_runtime) 12 | -------------------------------------------------------------------------------- /test/MLBridgeTest.cpp: -------------------------------------------------------------------------------- 1 | //===----------------------------------------------------------------------===// 2 | // 3 | // Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | // Exceptions. See the LICENSE file for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // 7 | //===----------------------------------------------------------------------===// 8 | 9 | #include "HelloMLBridge_Env.h" 10 | #include "MLModelRunner/MLModelRunner.h" 11 | #include "MLModelRunner/ONNXModelRunner/ONNXModelRunner.h" 12 | #include "MLModelRunner/PipeModelRunner.h" 13 | #include "MLModelRunner/TFModelRunner.h" 14 | #include "MLModelRunner/Utils/DataTypes.h" 15 | #include "MLModelRunner/Utils/MLConfig.h" 16 | #include "MLModelRunner/gRPCModelRunner.h" 17 | #include "ProtosInclude.h" 18 | #include "llvm/Support/CommandLine.h" 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | 27 | #define debug_out \ 28 | if (!silent) \ 29 | std::cout 30 | using namespace grpc; 31 | 32 | static llvm::cl::opt 33 | cl_server_address("test-server-address", llvm::cl::Hidden, 34 | llvm::cl::desc("Server address, format :"), 35 | llvm::cl::init("")); 36 | 37 | static llvm::cl::opt 38 | cl_pipe_name("test-pipe-name", llvm::cl::Hidden, llvm::cl::init(""), 39 | llvm::cl::desc("Name for pipe file")); 40 | 41 | static llvm::cl::opt 42 | cl_onnx_path("onnx-model-path", llvm::cl::Hidden, llvm::cl::init(""), 43 | llvm::cl::desc("Path to onnx model")); 44 | 45 | static llvm::cl::opt cl_test_config( 46 | "test-config", llvm::cl::Hidden, 47 | llvm::cl::desc("Method for communication with python model")); 48 | 49 | static llvm::cl::opt 50 | silent("silent", llvm::cl::Hidden, llvm::cl::init(false), 51 | llvm::cl::desc("Only print errors when if set to true")); 52 | 53 | namespace { 54 | std::string basename; 55 | SerDesKind SerDesType; 56 | 57 | std::string test_config; 58 | std::string pipe_name; 59 | std::string server_address; 60 | std::string onnx_path; 61 | 62 | // send value of type T1. Test received value of type T2 against expected value 63 | template 64 | void testPrimitive(MLRunnerTy &MLRunner, std::string label, T1 value, 65 | T2 expected) { 66 | std::pair p("request_" + label, value); 67 | MLRunner->populateFeatures(p); 68 | T2 out = MLRunner->template evaluate(); 69 | debug_out << " " << label << " reply: " << out << "\n"; 70 | if (std::abs(out - expected) > 10e-6) { 71 | std::cerr << "Error: Expected " << label << " reply: " << expected 72 | << ", Received: " << out << "\n"; 73 | exit(1); 74 | } 75 | } 76 | 77 | template 78 | void testVector(MLRunnerTy &MLRunner, std::string label, std::vector value, 79 | std::vector expected) { 80 | std::pair> p("request_" + label, value); 81 | MLRunner->populateFeatures(p); 82 | T2 *out; 83 | size_t size; 84 | MLRunner->template evaluate(out, size); 85 | std::vector reply(out, out + size); 86 | debug_out << " " << label << " reply: "; 87 | int i = 0; 88 | for (auto x : reply) { 89 | debug_out << x << " "; 90 | if (std::abs(x - expected[i]) > 10e-6) { 91 | std::cerr << "Error: Expected " << label << " reply: " << expected[i] 92 | << ", Received: " << x << "\n"; 93 | exit(1); 94 | } 95 | i++; 96 | } 97 | debug_out << "\n"; 98 | } 99 | 100 | int testPipeBytes() { 101 | if (pipe_name == "") { 102 | std::cerr 103 | << "Pipe name must be specified via --test-pipe-name=\n"; 104 | exit(1); 105 | } 106 | basename = "./" + pipe_name; 107 | SerDesType = SerDesKind::Bitstream; 108 | auto MLRunner = std::make_unique( 109 | basename + ".out", basename + ".in", SerDesType, nullptr); 110 | testPrimitive(MLRunner, "int", 11, 12); 111 | testPrimitive(MLRunner, "long", 1234567890l, 1234567891l); 112 | testPrimitive(MLRunner, "float", 3.14f, 4.14f); 113 | testPrimitive(MLRunner, "double", 0.123456789123456789, 1.123456789123456789); 114 | testPrimitive(MLRunner, "char", 'a', 'b'); 115 | testPrimitive(MLRunner, "bool", true, false); 116 | testVector(MLRunner, "vec_int", std::vector{11, 22, 33}, 117 | std::vector{12, 23, 34}); 118 | testVector(MLRunner, "vec_long", std::vector{123456780, 222, 333}, 119 | std::vector{123456780, 123456781, 123456782}); 120 | testVector(MLRunner, "vec_float", std::vector{11.1, 22.2, 33.3}, 121 | std::vector{1.11, 2.22, -3.33, 0}); 122 | testVector(MLRunner, "vec_double", 123 | std::vector{-1.1111111111, -2.2222222222, -3.3333333333}, 124 | std::vector{1.12345678912345670, -1.12345678912345671}); 125 | return 0; 126 | } 127 | 128 | int testPipeJSON() { 129 | if (pipe_name == "") { 130 | std::cerr 131 | << "Pipe name must be specified via --test-pipe-name=\n"; 132 | exit(1); 133 | } 134 | basename = "./" + pipe_name; 135 | SerDesType = SerDesKind::Json; 136 | auto MLRunner = std::make_unique( 137 | basename + ".out", basename + ".in", SerDesType, nullptr); 138 | testPrimitive(MLRunner, "int", 11, 12); 139 | testPrimitive(MLRunner, "long", 1234567890l, 12345l); 140 | testPrimitive(MLRunner, "float", 3.14f, 4.14f); 141 | 142 | // FIXME: doesn't work if expected value is double 143 | testPrimitive(MLRunner, "double", 0.123456789123456789, 144 | 1.123456789123456789f); 145 | testPrimitive(MLRunner, "char", 'a', 'b'); 146 | testPrimitive(MLRunner, "bool", true, false); 147 | testVector(MLRunner, "vec_int", std::vector{11, 22, 33}, 148 | std::vector{12, 23, 34}); 149 | 150 | // FIXME: doesn't work if expected value is long 151 | testVector(MLRunner, "vec_long", std::vector{123456780, 222, 333}, 152 | std::vector{6780, 6781, 6782}); 153 | testVector(MLRunner, "vec_float", std::vector{11.1, 22.2, 33.3}, 154 | std::vector{1.11, 2.22, -3.33, 0}); 155 | 156 | // FIXME: doesn't work if expected value is double 157 | testVector(MLRunner, "vec_double", 158 | std::vector{-1.1111111111, -2.2222222222, -3.3333333333}, 159 | std::vector{1.12345678912345670, -1.12345678912345671}); 160 | return 0; 161 | } 162 | 163 | void increment_port(int delta) { 164 | int split = server_address.find(":"); 165 | int port = stoi(server_address.substr(split + 1)); 166 | server_address = 167 | server_address.substr(0, split) + ":" + to_string(port + delta); 168 | } 169 | 170 | int testGRPC() { 171 | #define gRPCModelRunnerInit(datatype) \ 172 | increment_port(1); \ 173 | MLBridgeTestgRPC_##datatype::Reply response; \ 174 | MLBridgeTestgRPC_##datatype::Request request; \ 175 | auto MLRunner = std::make_unique< \ 176 | gRPCModelRunner>( \ 180 | server_address, &request, &response, nullptr); \ 181 | MLRunner->setRequest(&request); \ 182 | MLRunner->setResponse(&response) 183 | 184 | if (server_address == "") { 185 | std::cerr << "Server Address must be specified via " 186 | "--test-server-address=\":\"\n"; 187 | exit(1); 188 | } 189 | { 190 | gRPCModelRunnerInit(int); 191 | testPrimitive(MLRunner, "int", 11, 12); 192 | } 193 | { 194 | gRPCModelRunnerInit(long); 195 | testPrimitive(MLRunner, "long", 1234567890l, 1234567891l); 196 | } 197 | { 198 | gRPCModelRunnerInit(float); 199 | testPrimitive(MLRunner, "float", 3.14f, 4.14f); 200 | } 201 | { 202 | gRPCModelRunnerInit(double); 203 | testPrimitive(MLRunner, "double", 0.123456789123456789, 204 | 1.123456789123456789); 205 | } 206 | increment_port(1); 207 | { 208 | gRPCModelRunnerInit(bool); 209 | testPrimitive(MLRunner, "bool", true, false); 210 | } 211 | { 212 | gRPCModelRunnerInit(vec_int); 213 | testVector(MLRunner, "vec_int", std::vector{11, 22, 33}, 214 | std::vector{12, 23, 34}); 215 | } 216 | { 217 | gRPCModelRunnerInit(vec_long); 218 | testVector(MLRunner, "vec_long", std::vector{123456780, 222, 333}, 219 | std::vector{123456780, 123456781, 123456782}); 220 | } 221 | { 222 | gRPCModelRunnerInit(vec_float); 223 | testVector(MLRunner, "vec_float", std::vector{11.1, 22.2, 33.3}, 224 | std::vector{1.11, 2.22, -3.33, 0}); 225 | } 226 | { 227 | gRPCModelRunnerInit(vec_double); 228 | testVector(MLRunner, "vec_double", 229 | std::vector{-1.1111111111, -2.2222222222, -3.3333333333}, 230 | std::vector{1.12345678912345670, -1.12345678912345671}); 231 | } 232 | #undef gRPCModelRunnerInit 233 | return 0; 234 | } 235 | 236 | class ONNXTest : public MLBridgeTestEnv { 237 | public: 238 | int run(int expectedAction) { 239 | onnx_path = cl_onnx_path.getValue(); 240 | if (onnx_path == "") { 241 | std::cerr << "ONNX model path must be specified via " 242 | "--onnx-model-path=\n"; 243 | exit(1); 244 | } 245 | FeatureVector.clear(); 246 | int n = 100; 247 | for (int i = 0; i < n; i++) { 248 | float delta = (float)(i - expectedAction) / n; 249 | FeatureVector.push_back(delta * delta); 250 | } 251 | 252 | Agent *agent = new Agent(onnx_path); 253 | std::map agents; 254 | agents["agent"] = agent; 255 | auto MLRunner = std::make_unique(this, agents, nullptr); 256 | MLRunner->evaluate(); 257 | if (lastAction != expectedAction) { 258 | std::cerr << "Error: Expected action: " << expectedAction 259 | << ", Computed action: " << lastAction << "\n"; 260 | exit(1); 261 | } 262 | return 0; 263 | } 264 | }; 265 | 266 | } // namespace 267 | 268 | int main(int argc, char **argv) { 269 | llvm::cl::ParseCommandLineOptions(argc, argv); 270 | test_config = cl_test_config.getValue(); 271 | 272 | if (test_config == "pipe-bytes") { 273 | pipe_name = cl_pipe_name.getValue(); 274 | testPipeBytes(); 275 | } else if (test_config == "pipe-json") { 276 | pipe_name = cl_pipe_name.getValue(); 277 | testPipeJSON(); 278 | } else if (test_config == "grpc") { 279 | server_address = cl_server_address.getValue(); 280 | testGRPC(); 281 | } else if (test_config == "onnx") { 282 | ONNXTest t; 283 | t.run(20); 284 | } else 285 | std::cerr << "--test-config must be provided from [pipe-bytes, pipe-json, " 286 | "grpc, onnx]\n"; 287 | return 0; 288 | } 289 | -------------------------------------------------------------------------------- /test/hello-mlbridge.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # 3 | # Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | # Exceptions. See the LICENSE file for license information. 5 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | # 7 | # ------------------------------------------------------------------------------ 8 | 9 | import argparse 10 | import sys 11 | import torch, torch.nn as nn 12 | 13 | sys.path.append("../CompilerInterface") 14 | from PipeCompilerInterface import PipeCompilerInterface 15 | from GrpcCompilerInterface import GrpcCompilerInterface 16 | 17 | sys.path.append("../MLModelRunner/gRPCModelRunner/Python-Utilities") 18 | import helloMLBridge_pb2, helloMLBridge_pb2_grpc, grpc 19 | from concurrent import futures 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument( 23 | "--use_pipe", type=bool, default=False, help="Use pipe or not", required=False 24 | ) 25 | parser.add_argument( 26 | "--data_format", 27 | type=str, 28 | choices=["json", "protobuf", "bytes"], 29 | help="Data format to use for communication", 30 | ) 31 | parser.add_argument( 32 | "--pipe_name", 33 | type=str, 34 | help="Pipe Name", 35 | ) 36 | parser.add_argument( 37 | "--use_grpc", 38 | action="store_true", 39 | help="Use grpc communication", 40 | required=False, 41 | default=False, 42 | ) 43 | parser.add_argument( 44 | "--server_port", 45 | type=int, 46 | help="Server Port", 47 | default=5050, 48 | ) 49 | args = parser.parse_args() 50 | 51 | 52 | class DummyModel(nn.Module): 53 | def __init__(self, input_dim=10): 54 | nn.Module.__init__(self) 55 | self.fc1 = nn.Linear(input_dim, 1) 56 | 57 | def forward(self, input): 58 | x = self.fc1(input) 59 | return x 60 | 61 | 62 | def run_pipe_communication(data_format, pipe_name): 63 | compiler_interface = PipeCompilerInterface(data_format, "/tmp/" + pipe_name) 64 | print("PipeCompilerInterface init...") 65 | compiler_interface.reset_pipes() 66 | 67 | while True: 68 | try: 69 | data = compiler_interface.evaluate() 70 | if data_format == "json": 71 | data = data["tensor"] 72 | # print("Data: ", data["tensor"]) 73 | elif data_format == "bytes": 74 | data = [x for x in data[0]] 75 | # print("Data: ", [x for x in data[0]]) 76 | print("data: ", data) 77 | model = DummyModel(input_dim=len(data)) 78 | action = model(torch.Tensor(data)) 79 | compiler_interface.populate_buffer(3) 80 | except Exception as e: 81 | print("*******Exception*******", e) 82 | compiler_interface.reset_pipes() 83 | 84 | 85 | class service_server(helloMLBridge_pb2_grpc.HelloMLBridgeService): 86 | def __init__(self): 87 | # self.serdes = SerDes.SerDes(data_format, pipe_name) 88 | # self.serdes.init() 89 | pass 90 | 91 | def getAdvice(self, request, context): 92 | try: 93 | print("Entered getAdvice") 94 | print("Data: ", request.tensor) 95 | reply = helloMLBridge_pb2.ActionRequest(action=1) 96 | return reply 97 | except: 98 | reply = helloMLBridge_pb2.ActionRequest(action=-1) 99 | return reply 100 | 101 | 102 | def test_func(): 103 | data = 3.24 104 | import struct 105 | 106 | print(data, type(data)) 107 | byte_data = struct.pack("f", data) 108 | print(byte_data, len(byte_data)) 109 | 110 | print("decoding...") 111 | decoded = float(byte_data) 112 | 113 | print(decoded, type(decoded)) 114 | 115 | 116 | if __name__ == "__main__": 117 | if args.use_pipe: 118 | run_pipe_communication(args.data_format, args.pipe_name) 119 | elif args.use_grpc: 120 | compiler_interface = GrpcCompilerInterface( 121 | mode="server", 122 | add_server_method=helloMLBridge_pb2_grpc.add_HelloMLBridgeServiceServicer_to_server, 123 | grpc_service_obj=service_server(), 124 | hostport=args.server_port, 125 | ) 126 | compiler_interface.start_server() 127 | -------------------------------------------------------------------------------- /test/include/HelloMLBridge_Env.h: -------------------------------------------------------------------------------- 1 | //===----------------------------------------------------------------------===// 2 | // 3 | // Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | // Exceptions. See the LICENSE file for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // 7 | //===----------------------------------------------------------------------===// 8 | 9 | #include "MLModelRunner/ONNXModelRunner/environment.h" 10 | #include "MLModelRunner/ONNXModelRunner/utils.h" 11 | #include "llvm/IR/Module.h" 12 | #include "llvm/Support/raw_ostream.h" 13 | 14 | using namespace MLBridge; 15 | class MLBridgeTestEnv : public Environment { 16 | Observation CurrObs; 17 | 18 | public: 19 | MLBridgeTestEnv() { setNextAgent("agent"); }; 20 | Observation &reset() override; 21 | Observation &step(Action) override; 22 | Action lastAction; 23 | 24 | protected: 25 | std::vector FeatureVector; 26 | }; 27 | 28 | Observation &MLBridgeTestEnv::step(Action Action) { 29 | CurrObs.clear(); 30 | std::copy(FeatureVector.begin(), FeatureVector.end(), 31 | std::back_inserter(CurrObs)); 32 | lastAction = Action; 33 | setDone(); 34 | return CurrObs; 35 | } 36 | 37 | Observation &MLBridgeTestEnv::reset() { 38 | std::copy(FeatureVector.begin(), FeatureVector.end(), 39 | std::back_inserter(CurrObs)); 40 | return CurrObs; 41 | } 42 | -------------------------------------------------------------------------------- /test/include/ProtosInclude.h: -------------------------------------------------------------------------------- 1 | //===----------------------------------------------------------------------===// 2 | // 3 | // Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | // Exceptions. See the LICENSE file for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // 7 | //===----------------------------------------------------------------------===// 8 | 9 | #include "grpc/MLBridgeTest_bool/MLBridgeTest_bool.grpc.pb.h" 10 | #include "grpc/MLBridgeTest_bool/MLBridgeTest_bool.pb.h" 11 | #include "grpc/MLBridgeTest_char/MLBridgeTest_char.grpc.pb.h" 12 | #include "grpc/MLBridgeTest_char/MLBridgeTest_char.pb.h" 13 | #include "grpc/MLBridgeTest_double/MLBridgeTest_double.grpc.pb.h" 14 | #include "grpc/MLBridgeTest_double/MLBridgeTest_double.pb.h" 15 | #include "grpc/MLBridgeTest_float/MLBridgeTest_float.grpc.pb.h" 16 | #include "grpc/MLBridgeTest_float/MLBridgeTest_float.pb.h" 17 | #include "grpc/MLBridgeTest_int/MLBridgeTest_int.grpc.pb.h" 18 | #include "grpc/MLBridgeTest_int/MLBridgeTest_int.pb.h" 19 | #include "grpc/MLBridgeTest_long/MLBridgeTest_long.grpc.pb.h" 20 | #include "grpc/MLBridgeTest_long/MLBridgeTest_long.pb.h" 21 | #include "grpc/MLBridgeTest_vec_double/MLBridgeTest_vec_double.grpc.pb.h" 22 | #include "grpc/MLBridgeTest_vec_double/MLBridgeTest_vec_double.pb.h" 23 | #include "grpc/MLBridgeTest_vec_float/MLBridgeTest_vec_float.grpc.pb.h" 24 | #include "grpc/MLBridgeTest_vec_float/MLBridgeTest_vec_float.pb.h" 25 | #include "grpc/MLBridgeTest_vec_int/MLBridgeTest_vec_int.grpc.pb.h" 26 | #include "grpc/MLBridgeTest_vec_int/MLBridgeTest_vec_int.pb.h" 27 | #include "grpc/MLBridgeTest_vec_long/MLBridgeTest_vec_long.grpc.pb.h" 28 | #include "grpc/MLBridgeTest_vec_long/MLBridgeTest_vec_long.pb.h" 29 | -------------------------------------------------------------------------------- /test/mlbridge-test.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # 3 | # Part of the MLCompilerBridge Project, under the Apache License v2.0 with LLVM 4 | # Exceptions. See the LICENSE file for license information. 5 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | # 7 | # ------------------------------------------------------------------------------ 8 | import argparse 9 | import numpy as np 10 | import ctypes 11 | 12 | import sys 13 | import os 14 | import torch, torch.nn as nn 15 | import torch.onnx 16 | import subprocess 17 | import time 18 | 19 | BUILD_DIR = "../build_release" 20 | sys.path.extend( 21 | [ 22 | f"{BUILD_DIR}/MLModelRunner/gRPCModelRunner/Python-Utilities", 23 | ] 24 | ) 25 | from compilerinterface import PipeCompilerInterface, GrpcCompilerInterface 26 | 27 | FAIL = 1 28 | SUCCESS = 0 29 | 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument("--use_pipe", default=False, help="Use pipe or not", required=False) 32 | parser.add_argument( 33 | "--data_format", 34 | type=str, 35 | choices=["json", "protobuf", "bytes"], 36 | help="Data format to use for communication", 37 | ) 38 | parser.add_argument( 39 | "--pipe_name", 40 | type=str, 41 | help="Pipe Name", 42 | ) 43 | parser.add_argument( 44 | "--silent", type=bool, help="Only prints errors when set to true", default=False 45 | ) 46 | parser.add_argument( 47 | "--use_grpc", 48 | action="store_true", 49 | help="Use grpc communication", 50 | required=False, 51 | default=False, 52 | ) 53 | parser.add_argument( 54 | "--server_port", 55 | type=int, 56 | help="Server Port", 57 | ) 58 | parser.add_argument( 59 | "--test_number", 60 | type=int, 61 | help="Datatype number for test", 62 | default=0, 63 | ) 64 | parser.add_argument( 65 | "--export_onnx", 66 | help="Export onnx test model", 67 | required=False, 68 | default=False, 69 | ) 70 | args = parser.parse_args() 71 | 72 | if args.test_number <= 1: 73 | import MLBridgeTest_int_pb2, MLBridgeTest_int_pb2_grpc 74 | 75 | MLBridgeTest_pb2, MLBridgeTest_pb2_grpc = ( 76 | MLBridgeTest_int_pb2, 77 | MLBridgeTest_int_pb2_grpc, 78 | ) 79 | elif args.test_number == 2: 80 | import MLBridgeTest_long_pb2, MLBridgeTest_long_pb2_grpc 81 | 82 | MLBridgeTest_pb2, MLBridgeTest_pb2_grpc = ( 83 | MLBridgeTest_long_pb2, 84 | MLBridgeTest_long_pb2_grpc, 85 | ) 86 | elif args.test_number == 3: 87 | import MLBridgeTest_float_pb2, MLBridgeTest_float_pb2_grpc 88 | 89 | MLBridgeTest_pb2, MLBridgeTest_pb2_grpc = ( 90 | MLBridgeTest_float_pb2, 91 | MLBridgeTest_float_pb2_grpc, 92 | ) 93 | elif args.test_number == 4: 94 | import MLBridgeTest_double_pb2, MLBridgeTest_double_pb2_grpc 95 | 96 | MLBridgeTest_pb2, MLBridgeTest_pb2_grpc = ( 97 | MLBridgeTest_double_pb2, 98 | MLBridgeTest_double_pb2_grpc, 99 | ) 100 | elif args.test_number == 5: 101 | import MLBridgeTest_char_pb2, MLBridgeTest_char_pb2_grpc 102 | 103 | MLBridgeTest_pb2, MLBridgeTest_pb2_grpc = ( 104 | MLBridgeTest_char_pb2, 105 | MLBridgeTest_char_pb2_grpc, 106 | ) 107 | elif args.test_number == 6: 108 | import MLBridgeTest_bool_pb2, MLBridgeTest_bool_pb2_grpc 109 | 110 | MLBridgeTest_pb2, MLBridgeTest_pb2_grpc = ( 111 | MLBridgeTest_bool_pb2, 112 | MLBridgeTest_bool_pb2_grpc, 113 | ) 114 | elif args.test_number == 7: 115 | import MLBridgeTest_vec_int_pb2, MLBridgeTest_vec_int_pb2_grpc 116 | 117 | MLBridgeTest_pb2, MLBridgeTest_pb2_grpc = ( 118 | MLBridgeTest_vec_int_pb2, 119 | MLBridgeTest_vec_int_pb2_grpc, 120 | ) 121 | elif args.test_number == 8: 122 | import MLBridgeTest_vec_long_pb2, MLBridgeTest_vec_long_pb2_grpc 123 | 124 | MLBridgeTest_pb2, MLBridgeTest_pb2_grpc = ( 125 | MLBridgeTest_vec_long_pb2, 126 | MLBridgeTest_vec_long_pb2_grpc, 127 | ) 128 | elif args.test_number == 9: 129 | import MLBridgeTest_vec_float_pb2, MLBridgeTest_vec_float_pb2_grpc 130 | 131 | MLBridgeTest_pb2, MLBridgeTest_pb2_grpc = ( 132 | MLBridgeTest_vec_float_pb2, 133 | MLBridgeTest_vec_float_pb2_grpc, 134 | ) 135 | elif args.test_number == 10: 136 | import MLBridgeTest_vec_double_pb2, MLBridgeTest_vec_double_pb2_grpc 137 | 138 | MLBridgeTest_pb2, MLBridgeTest_pb2_grpc = ( 139 | MLBridgeTest_vec_double_pb2, 140 | MLBridgeTest_vec_double_pb2_grpc, 141 | ) 142 | 143 | 144 | class DummyModel(nn.Module): 145 | def __init__(self): 146 | nn.Module.__init__(self) 147 | 148 | def forward(self, input): 149 | return 2 - input 150 | 151 | 152 | def export_onnx_model(input_dim=100): 153 | onnx_filename = "./onnx/dummy_model.onnx" 154 | dummy_value = torch.randn(1, input_dim) 155 | torch.onnx.export( 156 | DummyModel(), 157 | dummy_value, 158 | onnx_filename, 159 | input_names=["obs"], 160 | verbose=True, 161 | export_params=True, 162 | ) 163 | print(f"Model exported to {onnx_filename}") 164 | 165 | 166 | expected_type = { 167 | 1: "int", 168 | 2: "long", 169 | 3: "float", 170 | 4: "double", 171 | 5: "char", 172 | 6: "bool", 173 | 7: "vec_int", 174 | 8: "vec_long", 175 | 9: "vec_float", 176 | 10: "vec_double", 177 | } 178 | 179 | expected_data = { 180 | 1: 11, 181 | 2: 1234567890, 182 | 3: 3.14, 183 | 4: 0.123456789123456789, 184 | 5: ord("a"), 185 | 6: True, 186 | 7: [11, 22, 33], 187 | 8: [123456780, 222, 333], 188 | 9: [11.1, 22.2, 33.3], 189 | 10: [-1.1111111111, -2.2222222222, -3.3333333333], 190 | } 191 | 192 | returned_data = { 193 | 1: 12, 194 | 2: 1234567891, 195 | 3: 4.14, 196 | 4: 1.123456789123456789, 197 | 5: ord("b"), 198 | 6: False, 199 | 7: [12, 23, 34], 200 | 8: [123456780, 123456781, 123456782], 201 | 9: [1.11, 2.22, -3.33, 0], 202 | 10: [1.12345678912345670, -1.12345678912345671], 203 | } 204 | 205 | if args.use_pipe and args.data_format == "bytes": 206 | returned_data.update( 207 | { 208 | 2: ctypes.c_long(1234567891), 209 | 4: ctypes.c_double(1.123456789123456789), 210 | 8: [ 211 | ctypes.c_long(123456780), 212 | ctypes.c_long(123456781), 213 | ctypes.c_long(123456782), 214 | ], 215 | 10: [ 216 | ctypes.c_double(1.12345678912345670), 217 | ctypes.c_double(-1.12345678912345671), 218 | ], 219 | } 220 | ) 221 | if args.use_pipe and args.data_format == "json": 222 | returned_data.update( 223 | { 224 | 2: ctypes.c_long(12345), 225 | 4: ctypes.c_double(1.123456789123456789), 226 | 8: [ctypes.c_long(6780), ctypes.c_long(6781), ctypes.c_long(6782)], 227 | 10: [ 228 | ctypes.c_double(1.12345678912345670), 229 | ctypes.c_double(-1.12345678912345671), 230 | ], 231 | } 232 | ) 233 | 234 | status = SUCCESS 235 | 236 | # test index vs received data 237 | def checkData(index, data): 238 | global status 239 | if not args.silent: 240 | print(" ", expected_type[index], "request:", data) 241 | 242 | if isinstance(expected_data[index], list): 243 | for e, d in zip(expected_data[index], data): 244 | if abs(e - d) > 10e-6: 245 | print( 246 | f"Error: Expected {expected_type[index]} request: {expected_data[index]}, Received: {data}" 247 | ) 248 | status = FAIL 249 | # raise Exception(f"Mismatch in {expected_type[i]}") 250 | 251 | elif abs(data - expected_data[index]) > 10e-6: 252 | print( 253 | f"Error: Expected {expected_type[index]} request: {expected_data[index]}, Received: {data}" 254 | ) 255 | status = FAIL 256 | # raise Exception(f"Mismatch in {expected_type[i]}") 257 | 258 | 259 | def run_pipe_communication(data_format, pipe_name): 260 | compiler_interface = PipeCompilerInterface(data_format, "./" + pipe_name) 261 | if not args.silent: 262 | print("PipeCompilerInterface init...") 263 | compiler_interface.reset_pipes() 264 | 265 | i = 0 266 | while True: 267 | i += 1 268 | try: 269 | data = compiler_interface.evaluate() 270 | if data_format == "json": 271 | key = list(data)[0] 272 | data = data[key] 273 | elif data_format == "bytes": 274 | data = [x for x in data[0]] 275 | if len(data) == 1: 276 | data = data[0] 277 | 278 | checkData(i, data) 279 | 280 | compiler_interface.populate_buffer(returned_data[i]) 281 | 282 | if i == len(expected_type): 283 | data = compiler_interface.evaluate(mode="exit") 284 | exit(status) 285 | except Exception as e: 286 | print("*******Exception*******", e) 287 | compiler_interface.reset_pipes() 288 | 289 | 290 | class service_server(MLBridgeTest_pb2_grpc.MLBridgeTestService): 291 | def __init__(self): 292 | pass 293 | 294 | def getAdvice(self, request, context): 295 | try: 296 | request_type = [var for var in dir(request) if "request" in var] 297 | data = getattr(request, request_type[0]) 298 | checkData(args.test_number, data) 299 | if status == FAIL: 300 | os.system("touch mlbridge-grpc-fail.txt") 301 | reply = MLBridgeTest_pb2.Reply(action=returned_data[args.test_number]) 302 | return reply 303 | except: 304 | reply = MLBridgeTest_pb2.Reply(action=-1) 305 | return reply 306 | 307 | 308 | def run_grpc_communication(): 309 | # parent with test_number 0 spawns different servers 310 | if args.test_number == 0: 311 | process_list = [] 312 | for i in range(1, len(expected_type) + 1): 313 | p = subprocess.Popen( 314 | f"python mlbridge-test.py --use_grpc --server_port={args.server_port} --silent={args.silent} --test_number={i}".split(), 315 | ) 316 | process_list.append(p) 317 | 318 | time.sleep(10) 319 | global status 320 | for p in process_list: 321 | if os.path.isfile("mlbridge-grpc-fail.txt"): 322 | status = FAIL 323 | os.system("rm mlbridge-grpc-fail.txt") 324 | p.terminate() 325 | exit(status) 326 | 327 | # servers serve different datatypes 328 | else: 329 | compiler_interface = GrpcCompilerInterface( 330 | mode="server", 331 | add_server_method=MLBridgeTest_pb2_grpc.add_MLBridgeTestServiceServicer_to_server, 332 | grpc_service_obj=service_server(), 333 | hostport=args.server_port + args.test_number, 334 | ) 335 | compiler_interface.start_server() 336 | 337 | 338 | if __name__ == "__main__": 339 | if args.use_pipe: 340 | run_pipe_communication(args.data_format, args.pipe_name) 341 | elif args.use_grpc: 342 | run_grpc_communication() 343 | elif args.export_onnx: 344 | export_onnx_model() 345 | -------------------------------------------------------------------------------- /test/mlbridge-test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | BLUE='\033[0;34m' 3 | RED='\033[0;31m' 4 | GREEN='\033[0;32m' 5 | NC='\033[0m' # No Color 6 | BOLD='\033[1m' 7 | 8 | REPO_DIR=$GITHUB_WORKSPACE 9 | BUILD_DIR=$REPO_DIR/build_release 10 | SERVER_FILE=$REPO_DIR/test/mlbridge-test.py 11 | 12 | export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python 13 | 14 | STATUS=0 15 | SERVER_PID=0 16 | 17 | run_test() { 18 | sleep 5 19 | $@ 20 | if [ $? != 0 ]; then 21 | STATUS=1 22 | echo -e "$(tput bold)${RED}[Test Failed] Error detected by compiler side.${NC}" 23 | kill $SERVER_PID 24 | SERVER_PID=0 25 | fi 26 | 27 | if [ $SERVER_PID != 0 ]; then 28 | wait $SERVER_PID 29 | if [ $? != 0 ]; then 30 | STATUS=1 31 | echo -e "$(tput bold)${RED}[Test Failed] Error detected by model side.${NC}" 32 | fi 33 | fi 34 | 35 | if [ $STATUS == 0 ]; then 36 | echo -e "${GREEN}${BOLD}[Test Passed] Datatypes transmitted succesfully.${NC}" 37 | fi 38 | 39 | SERVER_PID=0 40 | } 41 | 42 | # source deactivate 43 | # source activate ml_loopdist_env 44 | 45 | echo -e "${BLUE}${BOLD}Testing MLBridge [pipe-bytes]${NC}" 46 | python $SERVER_FILE --use_pipe=True --data_format=bytes --pipe_name=mlbridgepipe --silent=True & 47 | SERVER_PID=$! 48 | run_test $BUILD_DIR/bin/MLCompilerBridgeTest --test-config=pipe-bytes --test-pipe-name=mlbridgepipe --silent 49 | 50 | echo -e "${BLUE}${BOLD}Testing MLBridge [pipe-json]${NC}" 51 | python $SERVER_FILE --use_pipe=True --data_format=json --pipe_name=mlbridgepipe2 --silent=True & 52 | SERVER_PID=$! 53 | run_test $BUILD_DIR/bin/MLCompilerBridgeTest --test-config=pipe-json --test-pipe-name=mlbridgepipe2 --silent 54 | 55 | echo -e "${BLUE}${BOLD}Testing MLBridge [grpc]${NC}" 56 | python $SERVER_FILE --use_grpc --server_port=50155 --silent=True & 57 | SERVER_PID=$! 58 | run_test $BUILD_DIR/bin/MLCompilerBridgeTest --test-config=grpc --test-server-address="0.0.0.0:50155" --silent 59 | 60 | echo -e "${BLUE}${BOLD}Testing MLBridge [onnx]${NC}" 61 | run_test $BUILD_DIR/bin/MLCompilerBridgeTest --test-config=onnx --onnx-model-path=$REPO_DIR/test/onnx/dummy_model.onnx 62 | 63 | exit $STATUS 64 | -------------------------------------------------------------------------------- /test/onnx/dummy_model.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IITH-Compilers/ML-Compiler-Bridge/53bf989f6e34125a2be6dd6b832cb852fb6922ce/test/onnx/dummy_model.onnx -------------------------------------------------------------------------------- /test/protos/MLBridgeTest_bool.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package MLBridgeTestgRPC_bool; 4 | 5 | service MLBridgeTestService { 6 | rpc getAdvice(Request) returns (Reply) {} 7 | } 8 | 9 | message Request { bool request_bool = 1; } 10 | message Reply { bool action = 1; } 11 | -------------------------------------------------------------------------------- /test/protos/MLBridgeTest_char.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package MLBridgeTestgRPC_char; 4 | 5 | service MLBridgeTestService { 6 | rpc getAdvice(Request) returns (Reply) {} 7 | } 8 | 9 | message Request { string request_char = 1; } 10 | message Reply { string action = 1; } 11 | -------------------------------------------------------------------------------- /test/protos/MLBridgeTest_double.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package MLBridgeTestgRPC_double; 4 | 5 | service MLBridgeTestService { 6 | rpc getAdvice(Request) returns (Reply) {} 7 | } 8 | 9 | message Request { double request_double = 1; } 10 | message Reply { double action = 1; } 11 | -------------------------------------------------------------------------------- /test/protos/MLBridgeTest_float.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package MLBridgeTestgRPC_float; 4 | 5 | service MLBridgeTestService { 6 | rpc getAdvice(Request) returns (Reply) {} 7 | } 8 | 9 | message Request { float request_float = 1; } 10 | message Reply { float action = 1; } 11 | -------------------------------------------------------------------------------- /test/protos/MLBridgeTest_int.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package MLBridgeTestgRPC_int; 4 | 5 | service MLBridgeTestService { 6 | rpc getAdvice(Request) returns (Reply) {} 7 | } 8 | 9 | message Request { int32 request_int = 1; } 10 | message Reply { int32 action = 1; } 11 | -------------------------------------------------------------------------------- /test/protos/MLBridgeTest_long.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package MLBridgeTestgRPC_long; 4 | 5 | service MLBridgeTestService { 6 | rpc getAdvice(Request) returns (Reply) {} 7 | } 8 | 9 | message Request { int64 request_long = 1; } 10 | message Reply { int64 action = 1; } 11 | -------------------------------------------------------------------------------- /test/protos/MLBridgeTest_vec_double.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package MLBridgeTestgRPC_vec_double; 4 | 5 | service MLBridgeTestService { 6 | rpc getAdvice(Request) returns (Reply) {} 7 | } 8 | 9 | message Request { repeated double request_vec_double = 1; } 10 | message Reply { repeated double action = 1; } 11 | -------------------------------------------------------------------------------- /test/protos/MLBridgeTest_vec_float.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package MLBridgeTestgRPC_vec_float; 4 | 5 | service MLBridgeTestService { 6 | rpc getAdvice(Request) returns (Reply) {} 7 | } 8 | 9 | message Request { repeated float request_vec_float = 1; } 10 | message Reply { repeated float action = 1; } 11 | -------------------------------------------------------------------------------- /test/protos/MLBridgeTest_vec_int.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package MLBridgeTestgRPC_vec_int; 4 | 5 | service MLBridgeTestService { 6 | rpc getAdvice(Request) returns (Reply) {} 7 | } 8 | 9 | message Request { repeated int32 request_vec_int = 1; } 10 | message Reply { repeated int32 action = 1; } 11 | -------------------------------------------------------------------------------- /test/protos/MLBridgeTest_vec_long.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package MLBridgeTestgRPC_vec_long; 4 | 5 | service MLBridgeTestService { 6 | rpc getAdvice(Request) returns (Reply) {} 7 | } 8 | 9 | message Request { repeated int64 request_vec_long = 1; } 10 | message Reply { repeated int64 action = 1; } 11 | -------------------------------------------------------------------------------- /test/protos/helloMLBridge.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package helloMLBridgegRPC_a; 4 | 5 | service HelloMLBridgeService { 6 | rpc getAdvice(Request) returns (Reply) {} 7 | } 8 | 9 | message Request { int32 data = 1; } 10 | 11 | message Reply { int32 action = 1; } 12 | -------------------------------------------------------------------------------- /tools.cpp: -------------------------------------------------------------------------------- 1 | void dummy() {} 2 | --------------------------------------------------------------------------------