├── .clang-format ├── cmake ├── FetchNlohmannJson.cmake ├── moonshine-config.cmake.in ├── FetchSDL.cmake └── FetchOnnxruntime.cmake ├── .gitignore ├── .github └── workflows │ ├── clang-format-check.yaml │ ├── build.yaml │ └── release.yaml ├── scripts ├── build.sh ├── build.ps1 └── onnx_get_names.py ├── examples ├── CMakeLists.txt ├── demo.cpp └── live.cpp ├── LICENSE ├── src ├── moonshine.hpp └── moonshine.cpp ├── CMakeLists.txt └── README.md /.clang-format: -------------------------------------------------------------------------------- 1 | # .clang-format 2 | BasedOnStyle: Google 3 | IndentWidth: 4 4 | TabWidth: 4 5 | UseTab: Never 6 | ColumnLimit: 100 7 | AllowShortFunctionsOnASingleLine: Empty 8 | SortIncludes: false 9 | BreakBeforeBraces: Allman -------------------------------------------------------------------------------- /cmake/FetchNlohmannJson.cmake: -------------------------------------------------------------------------------- 1 | include(FetchContent) 2 | 3 | FetchContent_Declare(json 4 | DOWNLOAD_EXTRACT_TIMESTAMP 1 5 | URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz) 6 | FetchContent_MakeAvailable(json) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | CMakeLists.txt.user 2 | CMakeCache.txt 3 | CMakeFiles 4 | CMakeScripts 5 | Testing 6 | Makefile 7 | cmake_install.cmake 8 | install_manifest.txt 9 | compile_commands.json 10 | CTestTestfile.cmake 11 | _deps 12 | CMakeUserPresets.json 13 | build/ 14 | .vscode/ 15 | dist/ 16 | models/ -------------------------------------------------------------------------------- /cmake/moonshine-config.cmake.in: -------------------------------------------------------------------------------- 1 | # cmake/moonshine-config.cmake.in 2 | @PACKAGE_INIT@ 3 | 4 | include("${CMAKE_CURRENT_LIST_DIR}/moonshine-targets.cmake") 5 | 6 | # Find dependencies 7 | include(CMakeFindDependencyMacro) 8 | 9 | # Add any future dependencies here if needed 10 | # find_dependency(SomePackage) 11 | 12 | check_required_components(moonshine) -------------------------------------------------------------------------------- /.github/workflows/clang-format-check.yaml: -------------------------------------------------------------------------------- 1 | name: Clang Format Check 2 | 3 | on: 4 | workflow_call: 5 | 6 | jobs: 7 | clang-format-check: 8 | runs-on: ubuntu-latest 9 | 10 | steps: 11 | - name: Checkout code 12 | uses: actions/checkout@v4 13 | 14 | - name: Set up Clang 15 | run: sudo apt-get install -y clang-format 16 | 17 | - name: Run Clang Format 18 | run: | 19 | clang-format --version 20 | find ./src -name '*.cpp' -o -name '*.hpp' | xargs clang-format -i 21 | git diff --exit-code -------------------------------------------------------------------------------- /scripts/build.sh: -------------------------------------------------------------------------------- 1 | # Set default build type to Release 2 | BUILD_TYPE=${1:-Release} 3 | 4 | # find the directory of the script 5 | DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 6 | ROOT_DIR="$(dirname "$DIR")" 7 | 8 | # Create build directory 9 | mkdir -p $ROOT_DIR/build 10 | 11 | # Configure with system ONNX Runtime 12 | cmake -B $ROOT_DIR/build -S $ROOT_DIR -DCMAKE_BUILD_TYPE=$BUILD_TYPE \ 13 | -DCMAKE_INSTALL_PREFIX=$ROOT_DIR/dist 14 | 15 | # Build 16 | cmake --build $ROOT_DIR/build --config $BUILD_TYPE 17 | 18 | # Install to current directory 19 | cmake --install $ROOT_DIR/build --prefix $ROOT_DIR/dist \ 20 | --config $BUILD_TYPE 21 | -------------------------------------------------------------------------------- /examples/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # examples/CMakeLists.txt 2 | 3 | include(../cmake/FetchSDL.cmake) 4 | 5 | # Create example executables 6 | add_executable(moonshine_example demo.cpp) 7 | add_executable(moonshine_live live.cpp) 8 | 9 | target_link_libraries(moonshine_example 10 | PRIVATE 11 | moonshine 12 | ) 13 | 14 | target_link_libraries(moonshine_live 15 | PRIVATE 16 | moonshine 17 | SDL2::SDL2 18 | ) 19 | 20 | target_include_directories(moonshine_example 21 | PRIVATE 22 | ${ONNXRUNTIME_INCLUDE_DIRS} 23 | ) 24 | 25 | target_include_directories(moonshine_live 26 | PRIVATE 27 | ${ONNXRUNTIME_INCLUDE_DIRS} 28 | ${SDL2_INCLUDE_DIRS} 29 | ) 30 | 31 | # Install example executable to bin directory 32 | install(TARGETS moonshine_example moonshine_live 33 | RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} 34 | ) 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Locaal AI: Open Tools for AI Developers 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /scripts/build.ps1: -------------------------------------------------------------------------------- 1 | param ( 2 | [string]$BuildType = "Release", 3 | [switch]$BuildExamples = $false 4 | ) 5 | 6 | # find the directory of the script 7 | $DIR = Split-Path -Parent $MyInvocation.MyCommand.Definition 8 | $ROOT_DIR = Split-Path -Parent $DIR 9 | 10 | # Create build directory 11 | New-Item -ItemType Directory -Force -Path "$ROOT_DIR/build" -ErrorAction SilentlyContinue 12 | 13 | # Configure with system ONNX Runtime 14 | $cmakeArgs = "-B `"$ROOT_DIR/build`" -S `"$ROOT_DIR`" -DCMAKE_BUILD_TYPE=$BuildType -DCMAKE_INSTALL_PREFIX=`"$ROOT_DIR/dist`"" 15 | 16 | if ($BuildExamples) { 17 | $cmakeArgs += " -DBUILD_EXAMPLES=ON" 18 | } 19 | 20 | Invoke-Expression "cmake $cmakeArgs" 21 | if ($LASTEXITCODE -ne 0) { 22 | Write-Host "Config failed." 23 | exit 1 24 | } 25 | 26 | # Build 27 | cmake --build "$ROOT_DIR/build" --config $BuildType 28 | if ($LASTEXITCODE -ne 0) { 29 | Write-Host "Build failed." 30 | exit 1 31 | } 32 | 33 | # Install to current directory 34 | cmake --install "$ROOT_DIR/build" --prefix "$ROOT_DIR/dist" --config $BuildType 35 | if ($LASTEXITCODE -ne 0) { 36 | Write-Host "Install failed." 37 | exit 1 38 | } 39 | 40 | Write-Host "Build and install succeeded." -------------------------------------------------------------------------------- /.github/workflows/build.yaml: -------------------------------------------------------------------------------- 1 | name: Build 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | workflow_call: 9 | 10 | jobs: 11 | check-format: 12 | name: Check Formatting 🔍 13 | uses: ./.github/workflows/clang-format-check.yaml 14 | permissions: 15 | contents: read 16 | 17 | build: 18 | 19 | strategy: 20 | matrix: 21 | os: [ubuntu-latest, macos-latest, windows-latest] 22 | build_type: [Release, Debug] 23 | 24 | runs-on: ${{ matrix.os }} 25 | 26 | steps: 27 | - name: Checkout code 28 | uses: actions/checkout@v4 29 | 30 | - name: Set up CMake 31 | uses: jwlawson/actions-setup-cmake@v2 32 | 33 | - name: Configure CMake 34 | run: cmake -B build -S . -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} -DCMAKE_INSTALL_PREFIX=${{ github.workspace }}/dist 35 | 36 | - name: Build 37 | run: cmake --build build --config ${{ matrix.build_type }} 38 | 39 | - name: Install 40 | run: cmake --install build --prefix ${{ github.workspace }}/dist --config ${{ matrix.build_type }} 41 | 42 | - name: Upload artifacts 43 | uses: actions/upload-artifact@v4 44 | with: 45 | name: ${{ matrix.os }}-${{ matrix.build_type }} 46 | path: ${{ github.workspace }}/dist -------------------------------------------------------------------------------- /scripts/onnx_get_names.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import onnxruntime as ort 3 | import argparse 4 | 5 | 6 | def extract_io_info(model_path): 7 | """ 8 | Extract input and output information from an ONNX model. 9 | 10 | Args: 11 | model_path (str): Path to the .onnx model file 12 | """ 13 | try: 14 | # Create inference session 15 | session = ort.InferenceSession(model_path) 16 | 17 | # Get input details 18 | print("\nModel Inputs:") 19 | for input_detail in session.get_inputs(): 20 | print(f"- Name: {input_detail.name}") 21 | print(f" Shape: {input_detail.shape}") 22 | print(f" Type: {input_detail.type}") 23 | 24 | # Get output details 25 | print("\nModel Outputs:") 26 | for output_detail in session.get_outputs(): 27 | print(f"- Name: {output_detail.name}") 28 | print(f" Shape: {output_detail.shape}") 29 | print(f" Type: {output_detail.type}") 30 | 31 | except Exception as e: 32 | print(f"Error loading model: {str(e)}") 33 | sys.exit(1) 34 | 35 | 36 | def main(): 37 | parser = argparse.ArgumentParser( 38 | description="Extract input/output information from ONNX model" 39 | ) 40 | parser.add_argument("model_path", help="Path to the .onnx model file") 41 | args = parser.parse_args() 42 | 43 | extract_io_info(args.model_path) 44 | 45 | 46 | if __name__ == "__main__": 47 | main() 48 | -------------------------------------------------------------------------------- /cmake/FetchSDL.cmake: -------------------------------------------------------------------------------- 1 | # FetchSDL.cmake 2 | include(FetchContent) 3 | 4 | set(SDL_VERSION "2.30.9") 5 | 6 | FetchContent_Declare( 7 | SDL2 8 | DOWNLOAD_EXTRACT_TIMESTAMP 1 9 | URL https://github.com/libsdl-org/SDL/releases/download/release-${SDL_VERSION}/SDL2-${SDL_VERSION}.zip 10 | URL_HASH SHA256=ec855bcd815b4b63d0c958c42c2923311c656227d6e0c1ae1e721406d346444b 11 | ) 12 | 13 | # Platform specific options 14 | if(WIN32) 15 | set(SDL_SHARED ON CACHE BOOL "Build SDL2 shared library") 16 | set(SDL_STATIC OFF CACHE BOOL "Build SDL2 static library") 17 | elseif(APPLE) 18 | set(SDL_SHARED ON CACHE BOOL "Build SDL2 shared library") 19 | set(SDL_STATIC OFF CACHE BOOL "Build SDL2 static library") 20 | set(SDL_FRAMEWORK OFF CACHE BOOL "Build SDL2 framework") 21 | else() 22 | set(SDL_SHARED ON CACHE BOOL "Build SDL2 shared library") 23 | set(SDL_STATIC OFF CACHE BOOL "Build SDL2 static library") 24 | endif() 25 | 26 | FetchContent_MakeAvailable(SDL2) 27 | 28 | # Installation rules 29 | if(WIN32) 30 | install( 31 | FILES $ 32 | DESTINATION bin 33 | ) 34 | install( 35 | FILES $ 36 | DESTINATION lib 37 | ) 38 | else() 39 | install( 40 | TARGETS SDL2 41 | LIBRARY DESTINATION lib 42 | ARCHIVE DESTINATION lib 43 | RUNTIME DESTINATION bin 44 | ) 45 | endif() 46 | 47 | # Install headers 48 | install( 49 | DIRECTORY ${SDL2_SOURCE_DIR}/include/ 50 | DESTINATION include/SDL2 51 | FILES_MATCHING PATTERN "*.h" 52 | ) 53 | -------------------------------------------------------------------------------- /src/moonshine.hpp: -------------------------------------------------------------------------------- 1 | // moonshine.hpp 2 | #pragma once 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | /** 10 | * @class MoonshineModel 11 | * @brief A class to handle the ONNX model inference for the Moonshine project. 12 | */ 13 | class MoonshineModel 14 | { 15 | public: 16 | /** 17 | * @brief Constructor for the MoonshineModel class. 18 | * @param models_dir The directory containing the ONNX model files. 19 | */ 20 | explicit MoonshineModel(const std::string &models_dir); 21 | 22 | /** 23 | * @brief Generate tokens from audio samples. 24 | * @param audio_samples A vector of normalized float32 audio samples in the range [-1.0, 1.0]. 25 | * @param max_len The maximum length of the generated tokens. Default is 0 (no limit). 26 | * @return A vector of generated token IDs. 27 | */ 28 | std::vector generate(const std::vector &audio_samples, size_t max_len = 0); 29 | 30 | /** 31 | * @brief Detokenize the generated tokens into a string. 32 | * @param tokens A vector of token IDs. 33 | * @return A detokenized string. 34 | */ 35 | std::string detokenize(const std::vector &tokens); 36 | 37 | private: 38 | std::unique_ptr preprocess_; ///< ONNX session for the preprocessing model. 39 | std::unique_ptr encode_; ///< ONNX session for the encoding model. 40 | std::unique_ptr 41 | uncached_decode_; ///< ONNX session for the uncached decoding model. 42 | std::unique_ptr cached_decode_; ///< ONNX session for the cached decoding model. 43 | Ort::Env env_; ///< ONNX Runtime environment. 44 | Ort::MemoryInfo memory_info_; ///< Memory information for ONNX Runtime. 45 | 46 | /** 47 | * @brief Helper function to create an ONNX session. 48 | * @param model_path The path to the ONNX model file. 49 | * @return A unique pointer to the created ONNX session. 50 | */ 51 | std::unique_ptr createSession(const std::string &model_path); 52 | 53 | std::map token_id_to_token_; ///< Map from token IDs to token strings. 54 | 55 | /** 56 | * @brief Load the tokenizer from a JSON string. 57 | * @param tokenizer_content The JSON string containing the tokenizer data. 58 | */ 59 | void load_tokenizer(const std::string &tokenizer_content); 60 | }; 61 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # CMakeLists.txt 2 | cmake_minimum_required(VERSION 3.15) 3 | project(moonshine-cpp VERSION 1.0.1 LANGUAGES CXX) 4 | 5 | # Set C++ standard 6 | set(CMAKE_CXX_STANDARD 17) 7 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 8 | set(CMAKE_CXX_EXTENSIONS OFF) 9 | 10 | # Build type 11 | if(NOT CMAKE_BUILD_TYPE) 12 | set(CMAKE_BUILD_TYPE Release) 13 | endif() 14 | 15 | # Compiler flags 16 | if(MSVC) 17 | add_compile_options(/W4 /utf-8) 18 | else() 19 | add_compile_options(-Wall -Wextra -Wpedantic) 20 | endif() 21 | 22 | # Find ONNX Runtime package 23 | include(cmake/FetchOnnxruntime.cmake) 24 | include(cmake/FetchNlohmannJson.cmake) 25 | 26 | # Create moonshine library 27 | add_library(moonshine 28 | src/moonshine.cpp 29 | ) 30 | 31 | target_include_directories(moonshine 32 | PUBLIC 33 | $ 34 | $ 35 | PRIVATE 36 | ${ONNXRUNTIME_INCLUDE_DIRS} 37 | ) 38 | 39 | target_link_libraries(moonshine PUBLIC nlohmann_json::nlohmann_json) 40 | target_link_libraries(moonshine INTERFACE Ort) 41 | 42 | # Option to build examples 43 | option(BUILD_EXAMPLES "Build example programs" OFF) 44 | 45 | if(BUILD_EXAMPLES) 46 | add_subdirectory(examples) 47 | endif() 48 | 49 | # Installation rules 50 | include(GNUInstallDirs) 51 | 52 | install(TARGETS moonshine Ort nlohmann_json 53 | EXPORT moonshine-targets 54 | LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} 55 | ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} 56 | RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} 57 | INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} 58 | ) 59 | 60 | install(FILES src/moonshine.hpp 61 | DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/moonshine 62 | ) 63 | 64 | # Export targets 65 | install(EXPORT moonshine-targets 66 | FILE moonshine-targets.cmake 67 | NAMESPACE Moonshine:: 68 | DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/moonshine 69 | ) 70 | 71 | # Create and install config file 72 | include(CMakePackageConfigHelpers) 73 | 74 | configure_package_config_file( 75 | ${CMAKE_CURRENT_SOURCE_DIR}/cmake/moonshine-config.cmake.in 76 | ${CMAKE_CURRENT_BINARY_DIR}/moonshine-config.cmake 77 | INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/moonshine 78 | ) 79 | 80 | write_basic_package_version_file( 81 | ${CMAKE_CURRENT_BINARY_DIR}/moonshine-config-version.cmake 82 | VERSION ${PROJECT_VERSION} 83 | COMPATIBILITY SameMajorVersion 84 | ) 85 | 86 | install(FILES 87 | ${CMAKE_CURRENT_BINARY_DIR}/moonshine-config.cmake 88 | ${CMAKE_CURRENT_BINARY_DIR}/moonshine-config-version.cmake 89 | DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/moonshine 90 | ) 91 | -------------------------------------------------------------------------------- /examples/demo.cpp: -------------------------------------------------------------------------------- 1 | // main.cpp 2 | #include 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | std::vector readWavFile(const std::string &filename) 10 | { 11 | std::ifstream file(filename, std::ios::binary); 12 | if (!file.is_open()) 13 | { 14 | throw std::runtime_error("Could not open WAV file"); 15 | } 16 | 17 | // Skip WAV header (44 bytes) 18 | file.seekg(44); 19 | 20 | // Read PCM data 21 | std::vector pcm_data; 22 | int16_t sample; 23 | while (file.read(reinterpret_cast(&sample), sizeof(int16_t))) 24 | { 25 | pcm_data.push_back(sample); 26 | if (pcm_data.size() >= 16000 * 30) 27 | { 28 | break; 29 | } 30 | } 31 | 32 | // Convert to float32 normalized [-1.0, 1.0] 33 | std::vector float_data; 34 | float_data.reserve(pcm_data.size()); 35 | for (const auto &pcm_sample : pcm_data) 36 | { 37 | float_data.push_back(static_cast(pcm_sample) / 32768.0f); 38 | } 39 | 40 | return float_data; 41 | } 42 | 43 | int main(int argc, char *argv[]) 44 | { 45 | if (argc != 3) 46 | { 47 | std::cerr << "Usage: " << argv[0] << " \n"; 48 | return 1; 49 | } 50 | 51 | try 52 | { 53 | // Read audio file 54 | auto audio_samples = readWavFile(argv[2]); 55 | 56 | std::cout << "Read " << audio_samples.size() << " samples from file (" 57 | << audio_samples.size() / 16000.0 << " seconds)\n"; 58 | 59 | // Initialize model 60 | MoonshineModel model(argv[1]); 61 | 62 | // Generate tokens 63 | 64 | auto start = std::chrono::high_resolution_clock::now(); 65 | auto tokens = model.generate(audio_samples); 66 | auto end = std::chrono::high_resolution_clock::now(); 67 | std::chrono::duration duration = end - start; 68 | 69 | std::cout << "Token generation took " << duration.count() << " seconds\n"; 70 | 71 | // Print tokens (you'll need to decode these using your tokenizer) 72 | std::cout << "Generated tokens: "; 73 | for (const auto &token : tokens) 74 | { 75 | std::cout << token << " "; 76 | } 77 | std::cout << "\n"; 78 | 79 | // Detokenize tokens 80 | std::string result = model.detokenize(tokens); 81 | std::cout << "Detokenized: " << result << "\n"; 82 | } 83 | catch (const Ort::Exception &e) 84 | { 85 | std::cerr << "ONNX Runtime error: " << e.what() << "\n"; 86 | return 1; 87 | } 88 | catch (const std::exception &e) 89 | { 90 | std::cerr << "Error: " << e.what() << "\n"; 91 | return 1; 92 | } 93 | 94 | return 0; 95 | } 96 | -------------------------------------------------------------------------------- /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | name: Release 2 | run-name: ${{ github.ref_name }} release run 🚀 3 | on: 4 | push: 5 | branches: 6 | - main 7 | tags: 8 | - '*' 9 | permissions: 10 | contents: write 11 | concurrency: 12 | group: '${{ github.workflow }} @ ${{ github.ref }}' 13 | cancel-in-progress: ${{ github.ref_type == 'tag' }} 14 | jobs: 15 | build-project: 16 | name: Build Project 🧱 17 | uses: ./.github/workflows/build.yaml 18 | secrets: inherit 19 | permissions: 20 | contents: read 21 | 22 | create-release: 23 | name: Create Release 🛫 24 | if: github.ref_type == 'tag' 25 | runs-on: ubuntu-22.04 26 | needs: build-project 27 | defaults: 28 | run: 29 | shell: bash 30 | steps: 31 | - name: Check Release Tag ☑️ 32 | id: check 33 | run: | 34 | : Check Release Tag ☑️ 35 | if [[ "${RUNNER_DEBUG}" ]]; then set -x; fi 36 | shopt -s extglob 37 | 38 | case "${GITHUB_REF_NAME}" in 39 | +([0-9]).+([0-9]).+([0-9]) ) 40 | echo 'validTag=true' >> $GITHUB_OUTPUT 41 | echo 'prerelease=false' >> $GITHUB_OUTPUT 42 | echo "version=${GITHUB_REF_NAME}" >> $GITHUB_OUTPUT 43 | ;; 44 | +([0-9]).+([0-9]).+([0-9])-@(beta|rc)*([0-9]) ) 45 | echo 'validTag=true' >> $GITHUB_OUTPUT 46 | echo 'prerelease=true' >> $GITHUB_OUTPUT 47 | echo "version=${GITHUB_REF_NAME}" >> $GITHUB_OUTPUT 48 | ;; 49 | *) echo 'validTag=false' >> $GITHUB_OUTPUT ;; 50 | esac 51 | 52 | - name: Download Build Artifacts 📥 53 | uses: actions/download-artifact@v4 54 | if: fromJSON(steps.check.outputs.validTag) 55 | id: download 56 | 57 | - name: Print downloaded artifacts 📥 58 | if: fromJSON(steps.check.outputs.validTag) 59 | run: | 60 | : Print downloaded artifacts 📥 61 | if [[ "${RUNNER_DEBUG}" ]]; then set -x; fi 62 | shopt -s extglob 63 | 64 | ls -laR ${{ steps.download.outputs.artifacts }} 65 | 66 | - name: Rename Files 🏷️ 67 | if: fromJSON(steps.check.outputs.validTag) 68 | run: | 69 | : Rename Files 🏷️ 70 | if [[ "${RUNNER_DEBUG}" ]]; then set -x; fi 71 | shopt -s extglob 72 | shopt -s nullglob 73 | 74 | commit_hash="${GITHUB_SHA:0:9}" 75 | 76 | variants=( 77 | 'ubuntu-latest-Release' 78 | 'macos-latest-Release' 79 | 'windows-latest-Release' 80 | 'ubuntu-latest-Debug' 81 | 'macos-latest-Debug' 82 | 'windows-latest-Debug' 83 | ) 84 | 85 | mkdir -p "${{ github.workspace }}/uploads" 86 | 87 | for variant in "${variants[@]}"; do 88 | zip "${{ github.workspace }}/uploads/moonshine-cpp-${variant}-${GITHUB_REF_NAME}.zip" "${variant}" 89 | done 90 | 91 | - name: Upload Release Artifacts 📤 92 | if: fromJSON(steps.check.outputs.validTag) 93 | uses: actions/upload-artifact@v4 94 | with: 95 | name: moonshine-cpp-${GITHUB_REF_NAME} 96 | path: uploads 97 | 98 | - name: Generate Checksums 🪪 99 | if: fromJSON(steps.check.outputs.validTag) 100 | run: | 101 | : Generate Checksums 🪪 102 | if [[ "${RUNNER_DEBUG}" ]]; then set -x; fi 103 | shopt -s extglob 104 | 105 | echo "### Checksums" > ${{ github.workspace }}/CHECKSUMS.txt 106 | # find the files from the above step and generate checksums 107 | for file in ${{ github.workspace }}/uploads/moonshine-cpp-*; do 108 | echo " ${file##*/}: $(sha256sum "${file}" | cut -d " " -f 1)" >> ${{ github.workspace }}/CHECKSUMS.txt 109 | done 110 | 111 | - name: Create Release 🛫 112 | if: fromJSON(steps.check.outputs.validTag) 113 | id: create_release 114 | uses: softprops/action-gh-release@v2 115 | with: 116 | draft: false 117 | name: Moonshine.cpp v${{ steps.check.outputs.version }} 118 | generate_release_notes: true 119 | body_path: ${{ github.workspace }}/CHECKSUMS.txt 120 | files: | 121 | ${{ github.workspace }}/uploads/moonshine-cpp-*.zip 122 | -------------------------------------------------------------------------------- /cmake/FetchOnnxruntime.cmake: -------------------------------------------------------------------------------- 1 | include(FetchContent) 2 | 3 | set(CUSTOM_ONNXRUNTIME_URL 4 | "" 5 | CACHE STRING "URL of a downloaded ONNX Runtime tarball") 6 | 7 | set(CUSTOM_ONNXRUNTIME_HASH 8 | "" 9 | CACHE STRING "Hash of a downloaded ONNX Runtime tarball") 10 | 11 | set(Onnxruntime_VERSION "1.19.2") 12 | 13 | if(CUSTOM_ONNXRUNTIME_URL STREQUAL "") 14 | set(USE_PREDEFINED_ONNXRUNTIME ON) 15 | else() 16 | if(CUSTOM_ONNXRUNTIME_HASH STREQUAL "") 17 | message(FATAL_ERROR "Both of CUSTOM_ONNXRUNTIME_URL and CUSTOM_ONNXRUNTIME_HASH must be present!") 18 | else() 19 | set(USE_PREDEFINED_ONNXRUNTIME OFF) 20 | endif() 21 | endif() 22 | 23 | if(USE_PREDEFINED_ONNXRUNTIME) 24 | set(Onnxruntime_BASEURL "https://github.com/microsoft/onnxruntime/releases/download/v${Onnxruntime_VERSION}") 25 | 26 | if(APPLE) 27 | set(Onnxruntime_URL "${Onnxruntime_BASEURL}/onnxruntime-osx-universal2-${Onnxruntime_VERSION}.tgz") 28 | set(Onnxruntime_HASH SHA256=b0289ddbc32f76e5d385abc7b74cc7c2c51cdf2285b7d118bf9d71206e5aee3a) 29 | elseif(MSVC) 30 | set(Onnxruntime_URL "${Onnxruntime_BASEURL}/onnxruntime-win-x64-${Onnxruntime_VERSION}.zip") 31 | set(OOnnxruntime_HASH SHA256=dc4f841e511977c0a4f02e5066c3d9a58427644010ab4f89b918614a1cd4c2b0) 32 | else() 33 | if(CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64") 34 | set(Onnxruntime_URL "${Onnxruntime_BASEURL}/onnxruntime-linux-aarch64-${Onnxruntime_VERSION}.tgz") 35 | set(Onnxruntime_HASH SHA256=dc4f841e511977c0a4f02e5066c3d9a58427644010ab4f89b918614a1cd4c2b0) 36 | else() 37 | set(Onnxruntime_URL "${Onnxruntime_BASEURL}/onnxruntime-linux-x64-gpu-${Onnxruntime_VERSION}.tgz") 38 | set(Onnxruntime_HASH SHA256=4d1c10f0b410b67261302c6e18bb1b05ba924ca9081e3a26959e0d12ab69f534) 39 | endif() 40 | endif() 41 | else() 42 | set(Onnxruntime_URL "${CUSTOM_ONNXRUNTIME_URL}") 43 | set(Onnxruntime_HASH "${CUSTOM_ONNXRUNTIME_HASH}") 44 | endif() 45 | 46 | FetchContent_Declare( 47 | onnxruntime 48 | DOWNLOAD_EXTRACT_TIMESTAMP 1 49 | URL ${Onnxruntime_URL} 50 | URL_HASH ${Onnxruntime_HASH}) 51 | FetchContent_MakeAvailable(onnxruntime) 52 | 53 | add_library(Ort INTERFACE) 54 | set(ONNXRUNTIME_INCLUDE_DIRS "${onnxruntime_SOURCE_DIR}/include") 55 | 56 | if(APPLE) 57 | set(Onnxruntime_LIB "${onnxruntime_SOURCE_DIR}/lib/libonnxruntime.${Onnxruntime_VERSION}.dylib") 58 | 59 | target_link_libraries(Ort INTERFACE "${Onnxruntime_LIB}") 60 | target_include_directories(Ort INTERFACE 61 | $ 62 | $ 63 | ) 64 | # target_sources(Ort INTERFACE "${Onnxruntime_LIB}") 65 | # set_property(SOURCE "${Onnxruntime_LIB}" PROPERTY MACOSX_PACKAGE_LOCATION Frameworks) 66 | # source_group("Frameworks" FILES "${Onnxruntime_LIB}") 67 | # add_custom_command( 68 | # TARGET Ort 69 | # POST_BUILD 70 | # COMMAND 71 | # ${CMAKE_INSTALL_NAME_TOOL} -change "@rpath/libonnxruntime.${Onnxruntime_VERSION}.dylib" 72 | # "@loader_path/../Frameworks/libonnxruntime.${Onnxruntime_VERSION}.dylib" $) 73 | elseif(MSVC) 74 | set(Onnxruntime_LIB_NAMES onnxruntime;onnxruntime_providers_shared) 75 | 76 | foreach(lib_name IN LISTS Onnxruntime_LIB_NAMES) 77 | add_library(Ort::${lib_name} SHARED IMPORTED) 78 | set_target_properties(Ort::${lib_name} PROPERTIES IMPORTED_IMPLIB ${onnxruntime_SOURCE_DIR}/lib/${lib_name}.lib) 79 | set_target_properties(Ort::${lib_name} PROPERTIES IMPORTED_LOCATION ${onnxruntime_SOURCE_DIR}/lib/${lib_name}.dll) 80 | set_target_properties(Ort::${lib_name} PROPERTIES INTERFACE_INCLUDE_DIRECTORIES ${onnxruntime_SOURCE_DIR}/include) 81 | target_link_libraries(Ort INTERFACE Ort::${lib_name}) 82 | install(FILES ${onnxruntime_SOURCE_DIR}/lib/${lib_name}.dll DESTINATION "${CMAKE_INSTALL_LIBDIR}/") 83 | endforeach() 84 | else() 85 | set(Onnxruntime_LINK_LIBS "${onnxruntime_SOURCE_DIR}/lib/libonnxruntime.so.${Onnxruntime_VERSION}") 86 | set(Onnxruntime_ADDITIONAL_LIBS 87 | "${onnxruntime_SOURCE_DIR}/lib/libonnxruntime_providers_shared.so" 88 | "${onnxruntime_SOURCE_DIR}/lib/libonnxruntime.so" "${onnxruntime_SOURCE_DIR}/lib/libonnxruntime.so.1") 89 | 90 | target_link_libraries(Ort INTERFACE "${Onnxruntime_LINK_LIBS}") 91 | 92 | if(CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64") 93 | set(Onnxruntime_INSTALL_LIBS ${Onnxruntime_LINK_LIBS} ${Onnxruntime_ADDITIONAL_LIBS}) 94 | else() 95 | set(Onnxruntime_INSTALL_LIBS ${Onnxruntime_LINK_LIBS} ${Onnxruntime_ADDITIONAL_LIBS}) 96 | endif() 97 | 98 | install(FILES ${Onnxruntime_INSTALL_LIBS} DESTINATION "${CMAKE_INSTALL_PREFIX}/lib/${CMAKE_PROJECT_NAME}") 99 | set_target_properties(Ort PROPERTIES INSTALL_RPATH "$ORIGIN/${CMAKE_PROJECT_NAME}") 100 | endif() 101 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # moonshine.cpp 2 | 3 |
4 | 5 | [![Build Status](https://github.com/locaal-ai/moonshine.cpp/actions/workflows/build.yaml/badge.svg)](https://github.com/locaal-ai/moonshine.cpp/actions) 6 | 7 |
8 | 9 | Standalone C++ implementation of [Moonshine ASR](https://github.com/usefulsensors/moonshine) with [ONNXRuntime](https://github.com/microsoft/onnxruntime) and no other dependencies. 10 | 11 | ## Table of Contents 12 | 13 | - [Introduction](#introduction) 14 | - [Build Instructions](#build-instructions) 15 | - [Example Usage](#example-usage) 16 | - [Using as a Library](#using-as-a-library) 17 | - [Credits](#credits) 18 | - [License](#license) 19 | 20 | ## Introduction 21 | 22 | [Moonshine ASR](https://github.com/usefulsensors/moonshine) is an Automatic Speech Recognition (ASR) system implemented in C++ using [ONNX Runtime](https://github.com/microsoft/onnxruntime). This project provides a standalone implementation that can be built and run on various platforms. 23 | 24 | ## Build Instructions 25 | 26 | ### Prerequisites 27 | 28 | - CMake 3.15 or higher 29 | - A C++17 compatible compiler 30 | - ONNX Runtime for the building OS will be fetched in build time. 31 | 32 | ### Building on Windows 33 | 34 | 1. Open a PowerShell terminal. 35 | 2. Navigate to the root directory of the project. 36 | 3. Run the build script: 37 | 38 | ```ps1 39 | .\scripts\build.ps1 -BuildType Release 40 | ``` 41 | 42 | ### Building on Linux/MacOS 43 | 44 | 1. Open a terminal. 45 | 2. Navigate to the root directory of the project. 46 | 3. Run the build script: 47 | 48 | ```sh 49 | ./scripts/build.sh 50 | ``` 51 | 52 | ### Manual Build Steps 53 | 54 | 1. Create a build directory: 55 | 56 | ```sh 57 | mkdir -p build 58 | cd build 59 | ``` 60 | 61 | 2. Configure the project with CMake: 62 | 63 | ```sh 64 | cmake -DCMAKE_BUILD_TYPE=Release .. 65 | ``` 66 | 67 | 3. Build the project: 68 | 69 | ```sh 70 | cmake --build . 71 | ``` 72 | 73 | 4. Install the project: 74 | 75 | ```sh 76 | cmake --install . --prefix ../dist 77 | ``` 78 | 79 | ### Build results 80 | 81 | After building you should have a folder like so (e.g. on Windows): 82 | 83 | ``` 84 | ./dist 85 | ├───bin 86 | │ moonshine_example.exe 87 | │ moonshine_live.exe 88 | │ onnxruntime.dll 89 | │ onnxruntime_providers_shared.dll 90 | │ 91 | ├───include 92 | │ └───moonshine 93 | │ moonshine.hpp 94 | │ 95 | └───lib 96 | │ moonshine.lib 97 | │ onnxruntime.dll 98 | │ onnxruntime_providers_shared.dll 99 | │ 100 | └───cmake 101 | └───moonshine 102 | moonshine-config-version.cmake 103 | moonshine-config.cmake 104 | moonshine-targets-release.cmake 105 | moonshine-targets.cmake 106 | ``` 107 | 108 | Which will allow you to link against moonshine.cpp with a CMake `find_package()`. 109 | 110 | ## Example Usage 111 | 112 | The project includes two example applications: 113 | - `moonshine_example`: File-based transcription 114 | - `moonshine_live`: Real-time microphone transcription (requires SDL2) 115 | 116 | ### Building Examples 117 | 118 | Build the project with examples enabled: 119 | 120 | ```powershell 121 | .\scripts\build.ps1 -BuildType Release -BuildExamples 122 | ``` 123 | 124 | ### File-based Transcription 125 | Transcribe audio from a WAV file: 126 | ```powershell 127 | .\dist\bin\moonshine_example.exe 128 | ``` 129 | 130 | Replace with the directory containing your ONNX models and with the path to a WAV file. 131 | 132 | Example: 133 | ```powershell 134 | .\dist\bin\moonshine_example.exe models\ example.wav 135 | ``` 136 | 137 | ### Live Transcription 138 | Real-time transcription from microphone input: 139 | ```powershell 140 | .\dist\bin\moonshine_live.exe 141 | ``` 142 | 143 | Example: 144 | ```powershell 145 | .\dist\bin\moonshine_live.exe models\ 146 | ``` 147 | 148 | Press ESC or q to stop recording and exit. 149 | 150 | ## Using as a Library 151 | 152 | To use Moonshine ASR as a library in your C++ project, follow these steps: 153 | 154 | 1. Build and install the Moonshine ASR library as described in the [Build Instructions](#build-instructions). 155 | 156 | 2. Link against the installed library in your CMake project: 157 | 158 | ```cmake 159 | cmake_minimum_required(VERSION 3.15) 160 | project(my_project) 161 | 162 | # Find the Moonshine package 163 | find_package(Moonshine REQUIRED) 164 | 165 | add_executable(my_project main.cpp) 166 | 167 | # Link against the Moonshine library 168 | target_link_libraries(my_project PRIVATE Moonshine::moonshine) 169 | ``` 170 | 171 | 3. Include the Moonshine header in your source code: 172 | 173 | ```cpp 174 | #include 175 | 176 | int main() { 177 | MoonshineModel model("path/to/models"); 178 | // Use the model... 179 | return 0; 180 | } 181 | ``` 182 | 183 | ## Credits 184 | 185 | This project is based on the Moonshine ASR system: https://github.com/usefulsensors/moonshine. Special thanks to the Moonshine team for their contributions. 186 | 187 | ## License 188 | 189 | This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details. 190 | -------------------------------------------------------------------------------- /examples/live.cpp: -------------------------------------------------------------------------------- 1 | // live.cpp 2 | #define SDL_MAIN_HANDLED 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include // For _kbhit() and _getch() on Windows 11 | 12 | const int SAMPLE_RATE = 16000; 13 | const int BUFFER_SIZE = 4096; 14 | 15 | void audioCallback(void* userdata, Uint8* stream, int len) 16 | { 17 | std::vector* buffer = static_cast*>(userdata); 18 | float* samples = reinterpret_cast(stream); 19 | int sample_count = len / sizeof(float); 20 | buffer->insert(buffer->end(), samples, samples + sample_count); 21 | } 22 | 23 | void listAudioDevices() 24 | { 25 | int count = SDL_GetNumAudioDevices(SDL_TRUE); // SDL_TRUE for recording devices 26 | std::cout << "Available recording devices:\n"; 27 | for (int i = 0; i < count; ++i) 28 | { 29 | const char* name = SDL_GetAudioDeviceName(i, SDL_TRUE); 30 | std::cout << i << ": " << (name ? name : "Unknown Device") << "\n"; 31 | } 32 | } 33 | 34 | int main(int argc, char* argv[]) 35 | { 36 | std::cout << "Moonshine Live Transcription\n"; 37 | 38 | SDL_SetMainReady(); // Tell SDL we'll handle the main entry point 39 | if (argc != 2) 40 | { 41 | std::cerr << "Usage: " << argv[0] << " \n"; 42 | return 1; 43 | } 44 | 45 | try 46 | { 47 | std::cout << "Initializing...\n"; 48 | 49 | // Initialize model 50 | MoonshineModel model(argv[1]); 51 | 52 | std::cout << "Model initialized\n"; 53 | 54 | // Initialize SDL 55 | if (SDL_Init(SDL_INIT_AUDIO) < 0) 56 | { 57 | std::cerr << "Could not initialize SDL: " << SDL_GetError() << "\n"; 58 | return 1; 59 | } 60 | 61 | std::cout << "SDL initialized\n"; 62 | 63 | // List available devices 64 | listAudioDevices(); 65 | 66 | // Set up audio capture 67 | SDL_AudioSpec desired_spec; 68 | SDL_AudioSpec obtained_spec; 69 | SDL_zero(desired_spec); 70 | desired_spec.freq = SAMPLE_RATE; 71 | desired_spec.format = AUDIO_F32; 72 | desired_spec.channels = 1; 73 | desired_spec.samples = BUFFER_SIZE; 74 | desired_spec.callback = audioCallback; 75 | 76 | std::vector audio_buffer; 77 | desired_spec.userdata = &audio_buffer; 78 | 79 | // Open the default recording device 80 | SDL_AudioDeviceID dev = SDL_OpenAudioDevice(NULL, // device name (NULL for default) 81 | SDL_TRUE, // is_capture (recording) 82 | &desired_spec, // desired spec 83 | &obtained_spec, // obtained spec 84 | SDL_AUDIO_ALLOW_FORMAT_CHANGE); 85 | 86 | if (dev == 0) 87 | { 88 | std::cerr << "Could not open audio device: " << SDL_GetError() << "\n"; 89 | SDL_Quit(); 90 | return 1; 91 | } 92 | 93 | std::cout << "Audio device opened: " << SDL_GetAudioDeviceName(0, SDL_TRUE) << "\n"; 94 | // print the obtained spec 95 | std::cout << "Obtained spec: " << obtained_spec.freq << " Hz, " 96 | << SDL_AUDIO_BITSIZE(obtained_spec.format) << " bits, " 97 | << (obtained_spec.channels == 1 ? "mono" : "stereo") << "\n"; 98 | 99 | // Start audio capture 100 | SDL_PauseAudioDevice(dev, 0); 101 | 102 | std::atomic running(true); 103 | std::thread transcription_thread( 104 | [&]() 105 | { 106 | std::cout << "Transcribing...\n"; 107 | size_t last_buffer_size = 0; 108 | while (running) 109 | { 110 | if (audio_buffer.size() >= SAMPLE_RATE) 111 | { 112 | if (audio_buffer.size() == last_buffer_size) 113 | { 114 | // No new audio data 115 | std::this_thread::sleep_for(std::chrono::milliseconds(100)); 116 | continue; 117 | } 118 | last_buffer_size = audio_buffer.size(); 119 | 120 | // Process audio buffer 121 | std::vector buffer(audio_buffer.begin(), audio_buffer.end()); 122 | 123 | // Limit the buffer size to 10 seconds 124 | if (audio_buffer.size() > 10 * SAMPLE_RATE) 125 | { 126 | audio_buffer.erase(audio_buffer.begin(), audio_buffer.end()); 127 | } 128 | 129 | // Generate tokens 130 | auto start = std::chrono::high_resolution_clock::now(); 131 | auto tokens = model.generate(buffer); 132 | auto end = std::chrono::high_resolution_clock::now(); 133 | std::chrono::duration duration = end - start; 134 | 135 | // Detokenize tokens 136 | std::string result = model.detokenize(tokens); 137 | 138 | // erase the last console line 139 | std::cout << "\x1b[A"; 140 | // clear the line 141 | std::cout << "\r\033[K"; 142 | 143 | std::cout << "Transcription: " << result << "\n"; 144 | } 145 | else 146 | { 147 | std::this_thread::sleep_for(std::chrono::milliseconds(100)); 148 | } 149 | } 150 | std::cout << "Transcription thread finished\n"; 151 | }); 152 | 153 | std::cout << "Recording... Press 'q' or 'ESC' to stop.\n"; 154 | while (running) 155 | { 156 | if (_kbhit()) 157 | { 158 | int ch = _getch(); 159 | if (ch == 'q' || ch == 27) 160 | { // 27 is the ASCII code for ESC 161 | running = false; 162 | } 163 | } 164 | std::this_thread::sleep_for(std::chrono::milliseconds(100)); 165 | } 166 | 167 | // Stop audio capture 168 | SDL_PauseAudioDevice(dev, 1); 169 | 170 | // Wait for transcription thread to finish 171 | transcription_thread.join(); 172 | 173 | // Clean up 174 | SDL_CloseAudioDevice(dev); 175 | SDL_Quit(); 176 | } 177 | catch (const Ort::Exception& e) 178 | { 179 | std::cerr << "ONNX Runtime error: " << e.what() << "\n"; 180 | return 1; 181 | } 182 | catch (const std::exception& e) 183 | { 184 | std::cerr << "Error: " << e.what() << "\n"; 185 | return 1; 186 | } 187 | 188 | return 0; 189 | } -------------------------------------------------------------------------------- /src/moonshine.cpp: -------------------------------------------------------------------------------- 1 | #include "moonshine.hpp" 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #ifdef _WIN32 10 | #include 11 | #endif 12 | 13 | std::vector cached_decode_input_names = { 14 | "args_0", "args_1", "args_2", "args_3", "args_4", "args_5", "args_6", 15 | "args_7", "args_8", "args_9", "args_10", "args_11", "args_12", "args_13", 16 | "args_14", "args_15", "args_16", "args_17", "args_18", "args_19", "args_20", 17 | "args_21", "args_22", "args_23", "args_24", "args_25", "args_26", 18 | }; 19 | 20 | std::vector cached_decode_output_names = { 21 | "reversible_embedding", "functional_23", "functional_23_1", "input_layer_102", 22 | "input_layer_103", "functional_26", "functional_26_1", "input_layer_106", 23 | "input_layer_107", "functional_29", "functional_29_1", "input_layer_110", 24 | "input_layer_111", "functional_32", "functional_32_1", "input_layer_114", 25 | "input_layer_115", "functional_35", "functional_35_1", "input_layer_118", 26 | "input_layer_119", "functional_38", "functional_38_1", "input_layer_122", 27 | "input_layer_123", 28 | }; 29 | 30 | std::vector decode_input_names = {"args_0", "args_1", "args_2"}; 31 | std::vector decode_output_names = { 32 | "reversible_embedding", "functional_22", "functional_22_1", "functional_22_2", 33 | "functional_22_3", "functional_25", "functional_25_1", "functional_25_2", 34 | "functional_25_3", "functional_28", "functional_28_1", "functional_28_2", 35 | "functional_28_3", "functional_31", "functional_31_1", "functional_31_2", 36 | "functional_31_3", "functional_34", "functional_34_1", "functional_34_2", 37 | "functional_34_3", "functional_37", "functional_37_1", "functional_37_2", 38 | "functional_37_3", 39 | }; 40 | 41 | std::vector encode_input_names = {"args_0", "args_1"}; 42 | std::vector encode_ouput_names = {"layer_normalization_12"}; 43 | 44 | // Function to read a UTF-8 encoded file into a string 45 | std::string readFileAsUtf8(const std::string &file_path) 46 | { 47 | std::ifstream file(file_path, std::ios::binary); 48 | if (!file.is_open()) 49 | { 50 | throw std::runtime_error("File not found: " + file_path); 51 | } 52 | 53 | std::stringstream buffer; 54 | buffer << file.rdbuf(); 55 | return buffer.str(); 56 | } 57 | 58 | MoonshineModel::MoonshineModel(const std::string &models_dir) 59 | : memory_info_(Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU)) 60 | { 61 | std::cout << "Initializing Moonshine model from " << models_dir << std::endl; 62 | this->env_ = Ort::Env(ORT_LOGGING_LEVEL_WARNING, "MoonshineModel"); 63 | preprocess_ = createSession(models_dir + "/preprocess.onnx"); 64 | encode_ = createSession(models_dir + "/encode.onnx"); 65 | uncached_decode_ = createSession(models_dir + "/uncached_decode.onnx"); 66 | cached_decode_ = createSession(models_dir + "/cached_decode.onnx"); 67 | 68 | // Read tokenizer JSON as UTF-8 69 | std::string tokenizer_content = readFileAsUtf8(models_dir + "/tokenizer.json"); 70 | load_tokenizer(tokenizer_content); 71 | } 72 | 73 | std::unique_ptr MoonshineModel::createSession(const std::string &model_path) 74 | { 75 | if (!std::filesystem::exists(model_path)) 76 | { 77 | throw std::runtime_error("Model file not found: " + model_path); 78 | } 79 | 80 | Ort::SessionOptions session_options; 81 | session_options.SetIntraOpNumThreads(1); 82 | session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL); 83 | 84 | #ifdef _WIN32 85 | // Convert wstring for Windows compatibility 86 | const std::wstring real_path = std::filesystem::path(model_path).wstring(); 87 | #else 88 | const std::string real_path = std::filesystem::path(model_path); 89 | #endif 90 | 91 | // Use the constructor with wide string path 92 | return std::make_unique(env_, real_path.c_str(), session_options); 93 | } 94 | 95 | std::vector MoonshineModel::generate(const std::vector &audio_samples, 96 | size_t max_len) 97 | { 98 | // Prepare input audio tensor 99 | std::vector audio_shape = {1, static_cast(audio_samples.size())}; 100 | Ort::Value audio_tensor = Ort::Value::CreateTensor( 101 | memory_info_, const_cast(audio_samples.data()), audio_samples.size(), 102 | audio_shape.data(), audio_shape.size()); 103 | 104 | std::vector outputNames; 105 | Ort::AllocatorWithDefaultOptions allocator; 106 | 107 | std::vector rawInputNames = {"args_0"}; 108 | std::vector rawOutputNames = {"sequential"}; 109 | 110 | // Preprocess 111 | std::vector preprocess_inputs; 112 | preprocess_inputs.push_back(std::move(audio_tensor)); 113 | auto preprocessed = 114 | preprocess_->Run(Ort::RunOptions{nullptr}, rawInputNames.data(), preprocess_inputs.data(), 115 | preprocess_inputs.size(), rawOutputNames.data(), 1); 116 | 117 | // print the shape of the output tensor 118 | auto shape = preprocessed[0].GetTensorTypeAndShapeInfo().GetShape(); 119 | 120 | // Calculate sequence length 121 | int32_t seq_len = (int32_t)shape[1]; 122 | const std::vector seq_len_shape = {1}; 123 | Ort::Value seq_len_tensor = Ort::Value::CreateTensor( 124 | memory_info_, &seq_len, 1, seq_len_shape.data(), seq_len_shape.size()); 125 | 126 | std::vector encode_inputs; 127 | encode_inputs.reserve(2); // Reserve space for the inputs 128 | encode_inputs.push_back(std::move(preprocessed[0])); 129 | encode_inputs.push_back(std::move(seq_len_tensor)); 130 | // Encode 131 | auto context = 132 | encode_->Run(Ort::RunOptions{nullptr}, encode_input_names.data(), encode_inputs.data(), 133 | encode_inputs.size(), encode_ouput_names.data(), encode_ouput_names.size()); 134 | 135 | // copy context to avoid modifying the original context 136 | auto context_shape = context[0].GetTensorTypeAndShapeInfo().GetShape(); 137 | Ort::Value context_copy = 138 | Ort::Value::CreateTensor(memory_info_, context[0].GetTensorMutableData(), 139 | context[0].GetTensorTypeAndShapeInfo().GetElementCount(), 140 | context_shape.data(), context_shape.size()); 141 | 142 | // Initial token 143 | std::vector tokens = {1}; // Start token 144 | std::vector input_shape = {1, 1}; 145 | Ort::Value inputs_tensor = Ort::Value::CreateTensor( 146 | memory_info_, tokens.data(), tokens.size(), input_shape.data(), input_shape.size()); 147 | 148 | // Calculate max_len if not provided 149 | if (max_len == 0) 150 | { 151 | max_len = static_cast((audio_samples.size() / 16000.0) * 6); 152 | } 153 | 154 | seq_len_tensor = Ort::Value::CreateTensor(memory_info_, &seq_len, 1, 155 | seq_len_shape.data(), seq_len_shape.size()); 156 | 157 | std::vector uncached_decode_inputs; 158 | uncached_decode_inputs.reserve(3); // Reserve space for the inputs 159 | uncached_decode_inputs.push_back(std::move(inputs_tensor)); 160 | uncached_decode_inputs.push_back(std::move(context_copy)); 161 | uncached_decode_inputs.push_back(std::move(seq_len_tensor)); 162 | 163 | // Initial uncached decode 164 | auto cache = 165 | uncached_decode_->Run(Ort::RunOptions{nullptr}, decode_input_names.data(), 166 | uncached_decode_inputs.data(), uncached_decode_inputs.size(), 167 | decode_output_names.data(), uncached_decode_->GetOutputCount()); 168 | 169 | // Generate tokens 170 | for (size_t i = 0; i < max_len; ++i) 171 | { 172 | float *logits_data = cache[0].GetTensorMutableData(); 173 | size_t logits_size = cache[0].GetTensorTypeAndShapeInfo().GetElementCount(); 174 | 175 | // Find argmax 176 | int32_t next_token = 0; 177 | float max_val = logits_data[0]; 178 | for (size_t j = 1; j < logits_size; ++j) 179 | { 180 | if (logits_data[j] > max_val) 181 | { 182 | max_val = logits_data[j]; 183 | next_token = static_cast(j); 184 | } 185 | } 186 | 187 | tokens.push_back(next_token); 188 | if (next_token == 2) break; // End token 189 | 190 | // Update sequence length 191 | seq_len++; 192 | seq_len_tensor = Ort::Value::CreateTensor( 193 | memory_info_, &seq_len, 1, seq_len_shape.data(), seq_len_shape.size()); 194 | 195 | // Prepare next input 196 | std::vector next_input = {next_token}; 197 | inputs_tensor = Ort::Value::CreateTensor(memory_info_, next_input.data(), 1, 198 | input_shape.data(), input_shape.size()); 199 | 200 | context_copy = Ort::Value::CreateTensor( 201 | memory_info_, context[0].GetTensorMutableData(), 202 | context[0].GetTensorTypeAndShapeInfo().GetElementCount(), context_shape.data(), 203 | context_shape.size()); 204 | 205 | // Run cached decode 206 | std::vector cached_inputs; 207 | cached_inputs.push_back(std::move(inputs_tensor)); 208 | cached_inputs.push_back(std::move(context_copy)); 209 | cached_inputs.push_back(std::move(seq_len_tensor)); 210 | for (size_t j = 1; j < cache.size(); ++j) 211 | { 212 | cached_inputs.push_back(std::move(cache[j])); 213 | } 214 | 215 | cache = cached_decode_->Run(Ort::RunOptions{nullptr}, cached_decode_input_names.data(), 216 | cached_inputs.data(), cached_inputs.size(), 217 | cached_decode_output_names.data(), 218 | cached_decode_->GetOutputCount()); 219 | } 220 | 221 | return tokens; 222 | } 223 | 224 | void MoonshineModel::load_tokenizer(const std::string &tokenizer_content) 225 | { 226 | nlohmann::json tokenizer = nlohmann::json::parse(tokenizer_content); 227 | 228 | // Create token ID to token map 229 | for (const auto &item : tokenizer["model"]["vocab"].items()) 230 | { 231 | token_id_to_token_[item.value()] = item.key(); 232 | } 233 | } 234 | 235 | std::string MoonshineModel::detokenize(const std::vector &tokens) 236 | { 237 | std::string result; 238 | for (const auto &token : tokens) 239 | { 240 | if (token_id_to_token_.find(token) != token_id_to_token_.end()) 241 | { 242 | std::string token_str = token_id_to_token_.at(token); 243 | // Remove the '▁' prefix if it exists and add actual space 244 | if (!token_str.empty() && (unsigned char)token_str[0] == 0xE2) 245 | { 246 | // The '▁' character is E2 96 81 in UTF-8 247 | result += " " + token_str.substr(3); // Skip the 3 bytes of '▁' 248 | } 249 | else 250 | { 251 | result += token_str; 252 | } 253 | } 254 | } 255 | // Trim leading space if exists 256 | if (!result.empty() && result[0] == ' ') 257 | { 258 | result = result.substr(1); 259 | } 260 | return result; 261 | } 262 | --------------------------------------------------------------------------------