├── 3rd_party └── onnxruntime │ └── .keep ├── CMakeLists.txt ├── LICENSE ├── README.md ├── demo └── cli_demo.cpp ├── include ├── httplib.h ├── json.hpp ├── llm.hpp ├── ortwrapper.hpp └── tokenizer.hpp ├── resource ├── logo.png └── prompt.txt └── src ├── llm.cpp ├── llmconfig.hpp └── tokenizer.cpp /3rd_party/onnxruntime/.keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangzhaode/onnx-llm/7458c3a068b9e7e9ad1918ed1f9f0db14eec9e31/3rd_party/onnxruntime/.keep -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.5) 2 | project(onnx-llm) 3 | 4 | option(BUILD_FOR_ANDROID "Build for android whith mini memory mode." OFF) 5 | option(LLM_SUPPORT_VISION "Llm model support vision input." OFF) 6 | option(DUMP_PROFILE_INFO "Dump profile info when chat." OFF) 7 | option(BUILD_JNI "Build JNI for android app." OFF) 8 | 9 | if (DUMP_PROFILE_INFO) 10 | add_definitions(-DDUMP_PROFILE_INFO) 11 | endif() 12 | 13 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17") 14 | if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang") 15 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -stdlib=libc++") 16 | endif() 17 | 18 | 19 | set(ONNXRUNTIME_PATH ${CMAKE_SOURCE_DIR}/3rd_party/onnxruntime) 20 | 21 | link_directories(${ONNXRUNTIME_PATH}/lib) 22 | include_directories(${ONNXRUNTIME_PATH}/include 23 | ${CMAKE_SOURCE_DIR}/include) 24 | 25 | 26 | FILE(GLOB SRCS ${CMAKE_CURRENT_LIST_DIR}/src/*.cpp) 27 | 28 | add_library(llm STATIC ${SRCS}) 29 | target_link_libraries(llm onnxruntime) 30 | 31 | add_executable(cli_demo ${CMAKE_SOURCE_DIR}/demo/cli_demo.cpp) 32 | target_link_libraries(cli_demo llm) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![onnx-llm](resource/logo.png) 2 | 3 | # onnx-llm 4 | [![License](https://img.shields.io/github/license/wangzhaode/onnx-llm)](LICENSE.txt) 5 | 6 | ## Build 7 | 8 | ### Steps 9 | 10 | 1. Download `onnxruntime` release package from [here](https://github.com/microsoft/onnxruntime/releases). 11 | 2. Extract the package to `onnx-llm/3rd_party/onnxruntime`. 12 | 3. Compile the project. 13 | 14 | ### Example 15 | ```base 16 | wget https://github.com/microsoft/onnxruntime/releases/download/v1.19.2/onnxruntime-osx-arm64-1.19.2.tgz 17 | tar -xvf onnxruntime-osx-arm64-1.19.2.tgz 18 | mv onnxruntime-osx-arm64-1.19.2 3rd_party/onnxruntime -T 19 | mkdir build && cd build 20 | cmake .. 21 | make -j 22 | ``` 23 | 24 | ## Usage 25 | 26 | 1. Model export using [llm-export](https://github.com/wangzhaode/llm-export) 27 | 2. Usage of onnx-llm same as [mnn-llm](https://github.com/wangzhaode/mnn-llm) 28 | 29 | ```base 30 | (base) ➜ build git:(main) ✗ ./cli_demo qwen2-0.5b-instruct/config.json ../resource/prompt.txt 31 | model path is ../../llm-export/model/config.json 32 | load tokenizer 33 | tokenizer_type = 3 34 | load tokenizer Done 35 | load ../../llm-export/model/llm.onnx ... Load Module Done! 36 | prompt file is ../resource/prompt.txt 37 | Hello! How can I assist you today? 38 | 我是来自阿里云的超大规模语言模型,我叫通义千问。 39 | 很抱歉,作为AI助手,我无法实时获取和显示当前的天气信息。建议您查看当地的气象预报或应用中的天气查询功能来获取准确的信息。 40 | 41 | ################################# 42 | prompt tokens num = 36 43 | decode tokens num = 64 44 | prefill time = 0.32 s 45 | decode time = 2.00 s 46 | prefill speed = 112.66 tok/s 47 | decode speed = 32.07 tok/s 48 | ################################## 49 | ``` 50 | 51 | ## Reference 52 | - [mnn-llm](https://github.com/wangzhaode/mnn-llm) 53 | - [onnxruntime](https://github.com/microsoft/onnxruntime) 54 | -------------------------------------------------------------------------------- /demo/cli_demo.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // cli_demo.cpp 3 | // 4 | // Created by MNN on 2023/03/24. 5 | // ZhaodeWang 6 | // 7 | 8 | #include "llm.hpp" 9 | #include 10 | #include 11 | 12 | void benchmark(Llm* llm, std::string prompt_file) { 13 | std::cout << "prompt file is " << prompt_file << std::endl; 14 | std::ifstream prompt_fs(prompt_file); 15 | std::vector prompts; 16 | std::string prompt; 17 | while (std::getline(prompt_fs, prompt)) { 18 | // prompt start with '#' will be ignored 19 | if (prompt.substr(0, 1) == "#") { 20 | continue; 21 | } 22 | std::string::size_type pos = 0; 23 | while ((pos = prompt.find("\\n", pos)) != std::string::npos) { 24 | prompt.replace(pos, 2, "\n"); 25 | pos += 1; 26 | } 27 | prompts.push_back(prompt); 28 | } 29 | int prompt_len = 0; 30 | int decode_len = 0; 31 | int64_t prefill_time = 0; 32 | int64_t decode_time = 0; 33 | for (int i = 0; i < prompts.size(); i++) { 34 | llm->response(prompts[i]); 35 | prompt_len += llm->prompt_len_; 36 | decode_len += llm->gen_seq_len_; 37 | prefill_time += llm->prefill_us_; 38 | decode_time += llm->decode_us_; 39 | } 40 | float prefill_s = prefill_time / 1e6; 41 | float decode_s = decode_time / 1e6; 42 | printf("\n#################################\n"); 43 | printf("prompt tokens num = %d\n", prompt_len); 44 | printf("decode tokens num = %d\n", decode_len); 45 | printf("prefill time = %.2f s\n", prefill_s); 46 | printf(" decode time = %.2f s\n", decode_s); 47 | printf("prefill speed = %.2f tok/s\n", prompt_len / prefill_s); 48 | printf(" decode speed = %.2f tok/s\n", decode_len / decode_s); 49 | printf("##################################\n"); 50 | } 51 | 52 | int main(int argc, const char* argv[]) { 53 | if (argc < 2) { 54 | std::cout << "Usage: " << argv[0] << " model_dir " << std::endl; 55 | return 0; 56 | } 57 | std::string model_dir = argv[1]; 58 | std::cout << "model path is " << model_dir << std::endl; 59 | std::unique_ptr llm(Llm::createLLM(model_dir)); 60 | llm->load(); 61 | if (argc < 3) { 62 | llm->chat(); 63 | } 64 | std::string prompt_file = argv[2]; 65 | benchmark(llm.get(), prompt_file); 66 | return 0; 67 | } 68 | -------------------------------------------------------------------------------- /include/llm.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // llm.hpp 3 | // 4 | // Created by MNN on 2023/08/25. 5 | // ZhaodeWang 6 | // 7 | 8 | #ifndef LLM_hpp 9 | #define LLM_hpp 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | 20 | #include "ortwrapper.hpp" 21 | #include "tokenizer.hpp" 22 | #include "json.hpp" 23 | 24 | using namespace Ort; 25 | using json = nlohmann::json; 26 | class Tokenizer; 27 | class Pipeline; 28 | class LlmConfig; 29 | 30 | // Llm start 31 | // llm stream buffer with callback 32 | class LlmStreamBuffer : public std::streambuf { 33 | public: 34 | using CallBack = std::function;; 35 | LlmStreamBuffer(CallBack callback) : callback_(callback) {} 36 | 37 | protected: 38 | virtual std::streamsize xsputn(const char* s, std::streamsize n) override { 39 | if (callback_) { 40 | callback_(s, n); 41 | } 42 | return n; 43 | } 44 | 45 | private: 46 | CallBack callback_ = nullptr; 47 | }; 48 | 49 | enum PROMPT_TYPE { 50 | SYSTEM = 0, 51 | ATTACHMENT = 1, 52 | USER = 2, 53 | ASSISTANT = 3, 54 | OTHER = 4 55 | }; 56 | 57 | struct Prompt { 58 | PROMPT_TYPE type; 59 | std::string str; 60 | std::vector tokens; 61 | }; 62 | 63 | class Llm { 64 | public: 65 | using PromptItem = std::pair; // 66 | Llm(std::shared_ptr config) : config_(config) {} 67 | virtual ~Llm(); 68 | void chat(); 69 | void reset(); 70 | static Llm* createLLM(const std::string& config_path); 71 | virtual void load(); 72 | Value forward(const std::vector& input_ids); 73 | int sample(Value& logits, const std::vector& pre_ids); 74 | std::string apply_prompt_template(const std::string& user_content) const; 75 | std::string apply_chat_template(const std::vector& chat_prompts) const; 76 | std::string response(const std::string& user_content, std::ostream* os = &std::cout, const char* end_with = nullptr); 77 | std::string response(const std::vector& chat_prompts, std::ostream* os = &std::cout, const char* end_with = nullptr); 78 | void generate_init(); 79 | std::string generate(const std::vector& input_ids, std::ostream* os, const char* end_with); 80 | std::vector generate(const std::vector& input_ids, int max_new_tokens = -1); 81 | void print_speed(); 82 | // config function 83 | std::string dump_config(); 84 | bool set_config(const std::string& content); 85 | friend class Pipeline; 86 | public: 87 | // forward info 88 | int prompt_len_ = 0; 89 | int gen_seq_len_ = 0; 90 | int all_seq_len_ = 0; 91 | std::vector history_ids_; 92 | // time 93 | int64_t prefill_us_ = 0; 94 | int64_t decode_us_ = 0; 95 | bool is_single_ = true; 96 | bool attention_fused_ = true; 97 | protected: 98 | std::shared_ptr config_; 99 | std::shared_ptr tokenizer_; 100 | std::vector key_value_shape_ = {}; 101 | Value past_key_values_ {nullptr}; 102 | std::shared_ptr runtime_manager_; 103 | std::shared_ptr module_; 104 | void init_runtime(); 105 | std::string decode(int id); 106 | bool is_stop(int token_id); 107 | virtual std::vector tokenizer(const std::string& query); 108 | virtual Value embedding(const std::vector& input_ids); 109 | virtual Value gen_attention_mask(int seq_len); 110 | virtual Value gen_position_ids(int seq_len); 111 | }; 112 | // Llm end 113 | 114 | #endif // LLM_hpp 115 | -------------------------------------------------------------------------------- /include/ortwrapper.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // ortwrapper.hpp 3 | // 4 | // Created by zhaode on 2024/10/09. 5 | // ZhaodeWang 6 | // 7 | 8 | #ifndef ORTWRAPPER_hpp 9 | #define ORTWRAPPER_hpp 10 | 11 | #include 12 | #include 13 | 14 | namespace Ort { 15 | 16 | class RuntimeManager { 17 | public: 18 | RuntimeManager() { 19 | env_.reset(new Ort::Env(ORT_LOGGING_LEVEL_WARNING, "onnx-llm")); 20 | options_.reset(new Ort::SessionOptions()); 21 | options_->SetIntraOpNumThreads(1); 22 | options_->SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED); 23 | allocator_.reset(new Ort::AllocatorWithDefaultOptions()); 24 | } 25 | ~RuntimeManager() {} 26 | const Ort::Env& env() const { 27 | return *env_; 28 | } 29 | const Ort::SessionOptions& options() const { 30 | return *options_; 31 | } 32 | const Ort::AllocatorWithDefaultOptions& allocator() const { 33 | return *allocator_; 34 | } 35 | private: 36 | std::unique_ptr env_; 37 | std::unique_ptr options_; 38 | std::unique_ptr allocator_; 39 | }; 40 | 41 | class Module { 42 | public: 43 | Module(std::shared_ptr runtime, const std::string& path) { 44 | session_.reset(new Ort::Session(runtime->env(), path.c_str(), runtime->options())); 45 | input_count_ = session_->GetInputCount(); 46 | output_count_ = session_->GetOutputCount(); 47 | for (int i = 0; i < input_count_; i++) { 48 | input_strs_.push_back(session_->GetInputNameAllocated(i, runtime->allocator())); 49 | input_names_.push_back(input_strs_[i].get()); 50 | } 51 | for (int i = 0; i < output_count_; i++) { 52 | output_strs_.push_back(session_->GetOutputNameAllocated(i, runtime->allocator())); 53 | output_names_.push_back(output_strs_[i].get()); 54 | } 55 | } 56 | std::vector onForward(const std::vector& inputs) { 57 | auto outputs = session_->Run(Ort::RunOptions{nullptr}, 58 | input_names_.data(), inputs.data(), inputs.size(), 59 | output_names_.data(), output_names_.size()); 60 | return outputs; 61 | } 62 | private: 63 | std::unique_ptr session_; 64 | size_t input_count_, output_count_; 65 | std::vector input_strs_, output_strs_; 66 | std::vector input_names_, output_names_; 67 | }; 68 | 69 | template 70 | static Value _Input(const std::vector& shape, std::shared_ptr rtmgr) { 71 | std::vector shape_int64(shape.begin(), shape.end()); 72 | return Value::CreateTensor(rtmgr->allocator(), shape_int64.data(), shape_int64.size()); 73 | } 74 | 75 | } // namespace Ort 76 | 77 | #endif /* ORTWRAPPER_hpp */ -------------------------------------------------------------------------------- /include/tokenizer.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // tokenizer.hpp 3 | // 4 | // Created by MNN on 2023/09/25. 5 | // ZhaodeWang 6 | // 7 | 8 | #ifndef TOKENIZER_hpp 9 | #define TOKENIZER_hpp 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | // #include 17 | #include 18 | 19 | // std::string_view impl in c++11 start 20 | class string_view_ { 21 | public: 22 | string_view_() : data_(nullptr), size_(0) {} 23 | string_view_(const char* data) : data_(data), size_(std::strlen(data)) {} 24 | string_view_(const char* data, std::size_t size) : data_(data), size_(size) {} 25 | string_view_(const std::string& str) : data_(str.data()), size_(str.size()) {} 26 | constexpr string_view_(const string_view_&) noexcept = default; 27 | string_view_& operator=(const string_view_&) noexcept = default; 28 | const char& operator[](size_t pos) const { return data_[pos]; } 29 | constexpr const char* data() const noexcept { return data_; } 30 | constexpr std::size_t size() const noexcept { return size_; } 31 | constexpr bool empty() const { return size_ == 0; } 32 | std::string to_string() const { return std::string(data_, size_); } 33 | bool operator==(const string_view_& other) const noexcept { 34 | return size_ == other.size_ && strncmp(data_, other.data_, size_) == 0; 35 | } 36 | void remove_prefix(size_t n) { 37 | if (n < size_) { 38 | data_ += n; 39 | size_ -= n; 40 | } else { 41 | data_ = ""; 42 | size_ = 0; 43 | } 44 | } 45 | private: 46 | const char* data_; 47 | std::size_t size_ = 0; 48 | }; 49 | 50 | namespace std { 51 | template<> 52 | class hash { 53 | public: 54 | size_t operator()(const string_view_& sv) const { 55 | size_t result = 0; 56 | for (size_t i = 0; i < sv.size(); ++i) { 57 | result = (result * 31) + static_cast(sv[i]); 58 | } 59 | return result; 60 | } 61 | }; 62 | } 63 | // std::string_view impl in c++11 end 64 | 65 | class Tokenizer { 66 | public: 67 | static constexpr int MAGIC_NUMBER = 430; 68 | enum TokenizerType { 69 | SENTENCEPIECE = 0, 70 | TIKTOIKEN = 1, 71 | BERT = 2, 72 | HUGGINGFACE = 3 73 | }; 74 | Tokenizer() = default; 75 | virtual ~Tokenizer() = default; 76 | static Tokenizer* createTokenizer(const std::string& filename); 77 | bool is_stop(int token); 78 | bool is_special(int token); 79 | std::vector encode(const std::string& str); 80 | virtual std::string decode(int id) = 0; 81 | protected: 82 | virtual void load_special(std::ifstream& file); 83 | virtual bool load_vocab(std::ifstream& file) = 0; 84 | virtual void encode(const std::string& str, std::vector& ids) = 0; 85 | std::vector special_tokens_; 86 | std::vector stop_tokens_; 87 | std::vector prefix_tokens_; 88 | }; 89 | 90 | class Sentencepiece : public Tokenizer { 91 | public: 92 | Sentencepiece() = default; 93 | virtual std::string decode(int id) override; 94 | protected: 95 | virtual bool load_vocab(std::ifstream& file) override; 96 | virtual void encode(const std::string& str, std::vector& ids) override; 97 | private: 98 | enum ModelType { 99 | UNIGRAM = 1, 100 | BPE = 2, 101 | WORD = 3, 102 | CHAR = 4 103 | }; 104 | enum PieceType { 105 | NORMAL = 1, 106 | UNKNOWN = 2, 107 | CONTROL = 3, 108 | USER_DEFINED = 4, 109 | UNUSED = 5, 110 | BYTE = 6 111 | }; 112 | struct SentencePiece { 113 | std::string piece; 114 | float score; 115 | PieceType type = PieceType::NORMAL; 116 | SentencePiece() {} 117 | SentencePiece(const std::string& p, float s, PieceType t) : piece(p), score(s), type(t) {} 118 | }; 119 | using EncodeResult = std::vector>; 120 | private: 121 | // model train type 122 | ModelType type_ = BPE; 123 | // byte fall back enable 124 | bool byte_fall_back_ = true; 125 | // unknown id. 126 | int unk_id_ = 0; 127 | // pieces from model 128 | std::vector sentence_pieces_; 129 | // piece -> id map for normal pieces 130 | std::unordered_map pieces_; 131 | // piece -> id map for control, unknown, and byte pieces 132 | std::unordered_map reserved_id_map_; 133 | private: 134 | float get_score(int id) const; 135 | bool is_unused(int id) const; 136 | bool is_control(int id) const; 137 | int piece_to_id(const std::string& w) const; 138 | std::string byte_to_piece(unsigned char c) const; 139 | EncodeResult bpe_encode(string_view_ str, float alpha = 0.f); 140 | }; 141 | 142 | class Tiktoken : public Tokenizer { 143 | public: 144 | Tiktoken() = default; 145 | virtual std::string decode(int id) override; 146 | protected: 147 | virtual bool load_vocab(std::ifstream& file) override; 148 | virtual void encode(const std::string& str, std::vector& ids) override; 149 | std::unordered_map encoder_; 150 | std::vector decoder_; 151 | }; 152 | 153 | class BertTokenizer : public Tiktoken { 154 | public: 155 | BertTokenizer() = default; 156 | protected: 157 | virtual void encode(const std::string& str, std::vector& ids) override; 158 | private: 159 | std::vector word_piece(const std::string& token); 160 | }; 161 | 162 | class HuggingfaceTokenizer : public Tokenizer { 163 | struct hash_pair_wstring { 164 | size_t operator()(const std::pair& p) const { 165 | auto hash1 = std::hash{}(p.first); 166 | auto hash2 = std::hash{}(p.second); 167 | // If hash1 == hash2, their XOR is zero. 168 | return (hash1 != hash2) ? hash1 ^ hash2 : hash1; 169 | } 170 | }; 171 | using BPERanks = std::unordered_map, int, hash_pair_wstring>; 172 | public: 173 | HuggingfaceTokenizer() = default; 174 | virtual std::string decode(int id) override; 175 | protected: 176 | virtual bool load_vocab(std::ifstream& file) override; 177 | virtual void encode(const std::string& str, std::vector& ids) override; 178 | private: 179 | void bpe(const std::wstring& token, const BPERanks& bpe_ranks, std::vector* result); 180 | BPERanks bpe_ranks_; 181 | std::unordered_map b2u_; 182 | std::unordered_map u2b_; 183 | std::unordered_map encoder_; 184 | std::vector decoder_; 185 | }; 186 | 187 | #endif // TOKENIZER_hpp -------------------------------------------------------------------------------- /resource/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangzhaode/onnx-llm/7458c3a068b9e7e9ad1918ed1f9f0db14eec9e31/resource/logo.png -------------------------------------------------------------------------------- /resource/prompt.txt: -------------------------------------------------------------------------------- 1 | Hello 2 | 你好,请问你是谁? 3 | 请问今天的天气如何? 4 | -------------------------------------------------------------------------------- /src/llm.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // llm.cpp 3 | // 4 | // Created by MNN on 2023/08/25. 5 | // ZhaodeWang 6 | // 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include "llm.hpp" 15 | #include "llmconfig.hpp" 16 | #include "tokenizer.hpp" 17 | 18 | #ifdef LLM_SUPPORT_VISION 19 | #include "httplib.h" 20 | #endif 21 | 22 | // Llm start 23 | std::string Llm::dump_config() { 24 | return config_->config_.dump(); 25 | } 26 | 27 | bool Llm::set_config(const std::string& content) { 28 | config_->config_.merge_patch(content.c_str()); 29 | return true; 30 | } 31 | 32 | void Llm::init_runtime() { 33 | runtime_manager_.reset(new RuntimeManager()); 34 | } 35 | 36 | void Llm::load() { 37 | init_runtime(); 38 | // init module status 39 | key_value_shape_ = config_->key_value_shape(); 40 | is_single_ = config_->is_single(); 41 | attention_fused_ = config_->attention_fused(); 42 | { 43 | std::ifstream embedding_bin(config_->embedding_file()); 44 | embedding_bin.close(); 45 | } 46 | // 1. load vocab 47 | printf("load tokenizer\n"); 48 | tokenizer_.reset(Tokenizer::createTokenizer(config_->tokenizer_file())); 49 | printf("load tokenizer Done\n"); 50 | // 3. load model 51 | int layer_nums = config_->layer_nums(); 52 | key_value_shape_.insert(key_value_shape_.begin(), layer_nums); 53 | std::string model_path = config_->llm_model(); 54 | printf("load %s ... ", model_path.c_str()); 55 | module_.reset(new Module(runtime_manager_, model_path)); 56 | printf("Load Module Done!\n"); 57 | } 58 | 59 | Value Llm::forward(const std::vector& input_ids) { 60 | int seq_len = input_ids.size(); 61 | std::vector inputs; 62 | inputs.emplace_back(embedding(input_ids)); 63 | inputs.emplace_back(gen_attention_mask(seq_len)); 64 | inputs.emplace_back(gen_position_ids(seq_len)); 65 | inputs.emplace_back(std::move(past_key_values_)); 66 | auto outputs = module_->onForward(inputs); 67 | auto logits = std::move(outputs[0]); 68 | past_key_values_ = std::move(outputs[1]); 69 | all_seq_len_ += seq_len; 70 | gen_seq_len_++; 71 | return logits; 72 | } 73 | 74 | int Llm::sample(Value& logits, const std::vector& pre_ids) { 75 | std::unordered_set ids_set(pre_ids.begin(), pre_ids.end()); 76 | auto scores = logits.GetTensorMutableData(); 77 | auto shape = logits.GetTensorTypeAndShapeInfo().GetShape(); 78 | auto size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); 79 | // repetition penalty 80 | const float repetition_penalty = 1.1; 81 | for (auto id : ids_set) { 82 | float score = scores[id]; 83 | scores[id] = score < 0 ? score * repetition_penalty : score / repetition_penalty; 84 | } 85 | // argmax 86 | float max_score = 0; 87 | int token_id = 0; 88 | for (int i = 0; i < size; i++) { 89 | float score = scores[i]; 90 | if (score > max_score) { 91 | max_score = score; 92 | token_id = i; 93 | } 94 | } 95 | return token_id; 96 | } 97 | 98 | static std::string apply_template(std::string prompt_template, const std::string& content, const std::string& role = "") { 99 | if (prompt_template.empty()) return content; 100 | if (!role.empty()) { 101 | const std::string placeholder = "%r"; 102 | size_t start_pos = prompt_template.find(placeholder); 103 | if (start_pos == std::string::npos) return content; 104 | prompt_template.replace(start_pos, placeholder.length(), role); 105 | } 106 | const std::string placeholder = "%s"; 107 | size_t start_pos = prompt_template.find(placeholder); 108 | if (start_pos == std::string::npos) return content; 109 | prompt_template.replace(start_pos, placeholder.length(), content); 110 | return prompt_template; 111 | } 112 | 113 | std::string Llm::apply_prompt_template(const std::string& user_content) const { 114 | auto chat_prompt = config_->prompt_template(); 115 | return apply_template(chat_prompt, user_content); 116 | } 117 | 118 | std::string Llm::apply_chat_template(const std::vector& chat_prompts) const { 119 | auto chat_template = config_->chat_template(); 120 | std::string prompt_result; 121 | auto iter = chat_prompts.begin(); 122 | for (; iter != chat_prompts.end() - 1; ++iter) { 123 | prompt_result += apply_template(chat_template, iter->second, iter->first); 124 | } 125 | if (iter->first == "user") { 126 | prompt_result += apply_prompt_template(iter->second); 127 | } else { 128 | prompt_result += apply_template(chat_template, iter->second, iter->first); 129 | } 130 | return prompt_result; 131 | } 132 | 133 | void Llm::chat() { 134 | std::vector history; 135 | history.push_back(std::make_pair("system", "You are a helpful assistant.")); 136 | while (true) { 137 | std::cout << "\nQ: "; 138 | std::string user_str; 139 | std::cin >> user_str; 140 | if (user_str == "/exit") { 141 | break; 142 | } 143 | if (user_str == "/reset") { 144 | history.resize(1); 145 | std::cout << "\nA: reset done." << std::endl; 146 | continue; 147 | } 148 | std::cout << "\nA: " << std::flush; 149 | history.emplace_back(std::make_pair("user", user_str)); 150 | auto assistant_str = response(history); 151 | history.emplace_back(std::make_pair("assistant", assistant_str)); 152 | std::cout << std::endl; 153 | } 154 | } 155 | 156 | void Llm::reset() { 157 | history_ids_.clear(); 158 | all_seq_len_ = 0; 159 | } 160 | 161 | void Llm::generate_init() { 162 | // init status 163 | gen_seq_len_ = 0; 164 | prefill_us_ = 0; 165 | decode_us_ = 0; 166 | past_key_values_ = _Input(key_value_shape_, runtime_manager_); 167 | if (!config_->reuse_kv()) { 168 | all_seq_len_ = 0; 169 | history_ids_.clear(); 170 | } 171 | } 172 | 173 | std::vector Llm::generate(const std::vector& input_ids, int max_new_tokens) { 174 | generate_init(); 175 | std::vector output_ids, all_ids = input_ids; 176 | prompt_len_ = static_cast(input_ids.size()); 177 | if (max_new_tokens < 0) { max_new_tokens = config_->max_new_tokens(); } 178 | // prefill 179 | auto logits = forward(input_ids); 180 | int token = sample(logits, all_ids); 181 | output_ids.push_back(token); 182 | all_ids.push_back(token); 183 | // decode 184 | while (gen_seq_len_ < max_new_tokens) { 185 | logits = forward({token}); 186 | token = sample(logits, all_ids); 187 | if (is_stop(token)) { break; } 188 | output_ids.push_back(token); 189 | all_ids.push_back(token); 190 | } 191 | return output_ids; 192 | } 193 | 194 | std::string Llm::generate(const std::vector& input_ids, std::ostream* os, const char* end_with) { 195 | prompt_len_ = static_cast(input_ids.size()); 196 | history_ids_.insert(history_ids_.end(), input_ids.begin(), input_ids.end()); // push to history_ids_ 197 | auto st = std::chrono::system_clock::now(); 198 | auto logits = forward(input_ids); 199 | int token = sample(logits, history_ids_); 200 | auto et = std::chrono::system_clock::now(); 201 | std::string output_str = decode(token); 202 | prefill_us_ = std::chrono::duration_cast(et - st).count(); 203 | *os << output_str << std::flush; 204 | while (gen_seq_len_ < config_->max_new_tokens()) { 205 | st = std::chrono::system_clock::now(); 206 | history_ids_.push_back(token); 207 | logits = forward({token}); 208 | token = sample(logits, history_ids_); 209 | et = std::chrono::system_clock::now(); 210 | decode_us_ += std::chrono::duration_cast(et - st).count(); 211 | if (is_stop(token)) { 212 | *os << end_with << std::flush; 213 | break; 214 | } 215 | auto word = decode(token); 216 | *os << word << std::flush; 217 | output_str += word; 218 | } 219 | #ifdef DUMP_PROFILE_INFO 220 | print_speed(); 221 | #endif 222 | return output_str; 223 | } 224 | 225 | std::vector Llm::tokenizer(const std::string& query) { 226 | auto prompt = apply_prompt_template(query); 227 | auto input_ids = tokenizer_->encode(prompt); 228 | return input_ids; 229 | } 230 | 231 | std::string Llm::response(const std::string& user_content, std::ostream* os, const char* end_with) { 232 | generate_init(); 233 | if (!end_with) { end_with = "\n"; } 234 | std::vector input_ids; 235 | if (config_->reuse_kv()) { 236 | auto prompt = apply_prompt_template(user_content); 237 | if (all_seq_len_ > 0) { 238 | prompt = "<|im_end|>\n" + prompt; 239 | } 240 | input_ids = tokenizer_->encode(prompt); 241 | } else { 242 | input_ids = tokenizer(user_content); 243 | } 244 | return generate(input_ids, os, end_with); 245 | } 246 | 247 | std::string Llm::response(const std::vector& chat_prompts, std::ostream* os, const char* end_with) { 248 | if (chat_prompts.empty()) { return ""; } 249 | generate_init(); 250 | if (!end_with) { end_with = "\n"; } 251 | auto prompt = apply_chat_template(chat_prompts); 252 | if (config_->reuse_kv() && all_seq_len_ > 0) { 253 | prompt = "<|im_end|>\n" + prompt; 254 | } 255 | // std::cout << "# prompt : " << prompt << std::endl; 256 | auto input_ids = tokenizer_->encode(prompt); 257 | // printf("input_ids (%lu): ", input_ids.size()); for (auto id : input_ids) printf("%d, ", id); printf("\n"); 258 | return generate(input_ids, os, end_with); 259 | } 260 | 261 | void Llm::print_speed() { 262 | auto prefill_s = prefill_us_ * 1e-6; 263 | auto decode_s = decode_us_ * 1e-6; 264 | auto total_s = prefill_s + decode_s; 265 | printf("\n#################################\n"); 266 | printf(" total tokens num = %d\n", prompt_len_ + gen_seq_len_); 267 | printf("prompt tokens num = %d\n", prompt_len_); 268 | printf("output tokens num = %d\n", gen_seq_len_); 269 | printf(" total time = %.2f s\n", total_s); 270 | printf("prefill time = %.2f s\n", prefill_s); 271 | printf(" decode time = %.2f s\n", decode_s); 272 | printf(" total speed = %.2f tok/s\n", (prompt_len_ + gen_seq_len_) / total_s); 273 | printf("prefill speed = %.2f tok/s\n", prompt_len_ / prefill_s); 274 | printf(" decode speed = %.2f tok/s\n", gen_seq_len_ / decode_s); 275 | printf(" chat speed = %.2f tok/s\n", gen_seq_len_ / total_s); 276 | printf("##################################\n"); 277 | } 278 | 279 | 280 | Llm::~Llm() { 281 | module_.reset(); 282 | runtime_manager_.reset(); 283 | } 284 | 285 | Value Llm::embedding(const std::vector& input_ids) { 286 | // disk embedding to save memory 287 | int hidden_size = config_->hidden_size(); 288 | int seq_len = static_cast(input_ids.size()); 289 | auto inputs_embeds = _Input({seq_len, 1, hidden_size}, runtime_manager_); 290 | size_t size = hidden_size * sizeof(int16_t); 291 | FILE* file = fopen(config_->embedding_file().c_str(), "rb"); 292 | std::unique_ptr buffer(new int16_t[hidden_size]); 293 | for (size_t i = 0; i < seq_len; i++) { 294 | fseek(file, input_ids[i] * size, SEEK_SET); 295 | size_t bytes_read = fread(buffer.get(), 1, size, file); 296 | (void)bytes_read; 297 | auto ptr = inputs_embeds.GetTensorMutableData() + i * hidden_size * 2; 298 | for (int j = 0; j < hidden_size; j++) { 299 | ptr[j * 2] = 0; 300 | ptr[j * 2 + 1] = buffer[j]; 301 | } 302 | } 303 | fclose(file); 304 | return std::move(inputs_embeds); 305 | } 306 | 307 | std::string Llm::decode(int id) { 308 | std::string word = tokenizer_->decode(id); 309 | // Fix utf-8 garbled characters 310 | if (word.length() == 6 && word[0] == '<' && word[word.length()-1] == '>' && word[1] == '0' && word[2] == 'x') { 311 | int num = std::stoi(word.substr(3, 2), nullptr, 16); 312 | word = static_cast(num); 313 | } 314 | return word; 315 | } 316 | 317 | Value Llm::gen_attention_mask(int seq_len) { 318 | int kv_seq_len = all_seq_len_ + seq_len; 319 | if (seq_len == 1) { 320 | kv_seq_len = seq_len; 321 | } 322 | if (config_->attention_mask() == "float") { 323 | auto attention_mask = _Input({1, 1, seq_len, kv_seq_len}, runtime_manager_); 324 | auto ptr = attention_mask.GetTensorMutableData(); 325 | for (int i = 0; i < seq_len; i++) { 326 | for (int j = 0; j < kv_seq_len; j++) { 327 | int row = i + all_seq_len_; 328 | ptr[kv_seq_len * i + j] = (j > row) * std::numeric_limits::lowest(); 329 | } 330 | } 331 | return attention_mask; 332 | } else { 333 | auto attention_mask = _Input({1, 1, seq_len, kv_seq_len}, runtime_manager_); 334 | auto ptr = attention_mask.GetTensorMutableData(); 335 | if (config_->attention_mask() == "glm") { 336 | // chatglm 337 | for (int i = 0; i < seq_len * kv_seq_len; i++) { 338 | ptr[i] = 0; 339 | } 340 | if (seq_len > 1) { 341 | for (int i = 1; i < seq_len; i++) { 342 | ptr[seq_len * i - 1] = 1; 343 | } 344 | } 345 | } else { 346 | bool is_glm2 = config_->attention_mask() == "glm2"; 347 | for (int i = 0; i < seq_len; i++) { 348 | for (int j = 0; j < kv_seq_len; j++) { 349 | int row = i + all_seq_len_; 350 | ptr[seq_len * i + j] = is_glm2 ? j > row : j <= row; 351 | } 352 | } 353 | } 354 | return attention_mask; 355 | } 356 | } 357 | 358 | Value Llm::gen_position_ids(int seq_len) { 359 | if (config_->attention_mask() == "glm") { 360 | // chatglm 361 | auto position_ids = _Input({1, 2, seq_len}, runtime_manager_); 362 | auto ptr = position_ids.GetTensorMutableData(); 363 | if (seq_len == 1) { 364 | ptr[0] = all_seq_len_ - gen_seq_len_ - 2; 365 | ptr[1] = gen_seq_len_ + 1; 366 | } else { 367 | for (int i = 0; i < seq_len - 1; i++) { 368 | ptr[i] = i; 369 | ptr[seq_len + i] = 0; 370 | } 371 | ptr[seq_len - 1] = seq_len - 2; 372 | ptr[2 * seq_len - 1] = 1; 373 | } 374 | return position_ids; 375 | } else { 376 | bool is_glm2 = config_->attention_mask() == "glm2"; 377 | auto position_ids = _Input({1, seq_len}, runtime_manager_); 378 | auto ptr = position_ids.GetTensorMutableData(); 379 | if (seq_len == 1) { 380 | ptr[0] = is_glm2 ? gen_seq_len_ : all_seq_len_; 381 | } else { 382 | for (int i = 0; i < seq_len; i++) { 383 | ptr[i] = i + all_seq_len_; 384 | } 385 | } 386 | return position_ids; 387 | } 388 | } 389 | 390 | bool Llm::is_stop(int token_id) { 391 | return tokenizer_->is_stop(token_id); 392 | } 393 | 394 | Llm* Llm::createLLM(const std::string& config_path) { 395 | std::shared_ptr config(new LlmConfig(config_path)); 396 | Llm* llm = nullptr; 397 | if (config->is_visual()) { 398 | // llm = new Lvlm(config); 399 | } else { 400 | llm = new Llm(config); 401 | } 402 | return llm; 403 | } -------------------------------------------------------------------------------- /src/llmconfig.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // llmconfig.hpp 3 | // 4 | // Created by MNN on 2024/07/19. 5 | // ZhaodeWang 6 | // 7 | #include "llm.hpp" 8 | 9 | static inline bool has_suffix(const std::string& str, const std::string& suffix) { 10 | return str.size() >= suffix.size() && 11 | str.compare(str.size() - suffix.size(), suffix.size(), suffix) == 0; 12 | } 13 | 14 | static inline std::string base_dir(const std::string& path) { 15 | size_t pos = path.find_last_of("/\\"); 16 | if (pos == std::string::npos) { 17 | return "./"; 18 | } else { 19 | return path.substr(0, pos + 1); 20 | } 21 | } 22 | 23 | static inline std::string file_name(const std::string& path) { 24 | size_t pos = path.find_last_of("/\\"); 25 | if (pos == std::string::npos) { 26 | return path; 27 | } else { 28 | return path.substr(pos + 1); 29 | } 30 | } 31 | 32 | class LlmConfig { 33 | public: 34 | std::string base_dir_; 35 | json config_, llm_config_; 36 | LlmConfig() {} 37 | LlmConfig(const std::string& path) { 38 | // load config 39 | if (has_suffix(path, ".json")) { 40 | std::ifstream config_file(path); 41 | if (config_file.is_open()) { 42 | config_ = json::parse(config_file); 43 | } else { 44 | std::cerr << "Unable to open config file: " << path << std::endl; 45 | } 46 | base_dir_ = base_dir(path); 47 | } else { 48 | // compatibility with the original usage 49 | if (has_suffix(path, ".mnn")) { 50 | auto model_name = file_name(path); 51 | config_ = { 52 | {"llm_model", model_name}, 53 | {"llm_weight", model_name + ".weight"} 54 | }; 55 | base_dir_ = base_dir(path); 56 | } else { 57 | config_ = {}; 58 | base_dir_ = path; 59 | } 60 | } 61 | // using config's base_dir 62 | base_dir_ = config_.value("base_dir", base_dir_); 63 | // load llm_config for model info 64 | std::ifstream llm_config_file(llm_config()); 65 | if (llm_config_file.is_open()) { 66 | llm_config_ = json::parse(llm_config_file); 67 | } else { 68 | std::cerr << "Unable to open llm_config file: " << llm_config() << std::endl; 69 | } 70 | } 71 | 72 | #define DEFINE_CONFIG_PATH_ACCESSOR(name, defaultValue) \ 73 | std::string name() const { return base_dir_ + config_.value(#name, defaultValue); } 74 | 75 | #define DEFINE_CONFIG_ACCESSOR(name, type, defaultValue) \ 76 | type name() const { return config_.value(#name, defaultValue); } 77 | 78 | #define DEFINE_LLM_CONFIG_ACCESSOR(name, type, defaultValue) \ 79 | type name() const { return llm_config_.value(#name, defaultValue); } 80 | 81 | // < model file config start 82 | DEFINE_CONFIG_PATH_ACCESSOR(llm_config, "llm_config.json") 83 | DEFINE_CONFIG_PATH_ACCESSOR(llm_model, "llm.mnn") 84 | DEFINE_CONFIG_PATH_ACCESSOR(llm_weight, "llm.mnn.weight") 85 | DEFINE_CONFIG_PATH_ACCESSOR(lm_model, "lm.mnn") 86 | DEFINE_CONFIG_PATH_ACCESSOR(embedding_model, "embedding.mnn") 87 | DEFINE_CONFIG_PATH_ACCESSOR(embedding_file, "embeddings_bf16.bin") 88 | DEFINE_CONFIG_PATH_ACCESSOR(tokenizer_file, "tokenizer.txt") 89 | DEFINE_CONFIG_PATH_ACCESSOR(visual_model, "visual.mnn") 90 | // model file config end > 91 | 92 | // < generate config start 93 | DEFINE_CONFIG_ACCESSOR(max_new_tokens, int, 512) 94 | DEFINE_CONFIG_ACCESSOR(reuse_kv, bool, false) 95 | DEFINE_CONFIG_ACCESSOR(backend_type, std::string, "cpu") 96 | DEFINE_CONFIG_ACCESSOR(thread_num, int, 4) 97 | DEFINE_CONFIG_ACCESSOR(precision, std::string, "low") 98 | DEFINE_CONFIG_ACCESSOR(power, std::string, "normal") 99 | DEFINE_CONFIG_ACCESSOR(memory, std::string, "low") 100 | DEFINE_CONFIG_ACCESSOR(quant_qkv, int, 0) 101 | DEFINE_CONFIG_ACCESSOR(kvcache_limit, int, -1) 102 | DEFINE_CONFIG_ACCESSOR(use_mmap, bool, false) 103 | DEFINE_CONFIG_ACCESSOR(kvcache_mmap, bool, false) 104 | DEFINE_CONFIG_ACCESSOR(tmp_path, std::string, "") 105 | // generate config end > 106 | 107 | // < llm model config start 108 | DEFINE_LLM_CONFIG_ACCESSOR(is_single, bool, true) 109 | DEFINE_LLM_CONFIG_ACCESSOR(is_visual, bool, false) 110 | DEFINE_LLM_CONFIG_ACCESSOR(hidden_size, int, 4096) 111 | DEFINE_LLM_CONFIG_ACCESSOR(layer_nums, int, 32) 112 | DEFINE_LLM_CONFIG_ACCESSOR(key_value_shape, std::vector, std::vector{}) 113 | DEFINE_LLM_CONFIG_ACCESSOR(attention_mask, std::string, "int") 114 | DEFINE_LLM_CONFIG_ACCESSOR(attention_fused, bool, true) 115 | DEFINE_LLM_CONFIG_ACCESSOR(chat_template, std::string, "") 116 | DEFINE_LLM_CONFIG_ACCESSOR(prompt_template, std::string, "") 117 | // llm model config end > 118 | }; -------------------------------------------------------------------------------- /src/tokenizer.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // tokenizer.cpp 3 | // 4 | // Created by MNN on 2023/09/25. 5 | // ZhaodeWang 6 | // 7 | 8 | #include "tokenizer.hpp" 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | // base64 20 | static const std::string base64_chars = 21 | "ABCDEFGHIJKLMNOPQRSTUVWXYZ" 22 | "abcdefghijklmnopqrstuvwxyz" 23 | "0123456789+/"; 24 | 25 | static inline bool is_base64(unsigned char c) { 26 | return (isalnum(c) || (c == '+') || (c == '/')); 27 | } 28 | 29 | static inline size_t one_char_len(const char *src) { 30 | return "\1\1\1\1\1\1\1\1\1\1\1\1\2\2\3\4"[(*src & 0xFF) >> 4]; 31 | } 32 | 33 | static std::string base64_decode(const std::string& str) { 34 | int in_len = str.size(); 35 | int i = 0; 36 | int j = 0; 37 | int in_ = 0; 38 | unsigned char char_array_4[4], char_array_3[3]; 39 | std::string ret; 40 | 41 | while (in_len-- && ( str[in_] != '=') && is_base64(str[in_])) { 42 | char_array_4[i++] = str[in_]; in_++; 43 | if (i ==4) { 44 | for (i = 0; i <4; i++) { 45 | char_array_4[i] = base64_chars.find(char_array_4[i]); 46 | } 47 | char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); 48 | char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); 49 | char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; 50 | for (i = 0; (i < 3); i++) { 51 | ret.push_back(char_array_3[i]); 52 | } 53 | i = 0; 54 | } 55 | } 56 | if (i) { 57 | for (j = i; j < 4; j++) { 58 | char_array_4[j] = 0; 59 | } 60 | for (j = 0; j < 4; j++) { 61 | char_array_4[j] = base64_chars.find(char_array_4[j]); 62 | } 63 | char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); 64 | char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); 65 | char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; 66 | for (j = 0; (j < i - 1); j++) { 67 | ret.push_back(char_array_3[j]); 68 | } 69 | } 70 | return ret; 71 | } 72 | 73 | static inline void to_lower_case(std::string& str) { 74 | for (auto &c : str) { 75 | if (c >= 'A' && c <= 'Z') { 76 | c = std::tolower(static_cast(c)); 77 | } 78 | } 79 | } 80 | 81 | Tokenizer* Tokenizer::createTokenizer(const std::string& filename) { 82 | Tokenizer* tokenizer = nullptr; 83 | // check file 84 | std::ifstream tok_file(filename); 85 | if (!tok_file.good()) { 86 | printf("Failed: can't load tokenzier from: %s.\n", filename.c_str()); 87 | return tokenizer; 88 | } 89 | // check tokenizer info 90 | std::string line; 91 | std::getline(tok_file, line); 92 | std::istringstream line_str(line); 93 | int magic_number, tokenizer_type; 94 | line_str >> magic_number; 95 | if (magic_number != MAGIC_NUMBER) { 96 | printf("Failed: magic number is wrong from: %s.\n", filename.c_str()); 97 | return tokenizer; 98 | } 99 | line_str >> tokenizer_type; 100 | printf("tokenizer_type = %d\n", tokenizer_type); 101 | // create tokenizer 102 | switch (tokenizer_type) 103 | { 104 | case SENTENCEPIECE: 105 | tokenizer = new Sentencepiece(); 106 | break; 107 | case TIKTOIKEN: 108 | tokenizer = new Tiktoken(); 109 | break; 110 | case BERT: 111 | tokenizer = new BertTokenizer(); 112 | break; 113 | case HUGGINGFACE: 114 | tokenizer = new HuggingfaceTokenizer(); 115 | break; 116 | default: 117 | return tokenizer; 118 | } 119 | // load special tokens 120 | tokenizer->load_special(tok_file); 121 | // load vocabs 122 | tokenizer->load_vocab(tok_file); 123 | tok_file.close(); 124 | return tokenizer; 125 | } 126 | 127 | bool Tokenizer::is_stop(int token) { 128 | return std::find(stop_tokens_.begin(), stop_tokens_.end(), token) != stop_tokens_.end(); 129 | } 130 | 131 | bool Tokenizer::is_special(int token) { 132 | return std::find(special_tokens_.begin(), special_tokens_.end(), token) != special_tokens_.end(); 133 | } 134 | 135 | void Tokenizer::load_special(std::ifstream& tok_file) { 136 | std::string line; 137 | std::getline(tok_file, line); 138 | std::istringstream line_str(line); 139 | int special_num, stop_num, prefix_num; 140 | line_str >> special_num >> stop_num >> prefix_num; 141 | std::getline(tok_file, line); 142 | std::istringstream specail_line(line); 143 | if (special_num) { 144 | // load special tokens 145 | special_tokens_.resize(special_num); 146 | for (int i = 0; i < special_num; i++) { 147 | specail_line >> special_tokens_[i]; 148 | } 149 | } 150 | if (stop_num) { 151 | // load stop tokens 152 | stop_tokens_.resize(stop_num); 153 | for (int i = 0; i < stop_num; i++) { 154 | specail_line >> stop_tokens_[i]; 155 | } 156 | } 157 | if (prefix_num) { 158 | // load prefix tokens 159 | prefix_tokens_.resize(prefix_num); 160 | for (int i = 0; i < prefix_num; i++) { 161 | specail_line >> prefix_tokens_[i]; 162 | } 163 | } 164 | } 165 | 166 | std::vector Tokenizer::encode(const std::string& str) { 167 | std::vector ids = prefix_tokens_; 168 | if (!special_tokens_.empty()) { 169 | std::string text = str; 170 | size_t start = 0; 171 | for (size_t i = 0; i < text.length(); ++i) { 172 | for (auto special_id : special_tokens_) { 173 | const auto& token = decode(special_id); 174 | if (token.empty()) continue; 175 | if (i + token.length() <= text.length() && text.substr(i, token.length()) == token) { 176 | if (i > start) { 177 | encode(text.substr(start, i - start), ids); 178 | } 179 | ids.push_back(special_id); 180 | start = i + token.length(); 181 | i = start - 1; 182 | break; 183 | } 184 | } 185 | } 186 | if (start < text.length()) { 187 | encode(text.substr(start), ids); 188 | } 189 | } else { 190 | encode(str, ids); 191 | } 192 | return ids; 193 | } 194 | 195 | bool Sentencepiece::load_vocab(std::ifstream& tok_file) { 196 | std::string line, token; 197 | std::getline(tok_file, line); 198 | int vocab_len = std::stoi(line); 199 | float score; 200 | int type; 201 | sentence_pieces_.resize(vocab_len); 202 | for (int index = 0; index < vocab_len; index++) { 203 | std::getline(tok_file, line); 204 | std::istringstream line_str(line); 205 | line_str >> token >> score >> type; 206 | token = base64_decode(token); 207 | auto piece_type = static_cast(type); 208 | SentencePiece piece = {token, score, piece_type}; 209 | sentence_pieces_[index] = std::move(piece); 210 | if (piece_type == PieceType::NORMAL) { 211 | pieces_.insert({token, index}); 212 | } else { 213 | reserved_id_map_.insert({token, index}); 214 | if (piece_type == PieceType::UNKNOWN) { 215 | unk_id_ = index; 216 | } 217 | } 218 | } 219 | return true; 220 | } 221 | 222 | int Sentencepiece::piece_to_id(const std::string& piece) const { 223 | auto it = reserved_id_map_.find(piece); 224 | if (it != reserved_id_map_.end()) { 225 | return it->second; 226 | } 227 | auto it2 = pieces_.find(piece); 228 | if (it2 != pieces_.end()) { 229 | return it2->second; 230 | } 231 | return unk_id_; 232 | } 233 | 234 | std::string Sentencepiece::byte_to_piece(unsigned char c) const { 235 | const int len = ::snprintf(nullptr, 0, "<0x%02X>", c); 236 | std::string s; 237 | s.resize(len); 238 | ::snprintf(&s[0], s.size() + 1, "<0x%02X>", c); 239 | return s; 240 | } 241 | 242 | // ref: https://github.com/google/sentencepiece/blob/master/src/bpe_model.cc 243 | Sentencepiece::EncodeResult Sentencepiece::bpe_encode(string_view_ normalized, float alpha) { 244 | // util class begin 245 | struct SymbolPair { 246 | int left; // left index of this pair 247 | int right; // right index of this pair 248 | float score; // score of this pair. large is better. 249 | size_t size; // length of this piece 250 | }; 251 | 252 | class SymbolPairComparator { 253 | public: 254 | const bool operator()(SymbolPair *h1, SymbolPair *h2) { 255 | return (h1->score < h2->score || (h1->score == h2->score && h1->left > h2->left)); 256 | } 257 | }; 258 | 259 | struct Symbol { 260 | int prev; // prev index of this symbol. -1 for BOS. 261 | int next; // next index of tihs symbol. -1 for EOS. 262 | bool freeze = false; // this symbol is never be merged. 263 | string_view_ piece; 264 | }; 265 | // util class end 266 | 267 | using Agenda = std::priority_queue, SymbolPairComparator>; 268 | Agenda agenda; 269 | std::vector symbols; 270 | symbols.reserve(normalized.size()); 271 | // Reverse merge rules. key: merged symbol, value: pair of original symbols. 272 | std::unordered_map> rev_merge; 273 | // SymbolPair holder. 274 | std::vector> symbol_pair_holder; 275 | // Lookup new symbol pair at [left, right] and inserts it to agenda. 276 | auto MaybeAddNewSymbolPair = [this, &symbol_pair_holder, &symbols, &agenda, &rev_merge](int left, int right) { 277 | if (left == -1 || right == -1 || symbols[left].freeze || symbols[right].freeze) { 278 | return; 279 | } 280 | const string_view_ piece(symbols[left].piece.data(), symbols[left].piece.size() + symbols[right].piece.size()); 281 | std::string piece_str(piece.to_string()); 282 | const auto it = pieces_.find(piece_str); 283 | if (it == pieces_.end()) { 284 | return; 285 | } 286 | symbol_pair_holder.emplace_back(new SymbolPair); 287 | auto *h = symbol_pair_holder.back().get(); 288 | h->left = left; 289 | h->right = right; 290 | h->score = get_score(it->second); 291 | h->size = piece.size(); 292 | agenda.push(h); 293 | 294 | // Makes `rev_merge` for resegmentation. 295 | if (is_unused(it->second)) { 296 | rev_merge[piece] = std::make_pair(symbols[left].piece, symbols[right].piece); 297 | } 298 | }; 299 | // Splits the input into character sequence 300 | int index = 0; 301 | while (!normalized.empty()) { 302 | Symbol s; 303 | // const int mblen = matcher_->PrefixMatch(normalized, &s.freeze); 304 | int mblen = std::min(normalized.size(), one_char_len(normalized.data())); 305 | s.piece = string_view_(normalized.data(), mblen); 306 | s.prev = index == 0 ? -1 : index - 1; 307 | normalized.remove_prefix(mblen); 308 | s.next = normalized.empty() ? -1 : index + 1; 309 | ++index; 310 | symbols.emplace_back(s); 311 | } 312 | 313 | if (symbols.empty()) { 314 | return {}; 315 | } 316 | // Lookup all bigrams. 317 | for (size_t i = 1; i < symbols.size(); ++i) { 318 | MaybeAddNewSymbolPair(i - 1, i); 319 | } 320 | 321 | // BPE-dropout: https://arxiv.org/pdf/1910.13267.pdf 322 | // std::mt19937 *rand_gen = nullptr; 323 | std::mt19937 rand_gen; 324 | auto skip_merge = [&]() { 325 | if (alpha <= 0.0) return false; 326 | if (alpha >= 1.0) return true; 327 | // if (rand_gen == nullptr) rand_gen = random::GetRandomGenerator(); 328 | std::uniform_real_distribution<> gen(0.0, 1.0); 329 | return gen(rand_gen) < alpha; 330 | }; 331 | 332 | // Main loop. 333 | while (!agenda.empty()) { 334 | SymbolPair *top = agenda.top(); 335 | agenda.pop(); 336 | 337 | // `top` is no longer available. 338 | if (symbols[top->left].piece.empty() || symbols[top->right].piece.empty() || 339 | symbols[top->left].piece.size() + symbols[top->right].piece.size() != top->size) { 340 | continue; 341 | } 342 | 343 | if (skip_merge()) continue; 344 | // Replaces symbols with `top` rule. 345 | symbols[top->left].piece = string_view_( 346 | symbols[top->left].piece.data(), 347 | symbols[top->left].piece.size() + symbols[top->right].piece.size()); 348 | 349 | // Updates prev/next pointers. 350 | symbols[top->left].next = symbols[top->right].next; 351 | if (symbols[top->right].next >= 0) { 352 | symbols[symbols[top->right].next].prev = top->left; 353 | } 354 | symbols[top->right].piece = string_view_(""); 355 | 356 | // Adds new symbol pairs which are newly added after symbol replacement. 357 | MaybeAddNewSymbolPair(symbols[top->left].prev, top->left); 358 | MaybeAddNewSymbolPair(top->left, symbols[top->left].next); 359 | } 360 | 361 | std::function resegment; 362 | resegment = [this, &resegment, &rev_merge](string_view_ w, EncodeResult *output) -> void { 363 | std::string w_str(w.to_string()); 364 | const int id = piece_to_id(w_str); 365 | // std::cout << "piece: " << w << ", id = " << id << std::endl; 366 | if (id == -1 || !is_unused(id)) { 367 | output->emplace_back(w, id); 368 | return; 369 | } 370 | const auto p = rev_merge.find(w); 371 | if (p == rev_merge.end()) { 372 | // This block will never be called, as `rev_merge` stores all the 373 | // resegmentation info for unused id. 374 | output->emplace_back(w, id); 375 | return; 376 | } 377 | // Recursively resegment left and right symbols. 378 | resegment(p->second.first, output); 379 | resegment(p->second.second, output); 380 | }; 381 | EncodeResult output; 382 | for (int index = 0; index != -1; index = symbols[index].next) { 383 | resegment(symbols[index].piece, &output); 384 | } 385 | return output; 386 | } 387 | 388 | void Sentencepiece::encode(const std::string& str, std::vector& ids) { 389 | auto result = bpe_encode(str); 390 | size_t consumed = 0; 391 | for (const auto &p : result) { 392 | const string_view_ w = p.first; // piece 393 | const int id = p.second; // id 394 | const bool is_unk = (id == unk_id_); 395 | if (is_unk && byte_fall_back_) { 396 | // Decomposes an unknown piece into UTF-8 bytes 397 | for (int i = 0; i < w.size(); ++i) { 398 | // Create a byte piece 399 | const char b = w[i]; 400 | const auto piece = byte_to_piece(b); 401 | auto sp_id = piece_to_id(piece); 402 | ids.push_back(sp_id); 403 | } 404 | } else { 405 | ids.push_back(id); 406 | } 407 | } 408 | } 409 | 410 | std::string Sentencepiece::decode(int id) { 411 | auto piece = sentence_pieces_[id].piece; 412 | int pos = piece.find("▁"); 413 | if (pos != -1) { 414 | piece.replace(pos, pos + 3, " "); 415 | } 416 | return piece; 417 | } 418 | 419 | float Sentencepiece::get_score(int id) const { 420 | return sentence_pieces_[id].score; 421 | } 422 | 423 | bool Sentencepiece::is_unused(int id) const { 424 | return sentence_pieces_[id].type == PieceType::UNUSED; 425 | } 426 | 427 | bool Sentencepiece::is_control(int id) const { 428 | return sentence_pieces_[id].type == PieceType::CONTROL; 429 | } 430 | 431 | bool Tiktoken::load_vocab(std::ifstream& tok_file) { 432 | std::string line; 433 | std::getline(tok_file, line); 434 | int vocab_len = std::stoi(line); 435 | // load vocab 436 | decoder_.resize(vocab_len); 437 | for (int i = 0; i < vocab_len; i++) { 438 | std::getline(tok_file, line); 439 | auto token = base64_decode(line); 440 | encoder_.insert({token, i}); 441 | decoder_[i] = token; 442 | } 443 | return true; 444 | } 445 | 446 | void Tiktoken::encode(const std::string& str, std::vector& ids) { 447 | if (str.empty()) { 448 | return; 449 | } 450 | size_t i = 0; 451 | while (i < str.size()) { 452 | bool found_pair = false; 453 | // Attempt to match the longest possible symbol 454 | size_t longest_match_len = 0; 455 | std::string longest_match; 456 | 457 | // Check substrings of decreasing length 458 | for (size_t len = str.size() - i; len > 0; --len) { 459 | std::string token = str.substr(i, len); 460 | auto it = encoder_.find(token); 461 | if (it != encoder_.end()) { 462 | if (len > longest_match_len) { 463 | longest_match_len = len; 464 | longest_match = it->first; 465 | } 466 | } 467 | } 468 | 469 | if (!longest_match.empty()) { 470 | ids.push_back(encoder_.at(longest_match)); 471 | i += longest_match_len; 472 | } else { 473 | // If no matching symbol is found, this typically means an error in the encoding 474 | // or the input text contains characters that the encoder doesn't know how to handle 475 | std::cerr << "Error: No encoding found for the sequence starting at position " << i << std::endl; 476 | return; 477 | } 478 | } 479 | } 480 | 481 | std::string Tiktoken::decode(int id) { 482 | if (id >= decoder_.size()) { 483 | return ""; 484 | } 485 | return decoder_[id]; 486 | } 487 | 488 | std::vector BertTokenizer::word_piece(const std::string& token) { 489 | auto it = encoder_.find(token); 490 | if (it != encoder_.end()) { 491 | return {it->second}; 492 | } 493 | std::vector ids; 494 | std::string current = token; 495 | while (!current.empty()) { 496 | int match_id = -1; 497 | size_t match_pos = 0; 498 | for (int len = current.size(); len > 0; --len) { 499 | std::string candidate = current.substr(0, len); 500 | if (!ids.empty()) { 501 | candidate = "##" + candidate; 502 | } 503 | auto it = encoder_.find(candidate); 504 | if (it != encoder_.end()) { 505 | match_id = it->second; 506 | match_pos = len; 507 | break; 508 | } 509 | } 510 | // [UNK] 511 | if (match_id == -1) { 512 | ids.push_back(100); 513 | break; 514 | } 515 | ids.push_back(match_id); 516 | // not first word, adding ## prefix 517 | current = current.substr(match_pos); 518 | } 519 | return ids; 520 | } 521 | 522 | void BertTokenizer::encode(const std::string& str, std::vector& ids) { 523 | std::vector tokens; 524 | std::string current_token; 525 | size_t i = 0; 526 | while (i < str.size()) { 527 | current_token.clear(); 528 | unsigned char c = static_cast(str[i]); 529 | // handle multi-byte UTF-8 characters 530 | if ((c & 0x80) != 0) { 531 | unsigned char mask = 0xE0; // 1110 0000 for 3-byte char 532 | if ((c & mask) == mask) { 533 | current_token = str.substr(i, 3); 534 | i += 3; 535 | } else { 536 | ++i; 537 | continue; 538 | } 539 | } 540 | // handle continuous sequence of letters and digits 541 | else if (std::isalnum(c)) { 542 | while (i < str.size() && std::isalnum(static_cast(str[i]))) { 543 | current_token += std::tolower(str[i]); 544 | ++i; 545 | } 546 | } 547 | // handle punctuation and symbols 548 | else if (std::ispunct(c)) { 549 | current_token = str[i]; 550 | ++i; 551 | } 552 | // handle space, tab, enter 553 | else if (std::isspace(c)) { 554 | ++i; 555 | continue; 556 | } 557 | // handle any other single-byte characters 558 | else { 559 | current_token = str[i]; 560 | ++i; 561 | } 562 | if (!current_token.empty()) { 563 | tokens.push_back(current_token); 564 | } 565 | } 566 | 567 | for (auto token : tokens) { 568 | for (auto id : word_piece(token)) { 569 | ids.push_back(id); 570 | } 571 | } 572 | } 573 | 574 | std::wstring utf8_to_wstring(const std::string& str) { 575 | std::wstring_convert> myconv; 576 | return myconv.from_bytes(str); 577 | } 578 | 579 | std::string wstring_to_utf8(const std::wstring& str) { 580 | std::wstring_convert> myconv; 581 | return myconv.to_bytes(str); 582 | } 583 | 584 | // Given a token as a UTF8 string, encode each byte into an wchar_t 585 | void byte_encode_token(const std::string& token, 586 | const std::unordered_map& b2u, 587 | std::wstring* result) { 588 | result->resize(0); 589 | for (char c : token) { 590 | wchar_t wc = b2u.at(uint8_t(c)); 591 | result->push_back(wc); 592 | } 593 | } 594 | 595 | bool HuggingfaceTokenizer::load_vocab(std::ifstream& tok_file) { 596 | std::string line, token; 597 | // get nums 598 | int vocab_len, merge_len; 599 | std::getline(tok_file, line); 600 | std::istringstream line_str(line); 601 | line_str >> vocab_len >> merge_len; 602 | // load vocab 603 | decoder_.resize(vocab_len); 604 | for (int i = 0; i < vocab_len; i++) { 605 | std::getline(tok_file, line); 606 | encoder_.insert({line, i}); 607 | decoder_[i] = line; 608 | } 609 | // load merge_rule 610 | for (int i = 0; i < merge_len; i++) { 611 | std::getline(tok_file, line); 612 | int d = line.find(" "); 613 | bpe_ranks_.insert({{utf8_to_wstring(line.substr(0, d)), 614 | utf8_to_wstring(line.substr(d + 1))}, i}); 615 | } 616 | // bytes_to_unicode 617 | auto _insert_range = [=](int start, int end) { 618 | for (int c = start; c <= end; c++) { 619 | b2u_.insert({uint8_t(c), wchar_t(c)}); 620 | } 621 | }; 622 | 623 | b2u_.clear(); 624 | _insert_range(L'!', L'~'); 625 | _insert_range(L'¡', L'¬'); 626 | _insert_range(L'®', L'ÿ'); 627 | 628 | int n = 0; 629 | for (int b = 0; b < 256; b++) { 630 | if (b2u_.find(uint8_t(b)) == b2u_.end()) { 631 | b2u_.insert({uint8_t(b), wchar_t(256 + n)}); 632 | n++; 633 | } 634 | } 635 | for (auto e : b2u_) { 636 | u2b_.insert({e.second, e.first}); 637 | } 638 | return true; 639 | } 640 | 641 | void get_pairs(const std::wstring& word, std::vector>* pairs) { 642 | pairs->clear(); 643 | 644 | if (word.size() < 2) return; 645 | 646 | wchar_t previous = word[0]; 647 | for (int i = 1; i < word.size(); i++) { 648 | pairs->push_back({std::wstring(1, previous), std::wstring(1, word[i])}); 649 | previous = word[i]; 650 | } 651 | } 652 | 653 | void HuggingfaceTokenizer::bpe(const std::wstring& token, const BPERanks& bpe_ranks, std::vector* result) { 654 | std::set merged; // records indices in pairs that were merged. 655 | auto _left = [](int i, std::set& merged) { 656 | for (int j = i - 1; j >= -1; j--) { 657 | if (merged.find(j) == merged.end()) return j; 658 | } 659 | return -1; 660 | }; 661 | auto _right = [](int i, int cap, std::set& merged) { 662 | for (int j = i + 1; j < cap; j++) { 663 | if (merged.find(j) == merged.end()) return j; 664 | } 665 | return cap; 666 | }; 667 | 668 | std::vector> pairs; 669 | get_pairs(token, &pairs); 670 | 671 | while (true) { 672 | int min_score = INT_MAX; 673 | int to_merge = -1; // indices into pairs. 674 | 675 | for (int i = 0; i < pairs.size(); ++i) { 676 | if (merged.find(i) == merged.end()) { // pair i is not merged. 677 | auto iter = bpe_ranks.find(pairs[i]); 678 | int score = iter != bpe_ranks.end() ? iter->second : INT_MAX; 679 | if (score < min_score) { 680 | min_score = score; 681 | to_merge = i; 682 | } 683 | } 684 | } 685 | 686 | if (to_merge == -1) break; 687 | 688 | merged.insert(to_merge); 689 | std::wstring merge_into = pairs[to_merge].first + pairs[to_merge].second; 690 | 691 | int l = _left(to_merge, merged); 692 | if (l >= 0) pairs[l].second = merge_into; 693 | int r = _right(to_merge, pairs.size(), merged); 694 | if (r < pairs.size()) pairs[r].first = merge_into; 695 | } // end while (true) 696 | 697 | if (merged.size() == pairs.size()) { 698 | result->push_back(token); 699 | 700 | } else { 701 | for (int i = 0; i < pairs.size(); ++i) { 702 | if (merged.find(i) == merged.end()) { 703 | if (_left(i, merged) < 0) result->push_back(pairs[i].first); 704 | result->push_back(pairs[i].second); 705 | } 706 | } 707 | } 708 | } 709 | 710 | void HuggingfaceTokenizer::encode(const std::string& str, std::vector& ids) { 711 | std::regex re("('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\\s\\w]+|\\s+)"); 712 | std::string input = str; 713 | std::vector result; 714 | std::string token; 715 | std::smatch match; 716 | while (std::regex_search(input, match, re)) { 717 | token = match.str(0); 718 | input = match.suffix().str(); 719 | std::wstring wtoken; 720 | for (char c : token) { 721 | wtoken.push_back(b2u_.at(uint8_t(c))); 722 | } 723 | 724 | std::vector bpe_tokens; 725 | bpe(wtoken, bpe_ranks_, &bpe_tokens); 726 | 727 | for (auto ws : bpe_tokens) { 728 | result.push_back(wstring_to_utf8(ws)); 729 | } 730 | } 731 | for (auto s : result) { 732 | ids.push_back(encoder_.at(s)); 733 | } 734 | } 735 | 736 | std::string HuggingfaceTokenizer::decode(int id) { 737 | // printf("decode id = %d, %lu, %s#\n", id, decoder_.size(), decoder_.at(id).c_str()); 738 | if (id >= decoder_.size()) { 739 | return ""; 740 | } 741 | std::wstring w = utf8_to_wstring(decoder_.at(id)); 742 | std::string r; 743 | for (wchar_t c : w) { 744 | if (u2b_.find(c) != u2b_.end()) { 745 | r.push_back(char(u2b_.at(c))); 746 | } 747 | } 748 | return r; 749 | } 750 | --------------------------------------------------------------------------------