├── .github └── workflows │ ├── auto_format_docs.yml │ ├── auto_format_python.yml │ └── auto_release.yml ├── LICENSE ├── README.md ├── axengine ├── __init__.py ├── _axclrt.py ├── _axclrt_capi.py ├── _axclrt_types.py ├── _axe.py ├── _axe_capi.py ├── _axe_types.py ├── _base_session.py ├── _node.py ├── _providers.py └── _session.py ├── examples ├── classification.py └── yolov5.py └── setup.py /.github/workflows/auto_format_docs.yml: -------------------------------------------------------------------------------- 1 | name: Format Code 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | format: 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - name: Checkout code 17 | uses: actions/checkout@v2 18 | 19 | - name: Set up Node.js 20 | uses: actions/setup-node@v2 21 | with: 22 | node-version: "16" 23 | 24 | - name: Install Prettier 25 | run: | 26 | npm install --save-dev prettier 27 | 28 | - name: Format YAML and Markdown files 29 | run: | 30 | npx prettier --write "**/*.yml" "**/*.md" 31 | 32 | - name: Check for changes 33 | id: check_changes 34 | run: | 35 | git diff --exit-code || echo "Changes detected" 36 | 37 | - name: Commit changes 38 | if: steps.check_changes.outputs.exit_code != 0 39 | run: | 40 | git config --local user.name "github-actions" 41 | git config --local user.email "github-actions@github.com" 42 | git add . 43 | git commit -m "Format YAML and Markdown files with Prettier" || echo "No changes to commit" 44 | git push 45 | -------------------------------------------------------------------------------- /.github/workflows/auto_format_python.yml: -------------------------------------------------------------------------------- 1 | name: Auto Format Python Code with Black 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | format: 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - name: Checkout code 17 | uses: actions/checkout@v2 18 | 19 | - name: Set up Python 20 | uses: actions/setup-python@v2 21 | with: 22 | python-version: "3.8" 23 | 24 | - name: Install dependencies 25 | run: | 26 | python -m pip install --upgrade pip 27 | pip install black 28 | 29 | - name: Run Black 30 | run: | 31 | black . 32 | 33 | - name: Check for changes 34 | id: check_changes 35 | run: | 36 | git diff --exit-code || echo "Changes detected" 37 | 38 | - name: Commit changes 39 | if: steps.check_changes.outputs.exit_code != 0 40 | run: | 41 | git config --local user.name "github-actions" 42 | git config --local user.email "github-actions@github.com" 43 | git add . 44 | git commit -m "Format Python code with Black" || echo "No changes to commit" 45 | git push 46 | -------------------------------------------------------------------------------- /.github/workflows/auto_release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | build: 9 | runs-on: ubuntu-latest 10 | 11 | steps: 12 | - name: Checkout code 13 | uses: actions/checkout@v2 14 | 15 | - name: Set up Python 16 | uses: actions/setup-python@v2 17 | with: 18 | python-version: "3.8" 19 | 20 | - name: Install build dependencies 21 | run: | 22 | python -m pip install --upgrade pip setuptools wheel 23 | 24 | - name: Build the package 25 | run: | 26 | python setup.py bdist_wheel 27 | 28 | - name: Upload the package 29 | uses: actions/upload-artifact@v4 30 | with: 31 | name: python-package 32 | path: dist/*.whl 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2024, AXERA 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyAXEngine 2 | 3 | [![License](https://img.shields.io/badge/license-BSD--3--Clause-blue.svg)](https://raw.githubusercontent.com/AXERA-TECH/pyaxengine/main/LICENSE) 4 | 5 | ## 简介 6 | 7 | **PyAXEngine** 基于 cffi 模块实现了 Axera NPU Runtime 的 Python API,其 Python API 与 ONNXRuntime 高度兼(相)容(似),并同时支持开发板和M.2算力卡形态,方便开源社区开发者使用 8 | Python 脚本快速构建 NPU 推理脚本 9 | 10 | 支持芯片 11 | 12 | - AX650N 13 | - AX630C 14 | 15 | 环境版本 16 | 17 | - python >= 3.8 18 | - cffi >= 1.0.0 19 | - ml-dtypes >= 0.1.0 20 | - numpy >= 1.22.0 21 | 22 | *需要注意的是,如果您的开发环境是算力卡,那么更建议您优先考虑使用 [pyAXCL](https://github.com/AXERA-TECH/pyaxcl) 进行项目开发;pyAXCL 项目完整包含了算力卡形态的全部 API,更适合用于正式部署;PyAXEngine 项目更适合算法工程师进行快速原型验证,且用于计算卡环境时,PyAXEngine 不能调用编解码等模块(不是 PyAXEngine 的设计目标)。* 23 | 24 | *AX650 SDK 2.18,AX620E SDK 3.12 以前的版本不支持 bf16,llm 模型会有返回 unknown 的 dtype问题,请注意升级* 25 | 26 | *如果您评估认为不知道如何升级 SDK,也可以提交 issue 索要下载,不需要更新完整 SDK,只更新 libax_engine.so 即可* 27 | 28 | ## 快速上手 29 | 30 | 基于社区开发板 **爱芯派Pro(AX650N)** 进行展示 31 | 32 | ### 获取 wheel 包并安装 33 | 34 | - [下载链接](https://github.com/AXERA-TECH/pyaxengine/releases/latest) 35 | - 将 `axengine-x.x.x-py3-none-any.whl` 拷贝到开发板上,执行 `pip install axengine-x.x.x-py3-none-any.whl` 安装 36 | 37 | ### 简单示例 38 | 39 | 当前示例需要分别依赖 PIL 和 OpenCV,可以用 `pip install pillow opencv-python-headless` 安装。其中 `opencv-python-headless` 是 OpenCV 的 headless 版本,不依赖 GUI(非 headless 的版本需要依赖 OpenGL ES,运行环境中并没有)。 40 | 41 | ```python 42 | 将 [classification.py](https://github.com/AXERA-TECH/pyaxengine/blob/main/examples/classification.py) 拷贝到开发板上并执行。 43 | 44 | ```bash 45 | root@ax650:~/samples# python3 classification.py -m /opt/data/npu/models/mobilenetv2.axmodel -i /opt/data/npu/images/cat.jpg 46 | [INFO] Available providers: ['AXCLRTExecutionProvider', 'AxEngineExecutionProvider'] 47 | [INFO] Using provider: AxEngineExecutionProvider 48 | [INFO] Chip type: ChipType.MC50 49 | [INFO] VNPU type: VNPUType.DISABLED 50 | [INFO] Engine version: 2.10.1s 51 | [INFO] Model type: 0 (single core) 52 | [INFO] Compiler version: 1.2-patch2 7e6b2b5f 53 | ------------------------------------------------------ 54 | Top 5 Predictions: 55 | Class Index: 282, Score: 9.774 56 | Class Index: 278, Score: 8.981 57 | Class Index: 277, Score: 8.453 58 | Class Index: 281, Score: 8.321 59 | Class Index: 287, Score: 7.924 60 | ------------------------------------------------------ 61 | min = 0.890 ms max = 22.417 ms avg = 1.119 ms 62 | ------------------------------------------------------ 63 | ``` 64 | 65 | 示例也演示了如何选择计算设备:这意味着既可以在 **AX650/AX630C** 等开发板上运行,也可以在 AX650 M.2 算力卡上运行。 66 | 67 | 切换计算设备的方式是通过 `-p` 参数指定,如 `-p AxEngineExecutionProvider` 表示使用开发板上的 NPU 进行推理,而 `-p AXCLRTExecutionProvider` 表示使用 M.2 算力卡进行推理。 68 | 注意:在使用 M.2 算力卡进行推理时,需要将算力卡插入宿主机上,并且已经安装驱动,详见: [axcl](https://axcl-docs.readthedocs.io/zh-cn/latest/)。 69 | 70 | ```bash 71 | root@ax650:~/samples# python3 classification.py -m /opt/data/npu/models/mobilenetv2.axmodel -i /opt/data/npu/images/cat.jpg -p AXCLRTExecutionProvider 72 | [INFO] Available providers: ['AXCLRTExecutionProvider', 'AxEngineExecutionProvider'] 73 | [INFO] Using provider: AXCLRTExecutionProvider 74 | [INFO] SOC Name: AX650N 75 | [INFO] VNPU type: VNPUType.DISABLED 76 | [INFO] Compiler version: 1.2-patch2 7e6b2b5f 77 | ------------------------------------------------------ 78 | Top 5 Predictions: 79 | Class Index: 282, Score: 9.774 80 | Class Index: 278, Score: 8.981 81 | Class Index: 277, Score: 8.453 82 | Class Index: 281, Score: 8.321 83 | Class Index: 287, Score: 7.924 84 | ------------------------------------------------------ 85 | min = 1.587 ms max = 12.624 ms avg = 1.718 ms 86 | ------------------------------------------------------ 87 | root@ax650:~/samples# python3 classification.py -m /opt/data/npu/models/mobilenetv2.axmodel -i /opt/data/npu/images/cat.jpg -p AxEngineExecutionProvider 88 | [INFO] Available providers: ['AXCLRTExecutionProvider', 'AxEngineExecutionProvider'] 89 | [INFO] Using provider: AxEngineExecutionProvider 90 | [INFO] Chip type: ChipType.MC50 91 | [INFO] VNPU type: VNPUType.DISABLED 92 | [INFO] Engine version: 2.10.1s 93 | [INFO] Model type: 0 (single core) 94 | [INFO] Compiler version: 1.2-patch2 7e6b2b5f 95 | ------------------------------------------------------ 96 | Top 5 Predictions: 97 | Class Index: 282, Score: 9.774 98 | Class Index: 278, Score: 8.981 99 | Class Index: 277, Score: 8.453 100 | Class Index: 281, Score: 8.321 101 | Class Index: 287, Score: 7.924 102 | ------------------------------------------------------ 103 | min = 0.897 ms max = 22.542 ms avg = 1.125 ms 104 | ------------------------------------------------------ 105 | ``` 106 | 107 | ## 社区贡献者 108 | 109 | - [zylo117](https://github.com/zylo117): 提供了基于 cffi 的 AXCL Runtime Python API 实现 110 | - [nnn](https://github.com/nnn112358),[HongJie Li](https://github.com/techshoww) 和 [Shinichi Tanaka](https://github.com/s1tnk) 报告 cffi 的使用问题,[Shinichi Tanaka](https://github.com/s1tnk) 提供了解决方案 111 | 112 | 113 | ## 关联项目 114 | 115 | - [ax-samples](https://github.com/AXERA-TECH/ax-samples) 116 | - [ax-llm](https://github.com/AXERA-TECH/ax-llm) 117 | - [Pulsar2](https://pulsar2-docs.readthedocs.io/zh-cn/latest/) 118 | - [AXCL](https://axcl-docs.readthedocs.io/zh-cn/latest/) 119 | - [pyAXCL](https://github.com/AXERA-TECH/pyaxcl) 120 | 121 | ## 技术讨论 122 | 123 | - Github issues 124 | - QQ 群: 139953715 125 | -------------------------------------------------------------------------------- /axengine/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved. 2 | # 3 | # This source file is the property of Axera Semiconductor Co., Ltd. and 4 | # may not be copied or distributed in any isomorphic form without the prior 5 | # written consent of Axera Semiconductor Co., Ltd. 6 | # 7 | 8 | # thanks to community contributors list below: 9 | # zylo117: https://github.com/zylo117, first implementation of the axclrt backend 10 | 11 | from ._providers import axengine_provider_name, axclrt_provider_name 12 | from ._providers import get_all_providers, get_available_providers 13 | 14 | # check if axclrt is installed, or is a supported chip(e.g. AX650, AX620E etc.) 15 | _available_providers = get_available_providers() 16 | if not _available_providers: 17 | raise ImportError( 18 | f"No providers found. Please make sure you have installed one of the following: {get_all_providers()}") 19 | print("[INFO] Available providers: ", _available_providers) 20 | 21 | from ._node import NodeArg 22 | from ._session import SessionOptions, InferenceSession 23 | -------------------------------------------------------------------------------- /axengine/_axclrt.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved. 2 | # 3 | # This source file is the property of Axera Semiconductor Co., Ltd. and 4 | # may not be copied or distributed in any isomorphic form without the prior 5 | # written consent of Axera Semiconductor Co., Ltd. 6 | # 7 | # first implementation of AXCLRTSession contributed by zylo117 8 | 9 | import atexit 10 | import os 11 | import time 12 | from typing import Any, Sequence 13 | 14 | import ml_dtypes as mldt 15 | import numpy as np 16 | 17 | from ._axclrt_capi import axclrt_cffi, axclrt_lib 18 | from ._axclrt_types import VNPUType, ModelType 19 | from ._base_session import Session, SessionOptions 20 | from ._node import NodeArg 21 | 22 | __all__: ["AXCLRTSession"] 23 | 24 | _is_axclrt_initialized = False 25 | _is_axclrt_engine_initialized = False 26 | _all_model_instances = [] 27 | 28 | 29 | def _transform_dtype(dtype): 30 | if dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_UINT8): 31 | return np.dtype(np.uint8) 32 | elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_INT8): 33 | return np.dtype(np.int8) 34 | elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_UINT16): 35 | return np.dtype(np.uint16) 36 | elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_INT16): 37 | return np.dtype(np.int16) 38 | elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_UINT32): 39 | return np.dtype(np.uint32) 40 | elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_INT32): 41 | return np.dtype(np.int32) 42 | elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_FP32): 43 | return np.dtype(np.float32) 44 | elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_BF16): 45 | return np.dtype(mldt.bfloat16) 46 | else: 47 | raise ValueError(f"Unsupported data type '{dtype}'.") 48 | 49 | def _initialize_axclrt(): 50 | global _is_axclrt_initialized 51 | ret = axclrt_lib.axclInit([]) 52 | if ret != 0: 53 | raise RuntimeError(f"Failed to initialize axcl runtime. {ret}.") 54 | _is_axclrt_initialized = True 55 | 56 | 57 | def _finalize_axclrt(): 58 | global _is_axclrt_initialized, _is_axclrt_engine_initialized 59 | for model_instance in _all_model_instances: 60 | model_instance._unload() 61 | if _is_axclrt_engine_initialized: 62 | axclrt_lib.axclrtEngineFinalize() 63 | _is_axclrt_engine_initialized = False 64 | if _is_axclrt_initialized: 65 | axclrt_lib.axclFinalize() 66 | _is_axclrt_initialized = False 67 | 68 | 69 | _initialize_axclrt() 70 | atexit.register(_finalize_axclrt) 71 | 72 | 73 | def _get_vnpu_type() -> VNPUType: 74 | vnpu_type = axclrt_cffi.new("axclrtEngineVNpuKind *") 75 | ret = axclrt_lib.axclrtEngineGetVNpuKind(vnpu_type) 76 | if ret != 0: 77 | raise RuntimeError("Failed to get VNPU attribute.") 78 | return VNPUType(vnpu_type[0]) 79 | 80 | 81 | def _get_version(): 82 | major, minor, patch = axclrt_cffi.new('int32_t *'), axclrt_cffi.new('int32_t *'), axclrt_cffi.new( 83 | 'int32_t *') 84 | axclrt_lib.axclrtGetVersion(major, minor, patch) 85 | return f'{major[0]}.{minor[0]}.{patch[0]}' 86 | 87 | 88 | class AXCLRTSession(Session): 89 | def __init__( 90 | self, 91 | path_or_bytes: str | bytes | os.PathLike, 92 | sess_options: SessionOptions | None = None, 93 | provider_options: dict[Any, Any] | None = None, 94 | **kwargs, 95 | ) -> None: 96 | super().__init__() 97 | 98 | self._device_index = 0 99 | self._io = None 100 | self._model_id = None 101 | 102 | if provider_options is not None and "device_id" in provider_options[0]: 103 | self._device_index = provider_options[0].get("device_id", 0) 104 | 105 | lst = axclrt_cffi.new("axclrtDeviceList *") 106 | ret = axclrt_lib.axclrtGetDeviceList(lst) 107 | if ret != 0 or lst.num == 0: 108 | raise RuntimeError(f"Get AXCL device failed 0x{ret:08x}, find total {lst.num} device.") 109 | 110 | if self._device_index >= lst.num: 111 | raise RuntimeError(f"Device index {self._device_index} is out of range, total {lst.num} device.") 112 | 113 | self._device_id = lst.devices[self._device_index] 114 | ret = axclrt_lib.axclrtSetDevice(self._device_id) 115 | if ret != 0 or lst.num == 0: 116 | raise RuntimeError(f"Set AXCL device failed 0x{ret:08x}.") 117 | 118 | global _is_axclrt_engine_initialized 119 | vnpu_type = axclrt_cffi.cast( 120 | "axclrtEngineVNpuKind", VNPUType.DISABLED.value 121 | ) 122 | # try to initialize NPU as disabled 123 | ret = axclrt_lib.axclrtEngineInit(vnpu_type) 124 | # if failed, try to get vnpu type 125 | if 0 != ret: 126 | vnpu = axclrt_cffi.new("axclrtEngineVNpuKind *") 127 | ret = axclrt_lib.axclrtEngineGetVNpuKind(vnpu) 128 | # if failed, that means the NPU is not available 129 | if ret != 0: 130 | raise RuntimeError(f"axclrtEngineInit as {vnpu.value} failed 0x{ret:08x}.") 131 | # if success, that means the NPU is already initialized as vnpu.value 132 | # so the initialization is failed. 133 | # this means the other users maybe uninitialized the NPU suddenly 134 | # and the app would be terminated unexpectedly at that moment. 135 | # but we can't do anything to fix this issue, just print a warning message. 136 | # it because the api looks like onnxruntime, so there no window avoid this. 137 | # such as the life. 138 | else: 139 | print(f"[WARNING] Failed to initialize NPU as {vnpu_type}, NPU is already initialized as {vnpu.value}.") 140 | # initialize NPU successfully, mark the flag to ensure the engine will be finalized 141 | else: 142 | _is_axclrt_engine_initialized = True 143 | 144 | self.soc_name = axclrt_cffi.string(axclrt_lib.axclrtGetSocName()).decode() 145 | print(f"[INFO] SOC Name: {self.soc_name}") 146 | 147 | self._thread_context = axclrt_cffi.new("axclrtContext *") 148 | ret = axclrt_lib.axclrtGetCurrentContext(self._thread_context) 149 | if ret != 0: 150 | raise RuntimeError("axclrtGetCurrentContext failed") 151 | 152 | # model handle, context, info, io 153 | self._model_id = axclrt_cffi.new("uint64_t *") 154 | self._context_id = axclrt_cffi.new("uint64_t *") 155 | 156 | # get vnpu type 157 | self._vnpu_type = _get_vnpu_type() 158 | print(f"[INFO] VNPU type: {self._vnpu_type}") 159 | 160 | # load model 161 | ret = self._load(path_or_bytes) 162 | if 0 != ret: 163 | raise RuntimeError("Failed to load model.") 164 | print(f"[INFO] Compiler version: {self._get_model_tool_version()}") 165 | 166 | # get model info 167 | self._info = self._get_info() 168 | self._shape_count = self._get_shape_count() 169 | self._inputs = self._get_inputs() 170 | self._outputs = self._get_outputs() 171 | 172 | # prepare io 173 | self._io = self._prepare_io() 174 | 175 | _all_model_instances.append(self) 176 | 177 | def __del__(self): 178 | self._unload() 179 | _all_model_instances.remove(self) 180 | 181 | def _load(self, path_or_bytes): 182 | # model buffer, almost copied from onnx runtime 183 | if isinstance(path_or_bytes, (str, os.PathLike)): 184 | _model_path = axclrt_cffi.new("char[]", path_or_bytes.encode('utf-8')) 185 | ret = axclrt_lib.axclrtEngineLoadFromFile(_model_path, self._model_id) 186 | if ret != 0: 187 | raise RuntimeError("axclrtEngineLoadFromFile failed.") 188 | elif isinstance(path_or_bytes, bytes): 189 | _model_buffer = axclrt_cffi.new("char[]", path_or_bytes) 190 | _model_buffer_size = len(path_or_bytes) 191 | 192 | dev_mem_ptr = axclrt_cffi.new('void **', axclrt_cffi.NULL) 193 | ret = axclrt_lib.axclrtMalloc(dev_mem_ptr, _model_buffer_size, axclrt_lib.AXCL_MEM_MALLOC_NORMAL_ONLY) 194 | if ret != 0: 195 | raise RuntimeError("axclrtMalloc failed.") 196 | 197 | ret = axclrt_lib.axclrtMemcpy(dev_mem_ptr[0], _model_buffer, _model_buffer_size, axclrt_lib.AXCL_MEMCPY_HOST_TO_DEVICE) 198 | if ret != 0: 199 | axclrt_lib.axclrtFree(dev_mem_ptr[0]) 200 | raise RuntimeError("axclrtMemcpy failed.") 201 | 202 | ret = axclrt_lib.axclrtEngineLoadFromMem(dev_mem_ptr[0], _model_buffer_size, self._model_id) 203 | axclrt_lib.axclrtFree(dev_mem_ptr[0]) 204 | if ret != 0: 205 | raise RuntimeError("axclrtEngineLoadFromMem failed.") 206 | else: 207 | raise TypeError(f"Unable to load model from type '{type(path_or_bytes)}'") 208 | 209 | ret = axclrt_lib.axclrtEngineCreateContext(self._model_id[0], self._context_id) 210 | if ret != 0: 211 | raise RuntimeError("axclrtEngineCreateContext failed") 212 | return ret 213 | 214 | def _unload(self): 215 | if self._io is not None: 216 | dev_size = axclrt_cffi.new("uint64_t *") 217 | dev_prt = axclrt_cffi.new("void **") 218 | for i in range(axclrt_lib.axclrtEngineGetNumInputs(self._info[0])): 219 | axclrt_lib.axclrtEngineGetInputBufferByIndex(self._io[0], i, dev_prt, dev_size) 220 | axclrt_lib.axclrtFree(dev_prt[0]) 221 | for i in range(axclrt_lib.axclrtEngineGetNumOutputs(self._info[0])): 222 | axclrt_lib.axclrtEngineGetOutputBufferByIndex(self._io[0], i, dev_prt, dev_size) 223 | axclrt_lib.axclrtFree(dev_prt[0]) 224 | axclrt_lib.axclrtEngineDestroyIO(self._io[0]) 225 | self._io = None 226 | if self._model_id[0] is not None and self._model_id[0] != 0: 227 | axclrt_lib.axclrtEngineUnload(self._model_id[0]) 228 | self._model_id[0] = 0 229 | 230 | def _get_model_tool_version(self): 231 | model_tool_version = axclrt_lib.axclrtEngineGetModelCompilerVersion(self._model_id[0]) 232 | return axclrt_cffi.string(model_tool_version).decode() 233 | 234 | def _get_info(self): 235 | io_info = axclrt_cffi.new("axclrtEngineIOInfo *") 236 | ret = axclrt_lib.axclrtEngineGetIOInfo(self._model_id[0], io_info) 237 | if ret != 0: 238 | raise RuntimeError("axclrtEngineGetIOInfo failed.") 239 | return io_info 240 | 241 | def _get_shape_count(self): 242 | count = axclrt_cffi.new("int32_t *") 243 | ret = axclrt_lib.axclrtEngineGetShapeGroupsCount(self._info[0], count) 244 | if ret != 0: 245 | axclrt_lib.axclrtEngineUnload(self._model_id[0]) 246 | raise RuntimeError("axclrtEngineGetShapeGroupsCount failed.") 247 | return count[0] 248 | 249 | def _get_inputs(self): 250 | inputs = [] 251 | for group in range(self._shape_count): 252 | one_group_io = [] 253 | for index in range(axclrt_lib.axclrtEngineGetNumInputs(self._info[0])): 254 | cffi_name = axclrt_lib.axclrtEngineGetInputNameByIndex(self._info[0], index) 255 | name = axclrt_cffi.string(cffi_name).decode("utf-8") 256 | 257 | cffi_dtype = axclrt_cffi.new("axclrtEngineDataType *") 258 | ret = axclrt_lib.axclrtEngineGetInputDataType(self._info[0], index, cffi_dtype) 259 | if ret != 0: 260 | raise RuntimeError("axclrtEngineGetInputDataType failed.") 261 | dtype = _transform_dtype(cffi_dtype[0]) 262 | 263 | cffi_dims = axclrt_cffi.new("axclrtEngineIODims *") 264 | ret = axclrt_lib.axclrtEngineGetInputDims(self._info[0], group, index, cffi_dims) 265 | if ret != 0: 266 | raise RuntimeError("axclrtEngineGetInputDims failed.") 267 | shape = [cffi_dims.dims[i] for i in range(cffi_dims.dimCount)] 268 | 269 | meta = NodeArg(name, dtype, shape) 270 | one_group_io.append(meta) 271 | inputs.append(one_group_io) 272 | return inputs 273 | 274 | def _get_outputs(self): 275 | outputs = [] 276 | for group in range(self._shape_count): 277 | one_group_io = [] 278 | for index in range(axclrt_lib.axclrtEngineGetNumOutputs(self._info[0])): 279 | cffi_name = axclrt_lib.axclrtEngineGetOutputNameByIndex(self._info[0], index) 280 | name = axclrt_cffi.string(cffi_name).decode("utf-8") 281 | 282 | cffi_dtype = axclrt_cffi.new("axclrtEngineDataType *") 283 | ret = axclrt_lib.axclrtEngineGetOutputDataType(self._info[0], index, cffi_dtype) 284 | if ret != 0: 285 | raise RuntimeError("axclrtEngineGetOutputDataType failed.") 286 | dtype = _transform_dtype(cffi_dtype[0]) 287 | 288 | cffi_dims = axclrt_cffi.new("axclrtEngineIODims *") 289 | ret = axclrt_lib.axclrtEngineGetOutputDims(self._info[0], group, index, cffi_dims) 290 | if ret != 0: 291 | raise RuntimeError("axclrtEngineGetOutputDims failed.") 292 | shape = [cffi_dims.dims[i] for i in range(cffi_dims.dimCount)] 293 | 294 | meta = NodeArg(name, dtype, shape) 295 | one_group_io.append(meta) 296 | outputs.append(one_group_io) 297 | return outputs 298 | 299 | def _prepare_io(self): 300 | _io = axclrt_cffi.new("axclrtEngineIO *") 301 | ret = axclrt_lib.axclrtEngineCreateIO(self._info[0], _io) 302 | if ret != 0: 303 | raise RuntimeError(f"axclrtEngineCreateIO failed 0x{ret:08x}.") 304 | for i in range(axclrt_lib.axclrtEngineGetNumInputs(self._info[0])): 305 | max_size = 0 306 | for group in range(self._shape_count): 307 | size = axclrt_lib.axclrtEngineGetInputSizeByIndex(self._info[0], group, i) 308 | max_size = max(max_size, size) 309 | dev_ptr = axclrt_cffi.new("void **") 310 | ret = axclrt_lib.axclrtMalloc(dev_ptr, max_size, axclrt_lib.AXCL_MEM_MALLOC_NORMAL_ONLY) 311 | if 0 != ret or dev_ptr[0] == axclrt_cffi.NULL: 312 | raise RuntimeError(f"axclrtMalloc failed 0x{ret:08x} for input {i}.") 313 | ret = axclrt_lib.axclrtEngineSetInputBufferByIndex(_io[0], i, dev_ptr[0], max_size) 314 | if 0 != ret: 315 | raise RuntimeError(f"axclrtEngineSetInputBufferByIndex failed 0x{ret:08x} for input {i}.") 316 | for i in range(axclrt_lib.axclrtEngineGetNumOutputs(self._info[0])): 317 | max_size = 0 318 | for group in range(self._shape_count): 319 | size = axclrt_lib.axclrtEngineGetOutputSizeByIndex(self._info[0], group, i) 320 | max_size = max(max_size, size) 321 | dev_ptr = axclrt_cffi.new("void **") 322 | ret = axclrt_lib.axclrtMalloc(dev_ptr, max_size, axclrt_lib.AXCL_MEM_MALLOC_NORMAL_ONLY) 323 | if 0 != ret or dev_ptr[0] == axclrt_cffi.NULL: 324 | raise RuntimeError(f"axclrtMalloc failed 0x{ret:08x} for output {i}.") 325 | ret = axclrt_lib.axclrtEngineSetOutputBufferByIndex(_io[0], i, dev_ptr[0], max_size) 326 | if 0 != ret: 327 | raise RuntimeError(f"axclrtEngineSetOutputBufferByIndex failed 0x{ret:08x} for output {i}.") 328 | return _io 329 | 330 | def run( 331 | self, 332 | output_names: list[str], 333 | input_feed: dict[str, np.ndarray], 334 | run_options=None, 335 | shape_group: int = 0 336 | ): 337 | self._validate_input(input_feed) 338 | self._validate_output(output_names) 339 | 340 | ret = axclrt_lib.axclrtSetCurrentContext(self._thread_context[0]) 341 | if ret != 0: 342 | raise RuntimeError("axclrtSetCurrentContext failed") 343 | 344 | if None is output_names: 345 | output_names = [o.name for o in self.get_outputs(shape_group)] 346 | 347 | if (shape_group > self._shape_count - 1) or (shape_group < 0): 348 | raise ValueError(f"Invalid shape group: {shape_group}") 349 | 350 | # fill model io 351 | dev_prt = axclrt_cffi.new("void **") 352 | dev_size = axclrt_cffi.new("uint64_t *") 353 | for key, npy in input_feed.items(): 354 | for i, one in enumerate(self.get_inputs(shape_group)): 355 | if one.name == key: 356 | assert ( 357 | list(one.shape) == list(npy.shape) and one.dtype == npy.dtype 358 | ), f"model inputs({key}) expect shape {one.shape} and dtype {one.dtype}, howerver gets input with shape {npy.shape} and dtype {npy.dtype}" 359 | 360 | if not (npy.flags.c_contiguous or npy.flags.f_contiguous): 361 | npy = np.ascontiguousarray(npy) 362 | npy_ptr = axclrt_cffi.cast("void *", npy.ctypes.data) 363 | ret = axclrt_lib.axclrtEngineGetInputBufferByIndex(self._io[0], i, dev_prt, dev_size) 364 | if 0 != ret: 365 | raise RuntimeError(f"axclrtEngineGetInputBufferByIndex failed for input {i}.") 366 | ret = axclrt_lib.axclrtMemcpy(dev_prt[0], npy_ptr, npy.nbytes, axclrt_lib.AXCL_MEMCPY_HOST_TO_DEVICE) 367 | if 0 != ret: 368 | raise RuntimeError(f"axclrtMemcpy failed for input {i}.") 369 | 370 | # execute model 371 | ret = axclrt_lib.axclrtEngineExecute(self._model_id[0], self._context_id[0], shape_group, self._io[0]) 372 | 373 | # get output 374 | outputs = [] 375 | if 0 == ret: 376 | for i in range(len(self.get_outputs(shape_group))): 377 | ret = axclrt_lib.axclrtEngineGetOutputBufferByIndex(self._io[0], i, dev_prt, dev_size) 378 | if 0 != ret: 379 | raise RuntimeError(f"axclrtEngineGetOutputBufferByIndex failed for output {i}.") 380 | buffer_addr = dev_prt[0] 381 | npy_size = self.get_outputs(shape_group)[i].dtype.itemsize * np.prod(self.get_outputs(shape_group)[i].shape) 382 | npy = np.zeros(self.get_outputs(shape_group)[i].shape, dtype=self.get_outputs(shape_group)[i].dtype) 383 | npy_ptr = axclrt_cffi.cast("void *", npy.ctypes.data) 384 | ret = axclrt_lib.axclrtMemcpy(npy_ptr, buffer_addr, npy_size, axclrt_lib.AXCL_MEMCPY_DEVICE_TO_HOST) 385 | if 0 != ret: 386 | raise RuntimeError(f"axclrtMemcpy failed for output {i}.") 387 | name = self.get_outputs(shape_group)[i].name 388 | if name in output_names: 389 | outputs.append(npy) 390 | return outputs 391 | else: 392 | raise RuntimeError(f"axclrtEngineExecute failed 0x{ret:08x}") 393 | -------------------------------------------------------------------------------- /axengine/_axclrt_capi.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved. 2 | # 3 | # This source file is the property of Axera Semiconductor Co., Ltd. and 4 | # may not be copied or distributed in any isomorphic form without the prior 5 | # written consent of Axera Semiconductor Co., Ltd. 6 | # 7 | 8 | import ctypes.util 9 | 10 | from cffi import FFI 11 | 12 | __all__: ["axclrt_cffi", "axclrt_lib"] 13 | 14 | axclrt_cffi = FFI() 15 | 16 | # axcl_base.h 17 | axclrt_cffi.cdef( 18 | """ 19 | #define AXCL_MAX_DEVICE_COUNT 256 20 | typedef int32_t axclError; 21 | typedef void *axclrtContext; 22 | """ 23 | ) 24 | 25 | # axcl_rt_type.h 26 | axclrt_cffi.cdef( 27 | """ 28 | typedef struct axclrtDeviceList { 29 | uint32_t num; 30 | int32_t devices[AXCL_MAX_DEVICE_COUNT]; 31 | } axclrtDeviceList; 32 | 33 | typedef enum axclrtMemMallocPolicy { 34 | AXCL_MEM_MALLOC_HUGE_FIRST, 35 | AXCL_MEM_MALLOC_HUGE_ONLY, 36 | AXCL_MEM_MALLOC_NORMAL_ONLY 37 | } axclrtMemMallocPolicy; 38 | 39 | typedef enum axclrtMemcpyKind { 40 | AXCL_MEMCPY_HOST_TO_HOST, 41 | AXCL_MEMCPY_HOST_TO_DEVICE, //!< host vir -> device phy 42 | AXCL_MEMCPY_DEVICE_TO_HOST, //!< host vir <- device phy 43 | AXCL_MEMCPY_DEVICE_TO_DEVICE, 44 | AXCL_MEMCPY_HOST_PHY_TO_DEVICE, //!< host phy -> device phy 45 | AXCL_MEMCPY_DEVICE_TO_HOST_PHY, //!< host phy <- device phy 46 | } axclrtMemcpyKind; 47 | """ 48 | ) 49 | 50 | # axcl_rt_engine_type.h 51 | axclrt_cffi.cdef( 52 | """ 53 | #define AXCLRT_ENGINE_MAX_DIM_CNT 32 54 | typedef void* axclrtEngineIOInfo; 55 | typedef void* axclrtEngineIO; 56 | 57 | typedef enum axclrtEngineVNpuKind { 58 | AXCL_VNPU_DISABLE = 0, 59 | AXCL_VNPU_ENABLE = 1, 60 | AXCL_VNPU_BIG_LITTLE = 2, 61 | AXCL_VNPU_LITTLE_BIG = 3, 62 | } axclrtEngineVNpuKind; 63 | 64 | typedef enum axclrtEngineDataType { 65 | AXCL_DATA_TYPE_NONE = 0, 66 | AXCL_DATA_TYPE_INT4 = 1, 67 | AXCL_DATA_TYPE_UINT4 = 2, 68 | AXCL_DATA_TYPE_INT8 = 3, 69 | AXCL_DATA_TYPE_UINT8 = 4, 70 | AXCL_DATA_TYPE_INT16 = 5, 71 | AXCL_DATA_TYPE_UINT16 = 6, 72 | AXCL_DATA_TYPE_INT32 = 7, 73 | AXCL_DATA_TYPE_UINT32 = 8, 74 | AXCL_DATA_TYPE_INT64 = 9, 75 | AXCL_DATA_TYPE_UINT64 = 10, 76 | AXCL_DATA_TYPE_FP4 = 11, 77 | AXCL_DATA_TYPE_FP8 = 12, 78 | AXCL_DATA_TYPE_FP16 = 13, 79 | AXCL_DATA_TYPE_BF16 = 14, 80 | AXCL_DATA_TYPE_FP32 = 15, 81 | AXCL_DATA_TYPE_FP64 = 16, 82 | } axclrtEngineDataType; 83 | 84 | typedef enum axclrtEngineDataLayout { 85 | AXCL_DATA_LAYOUT_NONE = 0, 86 | AXCL_DATA_LAYOUT_NHWC = 0, 87 | AXCL_DATA_LAYOUT_NCHW = 1, 88 | } axclrtEngineDataLayout; 89 | 90 | typedef struct axclrtEngineIODims { 91 | int32_t dimCount; 92 | int32_t dims[AXCLRT_ENGINE_MAX_DIM_CNT]; 93 | } axclrtEngineIODims; 94 | """ 95 | ) 96 | 97 | # axcl.h 98 | axclrt_cffi.cdef( 99 | """ 100 | axclError axclInit(const char *config); 101 | axclError axclFinalize(); 102 | """ 103 | ) 104 | 105 | # axcl_rt.h 106 | axclrt_cffi.cdef( 107 | """ 108 | axclError axclrtGetVersion(int32_t *major, int32_t *minor, int32_t *patch); 109 | const char *axclrtGetSocName(); 110 | """ 111 | ) 112 | 113 | # axcl_rt_device.h 114 | axclrt_cffi.cdef( 115 | """ 116 | axclError axclrtGetDeviceList(axclrtDeviceList *deviceList); 117 | axclError axclrtSetDevice(int32_t deviceId); 118 | axclError axclrtResetDevice(int32_t deviceId); 119 | """ 120 | ) 121 | 122 | # axcl_rt_context.h 123 | axclrt_cffi.cdef( 124 | """ 125 | axclError axclrtCreateContext(axclrtContext *context, int32_t deviceId); 126 | axclError axclrtDestroyContext(axclrtContext context); 127 | axclError axclrtSetCurrentContext(axclrtContext context); 128 | axclError axclrtGetCurrentContext(axclrtContext *context); 129 | axclError axclrtGetDefaultContext(axclrtContext *context, int32_t deviceId); 130 | """ 131 | ) 132 | 133 | # axcl_rt_engine.h 134 | axclrt_cffi.cdef( 135 | """ 136 | axclError axclrtEngineInit(axclrtEngineVNpuKind npuKind); 137 | axclError axclrtEngineGetVNpuKind(axclrtEngineVNpuKind *npuKind); 138 | axclError axclrtEngineFinalize(); 139 | 140 | axclError axclrtEngineLoadFromFile(const char *modelPath, uint64_t *modelId); 141 | axclError axclrtEngineLoadFromMem(const void *model, uint64_t modelSize, uint64_t *modelId); 142 | const char* axclrtEngineGetModelCompilerVersion(uint64_t modelId); 143 | axclError axclrtEngineUnload(uint64_t modelId); 144 | 145 | axclError axclrtEngineGetIOInfo(uint64_t modelId, axclrtEngineIOInfo *ioInfo); 146 | axclError axclrtEngineGetShapeGroupsCount(axclrtEngineIOInfo ioInfo, int32_t *count); 147 | 148 | uint32_t axclrtEngineGetNumInputs(axclrtEngineIOInfo ioInfo); 149 | uint32_t axclrtEngineGetNumOutputs(axclrtEngineIOInfo ioInfo); 150 | 151 | uint64_t axclrtEngineGetInputSizeByIndex(axclrtEngineIOInfo ioInfo, uint32_t group, uint32_t index); 152 | uint64_t axclrtEngineGetOutputSizeByIndex(axclrtEngineIOInfo ioInfo, uint32_t group, uint32_t index); 153 | 154 | axclError axclrtEngineGetInputDims(axclrtEngineIOInfo ioInfo, uint32_t group, uint32_t index, axclrtEngineIODims *dims); 155 | axclError axclrtEngineGetOutputDims(axclrtEngineIOInfo ioInfo, uint32_t group, uint32_t index, axclrtEngineIODims *dims); 156 | 157 | const char *axclrtEngineGetInputNameByIndex(axclrtEngineIOInfo ioInfo, uint32_t index); 158 | const char *axclrtEngineGetOutputNameByIndex(axclrtEngineIOInfo ioInfo, uint32_t index); 159 | 160 | int32_t axclrtEngineGetInputDataType(axclrtEngineIOInfo ioInfo, uint32_t index, axclrtEngineDataType *type); 161 | int32_t axclrtEngineGetOutputDataType(axclrtEngineIOInfo ioInfo, uint32_t index, axclrtEngineDataType *type); 162 | 163 | int32_t axclrtEngineGetInputDataLayout(axclrtEngineIOInfo ioInfo, uint32_t index, axclrtEngineDataLayout *layout); 164 | int32_t axclrtEngineGetOutputDataLayout(axclrtEngineIOInfo ioInfo, uint32_t index, axclrtEngineDataLayout *layout); 165 | 166 | axclError axclrtEngineCreateIO(axclrtEngineIOInfo ioInfo, axclrtEngineIO *io); 167 | axclError axclrtEngineDestroyIO(axclrtEngineIO io); 168 | 169 | axclError axclrtEngineSetInputBufferByIndex(axclrtEngineIO io, uint32_t index, const void *dataBuffer, uint64_t size); 170 | axclError axclrtEngineSetOutputBufferByIndex(axclrtEngineIO io, uint32_t index, const void *dataBuffer, uint64_t size); 171 | axclError axclrtEngineGetInputBufferByIndex(axclrtEngineIO io, uint32_t index, void **dataBuffer, uint64_t *size); 172 | axclError axclrtEngineGetOutputBufferByIndex(axclrtEngineIO io, uint32_t index, void **dataBuffer, uint64_t *size); 173 | 174 | axclError axclrtEngineCreateContext(uint64_t modelId, uint64_t *contextId); 175 | 176 | axclError axclrtEngineExecute(uint64_t modelId, uint64_t contextId, uint32_t group, axclrtEngineIO io); 177 | """ 178 | ) 179 | 180 | # axcl_rt_memory.h 181 | axclrt_cffi.cdef( 182 | """ 183 | axclError axclrtMalloc(void **devPtr, size_t size, axclrtMemMallocPolicy policy); 184 | axclError axclrtMallocCached(void **devPtr, size_t size, axclrtMemMallocPolicy policy); 185 | axclError axclrtMemcpy(void *dstPtr, const void *srcPtr, size_t count, axclrtMemcpyKind kind); 186 | axclError axclrtFree(void *devPtr); 187 | axclError axclrtMemFlush(void *devPtr, size_t size); 188 | """ 189 | ) 190 | 191 | rt_name = "axcl_rt" 192 | rt_path = ctypes.util.find_library(rt_name) 193 | assert ( 194 | rt_path is not None 195 | ), f"Failed to find library {rt_name}. Please ensure it is installed and in the library path." 196 | 197 | axclrt_lib = axclrt_cffi.dlopen(rt_path) 198 | assert axclrt_lib is not None, f"Failed to load library {rt_path}. Please ensure it is installed and in the library path." 199 | -------------------------------------------------------------------------------- /axengine/_axclrt_types.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved. 2 | # 3 | # This source file is the property of Axera Semiconductor Co., Ltd. and 4 | # may not be copied or distributed in any isomorphic form without the prior 5 | # written consent of Axera Semiconductor Co., Ltd. 6 | # 7 | 8 | from enum import Enum 9 | 10 | 11 | class VNPUType(Enum): 12 | DISABLED = 0 13 | ENABLED = 1 14 | BIG_LITTLE = 2 15 | LITTLE_BIG = 3 16 | 17 | 18 | class ModelType(Enum): 19 | SINGLE = 0 20 | DUAL = 1 21 | TRIPLE = 2 22 | -------------------------------------------------------------------------------- /axengine/_axe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved. 2 | # 3 | # This source file is the property of Axera Semiconductor Co., Ltd. and 4 | # may not be copied or distributed in any isomorphic form without the prior 5 | # written consent of Axera Semiconductor Co., Ltd. 6 | # 7 | 8 | import atexit 9 | import os 10 | from typing import Any, Sequence 11 | 12 | import ml_dtypes as mldt 13 | import numpy as np 14 | 15 | from ._axe_capi import sys_lib, engine_cffi, engine_lib 16 | from ._axe_types import VNPUType, ModelType, ChipType 17 | from ._base_session import Session, SessionOptions 18 | from ._node import NodeArg 19 | 20 | __all__: ["AXEngineSession"] 21 | 22 | _is_sys_initialized = False 23 | _is_engine_initialized = False 24 | 25 | 26 | def _transform_dtype(dtype): 27 | if dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_UINT8): 28 | return np.dtype(np.uint8) 29 | elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_SINT8): 30 | return np.dtype(np.int8) 31 | elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_UINT16): 32 | return np.dtype(np.uint16) 33 | elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_SINT16): 34 | return np.dtype(np.int16) 35 | elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_UINT32): 36 | return np.dtype(np.uint32) 37 | elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_SINT32): 38 | return np.dtype(np.int32) 39 | elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_FLOAT32): 40 | return np.dtype(np.float32) 41 | elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_BFLOAT16): 42 | return np.dtype(mldt.bfloat16) 43 | else: 44 | raise ValueError(f"Unsupported data type '{dtype}'.") 45 | 46 | 47 | def _check_cffi_func_exists(lib, func_name): 48 | try: 49 | getattr(lib, func_name) 50 | return True 51 | except AttributeError: 52 | return False 53 | 54 | 55 | def _get_chip_type(): 56 | if not _check_cffi_func_exists(engine_lib, "AX_ENGINE_SetAffinity"): 57 | return ChipType.M57H 58 | elif not _check_cffi_func_exists(engine_lib, "AX_ENGINE_GetTotalOps"): 59 | return ChipType.MC50 60 | else: 61 | return ChipType.MC20E 62 | 63 | 64 | def _get_version(): 65 | engine_version = engine_lib.AX_ENGINE_GetVersion() 66 | return engine_cffi.string(engine_version).decode("utf-8") 67 | 68 | 69 | def _get_vnpu_type() -> VNPUType: 70 | vnpu_type = engine_cffi.new("AX_ENGINE_NPU_ATTR_T *") 71 | ret = engine_lib.AX_ENGINE_GetVNPUAttr(vnpu_type) 72 | if 0 != ret: 73 | raise RuntimeError("Failed to get VNPU attribute.") 74 | return VNPUType(vnpu_type.eHardMode) 75 | 76 | 77 | def _initialize_engine(): 78 | global _is_sys_initialized, _is_engine_initialized 79 | 80 | ret = sys_lib.AX_SYS_Init() 81 | if ret != 0: 82 | raise RuntimeError("Failed to initialize ax sys.") 83 | _is_sys_initialized = True 84 | 85 | # disabled mode by default 86 | vnpu_type = engine_cffi.new("AX_ENGINE_NPU_ATTR_T *") 87 | ret = engine_lib.AX_ENGINE_GetVNPUAttr(vnpu_type) 88 | if 0 != ret: 89 | # this means the NPU was not initialized 90 | vnpu_type.eHardMode = engine_cffi.cast( 91 | "AX_ENGINE_NPU_MODE_T", VNPUType.DISABLED.value 92 | ) 93 | ret = engine_lib.AX_ENGINE_Init(vnpu_type) 94 | if ret != 0: 95 | raise RuntimeError("Failed to initialize ax sys engine.") 96 | _is_engine_initialized = True 97 | 98 | print(f"[INFO] Chip type: {_get_chip_type()}") 99 | print(f"[INFO] VNPU type: {_get_vnpu_type()}") 100 | print(f"[INFO] Engine version: {_get_version()}") 101 | 102 | 103 | def _finalize_engine(): 104 | global _is_sys_initialized, _is_engine_initialized 105 | 106 | if _is_engine_initialized: 107 | engine_lib.AX_ENGINE_Deinit() 108 | if _is_sys_initialized: 109 | sys_lib.AX_SYS_Deinit() 110 | 111 | 112 | _initialize_engine() 113 | atexit.register(_finalize_engine) 114 | 115 | 116 | class AXEngineSession(Session): 117 | def __init__( 118 | self, 119 | path_or_bytes: str | bytes | os.PathLike, 120 | sess_options: SessionOptions | None = None, 121 | provider_options: dict[Any, Any] | None = None, 122 | **kwargs, 123 | ) -> None: 124 | super().__init__() 125 | 126 | self._chip_type = _get_chip_type() 127 | self._vnpu_type = _get_vnpu_type() 128 | 129 | # handle, context, info, io 130 | self._handle = engine_cffi.new("uint64_t **") 131 | self._context = engine_cffi.new("uint64_t **") 132 | self._io = engine_cffi.new("AX_ENGINE_IO_T *") 133 | 134 | # model buffer, almost copied from onnx runtime 135 | if isinstance(path_or_bytes, (str, os.PathLike)): 136 | self._model_name = os.path.splitext(os.path.basename(path_or_bytes))[0] 137 | with open(path_or_bytes, "rb") as f: 138 | data = f.read() 139 | self._model_buffer = engine_cffi.new("char[]", data) 140 | self._model_buffer_size = len(data) 141 | elif isinstance(path_or_bytes, bytes): 142 | self._model_buffer = engine_cffi.new("char[]", path_or_bytes) 143 | self._model_buffer_size = len(path_or_bytes) 144 | else: 145 | raise TypeError(f"Unable to load model from type '{type(path_or_bytes)}'") 146 | 147 | # get model type 148 | self._model_type = self._get_model_type() 149 | if self._chip_type is ChipType.MC20E: 150 | if self._model_type is ModelType.FULL: 151 | print(f"[INFO] Model type: {self._model_type.value} (full core)") 152 | if self._model_type is ModelType.HALF: 153 | print(f"[INFO] Model type: {self._model_type.value} (half core)") 154 | if self._chip_type is ChipType.MC50: 155 | if self._model_type is ModelType.SINGLE: 156 | print(f"[INFO] Model type: {self._model_type.value} (single core)") 157 | if self._model_type is ModelType.DUAL: 158 | print(f"[INFO] Model type: {self._model_type.value} (dual core)") 159 | if self._model_type is ModelType.TRIPLE: 160 | print(f"[INFO] Model type: {self._model_type.value} (triple core)") 161 | if self._chip_type is ChipType.M57H: 162 | print(f"[INFO] Model type: {self._model_type.value} (single core)") 163 | 164 | # check model type 165 | if self._chip_type is ChipType.MC50: 166 | # all types (single or dual or triple) of model are allowed in vnpu mode disabled 167 | # only single core model is allowed in vnpu mode enabled 168 | # only triple core model is NOT allowed in vnpu mode big-little or little-big 169 | if self._vnpu_type is VNPUType.ENABLED: 170 | if self._model_type is not ModelType.SINGLE: 171 | raise ValueError( 172 | f"Model type '{self._model_type}' is not allowed when vnpu is inited as {self._vnpu_type}." 173 | ) 174 | if ( 175 | self._vnpu_type is VNPUType.BIG_LITTLE 176 | or self._vnpu_type is VNPUType.LITTLE_BIG 177 | ): 178 | if self._model_type is ModelType.TRIPLE: 179 | raise ValueError( 180 | f"Model type '{self._model_type}' is not allowed when vnpu is inited as {self._vnpu_type}." 181 | ) 182 | if self._chip_type is ChipType.MC20E: 183 | # all types of full or half core model are allowed in vnpu mode disabled 184 | # only half core model is allowed in vnpu mode enabled 185 | if self._vnpu_type is VNPUType.ENABLED: 186 | if self._model_type is ModelType.FULL: 187 | raise ValueError( 188 | f"Model type '{self._model_type}' is not allowed when vnpu is inited as {self._vnpu_type}." 189 | ) 190 | # if self._chip_type is ChipType.M57H: 191 | # there only one type of model will be compiled, so no need to check 192 | 193 | # load model 194 | ret = self._load() 195 | if 0 != ret: 196 | raise RuntimeError("Failed to load model.") 197 | print(f"[INFO] Compiler version: {self._get_model_tool_version()}") 198 | 199 | # get shape group count 200 | try: 201 | self._shape_count = self._get_shape_count() 202 | except AttributeError as e: 203 | print(f"[WARNING] {e}") 204 | self._shape_count = 1 205 | 206 | # get model shape 207 | self._info = self._get_info() 208 | self._inputs = self._get_inputs() 209 | self._outputs = self._get_outputs() 210 | 211 | # fill model io 212 | self._align = 128 213 | self._cmm_token = engine_cffi.new("AX_S8[]", b"PyEngine") 214 | self._io[0].nInputSize = len(self.get_inputs()) 215 | self._io[0].nOutputSize = len(self.get_outputs()) 216 | _inputs= engine_cffi.new( 217 | "AX_ENGINE_IO_BUFFER_T[{}]".format(self._io[0].nInputSize) 218 | ) 219 | _outputs = engine_cffi.new( 220 | "AX_ENGINE_IO_BUFFER_T[{}]".format(self._io[0].nOutputSize) 221 | ) 222 | self._io_buffers = (_inputs, _outputs) 223 | self._io[0].pInputs = _inputs 224 | self._io[0].pOutputs = _outputs 225 | 226 | self._io_inputs_pool = [] 227 | for i in range(len(self.get_inputs())): 228 | max_buf = 0 229 | for j in range(self._shape_count): 230 | max_buf = max(max_buf, self._info[j][0].pInputs[i].nSize) 231 | self._io[0].pInputs[i].nSize = max_buf 232 | phy = engine_cffi.new("AX_U64*") 233 | vir = engine_cffi.new("AX_VOID**") 234 | self._io_inputs_pool.append((phy, vir)) 235 | ret = sys_lib.AX_SYS_MemAllocCached( 236 | phy, vir, self._io[0].pInputs[i].nSize, self._align, self._cmm_token 237 | ) 238 | if 0 != ret: 239 | raise RuntimeError("Failed to allocate memory for input.") 240 | self._io[0].pInputs[i].phyAddr = phy[0] 241 | self._io[0].pInputs[i].pVirAddr = vir[0] 242 | 243 | self._io_outputs_pool = [] 244 | for i in range(len(self.get_outputs())): 245 | max_buf = 0 246 | for j in range(self._shape_count): 247 | max_buf = max(max_buf, self._info[j][0].pOutputs[i].nSize) 248 | self._io[0].pOutputs[i].nSize = max_buf 249 | phy = engine_cffi.new("AX_U64*") 250 | vir = engine_cffi.new("AX_VOID**") 251 | self._io_outputs_pool.append((phy, vir)) 252 | ret = sys_lib.AX_SYS_MemAllocCached( 253 | phy, vir, self._io[0].pOutputs[i].nSize, self._align, self._cmm_token 254 | ) 255 | if 0 != ret: 256 | raise RuntimeError("Failed to allocate memory for output.") 257 | self._io[0].pOutputs[i].phyAddr = phy[0] 258 | self._io[0].pOutputs[i].pVirAddr = vir[0] 259 | 260 | def __del__(self): 261 | self._unload() 262 | 263 | def _get_model_type(self) -> ModelType: 264 | model_type = engine_cffi.new("AX_ENGINE_MODEL_TYPE_T *") 265 | ret = engine_lib.AX_ENGINE_GetModelType( 266 | self._model_buffer, self._model_buffer_size, model_type 267 | ) 268 | if 0 != ret: 269 | raise RuntimeError("Failed to get model type.") 270 | return ModelType(model_type[0]) 271 | 272 | def _get_model_tool_version(self): 273 | model_tool_version = engine_lib.AX_ENGINE_GetModelToolsVersion( 274 | self._handle[0] 275 | ) 276 | return engine_cffi.string(model_tool_version).decode("utf-8") 277 | 278 | def _load(self): 279 | extra = engine_cffi.new("AX_ENGINE_HANDLE_EXTRA_T *") 280 | extra_name = engine_cffi.new("char[]", self._model_name.encode("utf-8")) 281 | extra.pName = extra_name 282 | 283 | # for onnx runtime do not support one model multiple context running in multi-thread as far as I know, so 284 | # the engine handle and context will create only once 285 | ret = engine_lib.AX_ENGINE_CreateHandleV2( 286 | self._handle, self._model_buffer, self._model_buffer_size, extra 287 | ) 288 | if 0 == ret: 289 | ret = engine_lib.AX_ENGINE_CreateContextV2( 290 | self._handle[0], self._context 291 | ) 292 | return ret 293 | 294 | def _get_info(self): 295 | total_info = [] 296 | if 1 == self._shape_count: 297 | info = engine_cffi.new("AX_ENGINE_IO_INFO_T **") 298 | ret = engine_lib.AX_ENGINE_GetIOInfo(self._handle[0], info) 299 | if 0 != ret: 300 | raise RuntimeError("Failed to get model shape.") 301 | total_info.append(info) 302 | else: 303 | for i in range(self._shape_count): 304 | info = engine_cffi.new("AX_ENGINE_IO_INFO_T **") 305 | ret = engine_lib.AX_ENGINE_GetGroupIOInfo( 306 | self._handle[0], i, info 307 | ) 308 | if 0 != ret: 309 | raise RuntimeError(f"Failed to get model the {i}th shape.") 310 | total_info.append(info) 311 | return total_info 312 | 313 | def _get_shape_count(self): 314 | count = engine_cffi.new("AX_U32 *") 315 | ret = engine_lib.AX_ENGINE_GetGroupIOInfoCount(self._handle[0], count) 316 | if 0 != ret: 317 | raise RuntimeError("Failed to get model shape group.") 318 | return count[0] 319 | 320 | def _unload(self): 321 | if self._handle[0] is not None: 322 | engine_lib.AX_ENGINE_DestroyHandle(self._handle[0]) 323 | self._handle[0] = engine_cffi.NULL 324 | 325 | def _get_io(self, io_type: str): 326 | io_info = [] 327 | for group in range(self._shape_count): 328 | one_group_io = [] 329 | for index in range(getattr(self._info[group][0], f'n{io_type}Size')): 330 | current_io = getattr(self._info[group][0], f'p{io_type}s')[index] 331 | name = engine_cffi.string(current_io.pName).decode("utf-8") 332 | shape = [current_io.pShape[i] for i in range(current_io.nShapeSize)] 333 | dtype = _transform_dtype(current_io.eDataType) 334 | meta = NodeArg(name, dtype, shape) 335 | one_group_io.append(meta) 336 | io_info.append(one_group_io) 337 | return io_info 338 | 339 | def _get_inputs(self): 340 | return self._get_io('Input') 341 | 342 | def _get_outputs(self): 343 | return self._get_io('Output') 344 | 345 | def run( 346 | self, 347 | output_names: list[str], 348 | input_feed: dict[str, np.ndarray], 349 | run_options=None, 350 | shape_group: int = 0 351 | ): 352 | self._validate_input(input_feed) 353 | self._validate_output(output_names) 354 | 355 | if None is output_names: 356 | output_names = [o.name for o in self.get_outputs(shape_group)] 357 | 358 | if (shape_group > self._shape_count - 1) or (shape_group < 0): 359 | raise ValueError(f"Invalid shape group: {shape_group}") 360 | 361 | # fill model io 362 | for key, npy in input_feed.items(): 363 | for i, one in enumerate(self.get_inputs(shape_group)): 364 | if one.name == key: 365 | assert ( 366 | list(one.shape) == list(npy.shape) and one.dtype == npy.dtype 367 | ), f"model inputs({key}) expect shape {one.shape} and dtype {one.dtype}, however gets input with shape {npy.shape} and dtype {npy.dtype}" 368 | 369 | if not (npy.flags.c_contiguous or npy.flags.f_contiguous): 370 | npy = np.ascontiguousarray(npy) 371 | npy_ptr = engine_cffi.cast("void *", npy.ctypes.data) 372 | 373 | engine_cffi.memmove( 374 | self._io[0].pInputs[i].pVirAddr, npy_ptr, npy.nbytes 375 | ) 376 | sys_lib.AX_SYS_MflushCache( 377 | self._io[0].pInputs[i].phyAddr, 378 | self._io[0].pInputs[i].pVirAddr, 379 | self._io[0].pInputs[i].nSize, 380 | ) 381 | break 382 | 383 | # execute model 384 | if self._shape_count > 1: 385 | ret = engine_lib.AX_ENGINE_RunGroupIOSync( 386 | self._handle[0], self._context[0], shape_group, self._io 387 | ) 388 | else: 389 | ret = engine_lib.AX_ENGINE_RunSyncV2( 390 | self._handle[0], self._context[0], self._io 391 | ) 392 | 393 | # flush output 394 | outputs = [] 395 | if 0 == ret: 396 | for i in range(len(self.get_outputs(shape_group))): 397 | sys_lib.AX_SYS_MinvalidateCache( 398 | self._io[0].pOutputs[i].phyAddr, 399 | self._io[0].pOutputs[i].pVirAddr, 400 | self._io[0].pOutputs[i].nSize, 401 | ) 402 | npy_size = self.get_outputs(shape_group)[i].dtype.itemsize * np.prod(self.get_outputs(shape_group)[i].shape) 403 | npy = np.frombuffer( 404 | engine_cffi.buffer( 405 | self._io[0].pOutputs[i].pVirAddr, npy_size 406 | ), 407 | dtype=self.get_outputs(shape_group)[i].dtype, 408 | ).reshape(self.get_outputs(shape_group)[i].shape).copy() 409 | name = self.get_outputs(shape_group)[i].name 410 | if name in output_names: 411 | outputs.append(npy) 412 | return outputs 413 | else: 414 | raise RuntimeError("Failed to run model.") 415 | -------------------------------------------------------------------------------- /axengine/_axe_capi.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved. 2 | # 3 | # This source file is the property of Axera Semiconductor Co., Ltd. and 4 | # may not be copied or distributed in any isomorphic form without the prior 5 | # written consent of Axera Semiconductor Co., Ltd. 6 | # 7 | 8 | import ctypes.util 9 | import platform 10 | 11 | from cffi import FFI 12 | 13 | __all__: ["sys_lib", "sys_cffi", "engine_lib", "engine_cffi"] 14 | 15 | sys_cffi = FFI() 16 | 17 | # ax_base_type.h 18 | sys_cffi.cdef( 19 | """ 20 | typedef int AX_S32; 21 | typedef unsigned int AX_U32; 22 | typedef unsigned long long int AX_U64; 23 | typedef signed char AX_S8; 24 | typedef void AX_VOID; 25 | """ 26 | ) 27 | 28 | # ax_sys_api.h 29 | sys_cffi.cdef( 30 | """ 31 | AX_S32 AX_SYS_Init(AX_VOID); 32 | AX_S32 AX_SYS_Deinit(AX_VOID); 33 | AX_S32 AX_SYS_MemAllocCached(AX_U64 *phyaddr, AX_VOID **pviraddr, AX_U32 size, AX_U32 align, const AX_S8 *token); 34 | AX_S32 AX_SYS_MemFree(AX_U64 phyaddr, AX_VOID *pviraddr); 35 | AX_S32 AX_SYS_MflushCache(AX_U64 phyaddr, AX_VOID *pviraddr, AX_U32 size); 36 | AX_S32 AX_SYS_MinvalidateCache(AX_U64 phyaddr, AX_VOID *pviraddr, AX_U32 size); 37 | """ 38 | ) 39 | 40 | sys_name = "ax_sys" 41 | sys_path = ctypes.util.find_library(sys_name) 42 | assert ( 43 | sys_path is not None 44 | ), f"Failed to find library {sys_name}. Please ensure it is installed and in the library path." 45 | 46 | sys_lib = sys_cffi.dlopen(sys_path) 47 | assert sys_lib is not None, f"Failed to load library {sys_path}. Please ensure it is installed and in the library path." 48 | 49 | engine_cffi = FFI() 50 | 51 | # ax_base_type.h 52 | engine_cffi.cdef( 53 | """ 54 | typedef unsigned long long int AX_U64; 55 | typedef unsigned int AX_U32; 56 | typedef unsigned char AX_U8; 57 | typedef int AX_S32; 58 | typedef signed char AX_S8; 59 | typedef char AX_CHAR; 60 | typedef void AX_VOID; 61 | 62 | typedef enum { 63 | AX_FALSE = 0, 64 | AX_TRUE = 1, 65 | } AX_BOOL; 66 | """ 67 | ) 68 | 69 | # ax_engine_type.h, base type 70 | engine_cffi.cdef( 71 | """ 72 | typedef AX_U32 AX_ENGINE_NPU_SET_T; 73 | """ 74 | ) 75 | 76 | # ax_engine_type.h, enum 77 | engine_cffi.cdef( 78 | """ 79 | typedef enum _AX_ENGINE_TENSOR_LAYOUT_E 80 | { 81 | AX_ENGINE_TENSOR_LAYOUT_UNKNOWN = 0, 82 | AX_ENGINE_TENSOR_LAYOUT_NHWC = 1, 83 | AX_ENGINE_TENSOR_LAYOUT_NCHW = 2, 84 | } AX_ENGINE_TENSOR_LAYOUT_T; 85 | 86 | typedef enum 87 | { 88 | AX_ENGINE_MT_PHYSICAL = 0, 89 | AX_ENGINE_MT_VIRTUAL = 1, 90 | AX_ENGINE_MT_OCM = 2, 91 | } AX_ENGINE_MEMORY_TYPE_T; 92 | 93 | typedef enum 94 | { 95 | AX_ENGINE_DT_UNKNOWN = 0, 96 | AX_ENGINE_DT_UINT8 = 1, 97 | AX_ENGINE_DT_UINT16 = 2, 98 | AX_ENGINE_DT_FLOAT32 = 3, 99 | AX_ENGINE_DT_SINT16 = 4, 100 | AX_ENGINE_DT_SINT8 = 5, 101 | AX_ENGINE_DT_SINT32 = 6, 102 | AX_ENGINE_DT_UINT32 = 7, 103 | AX_ENGINE_DT_FLOAT64 = 8, 104 | AX_ENGINE_DT_BFLOAT16 = 9, 105 | AX_ENGINE_DT_UINT10_PACKED = 100, 106 | AX_ENGINE_DT_UINT12_PACKED = 101, 107 | AX_ENGINE_DT_UINT14_PACKED = 102, 108 | AX_ENGINE_DT_UINT16_PACKED = 103, 109 | } AX_ENGINE_DATA_TYPE_T; 110 | 111 | typedef enum 112 | { 113 | AX_ENGINE_CS_FEATUREMAP = 0, 114 | AX_ENGINE_CS_RAW8 = 12, 115 | AX_ENGINE_CS_RAW10 = 1, 116 | AX_ENGINE_CS_RAW12 = 2, 117 | AX_ENGINE_CS_RAW14 = 11, 118 | AX_ENGINE_CS_RAW16 = 3, 119 | AX_ENGINE_CS_NV12 = 4, 120 | AX_ENGINE_CS_NV21 = 5, 121 | AX_ENGINE_CS_RGB = 6, 122 | AX_ENGINE_CS_BGR = 7, 123 | AX_ENGINE_CS_RGBA = 8, 124 | AX_ENGINE_CS_GRAY = 9, 125 | AX_ENGINE_CS_YUV444 = 10, 126 | } AX_ENGINE_COLOR_SPACE_T; 127 | """ 128 | ) 129 | 130 | # ax_engine_type.h, architecturally agnostic struct 131 | engine_cffi.cdef( 132 | """ 133 | typedef enum { 134 | AX_ENGINE_VIRTUAL_NPU_DISABLE = 0, 135 | } AX_ENGINE_NPU_MODE_T; 136 | 137 | typedef enum { 138 | AX_ENGINE_MODEL_TYPE0 = 0, 139 | } AX_ENGINE_MODEL_TYPE_T; 140 | 141 | typedef struct { 142 | AX_ENGINE_NPU_MODE_T eHardMode; 143 | AX_U32 reserve[8]; 144 | } AX_ENGINE_NPU_ATTR_T; 145 | 146 | typedef struct _AX_ENGINE_IO_META_EX_T 147 | { 148 | AX_ENGINE_COLOR_SPACE_T eColorSpace; 149 | AX_U64 u64Reserved[18]; 150 | } AX_ENGINE_IO_META_EX_T; 151 | 152 | typedef struct { 153 | AX_ENGINE_NPU_SET_T nNpuSet; 154 | AX_S8* pName; 155 | AX_U32 reserve[8]; 156 | } AX_ENGINE_HANDLE_EXTRA_T; 157 | 158 | typedef struct _AX_ENGINE_CMM_INFO_T 159 | { 160 | AX_U32 nCMMSize; 161 | } AX_ENGINE_CMM_INFO_T; 162 | 163 | typedef struct _AX_ENGINE_IO_SETTING_T 164 | { 165 | AX_U32 nWbtIndex; 166 | AX_U64 u64Reserved[7]; 167 | }AX_ENGINE_IO_SETTING_T; 168 | """ 169 | ) 170 | 171 | # check architecture, 32bit or 64bit 172 | arch = platform.architecture()[0] 173 | 174 | # ax_engine_type.h, struct 175 | if arch == "64bit": 176 | engine_cffi.cdef( 177 | """ 178 | typedef struct _AX_ENGINE_IO_META_T 179 | { 180 | AX_CHAR* pName; 181 | AX_S32* pShape; 182 | AX_U8 nShapeSize; 183 | AX_ENGINE_TENSOR_LAYOUT_T eLayout; 184 | AX_ENGINE_MEMORY_TYPE_T eMemoryType; 185 | AX_ENGINE_DATA_TYPE_T eDataType; 186 | AX_ENGINE_IO_META_EX_T* pExtraMeta; 187 | AX_U32 nSize; 188 | AX_U32 nQuantizationValue; 189 | AX_S32* pStride; 190 | AX_U64 u64Reserved[9]; 191 | } AX_ENGINE_IO_META_T; 192 | 193 | typedef struct _AX_ENGINE_IO_INFO_T 194 | { 195 | AX_ENGINE_IO_META_T* pInputs; 196 | AX_U32 nInputSize; 197 | AX_ENGINE_IO_META_T* pOutputs; 198 | AX_U32 nOutputSize; 199 | AX_U32 nMaxBatchSize; 200 | AX_BOOL bDynamicBatchSize; 201 | AX_U64 u64Reserved[11]; 202 | } AX_ENGINE_IO_INFO_T; 203 | 204 | typedef struct _AX_ENGINE_IO_BUFFER_T 205 | { 206 | AX_U64 phyAddr; 207 | AX_VOID* pVirAddr; 208 | AX_U32 nSize; 209 | AX_S32* pStride; 210 | AX_U8 nStrideSize; 211 | AX_U64 u64Reserved[11]; 212 | } AX_ENGINE_IO_BUFFER_T; 213 | 214 | typedef struct _AX_ENGINE_IO_T 215 | { 216 | AX_ENGINE_IO_BUFFER_T* pInputs; 217 | AX_U32 nInputSize; 218 | AX_ENGINE_IO_BUFFER_T* pOutputs; 219 | AX_U32 nOutputSize; 220 | AX_U32 nBatchSize; 221 | AX_ENGINE_IO_SETTING_T* pIoSetting; 222 | AX_U64 u64Reserved[10]; 223 | } AX_ENGINE_IO_T; 224 | """ 225 | ) 226 | else: 227 | engine_cffi.cdef( 228 | """ 229 | typedef struct _AX_ENGINE_IO_META_T 230 | { 231 | AX_CHAR* pName; 232 | AX_S32* pShape; 233 | AX_U8 nShapeSize; 234 | AX_ENGINE_TENSOR_LAYOUT_T eLayout; 235 | AX_ENGINE_MEMORY_TYPE_T eMemoryType; 236 | AX_ENGINE_DATA_TYPE_T eDataType; 237 | AX_ENGINE_IO_META_EX_T* pExtraMeta; 238 | AX_U32 nSize; 239 | AX_U32 nQuantizationValue; 240 | AX_S32* pStride; 241 | AX_U64 u64Reserved[11]; 242 | } AX_ENGINE_IO_META_T; 243 | 244 | typedef struct _AX_ENGINE_IO_INFO_T 245 | { 246 | AX_ENGINE_IO_META_T* pInputs; 247 | AX_U32 nInputSize; 248 | AX_ENGINE_IO_META_T* pOutputs; 249 | AX_U32 nOutputSize; 250 | AX_U32 nMaxBatchSize; 251 | AX_BOOL bDynamicBatchSize; 252 | AX_U64 u64Reserved[13]; 253 | } AX_ENGINE_IO_INFO_T; 254 | 255 | typedef struct _AX_ENGINE_IO_BUFFER_T 256 | { 257 | AX_U64 phyAddr; 258 | AX_VOID* pVirAddr; 259 | AX_U32 nSize; 260 | AX_S32* pStride; 261 | AX_U8 nStrideSize; 262 | AX_U64 u64Reserved[13]; 263 | } AX_ENGINE_IO_BUFFER_T; 264 | 265 | typedef struct _AX_ENGINE_IO_T 266 | { 267 | AX_ENGINE_IO_BUFFER_T* pInputs; 268 | AX_U32 nInputSize; 269 | AX_ENGINE_IO_BUFFER_T* pOutputs; 270 | AX_U32 nOutputSize; 271 | AX_U32 nBatchSize; 272 | AX_ENGINE_IO_SETTING_T* pIoSetting; 273 | AX_U64 u64Reserved[12]; 274 | } AX_ENGINE_IO_T; 275 | """ 276 | ) 277 | 278 | # ax_engine_api.h 279 | engine_cffi.cdef( 280 | """ 281 | const AX_CHAR* AX_ENGINE_GetVersion(AX_VOID); 282 | 283 | AX_VOID AX_ENGINE_NPUReset(AX_VOID); 284 | AX_S32 AX_ENGINE_Init(AX_ENGINE_NPU_ATTR_T* pNpuAttr); 285 | AX_S32 AX_ENGINE_GetVNPUAttr(AX_ENGINE_NPU_ATTR_T* pNpuAttr); 286 | AX_S32 AX_ENGINE_Deinit(AX_VOID); 287 | 288 | AX_S32 AX_ENGINE_GetModelType(const AX_VOID* pData, AX_U32 nDataSize, AX_ENGINE_MODEL_TYPE_T* pModelType); 289 | 290 | AX_S32 AX_ENGINE_CreateHandleV2(uint64_t** pHandle, const AX_VOID* pData, AX_U32 nDataSize, AX_ENGINE_HANDLE_EXTRA_T* pExtraParam); 291 | AX_S32 AX_ENGINE_DestroyHandle(uint64_t* nHandle); 292 | 293 | AX_S32 AX_ENGINE_GetIOInfo(uint64_t* nHandle, AX_ENGINE_IO_INFO_T** pIO); 294 | AX_S32 AX_ENGINE_GetGroupIOInfoCount(uint64_t* nHandle, AX_U32* pCount); 295 | AX_S32 AX_ENGINE_GetGroupIOInfo(uint64_t* nHandle, AX_U32 nIndex, AX_ENGINE_IO_INFO_T** pIO); 296 | 297 | AX_S32 AX_ENGINE_GetHandleModelType(uint64_t* nHandle, AX_ENGINE_MODEL_TYPE_T* pModelType); 298 | 299 | AX_S32 AX_ENGINE_CreateContextV2(uint64_t* nHandle, uint64_t** pContext); 300 | 301 | AX_S32 AX_ENGINE_RunSyncV2(uint64_t* handle, uint64_t* context, AX_ENGINE_IO_T* pIO); 302 | AX_S32 AX_ENGINE_RunGroupIOSync(uint64_t* handle, uint64_t* context, AX_U32 nIndex, AX_ENGINE_IO_T* pIO); 303 | 304 | AX_S32 AX_ENGINE_SetAffinity(uint64_t* nHandle, AX_ENGINE_NPU_SET_T nNpuSet); 305 | AX_S32 AX_ENGINE_GetAffinity(uint64_t* nHandle, AX_ENGINE_NPU_SET_T* pNpuSet); 306 | 307 | AX_S32 AX_ENGINE_GetCMMUsage(uint64_t* nHandle, AX_ENGINE_CMM_INFO_T* pCMMInfo); 308 | 309 | const AX_CHAR* AX_ENGINE_GetModelToolsVersion(uint64_t* nHandle); 310 | 311 | // internal use api, remember no question 312 | AX_S32 AX_ENGINE_GetTotalOps(); 313 | """ 314 | ) 315 | 316 | engine_name = "ax_engine" 317 | engine_path = ctypes.util.find_library(engine_name) 318 | assert ( 319 | engine_path is not None 320 | ), f"Failed to find library {engine_name}. Please ensure it is installed and in the library path." 321 | 322 | engine_lib = engine_cffi.dlopen(engine_path) 323 | assert engine_lib is not None, f"Failed to load library {engine_path}. Please ensure it is installed and in the library path." 324 | -------------------------------------------------------------------------------- /axengine/_axe_types.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved. 2 | # 3 | # This source file is the property of Axera Semiconductor Co., Ltd. and 4 | # may not be copied or distributed in any isomorphic form without the prior 5 | # written consent of Axera Semiconductor Co., Ltd. 6 | # 7 | 8 | from enum import Enum 9 | 10 | 11 | class VNPUType(Enum): 12 | DISABLED = 0 13 | ENABLED = 1 14 | BIG_LITTLE = 2 15 | LITTLE_BIG = 3 16 | 17 | 18 | class ModelType(Enum): 19 | HALF = 0 # for MC20E, which means chip is AX630C(x), or AX620Q(x) 20 | FULL = 1 # for MC20E 21 | SINGLE = 0 # for MC50, which means chip is AX650A or AX650N, and M57H 22 | DUAL = 1 # for MC50 23 | TRIPLE = 2 # for MC50 24 | 25 | 26 | class ChipType(Enum): 27 | MC20E = 0 28 | MC50 = 1 29 | M57H = 2 30 | -------------------------------------------------------------------------------- /axengine/_base_session.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved. 2 | # 3 | # This source file is the property of Axera Semiconductor Co., Ltd. and 4 | # may not be copied or distributed in any isomorphic form without the prior 5 | # written consent of Axera Semiconductor Co., Ltd. 6 | # 7 | 8 | from abc import ABC, abstractmethod 9 | 10 | import numpy as np 11 | 12 | from ._node import NodeArg 13 | 14 | 15 | class SessionOptions: 16 | pass 17 | 18 | 19 | class Session(ABC): 20 | def __init__(self) -> None: 21 | self._shape_count = 0 22 | self._inputs = [] 23 | self._outputs = [] 24 | 25 | def _validate_input(self, feed_input_names: dict[str, np.ndarray]): 26 | missing_input_names = [] 27 | for i in self.get_inputs(): 28 | if i.name not in feed_input_names: 29 | missing_input_names.append(i.name) 30 | if missing_input_names: 31 | raise ValueError( 32 | f"Required inputs ({missing_input_names}) are missing from input feed ({feed_input_names}).") 33 | 34 | def _validate_output(self, output_names: list[str]): 35 | if output_names is not None: 36 | for name in output_names: 37 | if name not in [o.name for o in self.get_outputs()]: 38 | raise ValueError(f"Output name '{name}' is not in model outputs name list.") 39 | 40 | def get_inputs(self, shape_group: int = 0) -> list[NodeArg]: 41 | if shape_group > self._shape_count: 42 | raise ValueError(f"Shape group '{shape_group}' is out of range, total {self._shape_count}.") 43 | selected_info = self._inputs[shape_group] 44 | return selected_info 45 | 46 | def get_outputs(self, shape_group: int = 0) -> list[NodeArg]: 47 | if shape_group > self._shape_count: 48 | raise ValueError(f"Shape group '{shape_group}' is out of range, total {self._shape_count}.") 49 | selected_info = self._outputs[shape_group] 50 | return selected_info 51 | 52 | @abstractmethod 53 | def run( 54 | self, 55 | output_names: list[str] | None, 56 | input_feed: dict[str, np.ndarray], 57 | run_options=None 58 | ) -> list[np.ndarray]: 59 | pass 60 | -------------------------------------------------------------------------------- /axengine/_node.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved. 2 | # 3 | # This source file is the property of Axera Semiconductor Co., Ltd. and 4 | # may not be copied or distributed in any isomorphic form without the prior 5 | # written consent of Axera Semiconductor Co., Ltd. 6 | # 7 | 8 | 9 | class NodeArg(object): 10 | def __init__(self, name, dtype, shape): 11 | self.name = name 12 | self.dtype = dtype 13 | self.shape = shape 14 | -------------------------------------------------------------------------------- /axengine/_providers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved. 2 | # 3 | # This source file is the property of Axera Semiconductor Co., Ltd. and 4 | # may not be copied or distributed in any isomorphic form without the prior 5 | # written consent of Axera Semiconductor Co., Ltd. 6 | # 7 | 8 | import ctypes.util as cutil 9 | 10 | providers = [] 11 | axengine_provider_name = 'AxEngineExecutionProvider' 12 | axclrt_provider_name = 'AXCLRTExecutionProvider' 13 | 14 | _axengine_lib_name = 'ax_engine' 15 | _axclrt_lib_name = 'axcl_rt' 16 | 17 | # check if axcl_rt is installed, so if available, it's the default provider 18 | if cutil.find_library(_axclrt_lib_name) is not None: 19 | providers.append(axclrt_provider_name) 20 | 21 | # check if ax_engine is installed 22 | if cutil.find_library(_axengine_lib_name) is not None: 23 | providers.append(axengine_provider_name) 24 | 25 | 26 | def get_all_providers(): 27 | return [axengine_provider_name, axclrt_provider_name] 28 | 29 | 30 | def get_available_providers(): 31 | return providers 32 | -------------------------------------------------------------------------------- /axengine/_session.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved. 2 | # 3 | # This source file is the property of Axera Semiconductor Co., Ltd. and 4 | # may not be copied or distributed in any isomorphic form without the prior 5 | # written consent of Axera Semiconductor Co., Ltd. 6 | # 7 | 8 | import os 9 | from typing import Any, Sequence 10 | 11 | import numpy as np 12 | 13 | from ._base_session import SessionOptions 14 | from ._node import NodeArg 15 | from ._providers import axclrt_provider_name, axengine_provider_name 16 | from ._providers import get_available_providers 17 | 18 | 19 | class InferenceSession: 20 | def __init__( 21 | self, 22 | path_or_bytes: str | bytes | os.PathLike, 23 | sess_options: SessionOptions | None = None, 24 | providers: Sequence[str | tuple[str, dict[Any, Any]]] | None = None, 25 | provider_options: Sequence[dict[Any, Any]] | None = None, **kwargs, 26 | ) -> None: 27 | self._sess = None 28 | self._sess_options = sess_options 29 | self._provider = None 30 | self._provider_options = None 31 | self._available_providers = get_available_providers() 32 | 33 | # the providers should be available at least one, checked in __init__.py 34 | if providers is None: 35 | # using first available provider as default 36 | _provider_name = self._available_providers[0] 37 | self._provider = _provider_name 38 | else: 39 | # if only one provider is specified 40 | if isinstance(providers, str): 41 | if providers not in self._available_providers: 42 | raise ValueError(f"Selected provider: '{providers}' is not available.") 43 | self._provider = providers 44 | # if multiple providers are specified, using the first one as default 45 | elif isinstance(providers, list): 46 | _unavailable_provider = [] 47 | for p in providers: 48 | assert isinstance(p, str) or isinstance(p, tuple), \ 49 | f"Invalid provider type: {type(p)}. Must be str or tuple." 50 | if isinstance(p, str): 51 | if p not in self._available_providers: 52 | _unavailable_provider.append(p) 53 | elif self._provider is None: 54 | self._provider = p 55 | if isinstance(p, tuple): 56 | assert len(p) == 2, f"Invalid provider type: {p}. Must be tuple with 2 elements." 57 | assert isinstance(p[0], str), f"Invalid provider type: {type(p[0])}. Must be str." 58 | assert isinstance(p[1], dict), f"Invalid provider type: {type(p[1])}. Must be dict." 59 | if p[0] not in self._available_providers: 60 | _unavailable_provider.append(p[0]) 61 | elif self._provider is None: 62 | self._provider = p[0] 63 | # FIXME: check provider options 64 | self._provider_options = p[1] 65 | if _unavailable_provider: 66 | if self._provider is None: 67 | raise ValueError(f"Selected provider(s): {_unavailable_provider} is(are) not available.") 68 | else: 69 | print(f"[WARNING] Selected provider(s): {_unavailable_provider} is(are) not available.") 70 | 71 | # FIXME: can we remove this check? 72 | if self._provider is None: 73 | raise ValueError(f"No available provider found in {providers}.") 74 | print(f"[INFO] Using provider: {self._provider}") 75 | 76 | if self._provider == axclrt_provider_name: 77 | from ._axclrt import AXCLRTSession 78 | self._sess = AXCLRTSession(path_or_bytes, sess_options, provider_options, **kwargs) 79 | if self._provider == axengine_provider_name: 80 | from ._axe import AXEngineSession 81 | self._sess = AXEngineSession(path_or_bytes, sess_options, provider_options, **kwargs) 82 | if self._sess is None: 83 | raise RuntimeError(f"Create session failed with provider: {self._provider}") 84 | 85 | # add to support 'with' statement 86 | def __enter__(self): 87 | return self 88 | 89 | def __exit__(self, exc_type, exc_value, traceback): 90 | # not suppress exceptions 91 | return False 92 | 93 | def get_session_options(self): 94 | """ 95 | Return the session options. See :class:`axengine.SessionOptions`. 96 | """ 97 | return self._sess_options 98 | 99 | def get_providers(self): 100 | """ 101 | Return list of registered execution providers. 102 | """ 103 | return self._provider 104 | 105 | def get_inputs(self, shape_group: int = 0) -> list[NodeArg]: 106 | return self._sess.get_inputs(shape_group) 107 | 108 | def get_outputs(self, shape_group: int = 0) -> list[NodeArg]: 109 | return self._sess.get_outputs(shape_group) 110 | 111 | def run( 112 | self, 113 | output_names: list[str] | None, 114 | input_feed: dict[str, np.ndarray], 115 | run_options=None, 116 | shape_group: int = 0 117 | ) -> list[np.ndarray]: 118 | return self._sess.run(output_names, input_feed, run_options, shape_group) 119 | -------------------------------------------------------------------------------- /examples/classification.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved. 2 | # 3 | # This source file is the property of Axera Semiconductor Co., Ltd. and 4 | # may not be copied or distributed in any isomorphic form without the prior 5 | # written consent of Axera Semiconductor Co., Ltd. 6 | # 7 | 8 | import argparse 9 | import os 10 | import re 11 | import sys 12 | import time 13 | 14 | import numpy as np 15 | from PIL import Image 16 | 17 | import axengine as axe 18 | from axengine import axclrt_provider_name, axengine_provider_name 19 | 20 | 21 | def load_model(model_path: str | os.PathLike, selected_provider: str, selected_device_id: int = 0): 22 | if selected_provider == 'AUTO': 23 | # Use AUTO to let the pyengine choose the first available provider 24 | return axe.InferenceSession(model_path) 25 | 26 | providers = [] 27 | if selected_provider == axclrt_provider_name: 28 | provider_options = {"device_id": selected_device_id} 29 | providers.append((axclrt_provider_name, provider_options)) 30 | if selected_provider == axengine_provider_name: 31 | providers.append(axengine_provider_name) 32 | 33 | return axe.InferenceSession(model_path, providers=providers) 34 | 35 | 36 | def preprocess_image( 37 | image_path: str | os.PathLike, 38 | middle_step_size: (int, int) = (256, 256), 39 | final_step_size: (int, int) = (224, 224) 40 | ): 41 | # Load the image 42 | img = Image.open(image_path).convert("RGB") 43 | 44 | # Get original dimensions 45 | original_width, original_height = img.size 46 | 47 | # Determine the shorter side and calculate the center crop 48 | if original_width < original_height: 49 | crop_area = original_width 50 | else: 51 | crop_area = original_height 52 | 53 | crop_x = (original_width - crop_area) // 2 54 | crop_y = (original_height - crop_area) // 2 55 | 56 | # Crop the center square 57 | img = img.crop((crop_x, crop_y, crop_x + crop_area, crop_y + crop_area)) 58 | 59 | # Resize the image to 256x256 60 | img = img.resize(middle_step_size) 61 | 62 | # Crop the center 224x224 63 | crop_x = (middle_step_size[0] - final_step_size[0]) // 2 64 | crop_y = (middle_step_size[1] - final_step_size[1]) // 2 65 | img = img.crop((crop_x, crop_y, crop_x + final_step_size[0], crop_y + final_step_size[1])) 66 | 67 | # Convert to numpy array and change dtype to int 68 | img_array = np.array(img).astype("uint8") 69 | # Transpose to (1, C, H, W) 70 | # img_array = np.transpose(img_array, (2, 0, 1)) 71 | img_array = np.expand_dims(img_array, axis=0) # Add batch dimension 72 | return img_array 73 | 74 | 75 | def get_top_k_predictions(output: list[np.ndarray], k: int = 5): 76 | # Get top k predictions 77 | top_k_indices = np.argsort(output[0].flatten())[-k:][::-1] 78 | top_k_scores = output[0].flatten()[top_k_indices] 79 | return top_k_indices, top_k_scores 80 | 81 | 82 | def main(model_path, image_path, middle_step_size, final_step_size, k, repeat_times, selected_provider, 83 | selected_device_id): 84 | # Load the model 85 | session = load_model(model_path, selected_provider, selected_device_id) 86 | 87 | # Preprocess the image 88 | input_tensor = preprocess_image(image_path, middle_step_size, final_step_size) 89 | 90 | # Get input name and run inference 91 | input_name = session.get_inputs()[0].name 92 | time_costs = [] 93 | output = None 94 | for i in range(repeat_times): 95 | t1 = time.time() 96 | output = session.run(None, {input_name: input_tensor}) 97 | t2 = time.time() 98 | time_costs.append((t2 - t1) * 1000) 99 | 100 | # Get top k predictions 101 | top_k_indices, top_k_scores = get_top_k_predictions(output, k) 102 | 103 | # Print the results 104 | print(" ------------------------------------------------------") 105 | print(f" Top {k} Predictions:") 106 | for i in range(k): 107 | print(f" Class Index: {top_k_indices[i]:>3}, Score: {top_k_scores[i]:.3f}") 108 | 109 | print(" ------------------------------------------------------") 110 | print( 111 | f" min = {min(time_costs):.3f} ms max = {max(time_costs):.3f} ms avg = {sum(time_costs) / len(time_costs):.3f} ms" 112 | ) 113 | print(" ------------------------------------------------------") 114 | 115 | 116 | def parse_size(size_str): 117 | pattern = r'^\s*\d+\s*,\s*\d+\s*$' 118 | if not re.match(pattern, size_str): 119 | raise argparse.ArgumentTypeError(R'params should looks like: "height,width", such as: "256,256"') 120 | 121 | height, width = map(int, size_str.split(',')) 122 | return height, width 123 | 124 | 125 | class ExampleParser(argparse.ArgumentParser): 126 | def error(self, message): 127 | self.print_usage(sys.stderr) 128 | print(f"\nError: {message}") 129 | print("\nExample usage:") 130 | print(" python3 classification.py -m -i ") 131 | print(" python3 classification.py -m /opt/data/npu/models/mobilenetv2.axmodel -i /opt/data/npu/images/cat.jpg") 132 | print( 133 | f" python3 classification.py -m /opt/data/npu/models/mobilenetv2.axmodel -i /opt/data/npu/images/cat.jpg -p {axengine_provider_name}") 134 | print( 135 | f" python3 classification.py -m /opt/data/npu/models/mobilenetv2.axmodel -i /opt/data/npu/images/cat.jpg -p {axclrt_provider_name}") 136 | sys.exit(1) 137 | 138 | 139 | if __name__ == "__main__": 140 | ap = ExampleParser() 141 | ap.add_argument('-m', '--model-path', type=str, help='model path', required=True) 142 | ap.add_argument('-i', '--image-path', type=str, help='image path', required=True) 143 | ap.add_argument( 144 | '-s', 145 | '--resize-size', 146 | type=parse_size, 147 | help=R'imagenet resize size: "height,width", such as: "256,256"', 148 | default='256,256', 149 | ) 150 | ap.add_argument( 151 | '-c', 152 | '--crop-size', 153 | type=parse_size, 154 | help=R'imagenet crop size: "height,width", such as: "224,224"', 155 | default='224,224', 156 | ) 157 | ap.add_argument( 158 | '-k', 159 | '--top-k', 160 | type=int, 161 | help='top k predictions', 162 | default=5 163 | ) 164 | ap.add_argument('-r', '--repeat', type=int, help='repeat times', default=100) 165 | ap.add_argument( 166 | '-p', 167 | '--provider', 168 | type=str, 169 | choices=["AUTO", f"{axclrt_provider_name}", f"{axengine_provider_name}"], 170 | help=f'"AUTO", "{axclrt_provider_name}", "{axengine_provider_name}"', 171 | default='AUTO' 172 | ) 173 | ap.add_argument( 174 | '-d', 175 | '--device-id', 176 | type=int, 177 | help=R'axclrt device index, depends on how many cards inserted', 178 | default=0 179 | ) 180 | args = ap.parse_args() 181 | 182 | model_file = args.model_path 183 | image_file = args.image_path 184 | 185 | # check if the model and image exist 186 | assert os.path.exists(model_file), f"model file path {model_file} does not exist" 187 | assert os.path.exists(image_file), f"image file path {image_file} does not exist" 188 | 189 | resize_size = args.resize_size 190 | crop_size = args.crop_size 191 | top_k = args.top_k 192 | 193 | repeat = args.repeat 194 | 195 | provider = args.provider 196 | device_id = args.device_id 197 | 198 | main(model_file, image_file, resize_size, crop_size, top_k, repeat, provider, device_id) 199 | -------------------------------------------------------------------------------- /examples/yolov5.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved. 2 | # 3 | # This source file is the property of Axera Semiconductor Co., Ltd. and 4 | # may not be copied or distributed in any isomorphic form without the prior 5 | # written consent of Axera Semiconductor Co., Ltd. 6 | # 7 | 8 | import argparse 9 | import colorsys 10 | import os 11 | import random 12 | import sys 13 | import time 14 | 15 | import cv2 16 | import numpy as np 17 | 18 | import axengine as axe 19 | from axengine import axclrt_provider_name, axengine_provider_name 20 | 21 | CONF_THRESH = 0.45 22 | IOU_THRESH = 0.45 23 | STRIDES = [8, 16, 32] 24 | ANCHORS = [ 25 | [10, 13, 16, 30, 33, 23], 26 | [30, 61, 62, 45, 59, 119], 27 | [116, 90, 156, 198, 373, 326], 28 | ] 29 | NUM_OUTPUTS = 85 30 | INPUT_SHAPE = [640, 640] 31 | 32 | CLASS_NAMES = [ 33 | "person", 34 | "bicycle", 35 | "car", 36 | "motorcycle", 37 | "airplane", 38 | "bus", 39 | "train", 40 | "truck", 41 | "boat", 42 | "traffic light", 43 | "fire hydrant", 44 | "stop sign", 45 | "parking meter", 46 | "bench", 47 | "bird", 48 | "cat", 49 | "dog", 50 | "horse", 51 | "sheep", 52 | "cow", 53 | "elephant", 54 | "bear", 55 | "zebra", 56 | "giraffe", 57 | "backpack", 58 | "umbrella", 59 | "handbag", 60 | "tie", 61 | "suitcase", 62 | "frisbee", 63 | "skis", 64 | "snowboard", 65 | "sports ball", 66 | "kite", 67 | "baseball bat", 68 | "baseball glove", 69 | "skateboard", 70 | "surfboard", 71 | "tennis racket", 72 | "bottle", 73 | "wine glass", 74 | "cup", 75 | "fork", 76 | "knife", 77 | "spoon", 78 | "bowl", 79 | "banana", 80 | "apple", 81 | "sandwich", 82 | "orange", 83 | "broccoli", 84 | "carrot", 85 | "hot dog", 86 | "pizza", 87 | "donut", 88 | "cake", 89 | "chair", 90 | "couch", 91 | "potted plant", 92 | "bed", 93 | "dining table", 94 | "toilet", 95 | "tv", 96 | "laptop", 97 | "mouse", 98 | "remote", 99 | "keyboard", 100 | "cell phone", 101 | "microwave", 102 | "oven", 103 | "toaster", 104 | "sink", 105 | "refrigerator", 106 | "book", 107 | "clock", 108 | "vase", 109 | "scissors", 110 | "teddy bear", 111 | "hair drier", 112 | "toothbrush", 113 | ] 114 | 115 | COCO_CATEGORIES = { 116 | "person": 1, 117 | "bicycle": 2, 118 | "car": 3, 119 | "motorcycle": 4, 120 | "airplane": 5, 121 | "bus": 6, 122 | "train": 7, 123 | "truck": 8, 124 | "boat": 9, 125 | "traffic light": 10, 126 | "fire hydrant": 11, 127 | "stop sign": 13, 128 | "parking meter": 14, 129 | "bench": 15, 130 | "bird": 16, 131 | "cat": 17, 132 | "dog": 18, 133 | "horse": 19, 134 | "sheep": 20, 135 | "cow": 21, 136 | "elephant": 22, 137 | "bear": 23, 138 | "zebra": 24, 139 | "giraffe": 25, 140 | "backpack": 27, 141 | "umbrella": 28, 142 | "handbag": 31, 143 | "tie": 32, 144 | "suitcase": 33, 145 | "frisbee": 34, 146 | "skis": 35, 147 | "snowboard": 36, 148 | "sports ball": 37, 149 | "kite": 38, 150 | "baseball bat": 39, 151 | "baseball glove": 40, 152 | "skateboard": 41, 153 | "surfboard": 42, 154 | "tennis racket": 43, 155 | "bottle": 44, 156 | "wine glass": 46, 157 | "cup": 47, 158 | "fork": 48, 159 | "knife": 49, 160 | "spoon": 50, 161 | "bowl": 51, 162 | "banana": 52, 163 | "apple": 53, 164 | "sandwich": 54, 165 | "orange": 55, 166 | "broccoli": 56, 167 | "carrot": 57, 168 | "hot dog": 58, 169 | "pizza": 59, 170 | "donut": 60, 171 | "cake": 61, 172 | "chair": 62, 173 | "couch": 63, 174 | "potted plant": 64, 175 | "bed": 65, 176 | "dining table": 67, 177 | "toilet": 70, 178 | "tv": 72, 179 | "laptop": 73, 180 | "mouse": 74, 181 | "remote": 75, 182 | "keyboard": 76, 183 | "cell phone": 77, 184 | "microwave": 78, 185 | "oven": 79, 186 | "toaster": 80, 187 | "sink": 81, 188 | "refrigerator": 82, 189 | "book": 84, 190 | "clock": 85, 191 | "vase": 86, 192 | "scissors": 87, 193 | "teddy bear": 88, 194 | "hair drier": 89, 195 | "toothbrush": 90, 196 | } 197 | 198 | 199 | def letterbox_yolov5( 200 | im, 201 | new_shape=(640, 640), 202 | color=(114, 114, 114), 203 | auto=True, 204 | scaleFill=False, 205 | scaleup=True, 206 | stride=32, 207 | ): 208 | # Resize and pad image while meeting stride-multiple constraints 209 | shape = im.shape[:2] # current shape [height, width] 210 | if isinstance(new_shape, int): 211 | new_shape = (new_shape, new_shape) 212 | 213 | # Scale ratio (new / old) 214 | r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) 215 | if not scaleup: # only scale down, do not scale up (for better val mAP) 216 | r = min(r, 1.0) 217 | 218 | # Compute padding 219 | ratio = r, r # width, height ratios 220 | new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) 221 | dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding 222 | if auto: # minimum rectangle 223 | dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding 224 | elif scaleFill: # stretch 225 | dw, dh = 0.0, 0.0 226 | new_unpad = (new_shape[1], new_shape[0]) 227 | ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios 228 | 229 | dw /= 2 # divide padding into 2 sides 230 | dh /= 2 231 | 232 | if shape[::-1] != new_unpad: # resize 233 | im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR) 234 | top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) 235 | left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) 236 | im = cv2.copyMakeBorder( 237 | im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color 238 | ) # add border 239 | return im, ratio, (dw, dh) 240 | 241 | 242 | def pre_processing(image_raw, img_shape): 243 | img = letterbox_yolov5(image_raw, img_shape, stride=32, auto=False)[0] 244 | img = img[:, :, ::-1] 245 | img = img[np.newaxis, ...] 246 | origin_shape = image_raw.shape[0:2] 247 | return img, origin_shape 248 | 249 | 250 | def draw_bbox(image, bboxes, classes=None, show_label=True, threshold=0.1): 251 | """ 252 | bboxes: [x_min, y_min, x_max, y_max, probability, cls_id] format coordinates. 253 | """ 254 | if classes == None: 255 | classes = {v: k for k, v in COCO_CATEGORIES.items()} 256 | 257 | num_classes = len(classes) 258 | image_h, image_w, _ = image.shape 259 | hsv_tuples = [(1.0 * x / num_classes, 1.0, 1.0) for x in range(num_classes)] 260 | colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples)) 261 | colors = list( 262 | map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), colors) 263 | ) 264 | 265 | random.seed(0) 266 | random.shuffle(colors) 267 | random.seed(None) 268 | 269 | for i, bbox in enumerate(bboxes): 270 | coor = np.array(bbox[:4], dtype=np.int32) 271 | fontScale = 0.5 272 | score = bbox[4] 273 | if score < threshold: 274 | continue 275 | class_ind = int(bbox[5]) 276 | bbox_color = colors[class_ind] 277 | bbox_thick = int(0.6 * (image_h + image_w) / 600) 278 | c1, c2 = (coor[0], coor[1]), (coor[2], coor[3]) 279 | cv2.rectangle(image, c1, c2, bbox_color, bbox_thick) 280 | print( 281 | f" {class_ind:>3}: {CLASS_NAMES[class_ind]:<10}: {coor}, score: {score*100:3.2f}%" 282 | ) 283 | if show_label: 284 | bbox_mess = "%s: %.2f" % (CLASS_NAMES[class_ind], score) 285 | t_size = cv2.getTextSize( 286 | bbox_mess, 0, fontScale, thickness=bbox_thick // 2 287 | )[0] 288 | cv2.rectangle(image, c1, (c1[0] + t_size[0], c1[1] - t_size[1] - 3), bbox_color, -1) 289 | 290 | cv2.putText( 291 | image, 292 | bbox_mess, 293 | (c1[0], c1[1] - 2), 294 | cv2.FONT_HERSHEY_SIMPLEX, 295 | fontScale, 296 | (0, 0, 0), 297 | bbox_thick // 2, 298 | lineType=cv2.LINE_AA, 299 | ) 300 | 301 | return image 302 | 303 | 304 | def xywh2xyxy(x): 305 | # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right 306 | y = np.copy(x) 307 | y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x 308 | y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y 309 | y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x 310 | y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y 311 | return y 312 | 313 | 314 | def bboxes_iou(boxes1, boxes2): 315 | """calculate the Intersection Over Union value""" 316 | boxes1 = np.array(boxes1) 317 | boxes2 = np.array(boxes2) 318 | 319 | boxes1_area = (boxes1[..., 2] - boxes1[..., 0]) * (boxes1[..., 3] - boxes1[..., 1]) 320 | boxes2_area = (boxes2[..., 2] - boxes2[..., 0]) * (boxes2[..., 3] - boxes2[..., 1]) 321 | 322 | left_up = np.maximum(boxes1[..., :2], boxes2[..., :2]) 323 | right_down = np.minimum(boxes1[..., 2:], boxes2[..., 2:]) 324 | 325 | inter_section = np.maximum(right_down - left_up, 0.0) 326 | inter_area = inter_section[..., 0] * inter_section[..., 1] 327 | union_area = boxes1_area + boxes2_area - inter_area 328 | ious = np.maximum(1.0 * inter_area / union_area, np.finfo(np.float32).eps) 329 | 330 | return ious 331 | 332 | 333 | def nms(proposals, iou_threshold, conf_threshold, multi_label=False): 334 | """ 335 | :param bboxes: (xmin, ymin, xmax, ymax, score, class) 336 | 337 | Note: soft-nms, https://arxiv.org/pdf/1704.04503.pdf 338 | https://github.com/bharatsingh430/soft-nms 339 | """ 340 | xc = proposals[..., 4] > conf_threshold 341 | proposals = proposals[xc] 342 | proposals[:, 5:] *= proposals[:, 4:5] 343 | bboxes = xywh2xyxy(proposals[:, :4]) 344 | if multi_label: 345 | mask = proposals[:, 5:] > conf_threshold 346 | nonzero_indices = np.argwhere(mask) 347 | if nonzero_indices.size < 0: 348 | return 349 | i, j = nonzero_indices.T 350 | bboxes = np.hstack( 351 | (bboxes[i], proposals[i, j + 5][:, None], j[:, None].astype(float)) 352 | ) 353 | else: 354 | confidences = proposals[:, 5:] 355 | conf = confidences.max(axis=1, keepdims=True) 356 | j = confidences.argmax(axis=1)[:, None] 357 | 358 | new_x_parts = [bboxes, conf, j.astype(float)] 359 | bboxes = np.hstack(new_x_parts) 360 | 361 | mask = conf.reshape(-1) > conf_threshold 362 | bboxes = bboxes[mask] 363 | 364 | classes_in_img = list(set(bboxes[:, 5])) 365 | bboxes = bboxes[bboxes[:, 4].argsort()[::-1][:300]] 366 | best_bboxes = [] 367 | 368 | for cls in classes_in_img: 369 | cls_mask = bboxes[:, 5] == cls 370 | cls_bboxes = bboxes[cls_mask] 371 | 372 | while len(cls_bboxes) > 0: 373 | max_ind = np.argmax(cls_bboxes[:, 4]) 374 | best_bbox = cls_bboxes[max_ind] 375 | best_bboxes.append(best_bbox) 376 | cls_bboxes = np.concatenate( 377 | [cls_bboxes[:max_ind], cls_bboxes[max_ind + 1:]] 378 | ) 379 | iou = bboxes_iou(best_bbox[np.newaxis, :4], cls_bboxes[:, :4]) 380 | weight = np.ones((len(iou),), dtype=np.float32) 381 | 382 | iou_mask = iou > iou_threshold 383 | weight[iou_mask] = 0.0 384 | 385 | cls_bboxes[:, 4] = cls_bboxes[:, 4] * weight 386 | score_mask = cls_bboxes[:, 4] > 0.0 387 | cls_bboxes = cls_bboxes[score_mask] 388 | best_bboxes = np.vstack(best_bboxes) 389 | best_bboxes = best_bboxes[best_bboxes[:, 4].argsort()[::-1]] 390 | return best_bboxes 391 | 392 | 393 | def sigmoid(x): 394 | return 1.0 / (np.exp(-x) + 1.0) 395 | 396 | 397 | def gen_proposals(outputs): 398 | new_pred = [] 399 | anchor_grid = np.array(ANCHORS).reshape(-1, 1, 1, 3, 2) 400 | for i, pred in enumerate(outputs): 401 | pred = sigmoid(pred) 402 | n, h, w, c = pred.shape 403 | 404 | pred = pred.reshape(n, h, w, 3, 85) 405 | conv_shape = pred.shape 406 | output_size = conv_shape[1] 407 | conv_raw_dxdy = pred[..., 0:2] 408 | conv_raw_dwdh = pred[..., 2:4] 409 | xy_grid = np.meshgrid(np.arange(output_size), np.arange(output_size)) 410 | xy_grid = np.expand_dims(np.stack(xy_grid, axis=-1), axis=2) 411 | 412 | xy_grid = np.tile(np.expand_dims(xy_grid, axis=0), [1, 1, 1, 3, 1]) 413 | xy_grid = xy_grid.astype(np.float32) 414 | pred_xy = (conv_raw_dxdy * 2.0 - 0.5 + xy_grid) * STRIDES[i] 415 | pred_wh = (conv_raw_dwdh * 2) ** 2 * anchor_grid[i] 416 | pred[:, :, :, :, 0:4] = np.concatenate([pred_xy, pred_wh], axis=-1) 417 | 418 | new_pred.append(np.reshape(pred, (-1, np.shape(pred)[-1]))) 419 | 420 | return np.concatenate(new_pred, axis=0) 421 | 422 | 423 | def clip_coords(boxes, shape): 424 | boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2 425 | boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2 426 | 427 | 428 | def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None): 429 | if ratio_pad is None: 430 | gain = min( 431 | img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1] 432 | ) # gain = old / new 433 | pad = (img1_shape[1] - img0_shape[1] * gain) / 2, ( 434 | img1_shape[0] - img0_shape[0] * gain 435 | ) / 2 # wh padding 436 | else: 437 | gain = ratio_pad[0][0] 438 | pad = ratio_pad[1] 439 | 440 | coords[:, [0, 2]] -= pad[0] 441 | coords[:, [1, 3]] -= pad[1] 442 | coords[:, :4] /= gain 443 | clip_coords(coords, img0_shape) 444 | return coords 445 | 446 | 447 | def post_processing(outputs, origin_shape, input_shape): 448 | proposals = gen_proposals(outputs) 449 | pred = nms( 450 | proposals, IOU_THRESH, CONF_THRESH, multi_label=True 451 | ) # set multi_label to true for testing map and then cost more time. 452 | pred[:, :4] = scale_coords(input_shape, pred[:, :4], origin_shape) 453 | return pred 454 | 455 | 456 | def detect_yolov5(model_path, image_path, save_path, repeat_times, selected_provider='AUTO', selected_device_id=0): 457 | if selected_provider == 'AUTO': 458 | # Use AUTO to let the pyengine choose the first available provider 459 | session = axe.InferenceSession(model_path) 460 | else: 461 | providers = [] 462 | if selected_provider == axclrt_provider_name: 463 | provider_options = {"device_id": selected_device_id} 464 | providers.append((axclrt_provider_name, provider_options)) 465 | if selected_provider == axengine_provider_name: 466 | providers.append(axengine_provider_name) 467 | session = axe.InferenceSession(model_path, providers=providers) 468 | 469 | image_data = cv2.imread(image_path) 470 | inputs, origin_shape = pre_processing(image_data, (640, 640)) 471 | inputs = np.ascontiguousarray(inputs) 472 | 473 | print(" ------------------------------------------------------") 474 | time_costs = [] 475 | results = None 476 | for i in range(repeat_times): 477 | t1 = time.time() 478 | results = session.run(None, {"images": inputs}) 479 | t2 = time.time() 480 | time_costs.append((t2 - t1) * 1000) 481 | 482 | det = post_processing(results, origin_shape, (640, 640)) 483 | ret_image = draw_bbox(image_data, det) 484 | cv2.imwrite(save_path, ret_image) 485 | 486 | print(" ------------------------------------------------------") 487 | print( 488 | f" min = {min(time_costs):.3f} ms max = {max(time_costs):.3f} ms avg = {sum(time_costs) / len(time_costs):.3f} ms" 489 | ) 490 | print(" ------------------------------------------------------") 491 | 492 | 493 | class ExampleParser(argparse.ArgumentParser): 494 | def error(self, message): 495 | self.print_usage(sys.stderr) 496 | print(f"\nError: {message}") 497 | print("\nExample usage:") 498 | print(" python3 yolov5.py -m -i ") 499 | print(" python3 yolov5.py -m /opt/data/npu/models/yolov5s.axmodel -i /opt/data/npu/images/dog.jpg") 500 | print( 501 | f" python3 yolov5.py -m /opt/data/npu/models/yolov5s.axmodel -i /opt/data/npu/images/dog.jpg -p {axengine_provider_name}" 502 | ) 503 | print( 504 | f" python3 yolov5.py -m /opt/data/npu/models/yolov5s.axmodel -i /opt/data/npu/images/dog.jpg -p {axclrt_provider_name}" 505 | ) 506 | sys.exit(1) 507 | 508 | 509 | if __name__ == "__main__": 510 | ap = ExampleParser(description="YOLOv5 example") 511 | ap.add_argument('-m', '--model-path', type=str, help='model path', required=True) 512 | ap.add_argument('-i', '--image-path', type=str, help='image path', required=True) 513 | ap.add_argument( 514 | '-s', "--save-path", type=str, default="YOLOv5_OUT.jpg", help="detected output image save path" 515 | ) 516 | ap.add_argument('-r', '--repeat', type=int, help='repeat times', default=10) 517 | ap.add_argument( 518 | '-p', 519 | '--provider', 520 | type=str, 521 | choices=["AUTO", f"{axclrt_provider_name}", f"{axengine_provider_name}"], 522 | help=f'"AUTO", "{axclrt_provider_name}", "{axengine_provider_name}"', 523 | default='AUTO' 524 | ) 525 | ap.add_argument( 526 | '-d', 527 | '--device-id', 528 | type=int, 529 | help=R'axclrt device index, depends on how many cards inserted', 530 | default=0 531 | ) 532 | args = ap.parse_args() 533 | 534 | model_file = args.model_path 535 | image_file = args.image_path 536 | 537 | # check if the model and image exist 538 | assert os.path.exists(model_file), f"model file path {model_file} does not exist" 539 | assert os.path.exists(image_file), f"image file path {image_file} does not exist" 540 | 541 | save_path = args.save_path 542 | 543 | repeat = args.repeat 544 | 545 | provider = args.provider 546 | device_id = args.device_id 547 | 548 | detect_yolov5(model_file, image_file, save_path, repeat, provider, device_id) 549 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved. 2 | # 3 | # This source file is the property of Axera Semiconductor Co., Ltd. and 4 | # may not be copied or distributed in any isomorphic form without the prior 5 | # written consent of Axera Semiconductor Co., Ltd. 6 | # 7 | 8 | from setuptools import setup 9 | 10 | setup( 11 | name="axengine", 12 | version="0.1.3", 13 | classifiers=[ 14 | "Development Status :: 1 - Alpha", 15 | "License :: OSI Approved :: BSD License", 16 | "Programming Language :: Python :: 3", 17 | "Programming Language :: Python :: 3.8", 18 | "Programming Language :: Python :: 3.9", 19 | "Programming Language :: Python :: 3.10", 20 | "Programming Language :: Python :: 3.11", 21 | "Programming Language :: Python :: 3.12", 22 | "Programming Language :: Python :: 3.13", 23 | "Programming Language :: Python :: Implementation :: PyPy", 24 | ], 25 | packages=["axengine"], 26 | ext_modules=[], 27 | install_requires=["cffi>=1.0.0", "ml-dtypes>=0.1.0", "numpy>=1.22"], 28 | setup_requires=["cffi>=1.0.0", "ml-dtypes>=0.1.0", "numpy>=1.22"], 29 | ) 30 | --------------------------------------------------------------------------------