├── .dockerignore
├── .github
├── FUNDING.yml
├── ISSUE_TEMPLATE
│ ├── bug_report.md
│ ├── feature_request.md
│ └── hallucination.md
├── pull_request_template.md
└── workflows
│ ├── ci.yml
│ ├── publish-docker.yml
│ └── publish-ghcr.yml
├── .gitignore
├── Dockerfile
├── Install.bat
├── Install.sh
├── LICENSE
├── README.md
├── app.py
├── backend
├── Dockerfile
├── README.md
├── __init__.py
├── cache
│ └── cached_files_are_generated_here
├── common
│ ├── audio.py
│ ├── cache_manager.py
│ ├── compresser.py
│ ├── config_loader.py
│ └── models.py
├── configs
│ ├── .env.example
│ └── config.yaml
├── db
│ ├── __init__.py
│ ├── db_instance.py
│ └── task
│ │ ├── __init__.py
│ │ ├── dao.py
│ │ └── models.py
├── docker-compose.yaml
├── main.py
├── nginx
│ ├── logs
│ │ └── logs_are_generated_here
│ ├── nginx.conf
│ └── temp
│ │ └── temps_are_generated_here
├── requirements-backend.txt
├── routers
│ ├── __init__.py
│ ├── bgm_separation
│ │ ├── __init__.py
│ │ ├── models.py
│ │ └── router.py
│ ├── task
│ │ ├── __init__.py
│ │ └── router.py
│ ├── transcription
│ │ ├── __init__.py
│ │ └── router.py
│ └── vad
│ │ ├── __init__.py
│ │ └── router.py
└── tests
│ ├── __init__.py
│ ├── test_backend_bgm_separation.py
│ ├── test_backend_config.py
│ ├── test_backend_transcription.py
│ ├── test_backend_vad.py
│ └── test_task_status.py
├── configs
├── default_parameters.yaml
└── translation.yaml
├── docker-compose.yaml
├── models
├── Diarization
│ └── diarization_models_will_be_saved_here
├── NLLB
│ └── nllb_models_will_be_saved_here
├── UVR
│ └── uvr_models_will_be_saved_here
├── Whisper
│ ├── faster-whisper
│ │ └── faster_whisper_models_will_be_saved_here
│ ├── insanely-fast-whisper
│ │ └── insanely_fast_whisper_models_will_be_saved_here
│ └── whisper_models_will_be_saved_here
└── models_will_be_saved_here
├── modules
├── __init__.py
├── diarize
│ ├── __init__.py
│ ├── audio_loader.py
│ ├── diarize_pipeline.py
│ └── diarizer.py
├── translation
│ ├── __init__.py
│ ├── deepl_api.py
│ ├── nllb_inference.py
│ └── translation_base.py
├── ui
│ ├── __init__.py
│ └── htmls.py
├── utils
│ ├── __init__.py
│ ├── audio_manager.py
│ ├── cli_manager.py
│ ├── constants.py
│ ├── files_manager.py
│ ├── logger.py
│ ├── paths.py
│ ├── subtitle_manager.py
│ └── youtube_manager.py
├── uvr
│ └── music_separator.py
├── vad
│ ├── __init__.py
│ └── silero_vad.py
└── whisper
│ ├── __init__.py
│ ├── base_transcription_pipeline.py
│ ├── data_classes.py
│ ├── faster_whisper_inference.py
│ ├── insanely_fast_whisper_inference.py
│ ├── whisper_Inference.py
│ └── whisper_factory.py
├── notebook
└── whisper-webui.ipynb
├── outputs
├── UVR
│ ├── instrumental
│ │ └── UVR_outputs_for_instrumental_will_be_saved_here
│ └── vocals
│ │ └── UVR_outputs_for_vocals_will_be_saved_here
├── outputs_will_be_saved_here
└── translations
│ └── translation_outputs_will_be_saved_here
├── requirements.txt
├── start-webui.bat
├── start-webui.sh
├── tests
├── test_bgm_separation.py
├── test_config.py
├── test_diarization.py
├── test_srt.srt
├── test_transcription.py
├── test_translation.py
├── test_vad.py
└── test_vtt.vtt
└── user-start-webui.bat
/.dockerignore:
--------------------------------------------------------------------------------
1 | # from .gitignore
2 | modules/yt_tmp.wav
3 | **/venv/
4 | **/__pycache__/
5 | **/outputs/
6 | **/models/
7 |
8 | **/.idea
9 | **/.git
10 | **/.github
11 |
--------------------------------------------------------------------------------
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | # These are supported funding model platforms
2 |
3 | github: [jhj0517]
4 | patreon: # Replace with a single Patreon username
5 | open_collective: # Replace with a single Open Collective username
6 | ko_fi: #
7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
9 | liberapay: # Replace with a single Liberapay username
10 | issuehunt: # Replace with a single IssueHunt username
11 | otechie: # Replace with a single Otechie username
12 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
13 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
14 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: ''
5 | labels: bug
6 | assignees: jhj0517
7 |
8 | ---
9 |
10 | **Which OS are you using?**
11 | - OS: [e.g. iOS or Windows.. If you are using Google Colab, just Colab.]
12 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Any feature you want
4 | title: ''
5 | labels: enhancement
6 | assignees: jhj0517
7 |
8 | ---
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/hallucination.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Hallucination
3 | about: Whisper hallucinations. ( Repeating certain words or subtitles starting too
4 | early, etc. )
5 | title: ''
6 | labels: hallucination
7 | assignees: jhj0517
8 |
9 | ---
10 |
11 | **Download URL for sample audio**
12 | - Please upload download URL for sample audio file so I can test with some settings for better result. You can use https://easyupload.io/ or any other service to share.
13 |
--------------------------------------------------------------------------------
/.github/pull_request_template.md:
--------------------------------------------------------------------------------
1 | ## Related issues / PRs. Summarize issues.
2 | - #
3 |
4 | ## Summarize Changes
5 | 1.
6 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | on:
4 | workflow_dispatch:
5 |
6 | push:
7 | branches:
8 | - master
9 | - intel-gpu
10 | pull_request:
11 | branches:
12 | - master
13 | - intel-gpu
14 |
15 | jobs:
16 | test:
17 | runs-on: ubuntu-latest
18 | strategy:
19 | matrix:
20 | python: ["3.10", "3.11", "3.12"]
21 |
22 | env:
23 | DEEPL_API_KEY: ${{ secrets.DEEPL_API_KEY }}
24 |
25 | steps:
26 | - name: Clean up space for action
27 | run: rm -rf /opt/hostedtoolcache
28 |
29 | - uses: actions/checkout@v4
30 | - name: Setup Python
31 | uses: actions/setup-python@v5
32 | with:
33 | python-version: ${{ matrix.python }}
34 |
35 | - name: Install git and ffmpeg
36 | run: sudo apt-get update && sudo apt-get install -y git ffmpeg
37 |
38 | - name: Install dependencies
39 | run: pip install -r requirements.txt pytest jiwer
40 |
41 | - name: Run test
42 | run: python -m pytest -rs tests
43 |
44 | test-backend:
45 | runs-on: ubuntu-latest
46 | strategy:
47 | matrix:
48 | python: ["3.10", "3.11", "3.12"]
49 |
50 | env:
51 | DEEPL_API_KEY: ${{ secrets.DEEPL_API_KEY }}
52 | TEST_ENV: true
53 |
54 | steps:
55 | - name: Clean up space for action
56 | run: rm -rf /opt/hostedtoolcache
57 |
58 | - uses: actions/checkout@v4
59 | - name: Setup Python
60 | uses: actions/setup-python@v5
61 | with:
62 | python-version: ${{ matrix.python }}
63 |
64 | - name: Install git and ffmpeg
65 | run: sudo apt-get update && sudo apt-get install -y git ffmpeg
66 |
67 | - name: Install dependencies
68 | run: pip install -r backend/requirements-backend.txt pytest pytest-asyncio jiwer
69 |
70 | - name: Run test
71 | run: python -m pytest -rs backend/tests
72 |
73 | test-shell-script:
74 | runs-on: ubuntu-latest
75 | strategy:
76 | matrix:
77 | python: [ "3.10", "3.11", "3.12" ]
78 |
79 | steps:
80 | - name: Clean up space for action
81 | run: rm -rf /opt/hostedtoolcache
82 |
83 | - uses: actions/checkout@v4
84 | - name: Setup Python
85 | uses: actions/setup-python@v5
86 | with:
87 | python-version: ${{ matrix.python }}
88 |
89 | - name: Install git and ffmpeg
90 | run: sudo apt-get update && sudo apt-get install -y git ffmpeg
91 |
92 | - name: Execute Install.sh
93 | run: |
94 | chmod +x ./Install.sh
95 | ./Install.sh
96 |
97 | - name: Execute start-webui.sh
98 | run: |
99 | chmod +x ./start-webui.sh
100 | timeout 60s ./start-webui.sh || true
101 |
102 |
--------------------------------------------------------------------------------
/.github/workflows/publish-docker.yml:
--------------------------------------------------------------------------------
1 | name: Publish to Docker Hub
2 |
3 | on:
4 | # Triggers minor version ( vX.Y.Z-ShortHash )
5 | push:
6 | branches:
7 | - master
8 | # Triggers major version ( vX.Y.Z )
9 | release:
10 | types: [created]
11 |
12 | workflow_dispatch:
13 |
14 | jobs:
15 | build-and-push:
16 | runs-on: ubuntu-latest
17 |
18 | strategy:
19 | matrix:
20 | name: [whisper-webui, whisper-webui-backend]
21 |
22 | steps:
23 | - name: Clean up space for action
24 | run: rm -rf /opt/hostedtoolcache
25 |
26 | - name: Checkout repository
27 | uses: actions/checkout@v3
28 |
29 | - name: Set up Docker Buildx
30 | uses: docker/setup-buildx-action@v3
31 |
32 | - name: Set up QEMU
33 | uses: docker/setup-qemu-action@v3
34 |
35 | - name: Extract metadata
36 | id: meta
37 | run: |
38 | SHORT_SHA=$(git rev-parse --short HEAD)
39 | echo "SHORT_SHA=$SHORT_SHA" >> $GITHUB_ENV
40 |
41 | # Triggered by a release event — versioning as major ( vX.Y.Z )
42 | if [[ "${GITHUB_EVENT_NAME}" == "release" ]]; then
43 | TAG_NAME="${{ github.event.release.tag_name }}"
44 | echo "GIT_TAG=$TAG_NAME" >> $GITHUB_ENV
45 | echo "IS_RELEASE=true" >> $GITHUB_ENV
46 |
47 | # Triggered by a general push event — versioning as minor ( vX.Y.Z-ShortHash )
48 | else
49 | git fetch --tags
50 | LATEST_TAG=$(git tag --list 'v*.*.*' | sort -V | tail -n1)
51 | FALLBACK_TAG="${LATEST_TAG:-v0.0.0}"
52 | echo "GIT_TAG=${FALLBACK_TAG}-${SHORT_SHA}" >> $GITHUB_ENV
53 | echo "IS_RELEASE=false" >> $GITHUB_ENV
54 | fi
55 |
56 | - name: Set Dockerfile path
57 | id: dockerfile
58 | run: |
59 | if [ "${{ matrix.name }}" = "whisper-webui" ]; then
60 | echo "DOCKERFILE=./Dockerfile" >> $GITHUB_ENV
61 | elif [ "${{ matrix.name }}" = "whisper-webui-backend" ]; then
62 | echo "DOCKERFILE=./backend/Dockerfile" >> $GITHUB_ENV
63 | else
64 | echo "Unknown component: ${{ matrix.name }}"
65 | exit 1
66 | fi
67 |
68 | - name: Log in to Docker Hub
69 | uses: docker/login-action@v2
70 | with:
71 | username: ${{ secrets.DOCKER_USERNAME }}
72 | password: ${{ secrets.DOCKER_PASSWORD }}
73 |
74 | - name: Build and push Docker image (version tag)
75 | uses: docker/build-push-action@v5
76 | with:
77 | context: .
78 | file: ${{ env.DOCKERFILE }}
79 | push: true
80 | tags: |
81 | ${{ secrets.DOCKER_USERNAME }}/${{ matrix.name }}:${{ env.GIT_TAG }}
82 |
83 | - name: Tag and push as latest (if release)
84 | if: env.IS_RELEASE == 'true'
85 | run: |
86 | docker pull ${{ secrets.DOCKER_USERNAME }}/${{ matrix.name }}:${{ env.GIT_TAG }}
87 | docker tag ${{ secrets.DOCKER_USERNAME }}/${{ matrix.name }}:${{ env.GIT_TAG }} \
88 | ${{ secrets.DOCKER_USERNAME }}/${{ matrix.name }}:latest
89 | docker push ${{ secrets.DOCKER_USERNAME }}/${{ matrix.name }}:latest
90 |
91 | - name: Log out of Docker Hub
92 | run: docker logout
93 |
--------------------------------------------------------------------------------
/.github/workflows/publish-ghcr.yml:
--------------------------------------------------------------------------------
1 | name: Publish to GHCR
2 |
3 | on:
4 | # Triggers minor version ( vX.Y.Z-ShortHash )
5 | push:
6 | branches:
7 | - master
8 | # Triggers major version ( vX.Y.Z )
9 | release:
10 | types: [created]
11 |
12 | workflow_dispatch:
13 |
14 | jobs:
15 | build-and-push:
16 | runs-on: ubuntu-latest
17 | permissions:
18 | packages: write
19 |
20 | strategy:
21 | matrix:
22 | name: [whisper-webui, whisper-webui-backend]
23 |
24 | steps:
25 | - name: Checkout repository
26 | uses: actions/checkout@v3
27 |
28 | - name: Set up Docker Buildx
29 | uses: docker/setup-buildx-action@v3
30 |
31 | - name: Set up QEMU
32 | uses: docker/setup-qemu-action@v3
33 |
34 | - name: Extract metadata
35 | id: meta
36 | run: |
37 | SHORT_SHA=$(git rev-parse --short HEAD)
38 | echo "SHORT_SHA=$SHORT_SHA" >> $GITHUB_ENV
39 |
40 | # Triggered by a release event — versioning as major ( vX.Y.Z )
41 | if [[ "${GITHUB_EVENT_NAME}" == "release" ]]; then
42 | TAG_NAME="${{ github.event.release.tag_name }}"
43 | echo "GIT_TAG=$TAG_NAME" >> $GITHUB_ENV
44 | echo "IS_RELEASE=true" >> $GITHUB_ENV
45 |
46 | # Triggered by a general push event — versioning as minor ( vX.Y.Z-ShortHash )
47 | else
48 | git fetch --tags
49 | LATEST_TAG=$(git tag --list 'v*.*.*' | sort -V | tail -n1)
50 | FALLBACK_TAG="${LATEST_TAG:-v0.0.0}"
51 | echo "GIT_TAG=${FALLBACK_TAG}-${SHORT_SHA}" >> $GITHUB_ENV
52 | echo "IS_RELEASE=false" >> $GITHUB_ENV
53 | fi
54 |
55 | echo "REPO_OWNER_LC=${GITHUB_REPOSITORY_OWNER,,}" >> $GITHUB_ENV
56 |
57 | - name: Set Dockerfile path
58 | id: dockerfile
59 | run: |
60 | if [ "${{ matrix.name }}" = "whisper-webui" ]; then
61 | echo "DOCKERFILE=./Dockerfile" >> $GITHUB_ENV
62 | elif [ "${{ matrix.name }}" = "whisper-webui-backend" ]; then
63 | echo "DOCKERFILE=./backend/Dockerfile" >> $GITHUB_ENV
64 | else
65 | echo "Unknown component: ${{ matrix.name }}"
66 | exit 1
67 | fi
68 |
69 | - name: Log in to GitHub Container Registry
70 | uses: docker/login-action@v2
71 | with:
72 | registry: ghcr.io
73 | username: ${{ github.actor }}
74 | password: ${{ secrets.GITHUB_TOKEN }}
75 |
76 | - name: Build and push Docker image (version tag)
77 | uses: docker/build-push-action@v5
78 | with:
79 | context: .
80 | file: ${{ env.DOCKERFILE }}
81 | push: true
82 | tags: |
83 | ghcr.io/${{ env.REPO_OWNER_LC }}/${{ matrix.name }}:${{ env.GIT_TAG }}
84 |
85 | - name: Tag and push as latest (if release)
86 | if: env.IS_RELEASE == 'true'
87 | run: |
88 | docker pull ghcr.io/${{ env.REPO_OWNER_LC }}/${{ matrix.name }}:${{ env.GIT_TAG }}
89 | docker tag ghcr.io/${{ env.REPO_OWNER_LC }}/${{ matrix.name }}:${{ env.GIT_TAG }} \
90 | ghcr.io/${{ env.REPO_OWNER_LC }}/${{ matrix.name }}:latest
91 | docker push ghcr.io/${{ env.REPO_OWNER_LC }}/${{ matrix.name }}:latest
92 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.wav
2 | *.png
3 | *.mp4
4 | *.mp3
5 | **/.env
6 | **/.idea/
7 | **/.pytest_cache/
8 | **/venv/
9 | **/__pycache__/
10 | outputs/
11 | models/
12 | modules/yt_tmp.wav
13 | configs/default_parameters.yaml
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM debian:bookworm-slim AS builder
2 |
3 | RUN apt-get update && \
4 | apt-get install -y curl git python3 python3-pip python3-venv && \
5 | rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* && \
6 | mkdir -p /Whisper-WebUI
7 |
8 | WORKDIR /Whisper-WebUI
9 |
10 | COPY requirements.txt .
11 |
12 | RUN python3 -m venv venv && \
13 | . venv/bin/activate && \
14 | pip install -U -r requirements.txt
15 |
16 |
17 | FROM debian:bookworm-slim AS runtime
18 |
19 | RUN apt-get update && \
20 | apt-get install -y curl ffmpeg python3 && \
21 | rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
22 |
23 | WORKDIR /Whisper-WebUI
24 |
25 | COPY . .
26 | COPY --from=builder /Whisper-WebUI/venv /Whisper-WebUI/venv
27 |
28 | VOLUME [ "/Whisper-WebUI/models" ]
29 | VOLUME [ "/Whisper-WebUI/outputs" ]
30 |
31 | ENV PATH="/Whisper-WebUI/venv/bin:$PATH"
32 | ENV LD_LIBRARY_PATH=/Whisper-WebUI/venv/lib64/python3.11/site-packages/nvidia/cublas/lib:/Whisper-WebUI/venv/lib64/python3.11/site-packages/nvidia/cudnn/lib
33 |
34 | ENTRYPOINT [ "python", "app.py" ]
35 |
--------------------------------------------------------------------------------
/Install.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 |
3 | if not exist "%~dp0\venv\Scripts" (
4 | echo Creating venv...
5 | python -m venv venv
6 | )
7 | echo checked the venv folder. now installing requirements..
8 |
9 | call "%~dp0\venv\scripts\activate"
10 |
11 | python -m pip install -U pip
12 | pip install -r requirements.txt
13 |
14 | if errorlevel 1 (
15 | echo.
16 | echo Requirements installation failed. please remove venv folder and run install.bat again.
17 | ) else (
18 | echo.
19 | echo Requirements installed successfully.
20 | )
21 | pause
--------------------------------------------------------------------------------
/Install.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [ ! -d "venv" ]; then
4 | echo "Creating virtual environment..."
5 | python -m venv venv
6 | fi
7 |
8 | source venv/bin/activate
9 |
10 | python -m pip install -U pip
11 | pip install -r requirements.txt && echo "Requirements installed successfully." || {
12 | echo ""
13 | echo "Requirements installation failed. Please remove the venv folder and run the script again."
14 | deactivate
15 | exit 1
16 | }
17 |
18 | deactivate
19 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Whisper-WebUI
2 | A Gradio-based browser interface for [Whisper](https://github.com/openai/whisper). You can use it as an Easy Subtitle Generator!
3 |
4 | 
5 |
6 |
7 |
8 | ## Notebook
9 | If you wish to try this on Colab, you can do it in [here](https://colab.research.google.com/github/jhj0517/Whisper-WebUI/blob/master/notebook/whisper-webui.ipynb)!
10 |
11 | # Feature
12 | - Select the Whisper implementation you want to use between :
13 | - [openai/whisper](https://github.com/openai/whisper)
14 | - [SYSTRAN/faster-whisper](https://github.com/SYSTRAN/faster-whisper) (used by default)
15 | - [Vaibhavs10/insanely-fast-whisper](https://github.com/Vaibhavs10/insanely-fast-whisper)
16 | - Generate subtitles from various sources, including :
17 | - Files
18 | - Youtube
19 | - Microphone
20 | - Currently supported subtitle formats :
21 | - SRT
22 | - WebVTT
23 | - txt ( only text file without timeline )
24 | - Speech to Text Translation
25 | - From other languages to English. ( This is Whisper's end-to-end speech-to-text translation feature )
26 | - Text to Text Translation
27 | - Translate subtitle files using Facebook NLLB models
28 | - Translate subtitle files using DeepL API
29 | - Pre-processing audio input with [Silero VAD](https://github.com/snakers4/silero-vad).
30 | - Pre-processing audio input to separate BGM with [UVR](https://github.com/Anjok07/ultimatevocalremovergui).
31 | - Post-processing with speaker diarization using the [pyannote](https://huggingface.co/pyannote/speaker-diarization-3.1) model.
32 | - To download the pyannote model, you need to have a Huggingface token and manually accept their terms in the pages below.
33 | 1. https://huggingface.co/pyannote/speaker-diarization-3.1
34 | 2. https://huggingface.co/pyannote/segmentation-3.0
35 |
36 | ### Pipeline Diagram
37 | 
38 |
39 | # Installation and Running
40 |
41 | - ## Running with Pinokio
42 |
43 | The app is able to run with [Pinokio](https://github.com/pinokiocomputer/pinokio).
44 |
45 | 1. Install [Pinokio Software](https://program.pinokio.computer/#/?id=install).
46 | 2. Open the software and search for Whisper-WebUI and install it.
47 | 3. Start the Whisper-WebUI and connect to the `http://localhost:7860`.
48 |
49 | - ## Running with Docker
50 |
51 | 1. Install and launch [Docker-Desktop](https://www.docker.com/products/docker-desktop/).
52 |
53 | 2. Git clone the repository
54 |
55 | ```sh
56 | git clone https://github.com/jhj0517/Whisper-WebUI.git
57 | ```
58 |
59 | 3. Build the image ( Image is about 7GB~ )
60 |
61 | ```sh
62 | docker compose build
63 | ```
64 |
65 | 4. Run the container
66 |
67 | ```sh
68 | docker compose up
69 | ```
70 |
71 | 5. Connect to the WebUI with your browser at `http://localhost:7860`
72 |
73 | If needed, update the [`docker-compose.yaml`](https://github.com/jhj0517/Whisper-WebUI/blob/master/docker-compose.yaml) to match your environment.
74 |
75 | - ## Run Locally
76 |
77 | ### Prerequisite
78 | To run this WebUI, you need to have `git`, `3.10 <= python <= 3.12`, `FFmpeg`.
79 |
80 | **Edit `--extra-index-url` in the [`requirements.txt`](https://github.com/jhj0517/Whisper-WebUI/blob/master/requirements.txt) to match your device.
**
81 | By default, the WebUI assumes you're using an Nvidia GPU and **CUDA 12.6.** If you're using Intel or another CUDA version, read the [`requirements.txt`](https://github.com/jhj0517/Whisper-WebUI/blob/master/requirements.txt) and edit `--extra-index-url`.
82 |
83 | Please follow the links below to install the necessary software:
84 | - git : [https://git-scm.com/downloads](https://git-scm.com/downloads)
85 | - python : [https://www.python.org/downloads/](https://www.python.org/downloads/) **`3.10 ~ 3.12` is recommended.**
86 | - FFmpeg : [https://ffmpeg.org/download.html](https://ffmpeg.org/download.html)
87 | - CUDA : [https://developer.nvidia.com/cuda-downloads](https://developer.nvidia.com/cuda-downloads)
88 |
89 | After installing FFmpeg, **make sure to add the `FFmpeg/bin` folder to your system PATH!**
90 |
91 | ### Installation Using the Script Files
92 |
93 | 1. git clone this repository
94 | ```shell
95 | git clone https://github.com/jhj0517/Whisper-WebUI.git
96 | ```
97 | 2. Run `install.bat` or `install.sh` to install dependencies. (It will create a `venv` directory and install dependencies there.)
98 | 3. Start WebUI with `start-webui.bat` or `start-webui.sh` (It will run `python app.py` after activating the venv)
99 |
100 | And you can also run the project with command line arguments if you like to, see [wiki](https://github.com/jhj0517/Whisper-WebUI/wiki/Command-Line-Arguments) for a guide to arguments.
101 |
102 | # VRAM Usages
103 | This project is integrated with [faster-whisper](https://github.com/guillaumekln/faster-whisper) by default for better VRAM usage and transcription speed.
104 |
105 | According to faster-whisper, the efficiency of the optimized whisper model is as follows:
106 | | Implementation | Precision | Beam size | Time | Max. GPU memory | Max. CPU memory |
107 | |-------------------|-----------|-----------|-------|-----------------|-----------------|
108 | | openai/whisper | fp16 | 5 | 4m30s | 11325MB | 9439MB |
109 | | faster-whisper | fp16 | 5 | 54s | 4755MB | 3244MB |
110 |
111 | If you want to use an implementation other than faster-whisper, use `--whisper_type` arg and the repository name.
112 | Read [wiki](https://github.com/jhj0517/Whisper-WebUI/wiki/Command-Line-Arguments) for more info about CLI args.
113 |
114 | If you want to use a fine-tuned model, manually place the models in `models/Whisper/` corresponding to the implementation.
115 |
116 | Alternatively, if you enter the huggingface repo id (e.g, [deepdml/faster-whisper-large-v3-turbo-ct2](https://huggingface.co/deepdml/faster-whisper-large-v3-turbo-ct2)) in the "Model" dropdown, it will be automatically downloaded in the directory.
117 |
118 | 
119 |
120 | # REST API
121 | If you're interested in deploying this app as a REST API, please check out [/backend](https://github.com/jhj0517/Whisper-WebUI/tree/master/backend).
122 |
123 | ## TODO🗓
124 |
125 | - [x] Add DeepL API translation
126 | - [x] Add NLLB Model translation
127 | - [x] Integrate with faster-whisper
128 | - [x] Integrate with insanely-fast-whisper
129 | - [x] Integrate with whisperX ( Only speaker diarization part )
130 | - [x] Add background music separation pre-processing with [UVR](https://github.com/Anjok07/ultimatevocalremovergui)
131 | - [x] Add fast api script
132 | - [ ] Add CLI usages
133 | - [ ] Support real-time transcription for microphone
134 |
135 | ### Translation 🌐
136 | Any PRs that translate the language into [translation.yaml](https://github.com/jhj0517/Whisper-WebUI/blob/master/configs/translation.yaml) would be greatly appreciated!
137 |
--------------------------------------------------------------------------------
/backend/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM debian:bookworm-slim AS builder
2 |
3 | RUN apt-get update && \
4 | apt-get install -y curl git python3 python3-pip python3-venv && \
5 | rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* && \
6 | mkdir -p /Whisper-WebUI
7 |
8 | WORKDIR /Whisper-WebUI
9 |
10 | COPY backend/ backend/
11 | COPY requirements.txt requirements.txt
12 |
13 | RUN python3 -m venv venv && \
14 | . venv/bin/activate && \
15 | pip install -U -r backend/requirements-backend.txt
16 |
17 |
18 | FROM debian:bookworm-slim AS runtime
19 |
20 | RUN apt-get update && \
21 | apt-get install -y curl ffmpeg python3 && \
22 | rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
23 |
24 | WORKDIR /Whisper-WebUI
25 |
26 | COPY . .
27 | COPY --from=builder /Whisper-WebUI/venv /Whisper-WebUI/venv
28 |
29 | VOLUME [ "/Whisper-WebUI/models" ]
30 | VOLUME [ "/Whisper-WebUI/outputs" ]
31 | VOLUME [ "/Whisper-WebUI/backend" ]
32 |
33 | ENV PATH="/Whisper-WebUI/venv/bin:$PATH"
34 | ENV LD_LIBRARY_PATH=/Whisper-WebUI/venv/lib64/python3.11/site-packages/nvidia/cublas/lib:/Whisper-WebUI/venv/lib64/python3.11/site-packages/nvidia/cudnn/lib
35 |
36 | ENTRYPOINT ["uvicorn", "backend.main:app", "--host", "0.0.0.0", "--port", "8000"]
37 |
--------------------------------------------------------------------------------
/backend/README.md:
--------------------------------------------------------------------------------
1 | # Whisper-WebUI REST API
2 | REST API for Whisper-WebUI. Documentation is auto-generated upon deploying the app.
3 |
[Swagger UI](https://github.com/swagger-api/swagger-ui) is available at `app/docs` or root URL with redirection. [Redoc](https://github.com/Redocly/redoc) is available at `app/redoc`.
4 |
5 | # Setup and Installation
6 |
7 | Installation assumes that you are in the root directory of Whisper-WebUI
8 |
9 | 1. Create `.env` in `backend/configs/.env`
10 | ```
11 | HF_TOKEN="YOUR_HF_TOKEN FOR DIARIZATION MODEL (READ PERMISSION)"
12 | DB_URL="sqlite:///backend/records.db"
13 | ```
14 | `HF_TOKEN` is used to download diarization model, `DB_URL` indicates where your db file is located. It is stored in `backend/` by default.
15 |
16 | 2. Install dependency
17 | ```
18 | pip install -r backend/requirements-backend.txt
19 | ```
20 |
21 | 3. Deploy the server with `uvicorn` or whatever.
22 | ```
23 | uvicorn backend.main:app --host 0.0.0.0 --port 8000
24 | ```
25 |
26 | ### Deploy with your domain name
27 | You can deploy the server with your domain name by setting up a reverse proxy with Nginx.
28 |
29 | 1. Install Nginx if you don't already have it.
30 | - Linux : https://nginx.org/en/docs/install.html
31 | - Windows : https://nginx.org/en/docs/windows.html
32 |
33 | 2. Edit [`nginx.conf`](https://github.com/jhj0517/Whisper-WebUI/blob/master/backend/nginx/nginx.conf) for your domain name.
34 | https://github.com/jhj0517/Whisper-WebUI/blob/895cafe400944396ad8be5b1cc793b54fecc8bbe/backend/nginx/nginx.conf#L12
35 |
36 | 3. Add an A type record of your public IPv4 address in your domain provider. (you can get it by searching "What is my IP" in Google)
37 |
38 | 4. Open a terminal and go to the location of [`nginx.conf`](https://github.com/jhj0517/Whisper-WebUI/blob/master/backend/nginx/nginx.conf), then start the nginx server, so that you can manage nginx-related logs there.
39 | ```shell
40 | cd backend/nginx
41 | nginx -c "/path/to/Whisper-WebUI/backend/nginx/nginx.conf"
42 | ```
43 |
44 | 5. Open another terminal in the root project location `/Whisper-WebUI`, and deploy the app with `uvicorn` or whatever. Now the app will be available at your domain.
45 | ```shell
46 | uvicorn backend.main:app --host 0.0.0.0 --port 8000
47 | ```
48 |
49 | 6. When you turn off nginx, you can use `nginx -s stop`.
50 | ```shell
51 | cd backend/nginx
52 | nginx -s stop -c "/path/to/Whisper-WebUI/backend/nginx/nginx.conf"
53 | ```
54 |
55 |
56 | ## Configuration
57 | You can set some server configurations in [config.yaml](https://github.com/jhj0517/Whisper-WebUI/blob/master/backend/configs/config.yaml).
58 |
For example, initial model size for Whisper or the cleanup frequency and TTL for cached files.
59 |
If the endpoint generates and saves the file, all output files are stored in the `cache` directory, e.g. separated vocal/instrument files for `/bgm-separation` are saved in `cache` directory.
60 |
61 | ## Docker
62 | You can also deploy the server with Docker for easy deployment.
63 | The Dockerfile should be built when you're in the root directory of Whisper-WebUI.
64 |
65 | 1. git clone this repository
66 | ```
67 | git clone https://github.com/jhj0517/Whisper-WebUI.git
68 | ```
69 | 2. Mount volume paths with your local paths in `docker-compose.yaml`
70 | https://github.com/jhj0517/Whisper-WebUI/blob/1dd708ec3844dbf0c1f77de9ef5764e883dd4c78/backend/docker-compose.yaml#L12-L15
71 | 3. Build the image
72 | ```
73 | docker compose -f backend/docker-compose.yaml build
74 | ```
75 | 4. Run the container
76 | ```
77 | docker compose -f backend/docker-compose.yaml up
78 | ```
79 |
80 | 5. Then you can read docs at `localhost:8000` (default port is set to `8000` in `docker-compose.yaml`) and run your own tests.
81 |
82 |
83 | # Architecture
84 |
85 | 
86 |
87 | The response can be obtained through [the polling API](https://docs.oracle.com/en/cloud/saas/marketing/responsys-develop/API/REST/Async/asyncApi-v1.3-requests-requestId-get.htm).
88 | Each task is stored in the DB whenever the task is queued or updated by the process.
89 |
90 | When the client first sends the `POST` request, the server returns an `identifier` to the client that can be used to track the status of the task. The task status is updated by the processes, and once the task is completed, the client can finally obtain the result.
91 |
92 | The client needs to implement manual API polling to do this, this is the example for the python client:
93 | ```python
94 | def wait_for_task_completion(identifier: str,
95 | max_attempts: int = 20,
96 | frequency: int = 3) -> httpx.Response:
97 | """
98 | Polls the task status every `frequency` until it is completed, failed, or the `max_attempts` are reached.
99 | """
100 | attempts = 0
101 | while attempts < max_attempts:
102 | task = fetch_task(identifier)
103 | status = task.json()["status"]
104 | if status == "COMPLETED":
105 | return task["result"]
106 | if status == "FAILED":
107 | raise Exception("Task polling failed")
108 | time.sleep(frequency)
109 | attempts += 1
110 | return None
111 | ```
112 |
--------------------------------------------------------------------------------
/backend/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/backend/__init__.py
--------------------------------------------------------------------------------
/backend/cache/cached_files_are_generated_here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/backend/cache/cached_files_are_generated_here
--------------------------------------------------------------------------------
/backend/common/audio.py:
--------------------------------------------------------------------------------
1 | from io import BytesIO
2 | import numpy as np
3 | import httpx
4 | import faster_whisper
5 | from pydantic import BaseModel
6 | from fastapi import (
7 | HTTPException,
8 | UploadFile,
9 | )
10 | from typing import Annotated, Any, BinaryIO, Literal, Generator, Union, Optional, List, Tuple
11 |
12 |
13 | class AudioInfo(BaseModel):
14 | duration: float
15 |
16 |
17 | async def read_audio(
18 | file: Optional[UploadFile] = None,
19 | file_url: Optional[str] = None
20 | ):
21 | """Read audio from "UploadFile". This resamples sampling rates to 16000."""
22 | if (file and file_url) or (not file and not file_url):
23 | raise HTTPException(status_code=400, detail="Provide only one of file or file_url")
24 |
25 | if file:
26 | file_content = await file.read()
27 | elif file_url:
28 | async with httpx.AsyncClient() as client:
29 | file_response = await client.get(file_url)
30 | if file_response.status_code != 200:
31 | raise HTTPException(status_code=422, detail="Could not download the file")
32 | file_content = file_response.content
33 | file_bytes = BytesIO(file_content)
34 | audio = faster_whisper.audio.decode_audio(file_bytes)
35 | duration = len(audio) / 16000
36 | return audio, AudioInfo(duration=duration)
37 |
--------------------------------------------------------------------------------
/backend/common/cache_manager.py:
--------------------------------------------------------------------------------
1 | import time
2 | import os
3 | from typing import Optional
4 |
5 | from modules.utils.paths import BACKEND_CACHE_DIR
6 |
7 |
8 | def cleanup_old_files(cache_dir: str = BACKEND_CACHE_DIR, ttl: int = 60):
9 | now = time.time()
10 | place_holder_name = "cached_files_are_generated_here"
11 | for root, dirs, files in os.walk(cache_dir):
12 | for filename in files:
13 | if filename == place_holder_name:
14 | continue
15 | filepath = os.path.join(root, filename)
16 | if now - os.path.getmtime(filepath) > ttl:
17 | try:
18 | os.remove(filepath)
19 | except Exception as e:
20 | print(f"Error removing {filepath}")
21 | raise
22 |
--------------------------------------------------------------------------------
/backend/common/compresser.py:
--------------------------------------------------------------------------------
1 | import os
2 | import zipfile
3 | from typing import List, Optional
4 | import hashlib
5 |
6 |
7 | def compress_files(file_paths: List[str], output_zip_path: str) -> str:
8 | """
9 | Compress multiple files into a single zip file.
10 |
11 | Args:
12 | file_paths (List[str]): List of paths to files to be compressed.
13 | output_zip (str): Path and name of the output zip file.
14 |
15 | Raises:
16 | FileNotFoundError: If any of the input files doesn't exist.
17 | """
18 | os.makedirs(os.path.dirname(output_zip_path), exist_ok=True)
19 | compression = zipfile.ZIP_DEFLATED
20 |
21 | with zipfile.ZipFile(output_zip_path, 'w', compression=compression) as zipf:
22 | for file_path in file_paths:
23 | if not os.path.exists(file_path):
24 | raise FileNotFoundError(f"File not found: {file_path}")
25 |
26 | file_name = os.path.basename(file_path)
27 | zipf.write(file_path, file_name)
28 | return output_zip_path
29 |
30 |
31 | def get_file_hash(file_path: str) -> str:
32 | """Generate the hash of a file using the specified hashing algorithm. It generates hash by content not path. """
33 | hash_func = hashlib.new("sha256")
34 | try:
35 | with open(file_path, 'rb') as f:
36 | for chunk in iter(lambda: f.read(4096), b""):
37 | hash_func.update(chunk)
38 | return hash_func.hexdigest()
39 | except FileNotFoundError:
40 | return f"File not found: {file_path}"
41 | except Exception as e:
42 | return f"An error occurred: {str(e)}"
43 |
44 |
45 | def find_file_by_hash(dir_path: str, hash_str: str) -> Optional[str]:
46 | """Get file path from the directory based on its hash"""
47 | if not os.path.exists(dir_path) and os.path.isdir(dir_path):
48 | raise ValueError(f"Directory {dir_path} does not exist")
49 |
50 | files = [os.path.join(dir_path, f) for f in os.listdir(dir_path) if os.path.isfile(os.path.join(dir_path, f))]
51 |
52 | for f in files:
53 | f_hash = get_file_hash(f)
54 | if hash_str == f_hash:
55 | return f
56 | return None
57 |
58 |
59 |
--------------------------------------------------------------------------------
/backend/common/config_loader.py:
--------------------------------------------------------------------------------
1 | from dotenv import load_dotenv
2 | import os
3 | from modules.utils.paths import SERVER_CONFIG_PATH, SERVER_DOTENV_PATH
4 | from modules.utils.files_manager import load_yaml, save_yaml
5 |
6 | import functools
7 |
8 |
9 | @functools.lru_cache
10 | def load_server_config(config_path: str = SERVER_CONFIG_PATH) -> dict:
11 | if os.getenv("TEST_ENV", "false").lower() == "true":
12 | server_config = load_yaml(config_path)
13 | server_config["whisper"]["model_size"] = "tiny"
14 | server_config["whisper"]["compute_type"] = "float32"
15 | save_yaml(server_config, config_path)
16 |
17 | return load_yaml(config_path)
18 |
19 |
20 | @functools.lru_cache
21 | def read_env(key: str, default: str = None, dotenv_path: str = SERVER_DOTENV_PATH):
22 | load_dotenv(dotenv_path)
23 | value = os.getenv(key, default)
24 | return value
25 |
26 |
--------------------------------------------------------------------------------
/backend/common/models.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel, Field, validator
2 | from typing import List, Any, Optional
3 | from backend.db.task.models import TaskStatus, ResultType, TaskType
4 |
5 |
6 | class QueueResponse(BaseModel):
7 | identifier: str = Field(..., description="Unique identifier for the queued task that can be used for tracking")
8 | status: TaskStatus = Field(..., description="Current status of the task")
9 | message: str = Field(..., description="Message providing additional information about the task")
10 |
11 |
12 | class Response(BaseModel):
13 | identifier: str
14 | message: str
15 |
--------------------------------------------------------------------------------
/backend/configs/.env.example:
--------------------------------------------------------------------------------
1 | HF_TOKEN="YOUR_HF_TOKEN FOR DIARIZATION MODEL (READ PERMISSION)"
2 | DB_URL="sqlite:///backend/records.db"
--------------------------------------------------------------------------------
/backend/configs/config.yaml:
--------------------------------------------------------------------------------
1 | whisper:
2 | # Default implementation is faster-whisper. This indicates model name within `models\Whisper\faster-whisper`
3 | model_size: large-v2
4 | # Compute type. 'float16' for CUDA, 'float32' for CPU.
5 | compute_type: float16
6 | # Whether to offload the model after the inference.
7 | enable_offload: true
8 |
9 | bgm_separation:
10 | # UVR model sizes between ["UVR-MDX-NET-Inst_HQ_4", "UVR-MDX-NET-Inst_3"]
11 | model_size: UVR-MDX-NET-Inst_HQ_4
12 | # Whether to offload the model after the inference. Should be true if your setup has a VRAM less than <16GB
13 | enable_offload: true
14 | # Device to load BGM separation model between ["cuda", "cpu", "xpu"]
15 | device: cuda
16 |
17 | # Settings that apply to the `cache' directory. The output files for `/bgm-separation` are stored in the `cache' directory,
18 | # (You can check out the actual generated files by testing `/bgm-separation`.)
19 | # You can adjust the TTL/cleanup frequency of the files in the `cache' directory here.
20 | cache:
21 | # TTL (Time-To-Live) in seconds, defaults to 10 minutes
22 | ttl: 600
23 | # Clean up frequency in seconds, defaults to 1 minutes
24 | frequency: 60
25 |
26 |
--------------------------------------------------------------------------------
/backend/db/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/backend/db/__init__.py
--------------------------------------------------------------------------------
/backend/db/db_instance.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import os
3 | from sqlalchemy import create_engine
4 | from sqlalchemy.orm import sessionmaker
5 | from functools import wraps
6 | from sqlalchemy.exc import SQLAlchemyError
7 | from fastapi import HTTPException
8 | from sqlmodel import SQLModel
9 | from dotenv import load_dotenv
10 |
11 | from backend.common.config_loader import read_env
12 |
13 |
14 | @functools.lru_cache
15 | def init_db():
16 | db_url = read_env("DB_URL", "sqlite:///backend/records.db")
17 | engine = create_engine(db_url, connect_args={"check_same_thread": False})
18 | SQLModel.metadata.create_all(engine)
19 | return sessionmaker(autocommit=False, autoflush=False, bind=engine)
20 |
21 |
22 | def get_db_session():
23 | db_instance = init_db()
24 | return db_instance()
25 |
26 |
27 | def handle_database_errors(func):
28 | @wraps(func)
29 | def wrapper(*args, **kwargs):
30 | session = None
31 | try:
32 | session = get_db_session()
33 | kwargs['session'] = session
34 |
35 | return func(*args, **kwargs)
36 | except Exception as e:
37 | print(f"Database error has occurred: {e}")
38 | raise
39 | finally:
40 | if session:
41 | session.close()
42 | return wrapper
43 |
--------------------------------------------------------------------------------
/backend/db/task/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/backend/db/task/__init__.py
--------------------------------------------------------------------------------
/backend/db/task/dao.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Any
2 | from sqlalchemy.orm import Session
3 | from fastapi import Depends
4 |
5 | from ..db_instance import handle_database_errors, get_db_session
6 | from .models import Task, TasksResult, TaskStatus
7 |
8 |
9 | @handle_database_errors
10 | def add_task_to_db(
11 | session,
12 | status=TaskStatus.QUEUED,
13 | task_type=None,
14 | language=None,
15 | task_params=None,
16 | file_name=None,
17 | url=None,
18 | audio_duration=None,
19 | ):
20 | """
21 | Add task to the db
22 | """
23 | task = Task(
24 | status=status,
25 | language=language,
26 | file_name=file_name,
27 | url=url,
28 | task_type=task_type,
29 | task_params=task_params,
30 | audio_duration=audio_duration,
31 | )
32 | session.add(task)
33 | session.commit()
34 | return task.uuid
35 |
36 |
37 | @handle_database_errors
38 | def update_task_status_in_db(
39 | identifier: str,
40 | update_data: Dict[str, Any],
41 | session: Session,
42 | ):
43 | """
44 | Update task status and attributes in the database.
45 |
46 | Args:
47 | identifier (str): Identifier of the task to be updated.
48 | update_data (Dict[str, Any]): Dictionary containing the attributes to update along with their new values.
49 | session (Session, optional): Database session. Defaults to Depends(get_db_session).
50 |
51 | Returns:
52 | None
53 | """
54 | task = session.query(Task).filter_by(uuid=identifier).first()
55 | if task:
56 | for key, value in update_data.items():
57 | setattr(task, key, value)
58 | session.commit()
59 |
60 |
61 | @handle_database_errors
62 | def get_task_status_from_db(
63 | identifier: str, session: Session
64 | ):
65 | """Retrieve task status from db"""
66 | task = session.query(Task).filter(Task.uuid == identifier).first()
67 | if task:
68 | return task
69 | else:
70 | return None
71 |
72 |
73 | @handle_database_errors
74 | def get_all_tasks_status_from_db(session: Session):
75 | """Get all tasks from db"""
76 | columns = [Task.uuid, Task.status, Task.task_type]
77 | query = session.query(*columns)
78 | tasks = [task for task in query]
79 | return TasksResult(tasks=tasks)
80 |
81 |
82 | @handle_database_errors
83 | def delete_task_from_db(identifier: str, session: Session):
84 | """Delete task from db"""
85 | task = session.query(Task).filter(Task.uuid == identifier).first()
86 |
87 | if task:
88 | # If the task exists, delete it from the database
89 | session.delete(task)
90 | session.commit()
91 | return True
92 | else:
93 | # If the task does not exist, return False
94 | return False
95 |
--------------------------------------------------------------------------------
/backend/db/task/models.py:
--------------------------------------------------------------------------------
1 | # Ported from https://github.com/pavelzbornik/whisperX-FastAPI/blob/main/app/models.py
2 |
3 | from enum import Enum
4 | from pydantic import BaseModel
5 | from typing import Optional, List
6 | from uuid import uuid4
7 | from datetime import datetime
8 | from sqlalchemy.types import Enum as SQLAlchemyEnum
9 | from typing import Any
10 | from sqlmodel import SQLModel, Field, JSON, Column
11 |
12 |
13 | class ResultType(str, Enum):
14 | JSON = "json"
15 | FILEPATH = "filepath"
16 |
17 |
18 | class TaskStatus(str, Enum):
19 | PENDING = "pending"
20 | IN_PROGRESS = "in_progress"
21 | COMPLETED = "completed"
22 | FAILED = "failed"
23 | CANCELLED = "cancelled"
24 | QUEUED = "queued"
25 | PAUSED = "paused"
26 | RETRYING = "retrying"
27 |
28 | def __str__(self):
29 | return self.value
30 |
31 |
32 | class TaskType(str, Enum):
33 | TRANSCRIPTION = "transcription"
34 | VAD = "vad"
35 | BGM_SEPARATION = "bgm_separation"
36 |
37 | def __str__(self):
38 | return self.value
39 |
40 |
41 | class TaskStatusResponse(BaseModel):
42 | """`TaskStatusResponse` is a wrapper class that hides sensitive information from `Task`"""
43 | identifier: str = Field(..., description="Unique identifier for the queued task that can be used for tracking")
44 | status: TaskStatus = Field(..., description="Current status of the task")
45 | task_type: Optional[TaskType] = Field(
46 | default=None,
47 | description="Type/category of the task"
48 | )
49 | result_type: Optional[ResultType] = Field(
50 | default=ResultType.JSON,
51 | description="Result type whether it's a filepath or JSON"
52 | )
53 | result: Optional[Any] = Field(
54 | default=None,
55 | description="JSON data representing the result of the task"
56 | )
57 | task_params: Optional[dict] = Field(
58 | default=None,
59 | description="Parameters of the task"
60 | )
61 | error: Optional[str] = Field(
62 | default=None,
63 | description="Error message, if any, associated with the task"
64 | )
65 | duration: Optional[float] = Field(
66 | default=None,
67 | description="Duration of the task execution"
68 | )
69 | progress: Optional[float] = Field(
70 | default=0.0,
71 | description="Progress of the task"
72 | )
73 |
74 |
75 | class Task(SQLModel, table=True):
76 | """
77 | Table to store tasks information.
78 |
79 | Attributes:
80 | - id: Unique identifier for each task (Primary Key).
81 | - uuid: Universally unique identifier for each task.
82 | - status: Current status of the task.
83 | - result: JSON data representing the result of the task.
84 | - result_type: Type of the data whether it is normal JSON data or filepath.
85 | - file_name: Name of the file associated with the task.
86 | - task_type: Type/category of the task.
87 | - duration: Duration of the task execution.
88 | - error: Error message, if any, associated with the task.
89 | - created_at: Date and time of creation.
90 | - updated_at: Date and time of last update.
91 | - progress: Progress of the task. If it is None, it means the progress tracking for the task is not started yet.
92 | """
93 |
94 | __tablename__ = "tasks"
95 |
96 | id: Optional[int] = Field(
97 | default=None,
98 | primary_key=True,
99 | description="Unique identifier for each task (Primary Key)"
100 | )
101 | uuid: str = Field(
102 | default_factory=lambda: str(uuid4()),
103 | description="Universally unique identifier for each task"
104 | )
105 | status: Optional[TaskStatus] = Field(
106 | default=None,
107 | sa_column=Field(sa_column=SQLAlchemyEnum(TaskStatus)),
108 | description="Current status of the task",
109 | )
110 | result: Optional[dict] = Field(
111 | default_factory=dict,
112 | sa_column=Column(JSON),
113 | description="JSON data representing the result of the task"
114 | )
115 | result_type: Optional[ResultType] = Field(
116 | default=ResultType.JSON,
117 | sa_column=Field(sa_column=SQLAlchemyEnum(ResultType)),
118 | description="Result type whether it's a filepath or JSON"
119 | )
120 | file_name: Optional[str] = Field(
121 | default=None,
122 | description="Name of the file associated with the task"
123 | )
124 | url: Optional[str] = Field(
125 | default=None,
126 | description="URL of the file associated with the task"
127 | )
128 | audio_duration: Optional[float] = Field(
129 | default=None,
130 | description="Duration of the audio in seconds"
131 | )
132 | language: Optional[str] = Field(
133 | default=None,
134 | description="Language of the file associated with the task"
135 | )
136 | task_type: Optional[TaskType] = Field(
137 | default=None,
138 | sa_column=Field(sa_column=SQLAlchemyEnum(TaskType)),
139 | description="Type/category of the task"
140 | )
141 | task_params: Optional[dict] = Field(
142 | default_factory=dict,
143 | sa_column=Column(JSON),
144 | description="Parameters of the task"
145 | )
146 | duration: Optional[float] = Field(
147 | default=None,
148 | description="Duration of the task execution"
149 | )
150 | error: Optional[str] = Field(
151 | default=None,
152 | description="Error message, if any, associated with the task"
153 | )
154 | created_at: datetime = Field(
155 | default_factory=datetime.utcnow,
156 | description="Date and time of creation"
157 | )
158 | updated_at: datetime = Field(
159 | default_factory=datetime.utcnow,
160 | sa_column_kwargs={"onupdate": datetime.utcnow},
161 | description="Date and time of last update"
162 | )
163 | progress: Optional[float] = Field(
164 | default=0.0,
165 | description="Progress of the task"
166 | )
167 |
168 | def to_response(self) -> "TaskStatusResponse":
169 | return TaskStatusResponse(
170 | identifier=self.uuid,
171 | status=self.status,
172 | task_type=self.task_type,
173 | result_type=self.result_type,
174 | result=self.result,
175 | task_params=self.task_params,
176 | error=self.error,
177 | duration=self.duration,
178 | progress=self.progress
179 | )
180 |
181 |
182 | class TasksResult(BaseModel):
183 | tasks: List[Task]
184 |
185 |
--------------------------------------------------------------------------------
/backend/docker-compose.yaml:
--------------------------------------------------------------------------------
1 | services:
2 | app:
3 | build:
4 | dockerfile: backend/Dockerfile
5 | context: ..
6 | image: jhj0517/whisper-webui-backend:latest
7 |
8 | volumes:
9 | # You can mount the container's volume paths to directory paths on your local machine.
10 | # Models will be stored in the `./models' directory on your machine.
11 | # Similarly, all output files will be stored in the `./outputs` directory.
12 | # The DB file is saved in /Whisper-WebUI/backend/records.db unless you edit it in /Whisper-WebUI/backend/configs/.env
13 | - ./models:/Whisper-WebUI/models
14 | - ./outputs:/Whisper-WebUI/outputs
15 | - ./backend:/Whisper-WebUI/backend
16 |
17 | ports:
18 | - "8000:8000"
19 |
20 | stdin_open: true
21 | tty: true
22 |
23 | entrypoint: ["uvicorn", "backend.main:app", "--host", "0.0.0.0", "--port", "8000"]
24 |
25 | # If you're not using Nvidia GPU, Update device to match yours.
26 | # See more info at : https://docs.docker.com/compose/compose-file/deploy/#driver
27 | deploy:
28 | resources:
29 | reservations:
30 | devices:
31 | - driver: nvidia
32 | count: all
33 | capabilities: [ gpu ]
--------------------------------------------------------------------------------
/backend/main.py:
--------------------------------------------------------------------------------
1 | from contextlib import asynccontextmanager
2 | from fastapi import (
3 | FastAPI,
4 | )
5 | from fastapi.responses import RedirectResponse
6 | from fastapi.middleware.cors import CORSMiddleware
7 | import os
8 | import time
9 | import threading
10 |
11 | from backend.db.db_instance import init_db
12 | from backend.routers.transcription.router import transcription_router, get_pipeline
13 | from backend.routers.vad.router import get_vad_model, vad_router
14 | from backend.routers.bgm_separation.router import get_bgm_separation_inferencer, bgm_separation_router
15 | from backend.routers.task.router import task_router
16 | from backend.common.config_loader import read_env, load_server_config
17 | from backend.common.cache_manager import cleanup_old_files
18 | from modules.utils.paths import SERVER_CONFIG_PATH, BACKEND_CACHE_DIR
19 |
20 |
21 | def clean_cache_thread(ttl: int, frequency: int) -> threading.Thread:
22 | def clean_cache(_ttl: int, _frequency: int):
23 | while True:
24 | cleanup_old_files(cache_dir=BACKEND_CACHE_DIR, ttl=_ttl)
25 | time.sleep(_frequency)
26 |
27 | return threading.Thread(
28 | target=clean_cache,
29 | args=(ttl, frequency),
30 | daemon=True
31 | )
32 |
33 |
34 | @asynccontextmanager
35 | async def lifespan(app: FastAPI):
36 | # Basic setup initialization
37 | server_config = load_server_config()
38 | read_env("DB_URL") # Place .env file into /configs/.env
39 | init_db()
40 |
41 | # Inferencer initialization
42 | transcription_pipeline = get_pipeline()
43 | vad_inferencer = get_vad_model()
44 | bgm_separation_inferencer = get_bgm_separation_inferencer()
45 |
46 | # Thread initialization
47 | cache_thread = clean_cache_thread(server_config["cache"]["ttl"], server_config["cache"]["frequency"])
48 | cache_thread.start()
49 |
50 | yield
51 |
52 | # Release VRAM when server shutdown
53 | transcription_pipeline = None
54 | vad_inferencer = None
55 | bgm_separation_inferencer = None
56 |
57 |
58 | app = FastAPI(
59 | title="Whisper-WebUI-Backend",
60 | description=f"""
61 | REST API for Whisper-WebUI. Swagger UI is available via /docs or root URL with redirection. Redoc is available via /redoc.
62 | """,
63 | version="0.0.1",
64 | lifespan=lifespan,
65 | openapi_tags=[
66 | {
67 | "name": "BGM Separation",
68 | "description": "Cached files for /bgm-separation are generated in the `backend/cache` directory,"
69 | " you can set TLL for these files in `backend/configs/config.yaml`."
70 | }
71 | ]
72 | )
73 | app.add_middleware(
74 | CORSMiddleware,
75 | allow_origins=["*"],
76 | allow_credentials=True,
77 | allow_methods=["GET", "POST", "PUT", "PATCH", "OPTIONS"], # Disable DELETE
78 | allow_headers=["*"],
79 | )
80 | app.include_router(transcription_router)
81 | app.include_router(vad_router)
82 | app.include_router(bgm_separation_router)
83 | app.include_router(task_router)
84 |
85 |
86 | @app.get("/", response_class=RedirectResponse, include_in_schema=False)
87 | async def index():
88 | """
89 | Redirect to the documentation. Defaults to Swagger UI.
90 | You can also check the /redoc with redoc style: https://github.com/Redocly/redoc
91 | """
92 | return "/docs"
93 |
--------------------------------------------------------------------------------
/backend/nginx/logs/logs_are_generated_here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/backend/nginx/logs/logs_are_generated_here
--------------------------------------------------------------------------------
/backend/nginx/nginx.conf:
--------------------------------------------------------------------------------
1 | worker_processes 1;
2 |
3 | events {
4 | worker_connections 1024;
5 | }
6 |
7 | http {
8 | server {
9 | listen 80;
10 | client_max_body_size 4G;
11 |
12 | server_name your-own-domain-name.com;
13 |
14 | location / {
15 | proxy_pass http://127.0.0.1:8000;
16 | proxy_set_header Host $host;
17 | proxy_set_header X-Real-IP $remote_addr;
18 | proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
19 | proxy_set_header X-Forwarded-Proto $scheme;
20 | }
21 | }
22 | }
23 |
24 |
--------------------------------------------------------------------------------
/backend/nginx/temp/temps_are_generated_here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/backend/nginx/temp/temps_are_generated_here
--------------------------------------------------------------------------------
/backend/requirements-backend.txt:
--------------------------------------------------------------------------------
1 | # Whisper-WebUI dependencies
2 | -r ../requirements.txt
3 |
4 | # Backend dependencies
5 | python-dotenv
6 | uvicorn
7 | SQLAlchemy
8 | sqlmodel
9 | pydantic
10 |
11 | # Test dependencies
12 | # pytest
13 | # pytest-asyncio
--------------------------------------------------------------------------------
/backend/routers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/backend/routers/__init__.py
--------------------------------------------------------------------------------
/backend/routers/bgm_separation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/backend/routers/bgm_separation/__init__.py
--------------------------------------------------------------------------------
/backend/routers/bgm_separation/models.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel, Field
2 |
3 |
4 | class BGMSeparationResult(BaseModel):
5 | instrumental_hash: str = Field(..., description="Instrumental file hash")
6 | vocal_hash: str = Field(..., description="Vocal file hash")
7 |
--------------------------------------------------------------------------------
/backend/routers/bgm_separation/router.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import numpy as np
3 | from fastapi import (
4 | File,
5 | UploadFile,
6 | )
7 | import gradio as gr
8 | from fastapi import APIRouter, BackgroundTasks, Depends, Response, status
9 | from fastapi.responses import FileResponse
10 | from typing import List, Dict, Tuple
11 | from datetime import datetime
12 | import os
13 |
14 | from modules.whisper.data_classes import *
15 | from modules.uvr.music_separator import MusicSeparator
16 | from modules.utils.paths import BACKEND_CACHE_DIR
17 | from backend.common.audio import read_audio
18 | from backend.common.models import QueueResponse
19 | from backend.common.config_loader import load_server_config
20 | from backend.common.compresser import get_file_hash, find_file_by_hash
21 | from backend.db.task.models import TaskStatus, TaskType, ResultType
22 | from backend.db.task.dao import add_task_to_db, update_task_status_in_db
23 | from .models import BGMSeparationResult
24 |
25 |
26 | bgm_separation_router = APIRouter(prefix="/bgm-separation", tags=["BGM Separation"])
27 |
28 |
29 | @functools.lru_cache
30 | def get_bgm_separation_inferencer() -> 'MusicSeparator':
31 | config = load_server_config()["bgm_separation"]
32 | inferencer = MusicSeparator(
33 | output_dir=os.path.join(BACKEND_CACHE_DIR, "UVR")
34 | )
35 | inferencer.update_model(
36 | model_name=config["model_size"],
37 | device=config["device"]
38 | )
39 | return inferencer
40 |
41 |
42 | def run_bgm_separation(
43 | audio: np.ndarray,
44 | params: BGMSeparationParams,
45 | identifier: str,
46 | ) -> Tuple[np.ndarray, np.ndarray]:
47 | update_task_status_in_db(
48 | identifier=identifier,
49 | update_data={
50 | "uuid": identifier,
51 | "status": TaskStatus.IN_PROGRESS,
52 | "updated_at": datetime.utcnow()
53 | }
54 | )
55 |
56 | start_time = datetime.utcnow()
57 | instrumental, vocal, filepaths = get_bgm_separation_inferencer().separate(
58 | audio=audio,
59 | model_name=params.uvr_model_size,
60 | device=params.uvr_device,
61 | segment_size=params.segment_size,
62 | save_file=True,
63 | progress=gr.Progress()
64 | )
65 | instrumental_path, vocal_path = filepaths
66 | elapsed_time = (datetime.utcnow() - start_time).total_seconds()
67 |
68 | update_task_status_in_db(
69 | identifier=identifier,
70 | update_data={
71 | "uuid": identifier,
72 | "status": TaskStatus.COMPLETED,
73 | "result": BGMSeparationResult(
74 | instrumental_hash=get_file_hash(instrumental_path),
75 | vocal_hash=get_file_hash(vocal_path)
76 | ).model_dump(),
77 | "result_type": ResultType.FILEPATH,
78 | "updated_at": datetime.utcnow(),
79 | "duration": elapsed_time
80 | }
81 | )
82 | return instrumental, vocal
83 |
84 |
85 | @bgm_separation_router.post(
86 | "/",
87 | response_model=QueueResponse,
88 | status_code=status.HTTP_201_CREATED,
89 | summary="Separate Background BGM abd vocal",
90 | description="Separate background music and vocal from an uploaded audio or video file.",
91 | )
92 | async def bgm_separation(
93 | background_tasks: BackgroundTasks,
94 | file: UploadFile = File(..., description="Audio or video file to separate background music."),
95 | params: BGMSeparationParams = Depends()
96 | ) -> QueueResponse:
97 | if not isinstance(file, np.ndarray):
98 | audio, info = await read_audio(file=file)
99 | else:
100 | audio, info = file, None
101 |
102 | identifier = add_task_to_db(
103 | status=TaskStatus.QUEUED,
104 | file_name=file.filename,
105 | audio_duration=info.duration if info else None,
106 | task_type=TaskType.BGM_SEPARATION,
107 | task_params=params.model_dump(),
108 | )
109 |
110 | background_tasks.add_task(
111 | run_bgm_separation,
112 | audio=audio,
113 | params=params,
114 | identifier=identifier
115 | )
116 |
117 | return QueueResponse(identifier=identifier, status=TaskStatus.QUEUED, message="BGM Separation task has queued")
118 |
119 |
120 |
--------------------------------------------------------------------------------
/backend/routers/task/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/backend/routers/task/__init__.py
--------------------------------------------------------------------------------
/backend/routers/task/router.py:
--------------------------------------------------------------------------------
1 | from fastapi import APIRouter, Depends, HTTPException, status
2 | from fastapi.responses import FileResponse
3 | from sqlalchemy.orm import Session
4 | import os
5 |
6 | from backend.db.db_instance import get_db_session
7 | from backend.db.task.dao import (
8 | get_task_status_from_db,
9 | get_all_tasks_status_from_db,
10 | delete_task_from_db,
11 | )
12 | from backend.db.task.models import (
13 | TasksResult,
14 | Task,
15 | TaskStatusResponse,
16 | TaskType
17 | )
18 | from backend.common.models import (
19 | Response,
20 | )
21 | from backend.common.compresser import compress_files, find_file_by_hash
22 | from modules.utils.paths import BACKEND_CACHE_DIR
23 |
24 | task_router = APIRouter(prefix="/task", tags=["Tasks"])
25 |
26 |
27 | @task_router.get(
28 | "/{identifier}",
29 | response_model=TaskStatusResponse,
30 | status_code=status.HTTP_200_OK,
31 | summary="Retrieve Task by Identifier",
32 | description="Retrieve the specific task by its identifier.",
33 | )
34 | async def get_task(
35 | identifier: str,
36 | session: Session = Depends(get_db_session),
37 | ) -> TaskStatusResponse:
38 | """
39 | Retrieve the specific task by its identifier.
40 | """
41 | task = get_task_status_from_db(identifier=identifier, session=session)
42 |
43 | if task is not None:
44 | return task.to_response()
45 | else:
46 | raise HTTPException(status_code=404, detail="Identifier not found")
47 |
48 |
49 | @task_router.get(
50 | "/file/{identifier}",
51 | status_code=status.HTTP_200_OK,
52 | summary="Retrieve FileResponse Task by Identifier",
53 | description="Retrieve the file response task by its identifier. You can use this endpoint if you need to download"
54 | " The file as a response",
55 | )
56 | async def get_file_task(
57 | identifier: str,
58 | session: Session = Depends(get_db_session),
59 | ) -> FileResponse:
60 | """
61 | Retrieve the downloadable file response of a specific task by its identifier.
62 | Compressed by ZIP basically.
63 | """
64 | task = get_task_status_from_db(identifier=identifier, session=session)
65 |
66 | if task is not None:
67 | if task.task_type == TaskType.BGM_SEPARATION:
68 | output_zip_path = os.path.join(BACKEND_CACHE_DIR, f"{identifier}_bgm_separation.zip")
69 | instrumental_path = find_file_by_hash(
70 | os.path.join(BACKEND_CACHE_DIR, "UVR", "instrumental"),
71 | task.result["instrumental_hash"]
72 | )
73 | vocal_path = find_file_by_hash(
74 | os.path.join(BACKEND_CACHE_DIR, "UVR", "vocals"),
75 | task.result["vocal_hash"]
76 | )
77 |
78 | output_zip_path = compress_files(
79 | [instrumental_path, vocal_path],
80 | output_zip_path
81 | )
82 | return FileResponse(
83 | path=output_zip_path,
84 | status_code=200,
85 | filename=output_zip_path,
86 | media_type="application/zip"
87 | )
88 | else:
89 | raise HTTPException(status_code=404, detail=f"File download is only supported for bgm separation."
90 | f" The given type is {task.task_type}")
91 | else:
92 | raise HTTPException(status_code=404, detail="Identifier not found")
93 |
94 |
95 | # Delete method, commented by default because this endpoint is likely to require special permissions
96 | # @task_router.delete(
97 | # "/{identifier}",
98 | # response_model=Response,
99 | # status_code=status.HTTP_200_OK,
100 | # summary="Delete Task by Identifier",
101 | # description="Delete a task from the system using its identifier.",
102 | # )
103 | async def delete_task(
104 | identifier: str,
105 | session: Session = Depends(get_db_session),
106 | ) -> Response:
107 | """
108 | Delete a task by its identifier.
109 | """
110 | if delete_task_from_db(identifier, session):
111 | return Response(identifier=identifier, message="Task deleted")
112 | else:
113 | raise HTTPException(status_code=404, detail="Task not found")
114 |
115 |
116 | # Get All method, commented by default because this endpoint is likely to require special permissions
117 | # @task_router.get(
118 | # "/all",
119 | # response_model=TasksResult,
120 | # status_code=status.HTTP_200_OK,
121 | # summary="Retrieve All Task Statuses",
122 | # description="Retrieve the statuses of all tasks available in the system.",
123 | # )
124 | async def get_all_tasks_status(
125 | session: Session = Depends(get_db_session),
126 | ) -> TasksResult:
127 | """
128 | Retrieve all tasks.
129 | """
130 | return get_all_tasks_status_from_db(session=session)
--------------------------------------------------------------------------------
/backend/routers/transcription/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/backend/routers/transcription/__init__.py
--------------------------------------------------------------------------------
/backend/routers/transcription/router.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import uuid
3 | import numpy as np
4 | from fastapi import (
5 | File,
6 | UploadFile,
7 | )
8 | import gradio as gr
9 | from fastapi import APIRouter, BackgroundTasks, Depends, Response, status
10 | from typing import List, Dict
11 | from sqlalchemy.orm import Session
12 | from datetime import datetime
13 | from modules.whisper.data_classes import *
14 | from modules.utils.paths import BACKEND_CACHE_DIR
15 | from modules.whisper.faster_whisper_inference import FasterWhisperInference
16 | from backend.common.audio import read_audio
17 | from backend.common.models import QueueResponse
18 | from backend.common.config_loader import load_server_config
19 | from backend.db.task.dao import (
20 | add_task_to_db,
21 | get_db_session,
22 | update_task_status_in_db
23 | )
24 | from backend.db.task.models import TaskStatus, TaskType
25 |
26 | transcription_router = APIRouter(prefix="/transcription", tags=["Transcription"])
27 |
28 |
29 | def create_progress_callback(identifier: str):
30 | def progress_callback(progress_value: float):
31 | update_task_status_in_db(
32 | identifier=identifier,
33 | update_data={
34 | "uuid": identifier,
35 | "status": TaskStatus.IN_PROGRESS,
36 | "progress": round(progress_value, 2),
37 | "updated_at": datetime.utcnow()
38 | },
39 | )
40 | return progress_callback
41 |
42 |
43 | @functools.lru_cache
44 | def get_pipeline() -> 'FasterWhisperInference':
45 | config = load_server_config()["whisper"]
46 | inferencer = FasterWhisperInference(
47 | output_dir=BACKEND_CACHE_DIR
48 | )
49 | inferencer.update_model(
50 | model_size=config["model_size"],
51 | compute_type=config["compute_type"]
52 | )
53 | return inferencer
54 |
55 |
56 | def run_transcription(
57 | audio: np.ndarray,
58 | params: TranscriptionPipelineParams,
59 | identifier: str,
60 | ) -> List[Segment]:
61 | update_task_status_in_db(
62 | identifier=identifier,
63 | update_data={
64 | "uuid": identifier,
65 | "status": TaskStatus.IN_PROGRESS,
66 | "updated_at": datetime.utcnow()
67 | },
68 | )
69 |
70 | progress_callback = create_progress_callback(identifier)
71 | segments, elapsed_time = get_pipeline().run(
72 | audio,
73 | gr.Progress(),
74 | "SRT",
75 | False,
76 | progress_callback,
77 | *params.to_list()
78 | )
79 | segments = [seg.model_dump() for seg in segments]
80 |
81 | update_task_status_in_db(
82 | identifier=identifier,
83 | update_data={
84 | "uuid": identifier,
85 | "status": TaskStatus.COMPLETED,
86 | "result": segments,
87 | "updated_at": datetime.utcnow(),
88 | "duration": elapsed_time,
89 | "progress": 1.0,
90 | },
91 | )
92 | return segments
93 |
94 |
95 | @transcription_router.post(
96 | "/",
97 | response_model=QueueResponse,
98 | status_code=status.HTTP_201_CREATED,
99 | summary="Transcribe Audio",
100 | description="Process the provided audio or video file to generate a transcription.",
101 | )
102 | async def transcription(
103 | background_tasks: BackgroundTasks,
104 | file: UploadFile = File(..., description="Audio or video file to transcribe."),
105 | whisper_params: WhisperParams = Depends(),
106 | vad_params: VadParams = Depends(),
107 | bgm_separation_params: BGMSeparationParams = Depends(),
108 | diarization_params: DiarizationParams = Depends(),
109 | ) -> QueueResponse:
110 | if not isinstance(file, np.ndarray):
111 | audio, info = await read_audio(file=file)
112 | else:
113 | audio, info = file, None
114 |
115 | params = TranscriptionPipelineParams(
116 | whisper=whisper_params,
117 | vad=vad_params,
118 | bgm_separation=bgm_separation_params,
119 | diarization=diarization_params
120 | )
121 |
122 | identifier = add_task_to_db(
123 | status=TaskStatus.QUEUED,
124 | file_name=file.filename,
125 | audio_duration=info.duration if info else None,
126 | language=params.whisper.lang,
127 | task_type=TaskType.TRANSCRIPTION,
128 | task_params=params.to_dict(),
129 | )
130 |
131 | background_tasks.add_task(
132 | run_transcription,
133 | audio=audio,
134 | params=params,
135 | identifier=identifier,
136 | )
137 |
138 | return QueueResponse(identifier=identifier, status=TaskStatus.QUEUED, message="Transcription task has queued")
139 |
140 |
141 |
--------------------------------------------------------------------------------
/backend/routers/vad/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/backend/routers/vad/__init__.py
--------------------------------------------------------------------------------
/backend/routers/vad/router.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import numpy as np
3 | from faster_whisper.vad import VadOptions
4 | from fastapi import (
5 | File,
6 | UploadFile,
7 | )
8 | from fastapi import APIRouter, BackgroundTasks, Depends, Response, status
9 | from typing import List, Dict
10 | from datetime import datetime
11 |
12 | from modules.vad.silero_vad import SileroVAD
13 | from modules.whisper.data_classes import VadParams
14 | from backend.common.audio import read_audio
15 | from backend.common.models import QueueResponse
16 | from backend.db.task.dao import add_task_to_db, update_task_status_in_db
17 | from backend.db.task.models import TaskStatus, TaskType
18 |
19 | vad_router = APIRouter(prefix="/vad", tags=["Voice Activity Detection"])
20 |
21 |
22 | @functools.lru_cache
23 | def get_vad_model() -> SileroVAD:
24 | inferencer = SileroVAD()
25 | inferencer.update_model()
26 | return inferencer
27 |
28 |
29 | def run_vad(
30 | audio: np.ndarray,
31 | params: VadOptions,
32 | identifier: str,
33 | ) -> List[Dict]:
34 | update_task_status_in_db(
35 | identifier=identifier,
36 | update_data={
37 | "uuid": identifier,
38 | "status": TaskStatus.IN_PROGRESS,
39 | "updated_at": datetime.utcnow()
40 | }
41 | )
42 |
43 | start_time = datetime.utcnow()
44 | audio, speech_chunks = get_vad_model().run(
45 | audio=audio,
46 | vad_parameters=params
47 | )
48 | elapsed_time = (datetime.utcnow() - start_time).total_seconds()
49 |
50 | update_task_status_in_db(
51 | identifier=identifier,
52 | update_data={
53 | "uuid": identifier,
54 | "status": TaskStatus.COMPLETED,
55 | "updated_at": datetime.utcnow(),
56 | "result": speech_chunks,
57 | "duration": elapsed_time
58 | }
59 | )
60 |
61 | return speech_chunks
62 |
63 |
64 | @vad_router.post(
65 | "/",
66 | response_model=QueueResponse,
67 | status_code=status.HTTP_201_CREATED,
68 | summary="Voice Activity Detection",
69 | description="Detect voice parts in the provided audio or video file to generate a timeline of speech segments.",
70 | )
71 | async def vad(
72 | background_tasks: BackgroundTasks,
73 | file: UploadFile = File(..., description="Audio or video file to detect voices."),
74 | params: VadParams = Depends()
75 | ) -> QueueResponse:
76 | if not isinstance(file, np.ndarray):
77 | audio, info = await read_audio(file=file)
78 | else:
79 | audio, info = file, None
80 |
81 | vad_options = VadOptions(
82 | threshold=params.threshold,
83 | min_speech_duration_ms=params.min_speech_duration_ms,
84 | max_speech_duration_s=params.max_speech_duration_s,
85 | min_silence_duration_ms=params.min_silence_duration_ms,
86 | speech_pad_ms=params.speech_pad_ms
87 | )
88 |
89 | identifier = add_task_to_db(
90 | status=TaskStatus.QUEUED,
91 | file_name=file.filename,
92 | audio_duration=info.duration if info else None,
93 | task_type=TaskType.VAD,
94 | task_params=params.model_dump(),
95 | )
96 |
97 | background_tasks.add_task(run_vad, audio=audio, params=vad_options, identifier=identifier)
98 |
99 | return QueueResponse(identifier=identifier, status=TaskStatus.QUEUED, message="VAD task has queued")
100 |
101 |
102 |
--------------------------------------------------------------------------------
/backend/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/backend/tests/__init__.py
--------------------------------------------------------------------------------
/backend/tests/test_backend_bgm_separation.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from fastapi import UploadFile
3 | from io import BytesIO
4 | import os
5 | import torch
6 |
7 | from backend.db.task.models import TaskStatus
8 | from backend.tests.test_task_status import wait_for_task_completion, fetch_file_response
9 | from backend.tests.test_backend_config import (
10 | get_client, setup_test_file, get_upload_file_instance, calculate_wer,
11 | TEST_BGM_SEPARATION_PARAMS, TEST_ANSWER, TEST_BGM_SEPARATION_OUTPUT_PATH
12 | )
13 |
14 |
15 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip the test because CUDA is not available")
16 | @pytest.mark.parametrize(
17 | "bgm_separation_params",
18 | [
19 | TEST_BGM_SEPARATION_PARAMS
20 | ]
21 | )
22 | def test_transcription_endpoint(
23 | get_upload_file_instance,
24 | bgm_separation_params: dict
25 | ):
26 | client = get_client()
27 | file_content = BytesIO(get_upload_file_instance.file.read())
28 | get_upload_file_instance.file.seek(0)
29 |
30 | response = client.post(
31 | "/bgm-separation",
32 | files={"file": (get_upload_file_instance.filename, file_content, "audio/mpeg")},
33 | params=bgm_separation_params
34 | )
35 |
36 | assert response.status_code == 201
37 | assert response.json()["status"] == TaskStatus.QUEUED
38 | task_identifier = response.json()["identifier"]
39 | assert isinstance(task_identifier, str) and task_identifier
40 |
41 | completed_task = wait_for_task_completion(
42 | identifier=task_identifier
43 | )
44 |
45 | assert completed_task is not None, f"Task with identifier {task_identifier} did not complete within the " \
46 | f"expected time."
47 |
48 | result = completed_task.json()["result"]
49 | assert "instrumental_hash" in result and result["instrumental_hash"]
50 | assert "vocal_hash" in result and result["vocal_hash"]
51 |
52 | file_response = fetch_file_response(task_identifier)
53 | assert file_response.status_code == 200, f"Fetching File Response has failed. Response is: {file_response}"
54 |
55 | with open(TEST_BGM_SEPARATION_OUTPUT_PATH, "wb") as file:
56 | file.write(file_response.content)
57 |
58 | assert os.path.exists(TEST_BGM_SEPARATION_OUTPUT_PATH)
59 |
60 |
--------------------------------------------------------------------------------
/backend/tests/test_backend_config.py:
--------------------------------------------------------------------------------
1 | import functools
2 | from fastapi import FastAPI, UploadFile
3 | from fastapi.testclient import TestClient
4 | from starlette.datastructures import UploadFile as StarletteUploadFile
5 | from io import BytesIO
6 | import os
7 | import requests
8 | import pytest
9 | import yaml
10 | import jiwer
11 |
12 | from backend.main import app
13 | from modules.whisper.data_classes import *
14 | from modules.utils.paths import *
15 | from modules.utils.files_manager import load_yaml, save_yaml
16 |
17 | TEST_PIPELINE_PARAMS = {**WhisperParams(model_size="tiny", compute_type="float32").model_dump(exclude_none=True),
18 | **VadParams().model_dump(exclude_none=True),
19 | **BGMSeparationParams().model_dump(exclude_none=True),
20 | **DiarizationParams().model_dump(exclude_none=True)}
21 | TEST_VAD_PARAMS = VadParams().model_dump()
22 | TEST_BGM_SEPARATION_PARAMS = BGMSeparationParams().model_dump()
23 | TEST_FILE_DOWNLOAD_URL = "https://github.com/jhj0517/whisper_flutter_new/raw/main/example/assets/jfk.wav"
24 | TEST_FILE_PATH = os.path.join(WEBUI_DIR, "backend", "tests", "jfk.wav")
25 | TEST_BGM_SEPARATION_OUTPUT_PATH = os.path.join(WEBUI_DIR, "backend", "tests", "separated_audio.zip")
26 | TEST_ANSWER = "And so my fellow Americans ask not what your country can do for you ask what you can do for your country"
27 | TEST_WHISPER_MODEL = "tiny"
28 | TEST_COMPUTE_TYPE = "float32"
29 |
30 |
31 | @pytest.fixture(autouse=True)
32 | def setup_test_file():
33 | @functools.lru_cache
34 | def download_file(url=TEST_FILE_DOWNLOAD_URL, file_path=TEST_FILE_PATH):
35 | if os.path.exists(file_path):
36 | return
37 |
38 | if not os.path.exists(os.path.dirname(file_path)):
39 | os.makedirs(os.path.dirname(file_path))
40 |
41 | response = requests.get(url)
42 |
43 | with open(file_path, "wb") as file:
44 | file.write(response.content)
45 |
46 | print(f"File downloaded to: {file_path}")
47 |
48 | download_file(TEST_FILE_DOWNLOAD_URL, TEST_FILE_PATH)
49 |
50 |
51 | @pytest.fixture
52 | @functools.lru_cache
53 | def get_upload_file_instance(filepath: str = TEST_FILE_PATH) -> UploadFile:
54 | with open(filepath, "rb") as f:
55 | file_contents = BytesIO(f.read())
56 | filename = os.path.basename(filepath)
57 | upload_file = StarletteUploadFile(file=file_contents, filename=filename)
58 | return upload_file
59 |
60 |
61 | @functools.lru_cache
62 | def get_client(app: FastAPI = app):
63 | return TestClient(app)
64 |
65 |
66 | def calculate_wer(answer, prediction):
67 | return jiwer.wer(answer, prediction)
68 |
--------------------------------------------------------------------------------
/backend/tests/test_backend_transcription.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from fastapi import UploadFile
3 | from io import BytesIO
4 |
5 | from backend.db.task.models import TaskStatus
6 | from backend.tests.test_task_status import wait_for_task_completion
7 | from backend.tests.test_backend_config import (
8 | get_client, setup_test_file, get_upload_file_instance, calculate_wer,
9 | TEST_PIPELINE_PARAMS, TEST_ANSWER
10 | )
11 |
12 |
13 | @pytest.mark.parametrize(
14 | "pipeline_params",
15 | [
16 | TEST_PIPELINE_PARAMS
17 | ]
18 | )
19 | def test_transcription_endpoint(
20 | get_upload_file_instance,
21 | pipeline_params: dict
22 | ):
23 | client = get_client()
24 | file_content = BytesIO(get_upload_file_instance.file.read())
25 | get_upload_file_instance.file.seek(0)
26 |
27 | response = client.post(
28 | "/transcription",
29 | files={"file": (get_upload_file_instance.filename, file_content, "audio/mpeg")},
30 | params=pipeline_params
31 | )
32 |
33 | assert response.status_code == 201
34 | assert response.json()["status"] == TaskStatus.QUEUED
35 | task_identifier = response.json()["identifier"]
36 | assert isinstance(task_identifier, str) and task_identifier
37 |
38 | completed_task = wait_for_task_completion(
39 | identifier=task_identifier
40 | )
41 |
42 | assert completed_task is not None, f"Task with identifier {task_identifier} did not complete within the " \
43 | f"expected time."
44 |
45 | result = completed_task.json()["result"]
46 | assert result, "Transcription text is empty"
47 |
48 | wer = calculate_wer(TEST_ANSWER, result[0]["text"].strip().replace(",", "").replace(".", ""))
49 | assert wer < 0.1, f"WER is too high, it's {wer}"
50 |
51 |
--------------------------------------------------------------------------------
/backend/tests/test_backend_vad.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from fastapi import UploadFile
3 | from io import BytesIO
4 |
5 | from backend.db.task.models import TaskStatus
6 | from backend.tests.test_task_status import wait_for_task_completion
7 | from backend.tests.test_backend_config import (
8 | get_client, setup_test_file, get_upload_file_instance, calculate_wer,
9 | TEST_VAD_PARAMS, TEST_ANSWER
10 | )
11 |
12 |
13 | @pytest.mark.parametrize(
14 | "vad_params",
15 | [
16 | TEST_VAD_PARAMS
17 | ]
18 | )
19 | def test_transcription_endpoint(
20 | get_upload_file_instance,
21 | vad_params: dict
22 | ):
23 | client = get_client()
24 | file_content = BytesIO(get_upload_file_instance.file.read())
25 | get_upload_file_instance.file.seek(0)
26 |
27 | response = client.post(
28 | "/vad",
29 | files={"file": (get_upload_file_instance.filename, file_content, "audio/mpeg")},
30 | params=vad_params
31 | )
32 |
33 | assert response.status_code == 201
34 | assert response.json()["status"] == TaskStatus.QUEUED
35 | task_identifier = response.json()["identifier"]
36 | assert isinstance(task_identifier, str) and task_identifier
37 |
38 | completed_task = wait_for_task_completion(
39 | identifier=task_identifier
40 | )
41 |
42 | assert completed_task is not None, f"Task with identifier {task_identifier} did not complete within the " \
43 | f"expected time."
44 |
45 | result = completed_task.json()["result"]
46 | assert result and "start" in result[0] and "end" in result[0]
47 |
48 |
--------------------------------------------------------------------------------
/backend/tests/test_task_status.py:
--------------------------------------------------------------------------------
1 | import time
2 | import pytest
3 | from typing import Optional, Union
4 | import httpx
5 |
6 | from backend.db.task.models import TaskStatus, Task
7 | from backend.tests.test_backend_config import get_client
8 |
9 |
10 | def fetch_task(identifier: str):
11 | """Get task status"""
12 | client = get_client()
13 | response = client.get(
14 | f"/task/{identifier}"
15 | )
16 | if response.status_code == 200:
17 | return response
18 | return None
19 |
20 |
21 | def fetch_file_response(identifier: str):
22 | """Get task status"""
23 | client = get_client()
24 | response = client.get(
25 | f"/task/file/{identifier}"
26 | )
27 | if response.status_code == 200:
28 | return response
29 | return None
30 |
31 |
32 | def wait_for_task_completion(identifier: str,
33 | max_attempts: int = 20,
34 | frequency: int = 3) -> httpx.Response:
35 | """
36 | Polls the task status until it is completed, failed, or the maximum attempts are reached.
37 |
38 | Args:
39 | identifier (str): The unique identifier of the task to monitor.
40 | max_attempts (int): The maximum number of polling attempts..
41 | frequency (int): The time (in seconds) to wait between polling attempts.
42 |
43 | Returns:
44 | bool: Returns json if the task completes successfully within the allowed attempts.
45 | """
46 | attempts = 0
47 | while attempts < max_attempts:
48 | task = fetch_task(identifier)
49 | status = task.json()["status"]
50 | if status == TaskStatus.COMPLETED:
51 | return task
52 | if status == TaskStatus.FAILED:
53 | raise Exception("Task polling failed")
54 | time.sleep(frequency)
55 | attempts += 1
56 | return None
57 |
--------------------------------------------------------------------------------
/configs/default_parameters.yaml:
--------------------------------------------------------------------------------
1 | whisper:
2 | model_size: "large-v2"
3 | file_format: "SRT"
4 | lang: "Automatic Detection"
5 | is_translate: false
6 | beam_size: 5
7 | log_prob_threshold: -1
8 | no_speech_threshold: 0.6
9 | best_of: 5
10 | patience: 1
11 | condition_on_previous_text: true
12 | prompt_reset_on_temperature: 0.5
13 | initial_prompt: null
14 | temperature: 0
15 | compression_ratio_threshold: 2.4
16 | chunk_length: 30
17 | batch_size: 24
18 | length_penalty: 1
19 | repetition_penalty: 1
20 | no_repeat_ngram_size: 0
21 | prefix: null
22 | suppress_blank: true
23 | suppress_tokens: "[-1]"
24 | max_initial_timestamp: 1
25 | word_timestamps: false
26 | prepend_punctuations: "\"'“¿([{-"
27 | append_punctuations: "\"'.。,,!!??::”)]}、"
28 | max_new_tokens: null
29 | hallucination_silence_threshold: null
30 | hotwords: null
31 | language_detection_threshold: 0.5
32 | language_detection_segments: 1
33 | add_timestamp: false
34 | enable_offload: true
35 |
36 | vad:
37 | vad_filter: false
38 | threshold: 0.5
39 | min_speech_duration_ms: 250
40 | max_speech_duration_s: 9999
41 | min_silence_duration_ms: 1000
42 | speech_pad_ms: 2000
43 |
44 | diarization:
45 | is_diarize: false
46 | hf_token: ""
47 | enable_offload: true
48 |
49 | bgm_separation:
50 | is_separate_bgm: false
51 | uvr_model_size: "UVR-MDX-NET-Inst_HQ_4"
52 | segment_size: 256
53 | save_file: false
54 | enable_offload: true
55 |
56 | translation:
57 | deepl:
58 | api_key: ""
59 | is_pro: false
60 | source_lang: "Automatic Detection"
61 | target_lang: "English"
62 | nllb:
63 | model_size: "facebook/nllb-200-1.3B"
64 | source_lang: null
65 | target_lang: null
66 | max_length: 200
67 | add_timestamp: false
68 |
--------------------------------------------------------------------------------
/docker-compose.yaml:
--------------------------------------------------------------------------------
1 | services:
2 | whisper-webui:
3 | container_name: whisper-webui
4 | build: .
5 | image: jhj0517/whisper-webui:latest
6 |
7 | volumes:
8 | # You can mount the container's volume paths to directory paths on your local machine.
9 | # Models will be stored in the `./models' directory on your machine.
10 | # Similarly, all output files will be stored in the `./outputs` directory.
11 | - ./models:/Whisper-WebUI/models
12 | - ./outputs:/Whisper-WebUI/outputs
13 | - ./configs:/Whisper-WebUI/configs
14 |
15 | ports:
16 | - "7860:7860"
17 |
18 | stdin_open: true
19 | tty: true
20 |
21 | entrypoint: ["python", "app.py", "--server_port", "7860", "--server_name", "0.0.0.0",]
22 |
23 | # If you're not using nvidia GPU, Update device to match yours.
24 | # See more info at : https://docs.docker.com/compose/compose-file/deploy/#driver
25 | # You can remove the entire `deploy' section if you are using CPU.
26 | deploy:
27 | resources:
28 | reservations:
29 | devices:
30 | - driver: nvidia
31 | count: all
32 | capabilities: [ gpu ]
33 |
--------------------------------------------------------------------------------
/models/Diarization/diarization_models_will_be_saved_here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/models/Diarization/diarization_models_will_be_saved_here
--------------------------------------------------------------------------------
/models/NLLB/nllb_models_will_be_saved_here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/models/NLLB/nllb_models_will_be_saved_here
--------------------------------------------------------------------------------
/models/UVR/uvr_models_will_be_saved_here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/models/UVR/uvr_models_will_be_saved_here
--------------------------------------------------------------------------------
/models/Whisper/faster-whisper/faster_whisper_models_will_be_saved_here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/models/Whisper/faster-whisper/faster_whisper_models_will_be_saved_here
--------------------------------------------------------------------------------
/models/Whisper/insanely-fast-whisper/insanely_fast_whisper_models_will_be_saved_here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/models/Whisper/insanely-fast-whisper/insanely_fast_whisper_models_will_be_saved_here
--------------------------------------------------------------------------------
/models/Whisper/whisper_models_will_be_saved_here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/models/Whisper/whisper_models_will_be_saved_here
--------------------------------------------------------------------------------
/models/models_will_be_saved_here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/models/models_will_be_saved_here
--------------------------------------------------------------------------------
/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/modules/__init__.py
--------------------------------------------------------------------------------
/modules/diarize/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/modules/diarize/__init__.py
--------------------------------------------------------------------------------
/modules/diarize/audio_loader.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/m-bain/whisperX/blob/main/whisperx/audio.py
2 |
3 | import os
4 | import subprocess
5 | from functools import lru_cache
6 | from typing import Optional, Union
7 | from scipy.io.wavfile import write
8 | import tempfile
9 | import numpy as np
10 | import torch
11 | import torch.nn.functional as F
12 |
13 | def exact_div(x, y):
14 | assert x % y == 0
15 | return x // y
16 |
17 | # hard-coded audio hyperparameters
18 | SAMPLE_RATE = 16000
19 | N_FFT = 400
20 | HOP_LENGTH = 160
21 | CHUNK_LENGTH = 30
22 | N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
23 | N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
24 |
25 | N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
26 | FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
27 | TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
28 |
29 |
30 | def load_audio(file: Union[str, np.ndarray], sr: int = SAMPLE_RATE) -> np.ndarray:
31 | """
32 | Open an audio file or process a numpy array containing audio data as mono waveform, resampling as necessary.
33 |
34 | Parameters
35 | ----------
36 | file: Union[str, np.ndarray]
37 | The audio file to open or a numpy array containing the audio data.
38 |
39 | sr: int
40 | The sample rate to resample the audio if necessary.
41 |
42 | Returns
43 | -------
44 | A NumPy array containing the audio waveform, in float32 dtype.
45 | """
46 | if isinstance(file, np.ndarray):
47 | if file.dtype != np.float32:
48 | file = file.astype(np.float32)
49 | if file.ndim > 1:
50 | file = np.mean(file, axis=1)
51 |
52 | temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
53 | write(temp_file.name, SAMPLE_RATE, (file * 32768).astype(np.int16))
54 | temp_file_path = temp_file.name
55 | temp_file.close()
56 | else:
57 | temp_file_path = file
58 |
59 | try:
60 | cmd = [
61 | "ffmpeg",
62 | "-nostdin",
63 | "-threads",
64 | "0",
65 | "-i",
66 | temp_file_path,
67 | "-f",
68 | "s16le",
69 | "-ac",
70 | "1",
71 | "-acodec",
72 | "pcm_s16le",
73 | "-ar",
74 | str(sr),
75 | "-",
76 | ]
77 | out = subprocess.run(cmd, capture_output=True, check=True).stdout
78 | except subprocess.CalledProcessError as e:
79 | raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
80 | finally:
81 | if isinstance(file, np.ndarray):
82 | os.remove(temp_file_path)
83 |
84 | return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
85 |
86 |
87 |
88 | def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
89 | """
90 | Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
91 | """
92 | if torch.is_tensor(array):
93 | if array.shape[axis] > length:
94 | array = array.index_select(
95 | dim=axis, index=torch.arange(length, device=array.device)
96 | )
97 |
98 | if array.shape[axis] < length:
99 | pad_widths = [(0, 0)] * array.ndim
100 | pad_widths[axis] = (0, length - array.shape[axis])
101 | array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
102 | else:
103 | if array.shape[axis] > length:
104 | array = array.take(indices=range(length), axis=axis)
105 |
106 | if array.shape[axis] < length:
107 | pad_widths = [(0, 0)] * array.ndim
108 | pad_widths[axis] = (0, length - array.shape[axis])
109 | array = np.pad(array, pad_widths)
110 |
111 | return array
112 |
113 |
114 | @lru_cache(maxsize=None)
115 | def mel_filters(device, n_mels: int) -> torch.Tensor:
116 | """
117 | load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
118 | Allows decoupling librosa dependency; saved using:
119 |
120 | np.savez_compressed(
121 | "mel_filters.npz",
122 | mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
123 | )
124 | """
125 | assert n_mels in [80, 128], f"Unsupported n_mels: {n_mels}"
126 | with np.load(
127 | os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
128 | ) as f:
129 | return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
130 |
131 |
132 | def log_mel_spectrogram(
133 | audio: Union[str, np.ndarray, torch.Tensor],
134 | n_mels: int,
135 | padding: int = 0,
136 | device: Optional[Union[str, torch.device]] = None,
137 | ):
138 | """
139 | Compute the log-Mel spectrogram of
140 |
141 | Parameters
142 | ----------
143 | audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
144 | The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
145 |
146 | n_mels: int
147 | The number of Mel-frequency filters, only 80 is supported
148 |
149 | padding: int
150 | Number of zero samples to pad to the right
151 |
152 | device: Optional[Union[str, torch.device]]
153 | If given, the audio tensor is moved to this device before STFT
154 |
155 | Returns
156 | -------
157 | torch.Tensor, shape = (80, n_frames)
158 | A Tensor that contains the Mel spectrogram
159 | """
160 | if not torch.is_tensor(audio):
161 | if isinstance(audio, str):
162 | audio = load_audio(audio)
163 | audio = torch.from_numpy(audio)
164 |
165 | if device is not None:
166 | audio = audio.to(device)
167 | if padding > 0:
168 | audio = F.pad(audio, (0, padding))
169 | window = torch.hann_window(N_FFT).to(audio.device)
170 | stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
171 | magnitudes = stft[..., :-1].abs() ** 2
172 |
173 | filters = mel_filters(audio.device, n_mels)
174 | mel_spec = filters @ magnitudes
175 |
176 | log_spec = torch.clamp(mel_spec, min=1e-10).log10()
177 | log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
178 | log_spec = (log_spec + 4.0) / 4.0
179 | return log_spec
--------------------------------------------------------------------------------
/modules/diarize/diarize_pipeline.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/m-bain/whisperX/blob/main/whisperx/diarize.py
2 |
3 | import numpy as np
4 | import pandas as pd
5 | import os
6 | from pyannote.audio import Pipeline
7 | from typing import Optional, Union
8 | import torch
9 |
10 | from modules.whisper.data_classes import *
11 | from modules.utils.paths import DIARIZATION_MODELS_DIR
12 | from modules.diarize.audio_loader import load_audio, SAMPLE_RATE
13 |
14 |
15 | class DiarizationPipeline:
16 | def __init__(
17 | self,
18 | model_name="pyannote/speaker-diarization-3.1",
19 | cache_dir: str = DIARIZATION_MODELS_DIR,
20 | use_auth_token=None,
21 | device: Optional[Union[str, torch.device]] = "cpu",
22 | ):
23 | if isinstance(device, str):
24 | device = torch.device(device)
25 | self.model = Pipeline.from_pretrained(
26 | model_name,
27 | use_auth_token=use_auth_token,
28 | cache_dir=cache_dir
29 | ).to(device)
30 |
31 | def __call__(self, audio: Union[str, np.ndarray], min_speakers=None, max_speakers=None):
32 | if isinstance(audio, str):
33 | audio = load_audio(audio)
34 | audio_data = {
35 | 'waveform': torch.from_numpy(audio[None, :]),
36 | 'sample_rate': SAMPLE_RATE
37 | }
38 | segments = self.model(audio_data, min_speakers=min_speakers, max_speakers=max_speakers)
39 | diarize_df = pd.DataFrame(segments.itertracks(yield_label=True), columns=['segment', 'label', 'speaker'])
40 | diarize_df['start'] = diarize_df['segment'].apply(lambda x: x.start)
41 | diarize_df['end'] = diarize_df['segment'].apply(lambda x: x.end)
42 | return diarize_df
43 |
44 |
45 | def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
46 | transcript_segments = transcript_result["segments"]
47 | if transcript_segments and isinstance(transcript_segments[0], Segment):
48 | transcript_segments = [seg.model_dump() for seg in transcript_segments]
49 | for seg in transcript_segments:
50 | # assign speaker to segment (if any)
51 | diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'],
52 | seg['start'])
53 | diarize_df['union'] = np.maximum(diarize_df['end'], seg['end']) - np.minimum(diarize_df['start'], seg['start'])
54 |
55 | intersected = diarize_df[diarize_df["intersection"] > 0]
56 |
57 | speaker = None
58 | if len(intersected) > 0:
59 | # Choosing most strong intersection
60 | speaker = intersected.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
61 | elif fill_nearest:
62 | # Otherwise choosing closest
63 | speaker = diarize_df.sort_values(by=["intersection"], ascending=False)["speaker"].values[0]
64 |
65 | if speaker is not None:
66 | seg["speaker"] = speaker
67 |
68 | # assign speaker to words
69 | if 'words' in seg and seg['words'] is not None:
70 | for word in seg['words']:
71 | if 'start' in word:
72 | diarize_df['intersection'] = np.minimum(diarize_df['end'], word['end']) - np.maximum(
73 | diarize_df['start'], word['start'])
74 | diarize_df['union'] = np.maximum(diarize_df['end'], word['end']) - np.minimum(diarize_df['start'],
75 | word['start'])
76 |
77 | intersected = diarize_df[diarize_df["intersection"] > 0]
78 |
79 | word_speaker = None
80 | if len(intersected) > 0:
81 | # Choosing most strong intersection
82 | word_speaker = \
83 | intersected.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
84 | elif fill_nearest:
85 | # Otherwise choosing closest
86 | word_speaker = diarize_df.sort_values(by=["intersection"], ascending=False)["speaker"].values[0]
87 |
88 | if word_speaker is not None:
89 | word["speaker"] = word_speaker
90 |
91 | return {"segments": transcript_segments}
92 |
93 |
94 | class DiarizationSegment:
95 | def __init__(self, start, end, speaker=None):
96 | self.start = start
97 | self.end = end
98 | self.speaker = speaker
99 |
--------------------------------------------------------------------------------
/modules/diarize/diarizer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from typing import List, Union, BinaryIO, Optional, Tuple
4 | import numpy as np
5 | import time
6 | import logging
7 | import gc
8 |
9 | from modules.utils.paths import DIARIZATION_MODELS_DIR
10 | from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers
11 | from modules.diarize.audio_loader import load_audio
12 | from modules.whisper.data_classes import *
13 |
14 |
15 | class Diarizer:
16 | def __init__(self,
17 | model_dir: str = DIARIZATION_MODELS_DIR
18 | ):
19 | self.device = self.get_device()
20 | self.available_device = self.get_available_device()
21 | self.compute_type = "float16"
22 | self.model_dir = model_dir
23 | os.makedirs(self.model_dir, exist_ok=True)
24 | self.pipe = None
25 |
26 | def run(self,
27 | audio: Union[str, BinaryIO, np.ndarray],
28 | transcribed_result: List[Segment],
29 | use_auth_token: str,
30 | device: Optional[str] = None
31 | ) -> Tuple[List[Segment], float]:
32 | """
33 | Diarize transcribed result as a post-processing
34 |
35 | Parameters
36 | ----------
37 | audio: Union[str, BinaryIO, np.ndarray]
38 | Audio input. This can be file path or binary type.
39 | transcribed_result: List[Segment]
40 | transcribed result through whisper.
41 | use_auth_token: str
42 | Huggingface token with READ permission. This is only needed the first time you download the model.
43 | You must manually go to the website https://huggingface.co/pyannote/speaker-diarization-3.1 and agree to their TOS to download the model.
44 | device: Optional[str]
45 | Device for diarization.
46 |
47 | Returns
48 | ----------
49 | segments_result: List[Segment]
50 | list of Segment that includes start, end timestamps and transcribed text
51 | elapsed_time: float
52 | elapsed time for running
53 | """
54 | start_time = time.time()
55 |
56 | if device is None:
57 | device = self.device
58 |
59 | if device != self.device or self.pipe is None:
60 | self.update_pipe(
61 | device=device,
62 | use_auth_token=use_auth_token
63 | )
64 |
65 | audio = load_audio(audio)
66 |
67 | diarization_segments = self.pipe(audio)
68 | diarized_result = assign_word_speakers(
69 | diarization_segments,
70 | {"segments": transcribed_result}
71 | )
72 |
73 | segments_result = []
74 | for segment in diarized_result["segments"]:
75 | speaker = "None"
76 | if "speaker" in segment:
77 | speaker = segment["speaker"]
78 | diarized_text = speaker + "|" + segment["text"].strip()
79 | segments_result.append(Segment(
80 | start=segment["start"],
81 | end=segment["end"],
82 | text=diarized_text
83 | ))
84 |
85 | elapsed_time = time.time() - start_time
86 | return segments_result, elapsed_time
87 |
88 | def update_pipe(self,
89 | use_auth_token: Optional[str] = None,
90 | device: Optional[str] = None,
91 | ):
92 | """
93 | Set pipeline for diarization
94 |
95 | Parameters
96 | ----------
97 | use_auth_token: str
98 | Huggingface token with READ permission. This is only needed the first time you download the model.
99 | You must manually go to the website https://huggingface.co/pyannote/speaker-diarization-3.1 and agree to their TOS to download the model.
100 | device: str
101 | Device for diarization.
102 | """
103 | if device is None:
104 | device = self.get_device()
105 | self.device = device
106 |
107 | os.makedirs(self.model_dir, exist_ok=True)
108 |
109 | if (not os.listdir(self.model_dir) and
110 | not use_auth_token):
111 | print(
112 | "\nFailed to diarize. You need huggingface token and agree to their requirements to download the diarization model.\n"
113 | "Go to \"https://huggingface.co/pyannote/speaker-diarization-3.1\" and follow their instructions to download the model.\n"
114 | )
115 | return
116 |
117 | logger = logging.getLogger("speechbrain.utils.train_logger")
118 | # Disable redundant torchvision warning message
119 | logger.disabled = True
120 | self.pipe = DiarizationPipeline(
121 | use_auth_token=use_auth_token,
122 | device=device,
123 | cache_dir=self.model_dir
124 | )
125 | logger.disabled = False
126 |
127 | def offload(self):
128 | """Offload the model and free up the memory"""
129 | if self.pipe is not None:
130 | del self.pipe
131 | self.pipe = None
132 | if self.device == "cuda":
133 | torch.cuda.empty_cache()
134 | torch.cuda.reset_max_memory_allocated()
135 | if self.device == "xpu":
136 | torch.xpu.empty_cache()
137 | torch.xpu.reset_accumulated_memory_stats()
138 | torch.xpu.reset_peak_memory_stats()
139 | gc.collect()
140 |
141 | @staticmethod
142 | def get_device():
143 | if torch.cuda.is_available():
144 | return "cuda"
145 | if torch.xpu.is_available():
146 | return "xpu"
147 | elif torch.backends.mps.is_available():
148 | return "mps"
149 | else:
150 | return "cpu"
151 |
152 | @staticmethod
153 | def get_available_device():
154 | devices = ["cpu"]
155 | if torch.cuda.is_available():
156 | devices.append("cuda")
157 | if torch.xpu.is_available():
158 | devices.append("xpu")
159 | if torch.backends.mps.is_available():
160 | devices.append("mps")
161 | return devices
--------------------------------------------------------------------------------
/modules/translation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/modules/translation/__init__.py
--------------------------------------------------------------------------------
/modules/translation/deepl_api.py:
--------------------------------------------------------------------------------
1 | import requests
2 | import time
3 | import os
4 | from datetime import datetime
5 | import gradio as gr
6 |
7 | from modules.utils.paths import TRANSLATION_OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH
8 | from modules.utils.constants import AUTOMATIC_DETECTION
9 | from modules.utils.subtitle_manager import *
10 | from modules.utils.files_manager import load_yaml, save_yaml
11 |
12 | """
13 | This is written with reference to the DeepL API documentation.
14 | If you want to know the information of the DeepL API, see here: https://www.deepl.com/docs-api/documents
15 | """
16 |
17 | DEEPL_AVAILABLE_TARGET_LANGS = {
18 | 'Bulgarian': 'BG',
19 | 'Czech': 'CS',
20 | 'Danish': 'DA',
21 | 'German': 'DE',
22 | 'Greek': 'EL',
23 | 'English': 'EN',
24 | 'English (British)': 'EN-GB',
25 | 'English (American)': 'EN-US',
26 | 'Spanish': 'ES',
27 | 'Estonian': 'ET',
28 | 'Finnish': 'FI',
29 | 'French': 'FR',
30 | 'Hungarian': 'HU',
31 | 'Indonesian': 'ID',
32 | 'Italian': 'IT',
33 | 'Japanese': 'JA',
34 | 'Korean': 'KO',
35 | 'Lithuanian': 'LT',
36 | 'Latvian': 'LV',
37 | 'Norwegian (Bokmål)': 'NB',
38 | 'Dutch': 'NL',
39 | 'Polish': 'PL',
40 | 'Portuguese': 'PT',
41 | 'Portuguese (Brazilian)': 'PT-BR',
42 | 'Portuguese (all Portuguese varieties excluding Brazilian Portuguese)': 'PT-PT',
43 | 'Romanian': 'RO',
44 | 'Russian': 'RU',
45 | 'Slovak': 'SK',
46 | 'Slovenian': 'SL',
47 | 'Swedish': 'SV',
48 | 'Turkish': 'TR',
49 | 'Ukrainian': 'UK',
50 | 'Chinese (simplified)': 'ZH'
51 | }
52 |
53 | DEEPL_AVAILABLE_SOURCE_LANGS = {
54 | AUTOMATIC_DETECTION: None,
55 | 'Bulgarian': 'BG',
56 | 'Czech': 'CS',
57 | 'Danish': 'DA',
58 | 'German': 'DE',
59 | 'Greek': 'EL',
60 | 'English': 'EN',
61 | 'Spanish': 'ES',
62 | 'Estonian': 'ET',
63 | 'Finnish': 'FI',
64 | 'French': 'FR',
65 | 'Hungarian': 'HU',
66 | 'Indonesian': 'ID',
67 | 'Italian': 'IT',
68 | 'Japanese': 'JA',
69 | 'Korean': 'KO',
70 | 'Lithuanian': 'LT',
71 | 'Latvian': 'LV',
72 | 'Norwegian (Bokmål)': 'NB',
73 | 'Dutch': 'NL',
74 | 'Polish': 'PL',
75 | 'Portuguese (all Portuguese varieties mixed)': 'PT',
76 | 'Romanian': 'RO',
77 | 'Russian': 'RU',
78 | 'Slovak': 'SK',
79 | 'Slovenian': 'SL',
80 | 'Swedish': 'SV',
81 | 'Turkish': 'TR',
82 | 'Ukrainian': 'UK',
83 | 'Chinese': 'ZH'
84 | }
85 |
86 |
87 | class DeepLAPI:
88 | def __init__(self,
89 | output_dir: str = TRANSLATION_OUTPUT_DIR
90 | ):
91 | self.api_interval = 1
92 | self.max_text_batch_size = 50
93 | self.available_target_langs = DEEPL_AVAILABLE_TARGET_LANGS
94 | self.available_source_langs = DEEPL_AVAILABLE_SOURCE_LANGS
95 | self.output_dir = output_dir
96 |
97 | def translate_deepl(self,
98 | auth_key: str,
99 | fileobjs: list,
100 | source_lang: str,
101 | target_lang: str,
102 | is_pro: bool = False,
103 | add_timestamp: bool = True,
104 | progress=gr.Progress()) -> list:
105 | """
106 | Translate subtitle files using DeepL API
107 | Parameters
108 | ----------
109 | auth_key: str
110 | API Key for DeepL from gr.Textbox()
111 | fileobjs: list
112 | List of files to transcribe from gr.Files()
113 | source_lang: str
114 | Source language of the file to transcribe from gr.Dropdown()
115 | target_lang: str
116 | Target language of the file to transcribe from gr.Dropdown()
117 | is_pro: str
118 | Boolean value that is about pro user or not from gr.Checkbox().
119 | add_timestamp: bool
120 | Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
121 | progress: gr.Progress
122 | Indicator to show progress directly in gradio.
123 |
124 | Returns
125 | ----------
126 | A List of
127 | String to return to gr.Textbox()
128 | Files to return to gr.Files()
129 | """
130 | if fileobjs and isinstance(fileobjs[0], gr.utils.NamedString):
131 | fileobjs = [fileobj.name for fileobj in fileobjs]
132 |
133 | self.cache_parameters(
134 | api_key=auth_key,
135 | is_pro=is_pro,
136 | source_lang=source_lang,
137 | target_lang=target_lang,
138 | add_timestamp=add_timestamp
139 | )
140 |
141 | files_info = {}
142 | for file_path in fileobjs:
143 | file_name, file_ext = os.path.splitext(os.path.basename(file_path))
144 | writer = get_writer(file_ext, self.output_dir)
145 | segments = writer.to_segments(file_path)
146 |
147 | batch_size = self.max_text_batch_size
148 | for batch_start in range(0, len(segments), batch_size):
149 | progress(batch_start / len(segments), desc="Translating..")
150 | sentences_to_translate = [seg.text for seg in segments[batch_start:batch_start+batch_size]]
151 | translated_texts = self.request_deepl_translate(auth_key, sentences_to_translate, source_lang,
152 | target_lang, is_pro)
153 | for i, translated_text in enumerate(translated_texts):
154 | segments[batch_start + i].text = translated_text["text"]
155 |
156 | subtitle, output_path = generate_file(
157 | output_dir=self.output_dir,
158 | output_file_name=file_name,
159 | output_format=file_ext,
160 | result=segments,
161 | add_timestamp=add_timestamp
162 | )
163 |
164 | files_info[file_name] = {"subtitle": subtitle, "path": output_path}
165 |
166 | total_result = ''
167 | for file_name, info in files_info.items():
168 | total_result += '------------------------------------\n'
169 | total_result += f'{file_name}\n\n'
170 | total_result += f'{info["subtitle"]}'
171 | gr_str = f"Done! Subtitle is in the outputs/translation folder.\n\n{total_result}"
172 |
173 | output_file_paths = [item["path"] for key, item in files_info.items()]
174 | return [gr_str, output_file_paths]
175 |
176 | def request_deepl_translate(self,
177 | auth_key: str,
178 | text: list,
179 | source_lang: str,
180 | target_lang: str,
181 | is_pro: bool = False):
182 | """Request API response to DeepL server"""
183 | if source_lang not in list(DEEPL_AVAILABLE_SOURCE_LANGS.keys()):
184 | raise ValueError(f"Source language {source_lang} is not supported."
185 | f"Use one of {list(DEEPL_AVAILABLE_SOURCE_LANGS.keys())}")
186 | if target_lang not in list(DEEPL_AVAILABLE_TARGET_LANGS.keys()):
187 | raise ValueError(f"Target language {target_lang} is not supported."
188 | f"Use one of {list(DEEPL_AVAILABLE_TARGET_LANGS.keys())}")
189 |
190 | url = 'https://api.deepl.com/v2/translate' if is_pro else 'https://api-free.deepl.com/v2/translate'
191 | headers = {
192 | 'Authorization': f'DeepL-Auth-Key {auth_key}'
193 | }
194 | data = {
195 | 'text': text,
196 | 'source_lang': DEEPL_AVAILABLE_SOURCE_LANGS[source_lang],
197 | 'target_lang': DEEPL_AVAILABLE_TARGET_LANGS[target_lang]
198 | }
199 | response = requests.post(url, headers=headers, data=data).json()
200 | time.sleep(self.api_interval)
201 | return response["translations"]
202 |
203 | @staticmethod
204 | def cache_parameters(api_key: str,
205 | is_pro: bool,
206 | source_lang: str,
207 | target_lang: str,
208 | add_timestamp: bool):
209 | cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
210 | cached_params["translation"]["deepl"] = {
211 | "api_key": api_key,
212 | "is_pro": is_pro,
213 | "source_lang": source_lang,
214 | "target_lang": target_lang
215 | }
216 | cached_params["translation"]["add_timestamp"] = add_timestamp
217 | save_yaml(cached_params, DEFAULT_PARAMETERS_CONFIG_PATH)
218 |
--------------------------------------------------------------------------------
/modules/translation/nllb_inference.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
2 | import gradio as gr
3 | import os
4 |
5 | from modules.utils.paths import TRANSLATION_OUTPUT_DIR, NLLB_MODELS_DIR
6 | from modules.translation.translation_base import TranslationBase
7 |
8 |
9 | class NLLBInference(TranslationBase):
10 | def __init__(self,
11 | model_dir: str = NLLB_MODELS_DIR,
12 | output_dir: str = TRANSLATION_OUTPUT_DIR
13 | ):
14 | super().__init__(
15 | model_dir=model_dir,
16 | output_dir=output_dir
17 | )
18 | self.tokenizer = None
19 | self.available_models = ["facebook/nllb-200-3.3B", "facebook/nllb-200-1.3B", "facebook/nllb-200-distilled-600M"]
20 | self.available_source_langs = list(NLLB_AVAILABLE_LANGS.keys())
21 | self.available_target_langs = list(NLLB_AVAILABLE_LANGS.keys())
22 | self.pipeline = None
23 |
24 | def translate(self,
25 | text: str,
26 | max_length: int
27 | ):
28 | result = self.pipeline(
29 | text,
30 | max_length=max_length
31 | )
32 | return result[0]["translation_text"]
33 |
34 | def update_model(self,
35 | model_size: str,
36 | src_lang: str,
37 | tgt_lang: str,
38 | progress: gr.Progress = gr.Progress()
39 | ):
40 | def validate_language(lang: str) -> str:
41 | if lang in NLLB_AVAILABLE_LANGS:
42 | return NLLB_AVAILABLE_LANGS[lang]
43 | elif lang not in NLLB_AVAILABLE_LANGS.values():
44 | raise ValueError(f"Language '{lang}' is not supported. Use one of: {list(NLLB_AVAILABLE_LANGS.keys())}")
45 | return lang
46 |
47 | src_lang = validate_language(src_lang)
48 | tgt_lang = validate_language(tgt_lang)
49 |
50 | if model_size != self.current_model_size or self.model is None:
51 | print("\nInitializing NLLB Model..\n")
52 | progress(0, desc="Initializing NLLB Model..")
53 | self.current_model_size = model_size
54 | local_files_only = self.is_model_exists(self.current_model_size)
55 | self.model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path=model_size,
56 | cache_dir=self.model_dir,
57 | local_files_only=local_files_only)
58 | self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_size,
59 | cache_dir=os.path.join(self.model_dir, "tokenizers"),
60 | local_files_only=local_files_only)
61 |
62 | self.pipeline = pipeline("translation",
63 | model=self.model,
64 | tokenizer=self.tokenizer,
65 | src_lang=src_lang,
66 | tgt_lang=tgt_lang,
67 | device=self.device)
68 |
69 | def is_model_exists(self,
70 | model_size: str):
71 | """Check if model exists or not (Only facebook model)"""
72 | prefix = "models--facebook--"
73 | _id, model_size_name = model_size.split("/")
74 | model_dir_name = prefix + model_size_name
75 | model_dir_path = os.path.join(self.model_dir, model_dir_name)
76 | if os.path.exists(model_dir_path) and os.listdir(model_dir_path):
77 | return True
78 | for model_dir_name in os.listdir(self.model_dir):
79 | if (model_size in model_dir_name or model_size_name in model_dir_name) and \
80 | os.listdir(os.path.join(self.model_dir, model_dir_name)):
81 | return True
82 | return False
83 |
84 |
85 | NLLB_AVAILABLE_LANGS = {
86 | "Acehnese (Arabic script)": "ace_Arab",
87 | "Acehnese (Latin script)": "ace_Latn",
88 | "Mesopotamian Arabic": "acm_Arab",
89 | "Ta’izzi-Adeni Arabic": "acq_Arab",
90 | "Tunisian Arabic": "aeb_Arab",
91 | "Afrikaans": "afr_Latn",
92 | "South Levantine Arabic": "ajp_Arab",
93 | "Akan": "aka_Latn",
94 | "Amharic": "amh_Ethi",
95 | "North Levantine Arabic": "apc_Arab",
96 | "Modern Standard Arabic": "arb_Arab",
97 | "Modern Standard Arabic (Romanized)": "arb_Latn",
98 | "Najdi Arabic": "ars_Arab",
99 | "Moroccan Arabic": "ary_Arab",
100 | "Egyptian Arabic": "arz_Arab",
101 | "Assamese": "asm_Beng",
102 | "Asturian": "ast_Latn",
103 | "Awadhi": "awa_Deva",
104 | "Central Aymara": "ayr_Latn",
105 | "South Azerbaijani": "azb_Arab",
106 | "North Azerbaijani": "azj_Latn",
107 | "Bashkir": "bak_Cyrl",
108 | "Bambara": "bam_Latn",
109 | "Balinese": "ban_Latn",
110 | "Belarusian": "bel_Cyrl",
111 | "Bemba": "bem_Latn",
112 | "Bengali": "ben_Beng",
113 | "Bhojpuri": "bho_Deva",
114 | "Banjar (Arabic script)": "bjn_Arab",
115 | "Banjar (Latin script)": "bjn_Latn",
116 | "Standard Tibetan": "bod_Tibt",
117 | "Bosnian": "bos_Latn",
118 | "Buginese": "bug_Latn",
119 | "Bulgarian": "bul_Cyrl",
120 | "Catalan": "cat_Latn",
121 | "Cebuano": "ceb_Latn",
122 | "Czech": "ces_Latn",
123 | "Chokwe": "cjk_Latn",
124 | "Central Kurdish": "ckb_Arab",
125 | "Crimean Tatar": "crh_Latn",
126 | "Welsh": "cym_Latn",
127 | "Danish": "dan_Latn",
128 | "German": "deu_Latn",
129 | "Southwestern Dinka": "dik_Latn",
130 | "Dyula": "dyu_Latn",
131 | "Dzongkha": "dzo_Tibt",
132 | "Greek": "ell_Grek",
133 | "English": "eng_Latn",
134 | "Esperanto": "epo_Latn",
135 | "Estonian": "est_Latn",
136 | "Basque": "eus_Latn",
137 | "Ewe": "ewe_Latn",
138 | "Faroese": "fao_Latn",
139 | "Fijian": "fij_Latn",
140 | "Finnish": "fin_Latn",
141 | "Fon": "fon_Latn",
142 | "French": "fra_Latn",
143 | "Friulian": "fur_Latn",
144 | "Nigerian Fulfulde": "fuv_Latn",
145 | "Scottish Gaelic": "gla_Latn",
146 | "Irish": "gle_Latn",
147 | "Galician": "glg_Latn",
148 | "Guarani": "grn_Latn",
149 | "Gujarati": "guj_Gujr",
150 | "Haitian Creole": "hat_Latn",
151 | "Hausa": "hau_Latn",
152 | "Hebrew": "heb_Hebr",
153 | "Hindi": "hin_Deva",
154 | "Chhattisgarhi": "hne_Deva",
155 | "Croatian": "hrv_Latn",
156 | "Hungarian": "hun_Latn",
157 | "Armenian": "hye_Armn",
158 | "Igbo": "ibo_Latn",
159 | "Ilocano": "ilo_Latn",
160 | "Indonesian": "ind_Latn",
161 | "Icelandic": "isl_Latn",
162 | "Italian": "ita_Latn",
163 | "Javanese": "jav_Latn",
164 | "Japanese": "jpn_Jpan",
165 | "Kabyle": "kab_Latn",
166 | "Jingpho": "kac_Latn",
167 | "Kamba": "kam_Latn",
168 | "Kannada": "kan_Knda",
169 | "Kashmiri (Arabic script)": "kas_Arab",
170 | "Kashmiri (Devanagari script)": "kas_Deva",
171 | "Georgian": "kat_Geor",
172 | "Central Kanuri (Arabic script)": "knc_Arab",
173 | "Central Kanuri (Latin script)": "knc_Latn",
174 | "Kazakh": "kaz_Cyrl",
175 | "Kabiyè": "kbp_Latn",
176 | "Kabuverdianu": "kea_Latn",
177 | "Khmer": "khm_Khmr",
178 | "Kikuyu": "kik_Latn",
179 | "Kinyarwanda": "kin_Latn",
180 | "Kyrgyz": "kir_Cyrl",
181 | "Kimbundu": "kmb_Latn",
182 | "Northern Kurdish": "kmr_Latn",
183 | "Kikongo": "kon_Latn",
184 | "Korean": "kor_Hang",
185 | "Lao": "lao_Laoo",
186 | "Ligurian": "lij_Latn",
187 | "Limburgish": "lim_Latn",
188 | "Lingala": "lin_Latn",
189 | "Lithuanian": "lit_Latn",
190 | "Lombard": "lmo_Latn",
191 | "Latgalian": "ltg_Latn",
192 | "Luxembourgish": "ltz_Latn",
193 | "Luba-Kasai": "lua_Latn",
194 | "Ganda": "lug_Latn",
195 | "Luo": "luo_Latn",
196 | "Mizo": "lus_Latn",
197 | "Standard Latvian": "lvs_Latn",
198 | "Magahi": "mag_Deva",
199 | "Maithili": "mai_Deva",
200 | "Malayalam": "mal_Mlym",
201 | "Marathi": "mar_Deva",
202 | "Minangkabau (Arabic script)": "min_Arab",
203 | "Minangkabau (Latin script)": "min_Latn",
204 | "Macedonian": "mkd_Cyrl",
205 | "Plateau Malagasy": "plt_Latn",
206 | "Maltese": "mlt_Latn",
207 | "Meitei (Bengali script)": "mni_Beng",
208 | "Halh Mongolian": "khk_Cyrl",
209 | "Mossi": "mos_Latn",
210 | "Maori": "mri_Latn",
211 | "Burmese": "mya_Mymr",
212 | "Dutch": "nld_Latn",
213 | "Norwegian Nynorsk": "nno_Latn",
214 | "Norwegian Bokmål": "nob_Latn",
215 | "Nepali": "npi_Deva",
216 | "Northern Sotho": "nso_Latn",
217 | "Nuer": "nus_Latn",
218 | "Nyanja": "nya_Latn",
219 | "Occitan": "oci_Latn",
220 | "West Central Oromo": "gaz_Latn",
221 | "Odia": "ory_Orya",
222 | "Pangasinan": "pag_Latn",
223 | "Eastern Panjabi": "pan_Guru",
224 | "Papiamento": "pap_Latn",
225 | "Western Persian": "pes_Arab",
226 | "Polish": "pol_Latn",
227 | "Portuguese": "por_Latn",
228 | "Dari": "prs_Arab",
229 | "Southern Pashto": "pbt_Arab",
230 | "Ayacucho Quechua": "quy_Latn",
231 | "Romanian": "ron_Latn",
232 | "Rundi": "run_Latn",
233 | "Russian": "rus_Cyrl",
234 | "Sango": "sag_Latn",
235 | "Sanskrit": "san_Deva",
236 | "Santali": "sat_Olck",
237 | "Sicilian": "scn_Latn",
238 | "Shan": "shn_Mymr",
239 | "Sinhala": "sin_Sinh",
240 | "Slovak": "slk_Latn",
241 | "Slovenian": "slv_Latn",
242 | "Samoan": "smo_Latn",
243 | "Shona": "sna_Latn",
244 | "Sindhi": "snd_Arab",
245 | "Somali": "som_Latn",
246 | "Southern Sotho": "sot_Latn",
247 | "Spanish": "spa_Latn",
248 | "Tosk Albanian": "als_Latn",
249 | "Sardinian": "srd_Latn",
250 | "Serbian": "srp_Cyrl",
251 | "Swati": "ssw_Latn",
252 | "Sundanese": "sun_Latn",
253 | "Swedish": "swe_Latn",
254 | "Swahili": "swh_Latn",
255 | "Silesian": "szl_Latn",
256 | "Tamil": "tam_Taml",
257 | "Tatar": "tat_Cyrl",
258 | "Telugu": "tel_Telu",
259 | "Tajik": "tgk_Cyrl",
260 | "Tagalog": "tgl_Latn",
261 | "Thai": "tha_Thai",
262 | "Tigrinya": "tir_Ethi",
263 | "Tamasheq (Latin script)": "taq_Latn",
264 | "Tamasheq (Tifinagh script)": "taq_Tfng",
265 | "Tok Pisin": "tpi_Latn",
266 | "Tswana": "tsn_Latn",
267 | "Tsonga": "tso_Latn",
268 | "Turkmen": "tuk_Latn",
269 | "Tumbuka": "tum_Latn",
270 | "Turkish": "tur_Latn",
271 | "Twi": "twi_Latn",
272 | "Central Atlas Tamazight": "tzm_Tfng",
273 | "Uyghur": "uig_Arab",
274 | "Ukrainian": "ukr_Cyrl",
275 | "Umbundu": "umb_Latn",
276 | "Urdu": "urd_Arab",
277 | "Northern Uzbek": "uzn_Latn",
278 | "Venetian": "vec_Latn",
279 | "Vietnamese": "vie_Latn",
280 | "Waray": "war_Latn",
281 | "Wolof": "wol_Latn",
282 | "Xhosa": "xho_Latn",
283 | "Eastern Yiddish": "ydd_Hebr",
284 | "Yoruba": "yor_Latn",
285 | "Yue Chinese": "yue_Hant",
286 | "Chinese (Simplified)": "zho_Hans",
287 | "Chinese (Traditional)": "zho_Hant",
288 | "Standard Malay": "zsm_Latn",
289 | "Zulu": "zul_Latn",
290 | }
291 |
--------------------------------------------------------------------------------
/modules/translation/translation_base.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import gradio as gr
4 | from abc import ABC, abstractmethod
5 | import gc
6 | from typing import List
7 | from datetime import datetime
8 |
9 | import modules.translation.nllb_inference as nllb
10 | from modules.whisper.data_classes import *
11 | from modules.utils.subtitle_manager import *
12 | from modules.utils.files_manager import load_yaml, save_yaml
13 | from modules.utils.paths import DEFAULT_PARAMETERS_CONFIG_PATH, NLLB_MODELS_DIR, TRANSLATION_OUTPUT_DIR
14 |
15 |
16 | class TranslationBase(ABC):
17 | def __init__(self,
18 | model_dir: str = NLLB_MODELS_DIR,
19 | output_dir: str = TRANSLATION_OUTPUT_DIR
20 | ):
21 | super().__init__()
22 | self.model = None
23 | self.model_dir = model_dir
24 | self.output_dir = output_dir
25 | os.makedirs(self.model_dir, exist_ok=True)
26 | os.makedirs(self.output_dir, exist_ok=True)
27 | self.current_model_size = None
28 | self.device = self.get_device()
29 |
30 | @abstractmethod
31 | def translate(self,
32 | text: str,
33 | max_length: int
34 | ):
35 | pass
36 |
37 | @abstractmethod
38 | def update_model(self,
39 | model_size: str,
40 | src_lang: str,
41 | tgt_lang: str,
42 | progress: gr.Progress = gr.Progress()
43 | ):
44 | pass
45 |
46 | def translate_file(self,
47 | fileobjs: list,
48 | model_size: str,
49 | src_lang: str,
50 | tgt_lang: str,
51 | max_length: int = 200,
52 | add_timestamp: bool = True,
53 | progress=gr.Progress()) -> list:
54 | """
55 | Translate subtitle file from source language to target language
56 |
57 | Parameters
58 | ----------
59 | fileobjs: list
60 | List of files to transcribe from gr.Files()
61 | model_size: str
62 | Whisper model size from gr.Dropdown()
63 | src_lang: str
64 | Source language of the file to translate from gr.Dropdown()
65 | tgt_lang: str
66 | Target language of the file to translate from gr.Dropdown()
67 | max_length: int
68 | Max length per line to translate
69 | add_timestamp: bool
70 | Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
71 | progress: gr.Progress
72 | Indicator to show progress directly in gradio.
73 | I use a forked version of whisper for this. To see more info : https://github.com/jhj0517/jhj0517-whisper/tree/add-progress-callback
74 |
75 | Returns
76 | ----------
77 | A List of
78 | String to return to gr.Textbox()
79 | Files to return to gr.Files()
80 | """
81 | try:
82 | if fileobjs and isinstance(fileobjs[0], gr.utils.NamedString):
83 | fileobjs = [file.name for file in fileobjs]
84 |
85 | self.cache_parameters(model_size=model_size,
86 | src_lang=src_lang,
87 | tgt_lang=tgt_lang,
88 | max_length=max_length,
89 | add_timestamp=add_timestamp)
90 |
91 | self.update_model(model_size=model_size,
92 | src_lang=src_lang,
93 | tgt_lang=tgt_lang,
94 | progress=progress)
95 |
96 | files_info = {}
97 | for fileobj in fileobjs:
98 | file_name, file_ext = os.path.splitext(os.path.basename(fileobj))
99 | writer = get_writer(file_ext, self.output_dir)
100 | segments = writer.to_segments(fileobj)
101 | for i, segment in enumerate(segments):
102 | progress(i / len(segments), desc="Translating..")
103 | translated_text = self.translate(segment.text, max_length=max_length)
104 | segment.text = translated_text
105 |
106 | subtitle, file_path = generate_file(
107 | output_dir=self.output_dir,
108 | output_file_name=file_name,
109 | output_format=file_ext,
110 | result=segments,
111 | add_timestamp=add_timestamp
112 | )
113 |
114 | files_info[file_name] = {"subtitle": subtitle, "path": file_path}
115 |
116 | total_result = ''
117 | for file_name, info in files_info.items():
118 | total_result += '------------------------------------\n'
119 | total_result += f'{file_name}\n\n'
120 | total_result += f'{info["subtitle"]}'
121 | gr_str = f"Done! Subtitle is in the outputs/translation folder.\n\n{total_result}"
122 |
123 | output_file_paths = [item["path"] for key, item in files_info.items()]
124 | return [gr_str, output_file_paths]
125 |
126 | except Exception as e:
127 | print(f"Error translating file: {e}")
128 | raise
129 | finally:
130 | self.offload()
131 |
132 | @staticmethod
133 | def get_device():
134 | if torch.cuda.is_available():
135 | return "cuda"
136 | if torch.xpu.is_available():
137 | return "xpu"
138 | elif torch.backends.mps.is_available():
139 | return "mps"
140 | else:
141 | return "cpu"
142 |
143 | def offload(self):
144 | """Offload the model and free up the memory"""
145 | if self.model is not None:
146 | del self.model
147 | self.model = None
148 | if self.device == "cuda":
149 | torch.cuda.empty_cache()
150 | torch.cuda.reset_max_memory_allocated()
151 | if self.device == "xpu":
152 | torch.xpu.empty_cache()
153 | torch.xpu.reset_accumulated_memory_stats()
154 | torch.xpu.reset_peak_memory_stats()
155 | gc.collect()
156 |
157 | @staticmethod
158 | def remove_input_files(file_paths: List[str]):
159 | if not file_paths:
160 | return
161 |
162 | for file_path in file_paths:
163 | if file_path and os.path.exists(file_path):
164 | os.remove(file_path)
165 |
166 | @staticmethod
167 | def cache_parameters(model_size: str,
168 | src_lang: str,
169 | tgt_lang: str,
170 | max_length: int,
171 | add_timestamp: bool):
172 | def validate_lang(lang: str):
173 | if lang in list(nllb.NLLB_AVAILABLE_LANGS.values()):
174 | flipped = {value: key for key, value in nllb.NLLB_AVAILABLE_LANGS.items()}
175 | return flipped[lang]
176 | return lang
177 |
178 | cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
179 | cached_params["translation"]["nllb"] = {
180 | "model_size": model_size,
181 | "source_lang": validate_lang(src_lang),
182 | "target_lang": validate_lang(tgt_lang),
183 | "max_length": max_length,
184 | }
185 | cached_params["translation"]["add_timestamp"] = add_timestamp
186 | save_yaml(cached_params, DEFAULT_PARAMETERS_CONFIG_PATH)
187 |
--------------------------------------------------------------------------------
/modules/ui/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jhj0517/Whisper-WebUI/f96deb2aa0b4208064d0d8d607419e3099f6b9a7/modules/ui/__init__.py
--------------------------------------------------------------------------------
/modules/ui/htmls.py:
--------------------------------------------------------------------------------
1 | CSS = """
2 | .bmc-button {
3 | padding: 2px 5px;
4 | border-radius: 5px;
5 | background-color: #FF813F;
6 | color: white;
7 | box-shadow: 0px 1px 2px rgba(0, 0, 0, 0.3);
8 | text-decoration: none;
9 | display: inline-block;
10 | font-size: 20px;
11 | margin: 2px;
12 | cursor: pointer;
13 | -webkit-transition: background-color 0.3s ease;
14 | -ms-transition: background-color 0.3s ease;
15 | transition: background-color 0.3s ease;
16 | }
17 | .bmc-button:hover,
18 | .bmc-button:active,
19 | .bmc-button:focus {
20 | background-color: #FF5633;
21 | }
22 | .markdown {
23 | margin-bottom: 0;
24 | padding-bottom: 0;
25 | }
26 | .tabs {
27 | margin-top: 0;
28 | padding-top: 0;
29 | }
30 |
31 | #md_project a {
32 | color: black;
33 | text-decoration: none;
34 | }
35 | #md_project a:hover {
36 | text-decoration: underline;
37 | }
38 | """
39 |
40 | MARKDOWN = """
41 | ### [Whisper-WebUI](https://github.com/jhj0517/Whsiper-WebUI)
42 | """
43 |
44 |
45 | NLLB_VRAM_TABLE = """
46 |
47 |
48 |
Model name | 74 |Required VRAM | 75 |
---|---|
nllb-200-3.3B | 80 |~16GB | 81 |
nllb-200-1.3B | 84 |~8GB | 85 |
nllb-200-distilled-600M | 88 |~4GB | 89 |
Note: Be mindful of your VRAM! The table above provides an approximate VRAM usage for each model.
93 |