├── .dockerignore ├── .github ├── FUNDING.yml ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── feature_request.md │ └── hallucination.md ├── pull_request_template.md └── workflows │ ├── ci.yml │ ├── publish-docker.yml │ └── publish-ghcr.yml ├── .gitignore ├── Dockerfile ├── Install.bat ├── Install.sh ├── LICENSE ├── README.md ├── app.py ├── backend ├── Dockerfile ├── README.md ├── __init__.py ├── cache │ └── cached_files_are_generated_here ├── common │ ├── audio.py │ ├── cache_manager.py │ ├── compresser.py │ ├── config_loader.py │ └── models.py ├── configs │ ├── .env.example │ └── config.yaml ├── db │ ├── __init__.py │ ├── db_instance.py │ └── task │ │ ├── __init__.py │ │ ├── dao.py │ │ └── models.py ├── docker-compose.yaml ├── main.py ├── nginx │ ├── logs │ │ └── logs_are_generated_here │ ├── nginx.conf │ └── temp │ │ └── temps_are_generated_here ├── requirements-backend.txt ├── routers │ ├── __init__.py │ ├── bgm_separation │ │ ├── __init__.py │ │ ├── models.py │ │ └── router.py │ ├── task │ │ ├── __init__.py │ │ └── router.py │ ├── transcription │ │ ├── __init__.py │ │ └── router.py │ └── vad │ │ ├── __init__.py │ │ └── router.py └── tests │ ├── __init__.py │ ├── test_backend_bgm_separation.py │ ├── test_backend_config.py │ ├── test_backend_transcription.py │ ├── test_backend_vad.py │ └── test_task_status.py ├── configs ├── default_parameters.yaml └── translation.yaml ├── docker-compose.yaml ├── models ├── Diarization │ └── diarization_models_will_be_saved_here ├── NLLB │ └── nllb_models_will_be_saved_here ├── UVR │ └── uvr_models_will_be_saved_here ├── Whisper │ ├── faster-whisper │ │ └── faster_whisper_models_will_be_saved_here │ ├── insanely-fast-whisper │ │ └── insanely_fast_whisper_models_will_be_saved_here │ └── whisper_models_will_be_saved_here └── models_will_be_saved_here ├── modules ├── __init__.py ├── diarize │ ├── __init__.py │ ├── audio_loader.py │ ├── diarize_pipeline.py │ └── diarizer.py ├── translation │ ├── __init__.py │ ├── deepl_api.py │ ├── nllb_inference.py │ └── translation_base.py ├── ui │ ├── __init__.py │ └── htmls.py ├── utils │ ├── __init__.py │ ├── audio_manager.py │ ├── cli_manager.py │ ├── constants.py │ ├── files_manager.py │ ├── logger.py │ ├── paths.py │ ├── subtitle_manager.py │ └── youtube_manager.py ├── uvr │ └── music_separator.py ├── vad │ ├── __init__.py │ └── silero_vad.py └── whisper │ ├── __init__.py │ ├── base_transcription_pipeline.py │ ├── data_classes.py │ ├── faster_whisper_inference.py │ ├── insanely_fast_whisper_inference.py │ ├── whisper_Inference.py │ └── whisper_factory.py ├── notebook └── whisper-webui.ipynb ├── outputs ├── UVR │ ├── instrumental │ │ └── UVR_outputs_for_instrumental_will_be_saved_here │ └── vocals │ │ └── UVR_outputs_for_vocals_will_be_saved_here ├── outputs_will_be_saved_here └── translations │ └── translation_outputs_will_be_saved_here ├── requirements.txt ├── start-webui.bat ├── start-webui.sh ├── tests ├── test_bgm_separation.py ├── test_config.py ├── test_diarization.py ├── test_srt.srt ├── test_transcription.py ├── test_translation.py ├── test_vad.py └── test_vtt.vtt └── user-start-webui.bat /.dockerignore: -------------------------------------------------------------------------------- 1 | # from .gitignore 2 | modules/yt_tmp.wav 3 | **/venv/ 4 | **/__pycache__/ 5 | **/outputs/ 6 | **/models/ 7 | 8 | **/.idea 9 | **/.git 10 | **/.github 11 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: [jhj0517] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 13 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 14 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug 6 | assignees: jhj0517 7 | 8 | --- 9 | 10 | **Which OS are you using?** 11 | - OS: [e.g. iOS or Windows.. If you are using Google Colab, just Colab.] 12 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Any feature you want 4 | title: '' 5 | labels: enhancement 6 | assignees: jhj0517 7 | 8 | --- 9 | 10 | 11 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/hallucination.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Hallucination 3 | about: Whisper hallucinations. ( Repeating certain words or subtitles starting too 4 | early, etc. ) 5 | title: '' 6 | labels: hallucination 7 | assignees: jhj0517 8 | 9 | --- 10 | 11 | **Download URL for sample audio** 12 | - Please upload download URL for sample audio file so I can test with some settings for better result. You can use https://easyupload.io/ or any other service to share. 13 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | ## Related issues / PRs. Summarize issues. 2 | - # 3 | 4 | ## Summarize Changes 5 | 1. 6 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | workflow_dispatch: 5 | 6 | push: 7 | branches: 8 | - master 9 | - intel-gpu 10 | pull_request: 11 | branches: 12 | - master 13 | - intel-gpu 14 | 15 | jobs: 16 | test: 17 | runs-on: ubuntu-latest 18 | strategy: 19 | matrix: 20 | python: ["3.10", "3.11", "3.12"] 21 | 22 | env: 23 | DEEPL_API_KEY: ${{ secrets.DEEPL_API_KEY }} 24 | 25 | steps: 26 | - name: Clean up space for action 27 | run: rm -rf /opt/hostedtoolcache 28 | 29 | - uses: actions/checkout@v4 30 | - name: Setup Python 31 | uses: actions/setup-python@v5 32 | with: 33 | python-version: ${{ matrix.python }} 34 | 35 | - name: Install git and ffmpeg 36 | run: sudo apt-get update && sudo apt-get install -y git ffmpeg 37 | 38 | - name: Install dependencies 39 | run: pip install -r requirements.txt pytest jiwer 40 | 41 | - name: Run test 42 | run: python -m pytest -rs tests 43 | 44 | test-backend: 45 | runs-on: ubuntu-latest 46 | strategy: 47 | matrix: 48 | python: ["3.10", "3.11", "3.12"] 49 | 50 | env: 51 | DEEPL_API_KEY: ${{ secrets.DEEPL_API_KEY }} 52 | TEST_ENV: true 53 | 54 | steps: 55 | - name: Clean up space for action 56 | run: rm -rf /opt/hostedtoolcache 57 | 58 | - uses: actions/checkout@v4 59 | - name: Setup Python 60 | uses: actions/setup-python@v5 61 | with: 62 | python-version: ${{ matrix.python }} 63 | 64 | - name: Install git and ffmpeg 65 | run: sudo apt-get update && sudo apt-get install -y git ffmpeg 66 | 67 | - name: Install dependencies 68 | run: pip install -r backend/requirements-backend.txt pytest pytest-asyncio jiwer 69 | 70 | - name: Run test 71 | run: python -m pytest -rs backend/tests 72 | 73 | test-shell-script: 74 | runs-on: ubuntu-latest 75 | strategy: 76 | matrix: 77 | python: [ "3.10", "3.11", "3.12" ] 78 | 79 | steps: 80 | - name: Clean up space for action 81 | run: rm -rf /opt/hostedtoolcache 82 | 83 | - uses: actions/checkout@v4 84 | - name: Setup Python 85 | uses: actions/setup-python@v5 86 | with: 87 | python-version: ${{ matrix.python }} 88 | 89 | - name: Install git and ffmpeg 90 | run: sudo apt-get update && sudo apt-get install -y git ffmpeg 91 | 92 | - name: Execute Install.sh 93 | run: | 94 | chmod +x ./Install.sh 95 | ./Install.sh 96 | 97 | - name: Execute start-webui.sh 98 | run: | 99 | chmod +x ./start-webui.sh 100 | timeout 60s ./start-webui.sh || true 101 | 102 | -------------------------------------------------------------------------------- /.github/workflows/publish-docker.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Docker Hub 2 | 3 | on: 4 | # Triggers minor version ( vX.Y.Z-ShortHash ) 5 | push: 6 | branches: 7 | - master 8 | # Triggers major version ( vX.Y.Z ) 9 | release: 10 | types: [created] 11 | 12 | workflow_dispatch: 13 | 14 | jobs: 15 | build-and-push: 16 | runs-on: ubuntu-latest 17 | 18 | strategy: 19 | matrix: 20 | name: [whisper-webui, whisper-webui-backend] 21 | 22 | steps: 23 | - name: Clean up space for action 24 | run: rm -rf /opt/hostedtoolcache 25 | 26 | - name: Checkout repository 27 | uses: actions/checkout@v3 28 | 29 | - name: Set up Docker Buildx 30 | uses: docker/setup-buildx-action@v3 31 | 32 | - name: Set up QEMU 33 | uses: docker/setup-qemu-action@v3 34 | 35 | - name: Extract metadata 36 | id: meta 37 | run: | 38 | SHORT_SHA=$(git rev-parse --short HEAD) 39 | echo "SHORT_SHA=$SHORT_SHA" >> $GITHUB_ENV 40 | 41 | # Triggered by a release event — versioning as major ( vX.Y.Z ) 42 | if [[ "${GITHUB_EVENT_NAME}" == "release" ]]; then 43 | TAG_NAME="${{ github.event.release.tag_name }}" 44 | echo "GIT_TAG=$TAG_NAME" >> $GITHUB_ENV 45 | echo "IS_RELEASE=true" >> $GITHUB_ENV 46 | 47 | # Triggered by a general push event — versioning as minor ( vX.Y.Z-ShortHash ) 48 | else 49 | git fetch --tags 50 | LATEST_TAG=$(git tag --list 'v*.*.*' | sort -V | tail -n1) 51 | FALLBACK_TAG="${LATEST_TAG:-v0.0.0}" 52 | echo "GIT_TAG=${FALLBACK_TAG}-${SHORT_SHA}" >> $GITHUB_ENV 53 | echo "IS_RELEASE=false" >> $GITHUB_ENV 54 | fi 55 | 56 | - name: Set Dockerfile path 57 | id: dockerfile 58 | run: | 59 | if [ "${{ matrix.name }}" = "whisper-webui" ]; then 60 | echo "DOCKERFILE=./Dockerfile" >> $GITHUB_ENV 61 | elif [ "${{ matrix.name }}" = "whisper-webui-backend" ]; then 62 | echo "DOCKERFILE=./backend/Dockerfile" >> $GITHUB_ENV 63 | else 64 | echo "Unknown component: ${{ matrix.name }}" 65 | exit 1 66 | fi 67 | 68 | - name: Log in to Docker Hub 69 | uses: docker/login-action@v2 70 | with: 71 | username: ${{ secrets.DOCKER_USERNAME }} 72 | password: ${{ secrets.DOCKER_PASSWORD }} 73 | 74 | - name: Build and push Docker image (version tag) 75 | uses: docker/build-push-action@v5 76 | with: 77 | context: . 78 | file: ${{ env.DOCKERFILE }} 79 | push: true 80 | tags: | 81 | ${{ secrets.DOCKER_USERNAME }}/${{ matrix.name }}:${{ env.GIT_TAG }} 82 | 83 | - name: Tag and push as latest (if release) 84 | if: env.IS_RELEASE == 'true' 85 | run: | 86 | docker pull ${{ secrets.DOCKER_USERNAME }}/${{ matrix.name }}:${{ env.GIT_TAG }} 87 | docker tag ${{ secrets.DOCKER_USERNAME }}/${{ matrix.name }}:${{ env.GIT_TAG }} \ 88 | ${{ secrets.DOCKER_USERNAME }}/${{ matrix.name }}:latest 89 | docker push ${{ secrets.DOCKER_USERNAME }}/${{ matrix.name }}:latest 90 | 91 | - name: Log out of Docker Hub 92 | run: docker logout 93 | -------------------------------------------------------------------------------- /.github/workflows/publish-ghcr.yml: -------------------------------------------------------------------------------- 1 | name: Publish to GHCR 2 | 3 | on: 4 | # Triggers minor version ( vX.Y.Z-ShortHash ) 5 | push: 6 | branches: 7 | - master 8 | # Triggers major version ( vX.Y.Z ) 9 | release: 10 | types: [created] 11 | 12 | workflow_dispatch: 13 | 14 | jobs: 15 | build-and-push: 16 | runs-on: ubuntu-latest 17 | permissions: 18 | packages: write 19 | 20 | strategy: 21 | matrix: 22 | name: [whisper-webui, whisper-webui-backend] 23 | 24 | steps: 25 | - name: Checkout repository 26 | uses: actions/checkout@v3 27 | 28 | - name: Set up Docker Buildx 29 | uses: docker/setup-buildx-action@v3 30 | 31 | - name: Set up QEMU 32 | uses: docker/setup-qemu-action@v3 33 | 34 | - name: Extract metadata 35 | id: meta 36 | run: | 37 | SHORT_SHA=$(git rev-parse --short HEAD) 38 | echo "SHORT_SHA=$SHORT_SHA" >> $GITHUB_ENV 39 | 40 | # Triggered by a release event — versioning as major ( vX.Y.Z ) 41 | if [[ "${GITHUB_EVENT_NAME}" == "release" ]]; then 42 | TAG_NAME="${{ github.event.release.tag_name }}" 43 | echo "GIT_TAG=$TAG_NAME" >> $GITHUB_ENV 44 | echo "IS_RELEASE=true" >> $GITHUB_ENV 45 | 46 | # Triggered by a general push event — versioning as minor ( vX.Y.Z-ShortHash ) 47 | else 48 | git fetch --tags 49 | LATEST_TAG=$(git tag --list 'v*.*.*' | sort -V | tail -n1) 50 | FALLBACK_TAG="${LATEST_TAG:-v0.0.0}" 51 | echo "GIT_TAG=${FALLBACK_TAG}-${SHORT_SHA}" >> $GITHUB_ENV 52 | echo "IS_RELEASE=false" >> $GITHUB_ENV 53 | fi 54 | 55 | echo "REPO_OWNER_LC=${GITHUB_REPOSITORY_OWNER,,}" >> $GITHUB_ENV 56 | 57 | - name: Set Dockerfile path 58 | id: dockerfile 59 | run: | 60 | if [ "${{ matrix.name }}" = "whisper-webui" ]; then 61 | echo "DOCKERFILE=./Dockerfile" >> $GITHUB_ENV 62 | elif [ "${{ matrix.name }}" = "whisper-webui-backend" ]; then 63 | echo "DOCKERFILE=./backend/Dockerfile" >> $GITHUB_ENV 64 | else 65 | echo "Unknown component: ${{ matrix.name }}" 66 | exit 1 67 | fi 68 | 69 | - name: Log in to GitHub Container Registry 70 | uses: docker/login-action@v2 71 | with: 72 | registry: ghcr.io 73 | username: ${{ github.actor }} 74 | password: ${{ secrets.GITHUB_TOKEN }} 75 | 76 | - name: Build and push Docker image (version tag) 77 | uses: docker/build-push-action@v5 78 | with: 79 | context: . 80 | file: ${{ env.DOCKERFILE }} 81 | push: true 82 | tags: | 83 | ghcr.io/${{ env.REPO_OWNER_LC }}/${{ matrix.name }}:${{ env.GIT_TAG }} 84 | 85 | - name: Tag and push as latest (if release) 86 | if: env.IS_RELEASE == 'true' 87 | run: | 88 | docker pull ghcr.io/${{ env.REPO_OWNER_LC }}/${{ matrix.name }}:${{ env.GIT_TAG }} 89 | docker tag ghcr.io/${{ env.REPO_OWNER_LC }}/${{ matrix.name }}:${{ env.GIT_TAG }} \ 90 | ghcr.io/${{ env.REPO_OWNER_LC }}/${{ matrix.name }}:latest 91 | docker push ghcr.io/${{ env.REPO_OWNER_LC }}/${{ matrix.name }}:latest 92 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.wav 2 | *.png 3 | *.mp4 4 | *.mp3 5 | **/.env 6 | **/.idea/ 7 | **/.pytest_cache/ 8 | **/venv/ 9 | **/__pycache__/ 10 | outputs/ 11 | models/ 12 | modules/yt_tmp.wav 13 | configs/default_parameters.yaml -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM debian:bookworm-slim AS builder 2 | 3 | RUN apt-get update && \ 4 | apt-get install -y curl git python3 python3-pip python3-venv && \ 5 | rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* && \ 6 | mkdir -p /Whisper-WebUI 7 | 8 | WORKDIR /Whisper-WebUI 9 | 10 | COPY requirements.txt . 11 | 12 | RUN python3 -m venv venv && \ 13 | . venv/bin/activate && \ 14 | pip install -U -r requirements.txt 15 | 16 | 17 | FROM debian:bookworm-slim AS runtime 18 | 19 | RUN apt-get update && \ 20 | apt-get install -y curl ffmpeg python3 && \ 21 | rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* 22 | 23 | WORKDIR /Whisper-WebUI 24 | 25 | COPY . . 26 | COPY --from=builder /Whisper-WebUI/venv /Whisper-WebUI/venv 27 | 28 | VOLUME [ "/Whisper-WebUI/models" ] 29 | VOLUME [ "/Whisper-WebUI/outputs" ] 30 | 31 | ENV PATH="/Whisper-WebUI/venv/bin:$PATH" 32 | ENV LD_LIBRARY_PATH=/Whisper-WebUI/venv/lib64/python3.11/site-packages/nvidia/cublas/lib:/Whisper-WebUI/venv/lib64/python3.11/site-packages/nvidia/cudnn/lib 33 | 34 | ENTRYPOINT [ "python", "app.py" ] 35 | -------------------------------------------------------------------------------- /Install.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | if not exist "%~dp0\venv\Scripts" ( 4 | echo Creating venv... 5 | python -m venv venv 6 | ) 7 | echo checked the venv folder. now installing requirements.. 8 | 9 | call "%~dp0\venv\scripts\activate" 10 | 11 | python -m pip install -U pip 12 | pip install -r requirements.txt 13 | 14 | if errorlevel 1 ( 15 | echo. 16 | echo Requirements installation failed. please remove venv folder and run install.bat again. 17 | ) else ( 18 | echo. 19 | echo Requirements installed successfully. 20 | ) 21 | pause -------------------------------------------------------------------------------- /Install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ ! -d "venv" ]; then 4 | echo "Creating virtual environment..." 5 | python -m venv venv 6 | fi 7 | 8 | source venv/bin/activate 9 | 10 | python -m pip install -U pip 11 | pip install -r requirements.txt && echo "Requirements installed successfully." || { 12 | echo "" 13 | echo "Requirements installation failed. Please remove the venv folder and run the script again." 14 | deactivate 15 | exit 1 16 | } 17 | 18 | deactivate 19 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Whisper-WebUI 2 | A Gradio-based browser interface for [Whisper](https://github.com/openai/whisper). You can use it as an Easy Subtitle Generator! 3 | 4 | ![screen](https://github.com/user-attachments/assets/caea3afd-a73c-40af-a347-8d57914b1d0f) 5 | 6 | 7 | 8 | ## Notebook 9 | If you wish to try this on Colab, you can do it in [here](https://colab.research.google.com/github/jhj0517/Whisper-WebUI/blob/master/notebook/whisper-webui.ipynb)! 10 | 11 | # Feature 12 | - Select the Whisper implementation you want to use between : 13 | - [openai/whisper](https://github.com/openai/whisper) 14 | - [SYSTRAN/faster-whisper](https://github.com/SYSTRAN/faster-whisper) (used by default) 15 | - [Vaibhavs10/insanely-fast-whisper](https://github.com/Vaibhavs10/insanely-fast-whisper) 16 | - Generate subtitles from various sources, including : 17 | - Files 18 | - Youtube 19 | - Microphone 20 | - Currently supported subtitle formats : 21 | - SRT 22 | - WebVTT 23 | - txt ( only text file without timeline ) 24 | - Speech to Text Translation 25 | - From other languages to English. ( This is Whisper's end-to-end speech-to-text translation feature ) 26 | - Text to Text Translation 27 | - Translate subtitle files using Facebook NLLB models 28 | - Translate subtitle files using DeepL API 29 | - Pre-processing audio input with [Silero VAD](https://github.com/snakers4/silero-vad). 30 | - Pre-processing audio input to separate BGM with [UVR](https://github.com/Anjok07/ultimatevocalremovergui). 31 | - Post-processing with speaker diarization using the [pyannote](https://huggingface.co/pyannote/speaker-diarization-3.1) model. 32 | - To download the pyannote model, you need to have a Huggingface token and manually accept their terms in the pages below. 33 | 1. https://huggingface.co/pyannote/speaker-diarization-3.1 34 | 2. https://huggingface.co/pyannote/segmentation-3.0 35 | 36 | ### Pipeline Diagram 37 | ![Transcription Pipeline](https://github.com/user-attachments/assets/1d8c63ac-72a4-4a0b-9db0-e03695dcf088) 38 | 39 | # Installation and Running 40 | 41 | - ## Running with Pinokio 42 | 43 | The app is able to run with [Pinokio](https://github.com/pinokiocomputer/pinokio). 44 | 45 | 1. Install [Pinokio Software](https://program.pinokio.computer/#/?id=install). 46 | 2. Open the software and search for Whisper-WebUI and install it. 47 | 3. Start the Whisper-WebUI and connect to the `http://localhost:7860`. 48 | 49 | - ## Running with Docker 50 | 51 | 1. Install and launch [Docker-Desktop](https://www.docker.com/products/docker-desktop/). 52 | 53 | 2. Git clone the repository 54 | 55 | ```sh 56 | git clone https://github.com/jhj0517/Whisper-WebUI.git 57 | ``` 58 | 59 | 3. Build the image ( Image is about 7GB~ ) 60 | 61 | ```sh 62 | docker compose build 63 | ``` 64 | 65 | 4. Run the container 66 | 67 | ```sh 68 | docker compose up 69 | ``` 70 | 71 | 5. Connect to the WebUI with your browser at `http://localhost:7860` 72 | 73 | If needed, update the [`docker-compose.yaml`](https://github.com/jhj0517/Whisper-WebUI/blob/master/docker-compose.yaml) to match your environment. 74 | 75 | - ## Run Locally 76 | 77 | ### Prerequisite 78 | To run this WebUI, you need to have `git`, `3.10 <= python <= 3.12`, `FFmpeg`. 79 | 80 | **Edit `--extra-index-url` in the [`requirements.txt`](https://github.com/jhj0517/Whisper-WebUI/blob/master/requirements.txt) to match your device.
** 81 | By default, the WebUI assumes you're using an Nvidia GPU and **CUDA 12.6.** If you're using Intel or another CUDA version, read the [`requirements.txt`](https://github.com/jhj0517/Whisper-WebUI/blob/master/requirements.txt) and edit `--extra-index-url`. 82 | 83 | Please follow the links below to install the necessary software: 84 | - git : [https://git-scm.com/downloads](https://git-scm.com/downloads) 85 | - python : [https://www.python.org/downloads/](https://www.python.org/downloads/) **`3.10 ~ 3.12` is recommended.** 86 | - FFmpeg : [https://ffmpeg.org/download.html](https://ffmpeg.org/download.html) 87 | - CUDA : [https://developer.nvidia.com/cuda-downloads](https://developer.nvidia.com/cuda-downloads) 88 | 89 | After installing FFmpeg, **make sure to add the `FFmpeg/bin` folder to your system PATH!** 90 | 91 | ### Installation Using the Script Files 92 | 93 | 1. git clone this repository 94 | ```shell 95 | git clone https://github.com/jhj0517/Whisper-WebUI.git 96 | ``` 97 | 2. Run `install.bat` or `install.sh` to install dependencies. (It will create a `venv` directory and install dependencies there.) 98 | 3. Start WebUI with `start-webui.bat` or `start-webui.sh` (It will run `python app.py` after activating the venv) 99 | 100 | And you can also run the project with command line arguments if you like to, see [wiki](https://github.com/jhj0517/Whisper-WebUI/wiki/Command-Line-Arguments) for a guide to arguments. 101 | 102 | # VRAM Usages 103 | This project is integrated with [faster-whisper](https://github.com/guillaumekln/faster-whisper) by default for better VRAM usage and transcription speed. 104 | 105 | According to faster-whisper, the efficiency of the optimized whisper model is as follows: 106 | | Implementation | Precision | Beam size | Time | Max. GPU memory | Max. CPU memory | 107 | |-------------------|-----------|-----------|-------|-----------------|-----------------| 108 | | openai/whisper | fp16 | 5 | 4m30s | 11325MB | 9439MB | 109 | | faster-whisper | fp16 | 5 | 54s | 4755MB | 3244MB | 110 | 111 | If you want to use an implementation other than faster-whisper, use `--whisper_type` arg and the repository name.
112 | Read [wiki](https://github.com/jhj0517/Whisper-WebUI/wiki/Command-Line-Arguments) for more info about CLI args. 113 | 114 | If you want to use a fine-tuned model, manually place the models in `models/Whisper/` corresponding to the implementation. 115 | 116 | Alternatively, if you enter the huggingface repo id (e.g, [deepdml/faster-whisper-large-v3-turbo-ct2](https://huggingface.co/deepdml/faster-whisper-large-v3-turbo-ct2)) in the "Model" dropdown, it will be automatically downloaded in the directory. 117 | 118 | ![image](https://github.com/user-attachments/assets/76487a46-b0a5-4154-b735-ded73b2d83d4) 119 | 120 | # REST API 121 | If you're interested in deploying this app as a REST API, please check out [/backend](https://github.com/jhj0517/Whisper-WebUI/tree/master/backend). 122 | 123 | ## TODO🗓 124 | 125 | - [x] Add DeepL API translation 126 | - [x] Add NLLB Model translation 127 | - [x] Integrate with faster-whisper 128 | - [x] Integrate with insanely-fast-whisper 129 | - [x] Integrate with whisperX ( Only speaker diarization part ) 130 | - [x] Add background music separation pre-processing with [UVR](https://github.com/Anjok07/ultimatevocalremovergui) 131 | - [x] Add fast api script 132 | - [ ] Add CLI usages 133 | - [ ] Support real-time transcription for microphone 134 | 135 | ### Translation 🌐 136 | Any PRs that translate the language into [translation.yaml](https://github.com/jhj0517/Whisper-WebUI/blob/master/configs/translation.yaml) would be greatly appreciated! 137 | -------------------------------------------------------------------------------- /backend/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM debian:bookworm-slim AS builder 2 | 3 | RUN apt-get update && \ 4 | apt-get install -y curl git python3 python3-pip python3-venv && \ 5 | rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* && \ 6 | mkdir -p /Whisper-WebUI 7 | 8 | WORKDIR /Whisper-WebUI 9 | 10 | COPY backend/ backend/ 11 | COPY requirements.txt requirements.txt 12 | 13 | RUN python3 -m venv venv && \ 14 | . venv/bin/activate && \ 15 | pip install -U -r backend/requirements-backend.txt 16 | 17 | 18 | FROM debian:bookworm-slim AS runtime 19 | 20 | RUN apt-get update && \ 21 | apt-get install -y curl ffmpeg python3 && \ 22 | rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* 23 | 24 | WORKDIR /Whisper-WebUI 25 | 26 | COPY . . 27 | COPY --from=builder /Whisper-WebUI/venv /Whisper-WebUI/venv 28 | 29 | VOLUME [ "/Whisper-WebUI/models" ] 30 | VOLUME [ "/Whisper-WebUI/outputs" ] 31 | VOLUME [ "/Whisper-WebUI/backend" ] 32 | 33 | ENV PATH="/Whisper-WebUI/venv/bin:$PATH" 34 | ENV LD_LIBRARY_PATH=/Whisper-WebUI/venv/lib64/python3.11/site-packages/nvidia/cublas/lib:/Whisper-WebUI/venv/lib64/python3.11/site-packages/nvidia/cudnn/lib 35 | 36 | ENTRYPOINT ["uvicorn", "backend.main:app", "--host", "0.0.0.0", "--port", "8000"] 37 | -------------------------------------------------------------------------------- /backend/README.md: -------------------------------------------------------------------------------- 1 | # Whisper-WebUI REST API 2 | REST API for Whisper-WebUI. Documentation is auto-generated upon deploying the app. 3 |
[Swagger UI](https://github.com/swagger-api/swagger-ui) is available at `app/docs` or root URL with redirection. [Redoc](https://github.com/Redocly/redoc) is available at `app/redoc`. 4 | 5 | # Setup and Installation 6 | 7 | Installation assumes that you are in the root directory of Whisper-WebUI 8 | 9 | 1. Create `.env` in `backend/configs/.env` 10 | ``` 11 | HF_TOKEN="YOUR_HF_TOKEN FOR DIARIZATION MODEL (READ PERMISSION)" 12 | DB_URL="sqlite:///backend/records.db" 13 | ``` 14 | `HF_TOKEN` is used to download diarization model, `DB_URL` indicates where your db file is located. It is stored in `backend/` by default. 15 | 16 | 2. Install dependency 17 | ``` 18 | pip install -r backend/requirements-backend.txt 19 | ``` 20 | 21 | 3. Deploy the server with `uvicorn` or whatever. 22 | ``` 23 | uvicorn backend.main:app --host 0.0.0.0 --port 8000 24 | ``` 25 | 26 | ### Deploy with your domain name 27 | You can deploy the server with your domain name by setting up a reverse proxy with Nginx. 28 | 29 | 1. Install Nginx if you don't already have it. 30 | - Linux : https://nginx.org/en/docs/install.html 31 | - Windows : https://nginx.org/en/docs/windows.html 32 | 33 | 2. Edit [`nginx.conf`](https://github.com/jhj0517/Whisper-WebUI/blob/master/backend/nginx/nginx.conf) for your domain name. 34 | https://github.com/jhj0517/Whisper-WebUI/blob/895cafe400944396ad8be5b1cc793b54fecc8bbe/backend/nginx/nginx.conf#L12 35 | 36 | 3. Add an A type record of your public IPv4 address in your domain provider. (you can get it by searching "What is my IP" in Google) 37 | 38 | 4. Open a terminal and go to the location of [`nginx.conf`](https://github.com/jhj0517/Whisper-WebUI/blob/master/backend/nginx/nginx.conf), then start the nginx server, so that you can manage nginx-related logs there. 39 | ```shell 40 | cd backend/nginx 41 | nginx -c "/path/to/Whisper-WebUI/backend/nginx/nginx.conf" 42 | ``` 43 | 44 | 5. Open another terminal in the root project location `/Whisper-WebUI`, and deploy the app with `uvicorn` or whatever. Now the app will be available at your domain. 45 | ```shell 46 | uvicorn backend.main:app --host 0.0.0.0 --port 8000 47 | ``` 48 | 49 | 6. When you turn off nginx, you can use `nginx -s stop`. 50 | ```shell 51 | cd backend/nginx 52 | nginx -s stop -c "/path/to/Whisper-WebUI/backend/nginx/nginx.conf" 53 | ``` 54 | 55 | 56 | ## Configuration 57 | You can set some server configurations in [config.yaml](https://github.com/jhj0517/Whisper-WebUI/blob/master/backend/configs/config.yaml). 58 |
For example, initial model size for Whisper or the cleanup frequency and TTL for cached files. 59 |
If the endpoint generates and saves the file, all output files are stored in the `cache` directory, e.g. separated vocal/instrument files for `/bgm-separation` are saved in `cache` directory. 60 | 61 | ## Docker 62 | You can also deploy the server with Docker for easy deployment. 63 | The Dockerfile should be built when you're in the root directory of Whisper-WebUI. 64 | 65 | 1. git clone this repository 66 | ``` 67 | git clone https://github.com/jhj0517/Whisper-WebUI.git 68 | ``` 69 | 2. Mount volume paths with your local paths in `docker-compose.yaml` 70 | https://github.com/jhj0517/Whisper-WebUI/blob/1dd708ec3844dbf0c1f77de9ef5764e883dd4c78/backend/docker-compose.yaml#L12-L15 71 | 3. Build the image 72 | ``` 73 | docker compose -f backend/docker-compose.yaml build 74 | ``` 75 | 4. Run the container 76 | ``` 77 | docker compose -f backend/docker-compose.yaml up 78 | ``` 79 | 80 | 5. Then you can read docs at `localhost:8000` (default port is set to `8000` in `docker-compose.yaml`) and run your own tests. 81 | 82 | 83 | # Architecture 84 | 85 | ![diagram](https://github.com/user-attachments/assets/37d2ab2d-4eb4-4513-bb7b-027d0d631971) 86 | 87 | The response can be obtained through [the polling API](https://docs.oracle.com/en/cloud/saas/marketing/responsys-develop/API/REST/Async/asyncApi-v1.3-requests-requestId-get.htm). 88 | Each task is stored in the DB whenever the task is queued or updated by the process. 89 | 90 | When the client first sends the `POST` request, the server returns an `identifier` to the client that can be used to track the status of the task. The task status is updated by the processes, and once the task is completed, the client can finally obtain the result. 91 | 92 | The client needs to implement manual API polling to do this, this is the example for the python client: 93 | ```python 94 | def wait_for_task_completion(identifier: str, 95 | max_attempts: int = 20, 96 | frequency: int = 3) -> httpx.Response: 97 | """ 98 | Polls the task status every `frequency` until it is completed, failed, or the `max_attempts` are reached. 99 | """ 100 | attempts = 0 101 | while attempts < max_attempts: 102 | task = fetch_task(identifier) 103 | status = task.json()["status"] 104 | if status == "COMPLETED": 105 | return task["result"] 106 | if status == "FAILED": 107 | raise Exception("Task polling failed") 108 | time.sleep(frequency) 109 | attempts += 1 110 | return None 111 | ``` 112 | -------------------------------------------------------------------------------- /backend/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/backend/__init__.py -------------------------------------------------------------------------------- /backend/cache/cached_files_are_generated_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/backend/cache/cached_files_are_generated_here -------------------------------------------------------------------------------- /backend/common/audio.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | import numpy as np 3 | import httpx 4 | import faster_whisper 5 | from pydantic import BaseModel 6 | from fastapi import ( 7 | HTTPException, 8 | UploadFile, 9 | ) 10 | from typing import Annotated, Any, BinaryIO, Literal, Generator, Union, Optional, List, Tuple 11 | 12 | 13 | class AudioInfo(BaseModel): 14 | duration: float 15 | 16 | 17 | async def read_audio( 18 | file: Optional[UploadFile] = None, 19 | file_url: Optional[str] = None 20 | ): 21 | """Read audio from "UploadFile". This resamples sampling rates to 16000.""" 22 | if (file and file_url) or (not file and not file_url): 23 | raise HTTPException(status_code=400, detail="Provide only one of file or file_url") 24 | 25 | if file: 26 | file_content = await file.read() 27 | elif file_url: 28 | async with httpx.AsyncClient() as client: 29 | file_response = await client.get(file_url) 30 | if file_response.status_code != 200: 31 | raise HTTPException(status_code=422, detail="Could not download the file") 32 | file_content = file_response.content 33 | file_bytes = BytesIO(file_content) 34 | audio = faster_whisper.audio.decode_audio(file_bytes) 35 | duration = len(audio) / 16000 36 | return audio, AudioInfo(duration=duration) 37 | -------------------------------------------------------------------------------- /backend/common/cache_manager.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | from typing import Optional 4 | 5 | from modules.utils.paths import BACKEND_CACHE_DIR 6 | 7 | 8 | def cleanup_old_files(cache_dir: str = BACKEND_CACHE_DIR, ttl: int = 60): 9 | now = time.time() 10 | place_holder_name = "cached_files_are_generated_here" 11 | for root, dirs, files in os.walk(cache_dir): 12 | for filename in files: 13 | if filename == place_holder_name: 14 | continue 15 | filepath = os.path.join(root, filename) 16 | if now - os.path.getmtime(filepath) > ttl: 17 | try: 18 | os.remove(filepath) 19 | except Exception as e: 20 | print(f"Error removing {filepath}") 21 | raise 22 | -------------------------------------------------------------------------------- /backend/common/compresser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import zipfile 3 | from typing import List, Optional 4 | import hashlib 5 | 6 | 7 | def compress_files(file_paths: List[str], output_zip_path: str) -> str: 8 | """ 9 | Compress multiple files into a single zip file. 10 | 11 | Args: 12 | file_paths (List[str]): List of paths to files to be compressed. 13 | output_zip (str): Path and name of the output zip file. 14 | 15 | Raises: 16 | FileNotFoundError: If any of the input files doesn't exist. 17 | """ 18 | os.makedirs(os.path.dirname(output_zip_path), exist_ok=True) 19 | compression = zipfile.ZIP_DEFLATED 20 | 21 | with zipfile.ZipFile(output_zip_path, 'w', compression=compression) as zipf: 22 | for file_path in file_paths: 23 | if not os.path.exists(file_path): 24 | raise FileNotFoundError(f"File not found: {file_path}") 25 | 26 | file_name = os.path.basename(file_path) 27 | zipf.write(file_path, file_name) 28 | return output_zip_path 29 | 30 | 31 | def get_file_hash(file_path: str) -> str: 32 | """Generate the hash of a file using the specified hashing algorithm. It generates hash by content not path. """ 33 | hash_func = hashlib.new("sha256") 34 | try: 35 | with open(file_path, 'rb') as f: 36 | for chunk in iter(lambda: f.read(4096), b""): 37 | hash_func.update(chunk) 38 | return hash_func.hexdigest() 39 | except FileNotFoundError: 40 | return f"File not found: {file_path}" 41 | except Exception as e: 42 | return f"An error occurred: {str(e)}" 43 | 44 | 45 | def find_file_by_hash(dir_path: str, hash_str: str) -> Optional[str]: 46 | """Get file path from the directory based on its hash""" 47 | if not os.path.exists(dir_path) and os.path.isdir(dir_path): 48 | raise ValueError(f"Directory {dir_path} does not exist") 49 | 50 | files = [os.path.join(dir_path, f) for f in os.listdir(dir_path) if os.path.isfile(os.path.join(dir_path, f))] 51 | 52 | for f in files: 53 | f_hash = get_file_hash(f) 54 | if hash_str == f_hash: 55 | return f 56 | return None 57 | 58 | 59 | -------------------------------------------------------------------------------- /backend/common/config_loader.py: -------------------------------------------------------------------------------- 1 | from dotenv import load_dotenv 2 | import os 3 | from modules.utils.paths import SERVER_CONFIG_PATH, SERVER_DOTENV_PATH 4 | from modules.utils.files_manager import load_yaml, save_yaml 5 | 6 | import functools 7 | 8 | 9 | @functools.lru_cache 10 | def load_server_config(config_path: str = SERVER_CONFIG_PATH) -> dict: 11 | if os.getenv("TEST_ENV", "false").lower() == "true": 12 | server_config = load_yaml(config_path) 13 | server_config["whisper"]["model_size"] = "tiny" 14 | server_config["whisper"]["compute_type"] = "float32" 15 | save_yaml(server_config, config_path) 16 | 17 | return load_yaml(config_path) 18 | 19 | 20 | @functools.lru_cache 21 | def read_env(key: str, default: str = None, dotenv_path: str = SERVER_DOTENV_PATH): 22 | load_dotenv(dotenv_path) 23 | value = os.getenv(key, default) 24 | return value 25 | 26 | -------------------------------------------------------------------------------- /backend/common/models.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field, validator 2 | from typing import List, Any, Optional 3 | from backend.db.task.models import TaskStatus, ResultType, TaskType 4 | 5 | 6 | class QueueResponse(BaseModel): 7 | identifier: str = Field(..., description="Unique identifier for the queued task that can be used for tracking") 8 | status: TaskStatus = Field(..., description="Current status of the task") 9 | message: str = Field(..., description="Message providing additional information about the task") 10 | 11 | 12 | class Response(BaseModel): 13 | identifier: str 14 | message: str 15 | -------------------------------------------------------------------------------- /backend/configs/.env.example: -------------------------------------------------------------------------------- 1 | HF_TOKEN="YOUR_HF_TOKEN FOR DIARIZATION MODEL (READ PERMISSION)" 2 | DB_URL="sqlite:///backend/records.db" -------------------------------------------------------------------------------- /backend/configs/config.yaml: -------------------------------------------------------------------------------- 1 | whisper: 2 | # Default implementation is faster-whisper. This indicates model name within `models\Whisper\faster-whisper` 3 | model_size: large-v2 4 | # Compute type. 'float16' for CUDA, 'float32' for CPU. 5 | compute_type: float16 6 | # Whether to offload the model after the inference. 7 | enable_offload: true 8 | 9 | bgm_separation: 10 | # UVR model sizes between ["UVR-MDX-NET-Inst_HQ_4", "UVR-MDX-NET-Inst_3"] 11 | model_size: UVR-MDX-NET-Inst_HQ_4 12 | # Whether to offload the model after the inference. Should be true if your setup has a VRAM less than <16GB 13 | enable_offload: true 14 | # Device to load BGM separation model between ["cuda", "cpu", "xpu"] 15 | device: cuda 16 | 17 | # Settings that apply to the `cache' directory. The output files for `/bgm-separation` are stored in the `cache' directory, 18 | # (You can check out the actual generated files by testing `/bgm-separation`.) 19 | # You can adjust the TTL/cleanup frequency of the files in the `cache' directory here. 20 | cache: 21 | # TTL (Time-To-Live) in seconds, defaults to 10 minutes 22 | ttl: 600 23 | # Clean up frequency in seconds, defaults to 1 minutes 24 | frequency: 60 25 | 26 | -------------------------------------------------------------------------------- /backend/db/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/backend/db/__init__.py -------------------------------------------------------------------------------- /backend/db/db_instance.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import os 3 | from sqlalchemy import create_engine 4 | from sqlalchemy.orm import sessionmaker 5 | from functools import wraps 6 | from sqlalchemy.exc import SQLAlchemyError 7 | from fastapi import HTTPException 8 | from sqlmodel import SQLModel 9 | from dotenv import load_dotenv 10 | 11 | from backend.common.config_loader import read_env 12 | 13 | 14 | @functools.lru_cache 15 | def init_db(): 16 | db_url = read_env("DB_URL", "sqlite:///backend/records.db") 17 | engine = create_engine(db_url, connect_args={"check_same_thread": False}) 18 | SQLModel.metadata.create_all(engine) 19 | return sessionmaker(autocommit=False, autoflush=False, bind=engine) 20 | 21 | 22 | def get_db_session(): 23 | db_instance = init_db() 24 | return db_instance() 25 | 26 | 27 | def handle_database_errors(func): 28 | @wraps(func) 29 | def wrapper(*args, **kwargs): 30 | session = None 31 | try: 32 | session = get_db_session() 33 | kwargs['session'] = session 34 | 35 | return func(*args, **kwargs) 36 | except Exception as e: 37 | print(f"Database error has occurred: {e}") 38 | raise 39 | finally: 40 | if session: 41 | session.close() 42 | return wrapper 43 | -------------------------------------------------------------------------------- /backend/db/task/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/backend/db/task/__init__.py -------------------------------------------------------------------------------- /backend/db/task/dao.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | from sqlalchemy.orm import Session 3 | from fastapi import Depends 4 | 5 | from ..db_instance import handle_database_errors, get_db_session 6 | from .models import Task, TasksResult, TaskStatus 7 | 8 | 9 | @handle_database_errors 10 | def add_task_to_db( 11 | session, 12 | status=TaskStatus.QUEUED, 13 | task_type=None, 14 | language=None, 15 | task_params=None, 16 | file_name=None, 17 | url=None, 18 | audio_duration=None, 19 | ): 20 | """ 21 | Add task to the db 22 | """ 23 | task = Task( 24 | status=status, 25 | language=language, 26 | file_name=file_name, 27 | url=url, 28 | task_type=task_type, 29 | task_params=task_params, 30 | audio_duration=audio_duration, 31 | ) 32 | session.add(task) 33 | session.commit() 34 | return task.uuid 35 | 36 | 37 | @handle_database_errors 38 | def update_task_status_in_db( 39 | identifier: str, 40 | update_data: Dict[str, Any], 41 | session: Session, 42 | ): 43 | """ 44 | Update task status and attributes in the database. 45 | 46 | Args: 47 | identifier (str): Identifier of the task to be updated. 48 | update_data (Dict[str, Any]): Dictionary containing the attributes to update along with their new values. 49 | session (Session, optional): Database session. Defaults to Depends(get_db_session). 50 | 51 | Returns: 52 | None 53 | """ 54 | task = session.query(Task).filter_by(uuid=identifier).first() 55 | if task: 56 | for key, value in update_data.items(): 57 | setattr(task, key, value) 58 | session.commit() 59 | 60 | 61 | @handle_database_errors 62 | def get_task_status_from_db( 63 | identifier: str, session: Session 64 | ): 65 | """Retrieve task status from db""" 66 | task = session.query(Task).filter(Task.uuid == identifier).first() 67 | if task: 68 | return task 69 | else: 70 | return None 71 | 72 | 73 | @handle_database_errors 74 | def get_all_tasks_status_from_db(session: Session): 75 | """Get all tasks from db""" 76 | columns = [Task.uuid, Task.status, Task.task_type] 77 | query = session.query(*columns) 78 | tasks = [task for task in query] 79 | return TasksResult(tasks=tasks) 80 | 81 | 82 | @handle_database_errors 83 | def delete_task_from_db(identifier: str, session: Session): 84 | """Delete task from db""" 85 | task = session.query(Task).filter(Task.uuid == identifier).first() 86 | 87 | if task: 88 | # If the task exists, delete it from the database 89 | session.delete(task) 90 | session.commit() 91 | return True 92 | else: 93 | # If the task does not exist, return False 94 | return False 95 | -------------------------------------------------------------------------------- /backend/db/task/models.py: -------------------------------------------------------------------------------- 1 | # Ported from https://github.com/pavelzbornik/whisperX-FastAPI/blob/main/app/models.py 2 | 3 | from enum import Enum 4 | from pydantic import BaseModel 5 | from typing import Optional, List 6 | from uuid import uuid4 7 | from datetime import datetime 8 | from sqlalchemy.types import Enum as SQLAlchemyEnum 9 | from typing import Any 10 | from sqlmodel import SQLModel, Field, JSON, Column 11 | 12 | 13 | class ResultType(str, Enum): 14 | JSON = "json" 15 | FILEPATH = "filepath" 16 | 17 | 18 | class TaskStatus(str, Enum): 19 | PENDING = "pending" 20 | IN_PROGRESS = "in_progress" 21 | COMPLETED = "completed" 22 | FAILED = "failed" 23 | CANCELLED = "cancelled" 24 | QUEUED = "queued" 25 | PAUSED = "paused" 26 | RETRYING = "retrying" 27 | 28 | def __str__(self): 29 | return self.value 30 | 31 | 32 | class TaskType(str, Enum): 33 | TRANSCRIPTION = "transcription" 34 | VAD = "vad" 35 | BGM_SEPARATION = "bgm_separation" 36 | 37 | def __str__(self): 38 | return self.value 39 | 40 | 41 | class TaskStatusResponse(BaseModel): 42 | """`TaskStatusResponse` is a wrapper class that hides sensitive information from `Task`""" 43 | identifier: str = Field(..., description="Unique identifier for the queued task that can be used for tracking") 44 | status: TaskStatus = Field(..., description="Current status of the task") 45 | task_type: Optional[TaskType] = Field( 46 | default=None, 47 | description="Type/category of the task" 48 | ) 49 | result_type: Optional[ResultType] = Field( 50 | default=ResultType.JSON, 51 | description="Result type whether it's a filepath or JSON" 52 | ) 53 | result: Optional[Any] = Field( 54 | default=None, 55 | description="JSON data representing the result of the task" 56 | ) 57 | task_params: Optional[dict] = Field( 58 | default=None, 59 | description="Parameters of the task" 60 | ) 61 | error: Optional[str] = Field( 62 | default=None, 63 | description="Error message, if any, associated with the task" 64 | ) 65 | duration: Optional[float] = Field( 66 | default=None, 67 | description="Duration of the task execution" 68 | ) 69 | progress: Optional[float] = Field( 70 | default=0.0, 71 | description="Progress of the task" 72 | ) 73 | 74 | 75 | class Task(SQLModel, table=True): 76 | """ 77 | Table to store tasks information. 78 | 79 | Attributes: 80 | - id: Unique identifier for each task (Primary Key). 81 | - uuid: Universally unique identifier for each task. 82 | - status: Current status of the task. 83 | - result: JSON data representing the result of the task. 84 | - result_type: Type of the data whether it is normal JSON data or filepath. 85 | - file_name: Name of the file associated with the task. 86 | - task_type: Type/category of the task. 87 | - duration: Duration of the task execution. 88 | - error: Error message, if any, associated with the task. 89 | - created_at: Date and time of creation. 90 | - updated_at: Date and time of last update. 91 | - progress: Progress of the task. If it is None, it means the progress tracking for the task is not started yet. 92 | """ 93 | 94 | __tablename__ = "tasks" 95 | 96 | id: Optional[int] = Field( 97 | default=None, 98 | primary_key=True, 99 | description="Unique identifier for each task (Primary Key)" 100 | ) 101 | uuid: str = Field( 102 | default_factory=lambda: str(uuid4()), 103 | description="Universally unique identifier for each task" 104 | ) 105 | status: Optional[TaskStatus] = Field( 106 | default=None, 107 | sa_column=Field(sa_column=SQLAlchemyEnum(TaskStatus)), 108 | description="Current status of the task", 109 | ) 110 | result: Optional[dict] = Field( 111 | default_factory=dict, 112 | sa_column=Column(JSON), 113 | description="JSON data representing the result of the task" 114 | ) 115 | result_type: Optional[ResultType] = Field( 116 | default=ResultType.JSON, 117 | sa_column=Field(sa_column=SQLAlchemyEnum(ResultType)), 118 | description="Result type whether it's a filepath or JSON" 119 | ) 120 | file_name: Optional[str] = Field( 121 | default=None, 122 | description="Name of the file associated with the task" 123 | ) 124 | url: Optional[str] = Field( 125 | default=None, 126 | description="URL of the file associated with the task" 127 | ) 128 | audio_duration: Optional[float] = Field( 129 | default=None, 130 | description="Duration of the audio in seconds" 131 | ) 132 | language: Optional[str] = Field( 133 | default=None, 134 | description="Language of the file associated with the task" 135 | ) 136 | task_type: Optional[TaskType] = Field( 137 | default=None, 138 | sa_column=Field(sa_column=SQLAlchemyEnum(TaskType)), 139 | description="Type/category of the task" 140 | ) 141 | task_params: Optional[dict] = Field( 142 | default_factory=dict, 143 | sa_column=Column(JSON), 144 | description="Parameters of the task" 145 | ) 146 | duration: Optional[float] = Field( 147 | default=None, 148 | description="Duration of the task execution" 149 | ) 150 | error: Optional[str] = Field( 151 | default=None, 152 | description="Error message, if any, associated with the task" 153 | ) 154 | created_at: datetime = Field( 155 | default_factory=datetime.utcnow, 156 | description="Date and time of creation" 157 | ) 158 | updated_at: datetime = Field( 159 | default_factory=datetime.utcnow, 160 | sa_column_kwargs={"onupdate": datetime.utcnow}, 161 | description="Date and time of last update" 162 | ) 163 | progress: Optional[float] = Field( 164 | default=0.0, 165 | description="Progress of the task" 166 | ) 167 | 168 | def to_response(self) -> "TaskStatusResponse": 169 | return TaskStatusResponse( 170 | identifier=self.uuid, 171 | status=self.status, 172 | task_type=self.task_type, 173 | result_type=self.result_type, 174 | result=self.result, 175 | task_params=self.task_params, 176 | error=self.error, 177 | duration=self.duration, 178 | progress=self.progress 179 | ) 180 | 181 | 182 | class TasksResult(BaseModel): 183 | tasks: List[Task] 184 | 185 | -------------------------------------------------------------------------------- /backend/docker-compose.yaml: -------------------------------------------------------------------------------- 1 | services: 2 | app: 3 | build: 4 | dockerfile: backend/Dockerfile 5 | context: .. 6 | image: jhj0517/whisper-webui-backend:latest 7 | 8 | volumes: 9 | # You can mount the container's volume paths to directory paths on your local machine. 10 | # Models will be stored in the `./models' directory on your machine. 11 | # Similarly, all output files will be stored in the `./outputs` directory. 12 | # The DB file is saved in /Whisper-WebUI/backend/records.db unless you edit it in /Whisper-WebUI/backend/configs/.env 13 | - ./models:/Whisper-WebUI/models 14 | - ./outputs:/Whisper-WebUI/outputs 15 | - ./backend:/Whisper-WebUI/backend 16 | 17 | ports: 18 | - "8000:8000" 19 | 20 | stdin_open: true 21 | tty: true 22 | 23 | entrypoint: ["uvicorn", "backend.main:app", "--host", "0.0.0.0", "--port", "8000"] 24 | 25 | # If you're not using Nvidia GPU, Update device to match yours. 26 | # See more info at : https://docs.docker.com/compose/compose-file/deploy/#driver 27 | deploy: 28 | resources: 29 | reservations: 30 | devices: 31 | - driver: nvidia 32 | count: all 33 | capabilities: [ gpu ] -------------------------------------------------------------------------------- /backend/main.py: -------------------------------------------------------------------------------- 1 | from contextlib import asynccontextmanager 2 | from fastapi import ( 3 | FastAPI, 4 | ) 5 | from fastapi.responses import RedirectResponse 6 | from fastapi.middleware.cors import CORSMiddleware 7 | import os 8 | import time 9 | import threading 10 | 11 | from backend.db.db_instance import init_db 12 | from backend.routers.transcription.router import transcription_router, get_pipeline 13 | from backend.routers.vad.router import get_vad_model, vad_router 14 | from backend.routers.bgm_separation.router import get_bgm_separation_inferencer, bgm_separation_router 15 | from backend.routers.task.router import task_router 16 | from backend.common.config_loader import read_env, load_server_config 17 | from backend.common.cache_manager import cleanup_old_files 18 | from modules.utils.paths import SERVER_CONFIG_PATH, BACKEND_CACHE_DIR 19 | 20 | 21 | def clean_cache_thread(ttl: int, frequency: int) -> threading.Thread: 22 | def clean_cache(_ttl: int, _frequency: int): 23 | while True: 24 | cleanup_old_files(cache_dir=BACKEND_CACHE_DIR, ttl=_ttl) 25 | time.sleep(_frequency) 26 | 27 | return threading.Thread( 28 | target=clean_cache, 29 | args=(ttl, frequency), 30 | daemon=True 31 | ) 32 | 33 | 34 | @asynccontextmanager 35 | async def lifespan(app: FastAPI): 36 | # Basic setup initialization 37 | server_config = load_server_config() 38 | read_env("DB_URL") # Place .env file into /configs/.env 39 | init_db() 40 | 41 | # Inferencer initialization 42 | transcription_pipeline = get_pipeline() 43 | vad_inferencer = get_vad_model() 44 | bgm_separation_inferencer = get_bgm_separation_inferencer() 45 | 46 | # Thread initialization 47 | cache_thread = clean_cache_thread(server_config["cache"]["ttl"], server_config["cache"]["frequency"]) 48 | cache_thread.start() 49 | 50 | yield 51 | 52 | # Release VRAM when server shutdown 53 | transcription_pipeline = None 54 | vad_inferencer = None 55 | bgm_separation_inferencer = None 56 | 57 | 58 | app = FastAPI( 59 | title="Whisper-WebUI-Backend", 60 | description=f""" 61 | REST API for Whisper-WebUI. Swagger UI is available via /docs or root URL with redirection. Redoc is available via /redoc. 62 | """, 63 | version="0.0.1", 64 | lifespan=lifespan, 65 | openapi_tags=[ 66 | { 67 | "name": "BGM Separation", 68 | "description": "Cached files for /bgm-separation are generated in the `backend/cache` directory," 69 | " you can set TLL for these files in `backend/configs/config.yaml`." 70 | } 71 | ] 72 | ) 73 | app.add_middleware( 74 | CORSMiddleware, 75 | allow_origins=["*"], 76 | allow_credentials=True, 77 | allow_methods=["GET", "POST", "PUT", "PATCH", "OPTIONS"], # Disable DELETE 78 | allow_headers=["*"], 79 | ) 80 | app.include_router(transcription_router) 81 | app.include_router(vad_router) 82 | app.include_router(bgm_separation_router) 83 | app.include_router(task_router) 84 | 85 | 86 | @app.get("/", response_class=RedirectResponse, include_in_schema=False) 87 | async def index(): 88 | """ 89 | Redirect to the documentation. Defaults to Swagger UI. 90 | You can also check the /redoc with redoc style: https://github.com/Redocly/redoc 91 | """ 92 | return "/docs" 93 | -------------------------------------------------------------------------------- /backend/nginx/logs/logs_are_generated_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/backend/nginx/logs/logs_are_generated_here -------------------------------------------------------------------------------- /backend/nginx/nginx.conf: -------------------------------------------------------------------------------- 1 | worker_processes 1; 2 | 3 | events { 4 | worker_connections 1024; 5 | } 6 | 7 | http { 8 | server { 9 | listen 80; 10 | client_max_body_size 4G; 11 | 12 | server_name your-own-domain-name.com; 13 | 14 | location / { 15 | proxy_pass http://127.0.0.1:8000; 16 | proxy_set_header Host $host; 17 | proxy_set_header X-Real-IP $remote_addr; 18 | proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; 19 | proxy_set_header X-Forwarded-Proto $scheme; 20 | } 21 | } 22 | } 23 | 24 | -------------------------------------------------------------------------------- /backend/nginx/temp/temps_are_generated_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/backend/nginx/temp/temps_are_generated_here -------------------------------------------------------------------------------- /backend/requirements-backend.txt: -------------------------------------------------------------------------------- 1 | # Whisper-WebUI dependencies 2 | -r ../requirements.txt 3 | 4 | # Backend dependencies 5 | python-dotenv 6 | uvicorn 7 | SQLAlchemy 8 | sqlmodel 9 | pydantic 10 | 11 | # Test dependencies 12 | # pytest 13 | # pytest-asyncio -------------------------------------------------------------------------------- /backend/routers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/backend/routers/__init__.py -------------------------------------------------------------------------------- /backend/routers/bgm_separation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/backend/routers/bgm_separation/__init__.py -------------------------------------------------------------------------------- /backend/routers/bgm_separation/models.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | 3 | 4 | class BGMSeparationResult(BaseModel): 5 | instrumental_hash: str = Field(..., description="Instrumental file hash") 6 | vocal_hash: str = Field(..., description="Vocal file hash") 7 | -------------------------------------------------------------------------------- /backend/routers/bgm_separation/router.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import numpy as np 3 | from fastapi import ( 4 | File, 5 | UploadFile, 6 | ) 7 | import gradio as gr 8 | from fastapi import APIRouter, BackgroundTasks, Depends, Response, status 9 | from fastapi.responses import FileResponse 10 | from typing import List, Dict, Tuple 11 | from datetime import datetime 12 | import os 13 | 14 | from modules.whisper.data_classes import * 15 | from modules.uvr.music_separator import MusicSeparator 16 | from modules.utils.paths import BACKEND_CACHE_DIR 17 | from backend.common.audio import read_audio 18 | from backend.common.models import QueueResponse 19 | from backend.common.config_loader import load_server_config 20 | from backend.common.compresser import get_file_hash, find_file_by_hash 21 | from backend.db.task.models import TaskStatus, TaskType, ResultType 22 | from backend.db.task.dao import add_task_to_db, update_task_status_in_db 23 | from .models import BGMSeparationResult 24 | 25 | 26 | bgm_separation_router = APIRouter(prefix="/bgm-separation", tags=["BGM Separation"]) 27 | 28 | 29 | @functools.lru_cache 30 | def get_bgm_separation_inferencer() -> 'MusicSeparator': 31 | config = load_server_config()["bgm_separation"] 32 | inferencer = MusicSeparator( 33 | output_dir=os.path.join(BACKEND_CACHE_DIR, "UVR") 34 | ) 35 | inferencer.update_model( 36 | model_name=config["model_size"], 37 | device=config["device"] 38 | ) 39 | return inferencer 40 | 41 | 42 | def run_bgm_separation( 43 | audio: np.ndarray, 44 | params: BGMSeparationParams, 45 | identifier: str, 46 | ) -> Tuple[np.ndarray, np.ndarray]: 47 | update_task_status_in_db( 48 | identifier=identifier, 49 | update_data={ 50 | "uuid": identifier, 51 | "status": TaskStatus.IN_PROGRESS, 52 | "updated_at": datetime.utcnow() 53 | } 54 | ) 55 | 56 | start_time = datetime.utcnow() 57 | instrumental, vocal, filepaths = get_bgm_separation_inferencer().separate( 58 | audio=audio, 59 | model_name=params.uvr_model_size, 60 | device=params.uvr_device, 61 | segment_size=params.segment_size, 62 | save_file=True, 63 | progress=gr.Progress() 64 | ) 65 | instrumental_path, vocal_path = filepaths 66 | elapsed_time = (datetime.utcnow() - start_time).total_seconds() 67 | 68 | update_task_status_in_db( 69 | identifier=identifier, 70 | update_data={ 71 | "uuid": identifier, 72 | "status": TaskStatus.COMPLETED, 73 | "result": BGMSeparationResult( 74 | instrumental_hash=get_file_hash(instrumental_path), 75 | vocal_hash=get_file_hash(vocal_path) 76 | ).model_dump(), 77 | "result_type": ResultType.FILEPATH, 78 | "updated_at": datetime.utcnow(), 79 | "duration": elapsed_time 80 | } 81 | ) 82 | return instrumental, vocal 83 | 84 | 85 | @bgm_separation_router.post( 86 | "/", 87 | response_model=QueueResponse, 88 | status_code=status.HTTP_201_CREATED, 89 | summary="Separate Background BGM abd vocal", 90 | description="Separate background music and vocal from an uploaded audio or video file.", 91 | ) 92 | async def bgm_separation( 93 | background_tasks: BackgroundTasks, 94 | file: UploadFile = File(..., description="Audio or video file to separate background music."), 95 | params: BGMSeparationParams = Depends() 96 | ) -> QueueResponse: 97 | if not isinstance(file, np.ndarray): 98 | audio, info = await read_audio(file=file) 99 | else: 100 | audio, info = file, None 101 | 102 | identifier = add_task_to_db( 103 | status=TaskStatus.QUEUED, 104 | file_name=file.filename, 105 | audio_duration=info.duration if info else None, 106 | task_type=TaskType.BGM_SEPARATION, 107 | task_params=params.model_dump(), 108 | ) 109 | 110 | background_tasks.add_task( 111 | run_bgm_separation, 112 | audio=audio, 113 | params=params, 114 | identifier=identifier 115 | ) 116 | 117 | return QueueResponse(identifier=identifier, status=TaskStatus.QUEUED, message="BGM Separation task has queued") 118 | 119 | 120 | -------------------------------------------------------------------------------- /backend/routers/task/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/backend/routers/task/__init__.py -------------------------------------------------------------------------------- /backend/routers/task/router.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, Depends, HTTPException, status 2 | from fastapi.responses import FileResponse 3 | from sqlalchemy.orm import Session 4 | import os 5 | 6 | from backend.db.db_instance import get_db_session 7 | from backend.db.task.dao import ( 8 | get_task_status_from_db, 9 | get_all_tasks_status_from_db, 10 | delete_task_from_db, 11 | ) 12 | from backend.db.task.models import ( 13 | TasksResult, 14 | Task, 15 | TaskStatusResponse, 16 | TaskType 17 | ) 18 | from backend.common.models import ( 19 | Response, 20 | ) 21 | from backend.common.compresser import compress_files, find_file_by_hash 22 | from modules.utils.paths import BACKEND_CACHE_DIR 23 | 24 | task_router = APIRouter(prefix="/task", tags=["Tasks"]) 25 | 26 | 27 | @task_router.get( 28 | "/{identifier}", 29 | response_model=TaskStatusResponse, 30 | status_code=status.HTTP_200_OK, 31 | summary="Retrieve Task by Identifier", 32 | description="Retrieve the specific task by its identifier.", 33 | ) 34 | async def get_task( 35 | identifier: str, 36 | session: Session = Depends(get_db_session), 37 | ) -> TaskStatusResponse: 38 | """ 39 | Retrieve the specific task by its identifier. 40 | """ 41 | task = get_task_status_from_db(identifier=identifier, session=session) 42 | 43 | if task is not None: 44 | return task.to_response() 45 | else: 46 | raise HTTPException(status_code=404, detail="Identifier not found") 47 | 48 | 49 | @task_router.get( 50 | "/file/{identifier}", 51 | status_code=status.HTTP_200_OK, 52 | summary="Retrieve FileResponse Task by Identifier", 53 | description="Retrieve the file response task by its identifier. You can use this endpoint if you need to download" 54 | " The file as a response", 55 | ) 56 | async def get_file_task( 57 | identifier: str, 58 | session: Session = Depends(get_db_session), 59 | ) -> FileResponse: 60 | """ 61 | Retrieve the downloadable file response of a specific task by its identifier. 62 | Compressed by ZIP basically. 63 | """ 64 | task = get_task_status_from_db(identifier=identifier, session=session) 65 | 66 | if task is not None: 67 | if task.task_type == TaskType.BGM_SEPARATION: 68 | output_zip_path = os.path.join(BACKEND_CACHE_DIR, f"{identifier}_bgm_separation.zip") 69 | instrumental_path = find_file_by_hash( 70 | os.path.join(BACKEND_CACHE_DIR, "UVR", "instrumental"), 71 | task.result["instrumental_hash"] 72 | ) 73 | vocal_path = find_file_by_hash( 74 | os.path.join(BACKEND_CACHE_DIR, "UVR", "vocals"), 75 | task.result["vocal_hash"] 76 | ) 77 | 78 | output_zip_path = compress_files( 79 | [instrumental_path, vocal_path], 80 | output_zip_path 81 | ) 82 | return FileResponse( 83 | path=output_zip_path, 84 | status_code=200, 85 | filename=output_zip_path, 86 | media_type="application/zip" 87 | ) 88 | else: 89 | raise HTTPException(status_code=404, detail=f"File download is only supported for bgm separation." 90 | f" The given type is {task.task_type}") 91 | else: 92 | raise HTTPException(status_code=404, detail="Identifier not found") 93 | 94 | 95 | # Delete method, commented by default because this endpoint is likely to require special permissions 96 | # @task_router.delete( 97 | # "/{identifier}", 98 | # response_model=Response, 99 | # status_code=status.HTTP_200_OK, 100 | # summary="Delete Task by Identifier", 101 | # description="Delete a task from the system using its identifier.", 102 | # ) 103 | async def delete_task( 104 | identifier: str, 105 | session: Session = Depends(get_db_session), 106 | ) -> Response: 107 | """ 108 | Delete a task by its identifier. 109 | """ 110 | if delete_task_from_db(identifier, session): 111 | return Response(identifier=identifier, message="Task deleted") 112 | else: 113 | raise HTTPException(status_code=404, detail="Task not found") 114 | 115 | 116 | # Get All method, commented by default because this endpoint is likely to require special permissions 117 | # @task_router.get( 118 | # "/all", 119 | # response_model=TasksResult, 120 | # status_code=status.HTTP_200_OK, 121 | # summary="Retrieve All Task Statuses", 122 | # description="Retrieve the statuses of all tasks available in the system.", 123 | # ) 124 | async def get_all_tasks_status( 125 | session: Session = Depends(get_db_session), 126 | ) -> TasksResult: 127 | """ 128 | Retrieve all tasks. 129 | """ 130 | return get_all_tasks_status_from_db(session=session) -------------------------------------------------------------------------------- /backend/routers/transcription/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/backend/routers/transcription/__init__.py -------------------------------------------------------------------------------- /backend/routers/transcription/router.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import uuid 3 | import numpy as np 4 | from fastapi import ( 5 | File, 6 | UploadFile, 7 | ) 8 | import gradio as gr 9 | from fastapi import APIRouter, BackgroundTasks, Depends, Response, status 10 | from typing import List, Dict 11 | from sqlalchemy.orm import Session 12 | from datetime import datetime 13 | from modules.whisper.data_classes import * 14 | from modules.utils.paths import BACKEND_CACHE_DIR 15 | from modules.whisper.faster_whisper_inference import FasterWhisperInference 16 | from backend.common.audio import read_audio 17 | from backend.common.models import QueueResponse 18 | from backend.common.config_loader import load_server_config 19 | from backend.db.task.dao import ( 20 | add_task_to_db, 21 | get_db_session, 22 | update_task_status_in_db 23 | ) 24 | from backend.db.task.models import TaskStatus, TaskType 25 | 26 | transcription_router = APIRouter(prefix="/transcription", tags=["Transcription"]) 27 | 28 | 29 | def create_progress_callback(identifier: str): 30 | def progress_callback(progress_value: float): 31 | update_task_status_in_db( 32 | identifier=identifier, 33 | update_data={ 34 | "uuid": identifier, 35 | "status": TaskStatus.IN_PROGRESS, 36 | "progress": round(progress_value, 2), 37 | "updated_at": datetime.utcnow() 38 | }, 39 | ) 40 | return progress_callback 41 | 42 | 43 | @functools.lru_cache 44 | def get_pipeline() -> 'FasterWhisperInference': 45 | config = load_server_config()["whisper"] 46 | inferencer = FasterWhisperInference( 47 | output_dir=BACKEND_CACHE_DIR 48 | ) 49 | inferencer.update_model( 50 | model_size=config["model_size"], 51 | compute_type=config["compute_type"] 52 | ) 53 | return inferencer 54 | 55 | 56 | def run_transcription( 57 | audio: np.ndarray, 58 | params: TranscriptionPipelineParams, 59 | identifier: str, 60 | ) -> List[Segment]: 61 | update_task_status_in_db( 62 | identifier=identifier, 63 | update_data={ 64 | "uuid": identifier, 65 | "status": TaskStatus.IN_PROGRESS, 66 | "updated_at": datetime.utcnow() 67 | }, 68 | ) 69 | 70 | progress_callback = create_progress_callback(identifier) 71 | segments, elapsed_time = get_pipeline().run( 72 | audio, 73 | gr.Progress(), 74 | "SRT", 75 | False, 76 | progress_callback, 77 | *params.to_list() 78 | ) 79 | segments = [seg.model_dump() for seg in segments] 80 | 81 | update_task_status_in_db( 82 | identifier=identifier, 83 | update_data={ 84 | "uuid": identifier, 85 | "status": TaskStatus.COMPLETED, 86 | "result": segments, 87 | "updated_at": datetime.utcnow(), 88 | "duration": elapsed_time, 89 | "progress": 1.0, 90 | }, 91 | ) 92 | return segments 93 | 94 | 95 | @transcription_router.post( 96 | "/", 97 | response_model=QueueResponse, 98 | status_code=status.HTTP_201_CREATED, 99 | summary="Transcribe Audio", 100 | description="Process the provided audio or video file to generate a transcription.", 101 | ) 102 | async def transcription( 103 | background_tasks: BackgroundTasks, 104 | file: UploadFile = File(..., description="Audio or video file to transcribe."), 105 | whisper_params: WhisperParams = Depends(), 106 | vad_params: VadParams = Depends(), 107 | bgm_separation_params: BGMSeparationParams = Depends(), 108 | diarization_params: DiarizationParams = Depends(), 109 | ) -> QueueResponse: 110 | if not isinstance(file, np.ndarray): 111 | audio, info = await read_audio(file=file) 112 | else: 113 | audio, info = file, None 114 | 115 | params = TranscriptionPipelineParams( 116 | whisper=whisper_params, 117 | vad=vad_params, 118 | bgm_separation=bgm_separation_params, 119 | diarization=diarization_params 120 | ) 121 | 122 | identifier = add_task_to_db( 123 | status=TaskStatus.QUEUED, 124 | file_name=file.filename, 125 | audio_duration=info.duration if info else None, 126 | language=params.whisper.lang, 127 | task_type=TaskType.TRANSCRIPTION, 128 | task_params=params.to_dict(), 129 | ) 130 | 131 | background_tasks.add_task( 132 | run_transcription, 133 | audio=audio, 134 | params=params, 135 | identifier=identifier, 136 | ) 137 | 138 | return QueueResponse(identifier=identifier, status=TaskStatus.QUEUED, message="Transcription task has queued") 139 | 140 | 141 | -------------------------------------------------------------------------------- /backend/routers/vad/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/backend/routers/vad/__init__.py -------------------------------------------------------------------------------- /backend/routers/vad/router.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import numpy as np 3 | from faster_whisper.vad import VadOptions 4 | from fastapi import ( 5 | File, 6 | UploadFile, 7 | ) 8 | from fastapi import APIRouter, BackgroundTasks, Depends, Response, status 9 | from typing import List, Dict 10 | from datetime import datetime 11 | 12 | from modules.vad.silero_vad import SileroVAD 13 | from modules.whisper.data_classes import VadParams 14 | from backend.common.audio import read_audio 15 | from backend.common.models import QueueResponse 16 | from backend.db.task.dao import add_task_to_db, update_task_status_in_db 17 | from backend.db.task.models import TaskStatus, TaskType 18 | 19 | vad_router = APIRouter(prefix="/vad", tags=["Voice Activity Detection"]) 20 | 21 | 22 | @functools.lru_cache 23 | def get_vad_model() -> SileroVAD: 24 | inferencer = SileroVAD() 25 | inferencer.update_model() 26 | return inferencer 27 | 28 | 29 | def run_vad( 30 | audio: np.ndarray, 31 | params: VadOptions, 32 | identifier: str, 33 | ) -> List[Dict]: 34 | update_task_status_in_db( 35 | identifier=identifier, 36 | update_data={ 37 | "uuid": identifier, 38 | "status": TaskStatus.IN_PROGRESS, 39 | "updated_at": datetime.utcnow() 40 | } 41 | ) 42 | 43 | start_time = datetime.utcnow() 44 | audio, speech_chunks = get_vad_model().run( 45 | audio=audio, 46 | vad_parameters=params 47 | ) 48 | elapsed_time = (datetime.utcnow() - start_time).total_seconds() 49 | 50 | update_task_status_in_db( 51 | identifier=identifier, 52 | update_data={ 53 | "uuid": identifier, 54 | "status": TaskStatus.COMPLETED, 55 | "updated_at": datetime.utcnow(), 56 | "result": speech_chunks, 57 | "duration": elapsed_time 58 | } 59 | ) 60 | 61 | return speech_chunks 62 | 63 | 64 | @vad_router.post( 65 | "/", 66 | response_model=QueueResponse, 67 | status_code=status.HTTP_201_CREATED, 68 | summary="Voice Activity Detection", 69 | description="Detect voice parts in the provided audio or video file to generate a timeline of speech segments.", 70 | ) 71 | async def vad( 72 | background_tasks: BackgroundTasks, 73 | file: UploadFile = File(..., description="Audio or video file to detect voices."), 74 | params: VadParams = Depends() 75 | ) -> QueueResponse: 76 | if not isinstance(file, np.ndarray): 77 | audio, info = await read_audio(file=file) 78 | else: 79 | audio, info = file, None 80 | 81 | vad_options = VadOptions( 82 | threshold=params.threshold, 83 | min_speech_duration_ms=params.min_speech_duration_ms, 84 | max_speech_duration_s=params.max_speech_duration_s, 85 | min_silence_duration_ms=params.min_silence_duration_ms, 86 | speech_pad_ms=params.speech_pad_ms 87 | ) 88 | 89 | identifier = add_task_to_db( 90 | status=TaskStatus.QUEUED, 91 | file_name=file.filename, 92 | audio_duration=info.duration if info else None, 93 | task_type=TaskType.VAD, 94 | task_params=params.model_dump(), 95 | ) 96 | 97 | background_tasks.add_task(run_vad, audio=audio, params=vad_options, identifier=identifier) 98 | 99 | return QueueResponse(identifier=identifier, status=TaskStatus.QUEUED, message="VAD task has queued") 100 | 101 | 102 | -------------------------------------------------------------------------------- /backend/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/backend/tests/__init__.py -------------------------------------------------------------------------------- /backend/tests/test_backend_bgm_separation.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from fastapi import UploadFile 3 | from io import BytesIO 4 | import os 5 | import torch 6 | 7 | from backend.db.task.models import TaskStatus 8 | from backend.tests.test_task_status import wait_for_task_completion, fetch_file_response 9 | from backend.tests.test_backend_config import ( 10 | get_client, setup_test_file, get_upload_file_instance, calculate_wer, 11 | TEST_BGM_SEPARATION_PARAMS, TEST_ANSWER, TEST_BGM_SEPARATION_OUTPUT_PATH 12 | ) 13 | 14 | 15 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip the test because CUDA is not available") 16 | @pytest.mark.parametrize( 17 | "bgm_separation_params", 18 | [ 19 | TEST_BGM_SEPARATION_PARAMS 20 | ] 21 | ) 22 | def test_transcription_endpoint( 23 | get_upload_file_instance, 24 | bgm_separation_params: dict 25 | ): 26 | client = get_client() 27 | file_content = BytesIO(get_upload_file_instance.file.read()) 28 | get_upload_file_instance.file.seek(0) 29 | 30 | response = client.post( 31 | "/bgm-separation", 32 | files={"file": (get_upload_file_instance.filename, file_content, "audio/mpeg")}, 33 | params=bgm_separation_params 34 | ) 35 | 36 | assert response.status_code == 201 37 | assert response.json()["status"] == TaskStatus.QUEUED 38 | task_identifier = response.json()["identifier"] 39 | assert isinstance(task_identifier, str) and task_identifier 40 | 41 | completed_task = wait_for_task_completion( 42 | identifier=task_identifier 43 | ) 44 | 45 | assert completed_task is not None, f"Task with identifier {task_identifier} did not complete within the " \ 46 | f"expected time." 47 | 48 | result = completed_task.json()["result"] 49 | assert "instrumental_hash" in result and result["instrumental_hash"] 50 | assert "vocal_hash" in result and result["vocal_hash"] 51 | 52 | file_response = fetch_file_response(task_identifier) 53 | assert file_response.status_code == 200, f"Fetching File Response has failed. Response is: {file_response}" 54 | 55 | with open(TEST_BGM_SEPARATION_OUTPUT_PATH, "wb") as file: 56 | file.write(file_response.content) 57 | 58 | assert os.path.exists(TEST_BGM_SEPARATION_OUTPUT_PATH) 59 | 60 | -------------------------------------------------------------------------------- /backend/tests/test_backend_config.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from fastapi import FastAPI, UploadFile 3 | from fastapi.testclient import TestClient 4 | from starlette.datastructures import UploadFile as StarletteUploadFile 5 | from io import BytesIO 6 | import os 7 | import requests 8 | import pytest 9 | import yaml 10 | import jiwer 11 | 12 | from backend.main import app 13 | from modules.whisper.data_classes import * 14 | from modules.utils.paths import * 15 | from modules.utils.files_manager import load_yaml, save_yaml 16 | 17 | TEST_PIPELINE_PARAMS = {**WhisperParams(model_size="tiny", compute_type="float32").model_dump(exclude_none=True), 18 | **VadParams().model_dump(exclude_none=True), 19 | **BGMSeparationParams().model_dump(exclude_none=True), 20 | **DiarizationParams().model_dump(exclude_none=True)} 21 | TEST_VAD_PARAMS = VadParams().model_dump() 22 | TEST_BGM_SEPARATION_PARAMS = BGMSeparationParams().model_dump() 23 | TEST_FILE_DOWNLOAD_URL = "https://github.com/jhj0517/whisper_flutter_new/raw/main/example/assets/jfk.wav" 24 | TEST_FILE_PATH = os.path.join(WEBUI_DIR, "backend", "tests", "jfk.wav") 25 | TEST_BGM_SEPARATION_OUTPUT_PATH = os.path.join(WEBUI_DIR, "backend", "tests", "separated_audio.zip") 26 | TEST_ANSWER = "And so my fellow Americans ask not what your country can do for you ask what you can do for your country" 27 | TEST_WHISPER_MODEL = "tiny" 28 | TEST_COMPUTE_TYPE = "float32" 29 | 30 | 31 | @pytest.fixture(autouse=True) 32 | def setup_test_file(): 33 | @functools.lru_cache 34 | def download_file(url=TEST_FILE_DOWNLOAD_URL, file_path=TEST_FILE_PATH): 35 | if os.path.exists(file_path): 36 | return 37 | 38 | if not os.path.exists(os.path.dirname(file_path)): 39 | os.makedirs(os.path.dirname(file_path)) 40 | 41 | response = requests.get(url) 42 | 43 | with open(file_path, "wb") as file: 44 | file.write(response.content) 45 | 46 | print(f"File downloaded to: {file_path}") 47 | 48 | download_file(TEST_FILE_DOWNLOAD_URL, TEST_FILE_PATH) 49 | 50 | 51 | @pytest.fixture 52 | @functools.lru_cache 53 | def get_upload_file_instance(filepath: str = TEST_FILE_PATH) -> UploadFile: 54 | with open(filepath, "rb") as f: 55 | file_contents = BytesIO(f.read()) 56 | filename = os.path.basename(filepath) 57 | upload_file = StarletteUploadFile(file=file_contents, filename=filename) 58 | return upload_file 59 | 60 | 61 | @functools.lru_cache 62 | def get_client(app: FastAPI = app): 63 | return TestClient(app) 64 | 65 | 66 | def calculate_wer(answer, prediction): 67 | return jiwer.wer(answer, prediction) 68 | -------------------------------------------------------------------------------- /backend/tests/test_backend_transcription.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from fastapi import UploadFile 3 | from io import BytesIO 4 | 5 | from backend.db.task.models import TaskStatus 6 | from backend.tests.test_task_status import wait_for_task_completion 7 | from backend.tests.test_backend_config import ( 8 | get_client, setup_test_file, get_upload_file_instance, calculate_wer, 9 | TEST_PIPELINE_PARAMS, TEST_ANSWER 10 | ) 11 | 12 | 13 | @pytest.mark.parametrize( 14 | "pipeline_params", 15 | [ 16 | TEST_PIPELINE_PARAMS 17 | ] 18 | ) 19 | def test_transcription_endpoint( 20 | get_upload_file_instance, 21 | pipeline_params: dict 22 | ): 23 | client = get_client() 24 | file_content = BytesIO(get_upload_file_instance.file.read()) 25 | get_upload_file_instance.file.seek(0) 26 | 27 | response = client.post( 28 | "/transcription", 29 | files={"file": (get_upload_file_instance.filename, file_content, "audio/mpeg")}, 30 | params=pipeline_params 31 | ) 32 | 33 | assert response.status_code == 201 34 | assert response.json()["status"] == TaskStatus.QUEUED 35 | task_identifier = response.json()["identifier"] 36 | assert isinstance(task_identifier, str) and task_identifier 37 | 38 | completed_task = wait_for_task_completion( 39 | identifier=task_identifier 40 | ) 41 | 42 | assert completed_task is not None, f"Task with identifier {task_identifier} did not complete within the " \ 43 | f"expected time." 44 | 45 | result = completed_task.json()["result"] 46 | assert result, "Transcription text is empty" 47 | 48 | wer = calculate_wer(TEST_ANSWER, result[0]["text"].strip().replace(",", "").replace(".", "")) 49 | assert wer < 0.1, f"WER is too high, it's {wer}" 50 | 51 | -------------------------------------------------------------------------------- /backend/tests/test_backend_vad.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from fastapi import UploadFile 3 | from io import BytesIO 4 | 5 | from backend.db.task.models import TaskStatus 6 | from backend.tests.test_task_status import wait_for_task_completion 7 | from backend.tests.test_backend_config import ( 8 | get_client, setup_test_file, get_upload_file_instance, calculate_wer, 9 | TEST_VAD_PARAMS, TEST_ANSWER 10 | ) 11 | 12 | 13 | @pytest.mark.parametrize( 14 | "vad_params", 15 | [ 16 | TEST_VAD_PARAMS 17 | ] 18 | ) 19 | def test_transcription_endpoint( 20 | get_upload_file_instance, 21 | vad_params: dict 22 | ): 23 | client = get_client() 24 | file_content = BytesIO(get_upload_file_instance.file.read()) 25 | get_upload_file_instance.file.seek(0) 26 | 27 | response = client.post( 28 | "/vad", 29 | files={"file": (get_upload_file_instance.filename, file_content, "audio/mpeg")}, 30 | params=vad_params 31 | ) 32 | 33 | assert response.status_code == 201 34 | assert response.json()["status"] == TaskStatus.QUEUED 35 | task_identifier = response.json()["identifier"] 36 | assert isinstance(task_identifier, str) and task_identifier 37 | 38 | completed_task = wait_for_task_completion( 39 | identifier=task_identifier 40 | ) 41 | 42 | assert completed_task is not None, f"Task with identifier {task_identifier} did not complete within the " \ 43 | f"expected time." 44 | 45 | result = completed_task.json()["result"] 46 | assert result and "start" in result[0] and "end" in result[0] 47 | 48 | -------------------------------------------------------------------------------- /backend/tests/test_task_status.py: -------------------------------------------------------------------------------- 1 | import time 2 | import pytest 3 | from typing import Optional, Union 4 | import httpx 5 | 6 | from backend.db.task.models import TaskStatus, Task 7 | from backend.tests.test_backend_config import get_client 8 | 9 | 10 | def fetch_task(identifier: str): 11 | """Get task status""" 12 | client = get_client() 13 | response = client.get( 14 | f"/task/{identifier}" 15 | ) 16 | if response.status_code == 200: 17 | return response 18 | return None 19 | 20 | 21 | def fetch_file_response(identifier: str): 22 | """Get task status""" 23 | client = get_client() 24 | response = client.get( 25 | f"/task/file/{identifier}" 26 | ) 27 | if response.status_code == 200: 28 | return response 29 | return None 30 | 31 | 32 | def wait_for_task_completion(identifier: str, 33 | max_attempts: int = 20, 34 | frequency: int = 3) -> httpx.Response: 35 | """ 36 | Polls the task status until it is completed, failed, or the maximum attempts are reached. 37 | 38 | Args: 39 | identifier (str): The unique identifier of the task to monitor. 40 | max_attempts (int): The maximum number of polling attempts.. 41 | frequency (int): The time (in seconds) to wait between polling attempts. 42 | 43 | Returns: 44 | bool: Returns json if the task completes successfully within the allowed attempts. 45 | """ 46 | attempts = 0 47 | while attempts < max_attempts: 48 | task = fetch_task(identifier) 49 | status = task.json()["status"] 50 | if status == TaskStatus.COMPLETED: 51 | return task 52 | if status == TaskStatus.FAILED: 53 | raise Exception("Task polling failed") 54 | time.sleep(frequency) 55 | attempts += 1 56 | return None 57 | -------------------------------------------------------------------------------- /configs/default_parameters.yaml: -------------------------------------------------------------------------------- 1 | whisper: 2 | model_size: "large-v2" 3 | file_format: "SRT" 4 | lang: "Automatic Detection" 5 | is_translate: false 6 | beam_size: 5 7 | log_prob_threshold: -1 8 | no_speech_threshold: 0.6 9 | best_of: 5 10 | patience: 1 11 | condition_on_previous_text: true 12 | prompt_reset_on_temperature: 0.5 13 | initial_prompt: null 14 | temperature: 0 15 | compression_ratio_threshold: 2.4 16 | chunk_length: 30 17 | batch_size: 24 18 | length_penalty: 1 19 | repetition_penalty: 1 20 | no_repeat_ngram_size: 0 21 | prefix: null 22 | suppress_blank: true 23 | suppress_tokens: "[-1]" 24 | max_initial_timestamp: 1 25 | word_timestamps: false 26 | prepend_punctuations: "\"'“¿([{-" 27 | append_punctuations: "\"'.。,,!!??::”)]}、" 28 | max_new_tokens: null 29 | hallucination_silence_threshold: null 30 | hotwords: null 31 | language_detection_threshold: 0.5 32 | language_detection_segments: 1 33 | add_timestamp: false 34 | enable_offload: true 35 | 36 | vad: 37 | vad_filter: false 38 | threshold: 0.5 39 | min_speech_duration_ms: 250 40 | max_speech_duration_s: 9999 41 | min_silence_duration_ms: 1000 42 | speech_pad_ms: 2000 43 | 44 | diarization: 45 | is_diarize: false 46 | hf_token: "" 47 | enable_offload: true 48 | 49 | bgm_separation: 50 | is_separate_bgm: false 51 | uvr_model_size: "UVR-MDX-NET-Inst_HQ_4" 52 | segment_size: 256 53 | save_file: false 54 | enable_offload: true 55 | 56 | translation: 57 | deepl: 58 | api_key: "" 59 | is_pro: false 60 | source_lang: "Automatic Detection" 61 | target_lang: "English" 62 | nllb: 63 | model_size: "facebook/nllb-200-1.3B" 64 | source_lang: null 65 | target_lang: null 66 | max_length: 200 67 | add_timestamp: false 68 | -------------------------------------------------------------------------------- /docker-compose.yaml: -------------------------------------------------------------------------------- 1 | services: 2 | whisper-webui: 3 | container_name: whisper-webui 4 | build: . 5 | image: jhj0517/whisper-webui:latest 6 | 7 | volumes: 8 | # You can mount the container's volume paths to directory paths on your local machine. 9 | # Models will be stored in the `./models' directory on your machine. 10 | # Similarly, all output files will be stored in the `./outputs` directory. 11 | - ./models:/Whisper-WebUI/models 12 | - ./outputs:/Whisper-WebUI/outputs 13 | - ./configs:/Whisper-WebUI/configs 14 | 15 | ports: 16 | - "7860:7860" 17 | 18 | stdin_open: true 19 | tty: true 20 | 21 | entrypoint: ["python", "app.py", "--server_port", "7860", "--server_name", "0.0.0.0",] 22 | 23 | # If you're not using nvidia GPU, Update device to match yours. 24 | # See more info at : https://docs.docker.com/compose/compose-file/deploy/#driver 25 | # You can remove the entire `deploy' section if you are using CPU. 26 | deploy: 27 | resources: 28 | reservations: 29 | devices: 30 | - driver: nvidia 31 | count: all 32 | capabilities: [ gpu ] 33 | -------------------------------------------------------------------------------- /models/Diarization/diarization_models_will_be_saved_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/models/Diarization/diarization_models_will_be_saved_here -------------------------------------------------------------------------------- /models/NLLB/nllb_models_will_be_saved_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/models/NLLB/nllb_models_will_be_saved_here -------------------------------------------------------------------------------- /models/UVR/uvr_models_will_be_saved_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/models/UVR/uvr_models_will_be_saved_here -------------------------------------------------------------------------------- /models/Whisper/faster-whisper/faster_whisper_models_will_be_saved_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/models/Whisper/faster-whisper/faster_whisper_models_will_be_saved_here -------------------------------------------------------------------------------- /models/Whisper/insanely-fast-whisper/insanely_fast_whisper_models_will_be_saved_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/models/Whisper/insanely-fast-whisper/insanely_fast_whisper_models_will_be_saved_here -------------------------------------------------------------------------------- /models/Whisper/whisper_models_will_be_saved_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/models/Whisper/whisper_models_will_be_saved_here -------------------------------------------------------------------------------- /models/models_will_be_saved_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/models/models_will_be_saved_here -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/modules/__init__.py -------------------------------------------------------------------------------- /modules/diarize/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/modules/diarize/__init__.py -------------------------------------------------------------------------------- /modules/diarize/audio_loader.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/m-bain/whisperX/blob/main/whisperx/audio.py 2 | 3 | import os 4 | import subprocess 5 | from functools import lru_cache 6 | from typing import Optional, Union 7 | from scipy.io.wavfile import write 8 | import tempfile 9 | import numpy as np 10 | import torch 11 | import torch.nn.functional as F 12 | 13 | def exact_div(x, y): 14 | assert x % y == 0 15 | return x // y 16 | 17 | # hard-coded audio hyperparameters 18 | SAMPLE_RATE = 16000 19 | N_FFT = 400 20 | HOP_LENGTH = 160 21 | CHUNK_LENGTH = 30 22 | N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk 23 | N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input 24 | 25 | N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2 26 | FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame 27 | TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token 28 | 29 | 30 | def load_audio(file: Union[str, np.ndarray], sr: int = SAMPLE_RATE) -> np.ndarray: 31 | """ 32 | Open an audio file or process a numpy array containing audio data as mono waveform, resampling as necessary. 33 | 34 | Parameters 35 | ---------- 36 | file: Union[str, np.ndarray] 37 | The audio file to open or a numpy array containing the audio data. 38 | 39 | sr: int 40 | The sample rate to resample the audio if necessary. 41 | 42 | Returns 43 | ------- 44 | A NumPy array containing the audio waveform, in float32 dtype. 45 | """ 46 | if isinstance(file, np.ndarray): 47 | if file.dtype != np.float32: 48 | file = file.astype(np.float32) 49 | if file.ndim > 1: 50 | file = np.mean(file, axis=1) 51 | 52 | temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") 53 | write(temp_file.name, SAMPLE_RATE, (file * 32768).astype(np.int16)) 54 | temp_file_path = temp_file.name 55 | temp_file.close() 56 | else: 57 | temp_file_path = file 58 | 59 | try: 60 | cmd = [ 61 | "ffmpeg", 62 | "-nostdin", 63 | "-threads", 64 | "0", 65 | "-i", 66 | temp_file_path, 67 | "-f", 68 | "s16le", 69 | "-ac", 70 | "1", 71 | "-acodec", 72 | "pcm_s16le", 73 | "-ar", 74 | str(sr), 75 | "-", 76 | ] 77 | out = subprocess.run(cmd, capture_output=True, check=True).stdout 78 | except subprocess.CalledProcessError as e: 79 | raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e 80 | finally: 81 | if isinstance(file, np.ndarray): 82 | os.remove(temp_file_path) 83 | 84 | return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 85 | 86 | 87 | 88 | def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): 89 | """ 90 | Pad or trim the audio array to N_SAMPLES, as expected by the encoder. 91 | """ 92 | if torch.is_tensor(array): 93 | if array.shape[axis] > length: 94 | array = array.index_select( 95 | dim=axis, index=torch.arange(length, device=array.device) 96 | ) 97 | 98 | if array.shape[axis] < length: 99 | pad_widths = [(0, 0)] * array.ndim 100 | pad_widths[axis] = (0, length - array.shape[axis]) 101 | array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes]) 102 | else: 103 | if array.shape[axis] > length: 104 | array = array.take(indices=range(length), axis=axis) 105 | 106 | if array.shape[axis] < length: 107 | pad_widths = [(0, 0)] * array.ndim 108 | pad_widths[axis] = (0, length - array.shape[axis]) 109 | array = np.pad(array, pad_widths) 110 | 111 | return array 112 | 113 | 114 | @lru_cache(maxsize=None) 115 | def mel_filters(device, n_mels: int) -> torch.Tensor: 116 | """ 117 | load the mel filterbank matrix for projecting STFT into a Mel spectrogram. 118 | Allows decoupling librosa dependency; saved using: 119 | 120 | np.savez_compressed( 121 | "mel_filters.npz", 122 | mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), 123 | ) 124 | """ 125 | assert n_mels in [80, 128], f"Unsupported n_mels: {n_mels}" 126 | with np.load( 127 | os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz") 128 | ) as f: 129 | return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) 130 | 131 | 132 | def log_mel_spectrogram( 133 | audio: Union[str, np.ndarray, torch.Tensor], 134 | n_mels: int, 135 | padding: int = 0, 136 | device: Optional[Union[str, torch.device]] = None, 137 | ): 138 | """ 139 | Compute the log-Mel spectrogram of 140 | 141 | Parameters 142 | ---------- 143 | audio: Union[str, np.ndarray, torch.Tensor], shape = (*) 144 | The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz 145 | 146 | n_mels: int 147 | The number of Mel-frequency filters, only 80 is supported 148 | 149 | padding: int 150 | Number of zero samples to pad to the right 151 | 152 | device: Optional[Union[str, torch.device]] 153 | If given, the audio tensor is moved to this device before STFT 154 | 155 | Returns 156 | ------- 157 | torch.Tensor, shape = (80, n_frames) 158 | A Tensor that contains the Mel spectrogram 159 | """ 160 | if not torch.is_tensor(audio): 161 | if isinstance(audio, str): 162 | audio = load_audio(audio) 163 | audio = torch.from_numpy(audio) 164 | 165 | if device is not None: 166 | audio = audio.to(device) 167 | if padding > 0: 168 | audio = F.pad(audio, (0, padding)) 169 | window = torch.hann_window(N_FFT).to(audio.device) 170 | stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) 171 | magnitudes = stft[..., :-1].abs() ** 2 172 | 173 | filters = mel_filters(audio.device, n_mels) 174 | mel_spec = filters @ magnitudes 175 | 176 | log_spec = torch.clamp(mel_spec, min=1e-10).log10() 177 | log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) 178 | log_spec = (log_spec + 4.0) / 4.0 179 | return log_spec -------------------------------------------------------------------------------- /modules/diarize/diarize_pipeline.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/m-bain/whisperX/blob/main/whisperx/diarize.py 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import os 6 | from pyannote.audio import Pipeline 7 | from typing import Optional, Union 8 | import torch 9 | 10 | from modules.whisper.data_classes import * 11 | from modules.utils.paths import DIARIZATION_MODELS_DIR 12 | from modules.diarize.audio_loader import load_audio, SAMPLE_RATE 13 | 14 | 15 | class DiarizationPipeline: 16 | def __init__( 17 | self, 18 | model_name="pyannote/speaker-diarization-3.1", 19 | cache_dir: str = DIARIZATION_MODELS_DIR, 20 | use_auth_token=None, 21 | device: Optional[Union[str, torch.device]] = "cpu", 22 | ): 23 | if isinstance(device, str): 24 | device = torch.device(device) 25 | self.model = Pipeline.from_pretrained( 26 | model_name, 27 | use_auth_token=use_auth_token, 28 | cache_dir=cache_dir 29 | ).to(device) 30 | 31 | def __call__(self, audio: Union[str, np.ndarray], min_speakers=None, max_speakers=None): 32 | if isinstance(audio, str): 33 | audio = load_audio(audio) 34 | audio_data = { 35 | 'waveform': torch.from_numpy(audio[None, :]), 36 | 'sample_rate': SAMPLE_RATE 37 | } 38 | segments = self.model(audio_data, min_speakers=min_speakers, max_speakers=max_speakers) 39 | diarize_df = pd.DataFrame(segments.itertracks(yield_label=True), columns=['segment', 'label', 'speaker']) 40 | diarize_df['start'] = diarize_df['segment'].apply(lambda x: x.start) 41 | diarize_df['end'] = diarize_df['segment'].apply(lambda x: x.end) 42 | return diarize_df 43 | 44 | 45 | def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False): 46 | transcript_segments = transcript_result["segments"] 47 | if transcript_segments and isinstance(transcript_segments[0], Segment): 48 | transcript_segments = [seg.model_dump() for seg in transcript_segments] 49 | for seg in transcript_segments: 50 | # assign speaker to segment (if any) 51 | diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'], 52 | seg['start']) 53 | diarize_df['union'] = np.maximum(diarize_df['end'], seg['end']) - np.minimum(diarize_df['start'], seg['start']) 54 | 55 | intersected = diarize_df[diarize_df["intersection"] > 0] 56 | 57 | speaker = None 58 | if len(intersected) > 0: 59 | # Choosing most strong intersection 60 | speaker = intersected.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0] 61 | elif fill_nearest: 62 | # Otherwise choosing closest 63 | speaker = diarize_df.sort_values(by=["intersection"], ascending=False)["speaker"].values[0] 64 | 65 | if speaker is not None: 66 | seg["speaker"] = speaker 67 | 68 | # assign speaker to words 69 | if 'words' in seg and seg['words'] is not None: 70 | for word in seg['words']: 71 | if 'start' in word: 72 | diarize_df['intersection'] = np.minimum(diarize_df['end'], word['end']) - np.maximum( 73 | diarize_df['start'], word['start']) 74 | diarize_df['union'] = np.maximum(diarize_df['end'], word['end']) - np.minimum(diarize_df['start'], 75 | word['start']) 76 | 77 | intersected = diarize_df[diarize_df["intersection"] > 0] 78 | 79 | word_speaker = None 80 | if len(intersected) > 0: 81 | # Choosing most strong intersection 82 | word_speaker = \ 83 | intersected.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0] 84 | elif fill_nearest: 85 | # Otherwise choosing closest 86 | word_speaker = diarize_df.sort_values(by=["intersection"], ascending=False)["speaker"].values[0] 87 | 88 | if word_speaker is not None: 89 | word["speaker"] = word_speaker 90 | 91 | return {"segments": transcript_segments} 92 | 93 | 94 | class DiarizationSegment: 95 | def __init__(self, start, end, speaker=None): 96 | self.start = start 97 | self.end = end 98 | self.speaker = speaker 99 | -------------------------------------------------------------------------------- /modules/diarize/diarizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from typing import List, Union, BinaryIO, Optional, Tuple 4 | import numpy as np 5 | import time 6 | import logging 7 | import gc 8 | 9 | from modules.utils.paths import DIARIZATION_MODELS_DIR 10 | from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers 11 | from modules.diarize.audio_loader import load_audio 12 | from modules.whisper.data_classes import * 13 | 14 | 15 | class Diarizer: 16 | def __init__(self, 17 | model_dir: str = DIARIZATION_MODELS_DIR 18 | ): 19 | self.device = self.get_device() 20 | self.available_device = self.get_available_device() 21 | self.compute_type = "float16" 22 | self.model_dir = model_dir 23 | os.makedirs(self.model_dir, exist_ok=True) 24 | self.pipe = None 25 | 26 | def run(self, 27 | audio: Union[str, BinaryIO, np.ndarray], 28 | transcribed_result: List[Segment], 29 | use_auth_token: str, 30 | device: Optional[str] = None 31 | ) -> Tuple[List[Segment], float]: 32 | """ 33 | Diarize transcribed result as a post-processing 34 | 35 | Parameters 36 | ---------- 37 | audio: Union[str, BinaryIO, np.ndarray] 38 | Audio input. This can be file path or binary type. 39 | transcribed_result: List[Segment] 40 | transcribed result through whisper. 41 | use_auth_token: str 42 | Huggingface token with READ permission. This is only needed the first time you download the model. 43 | You must manually go to the website https://huggingface.co/pyannote/speaker-diarization-3.1 and agree to their TOS to download the model. 44 | device: Optional[str] 45 | Device for diarization. 46 | 47 | Returns 48 | ---------- 49 | segments_result: List[Segment] 50 | list of Segment that includes start, end timestamps and transcribed text 51 | elapsed_time: float 52 | elapsed time for running 53 | """ 54 | start_time = time.time() 55 | 56 | if device is None: 57 | device = self.device 58 | 59 | if device != self.device or self.pipe is None: 60 | self.update_pipe( 61 | device=device, 62 | use_auth_token=use_auth_token 63 | ) 64 | 65 | audio = load_audio(audio) 66 | 67 | diarization_segments = self.pipe(audio) 68 | diarized_result = assign_word_speakers( 69 | diarization_segments, 70 | {"segments": transcribed_result} 71 | ) 72 | 73 | segments_result = [] 74 | for segment in diarized_result["segments"]: 75 | speaker = "None" 76 | if "speaker" in segment: 77 | speaker = segment["speaker"] 78 | diarized_text = speaker + "|" + segment["text"].strip() 79 | segments_result.append(Segment( 80 | start=segment["start"], 81 | end=segment["end"], 82 | text=diarized_text 83 | )) 84 | 85 | elapsed_time = time.time() - start_time 86 | return segments_result, elapsed_time 87 | 88 | def update_pipe(self, 89 | use_auth_token: Optional[str] = None, 90 | device: Optional[str] = None, 91 | ): 92 | """ 93 | Set pipeline for diarization 94 | 95 | Parameters 96 | ---------- 97 | use_auth_token: str 98 | Huggingface token with READ permission. This is only needed the first time you download the model. 99 | You must manually go to the website https://huggingface.co/pyannote/speaker-diarization-3.1 and agree to their TOS to download the model. 100 | device: str 101 | Device for diarization. 102 | """ 103 | if device is None: 104 | device = self.get_device() 105 | self.device = device 106 | 107 | os.makedirs(self.model_dir, exist_ok=True) 108 | 109 | if (not os.listdir(self.model_dir) and 110 | not use_auth_token): 111 | print( 112 | "\nFailed to diarize. You need huggingface token and agree to their requirements to download the diarization model.\n" 113 | "Go to \"https://huggingface.co/pyannote/speaker-diarization-3.1\" and follow their instructions to download the model.\n" 114 | ) 115 | return 116 | 117 | logger = logging.getLogger("speechbrain.utils.train_logger") 118 | # Disable redundant torchvision warning message 119 | logger.disabled = True 120 | self.pipe = DiarizationPipeline( 121 | use_auth_token=use_auth_token, 122 | device=device, 123 | cache_dir=self.model_dir 124 | ) 125 | logger.disabled = False 126 | 127 | def offload(self): 128 | """Offload the model and free up the memory""" 129 | if self.pipe is not None: 130 | del self.pipe 131 | self.pipe = None 132 | if self.device == "cuda": 133 | torch.cuda.empty_cache() 134 | torch.cuda.reset_max_memory_allocated() 135 | if self.device == "xpu": 136 | torch.xpu.empty_cache() 137 | torch.xpu.reset_accumulated_memory_stats() 138 | torch.xpu.reset_peak_memory_stats() 139 | gc.collect() 140 | 141 | @staticmethod 142 | def get_device(): 143 | if torch.cuda.is_available(): 144 | return "cuda" 145 | if torch.xpu.is_available(): 146 | return "xpu" 147 | elif torch.backends.mps.is_available(): 148 | return "mps" 149 | else: 150 | return "cpu" 151 | 152 | @staticmethod 153 | def get_available_device(): 154 | devices = ["cpu"] 155 | if torch.cuda.is_available(): 156 | devices.append("cuda") 157 | if torch.xpu.is_available(): 158 | devices.append("xpu") 159 | if torch.backends.mps.is_available(): 160 | devices.append("mps") 161 | return devices -------------------------------------------------------------------------------- /modules/translation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/modules/translation/__init__.py -------------------------------------------------------------------------------- /modules/translation/deepl_api.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import time 3 | import os 4 | from datetime import datetime 5 | import gradio as gr 6 | 7 | from modules.utils.paths import TRANSLATION_OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH 8 | from modules.utils.constants import AUTOMATIC_DETECTION 9 | from modules.utils.subtitle_manager import * 10 | from modules.utils.files_manager import load_yaml, save_yaml 11 | 12 | """ 13 | This is written with reference to the DeepL API documentation. 14 | If you want to know the information of the DeepL API, see here: https://www.deepl.com/docs-api/documents 15 | """ 16 | 17 | DEEPL_AVAILABLE_TARGET_LANGS = { 18 | 'Bulgarian': 'BG', 19 | 'Czech': 'CS', 20 | 'Danish': 'DA', 21 | 'German': 'DE', 22 | 'Greek': 'EL', 23 | 'English': 'EN', 24 | 'English (British)': 'EN-GB', 25 | 'English (American)': 'EN-US', 26 | 'Spanish': 'ES', 27 | 'Estonian': 'ET', 28 | 'Finnish': 'FI', 29 | 'French': 'FR', 30 | 'Hungarian': 'HU', 31 | 'Indonesian': 'ID', 32 | 'Italian': 'IT', 33 | 'Japanese': 'JA', 34 | 'Korean': 'KO', 35 | 'Lithuanian': 'LT', 36 | 'Latvian': 'LV', 37 | 'Norwegian (Bokmål)': 'NB', 38 | 'Dutch': 'NL', 39 | 'Polish': 'PL', 40 | 'Portuguese': 'PT', 41 | 'Portuguese (Brazilian)': 'PT-BR', 42 | 'Portuguese (all Portuguese varieties excluding Brazilian Portuguese)': 'PT-PT', 43 | 'Romanian': 'RO', 44 | 'Russian': 'RU', 45 | 'Slovak': 'SK', 46 | 'Slovenian': 'SL', 47 | 'Swedish': 'SV', 48 | 'Turkish': 'TR', 49 | 'Ukrainian': 'UK', 50 | 'Chinese (simplified)': 'ZH' 51 | } 52 | 53 | DEEPL_AVAILABLE_SOURCE_LANGS = { 54 | AUTOMATIC_DETECTION: None, 55 | 'Bulgarian': 'BG', 56 | 'Czech': 'CS', 57 | 'Danish': 'DA', 58 | 'German': 'DE', 59 | 'Greek': 'EL', 60 | 'English': 'EN', 61 | 'Spanish': 'ES', 62 | 'Estonian': 'ET', 63 | 'Finnish': 'FI', 64 | 'French': 'FR', 65 | 'Hungarian': 'HU', 66 | 'Indonesian': 'ID', 67 | 'Italian': 'IT', 68 | 'Japanese': 'JA', 69 | 'Korean': 'KO', 70 | 'Lithuanian': 'LT', 71 | 'Latvian': 'LV', 72 | 'Norwegian (Bokmål)': 'NB', 73 | 'Dutch': 'NL', 74 | 'Polish': 'PL', 75 | 'Portuguese (all Portuguese varieties mixed)': 'PT', 76 | 'Romanian': 'RO', 77 | 'Russian': 'RU', 78 | 'Slovak': 'SK', 79 | 'Slovenian': 'SL', 80 | 'Swedish': 'SV', 81 | 'Turkish': 'TR', 82 | 'Ukrainian': 'UK', 83 | 'Chinese': 'ZH' 84 | } 85 | 86 | 87 | class DeepLAPI: 88 | def __init__(self, 89 | output_dir: str = TRANSLATION_OUTPUT_DIR 90 | ): 91 | self.api_interval = 1 92 | self.max_text_batch_size = 50 93 | self.available_target_langs = DEEPL_AVAILABLE_TARGET_LANGS 94 | self.available_source_langs = DEEPL_AVAILABLE_SOURCE_LANGS 95 | self.output_dir = output_dir 96 | 97 | def translate_deepl(self, 98 | auth_key: str, 99 | fileobjs: list, 100 | source_lang: str, 101 | target_lang: str, 102 | is_pro: bool = False, 103 | add_timestamp: bool = True, 104 | progress=gr.Progress()) -> list: 105 | """ 106 | Translate subtitle files using DeepL API 107 | Parameters 108 | ---------- 109 | auth_key: str 110 | API Key for DeepL from gr.Textbox() 111 | fileobjs: list 112 | List of files to transcribe from gr.Files() 113 | source_lang: str 114 | Source language of the file to transcribe from gr.Dropdown() 115 | target_lang: str 116 | Target language of the file to transcribe from gr.Dropdown() 117 | is_pro: str 118 | Boolean value that is about pro user or not from gr.Checkbox(). 119 | add_timestamp: bool 120 | Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename. 121 | progress: gr.Progress 122 | Indicator to show progress directly in gradio. 123 | 124 | Returns 125 | ---------- 126 | A List of 127 | String to return to gr.Textbox() 128 | Files to return to gr.Files() 129 | """ 130 | if fileobjs and isinstance(fileobjs[0], gr.utils.NamedString): 131 | fileobjs = [fileobj.name for fileobj in fileobjs] 132 | 133 | self.cache_parameters( 134 | api_key=auth_key, 135 | is_pro=is_pro, 136 | source_lang=source_lang, 137 | target_lang=target_lang, 138 | add_timestamp=add_timestamp 139 | ) 140 | 141 | files_info = {} 142 | for file_path in fileobjs: 143 | file_name, file_ext = os.path.splitext(os.path.basename(file_path)) 144 | writer = get_writer(file_ext, self.output_dir) 145 | segments = writer.to_segments(file_path) 146 | 147 | batch_size = self.max_text_batch_size 148 | for batch_start in range(0, len(segments), batch_size): 149 | progress(batch_start / len(segments), desc="Translating..") 150 | sentences_to_translate = [seg.text for seg in segments[batch_start:batch_start+batch_size]] 151 | translated_texts = self.request_deepl_translate(auth_key, sentences_to_translate, source_lang, 152 | target_lang, is_pro) 153 | for i, translated_text in enumerate(translated_texts): 154 | segments[batch_start + i].text = translated_text["text"] 155 | 156 | subtitle, output_path = generate_file( 157 | output_dir=self.output_dir, 158 | output_file_name=file_name, 159 | output_format=file_ext, 160 | result=segments, 161 | add_timestamp=add_timestamp 162 | ) 163 | 164 | files_info[file_name] = {"subtitle": subtitle, "path": output_path} 165 | 166 | total_result = '' 167 | for file_name, info in files_info.items(): 168 | total_result += '------------------------------------\n' 169 | total_result += f'{file_name}\n\n' 170 | total_result += f'{info["subtitle"]}' 171 | gr_str = f"Done! Subtitle is in the outputs/translation folder.\n\n{total_result}" 172 | 173 | output_file_paths = [item["path"] for key, item in files_info.items()] 174 | return [gr_str, output_file_paths] 175 | 176 | def request_deepl_translate(self, 177 | auth_key: str, 178 | text: list, 179 | source_lang: str, 180 | target_lang: str, 181 | is_pro: bool = False): 182 | """Request API response to DeepL server""" 183 | if source_lang not in list(DEEPL_AVAILABLE_SOURCE_LANGS.keys()): 184 | raise ValueError(f"Source language {source_lang} is not supported." 185 | f"Use one of {list(DEEPL_AVAILABLE_SOURCE_LANGS.keys())}") 186 | if target_lang not in list(DEEPL_AVAILABLE_TARGET_LANGS.keys()): 187 | raise ValueError(f"Target language {target_lang} is not supported." 188 | f"Use one of {list(DEEPL_AVAILABLE_TARGET_LANGS.keys())}") 189 | 190 | url = 'https://api.deepl.com/v2/translate' if is_pro else 'https://api-free.deepl.com/v2/translate' 191 | headers = { 192 | 'Authorization': f'DeepL-Auth-Key {auth_key}' 193 | } 194 | data = { 195 | 'text': text, 196 | 'source_lang': DEEPL_AVAILABLE_SOURCE_LANGS[source_lang], 197 | 'target_lang': DEEPL_AVAILABLE_TARGET_LANGS[target_lang] 198 | } 199 | response = requests.post(url, headers=headers, data=data).json() 200 | time.sleep(self.api_interval) 201 | return response["translations"] 202 | 203 | @staticmethod 204 | def cache_parameters(api_key: str, 205 | is_pro: bool, 206 | source_lang: str, 207 | target_lang: str, 208 | add_timestamp: bool): 209 | cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH) 210 | cached_params["translation"]["deepl"] = { 211 | "api_key": api_key, 212 | "is_pro": is_pro, 213 | "source_lang": source_lang, 214 | "target_lang": target_lang 215 | } 216 | cached_params["translation"]["add_timestamp"] = add_timestamp 217 | save_yaml(cached_params, DEFAULT_PARAMETERS_CONFIG_PATH) 218 | -------------------------------------------------------------------------------- /modules/translation/nllb_inference.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline 2 | import gradio as gr 3 | import os 4 | 5 | from modules.utils.paths import TRANSLATION_OUTPUT_DIR, NLLB_MODELS_DIR 6 | from modules.translation.translation_base import TranslationBase 7 | 8 | 9 | class NLLBInference(TranslationBase): 10 | def __init__(self, 11 | model_dir: str = NLLB_MODELS_DIR, 12 | output_dir: str = TRANSLATION_OUTPUT_DIR 13 | ): 14 | super().__init__( 15 | model_dir=model_dir, 16 | output_dir=output_dir 17 | ) 18 | self.tokenizer = None 19 | self.available_models = ["facebook/nllb-200-3.3B", "facebook/nllb-200-1.3B", "facebook/nllb-200-distilled-600M"] 20 | self.available_source_langs = list(NLLB_AVAILABLE_LANGS.keys()) 21 | self.available_target_langs = list(NLLB_AVAILABLE_LANGS.keys()) 22 | self.pipeline = None 23 | 24 | def translate(self, 25 | text: str, 26 | max_length: int 27 | ): 28 | result = self.pipeline( 29 | text, 30 | max_length=max_length 31 | ) 32 | return result[0]["translation_text"] 33 | 34 | def update_model(self, 35 | model_size: str, 36 | src_lang: str, 37 | tgt_lang: str, 38 | progress: gr.Progress = gr.Progress() 39 | ): 40 | def validate_language(lang: str) -> str: 41 | if lang in NLLB_AVAILABLE_LANGS: 42 | return NLLB_AVAILABLE_LANGS[lang] 43 | elif lang not in NLLB_AVAILABLE_LANGS.values(): 44 | raise ValueError(f"Language '{lang}' is not supported. Use one of: {list(NLLB_AVAILABLE_LANGS.keys())}") 45 | return lang 46 | 47 | src_lang = validate_language(src_lang) 48 | tgt_lang = validate_language(tgt_lang) 49 | 50 | if model_size != self.current_model_size or self.model is None: 51 | print("\nInitializing NLLB Model..\n") 52 | progress(0, desc="Initializing NLLB Model..") 53 | self.current_model_size = model_size 54 | local_files_only = self.is_model_exists(self.current_model_size) 55 | self.model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path=model_size, 56 | cache_dir=self.model_dir, 57 | local_files_only=local_files_only) 58 | self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_size, 59 | cache_dir=os.path.join(self.model_dir, "tokenizers"), 60 | local_files_only=local_files_only) 61 | 62 | self.pipeline = pipeline("translation", 63 | model=self.model, 64 | tokenizer=self.tokenizer, 65 | src_lang=src_lang, 66 | tgt_lang=tgt_lang, 67 | device=self.device) 68 | 69 | def is_model_exists(self, 70 | model_size: str): 71 | """Check if model exists or not (Only facebook model)""" 72 | prefix = "models--facebook--" 73 | _id, model_size_name = model_size.split("/") 74 | model_dir_name = prefix + model_size_name 75 | model_dir_path = os.path.join(self.model_dir, model_dir_name) 76 | if os.path.exists(model_dir_path) and os.listdir(model_dir_path): 77 | return True 78 | for model_dir_name in os.listdir(self.model_dir): 79 | if (model_size in model_dir_name or model_size_name in model_dir_name) and \ 80 | os.listdir(os.path.join(self.model_dir, model_dir_name)): 81 | return True 82 | return False 83 | 84 | 85 | NLLB_AVAILABLE_LANGS = { 86 | "Acehnese (Arabic script)": "ace_Arab", 87 | "Acehnese (Latin script)": "ace_Latn", 88 | "Mesopotamian Arabic": "acm_Arab", 89 | "Ta’izzi-Adeni Arabic": "acq_Arab", 90 | "Tunisian Arabic": "aeb_Arab", 91 | "Afrikaans": "afr_Latn", 92 | "South Levantine Arabic": "ajp_Arab", 93 | "Akan": "aka_Latn", 94 | "Amharic": "amh_Ethi", 95 | "North Levantine Arabic": "apc_Arab", 96 | "Modern Standard Arabic": "arb_Arab", 97 | "Modern Standard Arabic (Romanized)": "arb_Latn", 98 | "Najdi Arabic": "ars_Arab", 99 | "Moroccan Arabic": "ary_Arab", 100 | "Egyptian Arabic": "arz_Arab", 101 | "Assamese": "asm_Beng", 102 | "Asturian": "ast_Latn", 103 | "Awadhi": "awa_Deva", 104 | "Central Aymara": "ayr_Latn", 105 | "South Azerbaijani": "azb_Arab", 106 | "North Azerbaijani": "azj_Latn", 107 | "Bashkir": "bak_Cyrl", 108 | "Bambara": "bam_Latn", 109 | "Balinese": "ban_Latn", 110 | "Belarusian": "bel_Cyrl", 111 | "Bemba": "bem_Latn", 112 | "Bengali": "ben_Beng", 113 | "Bhojpuri": "bho_Deva", 114 | "Banjar (Arabic script)": "bjn_Arab", 115 | "Banjar (Latin script)": "bjn_Latn", 116 | "Standard Tibetan": "bod_Tibt", 117 | "Bosnian": "bos_Latn", 118 | "Buginese": "bug_Latn", 119 | "Bulgarian": "bul_Cyrl", 120 | "Catalan": "cat_Latn", 121 | "Cebuano": "ceb_Latn", 122 | "Czech": "ces_Latn", 123 | "Chokwe": "cjk_Latn", 124 | "Central Kurdish": "ckb_Arab", 125 | "Crimean Tatar": "crh_Latn", 126 | "Welsh": "cym_Latn", 127 | "Danish": "dan_Latn", 128 | "German": "deu_Latn", 129 | "Southwestern Dinka": "dik_Latn", 130 | "Dyula": "dyu_Latn", 131 | "Dzongkha": "dzo_Tibt", 132 | "Greek": "ell_Grek", 133 | "English": "eng_Latn", 134 | "Esperanto": "epo_Latn", 135 | "Estonian": "est_Latn", 136 | "Basque": "eus_Latn", 137 | "Ewe": "ewe_Latn", 138 | "Faroese": "fao_Latn", 139 | "Fijian": "fij_Latn", 140 | "Finnish": "fin_Latn", 141 | "Fon": "fon_Latn", 142 | "French": "fra_Latn", 143 | "Friulian": "fur_Latn", 144 | "Nigerian Fulfulde": "fuv_Latn", 145 | "Scottish Gaelic": "gla_Latn", 146 | "Irish": "gle_Latn", 147 | "Galician": "glg_Latn", 148 | "Guarani": "grn_Latn", 149 | "Gujarati": "guj_Gujr", 150 | "Haitian Creole": "hat_Latn", 151 | "Hausa": "hau_Latn", 152 | "Hebrew": "heb_Hebr", 153 | "Hindi": "hin_Deva", 154 | "Chhattisgarhi": "hne_Deva", 155 | "Croatian": "hrv_Latn", 156 | "Hungarian": "hun_Latn", 157 | "Armenian": "hye_Armn", 158 | "Igbo": "ibo_Latn", 159 | "Ilocano": "ilo_Latn", 160 | "Indonesian": "ind_Latn", 161 | "Icelandic": "isl_Latn", 162 | "Italian": "ita_Latn", 163 | "Javanese": "jav_Latn", 164 | "Japanese": "jpn_Jpan", 165 | "Kabyle": "kab_Latn", 166 | "Jingpho": "kac_Latn", 167 | "Kamba": "kam_Latn", 168 | "Kannada": "kan_Knda", 169 | "Kashmiri (Arabic script)": "kas_Arab", 170 | "Kashmiri (Devanagari script)": "kas_Deva", 171 | "Georgian": "kat_Geor", 172 | "Central Kanuri (Arabic script)": "knc_Arab", 173 | "Central Kanuri (Latin script)": "knc_Latn", 174 | "Kazakh": "kaz_Cyrl", 175 | "Kabiyè": "kbp_Latn", 176 | "Kabuverdianu": "kea_Latn", 177 | "Khmer": "khm_Khmr", 178 | "Kikuyu": "kik_Latn", 179 | "Kinyarwanda": "kin_Latn", 180 | "Kyrgyz": "kir_Cyrl", 181 | "Kimbundu": "kmb_Latn", 182 | "Northern Kurdish": "kmr_Latn", 183 | "Kikongo": "kon_Latn", 184 | "Korean": "kor_Hang", 185 | "Lao": "lao_Laoo", 186 | "Ligurian": "lij_Latn", 187 | "Limburgish": "lim_Latn", 188 | "Lingala": "lin_Latn", 189 | "Lithuanian": "lit_Latn", 190 | "Lombard": "lmo_Latn", 191 | "Latgalian": "ltg_Latn", 192 | "Luxembourgish": "ltz_Latn", 193 | "Luba-Kasai": "lua_Latn", 194 | "Ganda": "lug_Latn", 195 | "Luo": "luo_Latn", 196 | "Mizo": "lus_Latn", 197 | "Standard Latvian": "lvs_Latn", 198 | "Magahi": "mag_Deva", 199 | "Maithili": "mai_Deva", 200 | "Malayalam": "mal_Mlym", 201 | "Marathi": "mar_Deva", 202 | "Minangkabau (Arabic script)": "min_Arab", 203 | "Minangkabau (Latin script)": "min_Latn", 204 | "Macedonian": "mkd_Cyrl", 205 | "Plateau Malagasy": "plt_Latn", 206 | "Maltese": "mlt_Latn", 207 | "Meitei (Bengali script)": "mni_Beng", 208 | "Halh Mongolian": "khk_Cyrl", 209 | "Mossi": "mos_Latn", 210 | "Maori": "mri_Latn", 211 | "Burmese": "mya_Mymr", 212 | "Dutch": "nld_Latn", 213 | "Norwegian Nynorsk": "nno_Latn", 214 | "Norwegian Bokmål": "nob_Latn", 215 | "Nepali": "npi_Deva", 216 | "Northern Sotho": "nso_Latn", 217 | "Nuer": "nus_Latn", 218 | "Nyanja": "nya_Latn", 219 | "Occitan": "oci_Latn", 220 | "West Central Oromo": "gaz_Latn", 221 | "Odia": "ory_Orya", 222 | "Pangasinan": "pag_Latn", 223 | "Eastern Panjabi": "pan_Guru", 224 | "Papiamento": "pap_Latn", 225 | "Western Persian": "pes_Arab", 226 | "Polish": "pol_Latn", 227 | "Portuguese": "por_Latn", 228 | "Dari": "prs_Arab", 229 | "Southern Pashto": "pbt_Arab", 230 | "Ayacucho Quechua": "quy_Latn", 231 | "Romanian": "ron_Latn", 232 | "Rundi": "run_Latn", 233 | "Russian": "rus_Cyrl", 234 | "Sango": "sag_Latn", 235 | "Sanskrit": "san_Deva", 236 | "Santali": "sat_Olck", 237 | "Sicilian": "scn_Latn", 238 | "Shan": "shn_Mymr", 239 | "Sinhala": "sin_Sinh", 240 | "Slovak": "slk_Latn", 241 | "Slovenian": "slv_Latn", 242 | "Samoan": "smo_Latn", 243 | "Shona": "sna_Latn", 244 | "Sindhi": "snd_Arab", 245 | "Somali": "som_Latn", 246 | "Southern Sotho": "sot_Latn", 247 | "Spanish": "spa_Latn", 248 | "Tosk Albanian": "als_Latn", 249 | "Sardinian": "srd_Latn", 250 | "Serbian": "srp_Cyrl", 251 | "Swati": "ssw_Latn", 252 | "Sundanese": "sun_Latn", 253 | "Swedish": "swe_Latn", 254 | "Swahili": "swh_Latn", 255 | "Silesian": "szl_Latn", 256 | "Tamil": "tam_Taml", 257 | "Tatar": "tat_Cyrl", 258 | "Telugu": "tel_Telu", 259 | "Tajik": "tgk_Cyrl", 260 | "Tagalog": "tgl_Latn", 261 | "Thai": "tha_Thai", 262 | "Tigrinya": "tir_Ethi", 263 | "Tamasheq (Latin script)": "taq_Latn", 264 | "Tamasheq (Tifinagh script)": "taq_Tfng", 265 | "Tok Pisin": "tpi_Latn", 266 | "Tswana": "tsn_Latn", 267 | "Tsonga": "tso_Latn", 268 | "Turkmen": "tuk_Latn", 269 | "Tumbuka": "tum_Latn", 270 | "Turkish": "tur_Latn", 271 | "Twi": "twi_Latn", 272 | "Central Atlas Tamazight": "tzm_Tfng", 273 | "Uyghur": "uig_Arab", 274 | "Ukrainian": "ukr_Cyrl", 275 | "Umbundu": "umb_Latn", 276 | "Urdu": "urd_Arab", 277 | "Northern Uzbek": "uzn_Latn", 278 | "Venetian": "vec_Latn", 279 | "Vietnamese": "vie_Latn", 280 | "Waray": "war_Latn", 281 | "Wolof": "wol_Latn", 282 | "Xhosa": "xho_Latn", 283 | "Eastern Yiddish": "ydd_Hebr", 284 | "Yoruba": "yor_Latn", 285 | "Yue Chinese": "yue_Hant", 286 | "Chinese (Simplified)": "zho_Hans", 287 | "Chinese (Traditional)": "zho_Hant", 288 | "Standard Malay": "zsm_Latn", 289 | "Zulu": "zul_Latn", 290 | } 291 | -------------------------------------------------------------------------------- /modules/translation/translation_base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import gradio as gr 4 | from abc import ABC, abstractmethod 5 | import gc 6 | from typing import List 7 | from datetime import datetime 8 | 9 | import modules.translation.nllb_inference as nllb 10 | from modules.whisper.data_classes import * 11 | from modules.utils.subtitle_manager import * 12 | from modules.utils.files_manager import load_yaml, save_yaml 13 | from modules.utils.paths import DEFAULT_PARAMETERS_CONFIG_PATH, NLLB_MODELS_DIR, TRANSLATION_OUTPUT_DIR 14 | 15 | 16 | class TranslationBase(ABC): 17 | def __init__(self, 18 | model_dir: str = NLLB_MODELS_DIR, 19 | output_dir: str = TRANSLATION_OUTPUT_DIR 20 | ): 21 | super().__init__() 22 | self.model = None 23 | self.model_dir = model_dir 24 | self.output_dir = output_dir 25 | os.makedirs(self.model_dir, exist_ok=True) 26 | os.makedirs(self.output_dir, exist_ok=True) 27 | self.current_model_size = None 28 | self.device = self.get_device() 29 | 30 | @abstractmethod 31 | def translate(self, 32 | text: str, 33 | max_length: int 34 | ): 35 | pass 36 | 37 | @abstractmethod 38 | def update_model(self, 39 | model_size: str, 40 | src_lang: str, 41 | tgt_lang: str, 42 | progress: gr.Progress = gr.Progress() 43 | ): 44 | pass 45 | 46 | def translate_file(self, 47 | fileobjs: list, 48 | model_size: str, 49 | src_lang: str, 50 | tgt_lang: str, 51 | max_length: int = 200, 52 | add_timestamp: bool = True, 53 | progress=gr.Progress()) -> list: 54 | """ 55 | Translate subtitle file from source language to target language 56 | 57 | Parameters 58 | ---------- 59 | fileobjs: list 60 | List of files to transcribe from gr.Files() 61 | model_size: str 62 | Whisper model size from gr.Dropdown() 63 | src_lang: str 64 | Source language of the file to translate from gr.Dropdown() 65 | tgt_lang: str 66 | Target language of the file to translate from gr.Dropdown() 67 | max_length: int 68 | Max length per line to translate 69 | add_timestamp: bool 70 | Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename. 71 | progress: gr.Progress 72 | Indicator to show progress directly in gradio. 73 | I use a forked version of whisper for this. To see more info : https://github.com/jhj0517/jhj0517-whisper/tree/add-progress-callback 74 | 75 | Returns 76 | ---------- 77 | A List of 78 | String to return to gr.Textbox() 79 | Files to return to gr.Files() 80 | """ 81 | try: 82 | if fileobjs and isinstance(fileobjs[0], gr.utils.NamedString): 83 | fileobjs = [file.name for file in fileobjs] 84 | 85 | self.cache_parameters(model_size=model_size, 86 | src_lang=src_lang, 87 | tgt_lang=tgt_lang, 88 | max_length=max_length, 89 | add_timestamp=add_timestamp) 90 | 91 | self.update_model(model_size=model_size, 92 | src_lang=src_lang, 93 | tgt_lang=tgt_lang, 94 | progress=progress) 95 | 96 | files_info = {} 97 | for fileobj in fileobjs: 98 | file_name, file_ext = os.path.splitext(os.path.basename(fileobj)) 99 | writer = get_writer(file_ext, self.output_dir) 100 | segments = writer.to_segments(fileobj) 101 | for i, segment in enumerate(segments): 102 | progress(i / len(segments), desc="Translating..") 103 | translated_text = self.translate(segment.text, max_length=max_length) 104 | segment.text = translated_text 105 | 106 | subtitle, file_path = generate_file( 107 | output_dir=self.output_dir, 108 | output_file_name=file_name, 109 | output_format=file_ext, 110 | result=segments, 111 | add_timestamp=add_timestamp 112 | ) 113 | 114 | files_info[file_name] = {"subtitle": subtitle, "path": file_path} 115 | 116 | total_result = '' 117 | for file_name, info in files_info.items(): 118 | total_result += '------------------------------------\n' 119 | total_result += f'{file_name}\n\n' 120 | total_result += f'{info["subtitle"]}' 121 | gr_str = f"Done! Subtitle is in the outputs/translation folder.\n\n{total_result}" 122 | 123 | output_file_paths = [item["path"] for key, item in files_info.items()] 124 | return [gr_str, output_file_paths] 125 | 126 | except Exception as e: 127 | print(f"Error translating file: {e}") 128 | raise 129 | finally: 130 | self.offload() 131 | 132 | @staticmethod 133 | def get_device(): 134 | if torch.cuda.is_available(): 135 | return "cuda" 136 | if torch.xpu.is_available(): 137 | return "xpu" 138 | elif torch.backends.mps.is_available(): 139 | return "mps" 140 | else: 141 | return "cpu" 142 | 143 | def offload(self): 144 | """Offload the model and free up the memory""" 145 | if self.model is not None: 146 | del self.model 147 | self.model = None 148 | if self.device == "cuda": 149 | torch.cuda.empty_cache() 150 | torch.cuda.reset_max_memory_allocated() 151 | if self.device == "xpu": 152 | torch.xpu.empty_cache() 153 | torch.xpu.reset_accumulated_memory_stats() 154 | torch.xpu.reset_peak_memory_stats() 155 | gc.collect() 156 | 157 | @staticmethod 158 | def remove_input_files(file_paths: List[str]): 159 | if not file_paths: 160 | return 161 | 162 | for file_path in file_paths: 163 | if file_path and os.path.exists(file_path): 164 | os.remove(file_path) 165 | 166 | @staticmethod 167 | def cache_parameters(model_size: str, 168 | src_lang: str, 169 | tgt_lang: str, 170 | max_length: int, 171 | add_timestamp: bool): 172 | def validate_lang(lang: str): 173 | if lang in list(nllb.NLLB_AVAILABLE_LANGS.values()): 174 | flipped = {value: key for key, value in nllb.NLLB_AVAILABLE_LANGS.items()} 175 | return flipped[lang] 176 | return lang 177 | 178 | cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH) 179 | cached_params["translation"]["nllb"] = { 180 | "model_size": model_size, 181 | "source_lang": validate_lang(src_lang), 182 | "target_lang": validate_lang(tgt_lang), 183 | "max_length": max_length, 184 | } 185 | cached_params["translation"]["add_timestamp"] = add_timestamp 186 | save_yaml(cached_params, DEFAULT_PARAMETERS_CONFIG_PATH) 187 | -------------------------------------------------------------------------------- /modules/ui/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/modules/ui/__init__.py -------------------------------------------------------------------------------- /modules/ui/htmls.py: -------------------------------------------------------------------------------- 1 | CSS = """ 2 | .bmc-button { 3 | padding: 2px 5px; 4 | border-radius: 5px; 5 | background-color: #FF813F; 6 | color: white; 7 | box-shadow: 0px 1px 2px rgba(0, 0, 0, 0.3); 8 | text-decoration: none; 9 | display: inline-block; 10 | font-size: 20px; 11 | margin: 2px; 12 | cursor: pointer; 13 | -webkit-transition: background-color 0.3s ease; 14 | -ms-transition: background-color 0.3s ease; 15 | transition: background-color 0.3s ease; 16 | } 17 | .bmc-button:hover, 18 | .bmc-button:active, 19 | .bmc-button:focus { 20 | background-color: #FF5633; 21 | } 22 | .markdown { 23 | margin-bottom: 0; 24 | padding-bottom: 0; 25 | } 26 | .tabs { 27 | margin-top: 0; 28 | padding-top: 0; 29 | } 30 | 31 | #md_project a { 32 | color: black; 33 | text-decoration: none; 34 | } 35 | #md_project a:hover { 36 | text-decoration: underline; 37 | } 38 | """ 39 | 40 | MARKDOWN = """ 41 | ### [Whisper-WebUI](https://github.com/jhj0517/Whsiper-WebUI) 42 | """ 43 | 44 | 45 | NLLB_VRAM_TABLE = """ 46 | 47 | 48 | 49 | 50 | 51 | 65 | 66 | 67 | 68 |
69 | VRAM usage for each model 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 |
Model nameRequired VRAM
nllb-200-3.3B~16GB
nllb-200-1.3B~8GB
nllb-200-distilled-600M~4GB
92 |

Note: Be mindful of your VRAM! The table above provides an approximate VRAM usage for each model.

93 |
94 | 95 | 96 | 97 | """ -------------------------------------------------------------------------------- /modules/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/modules/utils/__init__.py -------------------------------------------------------------------------------- /modules/utils/audio_manager.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | import soundfile as sf 3 | import os 4 | import numpy as np 5 | from faster_whisper.audio import decode_audio 6 | 7 | from modules.utils.files_manager import is_video 8 | from modules.utils.logger import get_logger 9 | 10 | logger = get_logger() 11 | 12 | 13 | def validate_audio(audio: Optional[str] = None): 14 | """Validate audio file and check if it's corrupted""" 15 | if isinstance(audio, np.ndarray): 16 | return True 17 | 18 | if not os.path.exists(audio): 19 | logger.info(f"The file {audio} does not exist. Please check the path.") 20 | return False 21 | 22 | try: 23 | audio = decode_audio(audio) 24 | return True 25 | except Exception as e: 26 | logger.info(f"The file {audio} is not able to open or corrupted. Please check the file. {e}") 27 | return False 28 | -------------------------------------------------------------------------------- /modules/utils/cli_manager.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def str2bool(v): 5 | if isinstance(v, bool): 6 | return v 7 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 8 | return True 9 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 10 | return False 11 | else: 12 | raise argparse.ArgumentTypeError('Boolean value expected.') -------------------------------------------------------------------------------- /modules/utils/constants.py: -------------------------------------------------------------------------------- 1 | from gradio_i18n import Translate, gettext as _ 2 | 3 | AUTOMATIC_DETECTION = _("Automatic Detection") 4 | GRADIO_NONE_STR = "" 5 | GRADIO_NONE_NUMBER_MAX = 9999 6 | GRADIO_NONE_NUMBER_MIN = 0 7 | -------------------------------------------------------------------------------- /modules/utils/files_manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | import fnmatch 3 | from ruamel.yaml import YAML 4 | from gradio.utils import NamedString 5 | 6 | from modules.utils.paths import DEFAULT_PARAMETERS_CONFIG_PATH 7 | 8 | AUDIO_EXTENSION = ['.mp3', '.wav', '.wma', '.aac', '.flac', '.ogg', '.m4a', '.aiff', '.alac', '.opus', '.webm', '.ac3', 9 | '.amr', '.au', '.mid', '.midi', '.mka'] 10 | 11 | VIDEO_EXTENSION = ['.mp4', '.mkv', '.flv', '.avi', '.mov', '.wmv', '.webm', '.m4v', '.mpeg', '.mpg', '.3gp', 12 | '.f4v', '.ogv', '.vob', '.mts', '.m2ts', '.divx', '.mxf', '.rm', '.rmvb', '.ts'] 13 | 14 | MEDIA_EXTENSION = VIDEO_EXTENSION + AUDIO_EXTENSION 15 | 16 | 17 | def load_yaml(path: str = DEFAULT_PARAMETERS_CONFIG_PATH): 18 | yaml = YAML(typ="safe") 19 | yaml.preserve_quotes = True 20 | with open(path, 'r', encoding='utf-8') as file: 21 | config = yaml.load(file) 22 | return config 23 | 24 | 25 | def save_yaml(data: dict, path: str = DEFAULT_PARAMETERS_CONFIG_PATH): 26 | yaml = YAML(typ="safe") 27 | yaml.map_indent = 2 28 | yaml.sequence_indent = 4 29 | yaml.sequence_dash_offset = 2 30 | yaml.preserve_quotes = True 31 | yaml.default_flow_style = False 32 | yaml.sort_base_mapping_type_on_output = False 33 | 34 | with open(path, 'w', encoding='utf-8') as file: 35 | yaml.dump(data, file) 36 | return path 37 | 38 | 39 | def get_media_files(folder_path, include_sub_directory=False): 40 | media_extensions = ['*' + extension for extension in MEDIA_EXTENSION] 41 | 42 | media_files = [] 43 | 44 | if include_sub_directory: 45 | for root, _, files in os.walk(folder_path): 46 | for extension in media_extensions: 47 | media_files.extend( 48 | os.path.join(root, file) for file in fnmatch.filter(files, extension) 49 | if os.path.exists(os.path.join(root, file)) 50 | ) 51 | else: 52 | for extension in media_extensions: 53 | media_files.extend( 54 | os.path.join(folder_path, file) for file in fnmatch.filter(os.listdir(folder_path), extension) 55 | if os.path.isfile(os.path.join(folder_path, file)) and os.path.exists(os.path.join(folder_path, file)) 56 | ) 57 | 58 | return media_files 59 | 60 | 61 | def format_gradio_files(files: list): 62 | if not files: 63 | return files 64 | 65 | gradio_files = [] 66 | for file in files: 67 | gradio_files.append(NamedString(file)) 68 | return gradio_files 69 | 70 | 71 | def is_video(file_path): 72 | extension = os.path.splitext(file_path)[1].lower() 73 | return extension in VIDEO_EXTENSION 74 | 75 | 76 | def read_file(file_path): 77 | with open(file_path, "r", encoding="utf-8") as f: 78 | subtitle_content = f.read() 79 | return subtitle_content 80 | -------------------------------------------------------------------------------- /modules/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Optional 3 | 4 | 5 | def get_logger(name: Optional[str] = None): 6 | if name is None: 7 | name = "Whisper-WebUI" 8 | logger = logging.getLogger(name) 9 | 10 | if not logger.handlers: 11 | logger.setLevel(logging.INFO) 12 | 13 | formatter = logging.Formatter( 14 | "%(asctime)s - %(name)s - %(levelname)s - %(message)s" 15 | ) 16 | 17 | handler = logging.StreamHandler() 18 | # handler.setFormatter(formatter) 19 | 20 | logger.addHandler(handler) 21 | 22 | return logger -------------------------------------------------------------------------------- /modules/utils/paths.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | WEBUI_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) 4 | MODELS_DIR = os.path.join(WEBUI_DIR, "models") 5 | WHISPER_MODELS_DIR = os.path.join(MODELS_DIR, "Whisper") 6 | FASTER_WHISPER_MODELS_DIR = os.path.join(WHISPER_MODELS_DIR, "faster-whisper") 7 | INSANELY_FAST_WHISPER_MODELS_DIR = os.path.join(WHISPER_MODELS_DIR, "insanely-fast-whisper") 8 | NLLB_MODELS_DIR = os.path.join(MODELS_DIR, "NLLB") 9 | DIARIZATION_MODELS_DIR = os.path.join(MODELS_DIR, "Diarization") 10 | UVR_MODELS_DIR = os.path.join(MODELS_DIR, "UVR", "MDX_Net_Models") 11 | CONFIGS_DIR = os.path.join(WEBUI_DIR, "configs") 12 | DEFAULT_PARAMETERS_CONFIG_PATH = os.path.join(CONFIGS_DIR, "default_parameters.yaml") 13 | I18N_YAML_PATH = os.path.join(CONFIGS_DIR, "translation.yaml") 14 | OUTPUT_DIR = os.path.join(WEBUI_DIR, "outputs") 15 | TRANSLATION_OUTPUT_DIR = os.path.join(OUTPUT_DIR, "translations") 16 | UVR_OUTPUT_DIR = os.path.join(OUTPUT_DIR, "UVR") 17 | UVR_INSTRUMENTAL_OUTPUT_DIR = os.path.join(UVR_OUTPUT_DIR, "instrumental") 18 | UVR_VOCALS_OUTPUT_DIR = os.path.join(UVR_OUTPUT_DIR, "vocals") 19 | BACKEND_DIR_PATH = os.path.join(WEBUI_DIR, "backend") 20 | SERVER_CONFIG_PATH = os.path.join(BACKEND_DIR_PATH, "configs", "config.yaml") 21 | SERVER_DOTENV_PATH = os.path.join(BACKEND_DIR_PATH, "configs", ".env") 22 | BACKEND_CACHE_DIR = os.path.join(BACKEND_DIR_PATH, "cache") 23 | 24 | for dir_path in [MODELS_DIR, 25 | WHISPER_MODELS_DIR, 26 | FASTER_WHISPER_MODELS_DIR, 27 | INSANELY_FAST_WHISPER_MODELS_DIR, 28 | NLLB_MODELS_DIR, 29 | DIARIZATION_MODELS_DIR, 30 | UVR_MODELS_DIR, 31 | CONFIGS_DIR, 32 | OUTPUT_DIR, 33 | TRANSLATION_OUTPUT_DIR, 34 | UVR_INSTRUMENTAL_OUTPUT_DIR, 35 | UVR_VOCALS_OUTPUT_DIR, 36 | BACKEND_CACHE_DIR]: 37 | os.makedirs(dir_path, exist_ok=True) 38 | -------------------------------------------------------------------------------- /modules/utils/youtube_manager.py: -------------------------------------------------------------------------------- 1 | from pytubefix import YouTube 2 | import subprocess 3 | import os 4 | 5 | 6 | def get_ytdata(link): 7 | return YouTube(link) 8 | 9 | 10 | def get_ytmetas(link): 11 | yt = YouTube(link) 12 | return yt.thumbnail_url, yt.title, yt.description 13 | 14 | 15 | def get_ytaudio(ytdata: YouTube): 16 | # Somehow the audio is corrupted so need to convert to valid audio file. 17 | # Fix for : https://github.com/jhj0517/Whisper-WebUI/issues/304 18 | 19 | audio_path = ytdata.streams.get_audio_only().download(filename=os.path.join("modules", "yt_tmp.wav")) 20 | temp_audio_path = os.path.join("modules", "yt_tmp_fixed.wav") 21 | 22 | try: 23 | subprocess.run([ 24 | 'ffmpeg', '-y', 25 | '-i', audio_path, 26 | temp_audio_path 27 | ], check=True) 28 | 29 | os.replace(temp_audio_path, audio_path) 30 | return audio_path 31 | except subprocess.CalledProcessError as e: 32 | print(f"Error during ffmpeg conversion: {e}") 33 | return None 34 | -------------------------------------------------------------------------------- /modules/uvr/music_separator.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union, List, Dict 2 | import numpy as np 3 | import torchaudio 4 | import soundfile as sf 5 | import os 6 | import torch 7 | import gc 8 | import gradio as gr 9 | from datetime import datetime 10 | import traceback 11 | 12 | from modules.utils.paths import DEFAULT_PARAMETERS_CONFIG_PATH, UVR_MODELS_DIR, UVR_OUTPUT_DIR 13 | from modules.utils.files_manager import load_yaml, save_yaml, is_video 14 | from modules.diarize.audio_loader import load_audio 15 | from modules.utils.logger import get_logger 16 | logger = get_logger() 17 | 18 | try: 19 | from uvr.models import MDX, Demucs, VrNetwork, MDXC 20 | except Exception as e: 21 | logger.warning( 22 | "Failed to import uvr. BGM separation feature will not work. " 23 | "Please open an issue on GitHub if you encounter this error. " 24 | f"Error: {type(e).__name__}: {traceback.format_exc()}" 25 | ) 26 | 27 | 28 | class MusicSeparator: 29 | def __init__(self, 30 | model_dir: Optional[str] = UVR_MODELS_DIR, 31 | output_dir: Optional[str] = UVR_OUTPUT_DIR): 32 | self.model = None 33 | self.device = self.get_device() 34 | self.available_devices = ["cpu", "cuda", "xpu", "mps"] 35 | self.model_dir = model_dir 36 | self.output_dir = output_dir 37 | instrumental_output_dir = os.path.join(self.output_dir, "instrumental") 38 | vocals_output_dir = os.path.join(self.output_dir, "vocals") 39 | os.makedirs(instrumental_output_dir, exist_ok=True) 40 | os.makedirs(vocals_output_dir, exist_ok=True) 41 | self.audio_info = None 42 | self.available_models = ["UVR-MDX-NET-Inst_HQ_4", "UVR-MDX-NET-Inst_3"] 43 | self.default_model = self.available_models[0] 44 | self.current_model_size = self.default_model 45 | self.model_config = { 46 | "segment": 256, 47 | "split": True 48 | } 49 | 50 | def update_model(self, 51 | model_name: str = "UVR-MDX-NET-Inst_1", 52 | device: Optional[str] = None, 53 | segment_size: int = 256): 54 | """ 55 | Update model with the given model name 56 | 57 | Args: 58 | model_name (str): Model name. 59 | device (str): Device to use for the model. 60 | segment_size (int): Segment size for the prediction. 61 | """ 62 | if device is None: 63 | device = self.device 64 | 65 | self.device = device 66 | self.model_config = { 67 | "segment": segment_size, 68 | "split": True 69 | } 70 | self.model = MDX(name=model_name, 71 | other_metadata=self.model_config, 72 | device=self.device, 73 | logger=None, 74 | model_dir=self.model_dir) 75 | 76 | def separate(self, 77 | audio: Union[str, np.ndarray], 78 | model_name: str, 79 | device: Optional[str] = None, 80 | segment_size: int = 256, 81 | save_file: bool = False, 82 | progress: gr.Progress = gr.Progress()) -> tuple[np.ndarray, np.ndarray, List]: 83 | """ 84 | Separate the background music from the audio. 85 | 86 | Args: 87 | audio (Union[str, np.ndarray]): Audio path or numpy array. 88 | model_name (str): Model name. 89 | device (str): Device to use for the model. 90 | segment_size (int): Segment size for the prediction. 91 | save_file (bool): Whether to save the separated audio to output path or not. 92 | progress (gr.Progress): Gradio progress indicator. 93 | 94 | Returns: 95 | A Tuple of 96 | np.ndarray: Instrumental numpy arrays. 97 | np.ndarray: Vocals numpy arrays. 98 | file_paths: List of file paths where the separated audio is saved. Return empty when save_file is False. 99 | """ 100 | if isinstance(audio, str): 101 | output_filename, ext = os.path.basename(audio), ".wav" 102 | output_filename, orig_ext = os.path.splitext(output_filename) 103 | 104 | if is_video(audio): 105 | audio = load_audio(audio) 106 | sample_rate = 16000 107 | else: 108 | self.audio_info = torchaudio.info(audio) 109 | sample_rate = self.audio_info.sample_rate 110 | else: 111 | timestamp = datetime.now().strftime("%m%d%H%M%S") 112 | output_filename, ext = f"UVR-{timestamp}", ".wav" 113 | sample_rate = 16000 114 | 115 | model_config = { 116 | "segment": segment_size, 117 | "split": True 118 | } 119 | 120 | if (self.model is None or 121 | self.current_model_size != model_name or 122 | self.model_config != model_config or 123 | self.model.sample_rate != sample_rate or 124 | self.device != device): 125 | progress(0, desc="Initializing UVR Model..") 126 | self.update_model( 127 | model_name=model_name, 128 | device=device, 129 | segment_size=segment_size 130 | ) 131 | self.model.sample_rate = sample_rate 132 | 133 | progress(0, desc="Separating background music from the audio.. " 134 | "(It will only display 0% until the job is complete.) ") 135 | result = self.model(audio) 136 | instrumental, vocals = result["instrumental"].T, result["vocals"].T 137 | 138 | file_paths = [] 139 | if save_file: 140 | instrumental_output_path = os.path.join(self.output_dir, "instrumental", f"{output_filename}-instrumental{ext}") 141 | vocals_output_path = os.path.join(self.output_dir, "vocals", f"{output_filename}-vocals{ext}") 142 | sf.write(instrumental_output_path, instrumental, sample_rate, format="WAV") 143 | sf.write(vocals_output_path, vocals, sample_rate, format="WAV") 144 | file_paths += [instrumental_output_path, vocals_output_path] 145 | 146 | return instrumental, vocals, file_paths 147 | 148 | def separate_files(self, 149 | files: List, 150 | model_name: str, 151 | device: Optional[str] = None, 152 | segment_size: int = 256, 153 | save_file: bool = True, 154 | progress: gr.Progress = gr.Progress()) -> List[str]: 155 | """Separate the background music from the audio files. Returns only last Instrumental and vocals file paths 156 | to display into gr.Audio()""" 157 | self.cache_parameters(model_size=model_name, segment_size=segment_size) 158 | 159 | for file_path in files: 160 | instrumental, vocals, file_paths = self.separate( 161 | audio=file_path, 162 | model_name=model_name, 163 | device=device, 164 | segment_size=segment_size, 165 | save_file=save_file, 166 | progress=progress 167 | ) 168 | return file_paths 169 | 170 | @staticmethod 171 | def get_device(): 172 | if torch.cuda.is_available(): 173 | return "cuda" 174 | if torch.xpu.is_available(): 175 | return "xpu" 176 | elif torch.backends.mps.is_available(): 177 | return "mps" 178 | else: 179 | return "cpu" 180 | 181 | def offload(self): 182 | """Offload the model and free up the memory""" 183 | if self.model is not None: 184 | del self.model 185 | self.model = None 186 | if self.device == "cuda": 187 | torch.cuda.empty_cache() 188 | torch.cuda.reset_max_memory_allocated() 189 | if self.device == "xpu": 190 | torch.xpu.empty_cache() 191 | torch.xpu.reset_accumulated_memory_stats() 192 | torch.xpu.reset_peak_memory_stats() 193 | gc.collect() 194 | self.audio_info = None 195 | 196 | @staticmethod 197 | def cache_parameters(model_size: str, 198 | segment_size: int): 199 | cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH) 200 | cached_uvr_params = cached_params["bgm_separation"] 201 | uvr_params_to_cache = { 202 | "model_size": model_size, 203 | "segment_size": segment_size 204 | } 205 | cached_uvr_params = {**cached_uvr_params, **uvr_params_to_cache} 206 | cached_params["bgm_separation"] = cached_uvr_params 207 | save_yaml(cached_params, DEFAULT_PARAMETERS_CONFIG_PATH) 208 | -------------------------------------------------------------------------------- /modules/vad/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/modules/vad/__init__.py -------------------------------------------------------------------------------- /modules/vad/silero_vad.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/SYSTRAN/faster-whisper/blob/master/faster_whisper/vad.py 2 | 3 | from faster_whisper.vad import VadOptions, get_vad_model 4 | import numpy as np 5 | from typing import BinaryIO, Union, List, Optional, Tuple 6 | import warnings 7 | import bisect 8 | import faster_whisper 9 | from faster_whisper.transcribe import SpeechTimestampsMap 10 | import gradio as gr 11 | 12 | from modules.whisper.data_classes import * 13 | 14 | 15 | class SileroVAD: 16 | def __init__(self): 17 | self.sampling_rate = 16000 18 | self.window_size_samples = 512 19 | self.model = None 20 | 21 | def run(self, 22 | audio: Union[str, BinaryIO, np.ndarray], 23 | vad_parameters: VadOptions, 24 | progress: gr.Progress = gr.Progress() 25 | ) -> Tuple[np.ndarray, List[dict]]: 26 | """ 27 | Run VAD 28 | 29 | Parameters 30 | ---------- 31 | audio: Union[str, BinaryIO, np.ndarray] 32 | Audio path or file binary or Audio numpy array 33 | vad_parameters: 34 | Options for VAD processing. 35 | progress: gr.Progress 36 | Indicator to show progress directly in gradio. 37 | 38 | Returns 39 | ---------- 40 | np.ndarray 41 | Pre-processed audio with VAD 42 | List[dict] 43 | Chunks of speeches to be used to restore the timestamps later 44 | """ 45 | 46 | sampling_rate = self.sampling_rate 47 | 48 | if not isinstance(audio, np.ndarray): 49 | audio = faster_whisper.decode_audio(audio, sampling_rate=sampling_rate) 50 | 51 | duration = audio.shape[0] / sampling_rate 52 | duration_after_vad = duration 53 | 54 | if vad_parameters is None: 55 | vad_parameters = VadOptions() 56 | elif isinstance(vad_parameters, dict): 57 | vad_parameters = VadOptions(**vad_parameters) 58 | speech_chunks = self.get_speech_timestamps( 59 | audio=audio, 60 | vad_options=vad_parameters, 61 | progress=progress 62 | ) 63 | 64 | audio = self.collect_chunks(audio, speech_chunks) 65 | duration_after_vad = audio.shape[0] / sampling_rate 66 | 67 | return audio, speech_chunks 68 | 69 | def get_speech_timestamps( 70 | self, 71 | audio: np.ndarray, 72 | vad_options: Optional[VadOptions] = None, 73 | progress: gr.Progress = gr.Progress(), 74 | **kwargs, 75 | ) -> List[dict]: 76 | """This method is used for splitting long audios into speech chunks using silero VAD. 77 | 78 | Args: 79 | audio: One dimensional float array. 80 | vad_options: Options for VAD processing. 81 | kwargs: VAD options passed as keyword arguments for backward compatibility. 82 | progress: Gradio progress to indicate progress. 83 | 84 | Returns: 85 | List of dicts containing begin and end samples of each speech chunk. 86 | """ 87 | 88 | if self.model is None: 89 | self.update_model() 90 | 91 | if vad_options is None: 92 | vad_options = VadOptions(**kwargs) 93 | 94 | threshold = vad_options.threshold 95 | neg_threshold = vad_options.neg_threshold 96 | min_speech_duration_ms = vad_options.min_speech_duration_ms 97 | max_speech_duration_s = vad_options.max_speech_duration_s 98 | min_silence_duration_ms = vad_options.min_silence_duration_ms 99 | window_size_samples = self.window_size_samples 100 | speech_pad_ms = vad_options.speech_pad_ms 101 | min_speech_samples = self.sampling_rate * min_speech_duration_ms / 1000 102 | speech_pad_samples = self.sampling_rate * speech_pad_ms / 1000 103 | max_speech_samples = ( 104 | self.sampling_rate * max_speech_duration_s 105 | - window_size_samples 106 | - 2 * speech_pad_samples 107 | ) 108 | min_silence_samples = self.sampling_rate * min_silence_duration_ms / 1000 109 | min_silence_samples_at_max_speech = self.sampling_rate * 98 / 1000 110 | 111 | audio_length_samples = len(audio) 112 | 113 | padded_audio = np.pad( 114 | audio, (0, window_size_samples - audio.shape[0] % window_size_samples) 115 | ) 116 | speech_probs = self.model(padded_audio.reshape(1, -1)).squeeze(0) 117 | 118 | triggered = False 119 | speeches = [] 120 | current_speech = {} 121 | if neg_threshold is None: 122 | neg_threshold = max(threshold - 0.15, 0.01) 123 | 124 | # to save potential segment end (and tolerate some silence) 125 | temp_end = 0 126 | # to save potential segment limits in case of maximum segment size reached 127 | prev_end = next_start = 0 128 | 129 | for i, speech_prob in enumerate(speech_probs): 130 | if (speech_prob >= threshold) and temp_end: 131 | temp_end = 0 132 | if next_start < prev_end: 133 | next_start = window_size_samples * i 134 | 135 | if (speech_prob >= threshold) and not triggered: 136 | triggered = True 137 | current_speech["start"] = window_size_samples * i 138 | continue 139 | 140 | if ( 141 | triggered 142 | and (window_size_samples * i) - current_speech["start"] > max_speech_samples 143 | ): 144 | if prev_end: 145 | current_speech["end"] = prev_end 146 | speeches.append(current_speech) 147 | current_speech = {} 148 | # previously reached silence (< neg_thres) and is still not speech (< thres) 149 | if next_start < prev_end: 150 | triggered = False 151 | else: 152 | current_speech["start"] = next_start 153 | prev_end = next_start = temp_end = 0 154 | else: 155 | current_speech["end"] = window_size_samples * i 156 | speeches.append(current_speech) 157 | current_speech = {} 158 | prev_end = next_start = temp_end = 0 159 | triggered = False 160 | continue 161 | 162 | if (speech_prob < neg_threshold) and triggered: 163 | if not temp_end: 164 | temp_end = window_size_samples * i 165 | # condition to avoid cutting in very short silence 166 | if (window_size_samples * i) - temp_end > min_silence_samples_at_max_speech: 167 | prev_end = temp_end 168 | if (window_size_samples * i) - temp_end < min_silence_samples: 169 | continue 170 | else: 171 | current_speech["end"] = temp_end 172 | if ( 173 | current_speech["end"] - current_speech["start"] 174 | ) > min_speech_samples: 175 | speeches.append(current_speech) 176 | current_speech = {} 177 | prev_end = next_start = temp_end = 0 178 | triggered = False 179 | continue 180 | 181 | if ( 182 | current_speech 183 | and (audio_length_samples - current_speech["start"]) > min_speech_samples 184 | ): 185 | current_speech["end"] = audio_length_samples 186 | speeches.append(current_speech) 187 | 188 | for i, speech in enumerate(speeches): 189 | if i == 0: 190 | speech["start"] = int(max(0, speech["start"] - speech_pad_samples)) 191 | if i != len(speeches) - 1: 192 | silence_duration = speeches[i + 1]["start"] - speech["end"] 193 | if silence_duration < 2 * speech_pad_samples: 194 | speech["end"] += int(silence_duration // 2) 195 | speeches[i + 1]["start"] = int( 196 | max(0, speeches[i + 1]["start"] - silence_duration // 2) 197 | ) 198 | else: 199 | speech["end"] = int( 200 | min(audio_length_samples, speech["end"] + speech_pad_samples) 201 | ) 202 | speeches[i + 1]["start"] = int( 203 | max(0, speeches[i + 1]["start"] - speech_pad_samples) 204 | ) 205 | else: 206 | speech["end"] = int( 207 | min(audio_length_samples, speech["end"] + speech_pad_samples) 208 | ) 209 | 210 | return speeches 211 | 212 | def update_model(self): 213 | self.model = get_vad_model() 214 | 215 | @staticmethod 216 | def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray: 217 | """Collects and concatenates audio chunks.""" 218 | if not chunks: 219 | return np.array([], dtype=np.float32) 220 | 221 | return np.concatenate([audio[chunk["start"]: chunk["end"]] for chunk in chunks]) 222 | 223 | @staticmethod 224 | def format_timestamp( 225 | seconds: float, 226 | always_include_hours: bool = False, 227 | decimal_marker: str = ".", 228 | ) -> str: 229 | assert seconds >= 0, "non-negative timestamp expected" 230 | milliseconds = round(seconds * 1000.0) 231 | 232 | hours = milliseconds // 3_600_000 233 | milliseconds -= hours * 3_600_000 234 | 235 | minutes = milliseconds // 60_000 236 | milliseconds -= minutes * 60_000 237 | 238 | seconds = milliseconds // 1_000 239 | milliseconds -= seconds * 1_000 240 | 241 | hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" 242 | return ( 243 | f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" 244 | ) 245 | 246 | def restore_speech_timestamps( 247 | self, 248 | segments: List[Segment], 249 | speech_chunks: List[dict], 250 | sampling_rate: Optional[int] = None, 251 | ) -> List[Segment]: 252 | if sampling_rate is None: 253 | sampling_rate = self.sampling_rate 254 | 255 | ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate) 256 | 257 | for segment in segments: 258 | if segment.words: 259 | words = [] 260 | for word in segment.words: 261 | # Ensure the word start and end times are resolved to the same chunk. 262 | middle = (word.start + word.end) / 2 263 | chunk_index = ts_map.get_chunk_index(middle) 264 | word.start = ts_map.get_original_time(word.start, chunk_index) 265 | word.end = ts_map.get_original_time(word.end, chunk_index) 266 | words.append(word) 267 | 268 | segment.start = words[0].start 269 | segment.end = words[-1].end 270 | segment.words = words 271 | 272 | else: 273 | segment.start = ts_map.get_original_time(segment.start) 274 | segment.end = ts_map.get_original_time(segment.end) 275 | 276 | return segments 277 | 278 | -------------------------------------------------------------------------------- /modules/whisper/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/modules/whisper/__init__.py -------------------------------------------------------------------------------- /modules/whisper/faster_whisper_inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import huggingface_hub 4 | import numpy as np 5 | import torch 6 | from typing import BinaryIO, Union, Tuple, List, Callable 7 | import faster_whisper 8 | from faster_whisper.vad import VadOptions 9 | import ast 10 | import ctranslate2 11 | import whisper 12 | import gradio as gr 13 | from argparse import Namespace 14 | 15 | from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR) 16 | from modules.whisper.data_classes import * 17 | from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline 18 | 19 | 20 | class FasterWhisperInference(BaseTranscriptionPipeline): 21 | def __init__(self, 22 | model_dir: str = FASTER_WHISPER_MODELS_DIR, 23 | diarization_model_dir: str = DIARIZATION_MODELS_DIR, 24 | uvr_model_dir: str = UVR_MODELS_DIR, 25 | output_dir: str = OUTPUT_DIR, 26 | ): 27 | super().__init__( 28 | model_dir=model_dir, 29 | diarization_model_dir=diarization_model_dir, 30 | uvr_model_dir=uvr_model_dir, 31 | output_dir=output_dir 32 | ) 33 | self.model_dir = model_dir 34 | os.makedirs(self.model_dir, exist_ok=True) 35 | 36 | self.model_paths = self.get_model_paths() 37 | self.device = self.get_device() 38 | self.available_models = self.model_paths.keys() 39 | 40 | def transcribe(self, 41 | audio: Union[str, BinaryIO, np.ndarray], 42 | progress: gr.Progress = gr.Progress(), 43 | progress_callback: Optional[Callable] = None, 44 | *whisper_params, 45 | ) -> Tuple[List[Segment], float]: 46 | """ 47 | transcribe method for faster-whisper. 48 | 49 | Parameters 50 | ---------- 51 | audio: Union[str, BinaryIO, np.ndarray] 52 | Audio path or file binary or Audio numpy array 53 | progress: gr.Progress 54 | Indicator to show progress directly in gradio. 55 | progress_callback: Optional[Callable] 56 | callback function to show progress. Can be used to update progress in the backend. 57 | *whisper_params: tuple 58 | Parameters related with whisper. This will be dealt with "WhisperParameters" data class 59 | 60 | Returns 61 | ---------- 62 | segments_result: List[Segment] 63 | list of Segment that includes start, end timestamps and transcribed text 64 | elapsed_time: float 65 | elapsed time for transcription 66 | """ 67 | start_time = time.time() 68 | 69 | params = WhisperParams.from_list(list(whisper_params)) 70 | 71 | if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type: 72 | self.update_model(params.model_size, params.compute_type, progress) 73 | 74 | segments, info = self.model.transcribe( 75 | audio=audio, 76 | language=params.lang, 77 | task="translate" if params.is_translate else "transcribe", 78 | beam_size=params.beam_size, 79 | log_prob_threshold=params.log_prob_threshold, 80 | no_speech_threshold=params.no_speech_threshold, 81 | best_of=params.best_of, 82 | patience=params.patience, 83 | temperature=params.temperature, 84 | initial_prompt=params.initial_prompt, 85 | compression_ratio_threshold=params.compression_ratio_threshold, 86 | length_penalty=params.length_penalty, 87 | repetition_penalty=params.repetition_penalty, 88 | no_repeat_ngram_size=params.no_repeat_ngram_size, 89 | prefix=params.prefix, 90 | suppress_blank=params.suppress_blank, 91 | suppress_tokens=params.suppress_tokens, 92 | max_initial_timestamp=params.max_initial_timestamp, 93 | word_timestamps=params.word_timestamps, 94 | prepend_punctuations=params.prepend_punctuations, 95 | append_punctuations=params.append_punctuations, 96 | max_new_tokens=params.max_new_tokens, 97 | chunk_length=params.chunk_length, 98 | hallucination_silence_threshold=params.hallucination_silence_threshold, 99 | hotwords=params.hotwords, 100 | language_detection_threshold=params.language_detection_threshold, 101 | language_detection_segments=params.language_detection_segments, 102 | prompt_reset_on_temperature=params.prompt_reset_on_temperature, 103 | ) 104 | progress(0, desc="Loading audio..") 105 | 106 | segments_result = [] 107 | for segment in segments: 108 | progress_n = segment.start / info.duration 109 | progress(progress_n, desc="Transcribing..") 110 | if progress_callback is not None: 111 | progress_callback(progress_n) 112 | segments_result.append(Segment.from_faster_whisper(segment)) 113 | 114 | elapsed_time = time.time() - start_time 115 | return segments_result, elapsed_time 116 | 117 | def update_model(self, 118 | model_size: str, 119 | compute_type: str, 120 | progress: gr.Progress = gr.Progress() 121 | ): 122 | """ 123 | Update current model setting 124 | 125 | Parameters 126 | ---------- 127 | model_size: str 128 | Size of whisper model. If you enter the huggingface repo id, it will try to download the model 129 | automatically from huggingface. 130 | compute_type: str 131 | Compute type for transcription. 132 | see more info : https://opennmt.net/CTranslate2/quantization.html 133 | progress: gr.Progress 134 | Indicator to show progress directly in gradio. 135 | """ 136 | progress(0, desc="Initializing Model..") 137 | 138 | model_size_dirname = model_size.replace("/", "--") if "/" in model_size else model_size 139 | if model_size not in self.model_paths and model_size_dirname not in self.model_paths: 140 | print(f"Model is not detected. Trying to download \"{model_size}\" from huggingface to " 141 | f"\"{os.path.join(self.model_dir, model_size_dirname)} ...") 142 | huggingface_hub.snapshot_download( 143 | model_size, 144 | local_dir=os.path.join(self.model_dir, model_size_dirname), 145 | ) 146 | self.model_paths = self.get_model_paths() 147 | gr.Info(f"Model is downloaded with the name \"{model_size_dirname}\"") 148 | 149 | self.current_model_size = self.model_paths[model_size_dirname] 150 | 151 | local_files_only = False 152 | hf_prefix = "models--Systran--faster-whisper-" 153 | official_model_path = os.path.join(self.model_dir, hf_prefix+model_size) 154 | if ((os.path.isdir(self.current_model_size) and os.path.exists(self.current_model_size)) or 155 | (model_size in faster_whisper.available_models() and os.path.exists(official_model_path))): 156 | local_files_only = True 157 | 158 | self.current_compute_type = compute_type 159 | self.model = faster_whisper.WhisperModel( 160 | device=self.device, 161 | model_size_or_path=self.current_model_size, 162 | download_root=self.model_dir, 163 | compute_type=self.current_compute_type, 164 | local_files_only=local_files_only 165 | ) 166 | 167 | def get_model_paths(self): 168 | """ 169 | Get available models from models path including fine-tuned model. 170 | 171 | Returns 172 | ---------- 173 | Name list of models 174 | """ 175 | model_paths = {model:model for model in faster_whisper.available_models()} 176 | faster_whisper_prefix = "models--Systran--faster-whisper-" 177 | 178 | existing_models = os.listdir(self.model_dir) 179 | wrong_dirs = [".locks", "faster_whisper_models_will_be_saved_here"] 180 | existing_models = list(set(existing_models) - set(wrong_dirs)) 181 | 182 | for model_name in existing_models: 183 | if faster_whisper_prefix in model_name: 184 | model_name = model_name[len(faster_whisper_prefix):] 185 | 186 | if model_name not in whisper.available_models(): 187 | model_paths[model_name] = os.path.join(self.model_dir, model_name) 188 | return model_paths 189 | 190 | @staticmethod 191 | def get_device(): 192 | if torch.cuda.is_available(): 193 | return "cuda" 194 | else: 195 | return "auto" 196 | 197 | @staticmethod 198 | def format_suppress_tokens_str(suppress_tokens_str: str) -> List[int]: 199 | try: 200 | suppress_tokens = ast.literal_eval(suppress_tokens_str) 201 | if not isinstance(suppress_tokens, list) or not all(isinstance(item, int) for item in suppress_tokens): 202 | raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]") 203 | return suppress_tokens 204 | except Exception as e: 205 | raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]") 206 | -------------------------------------------------------------------------------- /modules/whisper/insanely_fast_whisper_inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | from typing import BinaryIO, Union, Tuple, List, Callable 5 | import torch 6 | from transformers import pipeline 7 | from transformers.utils import is_flash_attn_2_available 8 | import gradio as gr 9 | from huggingface_hub import hf_hub_download 10 | import whisper 11 | from rich.progress import Progress, TimeElapsedColumn, BarColumn, TextColumn 12 | from argparse import Namespace 13 | 14 | from modules.utils.paths import (INSANELY_FAST_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR) 15 | from modules.whisper.data_classes import * 16 | from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline 17 | from modules.utils.logger import get_logger 18 | 19 | logger = get_logger() 20 | 21 | 22 | class InsanelyFastWhisperInference(BaseTranscriptionPipeline): 23 | def __init__(self, 24 | model_dir: str = INSANELY_FAST_WHISPER_MODELS_DIR, 25 | diarization_model_dir: str = DIARIZATION_MODELS_DIR, 26 | uvr_model_dir: str = UVR_MODELS_DIR, 27 | output_dir: str = OUTPUT_DIR, 28 | ): 29 | super().__init__( 30 | model_dir=model_dir, 31 | output_dir=output_dir, 32 | diarization_model_dir=diarization_model_dir, 33 | uvr_model_dir=uvr_model_dir 34 | ) 35 | self.model_dir = model_dir 36 | os.makedirs(self.model_dir, exist_ok=True) 37 | 38 | self.available_models = self.get_model_paths() 39 | 40 | def transcribe(self, 41 | audio: Union[str, np.ndarray, torch.Tensor], 42 | progress: gr.Progress = gr.Progress(), 43 | progress_callback: Optional[Callable] = None, 44 | *whisper_params, 45 | ) -> Tuple[List[Segment], float]: 46 | """ 47 | transcribe method for faster-whisper. 48 | 49 | Parameters 50 | ---------- 51 | audio: Union[str, BinaryIO, np.ndarray] 52 | Audio path or file binary or Audio numpy array 53 | progress: gr.Progress 54 | Indicator to show progress directly in gradio. 55 | progress_callback: Optional[Callable] 56 | callback function to show progress. Can be used to update progress in the backend. 57 | *whisper_params: tuple 58 | Parameters related with whisper. This will be dealt with "WhisperParameters" data class 59 | 60 | Returns 61 | ---------- 62 | segments_result: List[Segment] 63 | list of Segment that includes start, end timestamps and transcribed text 64 | elapsed_time: float 65 | elapsed time for transcription 66 | """ 67 | start_time = time.time() 68 | params = WhisperParams.from_list(list(whisper_params)) 69 | 70 | if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type: 71 | self.update_model(params.model_size, params.compute_type, progress) 72 | 73 | progress(0, desc="Transcribing...Progress is not shown in insanely-fast-whisper.") 74 | with Progress( 75 | TextColumn("[progress.description]{task.description}"), 76 | BarColumn(style="yellow1", pulse_style="white"), 77 | TimeElapsedColumn(), 78 | ) as progress: 79 | progress.add_task("[yellow]Transcribing...", total=None) 80 | 81 | kwargs = { 82 | "no_speech_threshold": params.no_speech_threshold, 83 | "temperature": params.temperature, 84 | "compression_ratio_threshold": params.compression_ratio_threshold, 85 | "logprob_threshold": params.log_prob_threshold, 86 | } 87 | 88 | if self.current_model_size.endswith(".en"): 89 | pass 90 | else: 91 | kwargs["language"] = params.lang 92 | kwargs["task"] = "translate" if params.is_translate else "transcribe" 93 | 94 | segments = self.model( 95 | inputs=audio, 96 | return_timestamps=True, 97 | chunk_length_s=params.chunk_length, 98 | batch_size=params.batch_size, 99 | generate_kwargs=kwargs 100 | ) 101 | 102 | segments_result = [] 103 | for item in segments["chunks"]: 104 | start, end = item["timestamp"][0], item["timestamp"][1] 105 | if end is None: 106 | end = start 107 | segments_result.append(Segment( 108 | text=item["text"], 109 | start=start, 110 | end=end 111 | )) 112 | 113 | elapsed_time = time.time() - start_time 114 | return segments_result, elapsed_time 115 | 116 | def update_model(self, 117 | model_size: str, 118 | compute_type: str, 119 | progress: gr.Progress = gr.Progress(), 120 | ): 121 | """ 122 | Update current model setting 123 | 124 | Parameters 125 | ---------- 126 | model_size: str 127 | Size of whisper model 128 | compute_type: str 129 | Compute type for transcription. 130 | see more info : https://opennmt.net/CTranslate2/quantization.html 131 | progress: gr.Progress 132 | Indicator to show progress directly in gradio. 133 | """ 134 | progress(0, desc="Initializing Model..") 135 | model_path = os.path.join(self.model_dir, model_size) 136 | if not os.path.isdir(model_path) or not os.listdir(model_path): 137 | self.download_model( 138 | model_size=model_size, 139 | download_root=model_path, 140 | progress=progress 141 | ) 142 | 143 | self.current_compute_type = compute_type 144 | self.current_model_size = model_size 145 | self.model = pipeline( 146 | "automatic-speech-recognition", 147 | model=os.path.join(self.model_dir, model_size), 148 | torch_dtype=self.current_compute_type, 149 | device=self.device, 150 | model_kwargs={"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {"attn_implementation": "sdpa"}, 151 | ) 152 | 153 | def get_model_paths(self): 154 | """ 155 | Get available models from models path including fine-tuned model. 156 | 157 | Returns 158 | ---------- 159 | Name set of models 160 | """ 161 | openai_models = whisper.available_models() 162 | distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"] 163 | default_models = openai_models + distil_models 164 | 165 | existing_models = os.listdir(self.model_dir) 166 | wrong_dirs = [".locks", "insanely_fast_whisper_models_will_be_saved_here"] 167 | 168 | available_models = default_models + existing_models 169 | available_models = [model for model in available_models if model not in wrong_dirs] 170 | available_models = sorted(set(available_models), key=available_models.index) 171 | 172 | return available_models 173 | 174 | @staticmethod 175 | def download_model( 176 | model_size: str, 177 | download_root: str, 178 | progress: gr.Progress 179 | ): 180 | progress(0, 'Initializing model..') 181 | logger.info(f'Downloading {model_size} to "{download_root}"....') 182 | 183 | os.makedirs(download_root, exist_ok=True) 184 | download_list = [ 185 | "model.safetensors", 186 | "config.json", 187 | "generation_config.json", 188 | "preprocessor_config.json", 189 | "tokenizer.json", 190 | "tokenizer_config.json", 191 | "added_tokens.json", 192 | "special_tokens_map.json", 193 | "vocab.json", 194 | ] 195 | 196 | if model_size.startswith("distil"): 197 | repo_id = f"distil-whisper/{model_size}" 198 | else: 199 | repo_id = f"openai/whisper-{model_size}" 200 | for item in download_list: 201 | hf_hub_download(repo_id=repo_id, filename=item, local_dir=download_root) 202 | -------------------------------------------------------------------------------- /modules/whisper/whisper_Inference.py: -------------------------------------------------------------------------------- 1 | import whisper 2 | import gradio as gr 3 | import time 4 | from typing import BinaryIO, Union, Tuple, List, Callable, Optional 5 | import numpy as np 6 | import torch 7 | import os 8 | from argparse import Namespace 9 | 10 | from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, UVR_MODELS_DIR) 11 | from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline 12 | from modules.whisper.data_classes import * 13 | 14 | 15 | class WhisperInference(BaseTranscriptionPipeline): 16 | def __init__(self, 17 | model_dir: str = WHISPER_MODELS_DIR, 18 | diarization_model_dir: str = DIARIZATION_MODELS_DIR, 19 | uvr_model_dir: str = UVR_MODELS_DIR, 20 | output_dir: str = OUTPUT_DIR, 21 | ): 22 | super().__init__( 23 | model_dir=model_dir, 24 | output_dir=output_dir, 25 | diarization_model_dir=diarization_model_dir, 26 | uvr_model_dir=uvr_model_dir 27 | ) 28 | 29 | def transcribe(self, 30 | audio: Union[str, np.ndarray, torch.Tensor], 31 | progress: gr.Progress = gr.Progress(), 32 | progress_callback: Optional[Callable] = None, 33 | *whisper_params, 34 | ) -> Tuple[List[Segment], float]: 35 | """ 36 | transcribe method for faster-whisper. 37 | 38 | Parameters 39 | ---------- 40 | audio: Union[str, BinaryIO, np.ndarray] 41 | Audio path or file binary or Audio numpy array 42 | progress: gr.Progress 43 | Indicator to show progress directly in gradio. 44 | progress_callback: Optional[Callable] 45 | callback function to show progress. Can be used to update progress in the backend. 46 | *whisper_params: tuple 47 | Parameters related with whisper. This will be dealt with "WhisperParameters" data class 48 | 49 | Returns 50 | ---------- 51 | segments_result: List[Segment] 52 | list of Segment that includes start, end timestamps and transcribed text 53 | elapsed_time: float 54 | elapsed time for transcription 55 | """ 56 | start_time = time.time() 57 | params = WhisperParams.from_list(list(whisper_params)) 58 | 59 | if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type: 60 | self.update_model(params.model_size, params.compute_type, progress) 61 | 62 | def progress_callback(progress_value): 63 | progress(progress_value, desc="Transcribing..") 64 | 65 | result = self.model.transcribe(audio=audio, 66 | language=params.lang, 67 | verbose=False, 68 | beam_size=params.beam_size, 69 | logprob_threshold=params.log_prob_threshold, 70 | no_speech_threshold=params.no_speech_threshold, 71 | task="translate" if params.is_translate else "transcribe", 72 | fp16=True if params.compute_type == "float16" else False, 73 | best_of=params.best_of, 74 | patience=params.patience, 75 | temperature=params.temperature, 76 | compression_ratio_threshold=params.compression_ratio_threshold, 77 | progress_callback=progress_callback,)["segments"] 78 | segments_result = [] 79 | for segment in result: 80 | segments_result.append(Segment( 81 | start=segment["start"], 82 | end=segment["end"], 83 | text=segment["text"] 84 | )) 85 | 86 | elapsed_time = time.time() - start_time 87 | return segments_result, elapsed_time 88 | 89 | def update_model(self, 90 | model_size: str, 91 | compute_type: str, 92 | progress: gr.Progress = gr.Progress(), 93 | ): 94 | """ 95 | Update current model setting 96 | 97 | Parameters 98 | ---------- 99 | model_size: str 100 | Size of whisper model 101 | compute_type: str 102 | Compute type for transcription. 103 | see more info : https://opennmt.net/CTranslate2/quantization.html 104 | progress: gr.Progress 105 | Indicator to show progress directly in gradio. 106 | """ 107 | progress(0, desc="Initializing Model..") 108 | self.current_compute_type = compute_type 109 | self.current_model_size = model_size 110 | self.model = whisper.load_model( 111 | name=model_size, 112 | device=self.device, 113 | download_root=self.model_dir 114 | ) -------------------------------------------------------------------------------- /modules/whisper/whisper_factory.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import os 3 | import torch 4 | 5 | from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, 6 | INSANELY_FAST_WHISPER_MODELS_DIR, WHISPER_MODELS_DIR, UVR_MODELS_DIR) 7 | from modules.whisper.faster_whisper_inference import FasterWhisperInference 8 | from modules.whisper.whisper_Inference import WhisperInference 9 | from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference 10 | from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline 11 | from modules.whisper.data_classes import * 12 | from modules.utils.logger import get_logger 13 | 14 | 15 | logger = get_logger() 16 | 17 | 18 | class WhisperFactory: 19 | @staticmethod 20 | def create_whisper_inference( 21 | whisper_type: str, 22 | whisper_model_dir: str = WHISPER_MODELS_DIR, 23 | faster_whisper_model_dir: str = FASTER_WHISPER_MODELS_DIR, 24 | insanely_fast_whisper_model_dir: str = INSANELY_FAST_WHISPER_MODELS_DIR, 25 | diarization_model_dir: str = DIARIZATION_MODELS_DIR, 26 | uvr_model_dir: str = UVR_MODELS_DIR, 27 | output_dir: str = OUTPUT_DIR, 28 | ) -> "BaseTranscriptionPipeline": 29 | """ 30 | Create a whisper inference class based on the provided whisper_type. 31 | 32 | Parameters 33 | ---------- 34 | whisper_type : str 35 | The type of Whisper implementation to use. Supported values (case-insensitive): 36 | - "faster-whisper": https://github.com/openai/whisper 37 | - "whisper": https://github.com/openai/whisper 38 | - "insanely-fast-whisper": https://github.com/Vaibhavs10/insanely-fast-whisper 39 | whisper_model_dir : str 40 | Directory path for the Whisper model. 41 | faster_whisper_model_dir : str 42 | Directory path for the Faster Whisper model. 43 | insanely_fast_whisper_model_dir : str 44 | Directory path for the Insanely Fast Whisper model. 45 | diarization_model_dir : str 46 | Directory path for the diarization model. 47 | uvr_model_dir : str 48 | Directory path for the UVR model. 49 | output_dir : str 50 | Directory path where output files will be saved. 51 | 52 | Returns 53 | ------- 54 | BaseTranscriptionPipeline 55 | An instance of the appropriate whisper inference class based on the whisper_type. 56 | """ 57 | # Temporal fix of the bug : https://github.com/jhj0517/Whisper-WebUI/issues/144 58 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' 59 | 60 | whisper_type = whisper_type.strip().lower() 61 | 62 | if whisper_type == WhisperImpl.FASTER_WHISPER.value: 63 | if torch.xpu.is_available(): 64 | logger.warning("XPU is detected but faster-whisper only supports CUDA. " 65 | "Automatically switching to insanely-whisper implementation.") 66 | return InsanelyFastWhisperInference( 67 | model_dir=insanely_fast_whisper_model_dir, 68 | output_dir=output_dir, 69 | diarization_model_dir=diarization_model_dir, 70 | uvr_model_dir=uvr_model_dir 71 | ) 72 | 73 | return FasterWhisperInference( 74 | model_dir=faster_whisper_model_dir, 75 | output_dir=output_dir, 76 | diarization_model_dir=diarization_model_dir, 77 | uvr_model_dir=uvr_model_dir 78 | ) 79 | elif whisper_type == WhisperImpl.WHISPER.value: 80 | return WhisperInference( 81 | model_dir=whisper_model_dir, 82 | output_dir=output_dir, 83 | diarization_model_dir=diarization_model_dir, 84 | uvr_model_dir=uvr_model_dir 85 | ) 86 | elif whisper_type == WhisperImpl.INSANELY_FAST_WHISPER.value: 87 | return InsanelyFastWhisperInference( 88 | model_dir=insanely_fast_whisper_model_dir, 89 | output_dir=output_dir, 90 | diarization_model_dir=diarization_model_dir, 91 | uvr_model_dir=uvr_model_dir 92 | ) 93 | else: 94 | return FasterWhisperInference( 95 | model_dir=faster_whisper_model_dir, 96 | output_dir=output_dir, 97 | diarization_model_dir=diarization_model_dir, 98 | uvr_model_dir=uvr_model_dir 99 | ) 100 | -------------------------------------------------------------------------------- /notebook/whisper-webui.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "source": [ 6 | "---\n", 7 | "\n", 8 | "📌 **This notebook has been updated [here](https://github.com/jhj0517/Whisper-WebUI.git)!**\n", 9 | "\n", 10 | "🖋 **Author**: [jhj0517](https://github.com/jhj0517/Whisper-WebUI/blob/master/notebook/whisper-webui.ipynb)\n", 11 | "\n", 12 | "😎 **Support the Project**:\n", 13 | "\n", 14 | "If you find this project useful, please consider supporting it:\n", 15 | "\n", 16 | "\n", 17 | " \"Buy\n", 18 | "\n", 19 | "\n", 20 | "---" 21 | ], 22 | "metadata": { 23 | "id": "doKhBBXIfS21" 24 | } 25 | }, 26 | { 27 | "cell_type": "code", 28 | "source": [ 29 | "#@title #(Optional) Check GPU\n", 30 | "#@markdown Some models may not function correctly on a CPU runtime.\n", 31 | "\n", 32 | "#@markdown so you should check your GPU setup before run.\n", 33 | "!nvidia-smi" 34 | ], 35 | "metadata": { 36 | "id": "23yZvUlagEsx", 37 | "cellView": "form" 38 | }, 39 | "execution_count": null, 40 | "outputs": [] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": { 46 | "id": "kNbSbsctxahq", 47 | "cellView": "form" 48 | }, 49 | "outputs": [], 50 | "source": [ 51 | "#@title #Installation\n", 52 | "#@markdown This cell will install dependencies for Whisper-WebUI!\n", 53 | "!git clone https://github.com/jhj0517/Whisper-WebUI.git\n", 54 | "%cd Whisper-WebUI\n", 55 | "!pip install git+https://github.com/jhj0517/jhj0517-whisper.git\n", 56 | "!pip install faster-whisper==1.1.1\n", 57 | "!pip install gradio\n", 58 | "!pip install gradio-i18n\n", 59 | "# Temporal bug fix from https://github.com/jhj0517/Whisper-WebUI/issues/256\n", 60 | "!pip install git+https://github.com/JuanBindez/pytubefix.git\n", 61 | "!pip install pyannote.audio==3.3.1\n", 62 | "!pip install git+https://github.com/jhj0517/ultimatevocalremover_api.git" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "source": [ 68 | "#@title # (Optional) Mount Google Drive\n", 69 | "#@markdown Uploading large input files directly via UI may consume alot of time because it has to be uploaded in colab's server.\n", 70 | "#@markdown
This section is for using the input file paths from Google Drive to reduce such file uploading time.\n", 71 | "#@markdown
For example, you can first upload the input file to Google Drive and use the directroy path in the \"Input Folder Path\" input, as shown below.\n", 72 | "\n", 73 | "#@markdown ![image](https://github.com/user-attachments/assets/85330905-e3ec-4502-bc4b-b9d1c5b41aa2)\n", 74 | "\n", 75 | "#@markdown
And it will mount the output paths to your Google Drive's as well. This section is optional and can be ignored.\n", 76 | "\n", 77 | "\n", 78 | "# Mount Google Drive\n", 79 | "from google.colab import drive\n", 80 | "import os\n", 81 | "drive.mount('/content/drive')\n", 82 | "\n", 83 | "\n", 84 | "# Symlink Output Paths for Whisper-WebUI\n", 85 | "import os\n", 86 | "\n", 87 | "OUTPUT_DIRECTORY_PATH = '/content/drive/MyDrive/Whisper-WebUI/outputs' # @param {type:\"string\"}\n", 88 | "local_output_path = '/content/Whisper-WebUI/outputs'\n", 89 | "os.makedirs(local_output_path, exist_ok=True)\n", 90 | "os.makedirs(OUTPUT_DIRECTORY_PATH, exist_ok=True)\n", 91 | "\n", 92 | "if os.path.exists(local_output_path):\n", 93 | " !rm -r \"$local_output_path\"\n", 94 | "\n", 95 | "os.symlink(OUTPUT_DIRECTORY_PATH, local_output_path)\n", 96 | "!ls \"$local_output_path\"" 97 | ], 98 | "metadata": { 99 | "cellView": "form", 100 | "id": "y2DY5oSb9Bol" 101 | }, 102 | "execution_count": null, 103 | "outputs": [] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "source": [ 108 | "#@title # (Optional) Configure arguments\n", 109 | "#@markdown This section is used to configure some command line arguments.\n", 110 | "\n", 111 | "#@markdown You can simply ignore this section and the default values will be used.\n", 112 | "\n", 113 | "USERNAME = '' #@param {type: \"string\"}\n", 114 | "PASSWORD = '' #@param {type: \"string\"}\n", 115 | "WHISPER_TYPE = 'faster-whisper' # @param [\"whisper\", \"faster-whisper\", \"insanely_fast_whisper\"]\n", 116 | "THEME = '' #@param {type: \"string\"}\n", 117 | "\n", 118 | "arguments = \"\"\n", 119 | "if USERNAME:\n", 120 | " arguments += f\" --username {USERNAME}\"\n", 121 | "if PASSWORD:\n", 122 | " arguments += f\" --password {PASSWORD}\"\n", 123 | "if THEME:\n", 124 | " arguments += f\" --theme {THEME}\"\n", 125 | "if WHISPER_TYPE:\n", 126 | " arguments += f\" --whisper_type {WHISPER_TYPE}\"\n", 127 | "\n", 128 | "\n", 129 | "#@markdown If you wonder how these arguments are used, you can see the [Wiki](https://github.com/jhj0517/Whisper-WebUI/wiki/Command-Line-Arguments)." 130 | ], 131 | "metadata": { 132 | "id": "Qosz9BFlGui3", 133 | "cellView": "form" 134 | }, 135 | "execution_count": null, 136 | "outputs": [] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 2, 141 | "metadata": { 142 | "id": "PQroYRRZzQiN", 143 | "cellView": "form" 144 | }, 145 | "outputs": [], 146 | "source": [ 147 | "#@title #Run\n", 148 | "#@markdown Once the installation is complete, you can use public URL that is displayed.\n", 149 | "if 'arguments' in locals():\n", 150 | " !python app.py --share --colab --allowed_path \"['/content/Whisper-WebUI/outputs']\"{arguments}\n", 151 | "else:\n", 152 | " !python app.py --share --colab --allowed_path \"['/content/Whisper-WebUI/outputs']\"" 153 | ] 154 | } 155 | ], 156 | "metadata": { 157 | "colab": { 158 | "provenance": [], 159 | "gpuType": "T4" 160 | }, 161 | "kernelspec": { 162 | "display_name": "Python 3", 163 | "name": "python3" 164 | }, 165 | "language_info": { 166 | "name": "python" 167 | }, 168 | "accelerator": "GPU" 169 | }, 170 | "nbformat": 4, 171 | "nbformat_minor": 0 172 | } -------------------------------------------------------------------------------- /outputs/UVR/instrumental/UVR_outputs_for_instrumental_will_be_saved_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/outputs/UVR/instrumental/UVR_outputs_for_instrumental_will_be_saved_here -------------------------------------------------------------------------------- /outputs/UVR/vocals/UVR_outputs_for_vocals_will_be_saved_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/outputs/UVR/vocals/UVR_outputs_for_vocals_will_be_saved_here -------------------------------------------------------------------------------- /outputs/outputs_will_be_saved_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/outputs/outputs_will_be_saved_here -------------------------------------------------------------------------------- /outputs/translations/translation_outputs_will_be_saved_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/outputs/translations/translation_outputs_will_be_saved_here -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu126 2 | 3 | # Update above --extra-index-url according to your device. 4 | 5 | ## Nvidia GPU 6 | # CUDA 12.6 : https://download.pytorch.org/whl/cu126 7 | # CUDA 12.8 : https://download.pytorch.org/whl/cu128 8 | 9 | ## Intel GPU 10 | # https://download.pytorch.org/whl/xpu 11 | 12 | 13 | torch 14 | torchaudio 15 | git+https://github.com/jhj0517/jhj0517-whisper.git 16 | faster-whisper==1.1.1 17 | transformers==4.47.1 18 | gradio 19 | gradio-i18n==0.3.0 20 | pytubefix 21 | ruamel.yaml==0.18.6 22 | pyannote.audio==3.3.2 23 | git+https://github.com/jhj0517/ultimatevocalremover_api.git 24 | git+https://github.com/jhj0517/pyrubberband.git 25 | -------------------------------------------------------------------------------- /start-webui.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | call venv\scripts\activate 4 | python app.py %* 5 | 6 | echo "launching the app" 7 | pause 8 | -------------------------------------------------------------------------------- /start-webui.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source venv/bin/activate 4 | python app.py "$@" 5 | 6 | echo "launching the app" 7 | -------------------------------------------------------------------------------- /tests/test_bgm_separation.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import pytest 3 | import torch 4 | import os 5 | 6 | from modules.utils.paths import * 7 | from modules.whisper.whisper_factory import WhisperFactory 8 | from modules.whisper.data_classes import * 9 | from test_config import * 10 | from test_transcription import download_file, run_asr_pipeline 11 | 12 | 13 | @pytest.mark.skipif( 14 | not is_cuda_available(), 15 | reason="Skipping because the test only works on GPU" 16 | ) 17 | @pytest.mark.parametrize( 18 | "whisper_type,vad_filter,bgm_separation,diarization", 19 | [ 20 | (WhisperImpl.WHISPER.value, False, True, False), 21 | (WhisperImpl.FASTER_WHISPER.value, False, True, False), 22 | (WhisperImpl.INSANELY_FAST_WHISPER.value, False, True, False) 23 | ] 24 | ) 25 | def test_bgm_separation_pipeline( 26 | whisper_type: str, 27 | vad_filter: bool, 28 | bgm_separation: bool, 29 | diarization: bool, 30 | ): 31 | run_asr_pipeline(whisper_type, vad_filter, bgm_separation, diarization) 32 | 33 | 34 | @pytest.mark.skipif( 35 | not is_cuda_available(), 36 | reason="Skipping because the test only works on GPU" 37 | ) 38 | @pytest.mark.parametrize( 39 | "whisper_type,vad_filter,bgm_separation,diarization", 40 | [ 41 | (WhisperImpl.WHISPER.value, True, True, False), 42 | (WhisperImpl.FASTER_WHISPER.value, True, True, False), 43 | (WhisperImpl.INSANELY_FAST_WHISPER.value, True, True, False) 44 | ] 45 | ) 46 | def test_bgm_separation_with_vad_pipeline( 47 | whisper_type: str, 48 | vad_filter: bool, 49 | bgm_separation: bool, 50 | diarization: bool, 51 | ): 52 | run_asr_pipeline(whisper_type, vad_filter, bgm_separation, diarization) 53 | 54 | -------------------------------------------------------------------------------- /tests/test_config.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import jiwer 3 | import os 4 | import pytest 5 | import requests 6 | import torch 7 | 8 | from modules.utils.paths import * 9 | from modules.utils.youtube_manager import * 10 | 11 | TEST_FILE_DOWNLOAD_URL = "https://github.com/jhj0517/whisper_flutter_new/raw/main/example/assets/jfk.wav" 12 | TEST_FILE_PATH = os.path.join(WEBUI_DIR, "tests", "jfk.wav") 13 | TEST_ANSWER = "And so my fellow Americans ask not what your country can do for you ask what you can do for your country" 14 | TEST_YOUTUBE_URL = "https://www.youtube.com/watch?v=4WEQtgnBu0I&ab_channel=AndriaFitzer" 15 | TEST_WHISPER_MODEL = "tiny" 16 | TEST_UVR_MODEL = "UVR-MDX-NET-Inst_HQ_4" 17 | TEST_NLLB_MODEL = "facebook/nllb-200-distilled-600M" 18 | TEST_SUBTITLE_SRT_PATH = os.path.join(WEBUI_DIR, "tests", "test_srt.srt") 19 | TEST_SUBTITLE_VTT_PATH = os.path.join(WEBUI_DIR, "tests", "test_vtt.vtt") 20 | 21 | 22 | @functools.lru_cache 23 | def is_xpu_available(): 24 | return torch.xpu.is_available() 25 | 26 | 27 | @functools.lru_cache 28 | def is_cuda_available(): 29 | return torch.cuda.is_available() 30 | 31 | 32 | @functools.lru_cache 33 | def is_pytube_detected_bot(url: str = TEST_YOUTUBE_URL): 34 | try: 35 | yt_temp_path = os.path.join("modules", "yt_tmp.wav") 36 | if os.path.exists(yt_temp_path): 37 | return False 38 | yt = get_ytdata(url) 39 | audio = get_ytaudio(yt) 40 | return False 41 | except Exception as e: 42 | print(f"Pytube has detected as a bot: {e}") 43 | return True 44 | 45 | 46 | @pytest.fixture(autouse=True) 47 | def download_file(url=TEST_FILE_DOWNLOAD_URL, file_path=TEST_FILE_PATH): 48 | if os.path.exists(file_path): 49 | return 50 | 51 | if not os.path.exists(os.path.dirname(file_path)): 52 | os.makedirs(os.path.dirname(file_path)) 53 | 54 | response = requests.get(url) 55 | 56 | with open(file_path, "wb") as file: 57 | file.write(response.content) 58 | 59 | print(f"File downloaded to: {file_path}") 60 | 61 | 62 | def calculate_wer(answer, prediction): 63 | return jiwer.wer(answer, prediction) 64 | -------------------------------------------------------------------------------- /tests/test_diarization.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import pytest 3 | import os 4 | 5 | from modules.utils.paths import * 6 | from modules.whisper.whisper_factory import WhisperFactory 7 | from modules.whisper.data_classes import * 8 | from test_config import * 9 | from test_transcription import download_file, run_asr_pipeline 10 | 11 | 12 | @pytest.mark.skipif( 13 | not is_cuda_available(), 14 | reason="Skipping because the test only works on GPU" 15 | ) 16 | @pytest.mark.parametrize( 17 | "whisper_type,vad_filter,bgm_separation,diarization", 18 | [ 19 | (WhisperImpl.WHISPER.value, False, False, True), 20 | (WhisperImpl.FASTER_WHISPER.value, False, False, True), 21 | (WhisperImpl.INSANELY_FAST_WHISPER.value, False, False, True) 22 | ] 23 | ) 24 | def test_diarization_pipeline( 25 | whisper_type: str, 26 | vad_filter: bool, 27 | bgm_separation: bool, 28 | diarization: bool, 29 | ): 30 | run_asr_pipeline(whisper_type, vad_filter, bgm_separation, diarization) 31 | 32 | -------------------------------------------------------------------------------- /tests/test_srt.srt: -------------------------------------------------------------------------------- 1 | 1 2 | 00:00:00,000 --> 00:00:02,240 3 | You've got 4 | 5 | 2 6 | 00:00:02,240 --> 00:00:04,160 7 | a friend in me. 8 | -------------------------------------------------------------------------------- /tests/test_transcription.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import pytest 3 | import gradio as gr 4 | import os 5 | 6 | from modules.whisper.whisper_factory import WhisperFactory 7 | from modules.whisper.data_classes import * 8 | from modules.utils.subtitle_manager import read_file 9 | from modules.utils.paths import WEBUI_DIR 10 | from test_config import * 11 | 12 | 13 | def run_asr_pipeline( 14 | whisper_type: str, 15 | vad_filter: bool, 16 | bgm_separation: bool, 17 | diarization: bool, 18 | ): 19 | audio_path = TEST_FILE_PATH 20 | 21 | answer = TEST_ANSWER 22 | if diarization: 23 | answer = "SPEAKER_00|"+TEST_ANSWER 24 | 25 | whisper_inferencer = WhisperFactory.create_whisper_inference( 26 | whisper_type=whisper_type, 27 | ) 28 | print( 29 | f"""Whisper Device : {whisper_inferencer.device}\n""" 30 | f"""BGM Separation Device: {whisper_inferencer.music_separator.device}\n""" 31 | f"""Diarization Device: {whisper_inferencer.diarizer.device}""" 32 | ) 33 | 34 | hparams = TranscriptionPipelineParams( 35 | whisper=WhisperParams( 36 | model_size=TEST_WHISPER_MODEL, 37 | compute_type=whisper_inferencer.current_compute_type 38 | ), 39 | vad=VadParams( 40 | vad_filter=vad_filter 41 | ), 42 | bgm_separation=BGMSeparationParams( 43 | is_separate_bgm=bgm_separation, 44 | enable_offload=True 45 | ), 46 | diarization=DiarizationParams( 47 | is_diarize=diarization 48 | ), 49 | ).to_list() 50 | 51 | subtitle_str, file_paths = whisper_inferencer.transcribe_file( 52 | [audio_path], 53 | None, 54 | None, 55 | None, 56 | "SRT", 57 | False, 58 | gr.Progress(), 59 | *hparams, 60 | ) 61 | subtitle = read_file(file_paths[0]).split("\n") 62 | assert calculate_wer(answer, subtitle[2].strip().replace(",", "").replace(".", "")) < 0.1 63 | 64 | if not is_pytube_detected_bot(): 65 | subtitle_str, file_path = whisper_inferencer.transcribe_youtube( 66 | TEST_YOUTUBE_URL, 67 | "SRT", 68 | False, 69 | gr.Progress(), 70 | *hparams, 71 | ) 72 | assert isinstance(subtitle_str, str) and subtitle_str 73 | assert os.path.exists(file_path) 74 | 75 | subtitle_str, file_path = whisper_inferencer.transcribe_mic( 76 | audio_path, 77 | "SRT", 78 | False, 79 | gr.Progress(), 80 | *hparams, 81 | ) 82 | subtitle = read_file(file_path).split("\n") 83 | wer = calculate_wer(answer, subtitle[2].strip().replace(",", "").replace(".", "")) 84 | assert wer < 0.1, f"WER is too high, it's {wer}" 85 | 86 | 87 | @pytest.mark.parametrize( 88 | "whisper_type,vad_filter,bgm_separation,diarization", 89 | [ 90 | (WhisperImpl.WHISPER.value, False, False, False), 91 | (WhisperImpl.FASTER_WHISPER.value, False, False, False), 92 | (WhisperImpl.INSANELY_FAST_WHISPER.value, False, False, False) 93 | ] 94 | ) 95 | def test_transcribe( 96 | whisper_type: str, 97 | vad_filter: bool, 98 | bgm_separation: bool, 99 | diarization: bool, 100 | ): 101 | run_asr_pipeline(whisper_type, vad_filter, bgm_separation, diarization) 102 | 103 | 104 | -------------------------------------------------------------------------------- /tests/test_translation.py: -------------------------------------------------------------------------------- 1 | from modules.translation.deepl_api import DeepLAPI 2 | from modules.translation.nllb_inference import NLLBInference 3 | from test_config import * 4 | 5 | import os 6 | import pytest 7 | 8 | 9 | @pytest.mark.parametrize("model_size, file_path", [ 10 | (TEST_NLLB_MODEL, TEST_SUBTITLE_SRT_PATH), 11 | (TEST_NLLB_MODEL, TEST_SUBTITLE_VTT_PATH), 12 | ]) 13 | def test_nllb_inference( 14 | model_size: str, 15 | file_path: str 16 | ): 17 | nllb_inferencer = NLLBInference() 18 | print(f"NLLB Device : {nllb_inferencer.device}") 19 | 20 | result_str, file_paths = nllb_inferencer.translate_file( 21 | fileobjs=[file_path], 22 | model_size=model_size, 23 | src_lang="eng_Latn", 24 | tgt_lang="kor_Hang", 25 | ) 26 | 27 | assert isinstance(result_str, str) 28 | assert isinstance(file_paths[0], str) 29 | 30 | 31 | @pytest.mark.skipif( 32 | os.getenv("DEEPL_API_KEY") is None or not os.getenv("DEEPL_API_KEY"), 33 | reason="DeepL API key is unavailable" 34 | ) 35 | @pytest.mark.parametrize("file_path", [ 36 | TEST_SUBTITLE_SRT_PATH, 37 | TEST_SUBTITLE_VTT_PATH, 38 | ]) 39 | def test_deepl_api( 40 | file_path: str 41 | ): 42 | deepl_api = DeepLAPI() 43 | 44 | api_key = os.getenv("DEEPL_API_KEY") 45 | 46 | result_str, file_paths = deepl_api.translate_deepl( 47 | auth_key=api_key, 48 | fileobjs=[file_path], 49 | source_lang="English", 50 | target_lang="Korean", 51 | is_pro=False, 52 | add_timestamp=True, 53 | ) 54 | 55 | assert isinstance(result_str, str) 56 | assert isinstance(file_paths[0], str) 57 | -------------------------------------------------------------------------------- /tests/test_vad.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import pytest 3 | import os 4 | 5 | from modules.whisper.data_classes import * 6 | from modules.vad.silero_vad import SileroVAD 7 | from test_config import * 8 | from test_transcription import download_file, run_asr_pipeline 9 | from faster_whisper.vad import VadOptions, get_speech_timestamps 10 | 11 | 12 | @pytest.mark.parametrize( 13 | "whisper_type,vad_filter,bgm_separation,diarization", 14 | [ 15 | (WhisperImpl.WHISPER.value, True, False, False), 16 | (WhisperImpl.FASTER_WHISPER.value, True, False, False), 17 | (WhisperImpl.INSANELY_FAST_WHISPER.value, True, False, False) 18 | ] 19 | ) 20 | def test_vad_pipeline( 21 | whisper_type: str, 22 | vad_filter: bool, 23 | bgm_separation: bool, 24 | diarization: bool, 25 | ): 26 | run_asr_pipeline(whisper_type, vad_filter, bgm_separation, diarization) 27 | 28 | 29 | @pytest.mark.parametrize( 30 | "threshold,min_speech_duration_ms,min_silence_duration_ms", 31 | [ 32 | (0.5, 250, 2000), 33 | ] 34 | ) 35 | def test_vad( 36 | threshold: float, 37 | min_speech_duration_ms: int, 38 | min_silence_duration_ms: int 39 | ): 40 | audio_path_dir = os.path.join(WEBUI_DIR, "tests") 41 | audio_path = os.path.join(audio_path_dir, "jfk.wav") 42 | 43 | if not os.path.exists(audio_path): 44 | download_file(TEST_FILE_DOWNLOAD_URL, audio_path_dir) 45 | 46 | vad_model = SileroVAD() 47 | vad_model.update_model() 48 | 49 | audio, speech_chunks = vad_model.run( 50 | audio=audio_path, 51 | vad_parameters=VadOptions( 52 | threshold=threshold, 53 | min_silence_duration_ms=min_silence_duration_ms, 54 | min_speech_duration_ms=min_speech_duration_ms 55 | ) 56 | ) 57 | 58 | assert speech_chunks 59 | -------------------------------------------------------------------------------- /tests/test_vtt.vtt: -------------------------------------------------------------------------------- 1 | WEBVTT 2 | 00:00:00.500 --> 00:00:02.000 3 | You've got 4 | 5 | 00:00:02.500 --> 00:00:04.300 6 | a friend in me. -------------------------------------------------------------------------------- /user-start-webui.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | :: This batch file is for launching with command line args 3 | :: See the wiki for a guide to command line arguments: https://github.com/jhj0517/Whisper-WebUI/wiki/Command-Line-Arguments 4 | :: Set the values here to whatever you want. See the wiki above for how to set this. 5 | set SERVER_NAME= 6 | set SERVER_PORT= 7 | set USERNAME= 8 | set PASSWORD= 9 | set SHARE= 10 | set THEME= 11 | set API_OPEN= 12 | set WHISPER_TYPE= 13 | set WHISPER_MODEL_DIR= 14 | set FASTER_WHISPER_MODEL_DIR= 15 | set INSANELY_FAST_WHISPER_MODEL_DIR= 16 | set DIARIZATION_MODEL_DIR= 17 | 18 | 19 | if not "%SERVER_NAME%"=="" ( 20 | set SERVER_NAME_ARG=--server_name %SERVER_NAME% 21 | ) 22 | if not "%SERVER_PORT%"=="" ( 23 | set SERVER_PORT_ARG=--server_port %SERVER_PORT% 24 | ) 25 | if not "%USERNAME%"=="" ( 26 | set USERNAME_ARG=--username %USERNAME% 27 | ) 28 | if not "%PASSWORD%"=="" ( 29 | set PASSWORD_ARG=--password %PASSWORD% 30 | ) 31 | if /I "%SHARE%"=="true" ( 32 | set SHARE_ARG=--share 33 | ) 34 | if not "%THEME%"=="" ( 35 | set THEME_ARG=--theme %THEME% 36 | ) 37 | if /I "%DISABLE_FASTER_WHISPER%"=="true" ( 38 | set DISABLE_FASTER_WHISPER_ARG=--disable_faster_whisper 39 | ) 40 | if /I "%API_OPEN%"=="true" ( 41 | set API_OPEN=--api_open 42 | ) 43 | if not "%WHISPER_TYPE%"=="" ( 44 | set WHISPER_TYPE_ARG=--whisper_type %WHISPER_TYPE% 45 | ) 46 | if not "%WHISPER_MODEL_DIR%"=="" ( 47 | set WHISPER_MODEL_DIR_ARG=--whisper_model_dir "%WHISPER_MODEL_DIR%" 48 | ) 49 | if not "%FASTER_WHISPER_MODEL_DIR%"=="" ( 50 | set FASTER_WHISPER_MODEL_DIR_ARG=--faster_whisper_model_dir "%FASTER_WHISPER_MODEL_DIR%" 51 | ) 52 | if not "%INSANELY_FAST_WHISPER_MODEL_DIR%"=="" ( 53 | set INSANELY_FAST_WHISPER_MODEL_DIR_ARG=--insanely_fast_whisper_model_dir "%INSANELY_FAST_WHISPER_MODEL_DIR%" 54 | ) 55 | if not "%DIARIZATION_MODEL_DIR%"=="" ( 56 | set DIARIZATION_MODEL_DIR_ARG=--diarization_model_dir "%DIARIZATION_MODEL_DIR%" 57 | ) 58 | 59 | :: Call the original .bat script with cli arguments 60 | start-webui.bat %SERVER_NAME_ARG% %SERVER_PORT_ARG% %USERNAME_ARG% %PASSWORD_ARG% %SHARE_ARG% %THEME_ARG% %API_OPEN% %WHISPER_TYPE_ARG% %WHISPER_MODEL_DIR_ARG% %FASTER_WHISPER_MODEL_DIR_ARG% %INSANELY_FAST_WHISPER_MODEL_DIR_ARG% %DIARIZATION_MODEL_DIR_ARG% 61 | pause --------------------------------------------------------------------------------