├── .gitignore ├── CMakeLists.txt ├── llama └── llama2 │ ├── 13B │ └── params.json │ ├── 7B │ └── params.json │ └── tokenizer.model ├── orangepi_install.md ├── prebuild ├── defs.hpp ├── libnpu_ops.so └── npu_ops.h ├── prompts ├── chat_example.txt └── readme.md ├── readme.md ├── scripts ├── benchmark_qwen2.5.sh ├── convert_llama2_weight.py ├── convert_llama_awq_4bit.py ├── convert_qwen2_awq_weight.py ├── convert_qwen2_weight.py ├── example_chat_llama2_13B_awq_4bit_orangepi.sh ├── example_chat_llama2_7B_awq_4bit_orangepi.sh ├── example_chat_llama2_7B_fp16_orangepi.sh ├── example_chat_qwen2.5_3B_bf16_orangepi.sh ├── example_orangepi_debug.sh ├── example_orangepi_msprof.sh ├── example_orangepi_msprof_awq_4bit.sh ├── example_orangepi_profiling.sh ├── example_orangepi_profiling_awq_4bit.sh ├── example_text_completion_deepseek_r1_qwen2.5_1.5B_bf16_orangepi.sh ├── example_text_completion_deepseek_r1_qwen2.5_14B_bf16_orangepi.sh ├── example_text_completion_deepseek_r1_qwen2.5_7B_bf16_orangepi.sh ├── example_text_completion_llama2_13B_awq_4bit_orangepi.sh ├── example_text_completion_llama2_13B_awq_4bit_orangepi_2.sh ├── example_text_completion_llama2_7B_awq_4bit_orangepi.sh ├── example_text_completion_llama2_7B_awq_4bit_orangepi_2.sh ├── example_text_completion_llama2_7B_fp16_orangepi.sh ├── example_text_completion_llama2_7B_fp16_orangepi_2.sh ├── example_text_completion_qwen2.5_14B_awq_orangepi.sh ├── example_text_completion_qwen2.5_32B_awq_orangepi.sh ├── example_text_completion_qwen2.5_3B_awq_orangepi.sh ├── example_text_completion_qwen2.5_3B_bf16_orangepi.sh ├── example_text_completion_qwen2.5_7B_awq_orangepi.sh └── example_text_completion_qwen2.5_7B_bf16_orangepi.sh ├── src ├── .gitignore ├── CMakeLists.txt ├── acl_util.hpp ├── base64.h ├── device.cpp ├── device.hpp ├── device_cpu.cpp ├── device_gpu.cpp ├── device_npu.cpp ├── llama2_layer_cpu.cpp ├── llama2_layer_cpu.hpp ├── llama2_layer_npu.cpp ├── llama2_layer_npu.hpp ├── llama2_main.cpp ├── llama2_model.cpp ├── llama2_model.hpp ├── model_base.cpp ├── model_base.hpp ├── profiling.cpp ├── profiling.hpp ├── qwen2_model.cpp ├── qwen2_model.hpp ├── tiktoken.h ├── tokenizer.cpp ├── tokenizer.hpp ├── unordered_dense.h ├── util.cpp └── util.h └── tests ├── CMakeLists.txt ├── embedding_main.cpp ├── embedding_test.cpp ├── embedding_test.h ├── flash_attn_main.cpp ├── flash_attn_test.cpp ├── flash_attn_test.h ├── gemm_awq_4bit_main.cpp ├── gemm_awq_4bit_test.cpp ├── gemm_awq_4bit_test.h ├── gemm_main.cpp ├── gemm_test.cpp ├── gemm_test.h ├── npu_op_test_util.cpp ├── npu_op_test_util.h ├── npu_operator_test.cpp ├── rms_norm_layer_main.cpp ├── rms_norm_layer_test.cpp ├── rms_norm_layer_test.h ├── rope_single_layer_main.cpp ├── rope_single_layer_test.cpp └── rope_single_layer_test.h /.gitignore: -------------------------------------------------------------------------------- 1 | log.txt 2 | cpu_log.txt 3 | test_tiny_orangepi.sh 4 | *.data 5 | *.bin 6 | *.json 7 | build/ 8 | model_output/ 9 | profiling_output/ 10 | 11 | .clangd 12 | 13 | # Prerequisites 14 | *.d 15 | 16 | # Compiled Object files 17 | *.slo 18 | *.lo 19 | *.o 20 | *.obj 21 | 22 | # Precompiled Headers 23 | *.gch 24 | *.pch 25 | 26 | # Compiled Dynamic libraries 27 | *.so 28 | *.dylib 29 | *.dll 30 | 31 | # Fortran module files 32 | *.mod 33 | *.smod 34 | 35 | # Compiled Static libraries 36 | *.lai 37 | *.la 38 | *.a 39 | *.lib 40 | 41 | # Executables 42 | *.exe 43 | *.out 44 | *.app 45 | 46 | # Byte-compiled / optimized / DLL files 47 | __pycache__/ 48 | *.py[cod] 49 | *$py.class 50 | 51 | # C extensions 52 | *.so 53 | 54 | # Distribution / packaging 55 | .Python 56 | build/ 57 | develop-eggs/ 58 | dist/ 59 | downloads/ 60 | eggs/ 61 | .eggs/ 62 | lib/ 63 | lib64/ 64 | parts/ 65 | sdist/ 66 | var/ 67 | wheels/ 68 | share/python-wheels/ 69 | *.egg-info/ 70 | .installed.cfg 71 | *.egg 72 | MANIFEST 73 | 74 | # PyInstaller 75 | # Usually these files are written by a python script from a template 76 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 77 | *.manifest 78 | *.spec 79 | 80 | .vscode/ 81 | test.sh 82 | 83 | # Installer logs 84 | pip-log.txt 85 | pip-delete-this-directory.txt 86 | 87 | # Unit test / coverage reports 88 | htmlcov/ 89 | .tox/ 90 | .nox/ 91 | .coverage 92 | .coverage.* 93 | .cache 94 | nosetests.xml 95 | coverage.xml 96 | *.cover 97 | *.py,cover 98 | .hypothesis/ 99 | .pytest_cache/ 100 | cover/ 101 | 102 | # Translations 103 | *.mo 104 | *.pot 105 | 106 | # Django stuff: 107 | *.log 108 | local_settings.py 109 | db.sqlite3 110 | db.sqlite3-journal 111 | 112 | # Flask stuff: 113 | instance/ 114 | .webassets-cache 115 | 116 | # Scrapy stuff: 117 | .scrapy 118 | 119 | # Sphinx documentation 120 | docs/_build/ 121 | 122 | # PyBuilder 123 | .pybuilder/ 124 | target/ 125 | 126 | # Jupyter Notebook 127 | .ipynb_checkpoints 128 | 129 | # IPython 130 | profile_default/ 131 | ipython_config.py 132 | 133 | # pyenv 134 | # For a library or package, you might want to ignore these files since the code is 135 | # intended to run in multiple environments; otherwise, check them in: 136 | # .python-version 137 | 138 | # pipenv 139 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 140 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 141 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 142 | # install all needed dependencies. 143 | #Pipfile.lock 144 | 145 | # poetry 146 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 147 | # This is especially recommended for binary packages to ensure reproducibility, and is more 148 | # commonly ignored for libraries. 149 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 150 | #poetry.lock 151 | 152 | # pdm 153 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 154 | #pdm.lock 155 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 156 | # in version control. 157 | # https://pdm.fming.dev/#use-with-ide 158 | .pdm.toml 159 | 160 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 161 | __pypackages__/ 162 | 163 | # Celery stuff 164 | celerybeat-schedule 165 | celerybeat.pid 166 | 167 | # SageMath parsed files 168 | *.sage.py 169 | 170 | # Environments 171 | .env 172 | .venv 173 | env/ 174 | venv/ 175 | ENV/ 176 | env.bak/ 177 | venv.bak/ 178 | 179 | # Spyder project settings 180 | .spyderproject 181 | .spyproject 182 | 183 | # Rope project settings 184 | .ropeproject 185 | 186 | # mkdocs documentation 187 | /site 188 | 189 | # mypy 190 | .mypy_cache/ 191 | .dmypy.json 192 | dmypy.json 193 | 194 | # Pyre type checker 195 | .pyre/ 196 | 197 | # pytype static type analyzer 198 | .pytype/ 199 | 200 | # Cython debug symbols 201 | cython_debug/ 202 | 203 | # PyCharm 204 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 205 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 206 | # and can be added to the global gitignore or merged into this file. For a more nuclear 207 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 208 | #.idea/ 209 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.2) 2 | 3 | project(llm_simple CXX) 4 | cmake_policy(SET CMP0128 OLD) 5 | 6 | set(CMAKE_CXX_STANDARD 20) 7 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++20") 8 | 9 | add_subdirectory(src) 10 | add_subdirectory(tests) -------------------------------------------------------------------------------- /llama/llama2/13B/params.json: -------------------------------------------------------------------------------- 1 | {"dim": 5120, "multiple_of": 256, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-06, "vocab_size": -1} 2 | -------------------------------------------------------------------------------- /llama/llama2/7B/params.json: -------------------------------------------------------------------------------- 1 | {"dim": 4096, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-06, "vocab_size": -1} -------------------------------------------------------------------------------- /llama/llama2/tokenizer.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lenLRX/llm_simple/c8466b22f27f2026d63747bcf0856a0e7f61e126/llama/llama2/tokenizer.model -------------------------------------------------------------------------------- /orangepi_install.md: -------------------------------------------------------------------------------- 1 | # 安装依赖(root权限) 2 | ```apt install libeigen3-dev libsentencepiece-dev libboost-program-options-dev libboost-system-dev libboost-filesystem-dev libgtest-dev libspdlog-dev nlohmann-json3-dev libre2-dev``` 3 | # 安装python依赖 4 | ```pip install transformers ml_dtypes``` 5 | # 编译 6 | ``` 7 | mkdir build 8 | cd build 9 | cmake .. -DCMAKE_BUILD_TYPE=Release 10 | make -j4 11 | ``` 12 | -------------------------------------------------------------------------------- /prebuild/defs.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | enum DeviceType { DEV_CPU = 0, DEV_GPU, DEV_NPU }; 7 | 8 | enum DataType : int32_t { 9 | DT_UINT8 = 0, 10 | DT_INT8, 11 | DT_UINT32, 12 | DT_INT32, 13 | DT_FLOAT16, 14 | DT_BFLOAT16, 15 | DT_FLOAT32, 16 | DT_UINT64, 17 | DT_INT64, 18 | }; 19 | 20 | inline static size_t SizeOfTensor(size_t size, DataType dt) { 21 | switch (dt) { 22 | case DT_INT8: 23 | case DT_UINT8: 24 | return size * sizeof(uint8_t); 25 | break; 26 | case DT_INT32: 27 | case DT_UINT32: 28 | return size * sizeof(uint32_t); 29 | case DT_FLOAT16: 30 | case DT_BFLOAT16: 31 | return size * sizeof(uint16_t); 32 | case DT_FLOAT32: 33 | return size * sizeof(uint32_t); 34 | case DT_UINT64: 35 | case DT_INT64: 36 | return size * sizeof(uint64_t); 37 | default: 38 | break; 39 | } 40 | return -1; 41 | } 42 | -------------------------------------------------------------------------------- /prebuild/libnpu_ops.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lenLRX/llm_simple/c8466b22f27f2026d63747bcf0856a0e7f61e126/prebuild/libnpu_ops.so -------------------------------------------------------------------------------- /prebuild/npu_ops.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "acl/acl.h" 3 | 4 | #include "defs.hpp" 5 | 6 | void npu_flash_attn_layer(void *output_dev, void *q_dev, void *k_dev, 7 | void *v_dev, int m, int n, int offset, int head_num, 8 | int head_dim, DataType dt, aclrtStream &stream); 9 | 10 | void npu_flash_attn_gqa_layer(void *output_dev, void *q_dev, void *k_dev, 11 | void *v_dev, int m, int n, int offset, 12 | int group_size, int kv_head_num, int head_dim, 13 | DataType dt, aclrtStream &stream); 14 | 15 | void npu_flash_attn_opt_prefill_layer(void *output_dev, void *q_dev, 16 | void *k_dev, void *v_dev, int m, int n, 17 | int offset, int head_num, int head_dim, 18 | DataType dt, aclrtStream &stream); 19 | 20 | void npu_embedding_layer(void *output_dev, void *weight_dev, void *index_dev, 21 | int seqlen, int hidden_dim, DataType dt, 22 | aclrtStream &stream); 23 | void npu_gather_layer(void *output_dev, void *data_dev, void *index_dev, 24 | int index_num, int last_dim, DataType index_dt, 25 | DataType data_dt, aclrtStream &stream); 26 | 27 | void npu_split_qkv_layer(void *output_q_dev, void *output_k_dev, 28 | void *output_v_dev, void *input_qkv, int batch, 29 | int q_dim, int k_dim, int v_dim, DataType dt, 30 | aclrtStream &stream); 31 | 32 | void npu_rmsnorm_layer(void *output_dev, void *weight_dev, void *input_dev, 33 | int cur_size, int hidden_dim, float eps, DataType dt, 34 | aclrtStream &stream); 35 | 36 | void npu_softmax_layer(void *output_dev, void *input_dev, int n_head_cur_size, 37 | int cur_size, DataType dt, aclrtStream &stream); 38 | 39 | void npu_rope_layer(void *output_q_dev, void *output_k_dev, void *freqs_cis_dev, 40 | void *input_q_dev, void *input_k_dev, int start_pos, 41 | int cur_size, int n_heads, int hidden_dim, 42 | bool is_neox_style, DataType dt, aclrtStream &stream); 43 | 44 | void npu_rope_single_layer(void *output_x_dev, void *freqs_cis_dev, 45 | void *input_x_dev, int start_pos, int cur_size, 46 | int n_heads, int hidden_dim, bool is_neox_style, 47 | DataType dt, aclrtStream &stream); 48 | 49 | void npu_rope_layer_vllm(void *output_q_dev, void *output_k_dev, 50 | void *freqs_cis_dev, void *input_q_dev, 51 | void *input_k_dev, void *positions_dev, 52 | int total_token_num, int n_heads, int hidden_dim, 53 | bool is_neox_style, DataType dt, aclrtStream &stream); 54 | 55 | void npu_batch_matmul_layer(void *output_dev, void *lhs_dev, void *rhs_dev, 56 | int batch, int m, int n, int k, float scale, 57 | DataType dt, aclrtStream &stream); 58 | 59 | void npu_batch_matmul_trans_v_layer(void *output_dev, void *lhs_dev, 60 | void *rhs_dev, int batch, int m, int n, 61 | int k, float scale, DataType dt, 62 | aclrtStream &stream); 63 | 64 | void npu_batch_matmul_causual_layer(void *output_dev, void *lhs_dev, 65 | void *rhs_dev, int batch, int m, int n, 66 | int k, int causual_offset, float scale, 67 | DataType dt, aclrtStream &stream); 68 | 69 | void npu_batch_matmul_qk_trans_causual_layer(void *output_dev, void *lhs_dev, 70 | void *rhs_dev, int batch, int m, 71 | int n, int k, int causual_offset, 72 | float scale, DataType dt, 73 | aclrtStream &stream); 74 | 75 | void npu_silu_mul_layer(void *output_dev, void *w1_dev, void *w3_dev, 76 | int total_size, DataType dt, aclrtStream &stream); 77 | 78 | void npu_silu_mul_layer_vllm(void *output_dev, void *input_dev, int first_dim, 79 | int last_dim, DataType dt, aclrtStream &stream); 80 | 81 | void npu_add_layer(void *output_dev, void *lhs, void *rhs, int total_size, 82 | DataType dt, aclrtStream &stream); 83 | 84 | void npu_matmul_layer(void *output_dev, void *lhs_dev, void *rhs_dev, int m, 85 | int n, int k, DataType dt, aclrtStream &stream); 86 | 87 | void npu_matmul_nz_layer(void *output_dev, void *lhs_dev, void *rhs_dev, int m, 88 | int n, int k, DataType dt, aclrtStream &stream); 89 | 90 | void npu_matmul_bias_nz_layer(void *output_dev, void *lhs_dev, void *rhs_dev, 91 | void *bias_dev, int m, int n, int k, DataType dt, 92 | aclrtStream &stream); 93 | 94 | void npu_mamtul_weight_transpose_layer(void *output_dev, void *input, int n, 95 | int k, DataType dt, aclrtStream &stream); 96 | 97 | void npu_matmul_nz_awq_4bit_layer(void *output_dev, void *lhs_dev, 98 | void *weight_dev, void *zero_dev, 99 | void *scale_dev, int m, int n, int k, 100 | DataType dt, aclrtStream &stream); 101 | 102 | void npu_matmul_nz_awq_4bit_bias_layer(void *output_dev, void *lhs_dev, 103 | void *weight_dev, void *zero_dev, 104 | void *scale_dev, void *bias_dev, int m, 105 | int n, int k, DataType dt, 106 | aclrtStream &stream); 107 | -------------------------------------------------------------------------------- /prompts/chat_example.txt: -------------------------------------------------------------------------------- 1 | Text transcript of a never ending dialog, where [[USER_NAME]] interacts with an AI assistant named [[AI_NAME]]. 2 | [[AI_NAME]] is helpful, kind, honest, friendly, good at writing and never fails to answer [[USER_NAME]]'s requests immediately and with details and precision. 3 | There are no annotations like (30 seconds passed...) or (to himself), just what [[USER_NAME]] and [[AI_NAME]] say aloud to each other. 4 | The dialog lasts for years, the entirety of it is shared below. It's 10000 pages long. 5 | The transcript only includes text, it does not include markup like HTML and Markdown. 6 | 7 | [[USER_NAME]]: Hello, [[AI_NAME]]! 8 | [[AI_NAME]]: Hello [[USER_NAME]]! How may I help you today? 9 | [[USER_NAME]]: What year is it? 10 | [[AI_NAME]]: We are in [[DATE_YEAR]]. 11 | [[USER_NAME]]: Please tell me the largest city in Europe. 12 | [[AI_NAME]]: The largest city in Europe is Moscow, the capital of Russia. 13 | [[USER_NAME]]: What can you tell me about Moscow? 14 | [[AI_NAME]]: Moscow, on the Moskva River in western Russia, is the nation's cosmopolitan capital. In its historic core is the Kremlin, a complex that's home to the president and tsarist treasures in the Armoury. Outside its walls is Red Square, Russia’s symbolic center. 15 | [[USER_NAME]]: What is a cat? 16 | [[AI_NAME]]: A cat is a domestic species of small carnivorous mammal. It is the only domesticated species in the family Felidae. 17 | [[USER_NAME]]: How do I pass command line arguments to a Node.js program? 18 | [[AI_NAME]]: The arguments are stored in process.argv. 19 | 20 | argv[0] is the path to the Node. js executable. 21 | argv[1] is the path to the script file. 22 | argv[2] is the first argument passed to the script. 23 | argv[3] is the second argument passed to the script and so on. 24 | [[USER_NAME]]: Name a color. 25 | [[AI_NAME]]: Blue. 26 | [[USER_NAME]]: What time is it? 27 | [[AI_NAME]]: It is [[DATE_TIME]]. 28 | [[USER_NAME]]: -------------------------------------------------------------------------------- /prompts/readme.md: -------------------------------------------------------------------------------- 1 | chat_example.txt from llama.cpp -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Orange Pi LLM推理 2 | 支持OrangePi LLM推理,当前测试硬件版本: Orange Pi 20T 24GB 3 | ## 安装 4 | [安装文档](orangepi_install.md) 5 | 6 | ## QWen2 7 | 支持*Qwen2ForCausalLM*模型 8 | ### 模型下载 9 | 建议通过git直接从modelscope下载(需要安装git lfs),比如DeepSeek-R1-Distill-Qwen-1.5B: 10 | ```git clone https://www.modelscope.cn/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B.git``` 11 | ### 权重转换 12 | #### BF16模型 (以Qwen2.5-3B-Instruct为例,请将路径替换为自己的路径) 13 | ```python3 /data/llm_simple/scripts/convert_qwen2_weight.py --input_model_path /ssd/models/Qwen2.5-3B-Instruct --output_dir /ssd/models/Qwen2.5-3B-Instruct_converted ``` 14 | #### AWQ模型 (以Qwen2.5-14B-Instruct-AWQ为例,请将路径替换为自己的路径) 15 | ```python3 /data/llm_simple/scripts/convert_qwen2_awq_weight.py --input_model_path /ssd/models/Qwen2.5-14B-Instruct-AWQ --output_dir /ssd/models/Qwen2.5-14B-Instruct-AWQ_converted``` 16 | ### 运行 17 | *请将脚本中的路径改为自己的路径* 18 | 19 | ```bash scripts/example_text_completion_deepseek_r1_qwen2.5_1.5B_bf16_orangepi.sh``` 20 | 21 | ### 性能(输入256token/输出256token) 22 | |模型大小|ttft(ms)|decode(ms/token)| 23 | |---|---|---| 24 | |1.5B|461|142| 25 | |3B|776|284| 26 | |7B|3215|881| 27 | |3B-AWQ|3215|113| 28 | |7B-AWQ|2358|206 29 | |14B-AWQ|8181|653| 30 | 31 | 32 | ## LLAMA2 33 | ### 权重转换 34 | #### LLAMA2-7B FP16 (支持llama官方发布的格式, 包含tokenizer.model,params.json,consolidated.00.pth文件) 35 | ```python3 scripts/convert_llama2_weight.py --input_dir --model_size 7B --output_dir ``` 36 | #### LLAMA2-7B-AWQ 4bit 37 | 权重下载链接:[model.safetensors](https://huggingface.co/TheBloke/Llama-2-7B-AWQ/blob/main/model.safetensors) 38 | ```python3 scripts/convert_llama_awq_4bit.py --input_safetensor --output_dir ``` 39 | #### LLAMA2-13B-AWQ 4bit 40 | 权重下载链接:[model.safetensors](https://huggingface.co/TheBloke/Llama-2-13B-AWQ/resolve/main/model.safetensors) 41 | ```python3 scripts/convert_llama_awq_4bit.py --input_safetensor --output_dir ``` 42 | 43 | ### 运行 44 | *请将转化后的权重文件夹,配置文件, tokenizer文件拷贝到设备上并修改bash文件中对应的路径* 45 | 46 | 1. ```bash scripts/example_chat_llama2_7B_fp16_orangepi.sh``` 47 | 2. ```bash scripts/example_text_completion_llama2_7B_fp16_orangepi.sh``` 48 | 3. ```bash scripts/example_chat_llama2_7B_awq_4bit_orangepi.sh``` 49 | 4. ```bash scripts/example_text_completion_llama2_7B_awq_4bit_orangepi.sh``` 50 | 5. ```bash scripts/example_chat_llama2_13B_awq_4bit_orangepi.sh``` 51 | 6. ```bash scripts/example_text_completion_llama2_13B_awq_4bit_orangepi.sh``` 52 | 53 | ### 性能 54 | |场景|ttft(ms)|decode(ms/token)| 55 | |---|---|---| 56 | |llama2-7B-AWQ-4bit|886|176.7| 57 | |llama2-7B-FP16|4498|568.4| 58 | |llama2-13B-AWQ-4bit|1819|320.1| 59 | -------------------------------------------------------------------------------- /scripts/benchmark_qwen2.5.sh: -------------------------------------------------------------------------------- 1 | ./build/src/llama2_main \ 2 | --config=$1/config.json \ 3 | --tokenizer=$1/ \ 4 | --weight=$1_converted \ 5 | --model_type=qwen2 \ 6 | --device_type=npu \ 7 | --max_seq_len=8192 \ 8 | --max_gen_token=1024 \ 9 | --temperature=0.6 \ 10 | --debug_print=false \ 11 | --log_level=info \ 12 | --rope_is_neox_style=true \ 13 | --benchmark \ 14 | --benchmark_input_seq_length=$2 \ 15 | --benchmark_output_seq_length=$3 16 | -------------------------------------------------------------------------------- /scripts/convert_llama2_weight.py: -------------------------------------------------------------------------------- 1 | # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py 2 | 3 | import argparse 4 | import json 5 | import os 6 | import shutil 7 | 8 | import torch 9 | import numpy as np 10 | 11 | 12 | NUM_SHARDS = { 13 | "7B": 1, 14 | "13B": 2, 15 | "30B": 4, 16 | "65B": 8, 17 | } 18 | 19 | 20 | def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256): 21 | return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of) 22 | 23 | 24 | def read_json(path): 25 | with open(path, "r") as f: 26 | return json.load(f) 27 | 28 | 29 | def write_json(text, path): 30 | with open(path, "w") as f: 31 | json.dump(text, f) 32 | 33 | 34 | 35 | def main(): 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument( 38 | "--input_dir", 39 | help="Location of LLaMA weights, which contains tokenizer.model and model folders", 40 | ) 41 | parser.add_argument( 42 | "--model_size", 43 | choices=["7B", "13B", "30B", "65B"], 44 | help="'f' models correspond to the finetuned versions, and are specific to the Llama2 official release. For more details on Llama2, checkout the original repo: https://huggingface.co/meta-llama", 45 | ) 46 | parser.add_argument( 47 | "--output_dir", 48 | help="Location to write HF model and tokenizer", 49 | ) 50 | 51 | args = parser.parse_args() 52 | 53 | input_base_path = args.input_dir 54 | model_size = args.model_size 55 | output_path = args.output_dir 56 | model_path = output_path 57 | 58 | spm_path = os.path.join(input_base_path, "tokenizer.model") 59 | 60 | 61 | input_base_path = os.path.join(input_base_path, model_size) 62 | os.makedirs(model_path, exist_ok=True) 63 | tmp_model_path = os.path.join(model_path, "tmp") 64 | os.makedirs(tmp_model_path, exist_ok=True) 65 | 66 | 67 | params = read_json(os.path.join(input_base_path, "params.json")) 68 | 69 | num_shards = NUM_SHARDS[model_size] 70 | params = params.get("model", params) 71 | n_layers = params["n_layers"] 72 | n_heads = params["n_heads"] 73 | n_heads_per_shard = n_heads // num_shards 74 | dim = params["dim"] 75 | dims_per_head = dim // n_heads 76 | base = params.get("rope_theta", 10000.0) 77 | inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) 78 | if base > 10000.0: 79 | max_position_embeddings = 16384 80 | else: 81 | # Depending on the Llama version, the default max_position_embeddings has different values. 82 | max_position_embeddings = 4096 83 | vocab_size = 32000 84 | 85 | if params.get("n_kv_heads", None) is not None: 86 | num_key_value_heads = params["n_kv_heads"] # for GQA / MQA 87 | num_local_key_value_heads = n_heads_per_shard // num_key_value_heads 88 | key_value_dim = dim // num_key_value_heads 89 | else: # compatibility with other checkpoints 90 | num_key_value_heads = n_heads 91 | num_local_key_value_heads = n_heads_per_shard 92 | key_value_dim = dim 93 | 94 | def permute(w, n_heads=n_heads, dim1=dim, dim2=dim): 95 | return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) 96 | 97 | print(f"Fetching all parameters from the checkpoint at {input_base_path}.") 98 | # Load weights 99 | if num_shards == 1: 100 | # Not sharded 101 | # (The sharded implementation would also work, but this is simpler.) 102 | loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu") 103 | else: 104 | # Sharded 105 | loaded = [ 106 | torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu") 107 | for i in range(num_shards) 108 | ] 109 | param_count = 0 110 | for layer_i in range(n_layers): 111 | if num_shards == 1: 112 | # Unsharded 113 | state_dict = { 114 | f"model.layers.{layer_i}.self_attn.q_proj.weight": permute( 115 | loaded[f"layers.{layer_i}.attention.wq.weight"] 116 | ), 117 | f"model.layers.{layer_i}.self_attn.k_proj.weight": permute( 118 | loaded[f"layers.{layer_i}.attention.wk.weight"] 119 | ), 120 | f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"], 121 | f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"], 122 | f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"], 123 | f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"], 124 | f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"], 125 | f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"layers.{layer_i}.attention_norm.weight"], 126 | f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"layers.{layer_i}.ffn_norm.weight"], 127 | } 128 | else: 129 | # Sharded 130 | # Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share 131 | # the same storage object, saving attention_norm and ffn_norm will save other weights too, which is 132 | # redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned. 133 | 134 | state_dict = { 135 | f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][ 136 | f"layers.{layer_i}.attention_norm.weight" 137 | ].clone(), 138 | f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0][ 139 | f"layers.{layer_i}.ffn_norm.weight" 140 | ].clone(), 141 | } 142 | state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute( 143 | torch.cat( 144 | [ 145 | loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim) 146 | for i in range(num_shards) 147 | ], 148 | dim=0, 149 | ).reshape(dim, dim) 150 | ) 151 | state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute( 152 | torch.cat( 153 | [ 154 | loaded[i][f"layers.{layer_i}.attention.wk.weight"].view( 155 | num_local_key_value_heads, dims_per_head, dim 156 | ) 157 | for i in range(num_shards) 158 | ], 159 | dim=0, 160 | ).reshape(key_value_dim, dim), 161 | num_key_value_heads, 162 | key_value_dim, 163 | dim, 164 | ) 165 | state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat( 166 | [ 167 | loaded[i][f"layers.{layer_i}.attention.wv.weight"].view( 168 | num_local_key_value_heads, dims_per_head, dim 169 | ) 170 | for i in range(num_shards) 171 | ], 172 | dim=0, 173 | ).reshape(key_value_dim, dim) 174 | 175 | state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat( 176 | [loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1 177 | ) 178 | state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat( 179 | [loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0 180 | ) 181 | state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat( 182 | [loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1 183 | ) 184 | state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat( 185 | [loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0 186 | ) 187 | 188 | state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq 189 | for k, v in state_dict.items(): 190 | param_count += v.numel() 191 | v.numpy().tofile(os.path.join(output_path, f"{k}.bin")) 192 | 193 | if num_shards == 1: 194 | # Unsharded 195 | state_dict = { 196 | "model.embed_tokens.weight": loaded["tok_embeddings.weight"], 197 | "model.norm.weight": loaded["norm.weight"], 198 | "lm_head.weight": loaded["output.weight"], 199 | } 200 | else: 201 | state_dict = { 202 | "model.norm.weight": loaded[0]["norm.weight"], 203 | "model.embed_tokens.weight": torch.cat( 204 | [loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1 205 | ), 206 | "lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0), 207 | } 208 | 209 | for k, v in state_dict.items(): 210 | param_count += v.numel() 211 | v.numpy().tofile(os.path.join(output_path, f"{k}.bin")) 212 | shutil.rmtree(tmp_model_path) 213 | 214 | 215 | if __name__ == '__main__': 216 | main() 217 | 218 | -------------------------------------------------------------------------------- /scripts/convert_llama_awq_4bit.py: -------------------------------------------------------------------------------- 1 | from safetensors.numpy import load_file 2 | import numpy as np 3 | import argparse 4 | import os 5 | 6 | 7 | def main(): 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument( 10 | "--input_safetensor", 11 | required=True, 12 | help="Location of LLaMA weights, which contains tokenizer.model and model folders", 13 | ) 14 | parser.add_argument( 15 | "--output_dir", 16 | required=True, 17 | help="Location to write converted model", 18 | ) 19 | 20 | args = parser.parse_args() 21 | 22 | input_path = args.input_safetensor 23 | output_path = args.output_dir 24 | os.makedirs(output_path, exist_ok=True) 25 | 26 | loaded = load_file(input_path) 27 | for k, v in loaded.items(): 28 | tname = f"{k}.bin" 29 | print(f"writting {tname} shape: {v.shape} dtype: {v.dtype}") 30 | if "qweight" in k: 31 | # (k, n//8) -> (k, n//2) 32 | v = v.view("uint8") 33 | k_dim, n_dim = v.shape 34 | print(f"weight k:{k_dim}, n:{n_dim}") 35 | v = v.reshape(k_dim, n_dim, 1) 36 | v = np.repeat(v, 2, axis=-1) 37 | v[..., 0] = v[..., 0] & 0xf 38 | v[..., 1] = (v[..., 1] >> 4) & 0xf 39 | n_dim = n_dim * 2 40 | v = v.reshape(k_dim, n_dim//8, 2, 4) 41 | v = np.transpose(v, (0, 1, 3, 2)) 42 | # transpose to (k, n) 43 | #v = np.transpose(v, (1, 0)) 44 | v = v.reshape(k_dim//16, 16, n_dim) 45 | v = np.transpose(v, (0, 2, 1)) 46 | print(f"new shape: {v.shape}") 47 | d1 = v.size // 512 48 | v = v.reshape(d1, 4, 64, 2) 49 | v = np.transpose(v, (0, 2, 1, 3)) 50 | v = (v + 8)&0xf 51 | v[..., 0] = v[..., 0] | (v[...,1] << 4) 52 | v = np.ascontiguousarray(v[..., 0]) 53 | print(f"weight output shape {v.shape}, dtype: {v.dtype}") 54 | if "qzeros" in k: 55 | # (k//128,n//8) -> (k//128, n//2) 56 | v = v.view("uint8") 57 | k_dim, n_dim = v.shape 58 | v = v.reshape(k_dim, n_dim, 1) 59 | v = np.repeat(v, 2, axis=-1) 60 | v[..., 0] = v[..., 0] & 0xf 61 | v[..., 1] = (v[..., 1] >> 4) & 0xf 62 | v = v.astype("float16") 63 | v = v - 8.0 64 | n_dim = n_dim * 2 65 | v = v.reshape(k_dim, n_dim//8, 2, 4) 66 | v = np.transpose(v, (0, 1, 3, 2)) 67 | v = v.reshape(k_dim, n_dim) 68 | #v = np.transpose(v, (1, 0)) 69 | print(f"new shape: {v.shape}") 70 | v = np.ascontiguousarray(v) 71 | if "scales" in k: 72 | # (k//128,n) 73 | #v = np.transpose(v, (1, 0)) 74 | print(f"new shape: {v.shape}") 75 | v = np.ascontiguousarray(v) 76 | 77 | v.tofile(os.path.join(output_path, tname)) 78 | 79 | 80 | if __name__ == '__main__': 81 | main() 82 | -------------------------------------------------------------------------------- /scripts/convert_qwen2_awq_weight.py: -------------------------------------------------------------------------------- 1 | from ml_dtypes import bfloat16 2 | from safetensors.numpy import load_file 3 | import numpy as np 4 | import argparse 5 | import os 6 | import json 7 | import re 8 | 9 | 10 | def change_weight_layout(t): 11 | return t 12 | n, k = t.shape 13 | t = t.reshape(n//16,16,k) 14 | t = np.transpose(t, (0, 2, 1)) 15 | t = np.ascontiguousarray(t) 16 | return t 17 | 18 | 19 | 20 | def main(): 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument( 23 | "--input_model_path", 24 | required=True, 25 | help="Location of Qwen Git repo", 26 | ) 27 | parser.add_argument( 28 | "--output_dir", 29 | required=True, 30 | help="Location to write converted model", 31 | ) 32 | 33 | args = parser.parse_args() 34 | 35 | input_path = args.input_model_path 36 | output_path = args.output_dir 37 | os.makedirs(output_path, exist_ok=True) 38 | 39 | index_json_path = os.path.join(input_path, "model.safetensors.index.json") 40 | 41 | safetensor_map = {} 42 | 43 | if os.path.exists(index_json_path): 44 | weight_config_f = open(index_json_path) 45 | weight_config = json.load(weight_config_f) 46 | 47 | weight_map = weight_config["weight_map"] 48 | st_set = set() 49 | 50 | for k, wst in weight_map.items(): 51 | st_set.add(wst) 52 | 53 | for st_name in st_set: 54 | st = load_file(os.path.join(input_path, st_name)) 55 | for k, t in st.items(): 56 | safetensor_map[k] = t 57 | else: 58 | # only one file 59 | safetensor_map = load_file(os.path.join(input_path, "model.safetensors")) 60 | 61 | 62 | for k, t in safetensor_map.items(): 63 | print(k, t.shape, t.dtype) 64 | if "qweight" in k: 65 | # (k, n//8) -> (k, n//2) 66 | t = t.view("uint8") 67 | k_dim, n_dim = t.shape 68 | print(f"qweight k:{k_dim}, n:{n_dim}") 69 | t = t.reshape(k_dim, n_dim, 1) 70 | t = np.repeat(t, 2, axis=-1) 71 | t[..., 0] = t[..., 0] & 0xf 72 | t[..., 1] = (t[..., 1] >> 4) & 0xf 73 | n_dim = n_dim * 2 74 | t = t.reshape(k_dim, n_dim//8, 2, 4) 75 | t = np.transpose(t, (0, 1, 3, 2)) 76 | # transpose to (k, n) 77 | #v = np.transpose(v, (1, 0)) 78 | t = t.reshape(k_dim//16, 16, n_dim) 79 | t = np.transpose(t, (0, 2, 1)) 80 | print(f"new shape: {t.shape}") 81 | d1 = t.size // 512 82 | t = t.reshape(d1, 4, 64, 2) 83 | t = np.transpose(t, (0, 2, 1, 3)) 84 | t = (t + 8)&0xf 85 | t[..., 0] = t[..., 0] | (t[...,1] << 4) 86 | t = np.ascontiguousarray(t[..., 0]) 87 | print(f"qweight output shape {t.shape}, dtype: {t.dtype}") 88 | if "qzeros" in k: 89 | # (k//128,n//8) -> (k//128, n//2) 90 | t = t.view("uint8") 91 | k_dim, n_dim = t.shape 92 | t = t.reshape(k_dim, n_dim, 1) 93 | t = np.repeat(t, 2, axis=-1) 94 | t[..., 0] = t[..., 0] & 0xf 95 | t[..., 1] = (t[..., 1] >> 4) & 0xf 96 | t = t.astype("float16") 97 | t = t - 8.0 98 | n_dim = n_dim * 2 99 | t = t.reshape(k_dim, n_dim//8, 2, 4) 100 | t = np.transpose(t, (0, 1, 3, 2)) 101 | t = t.reshape(k_dim, n_dim) 102 | #v = np.transpose(v, (1, 0)) 103 | print(f"qzeros new shape: {t.shape}") 104 | t = np.ascontiguousarray(t) 105 | if "scales" in k: 106 | # (k//128,n) 107 | #v = np.transpose(v, (1, 0)) 108 | print(f"new qscale shape: {t.shape}") 109 | t = np.ascontiguousarray(t) 110 | 111 | 112 | if k == "lm_head.weight": 113 | t = change_weight_layout(t) 114 | t.tofile(os.path.join(output_path, k + ".bin")) 115 | elif k == "model.norm.weight": 116 | t.tofile(os.path.join(output_path, k + ".bin")) 117 | elif k == "model.embed_tokens.weight": 118 | t.tofile(os.path.join(output_path, k + ".bin")) 119 | else: 120 | m = re.match("model.*layernorm.*", k) 121 | if m: 122 | t.tofile(os.path.join(output_path, k + ".bin")) 123 | continue 124 | m = re.match(".*bias", k) 125 | if m: 126 | t = t.astype("float32") 127 | t.tofile(os.path.join(output_path, k + ".bin")) 128 | continue 129 | 130 | t = change_weight_layout(t) 131 | t.tofile(os.path.join(output_path, k + ".bin")) 132 | 133 | 134 | if __name__ == '__main__': 135 | main() 136 | -------------------------------------------------------------------------------- /scripts/convert_qwen2_weight.py: -------------------------------------------------------------------------------- 1 | from ml_dtypes import bfloat16 2 | from safetensors.numpy import load_file 3 | import numpy as np 4 | import argparse 5 | import os 6 | import json 7 | import re 8 | 9 | 10 | def change_weight_layout(t): 11 | return t 12 | n, k = t.shape 13 | t = t.reshape(n//16,16,k) 14 | t = np.transpose(t, (0, 2, 1)) 15 | t = np.ascontiguousarray(t) 16 | return t 17 | 18 | 19 | 20 | def main(): 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument( 23 | "--input_model_path", 24 | required=True, 25 | help="Location of Qwen Git repo", 26 | ) 27 | parser.add_argument( 28 | "--output_dir", 29 | required=True, 30 | help="Location to write converted model", 31 | ) 32 | 33 | args = parser.parse_args() 34 | 35 | input_path = args.input_model_path 36 | output_path = args.output_dir 37 | os.makedirs(output_path, exist_ok=True) 38 | 39 | index_json_path = os.path.join(input_path, "model.safetensors.index.json") 40 | 41 | safetensor_map = {} 42 | 43 | if os.path.exists(index_json_path): 44 | weight_config_f = open(index_json_path) 45 | weight_config = json.load(weight_config_f) 46 | 47 | weight_map = weight_config["weight_map"] 48 | st_set = set() 49 | 50 | for k, wst in weight_map.items(): 51 | st_set.add(wst) 52 | 53 | for st_name in st_set: 54 | st = load_file(os.path.join(input_path, st_name)) 55 | for k, t in st.items(): 56 | safetensor_map[k] = t 57 | else: 58 | # only one file 59 | safetensor_map = load_file(os.path.join(input_path, "model.safetensors")) 60 | 61 | 62 | for k, t in safetensor_map.items(): 63 | print(k, t.shape, t.dtype) 64 | if k == "lm_head.weight": 65 | t = change_weight_layout(t) 66 | t.tofile(os.path.join(output_path, k + ".bin")) 67 | elif k == "model.norm.weight": 68 | t.tofile(os.path.join(output_path, k + ".bin")) 69 | elif k == "model.embed_tokens.weight": 70 | t.tofile(os.path.join(output_path, k + ".bin")) 71 | else: 72 | m = re.match("model.*layernorm.*", k) 73 | if m: 74 | t.tofile(os.path.join(output_path, k + ".bin")) 75 | continue 76 | m = re.match(".*bias", k) 77 | if m: 78 | t = t.astype("float32") 79 | t.tofile(os.path.join(output_path, k + ".bin")) 80 | continue 81 | 82 | t = change_weight_layout(t) 83 | t.tofile(os.path.join(output_path, k + ".bin")) 84 | 85 | 86 | if __name__ == '__main__': 87 | main() 88 | -------------------------------------------------------------------------------- /scripts/example_chat_llama2_13B_awq_4bit_orangepi.sh: -------------------------------------------------------------------------------- 1 | PROMPT_TEMPLATE=./prompts/chat_example.txt 2 | PROMPT_FILE=$(mktemp -t llamacpp_prompt.XXXXXXX.txt) 3 | 4 | DATE_TIME=$(date +%H:%M) 5 | DATE_YEAR=$(date +%Y) 6 | USER_NAME="user" 7 | AI_NAME="orange_pi" 8 | 9 | sed -e "s/\[\[USER_NAME\]\]/$USER_NAME/g" \ 10 | -e "s/\[\[AI_NAME\]\]/$AI_NAME/g" \ 11 | -e "s/\[\[DATE_TIME\]\]/$DATE_TIME/g" \ 12 | -e "s/\[\[DATE_YEAR\]\]/$DATE_YEAR/g" \ 13 | $PROMPT_TEMPLATE > $PROMPT_FILE 14 | 15 | ./build/src/llama2_main \ 16 | --config=llama/llama2/13B/params.json \ 17 | --tokenizer=llama/llama2/tokenizer.model \ 18 | --weight=/data/Llama-2-13B-AWQ/llama2_13B_awq_4bit \ 19 | --quant_method=awq_4bit \ 20 | --quant_group_size=128 \ 21 | --device_type=npu \ 22 | --max_seq_len=2048 \ 23 | --log_level=info \ 24 | --rope_is_neox_style=true \ 25 | --reverse_promt="${USER_NAME}:" \ 26 | --i \ 27 | --prompt_file=$PROMPT_FILE 28 | -------------------------------------------------------------------------------- /scripts/example_chat_llama2_7B_awq_4bit_orangepi.sh: -------------------------------------------------------------------------------- 1 | PROMPT_TEMPLATE=./prompts/chat_example.txt 2 | PROMPT_FILE=$(mktemp -t llamacpp_prompt.XXXXXXX.txt) 3 | 4 | DATE_TIME=$(date +%H:%M) 5 | DATE_YEAR=$(date +%Y) 6 | USER_NAME="user" 7 | AI_NAME="orange_pi" 8 | 9 | sed -e "s/\[\[USER_NAME\]\]/$USER_NAME/g" \ 10 | -e "s/\[\[AI_NAME\]\]/$AI_NAME/g" \ 11 | -e "s/\[\[DATE_TIME\]\]/$DATE_TIME/g" \ 12 | -e "s/\[\[DATE_YEAR\]\]/$DATE_YEAR/g" \ 13 | $PROMPT_TEMPLATE > $PROMPT_FILE 14 | 15 | ./build/src/llama2_main \ 16 | --config=llama/llama2/7B/params.json \ 17 | --tokenizer=llama/llama2/tokenizer.model \ 18 | --weight=/data/llama2_7b_awq/llama2_7b_awq_4bit/ \ 19 | --quant_method=awq_4bit \ 20 | --quant_group_size=128 \ 21 | --device_type=npu \ 22 | --max_seq_len=2048 \ 23 | --log_level=info \ 24 | --rope_is_neox_style=true \ 25 | --reverse_promt="${USER_NAME}:" \ 26 | --i \ 27 | --prompt_file=$PROMPT_FILE 28 | -------------------------------------------------------------------------------- /scripts/example_chat_llama2_7B_fp16_orangepi.sh: -------------------------------------------------------------------------------- 1 | PROMPT_TEMPLATE=./prompts/chat_example.txt 2 | PROMPT_FILE=$(mktemp -t llamacpp_prompt.XXXXXXX.txt) 3 | 4 | DATE_TIME=$(date +%H:%M) 5 | DATE_YEAR=$(date +%Y) 6 | USER_NAME="user" 7 | AI_NAME="orange_pi" 8 | 9 | sed -e "s/\[\[USER_NAME\]\]/$USER_NAME/g" \ 10 | -e "s/\[\[AI_NAME\]\]/$AI_NAME/g" \ 11 | -e "s/\[\[DATE_TIME\]\]/$DATE_TIME/g" \ 12 | -e "s/\[\[DATE_YEAR\]\]/$DATE_YEAR/g" \ 13 | $PROMPT_TEMPLATE > $PROMPT_FILE 14 | 15 | ./build/src/llama2_main \ 16 | --config=llama/llama2/7B/params.json \ 17 | --tokenizer=llama/llama2/tokenizer.model \ 18 | --weight=/data/llama2/7B/model_output \ 19 | --device_type=npu \ 20 | --max_seq_len=2048 \ 21 | --log_level=info \ 22 | --reverse_promt="${USER_NAME}:" \ 23 | --i \ 24 | --prompt_file=$PROMPT_FILE 25 | -------------------------------------------------------------------------------- /scripts/example_chat_qwen2.5_3B_bf16_orangepi.sh: -------------------------------------------------------------------------------- 1 | ./build/src/llama2_main \ 2 | --config=/ssd/models/Qwen2.5-3B-Instruct/config.json \ 3 | --tokenizer=/ssd/models/Qwen2.5-3B-Instruct/ \ 4 | --weight=/ssd/models/Qwen2.5-3B-Instruct_converted \ 5 | --model_type=qwen2 \ 6 | --device_type=npu \ 7 | --max_seq_len=2048 \ 8 | --max_gen_token=1024 \ 9 | --temperature=0.6 \ 10 | --debug_print=false \ 11 | --log_level=info \ 12 | --rope_is_neox_style=true \ 13 | --i \ 14 | --prompt=" " 15 | -------------------------------------------------------------------------------- /scripts/example_orangepi_debug.sh: -------------------------------------------------------------------------------- 1 | #gdb --args 2 | ./build/src/llama2_main \ 3 | --config=/data/llama2/7B/params.json \ 4 | --tokenizer=/data/llama2/tokenizer.model \ 5 | --weight=/data/llama2/7B/model_output \ 6 | --device_type=npu \ 7 | --max_seq_len=128 \ 8 | --log_level=debug \ 9 | --debug_print=true \ 10 | --prompt="Translate English to French: 11 | sea otter => loutre de mer 12 | peppermint => menthe poivrée 13 | plush girafe => girafe peluche 14 | cheese =>" -------------------------------------------------------------------------------- /scripts/example_orangepi_msprof.sh: -------------------------------------------------------------------------------- 1 | #gdb --args 2 | msprof --aic-mode=sample-based --output=./profiling_output --application="./build/src/llama2_main \ 3 | --config=/data/llama2/7B/params.json \ 4 | --tokenizer=/data/llama2/tokenizer.model \ 5 | --weight=/data/llama2/7B/model_output \ 6 | --device_type=npu \ 7 | --prompt=\"Once upon\"" -------------------------------------------------------------------------------- /scripts/example_orangepi_msprof_awq_4bit.sh: -------------------------------------------------------------------------------- 1 | #gdb --args 2 | msprof --aic-mode=sample-based --output=./profiling_output --application="./build/src/llama2_main \ 3 | --config=/data/llama2/7B/params.json \ 4 | --tokenizer=/data/llama2/tokenizer.model \ 5 | --weight=/data/llama2_7b_awq/llama2_7b_awq_4bit/ \ 6 | --device_type=npu \ 7 | --quant_method=awq_4bit \ 8 | --quant_group_size=128 \ 9 | --max_seq_len=128 \z 10 | --prompt=\"Once upon\"" -------------------------------------------------------------------------------- /scripts/example_orangepi_profiling.sh: -------------------------------------------------------------------------------- 1 | #gdb --args 2 | ./build/src/llama2_main \ 3 | --log_level=info \ 4 | --profiling_output=llama_full_prof.json \ 5 | --config=/data/llama2/7B/params.json \ 6 | --tokenizer=/data/llama2/tokenizer.model \ 7 | --weight=/data/llama2/7B/model_output \ 8 | --device_type=npu \ 9 | --max_seq_len=128 \ 10 | --prompt="Once upon" -------------------------------------------------------------------------------- /scripts/example_orangepi_profiling_awq_4bit.sh: -------------------------------------------------------------------------------- 1 | #gdb --args 2 | ./build/src/llama2_main \ 3 | --log_level=info \ 4 | --profiling_output=llama_full_prof.json \ 5 | --config=/data/llama2/7B/params.json \ 6 | --tokenizer=/data/llama2/tokenizer.model \ 7 | --weight=/data/llama2_7b_awq/llama2_7b_awq_4bit/ \ 8 | --device_type=npu \ 9 | --quant_method=awq_4bit \ 10 | --quant_group_size=128 \ 11 | --max_seq_len=128 \ 12 | --prompt="Once upon" -------------------------------------------------------------------------------- /scripts/example_text_completion_deepseek_r1_qwen2.5_1.5B_bf16_orangepi.sh: -------------------------------------------------------------------------------- 1 | ./build/src/llama2_main \ 2 | --config=/ssd/models/DeepSeek-R1-Distill-Qwen-1.5B/config.json \ 3 | --tokenizer=/ssd/models/DeepSeek-R1-Distill-Qwen-1.5B/ \ 4 | --weight=/ssd/models/DeepSeek-R1-Distill-Qwen-1.5B_converted \ 5 | --model_type=qwen2 \ 6 | --device_type=npu \ 7 | --max_seq_len=8192 \ 8 | --max_gen_token=8192 \ 9 | --temperature=0.6 \ 10 | --debug_print=false \ 11 | --log_level=info \ 12 | --rope_is_neox_style=true \ 13 | --prompt="深度优化指南请求:通用矩阵乘法(GEMM)的全栈性能工程实践 14 | 15 | 请以高性能计算专家的身份,系统论述GEMM优化的完整技术体系。要求从晶体管层面到算法层进行跨抽象层分析,包含以下维度: 16 | 17 | 一、硬件感知优化基础(展开以下每个子项) 18 | 19 | 现代CPU内存层级解剖 20 | 21 | 详述如何通过cache blocking适应L1/L2/L3缓存行 22 | 23 | 示例:针对不同缓存容量(如8MB L3)的分块策略计算公式 24 | 25 | 数据预取模式设计(软件预取指令的最佳插入距离) 26 | 27 | SIMD指令工程实践 28 | 29 | AVX-512与ARM SVE的寄存器压力对比 30 | 31 | 汇编级循环展开策略(展示8x8分块的双缓冲汇编模板) 32 | 33 | FMA指令流水线冒险规避技巧 34 | 35 | GPU架构深度适配 36 | 37 | CUDA warp-level同步优化(共享内存bank conflict量化分析) 38 | 39 | 全局内存合并访问模式设计(展示2D tile的访存对齐公式) 40 | 41 | Tensor Core编程范式(WMMA API使用陷阱与性能调优日志) 42 | 43 | 二、算法革新路线(需数学推导) 44 | 45 | 复杂分治策略 46 | 47 | Strassen算法在实践中的递归终止条件选择(给出浮点误差传播模型) 48 | 49 | Winograd变换的数值稳定性改进方案 50 | 51 | 基于分块秩的近似算法误差界证明 52 | 53 | 稀疏化与量化 54 | 55 | 结构化稀疏的硬件友好模式设计(NVIDIA A100稀疏特性适配) 56 | 57 | 混合精度训练中的动态缩放因子推导 58 | 59 | 低秩近似与GEMM的耦合优化(给出SVD截断误差分析) 60 | 61 | 三、编译工程化进阶 62 | 63 | LLVM中间表示调优 64 | 65 | Polly循环优化编译指示实战 66 | 67 | MLIR GEMM方言生成技术路线 68 | 69 | 自动向量化失败的补救模式 70 | 71 | 自动调参系统设计 72 | 73 | 遗传算法参数空间剪枝策略 74 | 75 | 贝叶斯优化中的协方差矩阵自适应 76 | 77 | 多目标优化Pareto前沿的筛选标准 78 | 79 | 四、异构计算协同 80 | 81 | 多芯片负载均衡 82 | 83 | CPU-GPU流水线深度分析(计算/通信重叠的数学模型) 84 | 85 | 基于RDMA的跨设备零拷贝实现 86 | 87 | 异构内存一致性模型解决方案 88 | 89 | 五、验证方法论 90 | 91 | Roofline模型深度应用 92 | 93 | 实测不同架构的运算强度阈值 94 | 95 | 性能偏离度诊断流程图 96 | 97 | 瓶颈定位的热力图分析法 98 | 99 | 技术写作要求: 100 | 101 | 每个优化点需提供理论依据(附复杂度公式推导) 102 | 103 | 关键路径给出CUDA/C++代码片段及编译器内联汇编示例 104 | 105 | 包含主流硬件实测数据(如A100 vs Xeon Platinum对比表格) 106 | 107 | 讨论商业化实现差异(对比oneDNN vs cuBLAS设计哲学) 108 | 109 | 最后给出优化决策树(包含分支判断条件) 110 | 111 | 请采用学术论文写作规范,分章节编号至三级标题,使用LaTeX公式描述关键技术指标,总输出保持工程技术文档的严谨性同时具备可操作性。" 112 | -------------------------------------------------------------------------------- /scripts/example_text_completion_deepseek_r1_qwen2.5_14B_bf16_orangepi.sh: -------------------------------------------------------------------------------- 1 | ./build/src/llama2_main \ 2 | --config=/ssd/models/DeepSeek-R1-Distill-Qwen-14B/config.json \ 3 | --tokenizer=/ssd/models/DeepSeek-R1-Distill-Qwen-14B/ \ 4 | --weight=/ssd/models/DeepSeek-R1-Distill-Qwen-14B_converted \ 5 | --model_type=qwen2 \ 6 | --device_type=npu \ 7 | --max_seq_len=8192 \ 8 | --max_gen_token=8192 \ 9 | --temperature=0.6 \ 10 | --debug_print=false \ 11 | --log_level=info \ 12 | --rope_is_neox_style=true \ 13 | --prompt="深度优化指南请求:通用矩阵乘法(GEMM)的全栈性能工程实践 14 | 15 | 请以高性能计算专家的身份,系统论述GEMM优化的完整技术体系。要求从晶体管层面到算法层进行跨抽象层分析,包含以下维度: 16 | 17 | 一、硬件感知优化基础(展开以下每个子项) 18 | 19 | 现代CPU内存层级解剖 20 | 21 | 详述如何通过cache blocking适应L1/L2/L3缓存行 22 | 23 | 示例:针对不同缓存容量(如8MB L3)的分块策略计算公式 24 | 25 | 数据预取模式设计(软件预取指令的最佳插入距离) 26 | 27 | SIMD指令工程实践 28 | 29 | AVX-512与ARM SVE的寄存器压力对比 30 | 31 | 汇编级循环展开策略(展示8x8分块的双缓冲汇编模板) 32 | 33 | FMA指令流水线冒险规避技巧 34 | 35 | GPU架构深度适配 36 | 37 | CUDA warp-level同步优化(共享内存bank conflict量化分析) 38 | 39 | 全局内存合并访问模式设计(展示2D tile的访存对齐公式) 40 | 41 | Tensor Core编程范式(WMMA API使用陷阱与性能调优日志) 42 | 43 | 二、算法革新路线(需数学推导) 44 | 45 | 复杂分治策略 46 | 47 | Strassen算法在实践中的递归终止条件选择(给出浮点误差传播模型) 48 | 49 | Winograd变换的数值稳定性改进方案 50 | 51 | 基于分块秩的近似算法误差界证明 52 | 53 | 稀疏化与量化 54 | 55 | 结构化稀疏的硬件友好模式设计(NVIDIA A100稀疏特性适配) 56 | 57 | 混合精度训练中的动态缩放因子推导 58 | 59 | 低秩近似与GEMM的耦合优化(给出SVD截断误差分析) 60 | 61 | 三、编译工程化进阶 62 | 63 | LLVM中间表示调优 64 | 65 | Polly循环优化编译指示实战 66 | 67 | MLIR GEMM方言生成技术路线 68 | 69 | 自动向量化失败的补救模式 70 | 71 | 自动调参系统设计 72 | 73 | 遗传算法参数空间剪枝策略 74 | 75 | 贝叶斯优化中的协方差矩阵自适应 76 | 77 | 多目标优化Pareto前沿的筛选标准 78 | 79 | 四、异构计算协同 80 | 81 | 多芯片负载均衡 82 | 83 | CPU-GPU流水线深度分析(计算/通信重叠的数学模型) 84 | 85 | 基于RDMA的跨设备零拷贝实现 86 | 87 | 异构内存一致性模型解决方案 88 | 89 | 五、验证方法论 90 | 91 | Roofline模型深度应用 92 | 93 | 实测不同架构的运算强度阈值 94 | 95 | 性能偏离度诊断流程图 96 | 97 | 瓶颈定位的热力图分析法 98 | 99 | 技术写作要求: 100 | 101 | 每个优化点需提供理论依据(附复杂度公式推导) 102 | 103 | 关键路径给出CUDA/C++代码片段及编译器内联汇编示例 104 | 105 | 包含主流硬件实测数据(如A100 vs Xeon Platinum对比表格) 106 | 107 | 讨论商业化实现差异(对比oneDNN vs cuBLAS设计哲学) 108 | 109 | 最后给出优化决策树(包含分支判断条件) 110 | 111 | 请采用学术论文写作规范,分章节编号至三级标题,使用LaTeX公式描述关键技术指标,总输出保持工程技术文档的严谨性同时具备可操作性。" 112 | -------------------------------------------------------------------------------- /scripts/example_text_completion_deepseek_r1_qwen2.5_7B_bf16_orangepi.sh: -------------------------------------------------------------------------------- 1 | ./build/src/llama2_main \ 2 | --config=/ssd/models/DeepSeek-R1-Distill-Qwen-7B/config.json \ 3 | --tokenizer=/ssd/models/DeepSeek-R1-Distill-Qwen-7B/ \ 4 | --weight=/ssd/models/DeepSeek-R1-Distill-Qwen-7B_converted \ 5 | --model_type=qwen2 \ 6 | --device_type=npu \ 7 | --max_seq_len=8192 \ 8 | --max_gen_token=8192 \ 9 | --temperature=0.6 \ 10 | --debug_print=false \ 11 | --log_level=info \ 12 | --rope_is_neox_style=true \ 13 | --prompt="深度优化指南请求:通用矩阵乘法(GEMM)的全栈性能工程实践 14 | 15 | 请以高性能计算专家的身份,系统论述GEMM优化的完整技术体系。要求从晶体管层面到算法层进行跨抽象层分析,包含以下维度: 16 | 17 | 一、硬件感知优化基础(展开以下每个子项) 18 | 19 | 现代CPU内存层级解剖 20 | 21 | 详述如何通过cache blocking适应L1/L2/L3缓存行 22 | 23 | 示例:针对不同缓存容量(如8MB L3)的分块策略计算公式 24 | 25 | 数据预取模式设计(软件预取指令的最佳插入距离) 26 | 27 | SIMD指令工程实践 28 | 29 | AVX-512与ARM SVE的寄存器压力对比 30 | 31 | 汇编级循环展开策略(展示8x8分块的双缓冲汇编模板) 32 | 33 | FMA指令流水线冒险规避技巧 34 | 35 | GPU架构深度适配 36 | 37 | CUDA warp-level同步优化(共享内存bank conflict量化分析) 38 | 39 | 全局内存合并访问模式设计(展示2D tile的访存对齐公式) 40 | 41 | Tensor Core编程范式(WMMA API使用陷阱与性能调优日志) 42 | 43 | 二、算法革新路线(需数学推导) 44 | 45 | 复杂分治策略 46 | 47 | Strassen算法在实践中的递归终止条件选择(给出浮点误差传播模型) 48 | 49 | Winograd变换的数值稳定性改进方案 50 | 51 | 基于分块秩的近似算法误差界证明 52 | 53 | 稀疏化与量化 54 | 55 | 结构化稀疏的硬件友好模式设计(NVIDIA A100稀疏特性适配) 56 | 57 | 混合精度训练中的动态缩放因子推导 58 | 59 | 低秩近似与GEMM的耦合优化(给出SVD截断误差分析) 60 | 61 | 三、编译工程化进阶 62 | 63 | LLVM中间表示调优 64 | 65 | Polly循环优化编译指示实战 66 | 67 | MLIR GEMM方言生成技术路线 68 | 69 | 自动向量化失败的补救模式 70 | 71 | 自动调参系统设计 72 | 73 | 遗传算法参数空间剪枝策略 74 | 75 | 贝叶斯优化中的协方差矩阵自适应 76 | 77 | 多目标优化Pareto前沿的筛选标准 78 | 79 | 四、异构计算协同 80 | 81 | 多芯片负载均衡 82 | 83 | CPU-GPU流水线深度分析(计算/通信重叠的数学模型) 84 | 85 | 基于RDMA的跨设备零拷贝实现 86 | 87 | 异构内存一致性模型解决方案 88 | 89 | 五、验证方法论 90 | 91 | Roofline模型深度应用 92 | 93 | 实测不同架构的运算强度阈值 94 | 95 | 性能偏离度诊断流程图 96 | 97 | 瓶颈定位的热力图分析法 98 | 99 | 技术写作要求: 100 | 101 | 每个优化点需提供理论依据(附复杂度公式推导) 102 | 103 | 关键路径给出CUDA/C++代码片段及编译器内联汇编示例 104 | 105 | 包含主流硬件实测数据(如A100 vs Xeon Platinum对比表格) 106 | 107 | 讨论商业化实现差异(对比oneDNN vs cuBLAS设计哲学) 108 | 109 | 最后给出优化决策树(包含分支判断条件) 110 | 111 | 请采用学术论文写作规范,分章节编号至三级标题,使用LaTeX公式描述关键技术指标,总输出保持工程技术文档的严谨性同时具备可操作性。" 112 | -------------------------------------------------------------------------------- /scripts/example_text_completion_llama2_13B_awq_4bit_orangepi.sh: -------------------------------------------------------------------------------- 1 | #gdb --args \ 2 | ./build/src/llama2_main \ 3 | --config=llama/llama2/13B/params.json \ 4 | --tokenizer=llama/llama2/tokenizer.model \ 5 | --weight=/data/Llama-2-13B-AWQ/llama2_13B_awq_4bit \ 6 | --device_type=npu \ 7 | --max_seq_len=2048 \ 8 | --max_gen_token=256 \ 9 | --log_level=info \ 10 | --debug_print=false \ 11 | --quant_method=awq_4bit \ 12 | --quant_group_size=128 \ 13 | --rope_is_neox_style=true \ 14 | --prompt="Translate English to French: 15 | sea otter => loutre de mer 16 | peppermint => menthe poivrée 17 | plush girafe => girafe peluche 18 | cheese =>" 19 | -------------------------------------------------------------------------------- /scripts/example_text_completion_llama2_13B_awq_4bit_orangepi_2.sh: -------------------------------------------------------------------------------- 1 | #gdb --args \ 2 | ./build/src/llama2_main \ 3 | --config=llama/llama2/13B/params.json \ 4 | --tokenizer=llama/llama2/tokenizer.model \ 5 | --weight=/data/Llama-2-13B-AWQ/llama2_13B_awq_4bit \ 6 | --device_type=npu \ 7 | --max_seq_len=2048 \ 8 | --max_gen_token=256 \ 9 | --log_level=info \ 10 | --debug_print=false \ 11 | --quant_method=awq_4bit \ 12 | --quant_group_size=128 \ 13 | --rope_is_neox_style=true \ 14 | --prompt="You are a virtual tour guide from 1901. You have tourists visiting Eiffel Tower. Describe Eiffel Tower to your audience. Begin with 15 | 1. Why it was built 16 | 2. Then by how long it took them to build 17 | 3. Where were the materials sourced to build 18 | 4. Number of people it took to build 19 | 5. End it with the number of people visiting the Eiffel tour annually in the 1900's, the amount of time it completes a full tour and why so many people visit this place each year. 20 | Make your tour funny by including 1 or 2 funny jokes at the end of the tour." 21 | -------------------------------------------------------------------------------- /scripts/example_text_completion_llama2_7B_awq_4bit_orangepi.sh: -------------------------------------------------------------------------------- 1 | #gdb --args \ 2 | ./build/src/llama2_main \ 3 | --config=llama/llama2/7B/params.json \ 4 | --tokenizer=llama/llama2/tokenizer.model \ 5 | --weight=/data/llama2_7b_awq/llama2_7b_awq_4bit/ \ 6 | --device_type=npu \ 7 | --max_seq_len=2048 \ 8 | --max_gen_token=256 \ 9 | --log_level=info \ 10 | --debug_print=false \ 11 | --quant_method=awq_4bit \ 12 | --quant_group_size=128 \ 13 | --rope_is_neox_style=true \ 14 | --prompt="Translate English to French: 15 | sea otter => loutre de mer 16 | peppermint => menthe poivrée 17 | plush girafe => girafe peluche 18 | cheese =>" 19 | -------------------------------------------------------------------------------- /scripts/example_text_completion_llama2_7B_awq_4bit_orangepi_2.sh: -------------------------------------------------------------------------------- 1 | #gdb --args \ 2 | ./build/src/llama2_main \ 3 | --config=/data/llama2/7B/params.json \ 4 | --tokenizer=/data/llama2_7b_awq/tokenizer.model \ 5 | --weight=/data/llama2_7b_awq/llama2_7b_awq_4bit/ \ 6 | --device_type=npu \ 7 | --max_seq_len=2048 \ 8 | --max_gen_token=256 \ 9 | --log_level=info \ 10 | --quant_method=awq_4bit \ 11 | --quant_group_size=128 \ 12 | --prompt="You are a virtual tour guide from 1901. You have tourists visiting Eiffel Tower. Describe Eiffel Tower to your audience. Begin with 13 | 1. Why it was built 14 | 2. Then by how long it took them to build 15 | 3. Where were the materials sourced to build 16 | 4. Number of people it took to build 17 | 5. End it with the number of people visiting the Eiffel tour annually in the 1900's, the amount of time it completes a full tour and why so many people visit this place each year. 18 | Make your tour funny by including 1 or 2 funny jokes at the end of the tour." 19 | -------------------------------------------------------------------------------- /scripts/example_text_completion_llama2_7B_fp16_orangepi.sh: -------------------------------------------------------------------------------- 1 | #gdb --args \ 2 | ./build/src/llama2_main \ 3 | --config=llama/llama2/7B/params.json \ 4 | --tokenizer=llama/llama2/tokenizer.model \ 5 | --weight=/data/llama2/7B/model_output \ 6 | --device_type=npu \ 7 | --max_seq_len=2048 \ 8 | --log_level=info \ 9 | --prompt="Translate English to French: 10 | sea otter => loutre de mer 11 | peppermint => menthe poivrée 12 | plush girafe => girafe peluche 13 | cheese =>" 14 | -------------------------------------------------------------------------------- /scripts/example_text_completion_llama2_7B_fp16_orangepi_2.sh: -------------------------------------------------------------------------------- 1 | #gdb --args \ 2 | ./build/src/llama2_main \ 3 | --config=/data/llama2/7B/params.json \ 4 | --tokenizer=/data/llama2/tokenizer.model \ 5 | --weight=/data/llama2/7B/model_output \ 6 | --device_type=npu \ 7 | --max_seq_len=2048 \ 8 | --max_gen_token=256 \ 9 | --log_level=info \ 10 | --prompt="You are a virtual tour guide from 1901. You have tourists visiting Eiffel Tower. Describe Eiffel Tower to your audience. Begin with 11 | 1. Why it was built 12 | 2. Then by how long it took them to build 13 | 3. Where were the materials sourced to build 14 | 4. Number of people it took to build 15 | 5. End it with the number of people visiting the Eiffel tour annually in the 1900's, the amount of time it completes a full tour and why so many people visit this place each year. 16 | Make your tour funny by including 1 or 2 funny jokes at the end of the tour." 17 | -------------------------------------------------------------------------------- /scripts/example_text_completion_qwen2.5_14B_awq_orangepi.sh: -------------------------------------------------------------------------------- 1 | ./build/src/llama2_main \ 2 | --config=/ssd/models/Qwen2.5-14B-Instruct-AWQ/config.json \ 3 | --tokenizer=/ssd/models/Qwen2.5-14B-Instruct-AWQ/ \ 4 | --weight=/ssd/models/Qwen2.5-14B-Instruct-AWQ_converted \ 5 | --model_type=qwen2 \ 6 | --device_type=npu \ 7 | --max_seq_len=8192 \ 8 | --max_gen_token=256 \ 9 | --temperature=0.6 \ 10 | --debug_print=false \ 11 | --log_level=info \ 12 | --rope_is_neox_style=true \ 13 | --prompt="深度优化指南请求:通用矩阵乘法(GEMM)的全栈性能工程实践 14 | 15 | 请以高性能计算专家的身份,系统论述GEMM优化的完整技术体系。要求从晶体管层面到算法层进行跨抽象层分析,包含以下维度: 16 | 17 | 一、硬件感知优化基础(展开以下每个子项) 18 | 19 | 现代CPU内存层级解剖 20 | 21 | 详述如何通过cache blocking适应L1/L2/L3缓存行 22 | 23 | 示例:针对不同缓存容量(如8MB L3)的分块策略计算公式 24 | 25 | 数据预取模式设计(软件预取指令的最佳插入距离) 26 | 27 | SIMD指令工程实践 28 | 29 | AVX-512与ARM SVE的寄存器压力对比 30 | 31 | 汇编级循环展开策略(展示8x8分块的双缓冲汇编模板) 32 | 33 | FMA指令流水线冒险规避技巧 34 | 35 | GPU架构深度适配 36 | 37 | CUDA warp-level同步优化(共享内存bank conflict量化分析) 38 | 39 | 全局内存合并访问模式设计(展示2D tile的访存对齐公式) 40 | 41 | Tensor Core编程范式(WMMA API使用陷阱与性能调优日志) 42 | 43 | 二、算法革新路线(需数学推导) 44 | 45 | 复杂分治策略 46 | 47 | Strassen算法在实践中的递归终止条件选择(给出浮点误差传播模型) 48 | 49 | Winograd变换的数值稳定性改进方案 50 | 51 | 基于分块秩的近似算法误差界证明 52 | 53 | 稀疏化与量化 54 | 55 | 结构化稀疏的硬件友好模式设计(NVIDIA A100稀疏特性适配) 56 | 57 | 混合精度训练中的动态缩放因子推导 58 | 59 | 低秩近似与GEMM的耦合优化(给出SVD截断误差分析) 60 | 61 | 三、编译工程化进阶 62 | 63 | LLVM中间表示调优 64 | 65 | Polly循环优化编译指示实战 66 | 67 | MLIR GEMM方言生成技术路线 68 | 69 | 自动向量化失败的补救模式 70 | 71 | 自动调参系统设计 72 | 73 | 遗传算法参数空间剪枝策略 74 | 75 | 贝叶斯优化中的协方差矩阵自适应 76 | 77 | 多目标优化Pareto前沿的筛选标准 78 | 79 | 四、异构计算协同 80 | 81 | 多芯片负载均衡 82 | 83 | CPU-GPU流水线深度分析(计算/通信重叠的数学模型) 84 | 85 | 基于RDMA的跨设备零拷贝实现 86 | 87 | 异构内存一致性模型解决方案 88 | 89 | 五、验证方法论 90 | 91 | Roofline模型深度应用 92 | 93 | 实测不同架构的运算强度阈值 94 | 95 | 性能偏离度诊断流程图 96 | 97 | 瓶颈定位的热力图分析法 98 | 99 | 技术写作要求: 100 | 101 | 每个优化点需提供理论依据(附复杂度公式推导) 102 | 103 | 关键路径给出CUDA/C++代码片段及编译器内联汇编示例 104 | 105 | 包含主流硬件实测数据(如A100 vs Xeon Platinum对比表格) 106 | 107 | 讨论商业化实现差异(对比oneDNN vs cuBLAS设计哲学) 108 | 109 | 最后给出优化决策树(包含分支判断条件) 110 | 111 | 请采用学术论文写作规范,分章节编号至三级标题,使用LaTeX公式描述关键技术指标,总输出保持工程技术文档的严谨性同时具备可操作性。" 112 | -------------------------------------------------------------------------------- /scripts/example_text_completion_qwen2.5_32B_awq_orangepi.sh: -------------------------------------------------------------------------------- 1 | ./build/src/llama2_main \ 2 | --config=/ssd/models/Qwen2.5-32B-Instruct-AWQ/config.json \ 3 | --tokenizer=/ssd/models/Qwen2.5-32B-Instruct-AWQ/ \ 4 | --weight=/ssd/models/Qwen2.5-32B-Instruct-AWQ_converted \ 5 | --model_type=qwen2 \ 6 | --device_type=npu \ 7 | --max_seq_len=8192 \ 8 | --max_gen_token=256 \ 9 | --temperature=0.6 \ 10 | --debug_print=false \ 11 | --log_level=info \ 12 | --rope_is_neox_style=true \ 13 | --prompt="深度优化指南请求:通用矩阵乘法(GEMM)的全栈性能工程实践 14 | 15 | 请以高性能计算专家的身份,系统论述GEMM优化的完整技术体系。要求从晶体管层面到算法层进行跨抽象层分析,包含以下维度: 16 | 17 | 一、硬件感知优化基础(展开以下每个子项) 18 | 19 | 现代CPU内存层级解剖 20 | 21 | 详述如何通过cache blocking适应L1/L2/L3缓存行 22 | 23 | 示例:针对不同缓存容量(如8MB L3)的分块策略计算公式 24 | 25 | 数据预取模式设计(软件预取指令的最佳插入距离) 26 | 27 | SIMD指令工程实践 28 | 29 | AVX-512与ARM SVE的寄存器压力对比 30 | 31 | 汇编级循环展开策略(展示8x8分块的双缓冲汇编模板) 32 | 33 | FMA指令流水线冒险规避技巧 34 | 35 | GPU架构深度适配 36 | 37 | CUDA warp-level同步优化(共享内存bank conflict量化分析) 38 | 39 | 全局内存合并访问模式设计(展示2D tile的访存对齐公式) 40 | 41 | Tensor Core编程范式(WMMA API使用陷阱与性能调优日志) 42 | 43 | 二、算法革新路线(需数学推导) 44 | 45 | 复杂分治策略 46 | 47 | Strassen算法在实践中的递归终止条件选择(给出浮点误差传播模型) 48 | 49 | Winograd变换的数值稳定性改进方案 50 | 51 | 基于分块秩的近似算法误差界证明 52 | 53 | 稀疏化与量化 54 | 55 | 结构化稀疏的硬件友好模式设计(NVIDIA A100稀疏特性适配) 56 | 57 | 混合精度训练中的动态缩放因子推导 58 | 59 | 低秩近似与GEMM的耦合优化(给出SVD截断误差分析) 60 | 61 | 三、编译工程化进阶 62 | 63 | LLVM中间表示调优 64 | 65 | Polly循环优化编译指示实战 66 | 67 | MLIR GEMM方言生成技术路线 68 | 69 | 自动向量化失败的补救模式 70 | 71 | 自动调参系统设计 72 | 73 | 遗传算法参数空间剪枝策略 74 | 75 | 贝叶斯优化中的协方差矩阵自适应 76 | 77 | 多目标优化Pareto前沿的筛选标准 78 | 79 | 四、异构计算协同 80 | 81 | 多芯片负载均衡 82 | 83 | CPU-GPU流水线深度分析(计算/通信重叠的数学模型) 84 | 85 | 基于RDMA的跨设备零拷贝实现 86 | 87 | 异构内存一致性模型解决方案 88 | 89 | 五、验证方法论 90 | 91 | Roofline模型深度应用 92 | 93 | 实测不同架构的运算强度阈值 94 | 95 | 性能偏离度诊断流程图 96 | 97 | 瓶颈定位的热力图分析法 98 | 99 | 技术写作要求: 100 | 101 | 每个优化点需提供理论依据(附复杂度公式推导) 102 | 103 | 关键路径给出CUDA/C++代码片段及编译器内联汇编示例 104 | 105 | 包含主流硬件实测数据(如A100 vs Xeon Platinum对比表格) 106 | 107 | 讨论商业化实现差异(对比oneDNN vs cuBLAS设计哲学) 108 | 109 | 最后给出优化决策树(包含分支判断条件) 110 | 111 | 请采用学术论文写作规范,分章节编号至三级标题,使用LaTeX公式描述关键技术指标,总输出保持工程技术文档的严谨性同时具备可操作性。" 112 | -------------------------------------------------------------------------------- /scripts/example_text_completion_qwen2.5_3B_awq_orangepi.sh: -------------------------------------------------------------------------------- 1 | ./build/src/llama2_main \ 2 | --config=/ssd/models/Qwen2.5-3B-Instruct-AWQ/config.json \ 3 | --tokenizer=/ssd/models/Qwen2.5-3B-Instruct-AWQ/ \ 4 | --weight=/ssd/models/Qwen2.5-3B-Instruct-AWQ_converted \ 5 | --model_type=qwen2 \ 6 | --device_type=npu \ 7 | --max_seq_len=8192 \ 8 | --max_gen_token=256 \ 9 | --temperature=0.6 \ 10 | --debug_print=false \ 11 | --log_level=info \ 12 | --rope_is_neox_style=true \ 13 | --prompt="深度优化指南请求:通用矩阵乘法(GEMM)的全栈性能工程实践 14 | 15 | 请以高性能计算专家的身份,系统论述GEMM优化的完整技术体系。要求从晶体管层面到算法层进行跨抽象层分析,包含以下维度: 16 | 17 | 一、硬件感知优化基础(展开以下每个子项) 18 | 19 | 现代CPU内存层级解剖 20 | 21 | 详述如何通过cache blocking适应L1/L2/L3缓存行 22 | 23 | 示例:针对不同缓存容量(如8MB L3)的分块策略计算公式 24 | 25 | 数据预取模式设计(软件预取指令的最佳插入距离) 26 | 27 | SIMD指令工程实践 28 | 29 | AVX-512与ARM SVE的寄存器压力对比 30 | 31 | 汇编级循环展开策略(展示8x8分块的双缓冲汇编模板) 32 | 33 | FMA指令流水线冒险规避技巧 34 | 35 | GPU架构深度适配 36 | 37 | CUDA warp-level同步优化(共享内存bank conflict量化分析) 38 | 39 | 全局内存合并访问模式设计(展示2D tile的访存对齐公式) 40 | 41 | Tensor Core编程范式(WMMA API使用陷阱与性能调优日志) 42 | 43 | 二、算法革新路线(需数学推导) 44 | 45 | 复杂分治策略 46 | 47 | Strassen算法在实践中的递归终止条件选择(给出浮点误差传播模型) 48 | 49 | Winograd变换的数值稳定性改进方案 50 | 51 | 基于分块秩的近似算法误差界证明 52 | 53 | 稀疏化与量化 54 | 55 | 结构化稀疏的硬件友好模式设计(NVIDIA A100稀疏特性适配) 56 | 57 | 混合精度训练中的动态缩放因子推导 58 | 59 | 低秩近似与GEMM的耦合优化(给出SVD截断误差分析) 60 | 61 | 三、编译工程化进阶 62 | 63 | LLVM中间表示调优 64 | 65 | Polly循环优化编译指示实战 66 | 67 | MLIR GEMM方言生成技术路线 68 | 69 | 自动向量化失败的补救模式 70 | 71 | 自动调参系统设计 72 | 73 | 遗传算法参数空间剪枝策略 74 | 75 | 贝叶斯优化中的协方差矩阵自适应 76 | 77 | 多目标优化Pareto前沿的筛选标准 78 | 79 | 四、异构计算协同 80 | 81 | 多芯片负载均衡 82 | 83 | CPU-GPU流水线深度分析(计算/通信重叠的数学模型) 84 | 85 | 基于RDMA的跨设备零拷贝实现 86 | 87 | 异构内存一致性模型解决方案 88 | 89 | 五、验证方法论 90 | 91 | Roofline模型深度应用 92 | 93 | 实测不同架构的运算强度阈值 94 | 95 | 性能偏离度诊断流程图 96 | 97 | 瓶颈定位的热力图分析法 98 | 99 | 技术写作要求: 100 | 101 | 每个优化点需提供理论依据(附复杂度公式推导) 102 | 103 | 关键路径给出CUDA/C++代码片段及编译器内联汇编示例 104 | 105 | 包含主流硬件实测数据(如A100 vs Xeon Platinum对比表格) 106 | 107 | 讨论商业化实现差异(对比oneDNN vs cuBLAS设计哲学) 108 | 109 | 最后给出优化决策树(包含分支判断条件) 110 | 111 | 请采用学术论文写作规范,分章节编号至三级标题,使用LaTeX公式描述关键技术指标,总输出保持工程技术文档的严谨性同时具备可操作性。" 112 | -------------------------------------------------------------------------------- /scripts/example_text_completion_qwen2.5_3B_bf16_orangepi.sh: -------------------------------------------------------------------------------- 1 | ./build/src/llama2_main \ 2 | --config=/ssd/models/Qwen2.5-3B-Instruct/config.json \ 3 | --tokenizer=/ssd/models/Qwen2.5-3B-Instruct/ \ 4 | --weight=/ssd/models/Qwen2.5-3B-Instruct_converted \ 5 | --model_type=qwen2 \ 6 | --device_type=npu \ 7 | --max_seq_len=8192 \ 8 | --max_gen_token=256 \ 9 | --temperature=0.6 \ 10 | --debug_print=false \ 11 | --log_level=info \ 12 | --rope_is_neox_style=true \ 13 | --prompt="深度优化指南请求:通用矩阵乘法(GEMM)的全栈性能工程实践 14 | 15 | 请以高性能计算专家的身份,系统论述GEMM优化的完整技术体系。要求从晶体管层面到算法层进行跨抽象层分析,包含以下维度: 16 | 17 | 一、硬件感知优化基础(展开以下每个子项) 18 | 19 | 现代CPU内存层级解剖 20 | 21 | 详述如何通过cache blocking适应L1/L2/L3缓存行 22 | 23 | 示例:针对不同缓存容量(如8MB L3)的分块策略计算公式 24 | 25 | 数据预取模式设计(软件预取指令的最佳插入距离) 26 | 27 | SIMD指令工程实践 28 | 29 | AVX-512与ARM SVE的寄存器压力对比 30 | 31 | 汇编级循环展开策略(展示8x8分块的双缓冲汇编模板) 32 | 33 | FMA指令流水线冒险规避技巧 34 | 35 | GPU架构深度适配 36 | 37 | CUDA warp-level同步优化(共享内存bank conflict量化分析) 38 | 39 | 全局内存合并访问模式设计(展示2D tile的访存对齐公式) 40 | 41 | Tensor Core编程范式(WMMA API使用陷阱与性能调优日志) 42 | 43 | 二、算法革新路线(需数学推导) 44 | 45 | 复杂分治策略 46 | 47 | Strassen算法在实践中的递归终止条件选择(给出浮点误差传播模型) 48 | 49 | Winograd变换的数值稳定性改进方案 50 | 51 | 基于分块秩的近似算法误差界证明 52 | 53 | 稀疏化与量化 54 | 55 | 结构化稀疏的硬件友好模式设计(NVIDIA A100稀疏特性适配) 56 | 57 | 混合精度训练中的动态缩放因子推导 58 | 59 | 低秩近似与GEMM的耦合优化(给出SVD截断误差分析) 60 | 61 | 三、编译工程化进阶 62 | 63 | LLVM中间表示调优 64 | 65 | Polly循环优化编译指示实战 66 | 67 | MLIR GEMM方言生成技术路线 68 | 69 | 自动向量化失败的补救模式 70 | 71 | 自动调参系统设计 72 | 73 | 遗传算法参数空间剪枝策略 74 | 75 | 贝叶斯优化中的协方差矩阵自适应 76 | 77 | 多目标优化Pareto前沿的筛选标准 78 | 79 | 四、异构计算协同 80 | 81 | 多芯片负载均衡 82 | 83 | CPU-GPU流水线深度分析(计算/通信重叠的数学模型) 84 | 85 | 基于RDMA的跨设备零拷贝实现 86 | 87 | 异构内存一致性模型解决方案 88 | 89 | 五、验证方法论 90 | 91 | Roofline模型深度应用 92 | 93 | 实测不同架构的运算强度阈值 94 | 95 | 性能偏离度诊断流程图 96 | 97 | 瓶颈定位的热力图分析法 98 | 99 | 技术写作要求: 100 | 101 | 每个优化点需提供理论依据(附复杂度公式推导) 102 | 103 | 关键路径给出CUDA/C++代码片段及编译器内联汇编示例 104 | 105 | 包含主流硬件实测数据(如A100 vs Xeon Platinum对比表格) 106 | 107 | 讨论商业化实现差异(对比oneDNN vs cuBLAS设计哲学) 108 | 109 | 最后给出优化决策树(包含分支判断条件) 110 | 111 | 请采用学术论文写作规范,分章节编号至三级标题,使用LaTeX公式描述关键技术指标,总输出保持工程技术文档的严谨性同时具备可操作性。" 112 | -------------------------------------------------------------------------------- /scripts/example_text_completion_qwen2.5_7B_awq_orangepi.sh: -------------------------------------------------------------------------------- 1 | ./build/src/llama2_main \ 2 | --config=/ssd/models/Qwen2.5-7B-Instruct-AWQ/config.json \ 3 | --tokenizer=/ssd/models/Qwen2.5-7B-Instruct-AWQ/ \ 4 | --weight=/ssd/models/Qwen2.5-7B-Instruct-AWQ_converted \ 5 | --model_type=qwen2 \ 6 | --device_type=npu \ 7 | --max_seq_len=8192 \ 8 | --max_gen_token=256 \ 9 | --temperature=0.6 \ 10 | --debug_print=false \ 11 | --log_level=info \ 12 | --rope_is_neox_style=true \ 13 | --prompt="深度优化指南请求:通用矩阵乘法(GEMM)的全栈性能工程实践 14 | 15 | 请以高性能计算专家的身份,系统论述GEMM优化的完整技术体系。要求从晶体管层面到算法层进行跨抽象层分析,包含以下维度: 16 | 17 | 一、硬件感知优化基础(展开以下每个子项) 18 | 19 | 现代CPU内存层级解剖 20 | 21 | 详述如何通过cache blocking适应L1/L2/L3缓存行 22 | 23 | 示例:针对不同缓存容量(如8MB L3)的分块策略计算公式 24 | 25 | 数据预取模式设计(软件预取指令的最佳插入距离) 26 | 27 | SIMD指令工程实践 28 | 29 | AVX-512与ARM SVE的寄存器压力对比 30 | 31 | 汇编级循环展开策略(展示8x8分块的双缓冲汇编模板) 32 | 33 | FMA指令流水线冒险规避技巧 34 | 35 | GPU架构深度适配 36 | 37 | CUDA warp-level同步优化(共享内存bank conflict量化分析) 38 | 39 | 全局内存合并访问模式设计(展示2D tile的访存对齐公式) 40 | 41 | Tensor Core编程范式(WMMA API使用陷阱与性能调优日志) 42 | 43 | 二、算法革新路线(需数学推导) 44 | 45 | 复杂分治策略 46 | 47 | Strassen算法在实践中的递归终止条件选择(给出浮点误差传播模型) 48 | 49 | Winograd变换的数值稳定性改进方案 50 | 51 | 基于分块秩的近似算法误差界证明 52 | 53 | 稀疏化与量化 54 | 55 | 结构化稀疏的硬件友好模式设计(NVIDIA A100稀疏特性适配) 56 | 57 | 混合精度训练中的动态缩放因子推导 58 | 59 | 低秩近似与GEMM的耦合优化(给出SVD截断误差分析) 60 | 61 | 三、编译工程化进阶 62 | 63 | LLVM中间表示调优 64 | 65 | Polly循环优化编译指示实战 66 | 67 | MLIR GEMM方言生成技术路线 68 | 69 | 自动向量化失败的补救模式 70 | 71 | 自动调参系统设计 72 | 73 | 遗传算法参数空间剪枝策略 74 | 75 | 贝叶斯优化中的协方差矩阵自适应 76 | 77 | 多目标优化Pareto前沿的筛选标准 78 | 79 | 四、异构计算协同 80 | 81 | 多芯片负载均衡 82 | 83 | CPU-GPU流水线深度分析(计算/通信重叠的数学模型) 84 | 85 | 基于RDMA的跨设备零拷贝实现 86 | 87 | 异构内存一致性模型解决方案 88 | 89 | 五、验证方法论 90 | 91 | Roofline模型深度应用 92 | 93 | 实测不同架构的运算强度阈值 94 | 95 | 性能偏离度诊断流程图 96 | 97 | 瓶颈定位的热力图分析法 98 | 99 | 技术写作要求: 100 | 101 | 每个优化点需提供理论依据(附复杂度公式推导) 102 | 103 | 关键路径给出CUDA/C++代码片段及编译器内联汇编示例 104 | 105 | 包含主流硬件实测数据(如A100 vs Xeon Platinum对比表格) 106 | 107 | 讨论商业化实现差异(对比oneDNN vs cuBLAS设计哲学) 108 | 109 | 最后给出优化决策树(包含分支判断条件) 110 | 111 | 请采用学术论文写作规范,分章节编号至三级标题,使用LaTeX公式描述关键技术指标,总输出保持工程技术文档的严谨性同时具备可操作性。" 112 | -------------------------------------------------------------------------------- /scripts/example_text_completion_qwen2.5_7B_bf16_orangepi.sh: -------------------------------------------------------------------------------- 1 | ./build/src/llama2_main \ 2 | --config=/ssd/models/Qwen2.5-7B-Instruct/config.json \ 3 | --tokenizer=/ssd/models/Qwen2.5-7B-Instruct/ \ 4 | --weight=/ssd/models/Qwen2.5-7B-Instruct_converted \ 5 | --model_type=qwen2 \ 6 | --device_type=npu \ 7 | --max_seq_len=8192 \ 8 | --max_gen_token=256 \ 9 | --temperature=0.0 \ 10 | --debug_print=false \ 11 | --log_level=info \ 12 | --rope_is_neox_style=true \ 13 | --prompt="深度优化指南请求:通用矩阵乘法(GEMM)的全栈性能工程实践 14 | 15 | 请以高性能计算专家的身份,系统论述GEMM优化的完整技术体系。要求从晶体管层面到算法层进行跨抽象层分析,包含以下维度: 16 | 17 | 一、硬件感知优化基础(展开以下每个子项) 18 | 19 | 现代CPU内存层级解剖 20 | 21 | 详述如何通过cache blocking适应L1/L2/L3缓存行 22 | 23 | 示例:针对不同缓存容量(如8MB L3)的分块策略计算公式 24 | 25 | 数据预取模式设计(软件预取指令的最佳插入距离) 26 | 27 | SIMD指令工程实践 28 | 29 | AVX-512与ARM SVE的寄存器压力对比 30 | 31 | 汇编级循环展开策略(展示8x8分块的双缓冲汇编模板) 32 | 33 | FMA指令流水线冒险规避技巧 34 | 35 | GPU架构深度适配 36 | 37 | CUDA warp-level同步优化(共享内存bank conflict量化分析) 38 | 39 | 全局内存合并访问模式设计(展示2D tile的访存对齐公式) 40 | 41 | Tensor Core编程范式(WMMA API使用陷阱与性能调优日志) 42 | 43 | 二、算法革新路线(需数学推导) 44 | 45 | 复杂分治策略 46 | 47 | Strassen算法在实践中的递归终止条件选择(给出浮点误差传播模型) 48 | 49 | Winograd变换的数值稳定性改进方案 50 | 51 | 基于分块秩的近似算法误差界证明 52 | 53 | 稀疏化与量化 54 | 55 | 结构化稀疏的硬件友好模式设计(NVIDIA A100稀疏特性适配) 56 | 57 | 混合精度训练中的动态缩放因子推导 58 | 59 | 低秩近似与GEMM的耦合优化(给出SVD截断误差分析) 60 | 61 | 三、编译工程化进阶 62 | 63 | LLVM中间表示调优 64 | 65 | Polly循环优化编译指示实战 66 | 67 | MLIR GEMM方言生成技术路线 68 | 69 | 自动向量化失败的补救模式 70 | 71 | 自动调参系统设计 72 | 73 | 遗传算法参数空间剪枝策略 74 | 75 | 贝叶斯优化中的协方差矩阵自适应 76 | 77 | 多目标优化Pareto前沿的筛选标准 78 | 79 | 四、异构计算协同 80 | 81 | 多芯片负载均衡 82 | 83 | CPU-GPU流水线深度分析(计算/通信重叠的数学模型) 84 | 85 | 基于RDMA的跨设备零拷贝实现 86 | 87 | 异构内存一致性模型解决方案 88 | 89 | 五、验证方法论 90 | 91 | Roofline模型深度应用 92 | 93 | 实测不同架构的运算强度阈值 94 | 95 | 性能偏离度诊断流程图 96 | 97 | 瓶颈定位的热力图分析法 98 | 99 | 技术写作要求: 100 | 101 | 每个优化点需提供理论依据(附复杂度公式推导) 102 | 103 | 关键路径给出CUDA/C++代码片段及编译器内联汇编示例 104 | 105 | 包含主流硬件实测数据(如A100 vs Xeon Platinum对比表格) 106 | 107 | 讨论商业化实现差异(对比oneDNN vs cuBLAS设计哲学) 108 | 109 | 最后给出优化决策树(包含分支判断条件) 110 | 111 | 请采用学术论文写作规范,分章节编号至三级标题,使用LaTeX公式描述关键技术指标,总输出保持工程技术文档的严谨性同时具备可操作性。" 112 | -------------------------------------------------------------------------------- /src/.gitignore: -------------------------------------------------------------------------------- 1 | npu_ops/ -------------------------------------------------------------------------------- /src/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | if (DEFINED ACL_PATH) 2 | message(STATUS "user set ACL_PATH: ${ACL_PATH}") 3 | elseif (EXISTS /usr/local/Ascend/acllib/) 4 | set(ACL_PATH "/usr/local/Ascend/acllib") 5 | message(STATUS "set ACL_PATH: /usr/local/Ascend/acllib") 6 | elseif (EXISTS /usr/local/Ascend/ascend-toolkit/latest/acllib) 7 | set(ACL_PATH "/usr/local/Ascend/ascend-toolkit/latest/acllib") 8 | message(STATUS "set ACL_PATH to default path: /usr/local/Ascend/ascend-toolkit/latest/acllib") 9 | elseif (EXISTS /usr/local/Ascend/nnrt/latest/acllib) 10 | set(ACL_PATH "/usr/local/Ascend/nnrt/latest/acllib") 11 | message(STATUS "set ACL_PATH to default path: /usr/local/Ascend/nnrt/latest/acllib") 12 | else () 13 | set(ACL_PATH "/home/HwHiAiUser/Ascend/acllib") 14 | message(STATUS "set ACL_PATH to default path: /home/HwHiAiUser/Ascend/acllib") 15 | endif() 16 | 17 | set(Python_FIND_VIRTUALENV FIRST) 18 | 19 | find_program(Python_EXECUTABLE 20 | NAMES python3 python 21 | DOC "Path to Python executable" 22 | ) 23 | if(NOT Python_EXECUTABLE) 24 | message(FATAL_ERROR "Python executable not found in PATH!") 25 | endif() 26 | 27 | 28 | find_package(Python REQUIRED COMPONENTS Interpreter Development) 29 | if(NOT Python_FOUND) 30 | message(FATAL_ERROR "Python development libraries not found") 31 | endif() 32 | message("using python: ${Python_EXECUTABLE}") 33 | message("using python lib: ${Python_LIBRARY_DIRS}") 34 | 35 | find_package (Eigen3 REQUIRED NO_MODULE) 36 | 37 | file(GLOB SOURCE_FILES *.cpp) 38 | 39 | add_executable(llama2_main ${SOURCE_FILES}) 40 | message("${EIGEN3_INCLUDE_DIR}") 41 | target_include_directories(llama2_main PUBLIC ${EIGEN3_INCLUDE_DIR} ${ACL_PATH}/include ../prebuild ${Python_INCLUDE_DIRS}) 42 | target_link_directories(llama2_main PUBLIC ${ACL_PATH}/lib64 ../prebuild ${Python_LIBRARY_DIRS}) 43 | target_link_libraries(llama2_main sentencepiece boost_program_options boost_system boost_filesystem fmt ascendcl runtime npu_ops re2 ${Python_LIBRARIES}) 44 | -------------------------------------------------------------------------------- /src/acl_util.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "acl/acl.h" 4 | #include 5 | 6 | #define CHECK_ACL(x) \ 7 | do { \ 8 | aclError __ret = x; \ 9 | if (__ret != ACL_ERROR_NONE) { \ 10 | spdlog::error("{}:{} aclError: {}", __FILE__, __LINE__, __ret); \ 11 | } \ 12 | } while (0); 13 | 14 | 15 | -------------------------------------------------------------------------------- /src/base64.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace base64 { 7 | 8 | static auto pos_of_char(const unsigned char chr) -> size_t { 9 | if (chr >= 'A' && chr <= 'Z') return chr - 'A'; 10 | else if (chr >= 'a' && chr <= 'z') return chr - 'a' + ('Z' - 'A') + 1; 11 | else if (chr >= '0' && chr <= '9') return chr - '0' + ('Z' - 'A') + ('z' - 'a') + 2; 12 | else if (chr == '+' || chr == '-') return 62; 13 | else if (chr == '/' || chr == '_') return 63; 14 | else throw std::runtime_error("Input is not valid base64-encoded data."); 15 | } 16 | 17 | inline auto decode(std::string_view s) -> std::string { 18 | if (s.empty()) throw std::runtime_error("empty input"); 19 | size_t length = s.length(); 20 | size_t idx = 0; 21 | 22 | std::string out; 23 | out.reserve(length / 4 * 3); 24 | 25 | while (idx < length) { 26 | size_t pos_of_char_1 = pos_of_char(s.at(idx + 1)); 27 | out.push_back(static_cast(((pos_of_char(s.at(idx+0))) << 2 ) + ((pos_of_char_1 & 0x30) >> 4))); 28 | if ((idx + 2 < length) && s.at(idx + 2) != '=' && s.at(idx + 2) != '.') { 29 | size_t pos_of_char_2 = pos_of_char(s.at(idx + 2)); 30 | out.push_back(static_cast(((pos_of_char_1 & 0x0f) << 4) + ((pos_of_char_2 & 0x3c) >> 2))); 31 | if ((idx + 3 < length) && s.at(idx + 3) != '=' && s.at(idx + 3) != '.') { 32 | out.push_back(static_cast(((pos_of_char_2 & 0x03) << 6) + pos_of_char(s.at(idx+3)))); 33 | } 34 | } 35 | idx += 4; 36 | } 37 | return out; 38 | } 39 | 40 | } // namespace base64 41 | -------------------------------------------------------------------------------- /src/device.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "device.hpp" 3 | #include "acl_util.hpp" 4 | 5 | 6 | Tensor::~Tensor() { 7 | switch (dev_type) 8 | { 9 | case DEV_NPU: 10 | NPUAllocator::Deallocate(data_ptr); 11 | break; 12 | case DEV_CPU: 13 | CPUAllocator::Deallocate(data_ptr); 14 | break; 15 | 16 | default: 17 | break; 18 | } 19 | } 20 | 21 | 22 | std::shared_ptr Tensor::to(DeviceType to_dev) { 23 | switch (dev_type) { 24 | case DEV_NPU: 25 | switch (to_dev) { 26 | case DEV_NPU: 27 | return shared_from_this(); 28 | case DEV_CPU: 29 | { 30 | auto result = Tensor::MakeCPUTensor(data_size, data_type); 31 | CHECK_ACL(aclrtMemcpy(result->data_ptr, SizeOfTensor(data_size, data_type), data_ptr, 32 | SizeOfTensor(data_size, data_type), ACL_MEMCPY_DEVICE_TO_HOST)); 33 | return result; 34 | } 35 | } 36 | break; 37 | case DEV_CPU: 38 | switch (to_dev) { 39 | case DEV_NPU: 40 | { 41 | auto result = Tensor::MakeNPUTensor(data_size, data_type); 42 | CHECK_ACL(aclrtMemcpy(result->data_ptr, SizeOfTensor(data_size, data_type), data_ptr, 43 | SizeOfTensor(data_size, data_type), ACL_MEMCPY_HOST_TO_DEVICE)); 44 | return result; 45 | } 46 | case DEV_CPU: 47 | return shared_from_this(); 48 | } 49 | break; 50 | } 51 | return std::shared_ptr(); 52 | } 53 | 54 | void Tensor::to_file(const char* path) { 55 | void* curr_data_ptr = data_ptr; 56 | std::shared_ptr temp_ref; 57 | if (dev_type != DEV_CPU) { 58 | temp_ref = this->to(DEV_CPU); 59 | curr_data_ptr = temp_ref->data_ptr; 60 | } 61 | 62 | std::ofstream ofs(path, std::ios::binary); 63 | ofs.write((const char*)curr_data_ptr, SizeOfTensor(data_size, data_type)); 64 | } 65 | -------------------------------------------------------------------------------- /src/device.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "defs.hpp" 8 | 9 | class CPUAllocator { 10 | public: 11 | static CPUAllocator& GetInstance(); 12 | static void* Allocate(size_t size); 13 | static void Deallocate(void* ptr); 14 | }; 15 | 16 | class NPUAllocatorEntry { 17 | public: 18 | NPUAllocatorEntry(size_t s, void* p); 19 | bool operator < (const NPUAllocatorEntry& other); 20 | size_t size; 21 | void* ptr; 22 | }; 23 | 24 | class NPUAllocator { 25 | public: 26 | static NPUAllocator& GetInstance(); 27 | static void* Allocate(size_t size); 28 | static void Deallocate(void* ptr); 29 | private: 30 | void* AllocateImpl(size_t size); 31 | void DeallocateImpl(void* ptr); 32 | 33 | std::list freelist; 34 | std::unordered_map ptr_size; 35 | size_t dev_mem_max{8*1024ULL*1024ULL*1024ULL}; 36 | size_t dev_mem_max_entry_num{128}; 37 | size_t max_record{0}; 38 | size_t allocated_bytes{0}; 39 | }; 40 | 41 | class Tensor : public std::enable_shared_from_this { 42 | public: 43 | Tensor() = default; 44 | ~Tensor(); 45 | static std::shared_ptr MakeCPUTensor(size_t size, DataType dtype); 46 | static std::shared_ptr MakeNPUTensor(size_t size, DataType dtype); 47 | 48 | std::shared_ptr to(DeviceType to_dev); 49 | void to_file(const char* path); 50 | 51 | void* data_ptr{nullptr}; 52 | size_t data_size; 53 | DataType data_type; 54 | DeviceType dev_type; 55 | }; 56 | 57 | -------------------------------------------------------------------------------- /src/device_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "device.hpp" 3 | 4 | 5 | CPUAllocator& CPUAllocator::GetInstance() { 6 | static CPUAllocator instance; 7 | return instance; 8 | } 9 | 10 | void* CPUAllocator::Allocate(size_t size) { 11 | return malloc(size); 12 | } 13 | 14 | void CPUAllocator::Deallocate(void* ptr) { 15 | free(ptr); 16 | } 17 | 18 | 19 | 20 | std::shared_ptr Tensor::MakeCPUTensor(size_t size, DataType dtype) { 21 | auto result = std::make_shared(); 22 | result->data_size = size; 23 | result->data_ptr = CPUAllocator::Allocate(SizeOfTensor(size, dtype)); 24 | result->data_type = dtype; 25 | result->dev_type = DEV_CPU; 26 | return result; 27 | } 28 | -------------------------------------------------------------------------------- /src/device_gpu.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lenLRX/llm_simple/c8466b22f27f2026d63747bcf0856a0e7f61e126/src/device_gpu.cpp -------------------------------------------------------------------------------- /src/device_npu.cpp: -------------------------------------------------------------------------------- 1 | #include "acl/acl.h" 2 | #include 3 | 4 | #include "device.hpp" 5 | #include "acl_util.hpp" 6 | 7 | 8 | NPUAllocatorEntry::NPUAllocatorEntry(size_t s, void* p): size(s), ptr(p) { 9 | 10 | } 11 | 12 | bool NPUAllocatorEntry::operator < (const NPUAllocatorEntry& other){ 13 | return this->size < other.size; 14 | } 15 | 16 | 17 | NPUAllocator& NPUAllocator::GetInstance() { 18 | static NPUAllocator instance; 19 | return instance; 20 | } 21 | 22 | void* NPUAllocator::Allocate(size_t size) { 23 | return GetInstance().AllocateImpl(size); 24 | } 25 | 26 | void* NPUAllocator::AllocateImpl(size_t size) { 27 | void *dev_mem = nullptr; 28 | for (auto it = freelist.begin(); it != freelist.end();) { 29 | if (it->size >= size) { 30 | spdlog::debug("NPUAllocator get {} size {} from pool", it->ptr, it->size); 31 | dev_mem = it->ptr; 32 | it = freelist.erase(it); 33 | break; 34 | } 35 | else { 36 | ++it; 37 | } 38 | } 39 | if (dev_mem != nullptr) { 40 | return dev_mem; 41 | } 42 | CHECK_ACL(aclrtMalloc(&dev_mem, size, ACL_MEM_MALLOC_HUGE_FIRST)); 43 | spdlog::debug("aclrtMalloc {} size {}", dev_mem, size); 44 | allocated_bytes+=size; 45 | ptr_size[dev_mem] = size; 46 | 47 | for (auto it = freelist.begin(); it != freelist.end() && allocated_bytes > dev_mem_max;) { 48 | spdlog::debug("aclrtFree {} size {}", it->ptr, it->size); 49 | CHECK_ACL(aclrtFree(it->ptr)); 50 | allocated_bytes -= it->size; 51 | ptr_size.erase(it->ptr); 52 | it = freelist.erase(it); 53 | } 54 | 55 | return dev_mem; 56 | } 57 | 58 | void NPUAllocator::Deallocate(void* ptr) { 59 | GetInstance().DeallocateImpl(ptr); 60 | } 61 | 62 | void NPUAllocator::DeallocateImpl(void* ptr) { 63 | auto release_size = ptr_size.at(ptr); 64 | spdlog::debug("NPUAllocator return {} size {} to pool", ptr, release_size); 65 | freelist.emplace_back(release_size, ptr); 66 | freelist.sort(); 67 | } 68 | 69 | 70 | 71 | std::shared_ptr Tensor::MakeNPUTensor(size_t size, DataType dtype) { 72 | auto result = std::make_shared(); 73 | result->data_size = size; 74 | result->data_ptr = NPUAllocator::Allocate(SizeOfTensor(size, dtype)); 75 | result->data_type = dtype; 76 | result->dev_type = DEV_NPU; 77 | return result; 78 | } -------------------------------------------------------------------------------- /src/llama2_layer_cpu.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "llama2_model.hpp" 4 | 5 | class EmbeddingLayerCPUImpl : public EmbeddingLayerImpl { 6 | public: 7 | virtual ~EmbeddingLayerCPUImpl(); 8 | virtual std::shared_ptr Forward(std::shared_ptr input, 9 | InferenceCtx &ctx) override; 10 | virtual bool Init(ModelBase *model, const std::string &weight_path) override; 11 | virtual void UnInit() override; 12 | }; 13 | 14 | class Llamma2TransformerLayerCPUImpl : public Llamma2TransformerLayerImpl { 15 | public: 16 | virtual ~Llamma2TransformerLayerCPUImpl(); 17 | std::shared_ptr Forward(std::shared_ptr input, 18 | std::shared_ptr mask, 19 | InferenceCtx &ctx) override; 20 | virtual bool Init(ModelBase *model, int layer_no) override; 21 | virtual void UnInit() override; 22 | 23 | RMSNormLayer pre_norm; 24 | RMSNormLayer post_norm; 25 | MatmulLayer q_proj; 26 | MatmulLayer k_proj; 27 | MatmulLayer v_proj; 28 | MatmulLayer o_proj; 29 | 30 | MatmulLayer gate_proj; // w1 31 | MatmulLayer down_proj; // w2 32 | MatmulLayer up_proj; // w3 33 | 34 | RoPELayer rope_emb; 35 | SoftmaxLayer softmax; 36 | 37 | std::shared_ptr k_cache; 38 | std::shared_ptr v_cache; 39 | }; 40 | 41 | class RMSNormLayerCPUImpl : public RMSNormLayerImpl { 42 | public: 43 | virtual ~RMSNormLayerCPUImpl(); 44 | virtual std::shared_ptr Forward(std::shared_ptr input, 45 | InferenceCtx &ctx) override; 46 | virtual bool Init(ModelBase *model, int layer_no, bool pre_norm, 47 | bool last_norm) override; 48 | virtual void UnInit() override; 49 | }; 50 | 51 | class ArgMaxLayerCPUImpl : public ArgMaxLayerImpl { 52 | public: 53 | virtual ~ArgMaxLayerCPUImpl(); 54 | virtual std::shared_ptr Forward(std::shared_ptr input, 55 | InferenceCtx &ctx) override; 56 | virtual bool Init(ModelBase *model) override; 57 | virtual void UnInit() override; 58 | }; 59 | 60 | class SoftmaxLayerCPUImpl : public SoftmaxLayerImpl { 61 | public: 62 | virtual ~SoftmaxLayerCPUImpl(); 63 | virtual std::shared_ptr Forward(std::shared_ptr input, 64 | InferenceCtx &ctx) override; 65 | virtual bool Init(ModelBase *model) override; 66 | virtual void UnInit() override; 67 | }; 68 | 69 | class CausualMaskLayerCPUImpl : public CausualMaskLayerImpl { 70 | public: 71 | virtual ~CausualMaskLayerCPUImpl(); 72 | virtual std::shared_ptr Forward(InferenceCtx &ctx) override; 73 | virtual bool Init(ModelBase *model) override; 74 | virtual void UnInit() override; 75 | }; 76 | 77 | class MatmulLayerCPUImpl : public MatmulLayerImpl { 78 | public: 79 | virtual ~MatmulLayerCPUImpl(); 80 | virtual std::shared_ptr Forward(std::shared_ptr input, 81 | InferenceCtx &ctx) override; 82 | virtual bool Init(ModelBase *model, const std::string &weight_path, size_t n, 83 | size_t k) override; 84 | virtual void UnInit() override; 85 | }; 86 | 87 | class RoPELayerCPUImpl : public RoPELayerImpl { 88 | public: 89 | virtual ~RoPELayerCPUImpl(); 90 | virtual std::tuple, std::shared_ptr> 91 | Forward(std::shared_ptr input_q, std::shared_ptr input_k, 92 | InferenceCtx &ctx) override; 93 | virtual bool Init(ModelBase *model, const std::string &weight_path) override; 94 | virtual void UnInit() override; 95 | }; 96 | 97 | class SampleTopPLayerCPUImpl : public SampleTopPLayerImpl { 98 | public: 99 | virtual ~SampleTopPLayerCPUImpl(); 100 | virtual int Forward(std::shared_ptr input, 101 | InferenceCtx &ctx) override; 102 | virtual bool Init(ModelBase *model) override; 103 | virtual void UnInit() override; 104 | }; 105 | -------------------------------------------------------------------------------- /src/llama2_layer_npu.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "llama2_model.hpp" 4 | 5 | class EmbeddingLayerNPUImpl : public EmbeddingLayerImpl { 6 | public: 7 | virtual ~EmbeddingLayerNPUImpl(); 8 | virtual std::shared_ptr Forward(std::shared_ptr input, 9 | InferenceCtx &ctx) override; 10 | virtual bool Init(ModelBase *model, const std::string& weight_path) override; 11 | virtual void UnInit() override; 12 | }; 13 | 14 | class RMSNormLayerNPUImpl : public RMSNormLayerImpl { 15 | public: 16 | virtual ~RMSNormLayerNPUImpl(); 17 | virtual std::shared_ptr Forward(std::shared_ptr input, 18 | InferenceCtx &ctx) override; 19 | virtual bool Init(ModelBase *model, int layer_no, bool pre_norm, 20 | bool last_norm) override; 21 | virtual void UnInit() override; 22 | }; 23 | 24 | class SoftmaxLayerNPUImpl : public SoftmaxLayerImpl { 25 | public: 26 | virtual ~SoftmaxLayerNPUImpl(); 27 | virtual std::shared_ptr Forward(std::shared_ptr input, 28 | InferenceCtx &ctx) override; 29 | virtual bool Init(ModelBase *model) override; 30 | virtual void UnInit() override; 31 | }; 32 | 33 | class RoPELayerNPUImpl : public RoPELayerImpl { 34 | public: 35 | virtual ~RoPELayerNPUImpl(); 36 | virtual std::tuple, std::shared_ptr> 37 | Forward(std::shared_ptr input_q, std::shared_ptr input_k, 38 | InferenceCtx &ctx) override; 39 | virtual bool Init(ModelBase *model, const std::string &weight_path) override; 40 | virtual void UnInit() override; 41 | }; 42 | 43 | class Llamma2TransformerLayerNPUImpl : public Llamma2TransformerLayerImpl { 44 | public: 45 | virtual ~Llamma2TransformerLayerNPUImpl(); 46 | std::shared_ptr Forward(std::shared_ptr input, 47 | std::shared_ptr mask, 48 | InferenceCtx &ctx) override; 49 | virtual bool Init(ModelBase *model, int layer_no) override; 50 | virtual void UnInit() override; 51 | 52 | RMSNormLayer pre_norm; 53 | RMSNormLayer post_norm; 54 | MatmulLayer q_proj; 55 | MatmulLayer k_proj; 56 | MatmulLayer v_proj; 57 | MatmulLayer o_proj; 58 | 59 | MatmulLayer gate_proj; // w1 60 | MatmulLayer down_proj; // w2 61 | MatmulLayer up_proj; // w3 62 | 63 | RoPELayer rope_emb; 64 | SoftmaxLayer softmax; 65 | 66 | std::shared_ptr k_cache; 67 | std::shared_ptr v_cache; 68 | }; 69 | 70 | class Qwen2TransformerLayerNPUImpl : public Qwen2TransformerLayerImpl { 71 | public: 72 | virtual ~Qwen2TransformerLayerNPUImpl(); 73 | std::shared_ptr Forward(std::shared_ptr input, 74 | std::shared_ptr mask, 75 | InferenceCtx &ctx) override; 76 | virtual bool Init(ModelBase *model, int layer_no) override; 77 | virtual void UnInit() override; 78 | 79 | RMSNormLayer pre_norm; 80 | RMSNormLayer post_norm; 81 | MatmulLayer q_proj; 82 | MatmulLayer k_proj; 83 | MatmulLayer v_proj; 84 | MatmulLayer o_proj; 85 | 86 | MatmulLayer gate_proj; // w1 87 | MatmulLayer down_proj; // w2 88 | MatmulLayer up_proj; // w3 89 | 90 | RoPELayer rope_emb; 91 | SoftmaxLayer softmax; 92 | 93 | std::shared_ptr k_cache; 94 | std::shared_ptr v_cache; 95 | }; 96 | 97 | class MatmulLayerNPUImpl : public MatmulLayerImpl { 98 | public: 99 | virtual ~MatmulLayerNPUImpl(); 100 | virtual std::shared_ptr Forward(std::shared_ptr input, 101 | InferenceCtx &ctx) override; 102 | virtual bool Init(ModelBase *model, const std::string &weight_path, size_t n, 103 | size_t k) override; 104 | 105 | virtual bool InitWithBias(ModelBase *model, const std::string &weight_path, 106 | const std::string &bias_path, size_t n, 107 | size_t k) override; 108 | 109 | virtual bool InitAWQ(ModelBase *model, const std::string &weight_path, 110 | const std::string &zero_path, 111 | const std::string &scale_path, size_t n, size_t k, 112 | QuantType quant_type) override; 113 | virtual bool AddBias(const std::string &bias_path) override; 114 | virtual void UnInit() override; 115 | }; 116 | -------------------------------------------------------------------------------- /src/llama2_main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #include "acl_util.hpp" 12 | #include "defs.hpp" 13 | #include "llama2_model.hpp" 14 | #include "model_base.hpp" 15 | #include "qwen2_model.hpp" 16 | #include "tokenizer.hpp" 17 | 18 | namespace po = boost::program_options; 19 | 20 | static std::map log_level_name_to_enum{ 21 | {"trace", spdlog::level::trace}, {"debug", spdlog::level::debug}, 22 | {"info", spdlog::level::info}, {"warning", spdlog::level::warn}, 23 | {"error", spdlog::level::err}, {"critical", spdlog::level::critical}, 24 | {"off", spdlog::level::off}}; 25 | 26 | int main(int argc, char **argv) { 27 | Py_Initialize(); 28 | PyImport_ImportModule("site"); 29 | 30 | // 导入sys模块 31 | PyObject *sys_module = PyImport_ImportModule("sys"); 32 | if (!sys_module) { 33 | PyErr_Print(); 34 | return 1; 35 | } 36 | 37 | // 获取sys.path 38 | PyObject *path = PyObject_GetAttrString(sys_module, "path"); 39 | if (path && PyList_Check(path)) { 40 | Py_ssize_t size = PyList_Size(path); 41 | std::cout << "sys.path:" << std::endl; 42 | for (Py_ssize_t i = 0; i < size; ++i) { 43 | PyObject *item = PyList_GetItem(path, i); 44 | const char *path_item = PyUnicode_AsUTF8(item); 45 | std::cout << " " << path_item << std::endl; 46 | } 47 | } else { 48 | PyErr_Print(); 49 | } 50 | 51 | // 获取sys.prefix 52 | PyObject *prefix = PyObject_GetAttrString(sys_module, "prefix"); 53 | if (prefix && PyUnicode_Check(prefix)) { 54 | const char *prefix_str = PyUnicode_AsUTF8(prefix); 55 | std::cout << "sys.prefix: " << prefix_str << std::endl; 56 | } else { 57 | PyErr_Print(); 58 | } 59 | 60 | // 清理引用 61 | Py_XDECREF(path); 62 | Py_XDECREF(prefix); 63 | Py_DECREF(sys_module); 64 | 65 | ModelBase *model = nullptr; 66 | ModelConfig model_config; 67 | std::string str_device_type; 68 | std::string str_prompt; 69 | std::string prompt_file_path; 70 | std::string str_level; 71 | std::string profiling_output_path; 72 | std::string reverse_promt; 73 | std::string str_quant_method; 74 | int benchmark_input_seq_length = 0; 75 | int benchmark_output_seq_length = 0; 76 | try { 77 | po::options_description desc("llama2 inference options"); 78 | desc.add_options() // 79 | ("help", "produce help message") // 80 | ("model_type", 81 | po::value(&model_config.model_type) 82 | ->default_value("llama2"), 83 | "model_type supported: [llama2, qwen2], default:llama2") // 84 | ("max_seq_len", 85 | po::value(&model_config.max_seq_len)->default_value(2048), 86 | "max sequence length of tokens. default:2048") // 87 | ("max_gen_token", 88 | po::value(&model_config.max_gen_len)->default_value(2048), 89 | "max generate of tokens. default:2048") // 90 | ("tokenizer", 91 | po::value(&model_config.tok_path)->required(), 92 | "path to tokenizer") // 93 | ("weight", po::value(&model_config.model_path)->required(), 94 | "path to model weight") // 95 | ("config", 96 | po::value(&model_config.config_path)->required(), 97 | "path to model config") // 98 | ("device_type", po::value(&str_device_type)->required(), 99 | "device type, cpu/gpu") // 100 | ("prompt", po::value(&str_prompt), "prompt str") // 101 | ("prompt_file", po::value(&prompt_file_path), 102 | "prompt file") // 103 | ("log_level", po::value(&str_level), 104 | "log level:[trace,debug,info,warning,error,critical,off]") // 105 | ("profiling_output", po::value(&profiling_output_path), 106 | "profiling_output_file xx.json") // 107 | ("debug_print", 108 | po::value(&model_config.debug_print)->default_value(false), 109 | "print tensor value to debug") // 110 | ("temperature", 111 | po::value(&model_config.temperature)->default_value(0.6), 112 | "sample temperature, default: 0.6") // 113 | ("top_p", po::value(&model_config.top_p)->default_value(0.9), 114 | "sample top_p, default: 0.9") // 115 | ("reverse_promt", po::value(&reverse_promt), 116 | "reverse_promt in interactive mode") // 117 | ("i", "interactive mode") // 118 | ("quant_method", po::value(&str_quant_method), 119 | "quant_method: current support: awq_4bit") // 120 | ("quant_group_size", 121 | po::value(&model_config.quant_group_size)->default_value(-1), 122 | "group size in quant") // 123 | ("rope_is_neox_style", 124 | po::value(&model_config.rope_is_neox_style) 125 | ->default_value(false), 126 | "rope embedding style, defalut: false") // 127 | ("benchmark_input_seq_length", 128 | po::value(&benchmark_input_seq_length), 129 | "benchmark input_seq length") // 130 | ("benchmark_output_seq_length", 131 | po::value(&benchmark_output_seq_length), 132 | "benchmark output_seq length") // 133 | ("benchmark", "performance benchmark"); 134 | 135 | po::variables_map vm; 136 | po::store(po::parse_command_line(argc, argv, desc), vm); 137 | 138 | if (vm.count("help")) { 139 | std::cout << desc << "\n"; 140 | return 1; 141 | } 142 | 143 | po::notify(vm); 144 | 145 | if (vm.count("log_level")) { 146 | if (!log_level_name_to_enum.count(str_level)) { 147 | std::cout << "invalid log_level:" << str_level << "\n"; 148 | return 1; 149 | } 150 | spdlog::set_level(log_level_name_to_enum[str_level]); 151 | } else { 152 | // default level is info 153 | spdlog::set_level(spdlog::level::info); 154 | } 155 | 156 | if (vm.count("prompt_file")) { 157 | if (vm.count("prompt")) { 158 | spdlog::warn("prompt_file overwrite prompt string"); 159 | } 160 | std::ifstream prompt_file(prompt_file_path.c_str()); 161 | if (!prompt_file) { 162 | spdlog::critical("failed to open prompt_file {}", prompt_file_path); 163 | return 1; 164 | } 165 | std::stringstream ss; 166 | ss << prompt_file.rdbuf(); 167 | str_prompt = ss.str(); 168 | } 169 | 170 | if (model_config.model_type == "llama2") { 171 | model = new Llama2Model(); 172 | model_config.data_type = DT_FLOAT16; 173 | } else if (model_config.model_type == "qwen2") { 174 | model = new Qwen2Model(); 175 | } else { 176 | spdlog::critical("invalid model_type type {}", model_config.model_type); 177 | return 1; 178 | } 179 | 180 | if (str_device_type == "cpu") { 181 | model_config.device_type = DEV_CPU; 182 | } else if (str_device_type == "gpu") { 183 | model_config.device_type = DEV_GPU; 184 | } else if (str_device_type == "npu") { 185 | CHECK_ACL(aclInit(nullptr)); 186 | model_config.device_type = DEV_NPU; 187 | } else { 188 | spdlog::critical("invalid device type {}", str_device_type); 189 | return 1; 190 | } 191 | 192 | if (boost::filesystem::exists(model_config.tok_path)) { 193 | } else { 194 | spdlog::error("invalid tokenizer path {}", model_config.tok_path); 195 | return 1; 196 | } 197 | 198 | if (boost::filesystem::exists(model_config.model_path) && 199 | boost::filesystem::is_directory(model_config.model_path)) { 200 | } else { 201 | spdlog::error("invalid model_weight path {}", model_config.model_path); 202 | return 1; 203 | } 204 | 205 | if (boost::filesystem::exists(model_config.config_path) && 206 | boost::filesystem::is_regular_file(model_config.config_path)) { 207 | } else { 208 | spdlog::error("invalid config path {}", model_config.config_path); 209 | return 1; 210 | } 211 | 212 | if (str_quant_method == "awq_4bit") { 213 | model_config.q_type = QuantType::AWQ_4B; 214 | } 215 | 216 | model->config = model_config; 217 | 218 | if (!model->Init()) { 219 | spdlog::error("failed to init model"); 220 | return 1; 221 | } 222 | 223 | if (vm.count("profiling_output")) { 224 | if (!model->profiler.StartLogging(profiling_output_path)) { 225 | spdlog::error("failed to init profiler, check {} is writable", 226 | profiling_output_path); 227 | return 1; 228 | } 229 | model->is_profiling = true; 230 | } 231 | 232 | // interactive mode 233 | if (vm.count("i")) { 234 | model->Chat(str_prompt, reverse_promt); 235 | } 236 | // benchmark mode 237 | if (vm.count("benchmark")) { 238 | model->Benchmark(benchmark_input_seq_length, benchmark_output_seq_length); 239 | } 240 | // text completion mode 241 | else { 242 | model->TextCompletion(str_prompt); 243 | } 244 | 245 | if (vm.count("profiling_output")) { 246 | model->profiler.Finish(); 247 | } 248 | } catch (std::exception &e) { 249 | spdlog::error("{}", e.what()); 250 | Py_Finalize(); 251 | return 1; 252 | } 253 | delete model; 254 | Py_Finalize(); 255 | } 256 | -------------------------------------------------------------------------------- /src/llama2_model.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include "acl_util.hpp" 10 | 11 | #include "device.hpp" 12 | #include "model_base.hpp" 13 | #include "profiling.hpp" 14 | #include "tokenizer.hpp" 15 | 16 | class Llama2Model : public ModelBase { 17 | public: 18 | Tokenizer tokenizer; 19 | 20 | virtual bool Init() override; 21 | 22 | void Chat(const std::string &input_seq, const std::string &reverse_prompt) override; 23 | void TextCompletion(const std::string &input_seq) override; 24 | void Benchmark(int input_seq_len, int output_seq_len) override; 25 | std::string GetCurrTokenString(size_t prev_string_size, 26 | const std::vector &tokens); 27 | 28 | int multiple_of; 29 | 30 | EmbeddingLayer embedding_layer; 31 | CausualMaskLayer causual_mask_layer; 32 | std::vector transformer_layers; 33 | RMSNormLayer last_norm; 34 | MatmulLayer last_mm; 35 | ArgMaxLayer argmax_layer; 36 | SampleTopPLayer top_p_layer; 37 | }; 38 | -------------------------------------------------------------------------------- /src/model_base.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include "acl_util.hpp" 10 | 11 | #include "defs.hpp" 12 | #include "device.hpp" 13 | #include "profiling.hpp" 14 | #include "tokenizer.hpp" 15 | 16 | enum class QuantType { NoQuant, AWQ_4B }; 17 | 18 | class ModelConfig { 19 | public: 20 | std::string tok_path; 21 | std::string model_path; 22 | std::string config_path; 23 | nlohmann::json config; 24 | std::string model_type; 25 | DeviceType device_type; 26 | DataType data_type; 27 | int max_seq_len; 28 | int max_gen_len; 29 | float norm_eps; 30 | float temperature{0.0f}; 31 | float top_p{0.0f}; 32 | QuantType q_type{QuantType::NoQuant}; 33 | int quant_group_size{-1}; 34 | bool rope_is_neox_style; 35 | bool debug_print{false}; 36 | }; 37 | 38 | class ModelBase; 39 | 40 | class InferenceCtx { 41 | public: 42 | InferenceCtx(ModelBase *model, size_t cur_pos, size_t prev_pos); 43 | ModelBase *model; 44 | int cur_pos; 45 | int prev_pos; 46 | int cur_size; 47 | aclrtStream npu_stream{nullptr}; 48 | }; 49 | 50 | class EmbeddingLayerImpl { 51 | public: 52 | virtual ~EmbeddingLayerImpl(); 53 | // N -> [N, hidden_dim] 54 | virtual std::shared_ptr Forward(std::shared_ptr input, 55 | InferenceCtx &ctx) = 0; 56 | virtual bool Init(ModelBase *model, const std::string& weight_path); 57 | 58 | virtual void UnInit(); 59 | 60 | size_t nwords_size; 61 | size_t weight_size; 62 | size_t hidden_dim; 63 | uint8_t *embedding_weight{nullptr}; 64 | }; 65 | 66 | class EmbeddingLayer { 67 | public: 68 | ~EmbeddingLayer(); 69 | // N -> [N, hidden_dim] 70 | std::shared_ptr Forward(std::shared_ptr input, 71 | InferenceCtx &ctx); 72 | bool Init(ModelBase *model, const std::string& weight_path); 73 | 74 | void UnInit(); 75 | 76 | EmbeddingLayerImpl *impl{nullptr}; 77 | }; 78 | 79 | class RMSNormLayerImpl { 80 | public: 81 | virtual ~RMSNormLayerImpl(); 82 | virtual std::shared_ptr Forward(std::shared_ptr input, 83 | InferenceCtx &ctx) = 0; 84 | virtual bool Init(ModelBase *model, int layer_no, bool pre_norm, 85 | bool last_norm); 86 | virtual bool Init(ModelBase *model, const std::string& weight_path); 87 | virtual void UnInit(); 88 | 89 | size_t hidden_dim; 90 | size_t weight_size; 91 | float eps; 92 | uint8_t *norm_weight{nullptr}; 93 | }; 94 | 95 | class RMSNormLayer { 96 | public: 97 | ~RMSNormLayer(); 98 | virtual std::shared_ptr Forward(std::shared_ptr input, 99 | InferenceCtx &ctx); 100 | bool Init(ModelBase *model, int layer_no, bool pre_norm, bool last_norm); 101 | bool Init(ModelBase *model, const std::string& weight_path); 102 | void UnInit(); 103 | RMSNormLayerImpl *impl{nullptr}; 104 | }; 105 | 106 | class RoPELayerImpl { 107 | public: 108 | virtual ~RoPELayerImpl(); 109 | virtual std::tuple, std::shared_ptr> 110 | Forward(std::shared_ptr input_q, std::shared_ptr input_k, 111 | InferenceCtx &ctx) = 0; 112 | virtual bool Init(ModelBase *model, const std::string &weight_path); 113 | virtual void UnInit(); 114 | 115 | size_t hidden_dim; 116 | size_t n_heads; 117 | size_t head_dim; 118 | size_t rope_dim; 119 | size_t weight_size; 120 | bool rope_is_neox_style; 121 | float *freqs_cis{nullptr}; 122 | }; 123 | 124 | class RoPELayer { 125 | public: 126 | ~RoPELayer(); 127 | virtual std::tuple, std::shared_ptr> 128 | Forward(std::shared_ptr input_q, std::shared_ptr input_k, 129 | InferenceCtx &ctx); 130 | bool Init(ModelBase *model, const std::string &weight_path); 131 | void UnInit(); 132 | RoPELayerImpl *impl{nullptr}; 133 | }; 134 | 135 | class ArgMaxLayerImpl { 136 | public: 137 | virtual ~ArgMaxLayerImpl(); 138 | virtual std::shared_ptr Forward(std::shared_ptr input, 139 | InferenceCtx &ctx) = 0; 140 | virtual bool Init(ModelBase *model); 141 | virtual void UnInit(); 142 | 143 | size_t hidden_dim; 144 | DataType dt; 145 | }; 146 | 147 | class ArgMaxLayer { 148 | public: 149 | ~ArgMaxLayer(); 150 | virtual std::shared_ptr Forward(std::shared_ptr input, 151 | InferenceCtx &ctx); 152 | bool Init(ModelBase *model); 153 | void UnInit(); 154 | ArgMaxLayerImpl *impl{nullptr}; 155 | }; 156 | 157 | class SampleTopPLayerImpl { 158 | public: 159 | virtual ~SampleTopPLayerImpl(); 160 | virtual int Forward(std::shared_ptr input, InferenceCtx &ctx) = 0; 161 | virtual bool Init(ModelBase *model); 162 | virtual void UnInit(); 163 | 164 | float temperature; 165 | float top_p; 166 | size_t vocab_size; 167 | }; 168 | 169 | class SampleTopPLayer { 170 | public: 171 | ~SampleTopPLayer(); 172 | virtual int Forward(std::shared_ptr input, InferenceCtx &ctx); 173 | bool Init(ModelBase *model); 174 | void UnInit(); 175 | SampleTopPLayerImpl *impl{nullptr}; 176 | }; 177 | 178 | class SoftmaxLayerImpl { 179 | public: 180 | virtual ~SoftmaxLayerImpl(); 181 | virtual std::shared_ptr Forward(std::shared_ptr input, 182 | InferenceCtx &ctx) = 0; 183 | virtual bool Init(ModelBase *model); 184 | virtual void UnInit(); 185 | 186 | size_t hidden_dim; 187 | size_t n_heads; 188 | float eps; 189 | }; 190 | 191 | class SoftmaxLayer { 192 | public: 193 | ~SoftmaxLayer(); 194 | virtual std::shared_ptr Forward(std::shared_ptr input, 195 | InferenceCtx &ctx); 196 | bool Init(ModelBase *model); 197 | void UnInit(); 198 | SoftmaxLayerImpl *impl{nullptr}; 199 | }; 200 | 201 | class CausualMaskLayerImpl { 202 | public: 203 | virtual ~CausualMaskLayerImpl(); 204 | virtual std::shared_ptr Forward(InferenceCtx &ctx) = 0; 205 | virtual bool Init(ModelBase *model); 206 | virtual void UnInit(); 207 | }; 208 | 209 | class CausualMaskLayer { 210 | public: 211 | ~CausualMaskLayer(); 212 | virtual std::shared_ptr Forward(InferenceCtx &ctx); 213 | bool Init(ModelBase *model); 214 | void UnInit(); 215 | CausualMaskLayerImpl *impl{nullptr}; 216 | }; 217 | 218 | class MatmulLayerImpl { 219 | public: 220 | virtual ~MatmulLayerImpl(); 221 | virtual std::shared_ptr Forward(std::shared_ptr input, 222 | InferenceCtx &ctx) = 0; 223 | virtual bool Init(ModelBase *model, const std::string &weight_path, size_t n, 224 | size_t k); 225 | virtual bool InitWithBias(ModelBase *model, const std::string &weight_path, 226 | const std::string &bias, size_t n, size_t k); 227 | virtual bool InitAWQ(ModelBase *model, const std::string &weight_path, 228 | const std::string &zero_path, 229 | const std::string &scale_path, size_t n, size_t k, 230 | QuantType quant_type); 231 | virtual bool AddBias(const std::string &bias_path); 232 | virtual void UnInit(); 233 | 234 | size_t n; 235 | size_t k; 236 | size_t weight_size; 237 | size_t bias_size; 238 | size_t zero_size; 239 | size_t scale_size; 240 | uint8_t *weight{nullptr}; 241 | uint8_t *qzeros{nullptr}; 242 | uint8_t *qscales{nullptr}; 243 | uint8_t *bias{nullptr}; 244 | 245 | DataType dtype; 246 | QuantType qtype{QuantType::NoQuant}; 247 | }; 248 | 249 | class MatmulLayer { 250 | public: 251 | ~MatmulLayer(); 252 | std::shared_ptr Forward(std::shared_ptr input, 253 | InferenceCtx &ctx); 254 | bool Init(ModelBase *model, const std::string &weight_path, size_t n, 255 | size_t k); 256 | bool InitWithBias(ModelBase *model, const std::string &weight_path, 257 | const std::string &bias, size_t n, size_t k); 258 | bool InitAWQ(ModelBase *model, const std::string &weight_path, 259 | const std::string &zero_path, const std::string &scale_path, 260 | size_t n, size_t k, QuantType quant_type); 261 | bool AddBias(const std::string &bias_path); 262 | void UnInit(); 263 | MatmulLayerImpl *impl{nullptr}; 264 | }; 265 | 266 | class Llamma2TransformerLayerImpl { 267 | public: 268 | virtual ~Llamma2TransformerLayerImpl() = default; 269 | virtual std::shared_ptr Forward(std::shared_ptr input, 270 | std::shared_ptr mask, 271 | InferenceCtx &ctx) = 0; 272 | virtual bool Init(ModelBase *model, int layer_no); 273 | virtual void UnInit(); 274 | 275 | size_t ffn_hidden; 276 | size_t hidden_dim; 277 | size_t head_dim; 278 | size_t n_heads; 279 | size_t max_seq_len; 280 | }; 281 | 282 | class Llamma2TransformerLayer { 283 | public: 284 | ~Llamma2TransformerLayer(); 285 | std::shared_ptr Forward(std::shared_ptr input, 286 | std::shared_ptr mask, 287 | InferenceCtx &ctx); 288 | bool Init(ModelBase *model, int layer_no); 289 | void UnInit(); 290 | 291 | Llamma2TransformerLayerImpl *impl{nullptr}; 292 | }; 293 | 294 | class Qwen2TransformerLayerImpl { 295 | public: 296 | virtual ~Qwen2TransformerLayerImpl() = default; 297 | virtual std::shared_ptr Forward(std::shared_ptr input, 298 | std::shared_ptr mask, 299 | InferenceCtx &ctx) = 0; 300 | virtual bool Init(ModelBase *model, int layer_no); 301 | virtual void UnInit(); 302 | 303 | DataType dtype; 304 | int ffn_hidden; 305 | int hidden_dim; 306 | int head_dim; 307 | int n_heads; 308 | int n_kv_heads; 309 | int max_seq_len; 310 | }; 311 | 312 | class Qwen2TransformerLayer { 313 | public: 314 | ~Qwen2TransformerLayer(); 315 | std::shared_ptr Forward(std::shared_ptr input, 316 | std::shared_ptr mask, 317 | InferenceCtx &ctx); 318 | bool Init(ModelBase *model, int layer_no); 319 | void UnInit(); 320 | 321 | Qwen2TransformerLayerImpl *impl{nullptr}; 322 | }; 323 | 324 | class ModelBase { 325 | public: 326 | ModelBase() = default; 327 | virtual ~ModelBase() = default; 328 | 329 | virtual bool Init() = 0; 330 | 331 | bool InitFreqCIS(const float theta=10000.0f, const DataType dt=DT_FLOAT32); 332 | 333 | virtual void Chat(const std::string &input_seq, const std::string &reverse_prompt) = 0; 334 | virtual void TextCompletion(const std::string &input_seq) = 0; 335 | virtual void Benchmark(int input_seq_len, int output_seq_len) = 0; 336 | 337 | 338 | ModelConfig config; 339 | int hidden_dim; 340 | int head_dim; 341 | int n_heads; 342 | int n_kv_heads; 343 | int n_layers; 344 | float norm_eps; 345 | 346 | int n_words; 347 | int pad_id; 348 | 349 | aclrtStream model_stream; 350 | 351 | float *freq_cis{nullptr}; 352 | bool is_profiling{false}; 353 | AppProfiler profiler; 354 | }; 355 | -------------------------------------------------------------------------------- /src/profiling.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #include "profiling.hpp" 8 | 9 | 10 | 11 | bool AppProfiler::StartLogging(const std::string &log_path) { 12 | profile_file.open(log_path.c_str()); 13 | if (!profile_file) { 14 | return false; 15 | } 16 | profile_file << "[\n"; 17 | first_record = true; 18 | auto start_tp = std::chrono::steady_clock::now(); 19 | auto start_us_tp = std::chrono::time_point_cast(start_tp); 20 | start_us_count = start_us_tp.time_since_epoch().count(); 21 | return true; 22 | } 23 | 24 | void AppProfiler::RecordEvent(json jevent) { 25 | if (first_record) { 26 | first_record = false; 27 | } else { 28 | profile_file << ",\n"; 29 | } 30 | jevent["ts"] = jevent["ts"].get() - start_us_count; 31 | profile_file << jevent; 32 | } 33 | 34 | void AppProfiler::Finish() { 35 | profile_file << "]\n"; 36 | profile_file.close(); 37 | } 38 | 39 | AppProfileGuard::AppProfileGuard(const char *name, 40 | const char *info, 41 | aclrtStream stream, 42 | AppProfiler* profiler, 43 | const char *fname, int lineno, 44 | bool is_profiling) 45 | : record_name(name), record_info(info), record_file_name(fname), record_file_lineno(lineno), 46 | stream(stream), profiler(profiler), is_profiling(is_profiling) { 47 | if (is_profiling) { 48 | CHECK_ACL(aclrtCreateEvent(&event)); 49 | CHECK_ACL(aclrtRecordEvent(event, stream)); 50 | AddBeginRecord(); 51 | } 52 | } 53 | 54 | AppProfileGuard::~AppProfileGuard() { 55 | if (is_profiling) { 56 | AddEndRecord(); 57 | } 58 | } 59 | 60 | void AppProfileGuard::AddBeginRecord() { 61 | auto current_tp = std::chrono::steady_clock::now(); 62 | start_us = 63 | std::chrono::time_point_cast(current_tp); 64 | } 65 | 66 | void AppProfileGuard::AddEndRecord() { 67 | AddRecord(record_name.c_str(), record_info.c_str(), stream, profiler, record_file_name, record_file_lineno); 68 | } 69 | 70 | void AppProfileGuard::AddRecord(const char *name, const char *info, aclrtStream stream, AppProfiler* profiler, const char *fname, int lineno) const { 71 | aclrtEvent end_event; 72 | CHECK_ACL(aclrtCreateEvent(&end_event)); 73 | CHECK_ACL(aclrtRecordEvent(end_event, stream)); 74 | CHECK_ACL(aclrtSynchronizeStream(stream)); 75 | float duration_ms; 76 | CHECK_ACL(aclrtEventElapsedTime(&duration_ms, event, end_event)); 77 | CHECK_ACL(aclrtDestroyEvent(event)); 78 | CHECK_ACL(aclrtDestroyEvent(end_event)); 79 | 80 | auto current_tp = std::chrono::steady_clock::now(); 81 | auto us_tp = 82 | std::chrono::time_point_cast(current_tp); 83 | auto start_count = start_us.time_since_epoch().count(); 84 | auto end_count = us_tp.time_since_epoch().count(); 85 | 86 | json record{{"name", name}, 87 | {"ph", "X"}, 88 | {"pid", 0}, 89 | {"ts", start_count}, 90 | {"dur", duration_ms * 1000}, 91 | {"args", {{"file", fname}, {"lineno", lineno}, {"info", info}}}}; 92 | 93 | profiler->RecordEvent(record); 94 | } -------------------------------------------------------------------------------- /src/profiling.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "acl_util.hpp" 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | 10 | using json = nlohmann::json; 11 | 12 | class AppProfiler { 13 | public: 14 | bool StartLogging(const std::string &log_path); 15 | void RecordEvent(json jevent); 16 | void Finish(); 17 | private: 18 | std::ofstream profile_file; 19 | bool first_record; 20 | uint64_t start_us_count; 21 | }; 22 | 23 | class AppProfileGuard { 24 | public: 25 | AppProfileGuard(const char *name, const char *info, aclrtStream stream, AppProfiler* profiler, const char *fname, int lineno, bool is_profiling); 26 | ~AppProfileGuard(); 27 | void AddBeginRecord(); 28 | void AddEndRecord(); 29 | 30 | private: 31 | void AddRecord(const char *name, const char *info, aclrtStream stream, AppProfiler* profiler, const char *fname, int lineno) const; 32 | std::string record_name; 33 | std::string record_info; 34 | aclrtStream stream; 35 | AppProfiler* profiler; 36 | const char *record_file_name; 37 | int record_file_lineno; 38 | bool is_profiling; 39 | std::chrono::time_point 40 | start_us; 41 | aclrtEvent event; 42 | }; 43 | 44 | #define _CONCAT_(x, y) x##y 45 | #define __CONCAT__(x, y) _CONCAT_(x, y) 46 | 47 | #define APP_PROFILE(name, info, stream, profiler, isprofiling) \ 48 | AppProfileGuard __CONCAT__(temp_perf_obj_, __LINE__)(name, info, stream, profiler, __FILE__, __LINE__, isprofiling) 49 | -------------------------------------------------------------------------------- /src/qwen2_model.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "acl_util.hpp" 11 | #include "device.hpp" 12 | #include "model_base.hpp" 13 | #include "profiling.hpp" 14 | #include "tokenizer.hpp" 15 | #include "util.h" 16 | 17 | // qwen.cpp/qwen.h 18 | class BaseStreamer { 19 | public: 20 | virtual ~BaseStreamer() = default; 21 | virtual auto put(const std::vector &output_ids) -> void = 0; 22 | virtual auto end() -> void = 0; 23 | }; 24 | 25 | class StreamerGroup : public BaseStreamer { 26 | public: 27 | StreamerGroup(std::vector> streamers) 28 | : streamers_(std::move(streamers)) {} 29 | auto put(const std::vector &output_ids) -> void override; 30 | auto end() -> void override; 31 | 32 | private: 33 | std::vector> streamers_; 34 | }; 35 | 36 | class TextStreamer : public BaseStreamer { 37 | public: 38 | TextStreamer(std::ostream &os, Qwen2HFTokenizer *tokenizer) 39 | : os_(os), tokenizer_(tokenizer), is_prompt_(true), print_len_(0) {} 40 | auto put(const std::vector &output_ids) -> void override; 41 | auto end() -> void override; 42 | 43 | private: 44 | std::ostream &os_; 45 | Qwen2HFTokenizer *tokenizer_; 46 | bool is_prompt_; 47 | std::vector token_cache_; 48 | int print_len_; 49 | }; 50 | 51 | class PerfStreamer : public BaseStreamer { 52 | public: 53 | PerfStreamer() 54 | : start_us_(0), prompt_us_(0), end_us_(0), num_prompt_tokens_(0), 55 | num_output_tokens_(0) {} 56 | 57 | auto put(const std::vector &output_ids) -> void override; 58 | auto end() -> void override { end_us_ = get_current_us(); } 59 | 60 | auto reset() -> void; 61 | auto to_string() -> std::string const; 62 | 63 | auto num_prompt_tokens() const -> int64_t { return num_prompt_tokens_; } 64 | auto prompt_total_time_us() const -> int64_t { 65 | return prompt_us_ - start_us_; 66 | } 67 | auto prompt_token_time_us() const -> int64_t { 68 | return num_prompt_tokens() ? prompt_total_time_us() / num_prompt_tokens() 69 | : 0; 70 | } 71 | auto num_output_tokens() const -> int64_t { return num_output_tokens_; } 72 | auto output_total_time_us() const -> int64_t { return end_us_ - prompt_us_; } 73 | auto output_token_time_us() const -> int64_t { 74 | return num_output_tokens() ? output_total_time_us() / num_output_tokens() 75 | : 0; 76 | } 77 | 78 | private: 79 | int64_t start_us_; 80 | int64_t prompt_us_; 81 | int64_t end_us_; 82 | int64_t num_prompt_tokens_; 83 | int64_t num_output_tokens_; 84 | }; 85 | 86 | class Qwen2Model : public ModelBase { 87 | public: 88 | //QwenTokenizer qwen_tokenizer; 89 | Qwen2HFTokenizer qwen_tokenizer; 90 | 91 | virtual bool Init() override; 92 | 93 | void Chat(const std::string &input_seq, 94 | const std::string &reverse_prompt) override; 95 | void TextCompletion(const std::string &input_seq) override; 96 | void Benchmark(int input_seq_len, int output_seq_len) override; 97 | int GenerateNextToken(const std::vector &input_ids, 98 | InferenceCtx &ctx, int n_past); 99 | std::vector Generate(const std::vector &input_tokens, 100 | InferenceCtx &ctx, BaseStreamer *streamer); 101 | std::string Generate(const std::vector &history, 102 | InferenceCtx &ctx, BaseStreamer *streamer); 103 | 104 | bool tie_word_embeddings; 105 | int intermediate_size; 106 | int generate_limit; 107 | EmbeddingLayer embedding_layer; 108 | CausualMaskLayer causual_mask_layer; 109 | std::vector transformer_layers; 110 | RMSNormLayer last_norm; 111 | MatmulLayer last_mm; 112 | ArgMaxLayer argmax_layer; 113 | SampleTopPLayer top_p_layer; 114 | }; 115 | -------------------------------------------------------------------------------- /src/tiktoken.h: -------------------------------------------------------------------------------- 1 | // from https://github.com/QwenLM/qwen.cpp/blob/master/tiktoken.h 2 | #pragma once 3 | 4 | #include 5 | #include "unordered_dense.h" 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | namespace tiktoken { 17 | 18 | static auto _byte_pair_merge( 19 | const std::string &piece, 20 | const ankerl::unordered_dense::map &ranks, 21 | std::function func 22 | ) -> std::vector { 23 | std::vector> parts; 24 | parts.reserve(piece.size() + 1); 25 | for (auto idx = 0U; idx < piece.size() + 1; ++idx) { 26 | parts.emplace_back(idx, std::numeric_limits::max()); 27 | } 28 | 29 | auto get_rank = [&piece, &ranks]( 30 | const std::vector> &parts, 31 | int start_idx, 32 | int skip 33 | ) -> std::optional { 34 | if (start_idx + skip + 2 < parts.size()) { 35 | auto s = parts[start_idx].first; 36 | auto e = parts[start_idx + skip + 2].first; 37 | auto key = piece.substr(s, e - s); 38 | auto iter = ranks.find(key); 39 | if (iter != ranks.end()) { 40 | return iter->second; 41 | } 42 | } 43 | return std::nullopt; 44 | }; 45 | 46 | for (auto i = 0U; i < parts.size() - 2; ++i) { 47 | auto rank = get_rank(parts, i, 0); 48 | if (rank) { 49 | assert(*rank != std::numeric_limits::max()); 50 | parts[i].second = *rank; 51 | } 52 | } 53 | 54 | while (true) { 55 | if (parts.size() == 1) break; 56 | 57 | auto min_rank = std::make_pair(std::numeric_limits::max(), 0); 58 | for (auto i = 0U; i < parts.size() - 1; ++i) { 59 | auto rank = parts[i].second; 60 | if (rank < min_rank.first) { 61 | min_rank = { rank, i }; 62 | } 63 | } 64 | 65 | if (min_rank.first != std::numeric_limits::max()) { 66 | auto i = min_rank.second; 67 | auto rank = get_rank(parts, i, 1); 68 | if (rank) { 69 | parts[i].second = *rank; 70 | } else { 71 | parts[i].second = std::numeric_limits::max(); 72 | } 73 | if (i > 0) { 74 | auto rank = get_rank(parts, i - 1, 1); 75 | if (rank) { 76 | parts[i - 1].second = *rank; 77 | } else { 78 | parts[i - 1].second = std::numeric_limits::max(); 79 | } 80 | } 81 | 82 | parts.erase(parts.begin() + (i + 1)); 83 | } else { 84 | break; 85 | } 86 | } 87 | std::vector out; 88 | out.reserve(parts.size() - 1); 89 | for (auto i = 0U; i < parts.size() - 1; ++i) { 90 | out.push_back(func(parts[i].first, parts[i + 1].first)); 91 | } 92 | return out; 93 | } 94 | 95 | static auto byte_pair_encode( 96 | const std::string &piece, 97 | const ankerl::unordered_dense::map &ranks 98 | ) -> std::vector { 99 | if (piece.size() == 1) { 100 | return {ranks.at(piece)}; 101 | } 102 | 103 | auto func = [&piece, &ranks](int start, int stop) -> int { 104 | std::string key = piece.substr(start, stop - start); 105 | return ranks.at(key); 106 | }; 107 | 108 | return _byte_pair_merge(piece, ranks, func); 109 | } 110 | 111 | class tiktoken { 112 | public: 113 | tiktoken() = default; 114 | tiktoken( 115 | ankerl::unordered_dense::map encoder, 116 | ankerl::unordered_dense::map special_encoder, 117 | const std::string &pattern 118 | ) { 119 | regex_ = std::make_unique("(" + pattern + ")"); 120 | 121 | std::string special_pattern; 122 | for (const auto &item : special_encoder) { 123 | if (!special_pattern.empty()) { 124 | special_pattern += "|"; 125 | } 126 | special_pattern += re2::RE2::QuoteMeta(item.first); 127 | } 128 | if (special_pattern.empty()) { 129 | special_regex_ = nullptr; 130 | } else { 131 | special_regex_ = std::make_unique("(" + special_pattern + ")"); 132 | } 133 | 134 | encoder_ = std::move(encoder); 135 | special_tokens_encoder = std::move(special_encoder); 136 | 137 | for (const auto &[k, v] : encoder_) { 138 | decoder_.emplace(v, k); 139 | } 140 | //assert(encoder_.size() != decoder_.size() && "Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?"); 141 | 142 | for (const auto &[k, v] : special_tokens_encoder) { 143 | special_tokens_decoder.emplace(v, k); 144 | } 145 | } 146 | 147 | auto encode_ordinary(const std::string &text) const -> std::vector { 148 | return _encode_ordinary_native(text); 149 | } 150 | 151 | auto encode(const std::string &text) const -> std::vector { 152 | return _encode_native(text, special_tokens_encoder).first; 153 | } 154 | 155 | auto encode_single_piece(const std::string &text) const -> std::vector { 156 | auto iter = encoder_.find(text); 157 | if (iter != encoder_.end()) { 158 | return {iter->second}; 159 | } 160 | return byte_pair_encode(text, encoder_); 161 | } 162 | 163 | auto decode(const std::vector &tokens) const -> std::string { 164 | return _decode_native(tokens); 165 | } 166 | 167 | private: 168 | auto split_with_allowed_special_token( 169 | re2::StringPiece &input, 170 | const ankerl::unordered_dense::map &allowed_special 171 | ) const -> std::pair, re2::StringPiece> { 172 | if (special_regex_ == nullptr) return { std::nullopt, input }; 173 | 174 | auto start = input.begin(); 175 | std::string special; 176 | while (true) { 177 | if (!re2::RE2::FindAndConsume(&input, *special_regex_, &special)) { 178 | break; 179 | } 180 | 181 | if (allowed_special.count(special) == 1) { 182 | return { std::move(special), re2::StringPiece(start, input.begin() - start - special.size()) }; 183 | } 184 | } 185 | 186 | return { std::nullopt, input }; 187 | } 188 | 189 | auto _encode_ordinary_native(const std::string &text) const -> std::vector { 190 | std::vector ret; 191 | re2::StringPiece input(text); 192 | 193 | std::string piece; 194 | while (re2::RE2::FindAndConsume(&input, *regex_, &piece)) { 195 | auto iter = encoder_.find(piece); 196 | if (iter != encoder_.end()) { 197 | ret.push_back(iter->second); 198 | continue; 199 | } 200 | auto tokens = byte_pair_encode(piece, encoder_); 201 | ret.insert(ret.end(), tokens.begin(), tokens.end()); 202 | } 203 | return ret; 204 | } 205 | 206 | auto _encode_native( 207 | const std::string &text, 208 | const ankerl::unordered_dense::map &allowed_special 209 | ) const -> std::pair, int> { 210 | std::vector ret; 211 | int last_piece_token_len = 0; 212 | re2::StringPiece input(text); 213 | 214 | while (true) { 215 | auto [special, sub_input] = split_with_allowed_special_token(input, allowed_special); 216 | std::string piece; 217 | while (re2::RE2::FindAndConsume(&sub_input, *regex_, &piece)) { 218 | auto iter = encoder_.find(piece); 219 | if (iter != encoder_.end()) { 220 | last_piece_token_len = 1; 221 | ret.push_back(iter->second); 222 | continue; 223 | } 224 | auto tokens = byte_pair_encode(piece, encoder_); 225 | last_piece_token_len = tokens.size(); 226 | ret.insert(ret.end(), tokens.begin(), tokens.end()); 227 | } 228 | 229 | if (special) { 230 | int token = special_tokens_encoder.at(*special); 231 | ret.push_back(token); 232 | last_piece_token_len = 0; 233 | } else { 234 | break; 235 | } 236 | } 237 | 238 | return { ret, last_piece_token_len }; 239 | } 240 | 241 | auto _decode_native(const std::vector &tokens) const -> std::string { 242 | std::string ret; 243 | ret.reserve(tokens.size() * 2); 244 | for (auto token : tokens) { 245 | std::string token_bytes; 246 | auto iter = decoder_.find(token); 247 | if (iter != decoder_.end()) { 248 | token_bytes = iter->second; 249 | } else { 250 | iter = special_tokens_decoder.find(token); 251 | if (iter != special_tokens_decoder.end()) { 252 | token_bytes = iter->second; 253 | } else { 254 | throw std::runtime_error("unknown token: " + std::to_string(token)); 255 | } 256 | } 257 | ret += token_bytes; 258 | } 259 | return ret; 260 | } 261 | 262 | ankerl::unordered_dense::map encoder_; 263 | ankerl::unordered_dense::map special_tokens_encoder; 264 | ankerl::unordered_dense::map decoder_; 265 | ankerl::unordered_dense::map special_tokens_decoder; 266 | std::unique_ptr regex_; 267 | std::unique_ptr special_regex_; 268 | }; 269 | 270 | } // namespace tiktoken 271 | -------------------------------------------------------------------------------- /src/tokenizer.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "base64.h" 6 | #include "tokenizer.hpp" 7 | 8 | bool Tokenizer::Init(const std::string &token_model_path) { 9 | const auto status = processor.Load(token_model_path); 10 | if (!status.ok()) { 11 | spdlog::critical("failed to init tokenizer from path {}", token_model_path); 12 | return false; 13 | } 14 | 15 | n_words = processor.GetPieceSize(); 16 | bos_id = processor.bos_id(); 17 | eos_id = processor.eos_id(); 18 | pad_id = processor.pad_id(); 19 | 20 | spdlog::info("initialized tokenizer from {}, nwords: {}, bos_id: {}, eos_id: " 21 | "{}, pad_id: {}", 22 | token_model_path, n_words, bos_id, eos_id, pad_id); 23 | 24 | return true; 25 | } 26 | 27 | std::vector Tokenizer::Encode(const std::string &text, bool bos, 28 | bool eos) { 29 | std::vector result; 30 | processor.Encode(text, &result); 31 | if (bos) { 32 | result.insert(result.begin(), bos_id); 33 | } 34 | if (eos) { 35 | result.push_back(eos_id); 36 | } 37 | return result; 38 | } 39 | 40 | std::string Tokenizer::Decode(const std::vector &ids) { 41 | std::string result; 42 | auto status = processor.Decode(ids, &result); 43 | if (!status.ok()) { 44 | spdlog::critical("failed to Decode {}", status.error_message()); 45 | } 46 | return result; 47 | } 48 | 49 | static std::pair _parse(const std::string &line) { 50 | auto pos = line.find(" "); 51 | if (pos == std::string::npos) { 52 | throw std::runtime_error("invalid encoder line: " + line); 53 | } 54 | 55 | auto token = base64::decode({line.data(), pos}); 56 | int rank = 0; 57 | try { 58 | rank = std::stoul(line.substr(pos + 1)); 59 | } catch (const std::exception &) { 60 | throw std::runtime_error("invalid encoder rank: " + line); 61 | } 62 | 63 | return {std::move(token), rank}; 64 | } 65 | 66 | void QwenTokenizer::Init(const std::string &tiktoken_path) { 67 | std::ifstream file(tiktoken_path); 68 | if (!file) { 69 | throw std::runtime_error("failed to open encoder file: " + tiktoken_path); 70 | } 71 | 72 | ankerl::unordered_dense::map encoder; 73 | std::string line; 74 | while (std::getline(file, line)) { 75 | auto [token, rank] = _parse(line); 76 | 77 | if (!encoder.emplace(std::move(token), rank).second) { 78 | throw std::runtime_error("duplicate item: " + line); 79 | } 80 | } 81 | 82 | std::vector special_tokens_s{"<|endoftext|>", "<|im_start|>", 83 | "<|im_end|>"}; 84 | char buffer[14]; 85 | for (size_t i = 0; i < 205; i++) { 86 | snprintf(buffer, 14, "<|extra_%zu|>", i); 87 | special_tokens_s.push_back(buffer); 88 | } 89 | size_t encoder_size = encoder.size(); 90 | ankerl::unordered_dense::map special_tokens; 91 | special_tokens.reserve(special_tokens_s.size()); 92 | for (size_t i = 0; i < special_tokens_s.size(); i++) { 93 | special_tokens[special_tokens_s[i]] = encoder_size + i; 94 | } 95 | 96 | tokenizer = tiktoken::tiktoken(std::move(encoder), special_tokens, PAT_STR); 97 | eos_token_id = encoder_size + 0; 98 | im_start_id = encoder_size + 1; 99 | im_end_id = encoder_size + 2; 100 | } 101 | 102 | auto QwenTokenizer::build_prompt(const std::vector &history) const 103 | -> std::string { 104 | if (!(history.size() % 2 == 1)) { 105 | spdlog::critical("invalid history size {}", history.size()); 106 | } 107 | 108 | std::ostringstream oss_prompt; 109 | oss_prompt << "<|im_start|>system\nYou are a helpful assistant.<|im_end|>"; 110 | for (size_t i = 0; i < history.size() - 1; i += 2) { 111 | oss_prompt << "\n<|im_start|>user\n" 112 | << history[i] << "<|im_end|>\n<|im_start|>" << history[i + 1] 113 | << "<|im_end|>"; 114 | } 115 | oss_prompt << "\n<|im_start|>user\n" 116 | << history.back() << "<|im_end|>\n<|im_start|>assistant\n"; 117 | 118 | return oss_prompt.str(); 119 | } 120 | 121 | auto QwenTokenizer::encode(const std::string &text, int max_length) const 122 | -> std::vector { 123 | auto ids = tokenizer.encode(text); 124 | if ((int)ids.size() > max_length) { 125 | ids.erase(ids.begin(), ids.end() - max_length); 126 | } 127 | return ids; 128 | } 129 | 130 | auto QwenTokenizer::decode(const std::vector &ids) const -> std::string { 131 | std::vector normal_ids(ids); 132 | normal_ids.erase(std::remove_if(normal_ids.begin(), normal_ids.end(), 133 | [this](int id) { return is_special_id(id); }), 134 | normal_ids.end()); 135 | auto text = tokenizer.decode(normal_ids); 136 | return text; 137 | } 138 | 139 | auto QwenTokenizer::encode_history(const std::vector &history, 140 | int max_length) const -> std::vector { 141 | std::string prompt = build_prompt(history); 142 | std::vector input_ids = encode(prompt, max_length); 143 | return input_ids; 144 | } 145 | 146 | auto QwenTokenizer::is_special_id(int id) const -> bool { 147 | return id == eos_token_id || id == im_start_id || id == im_end_id; 148 | } 149 | 150 | void Qwen2HFTokenizer::from_pretrained(const std::string &tokenizer_dir) { 151 | py_transformers_module = 152 | PyImport_ImportModule("transformers"); 153 | if (py_transformers_module == nullptr) { 154 | PyErr_Print(); 155 | throw std::exception(); 156 | } 157 | 158 | 159 | 160 | nlohmann::json tokenizer_config; 161 | 162 | auto cfg_json = boost::filesystem::path(tokenizer_dir) / "tokenizer_config.json"; 163 | 164 | 165 | std::ifstream config_fs(cfg_json.c_str()); 166 | if (!config_fs) { 167 | spdlog::error("failed to open tokenizer conifg {}", cfg_json.c_str()); 168 | throw std::exception(); 169 | } 170 | 171 | config_fs >> tokenizer_config; 172 | 173 | auto tokenizer_class = tokenizer_config["tokenizer_class"].get(); 174 | spdlog::info("using tokenizer_class {}", tokenizer_class); 175 | 176 | py_tokenizer_clz = 177 | PyObject_GetAttrString(py_transformers_module, tokenizer_class.c_str()); 178 | if (py_transformers_module == nullptr) { 179 | PyErr_Print(); 180 | throw std::exception(); 181 | } 182 | 183 | PyObject *init_args = PyTuple_New(1); 184 | PyTuple_SetItem(init_args, 0, PyUnicode_FromString(tokenizer_dir.c_str())); 185 | 186 | PyObject *kwargs = PyDict_New(); 187 | PyDict_SetItemString(kwargs, "trust_remote_code", Py_True); 188 | PyDict_SetItemString(kwargs, "local_files_only", Py_True); 189 | 190 | PyObject *py_tokenizer = 191 | PyObject_Call(PyObject_GetAttrString(py_tokenizer_clz, "from_pretrained"), 192 | init_args, kwargs); 193 | 194 | Py_DECREF(init_args); 195 | Py_DECREF(kwargs); 196 | 197 | if (py_tokenizer == nullptr) { 198 | PyErr_Print(); 199 | throw std::exception(); 200 | } 201 | 202 | py_encode_func = PyObject_GetAttrString(py_tokenizer, "encode"); 203 | if (py_encode_func == nullptr) { 204 | PyErr_Print(); 205 | throw std::exception(); 206 | } 207 | 208 | py_decode_func = PyObject_GetAttrString(py_tokenizer, "decode"); 209 | if (py_decode_func == nullptr) { 210 | PyErr_Print(); 211 | throw std::exception(); 212 | } 213 | 214 | PyObject* py_eos = PyObject_GetAttrString(py_tokenizer, "eos_token"); 215 | if (py_eos == nullptr) { 216 | PyErr_Print(); 217 | throw std::exception(); 218 | } 219 | std::string eos_str = PyUnicode_AsUTF8(py_eos); 220 | 221 | nlohmann::json js_tok; 222 | std::ifstream tok_fs( 223 | (boost::filesystem::path(tokenizer_dir) / "tokenizer.json").c_str()); 224 | tok_fs >> js_tok; 225 | auto add_tokens = js_tok["added_tokens"]; 226 | 227 | for (const auto &d : add_tokens) { 228 | if (d["content"].get() == eos_str) { 229 | eos_token_id = d["id"].get(); 230 | } 231 | if (d["content"].get() == "<|im_start|>") { 232 | im_start_id = d["id"].get(); 233 | } 234 | if (d["content"].get() == "<|im_end|>") { 235 | im_end_id = d["id"].get(); 236 | } 237 | } 238 | } 239 | 240 | std::vector Qwen2HFTokenizer::encode(const std::string &text, 241 | int max_length) const { 242 | PyObject *text_args = PyTuple_New(1); 243 | PyTuple_SetItem(text_args, 0, PyUnicode_FromString(text.c_str())); 244 | PyObject *result = PyObject_CallObject(py_encode_func, text_args); 245 | Py_DECREF(text_args); 246 | 247 | if (result == nullptr) { 248 | PyErr_Print(); 249 | throw std::exception(); 250 | } 251 | 252 | std::vector ids; 253 | if (PyList_Check(result)) { 254 | Py_ssize_t size = PyList_Size(result); 255 | ids.reserve(size); 256 | for (Py_ssize_t i = 0; i < size; ++i) { 257 | PyObject *item = PyList_GetItem(result, i); 258 | ids.push_back(PyLong_AsLong(item)); 259 | } 260 | } 261 | 262 | return ids; 263 | } 264 | 265 | std::string Qwen2HFTokenizer::decode(const std::vector &ids) { 266 | PyObject *id_list = PyList_New(ids.size()); 267 | for (size_t i = 0; i < ids.size(); ++i) { 268 | PyList_SetItem(id_list, i, PyLong_FromLong(ids[i])); 269 | } 270 | 271 | PyObject *id_list_args = PyTuple_New(1); 272 | PyTuple_SetItem(id_list_args, 0, id_list); 273 | 274 | PyObject *result = PyObject_CallObject(py_decode_func, id_list_args); 275 | Py_DECREF(id_list_args); 276 | 277 | if (result == nullptr) { 278 | PyErr_Print(); 279 | throw std::exception(); 280 | } 281 | 282 | if (!PyUnicode_Check(result)) { 283 | Py_DECREF(result); 284 | throw std::exception(); 285 | } 286 | const char *str_result = PyUnicode_AsUTF8(result); 287 | std::string decoded_str(str_result); 288 | Py_DECREF(result); 289 | return decoded_str; 290 | } 291 | 292 | auto Qwen2HFTokenizer::build_prompt( 293 | const std::vector &history) const -> std::string { 294 | if (!(history.size() % 2 == 1)) { 295 | spdlog::critical("invalid history size {}", history.size()); 296 | } 297 | 298 | std::ostringstream oss_prompt; 299 | oss_prompt << "<|im_start|>system\nYou are a helpful assistant.<|im_end|>"; 300 | for (size_t i = 0; i < history.size() - 1; i += 2) { 301 | oss_prompt << "\n<|im_start|>user\n" 302 | << history[i] << "<|im_end|>\n<|im_start|>" << history[i + 1] 303 | << "<|im_end|>"; 304 | } 305 | oss_prompt << "\n<|im_start|>user\n" 306 | << history.back() << "<|im_end|>\n<|im_start|>assistant\n"; 307 | 308 | return oss_prompt.str(); 309 | } 310 | 311 | auto Qwen2HFTokenizer::encode_history(const std::vector &history, 312 | int max_length) const 313 | -> std::vector { 314 | std::string prompt = build_prompt(history); 315 | std::vector input_ids = encode(prompt, max_length); 316 | return input_ids; 317 | } 318 | 319 | auto Qwen2HFTokenizer::is_special_id(int id) const -> bool { 320 | return id == eos_token_id || id == im_start_id || id == im_end_id; 321 | } 322 | -------------------------------------------------------------------------------- /src/tokenizer.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "tiktoken.h" 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | class Tokenizer { 10 | public: 11 | Tokenizer() = default; 12 | ~Tokenizer() = default; 13 | bool Init(const std::string &token_model_path); 14 | std::vector Encode(const std::string &text, bool bos, bool eos); 15 | std::string Decode(const std::vector &ids); 16 | // private: 17 | sentencepiece::SentencePieceProcessor processor; 18 | 19 | int32_t n_words; 20 | int32_t bos_id; 21 | int32_t eos_id; 22 | int32_t pad_id; 23 | }; 24 | 25 | // from https://github.com/QwenLM/qwen.cpp/blob/master/qwen.h 26 | static const std::string PAT_STR = 27 | R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?:$|[^\S])|\s+)"; 28 | 29 | struct QwenConfig { 30 | // common attributes 31 | // ggml_type dtype; 32 | int vocab_size; 33 | int hidden_size; 34 | int num_attention_heads; 35 | int num_kv_heads; 36 | int num_hidden_layers; 37 | int intermediate_size; 38 | // for sequence generation 39 | int max_length; 40 | // for tokenizer 41 | int eos_token_id; 42 | int pad_token_id; 43 | int im_start_id; 44 | int im_end_id; 45 | }; 46 | 47 | class QwenTokenizer { 48 | public: 49 | QwenTokenizer() = default; 50 | ~QwenTokenizer() = default; 51 | void Init(const std::string &tiktoken_path); 52 | 53 | auto encode(const std::string &text, int max_length) const 54 | -> std::vector; 55 | 56 | auto decode(const std::vector &ids) const -> std::string; 57 | 58 | auto encode_history(const std::vector &history, 59 | int max_length) const -> std::vector; 60 | 61 | auto build_prompt(const std::vector &history) const 62 | -> std::string; 63 | 64 | auto is_special_id(int id) const -> bool; 65 | 66 | tiktoken::tiktoken tokenizer; 67 | int eos_token_id; 68 | int im_start_id; 69 | int im_end_id; 70 | }; 71 | 72 | class Qwen2HFTokenizer { 73 | public: 74 | Qwen2HFTokenizer() = default; 75 | void from_pretrained(const std::string &tokenizer_dir); 76 | std::vector encode(const std::string &text, int max_length = -1) const; 77 | std::string decode(const std::vector &ids); 78 | 79 | auto encode_history(const std::vector &history, 80 | int max_length) const -> std::vector; 81 | 82 | auto build_prompt(const std::vector &history) const 83 | -> std::string; 84 | 85 | auto is_special_id(int id) const -> bool; 86 | 87 | int eos_token_id; 88 | int im_start_id; 89 | int im_end_id; 90 | 91 | private: 92 | PyObject *py_tokenizer{nullptr}; 93 | PyObject *py_tokenizer_clz{nullptr}; 94 | PyObject *py_transformers_module{nullptr}; 95 | PyObject *py_encode_func{nullptr}; 96 | PyObject *py_decode_func{nullptr}; 97 | }; 98 | -------------------------------------------------------------------------------- /src/util.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "util.h" 5 | 6 | bool LoadBinaryFile(const char *path, void *buffer, size_t data_size) { 7 | std::ifstream ifs(path, std::ios::binary); 8 | if (!ifs) { 9 | spdlog::critical("failed to open {}", path); 10 | return false; 11 | } 12 | 13 | ifs.read((char *)buffer, data_size); 14 | 15 | if (ifs.fail()) { 16 | spdlog::critical("failed to read {} of size {}", path, data_size); 17 | return false; 18 | } 19 | 20 | spdlog::debug("loaded binary {} size {}", path, data_size); 21 | 22 | return true; 23 | } 24 | 25 | int64_t get_current_us() { 26 | auto now = std::chrono::system_clock::now(); 27 | 28 | auto micros = std::chrono::duration_cast( 29 | now.time_since_epoch()); 30 | return micros.count(); 31 | } 32 | 33 | uint16_t fp32_to_bfloat16(uint32_t fp32_bits) { 34 | // Step 1: 分解FP32的组成部分 35 | const uint32_t sign = (fp32_bits >> 31) & 0x1; // 符号位 36 | const uint32_t exponent = (fp32_bits >> 23) & 0xFF; // 指数部分 37 | const uint32_t mantissa = fp32_bits & 0x007FFFFF; // 原尾数(不含隐含的1) 38 | 39 | // Step 2: 处理尾数舍入(含隐含的1,共24位) 40 | const uint32_t full_mantissa = 41 | mantissa | 0x00800000; // 添加隐含的1构成24位有效数 42 | const uint32_t trunc_bits = full_mantissa & 0x00FFFFFF; // 确保24位有效数 43 | 44 | // 提取需要保留的高8位和需要截断的低16位 45 | const uint32_t mant_high = (trunc_bits >> 16) & 0xFF; // 高8位(含隐含的1) 46 | const uint32_t mant_low = trunc_bits & 0xFFFF; // 低16位用于舍入判断 47 | 48 | // Round to Nearest, Ties to Even (RNTE) 49 | uint32_t rounded_mant = mant_high; 50 | if (mant_low > 0x8000) { // 大于中间值,进位 51 | rounded_mant += 1; 52 | } else if (mant_low == 0x8000) { // 等于中间值,判断奇偶 53 | if ((mant_high & 0x1) != 0) { // 奇数则进位 54 | rounded_mant += 1; 55 | } 56 | } // 小于中间值直接截断 57 | 58 | // Step 3: 处理尾数进位溢出(如0xFF -> 0x100) 59 | uint32_t new_exponent = exponent; 60 | if (rounded_mant > 0xFF) { // 进位导致尾数溢出 61 | new_exponent += 1; // 指数+1 62 | rounded_mant = 0x80; // 尾数重置为隐含的1 + 0x00(即0x80) 63 | } 64 | 65 | // Step 4: 组合BFloat16的二进制 66 | // 若指数溢出(>0xFF),则结果为无穷大(这里简化为饱和处理) 67 | const uint32_t bf16_exponent = (new_exponent > 0xFF) ? 0xFF : new_exponent; 68 | const uint32_t bf16_mantissa = 69 | (rounded_mant & 0x7F); // 取低7位(去掉隐含的1) 70 | 71 | return ((sign << 15) | (bf16_exponent << 7) | bf16_mantissa); 72 | } 73 | 74 | bool has_awq_quantization(const nlohmann::json &j) { 75 | return j.contains("quantization_config") && 76 | j["quantization_config"].is_object() && 77 | j["quantization_config"].value("quant_method", "") == "awq"; 78 | } 79 | -------------------------------------------------------------------------------- /src/util.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include 13 | #include 14 | #include 15 | 16 | #include "defs.hpp" 17 | 18 | bool LoadBinaryFile(const char *path, void *buffer, size_t data_size); 19 | 20 | template 21 | std::string TensorShapeToStr(TensorType &tensor) { 22 | auto &d = tensor.dimensions(); 23 | std::stringstream ss; 24 | ss << "("; 25 | for (int i = 0; i < d.size(); ++i) { 26 | ss << d[i]; 27 | if (i < d.size() - 1) { 28 | ss << ","; 29 | } 30 | } 31 | ss << ")"; 32 | return ss.str(); 33 | } 34 | 35 | int64_t get_current_us(); 36 | 37 | uint16_t fp32_to_bfloat16(uint32_t fp32_bits); 38 | 39 | bool has_awq_quantization(const nlohmann::json &j); 40 | 41 | template 42 | std::string 43 | print_tensor_typed(EigenTy *ptr, 44 | const Eigen::array &tensor_dim, 45 | const Eigen::array &print_extend) { 46 | 47 | Eigen::TensorMap< 48 | Eigen::Tensor> 49 | t_map(static_cast(ptr), tensor_dim); 50 | 51 | Eigen::array print_offsets{}; 52 | Eigen::Tensor print_slice = 53 | t_map.slice(print_offsets, print_extend); 54 | std::stringstream ss; 55 | ss << print_slice; 56 | return ss.str(); 57 | } 58 | 59 | template 60 | std::string print_tensor(void *ptr, DataType dtype, 61 | const Eigen::array &tensor_dim, 62 | const Eigen::array &print_extend) { 63 | switch (dtype) { 64 | case DT_FLOAT16: 65 | return print_tensor_typed( 66 | static_cast(ptr), tensor_dim, print_extend); 67 | break; 68 | case DT_BFLOAT16: 69 | return print_tensor_typed( 70 | static_cast(ptr), tensor_dim, print_extend); 71 | break; 72 | case DT_FLOAT32: 73 | return print_tensor_typed(static_cast(ptr), 74 | tensor_dim, print_extend); 75 | break; 76 | default: 77 | break; 78 | } 79 | 80 | return fmt::format("print_tensor unsupported dtype: {}", dtype); 81 | } 82 | 83 | template 84 | std::string print_tensor(void *ptr, DataType dtype, 85 | const std::array &tensor_dim_s64, 86 | const std::array &print_extend_s64) { 87 | Eigen::array tensor_dim; 88 | Eigen::array print_extend; 89 | std::transform(tensor_dim_s64.begin(), tensor_dim_s64.end(), 90 | tensor_dim.begin(), 91 | [](size_t x) { return static_cast(x); }); 92 | std::transform(print_extend_s64.begin(), print_extend_s64.end(), 93 | print_extend.begin(), 94 | [](size_t x) { return static_cast(x); }); 95 | return print_tensor<2>(ptr, dtype, tensor_dim, print_extend); 96 | } 97 | 98 | 99 | -------------------------------------------------------------------------------- /tests/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | if (DEFINED ACL_PATH) 2 | message(STATUS "user set ACL_PATH: ${ACL_PATH}") 3 | elseif (EXISTS /usr/local/Ascend/acllib/) 4 | set(ACL_PATH "/usr/local/Ascend/acllib") 5 | message(STATUS "set ACL_PATH: /usr/local/Ascend/acllib") 6 | elseif (EXISTS /usr/local/Ascend/ascend-toolkit/latest/acllib) 7 | set(ACL_PATH "/usr/local/Ascend/ascend-toolkit/latest/acllib") 8 | message(STATUS "set ACL_PATH to default path: /usr/local/Ascend/ascend-toolkit/latest/acllib") 9 | elseif (EXISTS /usr/local/Ascend/nnrt/latest/acllib) 10 | set(ACL_PATH "/usr/local/Ascend/nnrt/latest/acllib") 11 | message(STATUS "set ACL_PATH to default path: /usr/local/Ascend/nnrt/latest/acllib") 12 | else () 13 | set(ACL_PATH "/home/HwHiAiUser/Ascend/acllib") 14 | message(STATUS "set ACL_PATH to default path: /home/HwHiAiUser/Ascend/acllib") 15 | endif() 16 | 17 | 18 | find_package (Eigen3 REQUIRED NO_MODULE) 19 | 20 | enable_testing() 21 | find_package(GTest REQUIRED) 22 | 23 | add_executable(op_test npu_operator_test.cpp rms_norm_layer_test.cpp gemm_awq_4bit_test.cpp npu_op_test_util.cpp) 24 | target_include_directories(op_test PUBLIC ${EIGEN3_INCLUDE_DIR} ${ACL_PATH}/include ../prebuild) 25 | target_link_directories(op_test PUBLIC ${ACL_PATH}/lib64 ../prebuild) 26 | target_link_libraries(op_test GTest::gtest GTest::gtest_main sentencepiece boost_program_options boost_system boost_filesystem fmt ascendcl runtime npu_ops) 27 | 28 | add_test(NPU_operator_test op_test) 29 | 30 | add_executable(op_test_rms_norm rms_norm_layer_main.cpp rms_norm_layer_test.cpp npu_op_test_util.cpp) 31 | target_include_directories(op_test_rms_norm PUBLIC ${EIGEN3_INCLUDE_DIR} ${ACL_PATH}/include ../prebuild) 32 | target_link_directories(op_test_rms_norm PUBLIC ${ACL_PATH}/lib64 ../prebuild) 33 | target_link_libraries(op_test_rms_norm sentencepiece boost_program_options boost_system boost_filesystem fmt ascendcl runtime npu_ops) 34 | 35 | add_executable(op_test_gemm_awq_4bit gemm_awq_4bit_main.cpp gemm_awq_4bit_test.cpp npu_op_test_util.cpp) 36 | target_include_directories(op_test_gemm_awq_4bit PUBLIC ${EIGEN3_INCLUDE_DIR} ${ACL_PATH}/include ../prebuild) 37 | target_link_directories(op_test_gemm_awq_4bit PUBLIC ${ACL_PATH}/lib64 ../prebuild) 38 | target_link_libraries(op_test_gemm_awq_4bit sentencepiece boost_program_options boost_system boost_filesystem fmt ascendcl runtime npu_ops) 39 | 40 | add_executable(op_test_flash_attn flash_attn_main.cpp flash_attn_test.cpp npu_op_test_util.cpp) 41 | target_include_directories(op_test_flash_attn PUBLIC ${EIGEN3_INCLUDE_DIR} ${ACL_PATH}/include ../prebuild) 42 | target_link_directories(op_test_flash_attn PUBLIC ${ACL_PATH}/lib64 ../prebuild) 43 | target_link_libraries(op_test_flash_attn sentencepiece boost_program_options boost_system boost_filesystem fmt ascendcl runtime npu_ops) 44 | 45 | add_executable(op_test_gemm gemm_main.cpp gemm_test.cpp npu_op_test_util.cpp) 46 | target_include_directories(op_test_gemm PUBLIC ${EIGEN3_INCLUDE_DIR} ${ACL_PATH}/include ../prebuild) 47 | target_link_directories(op_test_gemm PUBLIC ${ACL_PATH}/lib64 ../prebuild) 48 | target_link_libraries(op_test_gemm sentencepiece boost_program_options boost_system boost_filesystem fmt ascendcl runtime npu_ops) 49 | 50 | add_executable(op_test_rope_single rope_single_layer_main.cpp rope_single_layer_test.cpp npu_op_test_util.cpp) 51 | target_include_directories(op_test_rope_single PUBLIC ${EIGEN3_INCLUDE_DIR} ${ACL_PATH}/include ../prebuild) 52 | target_link_directories(op_test_rope_single PUBLIC ${ACL_PATH}/lib64 ../prebuild) 53 | target_link_libraries(op_test_rope_single sentencepiece boost_program_options boost_system boost_filesystem fmt ascendcl runtime npu_ops) 54 | 55 | add_executable(op_test_embedding embedding_main.cpp embedding_test.cpp npu_op_test_util.cpp) 56 | target_include_directories(op_test_embedding PUBLIC ${EIGEN3_INCLUDE_DIR} ${ACL_PATH}/include ../prebuild) 57 | target_link_directories(op_test_embedding PUBLIC ${ACL_PATH}/lib64 ../prebuild) 58 | target_link_libraries(op_test_embedding sentencepiece boost_program_options boost_system boost_filesystem fmt ascendcl runtime npu_ops) 59 | 60 | -------------------------------------------------------------------------------- /tests/embedding_main.cpp: -------------------------------------------------------------------------------- 1 | #include "embedding_test.h" 2 | 3 | namespace po = boost::program_options; 4 | 5 | 6 | int main(int argc, char **argv) { 7 | aclrtContext context; 8 | int32_t deviceId{0}; 9 | int max_index_num; 10 | int vocab_size; 11 | int hidden_dim; 12 | int test_size; 13 | 14 | // clang-format off 15 | po::options_description desc("RopeSingleLayer options"); 16 | desc.add_options() 17 | ("help", "produce help message") 18 | ("max_index_num", po::value(&max_index_num)->default_value(4096), "max_index_num. default:4096") 19 | ("hidden_dim", po::value(&hidden_dim)->default_value(2048), "hidden_dim. default:2048") 20 | ("vocab_size", po::value(&vocab_size)->default_value(151936), "vocab_size. default:151936") 21 | ("test_size", po::value(&test_size)->default_value(32), "test_size. default:32"); 22 | 23 | // clang-format on 24 | po::variables_map vm; 25 | po::store(po::parse_command_line(argc, argv, desc), vm); 26 | 27 | if (vm.count("help")) { 28 | std::cout << desc << "\n"; 29 | return 1; 30 | } 31 | po::notify(vm); 32 | 33 | 34 | CHECK_ACL(aclInit(nullptr)); 35 | CHECK_ACL(aclrtSetDevice(deviceId)); 36 | CHECK_ACL(aclrtCreateContext(&context, deviceId)); 37 | 38 | EmbeddingOpTest op_test; 39 | op_test.Init(max_index_num, vocab_size, hidden_dim); 40 | bool test_result = op_test.Run(test_size); 41 | op_test.CleanUp(); 42 | 43 | CHECK_ACL(aclrtDestroyContext(context)); 44 | CHECK_ACL(aclrtResetDevice(deviceId)); 45 | CHECK_ACL(aclFinalize()); 46 | 47 | 48 | 49 | if (test_result) { 50 | spdlog::info("test success"); 51 | } else { 52 | spdlog::error("test failed"); 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /tests/embedding_test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | #include "embedding_test.h" 17 | #include "npu_op_test_util.h" 18 | #include "npu_ops.h" 19 | 20 | void EmbeddingOpTest::Init(size_t max_index_num, size_t vocab_size, 21 | size_t hidden_dim) { 22 | this->max_index_num = max_index_num; 23 | this->vocab_size = vocab_size; 24 | this->hidden_dim = hidden_dim; 25 | 26 | host_input = new int32_t[max_index_num]; 27 | host_weight = new uint16_t[vocab_size * hidden_dim]; 28 | host_output = new uint16_t[max_index_num * hidden_dim]; 29 | golden_u16 = new uint16_t[max_index_num * hidden_dim]; 30 | 31 | CHECK_ACL(aclrtMalloc((void **)&dev_index_s32, 32 | max_index_num * sizeof(int32_t), 33 | ACL_MEM_MALLOC_HUGE_FIRST)); 34 | CHECK_ACL(aclrtMalloc((void **)&dev_weight_u16, 35 | vocab_size * hidden_dim * sizeof(uint16_t), 36 | ACL_MEM_MALLOC_HUGE_FIRST)); 37 | CHECK_ACL(aclrtMalloc((void **)&dev_output_u16, 38 | max_index_num * hidden_dim * sizeof(uint16_t), 39 | ACL_MEM_MALLOC_HUGE_FIRST)); 40 | 41 | make_random_bytes((void *)host_weight, 42 | vocab_size * hidden_dim * sizeof(uint16_t)); 43 | CHECK_ACL(aclrtMemcpy( 44 | dev_weight_u16, vocab_size * hidden_dim * sizeof(uint16_t), host_weight, 45 | vocab_size * hidden_dim * sizeof(uint16_t), ACL_MEMCPY_HOST_TO_DEVICE)); 46 | } 47 | 48 | bool EmbeddingOpTest::Run(size_t test_size) { 49 | spdlog::info("{} vocab_size {} hidden_dim {} test_size {}", 50 | __PRETTY_FUNCTION__, vocab_size, hidden_dim, test_size); 51 | 52 | std::mt19937 generator(std::random_device{}()); 53 | std::uniform_int_distribution distribution(0, vocab_size - 1); 54 | 55 | for (std::size_t i = 0; i < test_size; ++i) { 56 | host_input[i] = static_cast(distribution(generator)); 57 | } 58 | 59 | CHECK_ACL(aclrtMemcpy(dev_index_s32, test_size * sizeof(int32_t), host_input, 60 | test_size * sizeof(int32_t), 61 | ACL_MEMCPY_HOST_TO_DEVICE)); 62 | 63 | npu_embedding_layer(dev_output_u16, dev_weight_u16, dev_index_s32, test_size, 64 | hidden_dim, DT_FLOAT16, stream); 65 | 66 | CHECK_ACL(aclrtSynchronizeStream(stream)); 67 | 68 | CHECK_ACL(aclrtMemcpy( 69 | host_output, test_size * hidden_dim * sizeof(uint16_t), dev_output_u16, 70 | test_size * hidden_dim * sizeof(uint16_t), ACL_MEMCPY_DEVICE_TO_HOST)); 71 | 72 | for (int i = 0; i < test_size; ++i) { 73 | int32_t idx = host_input[i]; 74 | for (int j = 0; j < hidden_dim; ++j) { 75 | auto a = host_weight[idx * hidden_dim + j]; 76 | auto b = host_output[i * hidden_dim + j]; 77 | if (a != b) { 78 | std::cout << "all_close failed, index " << idx << " output [" << i 79 | << "," << j << "] :" << a << " vs " << b << std::endl; 80 | return false; 81 | } 82 | } 83 | } 84 | 85 | return true; 86 | } 87 | 88 | void EmbeddingOpTest::CleanUp() { 89 | 90 | delete[] host_input; 91 | delete[] host_output; 92 | delete[] host_weight; 93 | delete[] golden_u16; 94 | 95 | CHECK_ACL(aclrtFree(dev_index_s32)); 96 | CHECK_ACL(aclrtFree(dev_output_u16)); 97 | CHECK_ACL(aclrtFree(dev_weight_u16)); 98 | } 99 | -------------------------------------------------------------------------------- /tests/embedding_test.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "npu_op_test_util.h" 4 | 5 | class EmbeddingOpTest : public OpTestBase { 6 | public: 7 | void Init(size_t max_index_num, size_t vocab_size, size_t hidden_size); 8 | bool Run(size_t test_size); 9 | void CleanUp(); 10 | 11 | 12 | size_t max_index_num{0}; 13 | size_t vocab_size{0}; 14 | size_t hidden_dim{0}; 15 | 16 | void *dev_index_s32{nullptr}; 17 | void *dev_weight_u16{nullptr}; 18 | void *dev_output_u16{nullptr}; 19 | 20 | int32_t *host_input{nullptr}; 21 | uint16_t *host_weight{nullptr}; 22 | uint16_t *host_output{nullptr}; 23 | uint16_t *golden_u16{nullptr}; 24 | }; 25 | -------------------------------------------------------------------------------- /tests/flash_attn_main.cpp: -------------------------------------------------------------------------------- 1 | #include "flash_attn_test.h" 2 | 3 | namespace po = boost::program_options; 4 | 5 | int main(int argc, char **argv) { 6 | aclrtContext context; 7 | int32_t deviceId{0}; 8 | int m; 9 | int n; 10 | int offset; 11 | int head_num; 12 | int head_dim; 13 | 14 | po::options_description desc("GemmAWQ4BitOpTest options"); 15 | desc.add_options()("help", "produce help message") // 16 | ("m", po::value(&m)->default_value(2048), "m. default:2048") // 17 | ("n", po::value(&n)->default_value(2048), "n. default:2048") // 18 | ("offset", po::value(&offset)->default_value(0), 19 | "offset. default:0") // 20 | ("head_num", po::value(&head_num)->default_value(32), 21 | "head_num. default:32") // 22 | ("head_dim", po::value(&head_dim)->default_value(128), 23 | "head_dim. default:128"); 24 | 25 | po::variables_map vm; 26 | po::store(po::parse_command_line(argc, argv, desc), vm); 27 | 28 | if (vm.count("help")) { 29 | std::cout << desc << "\n"; 30 | return 1; 31 | } 32 | po::notify(vm); 33 | 34 | CHECK_ACL(aclInit(nullptr)); 35 | CHECK_ACL(aclrtSetDevice(deviceId)); 36 | CHECK_ACL(aclrtCreateContext(&context, deviceId)); 37 | 38 | FlashAttentionOpTest op_test; 39 | op_test.Init(m, n, head_num, head_dim); 40 | bool test_result = op_test.Run(m, n, offset); 41 | op_test.CleanUp(); 42 | 43 | CHECK_ACL(aclrtDestroyContext(context)); 44 | CHECK_ACL(aclrtResetDevice(deviceId)); 45 | CHECK_ACL(aclFinalize()); 46 | 47 | if (test_result) { 48 | spdlog::info("test success"); 49 | } else { 50 | spdlog::error("test failed"); 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /tests/flash_attn_test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | #include "defs.hpp" 17 | #include "flash_attn_test.h" 18 | #include "npu_op_test_util.h" 19 | #include "npu_ops.h" 20 | 21 | void FlashAttentionOpTest::Init(size_t max_m, size_t max_n, size_t head_num, 22 | size_t head_dim) { 23 | this->max_m = max_m; 24 | this->max_n = max_n; 25 | this->head_dim = head_dim; 26 | this->head_num = head_num; 27 | this->hidden_dim = head_dim * head_num; 28 | 29 | max_q_buffer_size = max_m * hidden_dim; 30 | max_k_buffer_size = max_n * hidden_dim; 31 | max_qk_buffer_size = max_m * max_n * head_num; 32 | max_pv_buffer_size = max_m * hidden_dim; 33 | max_softmax_buffer_size = max_m * max_n * head_num; 34 | max_v_buffer_size = max_n * hidden_dim; 35 | max_o_buffer_size = max_m * hidden_dim; 36 | 37 | host_q = new float[max_q_buffer_size]; 38 | host_k = new float[max_k_buffer_size]; 39 | host_qk = new float[max_qk_buffer_size]; 40 | host_pv = new float[max_pv_buffer_size]; 41 | host_mask = new float[max_m * max_n]; 42 | host_softmax_output = new float[max_softmax_buffer_size]; 43 | host_v = new float[max_v_buffer_size]; 44 | host_o = new float[max_o_buffer_size]; 45 | 46 | host_q_f16 = new Eigen::half[max_q_buffer_size]; 47 | host_k_f16 = new Eigen::half[max_k_buffer_size]; 48 | host_v_f16 = new Eigen::half[max_v_buffer_size]; 49 | host_o_f16 = new Eigen::half[max_o_buffer_size]; 50 | 51 | golden_output = new float[max_o_buffer_size]; 52 | 53 | CHECK_ACL(aclrtMalloc((void **)&dev_q_f16, 54 | max_q_buffer_size * sizeof(aclFloat16), 55 | ACL_MEM_MALLOC_HUGE_FIRST)); 56 | CHECK_ACL(aclrtMalloc((void **)&dev_k_f16, 57 | max_k_buffer_size * sizeof(aclFloat16), 58 | ACL_MEM_MALLOC_HUGE_FIRST)); 59 | CHECK_ACL(aclrtMalloc((void **)&dev_v_f16, 60 | max_v_buffer_size * sizeof(aclFloat16), 61 | ACL_MEM_MALLOC_HUGE_FIRST)); 62 | CHECK_ACL(aclrtMalloc((void **)&dev_o_f16, 63 | max_o_buffer_size * sizeof(aclFloat16), 64 | ACL_MEM_MALLOC_HUGE_FIRST)); 65 | } 66 | 67 | bool FlashAttentionOpTest::Run(size_t m, size_t n, size_t offset) { 68 | spdlog::info( 69 | "FlashAttentionOpTest::Run {{m={},n={},offset={},hidden_dim={}}}", m, n, 70 | offset, hidden_dim); 71 | 72 | size_t q_element_cnt = m * hidden_dim; 73 | size_t k_element_cnt = n * hidden_dim; 74 | size_t v_element_cnt = n * hidden_dim; 75 | size_t o_element_cnt = m * n; 76 | 77 | make_random_float(host_q, q_element_cnt); 78 | make_random_float(host_k, k_element_cnt); 79 | make_random_float(host_v, v_element_cnt); 80 | 81 | Eigen::TensorMap> 82 | input_q_map((float *)host_q, m, head_num, head_dim); 83 | Eigen::TensorMap> 84 | input_k_map((float *)host_k, n, head_num, head_dim); 85 | Eigen::TensorMap> 86 | input_v_map((float *)host_v, n, head_num, head_dim); 87 | 88 | Eigen::TensorMap< 89 | Eigen::Tensor> 90 | input_q_fp16_map((Eigen::half *)host_q_f16, m, head_num, head_dim); 91 | input_q_fp16_map = input_q_map.cast(); 92 | input_q_map = input_q_fp16_map.cast(); 93 | 94 | Eigen::TensorMap< 95 | Eigen::Tensor> 96 | input_k_fp16_map((Eigen::half *)host_k_f16, n, head_num, head_dim); 97 | input_k_fp16_map = input_k_map.cast(); 98 | input_k_map = input_k_fp16_map.cast(); 99 | 100 | Eigen::TensorMap< 101 | Eigen::Tensor> 102 | input_v_fp16_map((Eigen::half *)host_v_f16, n, head_num, head_dim); 103 | input_v_fp16_map = input_v_map.cast(); 104 | input_v_map = input_v_fp16_map.cast(); 105 | 106 | float qk_scale = 1 / sqrtf(static_cast(head_dim)); 107 | 108 | // (bs, nh, seqlen, hd) @ (bs, nh, hd, cache_len+seqlen) => bs, nh, seqlen, 109 | // cache_len+seqlen 110 | Eigen::TensorMap> 111 | q_matmul_k_map(static_cast(host_qk), head_num, m, n); 112 | 113 | Eigen::Tensor q_emb_trans( 114 | head_num, m, head_dim); 115 | q_emb_trans = input_q_map.shuffle(Eigen::array({1, 0, 2})); 116 | Eigen::Tensor k_emb_trans( 117 | head_num, head_dim, n); 118 | k_emb_trans = input_k_map.shuffle(Eigen::array({1, 2, 0})); 119 | 120 | // todo make mask 121 | Eigen::TensorMap> 122 | mask_map(static_cast(host_mask), m, n); 123 | 124 | for (int row = 0; row < m; ++row) { 125 | for (int col = 0; col < n; ++col) { 126 | if (col > (row + offset)) { 127 | host_mask[row * n + col] = -std::numeric_limits::infinity(); 128 | } else { 129 | host_mask[row * n + col] = 0.0f; 130 | } 131 | } 132 | } 133 | 134 | // tensor contraction does not support batch matmul 135 | // need a for loop to bmm 136 | // https://gitlab.com/libeigen/eigen/-/issues/2449 137 | Eigen::array, 1> qk_product_dims = { 138 | Eigen::IndexPair(1, 0)}; 139 | for (int i = 0; i < head_num; ++i) { 140 | q_matmul_k_map.chip<0>(i) = q_emb_trans.chip<0>(i).contract( 141 | k_emb_trans.chip<0>(i), qk_product_dims) + 142 | mask_map; 143 | } 144 | 145 | q_matmul_k_map = (q_matmul_k_map * q_matmul_k_map.constant(qk_scale)); 146 | 147 | auto hs = head_num * m; 148 | 149 | Eigen::TensorMap> 150 | softmax_input_map(static_cast(host_qk), hs, n); 151 | 152 | Eigen::TensorMap> 153 | softmax_output_map(static_cast(host_softmax_output), head_num, m, 154 | n); 155 | 156 | auto softmax_input_max = softmax_input_map.maximum(Eigen::array{1}) 157 | .eval() 158 | .reshape(Eigen::array{hs, 1}) 159 | .broadcast(Eigen::array{1, n}); 160 | 161 | auto softmax_input_diff = 162 | (softmax_input_map - softmax_input_max).exp().eval(); 163 | 164 | auto softmax_input_sum = softmax_input_diff.sum(Eigen::array{1}) 165 | .eval() 166 | .reshape(Eigen::array{hs, 1}) 167 | .broadcast(Eigen::array{1, n}); 168 | 169 | softmax_output_map = (softmax_input_diff / softmax_input_sum) 170 | .reshape(std::array{head_num, m, n}); 171 | 172 | // (seq_length, n_heads, head_dim) -> (n_heads, seq_length, head_dim) 173 | auto vmap_trans = input_v_map.shuffle(Eigen::array({1, 0, 2})); 174 | 175 | Eigen::array, 1> output_product_dims = { 176 | Eigen::IndexPair(1, 0)}; 177 | // tmp_output: (n_heads, seq_length, head_dim) 178 | Eigen::TensorMap> 179 | tmp_output_map(static_cast(host_pv), head_num, m, head_dim); 180 | for (int i = 0; i < head_num; ++i) { 181 | tmp_output_map.chip<0>(i) = softmax_output_map.chip<0>(i).contract( 182 | vmap_trans.chip<0>(i), output_product_dims); 183 | } 184 | 185 | Eigen::TensorMap> 186 | tmp_output_tensor_map(static_cast(golden_output), m, hidden_dim); 187 | 188 | // tmp_output: (n_heads, seq_length, head_dim) -> (seq_length, n_heads, 189 | // head_dim) 190 | tmp_output_tensor_map = 191 | tmp_output_map.shuffle(Eigen::array({1, 0, 2})) 192 | .reshape(std::array{m, hidden_dim}); 193 | 194 | CHECK_ACL(aclrtMemcpy(dev_q_f16, m * hidden_dim * sizeof(aclFloat16), 195 | host_q_f16, m * hidden_dim * sizeof(aclFloat16), 196 | ACL_MEMCPY_HOST_TO_DEVICE)); 197 | CHECK_ACL(aclrtMemcpy(dev_k_f16, n * hidden_dim * sizeof(aclFloat16), 198 | host_k_f16, n * hidden_dim * sizeof(aclFloat16), 199 | ACL_MEMCPY_HOST_TO_DEVICE)); 200 | CHECK_ACL(aclrtMemcpy(dev_v_f16, n * hidden_dim * sizeof(aclFloat16), 201 | host_v_f16, n * hidden_dim * sizeof(aclFloat16), 202 | ACL_MEMCPY_HOST_TO_DEVICE)); 203 | CHECK_ACL(aclrtMemset(dev_o_f16, m * hidden_dim * sizeof(aclFloat16), 0, 204 | m * hidden_dim * sizeof(aclFloat16))); 205 | spdlog::info("launch kernel"); 206 | npu_flash_attn_layer(dev_o_f16, dev_q_f16, dev_k_f16, dev_v_f16, m, n, offset, 207 | head_num, head_dim, DT_FLOAT16, stream); 208 | 209 | CHECK_ACL(aclrtSynchronizeStream(stream)); 210 | 211 | CHECK_ACL(aclrtMemcpy(host_o_f16, m * hidden_dim * sizeof(aclFloat16), 212 | dev_o_f16, m * hidden_dim * sizeof(aclFloat16), 213 | ACL_MEMCPY_DEVICE_TO_HOST)); 214 | 215 | write_binary("fa_input_q_f16.bin", host_q_f16, 216 | m * hidden_dim * sizeof(aclFloat16)); 217 | write_binary("fa_input_k_f16.bin", host_k_f16, 218 | n * hidden_dim * sizeof(aclFloat16)); 219 | write_binary("fa_input_v_f16.bin", host_v_f16, 220 | n * hidden_dim * sizeof(aclFloat16)); 221 | 222 | write_binary("fa_out_f16.bin", host_o_f16, 223 | m * hidden_dim * sizeof(aclFloat16)); 224 | write_binary("fa_golden_out_fp32.bin", golden_output, 225 | m * hidden_dim * sizeof(aclFloat16)); 226 | 227 | Eigen::TensorMap> 228 | output_fp32_map((float *)host_o, m, hidden_dim); 229 | 230 | Eigen::TensorMap< 231 | Eigen::Tensor> 232 | output_fp16_map((Eigen::half *)host_o_f16, m, hidden_dim); 233 | output_fp32_map = output_fp16_map.cast(); 234 | 235 | return all_close(host_o, golden_output, m * hidden_dim); 236 | } 237 | 238 | void FlashAttentionOpTest::CleanUp() { 239 | delete[] host_q; 240 | delete[] host_k; 241 | delete[] host_qk; 242 | delete[] host_pv; 243 | delete[] host_mask; 244 | delete[] host_softmax_output; 245 | delete[] host_v; 246 | delete[] host_o; 247 | delete[] golden_output; 248 | delete[] host_q_f16; 249 | delete[] host_k_f16; 250 | delete[] host_v_f16; 251 | delete[] host_o_f16; 252 | 253 | CHECK_ACL(aclrtFree(dev_q_f16)); 254 | CHECK_ACL(aclrtFree(dev_k_f16)); 255 | CHECK_ACL(aclrtFree(dev_v_f16)); 256 | CHECK_ACL(aclrtFree(dev_o_f16)); 257 | } 258 | -------------------------------------------------------------------------------- /tests/flash_attn_test.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "npu_op_test_util.h" 3 | 4 | class FlashAttentionOpTest : public OpTestBase { 5 | public: 6 | void Init(size_t max_m, size_t max_n, size_t head_num, size_t head_dim); 7 | bool Run(size_t m, size_t n, size_t offset); 8 | void CleanUp(); 9 | 10 | size_t max_m{0}; 11 | size_t max_n{0}; 12 | 13 | size_t head_dim{0}; 14 | size_t head_num{0}; 15 | size_t hidden_dim{0}; 16 | 17 | size_t max_q_buffer_size; 18 | size_t max_k_buffer_size; 19 | size_t max_qk_buffer_size; 20 | size_t max_pv_buffer_size; 21 | size_t max_softmax_buffer_size; 22 | 23 | size_t max_v_buffer_size; 24 | size_t max_o_buffer_size; 25 | 26 | void *dev_q_f16{nullptr}; 27 | void *dev_k_f16{nullptr}; 28 | void *dev_v_f16{nullptr}; 29 | void *dev_o_f16{nullptr}; 30 | 31 | float *host_q{nullptr}; 32 | float *host_k{nullptr}; 33 | float *host_qk{nullptr}; 34 | float *host_pv{nullptr}; 35 | float *host_mask{nullptr}; 36 | float *host_softmax_output{nullptr}; 37 | 38 | 39 | float *host_v{nullptr}; 40 | float *host_o{nullptr}; 41 | Eigen::half *host_q_f16{nullptr}; 42 | Eigen::half *host_k_f16{nullptr}; 43 | Eigen::half *host_v_f16{nullptr}; 44 | Eigen::half *host_o_f16{nullptr}; 45 | 46 | float *golden_output{nullptr}; 47 | 48 | }; 49 | 50 | -------------------------------------------------------------------------------- /tests/gemm_awq_4bit_main.cpp: -------------------------------------------------------------------------------- 1 | #include "gemm_awq_4bit_test.h" 2 | 3 | namespace po = boost::program_options; 4 | 5 | int main(int argc, char **argv) { 6 | aclrtContext context; 7 | int32_t deviceId{0}; 8 | int m; 9 | int n; 10 | int k; 11 | bool bias; 12 | 13 | po::options_description desc("GemmAWQ4BitOpTest options"); 14 | desc.add_options()("help", "produce help message") // 15 | ("m", po::value(&m)->default_value(2048), "m. default:2048") // 16 | ("n", po::value(&n)->default_value(2048), "n. default:2048") // 17 | ("k", po::value(&k)->default_value(2048), "k. default:2048") 18 | ("bias", po::value(&bias)->default_value(false), "bias. default:false"); 19 | 20 | 21 | po::variables_map vm; 22 | po::store(po::parse_command_line(argc, argv, desc), vm); 23 | 24 | if (vm.count("help")) { 25 | std::cout << desc << "\n"; 26 | return 1; 27 | } 28 | po::notify(vm); 29 | 30 | CHECK_ACL(aclInit(nullptr)); 31 | CHECK_ACL(aclrtSetDevice(deviceId)); 32 | CHECK_ACL(aclrtCreateContext(&context, deviceId)); 33 | 34 | GemmAWQ4BitOpTest op_test; 35 | op_test.Init(m, n, k, bias); 36 | bool test_result = op_test.Run(m, n, k); 37 | op_test.CleanUp(); 38 | 39 | CHECK_ACL(aclrtDestroyContext(context)); 40 | CHECK_ACL(aclrtResetDevice(deviceId)); 41 | CHECK_ACL(aclFinalize()); 42 | 43 | if (test_result) { 44 | spdlog::info("test success"); 45 | } else { 46 | spdlog::error("test failed"); 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /tests/gemm_awq_4bit_test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | #include "gemm_awq_4bit_test.h" 17 | #include "npu_op_test_util.h" 18 | #include "npu_ops.h" 19 | 20 | void GemmAWQ4BitOpTest::Init(size_t max_m, size_t max_n, size_t max_k, 21 | bool bias) { 22 | this->max_m = max_m; 23 | this->max_n = max_n; 24 | this->max_k = max_k; 25 | this->bias = bias; 26 | 27 | max_lhs_buffer_size = max_m * max_k; 28 | max_weight_buffer_size = max_k * max_n; 29 | max_zero_buffer_size = max_k * max_n / group_size; 30 | max_scale_buffer_size = max_k * max_n / group_size; 31 | max_output_buffer_size = max_k * max_n; 32 | max_bias_buffer_size = max_n; 33 | 34 | host_lhs = new float[max_lhs_buffer_size]; 35 | host_rhs = new float[max_weight_buffer_size]; 36 | host_rhs_nz = new float[max_weight_buffer_size]; 37 | host_zero = new float[max_zero_buffer_size]; 38 | host_scale = new float[max_scale_buffer_size]; 39 | host_output = new float[max_output_buffer_size]; 40 | host_bias = new float[max_bias_buffer_size]; 41 | host_lhs_f16 = new Eigen::half[max_lhs_buffer_size]; 42 | host_weight_s4 = new uint8_t[max_weight_buffer_size / 2]; 43 | host_zero_f16 = new Eigen::half[max_zero_buffer_size]; 44 | host_scale_f16 = new Eigen::half[max_scale_buffer_size]; 45 | host_output_f16 = new Eigen::half[max_output_buffer_size]; 46 | golden_fp32 = new float[max_output_buffer_size]; 47 | 48 | CHECK_ACL(aclrtMalloc((void **)&dev_lhs_f16, 49 | max_lhs_buffer_size * sizeof(aclFloat16), 50 | ACL_MEM_MALLOC_HUGE_FIRST)); 51 | CHECK_ACL(aclrtMalloc((void **)&dev_weight_s4, 52 | max_weight_buffer_size / 2 * sizeof(uint8_t), 53 | ACL_MEM_MALLOC_HUGE_FIRST)); 54 | CHECK_ACL(aclrtMalloc((void **)&dev_zero_fp16, 55 | max_zero_buffer_size * sizeof(aclFloat16), 56 | ACL_MEM_MALLOC_HUGE_FIRST)); 57 | CHECK_ACL(aclrtMalloc((void **)&dev_scale_fp16, 58 | max_scale_buffer_size * sizeof(aclFloat16), 59 | ACL_MEM_MALLOC_HUGE_FIRST)); 60 | CHECK_ACL(aclrtMalloc((void **)&dev_output_f16, 61 | max_output_buffer_size * sizeof(aclFloat16), 62 | ACL_MEM_MALLOC_HUGE_FIRST)); 63 | CHECK_ACL(aclrtMalloc((void **)&dev_bias_f32, 64 | max_bias_buffer_size * sizeof(float), 65 | ACL_MEM_MALLOC_HUGE_FIRST)); 66 | } 67 | 68 | bool GemmAWQ4BitOpTest::Run(size_t m, size_t n, size_t k) { 69 | spdlog::info("GemmAWQ4BitOpTest::Run {{{},{},{}}}, bias={}", m, n, k, bias); 70 | 71 | size_t n1 = n / 16; 72 | size_t lhs_element_cnt = m * k; 73 | size_t rhs_element_cnt = n * k; 74 | size_t num_group = k / group_size; 75 | size_t zero_scale_cnt = n * num_group; 76 | size_t output_element_cnt = m * n; 77 | make_random_float(host_lhs, lhs_element_cnt); 78 | make_random_float_uint4(host_rhs, rhs_element_cnt); 79 | make_random_float_uint4(host_zero, zero_scale_cnt); 80 | make_random_float(host_scale, zero_scale_cnt); 81 | make_random_float(host_bias, n); 82 | 83 | Eigen::TensorMap> 84 | bias_fp32_map((float *)host_bias, 1, n); 85 | 86 | Eigen::TensorMap> 87 | golden_fp32_map((float *)golden_fp32, m, n); 88 | 89 | Eigen::TensorMap> 90 | input_lhs_map((float *)host_lhs, m, k); 91 | 92 | Eigen::TensorMap< 93 | Eigen::Tensor> 94 | input_lhs_fp16_map((Eigen::half *)host_lhs_f16, m, k); 95 | input_lhs_fp16_map = input_lhs_map.cast(); 96 | input_lhs_map = input_lhs_fp16_map.cast(); 97 | 98 | Eigen::TensorMap> 99 | input_rhs_map((float *)host_rhs, num_group, group_size, n); 100 | 101 | Eigen::TensorMap> 102 | input_rhs_nz_map((float *)host_rhs_nz, k / 16, n, 16); 103 | input_rhs_nz_map = 104 | input_rhs_map.reshape(Eigen::array{k / 16, 16, n}) 105 | .shuffle(Eigen::array({0, 2, 1})); 106 | 107 | Eigen::TensorMap> 108 | input_zero_fp32_map((float *)host_zero, num_group, 1, n); 109 | 110 | Eigen::TensorMap< 111 | Eigen::Tensor> 112 | input_zero_fp16_map((Eigen::half *)host_zero_f16, num_group, 1, n); 113 | 114 | input_zero_fp16_map = input_zero_fp32_map.cast(); 115 | input_zero_fp32_map = input_zero_fp16_map.cast(); 116 | 117 | // move weight offset to zero 118 | input_zero_fp16_map = input_zero_fp16_map - 119 | input_zero_fp32_map.constant(8.0f).cast(); 120 | 121 | Eigen::TensorMap> 122 | input_scale_fp32_map((float *)host_scale, num_group, 1, n); 123 | 124 | Eigen::TensorMap< 125 | Eigen::Tensor> 126 | input_scale_fp16_map((Eigen::half *)host_scale_f16, num_group, 1, n); 127 | 128 | input_scale_fp16_map = 129 | (input_scale_fp32_map / input_scale_fp32_map.constant(16.0f)) 130 | .cast(); 131 | input_scale_fp32_map = input_scale_fp16_map.cast(); 132 | 133 | auto float_to_u4 = [](float x) -> uint8_t { 134 | return (static_cast(x) + 8) & 0xf; 135 | }; 136 | 137 | // (x, 4, 64, 2) -> (x, 64, 4, 2) 138 | for (int i1 = 0; i1 < rhs_element_cnt / 512; ++i1) { 139 | int i1_stride_u8 = 4 * 64 * 2; 140 | int i1_stride_s4 = 4 * 64; 141 | for (int i2 = 0; i2 < 4; ++i2) { 142 | int i2_stride_u8 = 64 * 2; 143 | int i2_stride_s4 = 1; 144 | for (int i3 = 0; i3 < 64; ++i3) { 145 | int i3_stride_u8 = 2; 146 | int i3_stride_s4 = 4; 147 | int u8_offset = 148 | i1 * i1_stride_u8 + i2 * i2_stride_u8 + i3 * i3_stride_u8; 149 | host_weight_s4[i1 * i1_stride_s4 + i2 * i2_stride_s4 + 150 | i3 * i3_stride_s4] = 151 | (float_to_u4(host_rhs_nz[u8_offset])) | 152 | (float_to_u4(host_rhs_nz[u8_offset + 1]) << 4); 153 | } 154 | } 155 | } 156 | 157 | CHECK_ACL(aclrtMemcpy(dev_lhs_f16, lhs_element_cnt * sizeof(aclFloat16), 158 | host_lhs_f16, lhs_element_cnt * sizeof(aclFloat16), 159 | ACL_MEMCPY_HOST_TO_DEVICE)); 160 | CHECK_ACL(aclrtMemcpy(dev_weight_s4, rhs_element_cnt * sizeof(uint8_t) / 2, 161 | host_weight_s4, rhs_element_cnt * sizeof(uint8_t) / 2, 162 | ACL_MEMCPY_HOST_TO_DEVICE)); 163 | CHECK_ACL(aclrtMemcpy(dev_zero_fp16, zero_scale_cnt * sizeof(aclFloat16), 164 | host_zero_f16, zero_scale_cnt * sizeof(aclFloat16), 165 | ACL_MEMCPY_HOST_TO_DEVICE)); 166 | CHECK_ACL(aclrtMemcpy(dev_scale_fp16, zero_scale_cnt * sizeof(aclFloat16), 167 | host_scale_f16, zero_scale_cnt * sizeof(aclFloat16), 168 | ACL_MEMCPY_HOST_TO_DEVICE)); 169 | if (bias) { 170 | CHECK_ACL(aclrtMemcpy(dev_bias_f32, n * sizeof(float), host_bias, 171 | n * sizeof(float), ACL_MEMCPY_HOST_TO_DEVICE)); 172 | npu_matmul_nz_awq_4bit_bias_layer( 173 | dev_output_f16, dev_lhs_f16, dev_weight_s4, dev_zero_fp16, 174 | dev_scale_fp16, dev_bias_f32, m, n, k, DT_FLOAT16, stream); 175 | } else { 176 | 177 | npu_matmul_nz_awq_4bit_layer(dev_output_f16, dev_lhs_f16, dev_weight_s4, 178 | dev_zero_fp16, dev_scale_fp16, m, n, k, 179 | DT_FLOAT16, stream); 180 | } 181 | Eigen::array brc_dim = {1, group_size, 1}; 182 | 183 | auto tmp_expr = ((input_rhs_map - input_zero_fp32_map.broadcast(brc_dim)) 184 | .cast() 185 | .cast() * 186 | (input_scale_fp32_map.broadcast(brc_dim))) 187 | .reshape(Eigen::array{k, n}); 188 | 189 | Eigen::array, 1> product_dims = { 190 | Eigen::IndexPair(1, 0)}; 191 | golden_fp32_map = input_lhs_map.contract( 192 | tmp_expr.cast().cast(), product_dims); 193 | 194 | if (bias) { 195 | Eigen::array brc_dim = {m, 1}; 196 | 197 | golden_fp32_map = golden_fp32_map + bias_fp32_map.broadcast(brc_dim); 198 | } 199 | 200 | CHECK_ACL(aclrtSynchronizeStream(stream)); 201 | 202 | CHECK_ACL(aclrtMemcpy( 203 | host_output_f16, output_element_cnt * sizeof(aclFloat16), dev_output_f16, 204 | output_element_cnt * sizeof(aclFloat16), ACL_MEMCPY_DEVICE_TO_HOST)); 205 | 206 | Eigen::TensorMap> 207 | output_fp32_map((float *)host_output, m, n); 208 | 209 | Eigen::TensorMap< 210 | Eigen::Tensor> 211 | output_fp16_map((Eigen::half *)host_output_f16, m, n); 212 | output_fp32_map = output_fp16_map.cast(); 213 | 214 | return all_close(host_output, golden_fp32, output_element_cnt); 215 | } 216 | 217 | void GemmAWQ4BitOpTest::CleanUp() { 218 | delete[] host_lhs; 219 | delete[] host_rhs; 220 | delete[] host_zero; 221 | delete[] host_scale; 222 | delete[] host_output; 223 | delete[] host_lhs_f16; 224 | delete[] host_weight_s4; 225 | delete[] host_zero_f16; 226 | delete[] host_scale_f16; 227 | delete[] host_output_f16; 228 | delete[] host_bias; 229 | delete[] golden_fp32; 230 | 231 | CHECK_ACL(aclrtFree(dev_lhs_f16)); 232 | CHECK_ACL(aclrtFree(dev_weight_s4)); 233 | CHECK_ACL(aclrtFree(dev_zero_fp16)); 234 | CHECK_ACL(aclrtFree(dev_scale_fp16)); 235 | CHECK_ACL(aclrtFree(dev_output_f16)); 236 | CHECK_ACL(aclrtFree(dev_bias_f32)); 237 | } 238 | -------------------------------------------------------------------------------- /tests/gemm_awq_4bit_test.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "npu_op_test_util.h" 3 | 4 | class GemmAWQ4BitOpTest : public OpTestBase { 5 | public: 6 | void Init(size_t max_m, size_t max_n, size_t max_k, bool bias); 7 | bool Run(size_t m, size_t n, size_t k); 8 | void CleanUp(); 9 | 10 | size_t max_m{0}; 11 | size_t max_n{0}; 12 | size_t max_k{0}; 13 | bool bias{false}; 14 | 15 | size_t max_lhs_buffer_size{0}; 16 | size_t max_weight_buffer_size{0}; 17 | size_t max_zero_buffer_size{0}; 18 | size_t max_scale_buffer_size{0}; 19 | size_t max_output_buffer_size{0}; 20 | size_t max_bias_buffer_size{0}; 21 | 22 | size_t max_ffn_hidden{0}; 23 | size_t group_size{128}; 24 | 25 | void *dev_lhs_f16{nullptr}; 26 | void *dev_weight_s4{nullptr}; 27 | void *dev_zero_fp16{nullptr}; 28 | void *dev_scale_fp16{nullptr}; 29 | void *dev_output_f16{nullptr}; 30 | void *dev_bias_f32{nullptr}; 31 | 32 | float *host_lhs{nullptr}; 33 | float *host_rhs{nullptr}; 34 | float *host_rhs_nz{nullptr}; 35 | float *host_zero{nullptr}; 36 | float *host_scale{nullptr}; 37 | float *host_output{nullptr}; 38 | float *host_bias{nullptr}; 39 | Eigen::half *host_lhs_f16{nullptr}; 40 | uint8_t *host_weight_s4{nullptr}; 41 | Eigen::half *host_zero_f16{nullptr}; 42 | Eigen::half *host_scale_f16{nullptr}; 43 | Eigen::half *host_output_f16{nullptr}; 44 | float *golden_fp32{nullptr}; 45 | }; 46 | -------------------------------------------------------------------------------- /tests/gemm_main.cpp: -------------------------------------------------------------------------------- 1 | #include "gemm_test.h" 2 | #include 3 | 4 | namespace po = boost::program_options; 5 | 6 | template bool TestFn(int m, int n, int k, bool bias) { 7 | aclrtContext context; 8 | int32_t deviceId{0}; 9 | 10 | CHECK_ACL(aclInit(nullptr)); 11 | CHECK_ACL(aclrtSetDevice(deviceId)); 12 | CHECK_ACL(aclrtCreateContext(&context, deviceId)); 13 | 14 | GemmOpTest op_test; 15 | op_test.Init(m, n, k, bias); 16 | bool test_result = op_test.Run(m, n, k); 17 | op_test.CleanUp(); 18 | 19 | CHECK_ACL(aclrtDestroyContext(context)); 20 | CHECK_ACL(aclrtResetDevice(deviceId)); 21 | CHECK_ACL(aclFinalize()); 22 | return test_result; 23 | } 24 | 25 | int main(int argc, char **argv) { 26 | aclrtContext context; 27 | int32_t deviceId{0}; 28 | int m; 29 | int n; 30 | int k; 31 | bool bias; 32 | std::string dtype_str; 33 | 34 | po::options_description desc("GemmAWQ4BitOpTest options"); 35 | desc.add_options()("help", "produce help message") // 36 | ("m", po::value(&m)->default_value(2048), "m. default:2048") // 37 | ("n", po::value(&n)->default_value(2048), "n. default:2048") // 38 | ("k", po::value(&k)->default_value(2048), "k. default:2048") // 39 | ("bias", po::value(&bias)->default_value(false), 40 | "bias, default:false") // 41 | ("dtype", po::value(&dtype_str)->required(), 42 | "dtype. float16 or bfloat16"); 43 | 44 | po::variables_map vm; 45 | po::store(po::parse_command_line(argc, argv, desc), vm); 46 | 47 | if (vm.count("help")) { 48 | std::cout << desc << "\n"; 49 | return 1; 50 | } 51 | po::notify(vm); 52 | 53 | bool test_result = false; 54 | 55 | if (dtype_str == "float16") { 56 | test_result = TestFn(m, n, k, bias); 57 | } else if (dtype_str == "bfloat16") { 58 | test_result = TestFn(m, n, k, bias); 59 | } 60 | 61 | if (test_result) { 62 | spdlog::info("test success"); 63 | } else { 64 | spdlog::error("test failed"); 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /tests/gemm_test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | #include "gemm_test.h" 17 | #include "npu_op_test_util.h" 18 | #include "npu_ops.h" 19 | 20 | template 21 | void GemmOpTest::Init(size_t max_m, size_t max_n, size_t max_k, 22 | bool bias) { 23 | this->max_m = max_m; 24 | this->max_n = max_n; 25 | this->max_k = max_k; 26 | this->bias = bias; 27 | 28 | max_lhs_buffer_size = max_m * max_k; 29 | max_rhs_buffer_size = max_k * max_n; 30 | max_bias_buffer_size = max_n; 31 | max_output_buffer_size = max_m * max_n; 32 | 33 | host_lhs = new float[max_lhs_buffer_size]; 34 | host_rhs = new float[max_rhs_buffer_size]; 35 | host_bias = new float[max_bias_buffer_size]; 36 | host_output = new float[max_output_buffer_size]; 37 | host_lhs_b16 = new EigenTy[max_lhs_buffer_size]; 38 | host_rhs_b16 = new EigenTy[max_rhs_buffer_size]; 39 | host_rhs_nz_b16 = new EigenTy[max_rhs_buffer_size]; 40 | host_output_b16 = new EigenTy[max_output_buffer_size]; 41 | host_golden_b16 = new EigenTy[max_output_buffer_size]; 42 | golden_fp32 = new float[max_output_buffer_size]; 43 | 44 | CHECK_ACL(aclrtMalloc((void **)&dev_lhs, 45 | max_lhs_buffer_size * sizeof(aclFloat16), 46 | ACL_MEM_MALLOC_HUGE_FIRST)); 47 | CHECK_ACL(aclrtMalloc((void **)&dev_rhs, 48 | max_rhs_buffer_size * sizeof(aclFloat16), 49 | ACL_MEM_MALLOC_HUGE_FIRST)); 50 | CHECK_ACL(aclrtMalloc((void **)&dev_bias, 51 | max_bias_buffer_size * sizeof(float), 52 | ACL_MEM_MALLOC_HUGE_FIRST)); 53 | CHECK_ACL(aclrtMalloc((void **)&dev_output, 54 | max_output_buffer_size * sizeof(aclFloat16), 55 | ACL_MEM_MALLOC_HUGE_FIRST)); 56 | } 57 | 58 | template 59 | bool GemmOpTest::Run(size_t m, size_t n, size_t k) { 60 | spdlog::info("{} {{m={},n={},k={},bias={}}} dtype {}", __PRETTY_FUNCTION__, m, 61 | n, k, bias, GetDataType()); 62 | 63 | size_t n1 = n / 16; 64 | size_t lhs_element_cnt = m * k; 65 | size_t rhs_element_cnt = n * k; 66 | size_t output_element_cnt = m * n; 67 | size_t bias_element_cnt = n; 68 | make_random_float(host_lhs, lhs_element_cnt); 69 | make_random_float(host_bias, bias_element_cnt); 70 | make_random_float(host_rhs, rhs_element_cnt); 71 | 72 | Eigen::TensorMap> 73 | bias_fp32_map((float *)host_bias, 1, n); 74 | 75 | Eigen::TensorMap> 76 | golden_fp32_map((float *)golden_fp32, m, n); 77 | 78 | Eigen::TensorMap> 79 | input_lhs_map((float *)host_lhs, m, k); 80 | 81 | Eigen::TensorMap< 82 | Eigen::Tensor> 83 | input_lhs_b16_map((EigenTy *)host_lhs_b16, m, k); 84 | input_lhs_b16_map = input_lhs_map.cast(); 85 | input_lhs_map = input_lhs_b16_map.template cast(); 86 | 87 | Eigen::TensorMap> 88 | input_rhs_map((float *)host_rhs, k, n); 89 | 90 | Eigen::TensorMap< 91 | Eigen::Tensor> 92 | input_rhs_b16_map((EigenTy *)host_rhs_b16, k, n); 93 | input_rhs_b16_map = input_rhs_map.cast(); 94 | input_rhs_map = input_rhs_b16_map.template cast(); 95 | 96 | Eigen::TensorMap< 97 | Eigen::Tensor> 98 | input_rhs_b16_nz_map((EigenTy *)host_rhs_nz_b16, n / 16, k, 16); 99 | input_rhs_b16_nz_map = 100 | input_rhs_b16_map.reshape(Eigen::array{k, n / 16, 16}) 101 | .shuffle(Eigen::array({1, 0, 2})); 102 | 103 | // TODO init bias 104 | CHECK_ACL(aclrtMemcpy(dev_lhs, lhs_element_cnt * sizeof(aclFloat16), 105 | host_lhs_b16, lhs_element_cnt * sizeof(aclFloat16), 106 | ACL_MEMCPY_HOST_TO_DEVICE)); 107 | CHECK_ACL(aclrtMemcpy(dev_rhs, rhs_element_cnt * sizeof(aclFloat16), 108 | host_rhs_nz_b16, rhs_element_cnt * sizeof(aclFloat16), 109 | ACL_MEMCPY_HOST_TO_DEVICE)); 110 | CHECK_ACL(aclrtMemcpy(dev_bias, bias_element_cnt * sizeof(float), host_bias, 111 | bias_element_cnt * sizeof(float), 112 | ACL_MEMCPY_HOST_TO_DEVICE)) 113 | 114 | aclrtEvent start_event, end_event; 115 | CHECK_ACL(aclrtCreateEvent(&start_event)); 116 | CHECK_ACL(aclrtCreateEvent(&end_event)); 117 | CHECK_ACL(aclrtRecordEvent(start_event, this->stream)); 118 | 119 | if (bias) { 120 | npu_matmul_bias_nz_layer(dev_output, dev_lhs, dev_rhs, dev_bias, m, n, k, 121 | GetDataType(), this->stream); 122 | } else { 123 | npu_matmul_nz_layer(dev_output, dev_lhs, dev_rhs, m, n, k, 124 | GetDataType(), this->stream); 125 | } 126 | CHECK_ACL(aclrtRecordEvent(end_event, this->stream)); 127 | CHECK_ACL(aclrtSynchronizeStream(this->stream)); 128 | float duration_ms; 129 | CHECK_ACL(aclrtEventElapsedTime(&duration_ms, start_event, end_event)); 130 | CHECK_ACL(aclrtDestroyEvent(start_event)); 131 | CHECK_ACL(aclrtDestroyEvent(end_event)); 132 | 133 | spdlog::info("kernel duration {}ms", duration_ms); 134 | 135 | Eigen::array, 1> product_dims = { 136 | Eigen::IndexPair(1, 0)}; 137 | 138 | golden_fp32_map = input_lhs_map.contract(input_rhs_map, product_dims); 139 | if (bias) { 140 | Eigen::array brc_dim = {m, 1}; 141 | 142 | golden_fp32_map = golden_fp32_map + bias_fp32_map.broadcast(brc_dim); 143 | } 144 | 145 | Eigen::TensorMap< 146 | Eigen::Tensor> 147 | golden_b16_map((EigenTy *)host_golden_b16, m, n); 148 | golden_b16_map = golden_fp32_map.cast(); 149 | golden_fp32_map = golden_b16_map.template cast(); 150 | 151 | CHECK_ACL(aclrtSynchronizeStream(this->stream)); 152 | 153 | CHECK_ACL(aclrtMemcpy( 154 | host_output_b16, output_element_cnt * sizeof(aclFloat16), dev_output, 155 | output_element_cnt * sizeof(aclFloat16), ACL_MEMCPY_DEVICE_TO_HOST)); 156 | 157 | Eigen::TensorMap> 158 | output_fp32_map((float *)host_output, m, n); 159 | 160 | Eigen::TensorMap< 161 | Eigen::Tensor> 162 | output_b16_map((EigenTy *)host_output_b16, m, n); 163 | output_fp32_map = output_b16_map.template cast(); 164 | 165 | return all_close(host_output, golden_fp32, output_element_cnt, 0.01, 0.01); 166 | } 167 | 168 | template void GemmOpTest::CleanUp() { 169 | delete[] host_lhs; 170 | delete[] host_rhs; 171 | delete[] host_bias; 172 | delete[] host_output; 173 | delete[] host_lhs_b16; 174 | delete[] host_rhs_b16; 175 | delete[] host_rhs_nz_b16; 176 | delete[] host_output_b16; 177 | delete[] host_golden_b16; 178 | delete[] golden_fp32; 179 | 180 | CHECK_ACL(aclrtFree(dev_lhs)); 181 | CHECK_ACL(aclrtFree(dev_rhs)); 182 | CHECK_ACL(aclrtFree(dev_bias)); 183 | CHECK_ACL(aclrtFree(dev_output)); 184 | } 185 | 186 | template class GemmOpTest; 187 | 188 | template class GemmOpTest; 189 | -------------------------------------------------------------------------------- /tests/gemm_test.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "npu_ops.h" 4 | #include "npu_op_test_util.h" 5 | 6 | template 7 | class GemmOpTest : public OpTestBase> { 8 | public: 9 | void Init(size_t max_m, size_t max_n, size_t max_k, bool bias); 10 | bool Run(size_t m, size_t n, size_t k); 11 | void CleanUp(); 12 | 13 | size_t max_m{0}; 14 | size_t max_n{0}; 15 | size_t max_k{0}; 16 | bool bias{false}; 17 | 18 | size_t max_lhs_buffer_size{0}; 19 | size_t max_rhs_buffer_size{0}; 20 | size_t max_bias_buffer_size{0}; 21 | size_t max_output_buffer_size{0}; 22 | 23 | 24 | void *dev_lhs{nullptr}; 25 | void *dev_rhs{nullptr}; 26 | void *dev_bias{nullptr}; 27 | void *dev_output{nullptr}; 28 | 29 | float *host_lhs{nullptr}; 30 | float *host_rhs{nullptr}; 31 | float *host_bias{nullptr}; 32 | float *host_output{nullptr}; 33 | EigenTy *host_lhs_b16{nullptr}; 34 | EigenTy *host_rhs_b16{nullptr}; 35 | EigenTy *host_rhs_nz_b16{nullptr}; 36 | 37 | EigenTy *host_output_b16{nullptr}; 38 | EigenTy *host_golden_b16{nullptr}; 39 | float *golden_fp32{nullptr}; 40 | }; 41 | -------------------------------------------------------------------------------- /tests/npu_op_test_util.cpp: -------------------------------------------------------------------------------- 1 | #include "npu_op_test_util.h" 2 | 3 | void make_random_float(float *buffer, size_t size) { 4 | std::random_device rd; 5 | std::mt19937 gen(rd()); 6 | std::uniform_real_distribution<> dis(-1.0, 1.0); 7 | 8 | for (size_t i = 0; i < size; ++i) { 9 | buffer[i] = dis(gen); 10 | } 11 | } 12 | 13 | void make_random_float_uint4(float *buffer, size_t size) { 14 | std::random_device rd; 15 | std::mt19937 gen(rd()); 16 | std::uniform_int_distribution<> dis(0, 15); 17 | 18 | for (size_t i = 0; i < size; ++i) { 19 | buffer[i] = static_cast(dis(gen)); 20 | } 21 | } 22 | 23 | void make_random_bytes(void* ptr, std::size_t size) { 24 | unsigned char* byte_ptr = static_cast(ptr); 25 | std::mt19937 generator(std::random_device{}()); 26 | std::uniform_int_distribution distribution(0, 255); 27 | 28 | for (std::size_t i = 0; i < size; ++i) { 29 | byte_ptr[i] = static_cast(distribution(generator)); 30 | } 31 | } 32 | 33 | bool all_close(float *output_buffer, float *golden_buffer, size_t size, 34 | float abs_err, float relative_err) { 35 | for (size_t i = 0; i < size; ++i) { 36 | float a = output_buffer[i]; 37 | float b = golden_buffer[i]; 38 | 39 | float abs_diff = std::fabs(a - b); 40 | float max_abs_val = std::max(std::fabs(a), std::fabs(b)); 41 | if (abs_diff > abs_err && (abs_diff / max_abs_val) > relative_err) { 42 | std::cout << "all_close failed, output [" << i << "] :" << a << " vs " 43 | << b << std::endl; 44 | return false; 45 | } 46 | } 47 | return true; 48 | } 49 | 50 | bool all_close2(float *output_buffer, float *golden_buffer, size_t size, 51 | float abs_err, float relative_err) { 52 | size_t failed = 0; 53 | for (size_t i = 0; i < size; ++i) { 54 | float a = output_buffer[i]; 55 | float b = golden_buffer[i]; 56 | 57 | float abs_diff = std::fabs(a - b); 58 | float max_abs_val = std::max(std::fabs(a), std::fabs(b)); 59 | if (abs_diff > abs_err && (abs_diff / max_abs_val) > relative_err) { 60 | failed += 1; 61 | } 62 | } 63 | return (static_cast(failed) / static_cast(size)) < 64 | relative_err; 65 | } 66 | 67 | bool all_close_inf(float *output_buffer, float *golden_buffer, size_t size) { 68 | for (size_t i = 0; i < size; ++i) { 69 | float a = output_buffer[i]; 70 | float b = golden_buffer[i]; 71 | 72 | if (std::isinf(b)) { 73 | if (std::isinf(a) && std::signbit(a) == std::signbit(b)) { 74 | continue; 75 | } 76 | std::cout << "all_close failed, output [" << i << "] :" << a << " vs " 77 | << b << std::endl; 78 | return false; 79 | } 80 | 81 | float abs_diff = std::fabs(a - b); 82 | float max_abs_val = std::max(std::fabs(a), std::fabs(b)); 83 | 84 | if (abs_diff > 0.001f && (abs_diff / max_abs_val) > 0.001f) { 85 | std::cout << "all_close failed, output [" << i << "] :" << a << " vs " 86 | << b << std::endl; 87 | return false; 88 | } 89 | } 90 | return true; 91 | } 92 | 93 | void read_binary(const char *path, void *data, size_t size) { 94 | std::ifstream ifs(path, std::ios::binary); 95 | if (!ifs) { 96 | std::cout << "failed to open " << path << std::endl; 97 | } 98 | ifs.read((char *)data, size); 99 | } 100 | 101 | void write_binary(const char *path, void *data, size_t size) { 102 | std::ofstream ofs(path, std::ios::binary); 103 | if (!ofs) { 104 | std::cout << "failed to open " << path << std::endl; 105 | } 106 | ofs.write((char *)data, size); 107 | } 108 | 109 | void InitFreqCIS(float *freq_cis, int head_dim, int max_seq_len) { 110 | const float theta = 10000.0f; 111 | int freq_len = head_dim / 2; 112 | float *freq = new float[freq_len]; 113 | 114 | for (int i = 0; i < freq_len; ++i) { 115 | freq[i] = 116 | 1.0f / 117 | (powf(theta, static_cast(i * 2) / static_cast(head_dim))); 118 | } 119 | 120 | float *t = new float[max_seq_len]; 121 | for (int i = 0; i < max_seq_len; ++i) { 122 | t[i] = static_cast(i); 123 | } 124 | 125 | float *freq_outer = new float[freq_len * max_seq_len]; 126 | 127 | // max_seq_len row, freq_len column 128 | for (int i = 0; i < max_seq_len; ++i) { 129 | for (int j = 0; j < freq_len; ++j) { 130 | freq_outer[i * freq_len + j] = t[i] * freq[j]; 131 | } 132 | } 133 | 134 | for (int i = 0; i < max_seq_len * freq_len; ++i) { 135 | freq_cis[i * 2] = std::cos(freq_outer[i]); 136 | freq_cis[i * 2 + 1] = std::sin(freq_outer[i]); 137 | } 138 | 139 | delete[] freq; 140 | delete[] t; 141 | delete[] freq_outer; 142 | } 143 | 144 | 145 | 146 | -------------------------------------------------------------------------------- /tests/npu_op_test_util.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | #include "Eigen/src/Core/arch/Default/BFloat16.h" 20 | #include "defs.hpp" 21 | #include "npu_ops.h" 22 | 23 | #define CHECK_ACL(x) \ 24 | do { \ 25 | aclError __ret = x; \ 26 | if (__ret != ACL_ERROR_NONE) { \ 27 | std::cerr << __FILE__ << ":" << __LINE__ << " aclError:" << __ret \ 28 | << std::endl; \ 29 | } \ 30 | } while (0); 31 | 32 | void make_random_float(float *buffer, size_t size); 33 | 34 | void make_random_float_uint4(float *buffer, size_t size); 35 | 36 | void make_random_bytes(void* ptr, std::size_t size); 37 | 38 | bool all_close(float *output_buffer, float *golden_buffer, size_t size, 39 | float abs_err = 0.001f, float relative_err = 0.001f); 40 | 41 | bool all_close2(float *output_buffer, float *golden_buffer, size_t size, 42 | float abs_err = 0.001f, float relative_err = 0.001f); 43 | 44 | bool all_close_inf(float *output_buffer, float *golden_buffer, size_t size); 45 | 46 | void read_binary(const char *path, void *data, size_t size); 47 | void write_binary(const char *path, void *data, size_t size); 48 | 49 | class OpTestTensor { 50 | public: 51 | void *host_buffer{nullptr}; 52 | void *dev_buffer{nullptr}; 53 | }; 54 | 55 | template class OpTestBase { 56 | public: 57 | template void Init(Args... args) { 58 | CHECK_ACL(aclrtCreateStream(&stream)); 59 | static_cast(this)->Init(args...); 60 | } 61 | 62 | template bool Run(Args... args) { 63 | return static_cast(this)->Run(args...); 64 | } 65 | 66 | void CleanUp() { 67 | static_cast(this)->CleanUp(); 68 | CHECK_ACL(aclrtDestroyStream(stream)); 69 | } 70 | 71 | aclrtStream stream = nullptr; 72 | }; 73 | 74 | template constexpr DataType GetDataType(); 75 | 76 | template <> constexpr DataType GetDataType() { return DT_FLOAT16; } 77 | 78 | template <> constexpr DataType GetDataType() { 79 | return DT_BFLOAT16; 80 | } 81 | 82 | void InitFreqCIS(float *freq_cis, int head_dim, int max_seq_len); 83 | 84 | 85 | -------------------------------------------------------------------------------- /tests/rms_norm_layer_main.cpp: -------------------------------------------------------------------------------- 1 | #include "rms_norm_layer_test.h" 2 | 3 | namespace po = boost::program_options; 4 | 5 | int main(int argc, char **argv) { 6 | aclrtContext context; 7 | int32_t deviceId{0}; 8 | int first_dim; 9 | int last_dim; 10 | float eps; 11 | 12 | po::options_description desc("RMSNormOpTest options"); 13 | desc.add_options()("help", "produce help message")( 14 | "first_dim", po::value(&first_dim)->default_value(2048), 15 | "first_dim. default:2048")("last_dim", 16 | po::value(&last_dim)->default_value(2048), 17 | "last_dim. default:2048")( 18 | "eps", po::value(&eps)->default_value(1e-5), "eps. default:1e-5"); 19 | po::variables_map vm; 20 | po::store(po::parse_command_line(argc, argv, desc), vm); 21 | 22 | if (vm.count("help")) { 23 | std::cout << desc << "\n"; 24 | return 1; 25 | } 26 | po::notify(vm); 27 | 28 | CHECK_ACL(aclInit(nullptr)); 29 | CHECK_ACL(aclrtSetDevice(deviceId)); 30 | CHECK_ACL(aclrtCreateContext(&context, deviceId)); 31 | 32 | RMSNormOpTest op_test; 33 | op_test.Init(first_dim, last_dim); 34 | bool test_result = op_test.Run(first_dim, last_dim, eps); 35 | op_test.CleanUp(); 36 | 37 | CHECK_ACL(aclrtDestroyContext(context)); 38 | CHECK_ACL(aclrtResetDevice(deviceId)); 39 | CHECK_ACL(aclFinalize()); 40 | 41 | if (test_result) { 42 | spdlog::info("test success"); 43 | } else { 44 | spdlog::error("test failed"); 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /tests/rms_norm_layer_test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include "npu_op_test_util.h" 16 | #include "npu_ops.h" 17 | #include "rms_norm_layer_test.h" 18 | 19 | void RMSNormOpTest::Init(size_t max_first_dim, size_t max_last_dim) { 20 | this->max_first_dim = max_first_dim; 21 | this->max_last_dim = max_last_dim; 22 | 23 | size_t max_buffer_size = max_first_dim * max_last_dim; 24 | 25 | host_input = new float[max_buffer_size]; 26 | host_output = new float[max_buffer_size]; 27 | host_input_f16 = new Eigen::half[max_buffer_size]; 28 | host_output_f16 = new Eigen::half[max_buffer_size]; 29 | golden_fp32 = new float[max_buffer_size]; 30 | weight_fp32 = new float[max_last_dim]; 31 | weight_fp16 = new Eigen::half[max_last_dim]; 32 | 33 | CHECK_ACL(aclrtMalloc((void **)&dev_input_f16, 34 | max_buffer_size * sizeof(aclFloat16), 35 | ACL_MEM_MALLOC_HUGE_FIRST)); 36 | CHECK_ACL(aclrtMalloc((void **)&dev_output_f16, 37 | max_buffer_size * sizeof(aclFloat16), 38 | ACL_MEM_MALLOC_HUGE_FIRST)); 39 | CHECK_ACL(aclrtMalloc((void **)&dev_weight_f16, 40 | max_last_dim * sizeof(aclFloat16), 41 | ACL_MEM_MALLOC_HUGE_FIRST)); 42 | } 43 | 44 | bool RMSNormOpTest::Run(size_t first_dim, size_t last_dim, float eps) { 45 | spdlog::info("RMSNormOpTest::Run {{{},{}}} eps:{}", first_dim, last_dim, eps); 46 | 47 | if (last_dim % 16 != 0) { 48 | spdlog::critical("last dim {} is not aligned to 16!", last_dim); 49 | return false; 50 | } 51 | 52 | if (first_dim > max_first_dim) { 53 | spdlog::critical("{} > {}", first_dim, max_first_dim); 54 | return false; 55 | } 56 | 57 | if (last_dim > max_last_dim) { 58 | spdlog::critical("{} > {}", last_dim, max_last_dim); 59 | return false; 60 | } 61 | 62 | int total_element_cnt = first_dim * last_dim; 63 | make_random_float(host_input, total_element_cnt); 64 | make_random_float(weight_fp32, last_dim); 65 | 66 | Eigen::TensorMap> 67 | golden_fp32_map((float *)golden_fp32, first_dim, last_dim); 68 | 69 | Eigen::TensorMap> 70 | input_map((float *)host_input, first_dim, last_dim); 71 | 72 | Eigen::TensorMap< 73 | Eigen::Tensor> 74 | input_fp16_map((Eigen::half *)host_input_f16, first_dim, last_dim); 75 | input_fp16_map = input_map.cast(); 76 | input_map = input_fp16_map.cast(); 77 | 78 | Eigen::TensorMap> 79 | weight_map((float *)weight_fp32, 1, last_dim); 80 | 81 | Eigen::TensorMap< 82 | Eigen::Tensor> 83 | weight_fp16_map((Eigen::half *)weight_fp16, 1, last_dim); 84 | weight_fp16_map = weight_map.cast(); 85 | weight_map = weight_fp16_map.cast(); 86 | 87 | std::array mean_dims = {1}; 88 | Eigen::Tensor mean = 89 | (input_map * input_map) 90 | .mean(mean_dims) 91 | .eval() 92 | .reshape(std::array{(long)first_dim, 1L}); 93 | Eigen::Tensor 94 | sqrt_mean_add_eps = 95 | (mean + mean.constant(eps)) 96 | .sqrt() 97 | .eval() 98 | .reshape(std::array{(long int)first_dim, 1}); 99 | golden_fp32_map = 100 | (input_map / 101 | sqrt_mean_add_eps.broadcast(std::array{1, last_dim}) * 102 | weight_map.broadcast(std::array{first_dim, 1})); 103 | 104 | CHECK_ACL(aclrtMemcpy(dev_input_f16, total_element_cnt * sizeof(aclFloat16), 105 | host_input_f16, total_element_cnt * sizeof(aclFloat16), 106 | ACL_MEMCPY_HOST_TO_DEVICE)); 107 | 108 | CHECK_ACL(aclrtMemcpy(dev_weight_f16, last_dim * sizeof(aclFloat16), weight_fp16, 109 | last_dim * sizeof(aclFloat16), ACL_MEMCPY_HOST_TO_DEVICE)); 110 | 111 | npu_rmsnorm_layer(dev_output_f16, dev_weight_f16, dev_input_f16, first_dim, 112 | last_dim, eps, DT_FLOAT16, stream); 113 | CHECK_ACL(aclrtSynchronizeStream(stream)); 114 | 115 | CHECK_ACL(aclrtMemcpy(host_output_f16, total_element_cnt * sizeof(aclFloat16), 116 | dev_output_f16, total_element_cnt * sizeof(aclFloat16), 117 | ACL_MEMCPY_DEVICE_TO_HOST)); 118 | 119 | Eigen::TensorMap> 120 | output_fp32_map((float *)host_output, first_dim, last_dim); 121 | 122 | Eigen::TensorMap< 123 | Eigen::Tensor> 124 | output_fp16_map((Eigen::half *)host_output_f16, first_dim, last_dim); 125 | output_fp32_map = output_fp16_map.cast(); 126 | return all_close(host_output, golden_fp32, total_element_cnt); 127 | } 128 | 129 | void RMSNormOpTest::CleanUp() { 130 | 131 | delete[] host_input; 132 | delete[] host_output; 133 | delete[] host_input_f16; 134 | delete[] host_output_f16; 135 | delete[] golden_fp32; 136 | delete[] weight_fp32; 137 | delete[] weight_fp16; 138 | 139 | CHECK_ACL(aclrtFree(dev_input_f16)); 140 | CHECK_ACL(aclrtFree(dev_output_f16)); 141 | CHECK_ACL(aclrtFree(dev_weight_f16)); 142 | } 143 | 144 | 145 | -------------------------------------------------------------------------------- /tests/rms_norm_layer_test.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "npu_op_test_util.h" 4 | 5 | class RMSNormOpTest : public OpTestBase { 6 | public: 7 | void Init(size_t max_first_dim, size_t max_last_dim); 8 | bool Run(size_t first_dim, size_t last_dim, float eps); 9 | void CleanUp(); 10 | 11 | size_t max_first_dim{0}; 12 | size_t max_last_dim{0}; 13 | 14 | void *dev_input_f16{nullptr}; 15 | void *dev_output_f16{nullptr}; 16 | void *dev_weight_f16{nullptr}; 17 | 18 | float *host_input{nullptr}; 19 | float *host_output{nullptr}; 20 | Eigen::half *host_input_f16{nullptr}; 21 | Eigen::half *host_output_f16{nullptr}; 22 | float *golden_fp32{nullptr}; 23 | float *weight_fp32{nullptr}; 24 | Eigen::half *weight_fp16{nullptr}; 25 | }; 26 | -------------------------------------------------------------------------------- /tests/rope_single_layer_main.cpp: -------------------------------------------------------------------------------- 1 | #include "rope_single_layer_test.h" 2 | 3 | namespace po = boost::program_options; 4 | 5 | template bool TestFn(size_t max_seq_length, size_t head_dim, 6 | size_t head_num, size_t seq_len, size_t offset, bool is_neox) { 7 | aclrtContext context; 8 | int32_t deviceId{0}; 9 | 10 | CHECK_ACL(aclInit(nullptr)); 11 | CHECK_ACL(aclrtSetDevice(deviceId)); 12 | CHECK_ACL(aclrtCreateContext(&context, deviceId)); 13 | 14 | RoPESingleOpTest op_test; 15 | op_test.Init(max_seq_length, head_dim, head_num, is_neox); 16 | bool test_result = op_test.Run(offset, seq_len); 17 | op_test.CleanUp(); 18 | 19 | CHECK_ACL(aclrtDestroyContext(context)); 20 | CHECK_ACL(aclrtResetDevice(deviceId)); 21 | CHECK_ACL(aclFinalize()); 22 | return test_result; 23 | } 24 | 25 | int main(int argc, char **argv) { 26 | aclrtContext context; 27 | int32_t deviceId{0}; 28 | int max_seq_length; 29 | int head_dim; 30 | int head_num; 31 | int seq_len; 32 | int offset; 33 | bool is_neox; 34 | std::string dtype_str; 35 | 36 | // clang-format off 37 | po::options_description desc("RopeSingleLayer options"); 38 | desc.add_options() 39 | ("help", "produce help message") 40 | ("max_seq_length", po::value(&max_seq_length)->default_value(4096), "max_seq_length. default:4096") 41 | ("head_dim", po::value(&head_dim)->default_value(128), "head_dim. default:128") 42 | ("head_num", po::value(&head_num)->default_value(32), "head_num. default:32") 43 | ("seq_len", po::value(&seq_len)->default_value(32), "seq_len. default:32") 44 | ("is_neox", po::value(&is_neox)->default_value(true), "is_neox, default:true") 45 | ("offset", po::value(&offset)->default_value(0), "offset, default:0") 46 | ("dtype", po::value(&dtype_str)->required(), "dtype. float16 or bfloat16"); 47 | 48 | // clang-format on 49 | po::variables_map vm; 50 | po::store(po::parse_command_line(argc, argv, desc), vm); 51 | 52 | if (vm.count("help")) { 53 | std::cout << desc << "\n"; 54 | return 1; 55 | } 56 | po::notify(vm); 57 | 58 | bool test_result = false; 59 | 60 | if (dtype_str == "float16") { 61 | test_result = TestFn(max_seq_length, head_dim, head_num, seq_len, offset, is_neox); 62 | } else if (dtype_str == "bfloat16") { 63 | test_result = TestFn(max_seq_length, head_dim, head_num, seq_len, offset, is_neox);; 64 | } 65 | 66 | 67 | if (test_result) { 68 | spdlog::info("test success"); 69 | } else { 70 | spdlog::error("test failed"); 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /tests/rope_single_layer_test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include "npu_op_test_util.h" 16 | #include "npu_ops.h" 17 | #include "rope_single_layer_test.h" 18 | 19 | template 20 | void RoPESingleOpTest::Init(size_t max_seq_len, size_t head_dim, 21 | size_t head_num, bool is_neox) { 22 | this->max_seq_len = max_seq_len; 23 | this->head_dim = head_dim; 24 | this->head_num = head_num; 25 | 26 | int freq_cis_size = head_dim * max_seq_len; 27 | host_freq_cis = new float[freq_cis_size]; 28 | InitFreqCIS(host_freq_cis, head_dim, max_seq_len); 29 | 30 | size_t max_buffer_size = head_num * head_dim * max_seq_len; 31 | 32 | host_input = new float[max_buffer_size]; 33 | host_output = new float[max_buffer_size]; 34 | host_input_f16 = new EigenTy[max_buffer_size]; 35 | host_output_f16 = new EigenTy[max_buffer_size]; 36 | golden_fp32 = new float[max_buffer_size]; 37 | 38 | CHECK_ACL(aclrtMalloc((void **)&dev_input_f16, 39 | max_buffer_size * sizeof(EigenTy), 40 | ACL_MEM_MALLOC_HUGE_FIRST)); 41 | CHECK_ACL(aclrtMalloc((void **)&dev_output_f16, 42 | max_buffer_size * sizeof(EigenTy), 43 | ACL_MEM_MALLOC_HUGE_FIRST)); 44 | CHECK_ACL(aclrtMalloc((void **)&dev_freq_cis, freq_cis_size * sizeof(float), 45 | ACL_MEM_MALLOC_HUGE_FIRST)); 46 | CHECK_ACL(aclrtMemcpy(dev_freq_cis, freq_cis_size * sizeof(float), 47 | host_freq_cis, freq_cis_size * sizeof(float), 48 | ACL_MEMCPY_HOST_TO_DEVICE)); 49 | } 50 | 51 | template 52 | bool RoPESingleOpTest::Run(size_t offset, size_t seq_len) { 53 | spdlog::info("{} max_seq_len {} head_num {} head_dim {} offset {} seq_len {}", 54 | __PRETTY_FUNCTION__, max_seq_len, head_num, head_dim, offset, 55 | seq_len); 56 | 57 | make_random_float(host_input, head_dim * head_num * seq_len); 58 | 59 | Eigen::TensorMap> 60 | golden_fp32_map((float *)golden_fp32, seq_len, head_num, head_dim); 61 | 62 | Eigen::TensorMap> 63 | input_map((float *)host_input, seq_len, head_num, head_dim); 64 | 65 | Eigen::TensorMap> 66 | freq_cis_map((float *)host_freq_cis + offset * head_dim, seq_len, 1, 67 | head_dim); 68 | 69 | Eigen::TensorMap< 70 | Eigen::Tensor> 71 | input_fp16_map((EigenTy *)host_input_f16, seq_len, head_num, head_dim); 72 | input_fp16_map = input_map.cast(); 73 | input_map = input_fp16_map.template cast(); 74 | 75 | int freq_len = head_dim / 2; 76 | int hidden_dim = head_num * head_dim; 77 | 78 | for (int s = 0; s < seq_len; ++s) { 79 | for (int n = 0; n < head_num; ++n) { 80 | for (int f = 0; f < freq_len; ++f) { 81 | float fc = host_freq_cis[(s + offset) * freq_len * 2 + 2 * f]; 82 | float fd = host_freq_cis[(s + offset) * freq_len * 2 + 2 * f + 1]; 83 | 84 | int hidden_offset = s * hidden_dim + n * head_dim; 85 | 86 | float qa = host_input[hidden_offset + (is_neox ? f : (2 * f))]; 87 | float qb = host_input[hidden_offset + (is_neox ? (freq_len + f) : (2 * f + 1))]; 88 | 89 | 90 | golden_fp32[hidden_offset + (is_neox ? f : (2 * f))] = qa * fc - qb * fd; 91 | golden_fp32[hidden_offset + (is_neox ? (freq_len + f) : (2 * f + 1))] = qa * fd + qb * fc; 92 | 93 | } 94 | } 95 | } 96 | 97 | golden_fp32_map = golden_fp32_map.template cast().template cast(); 98 | 99 | CHECK_ACL(aclrtMemcpy( 100 | dev_input_f16, seq_len * head_num * head_dim * sizeof(EigenTy), 101 | host_input_f16, seq_len * head_num * head_dim * sizeof(EigenTy), 102 | ACL_MEMCPY_HOST_TO_DEVICE)); 103 | 104 | npu_rope_single_layer(dev_output_f16, dev_freq_cis, dev_input_f16, offset, 105 | seq_len, head_num, head_dim * head_num, is_neox, 106 | GetDataType(), this->stream); 107 | CHECK_ACL(aclrtSynchronizeStream(this->stream)); 108 | 109 | CHECK_ACL(aclrtMemcpy( 110 | host_output_f16, seq_len * head_num * head_dim * sizeof(EigenTy), 111 | dev_output_f16, seq_len * head_num * head_dim * sizeof(EigenTy), 112 | ACL_MEMCPY_DEVICE_TO_HOST)); 113 | 114 | Eigen::TensorMap> 115 | output_fp32_map((float *)host_output, seq_len, head_num, head_dim); 116 | 117 | Eigen::TensorMap< 118 | Eigen::Tensor> 119 | output_fp16_map((EigenTy *)host_output_f16, seq_len, head_num, head_dim); 120 | output_fp32_map = output_fp16_map.template cast(); 121 | return all_close(host_output, golden_fp32, seq_len * head_num * head_dim); 122 | } 123 | 124 | template void RoPESingleOpTest::CleanUp() { 125 | 126 | delete[] host_input; 127 | delete[] host_output; 128 | delete[] host_input_f16; 129 | delete[] host_output_f16; 130 | delete[] host_freq_cis; 131 | delete[] golden_fp32; 132 | 133 | CHECK_ACL(aclrtFree(dev_input_f16)); 134 | CHECK_ACL(aclrtFree(dev_output_f16)); 135 | CHECK_ACL(aclrtFree(dev_freq_cis)); 136 | } 137 | 138 | template class RoPESingleOpTest; 139 | template class RoPESingleOpTest; 140 | -------------------------------------------------------------------------------- /tests/rope_single_layer_test.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "npu_op_test_util.h" 4 | 5 | template 6 | class RoPESingleOpTest : public OpTestBase> { 7 | public: 8 | void Init(size_t max_seq_len, size_t head_dim, size_t head_num, bool is_neox); 9 | bool Run(size_t offset, size_t test_size); 10 | void CleanUp(); 11 | 12 | size_t max_seq_len{0}; 13 | size_t head_dim{0}; 14 | size_t head_num{0}; 15 | bool is_neox{true}; 16 | 17 | void *dev_input_f16{nullptr}; 18 | void *dev_output_f16{nullptr}; 19 | void *dev_freq_cis{nullptr}; 20 | 21 | float *host_input{nullptr}; 22 | float *host_output{nullptr}; 23 | EigenTy *host_input_f16{nullptr}; 24 | EigenTy *host_output_f16{nullptr}; 25 | float *golden_fp32{nullptr}; 26 | float *host_freq_cis{nullptr}; 27 | }; 28 | --------------------------------------------------------------------------------