├── .github ├── dependabot.yml └── workflows │ ├── release-plz.yml │ ├── release.yml │ └── test-pr.yml ├── .gitignore ├── CHANGELOG.md ├── Cargo.toml ├── LICENSE ├── README.md ├── assets └── samples_jfk.wav ├── cliff.toml ├── dist-workspace.toml ├── release-plz.toml ├── rust-toolchain.toml ├── simple-whisper-cli ├── Cargo.toml ├── README.md └── src │ └── main.rs ├── simple-whisper-server ├── Cargo.toml ├── README.md └── src │ └── main.rs └── simple-whisper ├── Cargo.toml ├── README.md └── src ├── download.rs ├── language.rs ├── lib.rs ├── model.rs └── transcribe.rs /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | # Maintain dependencies for Cargo 4 | - package-ecosystem: cargo 5 | directory: "/" 6 | schedule: 7 | interval: daily 8 | open-pull-requests-limit: 10 9 | 10 | # Maintain dependencies for GitHub Actions 11 | - package-ecosystem: github-actions 12 | directory: "/" 13 | schedule: 14 | interval: daily 15 | open-pull-requests-limit: 10 -------------------------------------------------------------------------------- /.github/workflows/release-plz.yml: -------------------------------------------------------------------------------- 1 | name: Relase-Plz 2 | on: 3 | push: 4 | branches: 5 | - main 6 | 7 | jobs: 8 | release-plz: 9 | name: Release-plz 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Checkout repository 13 | uses: actions/checkout@v4 14 | with: 15 | fetch-depth: 0 16 | token: ${{ secrets.RELEASE_PLZ_TOKEN }} 17 | - name: install dependencies (ubuntu only) 18 | run: | 19 | sudo apt-get update 20 | sudo apt-get install -y libasound2-dev 21 | - name: Install Rust toolchain 22 | uses: dtolnay/rust-toolchain@stable 23 | - name: Run release-plz 24 | uses: MarcoIeni/release-plz-action@v0.5 25 | env: 26 | GITHUB_TOKEN: ${{ secrets.RELEASE_PLZ_TOKEN }} 27 | CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }} 28 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | # This file was autogenerated by dist: https://opensource.axo.dev/cargo-dist/ 2 | # 3 | # Copyright 2022-2024, axodotdev 4 | # SPDX-License-Identifier: MIT or Apache-2.0 5 | # 6 | # CI that: 7 | # 8 | # * checks for a Git Tag that looks like a release 9 | # * builds artifacts with dist (archives, installers, hashes) 10 | # * uploads those artifacts to temporary workflow zip 11 | # * on success, uploads the artifacts to a GitHub Release 12 | # 13 | # Note that the GitHub Release will be created with a generated 14 | # title/body based on your changelogs. 15 | 16 | name: Release 17 | permissions: 18 | "contents": "write" 19 | 20 | # This task will run whenever you push a git tag that looks like a version 21 | # like "1.0.0", "v0.1.0-prerelease.1", "my-app/0.1.0", "releases/v1.0.0", etc. 22 | # Various formats will be parsed into a VERSION and an optional PACKAGE_NAME, where 23 | # PACKAGE_NAME must be the name of a Cargo package in your workspace, and VERSION 24 | # must be a Cargo-style SemVer Version (must have at least major.minor.patch). 25 | # 26 | # If PACKAGE_NAME is specified, then the announcement will be for that 27 | # package (erroring out if it doesn't have the given version or isn't dist-able). 28 | # 29 | # If PACKAGE_NAME isn't specified, then the announcement will be for all 30 | # (dist-able) packages in the workspace with that version (this mode is 31 | # intended for workspaces with only one dist-able package, or with all dist-able 32 | # packages versioned/released in lockstep). 33 | # 34 | # If you push multiple tags at once, separate instances of this workflow will 35 | # spin up, creating an independent announcement for each one. However, GitHub 36 | # will hard limit this to 3 tags per commit, as it will assume more tags is a 37 | # mistake. 38 | # 39 | # If there's a prerelease-style suffix to the version, then the release(s) 40 | # will be marked as a prerelease. 41 | on: 42 | pull_request: 43 | push: 44 | tags: 45 | - '**[0-9]+.[0-9]+.[0-9]+*' 46 | 47 | jobs: 48 | # Run 'dist plan' (or host) to determine what tasks we need to do 49 | plan: 50 | runs-on: "ubuntu-22.04" 51 | outputs: 52 | val: ${{ steps.plan.outputs.manifest }} 53 | tag: ${{ !github.event.pull_request && github.ref_name || '' }} 54 | tag-flag: ${{ !github.event.pull_request && format('--tag={0}', github.ref_name) || '' }} 55 | publishing: ${{ !github.event.pull_request }} 56 | env: 57 | GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} 58 | steps: 59 | - uses: actions/checkout@v4 60 | with: 61 | submodules: recursive 62 | - name: Install dist 63 | # we specify bash to get pipefail; it guards against the `curl` command 64 | # failing. otherwise `sh` won't catch that `curl` returned non-0 65 | shell: bash 66 | run: "curl --proto '=https' --tlsv1.2 -LsSf https://github.com/axodotdev/cargo-dist/releases/download/v0.28.0/cargo-dist-installer.sh | sh" 67 | - name: Cache dist 68 | uses: actions/upload-artifact@v4 69 | with: 70 | name: cargo-dist-cache 71 | path: ~/.cargo/bin/dist 72 | # sure would be cool if github gave us proper conditionals... 73 | # so here's a doubly-nested ternary-via-truthiness to try to provide the best possible 74 | # functionality based on whether this is a pull_request, and whether it's from a fork. 75 | # (PRs run on the *source* but secrets are usually on the *target* -- that's *good* 76 | # but also really annoying to build CI around when it needs secrets to work right.) 77 | - id: plan 78 | run: | 79 | dist ${{ (!github.event.pull_request && format('host --steps=create --tag={0}', github.ref_name)) || 'plan' }} --output-format=json > plan-dist-manifest.json 80 | echo "dist ran successfully" 81 | cat plan-dist-manifest.json 82 | echo "manifest=$(jq -c "." plan-dist-manifest.json)" >> "$GITHUB_OUTPUT" 83 | - name: "Upload dist-manifest.json" 84 | uses: actions/upload-artifact@v4 85 | with: 86 | name: artifacts-plan-dist-manifest 87 | path: plan-dist-manifest.json 88 | 89 | # Build and packages all the platform-specific things 90 | build-local-artifacts: 91 | name: build-local-artifacts (${{ join(matrix.targets, ', ') }}) 92 | # Let the initial task tell us to not run (currently very blunt) 93 | needs: 94 | - plan 95 | if: ${{ fromJson(needs.plan.outputs.val).ci.github.artifacts_matrix.include != null && (needs.plan.outputs.publishing == 'true' || fromJson(needs.plan.outputs.val).ci.github.pr_run_mode == 'upload') }} 96 | strategy: 97 | fail-fast: false 98 | # Target platforms/runners are computed by dist in create-release. 99 | # Each member of the matrix has the following arguments: 100 | # 101 | # - runner: the github runner 102 | # - dist-args: cli flags to pass to dist 103 | # - install-dist: expression to run to install dist on the runner 104 | # 105 | # Typically there will be: 106 | # - 1 "global" task that builds universal installers 107 | # - N "local" tasks that build each platform's binaries and platform-specific installers 108 | matrix: ${{ fromJson(needs.plan.outputs.val).ci.github.artifacts_matrix }} 109 | runs-on: ${{ matrix.runner }} 110 | container: ${{ matrix.container && matrix.container.image || null }} 111 | env: 112 | GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} 113 | BUILD_MANIFEST_NAME: target/distrib/${{ join(matrix.targets, '-') }}-dist-manifest.json 114 | steps: 115 | - name: enable windows longpaths 116 | run: | 117 | git config --global core.longpaths true 118 | - uses: actions/checkout@v4 119 | with: 120 | submodules: recursive 121 | - name: Install Rust non-interactively if not already installed 122 | if: ${{ matrix.container }} 123 | run: | 124 | if ! command -v cargo > /dev/null 2>&1; then 125 | curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y 126 | echo "$HOME/.cargo/bin" >> $GITHUB_PATH 127 | fi 128 | - name: Install dist 129 | run: ${{ matrix.install_dist.run }} 130 | # Get the dist-manifest 131 | - name: Fetch local artifacts 132 | uses: actions/download-artifact@v4 133 | with: 134 | pattern: artifacts-* 135 | path: target/distrib/ 136 | merge-multiple: true 137 | - name: Install dependencies 138 | run: | 139 | ${{ matrix.packages_install }} 140 | - name: install dependencies (ubuntu only) 141 | if: matrix.runner == 'ubuntu-22.04' 142 | run: | 143 | sudo apt-get update 144 | sudo apt-get install -y libasound2-dev 145 | - name: Build artifacts 146 | run: | 147 | # Actually do builds and make zips and whatnot 148 | dist build ${{ needs.plan.outputs.tag-flag }} --print=linkage --output-format=json ${{ matrix.dist_args }} > dist-manifest.json 149 | echo "dist ran successfully" 150 | - id: cargo-dist 151 | name: Post-build 152 | # We force bash here just because github makes it really hard to get values up 153 | # to "real" actions without writing to env-vars, and writing to env-vars has 154 | # inconsistent syntax between shell and powershell. 155 | shell: bash 156 | run: | 157 | # Parse out what we just built and upload it to scratch storage 158 | echo "paths<> "$GITHUB_OUTPUT" 159 | dist print-upload-files-from-manifest --manifest dist-manifest.json >> "$GITHUB_OUTPUT" 160 | echo "EOF" >> "$GITHUB_OUTPUT" 161 | 162 | cp dist-manifest.json "$BUILD_MANIFEST_NAME" 163 | - name: "Upload artifacts" 164 | uses: actions/upload-artifact@v4 165 | with: 166 | name: artifacts-build-local-${{ join(matrix.targets, '_') }} 167 | path: | 168 | ${{ steps.cargo-dist.outputs.paths }} 169 | ${{ env.BUILD_MANIFEST_NAME }} 170 | 171 | # Build and package all the platform-agnostic(ish) things 172 | build-global-artifacts: 173 | needs: 174 | - plan 175 | - build-local-artifacts 176 | runs-on: "ubuntu-22.04" 177 | env: 178 | GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} 179 | BUILD_MANIFEST_NAME: target/distrib/global-dist-manifest.json 180 | steps: 181 | - uses: actions/checkout@v4 182 | with: 183 | submodules: recursive 184 | - name: Install cached dist 185 | uses: actions/download-artifact@v4 186 | with: 187 | name: cargo-dist-cache 188 | path: ~/.cargo/bin/ 189 | - run: chmod +x ~/.cargo/bin/dist 190 | # Get all the local artifacts for the global tasks to use (for e.g. checksums) 191 | - name: Fetch local artifacts 192 | uses: actions/download-artifact@v4 193 | with: 194 | pattern: artifacts-* 195 | path: target/distrib/ 196 | merge-multiple: true 197 | - id: cargo-dist 198 | shell: bash 199 | run: | 200 | dist build ${{ needs.plan.outputs.tag-flag }} --output-format=json "--artifacts=global" > dist-manifest.json 201 | echo "dist ran successfully" 202 | 203 | # Parse out what we just built and upload it to scratch storage 204 | echo "paths<> "$GITHUB_OUTPUT" 205 | jq --raw-output ".upload_files[]" dist-manifest.json >> "$GITHUB_OUTPUT" 206 | echo "EOF" >> "$GITHUB_OUTPUT" 207 | 208 | cp dist-manifest.json "$BUILD_MANIFEST_NAME" 209 | - name: "Upload artifacts" 210 | uses: actions/upload-artifact@v4 211 | with: 212 | name: artifacts-build-global 213 | path: | 214 | ${{ steps.cargo-dist.outputs.paths }} 215 | ${{ env.BUILD_MANIFEST_NAME }} 216 | # Determines if we should publish/announce 217 | host: 218 | needs: 219 | - plan 220 | - build-local-artifacts 221 | - build-global-artifacts 222 | # Only run if we're "publishing", and only if local and global didn't fail (skipped is fine) 223 | if: ${{ always() && needs.plan.outputs.publishing == 'true' && (needs.build-global-artifacts.result == 'skipped' || needs.build-global-artifacts.result == 'success') && (needs.build-local-artifacts.result == 'skipped' || needs.build-local-artifacts.result == 'success') }} 224 | env: 225 | GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} 226 | runs-on: "ubuntu-22.04" 227 | outputs: 228 | val: ${{ steps.host.outputs.manifest }} 229 | steps: 230 | - uses: actions/checkout@v4 231 | with: 232 | submodules: recursive 233 | - name: Install cached dist 234 | uses: actions/download-artifact@v4 235 | with: 236 | name: cargo-dist-cache 237 | path: ~/.cargo/bin/ 238 | - run: chmod +x ~/.cargo/bin/dist 239 | # Fetch artifacts from scratch-storage 240 | - name: Fetch artifacts 241 | uses: actions/download-artifact@v4 242 | with: 243 | pattern: artifacts-* 244 | path: target/distrib/ 245 | merge-multiple: true 246 | - id: host 247 | shell: bash 248 | run: | 249 | dist host ${{ needs.plan.outputs.tag-flag }} --steps=upload --steps=release --output-format=json > dist-manifest.json 250 | echo "artifacts uploaded and released successfully" 251 | cat dist-manifest.json 252 | echo "manifest=$(jq -c "." dist-manifest.json)" >> "$GITHUB_OUTPUT" 253 | - name: "Upload dist-manifest.json" 254 | uses: actions/upload-artifact@v4 255 | with: 256 | # Overwrite the previous copy 257 | name: artifacts-dist-manifest 258 | path: dist-manifest.json 259 | # Create a GitHub Release while uploading all files to it 260 | - name: "Download GitHub Artifacts" 261 | uses: actions/download-artifact@v4 262 | with: 263 | pattern: artifacts-* 264 | path: artifacts 265 | merge-multiple: true 266 | - name: Cleanup 267 | run: | 268 | # Remove the granular manifests 269 | rm -f artifacts/*-dist-manifest.json 270 | - name: Create GitHub Release 271 | env: 272 | PRERELEASE_FLAG: "${{ fromJson(steps.host.outputs.manifest).announcement_is_prerelease && '--prerelease' || '' }}" 273 | ANNOUNCEMENT_TITLE: "${{ fromJson(steps.host.outputs.manifest).announcement_title }}" 274 | ANNOUNCEMENT_BODY: "${{ fromJson(steps.host.outputs.manifest).announcement_github_body }}" 275 | RELEASE_COMMIT: "${{ github.sha }}" 276 | run: | 277 | # Write and read notes from a file to avoid quoting breaking things 278 | echo "$ANNOUNCEMENT_BODY" > $RUNNER_TEMP/notes.txt 279 | 280 | gh release create "${{ needs.plan.outputs.tag }}" --target "$RELEASE_COMMIT" $PRERELEASE_FLAG --title "$ANNOUNCEMENT_TITLE" --notes-file "$RUNNER_TEMP/notes.txt" artifacts/* 281 | 282 | announce: 283 | needs: 284 | - plan 285 | - host 286 | # use "always() && ..." to allow us to wait for all publish jobs while 287 | # still allowing individual publish jobs to skip themselves (for prereleases). 288 | # "host" however must run to completion, no skipping allowed! 289 | if: ${{ always() && needs.host.result == 'success' }} 290 | runs-on: "ubuntu-22.04" 291 | env: 292 | GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} 293 | steps: 294 | - uses: actions/checkout@v4 295 | with: 296 | submodules: recursive -------------------------------------------------------------------------------- /.github/workflows/test-pr.yml: -------------------------------------------------------------------------------- 1 | name: "test-on-pr" 2 | on: [pull_request] 3 | 4 | jobs: 5 | check-whisper: 6 | runs-on: ubuntu-latest 7 | steps: 8 | - name: Checkout repository 9 | uses: actions/checkout@v4 10 | 11 | - name: Install Rust toolchain 12 | uses: dtolnay/rust-toolchain@stable 13 | with: 14 | components: rustfmt 15 | 16 | - name: check fmt 17 | run: cargo fmt --all -- --check 18 | 19 | test-whisper: 20 | strategy: 21 | fail-fast: false 22 | matrix: 23 | platform: [macos-latest, ubuntu-latest, windows-latest] 24 | 25 | runs-on: ${{ matrix.platform }} 26 | steps: 27 | - name: Checkout repository 28 | uses: actions/checkout@v4 29 | 30 | - name: install dependencies (ubuntu only) 31 | if: matrix.platform == 'ubuntu-latest' 32 | run: | 33 | sudo apt-get update 34 | sudo apt-get install -y libasound2-dev 35 | 36 | - name: install vulkan sdk 37 | if: matrix.platform == 'ubuntu-latest' 38 | uses: humbletim/install-vulkan-sdk@c2aa128094d42ba02959a660f03e0a4e012192f9 39 | 40 | - name: Install Rust toolchain 41 | uses: dtolnay/rust-toolchain@stable 42 | 43 | - name: Rust cache 44 | uses: swatinem/rust-cache@v2 45 | 46 | - name: cargo build 47 | run: cargo build 48 | 49 | - name: run generic tests 50 | run: cargo test 51 | 52 | - name: cargo build (metal) 53 | if: matrix.platform == 'macos-latest' 54 | run: cargo build --features metal 55 | 56 | - name: run generic tests (metal) 57 | if: matrix.platform == 'macos-latest' 58 | run: cargo test --features metal -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | bin/act 3 | Cargo.lock 4 | .idea 5 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | ## [0.1.5] - 2025-04-22 6 | 7 | ### 🚀 Features 8 | 9 | - 2024 edition (#27) 10 | 11 | 12 | ## [0.1.4] - 2025-04-10 13 | 14 | ### 🚀 Features 15 | 16 | - Transcribe as stream 17 | 18 | ### 🐛 Bug Fixes 19 | 20 | - *(doc)* Typo fixing 21 | 22 | ### 🚜 Refactor 23 | 24 | - Removed whisper_cpp_impl mod and improved model download methods (#26) 25 | 26 | 27 | ## [0.1.3] - 2025-03-03 28 | 29 | ### 🚀 Features 30 | 31 | - Whisper cpp quantized and v3 turbo models ([#17](https://github.com/newfla/simple-whisper/pull/17)) 32 | - Model download progress event ([#21](https://github.com/newfla/simple-whisper/pull/21)) 33 | 34 | ### 🐛 Bug Fixes 35 | 36 | - *(cli)* Display progressbar model download 37 | 38 | ### 🚜 Refactor 39 | 40 | - Remove burn backend ([#20](https://github.com/newfla/simple-whisper/pull/20)) 41 | 42 | 43 | ## [0.1.2] - 2024-09-16 44 | 45 | ### 🚀 Features 46 | 47 | - Whisper.cpp vulkan backend ([#15](https://github.com/newfla/simple-whisper/pull/15)) 48 | 49 | ### ⚙️ Miscellaneous Tasks 50 | 51 | - Typo fixing ([#13](https://github.com/newfla/simple-whisper/pull/13)) 52 | 53 | 54 | ## [0.1.1] - 2024-07-05 55 | 56 | ### 🚀 Features 57 | 58 | - Overlapping mel segments ([#10](https://github.com/newfla/simple-whisper/pull/10)) 59 | 60 | 61 | ## [0.1.0] - 2024-07-04 62 | 63 | ### 🚀 Features 64 | 65 | - Server tracing 66 | - Improved download model functions 67 | - Cli progressbar 68 | - Added audio sampling 69 | - First implementation 70 | - Min optimization 71 | - Added offest handling 72 | 73 | ### 🐛 Bug Fixes 74 | 75 | - Models hf links 76 | - Linting clippy 77 | 78 | ### 📚 Documentation 79 | 80 | - *(lib)* Added doc to struct 81 | 82 | ### ⚙️ Miscellaneous Tasks 83 | 84 | - Dependabot setup 85 | - Test pr action 86 | - Added release-plz cargo-dist actions 87 | - Fix ubuntu runner 88 | - Fix cargo publish ([#6](https://github.com/newfla/simple-whisper/pull/6)) 89 | 90 | 91 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | members = [ 3 | "simple-whisper-cli", 4 | "simple-whisper-server", 5 | "simple-whisper" 6 | ] 7 | resolver = "2" 8 | 9 | [workspace.package] 10 | edition = "2024" 11 | version = "0.1.5" 12 | authors = ["Flavio Bizzarri "] 13 | license = "Apache-2.0" 14 | repository = "https://github.com/newfla/simple-whisper" 15 | keywords = ["ai", "whisper"] 16 | categories = ["science","multimedia"] 17 | 18 | [workspace.dependencies] 19 | anyhow = "1.0.86" 20 | axum = { version = "0.7.5", features = ["json", "ws"] } 21 | clap = { version = "4.5.7", features = ["derive"] } 22 | derive_builder = "0.20.0" 23 | hf-hub = { version = "0.4.2", features = ["tokio"] } 24 | indicatif = { version = "0.17.8", features = ["improved_unicode"] } 25 | num_cpus = "1.16.0" 26 | rodio = { version = "0.20.1"} 27 | serde = "1.0.203" 28 | serde_json = "1.0.117" 29 | strum = { version = "0.26", features = ["derive"] } 30 | tempfile = "3.10.1" 31 | thiserror = "1.0.61" 32 | tokenizers = "0.19.1" 33 | tokio = { version = "1.38.0", features = ["full"] } 34 | tokio-stream = "0.1.17" 35 | tower-http = { version = "0.5.2", features = ["trace"] } 36 | tracing = "0.1" 37 | tracing-subscriber = { version = "0.3", features = ["env-filter"] } 38 | 39 | # The profile that 'cargo dist' will build with 40 | [profile.dist] 41 | inherits = "release" 42 | lto = "thin" 43 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Simple Whisper 2 | 3 | Transcription library/cli/server based on [OpenAI Whisper](https://github.com/openai/whisper) model written using [whisper-rs](https://github.com/tazz4843/whisper-rs). 4 | 5 | ## What is included? 6 | 7 | - [Simple Whisper lib](./simple-whisper/): Implements the Whisper model via: 8 | - [whisper.cpp Backend](https://github.com/tazz4843/whisper-rs). Weights are automatically downloaded from [Hugging Face repo](https://huggingface.co/ggerganov/whisper.cpp). 9 | - Supported codec: flac, vorbis, wav, mp3 10 | 11 | - [Simple Whisper cli](./simple-whisper-cli/): CLI application useful to transcribe audio file. For more information see the [README.md](./simple-whisper-cli/README.md). 12 | 13 | - [Simple Whisper server](./simple-whisper-server/): Websocket server that transcribe uploaded files. 14 | 15 | ## Goals 16 | - Show how malleable RUST is, scaling from server to GPU code. 17 | - Support a high variety of platforms. 18 | - Fast enough on every platform. 19 | 20 | ## No Goals 21 | - It is **NOT** intended to be the fastest/accurate Whisper implementation. 22 | 23 | ## Credits 24 | The project was inspired by: 25 | - Candle implementation: [rwhisper](https://github.com/floneum/floneum/tree/main/models/rwhisper). -------------------------------------------------------------------------------- /assets/samples_jfk.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/newfla/simple-whisper/19b9891b6034fdb1acba0bbc070f0d15248f64b5/assets/samples_jfk.wav -------------------------------------------------------------------------------- /cliff.toml: -------------------------------------------------------------------------------- 1 | # git-cliff ~ default configuration file 2 | # https://git-cliff.org/docs/configuration 3 | # 4 | # Lines starting with "#" are comments. 5 | # Configuration options are organized into tables and keys. 6 | # See documentation for more information on available options. 7 | 8 | [changelog] 9 | # template for the changelog footer 10 | header = """ 11 | # Changelog\n 12 | All notable changes to this project will be documented in this file.\n 13 | """ 14 | # template for the changelog body 15 | # https://keats.github.io/tera/docs/#introduction 16 | body = """ 17 | {% if version %}\ 18 | ## [{{ version | trim_start_matches(pat="v") }}] - {{ timestamp | date(format="%Y-%m-%d") }} 19 | {% else %}\ 20 | ## [unreleased] 21 | {% endif %}\ 22 | {% for group, commits in commits | group_by(attribute="group") %} 23 | ### {{ group | striptags | trim | upper_first }} 24 | {% for commit in commits %} 25 | - {% if commit.scope %}*({{ commit.scope }})* {% endif %}\ 26 | {% if commit.breaking %}[**breaking**] {% endif %}\ 27 | {{ commit.message | upper_first }}\ 28 | {% endfor %} 29 | {% endfor %}\n 30 | """ 31 | # template for the changelog footer 32 | footer = """ 33 | 34 | """ 35 | # remove the leading and trailing s 36 | trim = true 37 | # postprocessors 38 | postprocessors = [ 39 | # { pattern = '', replace = "https://github.com/orhun/git-cliff" }, # replace repository URL 40 | ] 41 | 42 | [git] 43 | # parse the commits based on https://www.conventionalcommits.org 44 | conventional_commits = true 45 | # filter out the commits that are not conventional 46 | filter_unconventional = true 47 | # process each line of a commit as an individual commit 48 | split_commits = false 49 | # regex for preprocessing the commit messages 50 | commit_preprocessors = [ 51 | # Replace issue numbers 52 | #{ pattern = '\((\w+\s)?#([0-9]+)\)', replace = "([#${2}](/issues/${2}))"}, 53 | # Check spelling of the commit with https://github.com/crate-ci/typos 54 | # If the spelling is incorrect, it will be automatically fixed. 55 | #{ pattern = '.*', replace_command = 'typos --write-changes -' }, 56 | ] 57 | # regex for parsing and grouping commits 58 | commit_parsers = [ 59 | { message = "^feat", group = "🚀 Features" }, 60 | { message = "^fix", group = "🐛 Bug Fixes" }, 61 | { message = "^doc", group = "📚 Documentation" }, 62 | { message = "^perf", group = "⚡ Performance" }, 63 | { message = "^refactor", group = "🚜 Refactor" }, 64 | { message = "^style", group = "🎨 Styling" }, 65 | { message = "^test", group = "🧪 Testing" }, 66 | { message = "^chore\\(release\\): prepare for", skip = true }, 67 | { message = "^chore\\(deps.*\\)", skip = true }, 68 | { message = "^chore\\(pr\\)", skip = true }, 69 | { message = "^chore\\(pull\\)", skip = true }, 70 | { message = "^chore|^ci", group = "⚙️ Miscellaneous Tasks" }, 71 | { body = ".*security", group = "🛡️ Security" }, 72 | { message = "^revert", group = "◀️ Revert" }, 73 | ] 74 | # protect breaking changes from being skipped due to matching a skipping commit_parser 75 | protect_breaking_commits = false 76 | # filter out the commits that are not matched by commit parsers 77 | filter_commits = false 78 | # regex for matching git tags 79 | # tag_pattern = "v[0-9].*" 80 | # regex for skipping tags 81 | # skip_tags = "" 82 | # regex for ignoring tags 83 | # ignore_tags = "" 84 | # sort the tags topologically 85 | topo_order = false 86 | # sort the commits inside sections by oldest/newest order 87 | sort_commits = "oldest" 88 | # limit the number of commits included in the changelog. 89 | # limit_commits = 42 90 | -------------------------------------------------------------------------------- /dist-workspace.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | members = ["cargo:."] 3 | 4 | # Config for 'dist' 5 | [dist] 6 | # The preferred dist version to use in CI (Cargo.toml SemVer syntax) 7 | cargo-dist-version = "0.28.0" 8 | # CI backends to support 9 | ci = "github" 10 | # The installers to generate for each app 11 | installers = [] 12 | # Target platforms to build apps for (Rust target-triple syntax) 13 | targets = ["aarch64-apple-darwin", "x86_64-apple-darwin", "x86_64-unknown-linux-gnu", "x86_64-pc-windows-msvc"] 14 | # Which actions to run on pull requests 15 | pr-run-mode = "upload" 16 | # Skip checking whether the specified configuration files are up to date 17 | allow-dirty = ["ci"] 18 | 19 | [dist.github-custom-runners] 20 | x86_64-unknown-linux-gnu = "ubuntu-22.04" -------------------------------------------------------------------------------- /release-plz.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | # path of the git-cliff configuration 3 | changelog_config = "cliff.toml" 4 | 5 | # enable changelog updates 6 | changelog_update = true 7 | 8 | # update dependencies with `cargo update` 9 | dependencies_update = false 10 | 11 | # create tags for the releases 12 | git_tag_enable = true 13 | 14 | git_tag_name = "v{{ version }}" 15 | 16 | # disable GitHub releases 17 | git_release_enable = false 18 | 19 | # labels for the release PR 20 | pr_labels = ["release"] 21 | 22 | # disallow updating repositories with uncommitted changes 23 | allow_dirty = true 24 | 25 | # disallow packaging with uncommitted changes 26 | publish_allow_dirty = false 27 | 28 | # disable running `cargo-semver-checks` 29 | semver_check = false 30 | 31 | changelog_path = "./CHANGELOG.md" 32 | -------------------------------------------------------------------------------- /rust-toolchain.toml: -------------------------------------------------------------------------------- 1 | [toolchain] 2 | channel = "1.85.1" -------------------------------------------------------------------------------- /simple-whisper-cli/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "simple-whisper-cli" 3 | version.workspace = true 4 | edition.workspace = true 5 | authors.workspace = true 6 | license.workspace = true 7 | readme.workspace = true 8 | repository.workspace = true 9 | keywords.workspace = true 10 | categories.workspace = true 11 | publish = false 12 | 13 | [package.metadata.dist] 14 | dist = true 15 | 16 | [dependencies] 17 | clap.workspace = true 18 | indicatif.workspace = true 19 | simple-whisper = { path = "../simple-whisper"} 20 | strum.workspace = true 21 | tokio.workspace = true 22 | tokio-stream.workspace = true 23 | 24 | [features] 25 | vulkan = ["simple-whisper/vulkan"] 26 | cuda = ["simple-whisper/cuda"] 27 | metal = ["simple-whisper/metal"] 28 | -------------------------------------------------------------------------------- /simple-whisper-cli/README.md: -------------------------------------------------------------------------------- 1 | # Simple Whisper CLI 2 | A modest CLI for speech transcription 3 | 4 | ## Example 5 | `simple-whisper-cli transcribe recording.mp3 tiny_en en output.txt` 6 | 7 | ## Usage 8 | 9 | ``` 10 | Usage: simple-whisper-cli transcribe [OPTIONS] 11 | 12 | Arguments: 13 | Audio file 14 | Which whisper model to use 15 | Audio language 16 | Output transcription file 17 | 18 | Options: 19 | --ignore-cache Ignore cached model files 20 | -v, --verbose Verbose STDOUT 21 | -h, --help Print help 22 | ``` -------------------------------------------------------------------------------- /simple-whisper-cli/src/main.rs: -------------------------------------------------------------------------------- 1 | use std::{path::PathBuf, str::FromStr}; 2 | 3 | use clap::{Parser, Subcommand}; 4 | use indicatif::{ProgressBar, ProgressStyle}; 5 | use simple_whisper::{Event, Language, Model, WhisperBuilder}; 6 | use strum::{EnumMessage, IntoEnumIterator}; 7 | use tokio::fs::write; 8 | use tokio_stream::StreamExt; 9 | 10 | #[derive(Parser, Debug)] 11 | #[command(version, about, long_about = None)] 12 | struct Cli { 13 | #[command(subcommand)] 14 | command: Commands, 15 | } 16 | #[derive(Debug, Subcommand)] 17 | enum Commands { 18 | /// Provide information on supported languages 19 | Languages { 20 | #[command(subcommand)] 21 | sub_command: LangCommands, 22 | }, 23 | /// Provide information on supported models 24 | Models { 25 | #[command(subcommand)] 26 | sub_command: ModelCommands, 27 | }, 28 | /// Transcribe audio file 29 | Transcribe { 30 | /// Audio file 31 | input_file: PathBuf, 32 | 33 | /// Which whisper model to use 34 | model: Model, 35 | 36 | /// Audio language 37 | language: Language, 38 | 39 | /// Output transcription file 40 | output_file: PathBuf, 41 | 42 | /// Ignore cached model files 43 | #[arg(long, required = false)] 44 | ignore_cache: bool, 45 | 46 | /// Force single segment output. This may be useful for streaming. 47 | #[arg(long, required = false)] 48 | single_segment: bool, 49 | 50 | /// Verbose STDOUT 51 | #[arg(long, required = false, short = 'v')] 52 | verbose: bool, 53 | }, 54 | } 55 | 56 | #[derive(Debug, Subcommand)] 57 | enum LangCommands { 58 | /// List supported languages 59 | List, 60 | 61 | /// Check if a language is supported by providing its code 62 | Check { 63 | /// The code associated to the language 64 | code: String, 65 | }, 66 | } 67 | 68 | #[derive(Debug, Subcommand)] 69 | enum ModelCommands { 70 | /// List supported models 71 | List, 72 | /// Download a model by providing its code 73 | Download { 74 | /// The code associated to the model 75 | code: String, 76 | 77 | /// Ignore cached model files 78 | #[arg(long, required = false)] 79 | ignore_cache: bool, 80 | }, 81 | } 82 | 83 | #[tokio::main] 84 | async fn main() { 85 | let cli = Cli::parse(); 86 | match cli.command { 87 | Commands::Languages { sub_command } => match sub_command { 88 | LangCommands::List => { 89 | for lang in Language::iter() { 90 | println!("{}", lang.get_message().unwrap()) 91 | } 92 | } 93 | LangCommands::Check { code } => match Language::from_str(&code) { 94 | Ok(lang) => println!("{lang} is supported"), 95 | Err(_) => println!("{code} not associated to any supported language"), 96 | }, 97 | }, 98 | Commands::Transcribe { 99 | input_file, 100 | output_file, 101 | model, 102 | language, 103 | ignore_cache, 104 | single_segment, 105 | verbose, 106 | } => { 107 | match WhisperBuilder::default() 108 | .language(language) 109 | .model(model) 110 | .progress_bar(true) 111 | .force_download(ignore_cache) 112 | .force_single_segment(single_segment) 113 | .build() 114 | { 115 | Ok(model) => { 116 | let mut segments: Vec = Vec::new(); 117 | let mut stream = model.transcribe(input_file); 118 | let pb = if verbose { 119 | None 120 | } else { 121 | let pb = ProgressBar::new(100); 122 | pb.set_style( 123 | ProgressStyle::with_template( 124 | "[{elapsed_precise}] {wide_bar} eta({eta})", 125 | ) 126 | .unwrap(), 127 | ); 128 | Some(pb) 129 | }; 130 | while let Some(msg) = stream.next().await { 131 | match msg { 132 | Ok(msg) => { 133 | if msg.is_segment() { 134 | segments.push(msg.to_string()); 135 | if verbose { 136 | println!("{msg:?}") 137 | } else if let Event::Segment { percentage, .. } = msg { 138 | pb.as_ref() 139 | .unwrap() 140 | .set_position((percentage * 100.) as u64); 141 | } 142 | } 143 | } 144 | Err(err) => println!("{err} occurred\nAborting!"), 145 | } 146 | } 147 | if let Some(pb) = pb { 148 | pb.finish(); 149 | } 150 | if let Err(err) = write(output_file, segments.join("\n")).await { 151 | println!("{err} occurred\nAborting!"); 152 | } 153 | } 154 | Err(err) => println!("{err} occurred\nAborting!"), 155 | } 156 | } 157 | Commands::Models { sub_command } => match sub_command { 158 | ModelCommands::List => { 159 | for model in Model::iter() { 160 | println!("{model}") 161 | } 162 | } 163 | ModelCommands::Download { code, ignore_cache } => match Model::from_str(&code) { 164 | Ok(model) => { 165 | if let Err(err) = model.download_model(ignore_cache).await { 166 | println!("Error {err}.\nAborting!"); 167 | } else { 168 | println!("Download completed"); 169 | } 170 | } 171 | Err(_) => println!("{code} not associated to any supported model"), 172 | }, 173 | }, 174 | } 175 | } 176 | -------------------------------------------------------------------------------- /simple-whisper-server/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "simple-whisper-server" 3 | version.workspace = true 4 | edition.workspace = true 5 | authors.workspace = true 6 | license.workspace = true 7 | readme.workspace = true 8 | repository.workspace = true 9 | keywords.workspace = true 10 | categories.workspace = true 11 | publish = false 12 | 13 | [package.metadata.dist] 14 | dist = true 15 | 16 | [dependencies] 17 | anyhow.workspace = true 18 | axum.workspace = true 19 | clap.workspace = true 20 | serde.workspace = true 21 | serde_json.workspace = true 22 | simple-whisper = { path = "../simple-whisper"} 23 | strum.workspace = true 24 | tempfile.workspace = true 25 | tokio.workspace = true 26 | tokio-stream.workspace = true 27 | tower-http.workspace = true 28 | thiserror.workspace = true 29 | tracing.workspace = true 30 | tracing-subscriber.workspace = true 31 | 32 | [dev-dependencies] 33 | futures = "0.3" 34 | reqwest = {version = "0.12.5", features = ["json"] } 35 | reqwest-websocket = "0.4.0" 36 | 37 | [features] 38 | vulkan = ["simple-whisper/vulkan"] 39 | cuda = ["simple-whisper/cuda"] 40 | metal = ["simple-whisper/metal"] 41 | -------------------------------------------------------------------------------- /simple-whisper-server/README.md: -------------------------------------------------------------------------------- 1 | # Simple Whisper Server 2 | A modest server for speech transcription 3 | 4 | ## Usage 5 | 6 | ``` 7 | Usage: simple-whisper-server [OPTIONS] 8 | 9 | Options: 10 | -p, --server-port Server listening port [default: 3000] 11 | -h, --help Print help 12 | -V, --version Print version 13 | ``` -------------------------------------------------------------------------------- /simple-whisper-server/src/main.rs: -------------------------------------------------------------------------------- 1 | use std::{str::FromStr, time::Duration}; 2 | 3 | use axum::{ 4 | Json, Router, 5 | extract::{ 6 | DefaultBodyLimit, MatchedPath, Path, Query, WebSocketUpgrade, 7 | ws::{Message, WebSocket}, 8 | }, 9 | http::{Request, StatusCode}, 10 | response::{IntoResponse, Response}, 11 | routing::get, 12 | serve, 13 | }; 14 | use clap::Parser; 15 | use serde::{Deserialize, Serialize}; 16 | use simple_whisper::{Event, Language, Model, Whisper, WhisperBuilder}; 17 | use strum::{EnumIs, EnumMessage, IntoEnumIterator}; 18 | use tempfile::NamedTempFile; 19 | use thiserror::Error; 20 | use tokio::{fs::write, net::TcpListener, spawn, sync::mpsc::unbounded_channel}; 21 | use tokio_stream::StreamExt; 22 | use tower_http::trace::TraceLayer; 23 | use tracing::info_span; 24 | use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; 25 | 26 | #[derive(Parser, Debug)] 27 | #[command(version, about, long_about = None)] 28 | struct Cli { 29 | /// Server listening port 30 | #[arg(long, short = 'p', default_value = "3000")] 31 | server_port: u16, 32 | } 33 | 34 | #[derive(Error, Debug)] 35 | enum Error { 36 | #[error("Model {0} not supported")] 37 | ModelNotSupported(String), 38 | #[error("Language {0} not supported")] 39 | LanguageNotSupported(String), 40 | } 41 | 42 | impl IntoResponse for Error { 43 | fn into_response(self) -> Response { 44 | match self { 45 | Error::ModelNotSupported(_) => (StatusCode::BAD_REQUEST, format!("{self}")), 46 | Error::LanguageNotSupported(_) => (StatusCode::BAD_REQUEST, format!("{self}")), 47 | } 48 | .into_response() 49 | } 50 | } 51 | #[derive(Deserialize, Serialize)] 52 | struct LanguageResponse { 53 | id: String, 54 | lang: String, 55 | } 56 | 57 | #[derive(Deserialize, Serialize)] 58 | struct ModelResponse { 59 | id: String, 60 | model: String, 61 | } 62 | 63 | #[derive(Deserialize)] 64 | struct ModelParameters { 65 | ignore_cache: bool, 66 | } 67 | 68 | #[derive(Deserialize)] 69 | struct TranscribeParameters { 70 | #[serde(default)] 71 | ignore_cache: bool, 72 | #[serde(default)] 73 | single_segment: bool, 74 | } 75 | 76 | #[derive(EnumIs, Debug, Deserialize, Serialize)] 77 | enum ServerResponse { 78 | FileStarted { 79 | file: String, 80 | }, 81 | FileCompleted { 82 | file: String, 83 | }, 84 | FileProgress { 85 | file: String, 86 | percentage: f32, 87 | elapsed_time: Duration, 88 | remaining_time: Duration, 89 | }, 90 | Failed, 91 | DownloadModelCompleted, 92 | Segment { 93 | start_offset: Duration, 94 | end_offset: Duration, 95 | percentage: f32, 96 | transcription: String, 97 | }, 98 | } 99 | 100 | impl From for ServerResponse { 101 | fn from(value: Event) -> Self { 102 | match value { 103 | Event::DownloadStarted { file } => Self::FileStarted { file }, 104 | Event::DownloadCompleted { file } => Self::FileCompleted { file }, 105 | Event::Segment { 106 | start_offset, 107 | end_offset, 108 | percentage, 109 | transcription, 110 | } => Self::Segment { 111 | start_offset, 112 | end_offset, 113 | percentage, 114 | transcription, 115 | }, 116 | Event::DownloadProgress { 117 | file, 118 | percentage, 119 | elapsed_time, 120 | remaining_time, 121 | } => Self::FileProgress { 122 | file, 123 | percentage, 124 | elapsed_time, 125 | remaining_time, 126 | }, 127 | } 128 | } 129 | } 130 | 131 | #[tokio::main] 132 | async fn main() { 133 | let cli = Cli::parse(); 134 | tracing_subscriber::registry() 135 | .with( 136 | tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { 137 | "simple-whisper-server=debug,tower_http=debug,axum::rejection=trace".into() 138 | }), 139 | ) 140 | .with(tracing_subscriber::fmt::layer()) 141 | .init(); 142 | 143 | let listener = TcpListener::bind(("127.0.0.1", cli.server_port)) 144 | .await 145 | .unwrap(); 146 | serve(listener, app()).await.unwrap(); 147 | } 148 | 149 | fn app() -> Router { 150 | Router::new() 151 | .layer( 152 | TraceLayer::new_for_http().make_span_with(|request: &Request<_>| { 153 | let matched_path = request 154 | .extensions() 155 | .get::() 156 | .map(MatchedPath::as_str); 157 | 158 | info_span!( 159 | "http_request", 160 | method = ?request.method(), 161 | matched_path 162 | ) 163 | }), 164 | ) 165 | .nest("/languages", languages_router()) 166 | .nest("/models", models_router()) 167 | .nest("/transcribe", transcribe_router()) 168 | } 169 | 170 | fn languages_router() -> Router { 171 | Router::new() 172 | .route("/list", get(list_languages)) 173 | .route("/check/:id", get(valid_language)) 174 | } 175 | 176 | async fn list_languages() -> Json> { 177 | Json( 178 | Language::iter() 179 | .map(|l| { 180 | let binding = l.get_message().unwrap(); 181 | let (lang, code) = binding.split_once('-').unwrap(); 182 | LanguageResponse { 183 | id: code.trim().to_owned(), 184 | lang: lang.trim().to_owned(), 185 | } 186 | }) 187 | .collect(), 188 | ) 189 | } 190 | 191 | async fn valid_language(Path(id): Path) -> Result<(), Error> { 192 | Language::from_str(&id) 193 | .map(|_| ()) 194 | .map_err(|_| Error::LanguageNotSupported(id)) 195 | } 196 | 197 | fn models_router() -> Router { 198 | Router::new() 199 | .route("/list", get(list_models)) 200 | .route("/download/:id", get(download_model)) 201 | } 202 | 203 | async fn list_models() -> Json> { 204 | Json( 205 | Model::iter() 206 | .map(|l| { 207 | let binding = l.to_string(); 208 | let (model, code) = binding.split_once('-').unwrap(); 209 | ModelResponse { 210 | id: code.trim().to_owned(), 211 | model: model.trim().to_owned(), 212 | } 213 | }) 214 | .collect(), 215 | ) 216 | } 217 | 218 | async fn download_model( 219 | ws: WebSocketUpgrade, 220 | Path(id): Path, 221 | parameters: Query, 222 | ) -> Response { 223 | let maybe_model: Result = 224 | Model::from_str(&id).map_err(|_| Error::ModelNotSupported(id)); 225 | match maybe_model { 226 | Ok(model) => ws.on_upgrade(|socket| handle_download_model(socket, model, parameters.0)), 227 | Err(err) => err.into_response(), 228 | } 229 | } 230 | 231 | async fn handle_download_model(socket: WebSocket, model: Model, params: ModelParameters) { 232 | let _ = internal_handle_download_model(socket, model, params).await; 233 | } 234 | 235 | async fn internal_handle_download_model( 236 | mut socket: WebSocket, 237 | model: Model, 238 | params: ModelParameters, 239 | ) -> anyhow::Result<()> { 240 | let (tx, mut rx) = unbounded_channel(); 241 | let download = 242 | spawn(async move { model.download_model_listener(params.ignore_cache, tx).await }); 243 | 244 | while let Some(msg) = rx.recv().await { 245 | socket 246 | .send(Message::Text(serde_json::to_string( 247 | &Into::::into(msg), 248 | )?)) 249 | .await?; 250 | } 251 | match download.await { 252 | Ok(_) => { 253 | socket 254 | .send(Message::Text(serde_json::to_string( 255 | &ServerResponse::DownloadModelCompleted, 256 | )?)) 257 | .await? 258 | } 259 | Err(_) => { 260 | socket 261 | .send(Message::Text(serde_json::to_string( 262 | &ServerResponse::Failed, 263 | )?)) 264 | .await? 265 | } 266 | } 267 | Ok(()) 268 | } 269 | 270 | fn transcribe_router() -> Router { 271 | Router::new() 272 | .route("/:model/:lang", get(transcribe)) 273 | .layer(DefaultBodyLimit::max(100 * 1024 * 1024)) 274 | } 275 | 276 | async fn transcribe( 277 | ws: WebSocketUpgrade, 278 | Path((model, lang)): Path<(String, String)>, 279 | parameters: Query, 280 | ) -> Response { 281 | let model = Model::from_str(&model).map_err(|_| Error::ModelNotSupported(model)); 282 | let lang = Language::from_str(&lang).map_err(|_| Error::LanguageNotSupported(lang)); 283 | if let Err(err) = model { 284 | return err.into_response(); 285 | } 286 | 287 | if let Err(err) = lang { 288 | return err.into_response(); 289 | } 290 | 291 | let whisper = WhisperBuilder::default() 292 | .language(lang.unwrap()) 293 | .model(model.unwrap()) 294 | .force_download(parameters.0.ignore_cache) 295 | .force_single_segment(parameters.0.single_segment) 296 | .build() 297 | .unwrap(); 298 | 299 | ws.on_upgrade(|socket| handle_transcription_model(socket, whisper)) 300 | } 301 | 302 | async fn handle_transcription_model(socket: WebSocket, model: Whisper) { 303 | let _ = internal_handle_transcription_model(socket, model).await; 304 | } 305 | 306 | async fn internal_handle_transcription_model( 307 | mut socket: WebSocket, 308 | model: Whisper, 309 | ) -> anyhow::Result<()> { 310 | if let Some(Ok(Message::Binary(data))) = socket.recv().await { 311 | let file = NamedTempFile::new()?; 312 | write(file.path(), data).await?; 313 | let mut stream = model.transcribe(file.path()); 314 | while let Some(msg) = stream.next().await { 315 | match msg { 316 | Ok(msg) => { 317 | if msg.is_segment() { 318 | socket 319 | .send(Message::Text(serde_json::to_string( 320 | &Into::::into(msg), 321 | )?)) 322 | .await?; 323 | } 324 | } 325 | Err(_) => { 326 | socket 327 | .send(Message::Text(serde_json::to_string( 328 | &ServerResponse::Failed, 329 | )?)) 330 | .await? 331 | } 332 | } 333 | } 334 | } 335 | Ok(()) 336 | } 337 | 338 | #[cfg(test)] 339 | mod tests { 340 | use std::future::IntoFuture; 341 | 342 | use axum::serve; 343 | use futures::{SinkExt, StreamExt}; 344 | use reqwest::Client; 345 | use reqwest_websocket::{Message, RequestBuilderExt}; 346 | use tokio::{net::TcpListener, spawn}; 347 | 348 | use crate::{LanguageResponse, ModelResponse, ServerResponse, app}; 349 | 350 | macro_rules! test_file { 351 | ($file_name:expr) => { 352 | concat!(env!("CARGO_MANIFEST_DIR"), "/../assets/", $file_name) 353 | }; 354 | } 355 | 356 | #[tokio::test] 357 | async fn integration_test_languages() { 358 | let listener = TcpListener::bind("127.0.0.1:3000").await.unwrap(); 359 | spawn(serve(listener, app()).into_future()); 360 | 361 | let languages: Vec = reqwest::get("http://127.0.0.1:3000/languages/list") 362 | .await 363 | .unwrap() 364 | .json() 365 | .await 366 | .unwrap(); 367 | assert_eq!(99, languages.len()); 368 | 369 | let good_request = reqwest::get("http://127.0.0.1:3000/languages/check/en") 370 | .await 371 | .unwrap() 372 | .status(); 373 | assert!(good_request.is_success()); 374 | 375 | let bad_request = reqwest::get("http://127.0.0.1:3000/languages/check/zy") 376 | .await 377 | .unwrap() 378 | .status(); 379 | assert_eq!(bad_request.as_u16(), 400); 380 | } 381 | 382 | #[tokio::test] 383 | async fn integration_test_models() { 384 | let listener = TcpListener::bind("127.0.0.1:4000").await.unwrap(); 385 | spawn(serve(listener, app()).into_future()); 386 | 387 | let models: Vec = reqwest::get("http://127.0.0.1:4000/models/list") 388 | .await 389 | .unwrap() 390 | .json() 391 | .await 392 | .unwrap(); 393 | assert_eq!(33, models.len()); 394 | 395 | let websocket = Client::default() 396 | .get("ws://127.0.0.1:4000/models/download/tiny_en?ignore_cache=true") 397 | .upgrade() 398 | .send() 399 | .await 400 | .unwrap() 401 | .into_websocket() 402 | .await 403 | .unwrap(); 404 | 405 | let (_, mut rx) = websocket.split(); 406 | while let Some(Ok(Message::Text(msg))) = rx.next().await { 407 | let msg: ServerResponse = serde_json::from_str(&msg).unwrap(); 408 | println!("{msg:?}"); 409 | assert!( 410 | msg.is_file_started() 411 | || msg.is_file_completed() 412 | || msg.is_file_progress() 413 | || msg.is_download_model_completed() 414 | ) 415 | } 416 | } 417 | 418 | #[ignore] 419 | #[tokio::test] 420 | async fn integration_test_transcription() { 421 | let listener = TcpListener::bind("127.0.0.1:5000").await.unwrap(); 422 | spawn(serve(listener, app()).into_future()); 423 | 424 | let client = Client::new(); 425 | let websocket = client 426 | .get("ws://127.0.0.1:5000/transcribe/tiny/en?single_segment=true") 427 | .upgrade() 428 | .send() 429 | .await 430 | .unwrap() 431 | .into_websocket() 432 | .await 433 | .unwrap(); 434 | 435 | let (mut tx, mut rx) = websocket.split(); 436 | 437 | let data = tokio::fs::read(test_file!("samples_jfk.wav")) 438 | .await 439 | .unwrap(); 440 | tx.send(Message::Binary(data)).await.unwrap(); 441 | 442 | while let Some(Ok(Message::Text(msg))) = rx.next().await { 443 | let msg: ServerResponse = serde_json::from_str(&msg).unwrap(); 444 | println!("{msg:?}"); 445 | } 446 | } 447 | } 448 | -------------------------------------------------------------------------------- /simple-whisper/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "simple-whisper" 3 | version.workspace = true 4 | readme = "README.md" 5 | description = "OpenAI whisper library based on whisper.cpp" 6 | edition.workspace = true 7 | authors.workspace = true 8 | license.workspace = true 9 | repository.workspace = true 10 | keywords.workspace = true 11 | categories.workspace = true 12 | 13 | [dependencies] 14 | whisper-rs = "0.14.2" 15 | derive_builder.workspace = true 16 | hf-hub.workspace = true 17 | num_cpus.workspace = true 18 | rodio.workspace = true 19 | strum.workspace = true 20 | thiserror.workspace = true 21 | tokenizers.workspace = true 22 | tokio.workspace = true 23 | tokio-stream.workspace = true 24 | 25 | [features] 26 | vulkan = ["whisper-rs/vulkan"] 27 | cuda = ["whisper-rs/cuda"] 28 | metal = ["whisper-rs/metal"] 29 | -------------------------------------------------------------------------------- /simple-whisper/README.md: -------------------------------------------------------------------------------- 1 | # Simple Whisper 2 | Implements the Whisper model via [whisper-rs](https://github.com/tazz4843/whisper-rs). 3 | 4 | Weights are automatically downloaded from Hugging Face. 5 | 6 | ## Feature flags 7 | - `vulkan` = enables the Vulkan whisper.cpp backend 8 | - `cuda` = enables the Cuda whisper.cpp backend 9 | - `metal` = enables the Metal whisper.cpp backend 10 | ## Other resources 11 | See [newfla/simple-whisper](https://github.com/newfla/simple-whisper) for prebuilt cli & server binaries 12 | -------------------------------------------------------------------------------- /simple-whisper/src/download.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | path::PathBuf, 3 | sync::{Arc, Mutex}, 4 | time::Instant, 5 | }; 6 | 7 | use hf_hub::{Cache, Repo, api::tokio::Progress}; 8 | use tokio::sync::mpsc::UnboundedSender; 9 | 10 | use crate::{Error, Event}; 11 | 12 | /// Store the state of a download 13 | #[derive(Debug, Clone)] 14 | struct DownloadState { 15 | start_time: Instant, 16 | len: usize, 17 | offset: usize, 18 | url: String, 19 | } 20 | 21 | impl DownloadState { 22 | fn new(len: usize, url: &str) -> DownloadState { 23 | DownloadState { 24 | start_time: Instant::now(), 25 | len, 26 | offset: 0, 27 | url: url.to_string(), 28 | } 29 | } 30 | 31 | fn update(&mut self, delta: usize) -> Option { 32 | if delta == 0 { 33 | return None; 34 | } 35 | 36 | self.offset += delta; 37 | 38 | let elapsed_time = Instant::now() - self.start_time; 39 | 40 | let progress = self.offset as f32 / self.len as f32; 41 | let progress_100 = progress * 100.; 42 | 43 | let remaining_percentage = 100. - progress_100; 44 | let duration_unit = elapsed_time 45 | / if progress_100 as u32 == 0 { 46 | 1 47 | } else { 48 | progress_100 as u32 49 | }; 50 | let remaining_time = duration_unit * remaining_percentage as u32; 51 | 52 | let event = Event::DownloadProgress { 53 | file: self.url.clone(), 54 | percentage: progress_100, 55 | elapsed_time, 56 | remaining_time, 57 | }; 58 | Some(event) 59 | } 60 | } 61 | 62 | #[derive(Clone)] 63 | struct DownloadCallback { 64 | download_state: Arc>>, 65 | tx: UnboundedSender, 66 | } 67 | 68 | impl Progress for DownloadCallback { 69 | async fn init(&mut self, len: usize, file: &str) { 70 | self.download_state = Arc::new(Mutex::new(Some(DownloadState::new(len, file)))); 71 | 72 | let _ = self.tx.send(Event::DownloadStarted { 73 | file: file.to_owned(), 74 | }); 75 | } 76 | 77 | async fn update(&mut self, delta: usize) { 78 | let update = self 79 | .download_state 80 | .lock() 81 | .unwrap() 82 | .as_mut() 83 | .unwrap() 84 | .update(delta); 85 | if let Some(event) = update { 86 | let _ = self.tx.send(event); 87 | } 88 | } 89 | 90 | async fn finish(&mut self) { 91 | let file = self 92 | .download_state 93 | .lock() 94 | .unwrap() 95 | .as_ref() 96 | .unwrap() 97 | .url 98 | .clone(); 99 | let _ = self.tx.send(Event::DownloadCompleted { file }); 100 | } 101 | } 102 | 103 | pub enum ProgressType { 104 | Callback(UnboundedSender), 105 | ProgressBar, 106 | } 107 | 108 | pub async fn download_file( 109 | file: &str, 110 | force_download: bool, 111 | progress: ProgressType, 112 | repo: Repo, 113 | ) -> Result { 114 | let cache = Cache::from_env().repo(repo.clone()); 115 | let mut in_cache = cache.get(file); 116 | if force_download { 117 | in_cache = None 118 | } 119 | if let Some(val) = in_cache { 120 | Ok(val) 121 | } else { 122 | match progress { 123 | ProgressType::ProgressBar => { 124 | hf_hub::api::tokio::ApiBuilder::default() 125 | .with_progress(true) 126 | .build() 127 | .map(|api| api.repo(repo)) 128 | .map_err(Into::::into)? 129 | .download(file) 130 | .await 131 | } 132 | ProgressType::Callback(tx) => { 133 | let progress = DownloadCallback { 134 | download_state: Default::default(), 135 | tx, 136 | }; 137 | hf_hub::api::tokio::ApiBuilder::default() 138 | .with_progress(false) 139 | .build() 140 | .map(|api| api.repo(repo)) 141 | .map_err(Into::::into)? 142 | .download_with_progress(file, progress) 143 | .await 144 | } 145 | } 146 | .map_err(Into::into) 147 | } 148 | } 149 | -------------------------------------------------------------------------------- /simple-whisper/src/language.rs: -------------------------------------------------------------------------------- 1 | use strum::{Display, EnumIs, EnumIter, EnumMessage, EnumString}; 2 | 3 | /// Languages supported by the tokenizer 4 | #[derive(Default, Clone, Copy, Debug, EnumIs, EnumIter, EnumString, Display, EnumMessage)] 5 | pub enum Language { 6 | #[default] 7 | #[strum(serialize = "en", message = "English - en")] 8 | English, 9 | #[strum(serialize = "zh", message = "Chinese - zh")] 10 | Chinese, 11 | #[strum(serialize = "de", message = "German - de")] 12 | German, 13 | #[strum(serialize = "es", message = "Spanish - es")] 14 | Spanish, 15 | #[strum(serialize = "ru", message = "Russian - ru")] 16 | Russian, 17 | #[strum(serialize = "ko", message = "Korean - ko")] 18 | Korean, 19 | #[strum(serialize = "fr", message = "French - fr")] 20 | French, 21 | #[strum(serialize = "ja", message = "Japanese - ja")] 22 | Japanese, 23 | #[strum(serialize = "pt", message = "Portuguese - pt")] 24 | Portuguese, 25 | #[strum(serialize = "tr", message = "Turkish - tr")] 26 | Turkish, 27 | #[strum(serialize = "pl", message = "Polish - pl")] 28 | Polish, 29 | #[strum(serialize = "ca", message = "Catalan - ca")] 30 | Catalan, 31 | #[strum(serialize = "nl", message = "Dutch - nl")] 32 | Dutch, 33 | #[strum(serialize = "ar", message = "Arabic - ar")] 34 | Arabic, 35 | #[strum(serialize = "sv", message = "Swedish - sv")] 36 | Swedish, 37 | #[strum(serialize = "it", message = "Italian - it")] 38 | Italian, 39 | #[strum(serialize = "id", message = "Indonesian - id")] 40 | Indonesian, 41 | #[strum(serialize = "hi", message = "Hindi - hi")] 42 | Hindi, 43 | #[strum(serialize = "fi", message = "Finnish - fi")] 44 | Finnish, 45 | #[strum(serialize = "vi", message = "Vietnamese - vi")] 46 | Vietnamese, 47 | #[strum(serialize = "he", message = "Hebrew - he")] 48 | Hebrew, 49 | #[strum(serialize = "uk", message = "Ukrainian - uk")] 50 | Ukrainian, 51 | #[strum(serialize = "el", message = "Greek - el")] 52 | Greek, 53 | #[strum(serialize = "ms", message = "Malay - ms")] 54 | Malay, 55 | #[strum(serialize = "cs", message = "Czech - cs")] 56 | Czech, 57 | #[strum(serialize = "ro", message = "Romanian - ro")] 58 | Romanian, 59 | #[strum(serialize = "da", message = "Danish - da")] 60 | Danish, 61 | #[strum(serialize = "hu", message = "Hungarian - hu")] 62 | Hungarian, 63 | #[strum(serialize = "ta", message = "Tamil - ta")] 64 | Tamil, 65 | #[strum(serialize = "no", message = "Norwegian - no")] 66 | Norwegian, 67 | #[strum(serialize = "th", message = "Thai - th")] 68 | Thai, 69 | #[strum(serialize = "ur", message = "Urdu - ur")] 70 | Urdu, 71 | #[strum(serialize = "hr", message = "Croatian - hr")] 72 | Croatian, 73 | #[strum(serialize = "bg", message = "Bulgarian - bg")] 74 | Bulgarian, 75 | #[strum(serialize = "lt", message = "Lithuanian - lt")] 76 | Lithuanian, 77 | #[strum(serialize = "la", message = "Latin - la")] 78 | Latin, 79 | #[strum(serialize = "mi", message = "Maori - mi")] 80 | Maori, 81 | #[strum(serialize = "ml", message = "Malayalam - ml")] 82 | Malayalam, 83 | #[strum(serialize = "cy", message = "Welsh - cy")] 84 | Welsh, 85 | #[strum(serialize = "sk", message = "Slovak - sk")] 86 | Slovak, 87 | #[strum(serialize = "te", message = "Telugu - te")] 88 | Telugu, 89 | #[strum(serialize = "fa", message = "Persian - fa")] 90 | Persian, 91 | #[strum(serialize = "lv", message = "Latvian - lv")] 92 | Latvian, 93 | #[strum(serialize = "bn", message = "Bengali - bn")] 94 | Bengali, 95 | #[strum(serialize = "sr", message = "Serbian - sr")] 96 | Serbian, 97 | #[strum(serialize = "az", message = "Azerbaijani - az")] 98 | Azerbaijani, 99 | #[strum(serialize = "sl", message = "Slovenian - sl")] 100 | Slovenian, 101 | #[strum(serialize = "kn", message = "Kannada - kn")] 102 | Kannada, 103 | #[strum(serialize = "et", message = "Estonian - et")] 104 | Estonian, 105 | #[strum(serialize = "mk", message = "Macedonian - mk")] 106 | Macedonian, 107 | #[strum(serialize = "br", message = "Breton - br")] 108 | Breton, 109 | #[strum(serialize = "eu", message = "Basque - eu")] 110 | Basque, 111 | #[strum(serialize = "is", message = "Icelandic - is")] 112 | Icelandic, 113 | #[strum(serialize = "hy", message = "Armenian - hy")] 114 | Armenian, 115 | #[strum(serialize = "ne", message = "Nepali - ne")] 116 | Nepali, 117 | #[strum(serialize = "mn", message = "Mongolian - mn")] 118 | Mongolian, 119 | #[strum(serialize = "bs", message = "Bosnian - bs")] 120 | Bosnian, 121 | #[strum(serialize = "kk", message = "Kazakh - kk")] 122 | Kazakh, 123 | #[strum(serialize = "sq", message = "Albanian - sq")] 124 | Albanian, 125 | #[strum(serialize = "sw", message = "Swahili - sw")] 126 | Swahili, 127 | #[strum(serialize = "gl", message = "Galician - gl")] 128 | Galician, 129 | #[strum(serialize = "mr", message = "Marathi - mr")] 130 | Marathi, 131 | #[strum(serialize = "pa", message = "Punjabi - pa")] 132 | Punjabi, 133 | #[strum(serialize = "si", message = "Sinhala - si")] 134 | Sinhala, 135 | #[strum(serialize = "km", message = "Khmer - km")] 136 | Khmer, 137 | #[strum(serialize = "sn", message = "Shona - sn")] 138 | Shona, 139 | #[strum(serialize = "yo", message = "Yoruba - yo")] 140 | Yoruba, 141 | #[strum(serialize = "so", message = "Somali - so")] 142 | Somali, 143 | #[strum(serialize = "af", message = "Afrikaans - af")] 144 | Afrikaans, 145 | #[strum(serialize = "oc", message = "Occitan - oc")] 146 | Occitan, 147 | #[strum(serialize = "ka", message = "Georgian - ka")] 148 | Georgian, 149 | #[strum(serialize = "be", message = "Belarusian - be")] 150 | Belarusian, 151 | #[strum(serialize = "tg", message = "Tajik - tg")] 152 | Tajik, 153 | #[strum(serialize = "sd", message = "Sindhi - sd")] 154 | Sindhi, 155 | #[strum(serialize = "gu", message = "Gujarati - gu")] 156 | Gujarati, 157 | #[strum(serialize = "am", message = "Amharic - am")] 158 | Amharic, 159 | #[strum(serialize = "yi", message = "Yiddish - yi")] 160 | Yiddish, 161 | #[strum(serialize = "lo", message = "Lao - lo")] 162 | Lao, 163 | #[strum(serialize = "uz", message = "Uzbek - uz")] 164 | Uzbek, 165 | #[strum(serialize = "fo", message = "Faroese - fo")] 166 | Faroese, 167 | #[strum(serialize = "ht", message = "HaitianCreole - ht")] 168 | HaitianCreole, 169 | #[strum(serialize = "ps", message = "Pashto - ps")] 170 | Pashto, 171 | #[strum(serialize = "tk", message = "Turkmen - tk")] 172 | Turkmen, 173 | #[strum(serialize = "nn", message = "Nynorsk - nn")] 174 | Nynorsk, 175 | #[strum(serialize = "mt", message = "Maltese - mt")] 176 | Maltese, 177 | #[strum(serialize = "sa", message = "Sanskrit - sa")] 178 | Sanskrit, 179 | #[strum(serialize = "lb", message = "Luxembourgish - lb")] 180 | Luxembourgish, 181 | #[strum(serialize = "my", message = "Myanmar - my")] 182 | Myanmar, 183 | #[strum(serialize = "bo", message = "Tibetan - bo")] 184 | Tibetan, 185 | #[strum(serialize = "tl", message = "Tagalog - tl")] 186 | Tagalog, 187 | #[strum(serialize = "mg", message = "Malagasy - mg")] 188 | Malagasy, 189 | #[strum(serialize = "as", message = "Assamese - as")] 190 | Assamese, 191 | #[strum(serialize = "tt", message = "Tatar - tt")] 192 | Tatar, 193 | #[strum(serialize = "haw", message = "Hawaiian - haw")] 194 | Hawaiian, 195 | #[strum(serialize = "ln", message = "Lingala - ln")] 196 | Lingala, 197 | #[strum(serialize = "ha", message = "Hausa - ha")] 198 | Hausa, 199 | #[strum(serialize = "ba", message = "Bashkir - ba")] 200 | Bashkir, 201 | #[strum(serialize = "jw", message = "Javanese - jw")] 202 | Javanese, 203 | #[strum(serialize = "su", message = "Sundanese - su")] 204 | Sundanese, 205 | } 206 | -------------------------------------------------------------------------------- /simple-whisper/src/lib.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | fs::File, 3 | io::{self, BufReader}, 4 | path::{Path, PathBuf}, 5 | sync::Arc, 6 | time::Duration, 7 | }; 8 | 9 | use derive_builder::Builder; 10 | 11 | mod download; 12 | mod language; 13 | mod model; 14 | mod transcribe; 15 | 16 | use download::ProgressType; 17 | pub use language::Language; 18 | pub use model::Model; 19 | use rodio::{Decoder, Source, source::UniformSourceIterator}; 20 | use strum::{Display, EnumIs}; 21 | use thiserror::Error; 22 | use tokio::{ 23 | spawn, 24 | sync::{Notify, mpsc::unbounded_channel}, 25 | task::spawn_blocking, 26 | }; 27 | pub use transcribe::TranscribeBuilderError; 28 | 29 | use tokio_stream::{Stream, wrappers::UnboundedReceiverStream}; 30 | use transcribe::TranscribeBuilder; 31 | use whisper_rs::WhisperError; 32 | 33 | type Barrier = Arc; 34 | 35 | pub const SAMPLE_RATE: u32 = 16000; 36 | 37 | /// The Whisper audio transcription model. 38 | #[derive(Default, Builder, Debug)] 39 | #[builder(setter(into), build_fn(validate = "Self::validate"))] 40 | pub struct Whisper { 41 | language: Language, 42 | model: Model, 43 | #[builder(default = "false")] 44 | progress_bar: bool, 45 | #[builder(default = "false")] 46 | force_download: bool, 47 | #[builder(default = "false")] 48 | force_single_segment: bool, 49 | } 50 | 51 | /// Error conditions 52 | #[derive(Error, Debug)] 53 | pub enum Error { 54 | /// Error that can occur during model files download from huggingface 55 | #[error(transparent)] 56 | Download(#[from] hf_hub::api::tokio::ApiError), 57 | #[error(transparent)] 58 | Io(#[from] io::Error), 59 | /// Error that can occur during audio file decoding phase 60 | #[error(transparent)] 61 | AudioDecoder(#[from] rodio::decoder::DecoderError), 62 | /// The library was unable to determine the audio file duration 63 | #[error("Unable to find duration")] 64 | AudioDuration, 65 | #[error(transparent)] 66 | /// Missing parameters to instantiate the whisper cpp backend 67 | ComputeBuilder(#[from] TranscribeBuilderError), 68 | #[error(transparent)] 69 | Whisper(#[from] WhisperError), 70 | } 71 | 72 | /// Events generated by the [Whisper::transcribe] method 73 | #[derive(Clone, Debug, Display, EnumIs)] 74 | pub enum Event { 75 | #[strum(to_string = "Downloading {file}")] 76 | DownloadStarted { file: String }, 77 | #[strum(to_string = "{file} has been downloaded")] 78 | DownloadCompleted { file: String }, 79 | #[strum( 80 | to_string = "Downloading {file} --> {percentage} {elapsed_time:#?} | {remaining_time:#?}" 81 | )] 82 | DownloadProgress { 83 | /// The resource to download 84 | file: String, 85 | 86 | /// The progress expressed as % 87 | percentage: f32, 88 | 89 | /// Time elapsed since the download as being started 90 | elapsed_time: Duration, 91 | 92 | /// Estimated time to complete the download 93 | remaining_time: Duration, 94 | }, 95 | /// Audio chunk transcript 96 | #[strum(to_string = "{transcription}")] 97 | Segment { 98 | start_offset: Duration, 99 | end_offset: Duration, 100 | percentage: f32, 101 | transcription: String, 102 | }, 103 | } 104 | 105 | impl WhisperBuilder { 106 | fn validate(&self) -> Result<(), WhisperBuilderError> { 107 | if self.language.as_ref().is_some_and(|l| !l.is_english()) 108 | && self.model.as_ref().is_some_and(|m| !m.is_multilingual()) 109 | { 110 | let err = format!( 111 | "The requested language {} is not compatible with {} model", 112 | self.language.as_ref().unwrap(), 113 | self.model.as_ref().unwrap() 114 | ); 115 | return Err(WhisperBuilderError::ValidationError(err)); 116 | } 117 | Ok(()) 118 | } 119 | } 120 | 121 | impl Whisper { 122 | /// Transcribe an audio file into text. 123 | pub fn transcribe(self, path: impl AsRef) -> impl Stream> { 124 | let (tx, rx) = unbounded_channel(); 125 | let (tx_event, mut rx_event) = unbounded_channel(); 126 | 127 | let wait_download = Barrier::default(); 128 | let download_completed = wait_download.clone(); 129 | 130 | let path = path.as_ref().into(); 131 | 132 | // Download events forwarder 133 | let tx_forwarder = tx.clone(); 134 | spawn(async move { 135 | while let Some(msg) = rx_event.recv().await { 136 | let _ = tx_forwarder.send(Ok(msg)); 137 | } 138 | wait_download.notify_one(); 139 | }); 140 | 141 | spawn(async move { 142 | // Download model data from Hugging Face 143 | let progress = if self.progress_bar { 144 | drop(tx_event); 145 | ProgressType::ProgressBar 146 | } else { 147 | ProgressType::Callback(tx_event) 148 | }; 149 | let model = self 150 | .model 151 | .internal_download_model(self.force_download, progress) 152 | .await; 153 | download_completed.notified().await; 154 | 155 | spawn_blocking(move || { 156 | // Load audio file 157 | let audio = Self::load_audio(path); 158 | 159 | match audio.map(|audio| (audio, model)) { 160 | Ok((audio, Ok(model_files))) => { 161 | match TranscribeBuilder::default() 162 | .language(self.language) 163 | .audio(audio) 164 | .single_segment(self.force_single_segment) 165 | .tx(tx.clone()) 166 | .model(model_files) 167 | .build() 168 | { 169 | Ok(compute) => compute.transcribe(), 170 | Err(err) => { 171 | let _ = tx.send(Err(err.into())); 172 | } 173 | } 174 | } 175 | Ok((_, Err(err))) => { 176 | let _ = tx.send(Err(err)); 177 | } 178 | Err(err) => { 179 | let _ = tx.send(Err(err)); 180 | } 181 | } 182 | }); 183 | }); 184 | 185 | UnboundedReceiverStream::new(rx) 186 | } 187 | 188 | fn load_audio(path: PathBuf) -> Result<(Vec, Duration), Error> { 189 | let reader = BufReader::new(File::open(&path)?); 190 | let decoder = Decoder::new(reader)?; 191 | let resample: UniformSourceIterator>, f32> = 192 | UniformSourceIterator::new(decoder, 1, SAMPLE_RATE); 193 | let samples = resample 194 | .low_pass(3000) 195 | .high_pass(200) 196 | .convert_samples() 197 | .collect::>(); 198 | 199 | let duration = Self::get_audio_duration(samples.len()); 200 | 201 | Ok((samples, duration)) 202 | } 203 | 204 | fn get_audio_duration(samples: usize) -> Duration { 205 | let secs = samples as f64 / SAMPLE_RATE as f64; 206 | Duration::from_secs_f64(secs) 207 | } 208 | } 209 | 210 | #[cfg(test)] 211 | mod tests { 212 | use tokio_stream::StreamExt; 213 | 214 | use super::*; 215 | 216 | macro_rules! test_file { 217 | ($file_name:expr) => { 218 | concat!(env!("CARGO_MANIFEST_DIR"), "/../assets/", $file_name) 219 | }; 220 | } 221 | 222 | #[test] 223 | fn incompatible_lang_model() { 224 | let error = WhisperBuilder::default() 225 | .language(Language::Italian) 226 | .model(Model::BaseEn) 227 | .build() 228 | .unwrap_err(); 229 | assert!(matches!(error, WhisperBuilderError::ValidationError(_))); 230 | } 231 | 232 | #[test] 233 | fn compatible_lang_model() { 234 | WhisperBuilder::default() 235 | .language(Language::Italian) 236 | .model(Model::Base) 237 | .build() 238 | .unwrap(); 239 | } 240 | 241 | #[ignore] 242 | #[tokio::test] 243 | async fn simple_transcribe_ok() { 244 | let mut rx = WhisperBuilder::default() 245 | .language(Language::English) 246 | .model(Model::Tiny) 247 | .progress_bar(true) 248 | .build() 249 | .unwrap() 250 | .transcribe(test_file!("samples_jfk.wav")); 251 | 252 | while let Some(msg) = rx.next().await { 253 | assert!(msg.is_ok()); 254 | println!("{msg:?}"); 255 | } 256 | } 257 | } 258 | -------------------------------------------------------------------------------- /simple-whisper/src/model.rs: -------------------------------------------------------------------------------- 1 | use std::path::PathBuf; 2 | 3 | use hf_hub::{Cache, Repo}; 4 | use strum::{Display, EnumIter, EnumString}; 5 | use tokio::sync::mpsc::UnboundedSender; 6 | 7 | use crate::{ 8 | Error, Event, 9 | download::{ProgressType, download_file}, 10 | }; 11 | 12 | struct HFCoordinates { 13 | repo: Repo, 14 | model: String, 15 | } 16 | 17 | /// OpenAI supported models 18 | #[derive(Default, Clone, Debug, EnumIter, EnumString, Display)] 19 | #[strum(serialize_all = "snake_case")] 20 | pub enum Model { 21 | /// The tiny model. 22 | #[strum(serialize = "tiny", to_string = "Tiny - tiny")] 23 | Tiny, 24 | /// The tiny-q5_1 model. 25 | #[strum(serialize = "tiny-q5_1", to_string = "Tiny - tiny-q5_1")] 26 | TinyQ5_1, 27 | /// The tiny-q8_0 model. 28 | #[strum(serialize = "tiny-q8_0", to_string = "Tiny - tiny-q8_0")] 29 | TinyQ8_0, 30 | /// The tiny model with only English support. 31 | #[strum(serialize = "tiny_en", to_string = "TinyEn - tiny_en")] 32 | TinyEn, 33 | /// The tiny-q5_1 model with only English support. 34 | #[strum(serialize = "tiny_en-q5_1", to_string = "TinyEn - tiny_en-q5_1")] 35 | TinyEnQ5_1, 36 | /// The tiny-q8_0 model with only English support. 37 | #[strum(serialize = "tiny_en-q8_0", to_string = "Tiny - tiny_en-q8_0")] 38 | TinyEnQ8_0, 39 | /// The base model. 40 | #[default] 41 | #[strum(serialize = "base", to_string = "Base - base")] 42 | Base, 43 | /// The base-q5_1 model. 44 | #[strum(serialize = "base-q5_1", to_string = "Base - base-q5_1")] 45 | BaseQ5_1, 46 | /// The base-q8_0 model. 47 | #[strum(serialize = "base-q8_0", to_string = "Base - base-q8_0")] 48 | BaseQ8_0, 49 | /// The base model with only English support. 50 | #[strum(serialize = "base_en", to_string = "BaseEn - base_en")] 51 | BaseEn, 52 | /// The base-q5_1 model with only English support. 53 | #[strum(serialize = "base_en-q5_1", to_string = "BaseEn -base_en-q5_1")] 54 | BaseEnQ5_1, 55 | /// The base-q8_0 model with only English support. 56 | #[strum(serialize = "base_en-q8_0", to_string = "BaseEn - base_en-q8_0")] 57 | BaseEnQ8_0, 58 | /// The small model. 59 | #[strum(serialize = "small", to_string = "Small - small")] 60 | Small, 61 | /// The small-q5_1 model. 62 | #[strum(serialize = "small-q5_1", to_string = "Small - small-q5_1")] 63 | SmallQ5_1, 64 | /// The small-q8_0 model. 65 | #[strum(serialize = "small-q8_0", to_string = "Small - small-q8_0")] 66 | SmallQ8_0, 67 | /// The small model with only English support. 68 | #[strum(serialize = "small_en", to_string = "SmallEn - small_en")] 69 | SmallEn, 70 | /// The small-q5_1 model with only English support. 71 | #[strum(serialize = "small_en-q5_1", to_string = "SmallEn - small_en-q5_1")] 72 | SmallEnQ5_1, 73 | /// The small-q8_0 model with only English support. 74 | #[strum(serialize = "small_en-q8_0", to_string = "SmallEn - small_en-q8_0")] 75 | SmallEnQ8_0, 76 | /// The medium model. 77 | #[strum(serialize = "medium", to_string = "Medium - medium")] 78 | Medium, 79 | /// The medium-q5_0 model. 80 | #[strum(serialize = "medium-q5_0", to_string = "Medium - medium-q5_0")] 81 | MediumQ5_0, 82 | /// The medium-q8_0 model. 83 | #[strum(serialize = "medium-q8_0", to_string = "Medium - medium-q8_0")] 84 | MediumQ8_0, 85 | /// The medium model with only English support. 86 | #[strum(serialize = "medium_en", to_string = "MediumEn - medium_en")] 87 | MediumEn, 88 | /// The medium-q5_0 model with only English support. 89 | #[strum(serialize = "medium_en-q5_0 ", to_string = "MediumEn - medium_en-q5_0")] 90 | MediumEnQ5_0, 91 | /// The medium-q8_0 model with only English support. 92 | #[strum(serialize = "medium_en-q8_0", to_string = "MediumEn - medium_en-q8_0")] 93 | MediumEnQ8_0, 94 | /// The large model. 95 | #[strum(serialize = "large", to_string = "Large V1 - large")] 96 | Large, 97 | /// The large model v2. 98 | #[strum(serialize = "large_v2", to_string = "Large V2 - large_v2")] 99 | LargeV2, 100 | #[strum(serialize = "large_v2-q5_0", to_string = "Large V2 - large_v2-q5_0")] 101 | LargeV2Q5_0, 102 | #[strum(serialize = "large_v2-q8_0", to_string = "Large V2 - large_v2-q8_0")] 103 | LargeV2Q8_0, 104 | /// The large model v3. 105 | #[strum(serialize = "large_v3", to_string = "Large V3 - large_v3")] 106 | LargeV3, 107 | /// The large_v3-q5_0 model v3. 108 | #[strum(serialize = "large_v3-q5_0", to_string = "Large V3 - large_v3-q5_0")] 109 | LargeV3Q5_0, 110 | /// The large model v3 turbo. 111 | #[strum( 112 | serialize = "large_v3_turbo", 113 | to_string = "Large V3 Turbo - large_v3_turbo" 114 | )] 115 | LargeV3Turbo, 116 | /// The large_v3_turbo-q5_0 model v3 turbo. 117 | #[strum( 118 | serialize = "large_v3_turbo-q5_0", 119 | to_string = "Large V3 Turbo - large_v3_turbo-q5_0" 120 | )] 121 | LargeV3TurboQ5_0, 122 | /// The large_v3_turbo-q8_0 model v3 turbo. 123 | #[strum( 124 | serialize = "large_v3_turbo-q8_0", 125 | to_string = "Large V3 Turbo - large_v3_turbo-q8_0" 126 | )] 127 | LargeV3TurboQ8_0, 128 | } 129 | 130 | impl Model { 131 | fn hf_coordinates(&self) -> HFCoordinates { 132 | let repo = Repo::with_revision( 133 | "ggerganov/whisper.cpp".to_owned(), 134 | hf_hub::RepoType::Model, 135 | "main".to_owned(), 136 | ); 137 | match self { 138 | Model::Tiny => HFCoordinates { 139 | repo, 140 | model: "ggml-tiny.bin".to_owned(), 141 | }, 142 | Model::TinyEn => HFCoordinates { 143 | repo, 144 | model: "ggml-tiny.en.bin".to_owned(), 145 | }, 146 | Model::Base => HFCoordinates { 147 | repo, 148 | model: "ggml-base.bin".to_owned(), 149 | }, 150 | Model::BaseEn => HFCoordinates { 151 | repo, 152 | model: "ggml-base.en.bin".to_owned(), 153 | }, 154 | Model::Small => HFCoordinates { 155 | repo, 156 | model: "ggml-small.bin".to_owned(), 157 | }, 158 | Model::SmallEn => HFCoordinates { 159 | repo, 160 | model: "ggml-small.en.bin".to_owned(), 161 | }, 162 | Model::Medium => HFCoordinates { 163 | repo, 164 | model: "ggml-medium.bin".to_owned(), 165 | }, 166 | Model::MediumEn => HFCoordinates { 167 | repo, 168 | model: "ggml-medium.en.bin".to_owned(), 169 | }, 170 | Model::Large => HFCoordinates { 171 | repo, 172 | model: "ggml-large-v1.bin".to_owned(), 173 | }, 174 | Model::LargeV2 => HFCoordinates { 175 | repo, 176 | model: "ggml-large-v2.bin".to_owned(), 177 | }, 178 | Model::LargeV3 => HFCoordinates { 179 | repo, 180 | model: "ggml-large-v3.bin".to_owned(), 181 | }, 182 | Model::TinyQ5_1 => HFCoordinates { 183 | repo, 184 | model: "ggml-tiny-q5_1.bin".to_owned(), 185 | }, 186 | Model::TinyQ8_0 => HFCoordinates { 187 | repo, 188 | model: "ggml-tiny-q8_0.bin".to_owned(), 189 | }, 190 | Model::TinyEnQ5_1 => HFCoordinates { 191 | repo, 192 | model: "ggml-tiny.en-q5_1.bin".to_owned(), 193 | }, 194 | Model::TinyEnQ8_0 => HFCoordinates { 195 | repo, 196 | model: "ggml-tiny.en-q8_0.bin".to_owned(), 197 | }, 198 | Model::BaseQ5_1 => HFCoordinates { 199 | repo, 200 | model: "ggml-base-q5_1.bin".to_owned(), 201 | }, 202 | Model::BaseQ8_0 => HFCoordinates { 203 | repo, 204 | model: "ggml-base-q8_0.bin".to_owned(), 205 | }, 206 | Model::BaseEnQ5_1 => HFCoordinates { 207 | repo, 208 | model: "ggml-base.en-q5_1.bin".to_owned(), 209 | }, 210 | Model::BaseEnQ8_0 => HFCoordinates { 211 | repo, 212 | model: "ggml-base.en-q8_0.bin".to_owned(), 213 | }, 214 | Model::SmallQ5_1 => HFCoordinates { 215 | repo, 216 | model: "ggml-small-q5_1.bin".to_owned(), 217 | }, 218 | Model::SmallQ8_0 => HFCoordinates { 219 | repo, 220 | model: "ggml-small-q8_0.bin".to_owned(), 221 | }, 222 | Model::SmallEnQ5_1 => HFCoordinates { 223 | repo, 224 | model: "ggml-small.en-q5_1.bin".to_owned(), 225 | }, 226 | Model::SmallEnQ8_0 => HFCoordinates { 227 | repo, 228 | model: "ggml-small.en-q8_0.bin".to_owned(), 229 | }, 230 | Model::MediumQ5_0 => HFCoordinates { 231 | repo, 232 | model: "ggml-medium-q5_0.bin".to_owned(), 233 | }, 234 | Model::MediumQ8_0 => HFCoordinates { 235 | repo, 236 | model: "ggml-medium-q8_0.bin".to_owned(), 237 | }, 238 | Model::MediumEnQ5_0 => HFCoordinates { 239 | repo, 240 | model: "ggml-medium.en-q5_0.bin".to_owned(), 241 | }, 242 | Model::MediumEnQ8_0 => HFCoordinates { 243 | repo, 244 | model: "ggml-medium.en-q8_0.bin".to_owned(), 245 | }, 246 | Model::LargeV2Q5_0 => HFCoordinates { 247 | repo, 248 | model: "ggml-large-v2-q5_0.bin".to_owned(), 249 | }, 250 | Model::LargeV2Q8_0 => HFCoordinates { 251 | repo, 252 | model: "ggml-large-v2-q8_0.bin".to_owned(), 253 | }, 254 | Model::LargeV3Q5_0 => HFCoordinates { 255 | repo, 256 | model: "ggml-large-v3-q5_0.bin".to_owned(), 257 | }, 258 | Model::LargeV3Turbo => HFCoordinates { 259 | repo, 260 | model: "ggml-large-v3-turbo.bin".to_owned(), 261 | }, 262 | Model::LargeV3TurboQ5_0 => HFCoordinates { 263 | repo, 264 | model: "ggml-large-v3-turbo-q5_0.bin".to_owned(), 265 | }, 266 | Model::LargeV3TurboQ8_0 => HFCoordinates { 267 | repo, 268 | model: "ggml-large-v3-turbo-q8_0.bin".to_owned(), 269 | }, 270 | } 271 | } 272 | 273 | /// True if the model supports multiple languages, false otherwise. 274 | pub fn is_multilingual(&self) -> bool { 275 | !self.to_string().contains("en") 276 | } 277 | 278 | /// Check if the file model has been cached before 279 | pub fn cached(&self) -> bool { 280 | let coordinates = self.hf_coordinates(); 281 | let cache = Cache::from_env().repo(coordinates.repo); 282 | cache.get(&coordinates.model).is_some() 283 | } 284 | 285 | pub(crate) async fn internal_download_model( 286 | &self, 287 | force_download: bool, 288 | progress: ProgressType, 289 | ) -> Result { 290 | let coordinates = self.hf_coordinates(); 291 | 292 | download_file( 293 | &coordinates.model, 294 | force_download, 295 | progress, 296 | coordinates.repo, 297 | ) 298 | .await 299 | } 300 | 301 | pub async fn download_model(&self, force_download: bool) -> Result { 302 | self.internal_download_model(force_download, ProgressType::ProgressBar) 303 | .await 304 | } 305 | 306 | pub async fn download_model_listener( 307 | &self, 308 | force_download: bool, 309 | tx: UnboundedSender, 310 | ) -> Result { 311 | self.internal_download_model(force_download, ProgressType::Callback(tx)) 312 | .await 313 | } 314 | } 315 | -------------------------------------------------------------------------------- /simple-whisper/src/transcribe.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | path::{Path, PathBuf}, 3 | time::Duration, 4 | }; 5 | 6 | use derive_builder::Builder; 7 | use tokio::sync::mpsc::UnboundedSender; 8 | use whisper_rs::{ 9 | FullParams, SamplingStrategy, SegmentCallbackData, WhisperContext, WhisperContextParameters, 10 | WhisperError, WhisperState, 11 | }; 12 | 13 | use crate::{Error, Event, Language}; 14 | 15 | #[derive(Builder)] 16 | #[builder( 17 | setter(into), 18 | pattern = "owned", 19 | build_fn(skip, error = "TranscribeBuilderError") 20 | )] 21 | pub struct Transcribe { 22 | language: Language, 23 | audio: (Vec, Duration), 24 | tx: UnboundedSender>, 25 | #[builder(setter(name = "model"))] 26 | _model: PathBuf, 27 | #[builder(setter(skip))] 28 | state: WhisperState, 29 | single_segment: bool, 30 | } 31 | 32 | impl TranscribeBuilder { 33 | pub fn build(self) -> Result { 34 | if self.language.is_none() { 35 | return Err(TranscribeBuilderError::UninitializedFieldError("language")); 36 | } 37 | 38 | if self.audio.is_none() { 39 | return Err(TranscribeBuilderError::UninitializedFieldError("audio")); 40 | } 41 | 42 | if self.tx.is_none() { 43 | return Err(TranscribeBuilderError::UninitializedFieldError("tx")); 44 | } 45 | 46 | if self._model.is_none() { 47 | return Err(TranscribeBuilderError::UninitializedFieldError("model")); 48 | } 49 | 50 | let state = state_builder(self._model.as_ref().unwrap())?; 51 | 52 | Ok(Transcribe { 53 | language: self.language.unwrap(), 54 | audio: self.audio.unwrap(), 55 | tx: self.tx.unwrap(), 56 | _model: self._model.unwrap(), 57 | state, 58 | single_segment: self.single_segment.unwrap_or(false), 59 | }) 60 | } 61 | } 62 | 63 | /// Error type for TrascriveBuilder 64 | #[derive(Error, Debug)] 65 | pub enum TranscribeBuilderError { 66 | #[error("Field not initialized: {0}")] 67 | UninitializedFieldError(&'static str), 68 | #[error(transparent)] 69 | WhisperCppError(#[from] WhisperError), 70 | } 71 | 72 | fn state_builder(model: &Path) -> Result { 73 | let mut context_param = WhisperContextParameters::default(); 74 | 75 | context_param.use_gpu(true); 76 | 77 | let ctx = WhisperContext::new_with_params(model.to_str().unwrap(), context_param)?; 78 | 79 | ctx.create_state() 80 | } 81 | 82 | impl Transcribe { 83 | pub fn transcribe(mut self) { 84 | let tx_callback = self.tx.downgrade(); 85 | 86 | let (audio, duration) = &self.audio; 87 | let duration = *duration; 88 | let lang = self.language.to_string(); 89 | 90 | let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 0 }); 91 | params.set_single_segment(self.single_segment); 92 | params.set_n_threads(num_cpus::get().try_into().unwrap()); 93 | params.set_language(Some(&lang)); 94 | params.set_print_special(false); 95 | params.set_print_progress(false); 96 | params.set_print_timestamps(false); 97 | 98 | params.set_segment_callback_safe(move |seg: SegmentCallbackData| { 99 | let start_offset = Duration::from_millis(seg.start_timestamp as u64 * 10); 100 | let end_offset = Duration::from_millis(seg.end_timestamp as u64 * 10); 101 | let mut percentage = end_offset.as_millis() as f32 / duration.as_millis() as f32; 102 | if percentage > 1. { 103 | percentage = 1.; 104 | } 105 | let seg = Event::Segment { 106 | start_offset, 107 | end_offset, 108 | percentage, 109 | transcription: seg.text, 110 | }; 111 | let _ = tx_callback.upgrade().unwrap().send(Ok(seg)); 112 | }); 113 | 114 | if let Err(err) = self.state.full(params, audio) { 115 | let _ = self.tx.send(Err(Error::Whisper(err))); 116 | } 117 | } 118 | } 119 | --------------------------------------------------------------------------------