├── .clang-format ├── .devcontainer └── devcontainer.json ├── .gitignore ├── .vscode └── launch.json ├── CMakeLists.txt ├── LICENSE ├── README.md ├── build.sh ├── dockerfiles └── Dockerfile ├── inputs ├── image_0.ubyte ├── image_1.ubyte ├── image_10.ubyte ├── image_11.ubyte ├── image_12.ubyte ├── image_13.ubyte ├── image_14.ubyte ├── image_15.ubyte ├── image_16.ubyte ├── image_17.ubyte ├── image_18.ubyte ├── image_19.ubyte ├── image_2.ubyte ├── image_20.ubyte ├── image_21.ubyte ├── image_22.ubyte ├── image_23.ubyte ├── image_24.ubyte ├── image_25.ubyte ├── image_26.ubyte ├── image_27.ubyte ├── image_28.ubyte ├── image_29.ubyte ├── image_3.ubyte ├── image_30.ubyte ├── image_31.ubyte ├── image_32.ubyte ├── image_33.ubyte ├── image_34.ubyte ├── image_35.ubyte ├── image_36.ubyte ├── image_37.ubyte ├── image_38.ubyte ├── image_39.ubyte ├── image_4.ubyte ├── image_40.ubyte ├── image_41.ubyte ├── image_42.ubyte ├── image_43.ubyte ├── image_44.ubyte ├── image_45.ubyte ├── image_46.ubyte ├── image_47.ubyte ├── image_48.ubyte ├── image_49.ubyte ├── image_5.ubyte ├── image_50.ubyte ├── image_51.ubyte ├── image_52.ubyte ├── image_53.ubyte ├── image_54.ubyte ├── image_55.ubyte ├── image_56.ubyte ├── image_57.ubyte ├── image_58.ubyte ├── image_59.ubyte ├── image_6.ubyte ├── image_60.ubyte ├── image_61.ubyte ├── image_62.ubyte ├── image_63.ubyte ├── image_64.ubyte ├── image_65.ubyte ├── image_66.ubyte ├── image_67.ubyte ├── image_68.ubyte ├── image_69.ubyte ├── image_7.ubyte ├── image_70.ubyte ├── image_71.ubyte ├── image_72.ubyte ├── image_73.ubyte ├── image_74.ubyte ├── image_75.ubyte ├── image_76.ubyte ├── image_77.ubyte ├── image_78.ubyte ├── image_79.ubyte ├── image_8.ubyte ├── image_80.ubyte ├── image_81.ubyte ├── image_82.ubyte ├── image_83.ubyte ├── image_84.ubyte ├── image_85.ubyte ├── image_86.ubyte ├── image_87.ubyte ├── image_88.ubyte ├── image_89.ubyte ├── image_9.ubyte ├── image_90.ubyte ├── image_91.ubyte ├── image_92.ubyte ├── image_93.ubyte ├── image_94.ubyte ├── image_95.ubyte ├── image_96.ubyte ├── image_97.ubyte ├── image_98.ubyte └── image_99.ubyte ├── model_repository └── mnist.yaml ├── models ├── mnist_conv_ffn.onnx ├── mnist_ffn.onnx └── mnist_ffn_complex.onnx ├── server ├── go.mod ├── inference_wrapper.cpp ├── inference_wrapper.h ├── libengine_lib.so ├── main.go └── onnx-ml.pb.h ├── src ├── CMakeLists.txt ├── attribute.cpp ├── attribute.h ├── cpu_allocator.h ├── cpu_provider.cpp ├── cpu_provider.h ├── cuda_allocator.h ├── cuda_memory_pool.h ├── cuda_provider.cpp ├── cuda_provider.h ├── cuda_provider_unoptimized.cpp ├── cuda_provider_unoptimized.h ├── device.h ├── execution_provider.h ├── gemm_cpu.cpp ├── gemm_cpu.h ├── graph.cpp ├── graph.h ├── inference_session.cpp ├── inference_session.h ├── input_loader.cpp ├── input_loader.h ├── kernels.cu ├── kernels.h ├── main.cpp ├── memory_allocator.h ├── model_config.cpp ├── model_config.h ├── node.cpp ├── node.h ├── onnx-ml.proto ├── onnx_helper.cpp ├── onnx_helper.h ├── operators.cpp ├── operators.h ├── optypes.h ├── pinned_cpu_allocator.h ├── tensor.cpp ├── tensor.h └── test │ ├── CMakeLists.txt │ ├── gemm_bench.cpp │ ├── gemm_cuda.cpp │ ├── gemm_cuda_test.cpp │ ├── gemm_test.cpp │ ├── node_test.cpp │ ├── operators_test.cpp │ └── tensor_test.cpp └── utils ├── infer.py ├── infer_server.py ├── matrix.py └── mnist.py /.clang-format: -------------------------------------------------------------------------------- 1 | # Use Google style as the base 2 | BasedOnStyle: Google 3 | 4 | # Override specific rules 5 | IndentWidth: 4 # Use 4 spaces for indentation 6 | PointerAlignment: Left # Left-align pointer and reference symbols 7 | 8 | # Additional rules can be added if necessary 9 | Language: Cpp # Set the language as C++ (this should also apply to CUDA) 10 | TabWidth: 4 # Ensure tabs are 4 spaces wide 11 | UseTab: Never # Use spaces instead of tabs 12 | -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Onnx Runtime - Eager Mode", 3 | "build": { 4 | "dockerfile": "../dockerfiles/Dockerfile", 5 | "context": ".." 6 | }, 7 | "customizations": { 8 | // Configure properties specific to VS Code. 9 | "vscode": { 10 | // Set *default* container specific settings.json values on container create. 11 | "settings": { 12 | "python.languageServer": "Default", 13 | "cmake.ignoreCMakeListsMissing": true, 14 | "python.defaultInterpreterPath": "/usr/bin/python3", 15 | "python.linting.enabled": true, 16 | "python.linting.pylintEnabled": true 17 | }, 18 | // Add the IDs of extensions you want installed when the container is created. 19 | "extensions": [ 20 | "ms-python.vscode-pylance", 21 | "ms-python.python", 22 | "ms-vscode.cpptools-extension-pack", 23 | "ms-vscode.cmake-tools" 24 | ] 25 | } 26 | }, 27 | // Use 'forwardPorts' to make a list of ports inside the container available locally. 28 | // "forwardPorts": [], 29 | // Uncomment the next line to run commands after the container is created - for example installing curl. 30 | // "postCreateCommand": "apt-get update && apt-get install -y curl", 31 | // Uncomment when using a ptrace-based debugger like C++, Go, and Rust 32 | "runArgs": [ 33 | "--cap-add=SYS_PTRACE", 34 | "--security-opt", 35 | "seccomp=unconfined", 36 | "--gpus=all" 37 | ], 38 | // Comment out to connect as root instead. More info: https://aka.ms/vscode-remote/containers/non-root. 39 | // "remoteUser": "vscode", 40 | "features": { 41 | "git": "os-provided", 42 | "python": "os-provided", 43 | }, 44 | "postCreateCommand": "nvidia-smi" 45 | } -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | .cache 3 | models/bert.onnx 4 | models/mnist-*.onnx 5 | models/resnet10.onnx 6 | utils/venv 7 | perf.data* 8 | server/server 9 | profiles/* -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "CUDA C++: Launch", 9 | "type": "cuda-gdb", 10 | "request": "launch", 11 | "program": "${workspaceFolder}/build/src/test/gemm_cuda_test" 12 | }, 13 | { 14 | "name": "CUDA C++: Attach", 15 | "type": "cuda-gdb", 16 | "request": "attach" 17 | } 18 | ] 19 | } -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.22) 2 | project(engine LANGUAGES CXX CUDA) 3 | 4 | set(CMAKE_C_COMPILER clang) 5 | set(CMAKE_CXX_COMPILER clang++) 6 | 7 | set(CMAKE_CXX_STANDARD 20) 8 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 9 | set(CMAKE_CXX_EXTENSIONS OFF) 10 | set(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF) 11 | 12 | set(CMAKE_CUDA_STANDARD 17) 13 | set(CMAKE_CUDA_STANDARD_REQUIRED ON) 14 | 15 | set(CMAKE_CXX_FLAGS_DEBUG "-g -O0") 16 | set(CMAKE_CXX_FLAGS_RELEASE "-O3") 17 | set(CMAKE_CUDA_FLAGS_DEBUG "-g -G -O0") 18 | set(CMAKE_CUDA_FLAGS_RELEASE "-O3 --use_fast_math") 19 | 20 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -Wpedantic -Wno-deprecated-declarations -Wno-unused-function") 21 | set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -Wall,-Wextra,-Wno-deprecated-declarations,-Wno-unused-function") 22 | 23 | find_package(Protobuf REQUIRED) 24 | find_package(benchmark REQUIRED) 25 | find_package(GTest REQUIRED) 26 | find_package(yaml-cpp REQUIRED) 27 | find_package(CUDAToolkit REQUIRED) 28 | 29 | # Assuming that CUDA is required. 30 | set(CMAKE_CUDA_ARCHITECTURES native) 31 | add_definitions(-DUSE_CUDA) 32 | 33 | include_directories(${PROTOBUF_INCLUDE_DIRS}) 34 | include_directories(${YAML_CPP_INCLUDE_DIR}) 35 | 36 | add_subdirectory(src) 37 | add_subdirectory(src/test) 38 | 39 | message(STATUS "") 40 | message(STATUS "Configuration summary:") 41 | message(STATUS " CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") 42 | message(STATUS " Build type: ${CMAKE_BUILD_TYPE}") 43 | message(STATUS " Tensor Cores: ${USE_TENSOR_CORES}") 44 | if(IPO_SUPPORTED) 45 | message(STATUS " LTO support: YES") 46 | else() 47 | message(STATUS " LTO support: NO (${IPO_ERROR})") 48 | endif() 49 | message(STATUS "") 50 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Michal Pitr 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # C++ Inference Engine from scratch 2 | 3 | I am developing this project to learn C++ and get hands-on experience with inference engines. 4 | 5 | ## How to build 6 | 7 | 1. clone the project: `git clone git@github.com:MichalPitr/inference_engine.git` 8 | 2. `cd inference_engine` 9 | 3. `sh build.sh` 10 | 11 | CMake will complain if you are missing some system dependencies: protobuf, gtest, google benchmark, yaml-cpp 12 | 13 | ## How to run simple example 14 | 15 | This starts up an http server and uses python to send requests. You can also do the equivalent with curl via command line. 16 | 17 | 1. Build like explained above. 18 | 2. `cd server` 19 | 3. `go run main.go` 20 | 4. Open another terminal 21 | 5. `cd utils` 22 | 6. `source venv/bin/activate` 23 | 7. `python infer_server.py` 24 | 25 | ## Backlog: 26 | 27 | * Optimize Cuda kernels. Gemm is very naive at the moment. 28 | * Add dynamic batching to Go server. 29 | * Add graph optimizations. 30 | * Add input validations to Go server. 31 | * Optimize memory allocator usage - should check available memory during loading, total memory usage can be pretty accurately estimated. 32 | * Improve error handling. 33 | * Explore NVTX profiling. 34 | 35 | ## Contributing 36 | 37 | This project wasn't designed with the idea of external contributions but if you fancy, improvements are welcome! 38 | 39 | ## Blog posts 40 | 41 | I enjoy writing technical blog posts and I've written some about this project: 42 | 43 | * [Initial design](https://michalpitr.substack.com/p/build-your-own-inference-engine-from) 44 | * [Profiling-driven optimizations](https://michalpitr.substack.com/p/inference-engine-optimizing-performance) 45 | * [Adding CUDA execution provider](https://michalpitr.substack.com/p/inference-engine-accelerating-with) 46 | 47 | -------------------------------------------------------------------------------- /build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e # Exit immediately if a command exits with a non-zero status 3 | 4 | # Define directories using absolute paths 5 | export PROJECT_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" 6 | SRC_DIR="${PROJECT_ROOT}/src" 7 | BUILD_DIR="${PROJECT_ROOT}/build" 8 | SERVER_DIR="${PROJECT_ROOT}/server" 9 | GENERATED_HEADERS_DIR="${BUILD_DIR}/src" 10 | 11 | echo "Project root: ${PROJECT_ROOT}" 12 | 13 | # Step 1: Create build directory if it doesn't exist 14 | mkdir -p "${BUILD_DIR}" 15 | 16 | # Step 2: Run CMake 17 | /usr/bin/cmake --build "${BUILD_DIR}" --config Release --target engine_exe -j 14 -- 18 | 19 | # Step 3: Create a symbolic link for the shared library 20 | ln -sf "${GENERATED_HEADERS_DIR}/libengine_lib.so" "${SERVER_DIR}/libengine_lib.so" 21 | echo "Created symlink for libengine_lib.so" 22 | 23 | # Step 4: Create a symbolic link for the onnx-ml.pb.h header 24 | ln -sf "${GENERATED_HEADERS_DIR}/onnx-ml.pb.h" "${SERVER_DIR}/onnx-ml.pb.h" 25 | echo "Created symlink for onnx-ml.pb.h" 26 | 27 | # Step 5: Build the Go server 28 | cd "${SERVER_DIR}" 29 | go build 30 | 31 | echo "Build process completed successfully!" -------------------------------------------------------------------------------- /dockerfiles/Dockerfile: -------------------------------------------------------------------------------- 1 | ARG CUDA_VERSION=12.6.1 2 | ARG CUDNN_VERSION=9.5.0.50 3 | ARG OS=ubuntu24.04 4 | 5 | # Start with NVIDIA CUDA base image 6 | FROM nvcr.io/nvidia/cuda:${CUDA_VERSION}-devel-${OS} 7 | 8 | # Avoid interactive dialog during package installation 9 | ENV DEBIAN_FRONTEND=noninteractive 10 | 11 | # Install essential development tools and libraries 12 | RUN apt-get update && apt-get install -y \ 13 | build-essential \ 14 | cmake \ 15 | git \ 16 | vim \ 17 | gdb \ 18 | openssh-client \ 19 | curl \ 20 | protobuf-compiler \ 21 | libbenchmark-dev \ 22 | libgtest-dev \ 23 | libyaml-cpp-dev \ 24 | clang \ 25 | clang-format \ 26 | clang-tidy \ 27 | libc++-dev \ 28 | libc++abi-dev \ 29 | python3 \ 30 | python3-pip \ 31 | && rm -rf /var/lib/apt/lists/* 32 | 33 | # Set environment variables for CUDA 34 | ENV PATH=/usr/local/cuda/bin:${PATH} 35 | ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:${LD_LIBRARY_PATH} 36 | 37 | # Set clang as the default C/C++ compiler 38 | ENV CC=clang 39 | ENV CXX=clang++ 40 | 41 | WORKDIR /code 42 | 43 | CMD ["/bin/bash"] -------------------------------------------------------------------------------- /inputs/image_0.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_0.ubyte -------------------------------------------------------------------------------- /inputs/image_1.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_1.ubyte -------------------------------------------------------------------------------- /inputs/image_10.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_10.ubyte -------------------------------------------------------------------------------- /inputs/image_11.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_11.ubyte -------------------------------------------------------------------------------- /inputs/image_12.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_12.ubyte -------------------------------------------------------------------------------- /inputs/image_13.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_13.ubyte -------------------------------------------------------------------------------- /inputs/image_14.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_14.ubyte -------------------------------------------------------------------------------- /inputs/image_15.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_15.ubyte -------------------------------------------------------------------------------- /inputs/image_16.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_16.ubyte -------------------------------------------------------------------------------- /inputs/image_17.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_17.ubyte -------------------------------------------------------------------------------- /inputs/image_18.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_18.ubyte -------------------------------------------------------------------------------- /inputs/image_19.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_19.ubyte -------------------------------------------------------------------------------- /inputs/image_2.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_2.ubyte -------------------------------------------------------------------------------- /inputs/image_20.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_20.ubyte -------------------------------------------------------------------------------- /inputs/image_21.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_21.ubyte -------------------------------------------------------------------------------- /inputs/image_22.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_22.ubyte -------------------------------------------------------------------------------- /inputs/image_23.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_23.ubyte -------------------------------------------------------------------------------- /inputs/image_24.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_24.ubyte -------------------------------------------------------------------------------- /inputs/image_25.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_25.ubyte -------------------------------------------------------------------------------- /inputs/image_26.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_26.ubyte -------------------------------------------------------------------------------- /inputs/image_27.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_27.ubyte -------------------------------------------------------------------------------- /inputs/image_28.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_28.ubyte -------------------------------------------------------------------------------- /inputs/image_29.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_29.ubyte -------------------------------------------------------------------------------- /inputs/image_3.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_3.ubyte -------------------------------------------------------------------------------- /inputs/image_30.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_30.ubyte -------------------------------------------------------------------------------- /inputs/image_31.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_31.ubyte -------------------------------------------------------------------------------- /inputs/image_32.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_32.ubyte -------------------------------------------------------------------------------- /inputs/image_33.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_33.ubyte -------------------------------------------------------------------------------- /inputs/image_34.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_34.ubyte -------------------------------------------------------------------------------- /inputs/image_35.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_35.ubyte -------------------------------------------------------------------------------- /inputs/image_36.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_36.ubyte -------------------------------------------------------------------------------- /inputs/image_37.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_37.ubyte -------------------------------------------------------------------------------- /inputs/image_38.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_38.ubyte -------------------------------------------------------------------------------- /inputs/image_39.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_39.ubyte -------------------------------------------------------------------------------- /inputs/image_4.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_4.ubyte -------------------------------------------------------------------------------- /inputs/image_40.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_40.ubyte -------------------------------------------------------------------------------- /inputs/image_41.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_41.ubyte -------------------------------------------------------------------------------- /inputs/image_42.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_42.ubyte -------------------------------------------------------------------------------- /inputs/image_43.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_43.ubyte -------------------------------------------------------------------------------- /inputs/image_44.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_44.ubyte -------------------------------------------------------------------------------- /inputs/image_45.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_45.ubyte -------------------------------------------------------------------------------- /inputs/image_46.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_46.ubyte -------------------------------------------------------------------------------- /inputs/image_47.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_47.ubyte -------------------------------------------------------------------------------- /inputs/image_48.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_48.ubyte -------------------------------------------------------------------------------- /inputs/image_49.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_49.ubyte -------------------------------------------------------------------------------- /inputs/image_5.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_5.ubyte -------------------------------------------------------------------------------- /inputs/image_50.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_50.ubyte -------------------------------------------------------------------------------- /inputs/image_51.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_51.ubyte -------------------------------------------------------------------------------- /inputs/image_52.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_52.ubyte -------------------------------------------------------------------------------- /inputs/image_53.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_53.ubyte -------------------------------------------------------------------------------- /inputs/image_54.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_54.ubyte -------------------------------------------------------------------------------- /inputs/image_55.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_55.ubyte -------------------------------------------------------------------------------- /inputs/image_56.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_56.ubyte -------------------------------------------------------------------------------- /inputs/image_57.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_57.ubyte -------------------------------------------------------------------------------- /inputs/image_58.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_58.ubyte -------------------------------------------------------------------------------- /inputs/image_59.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_59.ubyte -------------------------------------------------------------------------------- /inputs/image_6.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_6.ubyte -------------------------------------------------------------------------------- /inputs/image_60.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_60.ubyte -------------------------------------------------------------------------------- /inputs/image_61.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_61.ubyte -------------------------------------------------------------------------------- /inputs/image_62.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_62.ubyte -------------------------------------------------------------------------------- /inputs/image_63.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_63.ubyte -------------------------------------------------------------------------------- /inputs/image_64.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_64.ubyte -------------------------------------------------------------------------------- /inputs/image_65.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_65.ubyte -------------------------------------------------------------------------------- /inputs/image_66.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_66.ubyte -------------------------------------------------------------------------------- /inputs/image_67.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_67.ubyte -------------------------------------------------------------------------------- /inputs/image_68.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_68.ubyte -------------------------------------------------------------------------------- /inputs/image_69.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_69.ubyte -------------------------------------------------------------------------------- /inputs/image_7.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_7.ubyte -------------------------------------------------------------------------------- /inputs/image_70.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_70.ubyte -------------------------------------------------------------------------------- /inputs/image_71.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_71.ubyte -------------------------------------------------------------------------------- /inputs/image_72.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_72.ubyte -------------------------------------------------------------------------------- /inputs/image_73.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_73.ubyte -------------------------------------------------------------------------------- /inputs/image_74.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_74.ubyte -------------------------------------------------------------------------------- /inputs/image_75.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_75.ubyte -------------------------------------------------------------------------------- /inputs/image_76.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_76.ubyte -------------------------------------------------------------------------------- /inputs/image_77.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_77.ubyte -------------------------------------------------------------------------------- /inputs/image_78.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_78.ubyte -------------------------------------------------------------------------------- /inputs/image_79.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_79.ubyte -------------------------------------------------------------------------------- /inputs/image_8.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_8.ubyte -------------------------------------------------------------------------------- /inputs/image_80.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_80.ubyte -------------------------------------------------------------------------------- /inputs/image_81.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_81.ubyte -------------------------------------------------------------------------------- /inputs/image_82.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_82.ubyte -------------------------------------------------------------------------------- /inputs/image_83.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_83.ubyte -------------------------------------------------------------------------------- /inputs/image_84.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_84.ubyte -------------------------------------------------------------------------------- /inputs/image_85.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_85.ubyte -------------------------------------------------------------------------------- /inputs/image_86.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_86.ubyte -------------------------------------------------------------------------------- /inputs/image_87.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_87.ubyte -------------------------------------------------------------------------------- /inputs/image_88.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_88.ubyte -------------------------------------------------------------------------------- /inputs/image_89.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_89.ubyte -------------------------------------------------------------------------------- /inputs/image_9.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_9.ubyte -------------------------------------------------------------------------------- /inputs/image_90.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_90.ubyte -------------------------------------------------------------------------------- /inputs/image_91.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_91.ubyte -------------------------------------------------------------------------------- /inputs/image_92.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_92.ubyte -------------------------------------------------------------------------------- /inputs/image_93.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_93.ubyte -------------------------------------------------------------------------------- /inputs/image_94.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_94.ubyte -------------------------------------------------------------------------------- /inputs/image_95.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_95.ubyte -------------------------------------------------------------------------------- /inputs/image_96.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_96.ubyte -------------------------------------------------------------------------------- /inputs/image_97.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_97.ubyte -------------------------------------------------------------------------------- /inputs/image_98.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_98.ubyte -------------------------------------------------------------------------------- /inputs/image_99.ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/inputs/image_99.ubyte -------------------------------------------------------------------------------- /model_repository/mnist.yaml: -------------------------------------------------------------------------------- 1 | model_path: ./models/mnist_ffn_complex.onnx 2 | model_format: onnx 3 | execution_provider: CUDA 4 | batch_size: 32 5 | inputs: 6 | - name: "onnx::Flatten_0" 7 | shape: [1, 1, 28, 28] 8 | data_type: FLOAT32 9 | outputs: 10 | - name: "21" 11 | shape: [1, 10] 12 | data_type: FLOAT32 -------------------------------------------------------------------------------- /models/mnist_conv_ffn.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/models/mnist_conv_ffn.onnx -------------------------------------------------------------------------------- /models/mnist_ffn.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/models/mnist_ffn.onnx -------------------------------------------------------------------------------- /models/mnist_ffn_complex.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalPitr/inference_engine/bbf67f8a39bf04b56d00748679e85c05ce31aa62/models/mnist_ffn_complex.onnx -------------------------------------------------------------------------------- /server/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/MichalPitr/inference_engine/server 2 | 3 | go 1.22.0 4 | -------------------------------------------------------------------------------- /server/inference_wrapper.cpp: -------------------------------------------------------------------------------- 1 | // inference_wrapper.cpp 2 | #include "inference_wrapper.h" 3 | 4 | #include 5 | #include 6 | 7 | #include "cpu_provider.h" 8 | #include "cuda_provider.h" 9 | #include "cuda_provider_unoptimized.h" 10 | #include "inference_session.h" 11 | #include "model_config.h" 12 | 13 | struct InferenceSessionWrapper { 14 | ModelConfig config; 15 | InferenceSession session; 16 | std::unique_ptr provider; 17 | }; 18 | 19 | extern "C" { 20 | 21 | InferenceSessionWrapper* create_session(const char* config_path) { 22 | auto wrapper = new InferenceSessionWrapper(); 23 | wrapper->config = ModelConfig(config_path); 24 | wrapper->session.load_model(wrapper->config); 25 | 26 | Device device = wrapper->config.get_device(); 27 | if (device == Device::CPU) { 28 | wrapper->provider = std::make_unique(); 29 | } else if (device == Device::CUDA) { 30 | wrapper->provider = std::make_unique(); 31 | } else if (device == Device::CUDA_SLOW) { 32 | wrapper->provider = std::make_unique(); 33 | } else { 34 | delete wrapper; 35 | return nullptr; 36 | } 37 | 38 | wrapper->session.set_execution_provider(std::move(wrapper->provider)); 39 | return wrapper; 40 | } 41 | 42 | uint64_t input_size(InferenceSessionWrapper* wrapper) { 43 | auto shape = wrapper->config.get_inputs()[0].shape; 44 | return std::accumulate(shape.begin(), shape.end(), 1ULL, 45 | std::multiplies()); 46 | } 47 | 48 | uint64_t output_size(InferenceSessionWrapper* wrapper) { 49 | auto shape = wrapper->config.get_outputs()[0].shape; 50 | return std::accumulate(shape.begin(), shape.end(), 1ULL, 51 | std::multiplies()); 52 | } 53 | 54 | void destroy_session(InferenceSessionWrapper* wrapper) { delete wrapper; } 55 | 56 | int initialize_provider(InferenceSessionWrapper* wrapper) { 57 | try { 58 | wrapper->session.initialize_provider(); 59 | return 0; 60 | } catch (const std::exception& e) { 61 | return -1; 62 | } 63 | } 64 | 65 | InferenceResult run_inference(InferenceSessionWrapper* wrapper, 66 | float* input_data, uint64_t input_size) { 67 | Tensor input(input_data, {1, input_size}); 68 | wrapper->session.set_input(wrapper->config.get_inputs()[0].name, 69 | std::move(input)); 70 | wrapper->session.run(); 71 | Tensor output = 72 | wrapper->session.get_output(wrapper->config.get_outputs()[0].name); 73 | InferenceResult result; 74 | result.size = output.size(); 75 | result.data = (float*)malloc(result.size * sizeof(float)); 76 | std::memcpy(result.data, output.data(), result.size * sizeof(float)); 77 | 78 | return result; 79 | } 80 | 81 | void free_result(InferenceResult result) { free(result.data); } 82 | } -------------------------------------------------------------------------------- /server/inference_wrapper.h: -------------------------------------------------------------------------------- 1 | #ifndef INFERENCE_WRAPPER_H 2 | #define INFERENCE_WRAPPER_H 3 | 4 | #include 5 | 6 | #ifdef __cplusplus 7 | extern "C" { 8 | #endif 9 | 10 | typedef struct InferenceSessionWrapper InferenceSessionWrapper; 11 | 12 | typedef struct { 13 | float* data; 14 | int size; 15 | } InferenceResult; 16 | 17 | InferenceSessionWrapper* create_session(const char* config_path); 18 | void destroy_session(InferenceSessionWrapper* session); 19 | 20 | uint64_t input_size(InferenceSessionWrapper* session); 21 | uint64_t output_size(InferenceSessionWrapper* session); 22 | int initialize_provider(InferenceSessionWrapper* session); 23 | 24 | InferenceResult run_inference(InferenceSessionWrapper* session, 25 | float* input_data, uint64_t input_size); 26 | void free_result(InferenceResult result); 27 | 28 | #ifdef __cplusplus 29 | } 30 | #endif 31 | 32 | #endif // INFERENCE_WRAPPER_H -------------------------------------------------------------------------------- /server/libengine_lib.so: -------------------------------------------------------------------------------- 1 | /home/michal/code/inference_engine/build/src/libengine_lib.so -------------------------------------------------------------------------------- /server/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | /* 4 | #include 5 | #include 6 | #cgo CXXFLAGS: -std=c++23 -I. -I${SRCDIR}/../src -I/usr/local/cuda-12.4/targets/x86_64-linux/include 7 | #cgo LDFLAGS: -L${SRCDIR} -lengine_lib -L/usr/local/cuda-12.4/targets/x86_64-linux/lib -lcudart -Wl,-rpath,${SRCDIR}:/usr/local/cuda-12.4/targets/x86_64-linux/lib 8 | #include "inference_wrapper.h" 9 | */ 10 | import "C" 11 | 12 | import ( 13 | "encoding/json" 14 | "fmt" 15 | "log" 16 | "net/http" 17 | "runtime" 18 | "sync" 19 | "unsafe" 20 | ) 21 | 22 | var ( 23 | session *C.InferenceSessionWrapper 24 | sessionMutex sync.Mutex 25 | ) 26 | 27 | type InferenceRequest struct { 28 | Data []float32 `json:"data"` 29 | } 30 | 31 | type InferenceResponse struct { 32 | Result []float32 `json:"result"` 33 | } 34 | 35 | func initSession(configPath string) error { 36 | sessionMutex.Lock() 37 | defer sessionMutex.Unlock() 38 | 39 | cConfigPath := C.CString(configPath) 40 | defer C.free(unsafe.Pointer(cConfigPath)) 41 | 42 | session = C.create_session(cConfigPath) 43 | if session == nil { 44 | return fmt.Errorf("failed to create inference session") 45 | } 46 | 47 | result := C.initialize_provider(session) 48 | if result != 0 { 49 | C.destroy_session(session) 50 | return fmt.Errorf("failed to initialize provider") 51 | } 52 | 53 | return nil 54 | } 55 | 56 | func cleanupSession() { 57 | sessionMutex.Lock() 58 | defer sessionMutex.Unlock() 59 | 60 | if session != nil { 61 | C.destroy_session(session) 62 | session = nil 63 | } 64 | } 65 | 66 | func inferenceHandler(w http.ResponseWriter, r *http.Request) { 67 | if r.Method != http.MethodPost { 68 | http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 69 | return 70 | } 71 | 72 | var req InferenceRequest 73 | if err := json.NewDecoder(r.Body).Decode(&req); err != nil { 74 | http.Error(w, err.Error(), http.StatusBadRequest) 75 | return 76 | } 77 | 78 | if len(req.Data) != int(C.input_size(session)) { 79 | http.Error(w, "Input data is empty", http.StatusBadRequest) 80 | return 81 | } 82 | 83 | cData := (*C.float)(unsafe.Pointer(&req.Data[0])) 84 | cSize := C.uint64_t(len(req.Data)) 85 | 86 | sessionMutex.Lock() 87 | cResult := C.run_inference(session, cData, cSize) 88 | sessionMutex.Unlock() 89 | defer C.free_result(cResult) 90 | 91 | resultSlice := (*[1 << 30]C.float)(unsafe.Pointer(cResult.data))[:cResult.size:cResult.size] 92 | goResult := make([]float32, cResult.size) 93 | for i, v := range resultSlice { 94 | goResult[i] = float32(v) 95 | } 96 | 97 | response := InferenceResponse{Result: goResult} 98 | 99 | w.Header().Set("Content-Type", "application/json") 100 | json.NewEncoder(w).Encode(response) 101 | } 102 | 103 | func main() { 104 | runtime.LockOSThread() 105 | defer runtime.UnlockOSThread() 106 | 107 | configPath := "/home/michal/code/inference_engine/model_repository/mnist.yaml" 108 | if err := initSession(configPath); err != nil { 109 | log.Fatalf("Failed to initialize inference session: %v", err) 110 | } 111 | defer cleanupSession() 112 | 113 | http.HandleFunc("/infer", inferenceHandler) 114 | 115 | log.Println("Server starting on :8080") 116 | log.Fatal(http.ListenAndServe(":8080", nil)) 117 | } 118 | -------------------------------------------------------------------------------- /server/onnx-ml.pb.h: -------------------------------------------------------------------------------- 1 | /home/michal/code/inference_engine/build/src/onnx-ml.pb.h -------------------------------------------------------------------------------- /src/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | 2 | # Generate the .h and .cxx files 3 | PROTOBUF_GENERATE_CPP(PROTO_SRCS PROTO_HDRS onnx-ml.proto) 4 | 5 | # Force shared library 6 | set(BUILD_SHARED_LIBS ON) 7 | 8 | # Add an executable 9 | add_library(engine_lib 10 | gemm_cpu.cpp 11 | kernels.cu 12 | attribute.cpp 13 | tensor.cpp 14 | input_loader.cpp 15 | operators.cpp 16 | graph.cpp 17 | node.cpp 18 | onnx_helper.cpp 19 | model_config.cpp 20 | inference_session.cpp 21 | cpu_provider.cpp 22 | cuda_provider.cpp 23 | cuda_provider_unoptimized.cpp 24 | ${PROTO_SRCS} 25 | ${PROTO_HDRS}) 26 | 27 | target_include_directories(engine_lib PUBLIC 28 | ${PROTOBUF_INCLUDE_DIRS} 29 | ${CMAKE_CURRENT_BINARY_DIR} 30 | ${YAML_CPP_INCLUDE_DIR} 31 | ${CUDAToolkit_INCLUDE_DIRS}) 32 | 33 | target_link_libraries(engine_lib PUBLIC ${PROTOBUF_LIBRARIES} yaml-cpp CUDA::cudart) 34 | 35 | add_executable(engine_exe main.cpp) 36 | 37 | target_link_libraries(engine_exe PRIVATE engine_lib) 38 | -------------------------------------------------------------------------------- /src/attribute.cpp: -------------------------------------------------------------------------------- 1 | #include "attribute.h" 2 | 3 | Attribute::Attribute(const onnx::AttributeProto &attrProto) 4 | : name(attrProto.name()) { 5 | switch (attrProto.type()) { 6 | case onnx::AttributeProto::INT: { 7 | value = attrProto.i(); 8 | break; 9 | } 10 | case onnx::AttributeProto::FLOAT: { 11 | value = attrProto.f(); 12 | break; 13 | } 14 | case onnx::AttributeProto::INTS: { 15 | std::vector ints(attrProto.ints_size()); 16 | for (int i = 0; i < attrProto.ints_size(); ++i) { 17 | ints[i] = attrProto.ints(i); 18 | } 19 | value = ints; 20 | break; 21 | } 22 | default: 23 | throw std::runtime_error("Unsupported attribute type" + 24 | std::to_string(attrProto.type())); 25 | } 26 | } 27 | 28 | const std::string &Attribute::getName() const { return name; } 29 | const Attribute::AttributeValue &Attribute::getValue() const { return value; } -------------------------------------------------------------------------------- /src/attribute.h: -------------------------------------------------------------------------------- 1 | #ifndef ATTRIBUTE_H 2 | #define ATTRIBUTE_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include "onnx-ml.pb.h" 9 | 10 | class Attribute { 11 | public: 12 | using AttributeValue = std::variant>; 13 | 14 | Attribute(const onnx::AttributeProto &attrProto); 15 | 16 | const std::string &getName() const; 17 | const AttributeValue &getValue() const; 18 | 19 | private: 20 | std::string name; 21 | AttributeValue value; 22 | }; 23 | 24 | #endif // ATTRIBUTE_H -------------------------------------------------------------------------------- /src/cpu_allocator.h: -------------------------------------------------------------------------------- 1 | #ifndef CPU_ALLOCATOR_H 2 | #define CPU_ALLOCATOR_H 3 | 4 | #include 5 | #include 6 | 7 | #include "memory_allocator.h" 8 | 9 | class CpuAllocator : public Allocator { 10 | public: 11 | void* allocate(size_t size) override { return malloc(size); } 12 | 13 | void deallocate(void* ptr) override { free(ptr); } 14 | DeviceType getDeviceType() const override { return DeviceType::CPU; }; 15 | }; 16 | 17 | #endif -------------------------------------------------------------------------------- /src/cpu_provider.cpp: -------------------------------------------------------------------------------- 1 | #include "cpu_provider.h" 2 | 3 | #include 4 | #include 5 | 6 | #include "onnx_helper.h" 7 | #include "operators.h" 8 | #include "optypes.h" 9 | 10 | Tensor CpuProvider::evaluateNode( 11 | const Node *node, const std::vector *> &inputs) { 12 | switch (node->getOpType()) { 13 | case OpType::Gemm: 14 | return gemm(node, inputs); 15 | case OpType::Flatten: 16 | return flatten(node, inputs); 17 | case OpType::Relu: 18 | return relu(node, inputs); 19 | case OpType::Add: 20 | return add(node, inputs); 21 | default: 22 | throw std::runtime_error("Unsupported operation type"); 23 | } 24 | } 25 | 26 | Tensor CpuProvider::gemm(const Node *node, 27 | const std::vector *> &inputs) { 28 | float alpha = node->getAttribute("alpha").value_or(1.0); 29 | float beta = node->getAttribute("beta").value_or(1.0); 30 | int transA = node->getAttribute("transA").value_or(0); 31 | int transB = node->getAttribute("transB").value_or(0); 32 | 33 | if (inputs.size() != 3) { 34 | throw std::runtime_error("Gemm operation expects 3 inputs"); 35 | } 36 | const Tensor &A = *inputs[0]; 37 | const Tensor &B = *inputs[1]; 38 | const Tensor &bias = *inputs[2]; 39 | return CpuOperators::gemm(A, B, bias, transA, transB, alpha, beta); 40 | } 41 | 42 | Tensor CpuProvider::flatten(const Node *node, 43 | const std::vector *> &inputs) { 44 | if (inputs.size() != 1) { 45 | throw std::runtime_error("Flatten operation expects 1 input"); 46 | } 47 | auto axisOpt = node->getAttribute("axis"); 48 | if (!axisOpt) { 49 | throw std::runtime_error("Axis missing for flatten operation"); 50 | } 51 | return CpuOperators::flatten(*inputs[0], axisOpt.value()); 52 | } 53 | 54 | Tensor CpuProvider::relu([[maybe_unused]] const Node *node, 55 | const std::vector *> &inputs) { 56 | if (inputs.size() != 1) { 57 | throw std::runtime_error("Relu operation expects 1 input"); 58 | } 59 | return CpuOperators::relu(*inputs[0]); 60 | } 61 | 62 | Tensor CpuProvider::add([[maybe_unused]] const Node *node, 63 | const std::vector *> &inputs) { 64 | if (inputs.size() != 2) { 65 | throw std::runtime_error("Add operation expects 2 inputs"); 66 | } 67 | return CpuOperators::add(*inputs[0], *inputs[1]); 68 | } 69 | -------------------------------------------------------------------------------- /src/cpu_provider.h: -------------------------------------------------------------------------------- 1 | #ifndef CPU_PROVIDER_H 2 | #define CPU_PROVIDER_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include "execution_provider.h" 9 | #include "graph.h" 10 | #include "tensor.h" 11 | 12 | class CpuProvider : public ExecutionProvider { 13 | public: 14 | Tensor evaluateNode( 15 | const Node *node, const std::vector *> &inputs) override; 16 | void transferWeightsToDevice( 17 | [[maybe_unused]] std::unordered_map> 18 | &weights) override { 19 | // No-op for CPU provider. 20 | return; 21 | } 22 | 23 | private: 24 | // Operators 25 | Tensor gemm(const Node *node, 26 | const std::vector *> &inputs); 27 | Tensor flatten(const Node *node, 28 | const std::vector *> &inputs); 29 | Tensor relu(const Node *node, 30 | const std::vector *> &inputs); 31 | Tensor add(const Node *node, 32 | const std::vector *> &inputs); 33 | }; 34 | 35 | #endif // CPU_PROVIDER_H -------------------------------------------------------------------------------- /src/cuda_allocator.h: -------------------------------------------------------------------------------- 1 | #ifndef CUDA_ALLOCATOR_H 2 | #define CUDA_ALLOCATOR_H 3 | 4 | #include 5 | 6 | #include 7 | #include 8 | 9 | #include "cuda_memory_pool.h" 10 | #include "memory_allocator.h" 11 | 12 | class CudaAllocator : public Allocator { 13 | public: 14 | CudaAllocator(size_t pool_size = 100 * 1024 * 1024) 15 | : pool_(std::make_unique(pool_size)) {} 16 | 17 | void* allocate(size_t size) override { 18 | if (pool_) { 19 | return pool_->allocate(size); 20 | } 21 | // fallback if no pool configured. 22 | void* ptr = nullptr; 23 | cudaMalloc(&ptr, size); 24 | return ptr; 25 | } 26 | void deallocate(void* ptr) override { 27 | if (pool_) { 28 | return pool_->deallocate(ptr); 29 | } 30 | // fallback if no pool configured. 31 | cudaFree(ptr); 32 | } 33 | DeviceType getDeviceType() const override { return DeviceType::CUDA; } 34 | 35 | private: 36 | std::unique_ptr pool_; 37 | }; 38 | 39 | #endif -------------------------------------------------------------------------------- /src/cuda_memory_pool.h: -------------------------------------------------------------------------------- 1 | #ifndef CUDA_MEMORY_POOL_H 2 | #define CUDA_MEMORY_POOL_H 3 | 4 | #include 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | class CudaMemoryPool { 14 | private: 15 | void* pool_start; 16 | size_t pool_size; 17 | struct Block { 18 | size_t offset; 19 | size_t size; 20 | bool is_free; 21 | Block(size_t off, size_t sz, bool free) 22 | : offset(off), size(sz), is_free(free) {} 23 | }; 24 | std::vector blocks; 25 | 26 | public: 27 | CudaMemoryPool(size_t size) : pool_size(size) { 28 | cudaError_t err = cudaMalloc(&pool_start, size); 29 | if (err != cudaSuccess) { 30 | throw std::runtime_error("Failed to allocate CUDA memory pool"); 31 | } 32 | blocks.emplace_back(0, size, true); 33 | } 34 | 35 | ~CudaMemoryPool() { cudaFree(pool_start); } 36 | 37 | void* allocate(size_t size) { 38 | auto best_fit = blocks.end(); 39 | size_t smallest_fit = pool_size + 1; 40 | 41 | for (auto it = blocks.begin(); it != blocks.end(); ++it) { 42 | if (it->is_free && it->size >= size) { 43 | if (it->size < smallest_fit) { 44 | best_fit = it; 45 | smallest_fit = it->size; 46 | } 47 | } 48 | } 49 | 50 | if (best_fit == blocks.end()) { 51 | throw std::runtime_error("Out of memory in CUDA pool"); 52 | } 53 | 54 | size_t alloc_offset = best_fit->offset; 55 | size_t remaining_size = best_fit->size - size; 56 | 57 | best_fit->is_free = false; 58 | best_fit->size = size; 59 | 60 | if (remaining_size > 0) { 61 | blocks.insert(best_fit + 1, 62 | Block(alloc_offset + size, remaining_size, true)); 63 | } 64 | 65 | return static_cast(pool_start) + alloc_offset; 66 | } 67 | 68 | void deallocate(void* ptr) { 69 | size_t offset = 70 | static_cast(ptr) - static_cast(pool_start); 71 | auto it = std::find_if( 72 | blocks.begin(), blocks.end(), 73 | [offset](const Block& b) { return b.offset == offset; }); 74 | 75 | if (it == blocks.end()) { 76 | throw std::runtime_error("Invalid pointer for deallocation"); 77 | } 78 | 79 | it->is_free = true; 80 | 81 | // Coalesce with previous block if free 82 | if (it != blocks.begin()) { 83 | auto prev = it - 1; 84 | if (prev->is_free) { 85 | prev->size += it->size; 86 | blocks.erase(it); 87 | it = prev; 88 | } 89 | } 90 | 91 | // Coalesce with next block if free 92 | if (it != blocks.end() - 1) { 93 | auto next = it + 1; 94 | if (next->is_free) { 95 | it->size += next->size; 96 | blocks.erase(next); 97 | } 98 | } 99 | } 100 | }; 101 | 102 | #endif -------------------------------------------------------------------------------- /src/cuda_provider.cpp: -------------------------------------------------------------------------------- 1 | #include "cuda_provider.h" 2 | 3 | #include 4 | #include 5 | 6 | #include "cuda_memory_pool.h" 7 | #include "device.h" 8 | #include "kernels.h" 9 | #include "onnx_helper.h" 10 | #include "operators.h" 11 | #include "optypes.h" 12 | 13 | template 14 | void validate_gemm_inputs(const Tensor &A, const Tensor &B, 15 | const Tensor &bias, bool transA, bool transB); 16 | 17 | CudaProvider::CudaProvider() : allocator_(std::make_shared()) {} 18 | 19 | Tensor CudaProvider::evaluateNode( 20 | const Node *node, const std::vector *> &inputs) { 21 | // TODO: find a way to move the input Tensor to device. Then we can skip 22 | // this entirely. 23 | for (auto input : inputs) { 24 | input->to(DeviceType::CUDA, allocator_); 25 | } 26 | 27 | switch (node->getOpType()) { 28 | case OpType::Gemm: 29 | return gemm(node, inputs); 30 | case OpType::Flatten: 31 | return flatten(node, inputs); 32 | case OpType::Relu: 33 | return relu(node, inputs); 34 | case OpType::Add: 35 | return add(node, inputs); 36 | default: 37 | throw std::runtime_error("Unsupported operation type"); 38 | } 39 | } 40 | 41 | Tensor CudaProvider::gemm(const Node *node, 42 | const std::vector *> &inputs) { 43 | if (inputs.size() != 3) { 44 | throw std::runtime_error("Gemm operation expects 3 inputs"); 45 | } 46 | 47 | float alpha = node->getAttribute("alpha").value_or(1.0); 48 | float beta = node->getAttribute("beta").value_or(1.0); 49 | int transA = node->getAttribute("transA").value_or(0); 50 | int transB = node->getAttribute("transB").value_or(0); 51 | 52 | const Tensor &A = *inputs[0]; 53 | const Tensor &B = *inputs[1]; 54 | const Tensor &bias = *inputs[2]; 55 | 56 | validate_gemm_inputs(A, B, bias, transA, transB); 57 | 58 | // Calculate output dimensions depending on transpositions. 59 | uint64_t N = transA ? A.shape()[1] : A.shape()[0]; 60 | uint64_t M = transB ? B.shape()[1] : B.shape()[0]; 61 | uint64_t K = transB ? B.shape()[0] : B.shape()[1]; 62 | 63 | std::vector dims{N, K}; 64 | 65 | Tensor out{std::move(dims), allocator_}; 66 | 67 | assert(A.device() == DeviceType::CUDA); 68 | const float *AData = A.data(); 69 | const float *BData = B.data(); 70 | const float *BiasData = bias.data(); 71 | 72 | gemm_cuda_naive(AData, BData, BiasData, out.data(), N, M, K, transA, transB, 73 | alpha, beta); 74 | 75 | return out; 76 | } 77 | 78 | Tensor CudaProvider::flatten( 79 | const Node *node, const std::vector *> &inputs) { 80 | if (inputs.size() != 1) { 81 | throw std::runtime_error("Flatten operation expects 1 input"); 82 | } 83 | auto axisOpt = node->getAttribute("axis"); 84 | if (!axisOpt) { 85 | throw std::runtime_error("Axis missing for flatten operation"); 86 | } 87 | return CudaOperators::flatten(*inputs[0], axisOpt.value()); 88 | } 89 | 90 | Tensor CudaProvider::relu([[maybe_unused]] const Node *node, 91 | const std::vector *> &inputs) { 92 | if (inputs.size() != 1) { 93 | throw std::runtime_error("Relu operation expects 1 input"); 94 | } 95 | 96 | const Tensor &in = *inputs[0]; 97 | Tensor out(in.shape(), allocator_); 98 | relu_cuda(in.data(), out.data(), out.size()); 99 | return out; 100 | } 101 | 102 | Tensor CudaProvider::add([[maybe_unused]] const Node *node, 103 | const std::vector *> &inputs) { 104 | if (inputs.size() != 2) { 105 | throw std::runtime_error("Add operation expects 2 inputs"); 106 | } 107 | const Tensor &A = *inputs[0]; 108 | const Tensor &B = *inputs[1]; 109 | Tensor out(A.shape(), allocator_); 110 | add_cuda(A.data(), B.data(), out.data(), out.size()); 111 | return out; 112 | } 113 | -------------------------------------------------------------------------------- /src/cuda_provider.h: -------------------------------------------------------------------------------- 1 | #ifndef CUDA_PROVIDER_H 2 | #define CUDA_PROVIDER_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include "cuda_allocator.h" 9 | #include "cuda_memory_pool.h" 10 | #include "execution_provider.h" 11 | #include "graph.h" 12 | #include "tensor.h" 13 | 14 | class CudaProvider : public ExecutionProvider { 15 | public: 16 | CudaProvider(); 17 | ~CudaProvider() override = default; 18 | 19 | Tensor evaluateNode( 20 | const Node *node, const std::vector *> &inputs) override; 21 | 22 | void transferWeightsToDevice( 23 | std::unordered_map> &weights) override { 24 | for (auto &[name, tensor] : weights) { 25 | tensor.to(DeviceType::CUDA, allocator_); 26 | } 27 | } 28 | 29 | private: 30 | std::shared_ptr allocator_; 31 | 32 | // Operators 33 | Tensor gemm(const Node *node, 34 | const std::vector *> &inputs); 35 | Tensor flatten(const Node *node, 36 | const std::vector *> &inputs); 37 | Tensor relu(const Node *node, 38 | const std::vector *> &inputs); 39 | Tensor add(const Node *node, 40 | const std::vector *> &inputs); 41 | }; 42 | 43 | #endif // CUDA_PROVIDER_H -------------------------------------------------------------------------------- /src/cuda_provider_unoptimized.cpp: -------------------------------------------------------------------------------- 1 | #include "cuda_provider_unoptimized.h" 2 | 3 | #include 4 | #include 5 | 6 | #include "device.h" 7 | #include "kernels.h" 8 | #include "onnx_helper.h" 9 | #include "operators.h" 10 | #include "optypes.h" 11 | 12 | template 13 | void validate_gemm_inputs(const Tensor &A, const Tensor &B, 14 | const Tensor &bias, bool transA, bool transB); 15 | 16 | CudaProviderUnoptimized::CudaProviderUnoptimized() {} 17 | 18 | Tensor CudaProviderUnoptimized::evaluateNode( 19 | const Node *node, const std::vector *> &inputs) { 20 | switch (node->getOpType()) { 21 | case OpType::Gemm: 22 | return gemm(node, inputs); 23 | case OpType::Flatten: 24 | return flatten(node, inputs); 25 | case OpType::Relu: 26 | return relu(node, inputs); 27 | case OpType::Add: 28 | return add(node, inputs); 29 | default: 30 | throw std::runtime_error("Unsupported operation type"); 31 | } 32 | } 33 | 34 | Tensor CudaProviderUnoptimized::gemm( 35 | const Node *node, const std::vector *> &inputs) { 36 | if (inputs.size() != 3) { 37 | throw std::runtime_error("Gemm operation expects 3 inputs"); 38 | } 39 | 40 | float alpha = node->getAttribute("alpha").value_or(1.0); 41 | float beta = node->getAttribute("beta").value_or(1.0); 42 | int transA = node->getAttribute("transA").value_or(0); 43 | int transB = node->getAttribute("transB").value_or(0); 44 | 45 | const Tensor &A = *inputs[0]; 46 | const Tensor &B = *inputs[1]; 47 | const Tensor &bias = *inputs[2]; 48 | 49 | validate_gemm_inputs(A, B, bias, transA, transB); 50 | 51 | // Calculate output dimensions depending on transpositions. 52 | uint64_t N = transA ? A.shape()[1] : A.shape()[0]; 53 | uint64_t M = transB ? B.shape()[1] : B.shape()[0]; 54 | uint64_t K = transB ? B.shape()[0] : B.shape()[1]; 55 | 56 | std::vector dims{N, K}; 57 | 58 | Tensor out{std::move(dims)}; 59 | 60 | const float *AData = A.data(); 61 | const float *BData = B.data(); 62 | const float *BiasData = bias.data(); 63 | 64 | gemm_cuda_unoptimized(AData, BData, BiasData, out.data(), N, M, K, transA, 65 | transB, alpha, beta); 66 | return out; 67 | } 68 | 69 | Tensor CudaProviderUnoptimized::flatten( 70 | const Node *node, const std::vector *> &inputs) { 71 | if (inputs.size() != 1) { 72 | throw std::runtime_error("Flatten operation expects 1 input"); 73 | } 74 | auto axisOpt = node->getAttribute("axis"); 75 | if (!axisOpt) { 76 | throw std::runtime_error("Axis missing for flatten operation"); 77 | } 78 | return CpuOperators::flatten(*inputs[0], axisOpt.value()); 79 | } 80 | 81 | Tensor CudaProviderUnoptimized::relu( 82 | [[maybe_unused]] const Node *node, 83 | const std::vector *> &inputs) { 84 | if (inputs.size() != 1) { 85 | throw std::runtime_error("Relu operation expects 1 input"); 86 | } 87 | 88 | const Tensor &in = *inputs[0]; 89 | Tensor out(in.shape()); 90 | relu_cuda_unoptimized(in.data(), out.data(), out.size()); 91 | return out; 92 | } 93 | 94 | Tensor CudaProviderUnoptimized::add( 95 | [[maybe_unused]] const Node *node, 96 | const std::vector *> &inputs) { 97 | if (inputs.size() != 2) { 98 | throw std::runtime_error("Add operation expects 2 inputs"); 99 | } 100 | const Tensor &A = *inputs[0]; 101 | const Tensor &B = *inputs[1]; 102 | Tensor out(A.shape()); 103 | add_cuda_unoptimized(A.data(), B.data(), out.data(), out.size()); 104 | return out; 105 | } 106 | -------------------------------------------------------------------------------- /src/cuda_provider_unoptimized.h: -------------------------------------------------------------------------------- 1 | #ifndef CUDA_PROVIDER_UNOPTIMIZED_H 2 | #define CUDA_PROVIDER_UNOPTIMIZED_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include "execution_provider.h" 9 | #include "graph.h" 10 | #include "tensor.h" 11 | 12 | class CudaProviderUnoptimized : public ExecutionProvider { 13 | public: 14 | CudaProviderUnoptimized(); 15 | ~CudaProviderUnoptimized() override = default; 16 | 17 | Tensor evaluateNode( 18 | const Node *node, const std::vector *> &inputs) override; 19 | 20 | void transferWeightsToDevice( 21 | [[maybe_unused]] std::unordered_map> 22 | &weights) override { 23 | // No-op for unoptimized cuda provider. 24 | return; 25 | } 26 | 27 | private: 28 | // Operators 29 | Tensor gemm(const Node *node, 30 | const std::vector *> &inputs); 31 | Tensor flatten(const Node *node, 32 | const std::vector *> &inputs); 33 | Tensor relu(const Node *node, 34 | const std::vector *> &inputs); 35 | Tensor add(const Node *node, 36 | const std::vector *> &inputs); 37 | }; 38 | 39 | #endif // CUDA_PROVIDER_UNOPTIMIZED_H -------------------------------------------------------------------------------- /src/device.h: -------------------------------------------------------------------------------- 1 | #ifndef DEVICE_H 2 | #define DEVICE_H 3 | 4 | enum class DeviceType { CPU, CUDA }; 5 | 6 | #endif -------------------------------------------------------------------------------- /src/execution_provider.h: -------------------------------------------------------------------------------- 1 | #ifndef EXECUTION_PROVIDER_H 2 | #define EXECUTION_PROVIDER_H 3 | 4 | #include "node.h" 5 | #include "tensor.h" 6 | 7 | class ExecutionProvider { 8 | public: 9 | ExecutionProvider() = default; 10 | virtual ~ExecutionProvider() = default; 11 | virtual void transferWeightsToDevice( 12 | std::unordered_map>& weights) = 0; 13 | virtual Tensor evaluateNode( 14 | const Node* node, const std::vector*>& inputs) = 0; 15 | }; 16 | 17 | #endif // EXECUTION_PROVIDER_H -------------------------------------------------------------------------------- /src/gemm_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include "gemm_cpu.h" 2 | 3 | #include 4 | 5 | // gemm_cpu returns out = A * B + bias 6 | // A is (n, m) 7 | // B is (m, k) 8 | // bias is assumed to be (n, 1) and broadcasted 9 | // out is (n, k) 10 | void gemm_cpu(const float* A, const float* B, const float* bias, float* out, 11 | const int n, const int m, const int k, const bool transA, 12 | const bool transB, const float alpha, const float beta) { 13 | for (int r = 0; r < n; ++r) { 14 | for (int c = 0; c < k; ++c) { 15 | float res = 0; 16 | for (int i = 0; i < m; ++i) { 17 | float aVal = transA ? A[i * n + r] : A[r * m + i]; 18 | float bVal = transB ? B[c * m + i] : B[i * k + c]; 19 | res += aVal * bVal; 20 | } 21 | out[r * k + c] = res * alpha; 22 | } 23 | } 24 | 25 | // Apply bias term 26 | for (int r = 0; r < n; ++r) { 27 | for (int c = 0; c < k; ++c) { 28 | out[r * k + c] += bias[r] * beta; 29 | } 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /src/gemm_cpu.h: -------------------------------------------------------------------------------- 1 | #ifndef GEMM_H 2 | #define GEMM_H 3 | 4 | void gemm_cpu(const float* A, const float* B, const float* bias, float* out, 5 | const int m, const int n, const int k, const bool transA, 6 | const bool transB, const float alpha, const float beta); 7 | 8 | #endif // GEMM_H -------------------------------------------------------------------------------- /src/graph.cpp: -------------------------------------------------------------------------------- 1 | #include "graph.h" 2 | 3 | Graph::Graph(const onnx::GraphProto& graphProto) { 4 | nodeMap_.reserve(graphProto.node_size()); 5 | for (const auto& nodeProto : graphProto.node()) { 6 | addNode(std::make_unique(nodeProto)); 7 | } 8 | 9 | inputs_.reserve(graphProto.input_size()); 10 | for (const auto& inputProto : graphProto.input()) { 11 | inputs_.push_back(inputProto.name()); 12 | } 13 | 14 | outputs_.reserve(graphProto.output_size()); 15 | for (const auto& outputProto : graphProto.output()) { 16 | outputs_.push_back(outputProto.name()); 17 | } 18 | } 19 | 20 | void Graph::addNode(std::unique_ptr node) { 21 | std::string nodeName = node->getName(); 22 | Node* nodePtr = node.get(); 23 | nodeMap_[nodeName].node = std::move(node); 24 | 25 | updateEdges(nodePtr); 26 | } 27 | 28 | void Graph::updateEdges(Node* node) { 29 | addIncomingEdges(node); 30 | addOutgoingEdges(node); 31 | } 32 | 33 | void Graph::addIncomingEdges(Node* node) { 34 | for (const auto& inputName : node->getInputs()) { 35 | for (auto& [_, info] : nodeMap_) { 36 | if (std::find(info.node->getOutputs().begin(), 37 | info.node->getOutputs().end(), 38 | inputName) != info.node->getOutputs().end()) { 39 | info.children.push_back(node); 40 | nodeMap_[node->getName()].parents.push_back(info.node.get()); 41 | } 42 | } 43 | } 44 | } 45 | 46 | void Graph::addOutgoingEdges(Node* node) { 47 | for (const auto& outputName : node->getOutputs()) { 48 | for (auto& [_, info] : nodeMap_) { 49 | if (std::find(info.node->getInputs().begin(), 50 | info.node->getInputs().end(), 51 | outputName) != info.node->getInputs().end()) { 52 | nodeMap_[node->getName()].children.push_back(info.node.get()); 53 | info.parents.push_back(node); 54 | } 55 | } 56 | } 57 | } 58 | 59 | void Graph::replaceNode(Node* oldNode, std::unique_ptr newNode) { 60 | // Note: This assumes that the new node has the same connections as the old 61 | // one. If this is not the case, parents/children need to updated manually. 62 | std::string name = oldNode->getName(); 63 | auto& info = nodeMap_[name]; 64 | info.node = std::move(newNode); 65 | } 66 | 67 | const std::string& Graph::getInputName(std::size_t index) const { 68 | return inputs_.at(index); 69 | } 70 | 71 | const std::string& Graph::getOutputName(std::size_t index) const { 72 | return outputs_.at(index); 73 | } 74 | 75 | std::vector Graph::getTopologicallySortedNodes() { 76 | if (!sortedNodes_.empty()) { 77 | return sortedNodes_; 78 | } 79 | 80 | std::unordered_set visited; 81 | std::stack stack; 82 | 83 | for (const auto& [_, info] : nodeMap_) { 84 | if (info.parents.empty() || isInputNode(info.node.get())) { 85 | topologicalSortUtil(info.node.get(), visited, stack); 86 | } 87 | } 88 | 89 | sortedNodes_.reserve(stack.size()); 90 | while (!stack.empty()) { 91 | sortedNodes_.push_back(stack.top()); 92 | stack.pop(); 93 | } 94 | 95 | return sortedNodes_; 96 | } 97 | 98 | bool Graph::isInputNode(Node* node) const { 99 | return std::any_of(node->getInputs().begin(), node->getInputs().end(), 100 | [this](const std::string& input) { 101 | return std::find(inputs_.begin(), inputs_.end(), 102 | input) != inputs_.end(); 103 | }); 104 | } 105 | 106 | void Graph::topologicalSortUtil(Node* node, std::unordered_set& visited, 107 | std::stack& stack) { 108 | visited.insert(node); 109 | 110 | for (Node* child : nodeMap_[node->getName()].children) { 111 | if (visited.find(child) == visited.end()) { 112 | topologicalSortUtil(child, visited, stack); 113 | } 114 | } 115 | 116 | stack.push(node); 117 | } 118 | 119 | void Graph::printGraph() const { 120 | for (const auto& [name, info] : nodeMap_) { 121 | std::cout << "Node " << name << ": \n"; 122 | for (const auto& child : info.children) { 123 | std::cout << " child " << child->getName() << "\n"; 124 | } 125 | } 126 | } -------------------------------------------------------------------------------- /src/graph.h: -------------------------------------------------------------------------------- 1 | #ifndef GRAPH_H 2 | #define GRAPH_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include "node.h" 9 | #include "onnx-ml.pb.h" 10 | 11 | class Graph { 12 | public: 13 | struct NodeInfo { 14 | std::unique_ptr node; 15 | std::vector children; 16 | std::vector parents; 17 | }; 18 | 19 | Graph() = default; 20 | Graph(const onnx::GraphProto& graphProto); 21 | 22 | const std::string& getInputName(std::size_t index) const; 23 | const std::string& getOutputName(std::size_t index) const; 24 | void printGraph() const; 25 | 26 | std::vector getTopologicallySortedNodes(); 27 | void addNode(std::unique_ptr node); 28 | void replaceNode(Node* oldNode, std::unique_ptr newNode); 29 | 30 | private: 31 | void updateEdges(Node* node); 32 | void addIncomingEdges(Node* node); 33 | bool isInputNode(Node* node) const; 34 | void addOutgoingEdges(Node* node); 35 | void topologicalSortUtil(Node* node, std::unordered_set& visited, 36 | std::stack& stack); 37 | std::vector inputs_; 38 | std::vector outputs_; 39 | std::unordered_map nodeMap_; 40 | std::vector sortedNodes_; 41 | }; 42 | 43 | #endif -------------------------------------------------------------------------------- /src/inference_session.cpp: -------------------------------------------------------------------------------- 1 | #include "inference_session.h" 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "device.h" 8 | 9 | void validate_model(const onnx::ModelProto& model, const ModelConfig& config); 10 | std::unordered_map> load_weights( 11 | const onnx::ModelProto& model); 12 | 13 | void InferenceSession::set_execution_provider( 14 | std::unique_ptr engine) { 15 | engine_ = std::move(engine); 16 | } 17 | 18 | void InferenceSession::initialize_provider() { 19 | pinned_allocator_ = std::make_shared(); 20 | engine_->transferWeightsToDevice(weights_); 21 | } 22 | 23 | void InferenceSession::set_input(const std::string& name, Tensor input) { 24 | weights_.insert_or_assign(name, std::move(input)); 25 | } 26 | 27 | Tensor InferenceSession::get_output(const std::string& name) { 28 | auto it = weights_.find(name); 29 | if (it == weights_.end()) { 30 | throw std::runtime_error("Output not found: " + name); 31 | } 32 | auto res = std::move(it->second); 33 | weights_.erase(it); 34 | res.to(DeviceType::CPU, pinned_allocator_); 35 | return res; 36 | } 37 | 38 | void InferenceSession::run() { 39 | for (const auto node : graph_->getTopologicallySortedNodes()) { 40 | auto inputs = prepare_node_inputs(node); 41 | auto output = engine_->evaluateNode(node, inputs); 42 | 43 | if (output.size() != 0) { 44 | weights_.insert_or_assign(node->getOutputs()[0], std::move(output)); 45 | } else { 46 | throw std::runtime_error("Got empty output after inference loop."); 47 | } 48 | } 49 | } 50 | 51 | std::vector*> InferenceSession::prepare_node_inputs( 52 | const Node* node) { 53 | std::vector*> inputs; 54 | const auto& input_names = node->getInputs(); 55 | inputs.reserve(input_names.size()); 56 | 57 | for (const auto& input_name : input_names) { 58 | auto it = weights_.find(input_name); 59 | if (it == weights_.end()) { 60 | throw std::runtime_error("Input not found: " + input_name); 61 | } 62 | inputs.push_back(&it->second); 63 | } 64 | return inputs; 65 | } 66 | 67 | void InferenceSession::load_model(const ModelConfig& config) { 68 | onnx::ModelProto model; 69 | { 70 | std::ifstream input(config.get_model_path(), std::ios::binary); 71 | if (!input || !model.ParseFromIstream(&input)) { 72 | throw std::runtime_error( 73 | "Failed to load or parse the ONNX model: " + 74 | config.get_model_path()); 75 | } 76 | } 77 | 78 | if (!model.has_graph() || model.graph().node_size() == 0) { 79 | throw std::runtime_error("Invalid ONNX model: missing graph or nodes"); 80 | } 81 | 82 | validate_model(model, config); 83 | 84 | weights_ = load_weights(model); 85 | graph_ = std::make_unique(model.graph()); 86 | } 87 | 88 | void validate_model(const onnx::ModelProto& model, const ModelConfig& config) { 89 | if ((std::size_t)model.graph().input_size() != config.get_inputs().size()) { 90 | throw std::runtime_error( 91 | "Mismatch in number of inputs between model and config"); 92 | } 93 | for (int i = 0; i < model.graph().input_size(); ++i) { 94 | const auto& model_input = model.graph().input(i); 95 | const auto& config_input = config.get_inputs()[i]; 96 | if (model_input.name() != config_input.name) { 97 | throw std::runtime_error( 98 | "Mismatch in input names between model and config. Got: " + 99 | model_input.name() + ", but expected: " + config_input.name); 100 | } 101 | } 102 | 103 | if ((std::size_t)model.graph().output_size() != 104 | config.get_outputs().size()) { 105 | throw std::runtime_error( 106 | "Mismatch in number of outputs between model and config"); 107 | } 108 | 109 | for (int i = 0; i < model.graph().output_size(); ++i) { 110 | const auto& model_output = model.graph().output(i); 111 | const auto& config_output = config.get_outputs()[i]; 112 | if (model_output.name() != config_output.name) { 113 | throw std::runtime_error( 114 | "Mismatch in output names between model and config. Got: " + 115 | model_output.name() + ", but expected: " + config_output.name); 116 | } 117 | } 118 | } 119 | 120 | std::unordered_map> load_weights( 121 | const onnx::ModelProto& model) { 122 | std::unordered_map> weights; 123 | for (const auto& initializer : model.graph().initializer()) { 124 | if (initializer.data_type() != onnx::TensorProto::FLOAT) { 125 | throw std::runtime_error("Unsupported initializer data type"); 126 | } 127 | 128 | const auto& raw_data = initializer.raw_data(); 129 | const float* data_ptr = reinterpret_cast(raw_data.data()); 130 | 131 | std::vector shape(initializer.dims().begin(), 132 | initializer.dims().end()); 133 | 134 | weights.emplace(initializer.name(), 135 | Tensor{data_ptr, std::move(shape)}); 136 | } 137 | return weights; 138 | } -------------------------------------------------------------------------------- /src/inference_session.h: -------------------------------------------------------------------------------- 1 | #ifndef INFERENCE_SESSION_H 2 | #define INFERENCE_SESSION_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include "execution_provider.h" 9 | #include "graph.h" 10 | #include "model_config.h" 11 | #include "pinned_cpu_allocator.h" 12 | #include "tensor.h" 13 | 14 | class InferenceSession { 15 | public: 16 | void load_model(const ModelConfig& config); 17 | void set_execution_provider(std::unique_ptr engine); 18 | void initialize_provider(); 19 | void set_input(const std::string& name, Tensor input); 20 | Tensor get_output(const std::string& name); 21 | void run(); 22 | 23 | private: 24 | std::vector*> prepare_node_inputs(const Node* node); 25 | 26 | std::shared_ptr pinned_allocator_; 27 | std::unique_ptr engine_; 28 | std::unique_ptr graph_; 29 | std::unordered_map> weights_; 30 | }; 31 | 32 | #endif 33 | -------------------------------------------------------------------------------- /src/input_loader.cpp: -------------------------------------------------------------------------------- 1 | #include "input_loader.h" 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | Tensor load_input(const std::string& filename, 8 | const ModelConfig& config) { 9 | std::ifstream file(filename, std::ios::binary); 10 | if (!file) { 11 | throw std::runtime_error("Error opening file: " + filename); 12 | } 13 | 14 | const auto& input_shape = config.get_inputs()[0].shape; 15 | size_t expected_size = 1; 16 | for (const auto& dim : input_shape) { 17 | expected_size *= dim; 18 | } 19 | std::vector bytes(expected_size); 20 | 21 | file.read(reinterpret_cast(bytes.data()), expected_size); 22 | 23 | if (file.gcount() != static_cast(expected_size)) { 24 | throw std::runtime_error("Unexpected file size: " + filename); 25 | } 26 | 27 | std::vector floatValues; 28 | floatValues.reserve(expected_size); 29 | 30 | for (unsigned char byte : bytes) { 31 | floatValues.push_back(static_cast(byte)); 32 | } 33 | return Tensor{std::move(floatValues.data()), input_shape}; 34 | } 35 | -------------------------------------------------------------------------------- /src/input_loader.h: -------------------------------------------------------------------------------- 1 | #ifndef INPUT_LOADER_H 2 | #define INPUT_LOADER_H 3 | 4 | #include 5 | 6 | #include "model_config.h" 7 | #include "tensor.h" 8 | 9 | Tensor load_input(const std::string& filename, 10 | const ModelConfig& config); 11 | 12 | #endif // INPUT_LOADER_H -------------------------------------------------------------------------------- /src/kernels.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | const std::size_t BLOCK_SIZE{16}; 8 | 9 | #define CEIL_DIV(M, N) (((M) + (N) - 1) / (N)) 10 | 11 | /* 12 | Tensors are of shape: 13 | A: (n, m) 14 | B: (m, k) 15 | C: (n, k) 16 | */ 17 | __global__ void gemm_kernel_naive(const float *A, const float *B, 18 | const float *bias, float *out, int n, int m, 19 | int k, bool transA, bool transB, float alpha, 20 | float beta) { 21 | const uint col = blockIdx.x * blockDim.x + threadIdx.x; 22 | const uint row = blockIdx.y * blockDim.y + threadIdx.y; 23 | 24 | if (row < n && col < k) { 25 | float res = 0.0f; 26 | 27 | for (int i = 0; i < m; ++i) { 28 | float aVal = transA ? A[i * n + row] : A[row * m + i]; 29 | float bVal = transB ? B[col * m + i] : B[i * k + col]; 30 | res += aVal * bVal; 31 | } 32 | out[row * k + col] = res * alpha + bias[col] * beta; 33 | } 34 | } 35 | 36 | template 37 | __global__ void gemm_kernel_tiled(const float *A, const float *B, 38 | const float *bias, float *out, int n, int m, 39 | int k, bool transA, bool transB, float alpha, 40 | float beta) { 41 | __shared__ float As[BLOCKSIZE][BLOCKSIZE]; 42 | __shared__ float Bs[BLOCKSIZE][BLOCKSIZE]; 43 | 44 | uint bx = blockIdx.x; 45 | uint by = blockIdx.y; 46 | uint tx = threadIdx.x; 47 | uint ty = threadIdx.y; 48 | 49 | int row = by * BLOCKSIZE + ty; 50 | int col = bx * BLOCKSIZE + tx; 51 | 52 | // Calcualtes single entry of C per block. 53 | float res = 0.0f; 54 | 55 | for (int blkIdx = 0; blkIdx < CEIL_DIV(m, BLOCKSIZE); ++blkIdx) { 56 | // Collaboratively load tile 57 | if (row < n && blkIdx * BLOCKSIZE + tx < m) { 58 | // handle transpose. 59 | As[ty][tx] = transA ? A[(blkIdx * BLOCKSIZE + tx) * n + row] 60 | : A[row * m + (blkIdx * BLOCKSIZE + tx)]; 61 | } else { 62 | // Out of bounds defaults to 0. 63 | As[ty][tx] = 0.0f; 64 | } 65 | 66 | if (blkIdx * BLOCKSIZE + ty < m && col < k) { 67 | // Handle transpose. 68 | Bs[ty][tx] = transB ? B[col * m + (blkIdx * BLOCKSIZE + ty)] 69 | : B[(blkIdx * BLOCKSIZE + ty) * k + col]; 70 | } else { 71 | // Out of bounds defaults to 0. 72 | Bs[ty][tx] = 0.0f; 73 | } 74 | 75 | __syncthreads(); 76 | 77 | // Matmul over tile 78 | for (int i = 0; i < BLOCKSIZE; ++i) { 79 | res += As[ty][i] * Bs[i][tx]; 80 | } 81 | 82 | __syncthreads(); 83 | } 84 | 85 | if (row < n && col < k) { 86 | out[row * k + col] = res * alpha + bias[col] * beta; 87 | } 88 | } 89 | 90 | /* 91 | Tensors are of shape: 92 | A: (n, m) 93 | B: (m, k) 94 | C: (n, k) 95 | */ 96 | template 97 | __global__ void gemm_kernel_tiled_1D(const float *A, const float *B, 98 | const float *bias, float *out, int N, 99 | int M, int K, bool transA, bool transB, 100 | float alpha, float beta) { 101 | __shared__ float As[BLOCKSIZE][BLOCKSIZE]; 102 | __shared__ float Bs[BLOCKSIZE][BLOCKSIZE]; 103 | 104 | // Block x y coordinates 105 | uint bx = blockIdx.x; 106 | uint by = blockIdx.y; 107 | 108 | // thread x y coordinates. 109 | uint tx = threadIdx.x % BLOCKSIZE; 110 | uint ty = threadIdx.x / BLOCKSIZE; 111 | 112 | // output row column. 113 | int row = by * BLOCKSIZE + ty; 114 | int col = bx * BLOCKSIZE + tx; 115 | 116 | // Calcualtes single entry of C per block. 117 | float res = 0.0f; 118 | 119 | for (int blkIdx = 0; blkIdx < CEIL_DIV(M, BLOCKSIZE); ++blkIdx) { 120 | // Collaboratively load tile 121 | if (row < N && blkIdx * BLOCKSIZE + tx < M) { 122 | As[ty][tx] = transA ? A[(blkIdx * BLOCKSIZE + tx) * N + row] 123 | : A[row * M + (blkIdx * BLOCKSIZE + tx)]; 124 | } else { 125 | // Out of bounds defaults to 0. 126 | As[ty][tx] = 0.0f; 127 | } 128 | 129 | if (blkIdx * BLOCKSIZE + ty < M && col < K) { 130 | Bs[ty][tx] = transB ? B[col * M + (blkIdx * BLOCKSIZE + ty)] 131 | : B[(blkIdx * BLOCKSIZE + ty) * K + col]; 132 | } else { 133 | // Out of bounds defaults to 0. 134 | Bs[ty][tx] = 0.0f; 135 | } 136 | 137 | __syncthreads(); 138 | 139 | // Matmul over tile 140 | for (int i = 0; i < BLOCKSIZE; ++i) { 141 | res += As[ty][i] * Bs[i][tx]; 142 | } 143 | 144 | __syncthreads(); 145 | } 146 | 147 | if (row < N && col < K) { 148 | out[row * K + col] = res * alpha + bias[col] * beta; 149 | } 150 | } 151 | 152 | /** Difference from simple tiled approach: 153 | 154 | This implementation only handles non-transposed matrices as indicated by the 155 | _nn suffix. 156 | 157 | 1. Tiles are not (BLOCKSIZE, BLOCKSIZE), but rather (BN, BM) and (BM, 158 | BK). Important is that each thread has to process (BN*BK) / ((BN*BM + 159 | BM+BK)/2) = (2*BN*BK)/(BN*BM + BM+BK) results. So then supposing we use 160 | (64, 8), (8, 64), we get 2*64*64 / (16*64) = 8. 161 | 162 | 2. Each thread still only loads one value, but then computes TM 163 | outputs. Each thread calculates an (TM,1) vector of C as a column. So thread 164 | (0,0,0) will calculate C[0][0], C[1][0], ... C[TM][0]. 165 | 166 | Tensors are of shape: 167 | A: (n, m) 168 | B: (m, k) 169 | C: (n, k) 170 | bias: (1, k) 171 | */ 172 | template 173 | __global__ void gemm_kernel_tiled_1D_blocktiling_nn(const float *A, 174 | const float *B, 175 | const float *bias, float *C, 176 | int N, int M, int K, 177 | float alpha, float beta) { 178 | __shared__ float As[BN][BM]; // 64, 8 179 | __shared__ float Bs[BM][BK]; // 8, 64 180 | 181 | // Block x y coordinates 182 | uint bx = blockIdx.x; // (0 - K div BK) 183 | uint by = blockIdx.y; // (0 - N div BN) 184 | 185 | // Thread for 186 | const uint threadCol = threadIdx.x % BK; // range: 0-64 187 | const uint threadRow = threadIdx.x / BK; // range: 0-8 188 | 189 | assert(BN * BM == blockDim.x); 190 | assert(BM * BK == blockDim.x); 191 | const uint innerColA = threadIdx.x % BM; // range: 0 - 8 192 | const uint innerRowA = threadIdx.x / BM; // range: 0 - 64 193 | const uint innerColB = threadIdx.x % BK; // range: 0 - 64 194 | const uint innerRowB = threadIdx.x / BK; // range: 0 - 8 195 | 196 | // Each thread block processes one block row of A and block column of B. 197 | // This means we need to correctly account for offsets depending on in which 198 | // thread block this kernel is. 199 | 200 | // This says how many rows * width of A to skip. 201 | // by * BN gives the size of the block tile and M is the row width. 202 | const uint A_baseoffset = by * BN * M; 203 | 204 | // This says how many columns * height of B to sip. 205 | // bx * BK gives the number of columns to skip. Using this offset, whenever 206 | // we index into a row, we'll automatically skip the first offset columns. 207 | const uint B_baseoffset = bx * BK; 208 | 209 | // Each thread calculates TM results. 210 | float threadResults[TM] = {0.0}; 211 | 212 | // Outer loop over block tiles. Block tile in A is of shape (BN, BM) and 213 | // block tile in B is of shape (BM, BK). Supposing A and B are both 214 | // (128, 128), and (BN=64, BM=8), (BM=8, BK=64). Each thread block will 215 | // calculate a 64x64 block of C. This is done by sliding a (64x8) 216 | // blocktile across A and B in lock-step. Partial results are 217 | // accummulated and once fully slided, written to C. 218 | for (uint blockTileOffset = 0; blockTileOffset < M; blockTileOffset += BM) { 219 | const bool validRowA = (by * BN + innerRowA) < N; 220 | const bool validColA = (blockTileOffset + innerColA) < M; 221 | const uint A_idx = 222 | A_baseoffset + blockTileOffset + innerRowA * M + innerColA; 223 | if (validRowA && validColA && A_idx < N * M) { 224 | As[innerRowA][innerColA] = A[A_idx]; 225 | } else { 226 | As[innerRowA][innerColA] = 0; 227 | } 228 | 229 | const bool validRowB = (blockTileOffset + innerRowB) < M; 230 | const bool validColB = (bx * BK + innerColB) < K; 231 | const uint B_idx = 232 | B_baseoffset + blockTileOffset * K + innerColB + innerRowB * K; 233 | if (validRowB && validColB && B_idx < M * K) { 234 | Bs[innerRowB][innerColB] = B[B_idx]; 235 | } else { 236 | Bs[innerRowB][innerColB] = 0; 237 | } 238 | 239 | __syncthreads(); 240 | // Calculate per-thread results. Each thread calculates TM dot products. 241 | // This is done by taking TM rows of As and multiplying them with a 242 | // single column of Bs. 243 | for (uint resIdx = 0; resIdx < TM; ++resIdx) { 244 | for (uint dotIdx = 0; dotIdx < BM; ++dotIdx) { 245 | threadResults[resIdx] += 246 | As[threadRow * TM + resIdx][dotIdx] * Bs[dotIdx][threadCol]; 247 | } 248 | } 249 | __syncthreads(); 250 | } 251 | 252 | // write out the results 253 | for (uint resIdx = 0; resIdx < TM; ++resIdx) { 254 | const uint rowC = by * BN + threadRow * TM + resIdx; 255 | const uint colC = bx * BK + threadCol; 256 | 257 | if (rowC < N && colC < K) { 258 | C[rowC * K + colC] = 259 | alpha * threadResults[resIdx] + beta * bias[colC]; 260 | } 261 | } 262 | } 263 | 264 | template 265 | __global__ void gemm_kernel_tiled_1D_blocktiling_nt(const float *A, 266 | const float *B, 267 | const float *bias, float *C, 268 | int N, int M, int K, 269 | float alpha, float beta) { 270 | __shared__ float As[BN][BM]; // 64, 8 271 | __shared__ float Bs[BM][BK]; // 8, 64 272 | 273 | // Block coordinates remain the same 274 | uint bx = blockIdx.x; // (0 - K div BK) 275 | uint by = blockIdx.y; // (0 - N div BN) 276 | 277 | const uint threadCol = threadIdx.x % BK; 278 | const uint threadRow = threadIdx.x / BK; 279 | 280 | assert(BN * BM == blockDim.x); 281 | assert(BM * BK == blockDim.x); 282 | const uint innerColA = threadIdx.x % BM; 283 | const uint innerRowA = threadIdx.x / BM; 284 | const uint innerColB = threadIdx.x % BK; 285 | const uint innerRowB = threadIdx.x / BK; 286 | 287 | // A's loading pattern remains the same 288 | const uint A_baseoffset = by * BN * M; 289 | 290 | // B's base offset changes since B is transposed 291 | // Instead of columns, we're now skipping rows in the transposed matrix 292 | const uint B_baseoffset = 293 | bx * BK * M; // Note: multiplied by M instead of 1 294 | 295 | float threadResults[TM] = {0.0}; 296 | 297 | for (uint blockTileOffset = 0; blockTileOffset < M; blockTileOffset += BM) { 298 | // A loading remains identical 299 | const bool validRowA = (by * BN + innerRowA) < N; 300 | const bool validColA = (blockTileOffset + innerColA) < M; 301 | const uint A_idx = 302 | A_baseoffset + blockTileOffset + innerRowA * M + innerColA; 303 | if (validRowA && validColA && A_idx < N * M) { 304 | As[innerRowA][innerColA] = A[A_idx]; 305 | } else { 306 | As[innerRowA][innerColA] = 0; 307 | } 308 | 309 | // B loading pattern changes for transposed case 310 | const bool validRowB = (blockTileOffset + innerRowB) < M; 311 | const bool validColB = (bx * BK + innerColB) < K; 312 | // Key change: For transposed B, we access it as B[k][m] instead of 313 | // B[m][k] 314 | const uint B_idx = 315 | B_baseoffset + blockTileOffset + innerColB * M + innerRowB; 316 | if (validRowB && validColB && B_idx < M * K) { 317 | Bs[innerRowB][innerColB] = B[B_idx]; 318 | } else { 319 | Bs[innerRowB][innerColB] = 0; 320 | } 321 | 322 | __syncthreads(); 323 | 324 | // The multiplication logic remains identical since we've loaded 325 | // the data into shared memory in the same format 326 | for (uint resIdx = 0; resIdx < TM; ++resIdx) { 327 | for (uint dotIdx = 0; dotIdx < BM; ++dotIdx) { 328 | threadResults[resIdx] += 329 | As[threadRow * TM + resIdx][dotIdx] * Bs[dotIdx][threadCol]; 330 | } 331 | } 332 | __syncthreads(); 333 | } 334 | 335 | // Result writing remains the same 336 | for (uint resIdx = 0; resIdx < TM; ++resIdx) { 337 | const uint rowC = by * BN + threadRow * TM + resIdx; 338 | const uint colC = bx * BK + threadCol; 339 | 340 | if (rowC < N && colC < K) { 341 | C[rowC * K + colC] = 342 | alpha * threadResults[resIdx] + beta * bias[colC]; 343 | } 344 | } 345 | } 346 | 347 | void gemm_cuda_tiled(const float *A, const float *B, const float *bias, 348 | float *out, int n, int m, int k, bool transA, bool transB, 349 | float alpha, float beta) { 350 | const uint blockSize{16}; 351 | dim3 blockDim(blockSize, blockSize); 352 | dim3 gridDim(CEIL_DIV(k, blockSize), CEIL_DIV(n, blockSize)); 353 | 354 | gemm_kernel_tiled<<>>( 355 | A, B, bias, out, n, m, k, transA, transB, alpha, beta); 356 | } 357 | 358 | void gemm_cuda_tiled_1D(const float *A, const float *B, const float *bias, 359 | float *out, int n, int m, int k, bool transA, 360 | bool transB, float alpha, float beta) { 361 | const uint blockSize{16}; 362 | dim3 blockDim(blockSize * blockSize); 363 | dim3 gridDim(CEIL_DIV(k, blockSize), CEIL_DIV(n, blockSize)); 364 | 365 | gemm_kernel_tiled_1D<<>>( 366 | A, B, bias, out, n, m, k, transA, transB, alpha, beta); 367 | } 368 | 369 | void gemm_tiled_1D_blocktiling(const float *A, const float *B, 370 | const float *bias, float *out, int n, int m, 371 | int k, bool transA, bool transB, float alpha, 372 | float beta) { 373 | const uint BN{64}; 374 | const uint BM{8}; 375 | const uint BK{64}; 376 | const uint TM{8}; 377 | dim3 gridDim(CEIL_DIV(k, BK), CEIL_DIV(n, BN)); 378 | dim3 blockDim((BN * BK) / TM); 379 | if (!transA && !transB) { 380 | gemm_kernel_tiled_1D_blocktiling_nn 381 | <<>>(A, B, bias, out, n, m, k, alpha, beta); 382 | } else if (!transA && transB) { 383 | gemm_kernel_tiled_1D_blocktiling_nt 384 | <<>>(A, B, bias, out, n, m, k, alpha, beta); 385 | } 386 | 387 | else { 388 | throw std::runtime_error( 389 | "Only NN, NT blocktiling kernels are implemented."); 390 | } 391 | } 392 | 393 | void gemm_cuda_naive(const float *A, const float *B, const float *bias, 394 | float *out, int n, int m, int k, bool transA, bool transB, 395 | float alpha, float beta) { 396 | dim3 blockDim(BLOCK_SIZE, BLOCK_SIZE, 1); 397 | dim3 gridDim((k + blockDim.x - 1) / blockDim.x, 398 | (n + blockDim.y - 1) / blockDim.y); 399 | 400 | gemm_kernel_naive<<>>(A, B, bias, out, n, m, k, transA, 401 | transB, alpha, beta); 402 | } 403 | 404 | void gemm_cuda_unoptimized(const float *A, const float *B, const float *bias, 405 | float *out, int n, int m, int k, bool transA, 406 | bool transB, float alpha, float beta) { 407 | float *d_A, *d_B, *d_bias, *d_out; 408 | 409 | cudaMalloc((void **)&d_A, n * m * sizeof(float)); 410 | cudaMalloc((void **)&d_B, m * k * sizeof(float)); 411 | cudaMalloc((void **)&d_bias, k * sizeof(float)); 412 | cudaMalloc((void **)&d_out, n * k * sizeof(float)); 413 | 414 | cudaMemcpy(d_A, A, n * m * sizeof(float), cudaMemcpyHostToDevice); 415 | cudaMemcpy(d_B, B, m * k * sizeof(float), cudaMemcpyHostToDevice); 416 | cudaMemcpy(d_bias, bias, n * sizeof(float), cudaMemcpyHostToDevice); 417 | 418 | dim3 blockSize(BLOCK_SIZE, BLOCK_SIZE); 419 | dim3 gridSize((k + blockSize.x - 1) / blockSize.x, 420 | (n + blockSize.y - 1) / blockSize.y); 421 | 422 | gemm_kernel_naive<<>>(d_A, d_B, d_bias, d_out, n, m, k, 423 | transA, transB, alpha, beta); 424 | 425 | cudaMemcpy(out, d_out, n * k * sizeof(float), cudaMemcpyDeviceToHost); 426 | 427 | cudaFree(d_A); 428 | cudaFree(d_B); 429 | cudaFree(d_bias); 430 | cudaFree(d_out); 431 | } 432 | 433 | __global__ void relu_kernel(const float *in, float *out, int n) { 434 | int idx = blockIdx.x * blockDim.x + threadIdx.x; 435 | if (idx < n) { 436 | out[idx] = in[idx] < 0 ? 0 : in[idx]; 437 | } 438 | } 439 | 440 | void relu_cuda(const float *in, float *out, int n) { 441 | relu_kernel<<>>(in, out, n); 442 | } 443 | 444 | void relu_cuda_unoptimized(const float *in, float *out, int n) { 445 | float *d_in, *d_out; 446 | cudaMalloc((void **)&d_in, n * sizeof(float)); 447 | cudaMalloc((void **)&d_out, n * sizeof(float)); 448 | 449 | cudaMemcpy(d_in, in, n * sizeof(float), cudaMemcpyHostToDevice); 450 | 451 | relu_kernel<<>>(d_in, d_out, n); 452 | 453 | cudaMemcpy(out, d_out, n * sizeof(float), cudaMemcpyDeviceToHost); 454 | cudaFree(d_in); 455 | cudaFree(d_out); 456 | } 457 | 458 | __global__ void add_kernel(const float *A, const float *B, float *out, int n) { 459 | int idx = blockIdx.x * blockDim.x + threadIdx.x; 460 | if (idx < n) { 461 | out[idx] = A[idx] + B[idx]; 462 | } 463 | } 464 | 465 | void add_cuda(const float *A, const float *B, float *out, int n) { 466 | add_kernel<<>>(A, B, out, n); 467 | } 468 | 469 | void add_cuda_unoptimized(const float *A, const float *B, float *out, int n) { 470 | float *d_A, *d_B, *d_out; 471 | cudaMalloc((void **)&d_A, n * sizeof(float)); 472 | cudaMalloc((void **)&d_B, n * sizeof(float)); 473 | cudaMalloc((void **)&d_out, n * sizeof(float)); 474 | 475 | cudaMemcpy(d_A, A, n * sizeof(float), cudaMemcpyHostToDevice); 476 | cudaMemcpy(d_B, B, n * sizeof(float), cudaMemcpyHostToDevice); 477 | 478 | add_kernel<<>>(d_A, d_B, d_out, n); 479 | 480 | cudaMemcpy(out, d_out, n * sizeof(float), cudaMemcpyDeviceToHost); 481 | cudaFree(d_A); 482 | cudaFree(d_B); 483 | cudaFree(d_out); 484 | } -------------------------------------------------------------------------------- /src/kernels.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNELS_H 2 | #define KERNELS_H 3 | 4 | void gemm_cuda_tiled(const float* A, const float* B, const float* bias, 5 | float* out, int n, int m, int k, bool transA, bool transB, 6 | float alpha, float beta); 7 | 8 | void gemm_cuda_tiled_1D(const float* A, const float* B, const float* bias, 9 | float* out, int n, int m, int k, bool transA, 10 | bool transB, float alpha, float beta); 11 | 12 | void gemm_tiled_1D_blocktiling(const float* A, const float* B, 13 | const float* bias, float* out, int n, int m, 14 | int k, bool transA, bool transB, float alpha, 15 | float beta); 16 | 17 | void gemm_cuda_naive(const float* A, const float* B, const float* bias, 18 | float* out, int n, int m, int k, bool transA, bool transB, 19 | float alpha, float beta); 20 | 21 | void gemm_cuda_unoptimized(const float* A, const float* B, const float* bias, 22 | float* out, int n, int m, int k, bool transA, 23 | bool transB, float alpha, float beta); 24 | 25 | void relu_cuda(const float* in, float* out, int n); 26 | 27 | void relu_cuda_unoptimized(const float* in, float* out, int n); 28 | 29 | void add_cuda(const float* A, const float* B, float* out, int n); 30 | 31 | void add_cuda_unoptimized(const float* A, const float* B, float* out, int n); 32 | 33 | #endif // KERNELS_H 34 | -------------------------------------------------------------------------------- /src/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "cpu_provider.h" 7 | #include "cuda_provider.h" 8 | #include "cuda_provider_unoptimized.h" 9 | #include "inference_session.h" 10 | #include "input_loader.h" 11 | #include "model_config.h" 12 | 13 | int main(int argc, char** argv) { 14 | if (argc != 2) { 15 | std::cerr << "Usage: " << argv[0] << " " << std::endl; 16 | return 1; 17 | } 18 | 19 | ModelConfig config(argv[1]); 20 | std::cout << "Model path: " << config.get_model_path() << std::endl; 21 | std::cout << "Device: " 22 | << (config.get_device() == Device::CPU ? "CPU" : "CUDA") 23 | << std::endl; 24 | std::cout << "Batch size: " << config.get_batch_size() << std::endl; 25 | 26 | InferenceSession session; 27 | session.load_model(config); 28 | 29 | std::unique_ptr provider; 30 | Device device = config.get_device(); 31 | if (device == Device::CPU) { 32 | provider = std::make_unique(); 33 | } else if (device == Device::CUDA) { 34 | provider = std::make_unique(); 35 | } else if (device == Device::CUDA_SLOW) { 36 | provider = std::make_unique(); 37 | } else { 38 | throw std::runtime_error("Unknown device type"); 39 | } 40 | 41 | session.set_execution_provider(std::move(provider)); 42 | 43 | // Moves model weights to device memory. 44 | auto start = std::chrono::high_resolution_clock::now(); 45 | session.initialize_provider(); 46 | auto end = std::chrono::high_resolution_clock::now(); 47 | auto duration = 48 | std::chrono::duration_cast(end - start); 49 | 50 | std::cout << "initialize_provider took: " << duration.count() 51 | << " microseconds" << std::endl; 52 | 53 | std::string file = "./inputs/image_"; 54 | 55 | // Preload all inputs into memory 56 | int loops{100}; 57 | int inferences{1}; 58 | int total_inferences{loops * inferences}; 59 | std::vector> inputs; 60 | inputs.reserve(total_inferences); 61 | for (int j = 0; j < loops; ++j) { 62 | for (int i = 0; i < inferences; ++i) { 63 | std::ostringstream oss; 64 | oss << file << i << ".ubyte"; 65 | std::string formattedString = oss.str(); 66 | inputs.push_back(load_input(formattedString, config)); 67 | } 68 | } 69 | 70 | // Create mini batches. Batch size is configured via yaml file. 71 | std::vector> mini_batches; 72 | int i = 0; 73 | while (i < total_inferences) { 74 | std::vector> batch; 75 | for (int j = 0; j < config.get_batch_size() && i + j < total_inferences; 76 | ++j) { 77 | batch.push_back(inputs[i + j]); 78 | } 79 | if (batch.size() > 0) { 80 | mini_batches.push_back(Tensor::stack(batch)); 81 | } 82 | i += batch.size(); 83 | } 84 | 85 | // Inference 86 | std::vector> res; 87 | res.reserve(total_inferences); 88 | start = std::chrono::high_resolution_clock::now(); 89 | for (auto batch : mini_batches) { 90 | session.set_input(config.get_inputs()[0].name, std::move(batch)); 91 | session.run(); 92 | res.push_back(session.get_output(config.get_outputs()[0].name)); 93 | } 94 | end = std::chrono::high_resolution_clock::now(); 95 | 96 | for (auto v : res) { 97 | std::cout << "Out: " << v.toString() << "\n"; 98 | } 99 | 100 | duration = 101 | std::chrono::duration_cast(end - start); 102 | 103 | std::cout << "Loop took: " << duration.count() << " microseconds" 104 | << std::endl; 105 | std::cout << "Avg inference duration: " 106 | << duration.count() / total_inferences << " microseconds" 107 | << std::endl; 108 | 109 | auto duration_s = 110 | std::chrono::duration_cast(duration); 111 | std::cout << "throughput = " << 1000 * total_inferences / duration_s.count() 112 | << "\n"; 113 | 114 | return 0; 115 | } 116 | -------------------------------------------------------------------------------- /src/memory_allocator.h: -------------------------------------------------------------------------------- 1 | #ifndef MEMORY_ALLOCATOR_H 2 | #define MEMORY_ALLOCATOR_H 3 | 4 | #include 5 | 6 | #include "device.h" 7 | 8 | class Allocator { 9 | public: 10 | virtual void* allocate(size_t size) = 0; 11 | virtual void deallocate(void* ptr) = 0; 12 | virtual DeviceType getDeviceType() const = 0; 13 | virtual ~Allocator() = default; 14 | }; 15 | 16 | #endif -------------------------------------------------------------------------------- /src/model_config.cpp: -------------------------------------------------------------------------------- 1 | 2 | // model_config.cpp 3 | #include "model_config.h" 4 | 5 | #include 6 | #include 7 | 8 | ModelConfig::ModelConfig(const std::string& config_file) { 9 | parse_config_file(config_file); 10 | } 11 | 12 | void ModelConfig::parse_config_file(const std::string& config_file) { 13 | try { 14 | YAML::Node config = YAML::LoadFile(config_file); 15 | 16 | model_path = config["model_path"].as(); 17 | model_format = config["model_format"].as(); 18 | device = 19 | string_to_device(config["execution_provider"].as()); 20 | 21 | batch_size = config["batch_size"].as(1); 22 | 23 | if (config["inputs"]) { 24 | for (const auto& input : config["inputs"]) { 25 | inputs.push_back(parse_tensor_config(input)); 26 | } 27 | } 28 | 29 | if (config["outputs"]) { 30 | for (const auto& output : config["outputs"]) { 31 | outputs.push_back(parse_tensor_config(output)); 32 | } 33 | } 34 | 35 | } catch (const YAML::Exception& e) { 36 | throw std::runtime_error("Error parsing YAML config file: " + 37 | std::string(e.what())); 38 | } 39 | 40 | // Validate that all required fields are set 41 | if (model_path.empty() || model_format.empty() || inputs.empty() || 42 | outputs.empty()) { 43 | throw std::runtime_error( 44 | "Invalid config file: missing required fields"); 45 | } 46 | } 47 | 48 | DataType ModelConfig::string_to_data_type(const std::string& str) { 49 | if (str == "FLOAT32") return DataType::FLOAT32; 50 | throw std::runtime_error("Unknown data type: " + str); 51 | } 52 | 53 | Device ModelConfig::string_to_device(const std::string& str) { 54 | if (str == "CPU") return Device::CPU; 55 | if (str == "CUDA") return Device::CUDA; 56 | if (str == "CUDA_SLOW") return Device::CUDA_SLOW; 57 | throw std::runtime_error("Unknown device: " + str); 58 | } 59 | 60 | TensorConfig ModelConfig::parse_tensor_config(const YAML::Node& node) { 61 | TensorConfig tensor_config; 62 | tensor_config.name = node["name"].as(); 63 | tensor_config.shape = node["shape"].as>(); 64 | tensor_config.data_type = 65 | string_to_data_type(node["data_type"].as()); 66 | return tensor_config; 67 | } -------------------------------------------------------------------------------- /src/model_config.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef MODEL_CONFIG_H 3 | #define MODEL_CONFIG_H 4 | 5 | #include 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | enum class DataType { 12 | FLOAT32, 13 | }; 14 | 15 | enum class Device { 16 | CPU, 17 | CUDA, 18 | CUDA_SLOW, 19 | }; 20 | 21 | struct TensorConfig { 22 | std::string name; 23 | std::vector shape; 24 | DataType data_type; 25 | }; 26 | 27 | class ModelConfig { 28 | public: 29 | ModelConfig() = default; 30 | ModelConfig(const std::string& config_file); 31 | 32 | std::string get_model_path() const { return model_path; } 33 | std::string get_model_format() const { return model_format; } 34 | Device get_device() const { return device; } 35 | const std::vector& get_inputs() const { return inputs; } 36 | const std::vector& get_outputs() const { return outputs; } 37 | int get_batch_size() const { return batch_size; } 38 | 39 | private: 40 | std::string model_path; 41 | std::string model_format; 42 | Device device; 43 | std::vector inputs; 44 | std::vector outputs; 45 | int batch_size; 46 | 47 | void parse_config_file(const std::string& config_file); 48 | DataType string_to_data_type(const std::string& str); 49 | Device string_to_device(const std::string& str); 50 | TensorConfig parse_tensor_config(const YAML::Node& node); 51 | }; 52 | 53 | #endif -------------------------------------------------------------------------------- /src/node.cpp: -------------------------------------------------------------------------------- 1 | #include "node.h" 2 | 3 | #include "tensor.h" 4 | 5 | OpType onnxOpTypeConverter(const std::string opType); 6 | 7 | Node::Node(const std::string &name, OpType opType) 8 | : name(name), opType(opType) {} 9 | 10 | Node::Node(const onnx::NodeProto &nodeProto) 11 | : name(nodeProto.name()), 12 | opType(onnxOpTypeConverter(nodeProto.op_type())), 13 | inputs(nodeProto.input().begin(), nodeProto.input().end()), 14 | outputs(nodeProto.output().begin(), nodeProto.output().end()) { 15 | for (const auto &attrProto : nodeProto.attribute()) { 16 | attributes.emplace(attrProto.name(), Attribute(attrProto)); 17 | } 18 | } 19 | 20 | const std::string &Node::getName() const { return name; } 21 | 22 | OpType Node::getOpType() const { return opType; } 23 | 24 | const std::vector &Node::getInputs() const { return inputs; } 25 | 26 | const std::vector &Node::getOutputs() const { return outputs; } 27 | 28 | void Node::addInput(std::string input) { inputs.push_back(input); } 29 | 30 | void Node::addOutput(std::string output) { outputs.push_back(output); } 31 | 32 | template 33 | std::optional Node::getAttribute(const std::string &name) const { 34 | auto it = attributes.find(name); 35 | if (it != attributes.end() && 36 | std::holds_alternative(it->second.getValue())) { 37 | return std::get(it->second.getValue()); 38 | } 39 | return std::nullopt; 40 | } 41 | 42 | OpType onnxOpTypeConverter(const std::string opType) { 43 | if (opType == "Gemm") { 44 | return OpType::Gemm; 45 | } 46 | if (opType == "Add") { 47 | return OpType::Add; 48 | } else if (opType == "Relu") { 49 | return OpType::Relu; 50 | } else if (opType == "Flatten") { 51 | return OpType::Flatten; 52 | } else if (opType == "Conv") { 53 | return OpType::Conv; 54 | } else if (opType == "MaxPool") { 55 | return OpType::MaxPool; 56 | } 57 | 58 | throw std::runtime_error("Unknown operation type: " + opType); 59 | } 60 | 61 | template std::optional Node::getAttribute(const std::string &) const; 62 | template std::optional Node::getAttribute(const std::string &) const; 63 | template std::optional> Node::getAttribute( 64 | const std::string &) const; -------------------------------------------------------------------------------- /src/node.h: -------------------------------------------------------------------------------- 1 | #ifndef NODE_H 2 | #define NODE_H 3 | 4 | #include 5 | #include 6 | 7 | #include "attribute.h" 8 | #include "onnx-ml.pb.h" 9 | #include "optypes.h" 10 | #include "tensor.h" 11 | 12 | class Node { 13 | public: 14 | Node(const std::string &name, const OpType optype); 15 | Node(const onnx::NodeProto &nodeProto); 16 | 17 | const std::string &getName() const; 18 | OpType getOpType() const; 19 | const std::vector &getInputs() const; 20 | const std::vector &getOutputs() const; 21 | const std::unordered_map &getAttributes() const; 22 | template 23 | std::optional getAttribute(const std::string &name) const; 24 | 25 | void addInput(std::string input); 26 | void addOutput(std::string output); 27 | 28 | private: 29 | std::string name; 30 | OpType opType; 31 | // Sorted list of inputs as expected the by the corresponding opType. 32 | std::vector inputs; 33 | std::vector outputs; 34 | std::unordered_map attributes; 35 | }; 36 | 37 | #endif -------------------------------------------------------------------------------- /src/onnx-ml.proto: -------------------------------------------------------------------------------- 1 | // 2 | // WARNING: This file is automatically generated! Please edit onnx.in.proto. 3 | // 4 | 5 | 6 | // SPDX-License-Identifier: Apache-2.0 7 | 8 | 9 | syntax = "proto2"; 10 | 11 | package onnx; 12 | 13 | // Overview 14 | // 15 | // ONNX is an open specification that is comprised of the following components: 16 | // 17 | // 1) A definition of an extensible computation graph model. 18 | // 2) Definitions of standard data types. 19 | // 3) Definitions of built-in operators. 20 | // 21 | // This document describes the syntax of models and their computation graphs, 22 | // as well as the standard data types. Together, they are referred to as the ONNX 23 | // Intermediate Representation, or 'IR' for short. 24 | // 25 | // The normative semantic specification of the ONNX IR is found in docs/IR.md. 26 | // Definitions of the built-in neural network operators may be found in docs/Operators.md. 27 | // Definitions of the built-in classical machine learning operators may be found in 28 | // docs/Operators-ml.md. 29 | 30 | // Notes 31 | // 32 | // Protobuf compatibility 33 | // 34 | // To simplify framework compatibility, ONNX is defined using the subset of protobuf 35 | // that is compatible with both protobuf v2 and v3. This means that we do not use any 36 | // protobuf features that are only available in one of the two versions. 37 | // 38 | // Here are the most notable contortions we have to carry out to work around 39 | // these limitations: 40 | // 41 | // - No 'map' (added protobuf 3.0). We instead represent mappings as lists 42 | // of key-value pairs, where order does not matter and duplicates 43 | // are not allowed. 44 | 45 | 46 | // Versioning 47 | // 48 | // ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md 49 | // 50 | // To be compatible with both proto2 and proto3, we will use a version number 51 | // that is not defined by the default value but an explicit enum number. 52 | enum Version { 53 | // proto3 requires the first enum value to be zero. 54 | // We add this just to appease the compiler. 55 | _START_VERSION = 0; 56 | // The version field is always serialized and we will use it to store the 57 | // version that the graph is generated from. This helps us set up version 58 | // control. 59 | // For the IR, we are using simple numbers starting with 0x00000001, 60 | // which was the version we published on Oct 10, 2017. 61 | IR_VERSION_2017_10_10 = 0x0000000000000001; 62 | 63 | // IR_VERSION 2 published on Oct 30, 2017 64 | // - Added type discriminator to AttributeProto to support proto3 users 65 | IR_VERSION_2017_10_30 = 0x0000000000000002; 66 | 67 | // IR VERSION 3 published on Nov 3, 2017 68 | // - For operator versioning: 69 | // - Added new message OperatorSetIdProto 70 | // - Added opset_import in ModelProto 71 | // - For vendor extensions, added domain in NodeProto 72 | IR_VERSION_2017_11_3 = 0x0000000000000003; 73 | 74 | // IR VERSION 4 published on Jan 22, 2019 75 | // - Relax constraint that initializers should be a subset of graph inputs 76 | // - Add type BFLOAT16 77 | IR_VERSION_2019_1_22 = 0x0000000000000004; 78 | 79 | // IR VERSION 5 published on March 18, 2019 80 | // - Add message TensorAnnotation. 81 | // - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters. 82 | IR_VERSION_2019_3_18 = 0x0000000000000005; 83 | 84 | // IR VERSION 6 published on Sep 19, 2019 85 | // - Add support for sparse tensor constants stored in model. 86 | // - Add message SparseTensorProto 87 | // - Add sparse initializers 88 | IR_VERSION_2019_9_19 = 0x0000000000000006; 89 | 90 | // IR VERSION 7 published on May 8, 2020 91 | // - Add support to allow function body graph to rely on multiple external opreator sets. 92 | // - Add a list to promote inference graph's initializers to global and 93 | // mutable variables. Global variables are visible in all graphs of the 94 | // stored models. 95 | // - Add message TrainingInfoProto to store initialization 96 | // method and training algorithm. The execution of TrainingInfoProto 97 | // can modify the values of mutable variables. 98 | // - Implicitly add inference graph into each TrainingInfoProto's algorithm. 99 | IR_VERSION_2020_5_8 = 0x0000000000000007; 100 | 101 | // IR VERSION 8 published on July 30, 2021 102 | // Introduce TypeProto.SparseTensor 103 | // Introduce TypeProto.Optional 104 | // Added a list of FunctionProtos local to the model 105 | // Deprecated since_version and operator status from FunctionProto 106 | IR_VERSION_2021_7_30 = 0x0000000000000008; 107 | 108 | // IR VERSION 9 published on May 5, 2023 109 | // Added AttributeProto to FunctionProto so that default attribute values can be set. 110 | // Added FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ. 111 | IR_VERSION_2023_5_5 = 0x0000000000000009; 112 | 113 | // IR VERSION 10 published on TBD 114 | // Added UINT4, INT4. 115 | IR_VERSION = 0x000000000000000A; 116 | } 117 | 118 | // Attributes 119 | // 120 | // A named attribute containing either singular float, integer, string, graph, 121 | // and tensor values, or repeated float, integer, string, graph, and tensor values. 122 | // An AttributeProto MUST contain the name field, and *only one* of the 123 | // following content fields, effectively enforcing a C/C++ union equivalent. 124 | message AttributeProto { 125 | reserved 12, 16 to 19; 126 | reserved "v"; 127 | 128 | // Note: this enum is structurally identical to the OpSchema::AttrType 129 | // enum defined in schema.h. If you rev one, you likely need to rev the other. 130 | enum AttributeType { 131 | UNDEFINED = 0; 132 | FLOAT = 1; 133 | INT = 2; 134 | STRING = 3; 135 | TENSOR = 4; 136 | GRAPH = 5; 137 | SPARSE_TENSOR = 11; 138 | TYPE_PROTO = 13; 139 | 140 | FLOATS = 6; 141 | INTS = 7; 142 | STRINGS = 8; 143 | TENSORS = 9; 144 | GRAPHS = 10; 145 | SPARSE_TENSORS = 12; 146 | TYPE_PROTOS = 14; 147 | } 148 | 149 | // The name field MUST be present for this version of the IR. 150 | optional string name = 1; // namespace Attribute 151 | 152 | // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function. 153 | // In this case, this AttributeProto does not contain data, and it's a reference of attribute 154 | // in parent scope. 155 | // NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph. 156 | optional string ref_attr_name = 21; 157 | 158 | // A human-readable documentation for this attribute. Markdown is allowed. 159 | optional string doc_string = 13; 160 | 161 | // The type field MUST be present for this version of the IR. 162 | // For 0.0.1 versions of the IR, this field was not defined, and 163 | // implementations needed to use has_field heuristics to determine 164 | // which value field was in use. For IR_VERSION 0.0.2 or later, this 165 | // field MUST be set and match the f|i|s|t|... field in use. This 166 | // change was made to accommodate proto3 implementations. 167 | optional AttributeType type = 20; // discriminator that indicates which field below is in use 168 | 169 | // Exactly ONE of the following fields must be present for this version of the IR 170 | optional float f = 2; // float 171 | optional int64 i = 3; // int 172 | optional bytes s = 4; // UTF-8 string 173 | optional TensorProto t = 5; // tensor value 174 | optional GraphProto g = 6; // graph 175 | optional SparseTensorProto sparse_tensor = 22; // sparse tensor value 176 | // Do not use field below, it's deprecated. 177 | // optional ValueProto v = 12; // value - subsumes everything but graph 178 | optional TypeProto tp = 14; // type proto 179 | 180 | repeated float floats = 7; // list of floats 181 | repeated int64 ints = 8; // list of ints 182 | repeated bytes strings = 9; // list of UTF-8 strings 183 | repeated TensorProto tensors = 10; // list of tensors 184 | repeated GraphProto graphs = 11; // list of graph 185 | repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors 186 | repeated TypeProto type_protos = 15;// list of type protos 187 | } 188 | 189 | // Defines information on value, including the name, the type, and 190 | // the shape of the value. 191 | message ValueInfoProto { 192 | // This field MUST be present in this version of the IR. 193 | optional string name = 1; // namespace Value 194 | // This field MUST be present in this version of the IR for 195 | // inputs and outputs of the top-level graph. 196 | optional TypeProto type = 2; 197 | // A human-readable documentation for this value. Markdown is allowed. 198 | optional string doc_string = 3; 199 | // Named metadata values; keys should be distinct. 200 | repeated StringStringEntryProto metadata_props = 4; 201 | } 202 | 203 | // Nodes 204 | // 205 | // Computation graphs are made up of a DAG of nodes, which represent what is 206 | // commonly called a "layer" or "pipeline stage" in machine learning frameworks. 207 | // 208 | // For example, it can be a node of type "Conv" that takes in an image, a filter 209 | // tensor and a bias tensor, and produces the convolved output. 210 | message NodeProto { 211 | repeated string input = 1; // namespace Value 212 | repeated string output = 2; // namespace Value 213 | 214 | // An optional identifier for this node in a graph. 215 | // This field MAY be absent in this version of the IR. 216 | optional string name = 3; // namespace Node 217 | 218 | // The symbolic identifier of the Operator to execute. 219 | optional string op_type = 4; // namespace Operator 220 | // The domain of the OperatorSet that specifies the operator named by op_type. 221 | optional string domain = 7; // namespace Domain 222 | // Overload identifier, used only to map this to a model-local function. 223 | optional string overload = 8; 224 | 225 | // Additional named attributes. 226 | repeated AttributeProto attribute = 5; 227 | 228 | // A human-readable documentation for this node. Markdown is allowed. 229 | optional string doc_string = 6; 230 | 231 | // Named metadata values; keys should be distinct. 232 | repeated StringStringEntryProto metadata_props = 9; 233 | } 234 | 235 | // Training information 236 | // TrainingInfoProto stores information for training a model. 237 | // In particular, this defines two functionalities: an initialization-step 238 | // and a training-algorithm-step. Initialization resets the model 239 | // back to its original state as if no training has been performed. 240 | // Training algorithm improves the model based on input data. 241 | // 242 | // The semantics of the initialization-step is that the initializers 243 | // in ModelProto.graph and in TrainingInfoProto.algorithm are first 244 | // initialized as specified by the initializers in the graph, and then 245 | // updated by the "initialization_binding" in every instance in 246 | // ModelProto.training_info. 247 | // 248 | // The field "algorithm" defines a computation graph which represents a 249 | // training algorithm's step. After the execution of a 250 | // TrainingInfoProto.algorithm, the initializers specified by "update_binding" 251 | // may be immediately updated. If the targeted training algorithm contains 252 | // consecutive update steps (such as block coordinate descent methods), 253 | // the user needs to create a TrainingInfoProto for each step. 254 | message TrainingInfoProto { 255 | // This field describes a graph to compute the initial tensors 256 | // upon starting the training process. Initialization graph has no input 257 | // and can have multiple outputs. Usually, trainable tensors in neural 258 | // networks are randomly initialized. To achieve that, for each tensor, 259 | // the user can put a random number operator such as RandomNormal or 260 | // RandomUniform in TrainingInfoProto.initialization.node and assign its 261 | // random output to the specific tensor using "initialization_binding". 262 | // This graph can also set the initializers in "algorithm" in the same 263 | // TrainingInfoProto; a use case is resetting the number of training 264 | // iteration to zero. 265 | // 266 | // By default, this field is an empty graph and its evaluation does not 267 | // produce any output. Thus, no initializer would be changed by default. 268 | optional GraphProto initialization = 1; 269 | 270 | // This field represents a training algorithm step. Given required inputs, 271 | // it computes outputs to update initializers in its own or inference graph's 272 | // initializer lists. In general, this field contains loss node, gradient node, 273 | // optimizer node, increment of iteration count. 274 | // 275 | // An execution of the training algorithm step is performed by executing the 276 | // graph obtained by combining the inference graph (namely "ModelProto.graph") 277 | // and the "algorithm" graph. That is, the actual 278 | // input/initializer/output/node/value_info/sparse_initializer list of 279 | // the training graph is the concatenation of 280 | // "ModelProto.graph.input/initializer/output/node/value_info/sparse_initializer" 281 | // and "algorithm.input/initializer/output/node/value_info/sparse_initializer" 282 | // in that order. This combined graph must satisfy the normal ONNX conditions. 283 | // Now, let's provide a visualization of graph combination for clarity. 284 | // Let the inference graph (i.e., "ModelProto.graph") be 285 | // tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d 286 | // and the "algorithm" graph be 287 | // tensor_d -> Add -> tensor_e 288 | // The combination process results 289 | // tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d -> Add -> tensor_e 290 | // 291 | // Notice that an input of a node in the "algorithm" graph may reference the 292 | // output of a node in the inference graph (but not the other way round). Also, inference 293 | // node cannot reference inputs of "algorithm". With these restrictions, inference graph 294 | // can always be run independently without training information. 295 | // 296 | // By default, this field is an empty graph and its evaluation does not 297 | // produce any output. Evaluating the default training step never 298 | // update any initializers. 299 | optional GraphProto algorithm = 2; 300 | 301 | // This field specifies the bindings from the outputs of "initialization" to 302 | // some initializers in "ModelProto.graph.initializer" and 303 | // the "algorithm.initializer" in the same TrainingInfoProto. 304 | // See "update_binding" below for details. 305 | // 306 | // By default, this field is empty and no initializer would be changed 307 | // by the execution of "initialization". 308 | repeated StringStringEntryProto initialization_binding = 3; 309 | 310 | // Gradient-based training is usually an iterative procedure. In one gradient 311 | // descent iteration, we apply 312 | // 313 | // x = x - r * g 314 | // 315 | // where "x" is the optimized tensor, "r" stands for learning rate, and "g" is 316 | // gradient of "x" with respect to a chosen loss. To avoid adding assignments 317 | // into the training graph, we split the update equation into 318 | // 319 | // y = x - r * g 320 | // x = y 321 | // 322 | // The user needs to save "y = x - r * g" into TrainingInfoProto.algorithm. To 323 | // tell that "y" should be assigned to "x", the field "update_binding" may 324 | // contain a key-value pair of strings, "x" (key of StringStringEntryProto) 325 | // and "y" (value of StringStringEntryProto). 326 | // For a neural network with multiple trainable (mutable) tensors, there can 327 | // be multiple key-value pairs in "update_binding". 328 | // 329 | // The initializers appears as keys in "update_binding" are considered 330 | // mutable variables. This implies some behaviors 331 | // as described below. 332 | // 333 | // 1. We have only unique keys in all "update_binding"s so that two 334 | // variables may not have the same name. This ensures that one 335 | // variable is assigned up to once. 336 | // 2. The keys must appear in names of "ModelProto.graph.initializer" or 337 | // "TrainingInfoProto.algorithm.initializer". 338 | // 3. The values must be output names of "algorithm" or "ModelProto.graph.output". 339 | // 4. Mutable variables are initialized to the value specified by the 340 | // corresponding initializer, and then potentially updated by 341 | // "initializer_binding"s and "update_binding"s in "TrainingInfoProto"s. 342 | // 343 | // This field usually contains names of trainable tensors 344 | // (in ModelProto.graph), optimizer states such as momentums in advanced 345 | // stochastic gradient methods (in TrainingInfoProto.graph), 346 | // and number of training iterations (in TrainingInfoProto.graph). 347 | // 348 | // By default, this field is empty and no initializer would be changed 349 | // by the execution of "algorithm". 350 | repeated StringStringEntryProto update_binding = 4; 351 | } 352 | 353 | // Models 354 | // 355 | // ModelProto is a top-level file/container format for bundling a ML model and 356 | // associating its computation graph with metadata. 357 | // 358 | // The semantics of the model are described by the associated GraphProto's. 359 | message ModelProto { 360 | // The version of the IR this model targets. See Version enum above. 361 | // This field MUST be present. 362 | optional int64 ir_version = 1; 363 | 364 | // The OperatorSets this model relies on. 365 | // All ModelProtos MUST have at least one entry that 366 | // specifies which version of the ONNX OperatorSet is 367 | // being imported. 368 | // 369 | // All nodes in the ModelProto's graph will bind against the operator 370 | // with the same-domain/same-op_type operator with the HIGHEST version 371 | // in the referenced operator sets. 372 | repeated OperatorSetIdProto opset_import = 8; 373 | 374 | // The name of the framework or tool used to generate this model. 375 | // This field SHOULD be present to indicate which implementation/tool/framework 376 | // emitted the model. 377 | optional string producer_name = 2; 378 | 379 | // The version of the framework or tool used to generate this model. 380 | // This field SHOULD be present to indicate which implementation/tool/framework 381 | // emitted the model. 382 | optional string producer_version = 3; 383 | 384 | // Domain name of the model. 385 | // We use reverse domain names as name space indicators. For example: 386 | // `com.facebook.fair` or `com.microsoft.cognitiveservices` 387 | // 388 | // Together with `model_version` and GraphProto.name, this forms the unique identity of 389 | // the graph. 390 | optional string domain = 4; 391 | 392 | // The version of the graph encoded. See Version enum below. 393 | optional int64 model_version = 5; 394 | 395 | // A human-readable documentation for this model. Markdown is allowed. 396 | optional string doc_string = 6; 397 | 398 | // The parameterized graph that is evaluated to execute the model. 399 | optional GraphProto graph = 7; 400 | 401 | // Named metadata values; keys should be distinct. 402 | repeated StringStringEntryProto metadata_props = 14; 403 | 404 | // Training-specific information. Sequentially executing all stored 405 | // `TrainingInfoProto.algorithm`s and assigning their outputs following 406 | // the corresponding `TrainingInfoProto.update_binding`s is one training 407 | // iteration. Similarly, to initialize the model 408 | // (as if training hasn't happened), the user should sequentially execute 409 | // all stored `TrainingInfoProto.initialization`s and assigns their outputs 410 | // using `TrainingInfoProto.initialization_binding`s. 411 | // 412 | // If this field is empty, the training behavior of the model is undefined. 413 | repeated TrainingInfoProto training_info = 20; 414 | 415 | // A list of function protos local to the model. 416 | // 417 | // The (domain, name, overload) tuple must be unique across the function protos in this list. 418 | // In case of any conflicts the behavior (whether the model local functions are given higher priority, 419 | // or standard operator sets are given higher priotity or this is treated as error) is defined by 420 | // the runtimes. 421 | // 422 | // The operator sets imported by FunctionProto should be compatible with the ones 423 | // imported by ModelProto and other model local FunctionProtos. 424 | // Example, if same operator set say 'A' is imported by a FunctionProto and ModelProto 425 | // or by 2 FunctionProtos then versions for the operator set may be different but, 426 | // the operator schema returned for op_type, domain, version combination 427 | // for both the versions should be same for every node in the function body. 428 | // 429 | // One FunctionProto can reference other FunctionProto in the model, however, recursive reference 430 | // is not allowed. 431 | repeated FunctionProto functions = 25; 432 | }; 433 | 434 | // StringStringEntryProto follows the pattern for cross-proto-version maps. 435 | // See https://developers.google.com/protocol-buffers/docs/proto3#maps 436 | message StringStringEntryProto { 437 | optional string key = 1; 438 | optional string value = 2; 439 | }; 440 | 441 | message TensorAnnotation { 442 | optional string tensor_name = 1; 443 | // pairs to annotate tensor specified by above. 444 | // The keys used in the mapping below must be pre-defined in ONNX spec. 445 | // For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as 446 | // quantization parameter keys. 447 | repeated StringStringEntryProto quant_parameter_tensor_names = 2; 448 | } 449 | 450 | 451 | 452 | // Graphs 453 | // 454 | // A graph defines the computational logic of a model and is comprised of a parameterized 455 | // list of nodes that form a directed acyclic graph based on their inputs and outputs. 456 | // This is the equivalent of the "network" or "graph" in many deep learning 457 | // frameworks. 458 | message GraphProto { 459 | // The nodes in the graph, sorted topologically. 460 | repeated NodeProto node = 1; 461 | 462 | // The name of the graph. 463 | optional string name = 2; // namespace Graph 464 | 465 | // A list of named tensor values, used to specify constant inputs of the graph. 466 | // Each initializer (both TensorProto as well SparseTensorProto) MUST have a name. 467 | // The name MUST be unique across both initializer and sparse_initializer, 468 | // but the name MAY also appear in the input list. 469 | repeated TensorProto initializer = 5; 470 | 471 | // Initializers (see above) stored in sparse format. 472 | repeated SparseTensorProto sparse_initializer = 15; 473 | 474 | // A human-readable documentation for this graph. Markdown is allowed. 475 | optional string doc_string = 10; 476 | 477 | // The inputs and outputs of the graph. 478 | repeated ValueInfoProto input = 11; 479 | repeated ValueInfoProto output = 12; 480 | 481 | // Information for the values in the graph. The ValueInfoProto.name's 482 | // must be distinct. It is optional for a value to appear in value_info list. 483 | repeated ValueInfoProto value_info = 13; 484 | 485 | // This field carries information to indicate the mapping among a tensor and its 486 | // quantization parameter tensors. For example: 487 | // For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated, 488 | // which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model. 489 | repeated TensorAnnotation quantization_annotation = 14; 490 | 491 | // Named metadata values; keys should be distinct. 492 | repeated StringStringEntryProto metadata_props = 16; 493 | 494 | reserved 3, 4, 6 to 9; 495 | reserved "ir_version", "producer_version", "producer_tag", "domain"; 496 | } 497 | 498 | // Tensors 499 | // 500 | // A serialized tensor value. 501 | message TensorProto { 502 | enum DataType { 503 | UNDEFINED = 0; 504 | // Basic types. 505 | FLOAT = 1; // float 506 | UINT8 = 2; // uint8_t 507 | INT8 = 3; // int8_t 508 | UINT16 = 4; // uint16_t 509 | INT16 = 5; // int16_t 510 | INT32 = 6; // int32_t 511 | INT64 = 7; // int64_t 512 | STRING = 8; // string 513 | BOOL = 9; // bool 514 | 515 | // IEEE754 half-precision floating-point format (16 bits wide). 516 | // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits. 517 | FLOAT16 = 10; 518 | 519 | DOUBLE = 11; 520 | UINT32 = 12; 521 | UINT64 = 13; 522 | COMPLEX64 = 14; // complex with float32 real and imaginary components 523 | COMPLEX128 = 15; // complex with float64 real and imaginary components 524 | 525 | // Non-IEEE floating-point format based on IEEE754 single-precision 526 | // floating-point number truncated to 16 bits. 527 | // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits. 528 | BFLOAT16 = 16; 529 | 530 | // Non-IEEE floating-point format based on papers 531 | // FP8 Formats for Deep Learning, https://arxiv.org/abs/2209.05433, 532 | // 8-bit Numerical Formats For Deep Neural Networks, https://arxiv.org/pdf/2206.02915.pdf. 533 | // Operators supported FP8 are Cast, CastLike, QuantizeLinear, DequantizeLinear. 534 | // The computation usually happens inside a block quantize / dequantize 535 | // fused by the runtime. 536 | FLOAT8E4M3FN = 17; // float 8, mostly used for coefficients, supports nan, not inf 537 | FLOAT8E4M3FNUZ = 18; // float 8, mostly used for coefficients, supports nan, not inf, no negative zero 538 | FLOAT8E5M2 = 19; // follows IEEE 754, supports nan, inf, mostly used for gradients 539 | FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, not inf, mostly used for gradients, no negative zero 540 | 541 | // 4-bit data-types 542 | UINT4 = 21; // Unsigned integer in range [0, 15] 543 | INT4 = 22; // Signed integer in range [-8, 7], using two's-complement representation 544 | 545 | // Future extensions go here. 546 | } 547 | 548 | // The shape of the tensor. 549 | repeated int64 dims = 1; 550 | 551 | // The data type of the tensor. 552 | // This field MUST have a valid TensorProto.DataType value 553 | optional int32 data_type = 2; 554 | 555 | // For very large tensors, we may want to store them in chunks, in which 556 | // case the following fields will specify the segment that is stored in 557 | // the current TensorProto. 558 | message Segment { 559 | optional int64 begin = 1; 560 | optional int64 end = 2; 561 | } 562 | optional Segment segment = 3; 563 | 564 | // Tensor content must be organized in row-major order. 565 | // 566 | // Depending on the data_type field, exactly one of the fields below with 567 | // name ending in _data is used to store the elements of the tensor. 568 | 569 | // For float and complex64 values 570 | // Complex64 tensors are encoded as a single array of floats, 571 | // with the real components appearing in odd numbered positions, 572 | // and the corresponding imaginary component appearing in the 573 | // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] 574 | // is encoded as [1.0, 2.0 ,3.0 ,4.0] 575 | // When this field is present, the data_type field MUST be FLOAT or COMPLEX64. 576 | repeated float float_data = 4 [packed = true]; 577 | 578 | // For int32, uint8, int8, uint16, int16, uint4, int4, bool, float8 and float16 values 579 | // float16 and float8 values must be bit-wise converted to an uint16_t prior 580 | // to writing to the buffer. 581 | // uint4 and int4 values must be packed to 4bitx2 prior to writing to the buffer, the first element is stored in 582 | // the 4 LSB and the second element is stored in the 4 MSB. 583 | // When this field is present, the data_type field MUST be 584 | // INT32, INT16, INT8, INT4, UINT16, UINT8, UINT4, BOOL, FLOAT16, BFLOAT16, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ 585 | repeated int32 int32_data = 5 [packed = true]; 586 | 587 | // For strings. 588 | // Each element of string_data is a UTF-8 encoded Unicode 589 | // string. No trailing null, no leading BOM. The protobuf "string" 590 | // scalar type is not used to match ML community conventions. 591 | // When this field is present, the data_type field MUST be STRING 592 | repeated bytes string_data = 6; 593 | 594 | // For int64. 595 | // When this field is present, the data_type field MUST be INT64 596 | repeated int64 int64_data = 7 [packed = true]; 597 | 598 | // Optionally, a name for the tensor. 599 | optional string name = 8; // namespace Value 600 | 601 | // A human-readable documentation for this tensor. Markdown is allowed. 602 | optional string doc_string = 12; 603 | 604 | // Serializations can either use one of the fields above, or use this 605 | // raw bytes field. The only exception is the string case, where one is 606 | // required to store the content in the repeated bytes string_data field. 607 | // 608 | // When this raw_data field is used to store tensor value, elements MUST 609 | // be stored in as fixed-width, little-endian order. 610 | // Floating-point data types MUST be stored in IEEE 754 format. 611 | // Complex64 elements must be written as two consecutive FLOAT values, real component first. 612 | // Complex128 elements must be written as two consecutive DOUBLE values, real component first. 613 | // Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false). 614 | // uint4 and int4 values must be packed to 4bitx2, the first element is stored in the 4 LSB and the second element is stored in the 4 MSB. 615 | // 616 | // Note: the advantage of specific field rather than the raw_data field is 617 | // that in some cases (e.g. int data), protobuf does a better packing via 618 | // variable length storage, and may lead to smaller binary footprint. 619 | // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED 620 | optional bytes raw_data = 9; 621 | 622 | // Data can be stored inside the protobuf file using type-specific fields or raw_data. 623 | // Alternatively, raw bytes data can be stored in an external file, using the external_data field. 624 | // external_data stores key-value pairs describing data location. Recognized keys are: 625 | // - "location" (required) - POSIX filesystem path relative to the directory where the ONNX 626 | // protobuf model was stored 627 | // - "offset" (optional) - position of byte at which stored data begins. Integer stored as string. 628 | // Offset values SHOULD be multiples 4096 (page size) to enable mmap support. 629 | // - "length" (optional) - number of bytes containing data. Integer stored as string. 630 | // - "checksum" (optional) - SHA1 digest of file specified in under 'location' key. 631 | repeated StringStringEntryProto external_data = 13; 632 | 633 | // Location of the data for this tensor. MUST be one of: 634 | // - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field. 635 | // - EXTERNAL - data stored in an external location as described by external_data field. 636 | enum DataLocation { 637 | DEFAULT = 0; 638 | EXTERNAL = 1; 639 | } 640 | 641 | // If value not set, data is stored in raw_data (if set) otherwise in type-specified field. 642 | optional DataLocation data_location = 14; 643 | 644 | // For double 645 | // Complex128 tensors are encoded as a single array of doubles, 646 | // with the real components appearing in odd numbered positions, 647 | // and the corresponding imaginary component appearing in the 648 | // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] 649 | // is encoded as [1.0, 2.0 ,3.0 ,4.0] 650 | // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 651 | repeated double double_data = 10 [packed = true]; 652 | 653 | // For uint64 and uint32 values 654 | // When this field is present, the data_type field MUST be 655 | // UINT32 or UINT64 656 | repeated uint64 uint64_data = 11 [packed = true]; 657 | 658 | // Named metadata values; keys should be distinct. 659 | repeated StringStringEntryProto metadata_props = 16; 660 | } 661 | 662 | // A serialized sparse-tensor value 663 | message SparseTensorProto { 664 | // The sequence of non-default values are encoded as a tensor of shape [NNZ]. 665 | // The default-value is zero for numeric tensors, and empty-string for string tensors. 666 | // values must have a non-empty name present which serves as a name for SparseTensorProto 667 | // when used in sparse_initializer list. 668 | optional TensorProto values = 1; 669 | 670 | // The indices of the non-default values, which may be stored in one of two formats. 671 | // (a) Indices can be a tensor of shape [NNZ, rank] with the [i,j]-th value 672 | // corresponding to the j-th index of the i-th value (in the values tensor). 673 | // (b) Indices can be a tensor of shape [NNZ], in which case the i-th value 674 | // must be the linearized-index of the i-th value (in the values tensor). 675 | // The linearized-index can be converted into an index tuple (k_1,...,k_rank) 676 | // using the shape provided below. 677 | // The indices must appear in ascending order without duplication. 678 | // In the first format, the ordering is lexicographic-ordering: 679 | // e.g., index-value [1,4] must appear before [2,1] 680 | optional TensorProto indices = 2; 681 | 682 | // The shape of the underlying dense-tensor: [dim_1, dim_2, ... dim_rank] 683 | repeated int64 dims = 3; 684 | } 685 | 686 | // Defines a tensor shape. A dimension can be either an integer value 687 | // or a symbolic variable. A symbolic variable represents an unknown 688 | // dimension. 689 | message TensorShapeProto { 690 | message Dimension { 691 | oneof value { 692 | int64 dim_value = 1; 693 | string dim_param = 2; // namespace Shape 694 | }; 695 | // Standard denotation can optionally be used to denote tensor 696 | // dimensions with standard semantic descriptions to ensure 697 | // that operations are applied to the correct axis of a tensor. 698 | // Refer to https://github.com/onnx/onnx/blob/main/docs/DimensionDenotation.md#denotation-definition 699 | // for pre-defined dimension denotations. 700 | optional string denotation = 3; 701 | }; 702 | repeated Dimension dim = 1; 703 | } 704 | 705 | // Types 706 | // 707 | // The standard ONNX data types. 708 | message TypeProto { 709 | 710 | message Tensor { 711 | // This field MUST NOT have the value of UNDEFINED 712 | // This field MUST have a valid TensorProto.DataType value 713 | // This field MUST be present for this version of the IR. 714 | optional int32 elem_type = 1; 715 | optional TensorShapeProto shape = 2; 716 | } 717 | 718 | // repeated T 719 | message Sequence { 720 | // The type and optional shape of each element of the sequence. 721 | // This field MUST be present for this version of the IR. 722 | optional TypeProto elem_type = 1; 723 | }; 724 | 725 | // map 726 | message Map { 727 | // This field MUST have a valid TensorProto.DataType value 728 | // This field MUST be present for this version of the IR. 729 | // This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING 730 | optional int32 key_type = 1; 731 | // This field MUST be present for this version of the IR. 732 | optional TypeProto value_type = 2; 733 | }; 734 | 735 | // wrapper for Tensor, Sequence, or Map 736 | message Optional { 737 | // The type and optional shape of the element wrapped. 738 | // This field MUST be present for this version of the IR. 739 | // Possible values correspond to OptionalProto.DataType enum 740 | optional TypeProto elem_type = 1; 741 | }; 742 | 743 | 744 | message SparseTensor { 745 | // This field MUST NOT have the value of UNDEFINED 746 | // This field MUST have a valid TensorProto.DataType value 747 | // This field MUST be present for this version of the IR. 748 | optional int32 elem_type = 1; 749 | optional TensorShapeProto shape = 2; 750 | } 751 | 752 | 753 | message Opaque { 754 | // When missing, the domain is the same as the model's. 755 | optional string domain = 1; 756 | // The name is optional but significant when provided. 757 | optional string name = 2; 758 | // parameters that help defining the type 759 | // DEPRECATED do not use. 760 | // repeated TypeProto parameters = 3; 761 | } 762 | 763 | 764 | oneof value { 765 | // The type of a tensor. 766 | Tensor tensor_type = 1; 767 | 768 | // NOTE: DNN-only implementations of ONNX MAY elect to not support non-tensor values 769 | // as input and output to graphs and nodes. These types are needed to naturally 770 | // support classical ML operators. DNN operators SHOULD restrict their input 771 | // and output types to tensors. 772 | 773 | // The type of a sequence. 774 | Sequence sequence_type = 4; 775 | 776 | // The type of a map. 777 | Map map_type = 5; 778 | 779 | // The type of an optional. 780 | Optional optional_type = 9; 781 | 782 | 783 | // Type of the sparse tensor 784 | SparseTensor sparse_tensor_type = 8; 785 | 786 | 787 | Opaque opaque_type = 7; 788 | 789 | } 790 | 791 | // An optional denotation can be used to denote the whole 792 | // type with a standard semantic description as to what is 793 | // stored inside. Refer to https://github.com/onnx/onnx/blob/main/docs/TypeDenotation.md#type-denotation-definition 794 | // for pre-defined type denotations. 795 | optional string denotation = 6; 796 | } 797 | 798 | // Operator Sets 799 | // 800 | // OperatorSets are uniquely identified by a (domain, opset_version) pair. 801 | message OperatorSetIdProto { 802 | // The domain of the operator set being identified. 803 | // The empty string ("") or absence of this field implies the operator 804 | // set that is defined as part of the ONNX specification. 805 | // This field MUST be present in this version of the IR when referring to any other operator set. 806 | optional string domain = 1; 807 | 808 | // The version of the operator set being identified. 809 | // This field MUST be present in this version of the IR. 810 | optional int64 version = 2; 811 | } 812 | 813 | // Operator/function status. 814 | enum OperatorStatus { 815 | EXPERIMENTAL = 0; 816 | STABLE = 1; 817 | } 818 | 819 | message FunctionProto { 820 | // The name of the function, similar to op_type in NodeProto. 821 | // This is part of the unique-id (domain, name, overload) of FunctionProtos in a model. 822 | optional string name = 1; 823 | 824 | // Deprecated since IR Version 8 825 | // optional int64 since_version = 2; 826 | reserved 2; 827 | reserved "since_version"; 828 | 829 | // Deprecated since IR Version 8 830 | // optional OperatorStatus status = 3; 831 | reserved 3; 832 | reserved "status"; 833 | 834 | // The inputs and outputs of the function. 835 | repeated string input = 4; 836 | repeated string output = 5; 837 | 838 | // The attribute parameters of the function. 839 | // It is for function parameters without default values. 840 | repeated string attribute = 6; 841 | 842 | // The attribute protos of the function. 843 | // It is for function attributes with default values. 844 | // A function attribute shall be represented either as 845 | // a string attribute or an AttributeProto, not both. 846 | repeated AttributeProto attribute_proto = 11; 847 | 848 | // The nodes in the function. 849 | repeated NodeProto node = 7; 850 | // A human-readable documentation for this function. Markdown is allowed. 851 | optional string doc_string = 8; 852 | 853 | // The OperatorSets this function body (graph) relies on. 854 | // 855 | // All nodes in the function body (graph) will bind against the operator 856 | // with the same-domain/same-op_type operator with the HIGHEST version 857 | // in the referenced operator sets. This means at most one version can be relied 858 | // for one domain. 859 | // 860 | // The operator sets imported by FunctionProto should be compatible with the ones 861 | // imported by ModelProto. Example, if same operator set say 'A' is imported by FunctionProto 862 | // and ModelProto then versions for the operator set may be different but, 863 | // the operator schema returned for op_type, domain, version combination 864 | // for both the versions should be same. 865 | 866 | repeated OperatorSetIdProto opset_import = 9; 867 | 868 | // The domain which this function belongs to. 869 | // This is part of the unique-id (domain, name, overload) of FunctionProtos in a model. 870 | optional string domain = 10; 871 | 872 | // The overload identifier of the function. 873 | // This is part of the unique-id (domain, name, overload) of FunctionProtos in a model. 874 | optional string overload = 13; 875 | 876 | // Information for the values in the function. The ValueInfoProto.name's 877 | // must be distinct and refer to names in the function (including inputs, 878 | // outputs, and intermediate values). It is optional for a value to appear 879 | // in value_info list. 880 | repeated ValueInfoProto value_info = 12; 881 | 882 | // Named metadata values; keys should be distinct. 883 | repeated StringStringEntryProto metadata_props = 14; 884 | } 885 | 886 | // For using protobuf-lite 887 | option optimize_for = LITE_RUNTIME; 888 | 889 | -------------------------------------------------------------------------------- /src/onnx_helper.cpp: -------------------------------------------------------------------------------- 1 | #include "onnx_helper.h" 2 | 3 | template 4 | std::tuple getAttr(const onnx::NodeProto& node, 5 | const std::string& attrName) { 6 | for (const auto& attr : node.attribute()) { 7 | if (attr.name() == attrName) { 8 | if (std::is_same::value) { 9 | return {true, attr.f()}; 10 | } else if (std::is_same::value) { 11 | return {true, attr.i()}; 12 | } 13 | } 14 | } 15 | return {false, {}}; 16 | } 17 | 18 | template std::tuple getAttr(const onnx::NodeProto& node, 19 | const std::string& attrName); 20 | template std::tuple getAttr(const onnx::NodeProto& node, 21 | const std::string& attrName); 22 | -------------------------------------------------------------------------------- /src/onnx_helper.h: -------------------------------------------------------------------------------- 1 | #ifndef ONNX_HELPER_H 2 | #define ONNX_HELPER_H 3 | 4 | #include 5 | 6 | #include "onnx-ml.pb.h" 7 | 8 | template 9 | std::tuple getAttr(const onnx::NodeProto& node, 10 | const std::string& attrName); 11 | 12 | #endif -------------------------------------------------------------------------------- /src/operators.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "operators.h" 3 | 4 | #include 5 | 6 | #include 7 | 8 | #include "gemm_cpu.h" 9 | #include "kernels.h" 10 | 11 | //-----------------// 12 | // CPU operators // 13 | //-----------------// 14 | 15 | template 16 | Tensor CpuOperators::gemm(const Tensor& A, const Tensor& B, 17 | const Tensor& bias, bool transA, bool transB, 18 | float alpha, float beta) { 19 | validate_gemm_inputs(A, B, bias, transA, transB); 20 | 21 | // Calculate output dimensions depending on transpositions. 22 | uint64_t N = transA ? A.shape()[1] : A.shape()[0]; 23 | uint64_t M = transB ? B.shape()[1] : B.shape()[0]; 24 | uint64_t K = transB ? B.shape()[0] : B.shape()[1]; 25 | std::vector dims{N, K}; 26 | 27 | Tensor out{std::move(dims)}; 28 | 29 | // assert(A.device() == DeviceType::CPU); 30 | const T* AData = A.data(); 31 | const T* BData = B.data(); 32 | const T* BiasData = bias.data(); 33 | gemm_cpu(AData, BData, BiasData, out.data(), N, M, K, transA, transB, alpha, 34 | beta); 35 | 36 | return out; 37 | } 38 | 39 | template 40 | Tensor CpuOperators::flatten(Tensor& tensor, uint64_t axis) { 41 | return base_flatten(tensor, axis); 42 | } 43 | 44 | template 45 | Tensor CpuOperators::relu(const Tensor& tensor) { 46 | Tensor output(tensor); 47 | T* raw = output.data(); 48 | for (std::size_t i = 0; i < output.size(); ++i) { 49 | raw[i] = std::max(static_cast(0), raw[i]); 50 | } 51 | return output; 52 | } 53 | 54 | template 55 | Tensor CpuOperators::add(const Tensor& A, const Tensor& B) { 56 | assert(A.shape() == B.shape()); 57 | Tensor output(A); 58 | T* raw = output.data(); 59 | const T* b_raw = B.data(); 60 | for (std::size_t i = 0; i < output.size(); ++i) { 61 | raw[i] += b_raw[i]; 62 | } 63 | return output; 64 | } 65 | 66 | //------------------// 67 | // CUDA operators // 68 | //------------------// 69 | 70 | template 71 | Tensor CudaOperators::gemm(const Tensor& A, const Tensor& B, 72 | const Tensor& bias, bool transA, 73 | bool transB, float alpha, float beta) { 74 | validate_gemm_inputs(A, B, bias, transA, transB); 75 | 76 | // Calculate output dimensions depending on transpositions. 77 | uint64_t N = transA ? A.shape()[1] : A.shape()[0]; 78 | uint64_t M = transB ? B.shape()[1] : B.shape()[0]; 79 | uint64_t K = transB ? B.shape()[0] : B.shape()[1]; 80 | 81 | std::vector dims{N, K}; 82 | 83 | Tensor out{std::move(dims)}; 84 | 85 | assert(A.device() == DeviceType::CUDA); 86 | const T* AData = A.data(); 87 | const T* BData = B.data(); 88 | const T* BiasData = bias.data(); 89 | 90 | gemm_cuda_tiled(AData, BData, BiasData, out.data(), N, M, K, transA, transB, 91 | alpha, beta); 92 | 93 | return out; 94 | } 95 | 96 | template 97 | Tensor CudaOperators::flatten(Tensor& tensor, uint64_t axis) { 98 | return base_flatten(tensor, axis); 99 | } 100 | 101 | // template 102 | // Tensor CudaOperators::relu(const Tensor& tensor) { 103 | // Tensor output(tensor); 104 | // relu_cuda(output.data(), output.size()); 105 | // return output; 106 | // } 107 | 108 | // template 109 | // Tensor CudaOperators::add(const Tensor& A, const Tensor& B) { 110 | // assert(A.shape() == B.shape()); 111 | // Tensor output(A); 112 | // add_cuda(output.data(), B.data(), output.size()); 113 | // return output; 114 | // } 115 | 116 | //------------------// 117 | // shared // 118 | //------------------// 119 | 120 | template 121 | void validate_gemm_inputs(const Tensor& A, const Tensor& B, 122 | const Tensor& bias, bool transA, bool transB) { 123 | if (A.shape().size() != 2 || B.shape().size() != 2 || 124 | bias.shape().size() == 0) { 125 | std::cerr << "A dims: " << A.shape().size() << " B dims " 126 | << B.shape().size() << " C dims " << bias.shape().size() 127 | << std::endl; 128 | std::cerr << "A.shape: " << A.stringShape() << std::endl; 129 | std::cerr << "B.shape: " << B.stringShape() << std::endl; 130 | std::cerr << "bias.shape: " << bias.stringShape() << std::endl; 131 | 132 | throw std::invalid_argument("Invalid dimensions for Gemm inputs."); 133 | } 134 | if (!transA && !transB && A.shape()[1] != B.shape()[0]) { 135 | std::cerr << "A.shape: " << A.stringShape() << std::endl; 136 | std::cerr << "B.shape: " << B.stringShape() << std::endl; 137 | throw std::invalid_argument( 138 | "Matrix dimensions are not compatible for multiplication in Gemm."); 139 | } 140 | if (transA && !transB && A.shape()[0] != B.shape()[0]) { 141 | std::cerr << "A.shape: " << A.stringShape() << std::endl; 142 | std::cerr << "B.shape: " << B.stringShape() << std::endl; 143 | throw std::invalid_argument( 144 | "Matrix dimensions are not compatible for multiplication in Gemm."); 145 | } 146 | if (transB && !transA && A.shape()[1] != B.shape()[1]) { 147 | std::cerr << "A.shape: " << A.stringShape() << std::endl; 148 | std::cerr << "B.shape: " << B.stringShape() << std::endl; 149 | throw std::invalid_argument( 150 | "Matrix dimensions are not compatible for multiplication in Gemm."); 151 | } 152 | if (transA && transB && A.shape()[0] != B.shape()[1]) { 153 | std::cerr << "A.shape: " << A.stringShape() << std::endl; 154 | std::cerr << "B.shape: " << B.stringShape() << std::endl; 155 | throw std::invalid_argument( 156 | "Matrix dimensions are not compatible for multiplication in Gemm."); 157 | } 158 | } 159 | 160 | template 161 | Tensor base_flatten(Tensor& tensor, uint64_t axis) { 162 | assert(axis <= tensor.shape().size()); 163 | 164 | uint64_t dimBefore = 1; 165 | for (std::size_t i = 0; i < axis; ++i) { 166 | dimBefore *= tensor.shape()[i]; 167 | } 168 | 169 | uint64_t dimAfter = 1; 170 | for (std::size_t i = axis; i < tensor.shape().size(); ++i) { 171 | dimAfter *= tensor.shape()[i]; 172 | } 173 | tensor.setShape({dimBefore, dimAfter}); 174 | return tensor; 175 | } 176 | 177 | template class CpuOperators; 178 | template class CudaOperators; 179 | -------------------------------------------------------------------------------- /src/operators.h: -------------------------------------------------------------------------------- 1 | #ifndef OPERATORS_H 2 | #define OPERATORS_H 3 | 4 | #include "tensor.h" 5 | 6 | template 7 | class CpuOperators { 8 | public: 9 | static Tensor gemm(const Tensor& A, const Tensor& B, 10 | const Tensor& bias, bool transA, bool transB, 11 | float alpha, float beta); 12 | static Tensor flatten(Tensor& tensor, uint64_t axis); 13 | static Tensor relu(const Tensor& tensor); 14 | static Tensor add(const Tensor& A, const Tensor& B); 15 | }; 16 | 17 | template 18 | class CudaOperators { 19 | public: 20 | static Tensor gemm(const Tensor& A, const Tensor& B, 21 | const Tensor& bias, bool transA, bool transB, 22 | float alpha, float beta); 23 | static Tensor flatten(Tensor& tensor, uint64_t axis); 24 | static Tensor relu(const Tensor& tensor); 25 | static Tensor add(const Tensor& A, const Tensor& B); 26 | }; 27 | 28 | #endif // OPERATORS_H -------------------------------------------------------------------------------- /src/optypes.h: -------------------------------------------------------------------------------- 1 | #ifndef OPTYPES_H 2 | #define OPTYPES_H 3 | 4 | enum class OpType { 5 | Input, // Input to the graph 6 | Output, // Output of the graph 7 | Add, // Add two tensors 8 | Gemm, // General Matrix Multiplication 9 | Flatten, // Flatten an input 10 | Relu, // Rectified Linear Unit 11 | Conv, // Convolutional Layer 12 | MaxPool, // Max Pooling Layer 13 | Constant, // Constant node 14 | }; 15 | 16 | #endif -------------------------------------------------------------------------------- /src/pinned_cpu_allocator.h: -------------------------------------------------------------------------------- 1 | #ifndef PINNED_CPU_ALLOCATOR_H 2 | #define PINNED_CPU_ALLOCATOR_H 3 | 4 | #include 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #include "memory_allocator.h" 12 | class PinnedMemoryPool { 13 | public: 14 | PinnedMemoryPool(size_t size) : total_size_(size) { 15 | // Use cudaMallocHost instead of cudaMalloc 16 | cudaMallocHost(&pool_, total_size_); 17 | free_blocks_.push_back({0, total_size_}); 18 | } 19 | 20 | ~PinnedMemoryPool() { 21 | // Use cudaFreeHost instead of cudaFree 22 | cudaFreeHost(pool_); 23 | } 24 | 25 | void* allocate(size_t size) { 26 | auto it = std::find_if( 27 | free_blocks_.begin(), free_blocks_.end(), 28 | [size](const auto& block) { return block.second >= size; }); 29 | 30 | if (it == free_blocks_.end()) { 31 | throw std::runtime_error("Pinned memory pool exhausted"); 32 | } 33 | 34 | size_t offset = it->first; 35 | size_t block_size = it->second; 36 | 37 | if (block_size > size) { 38 | // Split the block 39 | it->first += size; 40 | it->second -= size; 41 | } else { 42 | // Use the entire block 43 | free_blocks_.erase(it); 44 | } 45 | 46 | return static_cast(pool_) + offset; 47 | } 48 | 49 | void deallocate(void* ptr) { 50 | size_t offset = static_cast(ptr) - static_cast(pool_); 51 | 52 | auto it = std::lower_bound( 53 | free_blocks_.begin(), free_blocks_.end(), offset, 54 | [](const auto& block, size_t off) { return block.first < off; }); 55 | 56 | size_t size = (it != free_blocks_.end()) ? it->first - offset 57 | : total_size_ - offset; 58 | 59 | if (it != free_blocks_.begin() && 60 | (it - 1)->first + (it - 1)->second == offset) { 61 | // Merge with previous block 62 | auto prev = it - 1; 63 | prev->second += size; 64 | if (it != free_blocks_.end() && offset + size == it->first) { 65 | // Merge with next block too 66 | prev->second += it->second; 67 | free_blocks_.erase(it); 68 | } 69 | } else if (it != free_blocks_.end() && offset + size == it->first) { 70 | // Merge with next block 71 | it->first = offset; 72 | it->second += size; 73 | } else { 74 | // Insert new block 75 | free_blocks_.insert(it, {offset, size}); 76 | } 77 | } 78 | 79 | void reset() { 80 | free_blocks_.clear(); 81 | free_blocks_.push_back({0, total_size_}); 82 | } 83 | 84 | private: 85 | void* pool_; 86 | size_t total_size_; 87 | std::vector> free_blocks_; // offset, size 88 | }; 89 | 90 | class PinnedCpuAllocator : public Allocator { 91 | public: 92 | PinnedCpuAllocator(size_t pool_size = 100 * 1024 * 1024) // 100 MiB 93 | : pool_(std::make_unique(pool_size)) {} 94 | 95 | void* allocate(size_t size) override { 96 | if (pool_) { 97 | return pool_->allocate(size); 98 | } 99 | void* ptr = nullptr; 100 | cudaMallocHost(&ptr, size); // Allocate pinned memory 101 | return ptr; 102 | } 103 | 104 | void deallocate(void* ptr) override { 105 | if (pool_) { 106 | return pool_->deallocate(ptr); 107 | } 108 | cudaFreeHost(&ptr); 109 | } 110 | 111 | DeviceType getDeviceType() const override { return DeviceType::CPU; } 112 | 113 | private: 114 | std::unique_ptr pool_; 115 | }; 116 | 117 | #endif -------------------------------------------------------------------------------- /src/tensor.cpp: -------------------------------------------------------------------------------- 1 | #include "tensor.h" 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "cpu_allocator.h" 11 | #include "cuda_allocator.h" 12 | 13 | #ifdef USE_CUDA 14 | #include 15 | #endif 16 | 17 | template 18 | Tensor::Tensor(const std::vector& shape, 19 | std::shared_ptr alloc) 20 | : data_(nullptr), 21 | shape_(shape), 22 | size_(std::accumulate(shape.begin(), shape.end(), 1ULL, 23 | std::multiplies())), 24 | allocator_(std::move(alloc)) { 25 | allocateAndCopy(nullptr); 26 | } 27 | 28 | template 29 | Tensor::Tensor(const T* data, const std::vector& shape, 30 | std::shared_ptr alloc) 31 | : data_(nullptr), 32 | shape_(shape), 33 | size_(std::accumulate(shape.begin(), shape.end(), 1ULL, 34 | std::multiplies())), 35 | allocator_(std::move(alloc)) { 36 | allocateAndCopy(data); 37 | } 38 | 39 | template 40 | Tensor::Tensor(const Tensor& other) 41 | : data_(nullptr), 42 | shape_(other.shape_), 43 | size_(other.size_), 44 | allocator_(other.allocator_) { 45 | allocateAndCopy(other.data_); 46 | } 47 | 48 | template 49 | Tensor::Tensor(Tensor&& other) noexcept 50 | : data_(other.data_), 51 | shape_(std::move(other.shape_)), 52 | size_(other.size_), 53 | allocator_(std::move(other.allocator_)) { 54 | other.data_ = nullptr; 55 | other.size_ = 0; 56 | } 57 | 58 | template 59 | Tensor& Tensor::operator=(const Tensor& other) { 60 | if (this != &other) { 61 | freeMemory(); 62 | data_ = nullptr; 63 | shape_ = other.shape_; 64 | size_ = other.size_; 65 | allocator_ = other.allocator_; 66 | allocateAndCopy(other.data_); 67 | } 68 | return *this; 69 | } 70 | 71 | template 72 | Tensor& Tensor::operator=(Tensor&& other) noexcept { 73 | if (this != &other) { 74 | freeMemory(); 75 | data_ = other.data_; 76 | shape_ = std::move(other.shape_); 77 | size_ = other.size_; 78 | allocator_ = std::move(other.allocator_); 79 | other.data_ = nullptr; 80 | other.size_ = 0; 81 | } 82 | return *this; 83 | } 84 | 85 | template 86 | Tensor::~Tensor() { 87 | freeMemory(); 88 | } 89 | 90 | template 91 | void Tensor::allocateAndCopy(const T* src) { 92 | data_ = static_cast(allocator_->allocate(sizeof(T) * size_)); 93 | 94 | if (src) { 95 | if (allocator_->getDeviceType() == DeviceType::CPU) { 96 | std::memcpy(data_, src, size_ * sizeof(T)); 97 | } 98 | #ifdef USE_CUDA 99 | else if (allocator_->getDeviceType() == DeviceType::CUDA) { 100 | cudaMemcpyAsync(data_, src, size_ * sizeof(T), 101 | cudaMemcpyHostToDevice); 102 | } 103 | #endif 104 | else { 105 | throw std::runtime_error("Unsupported device type"); 106 | } 107 | } 108 | } 109 | 110 | template 111 | void Tensor::freeMemory() { 112 | if (data_) { 113 | allocator_->deallocate(data_); 114 | data_ = nullptr; 115 | } 116 | } 117 | 118 | template 119 | void Tensor::to(DeviceType newDevice, 120 | std::shared_ptr newAllocator) { 121 | if (newDevice == allocator_->getDeviceType()) return; 122 | 123 | if (newAllocator->getDeviceType() != newDevice) { 124 | throw std::runtime_error( 125 | "Provided allocator does not match the requested device type"); 126 | } 127 | 128 | T* newData = static_cast(newAllocator->allocate(sizeof(T) * size_)); 129 | 130 | if (allocator_->getDeviceType() == DeviceType::CPU && 131 | newDevice == DeviceType::CUDA) { 132 | cudaMemcpy(newData, data_, size_ * sizeof(T), cudaMemcpyHostToDevice); 133 | } else if (allocator_->getDeviceType() == DeviceType::CUDA && 134 | newDevice == DeviceType::CPU) { 135 | cudaMemcpy(newData, data_, size_ * sizeof(T), cudaMemcpyDeviceToHost); 136 | } else { 137 | throw std::runtime_error("Unsupported device transition"); 138 | } 139 | 140 | allocator_->deallocate(data_); 141 | data_ = newData; 142 | allocator_ = newAllocator; 143 | } 144 | 145 | template 146 | void Tensor::setShape(const std::vector& newShape) { 147 | size_t size = std::accumulate(newShape.begin(), newShape.end(), 1ULL, 148 | std::multiplies()); 149 | if (size != size_) { 150 | throw std::runtime_error("Expected setShape to match current shape."); 151 | } 152 | shape_ = std::move(newShape); 153 | } 154 | 155 | template 156 | std::string Tensor::stringShape() const { 157 | std::ostringstream oss; 158 | oss << "("; 159 | for (size_t i = 0; i < shape_.size(); ++i) { 160 | oss << shape_[i]; 161 | if (i < shape_.size() - 1) oss << ", "; 162 | } 163 | oss << ")"; 164 | return oss.str(); 165 | } 166 | 167 | template 168 | void Tensor::printShape() const { 169 | std::cout << "Tensor Shape: " << stringShape() << std::endl; 170 | } 171 | 172 | template 173 | std::string Tensor::toString() const { 174 | std::ostringstream oss; 175 | oss << "Tensor(" << stringShape() << ") ["; 176 | 177 | if (device() == DeviceType::CPU) { 178 | for (uint64_t i = 0; i < size_; ++i) { 179 | oss << std::setprecision(4) << data_[i]; 180 | if (i < size_ - 1) oss << ", "; 181 | if (i > 0 && i % 10 == 0) oss << "\n "; 182 | } 183 | } else { 184 | oss << "GPU data"; 185 | } 186 | 187 | oss << "]"; 188 | return oss.str(); 189 | } 190 | 191 | template 192 | void Tensor::print() const { 193 | std::cout << toString() << std::endl; 194 | } 195 | 196 | // Explicit instantiation for float 197 | template class Tensor; -------------------------------------------------------------------------------- /src/tensor.h: -------------------------------------------------------------------------------- 1 | #ifndef TENSOR_H 2 | #define TENSOR_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include "cpu_allocator.h" 10 | #include "device.h" 11 | #include "memory_allocator.h" 12 | 13 | template 14 | class Tensor { 15 | public: 16 | // Constructor with shape and optional allocator 17 | Tensor(const std::vector& shape, 18 | std::shared_ptr alloc = std::make_shared()); 19 | 20 | // Constructor with data, shape, and optional allocator 21 | Tensor(const T* data, const std::vector& shape, 22 | std::shared_ptr alloc = std::make_shared()); 23 | 24 | // Copy constructor 25 | Tensor(const Tensor& other); 26 | 27 | // Move constructor 28 | Tensor(Tensor&& other) noexcept; 29 | 30 | // Copy assignment operator 31 | Tensor& operator=(const Tensor& other); 32 | 33 | // Move assignment operator 34 | Tensor& operator=(Tensor&& other) noexcept; 35 | 36 | // Destructor 37 | ~Tensor(); 38 | 39 | // Method to move tensor to another device 40 | void to(DeviceType newDevice, std::shared_ptr newAllocator = 41 | std::make_shared()); 42 | 43 | // Modifications 44 | void setShape(const std::vector& newShape); 45 | 46 | // Getter methods 47 | const std::vector& shape() const { return shape_; } 48 | size_t size() const { return size_; } 49 | DeviceType device() const { return allocator_->getDeviceType(); } 50 | T* data() { return data_; } 51 | const T* data() const { return data_; } 52 | 53 | // Print methods 54 | std::string stringShape() const; 55 | void printShape() const; 56 | std::string toString() const; 57 | void print() const; 58 | 59 | static Tensor stack(const std::vector>& tensors) { 60 | if (tensors.empty()) { 61 | throw std::invalid_argument("Cannot stack empty vector of tensors"); 62 | } 63 | 64 | // Check if all tensors have the same shape 65 | const auto& firstShape = tensors[0].shape(); 66 | for (size_t i = 1; i < tensors.size(); ++i) { 67 | if (tensors[i].shape() != firstShape) { 68 | throw std::invalid_argument( 69 | "All tensors must have the same shape"); 70 | } 71 | } 72 | 73 | // Calculate new shape 74 | std::vector newShape = firstShape; 75 | newShape[0] = tensors.size() * firstShape[0]; 76 | 77 | // Create new tensor with the calculated shape 78 | Tensor result(newShape, tensors[0].allocator_); 79 | 80 | // Copy data from input tensors to the new tensor. Only supports 81 | // stacking on CPU. 82 | size_t offset = 0; 83 | for (const auto& tensor : tensors) { 84 | std::memcpy(result.data_ + offset, tensor.data_, 85 | tensor.size_ * sizeof(T)); 86 | offset += tensor.size_; 87 | } 88 | 89 | return result; 90 | } 91 | 92 | private: 93 | T* data_; 94 | std::vector shape_; 95 | size_t size_; 96 | std::shared_ptr allocator_; 97 | 98 | void allocateAndCopy(const T* data); 99 | void freeMemory(); 100 | }; 101 | 102 | #endif // TENSOR_H -------------------------------------------------------------------------------- /src/test/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | enable_testing() 2 | 3 | # Unit tests 4 | foreach(test_name gemm tensor node operators) 5 | add_executable(${test_name}_test ${test_name}_test.cpp) 6 | target_link_libraries(${test_name}_test PRIVATE engine_lib GTest::GTest GTest::Main) 7 | add_test(NAME ${test_name}_test COMMAND ${test_name}_test) 8 | endforeach() 9 | 10 | add_executable(gemm_cuda_test gemm_cuda_test.cpp ../kernels.cu) 11 | target_link_libraries(gemm_cuda_test PRIVATE engine_lib GTest::GTest GTest::Main ${CUDA_LIBRARIES}) 12 | target_include_directories(gemm_cuda_test PRIVATE ${CUDA_INCLUDE_DIRS}) 13 | set_target_properties(gemm_cuda_test PROPERTIES CUDA_SEPARABLE_COMPILATION ON CUDA_DEBUG_MODE ON) 14 | add_test(NAME gemm_cuda_test COMMAND gemm_cuda_test) 15 | 16 | # Benchmakrs 17 | 18 | # GEMM Benchmark 19 | add_executable(gemm_bench gemm_bench.cpp) 20 | target_link_libraries(gemm_bench 21 | PRIVATE 22 | engine_lib 23 | benchmark::benchmark 24 | ${CUDA_LIBRARIES} 25 | ) 26 | target_include_directories(gemm_bench PRIVATE ${CUDA_INCLUDE_DIRS}) 27 | 28 | # Set CUDA properties for gemm_bench 29 | set_target_properties(gemm_bench PROPERTIES 30 | CUDA_SEPARABLE_COMPILATION ON 31 | ) 32 | -------------------------------------------------------------------------------- /src/test/gemm_bench.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | 7 | #include "../kernels.h" 8 | 9 | // Helper function to initialize a matrix with random values 10 | void initializeMatrix(std::vector& matrix, int size) { 11 | for (int i = 0; i < size; ++i) { 12 | matrix[i] = static_cast(rand()) / RAND_MAX; 13 | } 14 | } 15 | 16 | // Benchmark function 17 | static void BM_GEMM(benchmark::State& state) { 18 | const int n = state.range(0); 19 | const int m = state.range(1); 20 | const int k = state.range(2); 21 | const int type = state.range(3); 22 | 23 | // Allocate and initialize host matrices 24 | std::vector h_A(n * m), h_B(m * k), h_bias(k), h_out(n * k); 25 | initializeMatrix(h_A, n * m); 26 | initializeMatrix(h_B, m * k); 27 | initializeMatrix(h_bias, k); 28 | 29 | // Allocate device memory 30 | float *d_A, *d_B, *d_bias, *d_out; 31 | cudaMalloc(&d_A, n * m * sizeof(float)); 32 | cudaMalloc(&d_B, m * k * sizeof(float)); 33 | cudaMalloc(&d_bias, k * sizeof(float)); 34 | cudaMalloc(&d_out, n * k * sizeof(float)); 35 | 36 | // Copy data to device 37 | cudaMemcpy(d_A, h_A.data(), n * m * sizeof(float), cudaMemcpyHostToDevice); 38 | cudaMemcpy(d_B, h_B.data(), m * k * sizeof(float), cudaMemcpyHostToDevice); 39 | cudaMemcpy(d_bias, h_bias.data(), k * sizeof(float), 40 | cudaMemcpyHostToDevice); 41 | 42 | // Benchmark loop 43 | for (auto _ : state) { 44 | // Create a CUDA event to measure GPU time 45 | cudaEvent_t start, stop; 46 | cudaEventCreate(&start); 47 | cudaEventCreate(&stop); 48 | 49 | // Record the start event 50 | cudaEventRecord(start, nullptr); 51 | if (type == 0) { 52 | gemm_cuda_naive(d_A, d_B, d_bias, d_out, n, m, k, false, false, 53 | 1.0f, 1.0f); 54 | } else if (type == 1) { 55 | gemm_cuda_tiled(d_A, d_B, d_bias, d_out, n, m, k, false, false, 56 | 1.0f, 1.0f); 57 | } else if (type == 2) { 58 | gemm_cuda_tiled_1D(d_A, d_B, d_bias, d_out, n, m, k, false, false, 59 | 1.0f, 1.0f); 60 | } else if (type == 3) { 61 | gemm_tiled_1D_blocktiling(d_A, d_B, d_bias, d_out, n, m, k, false, 62 | false, 1.0f, 1.0f); 63 | } 64 | 65 | // Record the stop event 66 | cudaEventRecord(stop, nullptr); 67 | cudaEventSynchronize(stop); 68 | 69 | // Calculate the elapsed time in milliseconds 70 | float milliseconds = 0; 71 | cudaEventElapsedTime(&milliseconds, start, stop); 72 | 73 | // Destroy the events 74 | cudaEventDestroy(start); 75 | cudaEventDestroy(stop); 76 | 77 | // Report the time 78 | state.SetIterationTime(milliseconds / 1000.0); 79 | } 80 | 81 | // Calculate and report throughput 82 | state.SetBytesProcessed(int64_t(state.iterations()) * n * m * k * 83 | sizeof(float)); 84 | state.SetItemsProcessed(int64_t(state.iterations()) * n * m * k); 85 | 86 | // Free device memory 87 | cudaFree(d_A); 88 | cudaFree(d_B); 89 | cudaFree(d_bias); 90 | cudaFree(d_out); 91 | } 92 | 93 | // Define the benchmark 94 | BENCHMARK(BM_GEMM) 95 | ->Args({4092, 4092, 4092, 0}) // n, m, k, (0 for naive) 96 | ->Args({4092, 4092, 4092, 1}) // n, m, k, (1 for tiled) 97 | ->Args({4092, 4092, 4092, 2}) // n, m, k, (2 for tiled_1D) 98 | ->Args({4092, 4092, 4092, 3}) // n, m, k, (3 for 1D block tiling) 99 | ->Unit(benchmark::kMillisecond) 100 | ->UseManualTime(); 101 | 102 | BENCHMARK_MAIN(); -------------------------------------------------------------------------------- /src/test/gemm_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | #include "../kernels.h" 7 | 8 | // Helper function to initialize a matrix with random values 9 | void initializeRandomMatrix(float* matrix, int rows, int cols) { 10 | std::random_device rd; 11 | std::mt19937 gen(rd()); 12 | std::uniform_real_distribution<> dis(-1.0, 1.0); 13 | 14 | for (int i = 0; i < rows * cols; ++i) { 15 | matrix[i] = dis(gen); 16 | } 17 | } 18 | 19 | // Reference CPU implementation of matrix multiplication 20 | void gemm_cpu_reference(const float* A, const float* B, const float* bias, 21 | float* C, int n, int m, int k, bool transA, bool transB, 22 | float alpha, float beta) { 23 | for (int i = 0; i < n; ++i) { 24 | for (int j = 0; j < k; ++j) { 25 | float sum = 0.0f; 26 | for (int l = 0; l < m; ++l) { 27 | int a_idx = transA ? l * n + i : i * m + l; 28 | int b_idx = transB ? j * m + l : l * k + j; 29 | sum += A[a_idx] * B[b_idx]; 30 | } 31 | C[i * k + j] = alpha * sum + beta * (bias ? bias[j] : 0.0f); 32 | } 33 | } 34 | } 35 | 36 | void runGemmTest(int n, int m, int k, bool transA, bool transB, float alpha, 37 | float beta) { 38 | size_t sizeA = n * m; 39 | size_t sizeB = m * k; 40 | size_t sizeC = n * k; 41 | 42 | // Allocate host memory 43 | std::vector h_A(sizeA); 44 | std::vector h_B(sizeB); 45 | std::vector h_bias(k); 46 | std::vector h_C_cuda(sizeC); 47 | std::vector h_C_cpu(sizeC); 48 | 49 | // Initialize matrices 50 | initializeRandomMatrix(h_A.data(), n, m); 51 | initializeRandomMatrix(h_B.data(), m, k); 52 | initializeRandomMatrix(h_bias.data(), 1, k); 53 | 54 | // Allocate device memory 55 | float *d_A, *d_B, *d_bias, *d_C; 56 | cudaMalloc(&d_A, sizeA * sizeof(float)); 57 | cudaMalloc(&d_B, sizeB * sizeof(float)); 58 | cudaMalloc(&d_bias, k * sizeof(float)); 59 | cudaMalloc(&d_C, sizeC * sizeof(float)); 60 | 61 | // Copy data to device 62 | cudaMemcpy(d_A, h_A.data(), sizeA * sizeof(float), cudaMemcpyHostToDevice); 63 | cudaMemcpy(d_B, h_B.data(), sizeB * sizeof(float), cudaMemcpyHostToDevice); 64 | cudaMemcpy(d_bias, h_bias.data(), k * sizeof(float), 65 | cudaMemcpyHostToDevice); 66 | 67 | // Run CUDA kernel 68 | gemm_tiled_1D_blocktiling(d_A, d_B, d_bias, d_C, n, m, k, transA, transB, 69 | alpha, beta); 70 | 71 | // Copy result back to host 72 | cudaMemcpy(h_C_cuda.data(), d_C, sizeC * sizeof(float), 73 | cudaMemcpyDeviceToHost); 74 | 75 | // Run CPU reference implementation 76 | gemm_cpu_reference(h_A.data(), h_B.data(), h_bias.data(), h_C_cpu.data(), n, 77 | m, k, transA, transB, alpha, beta); 78 | 79 | // Free device memory 80 | cudaFree(d_A); 81 | cudaFree(d_B); 82 | cudaFree(d_bias); 83 | cudaFree(d_C); 84 | } 85 | 86 | int main() { runGemmTest(64, 64, 64, false, false, 1.0f, 1.0f); } -------------------------------------------------------------------------------- /src/test/gemm_cuda_test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | 7 | #include "../kernels.h" 8 | 9 | #define TEST_GEMM(test_name, n, m, k) \ 10 | TEST_F(GemmCudaTest, test_name##_NoTranspose) { \ 11 | runGemmTest(n, m, k, false, false, 1.0f, 1.0f); \ 12 | } \ 13 | TEST_F(GemmCudaTest, test_name##_TransposeB) { \ 14 | runGemmTest(n, m, k, false, true, 1.0f, 1.0f); \ 15 | } 16 | 17 | // Helper function to initialize a matrix with random values 18 | void initializeRandomMatrix(float* matrix, int rows, int cols) { 19 | std::random_device rd; 20 | std::mt19937 gen(rd()); 21 | std::uniform_real_distribution<> dis(-1.0, 1.0); 22 | 23 | for (int i = 0; i < rows * cols; ++i) { 24 | matrix[i] = dis(gen); 25 | } 26 | } 27 | 28 | // Helper function to compare matrices 29 | void compareMatrices(const float* A, const float* B, int size, 30 | float tolerance = 1e-5) { 31 | for (int i = 0; i < size; ++i) { 32 | EXPECT_NEAR( 33 | A[i], B[i], 34 | tolerance); // Allow small tolerance for floating-point errors 35 | } 36 | } 37 | 38 | // Reference CPU implementation of matrix multiplication 39 | void gemm_cpu_reference(const float* A, const float* B, const float* bias, 40 | float* C, int n, int m, int k, bool transA, bool transB, 41 | float alpha, float beta) { 42 | for (int i = 0; i < n; ++i) { 43 | for (int j = 0; j < k; ++j) { 44 | float sum = 0.0f; 45 | for (int l = 0; l < m; ++l) { 46 | int a_idx = transA ? l * n + i : i * m + l; 47 | int b_idx = transB ? j * m + l : l * k + j; 48 | sum += A[a_idx] * B[b_idx]; 49 | } 50 | C[i * k + j] = alpha * sum + beta * (bias ? bias[j] : 0.0f); 51 | } 52 | } 53 | } 54 | 55 | class GemmCudaTest : public ::testing::Test { 56 | protected: 57 | void SetUp() override { 58 | // CUDA initialization if needed 59 | } 60 | 61 | void TearDown() override { 62 | // CUDA cleanup if needed 63 | } 64 | 65 | void runGemmTest(int n, int m, int k, bool transA, bool transB, float alpha, 66 | float beta) { 67 | size_t sizeA = n * m; 68 | size_t sizeB = m * k; 69 | size_t sizeC = n * k; 70 | 71 | // Allocate host memory 72 | std::vector h_A(sizeA); 73 | std::vector h_B(sizeB); 74 | std::vector h_bias(k); 75 | std::vector h_C_cuda(sizeC); 76 | std::vector h_C_cpu(sizeC); 77 | 78 | // Initialize matrices 79 | initializeRandomMatrix(h_A.data(), n, m); 80 | initializeRandomMatrix(h_B.data(), m, k); 81 | initializeRandomMatrix(h_bias.data(), 1, k); 82 | 83 | // Allocate device memory 84 | float *d_A, *d_B, *d_bias, *d_C; 85 | cudaMalloc(&d_A, sizeA * sizeof(float)); 86 | cudaMalloc(&d_B, sizeB * sizeof(float)); 87 | cudaMalloc(&d_bias, k * sizeof(float)); 88 | cudaMalloc(&d_C, sizeC * sizeof(float)); 89 | 90 | // Copy data to device 91 | cudaMemcpy(d_A, h_A.data(), sizeA * sizeof(float), 92 | cudaMemcpyHostToDevice); 93 | cudaMemcpy(d_B, h_B.data(), sizeB * sizeof(float), 94 | cudaMemcpyHostToDevice); 95 | cudaMemcpy(d_bias, h_bias.data(), k * sizeof(float), 96 | cudaMemcpyHostToDevice); 97 | 98 | // Run CUDA kernel 99 | gemm_tiled_1D_blocktiling(d_A, d_B, d_bias, d_C, n, m, k, transA, 100 | transB, alpha, beta); 101 | 102 | // Copy result back to host 103 | cudaMemcpy(h_C_cuda.data(), d_C, sizeC * sizeof(float), 104 | cudaMemcpyDeviceToHost); 105 | 106 | // Run CPU reference implementation 107 | gemm_cpu_reference(h_A.data(), h_B.data(), h_bias.data(), 108 | h_C_cpu.data(), n, m, k, transA, transB, alpha, 109 | beta); 110 | 111 | // Compare results 112 | compareMatrices(h_C_cuda.data(), h_C_cpu.data(), sizeC); 113 | 114 | // std::cout << "cuda:\n"; 115 | // for (int y = 0; y < k; ++y) { 116 | // for (int x = 0; x < n; ++x) { 117 | // std::cout << h_C_cuda[y * n + x] << " "; 118 | // } 119 | // std::cout << "\n"; 120 | // } 121 | 122 | // std::cout << "cpu:\n"; 123 | // for (int y = 0; y < k; ++y) { 124 | // for (int x = 0; x < n; ++x) { 125 | // std::cout << h_C_cpu[y * n + x] << " "; 126 | // } 127 | // std::cout << "\n"; 128 | // } 129 | 130 | // Free device memory 131 | cudaFree(d_A); 132 | cudaFree(d_B); 133 | cudaFree(d_bias); 134 | cudaFree(d_C); 135 | } 136 | }; 137 | TEST_GEMM(OneSquare, 1, 1, 1) 138 | TEST_GEMM(TwoSquare, 2, 2, 2) 139 | TEST_GEMM(FourSquare, 4, 4, 4) 140 | TEST_GEMM(FiveSquare, 5, 5, 5) 141 | TEST_GEMM(SixteenSquare, 16, 16, 16) 142 | TEST_GEMM(ThirtyTwoSquare, 32, 32, 32) 143 | TEST_GEMM(SixtyFourSquare, 64, 64, 64) 144 | TEST_GEMM(LargeSquare, 65, 65, 65) 145 | TEST_GEMM(LargerSquare, 128, 128, 128) 146 | TEST_GEMM(BigSquare, 1028, 1028, 1028) 147 | TEST_GEMM(SmallNonSquare, 2, 4, 8) 148 | TEST_GEMM(MediumNonSquare, 31, 15, 43) 149 | TEST_GEMM(NonSquare, 6, 4, 2) 150 | TEST_GEMM(BigNonSquare, 333, 1027, 717) 151 | 152 | int main(int argc, char** argv) { 153 | ::testing::InitGoogleTest(&argc, argv); 154 | return RUN_ALL_TESTS(); 155 | } -------------------------------------------------------------------------------- /src/test/gemm_test.cpp: -------------------------------------------------------------------------------- 1 | #include "../gemm_cpu.h" 2 | #include "../kernels.h" 3 | #include "gtest/gtest.h" 4 | 5 | TEST(GemmTest, NoBias) { 6 | const int n = 2; 7 | const int m = 2; 8 | const int k = 1; // Must be 1 for this gemm implementation 9 | 10 | float A[n * m] = {1.0, 1.0, 1.0, 1.0}; 11 | float B[m * k] = {2.0, 3.0}; 12 | float C[n] = {}; // 0 Bias vector 13 | 14 | float expected[n] = {5.0, 5.0}; 15 | float out[n]; 16 | 17 | gemm_cpu(A, B, C, out, n, m, k, false, false, 1, 1); 18 | 19 | for (int i = 0; i < n; ++i) { 20 | EXPECT_NEAR(out[i], expected[i], 21 | 1e-5); // Allow small tolerance for floating-point errors 22 | } 23 | } 24 | 25 | TEST(GemmTest, MatrixVectorMultiplication) { 26 | const int n = 2; 27 | const int m = 2; 28 | const int k = 1; 29 | 30 | float A[n * m] = {1.0, 2.0, 3.0, 4.0}; 31 | float B[m * k] = {5.0, 6.0}; 32 | float C[n] = {7.0, 8.0}; // Bias vector 33 | 34 | float expected[n] = {24.0, 47.0}; 35 | float out[n]; 36 | 37 | gemm_cpu(A, B, C, out, n, m, k, false, false, 1, 1); 38 | 39 | for (int i = 0; i < n; ++i) { 40 | EXPECT_NEAR(out[i], expected[i], 41 | 1e-5); // Allow small tolerance for floating-point errors 42 | } 43 | } 44 | 45 | TEST(GemmTest, DifferentDimensions) { 46 | const int n = 2; 47 | const int m = 3; 48 | const int k = 1; 49 | 50 | float A[n * m] = {1.0, 1.0, 1.0, 2.0, 2.0, 2.0}; 51 | float B[m * k] = {2.0, 3.0, 4.0}; 52 | float C[n] = {1.0, 1.0}; 53 | 54 | float expected[n] = {10.0, 19.0}; 55 | float out[n]; 56 | 57 | gemm_cpu(A, B, C, out, n, m, k, false, false, 1, 1); 58 | 59 | for (int i = 0; i < n; ++i) { 60 | EXPECT_NEAR(out[i], expected[i], 1e-5); 61 | } 62 | } 63 | 64 | // Generated with python. 65 | TEST(GemmTest, LargeMatrix) { 66 | const int n = 10; 67 | const int m = 5; 68 | const int k = 1; 69 | 70 | float A[n * m] = { 71 | 0.37406192933141147, 0.06745353143398714, 0.1360688163646061, 72 | 0.7562505073146693, 0.5428856927950624, 0.31954021763742935, 73 | 0.8946841711756393, 0.7743437586367583, 0.14352502492052077, 74 | 0.41265414487951024, 0.8296688316398949, 0.27251306465433534, 75 | 0.14776552565735301, 0.3950728124771392, 0.7687820220977069, 76 | 0.16392780073148927, 0.7402754740346736, 0.7243936043094837, 77 | 0.8633784327013307, 0.3487963299151927, 0.6669684460103615, 78 | 0.3531235616889036, 0.9760224780679175, 0.7263405530192332, 79 | 0.7516497260607778, 0.28503760873098205, 0.46137578945845437, 80 | 0.20012238575099484, 0.10529354612903685, 0.9239620151968889, 81 | 0.6608644427436974, 0.5586669905827026, 0.9818996653233024, 82 | 0.8885926029218781, 0.984107985981659, 0.0027823158763550238, 83 | 0.4801631350272688, 0.34023762153957926, 0.9783240486545588, 84 | 0.4044755719976537, 0.23219208762135335, 0.7583056863535441, 85 | 0.5877358671639532, 0.9278162173103481, 0.418900146932737, 86 | 0.2351060992425743, 0.5169877034859311, 0.4746289284042273, 87 | 0.4689403087502675, 0.08573114728236486}; 88 | float B[m * k] = {0.6388001442740947, 0.7295867686529213, 89 | 0.022577391563295524, 0.3353350568167005, 90 | 0.1839017387537819}; 91 | float C[n] = {0.9533634522173151, 0.6163495493570798, 0.183166574752734, 92 | 0.4766469611276871, 0.2677701486587737, 0.2026184257272713, 93 | 0.8402237268446139, 0.6291182992377615, 0.6915170468341604, 94 | 0.4516003466136026}; 95 | 96 | float expected[n] = {1.5980344793090224, 1.6147210692751712, 97 | 1.1891793450983736, 1.4914794802764628, 98 | 1.3552971503375555, 0.9310592425482995, 99 | 2.1711042265782834, 1.391348102552378, 100 | 1.794526493404922, 1.1627076600419188}; 101 | float out[n]; 102 | 103 | gemm_cpu(A, B, C, out, n, m, k, false, false, 1, 1); 104 | 105 | for (int i = 0; i < n; ++i) { 106 | EXPECT_NEAR(out[i], expected[i], 1e-5); 107 | } 108 | } 109 | 110 | TEST(GemmTest, AllZeroes) { 111 | const int n = 4; 112 | const int m = 3; 113 | const int k = 1; 114 | 115 | float A[n * m] = {}; 116 | float B[m * k] = {}; 117 | float C[n] = {}; 118 | 119 | float expected[n] = {}; 120 | float out[n]; 121 | 122 | gemm_cpu(A, B, C, out, n, m, k, false, false, 1, 1); 123 | 124 | for (int i = 0; i < n; ++i) { 125 | EXPECT_NEAR(out[i], expected[i], 1e-5); 126 | } 127 | } 128 | -------------------------------------------------------------------------------- /src/test/node_test.cpp: -------------------------------------------------------------------------------- 1 | #include "../node.h" 2 | 3 | #include "gtest/gtest.h" 4 | 5 | TEST(NodeTest, Constructor) { 6 | // Tensor setup 7 | Node node("relu1", OpType::Relu); 8 | std::vector inputs{"A", "B", "bias"}; 9 | node.addInput(inputs[0]); 10 | node.addInput(inputs[1]); 11 | node.addInput(inputs[2]); 12 | 13 | EXPECT_EQ(node.getName(), "relu1"); 14 | EXPECT_EQ(node.getInputs(), inputs); 15 | } 16 | -------------------------------------------------------------------------------- /src/test/operators_test.cpp: -------------------------------------------------------------------------------- 1 | #include "../operators.h" 2 | 3 | #include "../tensor.h" 4 | #include "gtest/gtest.h" 5 | 6 | TEST(OperatorsTest, flatten) { 7 | const float data[] = {1, 2, 3, 4, 5, 6, 7, 8}; 8 | std::vector shape = {1, 1, 2, 4}; 9 | Tensor t1(data, shape); 10 | 11 | // Note: flatten modifies the original tensor in-place. 12 | Tensor t2 = CpuOperators::flatten(t1, uint64_t{3}); 13 | std::vector expectedShape1{2, 4}; 14 | EXPECT_EQ(expectedShape1, t2.shape()); 15 | 16 | Tensor t3 = CpuOperators::flatten(t2, uint64_t{0}); 17 | std::vector expectedShape2{1, 8}; 18 | EXPECT_EQ(expectedShape2, t3.shape()); 19 | } 20 | 21 | TEST(OperatorsTest, relu) { 22 | const float data1[] = {-1.0f, 0, 1, -1, 1, -0.5f, -0, 0.5f}; 23 | std::vector shape1 = {2, 4}; 24 | Tensor t1(data1, shape1); 25 | auto t2 = CpuOperators::relu(t1); 26 | std::vector expected{0, 0, 1, 0, 1, 0, 0, 0.5}; 27 | 28 | for (std::size_t i = 0; i < t2.size(); ++i) { 29 | EXPECT_EQ(t2.data()[i], expected[i]); 30 | } 31 | 32 | const float data2[] = {1, 1, 1, 1}; 33 | std::vector shape2 = {2, 2}; 34 | Tensor t3(data2, shape2); 35 | 36 | auto t4 = CpuOperators::relu(t3); 37 | std::vector expected2{1, 1, 1, 1}; 38 | for (std::size_t i = 0; i < t4.size(); ++i) { 39 | EXPECT_EQ(t4.data()[i], expected2[i]); 40 | } 41 | } 42 | 43 | TEST(OperatorsTest, gemmMatrixVector) { 44 | const float dataA[] = {1, 2, 3, 4}; 45 | std::vector shapeA = {2, 2}; 46 | Tensor A(dataA, shapeA); 47 | 48 | const float dataB[] = {1, 1}; 49 | std::vector shapeB = {2, 1}; 50 | Tensor B(dataB, shapeB); 51 | 52 | const float dataBias[] = {1, 1}; 53 | std::vector shapeBias = {2, 1}; 54 | Tensor bias(dataBias, shapeBias); 55 | 56 | auto res = CpuOperators::gemm(A, B, bias, false, false, 1, 1); 57 | 58 | std::vector expectShape{2, 1}; 59 | EXPECT_EQ(expectShape, res.shape()); 60 | std::vector expectData{4, 8}; 61 | for (std::size_t i = 0; i < res.size(); ++i) { 62 | EXPECT_EQ(res.data()[i], expectData[i]); 63 | } 64 | } 65 | 66 | TEST(OperatorsTest, gemmMatrixMatrix) { 67 | const float dataA[] = {1, 2, 3, 4}; 68 | std::vector shapeA = {2, 2}; 69 | Tensor A(dataA, shapeA); 70 | 71 | const float dataB[] = {1, 1, 1, 1}; 72 | std::vector shapeB = {2, 2}; 73 | Tensor B(dataB, shapeB); 74 | 75 | const float dataBias[] = {1, 2}; 76 | std::vector shapeBias = {2, 1}; 77 | Tensor bias(dataBias, shapeBias); 78 | 79 | auto res = CpuOperators::gemm(A, B, bias, false, false, 1, 1); 80 | 81 | std::vector expectShape{2, 2}; 82 | EXPECT_EQ(expectShape, res.shape()); 83 | 84 | std::vector expectData{4, 4, 9, 9}; 85 | for (std::size_t i = 0; i < res.size(); ++i) { 86 | EXPECT_EQ(res.data()[i], expectData[i]); 87 | } 88 | } 89 | 90 | TEST(OperatorsTest, gemmMatrixMatrixTransA) { 91 | const float dataA[] = {1, 2}; 92 | std::vector shapeA = {2, 1}; 93 | Tensor A(dataA, shapeA); 94 | 95 | const float dataB[] = {1, 1, 1, 1}; 96 | std::vector shapeB = {2, 2}; 97 | Tensor B(dataB, shapeB); 98 | 99 | const float dataBias[] = {0, 0}; 100 | std::vector shapeBias = {2}; 101 | Tensor bias(dataBias, shapeBias); 102 | 103 | auto res = CpuOperators::gemm(A, B, bias, true, false, 1, 1); 104 | 105 | std::vector expectShape{1, 2}; 106 | EXPECT_EQ(expectShape, res.shape()); 107 | 108 | std::vector expectData{3, 3}; 109 | for (std::size_t i = 0; i < res.size(); ++i) { 110 | EXPECT_EQ(res.data()[i], expectData[i]); 111 | } 112 | } 113 | 114 | TEST(OperatorsTest, gemmMatrixMatrixTransB) { 115 | const float dataA[] = {1, 1, 1, 1}; 116 | std::vector shapeA = {2, 2}; 117 | Tensor A(dataA, shapeA); 118 | 119 | const float dataB[] = {1, 2}; 120 | std::vector shapeB = {1, 2}; 121 | Tensor B(dataB, shapeB); 122 | 123 | const float dataBias[] = {0, 0}; 124 | std::vector shapeBias = {2}; 125 | Tensor bias(dataBias, shapeBias); 126 | 127 | auto res = CpuOperators::gemm(A, B, bias, false, true, 1, 1); 128 | 129 | std::vector expectShape{2, 1}; 130 | EXPECT_EQ(expectShape, res.shape()); 131 | 132 | std::vector expectData{3, 3}; 133 | for (std::size_t i = 0; i < res.size(); ++i) { 134 | EXPECT_EQ(res.data()[i], expectData[i]); 135 | } 136 | } -------------------------------------------------------------------------------- /src/test/tensor_test.cpp: -------------------------------------------------------------------------------- 1 | #include "../tensor.h" 2 | 3 | #include "gtest/gtest.h" 4 | 5 | #ifdef USE_CUDA 6 | #include 7 | 8 | #include "../cuda_allocator.h" 9 | #endif 10 | 11 | TEST(TensorTest, ConstructorWithShape) { 12 | Tensor tensor({2, 3, 4}); 13 | EXPECT_EQ(tensor.shape(), (std::vector{2, 3, 4})); 14 | EXPECT_EQ(tensor.size(), 24); 15 | EXPECT_EQ(tensor.device(), DeviceType::CPU); 16 | EXPECT_NE(tensor.data(), nullptr); 17 | } 18 | 19 | TEST(TensorTest, ConstructorWithData) { 20 | std::vector data(24, 1.0f); 21 | Tensor tensor(data.data(), {2, 3, 4}); 22 | EXPECT_EQ(tensor.shape(), (std::vector{2, 3, 4})); 23 | EXPECT_EQ(tensor.size(), 24); 24 | EXPECT_EQ(tensor.device(), DeviceType::CPU); 25 | EXPECT_NE(tensor.data(), nullptr); 26 | 27 | // Check if data was correctly copied 28 | for (size_t i = 0; i < 24; ++i) { 29 | EXPECT_FLOAT_EQ(tensor.data()[i], 1.0f); 30 | } 31 | } 32 | 33 | TEST(TensorTest, CopyConstructor) { 34 | Tensor original({2, 3, 4}); 35 | for (size_t i = 0; i < original.size(); ++i) { 36 | original.data()[i] = static_cast(i); 37 | } 38 | 39 | Tensor copy(original); 40 | EXPECT_EQ(copy.shape(), original.shape()); 41 | EXPECT_EQ(copy.size(), original.size()); 42 | EXPECT_EQ(copy.device(), original.device()); 43 | EXPECT_NE(copy.data(), original.data()); // Ensure deep copy 44 | 45 | for (size_t i = 0; i < copy.size(); ++i) { 46 | EXPECT_FLOAT_EQ(copy.data()[i], original.data()[i]); 47 | } 48 | } 49 | 50 | TEST(TensorTest, MoveConstructor) { 51 | Tensor original({2, 3, 4}); 52 | float* originalData = original.data(); 53 | 54 | Tensor moved(std::move(original)); 55 | EXPECT_EQ(moved.shape(), (std::vector{2, 3, 4})); 56 | EXPECT_EQ(moved.size(), 24); 57 | EXPECT_EQ(moved.device(), DeviceType::CPU); 58 | EXPECT_EQ(moved.data(), originalData); 59 | 60 | // Check that the original tensor has been properly moved from 61 | EXPECT_EQ(original.data(), nullptr); 62 | EXPECT_TRUE(original.shape().empty()); 63 | EXPECT_EQ(original.size(), 0); 64 | } 65 | 66 | TEST(TensorTest, CopyAssignment) { 67 | Tensor original({2, 3, 4}); 68 | for (size_t i = 0; i < original.size(); ++i) { 69 | original.data()[i] = static_cast(i); 70 | } 71 | 72 | Tensor copy({1, 1, 1}); // Different size 73 | copy = original; 74 | 75 | EXPECT_EQ(copy.shape(), original.shape()); 76 | EXPECT_EQ(copy.size(), original.size()); 77 | EXPECT_EQ(copy.device(), original.device()); 78 | EXPECT_NE(copy.data(), original.data()); // Ensure deep copy 79 | 80 | for (size_t i = 0; i < copy.size(); ++i) { 81 | EXPECT_FLOAT_EQ(copy.data()[i], original.data()[i]); 82 | } 83 | } 84 | 85 | TEST(TensorTest, MoveAssignment) { 86 | Tensor original({2, 3, 4}); 87 | float* originalData = original.data(); 88 | 89 | Tensor moved({1, 1, 1}); // Different size 90 | moved = std::move(original); 91 | 92 | EXPECT_EQ(moved.shape(), (std::vector{2, 3, 4})); 93 | EXPECT_EQ(moved.size(), 24); 94 | EXPECT_EQ(moved.device(), DeviceType::CPU); 95 | EXPECT_EQ(moved.data(), originalData); 96 | 97 | // Check that the original tensor has been properly moved from 98 | EXPECT_EQ(original.data(), nullptr); 99 | EXPECT_TRUE(original.shape().empty()); 100 | EXPECT_EQ(original.size(), 0); 101 | } 102 | 103 | #ifdef USE_CUDA 104 | TEST(TensorTest, DeviceTransferCPUtoCUDA) { 105 | Tensor cpuTensor({2, 3, 4}); 106 | for (size_t i = 0; i < cpuTensor.size(); ++i) { 107 | cpuTensor.data()[i] = static_cast(i); 108 | } 109 | auto cudaAllocator = std::make_shared(); 110 | cpuTensor.to(DeviceType::CUDA, cudaAllocator); 111 | EXPECT_EQ(cpuTensor.device(), DeviceType::CUDA); 112 | 113 | // Create a new CPU tensor to check data 114 | Tensor checkTensor({2, 3, 4}); 115 | cudaMemcpy(checkTensor.data(), cpuTensor.data(), 116 | cpuTensor.size() * sizeof(float), cudaMemcpyDeviceToHost); 117 | 118 | for (size_t i = 0; i < checkTensor.size(); ++i) { 119 | EXPECT_FLOAT_EQ(checkTensor.data()[i], static_cast(i)); 120 | } 121 | } 122 | 123 | TEST(TensorTest, DeviceTransferCUDAtoCPU) { 124 | Tensor cudaTensor({2, 3, 4}, std::make_shared()); 125 | std::vector init_data(cudaTensor.size()); 126 | for (size_t i = 0; i < init_data.size(); ++i) { 127 | init_data[i] = static_cast(i); 128 | } 129 | cudaMemcpy(cudaTensor.data(), init_data.data(), 130 | cudaTensor.size() * sizeof(float), cudaMemcpyHostToDevice); 131 | 132 | cudaTensor.to(DeviceType::CPU); 133 | EXPECT_EQ(cudaTensor.device(), DeviceType::CPU); 134 | 135 | for (size_t i = 0; i < cudaTensor.size(); ++i) { 136 | EXPECT_FLOAT_EQ(cudaTensor.data()[i], static_cast(i)); 137 | } 138 | } 139 | #endif 140 | 141 | TEST(TensorTest, InvalidDeviceTransfer) { 142 | Tensor tensor({2, 3, 4}); 143 | EXPECT_THROW(tensor.to(static_cast(99)), std::runtime_error); 144 | } 145 | -------------------------------------------------------------------------------- /utils/infer.py: -------------------------------------------------------------------------------- 1 | import onnxruntime as rt 2 | import numpy as np 3 | import struct 4 | 5 | def load_mnist_image(filename): 6 | """Loads a single MNIST image from a .ubyte file.""" 7 | with open(filename, 'rb') as f: 8 | # Read image data as bytes 9 | image_data = f.read(28 * 28) 10 | # Convert to numpy array, reshape, and normalize 11 | image = np.frombuffer(image_data, dtype=np.uint8).reshape(1, 1, 28, 28).astype(np.float32) 12 | return image 13 | 14 | # Load the ONNX model 15 | sess = rt.InferenceSession("../models/mnist_ffn_complex.onnx") 16 | 17 | # Get input and output names 18 | input_name = sess.get_inputs()[0].name 19 | output_name = sess.get_outputs()[0].name 20 | print("input name:", input_name) 21 | print("output name:", output_name) 22 | 23 | 24 | # Load and preprocess the MNIST image (now as float32) 25 | image = load_mnist_image("../inputs/image_3.ubyte") 26 | 27 | # Run inference 28 | result = sess.run([output_name], {input_name: image}) 29 | 30 | # Get predicted class 31 | predicted_class = np.argmax(result) 32 | 33 | with np.printoptions(precision=3, suppress=True): 34 | print(result) 35 | 36 | print("Predicted class:", predicted_class) -------------------------------------------------------------------------------- /utils/infer_server.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import json 3 | import os 4 | import numpy as np 5 | 6 | def load_mnist_image(filename): 7 | """Loads a single MNIST image from a .ubyte file.""" 8 | with open(filename, 'rb') as f: 9 | # Read image data as bytes 10 | image_data = f.read(28 * 28) 11 | # Convert to numpy array, reshape, and normalize 12 | image = np.frombuffer(image_data, dtype=np.uint8).reshape(1, 1, 28, 28).astype(np.float32) 13 | return image 14 | 15 | 16 | def process_image(filename, url): 17 | image = load_mnist_image(filename) 18 | flattened_data = image.flatten().tolist() 19 | 20 | data = {"data": flattened_data} 21 | 22 | response = requests.post(url, json=data) 23 | 24 | if response.status_code == 200: 25 | result = response.json()["result"] 26 | print(f"Image is the number {np.argmax(result)}:", result) 27 | else: 28 | print(f"Error for {filename}:", response.status_code, response.text) 29 | print("Response content:", response.content) 30 | 31 | if __name__ == "__main__": 32 | inputs_folder = "/home/michal/code/inference_engine/inputs" 33 | url = "http://localhost:8080/infer" 34 | 35 | # List all files in the inputs folder 36 | files = os.listdir(inputs_folder) 37 | 38 | for file in sorted(files)[:1]: 39 | for i in range(10): 40 | full_path = os.path.join(inputs_folder, file) 41 | print(f"Processing file: {file}") 42 | process_image(full_path, url) 43 | print("------------------------") 44 | 45 | print("All images processed.") -------------------------------------------------------------------------------- /utils/matrix.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | n = 10 4 | m = 5 5 | k = 1 6 | A = np.random.rand(n, m) 7 | B = np.random.rand(m, k) 8 | C = np.random.rand(n, 1) 9 | out = A@B + C 10 | 11 | 12 | flattened = [element for row in A for element in row] 13 | # Convert elements to strings and join them with commas 14 | output = ", ".join(map(str, flattened)) 15 | print("A:") 16 | print(output) 17 | print() 18 | 19 | 20 | flattened = [element for row in B for element in row] 21 | # Convert elements to strings and join them with commas 22 | output = ", ".join(map(str, flattened)) 23 | print("B:") 24 | print(output) 25 | print() 26 | 27 | 28 | flattened = [element for row in C for element in row] 29 | # Convert elements to strings and join them with commas 30 | output = ", ".join(map(str, flattened)) 31 | print("C:") 32 | print(output) 33 | print() 34 | 35 | 36 | flattened = [element for row in out for element in row] 37 | # Convert elements to strings and join them with commas 38 | output = ", ".join(map(str, flattened)) 39 | print("out:") 40 | print(output) 41 | print() 42 | 43 | 44 | -------------------------------------------------------------------------------- /utils/mnist.py: -------------------------------------------------------------------------------- 1 | import struct 2 | from PIL import Image 3 | 4 | def extract_mnist_image(input_file, output_file, image_index): 5 | with open(input_file, 'rb') as f_in, open(output_file, 'wb') as f_out: 6 | # Read the MNIST header (we'll skip it for now) 7 | struct.unpack('>IIII', f_in.read(16)) 8 | 9 | # Seek to the start of the desired image 10 | image_offset = 16 + image_index * 784 11 | f_in.seek(image_offset) 12 | 13 | # Read the image data (784 bytes) 14 | image_data = f_in.read(784) 15 | 16 | # Write the image data to the output file 17 | f_out.write(image_data) 18 | 19 | 20 | def view_mnist_image(filename): 21 | with open(filename, 'rb') as f: 22 | image_data = f.read() 23 | 24 | image = Image.frombytes('L', (28, 28), image_data) # L means grayscale 25 | image.show() # Display the image 26 | # Or, save the image: image.save('image.png') 27 | 28 | 29 | if __name__ == "__main__": 30 | # Example usage: Extract the 10th image 31 | input_file = '../inputs/t10k-images.idx3-ubyte' 32 | 33 | for image_idx in range(100): 34 | output_file = f'../inputs/image_{image_idx}.ubyte' 35 | extract_mnist_image(input_file, output_file, image_idx) --------------------------------------------------------------------------------