├── .github └── workflows │ └── build.yml ├── .gitignore ├── .gitmodules ├── CMakeLists.txt ├── LICENSE ├── README.md ├── client ├── readme.md └── websocket_client.py ├── common ├── common-ggml.cpp ├── common-ggml.h ├── common-m4a.cpp ├── common-m4a.h ├── common-sdl.cpp ├── common-sdl.h ├── common.cpp ├── common.h ├── utils.cpp └── utils.h ├── distribute └── docker │ ├── base.en │ ├── Dockerfile │ └── readme.md │ ├── builder │ └── Dockerfile │ ├── large-v3 │ ├── Dockerfile │ └── readme.md │ ├── medium-q5_0 │ ├── Dockerfile │ └── readme.md │ ├── pure │ ├── Dockerfile │ └── readme.md │ ├── tiny.en-q5_1 │ ├── Dockerfile │ └── readme.md │ └── whisper-cpp-server-builder │ ├── Dockerfile │ └── readme.md ├── doc ├── client_code.md ├── whiser-cpp-server-http-cn.md ├── whiser-cpp-server-http-en.md ├── whiser-cpp-server-websocket-cn.md └── whiser-cpp-server-websocket-en.md ├── examples ├── audio_vad.cpp ├── sdl_version.cpp ├── simplest.cpp └── stream_local.cpp ├── ggml-metal.metal ├── handler ├── echo_handler.h ├── hello_handler.cpp ├── hello_handler.h ├── inference_handler.cpp ├── inference_handler.h └── ws_save_handler.h ├── models ├── .gitignore ├── README.md ├── convert-h5-to-coreml.py ├── convert-h5-to-ggml.py ├── convert-pt-to-ggml.py ├── convert-whisper-to-coreml.py ├── convert-whisper-to-openvino.py ├── download-coreml-model.sh ├── download-ggml-model.cmd ├── download-ggml-model.sh ├── generate-coreml-interface.sh ├── generate-coreml-model.sh ├── ggml-tiny.mlmodel ├── ggml_to_pt.py └── openvino-conversion-requirements.txt ├── params ├── whisper_params.cpp └── whisper_params.h ├── pcm ├── 16k_1.pcm ├── 16k_57test.pcm ├── 16k_test.pcm └── nopause.pcm ├── public └── index.html ├── resources └── json │ ├── transcription01.json │ └── whisper_local.json ├── samples ├── .gitignore ├── README.md ├── jfk.wav └── zh.wav ├── stream ├── stream_components.h ├── stream_components_audio.cpp ├── stream_components_audio.h ├── stream_components_output.cpp ├── stream_components_output.h ├── stream_components_params.h ├── stream_components_service.cpp └── stream_components_service.h ├── thirdparty └── CMakeLists.txt ├── vcpkg.json ├── web ├── favicon.ico ├── index.html ├── paddle_web_demo.png └── readme.md ├── whisper_http_server_base_httplib.cpp └── whisper_server_base_on_uwebsockets.cpp /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: C++ Build with CMake and vcpkg 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | pull_request: 7 | branches: [ "main" ] 8 | jobs: 9 | build_ubuntu: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v3 13 | with: 14 | submodules: recursive 15 | - name: Install Dependencies 16 | run: | 17 | sudo apt-get update 18 | sudo apt-get install -y gcc g++ cmake 19 | sudo apt-get install nasm 20 | 21 | - name: Cmake Version 22 | run: | 23 | cmake --version 24 | 25 | - name: Configure CMake 26 | run: | 27 | cmake -B cmake-build-release-linux 28 | 29 | - name: Build 30 | run: cmake --build cmake-build-release-linux --config Release -- -j 12 31 | 32 | - name: Copy Additional Files 33 | run: cp ./ggml-metal.metal cmake-build-release-linux 34 | 35 | - name: Archive Build Artifacts 36 | uses: actions/upload-artifact@v2 37 | with: 38 | name: cmake-build-release-ubuntu-latest 39 | path: cmake-build-release-linux/ 40 | 41 | - name: Archive Build binary Artifacts 42 | uses: actions/upload-artifact@v2 43 | with: 44 | name: whisper-cpp-server-ubuntu-latest 45 | path: | 46 | cmake-build-release-linux/ggml-metal.metal 47 | cmake-build-release-linux/simplest 48 | cmake-build-release-linux/whisper_http_server_base_httplib 49 | cmake-build-release-linux/whisper_server_base_on_uwebsockets 50 | cmake-build-release-linux/thirdparty/whisper.cpp/libwhisper.so 51 | 52 | 53 | build_macos: 54 | runs-on: macos-latest 55 | steps: 56 | - uses: actions/checkout@v3 57 | with: 58 | submodules: recursive 59 | - name: Install Dependencies 60 | run: | 61 | brew install cmake 62 | brew install nasm 63 | 64 | - name: Cmake Version 65 | run: | 66 | cmake --version 67 | 68 | - name: Setup vcpkg 69 | run: | 70 | export VCPKG_HOME=$PWD/thirdparty/vcpkg 71 | echo $VCPKG_HOME 72 | $VCPKG_HOME/bootstrap-vcpkg.sh 73 | 74 | - name: Configure CMake 75 | run: | 76 | export VCPKG_HOME=$PWD/thirdparty/vcpkg 77 | cmake -B cmake-build-release -DCMAKE_TOOLCHAIN_FILE=$VCPKG_HOME/scripts/buildsystems/vcpkg.cmake 78 | 79 | - name: Build 80 | run: cmake --build cmake-build-release --config Release -- -j 81 | 82 | - name: Copy Additional Files 83 | run: cp ./ggml-metal.metal cmake-build-release 84 | 85 | - name: Archive Build Artifacts 86 | uses: actions/upload-artifact@v2 87 | with: 88 | name: cmake-build-release-macos-latest 89 | path: cmake-build-release/ 90 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ### Eclipse template 2 | *.pydevproject 3 | .metadata 4 | .gradle* 5 | classes/ 6 | bin/ 7 | tmp/ 8 | *.tmp 9 | *.bak 10 | *.swp 11 | *~.nib 12 | local.properties 13 | .settings/ 14 | .loadpath 15 | rebel.xml 16 | 17 | # Eclipse Core 18 | .project 19 | 20 | generatedsources 21 | 22 | # External tool builders 23 | .externalToolBuilders/ 24 | 25 | # Locally stored "Eclipse launch configurations" 26 | *.launch 27 | 28 | # CDT-specific 29 | .cproject 30 | 31 | # JDT-specific (Eclipse Java Development Tools) 32 | .classpath 33 | 34 | # PDT-specific 35 | .buildpath 36 | 37 | # sbteclipse plugin 38 | .target 39 | 40 | # TeXlipse plugin 41 | .texlipse 42 | 43 | 44 | 45 | ### JetBrains template 46 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm 47 | 48 | *.iml 49 | .flattened-pom.xml 50 | ## Directory-based project format: 51 | .idea/ 52 | # if you remove the above rule, at least ignore the following: 53 | 54 | # User-specific stuff: 55 | # .idea/workspace.xml 56 | # .idea/tasks.xml 57 | # .idea/dictionaries 58 | 59 | # Sensitive or high-churn files: 60 | # .idea/dataSources.ids 61 | # .idea/dataSources.xml 62 | # .idea/sqlDataSources.xml 63 | # .idea/dynamic.xml 64 | # .idea/uiDesigner.xml 65 | 66 | # Gradle: 67 | # .idea/gradle.xml 68 | # .idea/libraries 69 | 70 | # Mongo Explorer plugin: 71 | # .idea/mongoSettings.xml 72 | 73 | ## File-based project format: 74 | *.ipr 75 | *.iws 76 | 77 | ## Plugin-specific files: 78 | 79 | # IntelliJ 80 | /out/ 81 | 82 | # mpeltonen/sbt-idea plugin 83 | .idea_modules/ 84 | 85 | # JIRA plugin 86 | atlassian-ide-plugin.xml 87 | 88 | # Crashlytics plugin (for Android Studio and IntelliJ) 89 | com_crashlytics_export_strings.xml 90 | crashlytics.properties 91 | crashlytics-build.properties 92 | 93 | build/ 94 | 95 | # Ignore Gradle GUI config 96 | gradle-app.setting 97 | 98 | # Avoid ignoring Gradle wrapper jar file (.jar files are usually ignored) 99 | !gradle-wrapper.jar 100 | 101 | db 102 | 103 | ### Java template 104 | *.class 105 | 106 | # Mobile Tools for Java (J2ME) 107 | .mtj.tmp/ 108 | 109 | # Package Files # 110 | #*.jar 111 | 112 | # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml 113 | hs_err_pid* 114 | 115 | 116 | ### Leiningen template 117 | classes/ 118 | target/ 119 | logs/ 120 | checkouts/ 121 | .lein-deps-sum 122 | .lein-repl-history 123 | .lein-plugins/ 124 | .lein-failures 125 | .nrepl-port 126 | 127 | querydsl/ 128 | 129 | .DS_Store 130 | 131 | *.exe 132 | *.out 133 | 134 | *.log 135 | node_modules/ 136 | dist/ 137 | dist.zip 138 | package-lock.json 139 | *.lock 140 | local.properties 141 | .cxx 142 | .externalNativeBuild 143 | /captures 144 | /build 145 | __pycache__/ 146 | *.pyc 147 | 148 | 149 | cmake-build-debug/ 150 | cmake-build-debug-mingw/ 151 | venv/ 152 | .vs/ 153 | Debug/ 154 | *.bin 155 | *.wav 156 | cmake-build-release/ 157 | whisper-cpp-server.zip 158 | 159 | .cache -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "thirdparty/vcpkg"] 2 | path = thirdparty/vcpkg 3 | url = https://github.com/microsoft/vcpkg 4 | [submodule "thirdparty/whisper.cpp"] 5 | path = thirdparty/whisper.cpp 6 | url = https://github.com/ggerganov/whisper.cpp.git 7 | [submodule "thirdparty/dr_libs"] 8 | path = thirdparty/dr_libs 9 | url = https://github.com/mackron/dr_libs.git 10 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.16) 2 | project(whisper_cpp_server) 3 | 4 | set(CMAKE_CXX_STANDARD 20) 5 | 6 | add_subdirectory(thirdparty) 7 | 8 | add_executable(audio_vad examples/audio_vad.cpp common/common.cpp 9 | stream/stream_components_service.cpp common/utils.cpp) 10 | target_link_libraries(audio_vad PRIVATE whisper SampleRate::samplerate ${SPEEXDSP_LIBRARY} nlohmann_json::nlohmann_json dr_lib_header-only) 11 | # linke header file 12 | target_include_directories(audio_vad PRIVATE ${SPEEXDSP_INCLUDE_DIRS}) 13 | 14 | 15 | add_executable(sdl_version examples/sdl_version.cpp) 16 | target_link_libraries(sdl_version PRIVATE SDL2::SDL2-static) 17 | 18 | add_executable(simplest examples/simplest.cpp common/common.cpp common/utils.cpp) 19 | target_link_libraries(simplest PRIVATE whisper SampleRate::samplerate nlohmann_json::nlohmann_json dr_lib_header-only) 20 | 21 | add_executable(stream_local examples/stream_local.cpp common/common.cpp common/common-sdl.cpp common/utils.cpp 22 | stream/stream_components_service.cpp stream/stream_components_audio.cpp 23 | stream/stream_components_output.cpp 24 | ) 25 | target_link_libraries(stream_local whisper SDL2::SDL2-static SampleRate::samplerate nlohmann_json::nlohmann_json dr_lib_header-only) 26 | 27 | add_executable(whisper_http_server_base_httplib whisper_http_server_base_httplib.cpp 28 | common/common.cpp common/utils.cpp handler/inference_handler.cpp params/whisper_params.cpp 29 | common/common-m4a.cpp handler/ws_save_handler.h) 30 | target_include_directories(whisper_http_server_base_httplib PRIVATE ${FFMPEG_INCLUDE_DIRS}) 31 | target_link_directories(whisper_http_server_base_httplib PRIVATE ${FFMPEG_LIBRARY_DIRS}) 32 | target_link_libraries(whisper_http_server_base_httplib PRIVATE whisper SampleRate::samplerate ${FFMPEG_LIBRARIES} httplib::httplib nlohmann_json::nlohmann_json dr_lib_header-only) 33 | 34 | 35 | add_executable(whisper_server_base_on_uwebsockets whisper_server_base_on_uwebsockets.cpp common/common.cpp 36 | stream/stream_components_service.cpp common/utils.cpp handler/hello_handler.cpp) 37 | #add uwebsockets head files 38 | target_include_directories(whisper_server_base_on_uwebsockets PRIVATE ${UWEBSOCKETS_INCLUDE_DIRS}) 39 | # linked uWebSockets、zlib、libuv 和 uSockets libs 40 | # Detecting Operating Systems 41 | if (WIN32) 42 | # if Windows 43 | target_link_libraries(whisper_server_base_on_uwebsockets PRIVATE libuv::uv) 44 | elseif (APPLE) 45 | # if MacOS 46 | target_link_libraries(whisper_server_base_on_uwebsockets PRIVATE libuv::uv_a) 47 | else () 48 | # if others eg. Linux 49 | target_link_libraries(whisper_server_base_on_uwebsockets PRIVATE libuv::uv_a) 50 | endif () 51 | 52 | target_link_libraries(whisper_server_base_on_uwebsockets PRIVATE whisper ZLIB::ZLIB ${USOCKETS_LIBRARY} 53 | SampleRate::samplerate ${SPEEXDSP_LIBRARY} nlohmann_json::nlohmann_json dr_lib_header-only) 54 | # linked header file 55 | target_include_directories(whisper_server_base_on_uwebsockets PRIVATE ${SPEEXDSP_INCLUDE_DIRS}) 56 | 57 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016 Tong Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # whisper-cpp service 2 | ## open sourde address 3 | [github](https://github.com/ppnt/whisper-cpp-server) 4 | [gitee](https://gitee.com/ppnt/whisper-cpp-server) 5 | ## Whisper-CPP-Server Introduction 6 | Whisper-CPP-Server is a high-performance speech recognition service written in C++, designed to provide developers and enterprises with a reliable and efficient speech-to-text inference engine. This project implements technology from ggml to perform inference on the open-source Whisper model. While ensuring speed and accuracy, it supports pure CPU-based inference operations, allowing for high-quality speech recognition services without the need for specialized hardware accelerators. 7 | 8 | Real-time speech recognition and display of recognition results in the browser backend 9 | ``` 10 | https://github.com/litongjava/whisper-cpp-server 11 | ``` 12 | frontend 13 | ``` 14 | https://github.com/litongjava/listen-know-web 15 | ``` 16 | Test video 17 | 18 | https://github.com/litongjava/whisper-cpp-server/assets/31761981/ba7268fa-312c-47b2-a538-804b96bb656f 19 | 20 | 21 | ## Main Features 22 | 1.Pure C++ Inference Engine 23 | Whisper-CPP-Server is entirely written in C++, leveraging the efficiency of C++ for rapid processing of vast amounts of voice data, even in environments that only have CPUs for computing power. 24 | 25 | 2.High Performance 26 | Thanks to the computational efficiency of C++, Whisper-CPP-Server can offer exceptionally high processing speeds, meeting real-time or near-real-time speech recognition demands. It is especially suited for scenarios that require processing large volumes of voice data. 27 | 28 | 3.Support for Multiple Languages 29 | The service supports speech recognition in multiple languages, broadening its applicability across various linguistic contexts. 30 | 31 | 4.Docker Container Support 32 | A Docker image is provided, enabling quick deployment of the service through simple command-line operations, significantly simplifying installation and configuration processes. Deploy using the following command: 33 | ``` 34 | docker run -dit --name whisper-server -p 8080:8080 litongjava/whisper-cpp-server:1.0.0-large-v3 35 | ``` 36 | This means you can run Whisper-CPP-Server on any platform that supports Docker, including but not limited to Linux, Windows, and macOS. 37 | 38 | 4.Easy Integration for Clients 39 | Detailed client integration documentation is provided, helping developers quickly incorporate speech recognition functionality into their applications. 40 | [Client Code Documentation](https://github.com/litongjava/whisper-cpp-server/blob/main/doc/client_code.md) 41 | 42 | ## Applicable Scenarios 43 | Whisper-CPP-Server is suitable for a variety of applications that require fast and accurate speech recognition, including but not limited to: 44 | 45 | - Voice-driven interactive applications 46 | - Transcription of meeting records 47 | - Automatic subtitle generation 48 | - Automatic translation of multi-language content 49 | 50 | ## How to build it 51 | build with cmake and vcpkg 52 | ``` 53 | git clone https://github.com/litongjava/whisper-cpp-server.git 54 | git submodule init 55 | git submodule update 56 | cmake -B cmake-build-release 57 | cp ./ggml-metal.metal cmake-build-release 58 | cmake --build cmake-build-release --config Release -- -j 12 59 | ``` 60 | macos 61 | ```shell 62 | cmake -B cmake-build-release -DWHISPER_COREML=1 63 | ``` 64 | 65 | run with simplest 66 | ``` 67 | ./cmake-build-release/simplest -m models/ggml-base.en.bin test.wav 68 | ``` 69 | 70 | run with http-server 71 | ``` 72 | ./cmake-build-release/whisper_http_server_base_httplib -m models/ggml-base.en.bin 73 | ``` 74 | 75 | run with websocket-server 76 | ``` 77 | ./cmake-build-release/whisper_server_base_on_uwebsockets -m models/ggml-base.en.bin 78 | ``` 79 | 80 | copy command 81 | ``` 82 | mkdir bin 83 | cp ./ggml-metal.metal bin 84 | cp ./cmake-build-release/simplest bin 85 | cp ./cmake-build-release/whisper_http_server_base_httplib bin 86 | cp ./cmake-build-release/whisper_server_base_on_uwebsockets bin 87 | ``` 88 | 89 | ## simplest 90 | ```shell 91 | cmake-build-debug/simplest -m models/ggml-base.en.bin samples/jfk.wav 92 | ``` 93 | ``` 94 | simplest [options] file0.wav file1.wav ... 95 | 96 | options: 97 | -h, --help [default] show this help message and exit 98 | -m FNAME, --model FNAME [models/ggml-base.en.bin] model path 99 | -di, --diarize [false ] stereo audio diarization 100 | ``` 101 | ## whisper_http_server_base_httplib 102 | 103 | Simple http service. WAV mp4 and m4a Files are passed to the inference model via http requests. 104 | 105 | ``` 106 | ./whisper_http_server_base_httplib -h 107 | 108 | usage: ./bin/whisper_http_server_base_httplib [options] 109 | 110 | options: 111 | -h, --help [default] show this help message and exit 112 | -t N, --threads N [4 ] number of threads to use during computation 113 | -p N, --processors N [1 ] number of processors to use during computation 114 | -ot N, --offset-t N [0 ] time offset in milliseconds 115 | -on N, --offset-n N [0 ] segment index offset 116 | -d N, --duration N [0 ] duration of audio to process in milliseconds 117 | -mc N, --max-context N [-1 ] maximum number of text context tokens to store 118 | -ml N, --max-len N [0 ] maximum segment length in characters 119 | -sow, --split-on-word [false ] split on word rather than on token 120 | -bo N, --best-of N [2 ] number of best candidates to keep 121 | -bs N, --beam-size N [-1 ] beam size for beam search 122 | -wt N, --word-thold N [0.01 ] word timestamp probability threshold 123 | -et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail 124 | -lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail 125 | -debug, --debug-mode [false ] enable debug mode (eg. dump log_mel) 126 | -tr, --translate [false ] translate from source language to english 127 | -di, --diarize [false ] stereo audio diarization 128 | -tdrz, --tinydiarize [false ] enable tinydiarize (requires a tdrz model) 129 | -nf, --no-fallback [false ] do not use temperature fallback while decoding 130 | -ps, --print-special [false ] print special tokens 131 | -pc, --print-colors [false ] print colors 132 | -pp, --print-progress [false ] print progress 133 | -nt, --no-timestamps [false ] do not print timestamps 134 | -l LANG, --language LANG [en ] spoken language ('auto' for auto-detect) 135 | -dl, --detect-language [false ] exit after automatically detecting language 136 | --prompt PROMPT [ ] initial prompt 137 | -m FNAME, --model FNAME [models/ggml-base.en.bin] model path 138 | -oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference 139 | --host HOST, [127.0.0.1] Hostname/ip-adress for the service 140 | --port PORT, [8080 ] Port number for the service 141 | ``` 142 | ## start whisper_http_server_base_httplib 143 | ``` 144 | ./cmake-build-debug/whisper_http_server_base_httplib -m models/ggml-base.en.bin 145 | ``` 146 | Test server 147 | see request doc in [doc](doc) 148 | ## request examples 149 | 150 | **/inference** 151 | ``` 152 | curl --location --request POST http://127.0.0.1:8080/inference \ 153 | --form file=@"./samples/jfk.wav" \ 154 | --form temperature="0.2" \ 155 | --form response-format="json" 156 | --form audio_format="wav" 157 | ``` 158 | 159 | **/load** 160 | ``` 161 | curl 127.0.0.1:8080/load \ 162 | -H "Content-Type: multipart/form-data" \ 163 | -F model="" 164 | ``` 165 | 166 | ## whisper_server_base_on_uwebsockets 167 | web socket server 168 | start server 169 | ``` 170 | ./cmake-build-debug/whisper_server_base_on_uwebsockets -m models/ggml-base.en.bin 171 | ``` 172 | Test server 173 | see python [client](client) 174 | 175 | ## Docker 176 | ### run whisper-cpp-server:1.0.0 177 | [Dockerfile](./distribute/docker/pure/) 178 | ``` 179 | docker run -dit --name=whisper-server -p 8080:8080 -v "$(pwd)/models/ggml-base.en.bin":/models/ggml-base.en.bin litongjava/whisper-cpp-server:1.0.0 /app/whisper_http_server_base_httplib -m /models/ggml-base.en.bin 180 | ``` 181 | the port is 8080 182 | ### test 183 | ``` 184 | curl --location --request POST 'http://127.0.0.1:8080/inference' \ 185 | --header 'Accept: */*' \ 186 | --header 'Content-Type: multipart/form-data; boundary=--------------------------671827497522367123871197' \ 187 | --form 'file=@"E:\\code\\cpp\\cpp-study\\cpp-study-clion\\audio\\jfk.wav"' \ 188 | --form 'temperature="0.2"' \ 189 | --form 'response-format="json"' \ 190 | --form 'audio_format="wav"' 191 | ``` 192 | ### run whisper-cpp-server:1.0.0-base-en 193 | [Dockerfile](./distribute/docker/base.en/) 194 | ``` 195 | docker run -dit --name whisper-server -p 8080:8080 litongjava/whisper-cpp-server:1.0.0-base-en 196 | ``` 197 | 198 | ### run whisper-cpp-server:1.0.0-large-v3 199 | [Dockerfile](./distribute/docker/large-v3/) 200 | ``` 201 | docker run -dit --name whisper-server -p 8080:8080 litongjava/whisper-cpp-server:1.0.0-large-v3 202 | ``` 203 | ## run whisper-cpp-server:1.0.0-tiny.en-q5_1 204 | [Dockerfile](./distribute/docker/tiny.en-q5_1/) 205 | ``` 206 | docker run -dit --name whisper-server -p 8080:8080 litongjava/whisper-cpp-server:1.0.0-tiny.en-q5_1 207 | ``` 208 | ### Client code 209 | [Client code](./doc/client_code.md) 210 | -------------------------------------------------------------------------------- /client/readme.md: -------------------------------------------------------------------------------- 1 | 2 | ```shell 3 | pip install soundfile 4 | pip install websockets 5 | ``` 6 | ```shell script 7 | python client\websocket_client.py --server_ip 127.0.0.1 --port 8090 --wavfile samples/jfk.wav 8 | ``` 9 | only save audio 10 | ```shell script 11 | python client\websocket_client.py --server_ip 127.0.0.1 --port 8090 --endpoint /paddlespeech/streaming/save --wavfile samples/jfk.wav 12 | ``` 13 | 14 | -------------------------------------------------------------------------------- /client/websocket_client.py: -------------------------------------------------------------------------------- 1 | # python3 websocket_client.py --server_ip 127.0.0.1 --port 8090 --wavfile ./zh.wav 2 | import argparse 3 | import asyncio 4 | import codecs 5 | import functools 6 | import json 7 | import logging 8 | import os 9 | import time 10 | 11 | import numpy as np 12 | import soundfile 13 | import websockets 14 | 15 | 16 | class Logger(object): 17 | def __init__(self, name: str = None): 18 | name = 'PaddleSpeech' if not name else name 19 | self.logger = logging.getLogger(name) 20 | 21 | log_config = { 22 | 'DEBUG': 10, 23 | 'INFO': 20, 24 | 'TRAIN': 21, 25 | 'EVAL': 22, 26 | 'WARNING': 30, 27 | 'ERROR': 40, 28 | 'CRITICAL': 50, 29 | 'EXCEPTION': 100, 30 | } 31 | for key, level in log_config.items(): 32 | logging.addLevelName(level, key) 33 | if key == 'EXCEPTION': 34 | self.__dict__[key.lower()] = self.logger.exception 35 | else: 36 | self.__dict__[key.lower()] = functools.partial(self.__call__, 37 | level) 38 | 39 | self.format = logging.Formatter( 40 | fmt='[%(asctime)-15s] [%(levelname)8s] - %(message)s') 41 | 42 | self.handler = logging.StreamHandler() 43 | self.handler.setFormatter(self.format) 44 | 45 | self.logger.addHandler(self.handler) 46 | self.logger.setLevel(logging.INFO) 47 | self.logger.propagate = False 48 | 49 | def __call__(self, log_level: str, msg: str): 50 | self.logger.log(log_level, msg) 51 | 52 | 53 | class ASRWsAudioHandler: 54 | def __init__(self, 55 | logger=None, 56 | url=None, 57 | port=None, 58 | endpoint="/paddlespeech/asr/streaming", ): 59 | """Online ASR Server Client audio handler 60 | Online asr server use the websocket protocal 61 | Args: 62 | url (str, optional): the server ip. Defaults to None. 63 | port (int, optional): the server port. Defaults to None. 64 | endpoint(str, optional): to compatiable with python server and c++ server. 65 | """ 66 | self.url = url 67 | self.port = port 68 | self.logger = logger 69 | if url is None or port is None or endpoint is None: 70 | self.url = None 71 | else: 72 | self.url = "ws://" + self.url + ":" + str(self.port) + endpoint 73 | self.logger.info(f"endpoint: {self.url}") 74 | 75 | def read_wave(self, wavfile_path: str): 76 | """read the audio file from specific wavfile path 77 | 78 | Args: 79 | wavfile_path (str): the audio wavfile, 80 | we assume that audio sample rate matches the model 81 | 82 | Yields: 83 | numpy.array: the samall package audio pcm data 84 | """ 85 | samples, sample_rate = soundfile.read(wavfile_path, dtype='int16') 86 | x_len = len(samples) 87 | assert sample_rate == 16000 88 | 89 | chunk_size = int(85 * sample_rate / 1000) # 85ms, sample_rate = 16kHz 90 | 91 | if x_len % chunk_size != 0: 92 | padding_len_x = chunk_size - x_len % chunk_size 93 | else: 94 | padding_len_x = 0 95 | 96 | padding = np.zeros((padding_len_x), dtype=samples.dtype) 97 | padded_x = np.concatenate([samples, padding], axis=0) 98 | 99 | assert (x_len + padding_len_x) % chunk_size == 0 100 | num_chunk = (x_len + padding_len_x) / chunk_size 101 | num_chunk = int(num_chunk) 102 | for i in range(0, num_chunk): 103 | start = i * chunk_size 104 | end = start + chunk_size 105 | x_chunk = padded_x[start:end] 106 | yield x_chunk 107 | 108 | async def run(self, wavfile_path: str): 109 | """Send a audio file to online server 110 | 111 | Args: 112 | wavfile_path (str): audio path 113 | 114 | Returns: 115 | str: the final asr result 116 | """ 117 | logging.debug("send a message to the server") 118 | 119 | results = [] 120 | if self.url is None: 121 | self.logger.error("No asr server, please input valid ip and port") 122 | return results 123 | 124 | # 1. send websocket handshake protocal 125 | start_time = time.time() 126 | async with websockets.connect(self.url) as ws: 127 | # 2. server has already received handshake protocal 128 | # client start to send the command 129 | audio_info = json.dumps( 130 | { 131 | "name": "test.wav", 132 | "signal": "start", 133 | "nbest": 1 134 | }, 135 | sort_keys=True, 136 | indent=4, 137 | separators=(',', ': ')) 138 | await ws.send(audio_info) 139 | msg = await ws.recv() 140 | self.logger.info("client receive msg={}".format(msg)) 141 | 142 | # 3. send chunk audio data to engine 143 | for chunk_data in self.read_wave(wavfile_path): 144 | await ws.send(chunk_data.tobytes()) 145 | msg = await ws.recv() 146 | if msg: 147 | try: 148 | json_object = json.loads(msg) 149 | self.logger.info("client receive msg={}".format(json_object)) 150 | if "result" in json_object: 151 | result = json_object.get("result") 152 | for sentence in result: 153 | # print(type(sentence)) 154 | sentence = "{}->{}:{}".format(sentence['t0'], sentence['t1'], sentence['sentence']) 155 | results.append(sentence) 156 | # print(sentence) 157 | except Exception as e: 158 | self.logger.error("Unexpected error: {}".format(e)) 159 | 160 | # 4. we must send finished signal to the server 161 | audio_info = json.dumps( 162 | { 163 | "name": "test.wav", 164 | "signal": "end", 165 | "nbest": 1 166 | }, 167 | sort_keys=True, 168 | indent=4, 169 | separators=(',', ': ')) 170 | await ws.send(audio_info) 171 | msg = await ws.recv() 172 | 173 | # 5. decode the bytes to str 174 | json_object = json.loads(msg) 175 | 176 | # 6. logging the final result and comptute the statstics 177 | elapsed_time = time.time() - start_time 178 | audio_info = soundfile.info(wavfile_path) 179 | self.logger.info("client final receive msg={}".format(json_object)) 180 | 181 | # print(type(json_object)) 182 | if "result" in json_object: 183 | result = json_object.get("result") 184 | for sentence in result: 185 | # print(type(sentence)) 186 | sentence = "{}->{}:{}".format(sentence['t0'], sentence['t1'], sentence['sentence']) 187 | results.append(sentence) 188 | # print(sentence) 189 | self.logger.info( 190 | f"audio duration: {audio_info.duration}, elapsed time: {elapsed_time}, RTF={elapsed_time / audio_info.duration}" 191 | ) 192 | return results 193 | 194 | 195 | logger = Logger() 196 | 197 | 198 | def main(args): 199 | logger.info("asr websocket client start") 200 | handler = ASRWsAudioHandler( 201 | logger, 202 | args.server_ip, 203 | args.port, 204 | endpoint=args.endpoint) 205 | loop = asyncio.get_event_loop() 206 | 207 | # support to process single audio file 208 | if args.wavfile and os.path.exists(args.wavfile): 209 | logger.info(f"start to process the wavscp: {args.wavfile}") 210 | results = loop.run_until_complete(handler.run(args.wavfile)) 211 | if results: 212 | for sentence in results: 213 | print(sentence) 214 | 215 | # support to process batch audios from wav.scp 216 | if args.wavscp and os.path.exists(args.wavscp): 217 | logger.info(f"start to process the wavscp: {args.wavscp}") 218 | with codecs.open(args.wavscp, 'r', encoding='utf-8') as f, \ 219 | codecs.open("result.txt", 'w', encoding='utf-8') as w: 220 | for line in f: 221 | utt_name, utt_path = line.strip().split() 222 | result = loop.run_until_complete(handler.run(utt_path)) 223 | result = result["result"] 224 | w.write(f"{utt_name} {result}\n") 225 | 226 | 227 | if __name__ == "__main__": 228 | logger.info("Start to do streaming asr client") 229 | parser = argparse.ArgumentParser() 230 | parser.add_argument( 231 | '--server_ip', type=str, default='127.0.0.1', help='server ip') 232 | parser.add_argument('--port', type=int, default=8090, help='server port') 233 | parser.add_argument( 234 | "--endpoint", 235 | type=str, 236 | default="/paddlespeech/asr/streaming", 237 | help="ASR websocket endpoint") 238 | parser.add_argument( 239 | "--wavfile", 240 | action="store", 241 | help="wav file path ", 242 | default="./16_audio.wav") 243 | parser.add_argument( 244 | "--wavscp", type=str, default=None, help="The batch audios dict text") 245 | args = parser.parse_args() 246 | 247 | main(args) 248 | -------------------------------------------------------------------------------- /common/common-ggml.cpp: -------------------------------------------------------------------------------- 1 | #include "common-ggml.h" 2 | 3 | #include 4 | #include 5 | 6 | static const std::map GGML_FTYPE_MAP = { 7 | {"q4_0", GGML_FTYPE_MOSTLY_Q4_0}, 8 | {"q4_1", GGML_FTYPE_MOSTLY_Q4_1}, 9 | {"q5_0", GGML_FTYPE_MOSTLY_Q5_0}, 10 | {"q5_1", GGML_FTYPE_MOSTLY_Q5_1}, 11 | {"q8_0", GGML_FTYPE_MOSTLY_Q8_0}, 12 | {"q2_k", GGML_FTYPE_MOSTLY_Q2_K}, 13 | {"q3_k", GGML_FTYPE_MOSTLY_Q3_K}, 14 | {"q4_k", GGML_FTYPE_MOSTLY_Q4_K}, 15 | {"q5_k", GGML_FTYPE_MOSTLY_Q5_K}, 16 | {"q6_k", GGML_FTYPE_MOSTLY_Q6_K}, 17 | }; 18 | 19 | void ggml_print_ftypes(FILE * fp) { 20 | for (auto it = GGML_FTYPE_MAP.begin(); it != GGML_FTYPE_MAP.end(); it++) { 21 | fprintf(fp, " type = \"%s\" or %d\n", it->first.c_str(), it->second); 22 | } 23 | } 24 | 25 | enum ggml_ftype ggml_parse_ftype(const char * str) { 26 | enum ggml_ftype ftype; 27 | if (str[0] == 'q') { 28 | const auto it = GGML_FTYPE_MAP.find(str); 29 | if (it == GGML_FTYPE_MAP.end()) { 30 | fprintf(stderr, "%s: unknown ftype '%s'\n", __func__, str); 31 | return GGML_FTYPE_UNKNOWN; 32 | } 33 | ftype = it->second; 34 | } else { 35 | ftype = (enum ggml_ftype) atoi(str); 36 | } 37 | 38 | return ftype; 39 | } 40 | 41 | bool ggml_common_quantize_0( 42 | std::ifstream & finp, 43 | std::ofstream & fout, 44 | const ggml_ftype ftype, 45 | const std::vector & to_quant, 46 | const std::vector & to_skip) { 47 | 48 | ggml_type qtype = GGML_TYPE_F32; 49 | 50 | switch (ftype) { 51 | case GGML_FTYPE_MOSTLY_Q4_0: qtype = GGML_TYPE_Q4_0; break; 52 | case GGML_FTYPE_MOSTLY_Q4_1: qtype = GGML_TYPE_Q4_1; break; 53 | case GGML_FTYPE_MOSTLY_Q5_0: qtype = GGML_TYPE_Q5_0; break; 54 | case GGML_FTYPE_MOSTLY_Q5_1: qtype = GGML_TYPE_Q5_1; break; 55 | case GGML_FTYPE_MOSTLY_Q8_0: qtype = GGML_TYPE_Q8_0; break; 56 | case GGML_FTYPE_MOSTLY_Q2_K: qtype = GGML_TYPE_Q2_K; break; 57 | case GGML_FTYPE_MOSTLY_Q3_K: qtype = GGML_TYPE_Q3_K; break; 58 | case GGML_FTYPE_MOSTLY_Q4_K: qtype = GGML_TYPE_Q4_K; break; 59 | case GGML_FTYPE_MOSTLY_Q5_K: qtype = GGML_TYPE_Q5_K; break; 60 | case GGML_FTYPE_MOSTLY_Q6_K: qtype = GGML_TYPE_Q6_K; break; 61 | case GGML_FTYPE_UNKNOWN: 62 | case GGML_FTYPE_ALL_F32: 63 | case GGML_FTYPE_MOSTLY_F16: 64 | case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: 65 | { 66 | fprintf(stderr, "%s: invalid model type %d\n", __func__, ftype); 67 | return false; 68 | } 69 | }; 70 | 71 | if (!ggml_is_quantized(qtype)) { 72 | fprintf(stderr, "%s: invalid quantization type %d (%s)\n", __func__, qtype, ggml_type_name(qtype)); 73 | return false; 74 | } 75 | 76 | size_t total_size_org = 0; 77 | size_t total_size_new = 0; 78 | 79 | std::vector work; 80 | 81 | std::vector data_u8; 82 | std::vector data_f16; 83 | std::vector data_f32; 84 | 85 | std::vector hist_all(1 << 4, 0); 86 | 87 | while (true) { 88 | int32_t n_dims; 89 | int32_t length; 90 | int32_t ttype; 91 | 92 | finp.read(reinterpret_cast(&n_dims), sizeof(n_dims)); 93 | finp.read(reinterpret_cast(&length), sizeof(length)); 94 | finp.read(reinterpret_cast(&ttype), sizeof(ttype)); 95 | 96 | if (finp.eof()) { 97 | break; 98 | } 99 | 100 | int32_t nelements = 1; 101 | int32_t ne[4] = { 1, 1, 1, 1 }; 102 | for (int i = 0; i < n_dims; ++i) { 103 | finp.read (reinterpret_cast(&ne[i]), sizeof(ne[i])); 104 | nelements *= ne[i]; 105 | } 106 | 107 | std::string name(length, 0); 108 | finp.read (&name[0], length); 109 | 110 | printf("%64s - [%5d, %5d, %5d], type = %6s ", name.data(), ne[0], ne[1], ne[2], ggml_type_name((ggml_type) ttype)); 111 | 112 | bool quantize = false; 113 | 114 | // check if we should quantize this tensor 115 | for (const auto & s : to_quant) { 116 | if (std::regex_match(name, std::regex(s))) { 117 | quantize = true; 118 | break; 119 | } 120 | } 121 | 122 | // check if we should skip this tensor 123 | for (const auto & s : to_skip) { 124 | if (std::regex_match(name, std::regex(s))) { 125 | quantize = false; 126 | break; 127 | } 128 | } 129 | 130 | // quantize only 2D tensors 131 | quantize &= (n_dims == 2); 132 | 133 | if (quantize) { 134 | if (ttype != GGML_TYPE_F32 && ttype != GGML_TYPE_F16) { 135 | fprintf(stderr, "%s: unsupported ttype %d (%s) for integer quantization\n", __func__, ttype, ggml_type_name((ggml_type) ttype)); 136 | return false; 137 | } 138 | 139 | if (ttype == GGML_TYPE_F16) { 140 | data_f16.resize(nelements); 141 | finp.read(reinterpret_cast(data_f16.data()), nelements * sizeof(ggml_fp16_t)); 142 | data_f32.resize(nelements); 143 | for (int i = 0; i < nelements; ++i) { 144 | data_f32[i] = ggml_fp16_to_fp32(data_f16[i]); 145 | } 146 | } else { 147 | data_f32.resize(nelements); 148 | finp.read(reinterpret_cast(data_f32.data()), nelements * sizeof(float)); 149 | } 150 | 151 | ttype = qtype; 152 | } else { 153 | const int bpe = (ttype == 0) ? sizeof(float) : sizeof(uint16_t); 154 | 155 | data_u8.resize(nelements*bpe); 156 | finp.read(reinterpret_cast(data_u8.data()), nelements * bpe); 157 | } 158 | 159 | fout.write(reinterpret_cast(&n_dims), sizeof(n_dims)); 160 | fout.write(reinterpret_cast(&length), sizeof(length)); 161 | fout.write(reinterpret_cast(&ttype), sizeof(ttype)); 162 | for (int i = 0; i < n_dims; ++i) { 163 | fout.write(reinterpret_cast(&ne[i]), sizeof(ne[i])); 164 | } 165 | fout.write(&name[0], length); 166 | 167 | if (quantize) { 168 | work.resize(nelements); // for quantization 169 | 170 | size_t cur_size = 0; 171 | std::vector hist_cur(1 << 4, 0); 172 | 173 | switch ((ggml_type) ttype) { 174 | case GGML_TYPE_Q4_0: 175 | case GGML_TYPE_Q4_1: 176 | case GGML_TYPE_Q5_0: 177 | case GGML_TYPE_Q5_1: 178 | case GGML_TYPE_Q8_0: 179 | case GGML_TYPE_Q2_K: 180 | case GGML_TYPE_Q3_K: 181 | case GGML_TYPE_Q4_K: 182 | case GGML_TYPE_Q5_K: 183 | case GGML_TYPE_Q6_K: 184 | { 185 | cur_size = ggml_quantize_chunk((ggml_type) ttype, data_f32.data(), work.data(), 0, nelements, hist_cur.data()); 186 | } break; 187 | case GGML_TYPE_F32: 188 | case GGML_TYPE_F16: 189 | case GGML_TYPE_I8: 190 | case GGML_TYPE_I16: 191 | case GGML_TYPE_I32: 192 | case GGML_TYPE_Q8_1: 193 | case GGML_TYPE_Q8_K: 194 | case GGML_TYPE_COUNT: 195 | { 196 | fprintf(stderr, "%s: unsupported quantization type %d (%s)\n", __func__, ttype, ggml_type_name((ggml_type) ttype)); 197 | return false; 198 | } 199 | } 200 | 201 | fout.write(reinterpret_cast(work.data()), cur_size); 202 | total_size_new += cur_size; 203 | 204 | printf("size = %8.2f MB -> %8.2f MB | hist: ", nelements * sizeof(float)/1024.0/1024.0, cur_size/1024.0/1024.0); 205 | for (int i = 0; i < (int) hist_cur.size(); ++i) { 206 | hist_all[i] += hist_cur[i]; 207 | } 208 | 209 | for (int i = 0; i < (int) hist_cur.size(); ++i) { 210 | printf("%5.3f ", hist_cur[i] / (float)nelements); 211 | } 212 | printf("\n"); 213 | } else { 214 | printf("size = %8.3f MB\n", data_u8.size()/1024.0/1024.0); 215 | fout.write(reinterpret_cast(data_u8.data()), data_u8.size()); 216 | total_size_new += data_u8.size(); 217 | } 218 | 219 | total_size_org += nelements * sizeof(float); 220 | } 221 | 222 | printf("%s: model size = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0); 223 | printf("%s: quant size = %8.2f MB | ftype = %d (%s)\n", __func__, total_size_new/1024.0/1024.0, ftype, ggml_type_name(qtype)); 224 | 225 | { 226 | int64_t sum_all = 0; 227 | for (int i = 0; i < (int) hist_all.size(); ++i) { 228 | sum_all += hist_all[i]; 229 | } 230 | 231 | printf("%s: hist: ", __func__); 232 | for (int i = 0; i < (int) hist_all.size(); ++i) { 233 | printf("%5.3f ", hist_all[i] / (float)sum_all); 234 | } 235 | printf("\n"); 236 | } 237 | 238 | return true; 239 | } 240 | -------------------------------------------------------------------------------- /common/common-ggml.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "ggml.h" 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | enum ggml_ftype ggml_parse_ftype(const char * str); 10 | 11 | void ggml_print_ftypes(FILE * fp = stderr); 12 | 13 | bool ggml_common_quantize_0( 14 | std::ifstream & finp, 15 | std::ofstream & fout, 16 | const ggml_ftype ftype, 17 | const std::vector & to_quant, 18 | const std::vector & to_skip); 19 | -------------------------------------------------------------------------------- /common/common-m4a.cpp: -------------------------------------------------------------------------------- 1 | #include "common.h" 2 | #include "common-m4a.h" 3 | extern "C" { 4 | #include 5 | #include 6 | #include 7 | } 8 | 9 | #include 10 | #include 11 | 12 | bool read_m4a(const std::string &fname, std::vector &pcmf32, std::vector> &pcmf32s, 13 | bool stereo) { 14 | avformat_network_init(); 15 | 16 | AVFormatContext *formatContext = avformat_alloc_context(); 17 | if (avformat_open_input(&formatContext, fname.c_str(), nullptr, nullptr) != 0) { 18 | fprintf(stderr, "Could not open file %s\n", fname.c_str()); 19 | return false; 20 | } 21 | 22 | if (avformat_find_stream_info(formatContext, nullptr) < 0) { 23 | fprintf(stderr, "Could not find stream information\n"); 24 | avformat_close_input(&formatContext); 25 | return false; 26 | } 27 | 28 | const AVCodec *codec = nullptr; 29 | int streamIndex = av_find_best_stream(formatContext, AVMEDIA_TYPE_AUDIO, -1, -1, &codec, 0); 30 | if (streamIndex < 0) { 31 | fprintf(stderr, "Could not find any audio stream in the file\n"); 32 | avformat_close_input(&formatContext); 33 | return false; 34 | } 35 | 36 | AVCodecContext *codecContext = avcodec_alloc_context3(codec); 37 | avcodec_parameters_to_context(codecContext, formatContext->streams[streamIndex]->codecpar); 38 | 39 | if (avcodec_open2(codecContext, codec, nullptr) < 0) { 40 | fprintf(stderr, "Could not open codec\n"); 41 | avcodec_free_context(&codecContext); 42 | avformat_close_input(&formatContext); 43 | return false; 44 | } 45 | 46 | //bool need_resample = (codecContext->sample_rate != COMMON_SAMPLE_RATE); 47 | SwrContext *swrCtx = nullptr; 48 | swrCtx = swr_alloc_set_opts(nullptr, 49 | stereo ? AV_CH_LAYOUT_STEREO : AV_CH_LAYOUT_MONO, 50 | AV_SAMPLE_FMT_FLT, 51 | COMMON_SAMPLE_RATE, 52 | codecContext->channel_layout, 53 | codecContext->sample_fmt, 54 | codecContext->sample_rate, 55 | 0, nullptr); 56 | if (!swrCtx || swr_init(swrCtx) < 0) { 57 | fprintf(stderr, "Could not initialize the resampling context\n"); 58 | swr_free(&swrCtx); 59 | avcodec_free_context(&codecContext); 60 | avformat_close_input(&formatContext); 61 | return false; 62 | } 63 | 64 | 65 | AVPacket packet; 66 | av_init_packet(&packet); 67 | packet.data = nullptr; 68 | packet.size = 0; 69 | 70 | AVFrame *frame = av_frame_alloc(); 71 | 72 | while (av_read_frame(formatContext, &packet) >= 0) { 73 | if (packet.stream_index == streamIndex) { 74 | //decode 75 | int ret = avcodec_send_packet(codecContext, &packet); 76 | if (ret < 0) { 77 | fprintf(stderr, "Error sending packet for decoding\n"); 78 | break; 79 | } 80 | 81 | while (ret >= 0) { 82 | ret = avcodec_receive_frame(codecContext, frame); 83 | if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) { 84 | break; 85 | } else if (ret < 0) { 86 | fprintf(stderr, "Error during decoding\n"); 87 | break; 88 | } 89 | 90 | // Direct processing of decoded frames 91 | uint8_t *out_buf[2] = {nullptr, nullptr}; 92 | int out_channels = stereo ? 2 : 1; 93 | int out_samples = av_rescale_rnd(swr_get_delay(swrCtx, codecContext->sample_rate) + frame->nb_samples, 94 | COMMON_SAMPLE_RATE, codecContext->sample_rate, AV_ROUND_UP); 95 | av_samples_alloc(out_buf, nullptr, out_channels, out_samples, AV_SAMPLE_FMT_FLT, 0); 96 | swr_convert(swrCtx, out_buf, out_samples, (const uint8_t **) frame->data, frame->nb_samples); 97 | 98 | int data_size = av_samples_get_buffer_size(nullptr, out_channels, out_samples, AV_SAMPLE_FMT_FLT, 0); 99 | for (int i = 0; i < data_size / sizeof(float); ++i) { 100 | pcmf32.push_back(((float *) out_buf[0])[i]); 101 | if (stereo && out_buf[1] != nullptr) { 102 | pcmf32s[0].push_back(((float *) out_buf[0])[i]); 103 | pcmf32s[1].push_back(((float *) out_buf[1])[i]); 104 | } 105 | } 106 | 107 | if (out_buf[0]) { 108 | av_freep(&out_buf[0]); 109 | } 110 | if (stereo && out_buf[1]) { 111 | av_freep(&out_buf[1]); 112 | } 113 | 114 | av_frame_unref(frame); 115 | } 116 | av_packet_unref(&packet); 117 | } 118 | av_packet_unref(&packet); 119 | } 120 | 121 | // Clean up 122 | av_frame_free(&frame); 123 | swr_free(&swrCtx); 124 | avcodec_free_context(&codecContext); 125 | avformat_close_input(&formatContext); 126 | avformat_network_deinit(); 127 | 128 | return true; 129 | } -------------------------------------------------------------------------------- /common/common-m4a.h: -------------------------------------------------------------------------------- 1 | #ifndef WHISPER_CPP_SERVER_COMMON_M4A_H 2 | #define WHISPER_CPP_SERVER_COMMON_M4A_H 3 | #include 4 | #include 5 | bool read_m4a(const std::string &fname, std::vector &pcmf32, std::vector> &pcmf32s, 6 | bool stereo); 7 | #endif //WHISPER_CPP_SERVER_COMMON_M4A_H 8 | -------------------------------------------------------------------------------- /common/common-sdl.cpp: -------------------------------------------------------------------------------- 1 | #include "common-sdl.h" 2 | 3 | audio_async::audio_async(int len_ms) { 4 | m_len_ms = len_ms; 5 | 6 | m_running = false; 7 | } 8 | 9 | audio_async::~audio_async() { 10 | if (m_dev_id_in) { 11 | SDL_CloseAudioDevice(m_dev_id_in); 12 | } 13 | } 14 | 15 | bool audio_async::init(int capture_id, int sample_rate) { 16 | SDL_LogSetPriority(SDL_LOG_CATEGORY_APPLICATION, SDL_LOG_PRIORITY_INFO); 17 | 18 | if (SDL_Init(SDL_INIT_AUDIO) < 0) { 19 | SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "Couldn't initialize SDL: %s\n", SDL_GetError()); 20 | return false; 21 | } 22 | 23 | SDL_SetHintWithPriority(SDL_HINT_AUDIO_RESAMPLING_MODE, "medium", SDL_HINT_OVERRIDE); 24 | 25 | { 26 | int nDevices = SDL_GetNumAudioDevices(SDL_TRUE); 27 | fprintf(stderr, "%s: found %d capture devices:\n", __func__, nDevices); 28 | for (int i = 0; i < nDevices; i++) { 29 | fprintf(stderr, "%s: - Capture device #%d: '%s'\n", __func__, i, SDL_GetAudioDeviceName(i, SDL_TRUE)); 30 | } 31 | } 32 | 33 | SDL_AudioSpec capture_spec_requested; 34 | SDL_AudioSpec capture_spec_obtained; 35 | 36 | SDL_zero(capture_spec_requested); 37 | SDL_zero(capture_spec_obtained); 38 | 39 | capture_spec_requested.freq = sample_rate; 40 | capture_spec_requested.format = AUDIO_F32; 41 | capture_spec_requested.channels = 1; 42 | capture_spec_requested.samples = 1024; 43 | capture_spec_requested.callback = [](void *userdata, uint8_t *stream, int len) { 44 | audio_async *audio = (audio_async *) userdata; 45 | audio->callback(stream, len); 46 | }; 47 | capture_spec_requested.userdata = this; 48 | 49 | if (capture_id >= 0) { 50 | fprintf(stderr, "%s: attempt to open capture device %d : '%s' ...\n", __func__, capture_id, 51 | SDL_GetAudioDeviceName(capture_id, SDL_TRUE)); 52 | m_dev_id_in = SDL_OpenAudioDevice(SDL_GetAudioDeviceName(capture_id, SDL_TRUE), SDL_TRUE, &capture_spec_requested, 53 | &capture_spec_obtained, 0); 54 | } else { 55 | fprintf(stderr, "%s: attempt to open default capture device ...\n", __func__); 56 | m_dev_id_in = SDL_OpenAudioDevice(nullptr, SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0); 57 | } 58 | 59 | if (!m_dev_id_in) { 60 | fprintf(stderr, "%s: couldn't open an audio device for capture: %s!\n", __func__, SDL_GetError()); 61 | m_dev_id_in = 0; 62 | 63 | return false; 64 | } else { 65 | if (capture_id > -1) { 66 | const char *deviceName = SDL_GetAudioDeviceName(capture_id, SDL_TRUE); 67 | fprintf(stderr, "%s: opened device name:'%s'\n", __func__, deviceName); 68 | } 69 | fprintf(stderr, "%s: obtained spec for input device (SDL Id = %d):\n", __func__, m_dev_id_in); 70 | fprintf(stderr, "%s: - sample rate: %d\n", __func__, capture_spec_obtained.freq); 71 | fprintf(stderr, "%s: - format: %d (required: %d)\n", __func__, capture_spec_obtained.format, 72 | capture_spec_requested.format); 73 | fprintf(stderr, "%s: - channels: %d (required: %d)\n", __func__, capture_spec_obtained.channels, 74 | capture_spec_requested.channels); 75 | fprintf(stderr, "%s: - samples per frame: %d\n", __func__, capture_spec_obtained.samples); 76 | } 77 | 78 | 79 | m_sample_rate = capture_spec_obtained.freq; 80 | 81 | m_audio.resize((m_sample_rate * m_len_ms) / 1000); 82 | 83 | return true; 84 | } 85 | 86 | bool audio_async::resume() { 87 | if (!m_dev_id_in) { 88 | fprintf(stderr, "%s: no audio device to resume!\n", __func__); 89 | return false; 90 | } 91 | 92 | if (m_running) { 93 | fprintf(stderr, "%s: already running!\n", __func__); 94 | return false; 95 | } 96 | 97 | SDL_PauseAudioDevice(m_dev_id_in, 0); 98 | 99 | m_running = true; 100 | 101 | return true; 102 | } 103 | 104 | bool audio_async::pause() { 105 | if (!m_dev_id_in) { 106 | fprintf(stderr, "%s: no audio device to pause!\n", __func__); 107 | return false; 108 | } 109 | 110 | if (!m_running) { 111 | fprintf(stderr, "%s: already paused!\n", __func__); 112 | return false; 113 | } 114 | 115 | SDL_PauseAudioDevice(m_dev_id_in, 1); 116 | 117 | m_running = false; 118 | 119 | return true; 120 | } 121 | 122 | bool audio_async::clear() { 123 | if (!m_dev_id_in) { 124 | fprintf(stderr, "%s: no audio device to clear!\n", __func__); 125 | return false; 126 | } 127 | 128 | if (!m_running) { 129 | fprintf(stderr, "%s: not running!\n", __func__); 130 | return false; 131 | } 132 | 133 | { 134 | std::lock_guard lock(m_mutex); 135 | 136 | m_audio_pos = 0; 137 | m_audio_len = 0; 138 | } 139 | 140 | return true; 141 | } 142 | 143 | // callback to be called by SDL 144 | void audio_async::callback(uint8_t *stream, int len) { 145 | if (!m_running) { 146 | return; 147 | } 148 | 149 | const size_t n_samples = len / sizeof(float); 150 | 151 | m_audio_new.resize(n_samples); 152 | memcpy(m_audio_new.data(), stream, n_samples * sizeof(float)); 153 | 154 | //fprintf(stderr, "%s: %zu samples, pos %zu, len %zu\n", __func__, n_samples, m_audio_pos, m_audio_len); 155 | 156 | { 157 | std::lock_guard lock(m_mutex); 158 | 159 | if (m_audio_pos + n_samples > m_audio.size()) { 160 | const size_t n0 = m_audio.size() - m_audio_pos; 161 | 162 | memcpy(&m_audio[m_audio_pos], stream, n0 * sizeof(float)); 163 | memcpy(&m_audio[0], &stream[n0], (n_samples - n0) * sizeof(float)); 164 | 165 | m_audio_pos = (m_audio_pos + n_samples) % m_audio.size(); 166 | m_audio_len = m_audio.size(); 167 | } else { 168 | memcpy(&m_audio[m_audio_pos], stream, n_samples * sizeof(float)); 169 | 170 | m_audio_pos = (m_audio_pos + n_samples) % m_audio.size(); 171 | m_audio_len = std::min(m_audio_len + n_samples, m_audio.size()); 172 | } 173 | } 174 | } 175 | 176 | void audio_async::get(int ms, std::vector &result) { 177 | if (!m_dev_id_in) { 178 | fprintf(stderr, "%s: no audio device to get audio from!\n", __func__); 179 | return; 180 | } 181 | 182 | if (!m_running) { 183 | fprintf(stderr, "%s: not running!\n", __func__); 184 | return; 185 | } 186 | 187 | result.clear(); 188 | 189 | { 190 | std::lock_guard lock(m_mutex); 191 | 192 | if (ms <= 0) { 193 | ms = m_len_ms; 194 | } 195 | 196 | size_t n_samples = (m_sample_rate * ms) / 1000; 197 | if (n_samples > m_audio_len) { 198 | n_samples = m_audio_len; 199 | } 200 | 201 | result.resize(n_samples); 202 | 203 | int s0 = m_audio_pos - n_samples; 204 | if (s0 < 0) { 205 | s0 += m_audio.size(); 206 | } 207 | 208 | if (s0 + n_samples > m_audio.size()) { 209 | const size_t n0 = m_audio.size() - s0; 210 | 211 | memcpy(result.data(), &m_audio[s0], n0 * sizeof(float)); 212 | memcpy(&result[n0], &m_audio[0], (n_samples - n0) * sizeof(float)); 213 | } else { 214 | memcpy(result.data(), &m_audio[s0], n_samples * sizeof(float)); 215 | } 216 | } 217 | } 218 | 219 | bool sdl_poll_events() { 220 | SDL_Event event; 221 | while (SDL_PollEvent(&event)) { 222 | switch (event.type) { 223 | case SDL_QUIT: { 224 | return false; 225 | } 226 | break; 227 | default: 228 | break; 229 | } 230 | } 231 | 232 | return true; 233 | } 234 | -------------------------------------------------------------------------------- /common/common-sdl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "SDL2/SDL.h" 4 | #include "SDL2/SDL_audio.h" 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | // 12 | // SDL Audio capture 13 | // 14 | 15 | class audio_async { 16 | public: 17 | audio_async(int len_ms); 18 | 19 | ~audio_async(); 20 | 21 | bool init(int capture_id, int sample_rate); 22 | 23 | // start capturing audio via the provided SDL callback 24 | // keep last len_ms seconds of audio in a circular buffer 25 | bool resume(); 26 | 27 | bool pause(); 28 | 29 | bool clear(); 30 | 31 | // callback to be called by SDL 32 | void callback(uint8_t *stream, int len); 33 | 34 | // get audio data from the circular buffer 35 | void get(int ms, std::vector &audio); 36 | 37 | private: 38 | SDL_AudioDeviceID m_dev_id_in = 0; 39 | 40 | int m_len_ms = 0; 41 | int m_sample_rate = 0; 42 | 43 | std::atomic_bool m_running; 44 | std::mutex m_mutex; 45 | 46 | std::vector m_audio; 47 | std::vector m_audio_new; 48 | size_t m_audio_pos = 0; 49 | size_t m_audio_len = 0; 50 | }; 51 | 52 | // Return false if need to quit 53 | bool sdl_poll_events(); 54 | -------------------------------------------------------------------------------- /common/common.h: -------------------------------------------------------------------------------- 1 | // Various helper functions and utilities 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | #define COMMON_SAMPLE_RATE 16000 14 | 15 | // 16 | // GPT CLI argument parsing 17 | // 18 | 19 | struct gpt_params { 20 | int32_t seed = -1; // RNG seed 21 | int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); 22 | int32_t n_predict = 200; // new tokens to predict 23 | int32_t n_parallel = 1; // number of parallel streams 24 | int32_t n_batch = 8; // batch size for prompt processing 25 | int32_t n_ctx = 2048; // context size (this is the KV cache max size) 26 | int32_t n_gpu_layers = 0; // number of layers to offlload to the GPU 27 | 28 | bool ignore_eos = false; // ignore EOS token when generating text 29 | 30 | // sampling parameters 31 | int32_t top_k = 40; 32 | float top_p = 0.9f; 33 | float temp = 0.9f; 34 | int32_t repeat_last_n = 64; 35 | float repeat_penalty = 1.00f; 36 | 37 | std::string model = "models/gpt-2-117M/ggml-model.bin"; // model path 38 | std::string prompt = ""; 39 | std::string token_test = ""; 40 | 41 | bool interactive = false; 42 | int32_t interactive_port = -1; 43 | }; 44 | 45 | 46 | bool gpt_params_parse(int argc, char **argv, gpt_params ¶ms); 47 | 48 | void gpt_print_usage(int argc, char **argv, const gpt_params ¶ms); 49 | 50 | std::string gpt_random_prompt(std::mt19937 &rng); 51 | 52 | // 53 | // Vocab utils 54 | // 55 | 56 | std::string trim(const std::string &s); 57 | 58 | std::string replace( 59 | const std::string &s, 60 | const std::string &from, 61 | const std::string &to); 62 | 63 | struct gpt_vocab { 64 | using id = int32_t; 65 | using token = std::string; 66 | 67 | std::map token_to_id; 68 | std::map id_to_token; 69 | std::vector special_tokens; 70 | 71 | void add_special_token(const std::string &token); 72 | }; 73 | 74 | // poor-man's JSON parsing 75 | std::map json_parse(const std::string &fname); 76 | 77 | std::string convert_to_utf8(const std::wstring &input); 78 | 79 | std::wstring convert_to_wstring(const std::string &input); 80 | 81 | void gpt_split_words(std::string str, std::vector &words); 82 | 83 | // split text into tokens 84 | // 85 | // ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53 86 | // 87 | // Regex (Python): 88 | // r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" 89 | // 90 | // Regex (C++): 91 | // R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)" 92 | // 93 | std::vector gpt_tokenize(const gpt_vocab &vocab, const std::string &text); 94 | 95 | // test outputs of gpt_tokenize 96 | // 97 | // - compare with tokens generated by the huggingface tokenizer 98 | // - test cases are chosen based on the model's main language (under 'prompt' directory) 99 | // - if all sentences are tokenized identically, print 'All tests passed.' 100 | // - otherwise, print sentence, huggingface tokens, ggml tokens 101 | // 102 | void test_gpt_tokenizer(gpt_vocab &vocab, const std::string &fpath_test); 103 | 104 | // load the tokens from encoder.json 105 | bool gpt_vocab_init(const std::string &fname, gpt_vocab &vocab); 106 | 107 | // sample next token given probabilities for each embedding 108 | // 109 | // - consider only the top K tokens 110 | // - from them, consider only the top tokens with cumulative probability > P 111 | // 112 | // TODO: not sure if this implementation is correct 113 | // TODO: temperature is not implemented 114 | // 115 | gpt_vocab::id gpt_sample_top_k_top_p( 116 | const gpt_vocab &vocab, 117 | const float *logits, 118 | int top_k, 119 | double top_p, 120 | double temp, 121 | std::mt19937 &rng); 122 | 123 | gpt_vocab::id gpt_sample_top_k_top_p_repeat( 124 | const gpt_vocab &vocab, 125 | const float *logits, 126 | const int32_t *last_n_tokens_data, 127 | size_t last_n_tokens_data_size, 128 | int top_k, 129 | double top_p, 130 | double temp, 131 | int repeat_last_n, 132 | float repeat_penalty, 133 | std::mt19937 &rng); 134 | 135 | // 136 | // Audio utils 137 | // 138 | bool resample(const float *input, size_t inputSampleRate, size_t inputSize, 139 | std::vector &output, size_t outputSampleRate); 140 | // Read WAV audio file and store the PCM data into pcmf32 141 | // The sample rate of the audio must be equal to COMMON_SAMPLE_RATE 142 | // If stereo flag is set and the audio has 2 channels, the pcmf32s will contain 2 channel PCM 143 | bool read_wav( 144 | const std::string &fname, 145 | std::vector &pcmf32, 146 | std::vector> &pcmf32s, 147 | bool stereo); 148 | bool read_mp3(const std::string &fname, std::vector &pcmf32, bool stereo); 149 | // Write PCM data into WAV audio file 150 | class wav_writer { 151 | private: 152 | std::ofstream fstream; 153 | uint32_t dataSize = 0; 154 | std::string wav_filename; 155 | 156 | bool write_header(const uint32_t sample_rate, 157 | const uint16_t bits_per_sample, 158 | const uint16_t channels) { 159 | 160 | fstream.write("RIFF", 4); 161 | fstream.write("\0\0\0\0", 4); // Placeholder for file size 162 | fstream.write("WAVE", 4); 163 | fstream.write("fmt ", 4); 164 | 165 | const uint32_t sub_chunk_size = 16; 166 | const uint16_t audio_format = 1; // PCM format 167 | const uint32_t byte_rate = sample_rate * channels * bits_per_sample / 8; 168 | const uint16_t block_align = channels * bits_per_sample / 8; 169 | 170 | fstream.write(reinterpret_cast(&sub_chunk_size), 4); 171 | fstream.write(reinterpret_cast(&audio_format), 2); 172 | fstream.write(reinterpret_cast(&channels), 2); 173 | fstream.write(reinterpret_cast(&sample_rate), 4); 174 | fstream.write(reinterpret_cast(&byte_rate), 4); 175 | fstream.write(reinterpret_cast(&block_align), 2); 176 | fstream.write(reinterpret_cast(&bits_per_sample), 2); 177 | fstream.write("data", 4); 178 | fstream.write("\0\0\0\0", 4); // Placeholder for data size 179 | 180 | return true; 181 | } 182 | 183 | // It is assumed that PCM data is normalized to a range from -1 to 1 184 | bool write_audio(const int16_t *data, size_t length) { 185 | for (size_t i = 0; i < length; ++i) { 186 | // Ensure that the data is in the range of -1 to 1 187 | fstream.write(reinterpret_cast(&data[i]), sizeof(int16_t)); 188 | dataSize += sizeof(int16_t); 189 | 190 | // Check if write was successful 191 | if (!fstream.good()) { 192 | fprintf(stderr, "Error writing to WAV file\n"); 193 | return false; // Stop writing and return an error 194 | } 195 | } 196 | if (fstream.is_open()) { 197 | fstream.seekp(4, std::ios::beg); 198 | uint32_t fileSize = 36 + dataSize; 199 | fstream.write(reinterpret_cast(&fileSize), 4); 200 | fstream.seekp(40, std::ios::beg); 201 | fstream.write(reinterpret_cast(&dataSize), 4); 202 | fstream.seekp(0, std::ios::end); 203 | } 204 | return true; 205 | } 206 | 207 | bool open_wav(const std::string &filename) { 208 | if (filename != wav_filename) { 209 | if (fstream.is_open()) { 210 | fstream.close(); 211 | } 212 | } 213 | if (!fstream.is_open()) { 214 | fstream.open(filename, std::ios::binary); 215 | wav_filename = filename; 216 | dataSize = 0; 217 | } 218 | return fstream.is_open(); 219 | } 220 | 221 | public: 222 | bool open(const std::string &filename, 223 | const uint32_t sample_rate, 224 | const uint16_t bits_per_sample, 225 | const uint16_t channels) { 226 | 227 | if (open_wav(filename)) { 228 | write_header(sample_rate, bits_per_sample, channels); 229 | } else { 230 | return false; 231 | } 232 | 233 | return true; 234 | } 235 | 236 | bool close() { 237 | fstream.close(); 238 | return true; 239 | } 240 | 241 | bool write(const int16_t *data, size_t length) { 242 | return write_audio(data, length); 243 | } 244 | 245 | ~wav_writer() { 246 | if (fstream.is_open()) { 247 | fstream.close(); 248 | } 249 | } 250 | }; 251 | 252 | 253 | // Apply a high-pass frequency filter to PCM audio 254 | // Suppresses frequencies below cutoff Hz 255 | void high_pass_filter( 256 | std::vector &data, 257 | float cutoff, 258 | float sample_rate); 259 | 260 | // Basic voice activity detection (VAD) using audio energy adaptive threshold 261 | bool vad_simple( 262 | std::vector &pcmf32, 263 | int sample_rate, 264 | int last_ms, 265 | float vad_thold, 266 | float freq_thold, 267 | bool verbose); 268 | 269 | // compute similarity between two strings using Levenshtein distance 270 | float similarity(const std::string &s0, const std::string &s1); 271 | 272 | // 273 | // SAM argument parsing 274 | // 275 | 276 | struct sam_params { 277 | int32_t seed = -1; // RNG seed 278 | int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); 279 | 280 | std::string model = "models/sam-vit-b/ggml-model-f16.bin"; // model path 281 | std::string fname_inp = "img.jpg"; 282 | std::string fname_out = "img.out"; 283 | }; 284 | 285 | bool sam_params_parse(int argc, char **argv, sam_params ¶ms); 286 | 287 | void sam_print_usage(int argc, char **argv, const sam_params ¶ms); 288 | -------------------------------------------------------------------------------- /common/utils.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by Ping Lee on 2023/11/21. 3 | // 4 | 5 | #include "utils.h" 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | 12 | std::string get_current_time() { 13 | auto now = std::chrono::system_clock::now(); 14 | auto now_c = std::chrono::system_clock::to_time_t(now); 15 | auto milliseconds = std::chrono::duration_cast(now.time_since_epoch()) % 1000; 16 | 17 | std::stringstream current_time_ss; 18 | current_time_ss << std::put_time(std::localtime(&now_c), "%y-%m-%d %H:%M:%S") 19 | << '.' << std::setfill('0') << std::setw(3) << milliseconds.count(); 20 | 21 | std::string current_time = current_time_ss.str(); 22 | return current_time; 23 | } 24 | 25 | long get_current_time_millis(){ 26 | auto start = std::chrono::system_clock::now(); 27 | return std::chrono::duration_cast(start.time_since_epoch()).count(); 28 | } 29 | 30 | // 500 -> 00:05.000 31 | // 6000 -> 01:00.000 32 | std::string to_timestamp(int64_t t, bool comma) { 33 | int64_t msec = t * 10; 34 | int64_t hr = msec / (1000 * 60 * 60); 35 | msec = msec - hr * (1000 * 60 * 60); 36 | int64_t min = msec / (1000 * 60); 37 | msec = msec - min * (1000 * 60); 38 | int64_t sec = msec / 1000; 39 | msec = msec - sec * 1000; 40 | 41 | char buf[32]; 42 | snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec); 43 | 44 | return std::string(buf); 45 | } 46 | 47 | nlohmann::json get_result(whisper_context *ctx) { 48 | nlohmann::json results = nlohmann::json(nlohmann::json::array()); 49 | const int n_segments = whisper_full_n_segments(ctx); 50 | for (int i = 0; i < n_segments; ++i) { 51 | nlohmann::json segment; 52 | int64_t t0 = whisper_full_get_segment_t0(ctx, i); 53 | int64_t t1 = whisper_full_get_segment_t1(ctx, i); 54 | const char *sentence = whisper_full_get_segment_text(ctx, i); 55 | auto result = std::to_string(t0) + "-->" + std::to_string(t1) + ":" + sentence + "\n"; 56 | // printf("%s: result:%s\n", get_current_time().c_str(), result.c_str()); 57 | segment["t0"] = to_timestamp(t0); 58 | segment["t1"] = to_timestamp(t1); 59 | segment["sentence"] = sentence; 60 | results.push_back(segment); 61 | } 62 | return results; 63 | } 64 | -------------------------------------------------------------------------------- /common/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include "nlohmann/json.hpp" 6 | 7 | std::string get_current_time(); 8 | long get_current_time_millis(); 9 | std::string to_timestamp(int64_t t, bool comma = false); 10 | nlohmann::json get_result(whisper_context *ctx); -------------------------------------------------------------------------------- /distribute/docker/base.en/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM litongjava/whisper-cpp-server:1.0.0 2 | 3 | COPY models/ggml-base.en.bin /app/models/ 4 | 5 | EXPOSE 8080 6 | 7 | CMD ["/app/whisper_http_server_base_httplib", "-m", "/app/models/ggml-base.en.bin"] -------------------------------------------------------------------------------- /distribute/docker/base.en/readme.md: -------------------------------------------------------------------------------- 1 | ```shell 2 | docker build -t litongjava/whisper-cpp-server:1.0.0-base-en -f distribute/docker/base.en/Dockerfile . 3 | ``` 4 | 5 | ``` 6 | docker run -dit --name whisper-server -p 8080:8080 litongjava/whisper-cpp-server:1.0.0-base-en 7 | ``` -------------------------------------------------------------------------------- /distribute/docker/builder/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM litongjava/whisper-cpp-server-builder:1.0.0 as builder 2 | 3 | WORKDIR /src 4 | RUN cmake -B cmake-build-release-linux 5 | RUN cmake --build cmake-build-release-linux --config Release -- -j $(nproc) 6 | -------------------------------------------------------------------------------- /distribute/docker/large-v3/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM litongjava/whisper-cpp-server:1.0.0 2 | 3 | COPY models/ggml-large-v3.bin /app/models/ 4 | 5 | EXPOSE 8080 6 | 7 | CMD ["/app/whisper_http_server_base_httplib", "-m", "/app/models/ggml-large-v3.bin"] -------------------------------------------------------------------------------- /distribute/docker/large-v3/readme.md: -------------------------------------------------------------------------------- 1 | ```shell 2 | docker build -t litongjava/whisper-cpp-server:1.0.0-large-v3 -f distribute/docker/large-v3/Dockerfile . 3 | ``` 4 | 5 | ## test run 6 | ``` 7 | docker run -dit --name whisper-server -p 8080:8080 litongjava/whisper-cpp-server:1.0.0-large-v3 8 | ``` 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /distribute/docker/medium-q5_0/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM litongjava/whisper-cpp-server:1.0.0 2 | 3 | COPY models/ggml-medium-q5_0.bin /app/models/ 4 | 5 | EXPOSE 8080 6 | 7 | CMD ["/app/whisper_http_server_base_httplib", "-m", "/app/models/ggml-medium-q5_0.bin"] -------------------------------------------------------------------------------- /distribute/docker/medium-q5_0/readme.md: -------------------------------------------------------------------------------- 1 | ```shell 2 | docker build -t litongjava/whisper-cpp-server:1.0.0-medium-q5_0 -f distribute/docker/medium-q5_0/Dockerfile . 3 | ``` 4 | 5 | ``` 6 | docker run -dit --name whisper-server -p 8080:8080 litongjava/whisper-cpp-server:1.0.0-medium-q5_0 7 | ``` -------------------------------------------------------------------------------- /distribute/docker/pure/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:latest 2 | 3 | WORKDIR /app 4 | 5 | COPY cmake-build-release/ggml-metal.metal /app/ 6 | COPY cmake-build-release/simplest /app/ 7 | COPY cmake-build-release/whisper_http_server_base_httplib /app/ 8 | COPY cmake-build-release/whisper_server_base_on_uwebsockets /app/ 9 | COPY cmake-build-release/thirdparty/whisper.cpp/libwhisper.so /lib -------------------------------------------------------------------------------- /distribute/docker/pure/readme.md: -------------------------------------------------------------------------------- 1 | ## build 2 | ```shell 3 | docker build -t litongjava/whisper-cpp-server:1.0.0 -f distribute/docker/pure/Dockerfile . 4 | ``` 5 | ## test 6 | ### test with ggml-base.en.bin 7 | ``` 8 | docker run --rm \ 9 | -v "$(pwd)/models":/models \ 10 | -v "$(pwd)/samples/jfk.wav":/jfk.wav \ 11 | litongjava/whisper-cpp-server:1.0.0 /app/simplest -m /models/ggml-base.en.bin /jfk.wav 12 | ``` 13 | 14 | ``` 15 | docker run --rm -v "$(pwd)/models/ggml-base.en.bin":/models/ggml-base.en.bin -v "$(pwd)/samples/zh.wav":/samples/zh.wav litongjava/whisper-cpp-server:1.0.0 /app/simplest -m /models/ggml-base.en.bin /samples/zh.wav 16 | ``` 17 | 18 | ### test with ggml-large-v3.bin 19 | ``` 20 | docker run --rm -v "$(pwd)/models/ggml-large-v3.bin":/models/ggml-large-v3.bin -v "$(pwd)/samples/zh.wav":/samples/zh.wav litongjava/whisper-cpp-server:1.0.0 /app/simplest -m /models/ggml-large-v3.bin /samples/zh.wav 21 | ``` 22 | 23 | ### test with server 24 | ``` 25 | docker run -dit --name=whisper-server --net=host -v "$(pwd)/models/ggml-base.en.bin":/models/ggml-base.en.bin litongjava/whisper-cpp-server:1.0.0 /app/whisper_http_server_base_httplib -m /models/ggml-base.en.bin 26 | ``` 27 | 28 | ### test outout 29 | #### English 30 | ``` 31 | root@ping-Inspiron-3458:~/code/whisper-cpp-server# docker run --rm \ 32 | > -v "$(pwd)/models":/models \ 33 | > -v "$(pwd)/jfk.wav":/jfk.wav \ 34 | > litongjava/whisper-cpp-server:1.0.0 /app/simplest -m /models/ggml-base.en.bin /jfk.wav 35 | whisper_init_from_file_with_params_no_state: loading model from '/models/ggml-base.en.bin' 36 | whisper_model_load: loading model 37 | whisper_model_load: n_vocab = 51864 38 | whisper_model_load: n_audio_ctx = 1500 39 | whisper_model_load: n_audio_state = 512 40 | whisper_model_load: n_audio_head = 8 41 | whisper_model_load: n_audio_layer = 6 42 | whisper_model_load: n_text_ctx = 448 43 | whisper_model_load: n_text_state = 512 44 | whisper_model_load: n_text_head = 8 45 | whisper_model_load: n_text_layer = 6 46 | whisper_model_load: n_mels = 80 47 | whisper_model_load: ftype = 1 48 | whisper_model_load: qntvr = 0 49 | whisper_model_load: type = 2 (base) 50 | whisper_model_load: adding 1607 extra tokens 51 | whisper_model_load: n_langs = 99 52 | whisper_model_load: CPU total size = 147.37 MB 53 | 54 | whisper_model_load: model size = 147.37 MB 55 | whisper_init_state: kv self size = 16.52 MB 56 | whisper_init_state: kv cross size = 18.43 MB 57 | whisper_init_state: compute buffer (conv) = 16.39 MB 58 | whisper_init_state: compute buffer (encode) = 132.07 MB 59 | whisper_init_state: compute buffer (cross) = 4.78 MB 60 | whisper_init_state: compute buffer (decode) = 96.48 MB 61 | 62 | system_info: n_threads = 4 / 4 | AVX = 1 | AVX2 = 1 | AVX512 = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | METAL = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | CUDA = 0 | COREML = 0 | OPENVINO = 0 | 63 | 64 | main: WARNING: model is not multilingual, ignoring language and translation options 65 | main: processing '/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ... 66 | 67 | 68 | whisper_print_timings: load time = 169.51 ms 69 | whisper_print_timings: fallbacks = 0 p / 0 h 70 | whisper_print_timings: mel time = 59.75 ms 71 | whisper_print_timings: sample time = 25.05 ms / 1 runs ( 25.05 ms per run) 72 | whisper_print_timings: encode time = 6384.86 ms / 1 runs ( 6384.86 ms per run) 73 | whisper_print_timings: decode time = 236.91 ms / 27 runs ( 8.77 ms per run) 74 | whisper_print_timings: batchd time = 0.00 ms / 1 runs ( 0.00 ms per run) 75 | whisper_print_timings: prompt time = 0.00 ms / 1 runs ( 0.00 ms per run) 76 | whisper_print_timings: total time = 6885.22 ms 77 | start 78 | 79 | [00:00:00.000 --> 00:00:11.000] And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country. 80 | ``` 81 | #### Chinese 82 | ``` 83 | root@ping-Inspiron-3458:~/code/whisper-cpp-server# docker run --rm -v "$(pwd)/models/ggml-large-v3.bin":/models/ggml-large-v3.bin -v "$(pwd)/samples/zh.wav":/samples/zh.wav litongjava/whisper-cpp-server:1.0.0 /app/simplest -m /models/ggml-large-v3.bin /samples/zh.wav 84 | whisper_init_from_file_with_params_no_state: loading model from '/models/ggml-large-v3.bin' 85 | whisper_model_load: loading model 86 | whisper_model_load: n_vocab = 51866 87 | whisper_model_load: n_audio_ctx = 1500 88 | whisper_model_load: n_audio_state = 1280 89 | whisper_model_load: n_audio_head = 20 90 | whisper_model_load: n_audio_layer = 32 91 | whisper_model_load: n_text_ctx = 448 92 | whisper_model_load: n_text_state = 1280 93 | whisper_model_load: n_text_head = 20 94 | whisper_model_load: n_text_layer = 32 95 | whisper_model_load: n_mels = 128 96 | whisper_model_load: ftype = 1 97 | whisper_model_load: qntvr = 0 98 | whisper_model_load: type = 5 (large v3) 99 | whisper_model_load: adding 1609 extra tokens 100 | whisper_model_load: n_langs = 100 101 | whisper_model_load: CPU total size = 3094.36 MB 102 | whisper_model_load: model size = 3094.36 MB 103 | whisper_init_state: kv self size = 220.20 MB 104 | whisper_init_state: kv cross size = 245.76 MB 105 | whisper_init_state: compute buffer (conv) = 36.26 MB 106 | whisper_init_state: compute buffer (encode) = 926.66 MB 107 | whisper_init_state: compute buffer (cross) = 9.38 MB 108 | whisper_init_state: compute buffer (decode) = 209.26 MB 109 | 110 | system_info: n_threads = 4 / 4 | AVX = 1 | AVX2 = 1 | AVX512 = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | METAL = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | CUDA = 0 | COREML = 0 | OPENVINO = 0 | 111 | 112 | main: processing '/samples/zh.wav' (79949 samples, 5.0 sec), 4 threads, 1 processors, lang = auto, task = transcribe, timestamps = 1 ... 113 | 114 | whisper_full_with_state: auto-detected language: zh (p = 0.998135) 115 | start 116 | 117 | [00:00:00.000 --> 00:00:05.000] 我认为跑步最重要的就是给我带来了身体健康 118 | 119 | whisper_print_timings: load time = 5730.34 ms 120 | whisper_print_timings: fallbacks = 0 p / 0 h 121 | whisper_print_timings: mel time = 27.16 ms 122 | whisper_print_timings: sample time = 27.23 ms / 1 runs ( 27.23 ms per run) 123 | whisper_print_timings: encode time = 253393.73 ms / 2 runs (126696.87 ms per run) 124 | whisper_print_timings: decode time = 11884.19 ms / 21 runs ( 565.91 ms per run) 125 | whisper_print_timings: batchd time = 260.31 ms / 3 runs ( 86.77 ms per run) 126 | whisper_print_timings: prompt time = 0.00 ms / 1 runs ( 0.00 ms per run) 127 | whisper_print_timings: total time = 271354.41 ms 128 | 129 | ``` -------------------------------------------------------------------------------- /distribute/docker/tiny.en-q5_1/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM litongjava/whisper-cpp-server:1.0.0 2 | 3 | COPY models/ggml-tiny.en-q5_1.bin /app/models/ 4 | 5 | EXPOSE 8080 6 | 7 | CMD ["/app/whisper_http_server_base_httplib", "-m", "/app/models/ggml-tiny.en-q5_1.bin"] -------------------------------------------------------------------------------- /distribute/docker/tiny.en-q5_1/readme.md: -------------------------------------------------------------------------------- 1 | ```shell 2 | docker build -t litongjava/whisper-cpp-server:1.0.0-tiny.en-q5_1 -f distribute/docker/tiny.en-q5_1/Dockerfile . 3 | ``` 4 | 5 | ``` 6 | docker run -dit --name whisper-server -p 8080:8080 litongjava/whisper-cpp-server:1.0.0-tiny.en-q5_1 7 | ``` -------------------------------------------------------------------------------- /distribute/docker/whisper-cpp-server-builder/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:latest 2 | 3 | RUN apt-get update && \ 4 | apt-get install -y curl gcc g++ cmake nasm git && \ 5 | rm -rf /var/lib/apt/lists/* -------------------------------------------------------------------------------- /distribute/docker/whisper-cpp-server-builder/readme.md: -------------------------------------------------------------------------------- 1 | ```shell 2 | docker build -t litongjava/whisper-cpp-server-builder:1.0.0 -f distribute/docker/whisper-cpp-server-builder/Dockerfile . 3 | ``` -------------------------------------------------------------------------------- /doc/whiser-cpp-server-http-cn.md: -------------------------------------------------------------------------------- 1 | # whisper-cpp-server http interface 2 | 3 | ## POST /inference 4 | 5 | POST /inference 6 | 7 | > Body 请求参数 8 | 9 | ```yaml 10 | file: file://1.m4a 11 | temperature: "0.2" 12 | response-format: json 13 | audio_format: m4a 14 | 15 | ``` 16 | 17 | ### 请求参数 18 | 19 | | 名称 | 位置 | 类型 | 必选 | 说明 | 20 | |-------------------|------|----------------|----|--------------------------------------| 21 | | body | body | object | 否 | none | 22 | | » file | body | string(binary) | 否 | filename | 23 | | » temperature | body | string | 否 | none | 24 | | » response-format | body | string | 否 | none | 25 | | » audio_format | body | string | 否 | audio format,support m4a,mp3,and wav | 26 | 27 | > 返回示例 28 | 29 | > 成功 30 | 31 | ```json 32 | { 33 | "code": 0, 34 | "data": [ 35 | { 36 | "sentence": " And so my fellow Americans ask not what your country can do for you, ask what you can do for your country.", 37 | "t0": 0, 38 | "t1": 1100 39 | } 40 | ] 41 | } 42 | ``` 43 | 44 | ### 返回结果 45 | 46 | | 状态码 | 状态码含义 | 说明 | 数据模型 | 47 | |-----|---------------------------------------------------------|----|--------| 48 | | 200 | [OK](https://tools.ietf.org/html/rfc7231#section-6.3.1) | 成功 | Inline | 49 | 50 | -------------------------------------------------------------------------------- /doc/whiser-cpp-server-http-en.md: -------------------------------------------------------------------------------- 1 | # whisper-cpp-server HTTP Interface 2 | 3 | ## POST /inference 4 | 5 | POST /inference 6 | 7 | > Body Request Parameters 8 | 9 | ```yaml 10 | file: file://1.m4a 11 | temperature: "0.2" 12 | response-format: json 13 | audio_format: m4a 14 | ``` 15 | 16 | ### Request Parameters 17 | 18 | | Name | Location | Type | Required | Description | 19 | |-------------------|----------|----------------|----------|-----------------------------------------| 20 | | body | body | object | No | none | 21 | | » file | body | string(binary) | No | filename | 22 | | » temperature | body | string | No | none | 23 | | » response-format | body | string | No | none | 24 | | » audio_format | body | string | No | audio format, support m4a, mp3, and wav | 25 | 26 | > Response Example 27 | 28 | > Success 29 | 30 | ```json 31 | { 32 | "code": 0, 33 | "data": [ 34 | { 35 | "sentence": " And so my fellow Americans ask not what your country can do for you, ask what you can do for your country.", 36 | "t0": 0, 37 | "t1": 1100 38 | } 39 | ] 40 | } 41 | ``` 42 | 43 | ### Response Result 44 | 45 | | Status Code | Status Code Meaning | Description | Data Model | 46 | |-------------|---------------------------------------------------------|-------------|------------| 47 | | 200 | [OK](https://tools.ietf.org/html/rfc7231#section-6.3.1) | Success | Inline | -------------------------------------------------------------------------------- /doc/whiser-cpp-server-websocket-cn.md: -------------------------------------------------------------------------------- 1 | # 流式语音识别接口 2 | 3 | ## 1.1. 创建连接 4 | 5 | - ws api: `ws://{server}:{port}/paddlespeech/asr/streaming` 6 | 7 | ## 1.2. 开始信号 8 | 9 | Client 通过开始信号传入流式识别音频信息,以及解码参数。 10 | 11 | #### 字段说明 12 | 13 | | 字段 | 必选 | 类型 | 说明 | 14 | |---------------|----|--------|-----------------------------------| 15 | | name | 是 | string | 传入的音频名称 | 16 | | signal | 是 | string | 流式识别中命令类型 | 17 | | nbest | 是 | int | 识别nbest参数,默认是1 | 18 | | sampleRate | 否 | int | 例如48000,默认16000 | 19 | | bitsPerSample | 否 | int | 位深度(bitsPerSample):表示每个样本的位数,默认16 | 20 | | channels | 否 | int | 通道数,默认1 | 21 | 22 | ### 请求示例 23 | 24 | ```json 25 | { 26 | "name": "test.wav", 27 | "signal": "start", 28 | "nbest": 1 29 | } 30 | ``` 31 | 32 | ### Server 信息 33 | 34 | Server 端返回新连接的情况。 35 | 36 | #### 字段说明 37 | 38 | | 字段 | 必选 | 类型 | 说明 | 39 | |--------|----|--------|-------------------| 40 | | status | 是 | string | ASR服务端状态 | 41 | | signal | 是 | string | 该流式连接必要的准备工作是完成状态 | 42 | 43 | ```json 44 | { 45 | "status": "ok", 46 | "signal": "server_ready" 47 | } 48 | ``` 49 | 50 | 服务端同时需要保存wav端的文件。 51 | 52 | ## 1.3. 数据 53 | 54 | Client和Server建立连接之后,Client端不断地向服务端发送数据。 55 | 56 | ### Client 信息 57 | 58 | 发送 PCM16 数据流到服务端。 59 | 60 | ### Server 信息 61 | 62 | 每发送一个数据,服务端会将该数据包解码的结果返回出来。 63 | 64 | #### 字段说明 65 | 66 | | 字段 | 必选 | 类型 | 说明 | 67 | |--------|----|--------|----------| 68 | | result | 是 | string | ASR解码的结果 | 69 | 70 | ## 1.4. 识别 71 | 72 | Client和Server建立连接之后,Client端不断地向服务端发送数据。服务器会创建一个缓存区,将语音数据放到缓存区中,在对发送的语音进行进行静音检测(vad), 73 | 当检测有到有静音后才会进行语音识别.但是有些时候环境比较嘈杂.会检测有误.客户端可以发送识别命令到服务端,收到识别命令后也会进行识别 74 | 75 | ### Client 信息 76 | ``` 77 | { 78 | "name": "test.wav", 79 | "signal": "recognize", 80 | "nbest": 1 81 | } 82 | ``` 83 | 84 | ### Server 信息 85 | 86 | 每发送一个数据,服务端会将该数据包解码的结果返回出来。 87 | ``` 88 | { 89 | "status": "ok", 90 | "result": "识别结果", 91 | } 92 | ``` 93 | 94 | #### 字段说明 95 | 96 | | 字段 | 必选 | 类型 | 说明 | 97 | |--------|----|--------|----------| 98 | | result | 是 | string | ASR解码的结果 | 99 | 100 | ## 1.5. 结束 101 | 102 | Client 发送完最后一个数据包之后,需要发送给服务端一个结束的命令,通知服务端销毁该链接的相关资源。 103 | 104 | 通过开始信号传入流式识别音频信息,以及解码参数 105 | 106 | #### 字段说明 107 | 108 | | 字段 | 必选 | 类型 | 说明 | 109 | |--------|----|--------|----------------| 110 | | name | 是 | string | 传入的音频名称 | 111 | | signal | 是 | string | 流式识别中命令类型 | 112 | | nbest | 是 | int | 识别nbest参数,默认是1 | 113 | 114 | ```json 115 | { 116 | "name": "test.wav", 117 | "signal": "end", 118 | "nbest": 1 119 | } 120 | ``` 121 | 122 | ### Server 信息 123 | 124 | Server 端信息接收到结束信息之后,将最后的结果返回出去。 125 | 126 | #### 字段说明 127 | 128 | | 字段 | 必选 | 类型 | 说明 | 129 | |--------|----|--------|--------------------| 130 | | name | 是 | string | 传入的音频名称 | 131 | | signal | 是 | string | 流式识别中命令类型,取发送的命令类型 | 132 | | result | 是 | string | 最后的识别结果 | 133 | 134 | ```json 135 | { 136 | "name": "test.wav", 137 | "signal": "end", 138 | "result": "" 139 | } 140 | ``` -------------------------------------------------------------------------------- /doc/whiser-cpp-server-websocket-en.md: -------------------------------------------------------------------------------- 1 | # Streaming Voice Recognition Interface 2 | 3 | ## 1.1. Create Connection 4 | 5 | - ws api: `ws://{server}:{port}/paddlespeech/asr/streaming` 6 | 7 | ## 1.2. Start Signal 8 | 9 | The client sends streaming voice recognition audio information and decoding parameters through the start signal. 10 | 11 | ### Request Example 12 | 13 | ```json 14 | { 15 | "name": "test.wav", 16 | "signal": "start", 17 | "nbest": 1 18 | } 19 | ``` 20 | 21 | ### Server Information 22 | 23 | The server returns the status of the new connection. 24 | 25 | #### Field Description 26 | 27 | | Field | Required | Type | Description | 28 | | ------ | -------- | ------ | ---------------------------------------- | 29 | | status | Yes | string | ASR server status | 30 | | signal | Yes | string | The streaming connection preparation is complete | 31 | 32 | ```json 33 | { 34 | "status": "ok", 35 | "signal": "server_ready" 36 | } 37 | ``` 38 | 39 | The server also needs to save the WAV file. 40 | 41 | ## 1.3. Data 42 | 43 | After the client and server establish a connection, the client continuously sends data to the server. 44 | 45 | ### Client Information 46 | 47 | Send PCM16 data stream to the server. 48 | 49 | ### Server Information 50 | 51 | For each data sent, the server returns the decoding result of that data packet. 52 | 53 | #### Field Description 54 | 55 | | Field | Required | Type | Description | 56 | | ------ | -------- | ------ | ---------------- | 57 | | result | Yes | string | ASR decoding result | 58 | 59 | ## 1.4. End 60 | 61 | After sending the last data packet, the client needs to send an end command to the server to notify the server to destroy the resources associated with that connection. 62 | 63 | Inputting Streaming Voice Recognition Audio Information and Decoding Parameters Through the Start Signal 64 | 65 | #### Field Description 66 | 67 | | Field | Required | Type | Description | 68 | | ------ | -------- | ------ | ------------------------------------ | 69 | | name | Yes | string | Name of the input audio | 70 | | signal | Yes | string | Type of command in streaming recognition | 71 | | nbest | Yes | int | Recognition nbest parameter, default is 1 | 72 | 73 | ```json 74 | { 75 | "name": "test.wav", 76 | "signal": "end", 77 | "nbest": 1 78 | } 79 | ``` 80 | 81 | ### Server Information 82 | 83 | After receiving the end information, the server sends back the final result. 84 | 85 | #### Field Description 86 | 87 | | Field | Required | Type | Description | 88 | | ------ | -------- | ------ | --------------------------------------------- | 89 | | name | Yes | string | Name of the input audio | 90 | | signal | Yes | string | Type of command in streaming recognition, as sent | 91 | | result | Yes | string | The final recognition result | 92 | 93 | ```json 94 | { 95 | "name": "test.wav", 96 | "signal": "end", 97 | "result": "" 98 | } 99 | ``` -------------------------------------------------------------------------------- /examples/audio_vad.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | 7 | #include "../stream/stream_components_service.h" 8 | #include "../stream/stream_components.h" 9 | #include "../common/utils.h" 10 | #include "../common/common.h" 11 | #include 12 | 13 | using namespace stream_components; 14 | 15 | 16 | int main() { 17 | std::string wav_file_path = "../samples/jfk.wav"; // 替换为您的 WAV 文件路径 18 | // audio arrays 19 | std::vector pcmf32; // mono-channel F32 PCM 20 | std::vector> pcmf32s; // stereo-channel F32 PCM 21 | ::read_wav(wav_file_path, pcmf32, pcmf32s, false); 22 | 23 | printf("size of samples %lu\n", pcmf32.size()); 24 | 25 | 26 | whisper_local_stream_params params; 27 | struct whisper_context_params cparams{}; 28 | cparams.use_gpu = params.service.use_gpu; 29 | //Instantiate the service 30 | stream_components::WhisperService whisperService(params.service, params.audio, cparams); 31 | 32 | //Simulate websokcet by adding 1500 data each time. 33 | std::vector audio_buffer; 34 | int chunk_size = 160; // 适用于 16 kHz 采样率的 100 毫秒帧 35 | SpeexPreprocessState *st = speex_preprocess_state_init(chunk_size, WHISPER_SAMPLE_RATE); 36 | int vad = 1; 37 | speex_preprocess_ctl(st, SPEEX_PREPROCESS_SET_VAD, &vad); 38 | 39 | bool last_is_speech = false; 40 | // 处理音频帧 41 | for (size_t i = 0; i < pcmf32.size(); i += chunk_size) { 42 | spx_int16_t frame[chunk_size]; 43 | for (int j = 0; j < chunk_size; ++j) { 44 | if (i + j < pcmf32.size()) { 45 | frame[j] = (spx_int16_t)(pcmf32[i + j] * 32768); 46 | } else { 47 | frame[j] = 0; // 对于超出范围的部分填充 0 48 | } 49 | } 50 | int is_speech = speex_preprocess_run(st, frame); 51 | 52 | // 将当前帧添加到 audio_buffer 53 | audio_buffer.insert(audio_buffer.end(), pcmf32.begin() + i, pcmf32.begin() + std::min(i + chunk_size, pcmf32.size())); 54 | printf("is_speech %d \n",is_speech); 55 | if (!is_speech && last_is_speech) { 56 | bool b = whisperService.process(pcmf32.data(), pcmf32.size()); 57 | const nlohmann::json &json_array = get_result(whisperService.ctx); 58 | const std::basic_string, std::allocator> &string = json_array.dump(); 59 | printf("%s\n",string.c_str()); 60 | return 0; 61 | audio_buffer.clear(); 62 | } 63 | 64 | last_is_speech = is_speech != 0; 65 | } 66 | 67 | speex_preprocess_state_destroy(st); 68 | } -------------------------------------------------------------------------------- /examples/sdl_version.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | int main(int argc, char *argv[]) { 5 | SDL_version compiled; 6 | SDL_version linked; 7 | 8 | SDL_VERSION(&compiled); // get compiled SDL version 9 | SDL_GetVersion(&linked); // get linked SDL version 10 | 11 | std::cout << "We compiled against SDL version " 12 | << static_cast(compiled.major) << "." 13 | << static_cast(compiled.minor) << "." 14 | << static_cast(compiled.patch) << std::endl; 15 | 16 | std::cout << "We are linking against SDL version " 17 | << static_cast(linked.major) << "." 18 | << static_cast(linked.minor) << "." 19 | << static_cast(linked.patch) << std::endl; 20 | 21 | return 0; 22 | } 23 | -------------------------------------------------------------------------------- /examples/stream_local.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "../stream/stream_components_audio.h" 3 | #include "../stream/stream_components_params.h" 4 | #include "../stream/stream_components_output.h" 5 | #include "../stream/stream_components_service.h" 6 | #include "../stream/stream_components.h" 7 | //#include "json.hpp" 8 | 9 | using namespace stream_components; 10 | 11 | int main(int argc, char **argv) { 12 | 13 | // Read parameters... 14 | whisper_local_stream_params params; 15 | 16 | if (whisper_params_parse(argc, argv, params) == false) { 17 | return 1; 18 | } 19 | 20 | // Compute derived parameters 21 | params.initialize(); 22 | //output params 23 | printf("vad:%d\n", params.audio.use_vad); 24 | 25 | // Check parameters 26 | if (params.service.language != "auto" && whisper_lang_id(params.service.language.c_str()) == -1) { 27 | fprintf(stderr, "error: unknown language '%s'\n", params.service.language.c_str()); 28 | whisper_print_usage(argc, argv, params); 29 | exit(0); 30 | } 31 | 32 | // Instantiate the microphone input 33 | stream_components::LocalSDLMicrophone microphone(params.audio); 34 | 35 | // Instantiate the service 36 | struct whisper_context_params cparams; 37 | cparams.use_gpu = params.service.use_gpu; 38 | stream_components::WhisperService whisperService(params.service, params.audio, cparams); 39 | 40 | // Print the 'header'... 41 | WhisperStreamOutput::to_json(std::cout, params.service, whisperService.ctx); 42 | 43 | // Run until Ctrl + C 44 | bool is_running = true; 45 | while (is_running) { 46 | 47 | // handle Ctrl + C 48 | is_running = sdl_poll_events(); 49 | if (!is_running) { 50 | break; 51 | } 52 | 53 | // get next microphone section 54 | auto pcmf32 = microphone.get_next(); 55 | 56 | // process 57 | bool isOk = whisperService.process(pcmf32.data(), pcmf32.size()); 58 | printf("isOk:%d\n", isOk); 59 | // get the whisper output 60 | if (isOk) { 61 | // WhisperOutputPtr outputPtr = std::make_shared(whisperService.ctx, params.service); 62 | // // write the output as json to stdout (for this example) 63 | // if (outputPtr) { 64 | // outputPtr->transcription_to_json(std::cout); 65 | // } 66 | const int n_segments = whisper_full_n_segments(whisperService.ctx); 67 | printf("n_segments:%d\n", n_segments); 68 | for (int i = 0; i < n_segments; ++i) { 69 | const char *text = whisper_full_get_segment_text(whisperService.ctx, i); 70 | const int64_t t0 = whisper_full_get_segment_t0(whisperService.ctx, i); 71 | const int64_t t1 = whisper_full_get_segment_t1(whisperService.ctx, i); 72 | printf("%lld-->%lld:%s\n", t0, t1, text); 73 | } 74 | } 75 | 76 | } 77 | std::cout << "EXITED MAIN LOOP" << std::endl; 78 | return 0; 79 | } -------------------------------------------------------------------------------- /handler/echo_handler.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by Ping Lee on 2024/3/10. 3 | // 4 | 5 | #ifndef WHISPER_CPP_SERVER_ECHO_HANDLER_H 6 | #define WHISPER_CPP_SERVER_ECHO_HANDLER_H 7 | 8 | #include 9 | 10 | // WebSocket /echo handler 11 | auto ws_echo_handler = [](auto *ws, std::string_view message, uWS::OpCode opCode) { 12 | ws->send(message, opCode); 13 | }; 14 | 15 | 16 | #endif //WHISPER_CPP_SERVER_ECHO_HANDLER_H 17 | -------------------------------------------------------------------------------- /handler/hello_handler.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by Tong Li on 2024/3/10. 3 | // 4 | 5 | #include "hello_handler.h" 6 | 7 | void hello_action(uWS::HttpResponse *res, uWS::HttpRequest *req) { 8 | res->end("Hello World!"); 9 | } -------------------------------------------------------------------------------- /handler/hello_handler.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by Tong Li on 2024/3/10. 3 | // 4 | 5 | #ifndef WHISPER_CPP_SERVER_HELLO_HANDLER_H 6 | #define WHISPER_CPP_SERVER_HELLO_HANDLER_H 7 | 8 | 9 | #include // 确保这里包含了必要的定义 10 | 11 | // 更新函数声明以使用具体的参数类型 12 | // HTTP GET /hello handler 13 | void hello_action(uWS::HttpResponse *res, uWS::HttpRequest *req); 14 | 15 | #endif //WHISPER_CPP_SERVER_HELLO_HANDLER_H 16 | -------------------------------------------------------------------------------- /handler/inference_handler.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "httplib.h" 4 | #include "../params/whisper_params.h" 5 | 6 | using namespace httplib; 7 | 8 | void handleInference(const Request &request, Response &response, std::mutex &whisper_mutex, whisper_params ¶ms, 9 | whisper_context *ctx, char *arg_audio_file); 10 | 11 | void handle_events(const Request &request, Response &response, std::mutex &whisper_mutex, whisper_params ¶ms, 12 | whisper_context *ctx, char *arg_audio_file); -------------------------------------------------------------------------------- /handler/ws_save_handler.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by Ping Lee on 2024/3/10. 3 | // 4 | 5 | #ifndef WHISPER_CPP_SERVER_WS_SAVE_HANDLER_H 6 | #define WHISPER_CPP_SERVER_WS_SAVE_HANDLER_H 7 | 8 | #endif //WHISPER_CPP_SERVER_WS_SAVE_HANDLER_H 9 | -------------------------------------------------------------------------------- /models/.gitignore: -------------------------------------------------------------------------------- 1 | *.bin 2 | coreml-encoder-base.en.mlpackage/ 3 | coreml-encoder-large.mlpackage/ 4 | ggml-large-encoder.mlmodelc/ 5 | ggml-base.en-encoder.mlmodelc/ 6 | 7 | -------------------------------------------------------------------------------- /models/README.md: -------------------------------------------------------------------------------- 1 | ## Whisper model files in custom ggml format 2 | 3 | The [original Whisper PyTorch models provided by OpenAI](https://github.com/openai/whisper/blob/main/whisper/__init__.py#L17-L27) 4 | are converted to custom `ggml` format in order to be able to load them in C/C++. 5 | Conversion is performed using the [convert-pt-to-ggml.py](convert-pt-to-ggml.py) script. 6 | 7 | You can either obtain the original models and generate the `ggml` files yourself using the conversion script, 8 | or you can use the [download-ggml-model.sh](download-ggml-model.sh) script to download the already converted models. 9 | Currently, they are hosted on the following locations: 10 | 11 | - https://huggingface.co/ggerganov/whisper.cpp 12 | - https://ggml.ggerganov.com 13 | 14 | Sample download: 15 | 16 | ```java 17 | $ ./download-ggml-model.sh base.en 18 | Downloading ggml model base.en ... 19 | models/ggml-base.en.bin 100%[=============================================>] 141.11M 5.41MB/s in 22s 20 | Done! Model 'base.en' saved in 'models/ggml-base.en.bin' 21 | You can now use it like this: 22 | 23 | $ ./main -m models/ggml-base.en.bin -f samples/jfk.wav 24 | ``` 25 | 26 | To convert the files yourself, use the convert-pt-to-ggml.py script. Here is an example usage. 27 | The original PyTorch files are assumed to have been downloaded into ~/.cache/whisper 28 | Change `~/path/to/repo/whisper/` to the location for your copy of the Whisper source: 29 | ``` 30 | mkdir models/whisper-medium 31 | python models/convert-pt-to-ggml.py ~/.cache/whisper/medium.pt ~/path/to/repo/whisper/ ./models/whisper-medium 32 | mv ./models/whisper-medium/ggml-model.bin models/ggml-medium.bin 33 | rmdir models/whisper-medium 34 | ``` 35 | 36 | A third option to obtain the model files is to download them from Hugging Face: 37 | 38 | https://huggingface.co/ggerganov/whisper.cpp/tree/main 39 | 40 | ## Available models 41 | 42 | | Model | Disk | SHA | 43 | | --- | --- | --- | 44 | | tiny | 75 MiB | `bd577a113a864445d4c299885e0cb97d4ba92b5f` | 45 | | tiny.en | 75 MiB | `c78c86eb1a8faa21b369bcd33207cc90d64ae9df` | 46 | | base | 142 MiB | `465707469ff3a37a2b9b8d8f89f2f99de7299dac` | 47 | | base.en | 142 MiB | `137c40403d78fd54d454da0f9bd998f78703390c` | 48 | | small | 466 MiB | `55356645c2b361a969dfd0ef2c5a50d530afd8d5` | 49 | | small.en | 466 MiB | `db8a495a91d927739e50b3fc1cc4c6b8f6c2d022` | 50 | | medium | 1.5 GiB | `fd9727b6e1217c2f614f9b698455c4ffd82463b4` | 51 | | medium.en | 1.5 GiB | `8c30f0e44ce9560643ebd10bbe50cd20eafd3723` | 52 | | large-v1 | 2.9 GiB | `b1caaf735c4cc1429223d5a74f0f4d0b9b59a299` | 53 | | large-v2 | 2.9 GiB | `0f4c8e34f21cf1a914c59d8b3ce882345ad349d6` | 54 | | large-v3 | 2.9 GiB | `ad82bf6a9043ceed055076d0fd39f5f186ff8062` | 55 | 56 | ## Model files for testing purposes 57 | 58 | The model files prefixed with `for-tests-` are empty (i.e. do not contain any weights) and are used by the CI for 59 | testing purposes. They are directly included in this repository for convenience and the Github Actions CI uses them to 60 | run various sanitizer tests. 61 | 62 | ## Fine-tuned models 63 | 64 | There are community efforts for creating fine-tuned Whisper models using extra training data. For example, this 65 | [blog post](https://huggingface.co/blog/fine-tune-whisper) describes a method for fine-tuning using Hugging Face (HF) 66 | Transformer implementation of Whisper. The produced models are in slightly different format compared to the original 67 | OpenAI format. To read the HF models you can use the [convert-h5-to-ggml.py](convert-h5-to-ggml.py) script like this: 68 | 69 | ```bash 70 | git clone https://github.com/openai/whisper 71 | git clone https://github.com/ggerganov/whisper.cpp 72 | 73 | # clone HF fine-tuned model (this is just an example) 74 | git clone https://huggingface.co/openai/whisper-medium 75 | 76 | # convert the model to ggml 77 | python3 ./whisper.cpp/models/convert-h5-to-ggml.py ./whisper-medium/ ./whisper . 78 | ``` 79 | 80 | ## Distilled models 81 | 82 | Initial support for https://huggingface.co/distil-whisper is available. 83 | 84 | Currently, the chunk-based transcription strategy is not implemented, so there can be sub-optimal quality when using the distilled models with `whisper.cpp`. 85 | 86 | ```bash 87 | # clone OpenAI whisper and whisper.cpp 88 | git clone https://github.com/openai/whisper 89 | git clone https://github.com/ggerganov/whisper.cpp 90 | 91 | # get the models 92 | cd whisper.cpp/models 93 | git clone https://huggingface.co/distil-whisper/distil-medium.en 94 | git clone https://huggingface.co/distil-whisper/distil-large-v2 95 | 96 | # convert to ggml 97 | python3 ./convert-h5-to-ggml.py ./distil-medium.en/ ../../whisper . 98 | mv ggml-model.bin ggml-medium.en-distil.bin 99 | 100 | python3 ./convert-h5-to-ggml.py ./distil-large-v2/ ../../whisper . 101 | mv ggml-model.bin ggml-large-v2-distil.bin 102 | ``` 103 | -------------------------------------------------------------------------------- /models/convert-h5-to-coreml.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import importlib.util 3 | 4 | spec = importlib.util.spec_from_file_location('whisper_to_coreml', 'models/convert-whisper-to-coreml.py') 5 | whisper_to_coreml = importlib.util.module_from_spec(spec) 6 | spec.loader.exec_module(whisper_to_coreml) 7 | 8 | from whisper import load_model 9 | 10 | from copy import deepcopy 11 | import torch 12 | from transformers import WhisperForConditionalGeneration 13 | from huggingface_hub import metadata_update 14 | 15 | # https://github.com/bayartsogt-ya/whisper-multiple-hf-datasets/blob/main/src/multiple_datasets/hub_default_utils.py 16 | WHISPER_MAPPING = { 17 | "layers": "blocks", 18 | "fc1": "mlp.0", 19 | "fc2": "mlp.2", 20 | "final_layer_norm": "mlp_ln", 21 | "layers": "blocks", 22 | ".self_attn.q_proj": ".attn.query", 23 | ".self_attn.k_proj": ".attn.key", 24 | ".self_attn.v_proj": ".attn.value", 25 | ".self_attn_layer_norm": ".attn_ln", 26 | ".self_attn.out_proj": ".attn.out", 27 | ".encoder_attn.q_proj": ".cross_attn.query", 28 | ".encoder_attn.k_proj": ".cross_attn.key", 29 | ".encoder_attn.v_proj": ".cross_attn.value", 30 | ".encoder_attn_layer_norm": ".cross_attn_ln", 31 | ".encoder_attn.out_proj": ".cross_attn.out", 32 | "decoder.layer_norm.": "decoder.ln.", 33 | "encoder.layer_norm.": "encoder.ln_post.", 34 | "embed_tokens": "token_embedding", 35 | "encoder.embed_positions.weight": "encoder.positional_embedding", 36 | "decoder.embed_positions.weight": "decoder.positional_embedding", 37 | "layer_norm": "ln_post", 38 | } 39 | 40 | # https://github.com/bayartsogt-ya/whisper-multiple-hf-datasets/blob/main/src/multiple_datasets/hub_default_utils.py 41 | def rename_keys(s_dict): 42 | keys = list(s_dict.keys()) 43 | for key in keys: 44 | new_key = key 45 | for k, v in WHISPER_MAPPING.items(): 46 | if k in key: 47 | new_key = new_key.replace(k, v) 48 | 49 | print(f"{key} -> {new_key}") 50 | 51 | s_dict[new_key] = s_dict.pop(key) 52 | return s_dict 53 | 54 | # https://github.com/bayartsogt-ya/whisper-multiple-hf-datasets/blob/main/src/multiple_datasets/hub_default_utils.py 55 | def convert_hf_whisper(hf_model_name_or_path: str, whisper_state_path: str): 56 | transformer_model = WhisperForConditionalGeneration.from_pretrained(hf_model_name_or_path) 57 | config = transformer_model.config 58 | 59 | # first build dims 60 | dims = { 61 | 'n_mels': config.num_mel_bins, 62 | 'n_vocab': config.vocab_size, 63 | 'n_audio_ctx': config.max_source_positions, 64 | 'n_audio_state': config.d_model, 65 | 'n_audio_head': config.encoder_attention_heads, 66 | 'n_audio_layer': config.encoder_layers, 67 | 'n_text_ctx': config.max_target_positions, 68 | 'n_text_state': config.d_model, 69 | 'n_text_head': config.decoder_attention_heads, 70 | 'n_text_layer': config.decoder_layers 71 | } 72 | 73 | state_dict = deepcopy(transformer_model.model.state_dict()) 74 | state_dict = rename_keys(state_dict) 75 | 76 | torch.save({"dims": dims, "model_state_dict": state_dict}, whisper_state_path) 77 | 78 | # Ported from models/convert-whisper-to-coreml.py 79 | if __name__ == "__main__": 80 | parser = argparse.ArgumentParser() 81 | parser.add_argument("--model-name", type=str, help="name of model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large-v1, large-v2, large-v3)", required=True) 82 | parser.add_argument("--model-path", type=str, help="path to the model (e.g. if published on HuggingFace: Oblivion208/whisper-tiny-cantonese)", required=True) 83 | parser.add_argument("--encoder-only", type=bool, help="only convert encoder", default=False) 84 | parser.add_argument("--quantize", type=bool, help="quantize weights to F16", default=False) 85 | parser.add_argument("--optimize-ane", type=bool, help="optimize for ANE execution (currently broken)", default=False) 86 | args = parser.parse_args() 87 | 88 | if args.model_name not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large-v1", "large-v2", "large-v3"]: 89 | raise ValueError("Invalid model name") 90 | 91 | pt_target_path = f"models/hf-{args.model_name}.pt" 92 | convert_hf_whisper(args.model_path, pt_target_path) 93 | 94 | whisper = load_model(pt_target_path).cpu() 95 | hparams = whisper.dims 96 | print(hparams) 97 | 98 | if args.optimize_ane: 99 | whisperANE = whisper_to_coreml.WhisperANE(hparams).eval() 100 | whisperANE.load_state_dict(whisper.state_dict()) 101 | 102 | encoder = whisperANE.encoder 103 | decoder = whisperANE.decoder 104 | else: 105 | encoder = whisper.encoder 106 | decoder = whisper.decoder 107 | 108 | # Convert encoder 109 | encoder = whisper_to_coreml.convert_encoder(hparams, encoder, quantize=args.quantize) 110 | encoder.save(f"models/coreml-encoder-{args.model_name}.mlpackage") 111 | 112 | if args.encoder_only is False: 113 | # Convert decoder 114 | decoder = whisper_to_coreml.convert_decoder(hparams, decoder, quantize=args.quantize) 115 | decoder.save(f"models/coreml-decoder-{args.model_name}.mlpackage") 116 | 117 | print("done converting") 118 | -------------------------------------------------------------------------------- /models/convert-h5-to-ggml.py: -------------------------------------------------------------------------------- 1 | # Convert Hugging Face fine-tuned models to ggml format 2 | # 3 | # Usage: 4 | # 5 | # git clone https://github.com/openai/whisper 6 | # git clone https://github.com/ggerganov/whisper.cpp 7 | # git clone https://huggingface.co/openai/whisper-medium 8 | # 9 | # python3 ./whisper.cpp/models/convert-h5-to-ggml.py ./whisper-medium/ ./whisper . 10 | # 11 | # This script is similar to "convert-pt-to-ggml.py" 12 | # 13 | # For more info: 14 | # 15 | # https://github.com/ggerganov/whisper.cpp/issues/157 16 | # 17 | 18 | import io 19 | import os 20 | import sys 21 | import struct 22 | import json 23 | import code 24 | import torch 25 | import numpy as np 26 | from pathlib import Path 27 | 28 | from transformers import WhisperForConditionalGeneration 29 | 30 | conv_map = { 31 | 'self_attn.k_proj' : 'attn.key', 32 | 'self_attn.q_proj' : 'attn.query', 33 | 'self_attn.v_proj' : 'attn.value', 34 | 'self_attn.out_proj' : 'attn.out', 35 | 'self_attn_layer_norm' : 'attn_ln', 36 | 'encoder_attn.q_proj' : 'cross_attn.query', 37 | 'encoder_attn.v_proj' : 'cross_attn.value', 38 | 'encoder_attn.out_proj' : 'cross_attn.out', 39 | 'encoder_attn_layer_norm' : 'cross_attn_ln', 40 | 'fc1' : 'mlp.0', 41 | 'fc2' : 'mlp.2', 42 | 'final_layer_norm' : 'mlp_ln', 43 | 'encoder.layer_norm.bias' : 'encoder.ln_post.bias', 44 | 'encoder.layer_norm.weight' : 'encoder.ln_post.weight', 45 | 'encoder.embed_positions.weight': 'encoder.positional_embedding', 46 | 'decoder.layer_norm.bias' : 'decoder.ln.bias', 47 | 'decoder.layer_norm.weight' : 'decoder.ln.weight', 48 | 'decoder.embed_positions.weight': 'decoder.positional_embedding', 49 | 'decoder.embed_tokens.weight' : 'decoder.token_embedding.weight', 50 | 'proj_out.weight' : 'decoder.proj.weight', 51 | } 52 | 53 | # ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py 54 | def bytes_to_unicode(): 55 | """ 56 | Returns list of utf-8 byte and a corresponding list of unicode strings. 57 | The reversible bpe codes work on unicode strings. 58 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 59 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 60 | This is a significant percentage of your normal, say, 32K bpe vocab. 61 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 62 | And avoids mapping to whitespace/control characters the bpe code barfs on. 63 | """ 64 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 65 | cs = bs[:] 66 | n = 0 67 | for b in range(2**8): 68 | if b not in bs: 69 | bs.append(b) 70 | cs.append(2**8+n) 71 | n += 1 72 | cs = [chr(n) for n in cs] 73 | return dict(zip(bs, cs)) 74 | 75 | if len(sys.argv) < 4: 76 | print("Usage: convert-h5-to-ggml.py dir_model path-to-whisper-repo dir-output [use-f32]\n") 77 | sys.exit(1) 78 | 79 | dir_model = Path(sys.argv[1]) 80 | dir_whisper = Path(sys.argv[2]) 81 | dir_out = Path(sys.argv[3]) 82 | 83 | encoder = json.load((dir_model / "vocab.json").open("r", encoding="utf8")) 84 | encoder_added = json.load((dir_model / "added_tokens.json").open( "r", encoding="utf8")) 85 | hparams = json.load((dir_model / "config.json").open("r", encoding="utf8") ) 86 | 87 | model = WhisperForConditionalGeneration.from_pretrained(dir_model) 88 | 89 | #code.interact(local=locals()) 90 | 91 | n_mels = hparams["num_mel_bins"] 92 | with np.load(os.path.join(dir_whisper, "whisper/assets", "mel_filters.npz")) as f: 93 | filters = torch.from_numpy(f[f"mel_{n_mels}"]) 94 | 95 | dir_tokenizer = dir_model 96 | 97 | fname_out = dir_out / "ggml-model.bin" 98 | 99 | tokens = json.load(open(dir_tokenizer / "vocab.json", "r", encoding="utf8")) 100 | 101 | # use 16-bit or 32-bit floats 102 | use_f16 = True 103 | if len(sys.argv) > 4: 104 | use_f16 = False 105 | fname_out = dir_out / "ggml-model-f32.bin" 106 | 107 | fout = open(fname_out, "wb") 108 | 109 | fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex 110 | fout.write(struct.pack("i", hparams["vocab_size"])) 111 | fout.write(struct.pack("i", hparams["max_source_positions"])) 112 | fout.write(struct.pack("i", hparams["d_model"])) 113 | fout.write(struct.pack("i", hparams["encoder_attention_heads"])) 114 | fout.write(struct.pack("i", hparams["encoder_layers"])) 115 | fout.write(struct.pack("i", hparams["max_length"])) 116 | fout.write(struct.pack("i", hparams["d_model"])) 117 | fout.write(struct.pack("i", hparams["decoder_attention_heads"])) 118 | fout.write(struct.pack("i", hparams["decoder_layers"])) 119 | fout.write(struct.pack("i", hparams["num_mel_bins"])) 120 | fout.write(struct.pack("i", use_f16)) 121 | 122 | fout.write(struct.pack("i", filters.shape[0])) 123 | fout.write(struct.pack("i", filters.shape[1])) 124 | for i in range(filters.shape[0]): 125 | for j in range(filters.shape[1]): 126 | fout.write(struct.pack("f", filters[i][j])) 127 | 128 | byte_encoder = bytes_to_unicode() 129 | byte_decoder = {v:k for k, v in byte_encoder.items()} 130 | 131 | fout.write(struct.pack("i", len(tokens))) 132 | 133 | tokens = sorted(tokens.items(), key=lambda x: x[1]) 134 | for key in tokens: 135 | text = bytearray([byte_decoder[c] for c in key[0]]) 136 | fout.write(struct.pack("i", len(text))) 137 | fout.write(text) 138 | 139 | list_vars = model.state_dict() 140 | for name in list_vars.keys(): 141 | # this seems to not be used 142 | # ref: https://github.com/huggingface/transformers/blob/9a5b84a0076a04fe9596da72e8668069d4f09ea0/src/transformers/models/whisper/modeling_whisper.py#L1099-L1106 143 | if name == "proj_out.weight": 144 | print('Skipping', name) 145 | continue 146 | 147 | src = name 148 | 149 | nn = name 150 | if name != "proj_out.weight": 151 | nn = nn.split(".")[1:] 152 | else: 153 | nn = nn.split(".") 154 | 155 | if nn[1] == "layers": 156 | nn[1] = "blocks" 157 | if ".".join(nn[3:-1]) == "encoder_attn.k_proj": 158 | mapped = "attn.key" if nn[0] == "encoder" else "cross_attn.key" 159 | else: 160 | mapped = conv_map[".".join(nn[3:-1])] 161 | name = ".".join(nn[:3] + [mapped] + nn[-1:]) 162 | else: 163 | name = ".".join(nn) 164 | name = conv_map[name] if name in conv_map else name 165 | 166 | print(src, ' -> ', name) 167 | data = list_vars[src].squeeze().numpy() 168 | data = data.astype(np.float16) 169 | 170 | # reshape conv bias from [n] to [n, 1] 171 | if name in ["encoder.conv1.bias", "encoder.conv2.bias"]: 172 | data = data.reshape(data.shape[0], 1) 173 | print(" Reshaped variable: " , name , " to shape: ", data.shape) 174 | 175 | n_dims = len(data.shape) 176 | print(name, n_dims, data.shape) 177 | 178 | # looks like the whisper models are in f16 by default 179 | # so we need to convert the small tensors to f32 until we fully support f16 in ggml 180 | # ftype == 0 -> float32, ftype == 1 -> float16 181 | ftype = 1 182 | if use_f16: 183 | if n_dims < 2 or \ 184 | name == "encoder.conv1.bias" or \ 185 | name == "encoder.conv2.bias" or \ 186 | name == "encoder.positional_embedding" or \ 187 | name == "decoder.positional_embedding": 188 | print(" Converting to float32") 189 | data = data.astype(np.float32) 190 | ftype = 0 191 | else: 192 | data = data.astype(np.float32) 193 | ftype = 0 194 | 195 | # header 196 | str_ = name.encode('utf-8') 197 | fout.write(struct.pack("iii", n_dims, len(str_), ftype)) 198 | for i in range(n_dims): 199 | fout.write(struct.pack("i", data.shape[n_dims - 1 - i])) 200 | fout.write(str_) 201 | 202 | # data 203 | data.tofile(fout) 204 | 205 | fout.close() 206 | 207 | print("Done. Output file: " , fname_out) 208 | print("") 209 | -------------------------------------------------------------------------------- /models/convert-pt-to-ggml.py: -------------------------------------------------------------------------------- 1 | # Convert Whisper transformer model from PyTorch to ggml format 2 | # 3 | # Usage: python convert-pt-to-ggml.py ~/.cache/whisper/medium.pt ~/path/to/repo/whisper/ ./models/whisper-medium 4 | # 5 | # You need to clone the original repo in ~/path/to/repo/whisper/ 6 | # 7 | # git clone https://github.com/openai/whisper ~/path/to/repo/whisper/ 8 | # 9 | # It is used to various assets needed by the algorithm: 10 | # 11 | # - tokenizer 12 | # - mel filters 13 | # 14 | # Also, you need to have the original models in ~/.cache/whisper/ 15 | # See the original repo for more details. 16 | # 17 | # This script loads the specified model and whisper assets and saves them in ggml format. 18 | # The output is a single binary file containing the following information: 19 | # 20 | # - hparams 21 | # - mel filters 22 | # - tokenizer vocab 23 | # - model variables 24 | # 25 | # For each variable, write the following: 26 | # 27 | # - Number of dimensions (int) 28 | # - Name length (int) 29 | # - Dimensions (int[n_dims]) 30 | # - Name (char[name_length]) 31 | # - Data (float[n_dims]) 32 | # 33 | 34 | import io 35 | import os 36 | import sys 37 | import struct 38 | import json 39 | import code 40 | import torch 41 | import numpy as np 42 | import base64 43 | from pathlib import Path 44 | #from transformers import GPTJForCausalLM 45 | #from transformers import GPT2TokenizerFast 46 | 47 | # ref: https://github.com/openai/whisper/blob/8cf36f3508c9acd341a45eb2364239a3d81458b9/whisper/tokenizer.py#L10-L110 48 | #LANGUAGES = { 49 | # "en": "english", 50 | # "zh": "chinese", 51 | # "de": "german", 52 | # "es": "spanish", 53 | # "ru": "russian", 54 | # "ko": "korean", 55 | # "fr": "french", 56 | # "ja": "japanese", 57 | # "pt": "portuguese", 58 | # "tr": "turkish", 59 | # "pl": "polish", 60 | # "ca": "catalan", 61 | # "nl": "dutch", 62 | # "ar": "arabic", 63 | # "sv": "swedish", 64 | # "it": "italian", 65 | # "id": "indonesian", 66 | # "hi": "hindi", 67 | # "fi": "finnish", 68 | # "vi": "vietnamese", 69 | # "iw": "hebrew", 70 | # "uk": "ukrainian", 71 | # "el": "greek", 72 | # "ms": "malay", 73 | # "cs": "czech", 74 | # "ro": "romanian", 75 | # "da": "danish", 76 | # "hu": "hungarian", 77 | # "ta": "tamil", 78 | # "no": "norwegian", 79 | # "th": "thai", 80 | # "ur": "urdu", 81 | # "hr": "croatian", 82 | # "bg": "bulgarian", 83 | # "lt": "lithuanian", 84 | # "la": "latin", 85 | # "mi": "maori", 86 | # "ml": "malayalam", 87 | # "cy": "welsh", 88 | # "sk": "slovak", 89 | # "te": "telugu", 90 | # "fa": "persian", 91 | # "lv": "latvian", 92 | # "bn": "bengali", 93 | # "sr": "serbian", 94 | # "az": "azerbaijani", 95 | # "sl": "slovenian", 96 | # "kn": "kannada", 97 | # "et": "estonian", 98 | # "mk": "macedonian", 99 | # "br": "breton", 100 | # "eu": "basque", 101 | # "is": "icelandic", 102 | # "hy": "armenian", 103 | # "ne": "nepali", 104 | # "mn": "mongolian", 105 | # "bs": "bosnian", 106 | # "kk": "kazakh", 107 | # "sq": "albanian", 108 | # "sw": "swahili", 109 | # "gl": "galician", 110 | # "mr": "marathi", 111 | # "pa": "punjabi", 112 | # "si": "sinhala", 113 | # "km": "khmer", 114 | # "sn": "shona", 115 | # "yo": "yoruba", 116 | # "so": "somali", 117 | # "af": "afrikaans", 118 | # "oc": "occitan", 119 | # "ka": "georgian", 120 | # "be": "belarusian", 121 | # "tg": "tajik", 122 | # "sd": "sindhi", 123 | # "gu": "gujarati", 124 | # "am": "amharic", 125 | # "yi": "yiddish", 126 | # "lo": "lao", 127 | # "uz": "uzbek", 128 | # "fo": "faroese", 129 | # "ht": "haitian creole", 130 | # "ps": "pashto", 131 | # "tk": "turkmen", 132 | # "nn": "nynorsk", 133 | # "mt": "maltese", 134 | # "sa": "sanskrit", 135 | # "lb": "luxembourgish", 136 | # "my": "myanmar", 137 | # "bo": "tibetan", 138 | # "tl": "tagalog", 139 | # "mg": "malagasy", 140 | # "as": "assamese", 141 | # "tt": "tatar", 142 | # "haw": "hawaiian", 143 | # "ln": "lingala", 144 | # "ha": "hausa", 145 | # "ba": "bashkir", 146 | # "jw": "javanese", 147 | # "su": "sundanese", 148 | #} 149 | 150 | ## ref: https://github.com/openai/whisper/blob/8cf36f3508c9acd341a45eb2364239a3d81458b9/whisper/tokenizer.py#L273-L292 151 | #def build_tokenizer(path_to_whisper_repo: str, name: str = "gpt2"): 152 | # os.environ["TOKENIZERS_PARALLELISM"] = "false" 153 | # path = os.path.join(path_to_whisper_repo, "whisper/assets", name) 154 | # tokenizer = GPT2TokenizerFast.from_pretrained(path) 155 | # 156 | # specials = [ 157 | # "<|startoftranscript|>", 158 | # *[f"<|{lang}|>" for lang in LANGUAGES.keys()], 159 | # "<|translate|>", 160 | # "<|transcribe|>", 161 | # "<|startoflm|>", 162 | # "<|startofprev|>", 163 | # "<|nocaptions|>", 164 | # "<|notimestamps|>", 165 | # ] 166 | # 167 | # tokenizer.add_special_tokens(dict(additional_special_tokens=specials)) 168 | # return tokenizer 169 | 170 | # ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py 171 | def bytes_to_unicode(): 172 | """ 173 | Returns list of utf-8 byte and a corresponding list of unicode strings. 174 | The reversible bpe codes work on unicode strings. 175 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 176 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 177 | This is a signficant percentage of your normal, say, 32K bpe vocab. 178 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 179 | And avoids mapping to whitespace/control characters the bpe code barfs on. 180 | """ 181 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 182 | cs = bs[:] 183 | n = 0 184 | for b in range(2**8): 185 | if b not in bs: 186 | bs.append(b) 187 | cs.append(2**8+n) 188 | n += 1 189 | cs = [chr(n) for n in cs] 190 | return dict(zip(bs, cs)) 191 | 192 | 193 | if len(sys.argv) < 4: 194 | print("Usage: convert-pt-to-ggml.py model.pt path-to-whisper-repo dir-output [use-f32]\n") 195 | sys.exit(1) 196 | 197 | fname_inp = Path(sys.argv[1]) 198 | dir_whisper = Path(sys.argv[2]) 199 | dir_out = Path(sys.argv[3]) 200 | 201 | # try to load PyTorch binary data 202 | try: 203 | model_bytes = open(fname_inp, "rb").read() 204 | with io.BytesIO(model_bytes) as fp: 205 | checkpoint = torch.load(fp, map_location="cpu") 206 | except Exception: 207 | print("Error: failed to load PyTorch model file:" , fname_inp) 208 | sys.exit(1) 209 | 210 | hparams = checkpoint["dims"] 211 | print("hparams:", hparams) 212 | 213 | list_vars = checkpoint["model_state_dict"] 214 | 215 | #print(list_vars['encoder.positional_embedding']) 216 | #print(list_vars['encoder.conv1.weight']) 217 | #print(list_vars['encoder.conv1.weight'].shape) 218 | 219 | # load mel filters 220 | n_mels = hparams["n_mels"] 221 | with np.load(dir_whisper / "whisper" / "assets" / "mel_filters.npz") as f: 222 | filters = torch.from_numpy(f[f"mel_{n_mels}"]) 223 | #print (filters) 224 | 225 | #code.interact(local=locals()) 226 | 227 | # load tokenizer 228 | # for backwards compatibility, also check for older hf_transformers format tokenizer files 229 | # old format: dir_whisper/whisper/assets/[multilingual/gpt2]/vocab.json 230 | # new format: dir_whisper/whisper/assets/[multilingual/gpt2].tiktoken 231 | multilingual = hparams["n_vocab"] >= 51865 232 | tokenizer = dir_whisper / "whisper" / "assets" / (multilingual and "multilingual.tiktoken" or "gpt2.tiktoken") 233 | tokenizer_type = "tiktoken" 234 | if not tokenizer.is_file(): 235 | tokenizer = dir_whisper / "whisper" / "assets" / (multilingual and "multilingual" or "gpt2") / "vocab.json" 236 | tokenizer_type = "hf_transformers" 237 | if not tokenizer.is_file(): 238 | print("Error: failed to find either tiktoken or hf_transformers tokenizer file:", tokenizer) 239 | sys.exit(1) 240 | 241 | byte_encoder = bytes_to_unicode() 242 | byte_decoder = {v:k for k, v in byte_encoder.items()} 243 | 244 | if tokenizer_type == "tiktoken": 245 | with open(tokenizer, "rb") as f: 246 | contents = f.read() 247 | tokens = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line)} 248 | elif tokenizer_type == "hf_transformers": 249 | with open(tokenizer, "r", encoding="utf8") as f: 250 | _tokens_raw = json.load(f) 251 | if '<|endoftext|>' in _tokens_raw: 252 | # ensures exact same model as tokenizer_type == tiktoken 253 | # details: https://github.com/ggerganov/whisper.cpp/pull/725 254 | del _tokens_raw['<|endoftext|>'] 255 | tokens = {bytes([byte_decoder[c] for c in token]): int(idx) for token, idx in _tokens_raw.items()} 256 | 257 | # output in the same directory as the model 258 | fname_out = dir_out / "ggml-model.bin" 259 | 260 | # use 16-bit or 32-bit floats 261 | use_f16 = True 262 | if len(sys.argv) > 4: 263 | use_f16 = False 264 | fname_out = dir_out / "ggml-model-f32.bin" 265 | 266 | fout = fname_out.open("wb") 267 | 268 | fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex 269 | fout.write(struct.pack("i", hparams["n_vocab"])) 270 | fout.write(struct.pack("i", hparams["n_audio_ctx"])) 271 | fout.write(struct.pack("i", hparams["n_audio_state"])) 272 | fout.write(struct.pack("i", hparams["n_audio_head"])) 273 | fout.write(struct.pack("i", hparams["n_audio_layer"])) 274 | fout.write(struct.pack("i", hparams["n_text_ctx"])) 275 | fout.write(struct.pack("i", hparams["n_text_state"])) 276 | fout.write(struct.pack("i", hparams["n_text_head"])) 277 | fout.write(struct.pack("i", hparams["n_text_layer"])) 278 | fout.write(struct.pack("i", hparams["n_mels"])) 279 | fout.write(struct.pack("i", use_f16)) 280 | 281 | # write mel filters 282 | fout.write(struct.pack("i", filters.shape[0])) 283 | fout.write(struct.pack("i", filters.shape[1])) 284 | for i in range(filters.shape[0]): 285 | for j in range(filters.shape[1]): 286 | fout.write(struct.pack("f", filters[i][j])) 287 | 288 | # write tokenizer 289 | fout.write(struct.pack("i", len(tokens))) 290 | 291 | for key in tokens: 292 | fout.write(struct.pack("i", len(key))) 293 | fout.write(key) 294 | 295 | for name in list_vars.keys(): 296 | data = list_vars[name].squeeze().numpy() 297 | print("Processing variable: " , name , " with shape: ", data.shape) 298 | 299 | # reshape conv bias from [n] to [n, 1] 300 | if name in ["encoder.conv1.bias", "encoder.conv2.bias"]: 301 | data = data.reshape(data.shape[0], 1) 302 | print(f" Reshaped variable: {name} to shape: ", data.shape) 303 | 304 | n_dims = len(data.shape) 305 | 306 | # looks like the whisper models are in f16 by default 307 | # so we need to convert the small tensors to f32 until we fully support f16 in ggml 308 | # ftype == 0 -> float32, ftype == 1 -> float16 309 | ftype = 1 310 | if use_f16: 311 | if n_dims < 2 or \ 312 | name == "encoder.conv1.bias" or \ 313 | name == "encoder.conv2.bias" or \ 314 | name == "encoder.positional_embedding" or \ 315 | name == "decoder.positional_embedding": 316 | print(" Converting to float32") 317 | data = data.astype(np.float32) 318 | ftype = 0 319 | else: 320 | data = data.astype(np.float32) 321 | ftype = 0 322 | 323 | #if name.startswith("encoder"): 324 | # if name.endswith("mlp.0.weight") or \ 325 | # name.endswith("mlp.2.weight"): 326 | # print(" Transposing") 327 | # data = data.transpose() 328 | 329 | # header 330 | str_ = name.encode('utf-8') 331 | fout.write(struct.pack("iii", n_dims, len(str_), ftype)) 332 | for i in range(n_dims): 333 | fout.write(struct.pack("i", data.shape[n_dims - 1 - i])) 334 | fout.write(str_) 335 | 336 | # data 337 | data.tofile(fout) 338 | 339 | fout.close() 340 | 341 | print("Done. Output file: " , fname_out) 342 | print("") 343 | -------------------------------------------------------------------------------- /models/convert-whisper-to-coreml.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn.functional as F 4 | import coremltools as ct 5 | 6 | from torch import Tensor 7 | from torch import nn 8 | from typing import Dict 9 | from typing import Optional 10 | from ane_transformers.reference.layer_norm import LayerNormANE as LayerNormANEBase 11 | from coremltools.models.neural_network.quantization_utils import quantize_weights 12 | from whisper.model import Whisper, AudioEncoder, TextDecoder, ResidualAttentionBlock, MultiHeadAttention, ModelDimensions 13 | from whisper import load_model 14 | 15 | # Use for changing dim of input in encoder and decoder embeddings 16 | def linear_to_conv2d_map(state_dict, prefix, local_metadata, strict, 17 | missing_keys, unexpected_keys, error_msgs): 18 | """ 19 | Unsqueeze twice to map nn.Linear weights to nn.Conv2d weights 20 | """ 21 | for k in state_dict: 22 | is_attention = all(substr in k for substr in ['attn', '.weight']) 23 | is_mlp = any(k.endswith(s) for s in ['mlp.0.weight', 'mlp.2.weight']) 24 | 25 | if (is_attention or is_mlp) and len(state_dict[k].shape) == 2: 26 | state_dict[k] = state_dict[k][:, :, None, None] 27 | 28 | 29 | def correct_for_bias_scale_order_inversion(state_dict, prefix, local_metadata, 30 | strict, missing_keys, 31 | unexpected_keys, error_msgs): 32 | state_dict[prefix + 'bias'] = state_dict[prefix + 'bias'] / state_dict[prefix + 'weight'] 33 | return state_dict 34 | 35 | class LayerNormANE(LayerNormANEBase): 36 | 37 | def __init__(self, *args, **kwargs): 38 | super().__init__(*args, **kwargs) 39 | self._register_load_state_dict_pre_hook( 40 | correct_for_bias_scale_order_inversion) 41 | 42 | class MultiHeadAttentionANE(MultiHeadAttention): 43 | def __init__(self, n_state: int, n_head: int): 44 | super().__init__(n_state, n_head) 45 | self.query = nn.Conv2d(n_state, n_state, kernel_size=1) 46 | self.key = nn.Conv2d(n_state, n_state, kernel_size=1, bias=False) 47 | self.value = nn.Conv2d(n_state, n_state, kernel_size=1) 48 | self.out = nn.Conv2d(n_state, n_state, kernel_size=1) 49 | 50 | def forward(self, 51 | x: Tensor, 52 | xa: Optional[Tensor] = None, 53 | mask: Optional[Tensor] = None, 54 | kv_cache: Optional[dict] = None): 55 | 56 | q = self.query(x) 57 | 58 | if kv_cache is None or xa is None or self.key not in kv_cache: 59 | # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors; 60 | # otherwise, perform key/value projections for self- or cross-attention as usual. 61 | k = self.key(x if xa is None else xa) 62 | v = self.value(x if xa is None else xa) 63 | 64 | else: 65 | # for cross-attention, calculate keys and values once and reuse in subsequent calls. 66 | k = kv_cache[self.key] 67 | v = kv_cache[self.value] 68 | 69 | wv, qk = self.qkv_attention_ane(q, k, v, mask) 70 | 71 | return self.out(wv), qk 72 | 73 | def qkv_attention_ane(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None): 74 | 75 | _, dim, _, seqlen = q.size() 76 | 77 | dim_per_head = dim // self.n_head 78 | 79 | scale = float(dim_per_head)**-0.5 80 | 81 | q = q * scale 82 | 83 | mh_q = q.split(dim_per_head, dim=1) 84 | mh_k = k.transpose(1,3).split(dim_per_head, dim=3) 85 | mh_v = v.split(dim_per_head, dim=1) 86 | 87 | mh_qk = [ 88 | torch.einsum('bchq,bkhc->bkhq', [qi, ki]) 89 | for qi, ki in zip(mh_q, mh_k) 90 | ] # (batch_size, max_seq_length, 1, max_seq_length) * n_heads 91 | 92 | if mask is not None: 93 | for head_idx in range(self.n_head): 94 | mh_qk[head_idx] = mh_qk[head_idx] + mask[:, :seqlen, :, :seqlen] 95 | 96 | attn_weights = [aw.softmax(dim=1) for aw in mh_qk] # (batch_size, max_seq_length, 1, max_seq_length) * n_heads 97 | attn = [torch.einsum('bkhq,bchk->bchq', wi, vi) for wi, vi in zip(attn_weights, mh_v)] # (batch_size, dim_per_head, 1, max_seq_length) * n_heads 98 | attn = torch.cat(attn, dim=1) # (batch_size, dim, 1, max_seq_length) 99 | 100 | return attn, torch.cat(mh_qk, dim=1).float().detach() 101 | 102 | 103 | class ResidualAttentionBlockANE(ResidualAttentionBlock): 104 | def __init__(self, n_state: int, n_head: int, cross_attention: bool = False): 105 | super().__init__(n_state, n_head, cross_attention) 106 | self.attn = MultiHeadAttentionANE(n_state, n_head) 107 | self.attn_ln = LayerNormANE(n_state) 108 | self.cross_attn = MultiHeadAttentionANE(n_state, n_head) if cross_attention else None 109 | self.cross_attn_ln = LayerNormANE(n_state) if cross_attention else None 110 | 111 | n_mlp = n_state * 4 112 | self.mlp = nn.Sequential( 113 | nn.Conv2d(n_state, n_mlp, kernel_size=1), 114 | nn.GELU(), 115 | nn.Conv2d(n_mlp, n_state, kernel_size=1) 116 | ) 117 | self.mlp_ln = LayerNormANE(n_state) 118 | 119 | 120 | class AudioEncoderANE(AudioEncoder): 121 | def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int): 122 | super().__init__(n_mels, n_ctx, n_state, n_head, n_layer) 123 | 124 | self.blocks = nn.ModuleList( 125 | [ResidualAttentionBlockANE(n_state, n_head) for _ in range(n_layer)] 126 | ) 127 | self.ln_post = LayerNormANE(n_state) 128 | 129 | def forward(self, x: Tensor): 130 | """ 131 | x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) 132 | the mel spectrogram of the audio 133 | """ 134 | x = F.gelu(self.conv1(x)) 135 | x = F.gelu(self.conv2(x)) 136 | 137 | assert x.shape[1:] == self.positional_embedding.shape[::-1], "incorrect audio shape" 138 | 139 | # Add positional embedding and add dummy dim for ANE 140 | x = (x + self.positional_embedding.transpose(0,1)).to(x.dtype).unsqueeze(2) 141 | 142 | for block in self.blocks: 143 | x = block(x) 144 | 145 | x = self.ln_post(x) 146 | 147 | # """ 148 | # TODO: 149 | # I think we need to transpose the result here to make it fit whisper.cpp memory order. 150 | # However, even doing this, the results are still wrong. Kind of less wrong compared to 151 | # not transposing, but still wrong. 152 | 153 | # Also, I don't know why the original OpenAI implementation does not need to transpose 154 | 155 | # transpose to (batch_size, n_ctx, n_state) 156 | # x : torch.Tensor, shape = (batch_size, n_state, 1, n_ctx) 157 | 158 | # """ 159 | # x = x.transpose(1,3) 160 | 161 | return x 162 | 163 | class TextDecoderANE(TextDecoder): 164 | 165 | def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int): 166 | super().__init__(n_vocab, n_ctx, n_state, n_head, n_layer) 167 | 168 | self.blocks= nn.ModuleList( 169 | [ResidualAttentionBlockANE(n_state, n_head, cross_attention=True) for _ in range(n_layer)] 170 | ) 171 | self.ln= LayerNormANE(n_state) 172 | 173 | def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): 174 | """ 175 | x : torch.LongTensor, shape = (batch_size, <= n_ctx) 176 | the text tokens 177 | xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx) 178 | the encoded audio features to be attended on 179 | """ 180 | offset = next(iter(kv_cache.values())).shape[3] if kv_cache else 0 181 | x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]] 182 | x = x.to(xa.dtype) 183 | 184 | # Reformat for ANE 185 | mask = self.mask[None, None, :, :].permute(0,3,1,2) 186 | x = x.transpose(1,2).unsqueeze(2) 187 | 188 | for block in self.blocks: 189 | x = block(x, xa, mask=mask, kv_cache=kv_cache) 190 | 191 | x = self.ln(x) 192 | 193 | # Reformat back from ANE 194 | x = x.permute(0,2,3,1).squeeze(0) 195 | 196 | # ANE can only load tensors with dim size of at most 16,384 - whisper uses 51,864 (en) or 51,865 (multi-lang) tokens so we need to compute in chunks 197 | if self.token_embedding.weight.shape[0] >= 51865: 198 | # split in 11 chunks - 4715 each 199 | splits = self.token_embedding.weight.split(self.token_embedding.weight.shape[0]//11, dim=0) 200 | logits = torch.cat([torch.einsum('bid,jd->bij', x, split) for split in splits]).view(*x.shape[:2], -1) 201 | else: 202 | # split in 12 chunks - 4322 each 203 | assert(self.token_embedding.weight.shape[0] == 51864) 204 | splits = self.token_embedding.weight.split(self.token_embedding.weight.shape[0]//12, dim=0) 205 | logits = torch.cat([torch.einsum('bid,jd->bij', x, split) for split in splits]).view(*x.shape[:2], -1) 206 | 207 | return logits 208 | 209 | class WhisperANE(Whisper): 210 | def __init__(self, dims: ModelDimensions): 211 | super().__init__(dims) 212 | 213 | self.encoder = AudioEncoderANE( 214 | self.dims.n_mels, 215 | self.dims.n_audio_ctx, 216 | self.dims.n_audio_state, 217 | self.dims.n_audio_head, 218 | self.dims.n_audio_layer, 219 | ) 220 | self.decoder = TextDecoderANE( 221 | self.dims.n_vocab, 222 | self.dims.n_text_ctx, 223 | self.dims.n_text_state, 224 | self.dims.n_text_head, 225 | self.dims.n_text_layer, 226 | ) 227 | 228 | self._register_load_state_dict_pre_hook(linear_to_conv2d_map) 229 | 230 | def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]: 231 | return self.decoder(tokens, self.encoder(mel)) 232 | 233 | def install_kv_cache_hooks(self, cache: Optional[dict] = None): 234 | cache = {**cache} if cache is not None else {} 235 | hooks = [] 236 | 237 | def save_to_cache(module, _, output): 238 | if module not in cache or output.shape[3] > self.decoder.positional_embedding.shape[0]: 239 | cache[module] = output # save as-is, for the first token or cross attention 240 | else: 241 | cache[module] = torch.cat([cache[module], output], dim=3).detach() 242 | return cache[module] 243 | 244 | def install_hooks(layer: nn.Module): 245 | if isinstance(layer, MultiHeadAttentionANE): 246 | hooks.append(layer.key.register_forward_hook(save_to_cache)) 247 | hooks.append(layer.value.register_forward_hook(save_to_cache)) 248 | 249 | self.decoder.apply(install_hooks) 250 | return cache, hooks 251 | 252 | def convert_encoder(hparams, model, quantize=False): 253 | model.eval() 254 | 255 | input_shape = (1, hparams.n_mels, 3000) 256 | input_data = torch.randn(input_shape) 257 | traced_model = torch.jit.trace(model, input_data) 258 | 259 | model = ct.convert( 260 | traced_model, 261 | convert_to=None if quantize else "mlprogram", # convert will fail if weights are quantized, not sure why 262 | inputs=[ct.TensorType(name="logmel_data", shape=input_shape)], 263 | outputs=[ct.TensorType(name="output")], 264 | compute_units=ct.ComputeUnit.ALL 265 | ) 266 | 267 | if quantize: 268 | model = quantize_weights(model, nbits=16) 269 | 270 | return model 271 | 272 | def convert_decoder(hparams, model, quantize=False): 273 | model.eval() 274 | 275 | tokens_shape = (1, 1) 276 | audio_shape = (1, hparams.n_audio_state, 1, 1500) 277 | 278 | audio_data = torch.randn(audio_shape) 279 | token_data = torch.randint(50257, tokens_shape).long() 280 | traced_model = torch.jit.trace(model, (token_data, audio_data)) 281 | 282 | model = ct.convert( 283 | traced_model, 284 | convert_to=None if quantize else "mlprogram", # convert will fail if weights are quantized, not sure why 285 | inputs=[ 286 | ct.TensorType(name="token_data", shape=tokens_shape, dtype=int), 287 | ct.TensorType(name="audio_data", shape=audio_shape) 288 | ] 289 | ) 290 | 291 | if quantize: 292 | model = quantize_weights(model, nbits=16) 293 | 294 | return model 295 | 296 | 297 | if __name__ == "__main__": 298 | parser = argparse.ArgumentParser() 299 | parser.add_argument("--model", type=str, help="model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large-v1, large-v2, large-v3)", required=True) 300 | parser.add_argument("--encoder-only", type=bool, help="only convert encoder", default=False) 301 | parser.add_argument("--quantize", type=bool, help="quantize weights to F16", default=False) 302 | parser.add_argument("--optimize-ane", type=bool, help="optimize for ANE execution (currently broken)", default=False) 303 | args = parser.parse_args() 304 | 305 | if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "small.en-tdrz", "medium", "medium.en", "large-v1", "large-v2", "large-v3"]: 306 | raise ValueError("Invalid model name") 307 | 308 | whisper = load_model(args.model).cpu() 309 | hparams = whisper.dims 310 | print(hparams) 311 | 312 | if args.optimize_ane: 313 | whisperANE = WhisperANE(hparams).eval() 314 | whisperANE.load_state_dict(whisper.state_dict()) 315 | 316 | encoder = whisperANE.encoder 317 | decoder = whisperANE.decoder 318 | else: 319 | encoder = whisper.encoder 320 | decoder = whisper.decoder 321 | 322 | # Convert encoder 323 | encoder = convert_encoder(hparams, encoder, quantize=args.quantize) 324 | encoder.save(f"models/coreml-encoder-{args.model}.mlpackage") 325 | 326 | if args.encoder_only is False: 327 | # Convert decoder 328 | decoder = convert_decoder(hparams, decoder, quantize=args.quantize) 329 | decoder.save(f"models/coreml-decoder-{args.model}.mlpackage") 330 | 331 | print("done converting") 332 | -------------------------------------------------------------------------------- /models/convert-whisper-to-openvino.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from whisper import load_model 4 | import os 5 | from openvino.tools import mo 6 | from openvino.runtime import serialize 7 | import shutil 8 | 9 | def convert_encoder(hparams, encoder, mname): 10 | encoder.eval() 11 | 12 | mel = torch.zeros((1, hparams.n_mels, 3000)) 13 | 14 | onnx_folder=os.path.join(os.path.dirname(__file__),"onnx_encoder") 15 | 16 | #create a directory to store the onnx model, and other collateral that is saved during onnx export procedure 17 | if not os.path.isdir(onnx_folder): 18 | os.makedirs(onnx_folder) 19 | 20 | onnx_path = os.path.join(onnx_folder, "whisper_encoder.onnx") 21 | 22 | torch.onnx.export( 23 | encoder, 24 | mel, 25 | onnx_path, 26 | input_names=["mel"], 27 | output_names=["output_features"] 28 | ) 29 | 30 | # use model optimizer to convert onnx to OpenVINO IR format 31 | encoder_model = mo.convert_model(onnx_path, compress_to_fp16=True) 32 | serialize(encoder_model, xml_path=os.path.join(os.path.dirname(__file__),"ggml-" + mname + "-encoder-openvino.xml")) 33 | 34 | #cleanup 35 | if os.path.isdir(onnx_folder): 36 | shutil.rmtree(onnx_folder) 37 | 38 | 39 | if __name__ == "__main__": 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument("--model", type=str, help="model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large-v1, large-v2, large-v3)", required=True) 42 | args = parser.parse_args() 43 | 44 | if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large-v1", "large-v2", "large-v3"]: 45 | raise ValueError("Invalid model name") 46 | 47 | whisper = load_model(args.model).cpu() 48 | hparams = whisper.dims 49 | 50 | encoder = whisper.encoder 51 | 52 | # Convert encoder to onnx 53 | convert_encoder(hparams, encoder, args.model) 54 | -------------------------------------------------------------------------------- /models/download-coreml-model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This script downloads Whisper model files that have already been converted to Core ML format. 4 | # This way you don't have to convert them yourself. 5 | 6 | src="https://huggingface.co/datasets/ggerganov/whisper.cpp-coreml" 7 | pfx="resolve/main/ggml" 8 | 9 | # get the path of this script 10 | function get_script_path() { 11 | if [ -x "$(command -v realpath)" ]; then 12 | echo "$(dirname $(realpath $0))" 13 | else 14 | local ret="$(cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P)" 15 | echo "$ret" 16 | fi 17 | } 18 | 19 | models_path="$(get_script_path)" 20 | 21 | # Whisper models 22 | models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large-v2" "large-v3" ) 23 | 24 | # list available models 25 | function list_models { 26 | printf "\n" 27 | printf " Available models:" 28 | for model in "${models[@]}"; do 29 | printf " $model" 30 | done 31 | printf "\n\n" 32 | } 33 | 34 | if [ "$#" -ne 1 ]; then 35 | printf "Usage: $0 \n" 36 | list_models 37 | 38 | exit 1 39 | fi 40 | 41 | model=$1 42 | 43 | if [[ ! " ${models[@]} " =~ " ${model} " ]]; then 44 | printf "Invalid model: $model\n" 45 | list_models 46 | 47 | exit 1 48 | fi 49 | 50 | # download Core ML model 51 | 52 | printf "Downloading Core ML model $model from '$src' ...\n" 53 | 54 | cd $models_path 55 | 56 | if [ -f "ggml-$model.mlmodel" ]; then 57 | printf "Model $model already exists. Skipping download.\n" 58 | exit 0 59 | fi 60 | 61 | if [ -x "$(command -v wget)" ]; then 62 | wget --quiet --show-progress -O ggml-$model.mlmodel $src/$pfx-$model.mlmodel 63 | elif [ -x "$(command -v curl)" ]; then 64 | curl -L --output ggml-$model.mlmodel $src/$pfx-$model.mlmodel 65 | else 66 | printf "Either wget or curl is required to download models.\n" 67 | exit 1 68 | fi 69 | 70 | 71 | if [ $? -ne 0 ]; then 72 | printf "Failed to download Core ML model $model \n" 73 | printf "Please try again later or download the original Whisper model files and convert them yourself.\n" 74 | exit 1 75 | fi 76 | 77 | printf "Done! Model '$model' saved in 'models/ggml-$model.mlmodel'\n" 78 | printf "Run the following command to compile it:\n\n" 79 | printf " $ xcrun coremlc compile ./models/ggml-$model.mlmodel ./models\n\n" 80 | printf "You can now use it like this:\n\n" 81 | printf " $ ./main -m models/ggml-$model.bin -f samples/jfk.wav\n" 82 | printf "\n" 83 | -------------------------------------------------------------------------------- /models/download-ggml-model.cmd: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | pushd %~dp0 4 | set models_path=%CD% 5 | for %%d in (%~dp0..) do set root_path=%%~fd 6 | popd 7 | 8 | set argc=0 9 | for %%x in (%*) do set /A argc+=1 10 | 11 | set models=tiny.en tiny base.en base small.en small medium.en medium large-v1 large-v2 large-v3 12 | 13 | if %argc% neq 1 ( 14 | echo. 15 | echo Usage: download-ggml-model.cmd model 16 | CALL :list_models 17 | goto :eof 18 | ) 19 | 20 | set model=%1 21 | 22 | for %%b in (%models%) do ( 23 | if "%%b"=="%model%" ( 24 | CALL :download_model 25 | goto :eof 26 | ) 27 | ) 28 | 29 | echo Invalid model: %model% 30 | CALL :list_models 31 | goto :eof 32 | 33 | :download_model 34 | echo Downloading ggml model %model%... 35 | 36 | cd "%models_path%" 37 | 38 | if exist "ggml-%model%.bin" ( 39 | echo Model %model% already exists. Skipping download. 40 | goto :eof 41 | ) 42 | 43 | PowerShell -NoProfile -ExecutionPolicy Bypass -Command "Start-BitsTransfer -Source https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-%model%.bin -Destination ggml-%model%.bin" 44 | 45 | if %ERRORLEVEL% neq 0 ( 46 | echo Failed to download ggml model %model% 47 | echo Please try again later or download the original Whisper model files and convert them yourself. 48 | goto :eof 49 | ) 50 | 51 | echo Done! Model %model% saved in %root_path%\models\ggml-%model%.bin 52 | echo You can now use it like this: 53 | echo main.exe -m %root_path%\models\ggml-%model%.bin -f %root_path%\samples\jfk.wav 54 | 55 | goto :eof 56 | 57 | :list_models 58 | echo. 59 | echo Available models: 60 | (for %%a in (%models%) do ( 61 | echo %%a 62 | )) 63 | echo. 64 | exit /b 65 | -------------------------------------------------------------------------------- /models/download-ggml-model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This script downloads Whisper model files that have already been converted to ggml format. 4 | # This way you don't have to convert them yourself. 5 | 6 | #src="https://ggml.ggerganov.com" 7 | #pfx="ggml-model-whisper" 8 | 9 | src="https://huggingface.co/ggerganov/whisper.cpp" 10 | pfx="resolve/main/ggml" 11 | 12 | # get the path of this script 13 | function get_script_path() { 14 | if [ -x "$(command -v realpath)" ]; then 15 | echo "$(dirname "$(realpath "$0")")" 16 | else 17 | local ret="$(cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P)" 18 | echo "$ret" 19 | fi 20 | } 21 | 22 | models_path="$(get_script_path)" 23 | 24 | # Whisper models 25 | models=( 26 | "tiny.en" 27 | "tiny" 28 | "tiny-q5_1" 29 | "tiny.en-q5_1" 30 | "base.en" 31 | "base" 32 | "base-q5_1" 33 | "base.en-q5_1" 34 | "small.en" 35 | "small.en-tdrz" 36 | "small" 37 | "small-q5_1" 38 | "small.en-q5_1" 39 | "medium" 40 | "medium.en" 41 | "medium-q5_0" 42 | "medium.en-q5_0" 43 | "large-v1" 44 | "large-v2" 45 | "large-v3" 46 | "large-q5_0" 47 | ) 48 | 49 | # list available models 50 | function list_models { 51 | printf "\n" 52 | printf " Available models:" 53 | for model in "${models[@]}"; do 54 | printf " $model" 55 | done 56 | printf "\n\n" 57 | } 58 | 59 | if [ "$#" -ne 1 ]; then 60 | printf "Usage: $0 \n" 61 | list_models 62 | 63 | exit 1 64 | fi 65 | 66 | model=$1 67 | 68 | if [[ ! " ${models[@]} " =~ " ${model} " ]]; then 69 | printf "Invalid model: $model\n" 70 | list_models 71 | 72 | exit 1 73 | fi 74 | 75 | # check if model contains `tdrz` and update the src and pfx accordingly 76 | if [[ $model == *"tdrz"* ]]; then 77 | src="https://huggingface.co/akashmjn/tinydiarize-whisper.cpp" 78 | pfx="resolve/main/ggml" 79 | fi 80 | 81 | # download ggml model 82 | 83 | printf "Downloading ggml model $model from '$src' ...\n" 84 | 85 | cd "$models_path" 86 | 87 | if [ -f "ggml-$model.bin" ]; then 88 | printf "Model $model already exists. Skipping download.\n" 89 | exit 0 90 | fi 91 | 92 | if [ -x "$(command -v wget)" ]; then 93 | wget --no-config --quiet --show-progress -O ggml-$model.bin $src/$pfx-$model.bin 94 | elif [ -x "$(command -v curl)" ]; then 95 | curl -L --output ggml-$model.bin $src/$pfx-$model.bin 96 | else 97 | printf "Either wget or curl is required to download models.\n" 98 | exit 1 99 | fi 100 | 101 | 102 | if [ $? -ne 0 ]; then 103 | printf "Failed to download ggml model $model \n" 104 | printf "Please try again later or download the original Whisper model files and convert them yourself.\n" 105 | exit 1 106 | fi 107 | 108 | printf "Done! Model '$model' saved in 'models/ggml-$model.bin'\n" 109 | printf "You can now use it like this:\n\n" 110 | printf " $ ./main -m models/ggml-$model.bin -f samples/jfk.wav\n" 111 | printf "\n" 112 | -------------------------------------------------------------------------------- /models/generate-coreml-interface.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # This generates: 4 | # - coreml/whisper-encoder-impl.h and coreml/whisper-encoder-impl.m 5 | # - coreml/whisper-decoder-impl.h and coreml/whisper-decoder-impl.m 6 | # 7 | 8 | wd=$(dirname "$0") 9 | cd "$wd/../" 10 | 11 | python3 models/convert-whisper-to-coreml.py --model tiny.en 12 | 13 | mv -v models/coreml-encoder-tiny.en.mlpackage models/whisper-encoder-impl.mlpackage 14 | xcrun coremlc generate models/whisper-encoder-impl.mlpackage coreml/ 15 | mv coreml/whisper_encoder_impl.h coreml/whisper-encoder-impl.h 16 | mv coreml/whisper_encoder_impl.m coreml/whisper-encoder-impl.m 17 | sed -i '' 's/whisper_encoder_impl\.h/whisper-encoder-impl.h/g' coreml/whisper-encoder-impl.m 18 | sed -i '' 's/whisper_encoder_impl\.m/whisper-encoder-impl.m/g' coreml/whisper-encoder-impl.m 19 | sed -i '' 's/whisper_encoder_impl\.h/whisper-encoder-impl.h/g' coreml/whisper-encoder-impl.h 20 | 21 | mv -v models/coreml-decoder-tiny.en.mlpackage models/whisper-decoder-impl.mlpackage 22 | xcrun coremlc generate models/whisper-decoder-impl.mlpackage coreml/ 23 | mv coreml/whisper_decoder_impl.h coreml/whisper-decoder-impl.h 24 | mv coreml/whisper_decoder_impl.m coreml/whisper-decoder-impl.m 25 | sed -i '' 's/whisper_decoder_impl\.h/whisper-decoder-impl.h/g' coreml/whisper-decoder-impl.m 26 | sed -i '' 's/whisper_decoder_impl\.m/whisper-decoder-impl.m/g' coreml/whisper-decoder-impl.m 27 | sed -i '' 's/whisper_decoder_impl\.h/whisper-decoder-impl.h/g' coreml/whisper-decoder-impl.h 28 | 29 | rm -rfv models/whisper-encoder-impl.mlpackage models/whisper-decoder-impl.mlpackage 30 | -------------------------------------------------------------------------------- /models/generate-coreml-model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Usage: ./generate-coreml-model.sh 4 | if [ $# -eq 0 ]; then 5 | echo "No model name supplied" 6 | echo "Usage for Whisper models: ./generate-coreml-model.sh " 7 | echo "Usage for HuggingFace models: ./generate-coreml-model.sh -h5 " 8 | exit 1 9 | elif [[ "$1" == "-h5" && $# != 3 ]]; then 10 | echo "No model name and model path supplied for a HuggingFace model" 11 | echo "Usage for HuggingFace models: ./generate-coreml-model.sh -h5 " 12 | exit 1 13 | fi 14 | 15 | mname="$1" 16 | 17 | wd=$(dirname "$0") 18 | cd "$wd/../" 19 | 20 | if [[ $mname == "-h5" ]]; then 21 | mname="$2" 22 | mpath="$3" 23 | echo $mpath 24 | python3 models/convert-h5-to-coreml.py --model-name $mname --model-path $mpath --encoder-only True 25 | else 26 | python3 models/convert-whisper-to-coreml.py --model $mname --encoder-only True 27 | fi 28 | 29 | xcrun coremlc compile models/coreml-encoder-${mname}.mlpackage models/ 30 | rm -rf models/ggml-${mname}-encoder.mlmodelc 31 | mv -v models/coreml-encoder-${mname}.mlmodelc models/ggml-${mname}-encoder.mlmodelc 32 | 33 | # TODO: decoder (sometime in the future maybe) 34 | #xcrun coremlc compile models/whisper-decoder-${mname}.mlpackage models/ 35 | #rm -rf models/ggml-${mname}-decoder.mlmodelc 36 | #mv -v models/coreml_decoder_${mname}.mlmodelc models/ggml-${mname}-decoder.mlmodelc 37 | -------------------------------------------------------------------------------- /models/ggml-tiny.mlmodel: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/litongjava/whisper-cpp-server/509f5aca1b6c35257d5f671e92385c892ddcf9ac/models/ggml-tiny.mlmodel -------------------------------------------------------------------------------- /models/ggml_to_pt.py: -------------------------------------------------------------------------------- 1 | import struct 2 | import torch 3 | import numpy as np 4 | from collections import OrderedDict 5 | from pathlib import Path 6 | import sys 7 | 8 | if len(sys.argv) < 3: 9 | print( 10 | "Usage: convert-ggml-to-pt.py model.bin dir-output\n") 11 | sys.exit(1) 12 | 13 | fname_inp = Path(sys.argv[1]) 14 | dir_out = Path(sys.argv[2]) 15 | fname_out = dir_out / "torch-model.pt" 16 | 17 | 18 | 19 | # Open the ggml file 20 | with open(fname_inp, "rb") as f: 21 | # Read magic number and hyperparameters 22 | magic_number, n_vocab, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer, n_text_ctx, n_text_state, n_text_head, n_text_layer, n_mels, use_f16 = struct.unpack("12i", f.read(48)) 23 | print(f"Magic number: {magic_number}") 24 | print(f"Vocab size: {n_vocab}") 25 | print(f"Audio context size: {n_audio_ctx}") 26 | print(f"Audio state size: {n_audio_state}") 27 | print(f"Audio head size: {n_audio_head}") 28 | print(f"Audio layer size: {n_audio_layer}") 29 | print(f"Text context size: {n_text_ctx}") 30 | print(f"Text head size: {n_text_head}") 31 | print(f"Mel size: {n_mels}") 32 | # Read mel filters 33 | # mel_filters = np.fromfile(f, dtype=np.float32, count=n_mels * 2).reshape(n_mels, 2) 34 | # print(f"Mel filters: {mel_filters}") 35 | filters_shape_0 = struct.unpack("i", f.read(4))[0] 36 | print(f"Filters shape 0: {filters_shape_0}") 37 | filters_shape_1 = struct.unpack("i", f.read(4))[0] 38 | print(f"Filters shape 1: {filters_shape_1}") 39 | 40 | # Read tokenizer tokens 41 | # bytes = f.read(4) 42 | # print(bytes) 43 | 44 | 45 | # for i in range(filters.shape[0]): 46 | # for j in range(filters.shape[1]): 47 | # fout.write(struct.pack("f", filters[i][j])) 48 | mel_filters = np.zeros((filters_shape_0, filters_shape_1)) 49 | 50 | for i in range(filters_shape_0): 51 | for j in range(filters_shape_1): 52 | mel_filters[i][j] = struct.unpack("f", f.read(4))[0] 53 | 54 | bytes_data = f.read(4) 55 | num_tokens = struct.unpack("i", bytes_data)[0] 56 | tokens = {} 57 | 58 | 59 | for _ in range(num_tokens): 60 | token_len = struct.unpack("i", f.read(4))[0] 61 | token = f.read(token_len) 62 | tokens[token] = {} 63 | 64 | # Read model variables 65 | model_state_dict = OrderedDict() 66 | while True: 67 | try: 68 | n_dims, name_length, ftype = struct.unpack("iii", f.read(12)) 69 | except struct.error: 70 | break # End of file 71 | dims = [struct.unpack("i", f.read(4))[0] for _ in range(n_dims)] 72 | dims = dims[::-1] 73 | name = f.read(name_length).decode("utf-8") 74 | if ftype == 1: # f16 75 | data = np.fromfile(f, dtype=np.float16, count=np.prod(dims)).reshape(dims) 76 | else: # f32 77 | data = np.fromfile(f, dtype=np.float32, count=np.prod(dims)).reshape(dims) 78 | 79 | 80 | if name in ["encoder.conv1.bias", "encoder.conv2.bias"]: 81 | 82 | data = data[:, 0] 83 | 84 | 85 | model_state_dict[name] = torch.from_numpy(data) 86 | 87 | # Now you have the model's state_dict stored in model_state_dict 88 | # You can load this state_dict into a model with the same architecture 89 | 90 | # dims = ModelDimensions(**checkpoint["dims"]) 91 | # model = Whisper(dims) 92 | from whisper import Whisper, ModelDimensions 93 | dims = ModelDimensions( 94 | n_mels=n_mels, 95 | n_audio_ctx=n_audio_ctx, 96 | n_audio_state=n_audio_state, 97 | n_audio_head=n_audio_head, 98 | n_audio_layer=n_audio_layer, 99 | n_text_ctx=n_text_ctx, 100 | n_text_state=n_text_state, 101 | n_text_head=n_text_head, 102 | n_text_layer=n_text_layer, 103 | n_vocab=n_vocab, 104 | ) 105 | model = Whisper(dims) # Replace with your model's class 106 | model.load_state_dict(model_state_dict) 107 | 108 | # Save the model in PyTorch format 109 | torch.save(model.state_dict(), fname_out) 110 | -------------------------------------------------------------------------------- /models/openvino-conversion-requirements.txt: -------------------------------------------------------------------------------- 1 | openvino-dev[pytorch,onnx] 2 | openai-whisper -------------------------------------------------------------------------------- /params/whisper_params.cpp: -------------------------------------------------------------------------------- 1 | #include "whisper_params.h" 2 | 3 | void whisper_print_usage(int /*argc*/, char **argv, const whisper_params ¶ms, 4 | const server_params &sparams) { 5 | fprintf(stderr, "\n"); 6 | fprintf(stderr, "usage: %s [options] \n", argv[0]); 7 | fprintf(stderr, "\n"); 8 | fprintf(stderr, "options:\n"); 9 | fprintf(stderr, " -h, --help [default] show this help message and exit\n"); 10 | fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", 11 | params.n_threads); 12 | fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", 13 | params.n_processors); 14 | fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms); 15 | fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n); 16 | fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", 17 | params.duration_ms); 18 | fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", 19 | params.max_context); 20 | fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len); 21 | fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", 22 | params.split_on_word ? "true" : "false"); 23 | fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of); 24 | fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size); 25 | fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", 26 | params.word_thold); 27 | fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", 28 | params.entropy_thold); 29 | fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", 30 | params.logprob_thold); 31 | // fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); 32 | fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", 33 | params.debug_mode ? "true" : "false"); 34 | fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", 35 | params.translate ? "true" : "false"); 36 | fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", 37 | params.diarize ? "true" : "false"); 38 | fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", 39 | params.tinydiarize ? "true" : "false"); 40 | fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", 41 | params.no_fallback ? "true" : "false"); 42 | fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", 43 | params.print_special ? "true" : "false"); 44 | fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false"); 45 | fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", 46 | params.print_progress ? "true" : "false"); 47 | fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", 48 | params.no_timestamps ? "true" : "false"); 49 | fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", 50 | params.language.c_str()); 51 | fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", 52 | params.detect_language ? "true" : "false"); 53 | fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str()); 54 | fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); 55 | fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", 56 | params.openvino_encode_device.c_str()); 57 | // server params 58 | fprintf(stderr, " --host HOST, [%-7s] Hostname/ip-adress for the server\n", 59 | sparams.hostname.c_str()); 60 | fprintf(stderr, " --port PORT, [%-7d] Port number for the server\n", sparams.port); 61 | fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); 62 | fprintf(stderr, "\n"); 63 | } 64 | 65 | 66 | bool whisper_params_parse(int argc, char **argv, whisper_params ¶ms, server_params &sparams) { 67 | for (int i = 1; i < argc; i++) { 68 | std::string arg = argv[i]; 69 | 70 | if (arg == "-h" || arg == "--help") { 71 | whisper_print_usage(argc, argv, params, sparams); 72 | exit(0); 73 | } else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } 74 | else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); } 75 | else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); } 76 | else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); } 77 | else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); } 78 | else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); } 79 | else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); } 80 | else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); } 81 | else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); } 82 | else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } 83 | else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); } 84 | else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); } 85 | // else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } 86 | else if (arg == "-debug" || arg == "--debug-mode") { params.debug_mode = true; } 87 | else if (arg == "-tr" || arg == "--translate") { params.translate = true; } 88 | else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } 89 | else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; } 90 | else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; } 91 | else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; } 92 | else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; } 93 | else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } 94 | else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; } 95 | else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; } 96 | else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; } 97 | else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } 98 | else if (arg == "-dl" || arg == "--detect-language") { params.detect_language = true; } 99 | else if (arg == "--prompt") { params.prompt = argv[++i]; } 100 | else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } 101 | else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; } 102 | // server params 103 | else if (arg == "--port") { sparams.port = std::stoi(argv[++i]); } 104 | else if (arg == "--host") { sparams.hostname = argv[++i]; } 105 | else if (arg == "-ad" || arg == "--port") { params.openvino_encode_device = argv[++i]; } 106 | else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } 107 | else { 108 | fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); 109 | whisper_print_usage(argc, argv, params, sparams); 110 | exit(0); 111 | } 112 | } 113 | 114 | return true; 115 | } -------------------------------------------------------------------------------- /params/whisper_params.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | // output formats 12 | const std::string json_format = "json"; 13 | const std::string text_format = "text"; 14 | const std::string srt_format = "srt"; 15 | const std::string vjson_format = "verbose_json"; 16 | const std::string vtt_format = "vtt"; 17 | 18 | struct whisper_params { 19 | int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); 20 | int32_t n_processors = 1; 21 | int32_t offset_t_ms = 0; 22 | int32_t offset_n = 0; 23 | int32_t duration_ms = 0; 24 | int32_t progress_step = 5; 25 | int32_t max_context = -1; 26 | int32_t max_len = 0; 27 | int32_t best_of = 2; 28 | int32_t beam_size = -1; 29 | 30 | float word_thold = 0.01f; 31 | float entropy_thold = 2.40f; 32 | float logprob_thold = -1.00f; 33 | float userdef_temp = 0.20f; 34 | 35 | bool speed_up = false; 36 | bool debug_mode = false; 37 | bool translate = false; 38 | bool detect_language = false; 39 | bool diarize = false; 40 | bool tinydiarize = false; 41 | bool split_on_word = false; 42 | bool no_fallback = false; 43 | bool print_special = false; 44 | bool print_colors = false; 45 | bool print_progress = false; 46 | bool no_timestamps = false; 47 | bool log_score = false; 48 | bool use_gpu = true; 49 | bool stream = false; 50 | 51 | std::string language = "auto"; 52 | std::string prompt = ""; 53 | std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf"; 54 | std::string model = "../models/ggml-base.en.bin"; 55 | 56 | std::string response_format = json_format; 57 | 58 | // [TDRZ] speaker turn string 59 | std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line 60 | 61 | std::string openvino_encode_device = "CPU"; 62 | std::string audio_format="wav"; 63 | }; 64 | 65 | struct server_params { 66 | std::string hostname = "0.0.0.0"; 67 | std::string public_path = "public"; 68 | int32_t port = 8080; 69 | int32_t read_timeout = 600; 70 | int32_t write_timeout = 600; 71 | }; 72 | 73 | void whisper_print_usage(int /*argc*/, char **argv, const whisper_params ¶ms, const server_params &sparams); 74 | bool whisper_params_parse(int argc, char **argv, whisper_params ¶ms, server_params &sparams); -------------------------------------------------------------------------------- /pcm/16k_1.pcm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/litongjava/whisper-cpp-server/509f5aca1b6c35257d5f671e92385c892ddcf9ac/pcm/16k_1.pcm -------------------------------------------------------------------------------- /pcm/16k_57test.pcm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/litongjava/whisper-cpp-server/509f5aca1b6c35257d5f671e92385c892ddcf9ac/pcm/16k_57test.pcm -------------------------------------------------------------------------------- /pcm/16k_test.pcm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/litongjava/whisper-cpp-server/509f5aca1b6c35257d5f671e92385c892ddcf9ac/pcm/16k_test.pcm -------------------------------------------------------------------------------- /pcm/nopause.pcm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/litongjava/whisper-cpp-server/509f5aca1b6c35257d5f671e92385c892ddcf9ac/pcm/nopause.pcm -------------------------------------------------------------------------------- /public/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 语音识别 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 |

14 | 
15 | 
73 | 
74 | 
75 | 


--------------------------------------------------------------------------------
/resources/json/transcription01.json:
--------------------------------------------------------------------------------
 1 | {
 2 |   "transcription": [
 3 |     {
 4 |       "timestamps": {
 5 |         "from": "00:00:00,000",
 6 |         "to": "00:00:30,000"
 7 |       },
 8 |       "offsets": {
 9 |         "from": 0,
10 |         "to": 30000
11 |       },
12 |       "token": {
13 |         "text": "[_BEG_]",
14 |         "id": 50363,
15 |         "confidence": 0.655302,
16 |         "t0": 0,
17 |         "t1": 0
18 |       },
19 |       "token": {
20 |         "text": " Thank",
21 |         "id": 6952,
22 |         "confidence": 0.673538,
23 |         "t0": 0,
24 |         "t1": 124
25 |       },
26 |       "token": {
27 |         "text": " you",
28 |         "id": 345,
29 |         "confidence": 0.931985,
30 |         "t0": 124,
31 |         "t1": 200
32 |       },
33 |       "token": {
34 |         "text": ".",
35 |         "id": 13,
36 |         "confidence": 0.826669,
37 |         "t0": 200,
38 |         "t1": 260
39 |       },
40 |       "token": {
41 |         "text": "[_TT_130]",
42 |         "id": 50493,
43 |         "confidence": 0.0699378,
44 |         "t0": 260,
45 |         "t1": 3000
46 |       },
47 |       "token": {
48 |         "text": "<|endoftext|>",
49 |         "id": 50256,
50 |         "confidence": 0.993674,
51 |         "t0": 3000,
52 |         "t1": 3000
53 |       },
54 |       "text": "[_BEG_] Thank you.[_TT_130]<|endoftext|>"
55 |     }
56 |   ]
57 | }


--------------------------------------------------------------------------------
/resources/json/whisper_local.json:
--------------------------------------------------------------------------------
 1 | {
 2 |   "systeminfo": "AVX = 0 | AVX2 = 0 | AVX512 = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | METAL = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | SSSE3 = 0 | VSX = 0 | CUDA = 0 | COREML = 1 | OPENVINO = 0 | ",
 3 |   "model": {
 4 |     "type": "base",
 5 |     "multilingual": false,
 6 |     "vocab": 51864,
 7 |     "audio": {
 8 |       "ctx": 1500,
 9 |       "state": 512,
10 |       "head": 8,
11 |       "layer": 6
12 |     },
13 |     "text": {
14 |       "ctx": 448,
15 |       "state": 512,
16 |       "head": 8,
17 |       "layer": 6
18 |     },
19 |     "mels": 80,
20 |     "ftype": 1
21 |   },
22 |   "params": {
23 |     "model": "models/ggml-base.en.bin",
24 |     "language": "en",
25 |     "translate": false
26 |   }
27 | }


--------------------------------------------------------------------------------
/samples/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | 


--------------------------------------------------------------------------------
/samples/README.md:
--------------------------------------------------------------------------------
1 | # Audio samples
2 | 
3 | This folder contains various audio files used for testing.
4 | If you want to quickly get some more samples, simply run `make samples`. This will download several public audio files and convert them to appropriate 16-bit WAV format using `ffmpeg`
5 | 
6 | https://github.com/ggerganov/whisper.cpp/blob/a09ce6e8899198015729ffc49ae10f67370906b1/Makefile#L104-L123
7 | 


--------------------------------------------------------------------------------
/samples/jfk.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/litongjava/whisper-cpp-server/509f5aca1b6c35257d5f671e92385c892ddcf9ac/samples/jfk.wav


--------------------------------------------------------------------------------
/samples/zh.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/litongjava/whisper-cpp-server/509f5aca1b6c35257d5f671e92385c892ddcf9ac/samples/zh.wav


--------------------------------------------------------------------------------
/stream/stream_components.h:
--------------------------------------------------------------------------------
 1 | #pragma once
 2 | 
 3 | namespace stream_components {
 4 |   struct whisper_local_stream_params {
 5 |     audio_params audio;
 6 |     service_params service;
 7 | 
 8 |     void initialize() {
 9 |       audio.initialize();
10 |       service.initialize();
11 |     }
12 |   };
13 | 
14 |   void whisper_print_usage(int /*argc*/, char **argv, const whisper_local_stream_params ¶ms) {
15 |     fprintf(stderr, "\n");
16 |     fprintf(stderr, "usage: %s [options]\n", argv[0]);
17 |     fprintf(stderr, "\n");
18 |     fprintf(stderr, "options:\n");
19 |     fprintf(stderr, "  -h,       --help          [default] show this help message and exit\n");
20 |     fprintf(stderr, "  -t N,     --threads N     [%-7d] number of threads to use during computation\n",
21 |             params.service.n_threads);
22 |     fprintf(stderr, "            --step N        [%-7d] audio step size in milliseconds\n", params.audio.step_ms);
23 |     fprintf(stderr, "            --length N      [%-7d] audio length in milliseconds\n", params.audio.length_ms);
24 |     fprintf(stderr, "            --keep N        [%-7d] audio to keep from previous step in ms\n",
25 |             params.audio.keep_ms);
26 |     fprintf(stderr, "  -c ID,    --capture ID    [%-7d] capture device ID\n", params.audio.capture_id);
27 |     //fprintf(stderr, "  -mt N,    --max-tokens N  [%-7d] maximum number of tokens per audio chunk\n",       params.max_tokens);
28 |     fprintf(stderr, "  -ac N,    --audio-ctx N   [%-7d] audio context size (0 - all)\n", params.audio.audio_ctx);
29 |     fprintf(stderr, "  -vth N,   --vad-thold N   [%-7.2f] voice activity detection threshold\n",
30 |             params.audio.vad_thold);
31 |     fprintf(stderr, "  -fth N,   --freq-thold N  [%-7.2f] high-pass frequency cutoff\n", params.audio.freq_thold);
32 |     fprintf(stderr, "  -su,      --speed-up      [%-7s] speed up audio by x2 (reduced accuracy)\n",
33 |             params.service.speed_up ? "true" : "false");
34 |     fprintf(stderr, "  -tr,      --translate     [%-7s] translate from source language to english\n",
35 |             params.service.translate ? "true" : "false");
36 |     fprintf(stderr, "  -nf,      --no-fallback   [%-7s] do not use temperature fallback while decoding\n",
37 |             params.service.no_fallback ? "true" : "false");
38 |     //fprintf(stderr, "  -ps,      --print-special [%-7s] print special tokens\n",                           params.print_special ? "true" : "false");
39 |     fprintf(stderr, "  -kc,      --keep-context  [%-7s] keep context between audio chunks\n",
40 |             params.service.no_context ? "false" : "true");
41 |     fprintf(stderr, "  -l LANG,  --language LANG [%-7s] spoken language\n", params.service.language.c_str());
42 |     fprintf(stderr, "  -m FNAME, --model FNAME   [%-7s] model path\n", params.service.model.c_str());
43 |     //fprintf(stderr, "  -f FNAME, --file FNAME    [%-7s] text output file name\n",                          params.fname_out.c_str());
44 |     fprintf(stderr, "  -tdrz,     --tinydiarize  [%-7s] enable tinydiarize (requires a tdrz model)\n",
45 |             params.service.tinydiarize ? "true" : "false");
46 |     //fprintf(stderr, "  -sa,      --save-audio    [%-7s] save the recorded audio to a file\n",              params.save_audio ? "true" : "false");
47 |     fprintf(stderr, "\n");
48 |   }
49 | 
50 |   bool whisper_params_parse(int argc, char **argv, whisper_local_stream_params ¶ms) {
51 |     for (int i = 1; i < argc; i++) {
52 |       std::string arg = argv[i];
53 | 
54 |       if (arg == "-h" || arg == "--help") {
55 |         whisper_print_usage(argc, argv, params);
56 |         exit(0);
57 |       } else if (arg == "-t" || arg == "--threads") { params.service.n_threads = std::stoi(argv[++i]); }
58 |       else if (arg == "-p" || arg == "--processors") { params.service.n_processors = std::stoi(argv[++i]); }
59 |       else if (arg == "--step") { params.audio.step_ms = std::stoi(argv[++i]); }
60 |       else if (arg == "--length") { params.audio.length_ms = std::stoi(argv[++i]); }
61 |       else if (arg == "--keep") { params.audio.keep_ms = std::stoi(argv[++i]); }
62 |       else if (arg == "-c" || arg == "--capture") { params.audio.capture_id = std::stoi(argv[++i]); }
63 |         //else if (arg == "-mt"  || arg == "--max-tokens")    { params.max_tokens    = std::stoi(argv[++i]); }
64 |       else if (arg == "-ac" || arg == "--audio-ctx") { params.audio.audio_ctx = std::stoi(argv[++i]); }
65 |       else if (arg == "-vth" || arg == "--vad-thold") { params.audio.vad_thold = std::stof(argv[++i]); }
66 |       else if (arg == "-fth" || arg == "--freq-thold") { params.audio.freq_thold = std::stof(argv[++i]); }
67 |       else if (arg == "-su" || arg == "--speed-up") { params.service.speed_up = true; }
68 |       else if (arg == "-tr" || arg == "--translate") { params.service.translate = true; }
69 |       else if (arg == "-nf" || arg == "--no-fallback") { params.service.no_fallback = true; }
70 |         //else if (arg == "-ps"  || arg == "--print-special") { params.print_special = true; }
71 |       else if (arg == "-kc" || arg == "--keep-context") { params.service.no_context = false; }
72 |       else if (arg == "-l" || arg == "--language") { params.service.language = argv[++i]; }
73 |       else if (arg == "-m" || arg == "--model") { params.service.model = argv[++i]; }
74 |         //else if (arg == "-f"   || arg == "--file")          { params.fname_out     = argv[++i]; }
75 |       else if (arg == "-tdrz" || arg == "--tinydiarize") { params.service.tinydiarize = true; }
76 |         //else if (arg == "-sa"  || arg == "--save-audio")    { params.save_audio    = true; }
77 | 
78 |       else {
79 |         fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
80 |         whisper_print_usage(argc, argv, params);
81 |         exit(0);
82 |       }
83 |     }
84 | 
85 |     return true;
86 |   }
87 | }


--------------------------------------------------------------------------------
/stream/stream_components_audio.cpp:
--------------------------------------------------------------------------------
  1 | #include "stream_components_audio.h"
  2 | 
  3 | using namespace stream_components;
  4 | 
  5 | // -- LocalSDLMicrophone --
  6 | 
  7 | LocalSDLMicrophone::LocalSDLMicrophone(audio_params ¶ms) : params(params),
  8 |                                                                audio(params.length_ms),
  9 |                                                                pcmf32_new(params.n_samples_30s, 0.0f),
 10 |                                                                pcmf32(params.n_samples_30s, 0.0f)
 11 | {
 12 |   fprintf(stderr, "%s: processing %d samples (step = %.1f sec / len = %.1f sec / keep = %.1f sec) ...\n",
 13 |           __func__,
 14 |           params.n_samples_step,
 15 |           float(params.n_samples_step) / WHISPER_SAMPLE_RATE,
 16 |           float(params.n_samples_len) / WHISPER_SAMPLE_RATE,
 17 |           float(params.n_samples_keep) / WHISPER_SAMPLE_RATE);
 18 | 
 19 |   if (params.use_vad)
 20 |   {
 21 |     fprintf(stderr, "%s: using VAD, will transcribe on speech activity\n", __func__);
 22 |   }
 23 | 
 24 |   if (!audio.init(params.capture_id, WHISPER_SAMPLE_RATE))
 25 |   {
 26 |     throw std::runtime_error("LocalSDLMicrophone(): audio.init() failed!");
 27 |   }
 28 | 
 29 |   audio.resume();
 30 | 
 31 |   t_last = std::chrono::steady_clock::now();
 32 |   t_start = t_last;
 33 | }
 34 | 
 35 | LocalSDLMicrophone::~LocalSDLMicrophone()
 36 | {
 37 |   audio.pause();
 38 | }
 39 | 
 40 | std::vector &LocalSDLMicrophone::get_next()
 41 | {
 42 |   if (!params.use_vad)
 43 |   {
 44 |     while (true)
 45 |     {
 46 |       audio.get(params.step_ms, pcmf32_new);
 47 | 
 48 |       if ((int)pcmf32_new.size() > 2 * params.n_samples_step)
 49 |       {
 50 |         fprintf(stderr, "\n\n%s: WARNING: cannot process audio fast enough, dropping audio ...\n\n", __func__);
 51 |         audio.clear();
 52 |         continue;
 53 |       }
 54 | 
 55 |       if ((int)pcmf32_new.size() >= params.n_samples_step)
 56 |       {
 57 |         audio.clear();
 58 |         break;
 59 |       }
 60 | 
 61 |       std::this_thread::sleep_for(std::chrono::milliseconds(1));
 62 |     }
 63 | 
 64 |     const int n_samples_new = pcmf32_new.size();
 65 | 
 66 |     // take up to params.length_ms audio from previous iteration
 67 |     const int n_samples_take = std::min((int)pcmf32_old.size(),
 68 |                                         std::max(0, params.n_samples_keep + params.n_samples_len - n_samples_new));
 69 | 
 70 |     // printf("processing: take = %d, new = %d, old = %d\n", n_samples_take, n_samples_new, (int) pcmf32_old.size());
 71 | 
 72 |     pcmf32.resize(n_samples_new + n_samples_take);
 73 | 
 74 |     for (int i = 0; i < n_samples_take; i++)
 75 |     {
 76 |       pcmf32[i] = pcmf32_old[pcmf32_old.size() - n_samples_take + i];
 77 |     }
 78 | 
 79 |     memcpy(pcmf32.data() + n_samples_take, pcmf32_new.data(), n_samples_new * sizeof(float));
 80 | 
 81 |     pcmf32_old = pcmf32;
 82 |   }
 83 |   else
 84 |   {
 85 |     while (true)
 86 |     {
 87 |       const auto t_now = std::chrono::steady_clock::now();
 88 |       const auto t_diff = std::chrono::duration_cast(t_now - t_last).count();
 89 | 
 90 |       if (t_diff < 2000)
 91 |       {
 92 |         std::this_thread::sleep_for(std::chrono::milliseconds(100));
 93 | 
 94 |         continue;
 95 |       }
 96 | 
 97 |       audio.get(2000, pcmf32_new);
 98 | 
 99 |       if (::vad_simple(pcmf32_new, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, false))
100 |       {
101 |         audio.get(params.length_ms, pcmf32);
102 |         t_last = t_now;
103 |         // done!
104 |         break;
105 |       }
106 |       else
107 |       {
108 |         std::this_thread::sleep_for(std::chrono::milliseconds(100));
109 |         continue;
110 |       }
111 |     }
112 |   }
113 | 
114 |   return pcmf32;
115 | }


--------------------------------------------------------------------------------
/stream/stream_components_audio.h:
--------------------------------------------------------------------------------
 1 | #ifndef WHISPER_STREAM_COMPONENTS_AUDIO_H
 2 | #define WHISPER_STREAM_COMPONENTS_AUDIO_H
 3 | 
 4 | #include 
 5 | #include "../common/common-sdl.h"
 6 | #include "../common/common.h"
 7 | #include "stream_components_params.h"
 8 | 
 9 | 
10 | namespace stream_components {
11 | 
12 | /**
13 |  * Encapsulates audio capture and processing.
14 |  * Represents an SDL audio device
15 |  */
16 |   class LocalSDLMicrophone {
17 |   public:
18 |     LocalSDLMicrophone(audio_params ¶ms);
19 | 
20 |     ~LocalSDLMicrophone();
21 | 
22 |     std::vector &get_next();
23 | 
24 |   protected:
25 |     audio_params params;
26 | 
27 |     audio_async audio;
28 |     std::chrono::steady_clock::time_point t_last;
29 |     std::chrono::steady_clock::time_point t_start;
30 | 
31 |     std::vector pcmf32_new;
32 |     std::vector pcmf32_old;
33 |     std::vector pcmf32;
34 |   };
35 | 
36 | } // namespace stream_components
37 | 
38 | #endif // WHISPER_STREAM_COMPONENTS_AUDIO_H


--------------------------------------------------------------------------------
/stream/stream_components_output.cpp:
--------------------------------------------------------------------------------
  1 | #include "stream_components_output.h"
  2 | #include "../common/utils.h"
  3 | using namespace stream_components;
  4 | 
  5 | 
  6 | 
  7 | char *escape_double_quotes_and_backslashes(const char *str) {
  8 |   if (str == NULL) {
  9 |     return NULL;
 10 |   }
 11 | 
 12 |   size_t escaped_length = strlen(str) + 1;
 13 | 
 14 |   for (size_t i = 0; str[i] != '\0'; i++) {
 15 |     if (str[i] == '"' || str[i] == '\\') {
 16 |       escaped_length++;
 17 |     }
 18 |   }
 19 | 
 20 |   char *escaped = (char *) calloc(escaped_length, 1); // pre-zeroed
 21 |   if (escaped == NULL) {
 22 |     return NULL;
 23 |   }
 24 | 
 25 |   size_t pos = 0;
 26 |   for (size_t i = 0; str[i] != '\0'; i++) {
 27 |     if (str[i] == '"' || str[i] == '\\') {
 28 |       escaped[pos++] = '\\';
 29 |     }
 30 |     escaped[pos++] = str[i];
 31 |   }
 32 | 
 33 |   // no need to set zero due to calloc() being used prior
 34 | 
 35 |   return escaped;
 36 | }
 37 | 
 38 | 
 39 | // -- WhisperEncoderJSON --
 40 | 
 41 | void WhisperEncoderJSON::start_arr(const char *name) {
 42 |   doindent();
 43 |   fout << "\"" << name << "\": [\n";
 44 |   indent++;
 45 | }
 46 | 
 47 | void WhisperEncoderJSON::end_arr(bool end) {
 48 |   indent--;
 49 |   doindent();
 50 |   fout << (end ? "]\n" : "},\n");
 51 | }
 52 | 
 53 | void WhisperEncoderJSON::start_obj(const char *name) {
 54 |   doindent();
 55 |   if (name) {
 56 |     fout << "\"" << name << "\": {\n";
 57 |   } else {
 58 |     fout << "{\n";
 59 |   }
 60 |   indent++;
 61 | }
 62 | 
 63 | void WhisperEncoderJSON::end_obj(bool end) {
 64 |   indent--;
 65 |   doindent();
 66 |   fout << (end ? "}\n" : "},\n");
 67 | }
 68 | 
 69 | void WhisperEncoderJSON::start_value(const char *name) {
 70 |   doindent();
 71 |   fout << "\"" << name << "\": ";
 72 | }
 73 | 
 74 | void WhisperEncoderJSON::value_s(const char *name, const char *val, bool end) {
 75 |   start_value(name);
 76 |   char *val_escaped = escape_double_quotes_and_backslashes(val);
 77 |   fout << "\"" << val_escaped << (end ? "\"\n" : "\",\n");
 78 |   free(val_escaped);
 79 | }
 80 | 
 81 | void WhisperEncoderJSON::end_value(bool end) {
 82 |   fout << (end ? "\n" : ",\n");
 83 | }
 84 | 
 85 | void WhisperEncoderJSON::value_i(const char *name, const int64_t val, bool end) {
 86 |   start_value(name);
 87 |   fout << val;
 88 |   end_value(end);
 89 | }
 90 | 
 91 | void WhisperEncoderJSON::value_b(const char *name, const bool val, bool end) {
 92 |   start_value(name);
 93 |   fout << (val ? "true" : "false");
 94 |   end_value(end);
 95 | }
 96 | 
 97 | void WhisperEncoderJSON::value_f(const char *name, const float val, bool end) {
 98 |   start_value(name);
 99 |   fout << val;
100 |   end_value(end);
101 | }
102 | 
103 | void WhisperEncoderJSON::doindent() {
104 |   for (int i = 0; i < indent; i++) fout << "\t";
105 | }
106 | 
107 | 
108 | // -- WhisperStreamOutput --
109 | 
110 | WhisperStreamOutput::WhisperStreamOutput(
111 |   struct whisper_context *ctx,
112 |   const service_params ¶ms) :
113 |   ctx(ctx),
114 |   params(params) {
115 | 
116 | }
117 | 
118 | void WhisperStreamOutput::encode_server(
119 |   WhisperEncoder &encoder,
120 |   const service_params ¶ms,
121 |   struct whisper_context *ctx) {
122 |   encoder.reset();
123 | 
124 |   encoder.start_obj(nullptr);
125 |   encoder.value_s("systeminfo", whisper_print_system_info(), false);
126 |   encoder.start_obj("model");
127 |   encoder.value_s("type", whisper_model_type_readable(ctx), false);
128 |   encoder.value_b("multilingual", whisper_is_multilingual(ctx), false);
129 |   encoder.value_i("vocab", whisper_model_n_vocab(ctx), false);
130 |   encoder.start_obj("audio");
131 |   encoder.value_i("ctx", whisper_model_n_audio_ctx(ctx), false);
132 |   encoder.value_i("state", whisper_model_n_audio_state(ctx), false);
133 |   encoder.value_i("head", whisper_model_n_audio_head(ctx), false);
134 |   encoder.value_i("layer", whisper_model_n_audio_layer(ctx), true);
135 |   encoder.end_obj(false);
136 |   encoder.start_obj("text");
137 |   encoder.value_i("ctx", whisper_model_n_text_ctx(ctx), false);
138 |   encoder.value_i("state", whisper_model_n_text_state(ctx), false);
139 |   encoder.value_i("head", whisper_model_n_text_head(ctx), false);
140 |   encoder.value_i("layer", whisper_model_n_text_layer(ctx), true);
141 |   encoder.end_obj(false);
142 |   encoder.value_i("mels", whisper_model_n_mels(ctx), false);
143 |   encoder.value_i("ftype", whisper_model_ftype(ctx), true);
144 |   encoder.end_obj(false);
145 |   encoder.start_obj("params");
146 |   encoder.value_s("model", params.model.c_str(), false);
147 |   encoder.value_s("language", params.language.c_str(), false);
148 |   encoder.value_b("translate", params.translate, true);
149 |   encoder.end_obj(false);
150 |   encoder.end_obj(true);
151 | }
152 | 
153 | void WhisperStreamOutput::encode_transcription(WhisperEncoder &encoder) const {
154 |   encoder.reset();
155 | 
156 |   encoder.start_obj(nullptr);
157 |   encoder.start_arr("transcription");
158 | 
159 |   const int n_segments = whisper_full_n_segments(ctx);
160 |   for (int i = 0; i < n_segments; ++i) {
161 | 
162 |     const char *text = whisper_full_get_segment_text(ctx, i);
163 | 
164 |     const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
165 |     const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
166 | 
167 |     encoder.start_obj(nullptr);
168 | 
169 |     // These don't seem to be useful...
170 |     encoder.start_obj("timestamps");
171 |     encoder.value_s("from", to_timestamp(t0, true).c_str(), false);
172 |     encoder.value_s("to", to_timestamp(t1, true).c_str(), true);
173 |     encoder.end_obj(false);
174 |     encoder.start_obj("offsets");
175 |     encoder.value_i("from", t0 * 10, false);
176 |     encoder.value_i("to", t1 * 10, true);
177 |     encoder.end_obj(false);
178 | 
179 |     for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
180 | 
181 |       whisper_token_data token = whisper_full_get_token_data(ctx, i, j);
182 | 
183 |       const char *text = whisper_full_get_token_text(ctx, i, j);
184 |       const float p = whisper_full_get_token_p(ctx, i, j);
185 |       encoder.start_obj("token");
186 |       encoder.value_s("text", text, false);
187 |       encoder.value_i("id", token.id, false);
188 |       encoder.value_f("confidence", p, false);
189 |       encoder.value_i("t0", token.t0, false);
190 |       encoder.value_i("t1", token.t1, true);
191 |       encoder.end_obj(false);
192 |     }
193 |     encoder.value_s("text", text, !params.diarize && !params.tinydiarize);
194 | 
195 |     // if (params.diarize && pcmf32s.size() == 2) {
196 |     //     encoder.value_s("speaker", estimate_diarization_speaker(pcmf32s, t0, t1, true).c_str(), true);
197 |     // }
198 | 
199 |     if (params.tinydiarize) {
200 |       encoder.value_b("speaker_turn_next", whisper_full_get_segment_speaker_turn_next(ctx, i), true);
201 |     }
202 |     encoder.end_obj(i == (n_segments - 1));
203 |   }
204 | 
205 |   encoder.end_arr(true);
206 |   encoder.end_obj(true);
207 | }
208 | 
209 | void WhisperStreamOutput::to_json(std::ostream &os, const service_params ¶ms, struct whisper_context *ctx) {
210 |   WhisperEncoderJSON encoder(os);
211 |   WhisperStreamOutput::encode_server(encoder, params, ctx);
212 | }
213 | 
214 | void WhisperStreamOutput::transcription_to_json(std::ostream &os) const {
215 |   WhisperEncoderJSON encoder(os);
216 |   encode_transcription(encoder);
217 | }


--------------------------------------------------------------------------------
/stream/stream_components_output.h:
--------------------------------------------------------------------------------
 1 | #ifndef WHISPER_STREAM_COMPONENTS_OUTPUT_H
 2 | #define WHISPER_STREAM_COMPONENTS_OUTPUT_H
 3 | 
 4 | #include 
 5 | #include 
 6 | #include "whisper.h"
 7 | #include "stream_components_params.h"
 8 | 
 9 | /**
10 |  * Classes that support componentization of the service.
11 |  * These classes support encoding the service state and the transcription
12 |  * into JSON (and in the future, into other formats such as msgpack,
13 |  * flexbuffers, etc.)
14 |  */
15 | 
16 | namespace stream_components {
17 | 
18 |   class WhisperEncoder {
19 |   public:
20 |     virtual void reset() = 0;
21 | 
22 |     virtual void start_arr(const char *name) = 0;
23 | 
24 |     virtual void end_arr(bool end) = 0;
25 | 
26 |     virtual void start_obj(const char *name) = 0;
27 | 
28 |     virtual void end_obj(bool end) = 0;
29 | 
30 |     virtual void start_value(const char *name) = 0;
31 | 
32 |     virtual void value_s(const char *name, const char *val, bool end) = 0;
33 | 
34 |     virtual void end_value(bool end) = 0;
35 | 
36 |     virtual void value_i(const char *name, const int64_t val, bool end) = 0;
37 | 
38 |     virtual void value_b(const char *name, const bool val, bool end) = 0;
39 | 
40 |     virtual void value_f(const char *name, const float val, bool end) = 0;
41 |   };
42 | 
43 |   class WhisperEncoderJSON : public WhisperEncoder {
44 |   public:
45 |     WhisperEncoderJSON(std::ostream &os) : fout(os) {}
46 | 
47 |     void reset() override { indent = 0; }
48 | 
49 |     void start_arr(const char *name) override;
50 | 
51 |     void end_arr(bool end) override;
52 | 
53 |     void start_obj(const char *name) override;
54 | 
55 |     void end_obj(bool end) override;
56 | 
57 |     void start_value(const char *name) override;
58 | 
59 |     void value_s(const char *name, const char *val, bool end) override;
60 | 
61 |     void end_value(bool end) override;
62 | 
63 |     void value_i(const char *name, const int64_t val, bool end) override;
64 | 
65 |     void value_b(const char *name, const bool val, bool end) override;
66 | 
67 |     void value_f(const char *name, const float val, bool end) override;
68 | 
69 |   protected:
70 |     std::ostream &fout;
71 |     int indent = 0;
72 | 
73 |     void doindent();
74 |   };
75 | 
76 |   class WhisperStreamOutput {
77 |   public:
78 |     WhisperStreamOutput(
79 |       struct whisper_context *ctx,
80 |       const service_params ¶ms);
81 | 
82 |     static void encode_server(WhisperEncoder &encoder, const service_params ¶ms, struct whisper_context *ctx);
83 | 
84 |     void encode_transcription(WhisperEncoder &encoder) const;
85 | 
86 |     static void to_json(std::ostream &os, const service_params ¶ms, struct whisper_context *ctx);
87 | 
88 |     void transcription_to_json(std::ostream &os) const;
89 | 
90 |   protected:
91 |     struct whisper_context *ctx;
92 |     const service_params params;
93 |   };
94 | 
95 |   using WhisperOutputPtr = std::shared_ptr;
96 | 
97 | } // namespace stream_components
98 | 
99 | #endif // WHISPER_STREAM_COMPONENTS_OUTPUT_H


--------------------------------------------------------------------------------
/stream/stream_components_params.h:
--------------------------------------------------------------------------------
 1 | #ifndef WHISPER_STREAM_COMPONENTS_PARAMS_H
 2 | #define WHISPER_STREAM_COMPONENTS_PARAMS_H
 3 | 
 4 | #include 
 5 | #include 
 6 | #include "whisper.h"
 7 | namespace stream_components {
 8 | 
 9 |   struct audio_params {
10 |     int32_t step_ms = 3000;
11 |     int32_t length_ms = 10000;
12 |     int32_t keep_ms = 200;
13 | 
14 |     int32_t capture_id = -1;
15 |     int32_t audio_ctx = 0;
16 | 
17 |     int32_t n_samples_step = 0;
18 |     int32_t n_samples_keep = 0;
19 |     int32_t n_samples_len = 0;
20 |     int32_t n_samples_30s = 0;
21 |     bool use_vad = true;
22 | 
23 |     float vad_thold = 0.6f;
24 |     float freq_thold = 100.0f;
25 | 
26 |     void initialize() {
27 |       keep_ms = std::min(keep_ms, step_ms);
28 |       length_ms = std::max(length_ms, step_ms);
29 | 
30 |       n_samples_step = int32_t(1e-3 * step_ms * WHISPER_SAMPLE_RATE);
31 |       n_samples_keep = int32_t(1e-3 * keep_ms * WHISPER_SAMPLE_RATE);
32 |       n_samples_len = int32_t(1e-3 * length_ms * WHISPER_SAMPLE_RATE);
33 |       n_samples_30s = int32_t(1e-3 * 30000.0 * WHISPER_SAMPLE_RATE);
34 | //      use_vad = n_samples_step <= 0;
35 |     }
36 |   };
37 | 
38 |   struct service_params {
39 |     int32_t n_threads = std::min(8, (int32_t) std::thread::hardware_concurrency());
40 |     int32_t n_processors = 1;
41 |     bool speed_up = false;
42 |     bool translate = false;
43 |     bool no_fallback = false;
44 |     bool no_context = true;
45 |     bool no_timestamps = true;
46 |     bool save_audio = false;
47 | 
48 |     bool tinydiarize = false;
49 |     bool diarize = false;
50 |     bool use_gpu = true;
51 | 
52 |     std::string language = "en";
53 |     std::string model = "../models/ggml-base.en.bin";
54 | 
55 |     void initialize() {}
56 |   };
57 | } // namespace stream_components
58 | 
59 | 
60 | #endif // WHISPER_STREAM_COMPONENTS_PARAMS_H


--------------------------------------------------------------------------------
/stream/stream_components_service.cpp:
--------------------------------------------------------------------------------
  1 | #include "stream_components_service.h"
  2 | 
  3 | using namespace stream_components;
  4 | 
  5 | // -- WhisperService --
  6 | 
  7 | WhisperService::WhisperService(const struct service_params &sparams,
  8 |                                const struct audio_params &aparams,
  9 |                                const struct whisper_context_params &cparams)
 10 |     : sparams(sparams),
 11 |       aparams(aparams),
 12 |       ctx(whisper_init_from_file_with_params(sparams.model.c_str(), cparams))
 13 | {
 14 |   // print system information
 15 |   {
 16 |     fprintf(stderr, "\n");
 17 |     fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
 18 |             sparams.n_threads * sparams.n_processors, std::thread::hardware_concurrency(),
 19 |             whisper_print_system_info());
 20 |   }
 21 |   {
 22 |     fprintf(stderr, "\n");
 23 |     if (!whisper_is_multilingual(ctx))
 24 |     {
 25 |       if (sparams.language != "en" || sparams.translate)
 26 |       {
 27 |         this->sparams.language = "en";
 28 |         this->sparams.translate = false;
 29 |         fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n",
 30 |                 __func__);
 31 |       }
 32 |     }
 33 |     fprintf(stderr, "%s: serving with %d threads, %d processors, lang = %s, task = %s, timestamps = %d ...\n",
 34 |             __func__,
 35 |             sparams.n_threads,
 36 |             sparams.n_processors,
 37 |             sparams.language.c_str(),
 38 |             sparams.translate ? "translate" : "transcribe",
 39 |             sparams.no_timestamps ? 0 : 1);
 40 | 
 41 |     //     if (!audio_params.use_vad) {
 42 |     //         fprintf(stderr, "%s: n_new_line = %d, no_context = %d\n", __func__, n_new_line, sparams.no_context);
 43 |     //     }
 44 | 
 45 |     fprintf(stderr, "\n");
 46 |   }
 47 | }
 48 | 
 49 | WhisperService::~WhisperService()
 50 | {
 51 |   whisper_print_timings(ctx);
 52 |   whisper_free(ctx);
 53 | }
 54 | 
 55 | bool WhisperService::process(const float *samples, int sample_count)
 56 | {
 57 |   whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
 58 | 
 59 |   wparams.print_progress = false;
 60 |   wparams.print_realtime = false;
 61 |   wparams.print_timestamps = false;
 62 |   wparams.print_special = true;
 63 |   wparams.max_tokens = 0;
 64 |   wparams.token_timestamps = true;
 65 | 
 66 |   wparams.translate = sparams.translate;
 67 |   wparams.single_segment = !aparams.use_vad;
 68 |   wparams.language = sparams.language.c_str();
 69 |   wparams.n_threads = sparams.n_threads;
 70 | 
 71 |   wparams.audio_ctx = aparams.audio_ctx;
 72 |   wparams.speed_up = sparams.speed_up;
 73 | 
 74 |   wparams.tdrz_enable = sparams.tinydiarize; // [TDRZ]
 75 | 
 76 |   // disable temperature fallback
 77 |   // wparams.temperature_inc  = -1.0f;
 78 |   wparams.temperature_inc = sparams.no_fallback ? 0.0f : wparams.temperature_inc;
 79 | 
 80 |   wparams.prompt_tokens = sparams.no_context ? nullptr : prompt_tokens.data();
 81 |   wparams.prompt_n_tokens = sparams.no_context ? 0 : prompt_tokens.size();
 82 | 
 83 |   // *** Run the actual inference!!! ***
 84 |   //  if (whisper_full(ctx, wparams, samples, sample_count) != 0) {
 85 |   //    return false;
 86 |   //  }
 87 |   // whisper_full_parallel
 88 |   if (whisper_full_parallel(ctx, wparams, samples, sample_count, sparams.n_processors) != 0)
 89 |   {
 90 |     // error:ggml_metal_get_buffer: error: buffer is nil
 91 |     return false;
 92 |   }
 93 | 
 94 |   // Now sure whether n_iter and n_new_line should have ever been there...
 95 |   // *** SUSPICIOUS what happens by removing them? Are they essential?
 96 |   // if (!use_vad && (n_iter % n_new_line) == 0) {
 97 |   //  if (!audio_params.use_vad) {
 98 |   // printf("\n");
 99 | 
100 |   // keep part of the audio for next iteration to try to mitigate word boundary issues
101 |   // *** I don't know if we need this...
102 |   // pcmf32_old = std::vector(pcmf32.end() - n_samples_keep, pcmf32.end());
103 | 
104 |   // Add tokens of the last full length segment as the prompt
105 |   //    if (!sparams.no_context) {
106 |   //      prompt_tokens.
107 |   //
108 |   //        clear();
109 |   //
110 |   //      const int n_segments = whisper_full_n_segments(ctx);
111 |   //      for (
112 |   //        int i = 0;
113 |   //        i < n_segments;
114 |   //        ++i) {
115 |   //        const int token_count = whisper_full_n_tokens(ctx, i);
116 |   //        for (
117 |   //          int j = 0;
118 |   //          j < token_count;
119 |   //          ++j) {
120 |   //          prompt_tokens.
121 |   //            push_back(whisper_full_get_token_id(ctx, i, j)
122 |   //          );
123 |   //        }
124 |   //      }
125 |   //    }
126 |   //  }
127 |   return true;
128 | }


--------------------------------------------------------------------------------
/stream/stream_components_service.h:
--------------------------------------------------------------------------------
 1 | #ifndef WHISPER_STREAM_COMPONENTS_SERVER_H
 2 | #define WHISPER_STREAM_COMPONENTS_SERVER_H
 3 | 
 4 | #include 
 5 | #include 
 6 | #include "stream_components_params.h"
 7 | #include "stream_components_output.h"
 8 | 
 9 | using std::shared_ptr;
10 | 
11 | namespace stream_components
12 | {
13 | 
14 |   /**
15 |    * Encapsulates the Whisper service.
16 |    */
17 |   class WhisperService
18 |   {
19 |   public:
20 |     WhisperService(
21 |         const struct service_params &sparams,
22 |         const struct audio_params &aparams,
23 |         const struct whisper_context_params &cparams);
24 | 
25 |     ~WhisperService();
26 | 
27 |     bool process(const float *samples, int size);
28 | 
29 |     service_params sparams;
30 |     audio_params aparams;
31 | 
32 |     struct whisper_context *ctx;
33 | 
34 |   protected:
35 |     std::vector prompt_tokens;
36 |   };
37 | 
38 | } // namespace stream_components
39 | 
40 | #endif // WHISPER_STREAM_COMPONENTS_SERVER_H
41 | 


--------------------------------------------------------------------------------
/thirdparty/CMakeLists.txt:
--------------------------------------------------------------------------------
 1 | set(CMAKE_TOOLCHAIN_FILE "${CMAKE_CURRENT_SOURCE_DIR}/vcpkg/scripts/buildsystems/vcpkg.cmake" CACHE STRING "vcpkg toolchain file")
 2 | include("$CACHE{CMAKE_TOOLCHAIN_FILE}")
 3 | 
 4 | # find uWebSockets head file path
 5 | find_path(UWEBSOCKETS_INCLUDE_DIRS "uwebsockets/App.h")
 6 | # find zlib
 7 | find_package(ZLIB REQUIRED GLOBAL)
 8 | 
 9 | # find libuv
10 | find_package(libuv CONFIG REQUIRED GLOBAL)
11 | # find uSockets
12 | find_library(USOCKETS_LIBRARY uSockets)
13 | 
14 | # find SDL2 library
15 | find_package(SDL2 CONFIG REQUIRED GLOBAL)
16 | 
17 | find_package(SampleRate CONFIG REQUIRED GLOBAL)
18 | find_package(FFMPEG REQUIRED GLOBAL)
19 | # find SpeexDSP library
20 | find_library(SPEEXDSP_LIBRARY NAMES speexdsp)
21 | # find SPEEXDSP header file
22 | find_path(SPEEXDSP_INCLUDE_DIRS "speex/speex_preprocess.h")
23 | 
24 | # httplib
25 | find_package(httplib CONFIG REQUIRED GLOBAL)
26 | 
27 | # nlohmann json
28 | find_package(nlohmann_json CONFIG REQUIRED GLOBAL)
29 | 
30 | #Whispercpp
31 | set(BUILD_SHARED_LIBS OFF)
32 | add_subdirectory(whisper.cpp)
33 | 
34 | # dr_lib
35 | add_library(dr_lib_header-only INTERFACE)
36 | target_include_directories(dr_lib_header-only INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}/dr_libs)


--------------------------------------------------------------------------------
/vcpkg.json:
--------------------------------------------------------------------------------
 1 | {
 2 |   "name": "whisper-cpp-server",
 3 |   "version-string": "1.0.0",
 4 |   "builtin-baseline": "0c20b2a97c390e106150837042d921b0939e7ecb",
 5 |   "dependencies": [
 6 |     {
 7 |       "name": "sdl2",
 8 |       "version>=": "2.28.4#1",
 9 |       "$comment": "    find_package(SDL2 CONFIG REQUIRED)\n\n    target_link_libraries(main\n\n        PRIVATE\n\n        $\n\n        $,SDL2::SDL2,SDL2::SDL2-static>\n\n    )\n"
10 |     },
11 |     {
12 |       "name": "uwebsockets",
13 |       "version>=": "20.47.0"
14 |     },
15 |     {
16 |       "name": "libsamplerate",
17 |       "version>=": "0.2.2#1",
18 |       "$comment": "  # this is heuristically generated, and may not be correct\n\n  find_package(SampleRate CONFIG REQUIRED)\n\n  target_link_libraries(main PRIVATE SampleRate::samplerate)\n"
19 |     },
20 |     {
21 |       "name": "ffmpeg",
22 |       "version>=": "6.1"
23 |     },
24 |     {
25 |       "name": "speexdsp",
26 |       "version>=": "1.2.1#1"
27 |     },
28 |     {
29 |       "name": "cpp-httplib",
30 |       "version>=": "0.14.1"
31 |     },
32 |     {
33 |       "name": "nlohmann-json",
34 |       "version>=": "3.11.2"
35 |     },
36 |     {
37 |       "name": "libuv",
38 |       "version>=": "1.46.0"
39 |     }
40 |   ]
41 | }


--------------------------------------------------------------------------------
/web/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/litongjava/whisper-cpp-server/509f5aca1b6c35257d5f671e92385c892ddcf9ac/web/favicon.ico


--------------------------------------------------------------------------------
/web/paddle_web_demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/litongjava/whisper-cpp-server/509f5aca1b6c35257d5f671e92385c892ddcf9ac/web/paddle_web_demo.png


--------------------------------------------------------------------------------
/web/readme.md:
--------------------------------------------------------------------------------
 1 | # paddlespeech serving 网页Demo
 2 | 
 3 | ![图片](./paddle_web_demo.png)
 4 | 
 5 | step1: 开启流式语音识别服务器端
 6 | 
 7 | ```
 8 | # 开启流式语音识别服务
 9 | ./cmake-build-debug/whisper_server_base_on_uwebsockets -m models/ggml-base.en.bin
10 | ```
11 | 
12 | step2: 谷歌游览器打开 `web`目录下`index.html`
13 | 
14 | step3: 点击`连接`,验证WebSocket是否成功连接
15 | 
16 | step4:点击开始录音(弹窗询问,允许录音)
17 | 
18 | 
19 |  
20 | 


--------------------------------------------------------------------------------
/whisper_http_server_base_httplib.cpp:
--------------------------------------------------------------------------------
  1 | #include "whisper.h"
  2 | #include "httplib.h"
  3 | #include "params/whisper_params.h"
  4 | #include "handler/inference_handler.h"
  5 | #include 
  6 | #include 
  7 | #include 
  8 | #include 
  9 | #include 
 10 | 
 11 | #if defined(_MSC_VER)
 12 | #pragma warning(disable: 4244 4267) // possible loss of data
 13 | #endif
 14 | 
 15 | using namespace httplib;
 16 | 
 17 | bool is_file_exist(const char *fileName) {
 18 |   std::ifstream infile(fileName);
 19 |   return infile.good();
 20 | }
 21 | 
 22 | int main(int argc, char **argv) {
 23 |   whisper_params params;
 24 |   server_params sparams;
 25 | 
 26 |   std::mutex whisper_mutex;
 27 | 
 28 |   if (whisper_params_parse(argc, argv, params, sparams) == false) {
 29 |     whisper_print_usage(argc, argv, params, sparams);
 30 |     return 1;
 31 |   }
 32 | 
 33 |   if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) {
 34 |     fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
 35 |     whisper_print_usage(argc, argv, params, sparams);
 36 |     exit(0);
 37 |   }
 38 | 
 39 |   if (params.diarize && params.tinydiarize) {
 40 |     fprintf(stderr, "error: cannot use both --diarize and --tinydiarize\n");
 41 |     whisper_print_usage(argc, argv, params, sparams);
 42 |     exit(0);
 43 |   }
 44 | 
 45 |   // whisper init
 46 |   struct whisper_context_params cparams;
 47 |   cparams.use_gpu = params.use_gpu;
 48 |   struct whisper_context *ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
 49 | 
 50 |   if (ctx == nullptr) {
 51 |     fprintf(stderr, "error: failed to initialize whisper context\n");
 52 |     return 3;
 53 |   }
 54 | 
 55 |   // initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
 56 |   whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
 57 | 
 58 |   Server svr;
 59 | 
 60 |   std::string const default_content = "hello";
 61 | 
 62 |   // this is only called if no index.html is found in the public --path
 63 |   svr.Get("/", [&default_content](const Request &, Response &res) {
 64 |     res.set_content(default_content, "text/html");
 65 |     return false;
 66 |   });
 67 | 
 68 |   svr.Post("/inference", [&](const httplib::Request &req, httplib::Response &res) {
 69 |     handleInference(req, res, whisper_mutex, params, ctx, argv[0]);
 70 |   });
 71 |   svr.Post("/events", [&](const httplib::Request &req, httplib::Response &res) {
 72 |     handle_events(req, res, whisper_mutex, params, ctx, argv[0]);
 73 |   });
 74 |   svr.Post("/load", [&](const Request &req, Response &res) {
 75 |     whisper_mutex.lock();
 76 |     if (!req.has_file("model")) {
 77 |       fprintf(stderr, "error: no 'model' field in the request\n");
 78 |       const std::string error_resp = "{\"error\":\"no 'model' field in the request\"}";
 79 |       res.set_content(error_resp, "application/json");
 80 |       whisper_mutex.unlock();
 81 |       return;
 82 |     }
 83 |     std::string model = req.get_file_value("model").content;
 84 |     if (!is_file_exist(model.c_str())) {
 85 |       fprintf(stderr, "error: 'model': %s not found!\n", model.c_str());
 86 |       const std::string error_resp = "{\"error\":\"model not found!\"}";
 87 |       res.set_content(error_resp, "application/json");
 88 |       whisper_mutex.unlock();
 89 |       return;
 90 |     }
 91 | 
 92 |     // clean up
 93 |     whisper_free(ctx);
 94 | 
 95 |     // whisper init
 96 | //    ctx = whisper_init_from_file(model.c_str());
 97 |     struct whisper_context_params cparams;
 98 | 
 99 |     ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
100 | 
101 |     // TODO perhaps load prior model here instead of exit
102 |     if (ctx == nullptr) {
103 |       fprintf(stderr, "error: model init  failed, no model loaded must exit\n");
104 |       exit(1);
105 |     }
106 | 
107 |     // initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
108 |     whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
109 | 
110 |     const std::string success = "Load was successful!";
111 |     res.set_content(success, "application/text");
112 | 
113 |     // check if the model is in the file system
114 |     whisper_mutex.unlock();
115 |   });
116 | 
117 |   svr.set_exception_handler([](const Request &, Response &res, std::exception_ptr ep) {
118 |     const char fmt[] = "500 Internal Server Error\n%s";
119 |     char buf[BUFSIZ];
120 |     try {
121 |       std::rethrow_exception(std::move(ep));
122 |     } catch (std::exception &e) {
123 |       snprintf(buf, sizeof(buf), fmt, e.what());
124 |     } catch (...) {
125 |       snprintf(buf, sizeof(buf), fmt, "Unknown Exception");
126 |     }
127 |     res.set_content(buf, "text/plain");
128 |     res.status = 500;
129 |   });
130 | 
131 |   svr.set_error_handler([](const Request &, Response &res) {
132 |     if (res.status == 400) {
133 |       res.set_content("Invalid request", "text/plain");
134 |     } else if (res.status != 500) {
135 |       res.set_content("File Not Found", "text/plain");
136 |       res.status = 404;
137 |     }
138 |   });
139 | 
140 |   // set timeouts and change hostname and port
141 |   svr.set_read_timeout(sparams.read_timeout);
142 |   svr.set_write_timeout(sparams.write_timeout);
143 | 
144 |   if (!svr.bind_to_port(sparams.hostname, sparams.port)) {
145 |     fprintf(stderr, "\ncouldn't bind to service socket: hostname=%s port=%d\n\n",
146 |             sparams.hostname.c_str(), sparams.port);
147 |     return 1;
148 |   }
149 | 
150 |   // Set the base directory for serving static files
151 |   svr.set_base_dir(sparams.public_path);
152 | 
153 |   // to make it ctrl+clickable:
154 |   printf("\nwhisper service listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port);
155 | 
156 |   if (!svr.listen_after_bind()) {
157 |     return 1;
158 |   }
159 | 
160 |   whisper_print_timings(ctx);
161 |   whisper_free(ctx);
162 | 
163 |   return 0;
164 | }


--------------------------------------------------------------------------------