├── .github └── workflows │ └── pr.yml ├── .gitignore ├── .gitmodules ├── BUILDING.md ├── CHANGELOG.md ├── Cargo.toml ├── LICENSE ├── README.md ├── build.rs ├── examples ├── audio_transcription.rs ├── basic_use.rs └── full_usage │ ├── 2830-3980-0043.wav │ ├── Cargo.toml │ └── src │ └── main.rs ├── src ├── common_logging.rs ├── error.rs ├── ggml_logging_hook.rs ├── lib.rs ├── standalone.rs ├── utilities.rs ├── whisper_ctx.rs ├── whisper_ctx_wrapper.rs ├── whisper_grammar.rs ├── whisper_logging_hook.rs ├── whisper_params.rs └── whisper_state.rs └── sys ├── Cargo.toml ├── build.rs ├── src ├── bindings.rs └── lib.rs └── wrapper.h /.github/workflows/pr.yml: -------------------------------------------------------------------------------- 1 | name: Check code 2 | on: 3 | push: 4 | pull_request: 5 | workflow_dispatch: 6 | 7 | jobs: 8 | rustfmt: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - name: Check out code into the proper directory 12 | uses: actions/checkout@v3 13 | with: 14 | submodules: 'recursive' 15 | 16 | - name: Cache rust 17 | uses: Swatinem/rust-cache@v2 18 | 19 | - name: Install rust 20 | uses: dtolnay/rust-toolchain@master 21 | with: 22 | toolchain: stable 23 | components: rustfmt 24 | 25 | - name: Check formatting 26 | run: cargo fmt --check 27 | 28 | 29 | clippy: 30 | strategy: 31 | fail-fast: false 32 | matrix: 33 | os: [ ubuntu-latest, windows-latest, macos-latest ] 34 | rust-version: [ stable, nightly ] 35 | runs-on: ${{ matrix.os }} 36 | steps: 37 | - name: Check out code into the proper directory 38 | uses: actions/checkout@v3 39 | with: 40 | submodules: 'recursive' 41 | 42 | - name: Cache rust 43 | uses: Swatinem/rust-cache@v2 44 | 45 | - name: Install rust 46 | uses: dtolnay/rust-toolchain@master 47 | with: 48 | toolchain: ${{ matrix.rust-version }} 49 | components: clippy 50 | 51 | - name: Check clippy lints 52 | run: cargo clippy 53 | 54 | build: 55 | strategy: 56 | fail-fast: false 57 | matrix: 58 | os: [ ubuntu-latest, windows-latest, macos-latest ] 59 | rust-version: [ stable, nightly ] 60 | runs-on: ${{ matrix.os }} 61 | steps: 62 | - name: Check out code into the proper directory 63 | uses: actions/checkout@v3 64 | with: 65 | submodules: 'recursive' 66 | 67 | - name: Cache rust 68 | uses: Swatinem/rust-cache@v2 69 | 70 | - name: Install rust 71 | uses: dtolnay/rust-toolchain@master 72 | with: 73 | toolchain: ${{ matrix.rust-version }} 74 | 75 | - name: Check build 76 | run: cargo build -F log_backend,tracing_backend --verbose --examples 77 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/target 2 | **/Cargo.lock 3 | /.idea 4 | /.vscode 5 | *.bin 6 | *.wav -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "sys/whisper.cpp"] 2 | path = sys/whisper.cpp 3 | url = https://github.com/ggerganov/whisper.cpp 4 | -------------------------------------------------------------------------------- /BUILDING.md: -------------------------------------------------------------------------------- 1 | # Running on Arch Linux 2 | `sudo pacman -Syy llvm clang cmake` 3 | `cargo build` 4 | 5 | # Running on Windows using MSYS2 6 | 7 | The following are instructions for building whisper-rs on Windows using the msys2 set of compilers. 8 | 9 | 1. install msys2/mingw by following [https://code.visualstudio.com/docs/cpp/config-mingw](`https://code.visualstudio.com/docs/cpp/config-mingw`) 10 | 1. Install g++ and make within msys2 ucrt64 11 | - `pacman -S --needed base-devel mingw-w64-x86_64-toolchain` 12 | 2. Add the msys2 ucrt64 bin folder to path `C:\msys64\ucrt64\bin` 13 | 2. Install make by running `pacman -S make` in msys2 ucrt66 14 | 3. Set rust to use msys2: by running `rustup toolchain install stable-x86_64-pc-windows-gnu` in Windows Powershell/Cmd 15 | 4. Add `.cargo/config.toml` file in the project with the following contents: 16 | ``` 17 | [target.x86_64-pc-windows-gnu] 18 | linker = "C:\\msys64\\ucrt64\\bin\\gcc.exe" 19 | ar = "C:\\msys64\\ucrt64\\bin\\ar.exe" 20 | ``` 21 | 5. Run `cargo run` in Windows Powershell/Cmd 22 | 23 | # Running on Windows using Microsoft Visual Studio C++ 24 | 25 | It has been reported that it is also possible to build whisper-rs using Visual Studio C++. 26 | 27 | Make sure you have installed and in the path: 28 | 29 | - Visual Studio C++ 30 | - cmake 31 | - LLVM(clang) 32 | 33 | ### Instructions (for builds with `cuda` enabled) 34 | 1. Download [CUDA](https://developer.nvidia.com/cuda-downloads?target_os=Windows) 35 | 2. Download [Visual Studio with Desktop C++ and Clang enabled](https://visualstudio.microsoft.com/de/downloads/) (see clang link below for installer walkthrough) 36 | 3. Download [CLANG](https://www.wikihow.com/Install-Clang-on-Windows) 37 | 4. Download [CMAKE](https://cmake.org/download/) 38 | 5. Run `where.exe clang`, then `setx LIBCLANG_PATH "C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Tools\Llvm\x64\bin"` or something like that 39 | 6. Restart your shell!!! 40 | 7. Cargo build 41 | 42 | # Running on M1 OSX 43 | 44 | To build on a M1 Mac, make sure to add the following to your project's `.cargo/config.toml`: 45 | 46 | ``` 47 | [target.aarch64-apple-darwin] 48 | rustflags = "-lc++ -l framework=Accelerate" 49 | ``` 50 | 51 | See https://github.com/tazz4843/whisper-rs/pull/2 for more information. 52 | 53 | You also need to have CMake installed. You can obtain this using homebrew: 54 | 55 | ``` 56 | brew install cmake 57 | ``` 58 | 59 | CMake can also be installed from https://cmake.org/download/ but `cmake` binary needs to be in your PATH. 60 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Version 0.8.0 (-sys bindings 0.6.1) (2023-06-18) 2 | * Fix CUDA and OpenCL build broken due to missing API headers. 3 | * Use PIC when building whisper.cpp (fixes building a cdylib on x86 Linux) 4 | 5 | # Version 0.8.0 (2023-05-14) 6 | * Update upstream whisper.cpp to v1.4.2 (OpenCL support) 7 | * Add CUDA and OpenCL support to bindings 8 | * No MacOS testers were able to test CoreML support, so it may be broken. Please open an issue if it is. 9 | * Enable CUDA support by enabling the `cuda` feature. 10 | * Enable OpenCL support by enabling the `opencl` feature. 11 | * Add `FullParams::set_detect_language` 12 | 13 | # Version 0.7.0 (2023-05-10) 14 | * Update upstream whisper.cpp to v1.4.0 (integer quantization support, see last point for CUDA support) 15 | * Expose `WhisperState` as a public type, allowing for more control over the state. 16 | * `WhisperContext::create_state` now returns a `WhisperState` instead of `()`. 17 | * All methods that took a key argument in v0.6.0 have been moved to `WhisperState`. 18 | * Generic key argument on `WhisperContext` has been removed. 19 | * Note: CUDA and OpenCL acceleration is supported on the `cuda-and-opencl-support` branch of the git repo, 20 | and will probably be released in v0.8.0. 21 | 22 | # Version 0.6.0 (2023-04-17) 23 | * Update upstream whisper.cpp to v1.3.0 24 | * Fix breaking changes in update, which cascade to users: 25 | * `WhisperContext`s now have a generic type parameter, which is a hashable key for a state map. 26 | This allows for a single context to be reused for multiple different states, saving memory. 27 | * You must create a new state upon creation, even if you are using the context only once, by calling `WhisperContext::create_key`. 28 | * Each method that now takes a state now takes a key, which internally is used to look up the state. 29 | * This also turns `WhisperContext` into an entirely immutable object, meaning it can be shared across threads and used concurrently, safely. 30 | * Send feedback on these changes to the PR: https://github.com/tazz4843/whisper-rs/pull/33 31 | 32 | # Version 0.2.0 (2022-10-28) 33 | * Update upstream whisper.cpp to 2c281d190b7ec351b8128ba386d110f100993973. 34 | * Fix breaking changes in update, which cascade to users: 35 | * `DecodeStrategy` has been renamed to `SamplingStrategy` 36 | * `WhisperContext::sample_best`'s signature has changed: `needs_timestamp` has been removed. 37 | * New features 38 | * `WhisperContext::full_n_tokens` 39 | * `WhisperContext::full_get_token_text` 40 | * `WhisperContext::full_get_token_id` 41 | * `WhisperContext::full_get_token_prob` 42 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | members = ["sys"] 3 | exclude = ["examples/full_usage"] 4 | 5 | [package] 6 | name = "whisper-rs" 7 | version = "0.14.2" 8 | edition = "2021" 9 | description = "Rust bindings for whisper.cpp" 10 | license = "Unlicense" 11 | documentation = "https://docs.rs/whisper-rs" 12 | repository = "https://github.com/tazz4843/whisper-rs" 13 | 14 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 15 | 16 | [dependencies] 17 | whisper-rs-sys = { path = "sys", version = "0.12" } 18 | log = { version = "0.4", optional = true } 19 | tracing = { version = "0.1", optional = true } 20 | 21 | [dev-dependencies] 22 | hound = "3.5.0" 23 | rand = "0.8.4" 24 | 25 | [features] 26 | default = [] 27 | 28 | raw-api = [] 29 | coreml = ["whisper-rs-sys/coreml"] 30 | cuda = ["whisper-rs-sys/cuda", "_gpu"] 31 | hipblas = ["whisper-rs-sys/hipblas", "_gpu"] 32 | openblas = ["whisper-rs-sys/openblas"] 33 | metal = ["whisper-rs-sys/metal", "_gpu"] 34 | vulkan = ["whisper-rs-sys/vulkan", "_gpu"] 35 | openmp = ["whisper-rs-sys/openmp"] 36 | _gpu = [] 37 | test-with-tiny-model = [] 38 | 39 | # Bring logs into Rust via the log crate. *Warning*: not mutually exclusive with tracing_backend, 40 | # will result in duplicate logs if both are enabled and one consumes logs from the other. 41 | log_backend = ["dep:log"] 42 | 43 | # Bring logs into Rust via the tracing crate. *Warning*: not mutually exclusive with log_backend, 44 | # will result in duplicate logs if both are enabled and one consumes logs from the other. 45 | tracing_backend = ["dep:tracing"] 46 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | This is free and unencumbered software released into the public domain. 2 | 3 | Anyone is free to copy, modify, publish, use, compile, sell, or 4 | distribute this software, either in source code form or as a compiled 5 | binary, for any purpose, commercial or non-commercial, and by any 6 | means. 7 | 8 | In jurisdictions that recognize copyright laws, the author or authors 9 | of this software dedicate any and all copyright interest in the 10 | software to the public domain. We make this dedication for the benefit 11 | of the public at large and to the detriment of our heirs and 12 | successors. We intend this dedication to be an overt act of 13 | relinquishment in perpetuity of all present and future rights to this 14 | software under copyright law. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 19 | IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR 20 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 21 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 22 | OTHER DEALINGS IN THE SOFTWARE. 23 | 24 | For more information, please refer to -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # whisper-rs 2 | 3 | Rust bindings to [whisper.cpp](https://github.com/ggerganov/whisper.cpp/) 4 | 5 | ## Usage 6 | 7 | ```bash 8 | git clone --recursive https://github.com/tazz4843/whisper-rs.git 9 | 10 | cd whisper-rs 11 | 12 | cargo run --example basic_use 13 | 14 | cargo run --example audio_transcription 15 | ``` 16 | 17 | ```rust 18 | use whisper_rs::{WhisperContext, WhisperContextParameters, FullParams, SamplingStrategy}; 19 | 20 | fn main() { 21 | let path_to_model = std::env::args().nth(1).unwrap(); 22 | 23 | // load a context and model 24 | let ctx = WhisperContext::new_with_params( 25 | path_to_model, 26 | WhisperContextParameters::default() 27 | ).expect("failed to load model"); 28 | 29 | // create a params object 30 | let params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 }); 31 | 32 | // assume we have a buffer of audio data 33 | // here we'll make a fake one, floating point samples, 32 bit, 16KHz, mono 34 | let audio_data = vec![0_f32; 16000 * 2]; 35 | 36 | // now we can run the model 37 | let mut state = ctx.create_state().expect("failed to create state"); 38 | state 39 | .full(params, &audio_data[..]) 40 | .expect("failed to run model"); 41 | 42 | // fetch the results 43 | let num_segments = state 44 | .full_n_segments() 45 | .expect("failed to get number of segments"); 46 | for i in 0..num_segments { 47 | let segment = state 48 | .full_get_segment_text(i) 49 | .expect("failed to get segment"); 50 | let start_timestamp = state 51 | .full_get_segment_t0(i) 52 | .expect("failed to get segment start timestamp"); 53 | let end_timestamp = state 54 | .full_get_segment_t1(i) 55 | .expect("failed to get segment end timestamp"); 56 | println!("[{} - {}]: {}", start_timestamp, end_timestamp, segment); 57 | } 58 | } 59 | ``` 60 | 61 | See [examples/basic_use.rs](examples/basic_use.rs) for more details. 62 | 63 | Lower level bindings are exposed if needed, but the above should be enough for most use cases. 64 | See the docs: https://docs.rs/whisper-rs/ for more details. 65 | 66 | ## Feature flags 67 | 68 | All disabled by default unless otherwise specified. 69 | 70 | * `raw-api`: expose whisper-rs-sys without having to pull it in as a dependency. 71 | **NOTE**: enabling this no longer guarantees semver compliance, 72 | as whisper-rs-sys may be upgraded to a breaking version in a patch release of whisper-rs. 73 | * `cuda`: enable CUDA support. Implicitly enables hidden GPU flag at runtime. 74 | * `hipblas`: enable ROCm/hipBLAS support. Only available on linux. Implicitly enables hidden GPU flag at runtime. 75 | * `openblas`: enable OpenBLAS support. 76 | * `metal`: enable Metal support. Implicitly enables hidden GPU flag at runtime. 77 | * `vulkan`: enable Vulkan support. Implicitly enables hidden GPU flag at runtime. 78 | * `log_backend`: allows hooking into whisper.cpp's log output and sending it to the `log` backend. Requires calling 79 | * `tracing_backend`: allows hooking into whisper.cpp's log output and sending it to the `tracing` backend. 80 | 81 | ## Building 82 | 83 | See [BUILDING.md](BUILDING.md) for instructions for building whisper-rs on Windows and OSX M1. Linux builds should just 84 | work out of the box. 85 | 86 | ## Troubleshooting 87 | 88 | * Something other than Windows/macOS/Linux isn't working! 89 | * I don't have a way to test these platforms, so I can't really help you. 90 | * If you can get it working, please open a PR with any changes to make it work and build instructions in 91 | BUILDING.md! 92 | * I get a panic during binding generation build! 93 | * You can attempt to fix it yourself, or you can set the `WHISPER_DONT_GENERATE_BINDINGS` environment variable. 94 | This skips attempting to build the bindings whatsoever and copies the existing ones. They may be out of date, 95 | but it's better than nothing. 96 | * `WHISPER_DONT_GENERATE_BINDINGS=1 cargo build` 97 | * If you can fix the issue, please open a PR! 98 | 99 | ## License 100 | 101 | [Unlicense](LICENSE) 102 | 103 | tl;dr: public domain 104 | -------------------------------------------------------------------------------- /build.rs: -------------------------------------------------------------------------------- 1 | use std::env; 2 | 3 | fn main() { 4 | let whisper_cpp_version = env::var("DEP_WHISPER_WHISPER_CPP_VERSION").unwrap_or_else(|e| { 5 | if env::var("DOCS_RS").is_ok() { 6 | // not sure why but this fails on docs.rs 7 | // return a default string 8 | "0.0.0-fake".to_string() 9 | } else { 10 | panic!("Failed to find upstream whisper.cpp version: your build environment is messed up. {}", e); 11 | } 12 | }); 13 | println!( 14 | "cargo:rustc-env=WHISPER_CPP_VERSION={}", 15 | whisper_cpp_version 16 | ); 17 | } 18 | -------------------------------------------------------------------------------- /examples/audio_transcription.rs: -------------------------------------------------------------------------------- 1 | // This example is not going to build in this folder. 2 | // You need to copy this code into your project and add the dependencies whisper_rs and hound in your cargo.toml 3 | 4 | use hound; 5 | use std::fs::File; 6 | use std::io::Write; 7 | use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters}; 8 | 9 | /// Loads a context and model, processes an audio file, and prints the resulting transcript to stdout. 10 | fn main() -> Result<(), &'static str> { 11 | // Load a context and model. 12 | let mut context_param = WhisperContextParameters::default(); 13 | 14 | // Enable DTW token level timestamp for known model by using model preset 15 | context_param.dtw_parameters.mode = whisper_rs::DtwMode::ModelPreset { 16 | model_preset: whisper_rs::DtwModelPreset::BaseEn, 17 | }; 18 | 19 | // Enable DTW token level timestamp for unknown model by providing custom aheads 20 | // see details https://github.com/ggerganov/whisper.cpp/pull/1485#discussion_r1519681143 21 | // values corresponds to ggml-base.en.bin, result will be the same as with DtwModelPreset::BaseEn 22 | let custom_aheads = [ 23 | (3, 1), 24 | (4, 2), 25 | (4, 3), 26 | (4, 7), 27 | (5, 1), 28 | (5, 2), 29 | (5, 4), 30 | (5, 6), 31 | ] 32 | .map(|(n_text_layer, n_head)| whisper_rs::DtwAhead { 33 | n_text_layer, 34 | n_head, 35 | }); 36 | context_param.dtw_parameters.mode = whisper_rs::DtwMode::Custom { 37 | aheads: &custom_aheads, 38 | }; 39 | 40 | let ctx = WhisperContext::new_with_params( 41 | "example/path/to/model/whisper.cpp/models/ggml-base.en.bin", 42 | context_param, 43 | ) 44 | .expect("failed to load model"); 45 | // Create a state 46 | let mut state = ctx.create_state().expect("failed to create key"); 47 | 48 | // Create a params object for running the model. 49 | // The number of past samples to consider defaults to 0. 50 | let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 0 }); 51 | 52 | // Edit params as needed. 53 | // Set the number of threads to use to 1. 54 | params.set_n_threads(1); 55 | // Enable translation. 56 | params.set_translate(true); 57 | // Set the language to translate to to English. 58 | params.set_language(Some("en")); 59 | // Disable anything that prints to stdout. 60 | params.set_print_special(false); 61 | params.set_print_progress(false); 62 | params.set_print_realtime(false); 63 | params.set_print_timestamps(false); 64 | // Enable token level timestamps 65 | params.set_token_timestamps(true); 66 | 67 | // Open the audio file. 68 | let reader = hound::WavReader::open("audio.wav").expect("failed to open file"); 69 | #[allow(unused_variables)] 70 | let hound::WavSpec { 71 | channels, 72 | sample_rate, 73 | bits_per_sample, 74 | .. 75 | } = reader.spec(); 76 | 77 | // Convert the audio to floating point samples. 78 | let samples: Vec = reader 79 | .into_samples::() 80 | .map(|x| x.expect("Invalid sample")) 81 | .collect(); 82 | let mut audio = vec![0.0f32; samples.len().try_into().unwrap()]; 83 | whisper_rs::convert_integer_to_float_audio(&samples, &mut audio).expect("Conversion error"); 84 | 85 | // Convert audio to 16KHz mono f32 samples, as required by the model. 86 | // These utilities are provided for convenience, but can be replaced with custom conversion logic. 87 | // SIMD variants of these functions are also available on nightly Rust (see the docs). 88 | if channels == 2 { 89 | audio = whisper_rs::convert_stereo_to_mono_audio(&audio).expect("Conversion error"); 90 | } else if channels != 1 { 91 | panic!(">2 channels unsupported"); 92 | } 93 | 94 | if sample_rate != 16000 { 95 | panic!("sample rate must be 16KHz"); 96 | } 97 | 98 | // Run the model. 99 | state.full(params, &audio[..]).expect("failed to run model"); 100 | 101 | // Create a file to write the transcript to. 102 | let mut file = File::create("transcript.txt").expect("failed to create file"); 103 | 104 | // Iterate through the segments of the transcript. 105 | let num_segments = state 106 | .full_n_segments() 107 | .expect("failed to get number of segments"); 108 | for i in 0..num_segments { 109 | // Get the transcribed text and timestamps for the current segment. 110 | let segment = state 111 | .full_get_segment_text(i) 112 | .expect("failed to get segment"); 113 | let start_timestamp = state 114 | .full_get_segment_t0(i) 115 | .expect("failed to get start timestamp"); 116 | let end_timestamp = state 117 | .full_get_segment_t1(i) 118 | .expect("failed to get end timestamp"); 119 | 120 | let first_token_dtw_ts = if let Ok(token_count) = state.full_n_tokens(i) { 121 | if token_count > 0 { 122 | if let Ok(token_data) = state.full_get_token_data(i, 0) { 123 | token_data.t_dtw 124 | } else { 125 | -1i64 126 | } 127 | } else { 128 | -1i64 129 | } 130 | } else { 131 | -1i64 132 | }; 133 | // Print the segment to stdout. 134 | println!( 135 | "[{} - {} ({})]: {}", 136 | start_timestamp, end_timestamp, first_token_dtw_ts, segment 137 | ); 138 | 139 | // Format the segment information as a string. 140 | let line = format!("[{} - {}]: {}\n", start_timestamp, end_timestamp, segment); 141 | 142 | // Write the segment information to the file. 143 | file.write_all(line.as_bytes()) 144 | .expect("failed to write to file"); 145 | } 146 | Ok(()) 147 | } 148 | -------------------------------------------------------------------------------- /examples/basic_use.rs: -------------------------------------------------------------------------------- 1 | /* 2 | wget https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.bin 3 | wget https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav 4 | cargo run --example basic_use ggml-tiny.bin jfk.wav 5 | */ 6 | 7 | use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters}; 8 | 9 | fn main() { 10 | let model_path = std::env::args() 11 | .nth(1) 12 | .expect("Please specify path to model"); 13 | let wav_path = std::env::args() 14 | .nth(2) 15 | .expect("Please specify path to wav file"); 16 | let language = "en"; 17 | 18 | let samples: Vec = hound::WavReader::open(wav_path) 19 | .unwrap() 20 | .into_samples::() 21 | .map(|x| x.unwrap()) 22 | .collect(); 23 | 24 | // load a context and model 25 | let ctx = WhisperContext::new_with_params(&model_path, WhisperContextParameters::default()) 26 | .expect("failed to load model"); 27 | 28 | let mut state = ctx.create_state().expect("failed to create state"); 29 | 30 | let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 }); 31 | 32 | // and set the language to translate to to english 33 | params.set_language(Some(&language)); 34 | 35 | // we also explicitly disable anything that prints to stdout 36 | params.set_print_special(false); 37 | params.set_print_progress(false); 38 | params.set_print_realtime(false); 39 | params.set_print_timestamps(false); 40 | 41 | // we must convert to 16KHz mono f32 samples for the model 42 | // some utilities exist for this 43 | // note that you don't need to use these, you can do it yourself or any other way you want 44 | // these are just provided for convenience 45 | // SIMD variants of these functions are also available, but only on nightly Rust: see the docs 46 | let mut inter_samples = vec![Default::default(); samples.len()]; 47 | 48 | whisper_rs::convert_integer_to_float_audio(&samples, &mut inter_samples) 49 | .expect("failed to convert audio data"); 50 | let samples = whisper_rs::convert_stereo_to_mono_audio(&inter_samples) 51 | .expect("failed to convert audio data"); 52 | 53 | // now we can run the model 54 | // note the key we use here is the one we created above 55 | state 56 | .full(params, &samples[..]) 57 | .expect("failed to run model"); 58 | 59 | // fetch the results 60 | let num_segments = state 61 | .full_n_segments() 62 | .expect("failed to get number of segments"); 63 | for i in 0..num_segments { 64 | let segment = state 65 | .full_get_segment_text(i) 66 | .expect("failed to get segment"); 67 | let start_timestamp = state 68 | .full_get_segment_t0(i) 69 | .expect("failed to get segment start timestamp"); 70 | let end_timestamp = state 71 | .full_get_segment_t1(i) 72 | .expect("failed to get segment end timestamp"); 73 | println!("[{} - {}]: {}", start_timestamp, end_timestamp, segment); 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /examples/full_usage/2830-3980-0043.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tazz4843/whisper-rs/de30f9c23da52c81b06fa78bab005ab353d74637/examples/full_usage/2830-3980-0043.wav -------------------------------------------------------------------------------- /examples/full_usage/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "full_usage" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 7 | 8 | [dependencies] 9 | hound = "3" 10 | whisper-rs = { path = "../.." } 11 | -------------------------------------------------------------------------------- /examples/full_usage/src/main.rs: -------------------------------------------------------------------------------- 1 | #![allow(clippy::uninlined_format_args)] 2 | 3 | use hound::{SampleFormat, WavReader}; 4 | use std::path::Path; 5 | use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters}; 6 | 7 | fn parse_wav_file(path: &Path) -> Vec { 8 | let reader = WavReader::open(path).expect("failed to read file"); 9 | 10 | if reader.spec().channels != 1 { 11 | panic!("expected mono audio file"); 12 | } 13 | if reader.spec().sample_format != SampleFormat::Int { 14 | panic!("expected integer sample format"); 15 | } 16 | if reader.spec().sample_rate != 16000 { 17 | panic!("expected 16KHz sample rate"); 18 | } 19 | if reader.spec().bits_per_sample != 16 { 20 | panic!("expected 16 bits per sample"); 21 | } 22 | 23 | reader 24 | .into_samples::() 25 | .map(|x| x.expect("sample")) 26 | .collect::>() 27 | } 28 | 29 | fn main() { 30 | let arg1 = std::env::args() 31 | .nth(1) 32 | .expect("first argument should be path to WAV file"); 33 | let audio_path = Path::new(&arg1); 34 | if !audio_path.exists() { 35 | panic!("audio file doesn't exist"); 36 | } 37 | let arg2 = std::env::args() 38 | .nth(2) 39 | .expect("second argument should be path to Whisper model"); 40 | let whisper_path = Path::new(&arg2); 41 | if !whisper_path.exists() { 42 | panic!("whisper file doesn't exist") 43 | } 44 | 45 | let original_samples = parse_wav_file(audio_path); 46 | let mut samples = vec![0.0f32; original_samples.len()]; 47 | whisper_rs::convert_integer_to_float_audio(&original_samples, &mut samples) 48 | .expect("failed to convert samples"); 49 | 50 | let ctx = WhisperContext::new_with_params( 51 | &whisper_path.to_string_lossy(), 52 | WhisperContextParameters::default(), 53 | ) 54 | .expect("failed to open model"); 55 | let mut state = ctx.create_state().expect("failed to create key"); 56 | let mut params = FullParams::new(SamplingStrategy::default()); 57 | params.set_initial_prompt("experience"); 58 | params.set_progress_callback_safe(|progress| println!("Progress callback: {}%", progress)); 59 | 60 | let st = std::time::Instant::now(); 61 | state 62 | .full(params, &samples) 63 | .expect("failed to convert samples"); 64 | let et = std::time::Instant::now(); 65 | 66 | let num_segments = state 67 | .full_n_segments() 68 | .expect("failed to get number of segments"); 69 | for i in 0..num_segments { 70 | let segment = state 71 | .full_get_segment_text(i) 72 | .expect("failed to get segment"); 73 | let start_timestamp = state 74 | .full_get_segment_t0(i) 75 | .expect("failed to get start timestamp"); 76 | let end_timestamp = state 77 | .full_get_segment_t1(i) 78 | .expect("failed to get end timestamp"); 79 | println!("[{} - {}]: {}", start_timestamp, end_timestamp, segment); 80 | } 81 | println!("took {}ms", (et - st).as_millis()); 82 | } 83 | -------------------------------------------------------------------------------- /src/common_logging.rs: -------------------------------------------------------------------------------- 1 | macro_rules! generic_error { 2 | ($($expr:tt)*) => { 3 | #[cfg(feature = "log_backend")] 4 | log::error!($($expr)*); 5 | #[cfg(feature = "tracing_backend")] 6 | tracing::error!($($expr)*); 7 | }; 8 | } 9 | 10 | macro_rules! generic_warn { 11 | ($($expr:tt)*) => { 12 | #[cfg(feature = "log_backend")] 13 | log::warn!($($expr)*); 14 | #[cfg(feature = "tracing_backend")] 15 | tracing::warn!($($expr)*); 16 | } 17 | } 18 | 19 | macro_rules! generic_info { 20 | ($($expr:tt)*) => { 21 | #[cfg(feature = "log_backend")] 22 | log::info!($($expr)*); 23 | #[cfg(feature = "tracing_backend")] 24 | tracing::info!($($expr)*); 25 | } 26 | } 27 | 28 | macro_rules! generic_debug { 29 | ($($expr:tt)*) => { 30 | #[cfg(feature = "log_backend")] 31 | log::debug!($($expr)*); 32 | #[cfg(feature = "tracing_backend")] 33 | tracing::debug!($($expr)*); 34 | } 35 | } 36 | 37 | macro_rules! generic_trace { 38 | ($($expr:tt)*) => { 39 | #[cfg(feature = "log_backend")] 40 | log::trace!($($expr)*); 41 | #[cfg(feature = "tracing_backend")] 42 | tracing::trace!($($expr)*); 43 | } 44 | } 45 | 46 | use whisper_rs_sys::ggml_log_level; 47 | pub(crate) use {generic_debug, generic_error, generic_info, generic_trace, generic_warn}; 48 | 49 | // Unsigned integer type on most platforms is 32 bit, niche platforms that whisper.cpp 50 | // likely doesn't even support would use 16 bit and would still fit 51 | #[cfg_attr(any(not(windows), target_env = "gnu"), repr(u32))] 52 | // Of course Windows thinks it's a special little shit and 53 | // picks a signed integer for an unsigned type 54 | #[cfg_attr(all(windows, not(target_env = "gnu")), repr(i32))] 55 | pub enum GGMLLogLevel { 56 | None = whisper_rs_sys::ggml_log_level_GGML_LOG_LEVEL_NONE, 57 | Info = whisper_rs_sys::ggml_log_level_GGML_LOG_LEVEL_INFO, 58 | Warn = whisper_rs_sys::ggml_log_level_GGML_LOG_LEVEL_WARN, 59 | Error = whisper_rs_sys::ggml_log_level_GGML_LOG_LEVEL_ERROR, 60 | Debug = whisper_rs_sys::ggml_log_level_GGML_LOG_LEVEL_DEBUG, 61 | Cont = whisper_rs_sys::ggml_log_level_GGML_LOG_LEVEL_CONT, 62 | Unknown(ggml_log_level), 63 | } 64 | impl From for GGMLLogLevel { 65 | fn from(level: ggml_log_level) -> Self { 66 | match level { 67 | whisper_rs_sys::ggml_log_level_GGML_LOG_LEVEL_NONE => GGMLLogLevel::None, 68 | whisper_rs_sys::ggml_log_level_GGML_LOG_LEVEL_INFO => GGMLLogLevel::Info, 69 | whisper_rs_sys::ggml_log_level_GGML_LOG_LEVEL_WARN => GGMLLogLevel::Warn, 70 | whisper_rs_sys::ggml_log_level_GGML_LOG_LEVEL_ERROR => GGMLLogLevel::Error, 71 | whisper_rs_sys::ggml_log_level_GGML_LOG_LEVEL_DEBUG => GGMLLogLevel::Debug, 72 | whisper_rs_sys::ggml_log_level_GGML_LOG_LEVEL_CONT => GGMLLogLevel::Cont, 73 | other => GGMLLogLevel::Unknown(other), 74 | } 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /src/error.rs: -------------------------------------------------------------------------------- 1 | use std::ffi::{c_int, NulError}; 2 | use std::str::Utf8Error; 3 | 4 | /// If you have not configured a logging trampoline with [crate::whisper_sys_log::install_whisper_log_trampoline] or 5 | /// [crate::whisper_sys_tracing::install_whisper_tracing_trampoline], 6 | /// then `whisper.cpp`'s errors will be output to stderr, 7 | /// so you can check there for more information upon receiving a `WhisperError`. 8 | #[derive(Debug, Copy, Clone)] 9 | pub enum WhisperError { 10 | /// Failed to create a new context. 11 | InitError, 12 | /// User didn't initialize spectrogram 13 | SpectrogramNotInitialized, 14 | /// Encode was not called. 15 | EncodeNotComplete, 16 | /// Decode was not called. 17 | DecodeNotComplete, 18 | /// Failed to calculate the spectrogram for some reason. 19 | UnableToCalculateSpectrogram, 20 | /// Failed to evaluate model. 21 | UnableToCalculateEvaluation, 22 | /// Failed to run the encoder 23 | FailedToEncode, 24 | /// Failed to run the decoder 25 | FailedToDecode, 26 | /// Invalid number of mel bands. 27 | InvalidMelBands, 28 | /// Invalid thread count 29 | InvalidThreadCount, 30 | /// Invalid UTF-8 detected in a string from Whisper. 31 | InvalidUtf8 { 32 | error_len: Option, 33 | valid_up_to: usize, 34 | }, 35 | /// A null byte was detected in a user-provided string. 36 | NullByteInString { idx: usize }, 37 | /// Whisper returned a null pointer. 38 | NullPointer, 39 | /// Generic whisper error. Varies depending on the function. 40 | GenericError(c_int), 41 | /// Whisper failed to convert the provided text into tokens. 42 | InvalidText, 43 | /// Creating a state pointer failed. Check stderr for more information. 44 | FailedToCreateState, 45 | /// No samples were provided. 46 | NoSamples, 47 | /// Input and output slices were not the same length. 48 | InputOutputLengthMismatch { input_len: usize, output_len: usize }, 49 | /// Input slice was not an even number of samples. 50 | HalfSampleMissing(usize), 51 | } 52 | 53 | impl From for WhisperError { 54 | fn from(e: Utf8Error) -> Self { 55 | Self::InvalidUtf8 { 56 | error_len: e.error_len(), 57 | valid_up_to: e.valid_up_to(), 58 | } 59 | } 60 | } 61 | 62 | impl From for WhisperError { 63 | fn from(e: NulError) -> Self { 64 | Self::NullByteInString { 65 | idx: e.nul_position(), 66 | } 67 | } 68 | } 69 | 70 | impl std::fmt::Display for WhisperError { 71 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { 72 | use WhisperError::*; 73 | match self { 74 | InitError => write!(f, "Failed to create a new whisper context."), 75 | SpectrogramNotInitialized => write!(f, "User didn't initialize spectrogram."), 76 | EncodeNotComplete => write!(f, "Encode was not called."), 77 | DecodeNotComplete => write!(f, "Decode was not called."), 78 | UnableToCalculateSpectrogram => { 79 | write!(f, "Failed to calculate the spectrogram for some reason.") 80 | } 81 | UnableToCalculateEvaluation => write!(f, "Failed to evaluate model."), 82 | FailedToEncode => write!(f, "Failed to run the encoder."), 83 | FailedToDecode => write!(f, "Failed to run the decoder."), 84 | InvalidMelBands => write!(f, "Invalid number of mel bands."), 85 | InvalidThreadCount => write!(f, "Invalid thread count."), 86 | InvalidUtf8 { 87 | valid_up_to, 88 | error_len: Some(len), 89 | } => write!( 90 | f, 91 | "Invalid UTF-8 detected in a string from Whisper. Index: {}, Length: {}.", 92 | valid_up_to, len 93 | ), 94 | InvalidUtf8 { 95 | valid_up_to, 96 | error_len: None, 97 | } => write!( 98 | f, 99 | "Invalid UTF-8 detected in a string from Whisper. Index: {}.", 100 | valid_up_to 101 | ), 102 | NullByteInString { idx } => write!( 103 | f, 104 | "A null byte was detected in a user-provided string. Index: {}", 105 | idx 106 | ), 107 | NullPointer => write!(f, "Whisper returned a null pointer."), 108 | InvalidText => write!( 109 | f, 110 | "Whisper failed to convert the provided text into tokens." 111 | ), 112 | FailedToCreateState => write!(f, "Creating a state pointer failed."), 113 | GenericError(c_int) => write!( 114 | f, 115 | "Generic whisper error. Varies depending on the function. Error code: {}", 116 | c_int 117 | ), 118 | NoSamples => write!(f, "Input sample buffer was empty."), 119 | InputOutputLengthMismatch { 120 | output_len, 121 | input_len, 122 | } => { 123 | write!( 124 | f, 125 | "Input and output slices were not the same length. Input: {}, Output: {}", 126 | input_len, output_len 127 | ) 128 | } 129 | HalfSampleMissing(size) => { 130 | write!( 131 | f, 132 | "Input slice was not an even number of samples, got {}, expected {}", 133 | size, 134 | size + 1 135 | ) 136 | } 137 | } 138 | } 139 | } 140 | 141 | impl std::error::Error for WhisperError {} 142 | -------------------------------------------------------------------------------- /src/ggml_logging_hook.rs: -------------------------------------------------------------------------------- 1 | use crate::common_logging::{ 2 | generic_debug, generic_error, generic_info, generic_trace, generic_warn, GGMLLogLevel, 3 | }; 4 | use core::ffi::{c_char, c_void}; 5 | use std::borrow::Cow; 6 | use std::ffi::CStr; 7 | use std::sync::Once; 8 | use whisper_rs_sys::ggml_log_level; 9 | 10 | static GGML_LOG_TRAMPOLINE_INSTALL: Once = Once::new(); 11 | pub(crate) fn install_ggml_logging_hook() { 12 | GGML_LOG_TRAMPOLINE_INSTALL.call_once(|| unsafe { 13 | whisper_rs_sys::ggml_log_set(Some(ggml_logging_trampoline), std::ptr::null_mut()) 14 | }); 15 | } 16 | 17 | unsafe extern "C" fn ggml_logging_trampoline( 18 | level: ggml_log_level, 19 | text: *const c_char, 20 | _: *mut c_void, // user_data 21 | ) { 22 | if text.is_null() { 23 | generic_error!("ggml_logging_trampoline: text is nullptr"); 24 | } 25 | let level = GGMLLogLevel::from(level); 26 | 27 | // SAFETY: we must trust ggml that it will not pass us a string that does not satisfy 28 | // from_ptr's requirements. 29 | let log_str = unsafe { CStr::from_ptr(text) }.to_string_lossy(); 30 | 31 | ggml_logging_trampoline_safe(level, log_str) 32 | } 33 | 34 | // this code essentially compiles down to a noop if neither feature is enabled 35 | #[cfg_attr( 36 | not(any(feature = "log_backend", feature = "tracing_backend")), 37 | allow(unused_variables) 38 | )] 39 | fn ggml_logging_trampoline_safe(level: GGMLLogLevel, text: Cow) { 40 | match level { 41 | GGMLLogLevel::None => { 42 | // no clue what to do here, trace it? 43 | generic_trace!("{}", text.trim()); 44 | } 45 | GGMLLogLevel::Info => { 46 | generic_info!("{}", text.trim()); 47 | } 48 | GGMLLogLevel::Warn => { 49 | generic_warn!("{}", text.trim()); 50 | } 51 | GGMLLogLevel::Error => { 52 | generic_error!("{}", text.trim()); 53 | } 54 | GGMLLogLevel::Debug => { 55 | generic_debug!("{}", text.trim()); 56 | } 57 | GGMLLogLevel::Cont => { 58 | // this means continue previous log 59 | // storing state to do this is a massive pain so it's just a lot easier to not 60 | // plus as far as i can tell it's not actually *used* anywhere 61 | // ggml splits at 128 chars and doesn't actually change the kind of log 62 | // so technically this is unused 63 | generic_trace!("{}", text.trim()); 64 | } 65 | GGMLLogLevel::Unknown(level) => { 66 | generic_warn!( 67 | "ggml_logging_trampoline: unknown log level {}: message: {}", 68 | level, 69 | text.trim() 70 | ); 71 | } 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | #![allow(clippy::uninlined_format_args)] 2 | #![cfg_attr(test, feature(test))] 3 | 4 | mod common_logging; 5 | mod error; 6 | mod ggml_logging_hook; 7 | mod standalone; 8 | mod utilities; 9 | mod whisper_ctx; 10 | mod whisper_ctx_wrapper; 11 | mod whisper_grammar; 12 | mod whisper_logging_hook; 13 | mod whisper_params; 14 | mod whisper_state; 15 | 16 | pub use common_logging::GGMLLogLevel; 17 | pub use error::WhisperError; 18 | pub use standalone::*; 19 | pub use utilities::*; 20 | pub use whisper_ctx::DtwMode; 21 | pub use whisper_ctx::DtwModelPreset; 22 | pub use whisper_ctx::DtwParameters; 23 | pub use whisper_ctx::WhisperContextParameters; 24 | use whisper_ctx::WhisperInnerContext; 25 | pub use whisper_ctx_wrapper::WhisperContext; 26 | pub use whisper_grammar::{WhisperGrammarElement, WhisperGrammarElementType}; 27 | pub use whisper_params::{FullParams, SamplingStrategy, SegmentCallbackData}; 28 | #[cfg(feature = "raw-api")] 29 | pub use whisper_rs_sys; 30 | pub use whisper_state::WhisperState; 31 | 32 | pub type WhisperSysContext = whisper_rs_sys::whisper_context; 33 | pub type WhisperSysState = whisper_rs_sys::whisper_state; 34 | 35 | pub type WhisperTokenData = whisper_rs_sys::whisper_token_data; 36 | pub type WhisperToken = whisper_rs_sys::whisper_token; 37 | pub type WhisperNewSegmentCallback = whisper_rs_sys::whisper_new_segment_callback; 38 | pub type WhisperStartEncoderCallback = whisper_rs_sys::whisper_encoder_begin_callback; 39 | pub type WhisperProgressCallback = whisper_rs_sys::whisper_progress_callback; 40 | pub type WhisperLogitsFilterCallback = whisper_rs_sys::whisper_logits_filter_callback; 41 | pub type WhisperAbortCallback = whisper_rs_sys::ggml_abort_callback; 42 | pub type WhisperLogCallback = whisper_rs_sys::ggml_log_callback; 43 | pub type DtwAhead = whisper_rs_sys::whisper_ahead; 44 | 45 | /// The version of whisper.cpp that whisper-rs was linked with. 46 | pub static WHISPER_CPP_VERSION: &str = env!("WHISPER_CPP_VERSION"); 47 | 48 | /// Redirect all whisper.cpp and GGML logs to logging hooks installed by whisper-rs. 49 | /// 50 | /// This will stop most logs from being output to stdout/stderr and will bring them into 51 | /// `log` or `tracing`, if the `log_backend` or `tracing_backend` features, respectively, 52 | /// are enabled. If neither is enabled, this will essentially disable logging, as they won't 53 | /// be output anywhere. 54 | /// 55 | /// Note whisper.cpp and GGML do not reliably follow Rust logging conventions. 56 | /// Use your logging crate's configuration to control how these logs will be output. 57 | /// whisper-rs does not currently output any logs, but this may change in the future. 58 | /// You should configure by module path and use `whisper_rs::ggml_logging_hook`, 59 | /// and/or `whisper_rs::whisper_logging_hook`, to avoid possibly ignoring useful 60 | /// `whisper-rs` logs in the future. 61 | /// 62 | /// Safe to call multiple times. Only has an effect the first time. 63 | /// (note this means installing your own logging handlers with unsafe functions after this call 64 | /// is permanent and cannot be undone) 65 | pub fn install_logging_hooks() { 66 | crate::whisper_logging_hook::install_whisper_logging_hook(); 67 | crate::ggml_logging_hook::install_ggml_logging_hook(); 68 | } 69 | -------------------------------------------------------------------------------- /src/standalone.rs: -------------------------------------------------------------------------------- 1 | //! Standalone functions that have no associated type. 2 | 3 | use std::ffi::{c_int, CStr, CString}; 4 | 5 | /// Return the id of the specified language, returns -1 if not found 6 | /// 7 | /// # Arguments 8 | /// * lang: The language to get the id for. 9 | /// 10 | /// # Returns 11 | /// The ID of the language, None if not found. 12 | /// 13 | /// # Panics 14 | /// Panics if the language contains a null byte. 15 | /// 16 | /// # C++ equivalent 17 | /// `int whisper_lang_id(const char * lang)` 18 | pub fn get_lang_id(lang: &str) -> Option { 19 | let c_lang = CString::new(lang).expect("Language contains null byte"); 20 | let ret = unsafe { whisper_rs_sys::whisper_lang_id(c_lang.as_ptr()) }; 21 | if ret == -1 { 22 | None 23 | } else { 24 | Some(ret) 25 | } 26 | } 27 | 28 | /// Return the ID of the maximum language (ie the number of languages - 1) 29 | /// 30 | /// # Returns 31 | /// i32 32 | /// 33 | /// # C++ equivalent 34 | /// `int whisper_lang_max_id()` 35 | pub fn get_lang_max_id() -> i32 { 36 | unsafe { whisper_rs_sys::whisper_lang_max_id() } 37 | } 38 | 39 | /// Get the short string of the specified language id (e.g. 2 -> "de"). 40 | /// 41 | /// # Returns 42 | /// The short string of the language, None if not found. 43 | /// 44 | /// # C++ equivalent 45 | /// `const char * whisper_lang_str(int id)` 46 | pub fn get_lang_str(id: i32) -> Option<&'static str> { 47 | let c_buf = unsafe { whisper_rs_sys::whisper_lang_str(id) }; 48 | if c_buf.is_null() { 49 | None 50 | } else { 51 | let c_str = unsafe { CStr::from_ptr(c_buf) }; 52 | Some(c_str.to_str().unwrap()) 53 | } 54 | } 55 | 56 | /// Get the full string of the specified language name (e.g. 2 -> "german"). 57 | /// 58 | /// # Returns 59 | /// The full string of the language, None if not found. 60 | /// 61 | /// # C++ equivalent 62 | /// `const char * whisper_lang_str_full(int id)` 63 | pub fn get_lang_str_full(id: i32) -> Option<&'static str> { 64 | let c_buf = unsafe { whisper_rs_sys::whisper_lang_str_full(id) }; 65 | if c_buf.is_null() { 66 | None 67 | } else { 68 | let c_str = unsafe { CStr::from_ptr(c_buf) }; 69 | Some(c_str.to_str().unwrap()) 70 | } 71 | } 72 | 73 | /// Callback to control logging output: default behaviour is to print to stderr. 74 | /// 75 | /// # Safety 76 | /// The callback must be safe to call from C (i.e. no panicking, no unwinding, etc). 77 | /// 78 | /// # C++ equivalent 79 | /// `void whisper_set_log_callback(whisper_log_callback callback);` 80 | pub unsafe fn set_log_callback( 81 | log_callback: crate::WhisperLogCallback, 82 | user_data: *mut std::ffi::c_void, 83 | ) { 84 | unsafe { 85 | whisper_rs_sys::whisper_log_set(log_callback, user_data); 86 | } 87 | } 88 | 89 | /// Print system information. 90 | /// 91 | /// # C++ equivalent 92 | /// `const char * whisper_print_system_info()` 93 | pub fn print_system_info() -> &'static str { 94 | let c_buf = unsafe { whisper_rs_sys::whisper_print_system_info() }; 95 | let c_str = unsafe { CStr::from_ptr(c_buf) }; 96 | c_str.to_str().unwrap() 97 | } 98 | 99 | /// Programmatically exposes the information provided by `print_system_info` 100 | /// 101 | /// # C++ equivalent 102 | /// `int ggml_cpu_has_...` 103 | pub struct SystemInfo { 104 | pub avx: bool, 105 | pub avx2: bool, 106 | pub fma: bool, 107 | pub f16c: bool, 108 | } 109 | 110 | impl Default for SystemInfo { 111 | fn default() -> Self { 112 | unsafe { 113 | Self { 114 | avx: whisper_rs_sys::ggml_cpu_has_avx() != 0, 115 | avx2: whisper_rs_sys::ggml_cpu_has_avx2() != 0, 116 | fma: whisper_rs_sys::ggml_cpu_has_fma() != 0, 117 | f16c: whisper_rs_sys::ggml_cpu_has_f16c() != 0, 118 | } 119 | } 120 | } 121 | } 122 | 123 | #[cfg(test)] 124 | mod tests { 125 | use super::*; 126 | 127 | #[test] 128 | fn test_openblas() { 129 | let info = SystemInfo::default(); 130 | assert_eq!(info.blas, cfg!(feature = "openblas")); 131 | } 132 | } 133 | -------------------------------------------------------------------------------- /src/utilities.rs: -------------------------------------------------------------------------------- 1 | use crate::WhisperError; 2 | 3 | /// Convert an array of 16 bit mono audio samples to a vector of 32 bit floats. 4 | /// 5 | /// # Arguments 6 | /// * `samples` - The array of 16 bit mono audio samples. 7 | /// * `output` - The vector of 32 bit floats to write the converted samples to. 8 | /// 9 | /// # Panics 10 | /// * if `samples.len != output.len()` 11 | /// 12 | /// # Examples 13 | /// ``` 14 | /// # use whisper_rs::convert_integer_to_float_audio; 15 | /// let samples = [0i16; 1024]; 16 | /// let mut output = vec![0.0f32; samples.len()]; 17 | /// convert_integer_to_float_audio(&samples, &mut output).expect("input and output lengths should be equal"); 18 | /// ``` 19 | pub fn convert_integer_to_float_audio( 20 | samples: &[i16], 21 | output: &mut [f32], 22 | ) -> Result<(), WhisperError> { 23 | if samples.len() != output.len() { 24 | return Err(WhisperError::InputOutputLengthMismatch { 25 | input_len: samples.len(), 26 | output_len: output.len(), 27 | }); 28 | } 29 | 30 | for (input, output) in samples.iter().zip(output.iter_mut()) { 31 | *output = *input as f32 / 32768.0; 32 | } 33 | 34 | Ok(()) 35 | } 36 | 37 | /// Convert 32-bit floating point stereo PCM audio to 32-bit floating point mono PCM audio. 38 | /// 39 | /// # Arguments 40 | /// * `samples` - The array of 32-bit floating point stereo PCM audio samples. 41 | /// 42 | /// # Errors 43 | /// * if `samples.len()` is odd 44 | /// 45 | /// # Returns 46 | /// A vector of 32-bit floating point mono PCM audio samples. 47 | /// 48 | /// # Examples 49 | /// ``` 50 | /// # use whisper_rs::convert_stereo_to_mono_audio; 51 | /// let samples = [0.0f32; 1024]; 52 | /// let mono = convert_stereo_to_mono_audio(&samples).expect("should be no half samples missing"); 53 | /// ``` 54 | pub fn convert_stereo_to_mono_audio(samples: &[f32]) -> Result, WhisperError> { 55 | if samples.len() & 1 != 0 { 56 | return Err(WhisperError::HalfSampleMissing(samples.len())); 57 | } 58 | 59 | Ok(samples 60 | .chunks_exact(2) 61 | .map(|x| (x[0] + x[1]) / 2.0) 62 | .collect()) 63 | } 64 | 65 | #[cfg(test)] 66 | mod test { 67 | use super::*; 68 | use rand::distributions::{Distribution, Standard}; 69 | use rand::Rng; 70 | use std::hint::black_box; 71 | 72 | extern crate test; 73 | 74 | fn random_sample_data() -> Vec 75 | where 76 | Standard: Distribution, 77 | { 78 | const SAMPLE_SIZE: usize = 1_048_576; 79 | 80 | let mut rng = rand::thread_rng(); 81 | let mut samples = Vec::with_capacity(SAMPLE_SIZE); 82 | for _ in 0..SAMPLE_SIZE { 83 | samples.push(rng.gen::()); 84 | } 85 | samples 86 | } 87 | 88 | #[test] 89 | pub fn assert_stereo_to_mono_err() { 90 | let samples = random_sample_data::(); 91 | let mono = convert_stereo_to_mono_audio(&samples); 92 | assert!(mono.is_err()); 93 | } 94 | 95 | #[bench] 96 | pub fn bench_stereo_to_mono(b: &mut test::Bencher) { 97 | let samples = random_sample_data::(); 98 | b.iter(|| black_box(convert_stereo_to_mono_audio(black_box(&samples)))); 99 | } 100 | 101 | #[bench] 102 | pub fn bench_integer_to_float(b: &mut test::Bencher) { 103 | let samples = random_sample_data::(); 104 | let mut output = vec![0.0f32; samples.len()]; 105 | b.iter(|| { 106 | black_box(convert_integer_to_float_audio( 107 | black_box(&samples), 108 | black_box(&mut output), 109 | )) 110 | }); 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /src/whisper_ctx.rs: -------------------------------------------------------------------------------- 1 | use crate::error::WhisperError; 2 | use crate::WhisperToken; 3 | use std::ffi::{c_int, CStr, CString}; 4 | 5 | /// Safe Rust wrapper around a Whisper context. 6 | /// 7 | /// You likely want to create this with [WhisperInnerContext::new_with_params], 8 | /// create a state with [WhisperInnerContext::create_state], 9 | /// then run a full transcription with [WhisperState::full]. 10 | #[derive(Debug)] 11 | pub struct WhisperInnerContext { 12 | pub(crate) ctx: *mut whisper_rs_sys::whisper_context, 13 | } 14 | 15 | impl WhisperInnerContext { 16 | /// Create a new WhisperContext from a file, with parameters. 17 | /// 18 | /// # Arguments 19 | /// * path: The path to the model file. 20 | /// * parameters: A parameter struct containing the parameters to use. 21 | /// 22 | /// # Returns 23 | /// Ok(Self) on success, Err(WhisperError) on failure. 24 | /// 25 | /// # C++ equivalent 26 | /// `struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params);` 27 | pub fn new_with_params( 28 | path: &str, 29 | parameters: WhisperContextParameters, 30 | ) -> Result { 31 | let path_cstr = CString::new(path)?; 32 | let ctx = unsafe { 33 | whisper_rs_sys::whisper_init_from_file_with_params_no_state( 34 | path_cstr.as_ptr(), 35 | parameters.to_c_struct(), 36 | ) 37 | }; 38 | if ctx.is_null() { 39 | Err(WhisperError::InitError) 40 | } else { 41 | Ok(Self { ctx }) 42 | } 43 | } 44 | 45 | /// Create a new WhisperContext from a buffer. 46 | /// 47 | /// # Arguments 48 | /// * buffer: The buffer containing the model. 49 | /// 50 | /// # Returns 51 | /// Ok(Self) on success, Err(WhisperError) on failure. 52 | /// 53 | /// # C++ equivalent 54 | /// `struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct whisper_context_params params);` 55 | pub fn new_from_buffer_with_params( 56 | buffer: &[u8], 57 | parameters: WhisperContextParameters, 58 | ) -> Result { 59 | let ctx = unsafe { 60 | whisper_rs_sys::whisper_init_from_buffer_with_params_no_state( 61 | buffer.as_ptr() as _, 62 | buffer.len(), 63 | parameters.to_c_struct(), 64 | ) 65 | }; 66 | if ctx.is_null() { 67 | Err(WhisperError::InitError) 68 | } else { 69 | Ok(Self { ctx }) 70 | } 71 | } 72 | 73 | /// Convert the provided text into tokens. 74 | /// 75 | /// # Arguments 76 | /// * text: The text to convert. 77 | /// 78 | /// # Returns 79 | /// `Ok(Vec)` on success, `Err(WhisperError)` on failure. 80 | /// 81 | /// # C++ equivalent 82 | /// `int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens);` 83 | pub fn tokenize( 84 | &self, 85 | text: &str, 86 | max_tokens: usize, 87 | ) -> Result, WhisperError> { 88 | // convert the text to a nul-terminated C string. Will raise an error if the text contains 89 | // any nul bytes. 90 | let text = CString::new(text)?; 91 | // allocate at least max_tokens to ensure the memory is valid 92 | let mut tokens: Vec = Vec::with_capacity(max_tokens); 93 | let ret = unsafe { 94 | whisper_rs_sys::whisper_tokenize( 95 | self.ctx, 96 | text.as_ptr(), 97 | tokens.as_mut_ptr(), 98 | max_tokens as c_int, 99 | ) 100 | }; 101 | if ret == -1 { 102 | Err(WhisperError::InvalidText) 103 | } else { 104 | // SAFETY: when ret != -1, we know that the length of the vector is at least ret tokens 105 | unsafe { tokens.set_len(ret as usize) }; 106 | Ok(tokens) 107 | } 108 | } 109 | 110 | /// Get n_vocab. 111 | /// 112 | /// # Returns 113 | /// c_int 114 | /// 115 | /// # C++ equivalent 116 | /// `int whisper_n_vocab (struct whisper_context * ctx)` 117 | #[inline] 118 | pub fn n_vocab(&self) -> c_int { 119 | unsafe { whisper_rs_sys::whisper_n_vocab(self.ctx) } 120 | } 121 | 122 | /// Get n_text_ctx. 123 | /// 124 | /// # Returns 125 | /// c_int 126 | /// 127 | /// # C++ equivalent 128 | /// `int whisper_n_text_ctx (struct whisper_context * ctx);` 129 | #[inline] 130 | pub fn n_text_ctx(&self) -> c_int { 131 | unsafe { whisper_rs_sys::whisper_n_text_ctx(self.ctx) } 132 | } 133 | 134 | /// Get n_audio_ctx. 135 | /// 136 | /// # Returns 137 | /// c_int 138 | /// 139 | /// # C++ equivalent 140 | /// `int whisper_n_audio_ctx (struct whisper_context * ctx);` 141 | #[inline] 142 | pub fn n_audio_ctx(&self) -> c_int { 143 | unsafe { whisper_rs_sys::whisper_n_audio_ctx(self.ctx) } 144 | } 145 | 146 | /// Does this model support multiple languages? 147 | /// 148 | /// # C++ equivalent 149 | /// `int whisper_is_multilingual(struct whisper_context * ctx)` 150 | #[inline] 151 | pub fn is_multilingual(&self) -> bool { 152 | unsafe { whisper_rs_sys::whisper_is_multilingual(self.ctx) != 0 } 153 | } 154 | 155 | /// Get model_n_vocab. 156 | /// 157 | /// # Returns 158 | /// c_int 159 | /// 160 | /// # C++ equivalent 161 | /// `int whisper_model_n_vocab (struct whisper_context * ctx);` 162 | #[inline] 163 | pub fn model_n_vocab(&self) -> c_int { 164 | unsafe { whisper_rs_sys::whisper_model_n_vocab(self.ctx) } 165 | } 166 | 167 | /// Get model_n_audio_ctx. 168 | /// 169 | /// # Returns 170 | /// c_int 171 | /// 172 | /// # C++ equivalent 173 | /// `int whisper_model_n_audio_ctx (struct whisper_context * ctx)` 174 | #[inline] 175 | pub fn model_n_audio_ctx(&self) -> c_int { 176 | unsafe { whisper_rs_sys::whisper_model_n_audio_ctx(self.ctx) } 177 | } 178 | 179 | /// Get model_n_audio_state. 180 | /// 181 | /// # Returns 182 | /// c_int 183 | /// 184 | /// # C++ equivalent 185 | /// `int whisper_model_n_audio_state(struct whisper_context * ctx);` 186 | #[inline] 187 | pub fn model_n_audio_state(&self) -> c_int { 188 | unsafe { whisper_rs_sys::whisper_model_n_audio_state(self.ctx) } 189 | } 190 | 191 | /// Get model_n_audio_head. 192 | /// 193 | /// # Returns 194 | /// c_int 195 | /// 196 | /// # C++ equivalent 197 | /// `int whisper_model_n_audio_head (struct whisper_context * ctx);` 198 | #[inline] 199 | pub fn model_n_audio_head(&self) -> c_int { 200 | unsafe { whisper_rs_sys::whisper_model_n_audio_head(self.ctx) } 201 | } 202 | 203 | /// Get model_n_audio_layer. 204 | /// 205 | /// # Returns 206 | /// c_int 207 | /// 208 | /// # C++ equivalent 209 | /// `int whisper_model_n_audio_layer(struct whisper_context * ctx);` 210 | #[inline] 211 | pub fn model_n_audio_layer(&self) -> c_int { 212 | unsafe { whisper_rs_sys::whisper_model_n_audio_layer(self.ctx) } 213 | } 214 | 215 | /// Get model_n_text_ctx. 216 | /// 217 | /// # Returns 218 | /// c_int 219 | /// 220 | /// # C++ equivalent 221 | /// `int whisper_model_n_text_ctx (struct whisper_context * ctx)` 222 | #[inline] 223 | pub fn model_n_text_ctx(&self) -> c_int { 224 | unsafe { whisper_rs_sys::whisper_model_n_text_ctx(self.ctx) } 225 | } 226 | 227 | /// Get model_n_text_state. 228 | /// 229 | /// # Returns 230 | /// c_int 231 | /// 232 | /// # C++ equivalent 233 | /// `int whisper_model_n_text_state (struct whisper_context * ctx);` 234 | #[inline] 235 | pub fn model_n_text_state(&self) -> c_int { 236 | unsafe { whisper_rs_sys::whisper_model_n_text_state(self.ctx) } 237 | } 238 | 239 | /// Get model_n_text_head. 240 | /// 241 | /// # Returns 242 | /// c_int 243 | /// 244 | /// # C++ equivalent 245 | /// `int whisper_model_n_text_head (struct whisper_context * ctx);` 246 | #[inline] 247 | pub fn model_n_text_head(&self) -> c_int { 248 | unsafe { whisper_rs_sys::whisper_model_n_text_head(self.ctx) } 249 | } 250 | 251 | /// Get model_n_text_layer. 252 | /// 253 | /// # Returns 254 | /// c_int 255 | /// 256 | /// # C++ equivalent 257 | /// `int whisper_model_n_text_layer (struct whisper_context * ctx);` 258 | #[inline] 259 | pub fn model_n_text_layer(&self) -> c_int { 260 | unsafe { whisper_rs_sys::whisper_model_n_text_layer(self.ctx) } 261 | } 262 | 263 | /// Get model_n_mels. 264 | /// 265 | /// # Returns 266 | /// c_int 267 | /// 268 | /// # C++ equivalent 269 | /// `int whisper_model_n_mels (struct whisper_context * ctx);` 270 | #[inline] 271 | pub fn model_n_mels(&self) -> c_int { 272 | unsafe { whisper_rs_sys::whisper_model_n_mels(self.ctx) } 273 | } 274 | 275 | /// Get model_ftype. 276 | /// 277 | /// # Returns 278 | /// c_int 279 | /// 280 | /// # C++ equivalent 281 | /// `int whisper_model_ftype (struct whisper_context * ctx);` 282 | #[inline] 283 | pub fn model_ftype(&self) -> c_int { 284 | unsafe { whisper_rs_sys::whisper_model_ftype(self.ctx) } 285 | } 286 | 287 | /// Get model_type. 288 | /// 289 | /// # Returns 290 | /// c_int 291 | /// 292 | /// # C++ equivalent 293 | /// `int whisper_model_type (struct whisper_context * ctx);` 294 | #[inline] 295 | pub fn model_type(&self) -> c_int { 296 | unsafe { whisper_rs_sys::whisper_model_type(self.ctx) } 297 | } 298 | 299 | // token functions 300 | /// Convert a token ID to a string. 301 | /// 302 | /// # Arguments 303 | /// * token_id: ID of the token. 304 | /// 305 | /// # Returns 306 | /// Ok(&str) on success, Err(WhisperError) on failure. 307 | /// 308 | /// # C++ equivalent 309 | /// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)` 310 | pub fn token_to_str(&self, token_id: WhisperToken) -> Result<&str, WhisperError> { 311 | let c_str = self.token_to_cstr(token_id)?; 312 | let r_str = c_str.to_str()?; 313 | Ok(r_str) 314 | } 315 | 316 | /// Convert a token ID to a &CStr. 317 | /// 318 | /// # Arguments 319 | /// * token_id: ID of the token. 320 | /// 321 | /// # Returns 322 | /// Ok(String) on success, Err(WhisperError) on failure. 323 | /// 324 | /// # C++ equivalent 325 | /// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)` 326 | pub fn token_to_cstr(&self, token_id: WhisperToken) -> Result<&CStr, WhisperError> { 327 | let ret = unsafe { whisper_rs_sys::whisper_token_to_str(self.ctx, token_id) }; 328 | if ret.is_null() { 329 | return Err(WhisperError::NullPointer); 330 | } 331 | Ok(unsafe { CStr::from_ptr(ret) }) 332 | } 333 | 334 | /// Undocumented but exposed function in the C++ API. 335 | /// `const char * whisper_model_type_readable(struct whisper_context * ctx);` 336 | /// 337 | /// # Returns 338 | /// Ok(String) on success, Err(WhisperError) on failure. 339 | pub fn model_type_readable(&self) -> Result { 340 | let ret = unsafe { whisper_rs_sys::whisper_model_type_readable(self.ctx) }; 341 | if ret.is_null() { 342 | return Err(WhisperError::NullPointer); 343 | } 344 | let c_str = unsafe { CStr::from_ptr(ret) }; 345 | let r_str = c_str.to_str()?; 346 | Ok(r_str.to_string()) 347 | } 348 | 349 | /// Get the ID of the eot token. 350 | /// 351 | /// # C++ equivalent 352 | /// `whisper_token whisper_token_eot (struct whisper_context * ctx)` 353 | #[inline] 354 | pub fn token_eot(&self) -> WhisperToken { 355 | unsafe { whisper_rs_sys::whisper_token_eot(self.ctx) } 356 | } 357 | 358 | /// Get the ID of the sot token. 359 | /// 360 | /// # C++ equivalent 361 | /// `whisper_token whisper_token_sot (struct whisper_context * ctx)` 362 | #[inline] 363 | pub fn token_sot(&self) -> WhisperToken { 364 | unsafe { whisper_rs_sys::whisper_token_sot(self.ctx) } 365 | } 366 | 367 | /// Get the ID of the solm token. 368 | /// 369 | /// # C++ equivalent 370 | /// `whisper_token whisper_token_solm(struct whisper_context * ctx)` 371 | #[inline] 372 | pub fn token_solm(&self) -> WhisperToken { 373 | unsafe { whisper_rs_sys::whisper_token_solm(self.ctx) } 374 | } 375 | 376 | /// Get the ID of the prev token. 377 | /// 378 | /// # C++ equivalent 379 | /// `whisper_token whisper_token_prev(struct whisper_context * ctx)` 380 | #[inline] 381 | pub fn token_prev(&self) -> WhisperToken { 382 | unsafe { whisper_rs_sys::whisper_token_prev(self.ctx) } 383 | } 384 | 385 | /// Get the ID of the nosp token. 386 | /// 387 | /// # C++ equivalent 388 | /// `whisper_token whisper_token_nosp(struct whisper_context * ctx)` 389 | #[inline] 390 | pub fn token_nosp(&self) -> WhisperToken { 391 | unsafe { whisper_rs_sys::whisper_token_nosp(self.ctx) } 392 | } 393 | 394 | /// Get the ID of the not token. 395 | /// 396 | /// # C++ equivalent 397 | /// `whisper_token whisper_token_not (struct whisper_context * ctx)` 398 | #[inline] 399 | pub fn token_not(&self) -> WhisperToken { 400 | unsafe { whisper_rs_sys::whisper_token_not(self.ctx) } 401 | } 402 | 403 | /// Get the ID of the beg token. 404 | /// 405 | /// # C++ equivalent 406 | /// `whisper_token whisper_token_beg (struct whisper_context * ctx)` 407 | #[inline] 408 | pub fn token_beg(&self) -> WhisperToken { 409 | unsafe { whisper_rs_sys::whisper_token_beg(self.ctx) } 410 | } 411 | 412 | /// Get the ID of a specified language token 413 | /// 414 | /// # Arguments 415 | /// * lang_id: ID of the language 416 | /// 417 | /// # C++ equivalent 418 | /// `whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id)` 419 | #[inline] 420 | pub fn token_lang(&self, lang_id: c_int) -> WhisperToken { 421 | unsafe { whisper_rs_sys::whisper_token_lang(self.ctx, lang_id) } 422 | } 423 | 424 | /// Print performance statistics to stderr. 425 | /// 426 | /// # C++ equivalent 427 | /// `void whisper_print_timings(struct whisper_context * ctx)` 428 | #[inline] 429 | pub fn print_timings(&self) { 430 | unsafe { whisper_rs_sys::whisper_print_timings(self.ctx) } 431 | } 432 | 433 | /// Reset performance statistics. 434 | /// 435 | /// # C++ equivalent 436 | /// `void whisper_reset_timings(struct whisper_context * ctx)` 437 | #[inline] 438 | pub fn reset_timings(&self) { 439 | unsafe { whisper_rs_sys::whisper_reset_timings(self.ctx) } 440 | } 441 | 442 | // task tokens 443 | /// Get the ID of the translate task token. 444 | /// 445 | /// # C++ equivalent 446 | /// `whisper_token whisper_token_translate ()` 447 | pub fn token_translate(&self) -> WhisperToken { 448 | unsafe { whisper_rs_sys::whisper_token_translate(self.ctx) } 449 | } 450 | 451 | /// Get the ID of the transcribe task token. 452 | /// 453 | /// # C++ equivalent 454 | /// `whisper_token whisper_token_transcribe()` 455 | pub fn token_transcribe(&self) -> WhisperToken { 456 | unsafe { whisper_rs_sys::whisper_token_transcribe(self.ctx) } 457 | } 458 | } 459 | 460 | impl Drop for WhisperInnerContext { 461 | #[inline] 462 | fn drop(&mut self) { 463 | unsafe { whisper_rs_sys::whisper_free(self.ctx) }; 464 | } 465 | } 466 | 467 | // following implementations are safe 468 | // see https://github.com/ggerganov/whisper.cpp/issues/32#issuecomment-1272790388 469 | unsafe impl Send for WhisperInnerContext {} 470 | unsafe impl Sync for WhisperInnerContext {} 471 | 472 | pub struct WhisperContextParameters<'a> { 473 | /// Use GPU if available. 474 | pub use_gpu: bool, 475 | /// Enable flash attention, default false 476 | /// 477 | /// **Warning** Can't be used with DTW. DTW will be disabled if flash_attn is true 478 | pub flash_attn: bool, 479 | /// GPU device id, default 0 480 | pub gpu_device: c_int, 481 | /// DTW token level timestamp parameters 482 | pub dtw_parameters: DtwParameters<'a>, 483 | } 484 | 485 | #[allow(clippy::derivable_impls)] // this impl cannot be derived 486 | impl<'a> Default for WhisperContextParameters<'a> { 487 | fn default() -> Self { 488 | Self { 489 | use_gpu: cfg!(feature = "_gpu"), 490 | flash_attn: false, 491 | gpu_device: 0, 492 | dtw_parameters: DtwParameters::default(), 493 | } 494 | } 495 | } 496 | impl<'a> WhisperContextParameters<'a> { 497 | pub fn new() -> Self { 498 | Self::default() 499 | } 500 | pub fn use_gpu(&mut self, use_gpu: bool) -> &mut Self { 501 | self.use_gpu = use_gpu; 502 | self 503 | } 504 | pub fn flash_attn(&mut self, flash_attn: bool) -> &mut Self { 505 | self.flash_attn = flash_attn; 506 | self 507 | } 508 | pub fn gpu_device(&mut self, gpu_device: c_int) -> &mut Self { 509 | self.gpu_device = gpu_device; 510 | self 511 | } 512 | pub fn dtw_parameters(&mut self, dtw_parameters: DtwParameters<'a>) -> &mut Self { 513 | self.dtw_parameters = dtw_parameters; 514 | self 515 | } 516 | 517 | fn to_c_struct(&self) -> whisper_rs_sys::whisper_context_params { 518 | let dtw_token_timestamps = !matches!(self.dtw_parameters.mode, DtwMode::None); 519 | let mut dtw_aheads_preset = 520 | whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_NONE; 521 | let mut dtw_n_top: c_int = -1; 522 | let mut dtw_aheads = whisper_rs_sys::whisper_aheads { 523 | n_heads: 0, 524 | heads: std::ptr::null(), 525 | }; 526 | 527 | match &self.dtw_parameters.mode { 528 | DtwMode::None => {} 529 | DtwMode::TopMost { n_top } => { 530 | dtw_aheads_preset = 531 | whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_N_TOP_MOST; 532 | dtw_n_top = *n_top; 533 | } 534 | DtwMode::Custom { aheads } => { 535 | dtw_aheads_preset = 536 | whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_CUSTOM; 537 | 538 | dtw_aheads = whisper_rs_sys::whisper_aheads { 539 | n_heads: aheads.len(), 540 | heads: aheads.as_ptr(), 541 | }; 542 | } 543 | DtwMode::ModelPreset { model_preset } => match model_preset { 544 | DtwModelPreset::TinyEn => { 545 | dtw_aheads_preset = 546 | whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_TINY_EN; 547 | } 548 | DtwModelPreset::Tiny => { 549 | dtw_aheads_preset = 550 | whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_TINY; 551 | } 552 | DtwModelPreset::BaseEn => { 553 | dtw_aheads_preset = 554 | whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_BASE_EN; 555 | } 556 | DtwModelPreset::Base => { 557 | dtw_aheads_preset = 558 | whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_BASE; 559 | } 560 | DtwModelPreset::SmallEn => { 561 | dtw_aheads_preset = 562 | whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_SMALL_EN; 563 | } 564 | DtwModelPreset::Small => { 565 | dtw_aheads_preset = 566 | whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_SMALL; 567 | } 568 | DtwModelPreset::MediumEn => { 569 | dtw_aheads_preset = 570 | whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_MEDIUM_EN; 571 | } 572 | DtwModelPreset::Medium => { 573 | dtw_aheads_preset = 574 | whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_MEDIUM; 575 | } 576 | DtwModelPreset::LargeV1 => { 577 | dtw_aheads_preset = 578 | whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_LARGE_V1; 579 | } 580 | DtwModelPreset::LargeV2 => { 581 | dtw_aheads_preset = 582 | whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_LARGE_V2; 583 | } 584 | DtwModelPreset::LargeV3 => { 585 | dtw_aheads_preset = 586 | whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_LARGE_V3; 587 | } 588 | DtwModelPreset::LargeV3Turbo => { 589 | dtw_aheads_preset = 590 | whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_LARGE_V3_TURBO; 591 | } 592 | }, 593 | } 594 | 595 | whisper_rs_sys::whisper_context_params { 596 | use_gpu: self.use_gpu, 597 | flash_attn: self.flash_attn, 598 | gpu_device: self.gpu_device, 599 | dtw_token_timestamps, 600 | dtw_aheads_preset, 601 | dtw_n_top, 602 | dtw_aheads, 603 | dtw_mem_size: self.dtw_parameters.dtw_mem_size, 604 | } 605 | } 606 | } 607 | 608 | /// [EXPERIMENTAL] Enable Token-level timestamps with DTW, default Disabled 609 | #[derive(Debug, Clone)] 610 | pub struct DtwParameters<'a> { 611 | pub mode: DtwMode<'a>, 612 | pub dtw_mem_size: usize, 613 | } 614 | 615 | impl Default for DtwParameters<'_> { 616 | fn default() -> Self { 617 | Self { 618 | mode: DtwMode::None, 619 | dtw_mem_size: 1024 * 1024 * 128, 620 | } 621 | } 622 | } 623 | 624 | #[derive(Debug, Clone)] 625 | pub enum DtwMode<'a> { 626 | /// DTW token level timestamps disabled 627 | None, 628 | /// Use N Top Most layers from loaded model 629 | TopMost { 630 | /// Number of top text layers used from model, should be 0 < n_top <= model n_text_layer 631 | n_top: c_int, 632 | }, 633 | /// Use custom aheads, non-empty list of whisper_ahead. 634 | /// 0 < n_text_layer < model n_text_layer, 0 < n_head < model n_text_head for each element 635 | /// See details https://github.com/ggerganov/whisper.cpp/pull/1485#discussion_r1519681143 636 | Custom { 637 | aheads: &'a [whisper_rs_sys::whisper_ahead], 638 | }, 639 | /// Use predefined preset for standard models 640 | ModelPreset { model_preset: DtwModelPreset }, 641 | } 642 | 643 | #[derive(Debug, Clone)] 644 | pub enum DtwModelPreset { 645 | TinyEn, 646 | Tiny, 647 | BaseEn, 648 | Base, 649 | SmallEn, 650 | Small, 651 | MediumEn, 652 | Medium, 653 | LargeV1, 654 | LargeV2, 655 | LargeV3, 656 | LargeV3Turbo, 657 | } 658 | 659 | #[cfg(test)] 660 | #[cfg(feature = "test-with-tiny-model")] 661 | mod test_with_tiny_model { 662 | use super::*; 663 | const MODEL_PATH: &str = "./sys/whisper.cpp/models/ggml-tiny.en.bin"; 664 | 665 | // These tests expect that the tiny.en model has been downloaded 666 | // using the script `sys/whisper.cpp/models/download-ggml-model.sh tiny.en` 667 | 668 | #[test] 669 | fn test_tokenize_round_trip() { 670 | let ctx = WhisperInnerContext::new(MODEL_PATH).expect("Download the ggml-tiny.en model using 'sys/whisper.cpp/models/download-ggml-model.sh tiny.en'"); 671 | let text_in = " And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country."; 672 | let tokens = ctx.tokenize(text_in, 1024).unwrap(); 673 | let text_out = tokens 674 | .into_iter() 675 | .map(|t| ctx.token_to_str(t).unwrap()) 676 | .collect::>() 677 | .join(""); 678 | assert_eq!(text_in, text_out); 679 | } 680 | } 681 | -------------------------------------------------------------------------------- /src/whisper_ctx_wrapper.rs: -------------------------------------------------------------------------------- 1 | use std::ffi::{c_int, CStr}; 2 | use std::sync::Arc; 3 | 4 | use crate::{ 5 | WhisperContextParameters, WhisperError, WhisperInnerContext, WhisperState, WhisperToken, 6 | }; 7 | 8 | pub struct WhisperContext { 9 | ctx: Arc, 10 | } 11 | 12 | impl WhisperContext { 13 | fn wrap(ctx: WhisperInnerContext) -> Self { 14 | Self { ctx: Arc::new(ctx) } 15 | } 16 | 17 | /// Create a new WhisperContext from a file, with parameters. 18 | /// 19 | /// # Arguments 20 | /// * path: The path to the model file. 21 | /// * parameters: A parameter struct containing the parameters to use. 22 | /// 23 | /// # Returns 24 | /// Ok(Self) on success, Err(WhisperError) on failure. 25 | /// 26 | /// # C++ equivalent 27 | /// `struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params);` 28 | pub fn new_with_params( 29 | path: &str, 30 | parameters: WhisperContextParameters, 31 | ) -> Result { 32 | let ctx = WhisperInnerContext::new_with_params(path, parameters)?; 33 | Ok(Self::wrap(ctx)) 34 | } 35 | 36 | /// Create a new WhisperContext from a buffer. 37 | /// 38 | /// # Arguments 39 | /// * buffer: The buffer containing the model. 40 | /// 41 | /// # Returns 42 | /// Ok(Self) on success, Err(WhisperError) on failure. 43 | /// 44 | /// # C++ equivalent 45 | /// `struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct whisper_context_params params);` 46 | pub fn new_from_buffer_with_params( 47 | buffer: &[u8], 48 | parameters: WhisperContextParameters, 49 | ) -> Result { 50 | let ctx = WhisperInnerContext::new_from_buffer_with_params(buffer, parameters)?; 51 | Ok(Self::wrap(ctx)) 52 | } 53 | 54 | /// Convert the provided text into tokens. 55 | /// 56 | /// # Arguments 57 | /// * text: The text to convert. 58 | /// 59 | /// # Returns 60 | /// `Ok(Vec)` on success, `Err(WhisperError)` on failure. 61 | /// 62 | /// # C++ equivalent 63 | /// `int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens);` 64 | pub fn tokenize( 65 | &self, 66 | text: &str, 67 | max_tokens: usize, 68 | ) -> Result, WhisperError> { 69 | self.ctx.tokenize(text, max_tokens) 70 | } 71 | 72 | /// Get n_vocab. 73 | /// 74 | /// # Returns 75 | /// c_int 76 | /// 77 | /// # C++ equivalent 78 | /// `int whisper_n_vocab (struct whisper_context * ctx)` 79 | #[inline] 80 | pub fn n_vocab(&self) -> c_int { 81 | self.ctx.n_vocab() 82 | } 83 | 84 | /// Get n_text_ctx. 85 | /// 86 | /// # Returns 87 | /// c_int 88 | /// 89 | /// # C++ equivalent 90 | /// `int whisper_n_text_ctx (struct whisper_context * ctx);` 91 | #[inline] 92 | pub fn n_text_ctx(&self) -> c_int { 93 | self.ctx.n_text_ctx() 94 | } 95 | 96 | /// Get n_audio_ctx. 97 | /// 98 | /// # Returns 99 | /// c_int 100 | /// 101 | /// # C++ equivalent 102 | /// `int whisper_n_audio_ctx (struct whisper_context * ctx);` 103 | #[inline] 104 | pub fn n_audio_ctx(&self) -> c_int { 105 | self.ctx.n_audio_ctx() 106 | } 107 | 108 | /// Does this model support multiple languages? 109 | /// 110 | /// # C++ equivalent 111 | /// `int whisper_is_multilingual(struct whisper_context * ctx)` 112 | #[inline] 113 | pub fn is_multilingual(&self) -> bool { 114 | self.ctx.is_multilingual() 115 | } 116 | 117 | /// Get model_n_vocab. 118 | /// 119 | /// # Returns 120 | /// c_int 121 | /// 122 | /// # C++ equivalent 123 | /// `int whisper_model_n_vocab (struct whisper_context * ctx);` 124 | #[inline] 125 | pub fn model_n_vocab(&self) -> c_int { 126 | self.ctx.model_n_vocab() 127 | } 128 | 129 | /// Get model_n_audio_ctx. 130 | /// 131 | /// # Returns 132 | /// c_int 133 | /// 134 | /// # C++ equivalent 135 | /// `int whisper_model_n_audio_ctx (struct whisper_context * ctx)` 136 | #[inline] 137 | pub fn model_n_audio_ctx(&self) -> c_int { 138 | self.ctx.model_n_audio_ctx() 139 | } 140 | 141 | /// Get model_n_audio_state. 142 | /// 143 | /// # Returns 144 | /// c_int 145 | /// 146 | /// # C++ equivalent 147 | /// `int whisper_model_n_audio_state(struct whisper_context * ctx);` 148 | #[inline] 149 | pub fn model_n_audio_state(&self) -> c_int { 150 | self.ctx.model_n_audio_state() 151 | } 152 | 153 | /// Get model_n_audio_head. 154 | /// 155 | /// # Returns 156 | /// c_int 157 | /// 158 | /// # C++ equivalent 159 | /// `int whisper_model_n_audio_head (struct whisper_context * ctx);` 160 | #[inline] 161 | pub fn model_n_audio_head(&self) -> c_int { 162 | self.ctx.model_n_audio_head() 163 | } 164 | 165 | /// Get model_n_audio_layer. 166 | /// 167 | /// # Returns 168 | /// c_int 169 | /// 170 | /// # C++ equivalent 171 | /// `int whisper_model_n_audio_layer(struct whisper_context * ctx);` 172 | #[inline] 173 | pub fn model_n_audio_layer(&self) -> c_int { 174 | self.ctx.model_n_audio_layer() 175 | } 176 | 177 | /// Get model_n_text_ctx. 178 | /// 179 | /// # Returns 180 | /// c_int 181 | /// 182 | /// # C++ equivalent 183 | /// `int whisper_model_n_text_ctx (struct whisper_context * ctx)` 184 | #[inline] 185 | pub fn model_n_text_ctx(&self) -> c_int { 186 | self.ctx.model_n_text_ctx() 187 | } 188 | 189 | /// Get model_n_text_state. 190 | /// 191 | /// # Returns 192 | /// c_int 193 | /// 194 | /// # C++ equivalent 195 | /// `int whisper_model_n_text_state (struct whisper_context * ctx);` 196 | #[inline] 197 | pub fn model_n_text_state(&self) -> c_int { 198 | self.ctx.model_n_text_state() 199 | } 200 | 201 | /// Get model_n_text_head. 202 | /// 203 | /// # Returns 204 | /// c_int 205 | /// 206 | /// # C++ equivalent 207 | /// `int whisper_model_n_text_head (struct whisper_context * ctx);` 208 | #[inline] 209 | pub fn model_n_text_head(&self) -> c_int { 210 | self.ctx.model_n_text_head() 211 | } 212 | 213 | /// Get model_n_text_layer. 214 | /// 215 | /// # Returns 216 | /// c_int 217 | /// 218 | /// # C++ equivalent 219 | /// `int whisper_model_n_text_layer (struct whisper_context * ctx);` 220 | #[inline] 221 | pub fn model_n_text_layer(&self) -> c_int { 222 | self.ctx.model_n_text_layer() 223 | } 224 | 225 | /// Get model_n_mels. 226 | /// 227 | /// # Returns 228 | /// c_int 229 | /// 230 | /// # C++ equivalent 231 | /// `int whisper_model_n_mels (struct whisper_context * ctx);` 232 | #[inline] 233 | pub fn model_n_mels(&self) -> c_int { 234 | self.ctx.model_n_mels() 235 | } 236 | 237 | /// Get model_ftype. 238 | /// 239 | /// # Returns 240 | /// c_int 241 | /// 242 | /// # C++ equivalent 243 | /// `int whisper_model_ftype (struct whisper_context * ctx);` 244 | #[inline] 245 | pub fn model_ftype(&self) -> c_int { 246 | self.ctx.model_ftype() 247 | } 248 | 249 | /// Get model_type. 250 | /// 251 | /// # Returns 252 | /// c_int 253 | /// 254 | /// # C++ equivalent 255 | /// `int whisper_model_type (struct whisper_context * ctx);` 256 | #[inline] 257 | pub fn model_type(&self) -> c_int { 258 | self.ctx.model_type() 259 | } 260 | 261 | // token functions 262 | /// Convert a token ID to a string. 263 | /// 264 | /// # Arguments 265 | /// * token_id: ID of the token. 266 | /// 267 | /// # Returns 268 | /// Ok(&str) on success, Err(WhisperError) on failure. 269 | /// 270 | /// # C++ equivalent 271 | /// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)` 272 | pub fn token_to_str(&self, token_id: WhisperToken) -> Result<&str, WhisperError> { 273 | self.ctx.token_to_str(token_id) 274 | } 275 | 276 | /// Convert a token ID to a &CStr. 277 | /// 278 | /// # Arguments 279 | /// * token_id: ID of the token. 280 | /// 281 | /// # Returns 282 | /// Ok(String) on success, Err(WhisperError) on failure. 283 | /// 284 | /// # C++ equivalent 285 | /// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)` 286 | pub fn token_to_cstr(&self, token_id: WhisperToken) -> Result<&CStr, WhisperError> { 287 | self.ctx.token_to_cstr(token_id) 288 | } 289 | 290 | /// Undocumented but exposed function in the C++ API. 291 | /// `const char * whisper_model_type_readable(struct whisper_context * ctx);` 292 | /// 293 | /// # Returns 294 | /// Ok(String) on success, Err(WhisperError) on failure. 295 | pub fn model_type_readable(&self) -> Result { 296 | self.ctx.model_type_readable() 297 | } 298 | 299 | /// Get the ID of the eot token. 300 | /// 301 | /// # C++ equivalent 302 | /// `whisper_token whisper_token_eot (struct whisper_context * ctx)` 303 | #[inline] 304 | pub fn token_eot(&self) -> WhisperToken { 305 | self.ctx.token_eot() 306 | } 307 | 308 | /// Get the ID of the sot token. 309 | /// 310 | /// # C++ equivalent 311 | /// `whisper_token whisper_token_sot (struct whisper_context * ctx)` 312 | #[inline] 313 | pub fn token_sot(&self) -> WhisperToken { 314 | self.ctx.token_sot() 315 | } 316 | 317 | /// Get the ID of the solm token. 318 | /// 319 | /// # C++ equivalent 320 | /// `whisper_token whisper_token_solm(struct whisper_context * ctx)` 321 | #[inline] 322 | pub fn token_solm(&self) -> WhisperToken { 323 | self.ctx.token_solm() 324 | } 325 | 326 | /// Get the ID of the prev token. 327 | /// 328 | /// # C++ equivalent 329 | /// `whisper_token whisper_token_prev(struct whisper_context * ctx)` 330 | #[inline] 331 | pub fn token_prev(&self) -> WhisperToken { 332 | self.ctx.token_prev() 333 | } 334 | 335 | /// Get the ID of the nosp token. 336 | /// 337 | /// # C++ equivalent 338 | /// `whisper_token whisper_token_nosp(struct whisper_context * ctx)` 339 | #[inline] 340 | pub fn token_nosp(&self) -> WhisperToken { 341 | self.ctx.token_nosp() 342 | } 343 | 344 | /// Get the ID of the not token. 345 | /// 346 | /// # C++ equivalent 347 | /// `whisper_token whisper_token_not (struct whisper_context * ctx)` 348 | #[inline] 349 | pub fn token_not(&self) -> WhisperToken { 350 | self.ctx.token_not() 351 | } 352 | 353 | /// Get the ID of the beg token. 354 | /// 355 | /// # C++ equivalent 356 | /// `whisper_token whisper_token_beg (struct whisper_context * ctx)` 357 | #[inline] 358 | pub fn token_beg(&self) -> WhisperToken { 359 | self.ctx.token_beg() 360 | } 361 | 362 | /// Get the ID of a specified language token 363 | /// 364 | /// # Arguments 365 | /// * lang_id: ID of the language 366 | /// 367 | /// # C++ equivalent 368 | /// `whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id)` 369 | #[inline] 370 | pub fn token_lang(&self, lang_id: c_int) -> WhisperToken { 371 | self.ctx.token_lang(lang_id) 372 | } 373 | 374 | /// Print performance statistics to stderr. 375 | /// 376 | /// # C++ equivalent 377 | /// `void whisper_print_timings(struct whisper_context * ctx)` 378 | #[inline] 379 | pub fn print_timings(&self) { 380 | self.ctx.print_timings() 381 | } 382 | 383 | /// Reset performance statistics. 384 | /// 385 | /// # C++ equivalent 386 | /// `void whisper_reset_timings(struct whisper_context * ctx)` 387 | #[inline] 388 | pub fn reset_timings(&self) { 389 | self.ctx.reset_timings() 390 | } 391 | 392 | // task tokens 393 | /// Get the ID of the translate task token. 394 | /// 395 | /// # C++ equivalent 396 | /// `whisper_token whisper_token_translate ()` 397 | pub fn token_translate(&self) -> WhisperToken { 398 | self.ctx.token_translate() 399 | } 400 | 401 | /// Get the ID of the transcribe task token. 402 | /// 403 | /// # C++ equivalent 404 | /// `whisper_token whisper_token_transcribe()` 405 | pub fn token_transcribe(&self) -> WhisperToken { 406 | self.ctx.token_transcribe() 407 | } 408 | 409 | // we don't implement `whisper_init()` here since i have zero clue what `whisper_model_loader` does 410 | 411 | /// Create a new state object, ready for use. 412 | /// 413 | /// # Returns 414 | /// Ok(WhisperState) on success, Err(WhisperError) on failure. 415 | /// 416 | /// # C++ equivalent 417 | /// `struct whisper_state * whisper_init_state(struct whisper_context * ctx);` 418 | pub fn create_state(&self) -> Result { 419 | let state = unsafe { whisper_rs_sys::whisper_init_state(self.ctx.ctx) }; 420 | if state.is_null() { 421 | Err(WhisperError::InitError) 422 | } else { 423 | // SAFETY: this is known to be a valid pointer to a `whisper_state` struct 424 | Ok(WhisperState::new(self.ctx.clone(), state)) 425 | } 426 | } 427 | } 428 | -------------------------------------------------------------------------------- /src/whisper_grammar.rs: -------------------------------------------------------------------------------- 1 | use whisper_rs_sys::{ 2 | whisper_gretype_WHISPER_GRETYPE_ALT, whisper_gretype_WHISPER_GRETYPE_CHAR, 3 | whisper_gretype_WHISPER_GRETYPE_CHAR_ALT, whisper_gretype_WHISPER_GRETYPE_CHAR_NOT, 4 | whisper_gretype_WHISPER_GRETYPE_CHAR_RNG_UPPER, whisper_gretype_WHISPER_GRETYPE_END, 5 | whisper_gretype_WHISPER_GRETYPE_RULE_REF, 6 | }; 7 | 8 | #[cfg_attr(any(not(windows), target_env = "gnu"), repr(u32))] // include windows-gnu 9 | #[cfg_attr(all(windows, not(target_env = "gnu")), repr(i32))] // msvc being *special* again 10 | #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] 11 | pub enum WhisperGrammarElementType { 12 | /// End of rule definition 13 | End = whisper_gretype_WHISPER_GRETYPE_END, 14 | /// Start of alternate definition for a rule 15 | Alternate = whisper_gretype_WHISPER_GRETYPE_ALT, 16 | /// Non-terminal element: reference to another rule 17 | RuleReference = whisper_gretype_WHISPER_GRETYPE_RULE_REF, 18 | /// Terminal element: character (code point) 19 | Character = whisper_gretype_WHISPER_GRETYPE_CHAR, 20 | /// Inverse of a character(s) 21 | NotCharacter = whisper_gretype_WHISPER_GRETYPE_CHAR_NOT, 22 | /// Modifies a preceding [Self::Character] to be an inclusive range 23 | CharacterRangeUpper = whisper_gretype_WHISPER_GRETYPE_CHAR_RNG_UPPER, 24 | /// Modifies a preceding [Self::Character] to add an alternate character to match 25 | CharacterAlternate = whisper_gretype_WHISPER_GRETYPE_CHAR_ALT, 26 | } 27 | 28 | impl From for WhisperGrammarElementType { 29 | fn from(value: whisper_rs_sys::whisper_gretype) -> Self { 30 | assert!( 31 | (0..=6).contains(&value), 32 | "Invalid WhisperGrammarElementType value: {}", 33 | value 34 | ); 35 | 36 | #[allow(non_upper_case_globals)] // weird place to trigger this 37 | match value { 38 | whisper_gretype_WHISPER_GRETYPE_END => WhisperGrammarElementType::End, 39 | whisper_gretype_WHISPER_GRETYPE_ALT => WhisperGrammarElementType::Alternate, 40 | whisper_gretype_WHISPER_GRETYPE_RULE_REF => WhisperGrammarElementType::RuleReference, 41 | whisper_gretype_WHISPER_GRETYPE_CHAR => WhisperGrammarElementType::Character, 42 | whisper_gretype_WHISPER_GRETYPE_CHAR_NOT => WhisperGrammarElementType::NotCharacter, 43 | whisper_gretype_WHISPER_GRETYPE_CHAR_RNG_UPPER => { 44 | WhisperGrammarElementType::CharacterRangeUpper 45 | } 46 | whisper_gretype_WHISPER_GRETYPE_CHAR_ALT => { 47 | WhisperGrammarElementType::CharacterAlternate 48 | } 49 | _ => unreachable!(), 50 | } 51 | } 52 | } 53 | 54 | impl From for whisper_rs_sys::whisper_gretype { 55 | fn from(value: WhisperGrammarElementType) -> Self { 56 | value as Self 57 | } 58 | } 59 | 60 | #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] 61 | pub struct WhisperGrammarElement { 62 | pub element_type: WhisperGrammarElementType, 63 | pub value: u32, 64 | } 65 | 66 | impl WhisperGrammarElement { 67 | pub fn new(element_type: WhisperGrammarElementType, value: u32) -> Self { 68 | Self { 69 | element_type, 70 | value, 71 | } 72 | } 73 | 74 | pub fn to_c_type(self) -> whisper_rs_sys::whisper_grammar_element { 75 | whisper_rs_sys::whisper_grammar_element { 76 | type_: self.element_type.into(), 77 | value: self.value, 78 | } 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /src/whisper_logging_hook.rs: -------------------------------------------------------------------------------- 1 | use crate::common_logging::{ 2 | generic_debug, generic_error, generic_info, generic_trace, generic_warn, GGMLLogLevel, 3 | }; 4 | use core::ffi::{c_char, c_void}; 5 | use std::borrow::Cow; 6 | use std::ffi::CStr; 7 | use std::sync::Once; 8 | use whisper_rs_sys::ggml_log_level; 9 | 10 | static WHISPER_LOG_TRAMPOLINE_INSTALL: Once = Once::new(); 11 | pub(crate) fn install_whisper_logging_hook() { 12 | WHISPER_LOG_TRAMPOLINE_INSTALL.call_once(|| unsafe { 13 | whisper_rs_sys::whisper_log_set(Some(whisper_logging_trampoline), std::ptr::null_mut()) 14 | }); 15 | } 16 | 17 | unsafe extern "C" fn whisper_logging_trampoline( 18 | level: ggml_log_level, 19 | text: *const c_char, 20 | _: *mut c_void, // user_data 21 | ) { 22 | if text.is_null() { 23 | generic_error!("whisper_logging_trampoline: text is nullptr"); 24 | } 25 | let level = GGMLLogLevel::from(level); 26 | 27 | // SAFETY: we must trust whisper.cpp that it will not pass us a string that does not satisfy 28 | // from_ptr's requirements. 29 | let log_str = unsafe { CStr::from_ptr(text) }.to_string_lossy(); 30 | 31 | whisper_logging_trampoline_safe(level, log_str) 32 | } 33 | 34 | // this code essentially compiles down to a noop if neither feature is enabled 35 | #[cfg_attr( 36 | not(any(feature = "log_backend", feature = "tracing_backend")), 37 | allow(unused_variables) 38 | )] 39 | fn whisper_logging_trampoline_safe(level: GGMLLogLevel, text: Cow) { 40 | match level { 41 | GGMLLogLevel::None => { 42 | // no clue what to do here, trace it? 43 | generic_trace!("{}", text.trim()); 44 | } 45 | GGMLLogLevel::Info => { 46 | generic_info!("{}", text.trim()); 47 | } 48 | GGMLLogLevel::Warn => { 49 | generic_warn!("{}", text.trim()); 50 | } 51 | GGMLLogLevel::Error => { 52 | generic_error!("{}", text.trim()); 53 | } 54 | GGMLLogLevel::Debug => { 55 | generic_debug!("{}", text.trim()); 56 | } 57 | GGMLLogLevel::Cont => { 58 | // this means continue previous log 59 | // storing state to do this is a massive pain so it's just a lot easier to not 60 | // plus as far as i can tell it's not actually *used* anywhere 61 | // whisper splits at 1024 chars and doesn't actually change the kind 62 | // so technically this is unused 63 | generic_trace!("{}", text.trim()); 64 | } 65 | GGMLLogLevel::Unknown(level) => { 66 | generic_warn!( 67 | "whisper_logging_trampoline: unknown log level {}: message: {}", 68 | level, 69 | text.trim() 70 | ); 71 | } 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /src/whisper_params.rs: -------------------------------------------------------------------------------- 1 | use crate::whisper_grammar::WhisperGrammarElement; 2 | use std::ffi::{c_char, c_float, c_int, CString}; 3 | use std::marker::PhantomData; 4 | use std::sync::Arc; 5 | use whisper_rs_sys::whisper_token; 6 | 7 | #[derive(Debug, Clone)] 8 | pub enum SamplingStrategy { 9 | Greedy { 10 | best_of: c_int, 11 | }, 12 | BeamSearch { 13 | beam_size: c_int, 14 | // not implemented in whisper.cpp as of this writing (v1.2.0) 15 | patience: c_float, 16 | }, 17 | } 18 | 19 | impl Default for SamplingStrategy { 20 | fn default() -> Self { 21 | Self::Greedy { best_of: 1 } 22 | } 23 | } 24 | 25 | #[derive(Debug, Clone)] 26 | pub struct SegmentCallbackData { 27 | pub segment: i32, 28 | pub start_timestamp: i64, 29 | pub end_timestamp: i64, 30 | pub text: String, 31 | } 32 | 33 | type SegmentCallbackFn = Box; 34 | 35 | #[derive(Clone)] 36 | pub struct FullParams<'a, 'b> { 37 | pub(crate) fp: whisper_rs_sys::whisper_full_params, 38 | phantom_lang: PhantomData<&'a str>, 39 | phantom_tokens: PhantomData<&'b [c_int]>, 40 | grammar: Option>, 41 | progess_callback_safe: Option>>, 42 | abort_callback_safe: Option bool>>>, 43 | segment_calllback_safe: Option>, 44 | } 45 | 46 | impl<'a, 'b> FullParams<'a, 'b> { 47 | /// Create a new set of parameters for the decoder. 48 | pub fn new(sampling_strategy: SamplingStrategy) -> FullParams<'a, 'b> { 49 | let mut fp = unsafe { 50 | whisper_rs_sys::whisper_full_default_params(match sampling_strategy { 51 | SamplingStrategy::Greedy { .. } => { 52 | whisper_rs_sys::whisper_sampling_strategy_WHISPER_SAMPLING_GREEDY 53 | } 54 | SamplingStrategy::BeamSearch { .. } => { 55 | whisper_rs_sys::whisper_sampling_strategy_WHISPER_SAMPLING_BEAM_SEARCH 56 | } 57 | } as _) 58 | }; 59 | 60 | match sampling_strategy { 61 | SamplingStrategy::Greedy { best_of } => { 62 | fp.greedy.best_of = best_of; 63 | } 64 | SamplingStrategy::BeamSearch { 65 | beam_size, 66 | patience, 67 | } => { 68 | fp.beam_search.beam_size = beam_size; 69 | fp.beam_search.patience = patience; 70 | } 71 | } 72 | 73 | Self { 74 | fp, 75 | phantom_lang: PhantomData, 76 | phantom_tokens: PhantomData, 77 | grammar: None, 78 | progess_callback_safe: None, 79 | abort_callback_safe: None, 80 | segment_calllback_safe: None, 81 | } 82 | } 83 | 84 | /// Set the number of threads to use for decoding. 85 | /// 86 | /// Defaults to min(4, std::thread::hardware_concurrency()). 87 | pub fn set_n_threads(&mut self, n_threads: c_int) { 88 | self.fp.n_threads = n_threads; 89 | } 90 | 91 | /// Max tokens to use from past text as prompt for the decoder 92 | /// 93 | /// Defaults to 16384. 94 | pub fn set_n_max_text_ctx(&mut self, n_max_text_ctx: c_int) { 95 | self.fp.n_max_text_ctx = n_max_text_ctx; 96 | } 97 | 98 | /// Set the start offset in milliseconds to use for decoding. 99 | /// 100 | /// Defaults to 0. 101 | pub fn set_offset_ms(&mut self, offset_ms: c_int) { 102 | self.fp.offset_ms = offset_ms; 103 | } 104 | 105 | /// Set the audio duration to process in milliseconds. 106 | /// 107 | /// Defaults to 0. 108 | pub fn set_duration_ms(&mut self, duration_ms: c_int) { 109 | self.fp.duration_ms = duration_ms; 110 | } 111 | 112 | /// Set whether to translate the output to the language specified by `language`. 113 | /// 114 | /// Defaults to false. 115 | pub fn set_translate(&mut self, translate: bool) { 116 | self.fp.translate = translate; 117 | } 118 | 119 | /// Do not use past transcription (if any) as initial prompt for the decoder. 120 | /// 121 | /// Defaults to false. 122 | pub fn set_no_context(&mut self, no_context: bool) { 123 | self.fp.no_context = no_context; 124 | } 125 | 126 | /// Do not generate timestamps. 127 | /// 128 | /// Defaults to false. 129 | pub fn set_no_timestamps(&mut self, no_timestamps: bool) { 130 | self.fp.no_timestamps = no_timestamps; 131 | } 132 | 133 | /// Force single segment output. This may be useful for streaming. 134 | /// 135 | /// Defaults to false. 136 | pub fn set_single_segment(&mut self, single_segment: bool) { 137 | self.fp.single_segment = single_segment; 138 | } 139 | 140 | /// Print special tokens (e.g. ``, ``, ``, etc.) 141 | /// 142 | /// Defaults to false. 143 | pub fn set_print_special(&mut self, print_special: bool) { 144 | self.fp.print_special = print_special; 145 | } 146 | 147 | /// Set whether to print progress. 148 | /// 149 | /// Defaults to true. 150 | pub fn set_print_progress(&mut self, print_progress: bool) { 151 | self.fp.print_progress = print_progress; 152 | } 153 | 154 | /// Print results from within whisper.cpp. 155 | /// Try to use the callback methods instead: [set_new_segment_callback](FullParams::set_new_segment_callback), 156 | /// [set_new_segment_callback_user_data](FullParams::set_new_segment_callback_user_data). 157 | /// 158 | /// Defaults to false. 159 | pub fn set_print_realtime(&mut self, print_realtime: bool) { 160 | self.fp.print_realtime = print_realtime; 161 | } 162 | 163 | /// Print timestamps for each text segment when printing realtime. Only has an effect if 164 | /// [set_print_realtime](FullParams::set_print_realtime) is set to true. 165 | /// 166 | /// Defaults to true. 167 | pub fn set_print_timestamps(&mut self, print_timestamps: bool) { 168 | self.fp.print_timestamps = print_timestamps; 169 | } 170 | 171 | /// # EXPERIMENTAL 172 | /// 173 | /// Enable token-level timestamps. 174 | /// 175 | /// Defaults to false. 176 | pub fn set_token_timestamps(&mut self, token_timestamps: bool) { 177 | self.fp.token_timestamps = token_timestamps; 178 | } 179 | 180 | /// # EXPERIMENTAL 181 | /// 182 | /// Set timestamp token probability threshold. 183 | /// 184 | /// Defaults to 0.01. 185 | pub fn set_thold_pt(&mut self, thold_pt: f32) { 186 | self.fp.thold_pt = thold_pt; 187 | } 188 | 189 | /// # EXPERIMENTAL 190 | /// 191 | /// Set timestamp token sum probability threshold. 192 | /// 193 | /// Defaults to 0.01. 194 | pub fn set_thold_ptsum(&mut self, thold_ptsum: f32) { 195 | self.fp.thold_ptsum = thold_ptsum; 196 | } 197 | 198 | /// # EXPERIMENTAL 199 | /// 200 | /// Set maximum segment length in characters. 201 | /// 202 | /// Defaults to 0. 203 | pub fn set_max_len(&mut self, max_len: c_int) { 204 | self.fp.max_len = max_len; 205 | } 206 | 207 | /// # EXPERIMENTAL 208 | /// 209 | /// Should the timestamps be split on words instead of characters? 210 | /// 211 | /// Defaults to false. 212 | pub fn set_split_on_word(&mut self, split_on_word: bool) { 213 | self.fp.split_on_word = split_on_word; 214 | } 215 | 216 | /// # EXPERIMENTAL 217 | /// 218 | /// Set maximum tokens per segment. 0 means no limit. 219 | /// 220 | /// Defaults to 0. 221 | pub fn set_max_tokens(&mut self, max_tokens: c_int) { 222 | self.fp.max_tokens = max_tokens; 223 | } 224 | 225 | /// # EXPERIMENTAL 226 | /// 227 | /// Enables debug mode, such as dumping the log mel spectrogram. 228 | /// 229 | /// Defaults to false. 230 | pub fn set_debug_mode(&mut self, debug: bool) { 231 | self.fp.debug_mode = debug; 232 | } 233 | 234 | /// # EXPERIMENTAL 235 | /// 236 | /// Overwrite the audio context size. 0 = default. 237 | /// 238 | /// Defaults to 0. 239 | pub fn set_audio_ctx(&mut self, audio_ctx: c_int) { 240 | self.fp.audio_ctx = audio_ctx; 241 | } 242 | 243 | /// # EXPERIMENTAL 244 | /// 245 | /// Enable tinydiarize support. 246 | /// Experimental speaker turn detection. 247 | /// 248 | /// Defaults to false. 249 | pub fn set_tdrz_enable(&mut self, tdrz_enable: bool) { 250 | self.fp.tdrz_enable = tdrz_enable; 251 | } 252 | 253 | /// Set tokens to provide the model as initial input. 254 | /// 255 | /// These tokens are prepended to any existing text content from a previous call. 256 | /// 257 | /// Calling this more than once will overwrite the previous tokens. 258 | /// 259 | /// Defaults to an empty vector. 260 | pub fn set_tokens(&mut self, tokens: &'b [c_int]) { 261 | // turn into ptr and len 262 | let tokens_ptr: *const whisper_token = tokens.as_ptr(); 263 | let tokens_len: c_int = tokens.len() as c_int; 264 | 265 | // set the tokens 266 | self.fp.prompt_tokens = tokens_ptr; 267 | self.fp.prompt_n_tokens = tokens_len; 268 | } 269 | 270 | /// Set the target language. 271 | /// 272 | /// For auto-detection, set this to either "auto" or None. 273 | /// 274 | /// Defaults to "en". 275 | pub fn set_language(&mut self, language: Option<&'a str>) { 276 | self.fp.language = match language { 277 | Some(language) => CString::new(language) 278 | .expect("Language contains null byte") 279 | .into_raw() as *const _, 280 | None => std::ptr::null(), 281 | }; 282 | } 283 | 284 | /// Set `detect_language`. 285 | /// 286 | /// Has the same effect as setting the language to "auto" or None. 287 | /// 288 | /// Defaults to false. 289 | pub fn set_detect_language(&mut self, detect_language: bool) { 290 | self.fp.detect_language = detect_language; 291 | } 292 | 293 | /// Set suppress_blank. 294 | /// See 295 | /// for more information. 296 | /// 297 | /// Defaults to true. 298 | pub fn set_suppress_blank(&mut self, suppress_blank: bool) { 299 | self.fp.suppress_blank = suppress_blank; 300 | } 301 | 302 | /// Set suppress_non_speech_tokens. 303 | /// See 304 | /// for more information. 305 | /// 306 | /// Defaults to false. 307 | pub fn set_suppress_nst(&mut self, suppress_nst: bool) { 308 | self.fp.suppress_nst = suppress_nst; 309 | } 310 | 311 | /// Set initial decoding temperature. 312 | /// See for more information. 313 | /// 314 | /// Defaults to 0.0. 315 | pub fn set_temperature(&mut self, temperature: f32) { 316 | self.fp.temperature = temperature; 317 | } 318 | 319 | /// Set max_initial_ts. 320 | /// See 321 | /// for more information. 322 | /// 323 | /// Defaults to 1.0. 324 | pub fn set_max_initial_ts(&mut self, max_initial_ts: f32) { 325 | self.fp.max_initial_ts = max_initial_ts; 326 | } 327 | 328 | /// Set length_penalty. 329 | /// See 330 | /// for more information. 331 | /// 332 | /// Defaults to -1.0. 333 | pub fn set_length_penalty(&mut self, length_penalty: f32) { 334 | self.fp.length_penalty = length_penalty; 335 | } 336 | 337 | /// Set temperature_inc. 338 | /// See 339 | /// for more information. 340 | /// 341 | /// Defaults to 0.2. 342 | pub fn set_temperature_inc(&mut self, temperature_inc: f32) { 343 | self.fp.temperature_inc = temperature_inc; 344 | } 345 | 346 | /// Set entropy_thold. Similar to OpenAI's compression_ratio_threshold. 347 | /// See for more information. 348 | /// 349 | /// Defaults to 2.4. 350 | pub fn set_entropy_thold(&mut self, entropy_thold: f32) { 351 | self.fp.entropy_thold = entropy_thold; 352 | } 353 | 354 | /// Set logprob_thold. 355 | /// See 356 | /// for more information. 357 | /// 358 | /// Defaults to -1.0. 359 | pub fn set_logprob_thold(&mut self, logprob_thold: f32) { 360 | self.fp.logprob_thold = logprob_thold; 361 | } 362 | 363 | /// Set no_speech_thold. Currently (as of v1.3.0) not implemented. 364 | /// 365 | /// Defaults to 0.6. 366 | pub fn set_no_speech_thold(&mut self, no_speech_thold: f32) { 367 | self.fp.no_speech_thold = no_speech_thold; 368 | } 369 | 370 | /// Set the callback for new segments. 371 | /// 372 | /// Note that this callback has not been Rustified yet (and likely never will be, unless someone else feels the need to do so). 373 | /// It is still a C callback. 374 | /// 375 | /// # Safety 376 | /// Do not use this function unless you know what you are doing. 377 | /// * Be careful not to mutate the state of the whisper_context pointer returned in the callback. 378 | /// This could cause undefined behavior, as this violates the thread-safety guarantees of the underlying C library. 379 | /// **Warning** Can't be used with DTW. DTW will produce inconsistent callback invocation 380 | /// 381 | /// Defaults to None. 382 | pub unsafe fn set_new_segment_callback( 383 | &mut self, 384 | new_segment_callback: crate::WhisperNewSegmentCallback, 385 | ) { 386 | self.fp.new_segment_callback = new_segment_callback; 387 | } 388 | 389 | /// Set the user data to be passed to the new segment callback. 390 | /// 391 | /// # Safety 392 | /// See the safety notes for `set_new_segment_callback`. 393 | /// **Warning** Can't be used with DTW. DTW will produce inconsistent callback invocation 394 | /// 395 | /// Defaults to None. 396 | pub unsafe fn set_new_segment_callback_user_data(&mut self, user_data: *mut std::ffi::c_void) { 397 | self.fp.new_segment_callback_user_data = user_data; 398 | } 399 | 400 | /// Set the callback for segment updates. 401 | /// 402 | /// Provides a limited segment_callback to ensure safety. 403 | /// See `set_new_segment_callback` if you need to use `whisper_context` and `whisper_state` 404 | /// **Warning** Can't be used with DTW. DTW will produce inconsistent callback invocation 405 | /// 406 | /// Defaults to None. 407 | pub fn set_segment_callback_safe(&mut self, closure: O) 408 | where 409 | F: FnMut(SegmentCallbackData) + 'static, 410 | O: Into>, 411 | { 412 | use std::ffi::{c_void, CStr}; 413 | use whisper_rs_sys::{whisper_context, whisper_state}; 414 | 415 | extern "C" fn trampoline( 416 | _: *mut whisper_context, 417 | state: *mut whisper_state, 418 | n_new: i32, 419 | user_data: *mut c_void, 420 | ) where 421 | F: FnMut(SegmentCallbackData) + 'static, 422 | { 423 | unsafe { 424 | let user_data = &mut *(user_data as *mut SegmentCallbackFn); 425 | let n_segments = whisper_rs_sys::whisper_full_n_segments_from_state(state); 426 | let s0 = n_segments - n_new; 427 | //let user_data = user_data as *mut Box; 428 | 429 | for i in s0..n_segments { 430 | let text = whisper_rs_sys::whisper_full_get_segment_text_from_state(state, i); 431 | let text = CStr::from_ptr(text); 432 | 433 | let t0 = whisper_rs_sys::whisper_full_get_segment_t0_from_state(state, i); 434 | let t1 = whisper_rs_sys::whisper_full_get_segment_t1_from_state(state, i); 435 | 436 | match text.to_str() { 437 | Ok(n) => user_data(SegmentCallbackData { 438 | segment: i, 439 | start_timestamp: t0, 440 | end_timestamp: t1, 441 | text: n.to_string(), 442 | }), 443 | Err(_) => {} 444 | } 445 | } 446 | } 447 | } 448 | 449 | match closure.into() { 450 | Some(closure) => { 451 | // Stable address 452 | let closure = Box::new(closure) as SegmentCallbackFn; 453 | // Thin pointer 454 | let closure = Box::new(closure); 455 | // Raw pointer 456 | let closure = Box::into_raw(closure); 457 | 458 | self.fp.new_segment_callback_user_data = closure as *mut c_void; 459 | self.fp.new_segment_callback = Some(trampoline::); 460 | self.segment_calllback_safe = None; 461 | } 462 | None => { 463 | self.segment_calllback_safe = None; 464 | self.fp.new_segment_callback = None; 465 | self.fp.new_segment_callback_user_data = std::ptr::null_mut::(); 466 | } 467 | } 468 | } 469 | 470 | /// Set the callback for segment updates. 471 | /// 472 | /// Provides a limited segment_callback to ensure safety with lossy handling of bad UTF-8 characters. 473 | /// See `set_new_segment_callback` if you need to use `whisper_context` and `whisper_state`. 474 | /// **Warning** Can't be used with DTW. DTW will produce inconsistent callback invocation 475 | /// 476 | /// Defaults to None. 477 | pub fn set_segment_callback_safe_lossy(&mut self, closure: O) 478 | where 479 | F: FnMut(SegmentCallbackData) + 'static, 480 | O: Into>, 481 | { 482 | use std::ffi::{c_void, CStr}; 483 | use whisper_rs_sys::{whisper_context, whisper_state}; 484 | 485 | extern "C" fn trampoline( 486 | _: *mut whisper_context, 487 | state: *mut whisper_state, 488 | n_new: i32, 489 | user_data: *mut c_void, 490 | ) where 491 | F: FnMut(SegmentCallbackData) + 'static, 492 | { 493 | unsafe { 494 | let user_data = &mut *(user_data as *mut SegmentCallbackFn); 495 | let n_segments = whisper_rs_sys::whisper_full_n_segments_from_state(state); 496 | let s0 = n_segments - n_new; 497 | //let user_data = user_data as *mut Box; 498 | 499 | for i in s0..n_segments { 500 | let text = whisper_rs_sys::whisper_full_get_segment_text_from_state(state, i); 501 | let text = CStr::from_ptr(text); 502 | 503 | let t0 = whisper_rs_sys::whisper_full_get_segment_t0_from_state(state, i); 504 | let t1 = whisper_rs_sys::whisper_full_get_segment_t1_from_state(state, i); 505 | user_data(SegmentCallbackData { 506 | segment: i, 507 | start_timestamp: t0, 508 | end_timestamp: t1, 509 | text: text.to_string_lossy().to_string(), 510 | }); 511 | } 512 | } 513 | } 514 | 515 | match closure.into() { 516 | Some(closure) => { 517 | // Stable address 518 | let closure = Box::new(closure) as SegmentCallbackFn; 519 | // Thin pointer 520 | let closure = Box::new(closure); 521 | // Raw pointer 522 | let closure = Box::into_raw(closure); 523 | 524 | self.fp.new_segment_callback_user_data = closure as *mut c_void; 525 | self.fp.new_segment_callback = Some(trampoline::); 526 | self.segment_calllback_safe = None; 527 | } 528 | None => { 529 | self.segment_calllback_safe = None; 530 | self.fp.new_segment_callback = None; 531 | self.fp.new_segment_callback_user_data = std::ptr::null_mut::(); 532 | } 533 | } 534 | } 535 | 536 | /// Set the callback for progress updates. 537 | /// 538 | /// Note that is still a C callback. 539 | /// See `set_progress_callback_safe` for a limited yet safe version. 540 | /// 541 | /// # Safety 542 | /// Do not use this function unless you know what you are doing. 543 | /// * Be careful not to mutate the state of the whisper_context pointer returned in the callback. 544 | /// This could cause undefined behavior, as this violates the thread-safety guarantees of the underlying C library. 545 | /// 546 | /// Defaults to None. 547 | pub unsafe fn set_progress_callback( 548 | &mut self, 549 | progress_callback: crate::WhisperProgressCallback, 550 | ) { 551 | self.fp.progress_callback = progress_callback; 552 | } 553 | 554 | /// Set the callback for progress updates, potentially using a closure. 555 | /// 556 | /// Note that, in order to ensure safety, the callback only accepts the progress in percent. 557 | /// See `set_progress_callback` if you need to use `whisper_context` and `whisper_state` 558 | /// (or extend this one to support their use). 559 | /// 560 | /// Defaults to None. 561 | pub fn set_progress_callback_safe(&mut self, closure: O) 562 | where 563 | F: FnMut(i32) + 'static, 564 | O: Into>, 565 | { 566 | use std::ffi::c_void; 567 | use whisper_rs_sys::{whisper_context, whisper_state}; 568 | 569 | unsafe extern "C" fn trampoline( 570 | _: *mut whisper_context, 571 | _: *mut whisper_state, 572 | progress: c_int, 573 | user_data: *mut c_void, 574 | ) where 575 | F: FnMut(i32), 576 | { 577 | let user_data = &mut *(user_data as *mut F); 578 | user_data(progress); 579 | } 580 | 581 | match closure.into() { 582 | Some(mut closure) => { 583 | self.fp.progress_callback = Some(trampoline::); 584 | self.fp.progress_callback_user_data = &mut closure as *mut F as *mut c_void; 585 | // store the closure internally to make sure that the pointer above remains valid 586 | self.progess_callback_safe = Some(Arc::new(Box::new(closure))); 587 | } 588 | None => { 589 | self.fp.progress_callback = None; 590 | self.fp.progress_callback_user_data = std::ptr::null_mut::(); 591 | self.progess_callback_safe = None; 592 | } 593 | } 594 | } 595 | 596 | /// Set the callback for abort conditions, potentially using a closure. 597 | /// 598 | /// Note that, for safety, the callback only accepts a function that returns a boolean 599 | /// indicating whether to abort or not. 600 | /// 601 | /// See `set_progress_callback` if you need to use `whisper_context` and `whisper_state`, 602 | /// or extend this one to support their use. 603 | /// 604 | /// Defaults to None. 605 | pub fn set_abort_callback_safe(&mut self, closure: O) 606 | where 607 | F: FnMut() -> bool + 'static, 608 | O: Into>, 609 | { 610 | use std::ffi::c_void; 611 | 612 | unsafe extern "C" fn trampoline(user_data: *mut c_void) -> bool 613 | where 614 | F: FnMut() -> bool, 615 | { 616 | let user_data = &mut *(user_data as *mut F); 617 | user_data() 618 | } 619 | 620 | match closure.into() { 621 | Some(closure) => { 622 | // Stable address 623 | let closure = Box::new(closure) as Box bool>; 624 | // Thin pointer 625 | let closure = Box::new(closure); 626 | // Raw pointer 627 | let closure = Box::into_raw(closure); 628 | 629 | self.fp.abort_callback = Some(trampoline::); 630 | self.fp.abort_callback_user_data = closure as *mut c_void; 631 | self.abort_callback_safe = None; 632 | } 633 | None => { 634 | self.fp.abort_callback = None; 635 | self.fp.abort_callback_user_data = std::ptr::null_mut::(); 636 | self.abort_callback_safe = None; 637 | } 638 | } 639 | } 640 | 641 | /// Set the user data to be passed to the progress callback. 642 | /// 643 | /// # Safety 644 | /// See the safety notes for `set_progress_callback`. 645 | /// 646 | /// Defaults to None. 647 | pub unsafe fn set_progress_callback_user_data(&mut self, user_data: *mut std::ffi::c_void) { 648 | self.fp.progress_callback_user_data = user_data; 649 | } 650 | 651 | /// Set the callback that is called each time before the encoder begins. 652 | /// 653 | /// Note that this callback has not been Rustified yet (and likely never will be, unless someone else feels the need to do so). 654 | /// It is still a C callback. 655 | /// 656 | /// # Safety 657 | /// Do not use this function unless you know what you are doing. 658 | /// * Be careful not to mutate the state of the whisper_context pointer returned in the callback. 659 | /// This could cause undefined behavior, as this violates the thread-safety guarantees of the underlying C library. 660 | /// 661 | /// Defaults to None. 662 | pub unsafe fn set_start_encoder_callback( 663 | &mut self, 664 | start_encoder_callback: crate::WhisperStartEncoderCallback, 665 | ) { 666 | self.fp.encoder_begin_callback = start_encoder_callback; 667 | } 668 | 669 | /// Set the user data to be passed to the start encoder callback. 670 | /// 671 | /// # Safety 672 | /// See the safety notes for `set_start_encoder_callback`. 673 | /// 674 | /// Defaults to None. 675 | pub unsafe fn set_start_encoder_callback_user_data( 676 | &mut self, 677 | user_data: *mut std::ffi::c_void, 678 | ) { 679 | self.fp.encoder_begin_callback_user_data = user_data; 680 | } 681 | 682 | /// Set the callback that is called by each decoder to filter obtained logits. 683 | /// 684 | /// Note that this callback has not been Rustified yet (and likely never will be, unless someone else feels the need to do so). 685 | /// It is still a C callback. 686 | /// 687 | /// # Safety 688 | /// Do not use this function unless you know what you are doing. 689 | /// * Be careful not to mutate the state of the whisper_context pointer returned in the callback. 690 | /// This could cause undefined behavior, as this violates the thread-safety guarantees of the underlying C library. 691 | /// 692 | /// Defaults to None. 693 | pub unsafe fn set_filter_logits_callback( 694 | &mut self, 695 | logits_filter_callback: crate::WhisperLogitsFilterCallback, 696 | ) { 697 | self.fp.logits_filter_callback = logits_filter_callback; 698 | } 699 | 700 | /// Set the user data to be passed to the logits filter callback. 701 | /// 702 | /// # Safety 703 | /// See the safety notes for `set_filter_logits_callback`. 704 | /// 705 | /// Defaults to None. 706 | pub unsafe fn set_filter_logits_callback_user_data( 707 | &mut self, 708 | user_data: *mut std::ffi::c_void, 709 | ) { 710 | self.fp.logits_filter_callback_user_data = user_data; 711 | } 712 | 713 | /// Set the callback that is called each time before ggml computation starts. 714 | /// 715 | /// Note that this callback has not been Rustified yet (and likely never will be, unless someone else feels the need to do so). 716 | /// It is still a C callback. 717 | /// 718 | /// # Safety 719 | /// Do not use this function unless you know what you are doing. 720 | /// * Be careful not to mutate the state of the whisper_context pointer returned in the callback. 721 | /// This could cause undefined behavior, as this violates the thread-safety guarantees of the underlying C library. 722 | /// 723 | /// Defaults to None. 724 | pub unsafe fn set_abort_callback(&mut self, abort_callback: crate::WhisperAbortCallback) { 725 | self.fp.abort_callback = abort_callback; 726 | } 727 | 728 | /// Set the user data to be passed to the abort callback. 729 | /// 730 | /// # Safety 731 | /// See the safety notes for `set_abort_callback`. 732 | /// 733 | /// Defaults to None. 734 | pub unsafe fn set_abort_callback_user_data(&mut self, user_data: *mut std::ffi::c_void) { 735 | self.fp.abort_callback_user_data = user_data; 736 | } 737 | 738 | /// Enable an array of grammar elements to be passed to the whisper model. 739 | /// 740 | /// Defaults to an empty vector. 741 | pub fn set_grammar(&mut self, grammar: Option<&[WhisperGrammarElement]>) { 742 | if let Some(grammar) = grammar { 743 | // convert to c types 744 | let inner = grammar.iter().map(|e| e.to_c_type()).collect::>(); 745 | // turn into ptr and len 746 | let grammar_ptr = inner.as_ptr() as *mut _; 747 | let grammar_len = inner.len(); 748 | 749 | self.grammar = Some(inner); 750 | 751 | // set the grammar 752 | self.fp.grammar_rules = grammar_ptr; 753 | self.fp.n_grammar_rules = grammar_len; 754 | } else { 755 | self.grammar = None; 756 | self.fp.grammar_rules = std::ptr::null_mut(); 757 | self.fp.n_grammar_rules = 0; 758 | self.fp.i_start_rule = 0; 759 | } 760 | } 761 | 762 | /// Set the start grammar rule. Does nothing if no grammar is set. 763 | /// 764 | /// Defaults to 0. 765 | pub fn set_start_rule(&mut self, start_rule: usize) { 766 | if self.grammar.is_some() { 767 | self.fp.i_start_rule = start_rule; 768 | } 769 | } 770 | 771 | /// Set grammar penalty. 772 | /// 773 | /// Defaults to 100.0. 774 | pub fn set_grammar_penalty(&mut self, grammar_penalty: f32) { 775 | self.fp.grammar_penalty = grammar_penalty; 776 | } 777 | 778 | /// Set the initial prompt for the model. 779 | /// 780 | /// This is the text that will be used as the starting point for the model's decoding. 781 | /// Calling this more than once will overwrite the previous initial prompt. 782 | /// 783 | /// # Arguments 784 | /// * `initial_prompt` - A string slice representing the initial prompt text. 785 | /// 786 | /// # Panics 787 | /// This method will panic if `initial_prompt` contains a null byte, as it cannot be converted into a `CString`. 788 | /// 789 | /// # Examples 790 | /// ``` 791 | /// # use whisper_rs::{FullParams, SamplingStrategy}; 792 | /// let mut params = FullParams::new(SamplingStrategy::default()); 793 | /// params.set_initial_prompt("Hello, world!"); 794 | /// // ... further usage of params ... 795 | /// ``` 796 | pub fn set_initial_prompt(&mut self, initial_prompt: &str) { 797 | self.fp.initial_prompt = CString::new(initial_prompt) 798 | .expect("Initial prompt contains null byte") 799 | .into_raw() as *const c_char; 800 | } 801 | } 802 | 803 | // following implementations are safe 804 | // see https://github.com/ggerganov/whisper.cpp/issues/32#issuecomment-1272790388 805 | // concurrent usage is prevented by &mut self on methods that modify the struct 806 | unsafe impl Send for FullParams<'_, '_> {} 807 | unsafe impl Sync for FullParams<'_, '_> {} 808 | 809 | #[cfg(test)] 810 | mod test_whisper_params_initial_prompt { 811 | use super::*; 812 | 813 | impl<'a, 'b> FullParams<'a, 'b> { 814 | pub fn get_initial_prompt(&self) -> &str { 815 | // SAFETY: Ensure this is safe and respects the lifetime of the string in self.fp 816 | unsafe { 817 | std::ffi::CStr::from_ptr(self.fp.initial_prompt) 818 | .to_str() 819 | .unwrap() 820 | } 821 | } 822 | } 823 | 824 | #[test] 825 | fn test_initial_prompt_normal_usage() { 826 | let mut params = FullParams::new(SamplingStrategy::default()); 827 | let prompt = "Hello, world!"; 828 | params.set_initial_prompt(prompt); 829 | assert_eq!(params.get_initial_prompt(), prompt); 830 | } 831 | 832 | #[test] 833 | #[should_panic(expected = "Initial prompt contains null byte")] 834 | fn test_initial_prompt_null_byte() { 835 | let mut params = FullParams::new(SamplingStrategy::default()); 836 | let prompt = "Hello\0, world!"; 837 | params.set_initial_prompt(prompt); 838 | // Should panic 839 | } 840 | 841 | #[test] 842 | fn test_initial_prompt_empty_string() { 843 | let mut params = FullParams::new(SamplingStrategy::default()); 844 | let prompt = ""; 845 | params.set_initial_prompt(prompt); 846 | 847 | assert_eq!( 848 | params.get_initial_prompt(), 849 | prompt, 850 | "The initial prompt should be an empty string." 851 | ); 852 | } 853 | 854 | #[test] 855 | fn test_initial_prompt_repeated_calls() { 856 | let mut params = FullParams::new(SamplingStrategy::default()); 857 | params.set_initial_prompt("First prompt"); 858 | assert_eq!( 859 | params.get_initial_prompt(), 860 | "First prompt", 861 | "The initial prompt should be 'First prompt'." 862 | ); 863 | 864 | params.set_initial_prompt("Second prompt"); 865 | assert_eq!( 866 | params.get_initial_prompt(), 867 | "Second prompt", 868 | "The initial prompt should be 'Second prompt' after second set." 869 | ); 870 | } 871 | 872 | #[test] 873 | fn test_initial_prompt_long_string() { 874 | let mut params = FullParams::new(SamplingStrategy::default()); 875 | let long_prompt = "a".repeat(10000); // a long string of 10,000 'a' characters 876 | params.set_initial_prompt(&long_prompt); 877 | 878 | assert_eq!( 879 | params.get_initial_prompt(), 880 | long_prompt.as_str(), 881 | "The initial prompt should match the long string provided." 882 | ); 883 | } 884 | } 885 | -------------------------------------------------------------------------------- /src/whisper_state.rs: -------------------------------------------------------------------------------- 1 | use std::ffi::{c_int, CStr}; 2 | use std::sync::Arc; 3 | 4 | use crate::{FullParams, WhisperError, WhisperInnerContext, WhisperToken, WhisperTokenData}; 5 | 6 | /// Rustified pointer to a Whisper state. 7 | #[derive(Debug)] 8 | pub struct WhisperState { 9 | ctx: Arc, 10 | ptr: *mut whisper_rs_sys::whisper_state, 11 | } 12 | 13 | unsafe impl Send for WhisperState {} 14 | 15 | unsafe impl Sync for WhisperState {} 16 | 17 | impl Drop for WhisperState { 18 | fn drop(&mut self) { 19 | unsafe { 20 | whisper_rs_sys::whisper_free_state(self.ptr); 21 | } 22 | } 23 | } 24 | 25 | impl WhisperState { 26 | pub(crate) fn new( 27 | ctx: Arc, 28 | ptr: *mut whisper_rs_sys::whisper_state, 29 | ) -> Self { 30 | Self { ctx, ptr } 31 | } 32 | 33 | /// Convert raw PCM audio (floating point 32 bit) to log mel spectrogram. 34 | /// The resulting spectrogram is stored in the context transparently. 35 | /// 36 | /// # Arguments 37 | /// * pcm: The raw PCM audio. 38 | /// * threads: How many threads to use. Defaults to 1. Must be at least 1, returns an error otherwise. 39 | /// 40 | /// # Returns 41 | /// Ok(()) on success, Err(WhisperError) on failure. 42 | /// 43 | /// # C++ equivalent 44 | /// `int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads)` 45 | pub fn pcm_to_mel(&mut self, pcm: &[f32], threads: usize) -> Result<(), WhisperError> { 46 | if threads < 1 { 47 | return Err(WhisperError::InvalidThreadCount); 48 | } 49 | let ret = unsafe { 50 | whisper_rs_sys::whisper_pcm_to_mel_with_state( 51 | self.ctx.ctx, 52 | self.ptr, 53 | pcm.as_ptr(), 54 | pcm.len() as c_int, 55 | threads as c_int, 56 | ) 57 | }; 58 | if ret == -1 { 59 | Err(WhisperError::UnableToCalculateSpectrogram) 60 | } else if ret == 0 { 61 | Ok(()) 62 | } else { 63 | Err(WhisperError::GenericError(ret)) 64 | } 65 | } 66 | 67 | /// This can be used to set a custom log mel spectrogram inside the provided whisper state. 68 | /// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram. 69 | /// 70 | /// # Note 71 | /// This is a low-level function. 72 | /// If you're a typical user, you probably don't want to use this function. 73 | /// See instead [WhisperState::pcm_to_mel]. 74 | /// 75 | /// # Arguments 76 | /// * data: The log mel spectrogram. 77 | /// 78 | /// # Returns 79 | /// Ok(()) on success, Err(WhisperError) on failure. 80 | /// 81 | /// # C++ equivalent 82 | /// `int whisper_set_mel(struct whisper_context * ctx, const float * data, int n_len, int n_mel)` 83 | pub fn set_mel(&mut self, data: &[f32]) -> Result<(), WhisperError> { 84 | let hop_size = 160; 85 | let n_len = (data.len() / hop_size) * 2; 86 | let ret = unsafe { 87 | whisper_rs_sys::whisper_set_mel_with_state( 88 | self.ctx.ctx, 89 | self.ptr, 90 | data.as_ptr(), 91 | n_len as c_int, 92 | 80 as c_int, 93 | ) 94 | }; 95 | if ret == -1 { 96 | Err(WhisperError::InvalidMelBands) 97 | } else if ret == 0 { 98 | Ok(()) 99 | } else { 100 | Err(WhisperError::GenericError(ret)) 101 | } 102 | } 103 | 104 | /// Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper state. 105 | /// Make sure to call [WhisperState::pcm_to_mel] or [WhisperState::set_mel] first. 106 | /// 107 | /// # Arguments 108 | /// * offset: Can be used to specify the offset of the first frame in the spectrogram. Usually 0. 109 | /// * threads: How many threads to use. Defaults to 1. Must be at least 1, returns an error otherwise. 110 | /// 111 | /// # Returns 112 | /// Ok(()) on success, Err(WhisperError) on failure. 113 | /// 114 | /// # C++ equivalent 115 | /// `int whisper_encode(struct whisper_context * ctx, int offset, int n_threads)` 116 | pub fn encode(&mut self, offset: usize, threads: usize) -> Result<(), WhisperError> { 117 | if threads < 1 { 118 | return Err(WhisperError::InvalidThreadCount); 119 | } 120 | let ret = unsafe { 121 | whisper_rs_sys::whisper_encode_with_state( 122 | self.ctx.ctx, 123 | self.ptr, 124 | offset as c_int, 125 | threads as c_int, 126 | ) 127 | }; 128 | if ret == -1 { 129 | Err(WhisperError::UnableToCalculateEvaluation) 130 | } else if ret == 0 { 131 | Ok(()) 132 | } else { 133 | Err(WhisperError::GenericError(ret)) 134 | } 135 | } 136 | 137 | /// Run the Whisper decoder to obtain the logits and probabilities for the next token. 138 | /// Make sure to call [WhisperState::encode] first. 139 | /// tokens + n_tokens is the provided context for the decoder. 140 | /// 141 | /// # Arguments 142 | /// * tokens: The tokens to decode. 143 | /// * n_tokens: The number of tokens to decode. 144 | /// * n_past: The number of past tokens to use for the decoding. 145 | /// * n_threads: How many threads to use. Defaults to 1. Must be at least 1, returns an error otherwise. 146 | /// 147 | /// # Returns 148 | /// Ok(()) on success, Err(WhisperError) on failure. 149 | /// 150 | /// # C++ equivalent 151 | /// `int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads)` 152 | pub fn decode( 153 | &mut self, 154 | tokens: &[WhisperToken], 155 | n_past: usize, 156 | threads: usize, 157 | ) -> Result<(), WhisperError> { 158 | if threads < 1 { 159 | return Err(WhisperError::InvalidThreadCount); 160 | } 161 | let ret = unsafe { 162 | whisper_rs_sys::whisper_decode_with_state( 163 | self.ctx.ctx, 164 | self.ptr, 165 | tokens.as_ptr(), 166 | tokens.len() as c_int, 167 | n_past as c_int, 168 | threads as c_int, 169 | ) 170 | }; 171 | if ret == -1 { 172 | Err(WhisperError::UnableToCalculateEvaluation) 173 | } else if ret == 0 { 174 | Ok(()) 175 | } else { 176 | Err(WhisperError::GenericError(ret)) 177 | } 178 | } 179 | 180 | // Language functions 181 | /// Use mel data at offset_ms to try and auto-detect the spoken language 182 | /// Make sure to call pcm_to_mel() or set_mel() first 183 | /// 184 | /// # Arguments 185 | /// * offset_ms: The offset in milliseconds to use for the language detection. 186 | /// * n_threads: How many threads to use. Defaults to 1. Must be at least 1, returns an error otherwise. 187 | /// 188 | /// # Returns 189 | /// `Ok((i32, Vec))` on success where the i32 is detected language id and Vec 190 | /// is array with the probabilities of all languages, `Err(WhisperError)` on failure. 191 | /// 192 | /// # C++ equivalent 193 | /// `int whisper_lang_auto_detect(struct whisper_context * ctx, int offset_ms, int n_threads, float * lang_probs)` 194 | pub fn lang_detect( 195 | &self, 196 | offset_ms: usize, 197 | threads: usize, 198 | ) -> Result<(i32, Vec), WhisperError> { 199 | if threads < 1 { 200 | return Err(WhisperError::InvalidThreadCount); 201 | } 202 | 203 | let mut lang_probs: Vec = vec![0.0; crate::standalone::get_lang_max_id() as usize + 1]; 204 | let ret = unsafe { 205 | whisper_rs_sys::whisper_lang_auto_detect_with_state( 206 | self.ctx.ctx, 207 | self.ptr, 208 | offset_ms as c_int, 209 | threads as c_int, 210 | lang_probs.as_mut_ptr(), 211 | ) 212 | }; 213 | if ret < 0 { 214 | Err(WhisperError::GenericError(ret)) 215 | } else { 216 | Ok((ret as i32, lang_probs)) 217 | } 218 | } 219 | 220 | // logit functions 221 | /// Gets logits obtained from the last call to [WhisperState::decode]. 222 | /// As of whisper.cpp 1.4.1, only a single row of logits is available, corresponding to the last token in the input. 223 | /// 224 | /// # Returns 225 | /// A slice of logits with length equal to n_vocab. 226 | /// 227 | /// # C++ equivalent 228 | /// `float * whisper_get_logits(struct whisper_context * ctx)` 229 | pub fn get_logits(&self) -> Result<&[f32], WhisperError> { 230 | let ret = unsafe { whisper_rs_sys::whisper_get_logits_from_state(self.ptr) }; 231 | if ret.is_null() { 232 | return Err(WhisperError::NullPointer); 233 | } 234 | let n_vocab = self.n_vocab(); 235 | Ok(unsafe { std::slice::from_raw_parts(ret, n_vocab as usize) }) 236 | } 237 | 238 | // model attributes 239 | /// Get the mel spectrogram length. 240 | /// 241 | /// # Returns 242 | /// Ok(c_int) on success, Err(WhisperError) on failure. 243 | /// 244 | /// # C++ equivalent 245 | /// `int whisper_n_len_from_state(struct whisper_context * ctx)` 246 | #[inline] 247 | pub fn n_len(&self) -> Result { 248 | Ok(unsafe { whisper_rs_sys::whisper_n_len_from_state(self.ptr) }) 249 | } 250 | 251 | /// Get n_vocab. 252 | /// 253 | /// # Returns 254 | /// c_int 255 | /// 256 | /// # C++ equivalent 257 | /// `int whisper_n_vocab (struct whisper_context * ctx)` 258 | #[inline] 259 | pub fn n_vocab(&self) -> c_int { 260 | unsafe { whisper_rs_sys::whisper_n_vocab(self.ctx.ctx) } 261 | } 262 | 263 | /// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text 264 | /// Uses the specified decoding strategy to obtain the text. 265 | /// 266 | /// This is usually the only function you need to call as an end user. 267 | /// 268 | /// # Arguments 269 | /// * params: [crate::FullParams] struct. 270 | /// * pcm: raw PCM audio data, 32 bit floating point at a sample rate of 16 kHz, 1 channel. 271 | /// See utilities in the root of this crate for functions to convert audio to this format. 272 | /// 273 | /// # Returns 274 | /// Ok(c_int) on success, Err(WhisperError) on failure. 275 | /// 276 | /// # C++ equivalent 277 | /// `int whisper_full(struct whisper_context * ctx, struct whisper_full_params params, const float * samples, int n_samples)` 278 | pub fn full(&mut self, params: FullParams, data: &[f32]) -> Result { 279 | if data.is_empty() { 280 | // can randomly trigger segmentation faults if we don't check this 281 | return Err(WhisperError::NoSamples); 282 | } 283 | 284 | let ret = unsafe { 285 | whisper_rs_sys::whisper_full_with_state( 286 | self.ctx.ctx, 287 | self.ptr, 288 | params.fp, 289 | data.as_ptr(), 290 | data.len() as c_int, 291 | ) 292 | }; 293 | if ret == -1 { 294 | Err(WhisperError::UnableToCalculateSpectrogram) 295 | } else if ret == 7 { 296 | Err(WhisperError::FailedToEncode) 297 | } else if ret == 8 { 298 | Err(WhisperError::FailedToDecode) 299 | } else if ret == 0 { 300 | Ok(ret) 301 | } else { 302 | Err(WhisperError::GenericError(ret)) 303 | } 304 | } 305 | 306 | /// Number of generated text segments. 307 | /// A segment can be a few words, a sentence, or even a paragraph. 308 | /// 309 | /// # C++ equivalent 310 | /// `int whisper_full_n_segments(struct whisper_context * ctx)` 311 | #[inline] 312 | pub fn full_n_segments(&self) -> Result { 313 | Ok(unsafe { whisper_rs_sys::whisper_full_n_segments_from_state(self.ptr) }) 314 | } 315 | 316 | /// Language ID associated with the provided state. 317 | /// 318 | /// # C++ equivalent 319 | /// `int whisper_full_lang_id_from_state(struct whisper_state * state);` 320 | #[inline] 321 | pub fn full_lang_id_from_state(&self) -> Result { 322 | Ok(unsafe { whisper_rs_sys::whisper_full_lang_id_from_state(self.ptr) }) 323 | } 324 | 325 | /// Get the start time of the specified segment. 326 | /// 327 | /// # Arguments 328 | /// * segment: Segment index. 329 | /// 330 | /// # C++ equivalent 331 | /// `int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment)` 332 | #[inline] 333 | pub fn full_get_segment_t0(&self, segment: c_int) -> Result { 334 | Ok(unsafe { whisper_rs_sys::whisper_full_get_segment_t0_from_state(self.ptr, segment) }) 335 | } 336 | 337 | /// Get the end time of the specified segment. 338 | /// 339 | /// # Arguments 340 | /// * segment: Segment index. 341 | /// 342 | /// # C++ equivalent 343 | /// `int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment)` 344 | #[inline] 345 | pub fn full_get_segment_t1(&self, segment: c_int) -> Result { 346 | Ok(unsafe { whisper_rs_sys::whisper_full_get_segment_t1_from_state(self.ptr, segment) }) 347 | } 348 | 349 | fn full_get_segment_raw(&self, segment: c_int) -> Result<&CStr, WhisperError> { 350 | let ret = 351 | unsafe { whisper_rs_sys::whisper_full_get_segment_text_from_state(self.ptr, segment) }; 352 | if ret.is_null() { 353 | return Err(WhisperError::NullPointer); 354 | } 355 | unsafe { Ok(CStr::from_ptr(ret)) } 356 | } 357 | 358 | /// Get the raw bytes of the specified segment. 359 | /// 360 | /// # Arguments 361 | /// * segment: Segment index. 362 | /// 363 | /// # Returns 364 | /// `Ok(Vec)` on success, with the returned bytes or 365 | /// `Err(WhisperError::NullPointer)` on failure (this is the only possible error) 366 | /// 367 | /// # C++ equivalent 368 | /// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)` 369 | pub fn full_get_segment_bytes(&self, segment: c_int) -> Result, WhisperError> { 370 | Ok(self.full_get_segment_raw(segment)?.to_bytes().to_vec()) 371 | } 372 | 373 | /// Get the text of the specified segment. 374 | /// 375 | /// # Arguments 376 | /// * segment: Segment index. 377 | /// 378 | /// # Returns 379 | /// `Ok(String)` on success, with the UTF-8 validated string, or 380 | /// `Err(WhisperError)` on failure (either `NullPointer` or `InvalidUtf8`) 381 | /// 382 | /// # C++ equivalent 383 | /// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)` 384 | pub fn full_get_segment_text(&self, segment: c_int) -> Result { 385 | Ok(self.full_get_segment_raw(segment)?.to_str()?.to_string()) 386 | } 387 | 388 | /// Get the text of the specified segment. 389 | /// This function differs from [WhisperState::full_get_segment_text] 390 | /// in that it ignores invalid UTF-8 in whisper strings, 391 | /// instead opting to replace it with the replacement character. 392 | /// 393 | /// # Arguments 394 | /// * segment: Segment index. 395 | /// 396 | /// # Returns 397 | /// `Ok(String)` on success, or 398 | /// `Err(WhisperError::NullPointer)` on failure (this is the only possible error) 399 | /// 400 | /// # C++ equivalent 401 | /// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)` 402 | pub fn full_get_segment_text_lossy(&self, segment: c_int) -> Result { 403 | Ok(self 404 | .full_get_segment_raw(segment)? 405 | .to_string_lossy() 406 | .to_string()) 407 | } 408 | 409 | /// Get number of tokens in the specified segment. 410 | /// 411 | /// # Arguments 412 | /// * segment: Segment index. 413 | /// 414 | /// # Returns 415 | /// c_int 416 | /// 417 | /// # C++ equivalent 418 | /// `int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment)` 419 | #[inline] 420 | pub fn full_n_tokens(&self, segment: c_int) -> Result { 421 | Ok(unsafe { whisper_rs_sys::whisper_full_n_tokens_from_state(self.ptr, segment) }) 422 | } 423 | 424 | fn full_get_token_raw(&self, segment: c_int, token: c_int) -> Result<&CStr, WhisperError> { 425 | let ret = unsafe { 426 | whisper_rs_sys::whisper_full_get_token_text_from_state( 427 | self.ctx.ctx, 428 | self.ptr, 429 | segment, 430 | token, 431 | ) 432 | }; 433 | if ret.is_null() { 434 | return Err(WhisperError::NullPointer); 435 | } 436 | unsafe { Ok(CStr::from_ptr(ret)) } 437 | } 438 | 439 | /// Get the raw token bytes of the specified token in the specified segment. 440 | /// 441 | /// Useful if you're using a language for which whisper is known to split tokens 442 | /// away from UTF-8 character boundaries. 443 | /// 444 | /// # Arguments 445 | /// * segment: Segment index. 446 | /// * token: Token index. 447 | /// 448 | /// # Returns 449 | /// `Ok(Vec)` on success, with the returned bytes or 450 | /// `Err(WhisperError::NullPointer)` on failure (this is the only possible error) 451 | /// 452 | /// # C++ equivalent 453 | /// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)` 454 | pub fn full_get_token_bytes( 455 | &self, 456 | segment: c_int, 457 | token: c_int, 458 | ) -> Result, WhisperError> { 459 | Ok(self.full_get_token_raw(segment, token)?.to_bytes().to_vec()) 460 | } 461 | 462 | /// Get the token text of the specified token in the specified segment. 463 | /// 464 | /// # Arguments 465 | /// * segment: Segment index. 466 | /// * token: Token index. 467 | /// 468 | /// # Returns 469 | /// `Ok(String)` on success, with the UTF-8 validated string, or 470 | /// `Err(WhisperError)` on failure (either `NullPointer` or `InvalidUtf8`) 471 | /// 472 | /// # C++ equivalent 473 | /// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)` 474 | pub fn full_get_token_text( 475 | &self, 476 | segment: c_int, 477 | token: c_int, 478 | ) -> Result { 479 | Ok(self 480 | .full_get_token_raw(segment, token)? 481 | .to_str()? 482 | .to_string()) 483 | } 484 | 485 | /// Get the token text of the specified token in the specified segment. 486 | /// This function differs from [WhisperState::full_get_token_text] 487 | /// in that it ignores invalid UTF-8 in whisper strings, 488 | /// instead opting to replace it with the replacement character. 489 | /// 490 | /// # Arguments 491 | /// * segment: Segment index. 492 | /// * token: Token index. 493 | /// 494 | /// # Returns 495 | /// `Ok(String)` on success, or 496 | /// `Err(WhisperError::NullPointer)` on failure (this is the only possible error) 497 | /// 498 | /// # C++ equivalent 499 | /// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)` 500 | pub fn full_get_token_text_lossy( 501 | &self, 502 | segment: c_int, 503 | token: c_int, 504 | ) -> Result { 505 | Ok(self 506 | .full_get_token_raw(segment, token)? 507 | .to_string_lossy() 508 | .to_string()) 509 | } 510 | 511 | /// Get the token ID of the specified token in the specified segment. 512 | /// 513 | /// # Arguments 514 | /// * segment: Segment index. 515 | /// * token: Token index. 516 | /// 517 | /// # Returns 518 | /// [crate::WhisperToken] 519 | /// 520 | /// # C++ equivalent 521 | /// `whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token)` 522 | pub fn full_get_token_id( 523 | &self, 524 | segment: c_int, 525 | token: c_int, 526 | ) -> Result { 527 | Ok(unsafe { 528 | whisper_rs_sys::whisper_full_get_token_id_from_state(self.ptr, segment, token) 529 | }) 530 | } 531 | 532 | /// Get token data for the specified token in the specified segment. 533 | /// 534 | /// # Arguments 535 | /// * segment: Segment index. 536 | /// * token: Token index. 537 | /// 538 | /// # Returns 539 | /// [crate::WhisperTokenData] 540 | /// 541 | /// # C++ equivalent 542 | /// `whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token)` 543 | #[inline] 544 | pub fn full_get_token_data( 545 | &self, 546 | segment: c_int, 547 | token: c_int, 548 | ) -> Result { 549 | Ok(unsafe { 550 | whisper_rs_sys::whisper_full_get_token_data_from_state(self.ptr, segment, token) 551 | }) 552 | } 553 | 554 | /// Get the probability of the specified token in the specified segment. 555 | /// 556 | /// # Arguments 557 | /// * segment: Segment index. 558 | /// * token: Token index. 559 | /// 560 | /// # Returns 561 | /// f32 562 | /// 563 | /// # C++ equivalent 564 | /// `float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token)` 565 | #[inline] 566 | pub fn full_get_token_prob(&self, segment: c_int, token: c_int) -> Result { 567 | Ok( 568 | unsafe { 569 | whisper_rs_sys::whisper_full_get_token_p_from_state(self.ptr, segment, token) 570 | }, 571 | ) 572 | } 573 | 574 | /// Get whether the next segment is predicted as a speaker turn. 575 | /// 576 | /// # Arguments 577 | /// * i_segment: Segment index. 578 | /// 579 | /// # Returns 580 | /// bool 581 | /// 582 | /// # C++ equivalent 583 | /// `bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment)` 584 | pub fn full_get_segment_speaker_turn_next(&mut self, i_segment: c_int) -> bool { 585 | unsafe { 586 | whisper_rs_sys::whisper_full_get_segment_speaker_turn_next_from_state( 587 | self.ptr, i_segment, 588 | ) 589 | } 590 | } 591 | } 592 | -------------------------------------------------------------------------------- /sys/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "whisper-rs-sys" 3 | version = "0.12.1" 4 | edition = "2021" 5 | description = "Rust bindings for whisper.cpp (FFI bindings)" 6 | license = "Unlicense" 7 | documentation = "https://docs.rs/whisper-rs-sys" 8 | repository = "https://github.com/tazz4843/whisper-rs" 9 | links = "whisper" 10 | include = [ 11 | "whisper.cpp/bindings/javascript/package-tmpl.json", 12 | "whisper.cpp/bindings/CMakeLists.txt", 13 | "whisper.cpp/CMakeLists.txt", 14 | "whisper.cpp/cmake", 15 | "whisper.cpp/src/**", 16 | "whisper.cpp/include/whisper.h", 17 | "whisper.cpp/ggml/cmake", 18 | "whisper.cpp/ggml/CMakeLists.txt", 19 | "whisper.cpp/ggml/src/**", 20 | "whisper.cpp/ggml/include/*.h", 21 | "whisper.cpp/LICENSE", 22 | "src/*.rs", 23 | "build.rs", 24 | "wrapper.h", 25 | ] 26 | 27 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 28 | 29 | [features] 30 | coreml = [] 31 | cuda = [] 32 | hipblas = [] 33 | openblas = [] 34 | metal = [] 35 | vulkan = [] 36 | force-debug = [] 37 | openmp = [] 38 | 39 | [build-dependencies] 40 | cmake = "0.1" 41 | bindgen = "0.71" 42 | cfg-if = "1" 43 | fs_extra = "1.3" 44 | -------------------------------------------------------------------------------- /sys/build.rs: -------------------------------------------------------------------------------- 1 | #![allow(clippy::uninlined_format_args)] 2 | 3 | extern crate bindgen; 4 | 5 | use cmake::Config; 6 | use std::env; 7 | use std::fs::File; 8 | use std::io::{BufRead, BufReader}; 9 | use std::path::PathBuf; 10 | 11 | fn main() { 12 | let target = env::var("TARGET").unwrap(); 13 | // Link C++ standard library 14 | if let Some(cpp_stdlib) = get_cpp_link_stdlib(&target) { 15 | println!("cargo:rustc-link-lib=dylib={}", cpp_stdlib); 16 | } 17 | // Link macOS Accelerate framework for matrix calculations 18 | if target.contains("apple") { 19 | println!("cargo:rustc-link-lib=framework=Accelerate"); 20 | #[cfg(feature = "coreml")] 21 | { 22 | println!("cargo:rustc-link-lib=framework=Foundation"); 23 | println!("cargo:rustc-link-lib=framework=CoreML"); 24 | } 25 | #[cfg(feature = "metal")] 26 | { 27 | println!("cargo:rustc-link-lib=framework=Foundation"); 28 | println!("cargo:rustc-link-lib=framework=Metal"); 29 | println!("cargo:rustc-link-lib=framework=MetalKit"); 30 | } 31 | } 32 | 33 | #[cfg(feature = "coreml")] 34 | println!("cargo:rustc-link-lib=static=whisper.coreml"); 35 | 36 | #[cfg(feature = "openblas")] 37 | { 38 | if let Ok(openblas_path) = env::var("OPENBLAS_PATH") { 39 | println!( 40 | "cargo::rustc-link-search={}", 41 | PathBuf::from(openblas_path).join("lib").display() 42 | ); 43 | } 44 | if cfg!(windows) { 45 | println!("cargo:rustc-link-lib=libopenblas"); 46 | } else { 47 | println!("cargo:rustc-link-lib=openblas"); 48 | } 49 | } 50 | #[cfg(feature = "cuda")] 51 | { 52 | println!("cargo:rustc-link-lib=cublas"); 53 | println!("cargo:rustc-link-lib=cudart"); 54 | println!("cargo:rustc-link-lib=cublasLt"); 55 | println!("cargo:rustc-link-lib=cuda"); 56 | cfg_if::cfg_if! { 57 | if #[cfg(target_os = "windows")] { 58 | let cuda_path = PathBuf::from(env::var("CUDA_PATH").unwrap()).join("lib/x64"); 59 | println!("cargo:rustc-link-search={}", cuda_path.display()); 60 | } else { 61 | println!("cargo:rustc-link-lib=culibos"); 62 | println!("cargo:rustc-link-search=/usr/local/cuda/lib64"); 63 | println!("cargo:rustc-link-search=/usr/local/cuda/lib64/stubs"); 64 | println!("cargo:rustc-link-search=/opt/cuda/lib64"); 65 | println!("cargo:rustc-link-search=/opt/cuda/lib64/stubs"); 66 | } 67 | } 68 | } 69 | #[cfg(feature = "hipblas")] 70 | { 71 | println!("cargo:rustc-link-lib=hipblas"); 72 | println!("cargo:rustc-link-lib=rocblas"); 73 | println!("cargo:rustc-link-lib=amdhip64"); 74 | 75 | cfg_if::cfg_if! { 76 | if #[cfg(target_os = "windows")] { 77 | panic!("Due to a problem with the last revision of the ROCm 5.7 library, it is not possible to compile the library for the windows environment.\nSee https://github.com/ggerganov/whisper.cpp/issues/2202 for more details.") 78 | } else { 79 | println!("cargo:rerun-if-env-changed=HIP_PATH"); 80 | 81 | let hip_path = match env::var("HIP_PATH") { 82 | Ok(path) =>PathBuf::from(path), 83 | Err(_) => PathBuf::from("/opt/rocm"), 84 | }; 85 | let hip_lib_path = hip_path.join("lib"); 86 | 87 | println!("cargo:rustc-link-search={}",hip_lib_path.display()); 88 | } 89 | } 90 | } 91 | 92 | #[cfg(feature = "openmp")] 93 | { 94 | if target.contains("gnu") { 95 | println!("cargo:rustc-link-lib=gomp"); 96 | } else if target.contains("apple") { 97 | println!("cargo:rustc-link-lib=omp"); 98 | println!("cargo:rustc-link-search=/opt/homebrew/opt/libomp/lib"); 99 | } 100 | } 101 | 102 | println!("cargo:rerun-if-changed=wrapper.h"); 103 | 104 | let out = PathBuf::from(env::var("OUT_DIR").unwrap()); 105 | let whisper_root = out.join("whisper.cpp/"); 106 | 107 | if !whisper_root.exists() { 108 | std::fs::create_dir_all(&whisper_root).unwrap(); 109 | fs_extra::dir::copy("./whisper.cpp", &out, &Default::default()).unwrap_or_else(|e| { 110 | panic!( 111 | "Failed to copy whisper sources into {}: {}", 112 | whisper_root.display(), 113 | e 114 | ) 115 | }); 116 | } 117 | 118 | if env::var("WHISPER_DONT_GENERATE_BINDINGS").is_ok() { 119 | let _: u64 = std::fs::copy("src/bindings.rs", out.join("bindings.rs")) 120 | .expect("Failed to copy bindings.rs"); 121 | } else { 122 | let bindings = bindgen::Builder::default().header("wrapper.h"); 123 | 124 | #[cfg(feature = "metal")] 125 | let bindings = bindings.header("whisper.cpp/ggml/include/ggml-metal.h"); 126 | 127 | let bindings = bindings 128 | .clang_arg("-I./whisper.cpp/") 129 | .clang_arg("-I./whisper.cpp/include") 130 | .clang_arg("-I./whisper.cpp/ggml/include") 131 | .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) 132 | .generate(); 133 | 134 | match bindings { 135 | Ok(b) => { 136 | let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); 137 | b.write_to_file(out_path.join("bindings.rs")) 138 | .expect("Couldn't write bindings!"); 139 | } 140 | Err(e) => { 141 | println!("cargo:warning=Unable to generate bindings: {}", e); 142 | println!("cargo:warning=Using bundled bindings.rs, which may be out of date"); 143 | // copy src/bindings.rs to OUT_DIR 144 | std::fs::copy("src/bindings.rs", out.join("bindings.rs")) 145 | .expect("Unable to copy bindings.rs"); 146 | } 147 | } 148 | }; 149 | 150 | // stop if we're on docs.rs 151 | if env::var("DOCS_RS").is_ok() { 152 | return; 153 | } 154 | 155 | let mut config = Config::new(&whisper_root); 156 | 157 | config 158 | .profile("Release") 159 | .define("BUILD_SHARED_LIBS", "OFF") 160 | .define("WHISPER_ALL_WARNINGS", "OFF") 161 | .define("WHISPER_ALL_WARNINGS_3RD_PARTY", "OFF") 162 | .define("WHISPER_BUILD_TESTS", "OFF") 163 | .define("WHISPER_BUILD_EXAMPLES", "OFF") 164 | .very_verbose(true) 165 | .pic(true); 166 | 167 | if cfg!(target_os = "windows") { 168 | config.cxxflag("/utf-8"); 169 | } 170 | 171 | if cfg!(feature = "coreml") { 172 | config.define("WHISPER_COREML", "ON"); 173 | config.define("WHISPER_COREML_ALLOW_FALLBACK", "1"); 174 | } 175 | 176 | if cfg!(feature = "cuda") { 177 | config.define("GGML_CUDA", "ON"); 178 | } 179 | 180 | if cfg!(feature = "hipblas") { 181 | config.define("GGML_HIPBLAS", "ON"); 182 | config.define("CMAKE_C_COMPILER", "hipcc"); 183 | config.define("CMAKE_CXX_COMPILER", "hipcc"); 184 | println!("cargo:rerun-if-env-changed=AMDGPU_TARGETS"); 185 | if let Ok(gpu_targets) = env::var("AMDGPU_TARGETS") { 186 | config.define("AMDGPU_TARGETS", gpu_targets); 187 | } 188 | } 189 | 190 | if cfg!(feature = "vulkan") { 191 | config.define("GGML_VULKAN", "ON"); 192 | if cfg!(windows) { 193 | println!("cargo:rerun-if-env-changed=VULKAN_SDK"); 194 | println!("cargo:rustc-link-lib=vulkan-1"); 195 | let vulkan_path = match env::var("VULKAN_SDK") { 196 | Ok(path) => PathBuf::from(path), 197 | Err(_) => panic!( 198 | "Please install Vulkan SDK and ensure that VULKAN_SDK env variable is set" 199 | ), 200 | }; 201 | let vulkan_lib_path = vulkan_path.join("Lib"); 202 | println!("cargo:rustc-link-search={}", vulkan_lib_path.display()); 203 | } else if cfg!(target_os = "macos") { 204 | println!("cargo:rerun-if-env-changed=VULKAN_SDK"); 205 | println!("cargo:rustc-link-lib=vulkan"); 206 | let vulkan_path = match env::var("VULKAN_SDK") { 207 | Ok(path) => PathBuf::from(path), 208 | Err(_) => panic!( 209 | "Please install Vulkan SDK and ensure that VULKAN_SDK env variable is set" 210 | ), 211 | }; 212 | let vulkan_lib_path = vulkan_path.join("lib"); 213 | println!("cargo:rustc-link-search={}", vulkan_lib_path.display()); 214 | } else { 215 | println!("cargo:rustc-link-lib=vulkan"); 216 | } 217 | } 218 | 219 | if cfg!(feature = "openblas") { 220 | config.define("GGML_BLAS", "ON"); 221 | config.define("GGML_BLAS_VENDOR", "OpenBLAS"); 222 | if env::var("BLAS_INCLUDE_DIRS").is_err() { 223 | panic!("BLAS_INCLUDE_DIRS environment variable must be set when using OpenBLAS"); 224 | } 225 | config.define("BLAS_INCLUDE_DIRS", env::var("BLAS_INCLUDE_DIRS").unwrap()); 226 | println!("cargo:rerun-if-env-changed=BLAS_INCLUDE_DIRS"); 227 | } 228 | 229 | if cfg!(feature = "metal") { 230 | config.define("GGML_METAL", "ON"); 231 | config.define("GGML_METAL_NDEBUG", "ON"); 232 | config.define("GGML_METAL_EMBED_LIBRARY", "ON"); 233 | } else { 234 | // Metal is enabled by default, so we need to explicitly disable it 235 | config.define("GGML_METAL", "OFF"); 236 | } 237 | 238 | if cfg!(debug_assertions) || cfg!(feature = "force-debug") { 239 | // debug builds are too slow to even remotely be usable, 240 | // so we build with optimizations even in debug mode 241 | config.define("CMAKE_BUILD_TYPE", "RelWithDebInfo"); 242 | config.cxxflag("-DWHISPER_DEBUG"); 243 | } 244 | 245 | // Allow passing any WHISPER or CMAKE compile flags 246 | for (key, value) in env::vars() { 247 | let is_whisper_flag = 248 | key.starts_with("WHISPER_") && key != "WHISPER_DONT_GENERATE_BINDINGS"; 249 | let is_cmake_flag = key.starts_with("CMAKE_"); 250 | if is_whisper_flag || is_cmake_flag { 251 | config.define(&key, &value); 252 | } 253 | } 254 | 255 | if cfg!(not(feature = "openmp")) { 256 | config.define("GGML_OPENMP", "OFF"); 257 | } 258 | 259 | let destination = config.build(); 260 | 261 | add_link_search_path(&out.join("build")).unwrap(); 262 | 263 | println!("cargo:rustc-link-search=native={}", destination.display()); 264 | println!("cargo:rustc-link-lib=static=whisper"); 265 | println!("cargo:rustc-link-lib=static=ggml"); 266 | println!("cargo:rustc-link-lib=static=ggml-base"); 267 | println!("cargo:rustc-link-lib=static=ggml-cpu"); 268 | if cfg!(target_os = "macos") || cfg!(feature = "openblas") { 269 | println!("cargo:rustc-link-lib=static=ggml-blas"); 270 | } 271 | if cfg!(feature = "vulkan") { 272 | println!("cargo:rustc-link-lib=static=ggml-vulkan"); 273 | } 274 | 275 | if cfg!(feature = "metal") { 276 | println!("cargo:rustc-link-lib=static=ggml-metal"); 277 | } 278 | 279 | if cfg!(feature = "cuda") { 280 | println!("cargo:rustc-link-lib=static=ggml-cuda"); 281 | } 282 | 283 | if cfg!(feature = "openblas") { 284 | println!("cargo:rustc-link-lib=static=ggml-blas"); 285 | } 286 | 287 | println!( 288 | "cargo:WHISPER_CPP_VERSION={}", 289 | get_whisper_cpp_version(&whisper_root) 290 | .expect("Failed to read whisper.cpp CMake config") 291 | .expect("Could not find whisper.cpp version declaration"), 292 | ); 293 | 294 | // for whatever reason this file is generated during build and triggers cargo complaining 295 | _ = std::fs::remove_file("bindings/javascript/package.json"); 296 | } 297 | 298 | // From https://github.com/alexcrichton/cc-rs/blob/fba7feded71ee4f63cfe885673ead6d7b4f2f454/src/lib.rs#L2462 299 | fn get_cpp_link_stdlib(target: &str) -> Option<&'static str> { 300 | if target.contains("msvc") { 301 | None 302 | } else if target.contains("apple") || target.contains("freebsd") || target.contains("openbsd") { 303 | Some("c++") 304 | } else if target.contains("android") { 305 | Some("c++_shared") 306 | } else { 307 | Some("stdc++") 308 | } 309 | } 310 | 311 | fn add_link_search_path(dir: &std::path::Path) -> std::io::Result<()> { 312 | if dir.is_dir() { 313 | println!("cargo:rustc-link-search={}", dir.display()); 314 | for entry in std::fs::read_dir(dir)? { 315 | add_link_search_path(&entry?.path())?; 316 | } 317 | } 318 | Ok(()) 319 | } 320 | 321 | fn get_whisper_cpp_version(whisper_root: &std::path::Path) -> std::io::Result> { 322 | let cmake_lists = BufReader::new(File::open(whisper_root.join("CMakeLists.txt"))?); 323 | 324 | for line in cmake_lists.lines() { 325 | let line = line?; 326 | 327 | if let Some(suffix) = line.strip_prefix(r#"project("whisper.cpp" VERSION "#) { 328 | let whisper_cpp_version = suffix.trim_end_matches(')'); 329 | return Ok(Some(whisper_cpp_version.into())); 330 | } 331 | } 332 | 333 | Ok(None) 334 | } 335 | -------------------------------------------------------------------------------- /sys/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![allow(non_upper_case_globals)] 2 | #![allow(non_camel_case_types)] 3 | #![allow(non_snake_case)] 4 | 5 | include!(concat!(env!("OUT_DIR"), "/bindings.rs")); 6 | -------------------------------------------------------------------------------- /sys/wrapper.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | --------------------------------------------------------------------------------