├── .dockerignore ├── .gitignore ├── .gitmodules ├── .vscode ├── launch.json ├── setting.json └── tasks.json ├── CMakeLists.txt ├── Dockerfile ├── LICENSE ├── README.md ├── apps ├── CMakeLists.txt ├── tests │ ├── CMakeLists.txt │ ├── asr.cpp │ ├── llm.cpp │ ├── paraformer.cpp │ └── tts.cpp └── web_server.cpp ├── config.example.yaml ├── docker-compose.yml ├── extern └── onnxruntime-linux-x64-1.21.0 │ ├── GIT_COMMIT_ID │ ├── LICENSE │ ├── Privacy.md │ ├── README.md │ ├── ThirdPartyNotices.txt │ ├── VERSION_NUMBER │ ├── include │ ├── core │ │ └── providers │ │ │ ├── custom_op_context.h │ │ │ └── resource.h │ ├── cpu_provider_factory.h │ ├── onnxruntime_c_api.h │ ├── onnxruntime_cxx_api.h │ ├── onnxruntime_cxx_inline.h │ ├── onnxruntime_float16.h │ ├── onnxruntime_lite_custom_op.h │ ├── onnxruntime_run_options_config_keys.h │ ├── onnxruntime_session_options_config_keys.h │ └── provider_options.h │ └── lib │ ├── cmake │ └── onnxruntime │ │ ├── onnxruntimeConfig.cmake │ │ ├── onnxruntimeConfigVersion.cmake │ │ ├── onnxruntimeTargets-release.cmake │ │ └── onnxruntimeTargets.cmake │ ├── libonnxruntime.so │ ├── libonnxruntime.so.1 │ ├── libonnxruntime.so.1.21.0 │ ├── libonnxruntime_providers_shared.so │ └── pkgconfig │ └── libonnxruntime.pc ├── include ├── funasr │ └── paraformer │ │ ├── com-define.h │ │ ├── commonfunc.h │ │ ├── model.h │ │ ├── paraformer-online.h │ │ ├── paraformer.h │ │ ├── phone-set.h │ │ ├── seg-dict.h │ │ ├── utils.h │ │ └── vocab.h └── xz-cpp-server │ ├── asr │ ├── base.h │ ├── bytedancev2.h │ └── paraformer.h │ ├── common │ ├── logger.h │ ├── request.h │ ├── setting.h │ ├── threadsafe_queue.hpp │ └── tools.h │ ├── connection.h │ ├── llm │ ├── base.h │ ├── cozev3.h │ ├── dify.h │ └── openai.h │ ├── precomp.h │ ├── server.h │ ├── silero_vad │ ├── silero_vad.h │ └── vad.h │ └── tts │ ├── base.h │ ├── bytedancev3.h │ └── edge.h ├── models └── silero_vad.onnx ├── run.sh ├── scripts ├── install_deps.sh └── tests │ ├── concurr_ws_client.py │ ├── connect.py │ ├── https_echo.py │ ├── vad.py │ └── wss_echo.py ├── src ├── CMakeLists.txt ├── asr │ ├── CMakeLists.txt │ ├── base.cpp │ ├── bytedancev2.cpp │ └── paraformer.cpp ├── common │ ├── CMakeLists.txt │ ├── logger.cpp │ ├── request.cpp │ ├── setting.cpp │ └── tools.cpp ├── connection.cpp ├── llm │ ├── CMakeLists.txt │ ├── base.cpp │ ├── cozev3.cpp │ ├── dify.cpp │ └── openai.cpp ├── paraformer │ ├── CMakeLists.txt │ ├── model.cpp │ ├── paraformer-online.cpp │ ├── paraformer.cpp │ ├── phone-set.cpp │ ├── seg-dict.cpp │ ├── utils.cpp │ └── vocab.cpp ├── server.cpp ├── silero_vad │ ├── CMakeLists.txt │ ├── silero_vad.cpp │ └── vad.cpp └── tts │ ├── CMakeLists.txt │ ├── base.cpp │ ├── bytedancev3.cpp │ └── edge.cpp └── tests ├── CMakeLists.txt ├── common ├── test_request.cpp └── test_setting.cpp └── test_find_last_segment.cpp /.dockerignore: -------------------------------------------------------------------------------- 1 | .git 2 | __pycache__ 3 | *.pyc 4 | Dockerfile 5 | tmp 6 | build 7 | .vscode 8 | .cache 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | .cache 3 | __pycache__/ 4 | tmp/ 5 | config.yaml 6 | models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "extern/Catch2"] 2 | path = extern/Catch2 3 | url = https://github.com/catchorg/Catch2.git 4 | [submodule "extern/yaml-cpp"] 5 | path = extern/yaml-cpp 6 | url = https://github.com/jbeder/yaml-cpp.git 7 | [submodule "extern/kaldi-native-fbank"] 8 | path = extern/kaldi-native-fbank 9 | url = https://github.com/csukuangfj/kaldi-native-fbank.git 10 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // 使用 IntelliSense 了解相关属性。 3 | // 悬停以查看现有属性的描述。 4 | // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "type": "lldb", 9 | "request": "launch", 10 | "name": "Debug Server", 11 | "program": "${workspaceFolder}/build/apps/web_server", 12 | "args": [], 13 | "cwd": "${workspaceFolder}", 14 | "preLaunchTask": "build" 15 | }, 16 | { 17 | "type": "lldb", 18 | "request": "launch", 19 | "name": "Debug File", 20 | "program": "${workspaceFolder}/build/${input:exeName}", 21 | "args": [], 22 | "cwd": "${workspaceFolder}", 23 | "preLaunchTask": "build" 24 | }, 25 | { 26 | "type": "cmake", 27 | "request": "launch", 28 | "name": "CMake项目调试", 29 | "cmakeDebugType": "configure", 30 | "clean": false, 31 | "configureAll": false 32 | } 33 | ], 34 | "inputs": [ 35 | { 36 | "id": "exeName", 37 | "type": "promptString", 38 | "description": "Enter executable relative build dir." 39 | } 40 | ] 41 | } -------------------------------------------------------------------------------- /.vscode/setting.json: -------------------------------------------------------------------------------- 1 | { 2 | "cmake.sourceDirectory": "${workspaceFolder}", 3 | "cmake.buildDirectory": "${workspaceFolder}/build", 4 | "cmake.configureOnOpen": true 5 | } -------------------------------------------------------------------------------- /.vscode/tasks.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "2.0.0", 3 | "tasks": [ 4 | { 5 | "label": "cmake", 6 | "type": "shell", 7 | "command": "cmake", 8 | "args": [ 9 | ".." 10 | ], 11 | "options": { 12 | "cwd": "${workspaceFolder}/build" 13 | } 14 | }, 15 | { 16 | "label": "make", 17 | "type": "shell", 18 | "command": "make", 19 | "args": [], 20 | "options": { 21 | "cwd": "${workspaceFolder}/build" 22 | } 23 | }, 24 | { 25 | "label": "build", 26 | "dependsOn": ["cmake", "make"], 27 | "dependsOrder": "sequence" 28 | } 29 | ] 30 | } -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.16) 2 | 3 | project( 4 | xiaozhi-server 5 | VERSION 0.1 6 | DESCRIPTION "xiaozhi cpp server" 7 | LANGUAGES CXX 8 | ) 9 | 10 | set(CMAKE_CXX_STANDARD 20) # 默认使用 C++20 11 | set(CMAKE_CXX_EXTENSIONS OFF) # 关闭编译器扩展(如 GNU 扩展) 12 | set(CMAKE_CXX_STANDARD_REQUIRED ON) # 强制要求 C++20,避免降级 13 | 14 | include_directories("${PROJECT_SOURCE_DIR}/include") 15 | find_package(Boost COMPONENTS log log_setup iostreams random json REQUIRED) 16 | find_package(OpenSSL REQUIRED) 17 | find_library(OGG_LIBRARY NAMES ogg) 18 | find_library(OPUS_LIBRARY NAMES opus) 19 | find_library(ONNXRUNTIME_LIB NAMES onnxruntime PATHS "extern/onnxruntime-linux-x64-1.21.0/lib" NO_DEFAULT_PATH) 20 | include_directories("${PROJECT_SOURCE_DIR}/extern/onnxruntime-linux-x64-1.21.0/include") 21 | 22 | add_subdirectory(extern/kaldi-native-fbank) 23 | include_directories("${PROJECT_SOURCE_DIR}/extern/kaldi-native-fbank") 24 | add_subdirectory(extern/yaml-cpp) 25 | 26 | add_library(precomp INTERFACE) 27 | target_link_libraries(precomp INTERFACE Boost::log Boost::json yaml-cpp::yaml-cpp) 28 | target_precompile_headers(precomp INTERFACE "${PROJECT_SOURCE_DIR}/include/xz-cpp-server/precomp.h") 29 | 30 | add_subdirectory(apps) 31 | add_subdirectory(src) 32 | 33 | include(CTest) 34 | add_subdirectory(extern/Catch2) 35 | add_subdirectory(tests) 36 | 37 | file(GLOB HEADER_LIST CONFIGURE_DEPENDS "${PROJECT_SOURCE_DIR}/include/xz-cpp-server/**/*.h" "${PROJECT_SOURCE_DIR}/include/xz-cpp-server/**/*.hpp") 38 | source_group( 39 | TREE "${PROJECT_SOURCE_DIR}/include" 40 | PREFIX "Header Files" 41 | FILES ${HEADER_LIST} 42 | ) -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # syntax=docker/dockerfile:1-labs 2 | FROM ubuntu:24.04 AS builder 3 | WORKDIR /app 4 | 5 | COPY --parents CMakeLists.txt src extern include apps scripts tests /app/ 6 | 7 | RUN apt-get update && apt-get install -y \ 8 | cmake \ 9 | g++ && \ 10 | bash scripts/install_deps.sh && \ 11 | mkdir -p build && \ 12 | cd build && \ 13 | cmake -DKALDI_NATIVE_FBANK_BUILD_TESTS=OFF -DKALDI_NATIVE_FBANK_BUILD_PYTHON=OFF -DCMAKE_BUILD_TYPE=Release .. && \ 14 | make web_server 15 | 16 | FROM ubuntu:24.04 17 | WORKDIR /app 18 | COPY extern/onnxruntime-linux-x64-1.21.0/lib /app/lib 19 | ENV LD_LIBRARY_PATH=/app/lib:$LD_LIBRARY_PATH 20 | 21 | COPY scripts/install_deps.sh /app/scripts/install_deps.sh 22 | RUN bash scripts/install_deps.sh 23 | 24 | COPY --from=builder /app/build/lib /app/lib 25 | COPY --from=builder /app/build/apps/web_server /app/web_server 26 | 27 | EXPOSE 8000 28 | CMD [ "/app/web_server" ] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 daxpot 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 小智 ESP-32 C++后端服务 (xz-cpp-server) 2 | 3 | 本项目为开源智能硬件项目 [xiaozhi-esp32](https://github.com/78/xiaozhi-esp32) 4 | 提供后端服务。根据 [小智通信协议](https://ccnphfhqs21z.feishu.cn/wiki/M0XiwldO9iJwHikpXD5cEx71nKh) 使用 `C++` 实现。 5 | 6 | 项目根据Python版本的[xiaozhi-esp32-server](https://github.com/xinnan-tech/xiaozhi-esp32-server)改版 7 | 8 | 基于BOOST 1.83, C++20协程开发。 9 | 10 | --- 11 | 12 | ## 功能清单 ✨ 13 | 14 | ### 已实现 ✅ 15 | 16 | - **通信协议** 17 | 基于 `xiaozhi-esp32` 协议,通过 WebSocket 实现数据交互。 18 | - **对话交互** 19 | 支持唤醒对话、手动对话及实时打断。长时间无对话时自动休眠 20 | - **语音识别** 21 | 支持国语(默认使用 FunASR的paraformer-zh-streaming模型)。 22 | - **LLM 模块** 23 | 支持灵活切换 LLM 模块,默认使用 ChatGLMLLM,也可选用Coze, Didy或其他类似openai的接口。 24 | - **TTS 模块** 25 | 支持 EdgeTTS(默认)、火山引擎豆包 TTS TTS 接口,满足语音合成需求。 26 | 27 | ## 本项目支持的平台/组件列表 📋 28 | 29 | ### LLM 语言模型 30 | 31 | | 类型 | 平台名称 | 使用方式 | 收费模式 | 备注 | 32 | |:---:|:------------------:|:---------------------:|:-----------:|:-----------------------------------------------------------------------------------------------------------------------:| 33 | | LLM | 智谱(ChatGLMLLM) | openai 接口调用 | 免费 | 虽然免费,仍需[点击申请密钥](https://bigmodel.cn/usercenter/proj-mgmt/apikeys) | 34 | | LLM | DifyLLM | dify 接口调用 | 免费/消耗 token | 本地化部署,注意配置提示词需在 Dify 控制台设置 | 35 | | LLM | CozeLLM | coze 接口调用 | 消耗 token | 需提供 bot_id、user_id 及个人令牌 | 36 | 37 | 实际上,任何支持 openai 接口调用的 LLM 均可接入使用。 38 | 39 | ### VAD 语音活动检测 40 | 41 | | 类型 | 平台名称 | 使用方式 | 收费模式 | 备注 | 42 | |:---:|:---------:|:----:|:----:|:--:| 43 | | VAD | SileroVAD | 本地使用 | 免费 | | 44 | 45 | --- 46 | 47 | ### ASR 语音识别 48 | 49 | | 类型 | 平台名称 | 使用方式 | 收费模式 | 备注 | 50 | |:---:|:---------:|:----:|:----:|:--:| 51 | | ASR | FunASR | 本地使用 | 免费 | | 52 | | ASR | BytedanceASRV2 | 接口调用 | 收费 | | 53 | 54 | --- 55 | 56 | 57 | ### TTS 语音合成 58 | 59 | | 类型 | 平台名称 | 使用方式 | 收费模式 | 备注 | 60 | |:---:|:----------------------:|:----:|:--------:|:-------------------------------------------------------------------------:| 61 | | TTS | EdgeTTS | 接口调用 | 免费 | 默认 TTS,基于微软语音合成技术 | 62 | | TTS | 火山引擎豆包 TTS (BytedanceTTSV3) | 接口调用 | 消耗 token | [点击创建密钥](https://console.volcengine.com/speech/service/10007) | 63 | 64 | --- 65 | 66 | --- 67 | 68 | ## 使用方式 69 | 70 | 参看[install_deps.sh](https://github.com/daxpot/xiaozhi-cpp-server/blob/master/scripts/install_deps.sh)安装依赖和构建 71 | 72 | 启动命令 73 | ``` 74 | # 修改config.example.yaml为config.yaml,并填写自己的配置信息 75 | 76 | ./build/apps/web_server 77 | # 或者指定config path 78 | ./build/apps/web_server config.yaml 79 | ``` 80 | -------------------------------------------------------------------------------- /apps/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_executable(web_server web_server.cpp) 2 | target_link_libraries(web_server PUBLIC server) 3 | 4 | add_subdirectory(tests) -------------------------------------------------------------------------------- /apps/tests/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | file(GLOB_RECURSE cpp_files *.cpp) 2 | 3 | foreach(cpp_file ${cpp_files}) 4 | get_filename_component(file_name ${cpp_file} NAME_WE) 5 | add_executable("tests_${file_name}" ${cpp_file}) 6 | target_link_libraries("tests_${file_name}" PRIVATE precomp) 7 | target_link_libraries("tests_${file_name}" PRIVATE common) 8 | endforeach() 9 | 10 | target_link_libraries(tests_asr PUBLIC asr) 11 | 12 | target_link_libraries(tests_tts PUBLIC tts) 13 | 14 | target_link_libraries(tests_llm PUBLIC llm) 15 | 16 | target_link_libraries(tests_paraformer PUBLIC paraformer) 17 | target_link_libraries(tests_paraformer PUBLIC ${OPUS_LIBRARY}) 18 | -------------------------------------------------------------------------------- /apps/tests/asr.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | net::awaitable test() { 8 | auto executor = co_await net::this_coro::executor; 9 | auto setting = xiaozhi::Setting::getSetting(); 10 | auto asr = xiaozhi::asr::createASR(executor); 11 | for(size_t index=0; index <= 92; index++) { 12 | std::ifstream file(std::format("tmp/example/opus_data_{}.opus", index), std::ifstream::binary); 13 | std::vector data((std::istreambuf_iterator(file)), std::istreambuf_iterator()); 14 | file.close(); 15 | beast::flat_buffer buffer; 16 | // 准备写入数据 17 | auto writable = buffer.prepare(data.size()); // 分配空间 18 | net::buffer_copy(writable, net::buffer(data)); // 复制数据 19 | buffer.commit(data.size()); // 提交数据 20 | co_await asr->detect_opus(buffer); 21 | } 22 | auto text = co_await asr->detect_opus(std::nullopt); 23 | BOOST_LOG_TRIVIAL(info) << "asr detect:" << text; 24 | } 25 | 26 | int main() { 27 | init_logging("DEBUG"); 28 | boost::asio::io_context ioc; 29 | net::co_spawn(ioc, test(), std::bind_front(tools::on_spawn_complete, "Test asr")); 30 | ioc.run(); 31 | } -------------------------------------------------------------------------------- /apps/tests/llm.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | net::awaitable test() { 6 | auto executor = co_await net::this_coro::executor; 7 | auto llm = xiaozhi::llm::createLLM(executor); 8 | auto session_id = co_await llm->create_session(); 9 | BOOST_LOG_TRIVIAL(info) << "session_id:" << session_id; 10 | boost::json::array dialogue = { 11 | boost::json::object{{"role", "system"}, {"content", R"(你是一个叫小智/小志的台湾女孩,说话机车,声音好听,习惯简短表达,爱用网络梗。 12 | 请注意,要像一个人一样说话,请不要回复表情符号、代码、和xml标签。 13 | 现在我正在和你进行语音聊天,我们开始吧。 14 | 如果用户希望结束对话,请在最后说“拜拜”或“再见”。)"}}, 15 | boost::json::object{{"role", "user"}, {"content", "你好,小智。今天天气怎么样"}} 16 | }; 17 | co_await llm->response(dialogue, [](const std::string_view text) { 18 | BOOST_LOG_TRIVIAL(info) << text; 19 | }); 20 | } 21 | 22 | int main() { 23 | init_logging("DEBUG"); 24 | boost::asio::io_context ioc; 25 | net::co_spawn(ioc, test(), std::bind_front(tools::on_spawn_complete, "Test llm")); 26 | ioc.run(); 27 | } -------------------------------------------------------------------------------- /apps/tests/paraformer.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | int main() { 7 | init_logging("DEBUG"); 8 | std::map model_path = { 9 | {"model-dir", "models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online"}, 10 | {"quantize", "true"} 11 | }; 12 | auto offline_model = funasr::CreateModel(model_path, 1); 13 | std::vector chunk_size = {5, 10, 5}; 14 | auto online_model = funasr::CreateModel(offline_model, chunk_size); 15 | int error; 16 | OpusDecoder* decoder = opus_decoder_create(16000, 1, &error); 17 | std::string full_result; 18 | std::vector pcm_data(960*10); 19 | for(size_t index=0; index <= 92; index++) { 20 | std::ifstream file(std::format("tmp/example/opus_data_{}.opus", index), std::ifstream::binary); 21 | std::vector data((std::istreambuf_iterator(file)), std::istreambuf_iterator()); 22 | file.close(); 23 | int frame_size = opus_decode_float(decoder, data.data(), data.size(), pcm_data.data() + (index%10)*960, 960, 0); 24 | if(frame_size < 0) { 25 | BOOST_LOG_TRIVIAL(error) << "Opus decode error:" << opus_strerror(frame_size); 26 | return -1; 27 | } 28 | if(index % 10 == 9) { 29 | bool is_final = (index / 10 == 8); // 索引从 0 到 92,共 93 段 30 | std::string result = online_model->Forward(pcm_data.data(), pcm_data.size(), is_final); 31 | full_result += result; 32 | BOOST_LOG_TRIVIAL(info) << "Chunk " << index << " result: " << result; 33 | } 34 | } 35 | BOOST_LOG_TRIVIAL(info) << "Full result: " << full_result; 36 | 37 | delete offline_model; 38 | delete online_model; 39 | } -------------------------------------------------------------------------------- /apps/tests/tts.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | bool write_binary_to_file(const std::string& filename, const std::vector& data) { 7 | // 以二进制模式打开文件 8 | std::ofstream out_file(filename, std::ios::binary); 9 | 10 | // 检查文件是否成功打开 11 | if (!out_file.is_open()) { 12 | BOOST_LOG_TRIVIAL(error) << "Failed to open file: " << filename; 13 | return false; 14 | } 15 | 16 | // 写入数据 17 | for(auto& item : data) { 18 | out_file.write(item.data(), item.size()); 19 | } 20 | 21 | // 检查写入是否成功 22 | if (!out_file.good()) { 23 | BOOST_LOG_TRIVIAL(error) << "Error occurred while writing to file: " << filename; 24 | return false; 25 | } 26 | 27 | // 关闭文件(析构时会自动关闭,但显式关闭是好习惯) 28 | out_file.close(); 29 | return true; 30 | } 31 | 32 | net::awaitable test() { 33 | auto executor = co_await net::this_coro::executor; 34 | auto asr = xiaozhi::tts::createTTS(executor); 35 | auto audio = co_await asr->text_to_speak("你好小智,我是你的朋友"); 36 | // write_binary_to_file("../tmp/test.opus", audio); 37 | // std::cout << "test end" << std::endl; 38 | } 39 | 40 | int main() { 41 | init_logging("DEBUG"); 42 | boost::asio::io_context ioc; 43 | net::co_spawn(ioc, test(), std::bind_front(tools::on_spawn_complete, "Test tts")); 44 | ioc.run(); 45 | } -------------------------------------------------------------------------------- /apps/web_server.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | 6 | int main(int argc, char* argv[]) { 7 | auto setting = xiaozhi::Setting::getSetting(argc > 1 ? argv[1] : 0); 8 | init_logging(setting->config["log"]["log_level"].as()); 9 | auto server = xiaozhi::Server(setting); 10 | server.run(); 11 | return 0; 12 | } -------------------------------------------------------------------------------- /config.example.yaml: -------------------------------------------------------------------------------- 1 | # 如果您是一名开发者,建议阅读以下内容。如果不是开发者,可以忽略这部分内容。 2 | # 在开发中,将【config.example.yaml】复制一份,改成【config.yaml】 3 | 4 | #io context并发线程数 5 | threads: 4 6 | # 服务器基础配置(Basic server configuration) 7 | server: 8 | # 服务器监听地址和端口(Server listening address and port) 9 | ip: 0.0.0.0 10 | port: 8000 11 | # 认证配置 12 | auth: 13 | # 是否启用认证 14 | enabled: false 15 | # 设备的token,可以在编译固件的环节,写入你自己定义的token 16 | # 固件上的token和以下的token如果能对应,才能连接本服务端 17 | tokens: 18 | - token: "your-token1" # 设备1的token 19 | name: "your-device-name1" # 设备1标识 20 | - token: "your-token2" # 设备2的token 21 | name: "your-device-name2" # 设备2标识 22 | 23 | log: 24 | # 设置日志等级:INFO、DEBUG 25 | log_level: INFO 26 | # 设置日志路径 27 | log_dir: tmp 28 | # 设置日志文件 29 | log_file: "server.log" 30 | 31 | welcome: 32 | type: hello 33 | transport: websocket 34 | audio_params: 35 | sample_rate: 24000 36 | prompt: | 37 | 你是一个叫小智/小志的台湾女孩,说话机车,声音好听,习惯简短表达,爱用网络梗。 38 | 请注意,要像一个人一样说话,请不要回复表情符号、代码、和xml标签。 39 | 当前时间是:{date_time},现在我正在和你进行语音聊天,我们开始吧。 40 | 如果用户希望结束对话,请在最后说“拜拜”或“再见”。 41 | 42 | # 没有语音输入多久后断开连接(秒),默认2分钟,即120秒 43 | close_connection_no_voice_time: 120 44 | 45 | CMD_exit: 46 | - "退出" 47 | - "关闭" 48 | - "拜拜" 49 | - "再见" 50 | 51 | # 具体处理时选择的模块(The module selected for specific processing) 52 | selected_module: 53 | ASR: Paraformer 54 | VAD: SileroVAD 55 | # 将根据配置名称对应的type调用实际的LLM适配器 56 | LLM: ChatGLMLLM 57 | # TTS将根据配置名称对应的type调用实际的TTS适配器 58 | TTS: EdgeTTS 59 | 60 | ASR: 61 | Paraformer: 62 | #使用命令:funasr-export ++model=paraformer-zh-streaming ++quantize=true 导出onnx模型,导出的目录填到下面去 63 | model_dir: funasr的paraformer online模型目录 64 | quantize: true #使用量化模型 65 | thread_num: 4 66 | BytedanceASRV2: 67 | appid: 你的火山引擎语音合成服务appid 68 | access_token: 你的火山引擎语音合成服务access_token 69 | cluster: volcengine_input_common 70 | VAD: 71 | SileroVAD: 72 | threshold: 0.5 73 | model_path: models/silero_vad.onnx 74 | min_silence_duration_ms: 700 # 如果说话停顿比较长,可以把这个值设置大一些 75 | 76 | LLM: 77 | ChatGLMLLM: 78 | # glm-4-flash 是免费的,但是还是需要注册填写api_key的 79 | # 可在这里找到你的api key https://bigmodel.cn/usercenter/proj-mgmt/apikeys 80 | type: openai 81 | model_name: glm-4-flash 82 | url: https://open.bigmodel.cn/api/paas/v4/ 83 | api_key: 你的chat-glm api key 84 | DifyLLM: 85 | # 建议使用本地部署的dify接口,国内部分区域访问dify公有云接口可能会受限 86 | # 如果使用DifyLLM,配置文件里prompt(提示词)是无效的,需要在dify控制台设置提示词 87 | base_url: https://api.dify.cn/v1 88 | api_key: 你的DifyLLM api key 89 | CozeLLMV3: 90 | bot_id: 你的bot_id 91 | user_id: 你的user_id 92 | personal_access_token: 你的coze个人令牌 93 | TTS: 94 | # 当前支持的type为edge、doubao,可自行适配 95 | EdgeTTS: 96 | # 定义TTS API类型,EdgeTTS支持的是24000采样率,需要将welcom.audio_params.sample_rate设置为24000 97 | voice: zh-CN-XiaoxiaoNeural 98 | BytedanceTTSV3: 99 | voice: zh_female_wanwanxiaohe_moon_bigtts 100 | appid: 你的火山引擎语音合成服务appid 101 | access_token: 你的火山引擎语音合成服务access_token -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | services: 2 | # Web Server 服务 3 | webserver: 4 | build: 5 | context: . 6 | dockerfile: Dockerfile 7 | ports: 8 | - "8000:8000" 9 | restart: unless-stopped 10 | volumes: 11 | - ./config.yaml:/app/config.yaml 12 | - ./models:/app/models -------------------------------------------------------------------------------- /extern/onnxruntime-linux-x64-1.21.0/GIT_COMMIT_ID: -------------------------------------------------------------------------------- 1 | e0b66cad282043d4377cea5269083f17771b6dfc 2 | -------------------------------------------------------------------------------- /extern/onnxruntime-linux-x64-1.21.0/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /extern/onnxruntime-linux-x64-1.21.0/Privacy.md: -------------------------------------------------------------------------------- 1 | # Privacy 2 | 3 | ## Data Collection 4 | The software may collect information about you and your use of the software and send it to Microsoft. Microsoft may use this information to provide services and improve our products and services. You may turn off the telemetry as described in the repository. There are also some features in the software that may enable you and Microsoft to collect data from users of your applications. If you use these features, you must comply with applicable law, including providing appropriate notices to users of your applications together with a copy of Microsoft's privacy statement. Our privacy statement is located at https://go.microsoft.com/fwlink/?LinkID=824704. You can learn more about data collection and use in the help documentation and our privacy statement. Your use of the software operates as your consent to these practices. 5 | 6 | *** 7 | 8 | ### Private Builds 9 | No data collection is performed when using your private builds built from source code. 10 | 11 | ### Official Builds 12 | ONNX Runtime does not maintain any independent telemetry collection mechanisms outside of what is provided by the platforms it supports. However, where applicable, ONNX Runtime will take advantage of platform-supported telemetry systems to collect trace events with the goal of improving product quality. 13 | 14 | Currently telemetry is only implemented for Windows builds and is turned **ON** by default in the official builds distributed in their respective package management repositories ([see here](../README.md#binaries)). This may be expanded to cover other platforms in the future. Data collection is implemented via 'Platform Telemetry' per vendor platform providers (see [telemetry.h](../onnxruntime/core/platform/telemetry.h)). 15 | 16 | #### Technical Details 17 | The Windows provider uses the [TraceLogging](https://docs.microsoft.com/en-us/windows/win32/tracelogging/trace-logging-about) API for its implementation. This enables ONNX Runtime trace events to be collected by the operating system, and based on user consent, this data may be periodically sent to Microsoft servers following GDPR and privacy regulations for anonymity and data access controls. 18 | 19 | Windows ML and onnxruntime C APIs allow Trace Logging to be turned on/off (see [API pages](../README.md#api-documentation) for details). 20 | For information on how to enable and disable telemetry, see [C API: Telemetry](./C_API.md#telemetry). 21 | There are equivalent APIs in the C#, Python, and Java language bindings as well. 22 | -------------------------------------------------------------------------------- /extern/onnxruntime-linux-x64-1.21.0/README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 | **ONNX Runtime is a cross-platform inference and training machine-learning accelerator**. 4 | 5 | **ONNX Runtime inference** can enable faster customer experiences and lower costs, supporting models from deep learning frameworks such as PyTorch and TensorFlow/Keras as well as classical machine learning libraries such as scikit-learn, LightGBM, XGBoost, etc. ONNX Runtime is compatible with different hardware, drivers, and operating systems, and provides optimal performance by leveraging hardware accelerators where applicable alongside graph optimizations and transforms. [Learn more →](https://www.onnxruntime.ai/docs/#onnx-runtime-for-inferencing) 6 | 7 | **ONNX Runtime training** can accelerate the model training time on multi-node NVIDIA GPUs for transformer models with a one-line addition for existing PyTorch training scripts. [Learn more →](https://www.onnxruntime.ai/docs/#onnx-runtime-for-training) 8 | 9 | ## Get Started & Resources 10 | 11 | * **General Information**: [onnxruntime.ai](https://onnxruntime.ai) 12 | 13 | * **Usage documentation and tutorials**: [onnxruntime.ai/docs](https://onnxruntime.ai/docs) 14 | 15 | * **YouTube video tutorials**: [youtube.com/@ONNXRuntime](https://www.youtube.com/@ONNXRuntime) 16 | 17 | * [**Upcoming Release Roadmap**](https://onnxruntime.ai/roadmap) 18 | 19 | * **Companion sample repositories**: 20 | - ONNX Runtime Inferencing: [microsoft/onnxruntime-inference-examples](https://github.com/microsoft/onnxruntime-inference-examples) 21 | - ONNX Runtime Training: [microsoft/onnxruntime-training-examples](https://github.com/microsoft/onnxruntime-training-examples) 22 | 23 | ## Builtin Pipeline Status 24 | 25 | |System|Inference|Training| 26 | |---|---|---| 27 | |Windows|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Windows%20CPU%20CI%20Pipeline?label=Windows+CPU)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=9)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Windows%20GPU%20CUDA%20CI%20Pipeline?label=Windows+GPU+CUDA)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=218)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Windows%20GPU%20TensorRT%20CI%20Pipeline?label=Windows+GPU+TensorRT)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=47)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Windows%20GPU%20WebGPU%20CI%20Pipeline?label=Windows+GPU+WebGPU)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=228)|| 28 | |Linux|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20CPU%20CI%20Pipeline?label=Linux+CPU)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=11)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20CPU%20Minimal%20Build%20E2E%20CI%20Pipeline?label=Linux+CPU+Minimal+Build)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=64)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20GPU%20CI%20Pipeline?label=Linux+GPU)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=12)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20GPU%20TensorRT%20CI%20Pipeline?label=Linux+GPU+TensorRT)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=45)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20OpenVINO%20CI%20Pipeline?label=Linux+OpenVINO)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=55)|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/orttraining-linux-ci-pipeline?label=Linux+CPU+Training)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=86)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/orttraining-linux-gpu-ci-pipeline?label=Linux+GPU+Training)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=84)| 29 | |Mac|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/MacOS%20CI%20Pipeline?label=MacOS+CPU)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=13)|| 30 | |Android|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Android%20CI%20Pipeline?label=Android)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=53)|| 31 | |iOS|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/iOS%20CI%20Pipeline?label=iOS)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=134)|| 32 | |Web|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/ONNX%20Runtime%20Web%20CI%20Pipeline?label=Web)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=161)|| 33 | |Other|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/onnxruntime-binary-size-checks-ci-pipeline?repoName=microsoft%2Fonnxruntime&label=Binary+Size+Check)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=187&repoName=microsoft%2Fonnxruntime)|| 34 | 35 | This project is tested with [BrowserStack](https://www.browserstack.com/home). 36 | 37 | ## Third-party Pipeline Status 38 | 39 | |System|Inference|Training| 40 | |---|---|---| 41 | |Linux|[![Build Status](https://github.com/Ascend/onnxruntime/actions/workflows/build-and-test.yaml/badge.svg)](https://github.com/Ascend/onnxruntime/actions/workflows/build-and-test.yaml)|| 42 | 43 | ## Releases 44 | 45 | The current release and past releases can be found here: https://github.com/microsoft/onnxruntime/releases. 46 | 47 | For details on the upcoming release, including release dates, announcements, features, and guidance on submitting feature requests, please visit the release roadmap: https://onnxruntime.ai/roadmap. 48 | 49 | ## Data/Telemetry 50 | 51 | Windows distributions of this project may collect usage data and send it to Microsoft to help improve our products and services. See the [privacy statement](docs/Privacy.md) for more details. 52 | 53 | ## Contributions and Feedback 54 | 55 | We welcome contributions! Please see the [contribution guidelines](CONTRIBUTING.md). 56 | 57 | For feature requests or bug reports, please file a [GitHub Issue](https://github.com/Microsoft/onnxruntime/issues). 58 | 59 | For general discussion or questions, please use [GitHub Discussions](https://github.com/microsoft/onnxruntime/discussions). 60 | 61 | ## Code of Conduct 62 | 63 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 64 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 65 | or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 66 | 67 | ## License 68 | 69 | This project is licensed under the [MIT License](LICENSE). 70 | -------------------------------------------------------------------------------- /extern/onnxruntime-linux-x64-1.21.0/VERSION_NUMBER: -------------------------------------------------------------------------------- 1 | 1.21.0 2 | -------------------------------------------------------------------------------- /extern/onnxruntime-linux-x64-1.21.0/include/core/providers/custom_op_context.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | 6 | // CustomOpContext defines an interface allowing a custom op to access ep-specific resources. 7 | struct CustomOpContext { 8 | CustomOpContext() = default; 9 | virtual ~CustomOpContext() {}; 10 | }; -------------------------------------------------------------------------------- /extern/onnxruntime-linux-x64-1.21.0/include/core/providers/resource.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | 6 | enum ResourceOffset { 7 | cpu_resource_offset = 0, 8 | cuda_resource_offset = 10000, 9 | dml_resource_offset = 20000, 10 | rocm_resource_offset = 30000, 11 | // offsets for other ort eps 12 | custom_ep_resource_offset = 10000000, 13 | // offsets for customized eps 14 | }; -------------------------------------------------------------------------------- /extern/onnxruntime-linux-x64-1.21.0/include/cpu_provider_factory.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #include "onnxruntime_c_api.h" 5 | 6 | #ifdef __cplusplus 7 | extern "C" { 8 | #endif 9 | 10 | /** 11 | * \param use_arena zero: false. non-zero: true. 12 | */ 13 | ORT_EXPORT 14 | ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CPU, _In_ OrtSessionOptions* options, int use_arena) 15 | ORT_ALL_ARGS_NONNULL; 16 | 17 | #ifdef __cplusplus 18 | } 19 | #endif 20 | -------------------------------------------------------------------------------- /extern/onnxruntime-linux-x64-1.21.0/include/onnxruntime_run_options_config_keys.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | 6 | /* 7 | * This file defines RunOptions Config Keys and format of the Config Values. 8 | * 9 | * The Naming Convention for a RunOptions Config Key, 10 | * "[Area][.[SubArea1].[SubArea2]...].[Keyname]" 11 | * Such as "ep.cuda.use_arena" 12 | * The Config Key cannot be empty 13 | * The maximum length of the Config Key is 128 14 | * 15 | * The string format of a RunOptions Config Value is defined individually for each Config. 16 | * The maximum length of the Config Value is 1024 17 | */ 18 | 19 | // Key for enabling shrinkages of user listed device memory arenas. 20 | // Expects a list of semi-colon separated key value pairs separated by colon in the following format: 21 | // "device_0:device_id_0;device_1:device_id_1" 22 | // No white-spaces allowed in the provided list string. 23 | // Currently, the only supported devices are : "cpu", "gpu" (case sensitive). 24 | // If "cpu" is included in the list, DisableCpuMemArena() API must not be called (i.e.) arena for cpu should be enabled. 25 | // Example usage: "cpu:0;gpu:0" (or) "gpu:0" 26 | // By default, the value for this key is empty (i.e.) no memory arenas are shrunk 27 | static const char* const kOrtRunOptionsConfigEnableMemoryArenaShrinkage = "memory.enable_memory_arena_shrinkage"; 28 | 29 | // Set to '1' to not synchronize execution providers with CPU at the end of session run. 30 | // Per default it will be set to '0' 31 | // Taking CUDA EP as an example, it omit triggering cudaStreamSynchronize on the compute stream. 32 | static const char* const kOrtRunOptionsConfigDisableSynchronizeExecutionProviders = "disable_synchronize_execution_providers"; 33 | 34 | // Set HTP performance mode for QNN HTP backend before session run. 35 | // options for HTP performance mode: "burst", "balanced", "default", "high_performance", 36 | // "high_power_saver", "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver", 37 | // "sustained_high_performance". Default to "default". 38 | static const char* const kOrtRunOptionsConfigQnnPerfMode = "qnn.htp_perf_mode"; 39 | 40 | // Set HTP performance mode for QNN HTP backend post session run. 41 | static const char* const kOrtRunOptionsConfigQnnPerfModePostRun = "qnn.htp_perf_mode_post_run"; 42 | 43 | // Set RPC control latency for QNN HTP backend 44 | static const char* const kOrtRunOptionsConfigQnnRpcControlLatency = "qnn.rpc_control_latency"; 45 | 46 | // Set graph annotation id for CUDA EP. Use with enable_cuda_graph=true. 47 | // The value should be an integer. If the value is not set, the default value is 0 and 48 | // ORT session only captures one cuda graph before another capture is requested. 49 | // If the value is set to -1, cuda graph capture/replay is disabled in that run. 50 | // User are not expected to set the value to 0 as it is reserved for internal use. 51 | static const char* const kOrtRunOptionsConfigCudaGraphAnnotation = "gpu_graph_id"; 52 | -------------------------------------------------------------------------------- /extern/onnxruntime-linux-x64-1.21.0/include/provider_options.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | namespace onnxruntime { 11 | 12 | // data types for execution provider options 13 | 14 | using ProviderOptions = std::unordered_map; 15 | using ProviderOptionsVector = std::vector; 16 | using ProviderOptionsMap = std::unordered_map; 17 | 18 | } // namespace onnxruntime 19 | -------------------------------------------------------------------------------- /extern/onnxruntime-linux-x64-1.21.0/lib/cmake/onnxruntime/onnxruntimeConfig.cmake: -------------------------------------------------------------------------------- 1 | 2 | ####### Expanded from @PACKAGE_INIT@ by configure_package_config_file() ####### 3 | ####### Any changes to this file will be overwritten by the next CMake run #### 4 | ####### The input file was PROJECT_CONFIG_FILE ######## 5 | 6 | get_filename_component(PACKAGE_PREFIX_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../" ABSOLUTE) 7 | 8 | macro(set_and_check _var _file) 9 | set(${_var} "${_file}") 10 | if(NOT EXISTS "${_file}") 11 | message(FATAL_ERROR "File or directory ${_file} referenced by variable ${_var} does not exist !") 12 | endif() 13 | endmacro() 14 | 15 | macro(check_required_components _NAME) 16 | foreach(comp ${${_NAME}_FIND_COMPONENTS}) 17 | if(NOT ${_NAME}_${comp}_FOUND) 18 | if(${_NAME}_FIND_REQUIRED_${comp}) 19 | set(${_NAME}_FOUND FALSE) 20 | endif() 21 | endif() 22 | endforeach() 23 | endmacro() 24 | 25 | #################################################################################### 26 | include("${CMAKE_CURRENT_LIST_DIR}/onnxruntimeTargets.cmake") 27 | -------------------------------------------------------------------------------- /extern/onnxruntime-linux-x64-1.21.0/lib/cmake/onnxruntime/onnxruntimeConfigVersion.cmake: -------------------------------------------------------------------------------- 1 | # This is a basic version file for the Config-mode of find_package(). 2 | # It is used by write_basic_package_version_file() as input file for configure_file() 3 | # to create a version-file which can be installed along a config.cmake file. 4 | # 5 | # The created file sets PACKAGE_VERSION_EXACT if the current version string and 6 | # the requested version string are exactly the same and it sets 7 | # PACKAGE_VERSION_COMPATIBLE if the current version is >= requested version, 8 | # but only if the requested major version is the same as the current one. 9 | # The variable CVF_VERSION must be set before calling configure_file(). 10 | 11 | 12 | set(PACKAGE_VERSION "1.21.0") 13 | 14 | if(PACKAGE_VERSION VERSION_LESS PACKAGE_FIND_VERSION) 15 | set(PACKAGE_VERSION_COMPATIBLE FALSE) 16 | else() 17 | 18 | if("1.21.0" MATCHES "^([0-9]+)\\.") 19 | set(CVF_VERSION_MAJOR "${CMAKE_MATCH_1}") 20 | if(NOT CVF_VERSION_MAJOR VERSION_EQUAL 0) 21 | string(REGEX REPLACE "^0+" "" CVF_VERSION_MAJOR "${CVF_VERSION_MAJOR}") 22 | endif() 23 | else() 24 | set(CVF_VERSION_MAJOR "1.21.0") 25 | endif() 26 | 27 | if(PACKAGE_FIND_VERSION_RANGE) 28 | # both endpoints of the range must have the expected major version 29 | math (EXPR CVF_VERSION_MAJOR_NEXT "${CVF_VERSION_MAJOR} + 1") 30 | if (NOT PACKAGE_FIND_VERSION_MIN_MAJOR STREQUAL CVF_VERSION_MAJOR 31 | OR ((PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "INCLUDE" AND NOT PACKAGE_FIND_VERSION_MAX_MAJOR STREQUAL CVF_VERSION_MAJOR) 32 | OR (PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "EXCLUDE" AND NOT PACKAGE_FIND_VERSION_MAX VERSION_LESS_EQUAL CVF_VERSION_MAJOR_NEXT))) 33 | set(PACKAGE_VERSION_COMPATIBLE FALSE) 34 | elseif(PACKAGE_FIND_VERSION_MIN_MAJOR STREQUAL CVF_VERSION_MAJOR 35 | AND ((PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "INCLUDE" AND PACKAGE_VERSION VERSION_LESS_EQUAL PACKAGE_FIND_VERSION_MAX) 36 | OR (PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "EXCLUDE" AND PACKAGE_VERSION VERSION_LESS PACKAGE_FIND_VERSION_MAX))) 37 | set(PACKAGE_VERSION_COMPATIBLE TRUE) 38 | else() 39 | set(PACKAGE_VERSION_COMPATIBLE FALSE) 40 | endif() 41 | else() 42 | if(PACKAGE_FIND_VERSION_MAJOR STREQUAL CVF_VERSION_MAJOR) 43 | set(PACKAGE_VERSION_COMPATIBLE TRUE) 44 | else() 45 | set(PACKAGE_VERSION_COMPATIBLE FALSE) 46 | endif() 47 | 48 | if(PACKAGE_FIND_VERSION STREQUAL PACKAGE_VERSION) 49 | set(PACKAGE_VERSION_EXACT TRUE) 50 | endif() 51 | endif() 52 | endif() 53 | 54 | 55 | # if the installed or the using project don't have CMAKE_SIZEOF_VOID_P set, ignore it: 56 | if("${CMAKE_SIZEOF_VOID_P}" STREQUAL "" OR "8" STREQUAL "") 57 | return() 58 | endif() 59 | 60 | # check that the installed version has the same 32/64bit-ness as the one which is currently searching: 61 | if(NOT CMAKE_SIZEOF_VOID_P STREQUAL "8") 62 | math(EXPR installedBits "8 * 8") 63 | set(PACKAGE_VERSION "${PACKAGE_VERSION} (${installedBits}bit)") 64 | set(PACKAGE_VERSION_UNSUITABLE TRUE) 65 | endif() 66 | -------------------------------------------------------------------------------- /extern/onnxruntime-linux-x64-1.21.0/lib/cmake/onnxruntime/onnxruntimeTargets-release.cmake: -------------------------------------------------------------------------------- 1 | #---------------------------------------------------------------- 2 | # Generated CMake target import file for configuration "Release". 3 | #---------------------------------------------------------------- 4 | 5 | # Commands may need to know the format version. 6 | set(CMAKE_IMPORT_FILE_VERSION 1) 7 | 8 | # Import target "onnxruntime::onnxruntime" for configuration "Release" 9 | set_property(TARGET onnxruntime::onnxruntime APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE) 10 | set_target_properties(onnxruntime::onnxruntime PROPERTIES 11 | IMPORTED_LOCATION_RELEASE "${_IMPORT_PREFIX}/lib64/libonnxruntime.so.1.21.0" 12 | IMPORTED_SONAME_RELEASE "libonnxruntime.so.1" 13 | ) 14 | 15 | list(APPEND _cmake_import_check_targets onnxruntime::onnxruntime ) 16 | list(APPEND _cmake_import_check_files_for_onnxruntime::onnxruntime "${_IMPORT_PREFIX}/lib64/libonnxruntime.so.1.21.0" ) 17 | 18 | # Commands beyond this point should not need to know the version. 19 | set(CMAKE_IMPORT_FILE_VERSION) 20 | -------------------------------------------------------------------------------- /extern/onnxruntime-linux-x64-1.21.0/lib/cmake/onnxruntime/onnxruntimeTargets.cmake: -------------------------------------------------------------------------------- 1 | # Generated by CMake 2 | 3 | if("${CMAKE_MAJOR_VERSION}.${CMAKE_MINOR_VERSION}" LESS 2.8) 4 | message(FATAL_ERROR "CMake >= 2.8.3 required") 5 | endif() 6 | if(CMAKE_VERSION VERSION_LESS "2.8.3") 7 | message(FATAL_ERROR "CMake >= 2.8.3 required") 8 | endif() 9 | cmake_policy(PUSH) 10 | cmake_policy(VERSION 2.8.3...3.29) 11 | #---------------------------------------------------------------- 12 | # Generated CMake target import file. 13 | #---------------------------------------------------------------- 14 | 15 | # Commands may need to know the format version. 16 | set(CMAKE_IMPORT_FILE_VERSION 1) 17 | 18 | # Protect against multiple inclusion, which would fail when already imported targets are added once more. 19 | set(_cmake_targets_defined "") 20 | set(_cmake_targets_not_defined "") 21 | set(_cmake_expected_targets "") 22 | foreach(_cmake_expected_target IN ITEMS onnxruntime::onnxruntime) 23 | list(APPEND _cmake_expected_targets "${_cmake_expected_target}") 24 | if(TARGET "${_cmake_expected_target}") 25 | list(APPEND _cmake_targets_defined "${_cmake_expected_target}") 26 | else() 27 | list(APPEND _cmake_targets_not_defined "${_cmake_expected_target}") 28 | endif() 29 | endforeach() 30 | unset(_cmake_expected_target) 31 | if(_cmake_targets_defined STREQUAL _cmake_expected_targets) 32 | unset(_cmake_targets_defined) 33 | unset(_cmake_targets_not_defined) 34 | unset(_cmake_expected_targets) 35 | unset(CMAKE_IMPORT_FILE_VERSION) 36 | cmake_policy(POP) 37 | return() 38 | endif() 39 | if(NOT _cmake_targets_defined STREQUAL "") 40 | string(REPLACE ";" ", " _cmake_targets_defined_text "${_cmake_targets_defined}") 41 | string(REPLACE ";" ", " _cmake_targets_not_defined_text "${_cmake_targets_not_defined}") 42 | message(FATAL_ERROR "Some (but not all) targets in this export set were already defined.\nTargets Defined: ${_cmake_targets_defined_text}\nTargets not yet defined: ${_cmake_targets_not_defined_text}\n") 43 | endif() 44 | unset(_cmake_targets_defined) 45 | unset(_cmake_targets_not_defined) 46 | unset(_cmake_expected_targets) 47 | 48 | 49 | # Compute the installation prefix relative to this file. 50 | get_filename_component(_IMPORT_PREFIX "${CMAKE_CURRENT_LIST_FILE}" PATH) 51 | get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH) 52 | get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH) 53 | get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH) 54 | if(_IMPORT_PREFIX STREQUAL "/") 55 | set(_IMPORT_PREFIX "") 56 | endif() 57 | 58 | # Create imported target onnxruntime::onnxruntime 59 | add_library(onnxruntime::onnxruntime SHARED IMPORTED) 60 | 61 | set_target_properties(onnxruntime::onnxruntime PROPERTIES 62 | INTERFACE_INCLUDE_DIRECTORIES "${_IMPORT_PREFIX}/include/onnxruntime" 63 | ) 64 | 65 | # Load information for each installed configuration. 66 | file(GLOB _cmake_config_files "${CMAKE_CURRENT_LIST_DIR}/onnxruntimeTargets-*.cmake") 67 | foreach(_cmake_config_file IN LISTS _cmake_config_files) 68 | include("${_cmake_config_file}") 69 | endforeach() 70 | unset(_cmake_config_file) 71 | unset(_cmake_config_files) 72 | 73 | # Cleanup temporary variables. 74 | set(_IMPORT_PREFIX) 75 | 76 | # Loop over all imported files and verify that they actually exist 77 | foreach(_cmake_target IN LISTS _cmake_import_check_targets) 78 | if(CMAKE_VERSION VERSION_LESS "3.28" 79 | OR NOT DEFINED _cmake_import_check_xcframework_for_${_cmake_target} 80 | OR NOT IS_DIRECTORY "${_cmake_import_check_xcframework_for_${_cmake_target}}") 81 | foreach(_cmake_file IN LISTS "_cmake_import_check_files_for_${_cmake_target}") 82 | if(NOT EXISTS "${_cmake_file}") 83 | message(FATAL_ERROR "The imported target \"${_cmake_target}\" references the file 84 | \"${_cmake_file}\" 85 | but this file does not exist. Possible reasons include: 86 | * The file was deleted, renamed, or moved to another location. 87 | * An install or uninstall procedure did not complete successfully. 88 | * The installation package was faulty and contained 89 | \"${CMAKE_CURRENT_LIST_FILE}\" 90 | but not all the files it references. 91 | ") 92 | endif() 93 | endforeach() 94 | endif() 95 | unset(_cmake_file) 96 | unset("_cmake_import_check_files_for_${_cmake_target}") 97 | endforeach() 98 | unset(_cmake_target) 99 | unset(_cmake_import_check_targets) 100 | 101 | # This file does not depend on other imported targets which have 102 | # been exported from the same project but in a separate export set. 103 | 104 | # Commands beyond this point should not need to know the version. 105 | set(CMAKE_IMPORT_FILE_VERSION) 106 | cmake_policy(POP) 107 | -------------------------------------------------------------------------------- /extern/onnxruntime-linux-x64-1.21.0/lib/libonnxruntime.so: -------------------------------------------------------------------------------- 1 | libonnxruntime.so.1 -------------------------------------------------------------------------------- /extern/onnxruntime-linux-x64-1.21.0/lib/libonnxruntime.so.1: -------------------------------------------------------------------------------- 1 | libonnxruntime.so.1.21.0 -------------------------------------------------------------------------------- /extern/onnxruntime-linux-x64-1.21.0/lib/libonnxruntime.so.1.21.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daxpot/xiaozhi-cpp-server/eae257d6e94594b2f4112ed94c33ba4adf33869c/extern/onnxruntime-linux-x64-1.21.0/lib/libonnxruntime.so.1.21.0 -------------------------------------------------------------------------------- /extern/onnxruntime-linux-x64-1.21.0/lib/libonnxruntime_providers_shared.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daxpot/xiaozhi-cpp-server/eae257d6e94594b2f4112ed94c33ba4adf33869c/extern/onnxruntime-linux-x64-1.21.0/lib/libonnxruntime_providers_shared.so -------------------------------------------------------------------------------- /extern/onnxruntime-linux-x64-1.21.0/lib/pkgconfig/libonnxruntime.pc: -------------------------------------------------------------------------------- 1 | prefix=/usr/local 2 | bindir=${prefix}/bin 3 | mandir=${prefix}/share/man 4 | docdir=${prefix}/share/doc/onnxruntime 5 | libdir=${prefix}/lib64 6 | includedir=${prefix}/include/onnxruntime 7 | 8 | Name: onnxruntime 9 | Description: ONNX runtime 10 | URL: https://github.com/microsoft/onnxruntime 11 | Version: 1.21.0 12 | Libs: -L${libdir} -lonnxruntime 13 | Cflags: -I${includedir} 14 | -------------------------------------------------------------------------------- /include/funasr/paraformer/com-define.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace funasr { 4 | #define S_BEGIN 0 5 | #define S_MIDDLE 1 6 | #define S_END 2 7 | #define S_ALL 3 8 | #define S_ERR 4 9 | 10 | #ifndef MODEL_SAMPLE_RATE 11 | #define MODEL_SAMPLE_RATE 16000 12 | #endif 13 | 14 | // parser option 15 | #define MODEL_DIR "model-dir" 16 | #define OFFLINE_MODEL_DIR "model-dir" 17 | #define ONLINE_MODEL_DIR "online-model-dir" 18 | #define LM_DIR "lm-dir" 19 | #define GLOB_BEAM "global-beam" 20 | #define LAT_BEAM "lattice-beam" 21 | #define AM_SCALE "am-scale" 22 | // #define FST_HOTWORD "fst-hotword" 23 | #define FST_INC_WTS "fst-inc-wts" 24 | #define VAD_DIR "vad-dir" 25 | #define PUNC_DIR "punc-dir" 26 | #define QUANTIZE "quantize" 27 | #define VAD_QUANT "vad-quant" 28 | #define PUNC_QUANT "punc-quant" 29 | #define ASR_MODE "mode" 30 | 31 | #define WAV_PATH "wav-path" 32 | #define WAV_SCP "wav-scp" 33 | #define TXT_PATH "txt-path" 34 | #define THREAD_NUM "thread-num" 35 | #define PORT_ID "port-id" 36 | #define HOTWORD_SEP " " 37 | #define AUDIO_FS "audio-fs" 38 | 39 | #define MODEL_PARA "Paraformer" 40 | #define MODEL_SVS "SenseVoiceSmall" 41 | 42 | // #define VAD_MODEL_PATH "vad-model" 43 | // #define VAD_CMVN_PATH "vad-cmvn" 44 | // #define VAD_CONFIG_PATH "vad-config" 45 | // #define AM_MODEL_PATH "am-model" 46 | // #define AM_CMVN_PATH "am-cmvn" 47 | // #define AM_CONFIG_PATH "am-config" 48 | // #define PUNC_MODEL_PATH "punc-model" 49 | // #define PUNC_CONFIG_PATH "punc-config" 50 | 51 | #define MODEL_NAME "model.onnx" 52 | // hotword embedding compile model 53 | #define MODEL_EB_NAME "model_eb.onnx" 54 | #define TORCH_MODEL_EB_NAME "model_eb.torchscript" 55 | #define QUANT_MODEL_NAME "model_quant.onnx" 56 | #define VAD_CMVN_NAME "am.mvn" 57 | #define VAD_CONFIG_NAME "config.yaml" 58 | 59 | // gpu models 60 | #define INFER_GPU "gpu" 61 | #define BATCHSIZE "batch-size" 62 | #define TORCH_MODEL_NAME "model.torchscript" 63 | #define TORCH_QUANT_MODEL_NAME "model_quant.torchscript" 64 | #define BLADE_MODEL_NAME "model_blade.torchscript" 65 | #define BLADEDISC "bladedisc" 66 | 67 | #define AM_CMVN_NAME "am.mvn" 68 | #define AM_CONFIG_NAME "config.yaml" 69 | #define LM_CONFIG_NAME "config.yaml" 70 | #define PUNC_CONFIG_NAME "config.yaml" 71 | #define MODEL_SEG_DICT "seg_dict" 72 | #define TOKEN_PATH "tokens.json" 73 | #define HOTWORD "hotword" 74 | // #define NN_HOTWORD "nn-hotword" 75 | 76 | #define ITN_DIR "itn-dir" 77 | #define ITN_TAGGER_NAME "zh_itn_tagger.fst" 78 | #define ITN_VERBALIZER_NAME "zh_itn_verbalizer.fst" 79 | 80 | #define ENCODER_NAME "model.onnx" 81 | #define QUANT_ENCODER_NAME "model_quant.onnx" 82 | #define DECODER_NAME "decoder.onnx" 83 | #define QUANT_DECODER_NAME "decoder_quant.onnx" 84 | 85 | #define LM_FST_RES "TLG.fst" 86 | #define LEX_PATH "lexicon.txt" 87 | 88 | // vad 89 | #ifndef VAD_SILENCE_DURATION 90 | #define VAD_SILENCE_DURATION 800 91 | #endif 92 | 93 | #ifndef VAD_MAX_LEN 94 | #define VAD_MAX_LEN 15000 95 | #endif 96 | 97 | #ifndef VAD_SPEECH_NOISE_THRES 98 | #define VAD_SPEECH_NOISE_THRES 0.9 99 | #endif 100 | 101 | #ifndef VAD_LFR_M 102 | #define VAD_LFR_M 5 103 | #endif 104 | 105 | #ifndef VAD_LFR_N 106 | #define VAD_LFR_N 1 107 | #endif 108 | 109 | // asr 110 | #ifndef PARA_LFR_M 111 | #define PARA_LFR_M 7 112 | #endif 113 | 114 | #ifndef PARA_LFR_N 115 | #define PARA_LFR_N 6 116 | #endif 117 | 118 | #ifndef ONLINE_STEP 119 | #define ONLINE_STEP 9600 120 | #endif 121 | 122 | // punc 123 | #define UNK_CHAR "" 124 | #define TOKEN_LEN 20 125 | 126 | #define CANDIDATE_NUM 6 127 | #define UNKNOW_INDEX 0 128 | #define NOTPUNC "_" 129 | #define NOTPUNC_INDEX 1 130 | #define COMMA_INDEX 2 131 | #define PERIOD_INDEX 3 132 | #define QUESTION_INDEX 4 133 | #define DUN_INDEX 5 134 | #define CACHE_POP_TRIGGER_LIMIT 200 135 | 136 | #define JIEBA_DICT "jieba.c.dict" 137 | #define JIEBA_USERDICT "jieba_usr_dict" 138 | #define JIEBA_HMM_MODEL "jieba.hmm" 139 | 140 | } // namespace funasr -------------------------------------------------------------------------------- /include/funasr/paraformer/commonfunc.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include "model.h" 4 | 5 | namespace funasr { 6 | typedef struct 7 | { 8 | std::string msg; 9 | std::string stamp; 10 | std::string stamp_sents; 11 | std::string tpass_msg; 12 | float snippet_time; 13 | }FUNASR_RECOG_RESULT; 14 | 15 | typedef struct 16 | { 17 | std::vector>* segments; 18 | float snippet_time; 19 | }FUNASR_VAD_RESULT; 20 | 21 | typedef struct 22 | { 23 | std::string msg; 24 | std::vector arr_cache; 25 | }FUNASR_PUNC_RESULT; 26 | 27 | 28 | #define ORTSTRING(str) str 29 | #define ORTCHAR(str) str 30 | 31 | inline void GetInputName(Ort::Session* session, std::string& inputName,int nIndex=0) { 32 | size_t numInputNodes = session->GetInputCount(); 33 | if (numInputNodes > 0) { 34 | Ort::AllocatorWithDefaultOptions allocator; 35 | { 36 | auto t = session->GetInputNameAllocated(nIndex, allocator); 37 | inputName = t.get(); 38 | } 39 | } 40 | } 41 | 42 | inline void GetInputNames(Ort::Session* session, std::vector &m_strInputNames, 43 | std::vector &m_szInputNames) { 44 | Ort::AllocatorWithDefaultOptions allocator; 45 | size_t numNodes = session->GetInputCount(); 46 | m_strInputNames.resize(numNodes); 47 | m_szInputNames.resize(numNodes); 48 | for (size_t i = 0; i != numNodes; ++i) { 49 | auto t = session->GetInputNameAllocated(i, allocator); 50 | m_strInputNames[i] = t.get(); 51 | m_szInputNames[i] = m_strInputNames[i].c_str(); 52 | } 53 | } 54 | 55 | inline void GetOutputName(Ort::Session* session, std::string& outputName, int nIndex = 0) { 56 | size_t numOutputNodes = session->GetOutputCount(); 57 | if (numOutputNodes > 0) { 58 | Ort::AllocatorWithDefaultOptions allocator; 59 | { 60 | auto t = session->GetOutputNameAllocated(nIndex, allocator); 61 | outputName = t.get(); 62 | } 63 | } 64 | } 65 | 66 | inline void GetOutputNames(Ort::Session* session, std::vector &m_strOutputNames, 67 | std::vector &m_szOutputNames) { 68 | Ort::AllocatorWithDefaultOptions allocator; 69 | size_t numNodes = session->GetOutputCount(); 70 | m_strOutputNames.resize(numNodes); 71 | m_szOutputNames.resize(numNodes); 72 | for (size_t i = 0; i != numNodes; ++i) { 73 | auto t = session->GetOutputNameAllocated(i, allocator); 74 | m_strOutputNames[i] = t.get(); 75 | m_szOutputNames[i] = m_strOutputNames[i].c_str(); 76 | } 77 | } 78 | 79 | template 80 | inline static size_t Argmax(ForwardIterator first, ForwardIterator last) { 81 | return std::distance(first, std::max_element(first, last)); 82 | } 83 | } // namespace funasr -------------------------------------------------------------------------------- /include/funasr/paraformer/model.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include "com-define.h" 6 | #include 7 | #include 8 | #include 9 | 10 | namespace funasr { 11 | class Model { 12 | public: 13 | virtual ~Model(){}; 14 | virtual void StartUtterance() = 0; 15 | virtual void EndUtterance() = 0; 16 | virtual void Reset() = 0; 17 | virtual std::string GreedySearch(float* in, int n_len, int64_t token_nums, bool is_stamp=false, std::vector us_alphas={0}, std::vector us_cif_peak={0}){return "";}; 18 | virtual void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){}; 19 | virtual void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){}; 20 | virtual void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, 21 | const std::string &am_config, const std::string &token_file, const std::string &online_token_file, int thread_num){}; 22 | virtual void InitLm(const std::string &lm_file, const std::string &lm_config, const std::string &lex_file){}; 23 | virtual void InitFstDecoder(){}; 24 | virtual std::string Forward(float *din, int len, bool input_finished, const std::vector> &hw_emb={{0.0}}, void* wfst_decoder=nullptr){return "";}; 25 | virtual std::vector Forward(float** din, int* len, bool input_finished, const std::vector> &hw_emb={{0.0}}, void* wfst_decoder=nullptr, int batch_in=1) 26 | {return std::vector();}; 27 | virtual std::vector Forward(float** din, int* len, bool input_finished, std::string svs_lang="auto", bool svs_itn=false, int batch_in=1) 28 | {return std::vector();}; 29 | virtual std::string Rescoring() = 0; 30 | virtual void InitHwCompiler(const std::string &hw_model, int thread_num){}; 31 | virtual void InitSegDict(const std::string &seg_dict_model){}; 32 | virtual std::vector> CompileHotwordEmbedding(std::string &hotwords){return std::vector>();}; 33 | virtual std::string GetLang(){return "";}; 34 | virtual int GetAsrSampleRate() = 0; 35 | virtual void SetBatchSize(int batch_size) {}; 36 | virtual int GetBatchSize() {return 0;}; 37 | // virtual Vocab* GetVocab() {return nullptr;}; 38 | // virtual Vocab* GetLmVocab() {return nullptr;}; 39 | // virtual PhoneSet* GetPhoneSet() {return nullptr;}; 40 | }; 41 | 42 | Model *CreateModel(std::map& model_path, int thread_num=1); 43 | Model *CreateModel(void* asr_handle, std::vector chunk_size); 44 | 45 | } // namespace funasr -------------------------------------------------------------------------------- /include/funasr/paraformer/paraformer-online.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. 3 | * MIT License (https://opensource.org/licenses/MIT) 4 | */ 5 | #pragma once 6 | 7 | #include "model.h" 8 | 9 | namespace funasr { 10 | 11 | class ParaformerOnline : public Model { 12 | /** 13 | * Author: Speech Lab of DAMO Academy, Alibaba Group 14 | * ParaformerOnline: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition 15 | * https://arxiv.org/pdf/2206.08317.pdf 16 | */ 17 | private: 18 | 19 | void FbankKaldi(float sample_rate, std::vector> &wav_feats, 20 | std::vector &waves); 21 | int OnlineLfrCmvn(std::vector> &wav_feats, bool input_finished); 22 | void GetPosEmb(std::vector> &wav_feats, int timesteps, int feat_dim); 23 | void CifSearch(std::vector> hidden, std::vector alphas, bool is_final, std::vector> &list_frame); 24 | 25 | static int ComputeFrameNum(int sample_length, int frame_sample_length, int frame_shift_sample_length) { 26 | int frame_num = static_cast((sample_length - frame_sample_length) / frame_shift_sample_length + 1); 27 | if (frame_num >= 1 && sample_length >= frame_sample_length) 28 | return frame_num; 29 | else 30 | return 0; 31 | } 32 | void InitOnline( 33 | knf::FbankOptions &fbank_opts, 34 | std::shared_ptr &encoder_session, 35 | std::shared_ptr &decoder_session, 36 | std::vector &en_szInputNames, 37 | std::vector &en_szOutputNames, 38 | std::vector &de_szInputNames, 39 | std::vector &de_szOutputNames, 40 | std::vector &means_list, 41 | std::vector &vars_list, 42 | int frame_length_, 43 | int frame_shift_, 44 | int n_mels_, 45 | int lfr_m_, 46 | int lfr_n_, 47 | int encoder_size_, 48 | int fsmn_layers_, 49 | int fsmn_lorder_, 50 | int fsmn_dims_, 51 | float cif_threshold_, 52 | float tail_alphas_); 53 | 54 | void StartUtterance() 55 | { 56 | } 57 | 58 | void EndUtterance() 59 | { 60 | } 61 | 62 | Model* offline_handle_ = nullptr; 63 | // from offline_handle_ 64 | knf::FbankOptions fbank_opts_; 65 | std::shared_ptr encoder_session_ = nullptr; 66 | std::shared_ptr decoder_session_ = nullptr; 67 | Ort::SessionOptions session_options_; 68 | std::vector en_szInputNames_; 69 | std::vector en_szOutputNames_; 70 | std::vector de_szInputNames_; 71 | std::vector de_szOutputNames_; 72 | std::vector means_list_; 73 | std::vector vars_list_; 74 | // configs from offline_handle_ 75 | int frame_length = 25; 76 | int frame_shift = 10; 77 | int n_mels = 80; 78 | int lfr_m = PARA_LFR_M; 79 | int lfr_n = PARA_LFR_N; 80 | int encoder_size = 512; 81 | int fsmn_layers = 16; 82 | int fsmn_lorder = 10; 83 | int fsmn_dims = 512; 84 | float cif_threshold = 1.0; 85 | float tail_alphas = 0.45; 86 | 87 | // configs 88 | int feat_dims = lfr_m*n_mels; 89 | std::vector chunk_size = {5,10,5}; 90 | int frame_sample_length_ = MODEL_SAMPLE_RATE / 1000 * frame_length; 91 | int frame_shift_sample_length_ = MODEL_SAMPLE_RATE / 1000 * frame_shift; 92 | 93 | // The reserved waveforms by fbank 94 | std::vector reserve_waveforms_; 95 | // waveforms reserved after last shift position 96 | std::vector input_cache_; 97 | // lfr reserved cache 98 | std::vector> lfr_splice_cache_; 99 | // position index cache 100 | int start_idx_cache_ = 0; 101 | // cif alpha 102 | std::vector alphas_cache_; 103 | std::vector> hidden_cache_; 104 | std::vector> feats_cache_; 105 | // fsmn init caches 106 | std::vector fsmn_init_cache_; 107 | std::vector decoder_onnx; 108 | 109 | bool is_first_chunk = true; 110 | bool is_last_chunk = false; 111 | double sqrt_factor; 112 | 113 | public: 114 | ParaformerOnline(Model* offline_handle, std::vector chunk_size, std::string model_type=MODEL_PARA); 115 | ~ParaformerOnline(); 116 | void Reset(); 117 | void ResetCache(); 118 | void InitCache(); 119 | void ExtractFeats(float sample_rate, std::vector> &wav_feats, std::vector &waves, bool input_finished); 120 | void AddOverlapChunk(std::vector> &wav_feats, bool input_finished); 121 | 122 | std::string ForwardChunk(std::vector> &wav_feats, bool input_finished); 123 | std::string Forward(float* din, int len, bool input_finished, const std::vector> &hw_emb={{0.0}}, void* wfst_decoder=nullptr); 124 | std::string Rescoring(); 125 | 126 | int GetAsrSampleRate() { return offline_handle_->GetAsrSampleRate(); }; 127 | 128 | // 2pass 129 | std::string online_res; 130 | int chunk_len; 131 | }; 132 | 133 | } // namespace funasr -------------------------------------------------------------------------------- /include/funasr/paraformer/paraformer.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. 3 | * MIT License (https://opensource.org/licenses/MIT) 4 | */ 5 | #pragma once 6 | 7 | #include "model.h" 8 | #include "vocab.h" 9 | #include "phone-set.h" 10 | #include "seg-dict.h" 11 | 12 | namespace funasr { 13 | 14 | class Paraformer : public Model { 15 | /** 16 | * Author: Speech Lab of DAMO Academy, Alibaba Group 17 | * Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition 18 | * https://arxiv.org/pdf/2206.08317.pdf 19 | */ 20 | private: 21 | Vocab* vocab = nullptr; 22 | Vocab* lm_vocab = nullptr; 23 | SegDict* seg_dict = nullptr; 24 | PhoneSet* phone_set_ = nullptr; 25 | //const float scale = 22.6274169979695; 26 | const float scale = 1.0; 27 | 28 | void LoadConfigFromYaml(const char* filename); 29 | void LoadOnlineConfigFromYaml(const char* filename); 30 | void LoadCmvn(const char *filename); 31 | void LfrCmvn(std::vector> &asr_feats); 32 | 33 | std::shared_ptr hw_m_session = nullptr; 34 | Ort::Env hw_env_; 35 | Ort::SessionOptions hw_session_options; 36 | std::vector hw_m_strInputNames, hw_m_strOutputNames; 37 | std::vector hw_m_szInputNames; 38 | std::vector hw_m_szOutputNames; 39 | bool use_hotword; 40 | 41 | public: 42 | Paraformer(); 43 | ~Paraformer(); 44 | // void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num); 45 | // online 46 | void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num); 47 | // 2pass 48 | // void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, 49 | // const std::string &am_config, const std::string &token_file, const std::string &online_token_file, int thread_num); 50 | // void InitHwCompiler(const std::string &hw_model, int thread_num); 51 | // void InitSegDict(const std::string &seg_dict_model); 52 | // std::vector> CompileHotwordEmbedding(std::string &hotwords); 53 | void Reset(); 54 | void FbankKaldi(float sample_rate, const float* waves, int len, std::vector> &asr_feats); 55 | // std::vector Forward(float** din, int* len, bool input_finished=true, const std::vector> &hw_emb={{0.0}}, void* wfst_decoder=nullptr, int batch_in=1); 56 | std::string GreedySearch( float* in, int n_len, int64_t token_nums, 57 | bool is_stamp=false, std::vector us_alphas={0}, std::vector us_cif_peak={0}); 58 | 59 | std::string Rescoring(); 60 | std::string GetLang(){return language;}; 61 | int GetAsrSampleRate() { return asr_sample_rate; }; 62 | int GetBatchSize() {return batch_size_;}; 63 | void StartUtterance(); 64 | void EndUtterance(); 65 | // void InitLm(const std::string &lm_file, const std::string &lm_cfg_file, const std::string &lex_file); 66 | Vocab* GetVocab(); 67 | Vocab* GetLmVocab(); 68 | PhoneSet* GetPhoneSet(); 69 | 70 | knf::FbankOptions fbank_opts_; 71 | std::vector means_list_; 72 | std::vector vars_list_; 73 | int lfr_m = PARA_LFR_M; 74 | int lfr_n = PARA_LFR_N; 75 | 76 | // paraformer-offline 77 | std::shared_ptr m_session_ = nullptr; 78 | Ort::Env env_; 79 | Ort::SessionOptions session_options_; 80 | 81 | std::vector m_strInputNames, m_strOutputNames; 82 | std::vector m_szInputNames; 83 | std::vector m_szOutputNames; 84 | 85 | std::string language="zh-cn"; 86 | 87 | // paraformer-online 88 | std::shared_ptr encoder_session_ = nullptr; 89 | std::shared_ptr decoder_session_ = nullptr; 90 | std::vector en_strInputNames, en_strOutputNames; 91 | std::vector en_szInputNames_; 92 | std::vector en_szOutputNames_; 93 | std::vector de_strInputNames, de_strOutputNames; 94 | std::vector de_szInputNames_; 95 | std::vector de_szOutputNames_; 96 | 97 | std::string window_type = "hamming"; 98 | int frame_length = 25; 99 | int frame_shift = 10; 100 | int n_mels = 80; 101 | int encoder_size = 512; 102 | int fsmn_layers = 16; 103 | int fsmn_lorder = 10; 104 | int fsmn_dims = 512; 105 | float cif_threshold = 1.0; 106 | float tail_alphas = 0.45; 107 | int asr_sample_rate = MODEL_SAMPLE_RATE; 108 | int batch_size_ = 1; 109 | }; 110 | 111 | } // namespace funasr -------------------------------------------------------------------------------- /include/funasr/paraformer/phone-set.h: -------------------------------------------------------------------------------- 1 | #ifndef PHONESET_H 2 | #define PHONESET_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #define UNIT_BEG_SIL_SYMBOL "" 11 | #define UNIT_END_SIL_SYMBOL "" 12 | #define UNIT_BLK_SYMBOL "" 13 | 14 | using namespace std; 15 | 16 | namespace funasr { 17 | class PhoneSet { 18 | public: 19 | PhoneSet(const char *filename); 20 | ~PhoneSet(); 21 | int Size() const; 22 | int String2Id(string str) const; 23 | string Id2String(int id) const; 24 | bool Find(string str) const; 25 | int GetBegSilPhnId() const; 26 | int GetEndSilPhnId() const; 27 | int GetBlkPhnId() const; 28 | 29 | private: 30 | vector phone_; 31 | unordered_map phn2Id_; 32 | void LoadPhoneSetFromYaml(const char* filename); 33 | void LoadPhoneSetFromJson(const char* filename); 34 | }; 35 | 36 | } // namespace funasr 37 | #endif -------------------------------------------------------------------------------- /include/funasr/paraformer/seg-dict.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. 3 | * MIT License (https://opensource.org/licenses/MIT) 4 | */ 5 | #ifndef SEG_DICT_H 6 | #define SEG_DICT_H 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | using namespace std; 13 | 14 | namespace funasr { 15 | class SegDict { 16 | private: 17 | std::map> seg_dict; 18 | 19 | public: 20 | SegDict(const char *filename); 21 | ~SegDict(); 22 | std::vector GetTokensByWord(const std::string &word); 23 | }; 24 | 25 | } // namespace funasr 26 | #endif -------------------------------------------------------------------------------- /include/funasr/paraformer/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | namespace funasr { 6 | std::string PathAppend(const std::string &p1, const std::string &p2); 7 | void FindMax(float *din, int len, float &max_val, int &max_idx); 8 | std::vector split(const std::string &s, char delim); 9 | } -------------------------------------------------------------------------------- /include/funasr/paraformer/vocab.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef VOCAB_H 3 | #define VOCAB_H 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | namespace funasr { 12 | class Vocab { 13 | private: 14 | std::vector vocab; 15 | std::map token_id; 16 | std::map lex_map; 17 | bool IsEnglish(std::string ch); 18 | void LoadVocabFromYaml(const char* filename); 19 | void LoadVocabFromJson(const char* filename); 20 | void LoadLex(const char* filename); 21 | 22 | public: 23 | Vocab(const char *filename); 24 | Vocab(const char *filename, const char *lex_file); 25 | ~Vocab(); 26 | int Size() const; 27 | bool IsChinese(std::string ch); 28 | void Vector2String(std::vector in, std::vector &preds); 29 | std::string Vector2String(std::vector in); 30 | std::string Vector2StringV2(std::vector in, std::string language=""); 31 | std::string Id2String(int id) const; 32 | std::string WordFormat(std::string word); 33 | int GetIdByToken(const std::string &token) const; 34 | std::string Word2Lex(const std::string &word) const; 35 | }; 36 | 37 | } // namespace funasr 38 | #endif -------------------------------------------------------------------------------- /include/xz-cpp-server/asr/base.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace xiaozhi { 4 | namespace asr { 5 | class Base { 6 | public: 7 | virtual ~Base(); 8 | virtual net::awaitable detect_opus(const std::optional& buf) = 0; 9 | }; 10 | std::unique_ptr createASR(const net::any_io_executor& executor); 11 | } 12 | } -------------------------------------------------------------------------------- /include/xz-cpp-server/asr/bytedancev2.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "base.h" 3 | 4 | namespace xiaozhi { 5 | namespace asr { 6 | class BytedanceV2: public Base { 7 | public: 8 | BytedanceV2(const net::any_io_executor& executor, const YAML::Node& config); 9 | ~BytedanceV2(); 10 | net::awaitable detect_opus(const std::optional& buf) override; 11 | private: 12 | class Impl; 13 | std::unique_ptr impl_; 14 | }; 15 | } 16 | } -------------------------------------------------------------------------------- /include/xz-cpp-server/asr/paraformer.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "base.h" 3 | 4 | namespace xiaozhi { 5 | namespace asr { 6 | class Paraformer: public Base { 7 | public: 8 | Paraformer(const net::any_io_executor& executor, const YAML::Node& config); 9 | ~Paraformer(); 10 | net::awaitable detect_opus(const std::optional& buf) override; 11 | private: 12 | class Impl; 13 | std::unique_ptr impl_; 14 | }; 15 | } 16 | } -------------------------------------------------------------------------------- /include/xz-cpp-server/common/logger.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | // 自定义格式化函数,根据日志级别设置颜色 4 | void color_formatter(boost::log::record_view const& rec, boost::log::formatting_ostream& strm); 5 | 6 | void init_logging(std::string log_level); -------------------------------------------------------------------------------- /include/xz-cpp-server/common/request.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace request { 4 | struct UrlInfo { 5 | bool is_https; 6 | std::string host; 7 | std::string port; 8 | std::string path; 9 | }; 10 | UrlInfo parse_url(const std::string& url); 11 | net::awaitable> connect(const UrlInfo& url_info); 12 | net::awaitable send(ssl::stream& stream, const http::verb method, const UrlInfo& url_info, const json::value& header, const std::string& data=""); 13 | net::awaitable request(const http::verb method, const std::string& url, const json::value& header, const std::string& data=""); 14 | net::awaitable get(const std::string& url, const json::value& header); 15 | net::awaitable post(const std::string& url, const json::value& header, const std::string& data); 16 | net::awaitable stream_post(const std::string& url, const json::value& header, const std::string& data, const std::function& callback); 17 | } -------------------------------------------------------------------------------- /include/xz-cpp-server/common/setting.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace xiaozhi { 4 | class Setting { 5 | public: 6 | static std::shared_ptr getSetting(const char* path=0); 7 | // ~Setting(); 8 | YAML::Node config; 9 | private: 10 | Setting(const char* path=0); 11 | }; 12 | } 13 | -------------------------------------------------------------------------------- /include/xz-cpp-server/common/threadsafe_queue.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #ifndef XIAOZHI_THREADSAFE_QUEUE_H 3 | #define XIAOZHI_THREADSAFE_QUEUE_H 4 | 5 | #include 6 | #include 7 | template 8 | class ThreadSafeQueue { 9 | private: 10 | mutable std::mutex mutex; 11 | std::queue data; 12 | public: 13 | ThreadSafeQueue() = default; 14 | ThreadSafeQueue(ThreadSafeQueue&) = delete; 15 | ThreadSafeQueue& operator=(const ThreadSafeQueue&) = delete; 16 | 17 | void push(T&& value) { 18 | std::lock_guard lk(mutex); 19 | data.push(std::forward(value)); 20 | } 21 | bool try_pop(T &value) { 22 | std::lock_guard lk(mutex); 23 | if(data.empty()) { 24 | return false; 25 | } 26 | value = std::move(data.front()); 27 | data.pop(); 28 | return true; 29 | } 30 | 31 | void clear() { 32 | std::lock_guard lk(mutex); 33 | std::queue().swap(data); 34 | } 35 | 36 | bool empty() const { 37 | std::lock_guard lk(mutex); 38 | return data.empty(); 39 | } 40 | }; 41 | #endif -------------------------------------------------------------------------------- /include/xz-cpp-server/common/tools.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | namespace tools { 5 | enum SegmentRet { 6 | NONE = 0, 7 | EN = 1, 8 | CHINESE = 3 9 | }; 10 | SegmentRet is_segment(const std::string& str, std::string::size_type pos); 11 | std::string::size_type find_last_segment(const std::string& input); 12 | std::string generate_uuid(); 13 | std::string gzip_compress(const std::string &data); 14 | std::string gzip_decompress(const std::string &data); 15 | long long get_tms(); 16 | std::tuple create_opus_coders(int sample_rate, bool create_encoder=true, bool create_decoder=true); 17 | void on_spawn_complete(std::string_view title, std::exception_ptr e); 18 | } -------------------------------------------------------------------------------- /include/xz-cpp-server/connection.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | namespace xiaozhi { 10 | class Connection: public std::enable_shared_from_this { 11 | private: 12 | std::atomic is_released_ = false; 13 | int min_silence_tms_ = 700; 14 | int close_connection_no_voice_time_ = 120; 15 | std::shared_ptr setting_ = nullptr; 16 | std::string session_id_; 17 | Vad vad_; 18 | net::any_io_executor executor_; 19 | boost::json::array dialogue_; 20 | std::vector cmd_exit_; 21 | 22 | ThreadSafeQueue llm_response_; 23 | ThreadSafeQueue> asr_audio_; 24 | 25 | boost::asio::steady_timer silence_timer_; 26 | std::unique_ptr asr_ = nullptr; 27 | std::unique_ptr llm_ = nullptr; 28 | std::unique_ptr tts_ = nullptr; 29 | websocket::stream ws_; 30 | net::strand strand_; 31 | 32 | net::awaitable handle_asr_text(std::string); 33 | net::awaitable handle_text(beast::flat_buffer &buffer); 34 | net::awaitable handle_binary(beast::flat_buffer &buffer); 35 | net::awaitable send_welcome(); 36 | 37 | void push_llm_response(std::string str); 38 | net::awaitable asr_loop(); 39 | net::awaitable tts_loop(); 40 | net::awaitable handle(); 41 | public: 42 | Connection(std::shared_ptr setting, websocket::stream ws, net::any_io_executor executor); 43 | ~Connection(); 44 | void start(); 45 | }; 46 | } -------------------------------------------------------------------------------- /include/xz-cpp-server/llm/base.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace xiaozhi { 4 | namespace llm { 5 | class Base { 6 | public: 7 | virtual ~Base(); 8 | virtual net::awaitable create_session() = 0; 9 | virtual net::awaitable response(const boost::json::array& dialogue, const std::function& callback) = 0; 10 | }; 11 | std::unique_ptr createLLM(const net::any_io_executor& executor); 12 | } 13 | } -------------------------------------------------------------------------------- /include/xz-cpp-server/llm/cozev3.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "base.h" 3 | 4 | namespace xiaozhi { 5 | namespace llm { 6 | class CozeV3: public Base { 7 | public: 8 | CozeV3(const net::any_io_executor& executor, const YAML::Node& config); 9 | ~CozeV3(); 10 | net::awaitable create_session() override; 11 | net::awaitable response(const boost::json::array& dialogue, const std::function& callback) override; 12 | private: 13 | class Impl; 14 | std::unique_ptr impl_; 15 | }; 16 | } 17 | } -------------------------------------------------------------------------------- /include/xz-cpp-server/llm/dify.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "base.h" 3 | 4 | namespace xiaozhi { 5 | namespace llm { 6 | class Dify: public Base { 7 | public: 8 | Dify(const net::any_io_executor &executor, const YAML::Node& config); 9 | ~Dify(); 10 | net::awaitable create_session() override; 11 | net::awaitable response(const boost::json::array& dialogue, const std::function& callback) override; 12 | private: 13 | class Impl; 14 | std::unique_ptr impl_; 15 | }; 16 | } 17 | } -------------------------------------------------------------------------------- /include/xz-cpp-server/llm/openai.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "base.h" 3 | 4 | namespace xiaozhi { 5 | namespace llm { 6 | class Openai: public Base { 7 | public: 8 | Openai(const net::any_io_executor &executor, const YAML::Node& config); 9 | ~Openai(); 10 | net::awaitable create_session() override; 11 | net::awaitable response(const boost::json::array& dialogue, const std::function& callback) override; 12 | private: 13 | class Impl; 14 | std::unique_ptr impl_; 15 | }; 16 | } 17 | } -------------------------------------------------------------------------------- /include/xz-cpp-server/precomp.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "yaml-cpp/yaml.h" 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | namespace net = boost::asio; 12 | namespace beast = boost::beast; 13 | namespace json = boost::json; 14 | namespace http = beast::http; 15 | namespace ssl = net::ssl; 16 | namespace websocket = beast::websocket; 17 | using tcp = net::ip::tcp; -------------------------------------------------------------------------------- /include/xz-cpp-server/server.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | namespace xiaozhi { 5 | class Server { 6 | private: 7 | net::io_context ioc; 8 | std::shared_ptr setting; 9 | net::awaitable listen(net::ip::tcp::endpoint endpoint); 10 | net::awaitable run_session(websocket::stream ws); 11 | net::awaitable authenticate(websocket::stream &ws, http::request &req); 12 | public: 13 | Server(std::shared_ptr setting); 14 | void run(); 15 | }; 16 | } -------------------------------------------------------------------------------- /include/xz-cpp-server/silero_vad/silero_vad.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "onnxruntime_cxx_api.h" 3 | 4 | namespace xiaozhi { 5 | class VoiceDetector { 6 | private: 7 | std::shared_ptr session_ = nullptr; 8 | Ort::MemoryInfo memory_info_; 9 | const int window_size_ = 512; // Silero-VAD 窗口大小 10 | std::vector state_; // Silero-VAD 隐藏状态 11 | int64_t sample_rate_ = 16000; 12 | public: 13 | VoiceDetector(); 14 | ~VoiceDetector(); 15 | float predict(std::vector &pcm_data); 16 | }; 17 | class SileroVad { 18 | public: 19 | static std::shared_ptr getVad(); 20 | // ~Vad(); 21 | float predict(std::vector &pcm_data); 22 | 23 | private: 24 | VoiceDetector detector; 25 | SileroVad(); 26 | }; 27 | } -------------------------------------------------------------------------------- /include/xz-cpp-server/silero_vad/vad.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | namespace xiaozhi { 6 | class Vad { 7 | private: 8 | OpusDecoder* decoder_; 9 | float threshold_; 10 | public: 11 | Vad(std::shared_ptr setting); 12 | ~Vad(); 13 | bool is_vad(beast::flat_buffer &buffer); 14 | }; 15 | } -------------------------------------------------------------------------------- /include/xz-cpp-server/tts/base.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace xiaozhi { 4 | namespace tts { 5 | class Base { 6 | public: 7 | virtual ~Base(); 8 | virtual net::awaitable>> text_to_speak(const std::string& text) = 0; 9 | }; 10 | std::unique_ptr createTTS(const net::any_io_executor& executor); 11 | } 12 | } -------------------------------------------------------------------------------- /include/xz-cpp-server/tts/bytedancev3.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "base.h" 3 | 4 | namespace xiaozhi { 5 | namespace tts { 6 | class BytedanceV3: public Base { 7 | public: 8 | BytedanceV3(const net::any_io_executor& executor, const YAML::Node& config, int sample_rate); 9 | ~BytedanceV3(); 10 | net::awaitable>> text_to_speak(const std::string& text) override; 11 | private: 12 | class Impl; 13 | std::unique_ptr impl_; 14 | }; 15 | } 16 | } -------------------------------------------------------------------------------- /include/xz-cpp-server/tts/edge.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "base.h" 3 | 4 | namespace xiaozhi { 5 | namespace tts { 6 | class Edge: public Base { 7 | public: 8 | Edge(const net::any_io_executor& executor, const YAML::Node& config, int sample_rate); 9 | ~Edge(); 10 | net::awaitable>> text_to_speak(const std::string& text) override; 11 | private: 12 | class Impl; 13 | std::unique_ptr impl_; 14 | }; 15 | } 16 | } -------------------------------------------------------------------------------- /models/silero_vad.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daxpot/xiaozhi-cpp-server/eae257d6e94594b2f4112ed94c33ba4adf33869c/models/silero_vad.onnx -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 先运行scripts/install_deps.sh安装依赖 3 | # 如果没有传入参数 $1,则默认使用 web_server 4 | if [ -z "$1" ]; then 5 | TARGET="web_server" 6 | else 7 | TARGET="$1" 8 | fi 9 | 10 | git submodule update --init --recursive 11 | mkdir -p build 12 | cd build || { echo "Error: Cannot enter build directory"; exit 1; } 13 | cmake -DKALDI_NATIVE_FBANK_BUILD_TESTS=OFF -DKALDI_NATIVE_FBANK_BUILD_PYTHON=OFF .. 14 | 15 | TARGET_NAME=$(basename "$TARGET") 16 | 17 | make "$TARGET_NAME" || { echo "Error: Make failed"; exit 1; } 18 | cd .. 19 | ./build/apps/"$TARGET" || { echo "Error: Failed to run $TARGET_NAME"; exit 1; } -------------------------------------------------------------------------------- /scripts/install_deps.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | apt-get update && apt-get install -y \ 3 | libboost-all-dev \ 4 | openssl \ 5 | libssl-dev \ 6 | libopus-dev \ 7 | libogg-dev -------------------------------------------------------------------------------- /scripts/tests/concurr_ws_client.py: -------------------------------------------------------------------------------- 1 | import websockets 2 | import asyncio 3 | import opuslib_next 4 | from pydub import AudioSegment 5 | import os 6 | import numpy as np 7 | import json 8 | import time 9 | 10 | ws_url = "ws://127.0.0.1:8000" 11 | 12 | def wav_to_opus_data(wav_file_path): 13 | # 使用pydub加载PCM文件 14 | # 获取文件后缀名 15 | file_type = os.path.splitext(wav_file_path)[1] 16 | if file_type: 17 | file_type = file_type.lstrip('.') 18 | audio = AudioSegment.from_file(wav_file_path, format=file_type) 19 | 20 | duration = len(audio) / 1000.0 21 | 22 | # 转换为单声道和16kHz采样率(确保与编码器匹配) 23 | audio = audio.set_channels(1).set_frame_rate(16000) 24 | 25 | # 获取原始PCM数据(16位小端) 26 | raw_data = audio.raw_data 27 | 28 | # 初始化Opus编码器 29 | encoder = opuslib_next.Encoder(16000, 1, opuslib_next.APPLICATION_AUDIO) 30 | 31 | # 编码参数 32 | frame_duration = 60 # 60ms per frame 33 | frame_size = int(16000 * frame_duration / 1000) # 960 samples/frame 34 | 35 | opus_datas = [] 36 | # 按帧处理所有音频数据(包括最后一帧可能补零) 37 | for i in range(0, len(raw_data), frame_size * 2): # 16bit=2bytes/sample 38 | # 获取当前帧的二进制数据 39 | chunk = raw_data[i:i + frame_size * 2] 40 | 41 | # 如果最后一帧不足,补零 42 | if len(chunk) < frame_size * 2: 43 | chunk += b'\x00' * (frame_size * 2 - len(chunk)) 44 | 45 | # 转换为numpy数组处理 46 | np_frame = np.frombuffer(chunk, dtype=np.int16) 47 | 48 | # 编码Opus数据 49 | opus_data = encoder.encode(np_frame.tobytes(), frame_size) 50 | opus_datas.append(opus_data) 51 | 52 | return opus_datas, duration 53 | 54 | 55 | async def connect(i, opus_data): 56 | async with websockets.connect(ws_url, additional_headers={"Authorization": "Bearer test-token", "Device-Id": "test-device"}) as ws: 57 | await ws.send('{"type":"hello","version": 1,"transport":"websocket","audio_params":{"format":"opus", "sample_rate":16000, "channels":1, "frame_duration":60}}') 58 | ret = await ws.recv() 59 | print(i, ret) 60 | for data in opus_data[0]: 61 | await ws.send(data, text=False) 62 | # return 63 | start = time.time() 64 | while True: 65 | ret = await ws.recv() 66 | if isinstance(ret, bytes): 67 | # print(i, len(ret)) 68 | pass 69 | else: 70 | print(i, time.time() - start, ret) 71 | rej = json.loads(ret) 72 | if rej["type"] == "tts" and rej["state"] == "stop": 73 | break 74 | # if rej.get("state") == "sentence_start": 75 | # break 76 | 77 | async def main(): 78 | opus_data = wav_to_opus_data("tmp/example.wav") 79 | loop = asyncio.get_event_loop() 80 | futs = [] 81 | for i in range(0, 10): 82 | fut = loop.create_task(connect(i, opus_data)) 83 | futs.append(fut) 84 | for fut in futs: 85 | await fut 86 | 87 | if __name__ == "__main__": 88 | asyncio.run(main()) -------------------------------------------------------------------------------- /scripts/tests/connect.py: -------------------------------------------------------------------------------- 1 | import websockets 2 | import asyncio 3 | 4 | ws_url = "ws://127.0.0.1:8000" 5 | 6 | async def connect(i): 7 | async with websockets.connect(ws_url, additional_headers={"Authorization": "Bearer test-token", "Device-Id": "test-device"}) as ws: 8 | await ws.send('{"type":"hello","version": 1,"transport":"websocket","audio_params":{"format":"opus", "sample_rate":16000, "channels":1, "frame_duration":60}}') 9 | ret = await ws.recv() 10 | print(i, ret) 11 | 12 | async def main(): 13 | loop = asyncio.get_event_loop() 14 | futs = [] 15 | for i in range(0, 1): 16 | fut = loop.create_task(connect(i)) 17 | futs.append(fut) 18 | for fut in futs: 19 | await fut 20 | 21 | if __name__ == "__main__": 22 | asyncio.run(main()) -------------------------------------------------------------------------------- /scripts/tests/https_echo.py: -------------------------------------------------------------------------------- 1 | import ssl 2 | from aiohttp import web 3 | import json 4 | 5 | async def handle(request: web.Request): 6 | debug_info = { 7 | "method": request.method, 8 | "path": request.path, 9 | "query": dict(request.query), # 查询参数 10 | "headers": dict(request.headers), # 请求头部 11 | } 12 | print(request.method, request.path, dict(request.query)) 13 | print(dict(request.headers)) 14 | 15 | # 尝试读取请求体(如果是 POST/PUT 等) 16 | body = None 17 | if request.can_read_body: 18 | try: 19 | body = await request.json() # 尝试解析为 JSON 20 | except json.JSONDecodeError: 21 | body = await request.text() # 如果不是 JSON,则返回原始文本 22 | debug_info["body"] = body 23 | print(body) 24 | # 格式化调试信息 25 | response_text = json.dumps(debug_info, indent=2, ensure_ascii=False) 26 | 27 | # 返回调试信息 28 | return web.Response( 29 | text=response_text, 30 | content_type="application/json", 31 | charset="utf-8" 32 | ) 33 | 34 | # 创建应用并设置路由 35 | app = web.Application() 36 | app.router.add_route("*", "/{path:.*}", handle) 37 | 38 | # 配置 SSL 上下文 39 | ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) 40 | ssl_context.load_cert_chain(certfile='tmp/server.crt', keyfile='tmp/server.key') 41 | 42 | # 运行服务器 43 | if __name__ == '__main__': 44 | web.run_app(app, host='localhost', port=8002, ssl_context=ssl_context) -------------------------------------------------------------------------------- /scripts/tests/vad.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | 3 | model = onnx.load("models/silero_vad.onnx") 4 | for input in model.graph.input: 5 | print("input", input.name, input.type) 6 | 7 | 8 | for output in model.graph.output: 9 | print("output", output.name, output.type) -------------------------------------------------------------------------------- /scripts/tests/wss_echo.py: -------------------------------------------------------------------------------- 1 | import websockets 2 | import asyncio 3 | import ssl 4 | import pathlib 5 | 6 | async def handle(conn: websockets.ServerConnection): 7 | print("handle header") 8 | print(conn.request.path) 9 | print(conn.request.headers) 10 | while True: 11 | data = await conn.recv() 12 | print("read", data.encode("utf-8")) 13 | # ret = await conn.send(data) 14 | # print("write", ret) 15 | 16 | async def main(): 17 | ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) 18 | ssl_context.load_cert_chain("tmp/server.crt", "tmp/server.key") 19 | async with websockets.serve( 20 | handle, 21 | "127.0.0.1", 22 | "8000", 23 | ssl=ssl_context 24 | ): 25 | await asyncio.Future() 26 | 27 | if __name__ == "__main__": 28 | asyncio.run(main()) -------------------------------------------------------------------------------- /src/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(common) 2 | add_subdirectory(paraformer) 3 | add_subdirectory(asr) 4 | add_subdirectory(tts) 5 | add_subdirectory(llm) 6 | add_subdirectory(silero_vad) 7 | 8 | 9 | #server library 10 | add_library(server connection.cpp server.cpp) 11 | target_link_libraries(server PUBLIC precomp) 12 | target_link_libraries(server PUBLIC common) 13 | target_link_libraries(server PUBLIC vad) 14 | target_link_libraries(server PUBLIC asr) 15 | target_link_libraries(server PUBLIC tts) 16 | target_link_libraries(server PUBLIC llm) 17 | 18 | -------------------------------------------------------------------------------- /src/asr/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | file(GLOB_RECURSE cpp_files *.cpp) 2 | add_library(asr ${cpp_files}) 3 | 4 | target_link_libraries(asr PRIVATE precomp) 5 | target_link_libraries(asr PRIVATE common) 6 | target_link_libraries(asr PRIVATE OpenSSL::SSL) 7 | target_link_libraries(asr PRIVATE ${OGG_LIBRARY}) 8 | target_link_libraries(asr PRIVATE ${OPUS_LIBRARY}) 9 | target_link_libraries(asr PRIVATE paraformer) 10 | -------------------------------------------------------------------------------- /src/asr/base.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | namespace xiaozhi { 7 | namespace asr { 8 | std::unique_ptr createASR(const net::any_io_executor& executor) { 9 | auto setting = Setting::getSetting(); 10 | auto selected_module = setting->config["selected_module"]["ASR"].as(); 11 | if(selected_module == "BytedanceASRV2") { 12 | return std::make_unique(executor, setting->config["ASR"][selected_module]); 13 | } else if(selected_module == "Paraformer") { 14 | return std::make_unique(executor, setting->config["ASR"][selected_module]); 15 | } else { 16 | throw std::invalid_argument("Selected_module ASR not be supported"); 17 | } 18 | } 19 | Base::~Base() { 20 | BOOST_LOG_TRIVIAL(debug) << "ASR base destroyed"; 21 | } 22 | } 23 | } -------------------------------------------------------------------------------- /src/asr/paraformer.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | namespace xiaozhi { 9 | namespace asr { 10 | class ParaformerSingleton { 11 | public: 12 | static funasr::Model* get_instance(const YAML::Node& config) { 13 | std::call_once(init_flag_, [&config]() { 14 | std::map model_path; 15 | model_path[MODEL_DIR] = config["model_dir"].as(); // 替换为你的模型目录 16 | model_path[QUANTIZE] = config["quantize"].as(); 17 | int thread_num = config["thread_num"].as(); 18 | instance_ = funasr::CreateModel(model_path, thread_num); 19 | }); 20 | return instance_; 21 | } 22 | ParaformerSingleton(const ParaformerSingleton&) = delete; 23 | ParaformerSingleton& operator=(const ParaformerSingleton&) = delete; 24 | 25 | private: 26 | ParaformerSingleton() = default; // 私有构造函数 27 | static funasr::Model* instance_; 28 | static std::once_flag init_flag_; 29 | }; 30 | 31 | funasr::Model* ParaformerSingleton::instance_ = nullptr; 32 | std::once_flag ParaformerSingleton::init_flag_; 33 | 34 | class Paraformer::Impl { 35 | private: 36 | 37 | std::vector pcm_; 38 | int decoded_length_ = 0; 39 | std::string full_result_; 40 | 41 | funasr::ParaformerOnline* online_model; 42 | OpusDecoder* decoder_; 43 | 44 | public: 45 | Impl(const net::any_io_executor& executor, const YAML::Node& config) { 46 | auto offline_model = ParaformerSingleton::get_instance(config); 47 | online_model = static_cast(funasr::CreateModel(offline_model, {5, 10, 5})); 48 | int error; 49 | decoder_ = opus_decoder_create(16000, 1, &error); 50 | if (error != OPUS_OK) throw std::runtime_error("Paraformer Opus 解码器初始化失败"); 51 | pcm_.reserve(960*11); 52 | } 53 | 54 | 55 | ~Impl() { 56 | BOOST_LOG_TRIVIAL(debug) << "Paraformer asr destroyed"; 57 | } 58 | 59 | net::awaitable detect_opus(const std::optional& buf) { 60 | if(buf) { 61 | int decoded_samples = opus_decode_float(decoder_, 62 | static_cast(buf->data().data()), 63 | buf->size(), pcm_.data() + decoded_length_, 960, 0); 64 | if(decoded_samples < 0) { 65 | BOOST_LOG_TRIVIAL(error) << "Paraformer opus 解码失败:" << opus_strerror(decoded_samples); 66 | co_return ""; 67 | } 68 | decoded_length_ += decoded_samples; 69 | } 70 | if(decoded_length_ >= 9600 || !buf) { 71 | std::string result = online_model->Forward(pcm_.data(), decoded_length_, !buf); 72 | full_result_ += result; 73 | if(!result.empty()) { 74 | BOOST_LOG_TRIVIAL(debug) << "Paraformer detect asr:" << result << ",pcm length:" << decoded_length_; 75 | } 76 | decoded_length_ = 0; 77 | } 78 | if(!buf) { 79 | co_return std::move(full_result_); 80 | } 81 | co_return ""; 82 | } 83 | }; 84 | 85 | Paraformer::Paraformer(const net::any_io_executor& executor, const YAML::Node& config) { 86 | impl_ = std::make_unique(executor, config); 87 | } 88 | 89 | Paraformer::~Paraformer() { 90 | 91 | } 92 | 93 | net::awaitable Paraformer::detect_opus(const std::optional& buf) { 94 | co_return co_await impl_->detect_opus(buf); 95 | } 96 | } 97 | } -------------------------------------------------------------------------------- /src/common/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | file(GLOB_RECURSE cpp_files *.cpp) 2 | add_library(common ${cpp_files}) 3 | 4 | target_link_libraries(common PRIVATE precomp) 5 | target_link_libraries(common PRIVATE OpenSSL::SSL) 6 | target_link_libraries(common PRIVATE ${OPUS_LIBRARY}) 7 | target_link_libraries(common PRIVATE Boost::iostreams) 8 | target_link_libraries(common PRIVATE Boost::random) 9 | target_link_libraries(common PRIVATE Boost::log_setup) -------------------------------------------------------------------------------- /src/common/logger.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | // 自定义格式化函数,根据日志级别设置颜色 8 | void color_formatter(boost::log::record_view const& rec, boost::log::formatting_ostream& strm) { 9 | auto severity = rec[boost::log::trivial::severity]; 10 | if (severity) { 11 | // 根据级别设置颜色 12 | if (severity.get() == boost::log::trivial::info) { 13 | strm << "\033[32m"; // 绿色 14 | } else if (severity.get() == boost::log::trivial::error) { 15 | strm << "\033[31m"; // 红色 16 | } else { 17 | strm << "\033[0m"; // 默认颜色(无色) 18 | } 19 | // 提取 TimeStamp 的值并格式化 20 | auto timestamp = rec["TimeStamp"]; 21 | if (timestamp) { 22 | boost::posix_time::ptime time = timestamp.extract().get(); 23 | strm << "[" << boost::posix_time::to_simple_string(time) << "] "; 24 | } 25 | // 输出格式 26 | strm << "[" << severity << "] " 27 | << "\033[0m" 28 | << rec[boost::log::expressions::smessage]; 29 | } 30 | } 31 | 32 | void init_logging(std::string log_level) { 33 | // 设置日志输出到终端 34 | auto console_sink = boost::log::add_console_log(std::cout); 35 | // 设置自定义格式化函数 36 | console_sink->set_formatter(&color_formatter); 37 | 38 | // 设置日志输出到文件 39 | boost::log::add_file_log( 40 | boost::log::keywords::file_name = "tmp/server_%N.log", // 文件名模式,%N 为文件序号 41 | boost::log::keywords::rotation_size = 10 * 1024 * 1024, // 10MB 轮换 42 | boost::log::keywords::format = "[%TimeStamp%] [%Severity%]: %Message%" // 日志格式 43 | ); 44 | 45 | // 添加常用属性,如时间戳 46 | boost::log::add_common_attributes(); 47 | auto level = boost::log::trivial::info; 48 | if(log_level == "DEBUG") { 49 | level = boost::log::trivial::debug; 50 | } else if(log_level == "ERROR") { 51 | level = boost::log::trivial::error; 52 | } else if(log_level == "WARNING") { 53 | level = boost::log::trivial::warning; 54 | } 55 | boost::log::core::get()->set_filter( 56 | boost::log::trivial::severity >= level 57 | ); 58 | } -------------------------------------------------------------------------------- /src/common/request.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | namespace request { 4 | UrlInfo parse_url(const std::string& url) { 5 | UrlInfo info{false, "", "", "/"}; 6 | std::regex url_regex(R"(^(https?)://([^:/]+)(?::(\d+))?(/.*)?$)"); 7 | std::smatch matches; 8 | if (std::regex_match(url, matches, url_regex)) { 9 | info.is_https = (matches[1] == "https"); 10 | info.host = matches[2]; 11 | info.port = matches[3].matched ? std::string(matches[3]) : (info.is_https ? "443" : "80"); 12 | info.path = matches[4].matched ? std::string(matches[4]) : "/"; 13 | } else { 14 | throw std::invalid_argument("Invalid URL format"); 15 | } 16 | 17 | return info; 18 | } 19 | static ssl::context get_ssl_context() { 20 | ssl::context ctx{ssl::context::sslv23_client}; 21 | ctx.set_verify_mode(ssl::verify_peer); 22 | ctx.set_default_verify_paths(); 23 | return ctx; 24 | } 25 | 26 | net::awaitable> connect(const UrlInfo& url_info) { 27 | auto executor = co_await net::this_coro::executor; 28 | auto ctx = get_ssl_context(); 29 | tcp::resolver resolver{executor}; 30 | ssl::stream stream{executor, ctx}; 31 | 32 | if(url_info.is_https && !SSL_set_tlsext_host_name(stream.native_handle(), url_info.host.c_str())) { 33 | throw boost::system::system_error( 34 | static_cast(::ERR_get_error()), 35 | net::error::get_ssl_category()); 36 | } 37 | auto const results = co_await resolver.async_resolve(url_info.host, url_info.port, net::use_awaitable); 38 | beast::get_lowest_layer(stream).expires_after(std::chrono::seconds(30)); 39 | co_await beast::get_lowest_layer(stream).async_connect(results, net::use_awaitable); 40 | // Set the timeout. 41 | if(url_info.is_https) { 42 | beast::get_lowest_layer(stream).expires_after(std::chrono::seconds(30)); 43 | co_await stream.async_handshake(ssl::stream_base::client, net::use_awaitable); 44 | } 45 | co_return std::move(stream); 46 | } 47 | net::awaitable send(ssl::stream& stream, const http::verb method, const UrlInfo& url_info, const json::value& header, const std::string& data) { 48 | http::request req{ method, url_info.path, 11 }; 49 | req.set(http::field::host, url_info.host); 50 | if(header.is_object()) { 51 | for(const auto& item : header.as_object()) { 52 | req.set(item.key(), item.value().as_string()); 53 | } 54 | } 55 | if(method == http::verb::post) { 56 | req.body() = data; 57 | req.prepare_payload(); 58 | } 59 | 60 | // Set the timeout. 61 | beast::get_lowest_layer(stream).expires_after(std::chrono::seconds(30)); 62 | 63 | if(url_info.is_https) { 64 | co_await http::async_write(stream, req, net::use_awaitable); 65 | } else { 66 | co_await http::async_write(stream.next_layer(), req, net::use_awaitable); 67 | } 68 | co_return; 69 | } 70 | 71 | net::awaitable close(ssl::stream& stream, bool is_https) { 72 | if(is_https) { 73 | auto [ec] = co_await stream.async_shutdown(net::as_tuple(net::use_awaitable)); 74 | if(ec && ec != net::ssl::error::stream_truncated) 75 | BOOST_LOG_TRIVIAL(error) << "Request shutdown error:" << ec.message(); 76 | } else { 77 | beast::error_code ec; 78 | auto r = stream.next_layer().socket().shutdown(net::ip::tcp::socket::shutdown_both, ec); 79 | if(ec && ec != beast::errc::not_connected) 80 | BOOST_LOG_TRIVIAL(error) << "Request shutdown error:" << ec.message(); 81 | } 82 | } 83 | 84 | net::awaitable request(const http::verb method, const std::string& url, const json::value& header, const std::string& data) { 85 | auto url_info = parse_url(url); 86 | auto stream = co_await connect(url_info); 87 | co_await send(stream, method, url_info, header, data); 88 | // This buffer is used for reading and must be persisted 89 | beast::flat_buffer buffer; 90 | 91 | // Declare a container to hold the response 92 | http::response res; 93 | url_info.is_https 94 | ? co_await http::async_read(stream, buffer, res, net::use_awaitable) 95 | : co_await http::async_read(stream.next_layer(), buffer, res, net::use_awaitable); 96 | 97 | co_await close(stream, url_info.is_https); 98 | 99 | std::string ret; 100 | ret.reserve(res.body().size()); 101 | for(auto buf : res.body().data()) { 102 | ret.append(static_cast(buf.data()), buf.size()); 103 | } 104 | co_return ret; 105 | } 106 | 107 | net::awaitable get(const std::string& url, const json::value& header) { 108 | co_return co_await request(http::verb::get, url, header); 109 | } 110 | 111 | net::awaitable post(const std::string& url, const json::value& header, const std::string& data) { 112 | co_return co_await request(http::verb::post, url, header, data); 113 | } 114 | 115 | static std::tuple parse_chunk(beast::flat_buffer& buffer) { 116 | std::string body; 117 | std::string_view view(static_cast(buffer.data().data()), buffer.data().size()); 118 | size_t pos = 0; 119 | if(view.starts_with("HTTP")) { 120 | auto header_end = view.find("\r\n\r\n"); 121 | if(header_end == std::string_view::npos) { //header数据不完整 122 | return {body, false}; 123 | } 124 | pos = header_end + 4; 125 | } 126 | bool is_over = false; 127 | while (pos < view.size()) { 128 | size_t size_end = view.find("\r\n", pos); 129 | if (size_end == std::string_view::npos) break; 130 | 131 | std::string size_str(view.substr(pos, size_end - pos)); 132 | size_t chunk_size = std::stoul(size_str, nullptr, 16); 133 | pos = size_end + 2; 134 | if (chunk_size == 0) { 135 | is_over = true; 136 | break; 137 | } 138 | 139 | if (pos + chunk_size > view.size()) break; 140 | body.append(view.substr(pos, chunk_size)); 141 | pos += chunk_size + 2; 142 | } 143 | 144 | buffer.consume(pos); 145 | 146 | return {body, is_over}; 147 | } 148 | 149 | net::awaitable stream_post(const std::string& url, const json::value& header, const std::string& data, const std::function& callback) { 150 | auto url_info = parse_url(url); 151 | auto stream = co_await connect(url_info); 152 | co_await send(stream, http::verb::post, url_info, header, data); 153 | 154 | beast::flat_buffer buffer; 155 | 156 | while (true) { 157 | beast::get_lowest_layer(stream).expires_after(std::chrono::seconds(30)); 158 | auto [ec, bytes_transferred] = url_info.is_https 159 | ? co_await stream.async_read_some(buffer.prepare(8192), net::as_tuple(net::use_awaitable)) 160 | : co_await stream.next_layer().async_read_some( buffer.prepare(8192), net::as_tuple(net::use_awaitable)); 161 | if(ec) { 162 | BOOST_LOG_TRIVIAL(error) << "Stream request read some error:" << ec.message(); 163 | break; 164 | } 165 | buffer.commit(bytes_transferred); 166 | auto [chunk, is_over] = parse_chunk(buffer); 167 | if(chunk.size() > 0) { 168 | callback(std::move(chunk)); 169 | } 170 | if(is_over) 171 | break; 172 | } 173 | co_await close(stream, url_info.is_https); 174 | } 175 | } -------------------------------------------------------------------------------- /src/common/setting.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | static std::shared_ptr setting = nullptr; 4 | static std::once_flag settingFlag; 5 | 6 | std::shared_ptr xiaozhi::Setting::getSetting(const char* path) { 7 | std::call_once(settingFlag, [&] { 8 | setting = std::shared_ptr(new Setting(path)); 9 | }); 10 | return setting; 11 | } 12 | 13 | xiaozhi::Setting::Setting(const char* path) { 14 | config = YAML::LoadFile(path == 0 ? "config.yaml" : path); 15 | } -------------------------------------------------------------------------------- /src/common/tools.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | // 常见的中文标点符号(UTF-8 编码) 11 | static const std::unordered_set chineseSegments { 12 | 0xE38081, // 、 13 | 0xEFBC8C, // , 14 | 0xE38082, // 。 15 | 0xEFBC81, // ! 16 | 0xEFBC9F, // ? 17 | 0xEFBC9B, // ; 18 | 0xEFBC9A // : 19 | }; 20 | 21 | namespace tools { 22 | SegmentRet is_segment(const std::string& str, std::string::size_type pos) { 23 | unsigned char byte = static_cast(str[pos]); 24 | 25 | // 单字节英文标点符号 26 | if (byte < 0x80) { 27 | return std::string(",.!?;:").find(byte) == std::string::npos 28 | ? SegmentRet::NONE 29 | : SegmentRet::EN; 30 | } 31 | 32 | // 多字节中文标点符号(UTF-8 编码) 33 | if (byte >= 0xE0) { 34 | // 将 3 个字节组合成一个 uint32_t(只用低 24 位) 35 | uint32_t seq = (static_cast(static_cast(str[pos])) << 16) | 36 | (static_cast(static_cast(str[pos + 1])) << 8) | 37 | static_cast(static_cast(str[pos + 2])); 38 | if(chineseSegments.contains(seq)) { 39 | return SegmentRet::CHINESE; 40 | } 41 | } 42 | return SegmentRet::NONE; 43 | } 44 | 45 | std::string::size_type find_last_segment(const std::string& input) { 46 | if (input.empty()) { 47 | return std::string::npos; 48 | } 49 | 50 | // 从后往前遍历字节 51 | for (std::string::size_type i = input.size() - 1; i != std::string::npos; --i) { 52 | auto ret = is_segment(input, i); 53 | if(ret == SegmentRet::EN) { 54 | return i; 55 | } else if(ret == SegmentRet::CHINESE) { 56 | return i+2; 57 | } 58 | } 59 | 60 | return std::string::npos; 61 | } 62 | 63 | std::string generate_uuid() { 64 | auto uuid = boost::uuids::random_generator()(); 65 | return boost::uuids::to_string(uuid); 66 | } 67 | 68 | std::string gzip_compress(const std::string &data) { 69 | std::stringstream compressed; 70 | std::stringstream origin(data); 71 | boost::iostreams::filtering_streambuf out; 72 | out.push(boost::iostreams::gzip_compressor()); 73 | out.push(origin); 74 | boost::iostreams::copy(out, compressed); 75 | return compressed.str(); 76 | } 77 | 78 | std::string gzip_decompress(const std::string &data) { 79 | std::stringstream compressed(data); 80 | std::stringstream decompressed; 81 | 82 | boost::iostreams::filtering_streambuf out; 83 | out.push(boost::iostreams::gzip_decompressor()); 84 | out.push(compressed); 85 | boost::iostreams::copy(out, decompressed); 86 | 87 | return decompressed.str(); 88 | } 89 | 90 | long long get_tms() { 91 | auto now = std::chrono::system_clock::now(); 92 | auto ms = std::chrono::duration_cast(now.time_since_epoch()); 93 | return ms.count(); 94 | } 95 | 96 | std::tuple create_opus_coders(int sample_rate, bool create_encoder, bool create_decoder) { 97 | int error; 98 | OpusEncoder* encoder = nullptr; 99 | OpusDecoder* decoder = nullptr; 100 | if(create_encoder) { 101 | encoder = opus_encoder_create(sample_rate, 1, OPUS_APPLICATION_AUDIO, &error); 102 | if (error != OPUS_OK) { 103 | BOOST_LOG_TRIVIAL(error) << "Failed to create opus encoder:" << opus_strerror(error); 104 | return {nullptr, nullptr}; 105 | } 106 | } 107 | if(create_decoder) { 108 | decoder = opus_decoder_create(sample_rate, 1, &error); 109 | if (error != OPUS_OK) { 110 | BOOST_LOG_TRIVIAL(error) << "Failed to create opus decoder:" << opus_strerror(error); 111 | if(encoder) 112 | opus_encoder_destroy(encoder); 113 | return {nullptr, nullptr}; 114 | } 115 | } 116 | return {encoder, decoder}; 117 | } 118 | 119 | void on_spawn_complete(std::string_view title, std::exception_ptr e) { 120 | if(e) { 121 | try { 122 | std::rethrow_exception(e); 123 | } catch(std::exception& e) { 124 | BOOST_LOG_TRIVIAL(error) << title << " spawn error:" << e.what(); 125 | } 126 | } 127 | } 128 | } 129 | -------------------------------------------------------------------------------- /src/connection.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | namespace xiaozhi { 5 | Connection::Connection(std::shared_ptr setting, websocket::stream ws, net::any_io_executor executor): 6 | setting_(setting), 7 | cmd_exit_(setting->config["CMD_exit"].as>()), 8 | vad_(setting), 9 | executor_(executor), 10 | ws_(std::move(ws)), 11 | strand_(ws_.get_executor()), 12 | silence_timer_(executor_) { 13 | auto prompt = setting->config["prompt"].as(); 14 | size_t pos = prompt.find("{date_time}"); 15 | if(pos != prompt.npos) { 16 | auto now = std::chrono::system_clock::now(); 17 | prompt.replace(pos, 11, std::format("{:%Y-%m-%d %H:%M:%S}", std::chrono::time_point_cast(now))); 18 | } 19 | 20 | dialogue_.push_back(boost::json::object{{"role", "system"}, {"content", std::move(prompt)}}); 21 | min_silence_tms_ = setting->config["VAD"]["SileroVAD"]["min_silence_duration_ms"].as(); 22 | close_connection_no_voice_time_ = setting->config["close_connection_no_voice_time"].as(), 23 | asr_ = asr::createASR(executor_); 24 | llm_ = llm::createLLM(executor_); 25 | tts_ = tts::createTTS(executor_); 26 | } 27 | 28 | void Connection::start() { 29 | auto self = shared_from_this(); 30 | net::co_spawn(executor_, [self] { 31 | return self->asr_loop(); 32 | }, std::bind_front(tools::on_spawn_complete, "Connection asr loop")); 33 | net::co_spawn(strand_, [self] { 34 | return self->tts_loop(); 35 | }, std::bind_front(tools::on_spawn_complete, "Connection tts loop")); 36 | net::co_spawn(strand_, [self] { 37 | return self->handle(); 38 | }, std::bind_front(tools::on_spawn_complete, "Connection handle")); 39 | } 40 | 41 | net::awaitable Connection::asr_loop() { 42 | while(!is_released_) { 43 | std::optional buf; 44 | if(!asr_audio_.try_pop(buf)) { 45 | co_await net::steady_timer(executor_, std::chrono::milliseconds(20)).async_wait(net::use_awaitable); 46 | continue; 47 | } 48 | auto text = co_await asr_->detect_opus(buf); 49 | if(!buf && text.size() > 0) { 50 | net::co_spawn(executor_, [self=shared_from_this(), text=std::move(text)] { 51 | return self->handle_asr_text(std::move(text)); 52 | }, std::bind_front(tools::on_spawn_complete, "Connection handle asr text")); 53 | } 54 | } 55 | BOOST_LOG_TRIVIAL(info) << "Connection asr loop over"; 56 | } 57 | 58 | net::awaitable Connection::tts_loop() { 59 | long long tts_stop_end_timestamp = 0; 60 | std::queue> tts_sentence_queue; 61 | while(!is_released_) { 62 | std::string text; 63 | auto now = tools::get_tms(); 64 | if(!llm_response_.try_pop(text)) { 65 | net::steady_timer timer(executor_, std::chrono::milliseconds(60)); 66 | co_await timer.async_wait(net::use_awaitable); 67 | } else { 68 | BOOST_LOG_TRIVIAL(info) << "获取大模型输出:" << text; 69 | if(text == "") { 70 | std::queue>().swap(tts_sentence_queue); 71 | const std::string_view data = R"({"type":"tts","state":"stop"})"; 72 | co_await ws_.async_write(net::buffer(data.data(), data.size()), net::use_awaitable); 73 | tts_stop_end_timestamp = 0; 74 | continue; 75 | } 76 | if(tts_stop_end_timestamp == 0) { 77 | ws_.text(true); 78 | const std::string_view data = R"({"type":"tts","state":"start"})"; 79 | co_await ws_.async_write(net::buffer(data.data(), data.size()), net::use_awaitable); 80 | now = tools::get_tms(); 81 | tts_stop_end_timestamp = now; 82 | BOOST_LOG_TRIVIAL(debug) << "tts start:" << now; 83 | } 84 | auto audio = co_await tts_->text_to_speak(text); 85 | tts_sentence_queue.push({text, tts_stop_end_timestamp}); 86 | for(auto& data: audio) { 87 | ws_.binary(true); 88 | co_await ws_.async_write(net::buffer(data), net::use_awaitable); 89 | tts_stop_end_timestamp += 60; //60ms一段音频 90 | } 91 | } 92 | while(!tts_sentence_queue.empty()) { 93 | auto front = tts_sentence_queue.front(); 94 | now = tools::get_tms(); 95 | if(now < front.second) { 96 | break; 97 | } 98 | ws_.text(true); 99 | boost::json::object obj = { 100 | {"type", "tts"}, 101 | {"state", "sentence_start"}, 102 | {"text", front.first} 103 | }; 104 | co_await ws_.async_write(net::buffer(boost::json::serialize(obj)), net::use_awaitable); 105 | tts_sentence_queue.pop(); 106 | // BOOST_LOG_TRIVIAL(debug) << "tts sentence start:" << front.first << ",now:" << now << ",plan:" << front.second; 107 | } 108 | now = tools::get_tms(); 109 | if(tts_stop_end_timestamp != 0 && now > tts_stop_end_timestamp) { 110 | ws_.text(true); 111 | const std::string_view data = R"({"type":"tts","state":"stop"})"; 112 | co_await ws_.async_write(net::buffer(data.data(), data.size()), net::use_awaitable); 113 | tts_stop_end_timestamp = 0; 114 | BOOST_LOG_TRIVIAL(debug) << "tts end:" << now; 115 | } 116 | } 117 | BOOST_LOG_TRIVIAL(info) << "Connection tts loop over"; 118 | } 119 | 120 | void Connection::push_llm_response(std::string str) { 121 | //去除首位空格和markdown的*号#号 122 | str = std::regex_replace(str, std::regex(R"((^\s+)|(\s+$)[\*\#])"), ""); 123 | if(str.size() > 0) { 124 | llm_response_.push(std::move(str)); 125 | } 126 | } 127 | 128 | net::awaitable Connection::handle_asr_text(std::string text) { 129 | BOOST_LOG_TRIVIAL(info) << "Connection handle asr text:" << text; 130 | for(auto& cmd : cmd_exit_) { 131 | if(text == cmd) { 132 | is_released_ = true; 133 | co_return; 134 | } 135 | } 136 | dialogue_.push_back(boost::json::object{{"role", "user"}, {"content", text}}); 137 | std::string message; 138 | size_t pos = 0; 139 | co_await llm_->response(dialogue_, [this, &message, &pos](std::string_view res) { 140 | message.append(res.data(), res.size()); 141 | auto p = tools::find_last_segment(message); 142 | if(p != message.npos && p - pos + 1 > 6) { 143 | push_llm_response(message.substr(pos, p-pos+1)); 144 | pos = p+1; 145 | } 146 | }); 147 | if(pos < message.size()) { 148 | push_llm_response(message.substr(pos)); 149 | } 150 | dialogue_.push_back(boost::json::object{{"role", "assistant"}, {"content", message}}); 151 | } 152 | 153 | net::awaitable Connection::send_welcome() { 154 | session_id_ = tools::generate_uuid(); 155 | boost::json::object welcome { 156 | {"type", "hello"}, 157 | {"transport", setting_->config["welcome"]["transport"].as()}, 158 | {"audio_params", { 159 | {"sample_rate", setting_->config["welcome"]["audio_params"]["sample_rate"].as()} 160 | }} 161 | }; 162 | auto welcome_msg_str = boost::json::serialize(welcome); 163 | BOOST_LOG_TRIVIAL(info) << "发送welcome_msg:" << welcome_msg_str; 164 | ws_.text(true); 165 | co_await ws_.async_write(net::buffer(std::move(welcome_msg_str)), net::use_awaitable); 166 | } 167 | 168 | net::awaitable Connection::handle_text(beast::flat_buffer &buffer) { 169 | auto data_str = boost::beast::buffers_to_string(buffer.data()); 170 | BOOST_LOG_TRIVIAL(info) << "收到文本消息(" << &ws_ << "):" << data_str; 171 | auto data = boost::json::parse(data_str).as_object(); 172 | if(data["type"] == "hello") { 173 | co_await send_welcome(); 174 | } else if(data["type"] == "listen" && data["state"] == "detect") { 175 | } else if(data["type"] == "abort") { 176 | llm_response_.clear(); 177 | llm_response_.push(""); 178 | } 179 | co_return; 180 | } 181 | 182 | net::awaitable Connection::handle_binary(beast::flat_buffer &buffer) { 183 | if(vad_.is_vad(buffer)) { 184 | // BOOST_LOG_TRIVIAL(debug) << "收到声音(" << &ws_ << "):" << buffer.size(); 185 | asr_audio_.push(std::move(buffer)); 186 | silence_timer_.cancel(); 187 | silence_timer_.expires_after(std::chrono::milliseconds(min_silence_tms_)); 188 | silence_timer_.async_wait([self=shared_from_this()](const boost::system::error_code& ec) { 189 | if(ec != net::error::operation_aborted) { 190 | self->asr_audio_.push(std::nullopt); 191 | } 192 | }); 193 | } else { 194 | // BOOST_LOG_TRIVIAL(debug) << "收到音频(" << &ws_ << "):" << buffer.size(); 195 | } 196 | co_return; 197 | } 198 | 199 | 200 | net::awaitable Connection::handle() { 201 | while(!is_released_) { 202 | beast::flat_buffer buffer; 203 | beast::get_lowest_layer(ws_).expires_after(std::chrono::seconds(close_connection_no_voice_time_)); 204 | auto [ec, _] = co_await ws_.async_read(buffer, net::as_tuple(net::use_awaitable)); 205 | if(ec == websocket::error::closed) { 206 | BOOST_LOG_TRIVIAL(debug) << "handle closed"; 207 | break; 208 | } else if(ec) { 209 | BOOST_LOG_TRIVIAL(debug) << "handle error" << ec.message(); 210 | break; 211 | } 212 | if(ws_.got_text()) { 213 | co_await handle_text(buffer); 214 | } else if(ws_.got_binary()) { 215 | co_await handle_binary(buffer); 216 | } 217 | } 218 | is_released_ = true; 219 | BOOST_LOG_TRIVIAL(debug) << "handle ended"; 220 | } 221 | 222 | Connection::~Connection() { 223 | is_released_ = true; 224 | BOOST_LOG_TRIVIAL(info) << "Connection destroyed"; 225 | } 226 | } -------------------------------------------------------------------------------- /src/llm/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | file(GLOB_RECURSE cpp_files *.cpp) 2 | add_library(llm ${cpp_files}) 3 | 4 | target_link_libraries(llm PRIVATE precomp) 5 | target_link_libraries(llm PRIVATE common) -------------------------------------------------------------------------------- /src/llm/base.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | namespace xiaozhi { 8 | namespace llm { 9 | std::unique_ptr createLLM(const net::any_io_executor& executor) { 10 | auto setting = Setting::getSetting(); 11 | auto selected_module = setting->config["selected_module"]["LLM"].as(); 12 | if(selected_module == "CozeLLMV3") { 13 | return std::make_unique(executor, setting->config["LLM"][selected_module]); 14 | } else if(selected_module == "DifyLLM") { 15 | return std::make_unique(executor, setting->config["LLM"][selected_module]); 16 | } else if(setting->config["LLM"][selected_module]["type"].IsDefined() && setting->config["LLM"][selected_module]["type"].as() == "openai") { 17 | return std::make_unique(executor, setting->config["LLM"][selected_module]); 18 | } else { 19 | throw std::invalid_argument("Selected_module LLM not be supported"); 20 | } 21 | } 22 | Base::~Base() { 23 | BOOST_LOG_TRIVIAL(debug) << "LLM base destroyed"; 24 | } 25 | } 26 | } -------------------------------------------------------------------------------- /src/llm/cozev3.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | namespace xiaozhi { 6 | namespace llm { 7 | class CozeV3::Impl { 8 | private: 9 | std::string conversation_id_; 10 | std::string bot_id_; 11 | std::string user_id_; 12 | boost::json::object header_; 13 | public: 14 | Impl(const net::any_io_executor &executor, const YAML::Node& config): 15 | bot_id_(config["bot_id"].as()), 16 | user_id_(config["user_id"].as()), 17 | header_({ 18 | {"Authorization", std::format("Bearer {}", config["personal_access_token"].as())}, 19 | {"Content-Type", "application/json"}, 20 | {"Connection", "keep-alive"} 21 | }) { 22 | } 23 | 24 | net::awaitable create_session() { 25 | const std::string url = "https://api.coze.cn/v1/conversation/create"; 26 | auto res = co_await request::post(url, header_, ""); 27 | auto rej = boost::json::parse(res).as_object(); 28 | if(!rej.contains("code") || rej.at("code").as_int64() != 0) { 29 | BOOST_LOG_TRIVIAL(error) << "CozeV3 create session failed:" << res; 30 | co_return ""; 31 | } 32 | conversation_id_ = rej.at("data").at("id").as_string(); 33 | co_return conversation_id_; 34 | } 35 | 36 | net::awaitable response(const boost::json::array& dialogue, const std::function& callback) { 37 | std::string query; 38 | for(auto it = dialogue.rbegin(); it != dialogue.rend(); --it) { 39 | if(it->at("role").as_string() == "user") { 40 | query = it->at("content").as_string(); 41 | break; 42 | } 43 | } 44 | const std::string url = std::format("https://api.coze.cn/v3/chat?conversation_id={}", conversation_id_); 45 | boost::json::object data = { 46 | {"bot_id", bot_id_}, 47 | {"user_id", user_id_}, 48 | {"auto_save_history", true}, 49 | {"stream", true}, 50 | {"additional_messages", { 51 | { 52 | {"role", "user"}, 53 | {"content", std::move(query)} 54 | } 55 | }} 56 | }; 57 | co_await request::stream_post(url, header_, boost::json::serialize(data), [&callback](const std::string res) { 58 | std::vector result; 59 | boost::split(result, res, boost::is_any_of("\n")); 60 | bool is_delta = false; 61 | for(auto& line : result) { 62 | if(line.size() == 0) 63 | continue; 64 | if(line == "event:conversation.message.delta") { 65 | is_delta = true; 66 | continue; 67 | } 68 | if(is_delta && line.starts_with("data:")) { 69 | auto rej = boost::json::parse(std::string_view(line.data() + 5, line.size() - 5)).as_object(); 70 | if(rej.contains("role") && rej.contains("type") && rej["role"] == "assistant" && rej["type"] == "answer") { 71 | callback(rej["content"].as_string()); 72 | } 73 | } 74 | } 75 | }); 76 | } 77 | 78 | }; 79 | 80 | CozeV3::CozeV3(const net::any_io_executor &executor, const YAML::Node& config) { 81 | impl_ = std::make_unique(executor, config); 82 | } 83 | 84 | CozeV3::~CozeV3() { 85 | 86 | } 87 | 88 | net::awaitable CozeV3::create_session() { 89 | co_return co_await impl_->create_session(); 90 | } 91 | 92 | net::awaitable CozeV3::response(const boost::json::array& dialogue, const std::function& callback) { 93 | co_await impl_->response(dialogue, callback); 94 | } 95 | } 96 | } -------------------------------------------------------------------------------- /src/llm/dify.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | 6 | namespace xiaozhi { 7 | namespace llm { 8 | class Dify::Impl { 9 | private: 10 | std::string conversation_id_; 11 | const std::string base_url_; 12 | const std::string api_key_; 13 | boost::json::object header_; 14 | public: 15 | Impl(const net::any_io_executor &executor, const YAML::Node& config): 16 | base_url_(config["base_url"].as()), 17 | api_key_(config["api_key"].as()), 18 | header_({ 19 | {"Authorization", std::format("Bearer {}", api_key_)}, 20 | {"Content-Type", "application/json"}, 21 | {"Connection", "keep-alive"} 22 | }) { 23 | } 24 | 25 | net::awaitable create_session() { 26 | co_return ""; 27 | } 28 | 29 | net::awaitable response(const boost::json::array& dialogue, const std::function& callback) { 30 | std::string query; 31 | for(auto it = dialogue.rbegin(); it != dialogue.rend(); --it) { 32 | if(it->at("role").as_string() == "user") { 33 | query = it->at("content").as_string(); 34 | break; 35 | } 36 | } 37 | const std::string url = std::format("{}/chat-messages", base_url_); 38 | boost::json::object data = { 39 | {"query", std::move(query)}, 40 | {"response_mode", "streaming"}, 41 | {"user", "143523"}, 42 | {"inputs", boost::json::object({})} 43 | }; 44 | if(!conversation_id_.empty()) { 45 | data["conversation_id"] = conversation_id_; 46 | } 47 | co_await request::stream_post(url, header_, boost::json::serialize(data), [&callback, this](std::string res) { 48 | std::vector result; 49 | boost::split(result, res, boost::is_any_of("\n")); 50 | for(auto& line : result) { 51 | if(line.size() == 0) 52 | continue; 53 | if(line.starts_with("data:")) { 54 | auto rej = boost::json::parse(std::string_view(line.data() + 5, line.size() - 5)).as_object(); 55 | if(conversation_id_.empty() && rej.contains("conversation_id")) { 56 | conversation_id_ = rej["conversation_id"].as_string(); 57 | } 58 | if(rej.at("event") == "error") { 59 | BOOST_LOG_TRIVIAL(error) << "Dify api error:" << line; 60 | } else if(rej.contains("answer")) { 61 | callback(rej.at("answer").as_string()); 62 | } 63 | } 64 | } 65 | }); 66 | } 67 | 68 | }; 69 | 70 | Dify::Dify(const net::any_io_executor &executor, const YAML::Node& config) { 71 | impl_ = std::make_unique(executor, config); 72 | } 73 | 74 | Dify::~Dify() { 75 | 76 | } 77 | 78 | net::awaitable Dify::create_session() { 79 | co_return co_await impl_->create_session(); 80 | } 81 | 82 | net::awaitable Dify::response(const boost::json::array& dialogue, const std::function& callback) { 83 | co_await impl_->response(dialogue, callback); 84 | } 85 | } 86 | } -------------------------------------------------------------------------------- /src/llm/openai.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | 6 | namespace xiaozhi { 7 | namespace llm { 8 | class Openai::Impl { 9 | private: 10 | std::string conversation_id_; 11 | const std::string url_; 12 | const std::string api_key_; 13 | const std::string model_name_; 14 | boost::json::object header_; 15 | public: 16 | Impl(const net::any_io_executor &executor, const YAML::Node& config): 17 | url_(config["url"].as()), 18 | api_key_(config["api_key"].as()), 19 | model_name_(config["model_name"].as()), 20 | header_({ 21 | {"Authorization", std::format("Bearer {}", api_key_)}, 22 | {"Content-Type", "application/json"}, 23 | {"Connection", "keep-alive"} 24 | }) { 25 | } 26 | 27 | net::awaitable create_session() { 28 | co_return ""; 29 | } 30 | 31 | net::awaitable response(const boost::json::array& dialogue, const std::function& callback) { 32 | const std::string url = std::format("{}/chat/completions", url_); 33 | boost::json::object data = { 34 | {"model", model_name_}, 35 | {"messages", dialogue}, 36 | {"stream", true} 37 | }; 38 | bool is_active = true; 39 | std::string last_line; //处理json分段的情况 40 | co_await request::stream_post(url, header_, boost::json::serialize(data), [&callback, this, &is_active, &last_line](std::string res) { 41 | std::vector result; 42 | boost::split(result, res, boost::is_any_of("\n")); 43 | for(auto& line : result) { 44 | if(line.size() == 0) 45 | continue; 46 | if(!line.starts_with("data:")) { 47 | line = last_line + line; 48 | last_line = ""; 49 | BOOST_LOG_TRIVIAL(info) << "llm response segment connected:" << line; 50 | } 51 | if(line.starts_with("data:") && line != "data: [DONE]") { 52 | boost::json::object rej; 53 | try { 54 | rej = boost::json::parse(std::string_view(line.data() + 5, line.size() - 5)).as_object(); 55 | } catch(std::exception e) { 56 | BOOST_LOG_TRIVIAL(error) << "llm can't parse response:" << line; 57 | last_line = line; 58 | continue; 59 | } 60 | if(rej.contains("choices")) { 61 | for(auto& value : rej["choices"].as_array()) { 62 | auto& item = value.as_object(); 63 | if(item.contains("delta") && item["delta"].as_object().contains("content")) { 64 | auto content = std::string_view(item["delta"].as_object()["content"].as_string()); 65 | if(content.find("") != content.npos) { 66 | is_active = false; 67 | } 68 | auto p = content.find(""); 69 | if(p != content.npos) { 70 | is_active = true; 71 | content = content.substr(p + 8); 72 | } 73 | if(is_active) { 74 | callback(content); 75 | } 76 | } 77 | } 78 | } 79 | } else if(line != "data: [DONE]") { 80 | BOOST_LOG_TRIVIAL(info) << "llm response exception:" << line; 81 | } 82 | } 83 | }); 84 | } 85 | 86 | }; 87 | 88 | Openai::Openai(const net::any_io_executor &executor, const YAML::Node& config) { 89 | impl_ = std::make_unique(executor, config); 90 | } 91 | 92 | Openai::~Openai() { 93 | 94 | } 95 | 96 | net::awaitable Openai::create_session() { 97 | co_return co_await impl_->create_session(); 98 | } 99 | 100 | net::awaitable Openai::response(const boost::json::array& dialogue, const std::function& callback) { 101 | co_await impl_->response(dialogue, callback); 102 | } 103 | } 104 | } -------------------------------------------------------------------------------- /src/paraformer/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | file(GLOB_RECURSE cpp_files *.cpp) 2 | add_library(paraformer ${cpp_files}) 3 | 4 | target_link_libraries(paraformer PUBLIC Boost::log) 5 | target_link_libraries(paraformer PUBLIC Boost::json) 6 | target_link_libraries(paraformer PUBLIC yaml-cpp::yaml-cpp) 7 | target_link_libraries(paraformer PUBLIC ${ONNXRUNTIME_LIB}) 8 | target_link_libraries(paraformer PUBLIC kaldi-native-fbank-core) -------------------------------------------------------------------------------- /src/paraformer/model.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | namespace funasr { 7 | Model *CreateModel(std::map& model_path, int thread_num) 8 | { 9 | std::string en_model_path; 10 | std::string de_model_path; 11 | std::string am_cmvn_path; 12 | std::string am_config_path; 13 | std::string token_path; 14 | 15 | en_model_path = PathAppend(model_path.at(MODEL_DIR), ENCODER_NAME); 16 | de_model_path = PathAppend(model_path.at(MODEL_DIR), DECODER_NAME); 17 | if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){ 18 | en_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_ENCODER_NAME); 19 | de_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_DECODER_NAME); 20 | } 21 | am_cmvn_path = PathAppend(model_path.at(MODEL_DIR), AM_CMVN_NAME); 22 | am_config_path = PathAppend(model_path.at(MODEL_DIR), AM_CONFIG_NAME); 23 | token_path = PathAppend(model_path.at(MODEL_DIR), TOKEN_PATH); 24 | 25 | Model *mm; 26 | mm = new Paraformer(); 27 | mm->InitAsr(en_model_path, de_model_path, am_cmvn_path, am_config_path, token_path, thread_num); 28 | return mm; 29 | } 30 | 31 | Model *CreateModel(void* asr_handle, std::vector chunk_size) 32 | { 33 | Model* mm; 34 | mm = new ParaformerOnline((Paraformer*)asr_handle, chunk_size); 35 | return mm; 36 | } 37 | 38 | } // namespace funasr -------------------------------------------------------------------------------- /src/paraformer/paraformer.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. 3 | * MIT License (https://opensource.org/licenses/MIT) 4 | */ 5 | 6 | #include "funasr/paraformer/paraformer.h" 7 | #include "funasr/paraformer/commonfunc.h" 8 | #include 9 | #include 10 | #include "funasr/paraformer/utils.h" 11 | 12 | namespace funasr { 13 | 14 | Paraformer::Paraformer() 15 | :use_hotword(false), 16 | env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),session_options_{}, 17 | hw_env_(ORT_LOGGING_LEVEL_ERROR, "paraformer_hw"),hw_session_options{} { 18 | } 19 | 20 | // online 21 | void Paraformer::InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){ 22 | 23 | LoadOnlineConfigFromYaml(am_config.c_str()); 24 | // knf options 25 | fbank_opts_.frame_opts.dither = 0; 26 | fbank_opts_.mel_opts.num_bins = n_mels; 27 | fbank_opts_.frame_opts.samp_freq = asr_sample_rate; 28 | fbank_opts_.frame_opts.window_type = window_type; 29 | fbank_opts_.frame_opts.frame_shift_ms = frame_shift; 30 | fbank_opts_.frame_opts.frame_length_ms = frame_length; 31 | fbank_opts_.energy_floor = 0; 32 | fbank_opts_.mel_opts.debug_mel = false; 33 | 34 | // session_options_.SetInterOpNumThreads(1); 35 | session_options_.SetIntraOpNumThreads(thread_num); 36 | session_options_.SetGraphOptimizationLevel(ORT_ENABLE_ALL); 37 | // DisableCpuMemArena can improve performance 38 | session_options_.DisableCpuMemArena(); 39 | 40 | try { 41 | encoder_session_ = std::make_unique(env_, ORTSTRING(en_model).c_str(), session_options_); 42 | BOOST_LOG_TRIVIAL(info) << "Successfully load model from " << en_model; 43 | } catch (std::exception const &e) { 44 | BOOST_LOG_TRIVIAL(error) << "Error when load am encoder model: " << e.what(); 45 | exit(-1); 46 | } 47 | 48 | try { 49 | decoder_session_ = std::make_unique(env_, ORTSTRING(de_model).c_str(), session_options_); 50 | BOOST_LOG_TRIVIAL(info) << "Successfully load model from " << de_model; 51 | } catch (std::exception const &e) { 52 | BOOST_LOG_TRIVIAL(error) << "Error when load am decoder model: " << e.what(); 53 | exit(-1); 54 | } 55 | 56 | // encoder 57 | std::string strName; 58 | GetInputName(encoder_session_.get(), strName); 59 | en_strInputNames.push_back(strName.c_str()); 60 | GetInputName(encoder_session_.get(), strName,1); 61 | en_strInputNames.push_back(strName); 62 | 63 | GetOutputName(encoder_session_.get(), strName); 64 | en_strOutputNames.push_back(strName); 65 | GetOutputName(encoder_session_.get(), strName,1); 66 | en_strOutputNames.push_back(strName); 67 | GetOutputName(encoder_session_.get(), strName,2); 68 | en_strOutputNames.push_back(strName); 69 | 70 | for (auto& item : en_strInputNames) 71 | en_szInputNames_.push_back(item.c_str()); 72 | for (auto& item : en_strOutputNames) 73 | en_szOutputNames_.push_back(item.c_str()); 74 | 75 | // decoder 76 | int de_input_len = 4 + fsmn_layers; 77 | int de_out_len = 2 + fsmn_layers; 78 | for(int i=0;iwindow_type = frontend_conf["window"].as(); 115 | this->n_mels = frontend_conf["n_mels"].as(); 116 | this->frame_length = frontend_conf["frame_length"].as(); 117 | this->frame_shift = frontend_conf["frame_shift"].as(); 118 | this->lfr_m = frontend_conf["lfr_m"].as(); 119 | this->lfr_n = frontend_conf["lfr_n"].as(); 120 | 121 | this->encoder_size = encoder_conf["output_size"].as(); 122 | this->fsmn_dims = encoder_conf["output_size"].as(); 123 | 124 | this->fsmn_layers = decoder_conf["num_blocks"].as(); 125 | this->fsmn_lorder = decoder_conf["kernel_size"].as()-1; 126 | 127 | this->cif_threshold = predictor_conf["threshold"].as(); 128 | this->tail_alphas = predictor_conf["tail_threshold"].as(); 129 | 130 | this->asr_sample_rate = frontend_conf["fs"].as(); 131 | 132 | 133 | }catch(std::exception const &e){ 134 | BOOST_LOG_TRIVIAL(error) << "Error when load argument from vad config YAML."; 135 | exit(-1); 136 | } 137 | } 138 | 139 | 140 | Paraformer::~Paraformer() 141 | { 142 | if(vocab){ 143 | delete vocab; 144 | } 145 | if(lm_vocab){ 146 | delete lm_vocab; 147 | } 148 | if(seg_dict){ 149 | delete seg_dict; 150 | } 151 | if(phone_set_){ 152 | delete phone_set_; 153 | } 154 | } 155 | 156 | void Paraformer::StartUtterance() 157 | { 158 | } 159 | 160 | void Paraformer::EndUtterance() 161 | { 162 | } 163 | 164 | void Paraformer::Reset() 165 | { 166 | } 167 | 168 | 169 | void Paraformer::LoadCmvn(const char *filename) 170 | { 171 | std::ifstream cmvn_stream(filename); 172 | if (!cmvn_stream.is_open()) { 173 | BOOST_LOG_TRIVIAL(error) << "Failed to open file: " << filename; 174 | exit(-1); 175 | } 176 | std::string line; 177 | 178 | while (getline(cmvn_stream, line)) { 179 | std::istringstream iss(line); 180 | std::vector line_item{std::istream_iterator{iss}, std::istream_iterator{}}; 181 | if (line_item[0] == "") { 182 | getline(cmvn_stream, line); 183 | std::istringstream means_lines_stream(line); 184 | std::vector means_lines{std::istream_iterator{means_lines_stream}, std::istream_iterator{}}; 185 | if (means_lines[0] == "") { 186 | for (int j = 3; j < means_lines.size() - 1; j++) { 187 | means_list_.push_back(stof(means_lines[j])); 188 | } 189 | continue; 190 | } 191 | } 192 | else if (line_item[0] == "") { 193 | getline(cmvn_stream, line); 194 | std::istringstream vars_lines_stream(line); 195 | std::vector vars_lines{std::istream_iterator{vars_lines_stream}, std::istream_iterator{}}; 196 | if (vars_lines[0] == "") { 197 | for (int j = 3; j < vars_lines.size() - 1; j++) { 198 | vars_list_.push_back(stof(vars_lines[j])*scale); 199 | } 200 | continue; 201 | } 202 | } 203 | } 204 | } 205 | 206 | std::string Paraformer::GreedySearch(float * in, int n_len, int64_t token_nums, bool is_stamp, std::vector us_alphas, std::vector us_cif_peak) 207 | { 208 | std::vector hyps; 209 | int Tmax = n_len; 210 | for (int i = 0; i < Tmax; i++) { 211 | int max_idx; 212 | float max_val; 213 | FindMax(in + i * token_nums, token_nums, max_val, max_idx); 214 | hyps.push_back(max_idx); 215 | } 216 | return vocab->Vector2StringV2(hyps, language); 217 | } 218 | 219 | 220 | Vocab* Paraformer::GetVocab() 221 | { 222 | return vocab; 223 | } 224 | 225 | Vocab* Paraformer::GetLmVocab() 226 | { 227 | return lm_vocab; 228 | } 229 | 230 | PhoneSet* Paraformer::GetPhoneSet() 231 | { 232 | return phone_set_; 233 | } 234 | 235 | std::string Paraformer::Rescoring() 236 | { 237 | BOOST_LOG_TRIVIAL(error)<<"Not Imp!!!!!!"; 238 | return ""; 239 | } 240 | } // namespace funasr -------------------------------------------------------------------------------- /src/paraformer/phone-set.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | 8 | namespace funasr { 9 | PhoneSet::PhoneSet(const char *filename) { 10 | ifstream in(filename); 11 | LoadPhoneSetFromJson(filename); 12 | } 13 | PhoneSet::~PhoneSet() 14 | { 15 | } 16 | 17 | void PhoneSet::LoadPhoneSetFromYaml(const char* filename) { 18 | YAML::Node config; 19 | try{ 20 | config = YAML::LoadFile(filename); 21 | }catch(exception const &e){ 22 | BOOST_LOG_TRIVIAL(info) << "Error loading file, yaml file error or not exist."; 23 | exit(-1); 24 | } 25 | YAML::Node myList = config["token_list"]; 26 | int id = 0; 27 | for (YAML::const_iterator it = myList.begin(); it != myList.end(); ++it, id++) { 28 | phone_.push_back(it->as()); 29 | phn2Id_.emplace(it->as(), id); 30 | } 31 | } 32 | 33 | void PhoneSet::LoadPhoneSetFromJson(const char* filename) { 34 | boost::json::value json_array; 35 | std::ifstream file(filename); 36 | if (file.is_open()) { 37 | std::string json_str((std::istreambuf_iterator(file)), std::istreambuf_iterator()); 38 | json_array = boost::json::parse(json_str); 39 | file.close(); 40 | } else { 41 | BOOST_LOG_TRIVIAL(info) << "Error loading token file, token file error or not exist."; 42 | exit(-1); 43 | } 44 | 45 | int id = 0; 46 | for (const auto& element : json_array.as_array()) { 47 | const std::string value(element.as_string()); 48 | phone_.push_back(value); 49 | phn2Id_.emplace(value, id); 50 | id++; 51 | } 52 | } 53 | 54 | int PhoneSet::Size() const { 55 | return phone_.size(); 56 | } 57 | 58 | int PhoneSet::String2Id(string phn_str) const { 59 | if (phn2Id_.count(phn_str)) { 60 | return phn2Id_.at(phn_str); 61 | } else { 62 | //BOOST_LOG_TRIVIAL(info) << "Phone unit not exist."; 63 | return -1; 64 | } 65 | } 66 | 67 | string PhoneSet::Id2String(int id) const { 68 | if (id < 0 || id > Size()) { 69 | //BOOST_LOG_TRIVIAL(info) << "Phone id not exist."; 70 | return ""; 71 | } else { 72 | return phone_[id]; 73 | } 74 | } 75 | 76 | bool PhoneSet::Find(string phn_str) const { 77 | return phn2Id_.count(phn_str) > 0; 78 | } 79 | 80 | int PhoneSet::GetBegSilPhnId() const { 81 | return String2Id(UNIT_BEG_SIL_SYMBOL); 82 | } 83 | 84 | int PhoneSet::GetEndSilPhnId() const { 85 | return String2Id(UNIT_END_SIL_SYMBOL); 86 | } 87 | 88 | int PhoneSet::GetBlkPhnId() const { 89 | return String2Id(UNIT_BLK_SYMBOL); 90 | } 91 | 92 | } -------------------------------------------------------------------------------- /src/paraformer/seg-dict.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. 3 | * MIT License (https://opensource.org/licenses/MIT) 4 | */ 5 | #include "funasr/paraformer/seg-dict.h" 6 | #include "funasr/paraformer/utils.h" 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | using namespace std; 14 | 15 | namespace funasr { 16 | SegDict::SegDict(const char *filename) 17 | { 18 | ifstream in(filename); 19 | if (!in) { 20 | BOOST_LOG_TRIVIAL(error) << filename << " open failed !!"; 21 | return; 22 | } 23 | string textline; 24 | while (getline(in, textline)) { 25 | std::vector line_item = split(textline, '\t'); 26 | //std::cout << textline << std::endl; 27 | if (line_item.size() > 1) { 28 | std::string word = line_item[0]; 29 | std::string segs = line_item[1]; 30 | std::vector segs_vec = split(segs, ' '); 31 | seg_dict[word] = segs_vec; 32 | } 33 | } 34 | BOOST_LOG_TRIVIAL(info) << "load seg dict successfully"; 35 | } 36 | std::vector SegDict::GetTokensByWord(const std::string &word) { 37 | if (seg_dict.count(word)) 38 | return seg_dict[word]; 39 | else { 40 | BOOST_LOG_TRIVIAL(info)<< word <<" is OOV!"; 41 | std::vector vec; 42 | return vec; 43 | } 44 | } 45 | 46 | SegDict::~SegDict() 47 | { 48 | } 49 | 50 | 51 | } // namespace funasr -------------------------------------------------------------------------------- /src/paraformer/utils.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | namespace funasr { 6 | std::string PathAppend(const std::string &p1, const std::string &p2) { 7 | char sep = '/'; 8 | std::string tmp = p1; 9 | if (p1[p1.length()-1] != sep) { // Need to add a 10 | tmp += sep; // path separator 11 | return (tmp + p2); 12 | } else 13 | return (p1 + p2); 14 | } 15 | 16 | void FindMax(float *din, int len, float &max_val, int &max_idx) { 17 | int i; 18 | max_val = -INFINITY; 19 | max_idx = -1; 20 | for (i = 0; i < len; i++) { 21 | if (din[i] > max_val) { 22 | max_val = din[i]; 23 | max_idx = i; 24 | } 25 | } 26 | } 27 | 28 | std::vector split(const std::string &s, char delim) { 29 | std::vector elems; 30 | std::stringstream ss(s); 31 | std::string item; 32 | while(std::getline(ss, item, delim)) { 33 | elems.push_back(item); 34 | } 35 | return elems; 36 | } 37 | } -------------------------------------------------------------------------------- /src/paraformer/vocab.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | using namespace std; 13 | 14 | namespace funasr { 15 | Vocab::Vocab(const char *filename) 16 | { 17 | ifstream in(filename); 18 | LoadVocabFromJson(filename); 19 | } 20 | Vocab::Vocab(const char *filename, const char *lex_file) 21 | { 22 | ifstream in(filename); 23 | LoadVocabFromYaml(filename); 24 | LoadLex(lex_file); 25 | } 26 | Vocab::~Vocab() 27 | { 28 | } 29 | 30 | void Vocab::LoadVocabFromYaml(const char* filename){ 31 | YAML::Node config; 32 | try{ 33 | config = YAML::LoadFile(filename); 34 | }catch(exception const &e){ 35 | BOOST_LOG_TRIVIAL(error) << "Error loading file, yaml file error or not exist."; 36 | exit(-1); 37 | } 38 | YAML::Node myList = config["token_list"]; 39 | int i = 0; 40 | for (YAML::const_iterator it = myList.begin(); it != myList.end(); ++it) { 41 | vocab.push_back(it->as()); 42 | token_id[it->as()] = i; 43 | i ++; 44 | } 45 | } 46 | 47 | void Vocab::LoadVocabFromJson(const char* filename){ 48 | std::ifstream file(filename); 49 | boost::json::value json_array; 50 | if (file.is_open()) { 51 | std::string json_str((std::istreambuf_iterator(file)), std::istreambuf_iterator()); 52 | json_array = boost::json::parse(json_str); 53 | file.close(); 54 | } else { 55 | BOOST_LOG_TRIVIAL(error) << "Error loading token file, token file error or not exist."; 56 | exit(-1); 57 | } 58 | 59 | int i = 0; 60 | for (const auto& element : json_array.as_array()) { 61 | const std::string value(element.as_string()); 62 | vocab.push_back(value); 63 | token_id[value] = i; 64 | i++; 65 | } 66 | } 67 | 68 | void Vocab::LoadLex(const char* filename){ 69 | std::ifstream file(filename); 70 | std::string line; 71 | while (std::getline(file, line)) { 72 | std::string key, value; 73 | std::istringstream iss(line); 74 | std::getline(iss, key, '\t'); 75 | std::getline(iss, value); 76 | 77 | if (!key.empty() && !value.empty()) { 78 | lex_map[key] = value; 79 | } 80 | } 81 | 82 | file.close(); 83 | } 84 | 85 | string Vocab::Word2Lex(const std::string &word) const { 86 | auto it = lex_map.find(word); 87 | if (it != lex_map.end()) { 88 | return it->second; 89 | } 90 | return ""; 91 | } 92 | 93 | int Vocab::GetIdByToken(const std::string &token) const { 94 | auto it = token_id.find(token); 95 | if (it != token_id.end()) { 96 | return it->second; 97 | } 98 | return -1; 99 | } 100 | 101 | void Vocab::Vector2String(vector in, std::vector &preds) 102 | { 103 | for (auto it = in.begin(); it != in.end(); it++) { 104 | string word = vocab[*it]; 105 | preds.emplace_back(word); 106 | } 107 | } 108 | 109 | string Vocab::Vector2String(vector in) 110 | { 111 | int i; 112 | stringstream ss; 113 | for (auto it = in.begin(); it != in.end(); it++) { 114 | ss << vocab[*it]; 115 | } 116 | return ss.str(); 117 | } 118 | 119 | int Str2Int(string str) 120 | { 121 | const char *ch_array = str.c_str(); 122 | if (((ch_array[0] & 0xf0) != 0xe0) || ((ch_array[1] & 0xc0) != 0x80) || 123 | ((ch_array[2] & 0xc0) != 0x80)) 124 | return 0; 125 | int val = ((ch_array[0] & 0x0f) << 12) | ((ch_array[1] & 0x3f) << 6) | 126 | (ch_array[2] & 0x3f); 127 | return val; 128 | } 129 | 130 | string Vocab::Id2String(int id) const 131 | { 132 | if (id < 0 || id >= vocab.size()) { 133 | BOOST_LOG_TRIVIAL(error) << "Error vocabulary id, this id do not exit."; 134 | return ""; 135 | } else { 136 | return vocab[id]; 137 | } 138 | } 139 | 140 | bool Vocab::IsChinese(string ch) 141 | { 142 | if (ch.size() != 3) { 143 | return false; 144 | } 145 | int unicode = Str2Int(ch); 146 | if (unicode >= 19968 && unicode <= 40959) { 147 | return true; 148 | } 149 | return false; 150 | } 151 | 152 | string Vocab::WordFormat(std::string word) 153 | { 154 | if(word == "i"){ 155 | return "I"; 156 | }else if(word == "i'm"){ 157 | return "I'm"; 158 | }else if(word == "i've"){ 159 | return "I've"; 160 | }else if(word == "i'll"){ 161 | return "I'll"; 162 | }else{ 163 | return word; 164 | } 165 | } 166 | 167 | string Vocab::Vector2StringV2(vector in, std::string language) 168 | { 169 | int i; 170 | list words; 171 | int is_pre_english = false; 172 | int pre_english_len = 0; 173 | int is_combining = false; 174 | std::string combine = ""; 175 | std::string unicodeChar = "▁"; 176 | 177 | for (i=0; i" || word == "" || word == "") 181 | continue; 182 | if (language == "en-bpe"){ 183 | size_t found = word.find(unicodeChar); 184 | if(found != std::string::npos){ 185 | if (combine != ""){ 186 | combine = WordFormat(combine); 187 | if (words.size() != 0){ 188 | combine = " " + combine; 189 | } 190 | words.push_back(combine); 191 | } 192 | combine = word.substr(3); 193 | }else{ 194 | combine += word; 195 | } 196 | continue; 197 | } 198 | // step2 combie phoneme to full word 199 | { 200 | int sub_word = !(word.find("@@") == string::npos); 201 | // process word start and middle part 202 | if (sub_word) { 203 | // if badcase: lo@@ chinese 204 | if (i == in.size()-1 || i 1) { 250 | words.push_back(" "); 251 | words.push_back(word); 252 | pre_english_len = word.size(); 253 | } 254 | else { 255 | if (word.size() > 1) { 256 | words.push_back(" "); 257 | } 258 | words.push_back(word); 259 | pre_english_len = word.size(); 260 | } 261 | } 262 | is_pre_english = true; 263 | } 264 | } 265 | } 266 | 267 | if (language == "en-bpe" && combine != ""){ 268 | combine = WordFormat(combine); 269 | if (words.size() != 0){ 270 | combine = " " + combine; 271 | } 272 | words.push_back(combine); 273 | } 274 | 275 | stringstream ss; 276 | for (auto it = words.begin(); it != words.end(); it++) { 277 | ss << *it; 278 | } 279 | 280 | return ss.str(); 281 | } 282 | 283 | int Vocab::Size() const 284 | { 285 | return vocab.size(); 286 | } 287 | 288 | } // namespace funasr -------------------------------------------------------------------------------- /src/server.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | namespace xiaozhi { 6 | Server::Server(std::shared_ptr setting): 7 | setting(setting), 8 | ioc(net::io_context{setting->config["threads"].as()}) { 9 | } 10 | 11 | net::awaitable Server::authenticate(websocket::stream &ws, http::request &req) { 12 | beast::flat_buffer buffer; 13 | co_await http::async_read(ws.next_layer(), buffer, req, net::use_awaitable); 14 | 15 | if(!setting->config["server"]["auth"]["enabled"].IsDefined() || !setting->config["server"]["auth"]["enabled"].as()) { 16 | co_return true; 17 | } 18 | 19 | auto auth = req.find("Authorization"); 20 | if(auth == req.end()) { 21 | BOOST_LOG_TRIVIAL(info) << "not valid auth"; 22 | co_return false; 23 | } 24 | auto val = auth->value(); 25 | val = val.substr(7); //"Bearer token" 去除 "Bearer " 26 | for(auto i=0; iconfig["server"]["auth"]["tokens"].size(); ++i) { 27 | if(val == setting->config["server"]["auth"]["tokens"][i]["token"].as()) { 28 | co_return true; 29 | } 30 | } 31 | co_return false; 32 | } 33 | 34 | net::awaitable Server::run_session(websocket::stream ws) { 35 | ws.set_option(websocket::stream_base::timeout::suggested(beast::role_type::server)); 36 | http::request req; 37 | auto ret = co_await authenticate(ws, req); 38 | if(!ret) { 39 | co_return; 40 | } 41 | co_await ws.async_accept(req, net::use_awaitable); 42 | auto executor = co_await net::this_coro::executor; 43 | auto conn = std::make_shared(setting, std::move(ws), executor); 44 | conn->start(); 45 | } 46 | 47 | net::awaitable Server::listen(net::ip::tcp::endpoint endpoint) { 48 | auto executor = co_await net::this_coro::executor; 49 | auto acceptor = net::ip::tcp::acceptor{executor, endpoint}; 50 | BOOST_LOG_TRIVIAL(info) << "Server is running at " << endpoint.address().to_string() << ":" << endpoint.port(); 51 | while(true) { 52 | net::co_spawn(executor, 53 | run_session(websocket::stream{ 54 | co_await acceptor.async_accept(net::use_awaitable) 55 | }), 56 | std::bind_front(tools::on_spawn_complete, "Session")); 57 | } 58 | } 59 | 60 | void Server::run() { 61 | auto address = net::ip::make_address(setting->config["server"]["ip"].as()); 62 | net::co_spawn(ioc, 63 | listen(net::ip::tcp::endpoint{address, setting->config["server"]["port"].as()}), 64 | std::bind_front(tools::on_spawn_complete, "Listen")); 65 | std::vector v; 66 | auto threads = setting->config["threads"].as(); 67 | v.reserve(threads-1); 68 | for(int i=0; i 2 | #include 3 | 4 | namespace xiaozhi { 5 | 6 | static std::shared_ptr vad = nullptr; 7 | static std::once_flag vadFlag; 8 | 9 | std::shared_ptr SileroVad::getVad() { 10 | std::call_once(vadFlag, [&] { 11 | vad = std::shared_ptr(new SileroVad()); 12 | }); 13 | return vad; 14 | } 15 | 16 | SileroVad::SileroVad() { 17 | } 18 | 19 | float SileroVad::predict(std::vector &pcm_data) { 20 | return detector.predict(pcm_data); 21 | } 22 | 23 | VoiceDetector::VoiceDetector() 24 | : memory_info_(Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault)) { 25 | const auto setting = Setting::getSetting(); 26 | 27 | // 初始化 ONNX Runtime 和 Silero-VAD 模型 28 | Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "SileroVAD"); 29 | Ort::SessionOptions session_options; 30 | session_options.SetIntraOpNumThreads(1); 31 | session_options.SetInterOpNumThreads(1); 32 | session_ = std::make_shared(env, setting->config["VAD"]["SileroVAD"]["model_path"].as().c_str(), session_options); 33 | 34 | // 初始化隐藏状态 35 | state_.resize(2 * 128, 0.0f); 36 | } 37 | 38 | VoiceDetector::~VoiceDetector() { 39 | } 40 | 41 | //todo: 考虑速度,暂时只检测前512个样本,后续如果有问题再改为滑动窗口检测后面的样本 42 | float VoiceDetector::predict(std::vector &pcm_data) { 43 | // 准备 Silero-VAD 输入张量 44 | std::vector input_shape = {1, window_size_}; 45 | std::vector sr_shape = {1}; 46 | std::vector state_shape = {2, 1, 128}; 47 | 48 | auto input_tensor = Ort::Value::CreateTensor(memory_info_, pcm_data.data(), window_size_, 49 | input_shape.data(), input_shape.size()); 50 | auto sr_tensor = Ort::Value::CreateTensor(memory_info_, &sample_rate_, 1, 51 | sr_shape.data(), sr_shape.size()); 52 | auto state_tensor = Ort::Value::CreateTensor(memory_info_, state_.data(), state_.size(), 53 | state_shape.data(), state_shape.size()); 54 | 55 | std::vector input_tensors; 56 | input_tensors.push_back(std::move(input_tensor)); 57 | input_tensors.push_back(std::move(sr_tensor)); 58 | input_tensors.push_back(std::move(state_tensor)); 59 | 60 | // 运行 Silero-VAD 61 | const char* input_names[] = {"input", "sr", "state"}; 62 | const char* output_names[] = {"output", "stateN"}; 63 | auto output_tensors = session_->Run(Ort::RunOptions{nullptr}, input_names, input_tensors.data(), 3, 64 | output_names, 2); 65 | 66 | // 获取语音概率 67 | float probability = output_tensors[0].GetTensorMutableData()[0]; 68 | 69 | // 更新隐藏状态 70 | float* stateN = output_tensors[1].GetTensorMutableData(); 71 | std::memcpy(state_.data(), stateN, state_.size() * sizeof(float)); 72 | 73 | // 判断是否有人声 74 | return probability; 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /src/silero_vad/vad.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | namespace xiaozhi { 6 | Vad::Vad(std::shared_ptr setting) { 7 | int error; 8 | decoder_ = opus_decoder_create(16000, 1, &error); 9 | if (error != OPUS_OK) throw std::runtime_error("Vad Opus 解码器初始化失败"); 10 | threshold_ = setting->config["VAD"]["SileroVAD"]["threshold"].as(); 11 | } 12 | 13 | Vad::~Vad() { 14 | opus_decoder_destroy(decoder_); 15 | } 16 | 17 | bool Vad::is_vad(beast::flat_buffer &buffer) { 18 | std::vector pcm(960); // 最大帧大小 19 | int decoded_samples = opus_decode_float(decoder_, 20 | static_cast(buffer.data().data()), 21 | buffer.size(), pcm.data(), 960, 0); 22 | if (decoded_samples > 0) { 23 | // todo:考虑到buffer实际为960样本的数据,所以暂时不缓存到pcm_buffer中累积 24 | if(pcm.size() > 512) { 25 | auto vad = SileroVad::getVad(); 26 | auto ret = vad->predict(pcm); 27 | if(ret > threshold_) { 28 | return true; 29 | } 30 | } else { 31 | BOOST_LOG_TRIVIAL(error) << "音频样本不足512,无法检测vad: " << pcm.size(); 32 | } 33 | } else { 34 | BOOST_LOG_TRIVIAL(error) << "opus解码失败: " << opus_strerror(decoded_samples); 35 | } 36 | return false; 37 | } 38 | } -------------------------------------------------------------------------------- /src/tts/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | file(GLOB_RECURSE cpp_files *.cpp) 2 | add_library(tts ${cpp_files}) 3 | 4 | target_link_libraries(tts PRIVATE precomp) 5 | target_link_libraries(tts PRIVATE common) 6 | target_link_libraries(tts PRIVATE ${OPUS_LIBRARY}) -------------------------------------------------------------------------------- /src/tts/base.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | namespace xiaozhi { 7 | namespace tts { 8 | std::unique_ptr createTTS(const net::any_io_executor& executor) { 9 | auto setting = Setting::getSetting(); 10 | auto selected_module = setting->config["selected_module"]["TTS"].as(); 11 | if(selected_module == "BytedanceTTSV3") { 12 | return std::make_unique(executor, setting->config["TTS"][selected_module], setting->config["welcome"]["audio_params"]["sample_rate"].as()); 13 | } else if(selected_module == "EdgeTTS") { 14 | return std::make_unique(executor, setting->config["TTS"][selected_module], setting->config["welcome"]["audio_params"]["sample_rate"].as()); 15 | } else { 16 | throw std::invalid_argument("Selected_module TTS not be supported"); 17 | } 18 | } 19 | 20 | Base::~Base() { 21 | BOOST_LOG_TRIVIAL(debug) << "TTS base destroyed"; 22 | } 23 | } 24 | } -------------------------------------------------------------------------------- /src/tts/bytedancev3.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | using wss_stream = websocket::stream>; 7 | 8 | const std::string host{"openspeech.bytedance.com"}; 9 | const std::string port{"443"}; 10 | const std::string path{"/api/v3/tts/bidirection"}; 11 | const bool is_gzip = false; //设置为true返回的也是未压缩的,索性直接不压缩了 12 | 13 | // Helper to construct the header as per protocol (big-endian) 14 | static std::vector make_header(bool json_serialization, bool gzip_compression) { 15 | std::vector header(4); 16 | header[0] = (0x1 << 4) | 0x1; // Protocol version 1, Header size 4 bytes 17 | header[1] = (0x1 << 4) | 0x4; // Message type and flags 18 | header[2] = (json_serialization ? 0x1 : 0x0) << 4 | (gzip_compression ? 0x1 : 0x0); // Serialization and compression 19 | header[3] = 0x00; // Reserved 20 | return header; 21 | } 22 | 23 | static void payload_insert_big_num(std::vector &data, uint32_t num) { 24 | data.insert(data.end(), {(uint8_t)(num >> 24), (uint8_t)(num >> 16), 25 | (uint8_t)(num >> 8), (uint8_t)num}); 26 | } 27 | 28 | static std::vector build_payload(int32_t event_code, std::string payload_str, std::string session_id = "") { 29 | if(is_gzip) { 30 | payload_str = tools::gzip_compress(payload_str); 31 | } 32 | 33 | auto data = make_header(true, is_gzip); // Full client request, JSON, Gzip 34 | payload_insert_big_num(data, event_code); 35 | 36 | if(session_id.size() > 0) { 37 | uint32_t session_size = session_id.size(); 38 | payload_insert_big_num(data, session_size); 39 | data.insert(data.end(), session_id.begin(), session_id.end()); 40 | } 41 | 42 | std::vector payload(payload_str.begin(), payload_str.end()); 43 | uint32_t payload_size = payload.size(); 44 | payload_insert_big_num(data, payload_size); 45 | data.insert(data.end(), payload.begin(), payload.end()); 46 | return data; 47 | } 48 | 49 | static int32_t parser_response_code(const std::string& payload, int header_len=4) { 50 | uint32_t event_code = *(unsigned int *) (payload.data() + header_len); 51 | return boost::endian::big_to_native(event_code); 52 | } 53 | 54 | static int32_t parser_response_code(const unsigned char* data, int header_len=4) { 55 | uint32_t event_code = *(unsigned int *) (data + header_len); 56 | return boost::endian::big_to_native(event_code); 57 | } 58 | 59 | enum EventCodes: int32_t { 60 | StartConnection = 1, //1 61 | FinishConnection = 2, 62 | ConnectionStarted = 50, 63 | ConnectionFailed = 51, 64 | ConnectionFinished = 52, 65 | StartSession = 100, //up 66 | FinishSession = 102, //up 67 | SessionStarted = 150, 68 | SessionCanceled = 151, 69 | SessionFinished = 152, //up | down 70 | SessionFailed = 153, 71 | TaskRequest = 200, //up 72 | TTSSentenceStart = 350, 73 | TTSSentenceEnd = 351, 74 | TTSResponse = 352 75 | }; 76 | 77 | // Ogg页面头部结构(简化版) 78 | struct OggPageHeader { 79 | uint8_t capture_pattern[4]; // "OggS" 80 | uint8_t version; 81 | uint8_t header_type; 82 | uint64_t granule_position; 83 | uint32_t serial_number; 84 | uint32_t page_sequence; 85 | uint32_t checksum; 86 | uint8_t segment_count; 87 | }; 88 | 89 | // 从数据中提取Opus帧 90 | std::vector> extractOpusFrames(const unsigned char* data, size_t data_size) { 91 | std::vector> opus_frames_pos; // 存储提取的Opus帧 92 | size_t offset = 0; 93 | 94 | while (offset < data_size) { 95 | // 检查是否还有足够的数据读取Ogg页面头部 96 | if (offset + 27 > data_size) break; 97 | 98 | // 读取Ogg页面头部 99 | OggPageHeader header; 100 | memcpy(&header.capture_pattern, data + offset, 4); 101 | offset += 4; 102 | header.version = data[offset++]; 103 | header.header_type = data[offset++]; 104 | header.granule_position = *(uint64_t*)(data + offset); // 小端序 105 | offset += 8; 106 | header.serial_number = *(uint32_t*)(data + offset); // 小端序 107 | offset += 4; 108 | header.page_sequence = *(uint32_t*)(data + offset); // 小端序 109 | offset += 4; 110 | header.checksum = *(uint32_t*)(data + offset); // 小端序 111 | offset += 4; 112 | header.segment_count = data[offset++]; 113 | 114 | // 验证OggS标记 115 | if (memcmp(header.capture_pattern, "OggS", 4) != 0) { 116 | BOOST_LOG_TRIVIAL(error) << "BytedanceTTSV3 Invalid OggS pattern at offset " << offset - 27 << std::endl; 117 | break; 118 | } 119 | 120 | // 读取段表 121 | if (offset + header.segment_count > data_size) break; 122 | std::vector segment_table(header.segment_count); 123 | memcpy(segment_table.data(), data + offset, header.segment_count); 124 | offset += header.segment_count; 125 | 126 | // 计算页面数据总长度 127 | size_t payload_size = 0; 128 | for (uint8_t len : segment_table) { 129 | payload_size += len; 130 | } 131 | 132 | // 检查数据是否足够 133 | if (offset + payload_size > data_size) break; 134 | 135 | // 跳过元数据页面(序列号0和1) 136 | if (header.page_sequence == 0 || header.page_sequence == 1) { 137 | offset += payload_size; // 跳过OpusHead或OpusTags 138 | continue; 139 | } 140 | 141 | // 提取Opus帧(每个段可能是一个完整的Opus帧) 142 | size_t segment_offset = offset; 143 | for (uint8_t len : segment_table) { 144 | if (len > 0) { 145 | opus_frames_pos.push_back({segment_offset, len}); 146 | segment_offset += len; 147 | } 148 | } 149 | offset += payload_size; 150 | } 151 | 152 | return opus_frames_pos; 153 | } 154 | 155 | namespace xiaozhi { 156 | namespace tts { 157 | class BytedanceV3::Impl { 158 | private: 159 | int sample_rate_; 160 | std::string appid_; 161 | std::string access_token_; 162 | std::string voice_; 163 | std::string uuid_; 164 | 165 | net::any_io_executor executor_; //需要比resolver和ws先初始化,所以申明在前面 166 | std::unique_ptr ws_; 167 | 168 | void clear(const char* title, beast::error_code ec) { 169 | if(ec) { 170 | BOOST_LOG_TRIVIAL(error) << title << ec.message(); 171 | } 172 | } 173 | 174 | net::awaitable connect() { 175 | try { 176 | auto stream = co_await request::connect({true, host, port, path}); 177 | ws_ = std::make_unique(std::move(stream)); 178 | } catch(const std::exception& e) { 179 | BOOST_LOG_TRIVIAL(error) << "BytedanceTTSV3 connect error:" << e.what(); 180 | co_return false; 181 | } 182 | beast::get_lowest_layer(*ws_).expires_never(); 183 | ws_->set_option( 184 | websocket::stream_base::timeout::suggested( 185 | beast::role_type::client)); 186 | ws_->set_option(websocket::stream_base::decorator( 187 | [this](websocket::request_type& req) { 188 | req.set("X-Api-App-Key", appid_); 189 | req.set("X-Api-Access-Key", access_token_); 190 | req.set("X-Api-Resource-Id", "volc.service_type.10029"); 191 | req.set("X-Api-Connect-Id", uuid_); 192 | })); 193 | auto [ec] = co_await ws_->async_handshake(host + ':' + port, path, net::as_tuple(net::use_awaitable)); 194 | if(ec) { 195 | clear("BytedanceTTSV3 handshake:", ec); 196 | co_return false; 197 | } 198 | co_return true; 199 | } 200 | 201 | net::awaitable start_connection() { 202 | auto data = build_payload(EventCodes::StartConnection, "{}"); 203 | co_await ws_->async_write(net::buffer(data), net::use_awaitable); 204 | 205 | beast::flat_buffer buffer; 206 | co_await ws_->async_read(buffer, net::use_awaitable); 207 | auto event_code = parser_response_code(beast::buffers_to_string(buffer.data())); 208 | if(event_code != EventCodes::ConnectionStarted) { 209 | BOOST_LOG_TRIVIAL(error) << "BytedanceTTS start connection fail with code:" << event_code << ",data:" << beast::make_printable(buffer.data());; 210 | co_return false; 211 | } 212 | co_return true; 213 | } 214 | 215 | net::awaitable start_session() { 216 | boost::json::object obj = { 217 | {"event", EventCodes::StartSession}, 218 | {"req_params", { 219 | {"speaker", voice_}, 220 | {"audio_params", { 221 | {"format", "ogg_opus"}, 222 | {"sample_rate", sample_rate_} 223 | }} 224 | }} 225 | }; 226 | auto data = build_payload(EventCodes::StartSession, boost::json::serialize(obj), uuid_); 227 | co_await ws_->async_write(net::buffer(data), net::use_awaitable); 228 | 229 | beast::flat_buffer buffer; 230 | co_await ws_->async_read(buffer, net::use_awaitable); 231 | auto event_code = parser_response_code(beast::buffers_to_string(buffer.data())); 232 | if(event_code != EventCodes::SessionStarted) { 233 | BOOST_LOG_TRIVIAL(error) << "BytedanceTTS start session fail with code:" << event_code << ",data:" << beast::make_printable(buffer.data()); 234 | co_return false; 235 | } 236 | co_return true; 237 | } 238 | 239 | net::awaitable task_request(const std::string& text) { 240 | boost::json::object obj = { 241 | {"event", EventCodes::TaskRequest}, 242 | {"req_params", { 243 | {"text", text} 244 | }} 245 | }; 246 | auto data = build_payload(EventCodes::TaskRequest, boost::json::serialize(obj), uuid_); 247 | co_await ws_->async_write(net::buffer(data), net::use_awaitable); 248 | 249 | data = build_payload(EventCodes::FinishSession, "{}", uuid_); 250 | co_await ws_->async_write(net::buffer(data), net::use_awaitable); 251 | } 252 | 253 | net::awaitable finish_connection() { 254 | 255 | auto data = build_payload(EventCodes::FinishConnection, "{}"); 256 | co_await ws_->async_write(net::buffer(data), net::use_awaitable); 257 | 258 | beast::flat_buffer buffer; 259 | co_await ws_->async_read(buffer, net::use_awaitable); 260 | auto event_code = parser_response_code(beast::buffers_to_string(buffer.data())); 261 | if(event_code != EventCodes::ConnectionFinished ) { 262 | BOOST_LOG_TRIVIAL(error) << "BytedanceTTS finish connection fail with code:" << event_code << ",data:" << beast::make_printable(buffer.data()); 263 | } 264 | } 265 | 266 | void encode_to_audio(OpusEncoder* encoder, std::vector& pcm, int frame_size, std::vector>& audio) { 267 | std::vector opus_packet_target(frame_size*2); // 最大缓冲区大小 268 | int bytes_written = opus_encode(encoder, pcm.data(), frame_size, 269 | opus_packet_target.data(), opus_packet_target.size()); 270 | if (bytes_written < 0) { 271 | BOOST_LOG_TRIVIAL(error) << "BytedanceTTSV3 opus encode failed:" << opus_strerror(bytes_written); 272 | } else { 273 | opus_packet_target.resize(bytes_written); 274 | audio.push_back(std::move(opus_packet_target)); 275 | } 276 | } 277 | public: 278 | Impl(const net::any_io_executor& executor, const YAML::Node& config, int sample_rate): 279 | executor_(executor), 280 | appid_(config["appid"].as()), 281 | access_token_(config["access_token"].as()), 282 | voice_(config["voice"].as()), 283 | sample_rate_(sample_rate), 284 | uuid_(tools::generate_uuid()) { 285 | } 286 | 287 | net::awaitable>> text_to_speak(const std::string& text) { 288 | std::vector> audio; 289 | if(co_await connect() == false) { 290 | co_return audio; 291 | } 292 | 293 | if(co_await start_connection() == false) { 294 | co_return audio; 295 | } 296 | 297 | if(co_await start_session() == false) { 298 | co_return audio; 299 | } 300 | co_await task_request(text); 301 | 302 | auto [encoder, decoder] = tools::create_opus_coders(sample_rate_); 303 | if(encoder == nullptr || decoder == nullptr) { 304 | co_return audio; 305 | } 306 | auto frame_size = sample_rate_ / 1000 * 60; 307 | 308 | std::vector pcm(frame_size); 309 | int samples_decoded = 0; 310 | 311 | while(true) { 312 | beast::flat_buffer buffer; 313 | co_await ws_->async_read(buffer, net::use_awaitable); 314 | auto data = static_cast(buffer.data().data()); 315 | uint8_t message_type = data[2]; 316 | if(message_type == 0xf0) { 317 | BOOST_LOG_TRIVIAL(error) << "BytedanceTTSV3 message type error"; 318 | break; 319 | } 320 | auto event_code = parser_response_code(data); 321 | if(event_code == EventCodes::SessionFinished) { 322 | break; 323 | } else if(event_code == EventCodes::TTSResponse) { 324 | auto session_id_len = parser_response_code(data, 8); 325 | auto packet = data + 16 + session_id_len; 326 | auto frames_pos = extractOpusFrames(packet, buffer.size() - 16 - session_id_len); 327 | for(auto& pos : frames_pos) { 328 | int origin_frame_size = opus_packet_get_samples_per_frame(packet + pos.first, sample_rate_); 329 | int decoded_frame_size = opus_decode(decoder, packet + pos.first, pos.second, pcm.data() + samples_decoded, origin_frame_size, 0); 330 | if (decoded_frame_size < 0) { 331 | BOOST_LOG_TRIVIAL(error) << "BytedanceTTSV3 opus decode failed:" << origin_frame_size << " len:" << pos.second << " error:" << opus_strerror(decoded_frame_size); 332 | } else { 333 | samples_decoded += decoded_frame_size; 334 | if(samples_decoded >= frame_size) { 335 | samples_decoded -= frame_size; 336 | encode_to_audio(encoder, pcm, frame_size, audio); 337 | } 338 | } 339 | } 340 | } 341 | } 342 | if(samples_decoded > 0) { 343 | pcm.resize(samples_decoded); 344 | if(samples_decoded < frame_size) { 345 | pcm.insert(pcm.end(), frame_size - samples_decoded, 0); //补足60ms 346 | } 347 | encode_to_audio(encoder, pcm, frame_size, audio); 348 | } 349 | opus_decoder_destroy(decoder); 350 | opus_encoder_destroy(encoder); 351 | co_await finish_connection(); 352 | co_return audio; 353 | } 354 | }; 355 | 356 | BytedanceV3::BytedanceV3(const net::any_io_executor& executor, const YAML::Node& config, int sample_rate) { 357 | impl_ = std::make_unique(executor, config, sample_rate); 358 | } 359 | 360 | BytedanceV3::~BytedanceV3() = default; 361 | 362 | net::awaitable>> BytedanceV3::text_to_speak(const std::string& text) { 363 | co_return co_await impl_->text_to_speak(text); 364 | } 365 | } 366 | } -------------------------------------------------------------------------------- /src/tts/edge.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | using wss_stream = websocket::stream>; 6 | 7 | const std::string host{"speech.platform.bing.com"}; 8 | const std::string port{"443"}; 9 | const std::string TRUSTED_CLIENT_TOKEN{"6A5AA1D4EAFF4E9FB37E23D68491D6F4"}; 10 | const constexpr char * path_format = "/consumer/speech/synthesize/readaloud/edge/v1?TrustedClientToken={}&Sec-MS-GEC={}&Sec-MS-GEC-Version=1-130.0.2849.68&ConnectionId={}"; 11 | 12 | 13 | namespace DRM { 14 | // 常量定义 15 | const int64_t WIN_EPOCH = 11644473600LL; // Windows 时间纪元偏移 (1601-01-01 到 1970-01-01 的秒数) 16 | const double S_TO_NS = 1e9; // 秒到纳秒的转换因子 17 | 18 | class DRM { 19 | public: 20 | static double clock_skew_seconds; // 静态成员,用于时钟偏差校正 21 | 22 | // 调整时钟偏差 23 | static void adj_clock_skew_seconds(double skew_seconds) { 24 | clock_skew_seconds += skew_seconds; 25 | } 26 | 27 | // 获取当前 Unix 时间戳(带时钟偏差校正) 28 | static double get_unix_timestamp() { 29 | auto now = std::chrono::system_clock::now(); 30 | auto duration = now.time_since_epoch(); 31 | double seconds = std::chrono::duration_cast(duration).count() / 1e6; 32 | return seconds + clock_skew_seconds; 33 | } 34 | 35 | // 生成 Sec-MS-GEC 令牌 36 | static std::string generate_sec_ms_gec() { 37 | // 获取当前时间戳(带时钟偏差校正) 38 | double ticks = get_unix_timestamp(); 39 | 40 | // 转换为 Windows 文件时间纪元 41 | ticks += WIN_EPOCH; 42 | 43 | // 向下取整到最近的 5 分钟(300 秒) 44 | ticks -= std::fmod(ticks, 300.0); 45 | 46 | // 转换为 100 纳秒间隔(Windows 文件时间格式) 47 | int64_t ticks_ns = static_cast(ticks * S_TO_NS / 100); 48 | 49 | // 拼接时间戳和 TRUSTED_CLIENT_TOKEN 50 | std::ostringstream str_to_hash; 51 | str_to_hash << ticks_ns << TRUSTED_CLIENT_TOKEN; 52 | 53 | // 计算 SHA256 哈希 54 | unsigned char hash[SHA256_DIGEST_LENGTH]; 55 | SHA256(reinterpret_cast(str_to_hash.str().c_str()), 56 | str_to_hash.str().length(), hash); 57 | 58 | // 转换为大写十六进制字符串 59 | std::ostringstream hex_digest; 60 | hex_digest << std::hex << std::uppercase << std::setfill('0'); 61 | for (int i = 0; i < SHA256_DIGEST_LENGTH; ++i) { 62 | hex_digest << std::setw(2) << static_cast(hash[i]); 63 | } 64 | 65 | return hex_digest.str(); 66 | } 67 | }; 68 | 69 | // 初始化静态成员 70 | double DRM::clock_skew_seconds = 5.0; 71 | } 72 | static std::string date_to_string() { 73 | // 获取当前 UTC 时间 74 | auto now = std::chrono::system_clock::now(); 75 | std::time_t tt = std::chrono::system_clock::to_time_t(now); 76 | std::tm* utc_tm = std::gmtime(&tt); // 转换为 UTC 时间结构体 77 | 78 | // 格式化时间戳 79 | std::ostringstream oss; 80 | oss << std::put_time(utc_tm, "%a %b %d %Y %H:%M:%S"); // Thu Mar 20 2025 09:35:43 81 | oss << " GMT+0000 (Coordinated Universal Time)"; // 添加固定后缀 82 | 83 | return oss.str(); 84 | } 85 | namespace xiaozhi { 86 | namespace tts { 87 | class Edge::Impl { 88 | private: 89 | const int sample_rate_ = 24000; //edge只支持24000khz采样 90 | 91 | std::string voice_; 92 | std::string uuid_; 93 | 94 | net::any_io_executor executor_; //需要比resolver和ws先初始化,所以申明在前面 95 | std::unique_ptr ws_; 96 | 97 | net::awaitable connect() { 98 | auto sec_ms_gec = DRM::DRM::generate_sec_ms_gec(); 99 | std::string path = std::format(path_format, TRUSTED_CLIENT_TOKEN, sec_ms_gec, uuid_); 100 | try { 101 | auto stream = co_await request::connect({true, host, port, path}); 102 | ws_ = std::make_unique(std::move(stream)); 103 | } catch(const std::exception& e) { 104 | BOOST_LOG_TRIVIAL(info) << "Edge tts connect error:" << e.what(); 105 | co_return false; 106 | } 107 | beast::get_lowest_layer(*ws_).expires_never(); 108 | ws_->set_option( 109 | websocket::stream_base::timeout::suggested( 110 | beast::role_type::client)); 111 | ws_->set_option(websocket::stream_base::decorator( 112 | [this](websocket::request_type& req) { 113 | req.set("Origin", "chrome-extension://jdiccldimpahaajbacbfkddppajiklmg"); 114 | })); 115 | auto [ec] = co_await ws_->async_handshake(host + ':' + port, path, net::as_tuple(net::use_awaitable)); 116 | if(ec) { 117 | BOOST_LOG_TRIVIAL(info) << "Edge tts handshake:" << ec.message(); 118 | co_return false; 119 | } 120 | co_return true; 121 | } 122 | 123 | net::awaitable send_command_request() { 124 | std::string timestamp = date_to_string(); 125 | std::ostringstream oss; 126 | oss << "X-Timestamp:" << timestamp << "\r\n" 127 | << "Content-Type:application/json; charset=utf-8\r\n" 128 | << "Path:speech.config\r\n" 129 | << "\r\n" 130 | << R"({"context":{"synthesis":{"audio":{"metadataoptions":{"sentenceBoundaryEnabled":"false","wordBoundaryEnabled":"false"},"outputFormat":"webm-24khz-16bit-mono-opus"}}}})" << "\r\n"; 131 | co_await ws_->async_write(net::buffer(oss.str()), net::use_awaitable); 132 | } 133 | 134 | net::awaitable send_ssml_request(const std::string& text) { 135 | std::string timestamp = date_to_string(); 136 | std::ostringstream oss; 137 | oss << "X-RequestId:" << uuid_ << "\r\n" 138 | << "Content-Type:application/ssml+xml\r\n" 139 | << "X-Timestamp:" << timestamp << "\r\n" 140 | << "Path:ssml\r\n" 141 | << "\r\n" 142 | << "" 143 | << "" 144 | << "" 145 | << text 146 | << "" 147 | << "\r\n"; 148 | co_await ws_->async_write(net::buffer(oss.str()), net::use_awaitable); 149 | } 150 | 151 | void encode_to_audio(OpusEncoder* encoder, std::vector& pcm, int frame_size, std::vector>& audio) { 152 | std::vector opus_packet_target(frame_size*2); // 最大缓冲区大小 153 | int bytes_written = opus_encode(encoder, pcm.data(), frame_size, 154 | opus_packet_target.data(), opus_packet_target.size()); 155 | if (bytes_written < 0) { 156 | BOOST_LOG_TRIVIAL(error) << "BytedanceTTSV3 opus encode failed:" << opus_strerror(bytes_written); 157 | } else { 158 | opus_packet_target.resize(bytes_written); 159 | audio.push_back(std::move(opus_packet_target)); 160 | } 161 | } 162 | public: 163 | Impl(const net::any_io_executor& executor, const YAML::Node& config, int sample_rate): 164 | executor_(executor), 165 | voice_(config["voice"].as()), 166 | uuid_(tools::generate_uuid()) { 167 | 168 | } 169 | 170 | net::awaitable>> text_to_speak(const std::string& text) { 171 | std::vector> audio; 172 | bool is_connected = false; 173 | int retry = 0; 174 | while(!is_connected && retry++ < 3) { //edge偶尔会连接不上,重试3次 175 | is_connected = co_await connect(); 176 | } 177 | if(!is_connected) { 178 | BOOST_LOG_TRIVIAL(info) << "Edge tts connect max retry"; 179 | co_return audio; 180 | } 181 | co_await send_command_request(); 182 | co_await send_ssml_request(text); 183 | 184 | auto [encoder, decoder] = tools::create_opus_coders(sample_rate_); 185 | if(encoder == nullptr || decoder == nullptr) { 186 | co_return audio; 187 | } 188 | auto frame_size = sample_rate_ / 1000 * 60; 189 | 190 | std::vector pcm(frame_size); 191 | int samples_decoded = 0; 192 | bool is_opus = false; 193 | while(true) { 194 | beast::flat_buffer buffer; 195 | co_await ws_->async_read(buffer, net::use_awaitable); 196 | if(ws_->got_binary()) { 197 | auto packet = static_cast(buffer.data().data()); 198 | uint32_t header_len = (static_cast(packet[0]) << 8) | static_cast(packet[1]); 199 | header_len += 2; 200 | packet += header_len; 201 | auto packet_len = buffer.size() - header_len; 202 | if(packet_len < 3) { 203 | continue; 204 | } 205 | uint32_t bin_type = (static_cast(packet[0]) << 16) | (static_cast(packet[1]) << 8) | static_cast(packet[2]); 206 | if(bin_type == 0xa3fc81) { 207 | is_opus = true; 208 | continue; 209 | } else if(bin_type == 0xab820c) { 210 | is_opus = false; 211 | continue; 212 | } 213 | if(!is_opus) { 214 | continue; 215 | } 216 | packet_len -= 6; 217 | int origin_frame_size = opus_packet_get_samples_per_frame(packet, sample_rate_); 218 | int decoded_frame_size = opus_decode(decoder, packet, packet_len < 120 ? packet_len : 120, pcm.data() + samples_decoded, origin_frame_size, 0); 219 | if (decoded_frame_size < 0) { 220 | BOOST_LOG_TRIVIAL(error) << "Edge tts opus decode failed:" << origin_frame_size << " len:" << buffer.size() - header_len << " error:" << opus_strerror(decoded_frame_size); 221 | } else { 222 | samples_decoded += decoded_frame_size; 223 | if(samples_decoded >= frame_size) { 224 | samples_decoded -= frame_size; 225 | encode_to_audio(encoder, pcm, frame_size, audio); 226 | } 227 | } 228 | } else { 229 | std::string data = beast::buffers_to_string(buffer.data()); 230 | auto p = data.find("Path:turn.end"); 231 | if(p != data.npos) { 232 | break; 233 | } 234 | } 235 | 236 | } 237 | if(samples_decoded > 0) { 238 | pcm.resize(samples_decoded); 239 | if(samples_decoded < frame_size) { 240 | pcm.insert(pcm.end(), frame_size - samples_decoded, 0); //补足60ms 241 | } 242 | encode_to_audio(encoder, pcm, frame_size, audio); 243 | } 244 | opus_decoder_destroy(decoder); 245 | opus_encoder_destroy(encoder); 246 | co_return audio; 247 | } 248 | }; 249 | 250 | 251 | Edge::Edge(const net::any_io_executor& executor, const YAML::Node& config, int sample_rate) { 252 | impl_ = std::make_unique(executor, config, sample_rate); 253 | } 254 | 255 | Edge::~Edge() = default; 256 | 257 | net::awaitable>> Edge::text_to_speak(const std::string& text) { 258 | co_return co_await impl_->text_to_speak(text); 259 | } 260 | } 261 | } -------------------------------------------------------------------------------- /tests/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | include(Catch) 2 | 3 | function(add_target_with_config target_name) 4 | add_executable(${target_name} ${ARGN}) 5 | target_link_libraries(${target_name} PRIVATE Catch2::Catch2WithMain) 6 | catch_discover_tests(${target_name}) 7 | 8 | endfunction() 9 | 10 | file(GLOB_RECURSE cpp_files *.cpp) 11 | 12 | foreach(cpp_file ${cpp_files}) 13 | file(RELATIVE_PATH file_path "${PROJECT_SOURCE_DIR}/tests" "${cpp_file}") 14 | string(REPLACE "/" "_" file_with_path ${file_path}) 15 | get_filename_component(name_with_path ${file_with_path} NAME_WE) 16 | add_target_with_config(${name_with_path} ${cpp_file}) 17 | message(STATUS "Add test: ${name_with_path}") 18 | endforeach() 19 | 20 | target_link_libraries(common_test_setting PUBLIC common precomp) 21 | target_link_libraries(common_test_request PUBLIC common precomp) 22 | target_link_libraries(test_find_last_segment PUBLIC common) -------------------------------------------------------------------------------- /tests/common/test_request.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | TEST_CASE("request get") { 9 | net::io_context ioc; 10 | std::promise promise; 11 | auto future = promise.get_future(); 12 | net::co_spawn(ioc, [&promise]() -> net::awaitable { 13 | boost::json::object header = { 14 | {"Content-Type", "application/json"} 15 | }; 16 | auto ret = co_await request::get("https://echo.websocket.org/test?v=1", header); 17 | promise.set_value(ret); 18 | }, net::detached); 19 | ioc.run(); 20 | std::string ret = future.get(); 21 | // REQUIRE(ret.find("Content-Type: application/json") != std::string::npos); 22 | REQUIRE_THAT(ret, Catch::Matchers::ContainsSubstring("Content-Type") && Catch::Matchers::ContainsSubstring("application/json")); 23 | } 24 | 25 | TEST_CASE("request post") { 26 | net::io_context ioc; 27 | std::promise promise; 28 | auto future = promise.get_future(); 29 | net::co_spawn(ioc, [&promise]() -> net::awaitable { 30 | boost::json::object header = { 31 | {"Content-Type", "application/json"} 32 | }; 33 | auto ret = co_await request::post("https://echo.websocket.org/test?v=1", header, R"({"test": "value"})"); 34 | promise.set_value(ret); 35 | }, net::detached); 36 | ioc.run(); 37 | std::string ret = future.get(); 38 | // REQUIRE(ret.find("Content-Type: application/json") != std::string::npos); 39 | REQUIRE_THAT(ret, Catch::Matchers::ContainsSubstring("Content-Type") 40 | && Catch::Matchers::ContainsSubstring("application/json") 41 | && Catch::Matchers::ContainsSubstring("test") 42 | && Catch::Matchers::ContainsSubstring("value")); 43 | } 44 | 45 | TEST_CASE("request stream post") { 46 | net::io_context ioc; 47 | std::promise promise; 48 | auto future = promise.get_future(); 49 | net::co_spawn(ioc, [&promise]() -> net::awaitable { 50 | boost::json::object header = { 51 | {"Content-Type", "application/json"} 52 | }; 53 | std::string value; 54 | co_await request::stream_post("https://echo.websocket.org/test?v=1", header, R"({"stream": true})", [&promise, &value](std::string data) { 55 | value += data; 56 | }); 57 | promise.set_value(value); 58 | }, net::detached); 59 | ioc.run(); 60 | std::string ret = future.get(); 61 | // REQUIRE(ret.find("Content-Type: application/json") != std::string::npos); 62 | REQUIRE_THAT(ret, Catch::Matchers::ContainsSubstring("Content-Type") 63 | && Catch::Matchers::ContainsSubstring("application/json") 64 | && Catch::Matchers::ContainsSubstring("stream") 65 | && Catch::Matchers::ContainsSubstring("true")); 66 | } -------------------------------------------------------------------------------- /tests/common/test_setting.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | TEST_CASE("load setting") { 5 | auto setting = xiaozhi::Setting::getSetting(); 6 | REQUIRE(setting->config["server"]["ip"].as() == "0.0.0.0"); 7 | REQUIRE(setting->config["server"]["port"].as() == 8000); 8 | } -------------------------------------------------------------------------------- /tests/test_find_last_segment.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | TEST_CASE("find last segment") { 5 | std::string txt = "."; 6 | auto p = tools::find_last_segment(txt); 7 | REQUIRE(p == 0); 8 | txt = "。"; 9 | p = tools::find_last_segment(txt); 10 | REQUIRE(p == 2); //只有一个标点符号的情况无视错误 11 | 12 | txt = "你好,我的世界。"; 13 | p = tools::find_last_segment(txt); 14 | REQUIRE(p == 23); 15 | 16 | txt = "你好,我的世界"; 17 | p = tools::find_last_segment(txt); 18 | REQUIRE(p == 8); 19 | 20 | txt = "。你好我的世界"; 21 | p = tools::find_last_segment(txt); 22 | REQUIRE(p == 2); 23 | 24 | txt = "你好,我的世界."; 25 | p = tools::find_last_segment(txt); 26 | REQUIRE(p == 19); 27 | 28 | txt = "你好,我的世界"; 29 | p = tools::find_last_segment(txt); 30 | REQUIRE(p == 6); 31 | 32 | txt = ".你好我的世界"; 33 | p = tools::find_last_segment(txt); 34 | REQUIRE(p == 0); 35 | 36 | txt = "你好、我的世界"; 37 | p = tools::find_last_segment(txt); 38 | REQUIRE(p == 8); 39 | txt = "你好,我的世界!"; 40 | p = tools::find_last_segment(txt); 41 | REQUIRE(p == 23); 42 | txt = "你好,我的世界?"; 43 | p = tools::find_last_segment(txt); 44 | REQUIRE(p == 23); 45 | txt = "你好,我的世界;"; 46 | p = tools::find_last_segment(txt); 47 | REQUIRE(p == 23); 48 | txt = "你好,我的世界:"; 49 | p = tools::find_last_segment(txt); 50 | REQUIRE(p == 23); 51 | } --------------------------------------------------------------------------------