├── .dockerignore
├── .github
├── PULL_REQUEST_TEMPLATE.md
└── workflows
│ └── ci.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CONTRIBUTING.md
├── Dockerfile
├── LICENSE
├── README.md
├── SWARM.md
├── bake_deploy_prod.sh
├── bake_deploy_staging.sh
├── docker-compose.yml
├── dockerless
├── .gitignore
├── start_backend.sh
├── start_frontend.sh
├── start_llm.sh
├── start_stt.sh
└── start_tts.sh
├── docs
└── browser_backend_communication.md
├── frontend
├── .gitignore
├── Dockerfile
├── README.md
├── eslint.config.mjs
├── hot-reloading.Dockerfile
├── next.config.ts
├── package.json
├── pnpm-lock.yaml
├── postcss.config.mjs
├── public
│ ├── audio-output-processor.js
│ ├── decoderWorker.min.js
│ ├── decoderWorker.min.wasm
│ └── encoderWorker.min.js
├── src
│ ├── app
│ │ ├── ConsentModal.tsx
│ │ ├── CouldNotConnect.tsx
│ │ ├── ErrorMessages.tsx
│ │ ├── Modal.tsx
│ │ ├── PositionedAudioVisualizer.tsx
│ │ ├── SingleRoleSubtitles.tsx
│ │ ├── SlantedButton.tsx
│ │ ├── SquareButton.tsx
│ │ ├── Subtitles.tsx
│ │ ├── TrimmedAudioPreview.tsx
│ │ ├── Unmute.tsx
│ │ ├── UnmuteConfigurator.tsx
│ │ ├── UnmuteHeader.tsx
│ │ ├── VoiceAttribution.tsx
│ │ ├── VoiceRecorder.tsx
│ │ ├── VoiceUpload.tsx
│ │ ├── audioUtil.ts
│ │ ├── chatHistory.ts
│ │ ├── cssUtil.ts
│ │ ├── favicon.ico
│ │ ├── faviconKyutai.ico
│ │ ├── faviconKyutai.png
│ │ ├── globals.css
│ │ ├── layout.tsx
│ │ ├── opus-recorder.d.ts
│ │ ├── page.tsx
│ │ ├── useAudioProcessor.ts
│ │ ├── useAudioVisualizerCircle.ts
│ │ ├── useBackendServerUrl.ts
│ │ ├── useGoogleAnalytics.ts
│ │ ├── useKeyboardShortcuts.ts
│ │ ├── useLocalStorage.ts
│ │ ├── useMicrophoneAccess.ts
│ │ ├── useRecordingCanvas.ts
│ │ ├── useWakeLock.ts
│ │ └── voice-donation
│ │ │ ├── DonationConsent.tsx
│ │ │ ├── IntroText.mdx
│ │ │ ├── page.tsx
│ │ │ ├── privacy-policy
│ │ │ └── page.tsx
│ │ │ └── terms-of-use
│ │ │ └── page.tsx
│ ├── assets
│ │ ├── fonts
│ │ │ ├── Satoshi-Variable.eot
│ │ │ ├── Satoshi-Variable.ttf
│ │ │ ├── Satoshi-Variable.woff
│ │ │ ├── Satoshi-Variable.woff2
│ │ │ ├── Satoshi-VariableItalic.eot
│ │ │ ├── Satoshi-VariableItalic.ttf
│ │ │ ├── Satoshi-VariableItalic.woff
│ │ │ └── Satoshi-VariableItalic.woff2
│ │ └── kyutai-logo-cropped.svg
│ └── mdx-components.tsx
└── tsconfig.json
├── notebooks
├── .gitignore
└── create-voice-donation-sentences.ipynb
├── pyproject.toml
├── services
├── debugger
│ └── Dockerfile
├── grafana
│ ├── Dockerfile
│ ├── dashboards
│ │ └── unmute-monitoring-1751624072717.json
│ ├── grafana.ini
│ └── provisioning
│ │ ├── dashboards
│ │ └── dashboards.yaml
│ │ └── datasources
│ │ └── datasources.yaml
├── moshi-server
│ ├── configs
│ │ ├── stt-prod.toml
│ │ ├── stt.toml
│ │ ├── tts-prod.toml
│ │ ├── tts.toml
│ │ └── voice-cloning.toml
│ ├── private.Dockerfile
│ ├── public.Dockerfile
│ ├── start_moshi_server_private.sh
│ └── start_moshi_server_public.sh
└── prometheus
│ ├── Dockerfile
│ └── prometheus.yml
├── setup_gpu_swarm_node.py
├── swarm-deploy.yml
├── tests
├── test_exponential_moving_average.py
└── test_llm_utils.py
├── unmute
├── audio_input_override.py
├── audio_stream_saver.py
├── cache.py
├── exceptions.py
├── kyutai_constants.py
├── llm
│ ├── chatbot.py
│ ├── llm_utils.py
│ ├── newsapi.py
│ ├── quiz_show_questions.py
│ └── system_prompt.py
├── loadtest
│ ├── dummy_tts_server.py
│ ├── generate_dataset_for_vllm.py
│ ├── loadtest_client.py
│ ├── loadtest_result.py
│ └── voices
│ │ ├── Bear-or-shark-trim.mp3
│ │ ├── dog-or-cat-3-nowait.mp3
│ │ ├── seine.mp3
│ │ └── vaclav_english_news_trim.mp3
├── main_gradio.py
├── main_websocket.py
├── metrics.py
├── openai_realtime_api_events.py
├── process_recording.py
├── quest_manager.py
├── recorder.py
├── scripts
│ ├── check_hugging_face_token_not_write.py
│ ├── copy_voice_to_prod.py
│ ├── example_websocket_client.py
│ ├── mistral_streaming.py
│ ├── output_from_file.py
│ ├── output_sine.py
│ ├── output_sine_async.py
│ ├── output_tts.py
│ ├── pitch_detection_handler.py
│ ├── stt_from_file_example.py
│ ├── stt_microphone_example.py
│ ├── tts_example.py
│ ├── update_voice_list.py
│ └── vllm_wrapper_example.py
├── service_discovery.py
├── stt
│ ├── dummy_speech_to_text.py
│ ├── exponential_moving_average.py
│ └── speech_to_text.py
├── timer.py
├── tts
│ ├── copy_approved_voice_donations.py
│ ├── create_voice_donation_table.py
│ ├── freesound_download.py
│ ├── realtime_queue.py
│ ├── text_to_speech.py
│ ├── trim_voice_donation_clip.py
│ ├── voice_cloning.py
│ ├── voice_donation.py
│ ├── voice_donation_sentences.txt
│ └── voices.py
├── unmute_handler.py
├── webrtc_utils.py
└── websocket_utils.py
├── uv.lock
└── voices.yaml
/.dockerignore:
--------------------------------------------------------------------------------
1 | *.egg-info
2 | *.pyc
3 |
4 | debug/
5 | recordings/
6 | .venv/
7 |
8 | Dockerfile
9 |
10 | frontend/node_modules
11 | frontend/.next
12 | volumes/
13 | notebooks/
14 | voices/
15 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | ## Checklist
2 |
3 | - [ ] Read CONTRIBUTING.md, and accept the CLA by including the provided snippet. We will not accept PR without this.
4 |
5 | ## PR Description
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | on:
4 | push:
5 | branches:
6 | - prod
7 | pull_request:
8 |
9 | jobs:
10 | # Enable again when we don't have private dependencies
11 | #build-docker-images:
12 | # runs-on: ubuntu-latest
13 | # steps:
14 | # - name: Checkout code
15 | # uses: actions/checkout@v3
16 | #
17 | # - name: Set up a builder (we don't want to load the images)
18 | # run: docker buildx create --name mybuilder --use
19 | #
20 | # - name: Build all docker images
21 | # run: docker buildx bake --progress=plain -f swarm-deploy.yml workers frontend tts
22 | # env:
23 | # DOMAIN: dummy
24 |
25 | pre-commit:
26 | runs-on: ubuntu-latest
27 | steps:
28 | - name: Checkout code
29 | uses: actions/checkout@v3
30 |
31 | - name: Install uv
32 | uses: astral-sh/setup-uv@v5
33 | with:
34 | version: "0.7.12"
35 |
36 | - name: Install Node.js
37 | uses: actions/setup-node@v4
38 | with:
39 | node-version: 20
40 |
41 | - name: Install pnpm
42 | run: npm install -g pnpm
43 |
44 | - name: Install dependencies
45 | run: cd frontend && pnpm install
46 |
47 | - name: Run pre-commit
48 | run: |
49 | uv run pre-commit run --all-files
50 | # Some redundancy here because some hooks will run in any stage,
51 | # but I don't think there is a cleaner way to make sure they all run
52 | uv run pre-commit run --all-files --hook-stage pre-push
53 |
54 | backend-unit-tests:
55 | runs-on: ubuntu-latest
56 | steps:
57 | - name: Checkout code
58 | uses: actions/checkout@v3
59 |
60 | - name: Install uv
61 | uses: astral-sh/setup-uv@v5
62 | with:
63 | version: "0.7.12"
64 |
65 | - name: Run backend unit tests
66 | run: uv run pytest -v
67 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.egg-info
2 | *.pyc
3 | .DS_Store
4 |
5 | debug/
6 | recordings/
7 | .venv/
8 | # only ignore voices/ in the root directory
9 | /voices/
10 |
11 | # env files (can opt-in for committing if needed)
12 | .env*
13 |
14 | # vercel
15 | .vercel
16 |
17 | # typescript
18 | *.tsbuildinfo
19 | next-env.d.ts
20 |
21 | # Traefik/HTTPS
22 | certs/
23 |
24 | volumes/
25 | CLAUDE.md
26 | .claude/settings.local.json
27 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/kynan/nbstripout
3 | rev: 0.8.1
4 | hooks:
5 | - id: nbstripout
6 | - repo: https://github.com/pre-commit/pre-commit-hooks
7 | rev: v5.0.0 # Use the ref you want to point at
8 | hooks:
9 | - id: check-added-large-files
10 | args: ["--maxkb=2048"]
11 | - repo: https://github.com/astral-sh/ruff-pre-commit
12 | # Ruff version.
13 | rev: v0.11.7
14 | hooks:
15 | # Run the linter.
16 | - id: ruff
17 | types_or: [python, pyi] # Don't run on `jupyter` files
18 | args: [--fix]
19 | # Run the formatter.
20 | - id: ruff-format
21 | types_or: [python, pyi] # Don't run on `jupyter` files
22 | - repo: https://github.com/pre-commit/pre-commit-hooks
23 | rev: v3.2.0
24 | hooks:
25 | - id: trailing-whitespace
26 | - repo: local
27 | hooks:
28 | - id: pnpm-run-lint
29 | name: pnpm run lint
30 | language: system
31 | entry: bash -c 'cd frontend && pnpm run lint --max-warnings 0'
32 | files: ^frontend/src/.*$
33 | pass_filenames: false
34 | stages: [pre-commit]
35 | - id: pnpm-run-build
36 | name: pnpm run build
37 | language: system
38 | entry: bash -c 'cd frontend && pnpm run build'
39 | files: ^frontend/src/.*$
40 | pass_filenames: false
41 | stages: [pre-push]
42 | - id: pyright
43 | name: Pyright type-checking
44 | language: system
45 | entry: bash -c 'uv run pyright'
46 | files: ^unmute/.*$
47 | pass_filenames: false
48 | stages: [pre-push]
49 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to Unmute
2 |
3 | ## Pull Requests
4 |
5 | 1. Fork the repo and create your branch from `main`.
6 | 2. If you have changed APIs, update the documentation accordingly.
7 | 3. Ensure pre-commit hooks pass properly, in particular the linting and typing.
8 | 4. Accept the Contributor License Agreement (see after).
9 |
10 | ## Contributor License Agreement ("CLA")
11 |
12 | In order to accept your pull request, we need you to submit a Contributor License Agreement.
13 |
14 | If you agree with the full CLA provided in the next paragraph, copy the following statement in your PR, changing your Github Handle:
15 |
16 | > I, {your GitHub handle}, confirm that I have read and understood the terms of the CLA of Kyutai-labs, as outlined in the repository's CONTRIBUTING.md, and I agree to be bound by these terms.
17 | The full CLA is provided as follows:
18 |
19 | > I, {your GitHub handle}, hereby grant to Kyutai-labs a perpetual, worldwide, non-exclusive, royalty-free,
20 | > irrevocable license to use, modify, distribute, and sublicense my Contributions.
21 | > I understand and accept that Contributions are limited to modifications, improvements, or changes
22 | > to the project’s source code submitted via pull requests. I accept that Kyutai-labs has full discretion to
23 | > review, accept, reject, or request changes to any Contributions I submit, and that submitting
24 | > a pull request does not guarantee its inclusion in the project.
25 | > By submitting a Contribution, I grant Kyutai-labs a perpetual, worldwide license to use, modify,
26 | > reproduce, distribute, and create derivative works based on my Contributions.
27 | > I also agree to assign all patent rights for any inventions or improvements that arise from my Contributions,
28 | > giving the Kyutai-labs full rights to file for and enforce patents.
29 | > I understand that the Kyutai-labs may commercialize, relicense, or exploit the project and my Contributions without further notice or obligation to me.
30 | > I confirm that my Contributions are original and that I have the legal right to grant this license.
31 | > If my Contributions include third-party materials, I will ensure that I have the necessary permissions
32 | > and will disclose this information. I accept that once my Contributions are integrated, they may be altered or removed at the Kyutai-labs’s discretion.
33 | > I acknowledge that I am making these Contributions voluntarily and will not receive any compensation.
34 | > Furthermore, I understand that all Contributions, including mine, are provided on an "as-is" basis, with no warranties.
35 | > By submitting a pull request, I agree to be bound by these terms.
36 |
37 | ## Issues
38 |
39 | Please submit issues on our GitHub repository.
40 |
41 | ## License
42 |
43 | By contributing to Unmute, you agree that your contributions will be licensed under the MIT license.
44 | See the `LICENSE` file in the root directory of this source tree.
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM ghcr.io/astral-sh/uv:0.6.17-debian AS build
2 | WORKDIR /app
3 |
4 | ENV UV_COMPILE_BYTECODE=1 UV_LOCKED=1
5 |
6 | RUN --mount=type=bind,source=uv.lock,target=uv.lock \
7 | --mount=type=bind,source=pyproject.toml,target=pyproject.toml \
8 | uv run --no-dev echo hello
9 |
10 | COPY . .
11 | ENV HOSTNAME="0.0.0.0"
12 |
13 | HEALTHCHECK --start-period=15s \
14 | CMD curl --fail http://localhost:80/metrics || exit 1
15 |
16 | FROM build AS prod
17 | # Running through uvicorn directly to be able to deactive the Websocket per message deflate which is slowing
18 | # down the replies by a few ms.
19 | CMD ["uv", "run", "--no-dev", "uvicorn", "unmute.main_websocket:app", "--host", "0.0.0.0", "--port", "80", "--ws-per-message-deflate=false"]
20 |
21 |
22 | FROM build AS hot-reloading
23 | CMD ["uv", "run", "--no-dev", "uvicorn", "unmute.main_websocket:app", "--reload", "--host", "0.0.0.0", "--port", "80", "--ws-per-message-deflate=false"]
24 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 kyutai
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/bake_deploy_prod.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e # Exit on error
3 |
4 | uv run unmute/scripts/check_hugging_face_token_not_write.py $HUGGING_FACE_HUB_TOKEN
5 |
6 | expected_branch="prod"
7 |
8 | current_branch=$(git rev-parse --abbrev-ref HEAD)
9 | if [[ "$current_branch" != "$expected_branch" ]]; then
10 | echo "❌ You are on branch '$current_branch'. Please switch to '$expected_branch' before deploying."
11 | exit 1
12 | fi
13 |
14 | if [[ -n $(git status --porcelain) ]]; then
15 | echo "❌ You have uncommitted changes. Please commit or stash them before deploying."
16 | exit 1
17 | fi
18 |
19 | set -x # Print commands
20 |
21 | export DOMAIN=unmute.sh
22 | # Note that using non-Mistral models also requires changing the vLLM args in ./swarm-deploy.yml
23 | export KYUTAI_LLM_MODEL=mistralai/Mistral-Small-3.2-24B-Instruct-2506
24 | export DOCKER_HOST=ssh://root@${DOMAIN}
25 |
26 | echo "If you get an connection error, do: ssh root@${DOMAIN}"
27 |
28 | docker buildx bake -f ./swarm-deploy.yml --allow=ssh --push
29 | docker stack deploy --with-registry-auth --prune --compose-file ./swarm-deploy.yml llm-wrapper
30 |
--------------------------------------------------------------------------------
/bake_deploy_staging.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -ex
3 |
4 | uv run unmute/scripts/check_hugging_face_token_not_write.py $HUGGING_FACE_HUB_TOKEN
5 |
6 | export DOMAIN=unmute-staging.kyutai.io
7 | export KYUTAI_LLM_MODEL=google/gemma-3-4b-it
8 | export DOCKER_HOST=ssh://root@${DOMAIN}
9 |
10 | echo "If you get an connection error, do: ssh root@${DOMAIN}"
11 |
12 | docker buildx bake -f ./swarm-deploy.yml --allow=ssh --push
13 | docker stack deploy --with-registry-auth --prune --compose-file ./swarm-deploy.yml llm-wrapper
14 | docker service scale -d llm-wrapper_tts=1 llm-wrapper_llm=1
15 |
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | # See NOTE comments for places to modify.
2 | services:
3 | traefik:
4 | image: traefik:v3.3.1
5 | command:
6 | # Swarm provider configuration
7 | - "--providers.docker=true"
8 | - "--providers.docker.exposedbydefault=false"
9 |
10 | # This is set up for HTTP. If you want HTTPS support for production, use Docker Swarm
11 | # (check out swarm-deploy.yml) or ask ChatGPT to modify this file for you.
12 | - "--entrypoints.web.address=:80"
13 | ports:
14 | - "80:80"
15 | volumes:
16 | - "/var/run/docker.sock:/var/run/docker.sock:ro"
17 |
18 | frontend:
19 | image: unmute-frontend:latest
20 | build:
21 | context: frontend/
22 | dockerfile: hot-reloading.Dockerfile
23 | volumes:
24 | - ./frontend/src:/app/src
25 | labels:
26 | - "traefik.enable=true"
27 | - "traefik.http.routers.frontend.rule=PathPrefix(`/`)"
28 | - "traefik.http.routers.frontend.entrypoints=web"
29 | - "traefik.http.services.frontend.loadbalancer.server.port=3000"
30 | - "traefik.http.routers.frontend.priority=10" # lowest priority
31 |
32 | backend:
33 | image: unmute-backend:latest
34 | build:
35 | context: ./
36 | target: hot-reloading
37 | volumes:
38 | - ./unmute:/app/unmute
39 | environment:
40 | - KYUTAI_STT_URL=ws://stt:8080
41 | - KYUTAI_TTS_URL=ws://tts:8080
42 | - KYUTAI_LLM_URL=http://llm:8000
43 | - NEWSAPI_API_KEY=$NEWSAPI_API_KEY
44 | labels:
45 | - "traefik.enable=true"
46 | - "traefik.http.routers.backend.rule=PathPrefix(`/api`)"
47 | - "traefik.http.routers.backend.middlewares=strip-api"
48 | - "traefik.http.middlewares.strip-api.replacepathregex.regex=^/api/(.*)"
49 | - "traefik.http.middlewares.strip-api.replacepathregex.replacement=/$$1"
50 | - "traefik.http.routers.backend.entrypoints=web"
51 | - "traefik.http.services.backend.loadbalancer.server.port=80"
52 | - "traefik.http.routers.backend.priority=100" # higher priority than frontend
53 | - "prometheus-port=80"
54 |
55 | tts:
56 | image: moshi-server:latest
57 | command: ["worker", "--config", "configs/tts.toml"]
58 | build:
59 | context: services/moshi-server
60 | dockerfile: public.Dockerfile
61 | environment:
62 | - HUGGING_FACE_HUB_TOKEN=$HUGGING_FACE_HUB_TOKEN
63 | volumes:
64 | - ./volumes/hf-cache:/root/.cache/huggingface
65 | - ./volumes/cargo-registry-tts:/root/.cargo/registry
66 | - ./volumes/tts-target:/app/target
67 | - ./volumes/uv-cache:/root/.cache/uv
68 | - /tmp/models/:/models
69 | - ./volumes/tts-logs:/logs
70 | deploy:
71 | resources:
72 | reservations:
73 | devices:
74 | - driver: nvidia
75 | count: all
76 | capabilities: [gpu]
77 |
78 | stt:
79 | image: moshi-server:latest
80 | command: ["worker", "--config", "configs/stt.toml"]
81 | build:
82 | context: services/moshi-server
83 | dockerfile: public.Dockerfile
84 | environment:
85 | - HUGGING_FACE_HUB_TOKEN=$HUGGING_FACE_HUB_TOKEN
86 | volumes:
87 | - ./volumes/hf-cache:/root/.cache/huggingface
88 | - ./volumes/cargo-registry-stt:/root/.cargo/registry
89 | - ./volumes/stt-target:/app/target
90 | - ./volumes/uv-cache:/root/.cache/uv
91 | - /tmp/models/:/models
92 | - ./volumes/stt-logs:/logs
93 | deploy:
94 | resources:
95 | reservations:
96 | devices:
97 | - driver: nvidia
98 | count: all
99 | capabilities: [gpu]
100 |
101 | llm:
102 | image: vllm/vllm-openai:v0.9.1
103 | command:
104 | [
105 | # NOTE: Change the LLM here if you want.
106 | # (caution: gemma-3-1b-it also exists but it's slow on vLLM: https://github.com/vllm-project/vllm/issues/19575)
107 | "--model=meta-llama/Llama-3.2-1B-Instruct",
108 | # NOTE: You can adapt this based on your GPU memory.
109 | # A higher value takes more memory but supports longer conversations.
110 | "--max-model-len=1536",
111 | "--dtype=bfloat16",
112 | # NOTE: Change this based on your GPU memory.
113 | # A higher value can make inference faster.
114 | "--gpu-memory-utilization=0.4",
115 | ]
116 | volumes:
117 | - ./volumes/hf-cache:/root/.cache/huggingface
118 | - ./volumes/vllm-cache:/root/.cache/vllm
119 | environment:
120 | - HUGGING_FACE_HUB_TOKEN=$HUGGING_FACE_HUB_TOKEN
121 | deploy:
122 | resources:
123 | reservations:
124 | devices:
125 | - driver: nvidia
126 | count: all
127 | capabilities: [gpu]
128 |
129 | networks:
130 | default:
--------------------------------------------------------------------------------
/dockerless/.gitignore:
--------------------------------------------------------------------------------
1 | # This is part of a hack to get dependencies needed for the TTS Rust server, because it integrates a Python component.
2 | pyproject.toml
3 | uv.lock
--------------------------------------------------------------------------------
/dockerless/start_backend.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -ex
3 | cd "$(dirname "$0")/.."
4 |
5 | uv run uvicorn unmute.main_websocket:app --reload --host 0.0.0.0 --port 8000 --ws-per-message-deflate=false
6 |
--------------------------------------------------------------------------------
/dockerless/start_frontend.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -ex
3 | cd "$(dirname "$0")/.."
4 |
5 | cd frontend
6 | pnpm install
7 | pnpm env use --global lts
8 | pnpm dev
9 |
--------------------------------------------------------------------------------
/dockerless/start_llm.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -ex
3 | cd "$(dirname "$0")/.."
4 |
5 | uv tool run vllm@v0.9.1 serve \
6 | --model=google/gemma-3-1b-it \
7 | --max-model-len=8192 \
8 | --dtype=bfloat16 \
9 | --gpu-memory-utilization=0.3 \
10 | --port=8091
11 |
--------------------------------------------------------------------------------
/dockerless/start_stt.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -ex
3 | cd "$(dirname "$0")/.."
4 |
5 | # A fix for building Sentencepiece on GCC 15, see: https://github.com/google/sentencepiece/issues/1108
6 | export CXXFLAGS="-include cstdint"
7 |
8 | cargo install --features cuda moshi-server@0.6.4
9 | moshi-server worker --config services/moshi-server/configs/stt.toml --port 8090
10 |
--------------------------------------------------------------------------------
/dockerless/start_tts.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -ex
3 | cd "$(dirname "$0")/"
4 |
5 | # This is part of a hack to get dependencies needed for the TTS Rust server, because it integrates a Python component
6 | [ -f pyproject.toml ] || wget https://raw.githubusercontent.com/kyutai-labs/moshi/9837ca328d58deef5d7a4fe95a0fb49c902ec0ae/rust/moshi-server/pyproject.toml
7 | [ -f uv.lock ] || wget https://raw.githubusercontent.com/kyutai-labs/moshi/9837ca328d58deef5d7a4fe95a0fb49c902ec0ae/rust/moshi-server/uv.lock
8 |
9 | uv venv
10 | source .venv/bin/activate
11 |
12 | cd ..
13 |
14 | # This env var must be set to get the correct environment for the Rust build.
15 | # Must be set before running `cargo install`!
16 | # If you don't have it, you'll see an error like `no module named 'huggingface_hub'`
17 | # or similar, which means you don't have the necessary Python packages installed.
18 | export LD_LIBRARY_PATH=$(python -c 'import sysconfig; print(sysconfig.get_config_var("LIBDIR"))')
19 |
20 | # A fix for building Sentencepiece on GCC 15, see: https://github.com/google/sentencepiece/issues/1108
21 | export CXXFLAGS="-include cstdint"
22 |
23 | # If you already have moshi-server installed and things are not working because of the LD_LIBRARY_PATH issue,
24 | # you might have to force a rebuild with --force.
25 | cargo install --features cuda moshi-server@0.6.4
26 |
27 | # If you're getting `moshi-server: error: unrecognized arguments: worker`, it means you're
28 | # using the binary from the `moshi` Python package rather than from the Rust package.
29 | # Use `pip install moshi --upgrade` to update the Python package to >=0.2.8.
30 | uv run --locked --project ./dockerless moshi-server worker --config services/moshi-server/configs/tts.toml --port 8089
31 |
--------------------------------------------------------------------------------
/docs/browser_backend_communication.md:
--------------------------------------------------------------------------------
1 | # Browser-backend communication protocol
2 |
3 | This document explains how the browser frontend and backend service communicate through WebSocket connections in the Unmute system.
4 |
5 | ## Overview
6 |
7 | Unmute uses a WebSocket-based protocol inspired by the [OpenAI Realtime API](https://platform.openai.com/docs/api-reference/realtime) for real-time voice conversations. The protocol handles:
8 |
9 | - Real-time audio streaming (bidirectional)
10 | - Voice conversation transcription
11 | - Session configuration
12 | - Error handling and debugging
13 |
14 | ## WebSocket connection
15 |
16 | ### Endpoint
17 | - **URL**: `/v1/realtime`
18 | - **Protocol**: `realtime` (specified in WebSocket subprotocol)
19 | - **Port**: 8000 (development), routed through Traefik in Docker Swarm and Compose. Traefik uses http (port 80) and https (port 443).
20 |
21 | ### Connection setup
22 |
23 | The WebSocket connection is established using the `realtime` subprotocol. See implementation details in:
24 | - **Frontend**: [`frontend/src/app/Unmute.tsx`](../frontend/src/app/Unmute.tsx)
25 | - **Backend**: [`unmute/main_websocket.py`](../unmute/main_websocket.py)
26 |
27 | ## Message protocol
28 |
29 | All messages are JSON-encoded with a common structure defined in [`unmute/openai_realtime_api_events.py`](../unmute/openai_realtime_api_events.py).
30 |
31 | ### Base message structure
32 |
33 | All messages inherit from [`BaseEvent`](https://github.com/kyutai-labs/unmute/blob/main/unmute/openai_realtime_api_events.py#L32-L50) which provides a common type and event_id structure.
34 |
35 | ## Client → server messages
36 |
37 | ### 1. Audio input streaming
38 |
39 | **Message Type**: `input_audio_buffer.append`
40 |
41 | **Purpose**: Stream real-time audio data from microphone to backend
42 |
43 | **Model**: [`InputAudioBufferAppend`](https://github.com/kyutai-labs/unmute/blob/main/unmute/openai_realtime_api_events.py#L80-L81)
44 |
45 | **Audio Format**:
46 | - **Codec**: Opus
47 | - **Sample Rate**: 24kHz
48 | - **Channels**: Mono
49 | - **Encoding**: Base64-encoded bytes
50 |
51 | ### 2. Session configuration
52 |
53 | **Message Type**: `session.update`
54 |
55 | **Purpose**: Configure voice character and conversation instructions. The backend will not start sending messages until it gets a session.update message that sets its instructions.
56 |
57 | **Models**:
58 | - [`SessionUpdate`](https://github.com/kyutai-labs/unmute/blob/main/unmute/openai_realtime_api_events.py#L72-L73)
59 | - [`SessionConfig`](https://github.com/kyutai-labs/unmute/blob/main/unmute/openai_realtime_api_events.py#L66-L69)
60 |
61 | ## Server → client messages
62 |
63 | ### 1. Audio response streaming
64 |
65 | **Message Type**: `response.audio.delta`
66 |
67 | **Purpose**: Stream generated speech audio to frontend
68 |
69 | **Model**: [`ResponseAudioDelta`](https://github.com/kyutai-labs/unmute/blob/main/unmute/openai_realtime_api_events.py#L133-L134)
70 |
71 | ### 2. Speech transcription
72 |
73 | **Message Type**: `conversation.item.input_audio_transcription.delta`
74 |
75 | **Purpose**: Real-time transcription of user speech
76 |
77 | **Model**: [`ConversationItemInputAudioTranscriptionDelta`](https://github.com/kyutai-labs/unmute/blob/main/unmute/openai_realtime_api_events.py#L147-L151)
78 |
79 | ### 3. Text response streaming
80 |
81 | **Message Type**: `response.text.delta`
82 |
83 | **Purpose**: Stream generated text responses (for display/debugging)
84 |
85 | **Model**: [`ResponseTextDelta`](https://github.com/kyutai-labs/unmute/blob/main/unmute/openai_realtime_api_events.py#L125-L126)
86 |
87 | ### 4. Speech detection events
88 |
89 | **Message Types**:
90 | - `input_audio_buffer.speech_started`
91 | - `input_audio_buffer.speech_stopped`
92 |
93 | **Purpose**: Indicate when user starts/stops speaking (for UI feedback). In Unmute we actually just ignore these events at the moment, even though we report them.
94 |
95 | **Models**:
96 | - [`InputAudioBufferSpeechStarted`](https://github.com/kyutai-labs/unmute/blob/main/unmute/openai_realtime_api_events.py#L95-L105)
97 | - [`InputAudioBufferSpeechStopped`](https://github.com/kyutai-labs/unmute/blob/main/unmute/openai_realtime_api_events.py#L108-L111)
98 |
99 | ### 5. Response status updates
100 |
101 | **Message Type**: `response.created`
102 |
103 | **Purpose**: Indicate when assistant starts generating a response
104 |
105 | **Models**:
106 | - [`ResponseCreated`](https://github.com/kyutai-labs/unmute/blob/main/unmute/openai_realtime_api_events.py#L121-L122)
107 | - [`Response`](https://github.com/kyutai-labs/unmute/blob/main/unmute/openai_realtime_api_events.py#L114-L118)
108 |
109 | ### 6. Error handling
110 |
111 | **Message Type**: `error`
112 |
113 | **Purpose**: Communicate errors and warnings
114 |
115 | **Models**:
116 | - [`Error`](https://github.com/kyutai-labs/unmute/blob/main/unmute/openai_realtime_api_events.py#L62-L63)
117 | - [`ErrorDetails`](https://github.com/kyutai-labs/unmute/blob/main/unmute/openai_realtime_api_events.py#L53-L59)
118 |
119 | ## Connection lifecycle
120 |
121 | 1. **Health Check**: Frontend checks `/v1/health` endpoint
122 | 2. **WebSocket Connection**: Establish connection with `realtime` protocol
123 | 3. **Session Setup**: Send `session.update` with voice and instructions
124 | 4. **Audio Streaming**: Bidirectional real-time audio communication
125 | 5. **Graceful Shutdown**: Handle disconnection and cleanup
126 |
127 |
--------------------------------------------------------------------------------
/frontend/.gitignore:
--------------------------------------------------------------------------------
1 | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
2 |
3 | # dependencies
4 | /node_modules
5 | /.pnp
6 | .pnp.*
7 | .yarn/*
8 | !.yarn/patches
9 | !.yarn/plugins
10 | !.yarn/releases
11 | !.yarn/versions
12 |
13 | # testing
14 | /coverage
15 |
16 | # next.js
17 | /.next/
18 | /out/
19 |
20 | # production
21 | /build
22 |
23 | # misc
24 | .DS_Store
25 | *.pem
26 |
27 | # debug
28 | npm-debug.log*
29 | yarn-debug.log*
30 | yarn-error.log*
31 | .pnpm-debug.log*
32 |
--------------------------------------------------------------------------------
/frontend/Dockerfile:
--------------------------------------------------------------------------------
1 | # syntax=docker.io/docker/dockerfile:1
2 | # Taken from: https://github.com/vercel/next.js/tree/aebf26f7923c7c4da7734798048bf48d2e57b521/examples/with-docker
3 |
4 | FROM node:18-alpine AS base
5 |
6 | # Install dependencies only when needed
7 | FROM base AS deps
8 | # Check https://github.com/nodejs/docker-node/tree/b4117f9333da4138b03a546ec926ef50a31506c3#nodealpine to understand why libc6-compat might be needed.
9 | RUN apk add --no-cache libc6-compat
10 | WORKDIR /app
11 |
12 | # Install dependencies based on the preferred package manager
13 | COPY package.json yarn.lock* package-lock.json* pnpm-lock.yaml* .npmrc* ./
14 | RUN corepack enable pnpm && pnpm i --frozen-lockfile
15 |
16 |
17 | # Rebuild the source code only when needed
18 | FROM base AS builder
19 | WORKDIR /app
20 | COPY --from=deps /app/node_modules ./node_modules
21 | COPY . .
22 |
23 | # Next.js collects completely anonymous telemetry data about general usage.
24 | # Learn more here: https://nextjs.org/telemetry
25 | # Uncomment the following line in case you want to disable telemetry during the build.
26 | ENV NEXT_TELEMETRY_DISABLED=1
27 |
28 | ENV NEXT_PUBLIC_IN_DOCKER=true
29 |
30 | RUN corepack enable pnpm && pnpm run build
31 |
32 | # Production image, copy all the files and run next
33 | FROM base AS runner
34 | WORKDIR /app
35 |
36 | ENV NODE_ENV=production
37 | # Uncomment the following line in case you want to disable telemetry during runtime.
38 | ENV NEXT_TELEMETRY_DISABLED=1
39 |
40 | RUN apk add --no-cache curl
41 |
42 | RUN addgroup --system --gid 1001 nodejs
43 | RUN adduser --system --uid 1001 nextjs
44 |
45 | COPY --from=builder /app/public ./public
46 |
47 | # Automatically leverage output traces to reduce image size
48 | # https://nextjs.org/docs/advanced-features/output-file-tracing
49 | COPY --from=builder --chown=nextjs:nodejs /app/.next/standalone ./
50 | COPY --from=builder --chown=nextjs:nodejs /app/.next/static ./.next/static
51 |
52 | USER nextjs
53 |
54 | EXPOSE 3000
55 |
56 | ENV PORT=3000
57 |
58 | HEALTHCHECK --start-period=15s \
59 | CMD curl --fail http://localhost:3000/ || exit 1
60 |
61 | # server.js is created by next build from the standalone output
62 | # https://nextjs.org/docs/pages/api-reference/config/next-config-js/output
63 | ENV HOSTNAME="0.0.0.0"
64 | CMD ["node", "server.js"]
--------------------------------------------------------------------------------
/frontend/README.md:
--------------------------------------------------------------------------------
1 | # Unmute frontend
2 |
3 | This is the frontend for Unmute, written in Next.js.
4 |
5 | Use `pnpm` to install:
6 |
7 | ```bash
8 | pnpm install
9 | # if you don't have Node:
10 | pnpm env use --global lts
11 | ```
12 |
13 | Then run:
14 |
15 | ```bash
16 | pnpm run dev
17 | ```
18 |
--------------------------------------------------------------------------------
/frontend/eslint.config.mjs:
--------------------------------------------------------------------------------
1 | import { dirname } from "path";
2 | import { fileURLToPath } from "url";
3 | import { FlatCompat } from "@eslint/eslintrc";
4 |
5 | const __filename = fileURLToPath(import.meta.url);
6 | const __dirname = dirname(__filename);
7 |
8 | const compat = new FlatCompat({
9 | baseDirectory: __dirname,
10 | });
11 |
12 | const eslintConfig = [
13 | ...compat.extends("next/core-web-vitals", "next/typescript"),
14 | {
15 | rules: {
16 | '@next/next/no-img-element': 'off',
17 | },
18 | },
19 | ];
20 |
21 | export default eslintConfig;
22 |
--------------------------------------------------------------------------------
/frontend/hot-reloading.Dockerfile:
--------------------------------------------------------------------------------
1 | # syntax=docker.io/docker/dockerfile:1
2 |
3 | FROM node:18-alpine AS dev
4 |
5 | # Install required dependencies
6 | RUN apk add --no-cache libc6-compat curl
7 |
8 | # Set working directory
9 | WORKDIR /app
10 |
11 | # Install dependencies using the package manager (detected automatically via lockfile)
12 | COPY package.json tsconfig.json yarn.lock* package-lock.json* pnpm-lock.yaml* .npmrc* postcss.config.mjs ./
13 | COPY public/ ./public/
14 | RUN corepack enable pnpm && pnpm i --frozen-lockfile
15 |
16 | # Expose the port the dev server runs on
17 | EXPOSE 3000
18 |
19 | # Set environment variables
20 | ENV NODE_ENV=development
21 | ENV NEXT_TELEMETRY_DISABLED=1
22 | ENV HOSTNAME=0.0.0.0
23 | ENV PORT=3000
24 | ENV NEXT_PUBLIC_IN_DOCKER=true
25 |
26 | HEALTHCHECK --start-period=15s \
27 | CMD curl --fail http://localhost:3000/ || exit 1
28 |
29 | # The source code will be mounted as a volume, so no need to copy it here
30 | # Default command to run the development server with hot reloading
31 | CMD ["pnpm", "dev"]
32 |
--------------------------------------------------------------------------------
/frontend/next.config.ts:
--------------------------------------------------------------------------------
1 | import createMDX from "@next/mdx";
2 | import type { NextConfig } from "next";
3 |
4 | const nextConfig: NextConfig = {
5 | output: "standalone", // For Docker
6 | // Configure `pageExtensions` to include markdown and MDX files
7 | pageExtensions: ["js", "jsx", "md", "mdx", "ts", "tsx"],
8 | };
9 |
10 | const withMDX = createMDX({
11 | // markdown plugins go here
12 | });
13 |
14 | // Merge MDX config with Next.js config
15 | export default withMDX(nextConfig);
16 |
--------------------------------------------------------------------------------
/frontend/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "tts-demo",
3 | "version": "0.1.0",
4 | "private": true,
5 | "scripts": {
6 | "dev": "next dev",
7 | "build": "next build",
8 | "start": "next start",
9 | "lint": "next lint"
10 | },
11 | "dependencies": {
12 | "@mdx-js/loader": "^3.1.0",
13 | "@mdx-js/react": "^3.1.0",
14 | "@next/mdx": "^15.3.4",
15 | "@next/third-parties": "^15.3.2",
16 | "@types/http-proxy": "^1.17.16",
17 | "@types/mdx": "^2.0.13",
18 | "bcryptjs": "^3.0.2",
19 | "clsx": "^2.1.1",
20 | "http-proxy": "^1.18.1",
21 | "lucide-react": "^0.503.0",
22 | "next": "15.2.2",
23 | "opus-recorder": "^8.0.5",
24 | "pretty-print-json": "^3.0.4",
25 | "react": "^19.0.0",
26 | "react-dom": "^19.0.0",
27 | "react-use-websocket": "^4.13.0"
28 | },
29 | "devDependencies": {
30 | "@eslint/eslintrc": "^3",
31 | "@next/eslint-plugin-next": "^15.3.2",
32 | "@tailwindcss/postcss": "^4",
33 | "@types/bcrypt": "^5.0.2",
34 | "@types/node": "^20",
35 | "@types/react": "^19",
36 | "@types/react-dom": "^19",
37 | "eslint": "^9.24.0",
38 | "eslint-config-next": "15.2.2",
39 | "eslint-plugin-react-hooks": "^5.2.0",
40 | "tailwindcss": "^4",
41 | "typescript": "^5"
42 | },
43 | "packageManager": "pnpm@10.7.1+sha512.2d92c86b7928dc8284f53494fb4201f983da65f0fb4f0d40baafa5cf628fa31dae3e5968f12466f17df7e97310e30f343a648baea1b9b350685dafafffdf5808"
44 | }
45 |
--------------------------------------------------------------------------------
/frontend/postcss.config.mjs:
--------------------------------------------------------------------------------
1 | const config = {
2 | plugins: ["@tailwindcss/postcss"],
3 | };
4 |
5 | export default config;
6 |
--------------------------------------------------------------------------------
/frontend/public/decoderWorker.min.wasm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/unmute/d85fbf05a90d2f019a10bf521c1c029d3208b280/frontend/public/decoderWorker.min.wasm
--------------------------------------------------------------------------------
/frontend/src/app/CouldNotConnect.tsx:
--------------------------------------------------------------------------------
1 | import clsx from "clsx";
2 | import UnmuteHeader from "./UnmuteHeader";
3 |
4 | export type HealthStatus = {
5 | connected: "no" | "yes_request_ok" | "yes_request_fail";
6 | ok: boolean;
7 | tts_up?: boolean;
8 | stt_up?: boolean;
9 | llm_up?: boolean;
10 | voice_cloning_up?: boolean;
11 | };
12 |
13 | const renderServiceStatus = (
14 | name: string,
15 | status: string | boolean | undefined,
16 | necessary: boolean = true
17 | ) => {
18 | if (status === undefined) {
19 | status = "Unknown";
20 | } else if (status === true) {
21 | status = "Up";
22 | } else if (status === false) {
23 | status = "Down";
24 | }
25 |
26 | return (
27 |
28 | {name}:{" "}
29 |
38 | {status}
39 |
40 |
41 | );
42 | };
43 |
44 | const humanReadableStatus = {
45 | no: "Down",
46 | yes_request_ok: "Up",
47 | yes_request_fail: "Up, but with errors",
48 | };
49 |
50 | const CouldNotConnect = ({ healthStatus }: { healthStatus: HealthStatus }) => {
51 | if (healthStatus.ok) {
52 | return null;
53 | }
54 |
55 | return (
56 |
57 |
58 |
59 |
{"Couldn't connect :("}
60 |
Service status:
61 | {renderServiceStatus(
62 | "Backend",
63 | humanReadableStatus[healthStatus.connected]
64 | )}
65 | {renderServiceStatus("STT", healthStatus.stt_up)}
66 | {renderServiceStatus("LLM", healthStatus.llm_up)}
67 | {renderServiceStatus("TTS", healthStatus.tts_up)}
68 | {renderServiceStatus(
69 | "Voice cloning",
70 | healthStatus.voice_cloning_up,
71 | false
72 | )}
73 |
74 |
75 | );
76 | };
77 |
78 | export default CouldNotConnect;
79 |
--------------------------------------------------------------------------------
/frontend/src/app/ErrorMessages.tsx:
--------------------------------------------------------------------------------
1 | import React, { useEffect } from "react";
2 | import { X } from "lucide-react";
3 |
4 | export interface ErrorItem {
5 | id: string;
6 | message: string;
7 | timestamp: number;
8 | }
9 |
10 | export const makeErrorItem = (message: string): ErrorItem => {
11 | const timestamp = Date.now();
12 | return {
13 | id: `${timestamp}-${Math.random()}`,
14 | message,
15 | timestamp,
16 | };
17 | };
18 |
19 | const ERROR_TIMEOUT_SEC = 10;
20 |
21 | export default function ErrorMessages({
22 | errors,
23 | setErrors,
24 | }: {
25 | errors: ErrorItem[];
26 | setErrors: React.Dispatch>;
27 | }) {
28 | // Auto-dismiss errors after 10 seconds
29 | useEffect(() => {
30 | const interval = setInterval(() => {
31 | setErrors((prev) => {
32 | const now = Date.now();
33 | const filtered = prev.filter(
34 | (error) => now - error.timestamp < ERROR_TIMEOUT_SEC * 1000
35 | );
36 | return filtered;
37 | });
38 | }, 1000);
39 |
40 | return () => clearInterval(interval);
41 | }, [setErrors]);
42 |
43 | const handleDismiss = (index: number, errorId: string) => {
44 | setErrors((prev) => prev.filter((error) => error.id !== errorId));
45 | };
46 |
47 | if (errors.length === 0) {
48 | return null;
49 | }
50 |
51 | return (
52 |
53 | {errors.map((error, index) => (
54 |
59 |
60 |
61 |
62 | {error.message}
63 |
64 |
65 |
handleDismiss(index, error.id)}
67 | className="flex-shrink-0 text-red-600 hover:text-red-800 transition-colors"
68 | aria-label="Dismiss error"
69 | >
70 |
71 |
72 |
73 |
74 | ))}
75 |
76 | );
77 | }
78 |
--------------------------------------------------------------------------------
/frontend/src/app/PositionedAudioVisualizer.tsx:
--------------------------------------------------------------------------------
1 | import clsx from "clsx";
2 | import { ChatMessage } from "./chatHistory";
3 | import { useAudioVisualizerCircle } from "./useAudioVisualizerCircle";
4 | import { useEffect, useRef } from "react";
5 |
6 | const PositionedAudioVisualizer = ({
7 | chatHistory,
8 | role,
9 | analyserNode,
10 | isConnected,
11 | onCircleClick,
12 | }: {
13 | chatHistory: ChatMessage[];
14 | role: "user" | "assistant";
15 | analyserNode: AnalyserNode | null;
16 | isConnected: boolean;
17 | onCircleClick?: () => void;
18 | }) => {
19 | const canvasRef = useRef(null);
20 | const isAssistant = role === "assistant";
21 |
22 | useAudioVisualizerCircle(canvasRef, {
23 | chatHistory,
24 | role,
25 | analyserNode,
26 | isConnected,
27 | showPlayButton: !!onCircleClick,
28 | clearCanvas: true,
29 | });
30 |
31 | // Resize the canvas to fit its parent element
32 | useEffect(() => {
33 | const canvas = canvasRef.current;
34 | if (!canvas) return;
35 |
36 | const parent = canvas.parentElement;
37 | if (!parent) return;
38 |
39 | const size = Math.min(parent.clientWidth, parent.clientHeight);
40 |
41 | // If we don't do this `if` check, the recording ends up with flickering
42 | if (canvas.width !== size || canvas.height !== size) {
43 | canvas.width = size;
44 | canvas.height = size;
45 | }
46 | });
47 |
48 | return (
49 |
71 | );
72 | };
73 |
74 | export default PositionedAudioVisualizer;
75 |
--------------------------------------------------------------------------------
/frontend/src/app/SingleRoleSubtitles.tsx:
--------------------------------------------------------------------------------
1 | import clsx from "clsx";
2 | import React, { useCallback, useEffect, useRef, useState } from "react";
3 |
4 | const SingleRoleSubtitles = ({
5 | text,
6 | role,
7 | nLines = 3,
8 | }: {
9 | text: string;
10 | role: "user" | "assistant";
11 | nLines?: number;
12 | }) => {
13 | const containerRef = useRef(null);
14 | const [displayText, setDisplayText] = useState([]);
15 | const [previousText, setPreviousText] = useState("");
16 |
17 | const updateDisplayText = useCallback(() => {
18 | if (!containerRef.current) return;
19 |
20 | const container = containerRef.current;
21 | const containerWidth = container.clientWidth;
22 |
23 | // Create a temporary span to measure text width
24 | const tempSpan = document.createElement("span");
25 | tempSpan.style.visibility = "hidden";
26 | tempSpan.style.position = "absolute";
27 | tempSpan.style.whiteSpace = "nowrap";
28 | tempSpan.style.font = window.getComputedStyle(container).font;
29 | document.body.appendChild(tempSpan);
30 |
31 | const words = text.split(" ");
32 | const lines: string[] = [];
33 | let currentLine = "";
34 |
35 | // Build lines word by word
36 | for (const word of words) {
37 | const testLine = currentLine ? `${currentLine} ${word}` : word;
38 | tempSpan.textContent = testLine;
39 |
40 | if (tempSpan.offsetWidth <= containerWidth) {
41 | currentLine = testLine;
42 | } else {
43 | if (currentLine) {
44 | lines.push(currentLine);
45 | currentLine = word;
46 | } else {
47 | // Word is too long for one line
48 | lines.push(word);
49 | currentLine = "";
50 | }
51 | }
52 | }
53 |
54 | // Add the last line if it's not empty
55 | if (currentLine) {
56 | lines.push(currentLine);
57 | }
58 |
59 | // Remove the temporary span
60 | document.body.removeChild(tempSpan);
61 |
62 | const lastLines = lines.slice(-nLines);
63 | setDisplayText(lastLines);
64 | }, [nLines, text]);
65 |
66 | useEffect(() => {
67 | // If the new text is not a prefix of the old text, reset
68 | if (!text.startsWith(previousText)) {
69 | setDisplayText([]);
70 | }
71 |
72 | setPreviousText(text);
73 |
74 | updateDisplayText();
75 | }, [previousText, text, updateDisplayText]);
76 |
77 | // Re-calculate when the window resizes
78 | useEffect(() => {
79 | const handleResize = () => {
80 | updateDisplayText();
81 | };
82 |
83 | window.addEventListener("resize", handleResize);
84 | return () => {
85 | window.removeEventListener("resize", handleResize);
86 | };
87 | }, [text, updateDisplayText]);
88 |
89 | return (
90 | // Apply padding from the outside because otherwise we have to take it into
91 | // account when deciding how to break lines
92 |
98 |
99 | {displayText.map((line, index) => (
100 |
101 | {line}
102 |
103 | ))}
104 |
105 |
106 | );
107 | };
108 |
109 | export default SingleRoleSubtitles;
110 |
--------------------------------------------------------------------------------
/frontend/src/app/SlantedButton.tsx:
--------------------------------------------------------------------------------
1 | import React from "react";
2 | import clsx from "clsx";
3 |
4 | const SlantedButton = ({
5 | onClick = () => {},
6 | children,
7 | kind = "primary",
8 | style,
9 | extraClasses,
10 | }: {
11 | onClick?: () => void;
12 | children: React.ReactNode;
13 | kind?: "primary" | "secondary" | "disabled";
14 | style?: React.CSSProperties;
15 | extraClasses?: string;
16 | }) => {
17 | const kindToClass = {
18 | primary: "cursor-pointer after:bg-green text-black after:border-green",
19 | secondary:
20 | "cursor-pointer after:bg-darkgray text-white after:border-white after:border-dashed",
21 | disabled:
22 | "cursor-not-allowed after:bg-darkgray text-lightgray after:border-lightgray after:border-dashed",
23 | };
24 |
25 | return (
26 |
45 | {children}
46 |
47 | );
48 | };
49 |
50 | export default SlantedButton;
51 |
--------------------------------------------------------------------------------
/frontend/src/app/SquareButton.tsx:
--------------------------------------------------------------------------------
1 | import React from "react";
2 | import clsx from "clsx";
3 |
4 | const SquareButton = ({
5 | onClick = () => {},
6 | children,
7 | kind = "primary",
8 | extraClasses,
9 | }: {
10 | onClick?: () => void;
11 | children: React.ReactNode;
12 | kind?: "primary" | "primaryOff" | "secondary";
13 | extraClasses?: string;
14 | }) => {
15 | const kindToClass = {
16 | primary: "text-green border-green",
17 | primaryOff: "text-white border-white",
18 | secondary: "text-white border-transparent",
19 | };
20 |
21 | return (
22 |
35 | {/* The inner span ensures the content overflows in a centered way */}
36 | {children}
37 |
38 | );
39 | };
40 |
41 | export default SquareButton;
42 |
--------------------------------------------------------------------------------
/frontend/src/app/Subtitles.tsx:
--------------------------------------------------------------------------------
1 | import { ChatMessage } from "./chatHistory";
2 | import SingleRoleSubtitles from "./SingleRoleSubtitles";
3 |
4 | const Subtitles = ({ chatHistory }: { chatHistory: ChatMessage[] }) => {
5 | const lastAssistantMessage = chatHistory.findLast(
6 | (message) => message.role === "assistant" && message.content !== ""
7 | );
8 | const lastUserMessage = chatHistory.findLast(
9 | (message) => message.role === "user" && message.content !== ""
10 | );
11 |
12 | return (
13 |
14 |
18 |
19 |
20 | );
21 | };
22 |
23 | export default Subtitles;
24 |
--------------------------------------------------------------------------------
/frontend/src/app/TrimmedAudioPreview.tsx:
--------------------------------------------------------------------------------
1 | import { memo, useRef, useState } from "react";
2 | import { MIC_RECORDING_FILENAME } from "./VoiceRecorder";
3 |
4 | const TrimmedAudioPreviewUnmemoized = ({ file }: { file: File }) => {
5 | const audioRef = useRef(null);
6 | const [duration, setDuration] = useState(null);
7 | const maxDurationSec = 10;
8 |
9 | const handleTimeUpdate = () => {
10 | if (audioRef.current && audioRef.current.currentTime >= maxDurationSec) {
11 | // If playing, restart the playhead to 0 so that you can just press the play
12 | // button to play again. If paused, max duration to indicate trimming
13 | audioRef.current.currentTime = audioRef.current.paused
14 | ? maxDurationSec
15 | : 0;
16 |
17 | audioRef.current.pause();
18 | }
19 | };
20 |
21 | const handleDurationChange = () => {
22 | setDuration(audioRef.current?.duration || null);
23 | };
24 |
25 | return (
26 |
27 | {file.name !== MIC_RECORDING_FILENAME && (
28 |
29 | Selected file: {file.name}
30 |
31 | )}
32 | {duration && duration > maxDurationSec + 1 && (
33 |
34 | Note that only the first {maxDurationSec} seconds {" "}
35 | will be used.
36 |
37 | )}
38 |
46 |
47 | );
48 | };
49 |
50 | // We memoize because otherwise the element resets playback
51 | // when we re-render
52 | const TrimmedAudioPreview = memo(TrimmedAudioPreviewUnmemoized);
53 |
54 | export default TrimmedAudioPreview;
55 |
--------------------------------------------------------------------------------
/frontend/src/app/VoiceAttribution.tsx:
--------------------------------------------------------------------------------
1 | import Link from "next/link";
2 | import { VoiceSample } from "./UnmuteConfigurator";
3 |
4 | const VoiceAttribution = ({ voice }: { voice: VoiceSample }) => {
5 | const inner = () => {
6 | if (voice.source.source_type === "file") {
7 | if (voice.source.description_link) {
8 | return (
9 |
15 | {voice.source.description ||
16 | "Source: " + voice.source.description_link}
17 |
18 | );
19 | } else if (voice.source.description) {
20 | return <>{voice.source.description}>;
21 | } else {
22 | // No description or link provided
23 | return <>>;
24 | }
25 | } else {
26 | return (
27 | <>
28 | The '{voice.name}' voice is based on{" "}
29 |
35 | this Freesound by {voice.source.sound_instance.username}
36 |
37 | .
38 | >
39 | );
40 | }
41 | };
42 | return {inner()}
;
43 | };
44 |
45 | export default VoiceAttribution;
46 |
--------------------------------------------------------------------------------
/frontend/src/app/VoiceRecorder.tsx:
--------------------------------------------------------------------------------
1 | import { useRef, useState } from "react";
2 | import SlantedButton from "./SlantedButton";
3 | import { convertWebmToWav } from "./audioUtil";
4 | import { Mic } from "lucide-react";
5 | import clsx from "clsx";
6 |
7 | export const MIC_RECORDING_FILENAME = "unmute-mic-recording.wav";
8 |
9 | export type RecordedAudio = {
10 | blobUrl: string;
11 | file: File;
12 | };
13 |
14 | const VoiceRecording = ({
15 | setRecordedAudio,
16 | setError,
17 | recordingDurationSec,
18 | onRecordingStarted,
19 | showProgress = true,
20 | }: {
21 | setRecordedAudio: (recordedAudio: RecordedAudio) => void;
22 | setError: (error: string | null) => void;
23 | recordingDurationSec: number;
24 | onRecordingStarted?: () => void;
25 | showProgress?: boolean;
26 | }) => {
27 | const [isRecording, setIsRecording] = useState(false);
28 | const [mediaRecorder, setMediaRecorder] = useState(
29 | null
30 | );
31 | const [recordingProgress, setRecordingProgress] = useState(0);
32 | const recordingIntervalRef = useRef(null);
33 | const audioChunksRef = useRef([]);
34 |
35 | const handleStartRecording = async () => {
36 | setError(null);
37 | onRecordingStarted?.();
38 | setRecordingProgress(0);
39 | try {
40 | const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
41 |
42 | // Prefer audio/wav if supported. The backend can't handle webm, so we need to convert it.
43 | // If neither is supported, don't specify and hope for the best. (That seems to work on Safari.)
44 | let mimeType = "";
45 | if (MediaRecorder.isTypeSupported("audio/wav")) {
46 | mimeType = "audio/wav";
47 | } else if (MediaRecorder.isTypeSupported("audio/webm")) {
48 | mimeType = "audio/webm";
49 | }
50 |
51 | const recorder = new MediaRecorder(stream, { mimeType });
52 | audioChunksRef.current = [];
53 | recorder.ondataavailable = (e) => {
54 | if (e.data.size > 0) {
55 | audioChunksRef.current.push(e.data);
56 | }
57 | };
58 | recorder.onstop = async () => {
59 | setRecordingProgress(0);
60 | if (recordingIntervalRef.current) {
61 | clearInterval(recordingIntervalRef.current);
62 | }
63 | const audioBlob = new Blob(audioChunksRef.current, { type: mimeType });
64 |
65 | let audioFile: File;
66 | if (mimeType === "audio/wav") {
67 | audioFile = new File([audioBlob], MIC_RECORDING_FILENAME, {
68 | type: "audio/wav",
69 | });
70 | } else {
71 | const wavBlob = await convertWebmToWav(audioBlob);
72 | audioFile = new File([wavBlob], MIC_RECORDING_FILENAME, {
73 | type: "audio/wav",
74 | });
75 | }
76 | const recordedAudio: RecordedAudio = {
77 | blobUrl: URL.createObjectURL(audioFile),
78 | file: audioFile,
79 | };
80 | setRecordedAudio(recordedAudio);
81 | };
82 | recorder.start();
83 | setMediaRecorder(recorder);
84 | setIsRecording(true);
85 |
86 | const start = Date.now();
87 | recordingIntervalRef.current = setInterval(() => {
88 | const elapsed = (Date.now() - start) / 1000;
89 | setRecordingProgress(Math.min(elapsed / recordingDurationSec, 1));
90 | }, 50);
91 |
92 | setTimeout(() => {
93 | if (recorder.state === "recording") {
94 | recorder.stop();
95 | setIsRecording(false);
96 | setMediaRecorder(null);
97 | }
98 | }, recordingDurationSec * 1000);
99 | } catch (err) {
100 | setError(
101 | err instanceof Error ? err.message : "Could not access microphone."
102 | );
103 | }
104 | };
105 |
106 | const handleStopRecording = () => {
107 | if (mediaRecorder) {
108 | mediaRecorder.stop();
109 | setIsRecording(false);
110 | setMediaRecorder(null);
111 | }
112 | setRecordingProgress(0);
113 | if (recordingIntervalRef.current) {
114 | clearInterval(recordingIntervalRef.current);
115 | }
116 | };
117 |
118 | return (
119 |
120 |
121 |
122 |
127 | {isRecording ? (
128 | "● Recording"
129 | ) : (
130 | <>
131 |
132 | Record
133 | >
134 | )}
135 |
136 |
137 |
138 | {showProgress && (
139 |
150 | )}
151 |
152 | );
153 | };
154 |
155 | export default VoiceRecording;
156 |
--------------------------------------------------------------------------------
/frontend/src/app/audioUtil.ts:
--------------------------------------------------------------------------------
1 | export const base64EncodeOpus = (opusData: Uint8Array) => {
2 | // Convert to base64
3 | let binary = "";
4 | for (let i = 0; i < opusData.byteLength; i++) {
5 | binary += String.fromCharCode(opusData[i]);
6 | }
7 | return window.btoa(binary);
8 | };
9 |
10 | export const base64DecodeOpus = (base64String: string): Uint8Array => {
11 | const binaryString = window.atob(base64String);
12 | const len = binaryString.length;
13 | const bytes = new Uint8Array(len);
14 | for (let i = 0; i < len; i++) {
15 | bytes[i] = binaryString.charCodeAt(i);
16 | }
17 | return bytes;
18 | };
19 |
20 | export const convertWebmToWav = async (webmBlob: Blob): Promise => {
21 | const arrayBuffer = await webmBlob.arrayBuffer();
22 | const AudioContextClass =
23 | window.AudioContext ||
24 | (window.hasOwnProperty("webkitAudioContext")
25 | ? (window as unknown as { webkitAudioContext: typeof AudioContext })
26 | .webkitAudioContext
27 | : undefined);
28 | if (!AudioContextClass) throw new Error("Web Audio API not supported");
29 | const audioCtx = new AudioContextClass();
30 | const audioBuffer = await audioCtx.decodeAudioData(arrayBuffer);
31 | // Encode to wav
32 | const wavBuffer = encodeWAV(audioBuffer);
33 | return new Blob([wavBuffer], { type: "audio/wav" });
34 | };
35 |
36 | // Helper: Encode AudioBuffer to WAV format
37 | export const encodeWAV = (audioBuffer: AudioBuffer): ArrayBuffer => {
38 | const numChannels = audioBuffer.numberOfChannels;
39 | const sampleRate = audioBuffer.sampleRate;
40 | const format = 1; // PCM
41 | const bitDepth = 16;
42 | const samples = audioBuffer.length * numChannels;
43 | const buffer = new ArrayBuffer(44 + samples * 2);
44 | const view = new DataView(buffer);
45 |
46 | // Write WAV header
47 | function writeString(view: DataView, offset: number, str: string) {
48 | for (let i = 0; i < str.length; i++) {
49 | view.setUint8(offset + i, str.charCodeAt(i));
50 | }
51 | }
52 | let offset = 0;
53 | writeString(view, offset, "RIFF");
54 | offset += 4;
55 | view.setUint32(offset, 36 + samples * 2, true);
56 | offset += 4;
57 | writeString(view, offset, "WAVE");
58 | offset += 4;
59 | writeString(view, offset, "fmt ");
60 | offset += 4;
61 | view.setUint32(offset, 16, true);
62 | offset += 4;
63 | view.setUint16(offset, format, true);
64 | offset += 2;
65 | view.setUint16(offset, numChannels, true);
66 | offset += 2;
67 | view.setUint32(offset, sampleRate, true);
68 | offset += 4;
69 | view.setUint32(offset, (sampleRate * numChannels * bitDepth) / 8, true);
70 | offset += 4;
71 | view.setUint16(offset, (numChannels * bitDepth) / 8, true);
72 | offset += 2;
73 | view.setUint16(offset, bitDepth, true);
74 | offset += 2;
75 | writeString(view, offset, "data");
76 | offset += 4;
77 | view.setUint32(offset, samples * 2, true);
78 | offset += 4;
79 |
80 | // Write PCM samples
81 | for (let ch = 0; ch < numChannels; ch++) {
82 | const channel = audioBuffer.getChannelData(ch);
83 | for (let i = 0; i < channel.length; i++) {
84 | const sample = Math.max(-1, Math.min(1, channel[i]));
85 | view.setInt16(
86 | 44 + (i * numChannels + ch) * 2,
87 | sample < 0 ? sample * 0x8000 : sample * 0x7fff,
88 | true
89 | );
90 | }
91 | }
92 | return buffer;
93 | };
94 |
--------------------------------------------------------------------------------
/frontend/src/app/chatHistory.ts:
--------------------------------------------------------------------------------
1 | export type ChatRole = "user" | "assistant" | "system";
2 |
3 | export type ChatMessage = {
4 | role: ChatRole;
5 | content: string;
6 | };
7 |
8 | /** If there are multiple messages from the same role in a row, combine them into one message */
9 | export const compressChatHistory = (
10 | chatHistory: ChatMessage[]
11 | ): ChatMessage[] => {
12 | const compressed: ChatMessage[] = [];
13 | for (const message of chatHistory) {
14 | if (
15 | compressed.length > 0 &&
16 | compressed[compressed.length - 1].role === message.role
17 | ) {
18 | compressed[compressed.length - 1].content += `\n${message.content}`;
19 | } else {
20 | compressed.push({ ...message });
21 | }
22 | }
23 | return compressed;
24 | };
25 |
--------------------------------------------------------------------------------
/frontend/src/app/cssUtil.ts:
--------------------------------------------------------------------------------
1 | export const getCSSVariable = (name: string) => {
2 | if (!name.startsWith("--")) {
3 | name = `--${name}`;
4 | }
5 |
6 | const variable = window
7 | .getComputedStyle(document.documentElement)
8 | .getPropertyValue(name)
9 | .trim();
10 |
11 | if (variable === "") {
12 | console.warn(`CSS variable ${name} not found`);
13 | }
14 | return variable;
15 | };
16 |
--------------------------------------------------------------------------------
/frontend/src/app/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/unmute/d85fbf05a90d2f019a10bf521c1c029d3208b280/frontend/src/app/favicon.ico
--------------------------------------------------------------------------------
/frontend/src/app/faviconKyutai.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/unmute/d85fbf05a90d2f019a10bf521c1c029d3208b280/frontend/src/app/faviconKyutai.ico
--------------------------------------------------------------------------------
/frontend/src/app/faviconKyutai.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/unmute/d85fbf05a90d2f019a10bf521c1c029d3208b280/frontend/src/app/faviconKyutai.png
--------------------------------------------------------------------------------
/frontend/src/app/globals.css:
--------------------------------------------------------------------------------
1 | @import "tailwindcss";
2 |
3 | :root {
4 | --green: #39F2AE;
5 | --offwhite: #EFEFEF;
6 | --lightgray: #9d9d9d;
7 | /* #B5BCC5; */
8 | --gray: #343434;
9 | /* #8D949E; */
10 | --darkgray: #191919;
11 | /* #1e293b; */
12 | --background: #121212;
13 |
14 | --orange: #f5ab4a;
15 | --pink: #f051db;
16 | --purple: #a57de8;
17 | /* for errors */
18 | --red: #f55142;
19 |
20 | --speaker-0: var(--green);
21 | --speaker-1: var(--orange);
22 | --speaker-2: var(--pink);
23 | --speaker-3: var(--purple);
24 |
25 | font-family:
26 | var(--font-satoshi),
27 | system-ui, Avenir, Helvetica, Arial, sans-serif;
28 | line-height: 1.5;
29 | font-weight: 500;
30 | font-size: 1.35rem;
31 |
32 | color-scheme: dark;
33 | color: var(--offwhite);
34 | background-color: var(--background);
35 |
36 | font-synthesis: none;
37 | text-rendering: optimizeLegibility;
38 | -webkit-font-smoothing: antialiased;
39 | -moz-osx-font-smoothing: grayscale;
40 | }
41 |
42 |
43 | /* TODO(vv): better place for the highlighted-word stuff? */
44 | .highlighted-word {
45 | color: black;
46 | }
47 |
48 | /* Create a pseudo-element for the background */
49 | .highlighted-word::after {
50 | content: "";
51 | position: absolute;
52 | background-color: var(--speaker-color, white);
53 | left: -0.25rem;
54 | right: -0.25rem;
55 | top: -0.25rem;
56 | bottom: -0.25rem;
57 | border-radius: 0.25rem;
58 | z-index: -1;
59 | }
60 |
61 | @theme {
62 | /* TODO: deduplicate */
63 | --color-green: var(--green);
64 | --color-offwhite: var(--offwhite);
65 | --color-lightgray: var(--lightgray);
66 | --color-gray: var(--gray);
67 | --color-darkgray: var(--darkgray);
68 | --color-orange: var(--orange);
69 | --color-pink: var(--pink);
70 | --color-purple: var(--purple);
71 | --color-red: var(--red);
72 | --color-background: var(--background);
73 | /* sometimes it's useful to have these as variables too */
74 | --color-white: #ffffff;
75 | --color-black: #000000;
76 |
77 | --font-family-satoshi: var(--font-satoshi);
78 | }
--------------------------------------------------------------------------------
/frontend/src/app/layout.tsx:
--------------------------------------------------------------------------------
1 | import type { Metadata } from "next";
2 | import "./globals.css";
3 | import localFont from "next/font/local";
4 | import ConsentModal from "./ConsentModal";
5 |
6 | export const metadata: Metadata = {
7 | title: "Unmute by Kyutai",
8 | description: "Make LLMs listen and speak.",
9 | };
10 |
11 | const satoshi = localFont({
12 | src: [
13 | {
14 | path: "../assets/fonts/Satoshi-Variable.woff2",
15 | weight: "300 900",
16 | style: "normal",
17 | },
18 | {
19 | path: "../assets/fonts/Satoshi-VariableItalic.woff2",
20 | weight: "300 900",
21 | style: "italic",
22 | },
23 | ],
24 | variable: "--font-satoshi",
25 | display: "swap",
26 | });
27 |
28 | export default function RootLayout({
29 | children,
30 | }: Readonly<{
31 | children: React.ReactNode;
32 | }>) {
33 | return (
34 |
35 |
36 | {/* Needed for debugging JSON styling */}
37 |
41 |
42 |
43 | {children}
44 |
45 |
46 |
47 | );
48 | }
49 |
--------------------------------------------------------------------------------
/frontend/src/app/opus-recorder.d.ts:
--------------------------------------------------------------------------------
1 | declare module "opus-recorder" {
2 | interface RecorderOptions {
3 | // Confusing naming here because MediaTrackConstraints also exists as a type,
4 | // but this is actually a MediaStreamConstraints object
5 | mediaTrackConstraints?: MediaStreamConstraints;
6 | encoderPath: string;
7 | bufferLength: number;
8 | encoderFrameSize: number;
9 | encoderSampleRate: number;
10 | maxFramesPerPage: number;
11 | numberOfChannels: number;
12 | recordingGain: number;
13 | resampleQuality: number;
14 | encoderComplexity: number;
15 | encoderApplication: number;
16 | streamPages: boolean;
17 | }
18 |
19 | export default class Recorder {
20 | constructor(options: RecorderOptions);
21 | start(): void;
22 | stop(): void;
23 | ondataavailable: (data: Uint8Array) => void;
24 | encodedSamplePosition: number;
25 | }
26 | }
27 |
28 |
29 | type DecoderWorker = Worker
--------------------------------------------------------------------------------
/frontend/src/app/page.tsx:
--------------------------------------------------------------------------------
1 | import Unmute from "./Unmute";
2 |
3 | export default function Home() {
4 | return (
5 |
6 |
7 |
8 | );
9 | }
10 |
--------------------------------------------------------------------------------
/frontend/src/app/useBackendServerUrl.ts:
--------------------------------------------------------------------------------
1 | import { useEffect, useState } from "react";
2 |
3 | export const useBackendServerUrl = () => {
4 | const [backendServerUrl, setBackendServerUrl] = useState(null);
5 |
6 | // Get the backend server URL. This is a bit involved to support different deployment methods.
7 | useEffect(() => {
8 | if (typeof window !== "undefined") {
9 | const isInDocker = ["true", "1"].includes(process.env.NEXT_PUBLIC_IN_DOCKER?.toLowerCase() || "");
10 |
11 | const prefix = isInDocker ? "/api" : "";
12 |
13 | const backendUrl = new URL("", window.location.href);
14 | if (!isInDocker) {
15 | backendUrl.port = "8000";
16 | }
17 | backendUrl.pathname = prefix;
18 | backendUrl.search = ""; // strip any query parameters
19 | setBackendServerUrl(backendUrl.toString().replace(/\/$/, "")); // remove trailing slash
20 | }
21 | }, []);
22 |
23 | return backendServerUrl;
24 | };
25 |
--------------------------------------------------------------------------------
/frontend/src/app/useGoogleAnalytics.ts:
--------------------------------------------------------------------------------
1 | import { useEffect, useRef } from "react";
2 | import { LanguageCode, UnmuteConfig } from "./UnmuteConfigurator";
3 | import { sendGAEvent } from "@next/third-parties/google";
4 |
5 | interface ConversationAnalyticsInfo {
6 | voice: string;
7 | voice_name: string;
8 | is_custom_voice: boolean;
9 | instructions: string;
10 | instructions_type: string;
11 | instructions_language: LanguageCode;
12 | is_custom_instructions: boolean;
13 | start_timestamp_ms: number;
14 | conversation_uuid: string;
15 | duration_sec?: number;
16 | }
17 |
18 | export function useGoogleAnalytics({
19 | shouldConnect,
20 | unmuteConfig,
21 | }: {
22 | shouldConnect: boolean;
23 | unmuteConfig: UnmuteConfig;
24 | }) {
25 | const conversationAnalyticsInfo = useRef(
26 | null
27 | );
28 | const unmuteConfigRef = useRef(unmuteConfig);
29 |
30 | // We keep the unmuteConfig in a ref because the useEffect that depends on it
31 | // should only run when shouldConnect changes, not when unmuteConfig changes.
32 | useEffect(() => {
33 | unmuteConfigRef.current = unmuteConfig;
34 | }, [unmuteConfig]);
35 |
36 | useEffect(() => {
37 | if (shouldConnect) {
38 | const config = unmuteConfigRef.current;
39 | const info = {
40 | voice: config.voice.startsWith("custom:") ? "custom" : config.voice,
41 | voice_name: config.voiceName,
42 | is_custom_voice: config.voice.startsWith("custom:"),
43 | instructions: JSON.stringify(config.instructions),
44 | instructions_language: config.instructions.language ?? "en",
45 | instructions_type: config.isCustomInstructions
46 | ? "constant_custom"
47 | : config.instructions.type,
48 | is_custom_instructions: config.isCustomInstructions,
49 | start_timestamp_ms: Date.now(),
50 | // Just to make it easy to pair with the end_conversation event
51 | conversation_uuid: crypto.randomUUID(),
52 | };
53 | conversationAnalyticsInfo.current = info;
54 |
55 | sendGAEvent("event", "start_conversation", info);
56 | } else {
57 | const info = conversationAnalyticsInfo.current;
58 | if (info) {
59 | info.duration_sec = (Date.now() - info.start_timestamp_ms) / 1000;
60 | sendGAEvent("event", "end_conversation", {
61 | ...info,
62 | });
63 | }
64 | }
65 | }, [shouldConnect]);
66 |
67 | const analyticsOnDownloadRecording = () => {
68 | const info = conversationAnalyticsInfo.current;
69 | if (info) {
70 | sendGAEvent("event", "download_recording", {
71 | ...info,
72 | });
73 | }
74 | };
75 |
76 | return { analyticsOnDownloadRecording };
77 | }
78 |
--------------------------------------------------------------------------------
/frontend/src/app/useKeyboardShortcuts.ts:
--------------------------------------------------------------------------------
1 | import { useEffect, useState } from "react";
2 |
3 | const ALLOW_DEV_MODE = false;
4 |
5 | const useKeyboardShortcuts = () => {
6 | // local storage persistence disabled in case random users activate it accidentally
7 | // useLocalStorage("useDevMode", false)
8 | const [isDevMode, setIsDevMode] = useState(false);
9 | // useLocalStorage("showSubtitles", false)
10 | const [showSubtitles, setShowSubtitles] = useState(false);
11 |
12 | useEffect(() => {
13 | const handleKeyDown = (event: KeyboardEvent) => {
14 | const activeElement = document.activeElement;
15 | // Don't toggle dev mode if the active element is an input field
16 | const isInputField =
17 | activeElement &&
18 | (activeElement.tagName === "INPUT" ||
19 | activeElement.tagName === "TEXTAREA" ||
20 | activeElement.getAttribute("contenteditable") === "true");
21 |
22 | if (
23 | ALLOW_DEV_MODE &&
24 | !isInputField &&
25 | (event.key === "D" || event.key === "d")
26 | ) {
27 | setIsDevMode((prev) => !prev);
28 | }
29 | if (!isInputField && (event.key === "S" || event.key === "s")) {
30 | setShowSubtitles((prev) => !prev);
31 | }
32 | };
33 |
34 | window.addEventListener("keydown", handleKeyDown);
35 | return () => {
36 | window.removeEventListener("keydown", handleKeyDown);
37 | };
38 | }, [setIsDevMode, setShowSubtitles]);
39 |
40 | return { isDevMode, showSubtitles };
41 | };
42 |
43 | export default useKeyboardShortcuts;
44 |
--------------------------------------------------------------------------------
/frontend/src/app/useLocalStorage.ts:
--------------------------------------------------------------------------------
1 | import { useState, useEffect } from "react";
2 |
3 | export const useLocalStorage = (
4 | key: string,
5 | defaultValue: T
6 | ): [T, React.Dispatch>] => {
7 | const [value, setValue] = useState(defaultValue);
8 |
9 | useEffect(() => {
10 | const saved = localStorage.getItem(key);
11 | const initial = saved ? (JSON.parse(saved) as T) : defaultValue;
12 | setValue(initial);
13 | }, [key, defaultValue]);
14 |
15 | useEffect(() => {
16 | localStorage.setItem(key, JSON.stringify(value));
17 | }, [key, value]);
18 |
19 | return [value, setValue];
20 | };
21 |
--------------------------------------------------------------------------------
/frontend/src/app/useMicrophoneAccess.ts:
--------------------------------------------------------------------------------
1 | import { useState, useCallback, useRef } from "react";
2 |
3 | type MicrophoneAccessType = "unknown" | "granted" | "refused";
4 |
5 | export const useMicrophoneAccess = () => {
6 | const [microphoneAccess, setMicrophoneAccess] =
7 | useState("unknown");
8 |
9 | const mediaStream = useRef(null);
10 |
11 | const askMicrophoneAccess = useCallback(async () => {
12 | try {
13 | mediaStream.current = await window.navigator.mediaDevices.getUserMedia({
14 | audio: {
15 | channelCount: 1,
16 | echoCancellation: true,
17 | autoGainControl: true,
18 | noiseSuppression: true,
19 | }
20 | });
21 | setMicrophoneAccess("granted");
22 | return mediaStream.current;
23 | } catch (e) {
24 | console.error(e);
25 | setMicrophoneAccess("refused");
26 | return null;
27 | }
28 | }, []);
29 |
30 | return {
31 | microphoneAccess,
32 | askMicrophoneAccess,
33 | mediaStream,
34 | };
35 | };
36 |
--------------------------------------------------------------------------------
/frontend/src/app/useWakeLock.ts:
--------------------------------------------------------------------------------
1 | import { useEffect, useRef } from "react";
2 |
3 | const useWakeLock = (shouldPreventSleep: boolean) => {
4 | const wakeLockRef = useRef(null);
5 |
6 | useEffect(() => {
7 | const requestWakeLock = async () => {
8 | try {
9 | if ("wakeLock" in navigator && shouldPreventSleep) {
10 | wakeLockRef.current = await navigator.wakeLock.request("screen");
11 | }
12 | } catch (err) {
13 | console.error("Failed to acquire wake lock:", err);
14 | }
15 | };
16 |
17 | const releaseWakeLock = () => {
18 | if (wakeLockRef.current) {
19 | wakeLockRef.current.release().catch((err) => {
20 | console.error("Failed to release wake lock:", err);
21 | });
22 | wakeLockRef.current = null;
23 | }
24 | };
25 |
26 | if (shouldPreventSleep) {
27 | requestWakeLock();
28 | } else {
29 | releaseWakeLock();
30 | }
31 |
32 | return () => {
33 | releaseWakeLock();
34 | };
35 | }, [shouldPreventSleep]);
36 | };
37 |
38 | export default useWakeLock;
39 |
--------------------------------------------------------------------------------
/frontend/src/app/voice-donation/DonationConsent.tsx:
--------------------------------------------------------------------------------
1 | import React from "react";
2 |
3 | const GreenLink = ({
4 | href,
5 | children,
6 | }: {
7 | href: string;
8 | children: React.ReactNode;
9 | }) => (
10 |
16 | {children}
17 |
18 | );
19 |
20 | const DonationConsent = ({
21 | setConsentGiven,
22 | }: {
23 | setConsentGiven: (value: boolean) => void;
24 | }) => {
25 | const [checks, setChecks] = React.useState([false, false, false]);
26 |
27 | React.useEffect(() => {
28 | setConsentGiven(checks.every(Boolean));
29 | }, [checks, setConsentGiven]);
30 |
31 | const handleCheck =
32 | (idx: number) => (e: React.ChangeEvent) => {
33 | const updated = [...checks];
34 | updated[idx] = e.target.checked;
35 | setChecks(updated);
36 | };
37 |
38 | return (
39 |
40 |
41 |
47 |
48 | I am at least 18 years old, I am also of legal age in my country of residence and I
49 | have read and I agree with Kyutai’s{" "}
50 | Terms and{" "}
51 |
52 | Privacy Policy
53 |
54 | . *
55 |
56 |
57 |
58 |
64 |
65 | I authorize Kyutai to collect, process and publish worldwide my voice
66 | recordings and embeddings as part of public datasets under a CC0
67 | license or similar open-source license, in accordance with Kyutai’s{" "}
68 |
69 | Privacy Policy
70 |
71 | . *
72 |
73 |
74 |
75 |
81 |
82 | I authorize Kyutai to use my voice recording and embedding worldwide
83 | to develop and train Kyutai’s AI models and make them available to the
84 | public, in accordance with Kyutai’s{" "}
85 |
86 | Privacy Policy
87 |
88 | . *
89 |
90 |
91 |
92 | );
93 | };
94 |
95 | export default DonationConsent;
96 |
--------------------------------------------------------------------------------
/frontend/src/app/voice-donation/IntroText.mdx:
--------------------------------------------------------------------------------
1 | Here you can donate a short recording of your voice in the context of our Unmute Voice Donation Project, an open-source text-to-speech initiative .
2 | These voices will be made available for use with Kyutai TTS.
3 | You can find the already-released voices in the [voice repository](https://huggingface.co/kyutai/tts-voices).
4 | By sharing a short recording of your voice, you help Kyutai to:
5 |
6 | - Provide voices for our open-science text-to-speech (TTS) models.
7 | - Build open vocal datasets that anyone can access and reuse.
8 |
9 | We value your privacy and transparency. Before proceeding, please review
10 | the following carefully:
11 |
12 | - Your voice recordings and data including voice embeddings
13 | (representations of vocal characteristics) may be collected,
14 | processed, and published openly by Kyutai. The resulting datasets will
15 | be publicly available under the Creative Commons CC0 license or any
16 | similar open-source license, allowing anyone to freely reuse them,
17 | subject to their compliance and with our Acceptable Use Policy.
18 | - Your voice may be made available for use with Kyutai Text-To-Speech.
19 | As a result, third parties could generate synthetic speech that
20 | closely resembles your voice. After public release, each third-party
21 | user will be directly responsible for its own use of your voice
22 | recording. If you do not want people to reuse your voice and reproduce
23 | it, you should not submit your voice recording.
24 |
25 | For more information, see the [Terms of Use](/voice-donation/terms-of-use) and
26 | [Privacy Policy](/voice-donation/privacy-policy).
27 |
28 | ## Verification
29 |
30 |
31 |
32 | Verification sentences
33 |
34 |
35 | Whatever you want (last 10 seconds will be used)
36 |
37 |
38 |
39 | To verify that this is your voice and not a pre-recorded sample, we will
40 | ask you to read a short text out loud. Afterwards, you can say whatever
41 | you want. Have fun with it! The TTS is good at reproducing the tone and
42 | mannerisms of the voice. The last 10 seconds of your recording will be
43 | used as the voice sample. Try to use the same tone throughout the
44 | recording to make it easier to verify that it's you.
45 |
--------------------------------------------------------------------------------
/frontend/src/app/voice-donation/privacy-policy/page.tsx:
--------------------------------------------------------------------------------
1 | // A redirection page, set up so that we can change the URL it points to later if needed.
2 | "use client";
3 | import { useEffect } from "react";
4 |
5 | const LINK =
6 | "https://kyutai.org/next/legal/Privacy%20Policy%20-%20Unmute%20Voice%20Donation%20Project%20v1.pdf";
7 |
8 | export default function TermsOfUseRedirect() {
9 | useEffect(() => {
10 | window.location.href = LINK;
11 | }, []);
12 | return (
13 |
17 | );
18 | }
19 |
--------------------------------------------------------------------------------
/frontend/src/app/voice-donation/terms-of-use/page.tsx:
--------------------------------------------------------------------------------
1 | // A redirection page, set up so that we can change the URL it points to later if needed.
2 | "use client";
3 | import { useEffect } from "react";
4 |
5 | const LINK =
6 | "https://kyutai.org/next/legal/Terms%20of%20Use%20-%20Unmute%20Voice%20Donation%20Project%20v1.pdf";
7 |
8 | export default function PrivacyPolicyRedirect() {
9 | useEffect(() => {
10 | window.location.href = LINK;
11 | }, []);
12 | return (
13 |
17 | );
18 | }
19 |
--------------------------------------------------------------------------------
/frontend/src/assets/fonts/Satoshi-Variable.eot:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/unmute/d85fbf05a90d2f019a10bf521c1c029d3208b280/frontend/src/assets/fonts/Satoshi-Variable.eot
--------------------------------------------------------------------------------
/frontend/src/assets/fonts/Satoshi-Variable.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/unmute/d85fbf05a90d2f019a10bf521c1c029d3208b280/frontend/src/assets/fonts/Satoshi-Variable.ttf
--------------------------------------------------------------------------------
/frontend/src/assets/fonts/Satoshi-Variable.woff:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/unmute/d85fbf05a90d2f019a10bf521c1c029d3208b280/frontend/src/assets/fonts/Satoshi-Variable.woff
--------------------------------------------------------------------------------
/frontend/src/assets/fonts/Satoshi-Variable.woff2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/unmute/d85fbf05a90d2f019a10bf521c1c029d3208b280/frontend/src/assets/fonts/Satoshi-Variable.woff2
--------------------------------------------------------------------------------
/frontend/src/assets/fonts/Satoshi-VariableItalic.eot:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/unmute/d85fbf05a90d2f019a10bf521c1c029d3208b280/frontend/src/assets/fonts/Satoshi-VariableItalic.eot
--------------------------------------------------------------------------------
/frontend/src/assets/fonts/Satoshi-VariableItalic.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/unmute/d85fbf05a90d2f019a10bf521c1c029d3208b280/frontend/src/assets/fonts/Satoshi-VariableItalic.ttf
--------------------------------------------------------------------------------
/frontend/src/assets/fonts/Satoshi-VariableItalic.woff:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/unmute/d85fbf05a90d2f019a10bf521c1c029d3208b280/frontend/src/assets/fonts/Satoshi-VariableItalic.woff
--------------------------------------------------------------------------------
/frontend/src/assets/fonts/Satoshi-VariableItalic.woff2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/unmute/d85fbf05a90d2f019a10bf521c1c029d3208b280/frontend/src/assets/fonts/Satoshi-VariableItalic.woff2
--------------------------------------------------------------------------------
/frontend/src/assets/kyutai-logo-cropped.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
--------------------------------------------------------------------------------
/frontend/src/mdx-components.tsx:
--------------------------------------------------------------------------------
1 | import type { MDXComponents } from "mdx/types";
2 |
3 | // This file allows you to provide custom React components
4 | // to be used in MDX files. You can import and use any
5 | // React component you want, including inline styles,
6 | // components from other libraries, and more.
7 |
8 | export function useMDXComponents(components: MDXComponents): MDXComponents {
9 | return {
10 | h1: ({ children }) => (
11 | {children}
12 | ),
13 | h2: ({ children }) => (
14 | {children}
15 | ),
16 | h3: ({ children }) => (
17 | {children}
18 | ),
19 | p: ({ children }) => (
20 | {children}
21 | ),
22 | li: ({ children }) => (
23 | {children}
24 | ),
25 | a: ({ href, children }) => (
26 |
27 | {children}
28 |
29 | ),
30 | strong: ({ children }) => (
31 | {children}
32 | ),
33 | // TODO more styling here
34 | ...components,
35 | };
36 | }
37 |
--------------------------------------------------------------------------------
/frontend/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | "target": "ES2017",
4 | "lib": ["dom", "dom.iterable", "esnext"],
5 | "allowJs": true,
6 | "skipLibCheck": true,
7 | "strict": true,
8 | "noEmit": true,
9 | "esModuleInterop": true,
10 | "module": "esnext",
11 | "moduleResolution": "bundler",
12 | "resolveJsonModule": true,
13 | "isolatedModules": true,
14 | "jsx": "preserve",
15 | "incremental": true,
16 | "plugins": [
17 | {
18 | "name": "next"
19 | }
20 | ],
21 | "paths": {
22 | "@/*": ["./src/*"]
23 | }
24 | },
25 | "include": ["next-env.d.ts", "**/*.ts", "**/*.tsx", ".next/types/**/*.ts"],
26 | "exclude": ["node_modules"]
27 | }
28 |
--------------------------------------------------------------------------------
/notebooks/.gitignore:
--------------------------------------------------------------------------------
1 | data/
2 |
--------------------------------------------------------------------------------
/notebooks/create-voice-donation-sentences.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "0",
6 | "metadata": {},
7 | "source": [
8 | "# Create voice donation sentences\n",
9 | "\n",
10 | "When people donate their voice, we ask them to read some verification sentences so that it's not possible to just use a pre-made recording.\n",
11 | "\n",
12 | "This notebook selects sentences by filtering from CommonVoice.\n"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": null,
18 | "id": "1",
19 | "metadata": {},
20 | "outputs": [],
21 | "source": [
22 | "import pandas as pd"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": null,
28 | "id": "2",
29 | "metadata": {},
30 | "outputs": [],
31 | "source": [
32 | "# Download here:\n",
33 | "# https://commonvoice.mozilla.org/en/datasets select \"Common Voice Delta Segment 21.0\"\n",
34 | "df = pd.read_csv(\n",
35 | " \"./data/cv-corpus-21.0-delta-2025-03-14/en/validated_sentences.tsv\", sep=\"\\t\"\n",
36 | ")"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": null,
42 | "id": "3",
43 | "metadata": {},
44 | "outputs": [],
45 | "source": [
46 | "def count_uppercase_letters(s):\n",
47 | " return sum(1 for c in s if c.isupper())\n",
48 | "\n",
49 | "\n",
50 | "def is_ascii(s):\n",
51 | " return all(ord(c) < 128 for c in s)\n",
52 | "\n",
53 | "\n",
54 | "def max_word_length(s):\n",
55 | " if not s.strip():\n",
56 | " return 0\n",
57 | "\n",
58 | " return max(len(word) for word in s.split())\n",
59 | "\n",
60 | "\n",
61 | "def is_ok_sentence(s):\n",
62 | " return (\n",
63 | " is_ascii(s)\n",
64 | " # Exclude long complicated words\n",
65 | " and max_word_length(s) <= 10\n",
66 | " # Proper names might be difficult to pronounce to non-native speakers\n",
67 | " # and are harder to check automatically, so we exclude them. The one capital\n",
68 | " # letter allowed is for the first letter of the sentence.\n",
69 | " and count_uppercase_letters(s) == 1\n",
70 | " and 30 <= len(s) <= 80\n",
71 | " # No questions or exclamations\n",
72 | " and s[-1] == \".\"\n",
73 | " )\n",
74 | "\n",
75 | "\n",
76 | "df[\"is_ok_sentence\"] = df[\"sentence\"].apply(is_ok_sentence)"
77 | ]
78 | },
79 | {
80 | "cell_type": "code",
81 | "execution_count": null,
82 | "id": "4",
83 | "metadata": {},
84 | "outputs": [],
85 | "source": [
86 | "df.loc[df[\"is_ok_sentence\"], \"sentence\"]"
87 | ]
88 | },
89 | {
90 | "cell_type": "code",
91 | "execution_count": null,
92 | "id": "5",
93 | "metadata": {},
94 | "outputs": [],
95 | "source": [
96 | "sum(df[\"is_ok_sentence\"])"
97 | ]
98 | },
99 | {
100 | "cell_type": "code",
101 | "execution_count": null,
102 | "id": "6",
103 | "metadata": {},
104 | "outputs": [],
105 | "source": [
106 | "sentences = df.loc[df[\"is_ok_sentence\"], \"sentence\"].tolist()[:10000]"
107 | ]
108 | },
109 | {
110 | "cell_type": "code",
111 | "execution_count": null,
112 | "id": "7",
113 | "metadata": {},
114 | "outputs": [],
115 | "source": [
116 | "with open(\"../unmute/tts/voice_donation_sentences.txt\", \"w\") as f:\n",
117 | " for sentence in sentences:\n",
118 | " f.write(sentence + \"\\n\")"
119 | ]
120 | },
121 | {
122 | "cell_type": "code",
123 | "execution_count": null,
124 | "id": "8",
125 | "metadata": {},
126 | "outputs": [],
127 | "source": []
128 | }
129 | ],
130 | "metadata": {
131 | "kernelspec": {
132 | "display_name": ".venv",
133 | "language": "python",
134 | "name": "python3"
135 | },
136 | "language_info": {
137 | "codemirror_mode": {
138 | "name": "ipython",
139 | "version": 3
140 | },
141 | "file_extension": ".py",
142 | "mimetype": "text/x-python",
143 | "name": "python",
144 | "nbconvert_exporter": "python",
145 | "pygments_lexer": "ipython3",
146 | "version": "3.12.9"
147 | }
148 | },
149 | "nbformat": 4,
150 | "nbformat_minor": 5
151 | }
152 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "unmute"
3 | version = "0.1.0"
4 | description = "Make text LLMs listen and speak"
5 | readme = "README.md"
6 | requires-python = ">=3.12,<3.13"
7 | dependencies = [
8 | "fastapi[standard]>=0.115.12",
9 | "fastrtc==0.0.23",
10 | "mistralai>=1.5.1",
11 | "msgpack>=1.1.0",
12 | "msgpack-types>=0.5.0",
13 | "openai>=1.70.0",
14 | "plotly>=6.0.1",
15 | "sphn>=0.2.0",
16 | "prometheus-fastapi-instrumentator==7.1.0",
17 | "prometheus-client==0.21.0",
18 | "ruamel-yaml>=0.18.10",
19 | "redis>=6.0.0",
20 | ]
21 |
22 | [build-system]
23 | requires = ["setuptools >= 77.0.3"]
24 | build-backend = "setuptools.build_meta"
25 |
26 | [tool.setuptools.packages]
27 | find = { include = ["unmute"] }
28 |
29 | [tool.pyright]
30 | typeCheckingMode = "strict"
31 |
32 | # Unlike MyPy, Pyright makes an explicit distinction between "Unknown" (I don't know
33 | # what this is) and "Any" (I'll allow anything here). By default, "Unknown" is treated
34 | # as "Any" but these reportUnknownX settings make it an error to use "Unknown".
35 | # You'd have to explicitly cast it to "Any" or something else.
36 | # Let's disable these for now to stick to MyPy-like behavior.
37 | reportUnknownMemberType = false
38 | reportUnknownArgumentType = false
39 | reportUnknownLambdaType = false
40 | reportUnknownVariableType = false
41 | reportUnknownParameterType = false
42 |
43 | # See above for how to fix reportMissingTypeStubs issues
44 | reportMissingTypeStubs = false
45 |
46 | # Ruff removes unused imports automatically, but doesn't hurt to have this enabled.
47 | reportUnusedImport = true # true in "strict", but make it explicit
48 |
49 | reportMissingTypeArgument = false
50 |
51 |
52 | [tool.ruff.lint]
53 | select = [
54 | "B", # bugbear
55 | "E", # pep8 rules
56 | "F", # pyflakes
57 | "I001", # isort
58 | "W", # pep8 warnings
59 | # pydocstyle - check that all arguments are documented
60 | # It can make sense to add other "D" checks for more docstring
61 | # consistency, but some of them are really pedantic - like requiring
62 | # function-level docstrings for every single function
63 | "D417",
64 | ]
65 |
66 | ignore = [
67 | # Line too long errors. Ruff format --fix will fix most of these
68 | # and sometimes we want to keep long strings in one line etc.
69 | "E501",
70 | ]
71 |
72 | pydocstyle.convention = "google" # Google docstring style
73 |
74 | [dependency-groups]
75 | dev = [
76 | "jupyter>=1.1.1",
77 | "pytest>=8.3.5",
78 | "pytest-asyncio>=0.26.0",
79 | "pyright",
80 | "ruff",
81 | "pre-commit",
82 | "pyinstrument",
83 | "ffmpeg-normalize>=1.31.3",
84 | ]
85 |
--------------------------------------------------------------------------------
/services/debugger/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM ubuntu:22.04
2 |
3 | # Install many tools useful for debugging the network
4 | RUN apt-get update && \
5 | apt-get install -y \
6 | iputils-ping \
7 | iproute2 \
8 | net-tools \
9 | curl \
10 | wget \
11 | dnsutils \
12 | traceroute \
13 | tcpdump \
14 | nmap \
15 | telnet \
16 | vim \
17 | less \
18 | git \
19 | && \
20 | apt-get clean && \
21 | rm -rf /var/lib/apt/lists/*
22 |
23 |
24 | COPY --from=ghcr.io/astral-sh/uv:0.7.2 /uv /uvx /bin/
25 |
26 | RUN curl https://sh.rustup.rs -sSf | sh -s -- -y
27 | ENV PATH="/root/.cargo/bin:${PATH}"
28 |
--------------------------------------------------------------------------------
/services/grafana/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM grafana/grafana:11.6.1-ubuntu
2 |
3 | COPY dashboards /etc/grafana/dashboards
4 | COPY provisioning /etc/grafana/provisioning
5 | COPY grafana.ini /etc/grafana/grafana.ini
6 |
--------------------------------------------------------------------------------
/services/grafana/grafana.ini:
--------------------------------------------------------------------------------
1 | [auth.anonymous]
2 | enabled = true
3 | org_role = Admin
4 |
5 | [auth.basic]
6 | enabled = false
7 |
8 | [auth]
9 | disable_login_form = true
10 |
11 |
12 | # Don't use US date format. From:
13 | # https://community.grafana.com/t/how-do-we-get-grafana-to-use-this-date-format-dd-mm-yyyy-hh-mm-ss-timestamp/54666/3
14 | [date_formats]
15 | # Default system date format used in time range picker and other places where full time is displayed
16 | full_date = DD-MM-YYYY HH:mm:ss
17 |
18 | # Used by graph and other places where we only show small intervals
19 | interval_second = HH:mm:ss
20 | interval_minute = HH:mm
21 | interval_hour = DD/MM HH:mm
22 | interval_day = DD/MM
23 | interval_month = MM-YYYY
24 | interval_year = YYYY
25 |
--------------------------------------------------------------------------------
/services/grafana/provisioning/dashboards/dashboards.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: 1
2 |
3 | providers:
4 | - name: 'default'
5 | orgId: 1
6 | folder: ''
7 | type: file
8 | disableDeletion: false
9 | editable: true
10 | updateIntervalSeconds: 10
11 | options:
12 | path: /etc/grafana/dashboards
13 |
--------------------------------------------------------------------------------
/services/grafana/provisioning/datasources/datasources.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: 1
2 |
3 | datasources:
4 | - name: Prometheus
5 | type: prometheus
6 | access: proxy
7 | orgId: 1
8 | url: http://prometheus:9090
9 | isDefault: true
10 | editable: true
11 |
--------------------------------------------------------------------------------
/services/moshi-server/configs/stt-prod.toml:
--------------------------------------------------------------------------------
1 | static_dir = "./static/"
2 | log_dir = "/tmp/unmute_logs"
3 | instance_name = "tts"
4 | authorized_ids = ["public_token"]
5 |
6 | [modules.asr]
7 | path = "/api/asr-streaming"
8 | type = "BatchedAsr"
9 | lm_model_file = "hf://kyutai/stt-1b-en_fr-candle/model.safetensors"
10 | text_tokenizer_file = "hf://kyutai/stt-1b-en_fr-candle/tokenizer_en_fr_audio_8000.model"
11 | audio_tokenizer_file = "hf://kyutai/stt-1b-en_fr-candle/mimi-pytorch-e351c8d8@125.safetensors"
12 | asr_delay_in_tokens = 6
13 | batch_size = 64 # NOTE: make this smaller if running on one GPU
14 | conditioning_learnt_padding = true
15 | temperature = 0.25
16 |
17 | [modules.asr.model]
18 | audio_vocab_size = 2049
19 | text_in_vocab_size = 8001
20 | text_out_vocab_size = 8000
21 | audio_codebooks = 20
22 |
23 | [modules.asr.model.transformer]
24 | d_model = 2048
25 | num_heads = 16
26 | num_layers = 16
27 | dim_feedforward = 8192
28 | causal = true
29 | norm_first = true
30 | bias_ff = false
31 | bias_attn = false
32 | context = 750
33 | max_period = 100000
34 | use_conv_block = false
35 | use_conv_bias = true
36 | gating = "silu"
37 | norm = "RmsNorm"
38 | positional_embedding = "Rope"
39 | conv_layout = false
40 | conv_kernel_size = 3
41 | kv_repeat = 1
42 | max_seq_len = 40960
43 |
44 | [modules.asr.model.extra_heads]
45 | num_heads = 4
46 | dim = 6
47 |
--------------------------------------------------------------------------------
/services/moshi-server/configs/stt.toml:
--------------------------------------------------------------------------------
1 | static_dir = "./static/"
2 | log_dir = "/tmp/unmute_logs"
3 | instance_name = "tts"
4 | authorized_ids = ["public_token"]
5 |
6 | [modules.asr]
7 | path = "/api/asr-streaming"
8 | type = "BatchedAsr"
9 | lm_model_file = "hf://kyutai/stt-1b-en_fr-candle/model.safetensors"
10 | text_tokenizer_file = "hf://kyutai/stt-1b-en_fr-candle/tokenizer_en_fr_audio_8000.model"
11 | audio_tokenizer_file = "hf://kyutai/stt-1b-en_fr-candle/mimi-pytorch-e351c8d8@125.safetensors"
12 | asr_delay_in_tokens = 6
13 | # A higher batch size allows you to serve more users at once, but with a higher latency and memory usage.
14 | batch_size = 1
15 | conditioning_learnt_padding = true
16 | temperature = 0.25
17 |
18 | [modules.asr.model]
19 | audio_vocab_size = 2049
20 | text_in_vocab_size = 8001
21 | text_out_vocab_size = 8000
22 | audio_codebooks = 20
23 |
24 | [modules.asr.model.transformer]
25 | d_model = 2048
26 | num_heads = 16
27 | num_layers = 16
28 | dim_feedforward = 8192
29 | causal = true
30 | norm_first = true
31 | bias_ff = false
32 | bias_attn = false
33 | context = 750
34 | max_period = 100000
35 | use_conv_block = false
36 | use_conv_bias = true
37 | gating = "silu"
38 | norm = "RmsNorm"
39 | positional_embedding = "Rope"
40 | conv_layout = false
41 | conv_kernel_size = 3
42 | kv_repeat = 1
43 | max_seq_len = 40960
44 |
45 | [modules.asr.model.extra_heads]
46 | num_heads = 4
47 | dim = 6
48 |
--------------------------------------------------------------------------------
/services/moshi-server/configs/tts-prod.toml:
--------------------------------------------------------------------------------
1 | static_dir = "./static/"
2 | log_dir = "/tmp/unmute_logs"
3 | instance_name = "tts"
4 | authorized_ids = ["public_token"]
5 |
6 | [modules.tts_py]
7 | type = "Py"
8 | path = "/api/tts_streaming"
9 | text_tokenizer_file = "hf://kyutai/tts-1.6b-en_fr/tokenizer_spm_8k_en_fr_audio.model"
10 | batch_size = 4 # NOTE: make this smaller if running on one GPU
11 | text_bos_token = 1
12 |
13 | [modules.tts_py.py]
14 | log_folder = "/tmp/unmute_logs"
15 | # We could use replace **/*.safetensors with unmute-prod-website/*.safetensors
16 | # to only get the voices used in Unmute, but we are using the TTS for the demo
17 | # on the project page too and for that we want to load the other voices as well
18 | voice_folder = "hf-snapshot://kyutai/tts-voices/**/*.safetensors"
19 | default_voice = "unmute-prod-website/default_voice.wav"
20 | cfg_coef = 2.0
21 | cfg_is_no_text = true
22 | padding_between = 1
23 | n_q = 24
24 |
--------------------------------------------------------------------------------
/services/moshi-server/configs/tts.toml:
--------------------------------------------------------------------------------
1 | static_dir = "./static/"
2 | log_dir = "/tmp/unmute_logs"
3 | instance_name = "tts"
4 | authorized_ids = ["public_token"]
5 |
6 | [modules.tts_py]
7 | type = "Py"
8 | path = "/api/tts_streaming"
9 | text_tokenizer_file = "hf://kyutai/tts-1.6b-en_fr/tokenizer_spm_8k_en_fr_audio.model"
10 | # A higher batch size allows you to serve more users at once, but with a higher latency and memory usage.
11 | batch_size = 2
12 | text_bos_token = 1
13 |
14 | [modules.tts_py.py]
15 | log_folder = "/tmp/unmute_logs"
16 | # We could use replace **/*.safetensors with unmute-prod-website/*.safetensors
17 | # to only get the voices used in Unmute, but we are using the TTS for the demo
18 | # on the project page too and for that we want to load the other voices as well
19 | voice_folder = "hf-snapshot://kyutai/tts-voices/**/*.safetensors"
20 | default_voice = "unmute-prod-website/default_voice.wav"
21 | cfg_coef = 2.0
22 | cfg_is_no_text = true
23 | padding_between = 1
24 | n_q = 24
25 |
--------------------------------------------------------------------------------
/services/moshi-server/configs/voice-cloning.toml:
--------------------------------------------------------------------------------
1 | static_dir = "./static/"
2 | log_dir = "$HOME/tmp/tts-logs"
3 | instance_name = "voice"
4 | authorized_ids = ["public_token"]
5 |
6 | [modules.voice_py]
7 | type = "PyPost"
8 | path = "/api/voice"
9 |
10 | [modules.voice_py.py]
11 | mimi_weight = "hf://kyutai/unmute-voice-cloning/e9d43d50_500_mimi_voice.safetensors"
12 |
--------------------------------------------------------------------------------
/services/moshi-server/private.Dockerfile:
--------------------------------------------------------------------------------
1 | # This is the Kyutai-internal version.
2 | FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS base
3 |
4 | # Set environment variables to avoid interactive prompts during package installation
5 | ENV DEBIAN_FRONTEND=noninteractive
6 |
7 | RUN apt-get update && apt-get install -y \
8 | curl \
9 | build-essential \
10 | ca-certificates \
11 | libssl-dev \
12 | git \
13 | pkg-config \
14 | cmake \
15 | openssh-client \
16 | --no-install-recommends && \
17 | rm -rf /var/lib/apt/lists/*
18 |
19 | RUN curl https://sh.rustup.rs -sSf | sh -s -- -y
20 | ENV PATH="/root/.cargo/bin:$PATH"
21 |
22 | COPY --from=ghcr.io/astral-sh/uv:0.7.2 /uv /uvx /bin/
23 |
24 | ARG GITHUB_ORG
25 | RUN --mount=type=ssh \
26 | mkdir -p ~/.ssh && \
27 | ssh-keyscan github.com >> ~/.ssh/known_hosts && \
28 | git clone git@github.com:${GITHUB_ORG}/moshi-rs.git /app \
29 | && cd /app \
30 | && git checkout 30fc5a90162ec32014672127f48da8ee4625d0e6
31 |
32 | WORKDIR /app
33 |
34 | # When starting the container for the first time, we need to compile and download
35 | # everything, so disregarding healthcheck failure for 10 minutes is fine.
36 | # We have a volume storing the build cache, so subsequent starts will be faster.
37 | HEALTHCHECK --start-period=10m \
38 | CMD curl --fail http://localhost:8080/api/build_info || exit 1
39 |
40 | EXPOSE 8080
41 | ENV RUST_BACKTRACE=1
42 |
43 | COPY . .
44 |
45 | ENTRYPOINT ["uv", "run", "--locked", "--project", "./moshi-server", "./start_moshi_server_private.sh"]
46 |
--------------------------------------------------------------------------------
/services/moshi-server/public.Dockerfile:
--------------------------------------------------------------------------------
1 | # This is the public-facing version.
2 | FROM nvidia/cuda:12.8.1-devel-ubuntu22.04 AS base
3 |
4 | # Set environment variables to avoid interactive prompts during package installation
5 | ENV DEBIAN_FRONTEND=noninteractive
6 |
7 | # Install dependencies, including dos2unix to handle Windows line endings
8 | RUN apt-get update && apt-get install -y \
9 | curl \
10 | build-essential \
11 | ca-certificates \
12 | libssl-dev \
13 | git \
14 | pkg-config \
15 | cmake \
16 | wget \
17 | openssh-client \
18 | dos2unix \
19 | --no-install-recommends && \
20 | rm -rf /var/lib/apt/lists/*
21 |
22 | RUN curl https://sh.rustup.rs -sSf | sh -s -- -y
23 | ENV PATH="/root/.cargo/bin:$PATH"
24 |
25 | COPY --from=ghcr.io/astral-sh/uv:0.7.2 /uv /uvx /bin/
26 |
27 | WORKDIR /app
28 |
29 | # When starting the container for the first time, we need to compile and download
30 | # everything, so disregarding healthcheck failure for 10 minutes is fine.
31 | # We have a volume storing the build cache, so subsequent starts will be faster.
32 | HEALTHCHECK --start-period=10m \
33 | CMD curl --fail http://localhost:8080/api/build_info || exit 1
34 |
35 | EXPOSE 8080
36 | ENV RUST_BACKTRACE=1
37 |
38 | RUN wget https://raw.githubusercontent.com/kyutai-labs/moshi/bf359af7694add34c13e65d2f009f0cb474d87cc/rust/moshi-server/pyproject.toml
39 | RUN wget https://raw.githubusercontent.com/kyutai-labs/moshi/bf359af7694add34c13e65d2f009f0cb474d87cc/rust/moshi-server/uv.lock
40 |
41 | COPY . .
42 |
43 | # Ensure the startup script is runnable inside the container.
44 | # This prevents script errors that can happen if the project was cloned on Windows,
45 | # which uses a different text file format (CRLF) than the Linux environment in the container (LF).
46 | RUN dos2unix ./start_moshi_server_public.sh && chmod +x ./start_moshi_server_public.sh
47 |
48 | ENTRYPOINT ["uv", "run", "--locked", "--project", "./moshi-server", "./start_moshi_server_public.sh"]
49 |
--------------------------------------------------------------------------------
/services/moshi-server/start_moshi_server_private.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # This is the Kyutai-internal version.
3 | set -ex
4 |
5 | export LD_LIBRARY_PATH=$(python3 -c 'import sysconfig; print(sysconfig.get_config_var("LIBDIR"))')
6 |
7 | uvx --from 'huggingface_hub[cli]' huggingface-cli login --token $HUGGING_FACE_HUB_TOKEN
8 |
9 | cargo run --features=cuda --bin=moshi-server -r -- $@
10 |
--------------------------------------------------------------------------------
/services/moshi-server/start_moshi_server_public.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # This is the public-facing version.
3 | set -ex
4 |
5 | export LD_LIBRARY_PATH=$(python3 -c 'import sysconfig; print(sysconfig.get_config_var("LIBDIR"))')
6 |
7 | uvx --from 'huggingface_hub[cli]' huggingface-cli login --token $HUGGING_FACE_HUB_TOKEN
8 |
9 | CARGO_TARGET_DIR=/app/target cargo install --features cuda moshi-server@0.6.4
10 |
11 | # Subtle detail here: We use the full path to `moshi-server` because there is a `moshi-server` binary
12 | # from the `moshi` Python package. We'll fix this conflict soon.
13 | /root/.cargo/bin/moshi-server $@
--------------------------------------------------------------------------------
/services/prometheus/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM prom/prometheus:v3.1.0
2 |
3 | COPY prometheus.yml /etc/prometheus/
4 |
--------------------------------------------------------------------------------
/services/prometheus/prometheus.yml:
--------------------------------------------------------------------------------
1 | scrape_configs:
2 | - job_name: 'dockerswarm'
3 | scrape_interval: 5s
4 | # Read the Docker Swarm api to discover the services and containers
5 | dockerswarm_sd_configs:
6 | - host: unix:///var/run/docker.sock
7 | role: tasks
8 | relabel_configs:
9 | # Keep only the tasks that are running
10 | - source_labels: [__meta_dockerswarm_task_desired_state]
11 | regex: running
12 | action: keep
13 | # Keep only the tasks that have the label prometheus-port
14 | - source_labels: [__meta_dockerswarm_service_label_prometheus_port]
15 | regex: .+
16 | action: keep
17 | # Rename the job to the service name (but remove the stack name)
18 | - source_labels: [__meta_dockerswarm_service_name]
19 | regex: .*_(.+)
20 | replacement: $1
21 | target_label: job
22 | # Set the ip and port where the /metrics are exposed
23 | - source_labels: [__address__, __meta_dockerswarm_service_label_prometheus_port]
24 | regex: ([^:]+):\d+;(\d+)
25 | replacement: $1:$2
26 | target_label: __address__
27 |
--------------------------------------------------------------------------------
/setup_gpu_swarm_node.py:
--------------------------------------------------------------------------------
1 | import json
2 | import subprocess
3 | import time
4 | from pathlib import Path
5 |
6 |
7 | def get_all_uuids() -> list[str]:
8 | return (
9 | subprocess.check_output(
10 | ["nvidia-smi", "--query-gpu=uuid", "--format=csv,noheader"],
11 | stderr=subprocess.STDOUT,
12 | text=True,
13 | )
14 | .strip()
15 | .splitlines()
16 | )
17 |
18 |
19 | def json_dumps(obj: dict) -> str:
20 | return json.dumps(obj, indent=4, sort_keys=True)
21 |
22 |
23 | def setup_docker_config():
24 | daemon_config_file = Path("/etc/docker/daemon.json")
25 | daemon_config = json.loads(daemon_config_file.read_text())
26 | print("Previous daemon config:\n", json_dumps(daemon_config))
27 |
28 | daemon_config["node-generic-resources"] = [
29 | "gpu=" + uuid for uuid in get_all_uuids()
30 | ]
31 |
32 | print("New daemon config:\n", json_dumps(daemon_config))
33 | daemon_config_file.write_text(json_dumps(daemon_config))
34 |
35 |
36 | def setup_nvidia_docker_config():
37 | nvidia_container_config_file = Path("/etc/nvidia-container-runtime/config.toml")
38 | nvidia_container_config = nvidia_container_config_file.read_text()
39 | line = '#swarm-resource = "DOCKER_RESOURCE_GPU"'
40 | if line in nvidia_container_config:
41 | print("uncommenting", line)
42 | nvidia_container_config = nvidia_container_config.replace(
43 | line, line.removeprefix("#")
44 | )
45 | print("New nvidia-container config:\n", nvidia_container_config)
46 | nvidia_container_config_file.write_text(nvidia_container_config)
47 | else:
48 | print(
49 | "Line not found in nvidia-container config, did you already uncomment it?"
50 | )
51 |
52 |
53 | def restarting_docker():
54 | print("Restarting docker")
55 | subprocess.check_call(["systemctl", "restart", "docker"])
56 | time.sleep(1)
57 |
58 |
59 | def checking_nvidia_docker():
60 | print("Checking nvidia-docker")
61 | subprocess.check_call(
62 | [
63 | "docker",
64 | "run",
65 | "--rm",
66 | "--runtime=nvidia",
67 | "--gpus",
68 | "all",
69 | "ubuntu",
70 | "nvidia-smi",
71 | ]
72 | )
73 |
74 |
75 | def main():
76 | setup_docker_config()
77 | setup_nvidia_docker_config()
78 | restarting_docker()
79 | checking_nvidia_docker()
80 |
81 |
82 | if __name__ == "__main__":
83 | main()
84 |
--------------------------------------------------------------------------------
/tests/test_exponential_moving_average.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from unmute.stt.exponential_moving_average import ExponentialMovingAverage
4 |
5 |
6 | def test_ema():
7 | ema = ExponentialMovingAverage(attack_time=0.1, release_time=0.5)
8 | ema.update(dt=1.0, new_value=0.0)
9 | assert ema.value == 0.0
10 |
11 | ema.update(dt=0.1, new_value=1.0)
12 | assert ema.value == pytest.approx(0.5)
13 |
14 | ema.update(dt=0.25, new_value=0.0)
15 | ema.update(dt=0.25, new_value=0.0)
16 | assert ema.value == pytest.approx(0.25)
17 |
18 | # Should work even with values different than 0 and 1
19 | ema.update(dt=0.1, new_value=0.75)
20 | assert ema.value == pytest.approx(0.5)
21 |
22 | ema.update(dt=1e9, new_value=1.0)
23 | assert ema.value == pytest.approx(1.0)
24 |
--------------------------------------------------------------------------------
/tests/test_llm_utils.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from unmute.llm.llm_utils import rechunk_to_words
4 |
5 |
6 | async def make_iterator(s: str):
7 | parts = s.split("|")
8 | for part in parts:
9 | yield part
10 |
11 |
12 | @pytest.mark.asyncio
13 | async def test_rechunk_to_words():
14 | test_strings = [
15 | "hel|lo| |w|orld",
16 | "hello world",
17 | "hello \nworld",
18 | "hello| |world",
19 | "hello| |world|.",
20 | "h|e|l|l|o| |\tw|o|r|l|d|.",
21 | "h|e|l|l|o\n| |w|o|r|l|d|.",
22 | ]
23 |
24 | for s in test_strings:
25 | parts = [x async for x in rechunk_to_words(make_iterator(s))]
26 | assert parts[0] == "hello"
27 | assert parts[1] == " world" or parts[1] == " world."
28 |
29 | async def f(s: str):
30 | x = [x async for x in rechunk_to_words(make_iterator(s))]
31 | print(x)
32 | return x
33 |
34 | assert await f("i am ok") == ["i", " am", " ok"]
35 | assert await f(" i am ok") == [" i", " am", " ok"]
36 | assert await f(" they are ok") == [" they", " are", " ok"]
37 | assert await f(" foo bar") == [" foo", " bar"]
38 | assert await f(" \t foo bar") == [" foo", " bar"]
39 |
--------------------------------------------------------------------------------
/unmute/audio_input_override.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import numpy as np
4 | import sphn
5 |
6 | from unmute.kyutai_constants import SAMPLE_RATE
7 |
8 |
9 | class AudioInputOverride:
10 | def __init__(self, file: Path):
11 | data, _sr = sphn.read(file, sample_rate=SAMPLE_RATE)
12 | assert data.ndim == 2
13 |
14 | if data.dtype != np.int16:
15 | data = (data * np.iinfo(np.int16).max).astype(np.int16)
16 |
17 | self.data = data
18 | self.position = 0
19 |
20 | def override(self, original_data: np.ndarray) -> np.ndarray:
21 | if self.position + original_data.shape[1] > self.data.shape[1]:
22 | return original_data
23 |
24 | data = self.data[
25 | :, self.position : self.position + original_data.shape[1]
26 | ].copy()
27 | self.position += original_data.shape[1]
28 |
29 | assert data.shape == original_data.shape, (
30 | f"{data.shape} != {original_data.shape}"
31 | )
32 | assert data.dtype == original_data.dtype
33 |
34 | return data
35 |
--------------------------------------------------------------------------------
/unmute/audio_stream_saver.py:
--------------------------------------------------------------------------------
1 | from logging import getLogger
2 | from pathlib import Path
3 |
4 | import numpy as np
5 | import sphn
6 |
7 | from unmute.kyutai_constants import SAMPLE_RATE
8 |
9 | DEBUG_DIR = Path(__file__).parents[1] / "debug"
10 | logger = getLogger(__name__)
11 |
12 |
13 | class AudioStreamSaver:
14 | """Collect and save an audio stream. For debugging"""
15 |
16 | def __init__(
17 | self,
18 | interval_sec: float = 1.0,
19 | output_path: str | Path | None = None,
20 | max_saves: int | None = 1,
21 | ):
22 | self.interval_sec = interval_sec
23 | self.max_saves = max_saves
24 | self.n_saves_done = 0
25 |
26 | if output_path is None:
27 | self.output_path = DEBUG_DIR / "out.wav"
28 | else:
29 | self.output_path = Path(output_path)
30 |
31 | self.buffer = []
32 |
33 | def add(self, audio_chunk: np.ndarray):
34 | """Add a chunk of audio. Save if we've collected enough."""
35 | if self.max_saves is not None and self.n_saves_done >= self.max_saves:
36 | return
37 |
38 | assert audio_chunk.dtype == np.float32
39 | assert audio_chunk.ndim == 1
40 |
41 | self.buffer.append(audio_chunk)
42 |
43 | if sum(len(x) for x in self.buffer) / SAMPLE_RATE >= self.interval_sec:
44 | output_path = self.output_path
45 | if self.max_saves != 1: # None is ok too
46 | output_path = output_path.with_stem(
47 | output_path.stem + f"_{self.n_saves_done + 1}"
48 | )
49 |
50 | sphn.write_wav(
51 | output_path,
52 | np.concatenate(self.buffer).astype(np.float32),
53 | SAMPLE_RATE,
54 | )
55 | self.n_saves_done += 1
56 | self.buffer.clear()
57 | logger.info(f"Saved audio to {output_path}")
58 |
--------------------------------------------------------------------------------
/unmute/cache.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import time
3 | from typing import Generic, Optional, TypeVar, cast
4 |
5 | import redis
6 | from redis.typing import EncodableT # Import EncodableT for Redis compatibility
7 |
8 | from unmute.kyutai_constants import REDIS_SERVER
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 | T = TypeVar("T", bound=EncodableT) # Generic type bound to EncodableT
13 |
14 |
15 | class CacheError(Exception):
16 | """An error happened while accessing the cache.
17 |
18 | This is so that we get the same exception type regardless of the cache implementation.
19 | """
20 |
21 |
22 | class LocalCache(Generic[T]):
23 | def __init__(self, ttl_seconds: int = 3600): # Default 1 hour expiration
24 | self.cache: dict[
25 | str, tuple[T, float]
26 | ] = {} # {key: (value, expiration_timestamp)}
27 | self.ttl_seconds = ttl_seconds
28 |
29 | def get(self, key: str) -> Optional[T]:
30 | cached = self.cache.get(key)
31 | if cached is not None:
32 | value, expiration = cached
33 | if time.time() < expiration:
34 | return value
35 | else:
36 | # Remove expired entry
37 | del self.cache[key]
38 | else:
39 | return None
40 |
41 | def set(self, key: str, value: T):
42 | expiration = time.time() + self.ttl_seconds
43 | self.cache[key] = (value, expiration)
44 |
45 | def delete(self, key: str):
46 | """Delete a key from the cache."""
47 | if key in self.cache:
48 | del self.cache[key]
49 |
50 | def cleanup(self):
51 | """Remove all expired entries"""
52 | now = time.time()
53 | expired_keys = [k for k, (_, exp) in self.cache.items() if exp < now]
54 | for k in expired_keys:
55 | del self.cache[k]
56 |
57 |
58 | class RedisCache(Generic[T]):
59 | def __init__(self, redis_url: str, prefix: str, ttl_seconds: int = 3600):
60 | self.ttl_seconds = ttl_seconds
61 | self.prefix = prefix
62 | self.redis_client = redis.Redis.from_url(redis_url, socket_connect_timeout=2)
63 |
64 | def get(self, key: str) -> Optional[T]:
65 | key = f"{self.prefix}:{key}"
66 |
67 | try:
68 | redis_value = self.redis_client.get(key)
69 | if redis_value is not None:
70 | logger.info(f"Retrieved value from Redis: {key}")
71 | return cast(T, redis_value)
72 | else:
73 | return None
74 | except redis.RedisError as e:
75 | raise CacheError(f"Failed to store value in Redis: {e}") from e
76 |
77 | def set(self, key: str, value: T):
78 | key = f"{self.prefix}:{key}"
79 | try:
80 | # Store with the TTL
81 | self.redis_client.setex(key, self.ttl_seconds, value)
82 | except redis.RedisError as e:
83 | raise CacheError(f"Failed to store value in Redis: {e}") from e
84 |
85 | def delete(self, key: str):
86 | key = f"{self.prefix}:{key}"
87 | try:
88 | # No error if the key does not exist.
89 | self.redis_client.delete(key)
90 | except redis.RedisError as e:
91 | raise CacheError(f"Failed to delete value from Redis: {e}") from e
92 |
93 | def cleanup(self):
94 | pass # No cleanup needed for Redis
95 |
96 |
97 | def get_cache(prefix: str, ttl_seconds: int) -> LocalCache[T] | RedisCache[T]:
98 | """
99 | Returns the appropriate cache based on the environment variables.
100 | If KYUTAI_REDIS_URL is set, it returns a RedisCache instance.
101 | If not, it returns a LocalCache instance.
102 | """
103 | if REDIS_SERVER is not None:
104 | cache = RedisCache[T](REDIS_SERVER, prefix, ttl_seconds=ttl_seconds)
105 | else:
106 | logger.info(
107 | "Redis cache address was not given in environment variables, using local cache."
108 | )
109 | cache = LocalCache[T](ttl_seconds=ttl_seconds)
110 |
111 | return cache
112 |
--------------------------------------------------------------------------------
/unmute/exceptions.py:
--------------------------------------------------------------------------------
1 | import unmute.openai_realtime_api_events as ora
2 |
3 |
4 | class MissingServiceAtCapacity(Exception):
5 | """A service is operating at capacity, but no serious error."""
6 |
7 | def __init__(self, service: str):
8 | self.service = service
9 | super().__init__(f"{service} is not available.")
10 |
11 |
12 | class MissingServiceTimeout(Exception):
13 | """A service timed out."""
14 |
15 | def __init__(self, service: str):
16 | self.service = service
17 | super().__init__(f"{service} timed out.")
18 |
19 |
20 | class WebSocketClosedError(Exception):
21 | """Remote web socket is closed, let's move on."""
22 |
23 |
24 | def make_ora_error(type: str, message: str) -> ora.Error:
25 | details = ora.ErrorDetails(type=type, message=message)
26 | return ora.Error(error=details)
27 |
--------------------------------------------------------------------------------
/unmute/kyutai_constants.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 |
4 | from unmute.websocket_utils import http_to_ws
5 |
6 | HEADERS = {"kyutai-api-key": "public_token"}
7 |
8 | # The defaults are already ws://, but make the env vars support http:// and https://
9 | STT_SERVER = http_to_ws(os.environ.get("KYUTAI_STT_URL", "ws://localhost:8090"))
10 | TTS_SERVER = http_to_ws(os.environ.get("KYUTAI_TTS_URL", "ws://localhost:8089"))
11 | LLM_SERVER = os.environ.get("KYUTAI_LLM_URL", "http://localhost:8091")
12 | KYUTAI_LLM_MODEL = os.environ.get("KYUTAI_LLM_MODEL")
13 | KYUTAI_LLM_API_KEY = os.environ.get("KYUTAI_LLM_API_KEY")
14 | VOICE_CLONING_SERVER = os.environ.get(
15 | "KYUTAI_VOICE_CLONING_URL", "http://localhost:8092"
16 | )
17 | # If None, a dict-based cache will be used instead of Redis
18 | REDIS_SERVER = os.environ.get("KYUTAI_REDIS_URL")
19 |
20 | SPEECH_TO_TEXT_PATH = "/api/asr-streaming"
21 | TEXT_TO_SPEECH_PATH = "/api/tts_streaming"
22 |
23 | repo_root = Path(__file__).parents[1]
24 | VOICE_DONATION_DIR = Path(
25 | os.environ.get("KYUTAI_VOICE_DONATION_DIR", repo_root / "voices" / "donation")
26 | )
27 |
28 | # If None, recordings will not be saved
29 | _recordings_dir = os.environ.get("KYUTAI_RECORDINGS_DIR")
30 | RECORDINGS_DIR = Path(_recordings_dir) if _recordings_dir else None
31 |
32 | # Also checked on the frontend, see constant of the same name
33 | MAX_VOICE_FILE_SIZE_MB = 4
34 |
35 |
36 | SAMPLE_RATE = 24000
37 | SAMPLES_PER_FRAME = 1920
38 | FRAME_TIME_SEC = SAMPLES_PER_FRAME / SAMPLE_RATE # 0.08
39 | # TODO: make it so that we can read this from the ASR server?
40 | STT_DELAY_SEC = 0.5
41 |
--------------------------------------------------------------------------------
/unmute/llm/chatbot.py:
--------------------------------------------------------------------------------
1 | from logging import getLogger
2 | from typing import Any, Literal
3 |
4 | from unmute.llm.llm_utils import preprocess_messages_for_llm
5 | from unmute.llm.system_prompt import ConstantInstructions, Instructions
6 |
7 | ConversationState = Literal["waiting_for_user", "user_speaking", "bot_speaking"]
8 |
9 | logger = getLogger(__name__)
10 |
11 |
12 | class Chatbot:
13 | def __init__(self):
14 | # It's actually a list of ChatCompletionStreamRequestMessagesTypedDict but then
15 | # it's really difficult to convince Python you're passing in the right type
16 | self.chat_history: list[dict[Any, Any]] = [
17 | {"role": "system", "content": ConstantInstructions().make_system_prompt()}
18 | ]
19 | self._instructions: Instructions | None = None
20 |
21 | def conversation_state(self) -> ConversationState:
22 | if not self.chat_history:
23 | return "waiting_for_user"
24 |
25 | last_message = self.chat_history[-1]
26 | if last_message["role"] == "assistant":
27 | return "bot_speaking"
28 | elif last_message["role"] == "user":
29 | if last_message["content"].strip() != "":
30 | return "user_speaking"
31 | else:
32 | # Or do we want "user_speaking" here?
33 | return "waiting_for_user"
34 | elif last_message["role"] == "system":
35 | return "waiting_for_user"
36 | else:
37 | raise RuntimeError(f"Unknown role: {last_message['role']}")
38 |
39 | async def add_chat_message_delta(
40 | self,
41 | delta: str,
42 | role: Literal["user", "assistant"],
43 | generating_message_i: int | None = None, # Avoid race conditions
44 | ) -> bool:
45 | """Add a partial message to the chat history, adding spaces if necessary.
46 |
47 | Returns:
48 | True if the message is a new message, False if it is a continuation of
49 | the last message.
50 | """
51 | if (
52 | generating_message_i is not None
53 | and len(self.chat_history) > generating_message_i
54 | ):
55 | logger.warning(
56 | f"Tried to add {delta=} {role=} "
57 | f"but {generating_message_i=} didn't match"
58 | )
59 | return False
60 |
61 | if not self.chat_history or self.chat_history[-1]["role"] != role:
62 | self.chat_history.append({"role": role, "content": delta})
63 | return True
64 | else:
65 | last_message: str = self.chat_history[-1]["content"]
66 |
67 | # Add a space if necessary
68 | needs_space_left = last_message != "" and not last_message[-1].isspace()
69 | needs_space_right = delta != "" and not delta[0].isspace()
70 |
71 | if needs_space_left and needs_space_right:
72 | delta = " " + delta
73 |
74 | self.chat_history[-1]["content"] += delta
75 | return last_message == "" # new message if `last_message` was empty
76 |
77 | def preprocessed_messages(self):
78 | if len(self.chat_history) > 2:
79 | messages = self.chat_history
80 | else:
81 | assert len(self.chat_history) >= 1
82 | assert self.chat_history[0]["role"] == "system"
83 |
84 | messages = [
85 | self.chat_history[0],
86 | # Some models, like Gemma, don't like it when there is no user message
87 | # so we add one.
88 | {"role": "user", "content": "Hello!"},
89 | ]
90 |
91 | messages = preprocess_messages_for_llm(messages)
92 | return messages
93 |
94 | def set_instructions(self, instructions: Instructions):
95 | # Note that make_system_prompt() might not be deterministic, so we run it only
96 | # once and save the result. We still keep self._instructions because it's used
97 | # to check whether initial instructions have been set.
98 | self._update_system_prompt(instructions.make_system_prompt())
99 | self._instructions = instructions
100 |
101 | def _update_system_prompt(self, system_prompt: str):
102 | self.chat_history[0] = {"role": "system", "content": system_prompt}
103 |
104 | def get_system_prompt(self) -> str:
105 | assert len(self.chat_history) > 0
106 | assert self.chat_history[0]["role"] == "system"
107 | return self.chat_history[0]["content"]
108 |
109 | def get_instructions(self) -> Instructions | None:
110 | return self._instructions
111 |
112 | def last_message(self, role: str) -> str | None:
113 | valid_messages = [
114 | message
115 | for message in self.chat_history
116 | if message["role"] == role and message["content"].strip() != ""
117 | ]
118 | if valid_messages:
119 | return valid_messages[-1]["content"]
120 | else:
121 | return None
122 |
--------------------------------------------------------------------------------
/unmute/llm/newsapi.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 |
4 | import requests
5 | from pydantic import BaseModel
6 |
7 | from unmute.cache import CacheError, get_cache
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 | newsapi_api_key = os.environ.get("NEWSAPI_API_KEY")
12 |
13 |
14 | class Source(BaseModel):
15 | id: str | None
16 | name: str
17 |
18 |
19 | class Article(BaseModel):
20 | source: Source
21 | author: str | None
22 | title: str
23 | description: str | None
24 | # Omit the URLs because we don't need them, save space
25 | # url: HttpUrl
26 | # urlToImage: HttpUrl | None
27 | publishedAt: str
28 | content: str | None
29 |
30 |
31 | class NewsResponse(BaseModel):
32 | status: str
33 | totalResults: int
34 | articles: list[Article]
35 |
36 |
37 | if not newsapi_api_key:
38 | logger.warning(
39 | "NEWSAPI_API_KEY is not set. News API functionality will be disabled."
40 | )
41 |
42 |
43 | cache = get_cache("newsapi", ttl_seconds=60 * 60 * 4) # 4 hours
44 | CACHE_KEY = "news"
45 |
46 |
47 | def get_news_without_caching() -> NewsResponse | None:
48 | if not newsapi_api_key:
49 | return None
50 |
51 | logger.info("Fetching news from News API")
52 | response = requests.get(
53 | "https://newsapi.org/v2/everything?sources=the-verge",
54 | headers={"Authorization": newsapi_api_key},
55 | )
56 | response.raise_for_status()
57 | news_response = NewsResponse(**response.json())
58 |
59 | return news_response
60 |
61 |
62 | def get_news() -> NewsResponse | None:
63 | try:
64 | cached_news_raw = cache.get(CACHE_KEY)
65 | except CacheError as e:
66 | logger.error(f"Failed to fetch news from cache: {e}")
67 | # Refuse to query because that would mean we have to query the API every time
68 | return None
69 |
70 | cached_news = (
71 | NewsResponse.model_validate_json(cached_news_raw) if cached_news_raw else None
72 | )
73 |
74 | if cached_news is None:
75 | try:
76 | cached_news = get_news_without_caching()
77 | if cached_news:
78 | cache.set(CACHE_KEY, cached_news.model_dump_json())
79 |
80 | except Exception as e:
81 | logger.error(f"Failed to fetch news: {e}")
82 | return None
83 |
84 | return cached_news
85 |
86 |
87 | if __name__ == "__main__":
88 | news = get_news()
89 | if news:
90 | print(news.model_dump_json(indent=2))
91 |
--------------------------------------------------------------------------------
/unmute/loadtest/dummy_tts_server.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import logging
3 | import random
4 |
5 | import msgpack
6 | import numpy as np
7 | from fastapi import FastAPI, WebSocket, WebSocketDisconnect
8 |
9 | from unmute.kyutai_constants import SAMPLE_RATE, SAMPLES_PER_FRAME
10 |
11 | TEXT_TO_SPEECH_PATH = "/api/tts_streaming"
12 |
13 | app = FastAPI()
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 |
18 | def generate_sine_wave(
19 | duration_s: float, frequency: float = 440.0
20 | ) -> list[list[float]]:
21 | """Generate a sine wave with the given duration and frequency.
22 | Returns a list of chunks, where each chunk contains exactly CHUNK_SIZE samples,
23 | except possibly the last chunk.
24 | """
25 | num_samples = int(duration_s * SAMPLE_RATE)
26 | t = np.linspace(0, duration_s, num_samples, endpoint=False)
27 |
28 | # Generate sine wave
29 | sine_wave = 0.5 * np.sin(2 * np.pi * frequency * t)
30 |
31 | # Apply envelope for smooth start and end
32 | envelope = np.ones_like(sine_wave)
33 | fade_samples = min(
34 | int(0.05 * SAMPLE_RATE), num_samples // 4
35 | ) # 50ms fade or 1/4 of sound
36 | if fade_samples > 0 and num_samples > 2 * fade_samples:
37 | envelope[:fade_samples] = np.linspace(0, 1, fade_samples)
38 | envelope[-fade_samples:] = np.linspace(1, 0, fade_samples)
39 |
40 | amplitude = 0.3
41 | envelope = amplitude * envelope
42 |
43 | # Apply envelope to sine wave
44 | audio_data = sine_wave * envelope
45 |
46 | # Split into chunks of CHUNK_SIZE
47 | chunks = []
48 | for i in range(0, len(audio_data), SAMPLES_PER_FRAME):
49 | chunk = audio_data[i : i + SAMPLES_PER_FRAME]
50 |
51 | # If we have a partial chunk at the end, pad it with zeros
52 | if len(chunk) < SAMPLES_PER_FRAME:
53 | padding = np.zeros(SAMPLES_PER_FRAME - len(chunk))
54 | chunk = np.concatenate([chunk, padding])
55 |
56 | chunks.append(chunk.tolist())
57 |
58 | return chunks
59 |
60 |
61 | @app.get("/api/build_info")
62 | def get_build_info():
63 | return {"note": "this is a dummy build info"}
64 |
65 |
66 | @app.websocket(TEXT_TO_SPEECH_PATH)
67 | async def websocket_endpoint(websocket: WebSocket):
68 | await websocket.accept()
69 |
70 | try:
71 | current_time = 0.0
72 |
73 | while True:
74 | try:
75 | message = await asyncio.wait_for(websocket.receive(), timeout=1.0)
76 | logger.info(f"Received message type: {type(message)}")
77 | except asyncio.TimeoutError:
78 | # This prevents the loop from completely blocking signals
79 | await asyncio.sleep(0.01)
80 | continue
81 |
82 | # message = await websocket.receive()
83 | logger.info(message)
84 |
85 | if "text" in message:
86 | text = message["text"]
87 | else:
88 | if message["bytes"] == b"\0":
89 | break
90 | else:
91 | raise ValueError(f"Invalid message: {message}")
92 |
93 | if not text.strip():
94 | continue
95 |
96 | words = text.strip().split()
97 |
98 | frame_length = SAMPLES_PER_FRAME / SAMPLE_RATE
99 |
100 | for word in words:
101 | # Sounds more fun if the lengths are uneven
102 | word_duration = frame_length * len(word)
103 |
104 | start_time = current_time
105 | stop_time = current_time + word_duration
106 |
107 | # Send text message with timing information
108 | text_message = {
109 | "type": "Text",
110 | "text": word,
111 | "start_s": start_time,
112 | "stop_s": stop_time,
113 | }
114 | await websocket.send_bytes(msgpack.packb(text_message))
115 |
116 | # Generate audio (sine wave) for this word, split into fixed-size chunks
117 | note = random.randint(0, 12)
118 | frequency = 440 * (2 ** (note / 12))
119 | audio_chunks = generate_sine_wave(word_duration, frequency=frequency)
120 |
121 | # Calculate time for each chunk (for consistent pacing)
122 | chunk_duration = SAMPLES_PER_FRAME / SAMPLE_RATE
123 | chunk_count = len(audio_chunks)
124 |
125 | # Send each audio chunk with proper timing
126 | for chunk_idx, pcm_data in enumerate(audio_chunks):
127 | audio_message = {"type": "Audio", "pcm": pcm_data}
128 | await websocket.send_bytes(msgpack.packb(audio_message))
129 |
130 | # Only sleep between chunks (not after the last chunk)
131 | if chunk_idx < chunk_count - 1:
132 | await asyncio.sleep(chunk_duration)
133 |
134 | # Calculate remaining time to wait to maintain 0.5s per word
135 | # We've already waited (chunk_count-1) * chunk_duration seconds
136 | remaining_wait = word_duration - (chunk_count - 1) * chunk_duration
137 | if remaining_wait > 0:
138 | await asyncio.sleep(remaining_wait)
139 |
140 | current_time += word_duration
141 |
142 | except WebSocketDisconnect:
143 | print("Client disconnected")
144 |
145 | await websocket.close()
146 |
147 |
148 | if __name__ == "__main__":
149 | import sys
150 |
151 | print(f"Run this via:\nfastapi dev {sys.argv[0]}")
152 | exit(1)
153 |
--------------------------------------------------------------------------------
/unmute/loadtest/generate_dataset_for_vllm.py:
--------------------------------------------------------------------------------
1 | """Generate data for benchmarking with vLLM's benchmark_serving.py.
2 |
3 | See:
4 | https://github.com/vllm-project/vllm/tree/main/benchmarks
5 | """
6 |
7 | import json
8 | import random
9 |
10 | from unmute.tts.voices import VoiceList
11 |
12 |
13 | def random_id():
14 | return "".join(random.choices("1234567890", k=8))
15 |
16 |
17 | METATEMPLATE = "user{system_prompt}\n\n\nHello!\nmodel\n"
18 |
19 |
20 | def main():
21 | voice_list = VoiceList()
22 | possible_instructions = [
23 | v.instructions for v in voice_list.voices if v.instructions is not None
24 | ]
25 |
26 | prompts = []
27 |
28 | for _ in range(10000):
29 | instructions = random.choice(possible_instructions)
30 |
31 | # This will lead to some amount of kv-caching because the system prompts have
32 | # common prefixes. But some of the dynamic prompts will be changing so that
33 | # will break the cache at the point where they differ, which should lead to
34 | # a realistic load.
35 | system_prompt = instructions.make_system_prompt()
36 | full_prompt = METATEMPLATE.format(system_prompt=system_prompt)
37 | prompts.append(full_prompt)
38 |
39 | s = json.dumps(
40 | [
41 | {
42 | "id": random_id(),
43 | "conversations": [
44 | {
45 | "from": "human",
46 | "value": full_prompt,
47 | },
48 | # The vLLM benchmark script looks at the length of the response to
49 | # know how many tokens to generate. This seems like a reasonable
50 | # length of the response
51 | {
52 | "from": "gpt",
53 | "value": "Here are the main ideas of Jeff Walker's Product Launch Formula that can be applied by a growth marketing agency for their clients.",
54 | },
55 | ],
56 | }
57 | for full_prompt in prompts
58 | ],
59 | indent=2,
60 | )
61 | print(s)
62 |
63 |
64 | if __name__ == "__main__":
65 | main()
66 |
--------------------------------------------------------------------------------
/unmute/loadtest/loadtest_result.py:
--------------------------------------------------------------------------------
1 | from typing import Literal
2 |
3 | import numpy as np
4 | from pydantic import BaseModel, model_validator
5 |
6 |
7 | class UserMessageTiming(BaseModel):
8 | audio_start: float
9 | text_start: float
10 | audio_end: float
11 |
12 | @model_validator(mode="after")
13 | def validate_timing(self):
14 | # Note that text_start and audio_end can be in either order
15 | if not (self.audio_start < self.text_start) or not (
16 | self.audio_start < self.audio_end
17 | ):
18 | raise ValueError(f"Invalid timing: {self}")
19 | return self
20 |
21 |
22 | class AssistantMessageTiming(BaseModel):
23 | response_created: float
24 | text_start: float
25 | audio_start: float
26 | audio_end: float
27 | received_audio_length: float
28 |
29 | @model_validator(mode="after")
30 | def validate_timing(self):
31 | if not (self.response_created < self.audio_start < self.audio_end):
32 | raise ValueError(f"Invalid timing: {self}")
33 | return self
34 |
35 |
36 | class BenchmarkUserMessage(BaseModel):
37 | role: Literal["user"] = "user"
38 | content: str
39 | timing: UserMessageTiming
40 |
41 |
42 | class BenchmarkAssistantMessage(BaseModel):
43 | role: Literal["assistant"] = "assistant"
44 | content: str
45 | timing: AssistantMessageTiming
46 |
47 |
48 | BenchmarkMessage = BenchmarkUserMessage | BenchmarkAssistantMessage
49 |
50 |
51 | class LatencyReport(BaseModel):
52 | stt_latencies: list[float]
53 | vad_latencies: list[float]
54 | llm_latencies: list[float]
55 | tts_start_latencies: list[float]
56 | tts_realtime_factors: list[float]
57 |
58 | def compress(self):
59 | return LatencyReport(
60 | stt_latencies=[float(np.mean(self.stt_latencies))],
61 | vad_latencies=[float(np.mean(self.vad_latencies))],
62 | llm_latencies=[float(np.mean(self.llm_latencies))],
63 | tts_start_latencies=[float(np.mean(self.tts_start_latencies))],
64 | tts_realtime_factors=[float(np.mean(self.tts_realtime_factors))],
65 | )
66 |
67 |
68 | def combine_latency_reports(reports: list[LatencyReport]) -> LatencyReport:
69 | return LatencyReport(
70 | stt_latencies=[lat for r in reports for lat in r.stt_latencies],
71 | vad_latencies=[lat for r in reports for lat in r.vad_latencies],
72 | llm_latencies=[lat for r in reports for lat in r.llm_latencies],
73 | tts_start_latencies=[lat for r in reports for lat in r.tts_start_latencies],
74 | tts_realtime_factors=[
75 | factor for r in reports for factor in r.tts_realtime_factors
76 | ],
77 | )
78 |
79 |
80 | def make_latency_report(
81 | benchmark_chat_history: list[BenchmarkMessage],
82 | ) -> LatencyReport:
83 | stt_latencies = []
84 | vad_latencies = []
85 | llm_latencies = []
86 | tts_start_latencies = []
87 | tts_realtime_factors = []
88 |
89 | for i in range(len(benchmark_chat_history)):
90 | m = benchmark_chat_history[i]
91 |
92 | if isinstance(m, BenchmarkAssistantMessage):
93 | realtime_factor = m.timing.received_audio_length / (
94 | m.timing.audio_end - m.timing.audio_start
95 | )
96 | tts_realtime_factors.append(realtime_factor)
97 | llm_latencies.append(m.timing.text_start - m.timing.response_created)
98 | tts_start_latencies.append(m.timing.audio_start - m.timing.text_start)
99 |
100 | if i > 0:
101 | vad_latency = (
102 | m.timing.response_created
103 | - benchmark_chat_history[i - 1].timing.audio_end
104 | )
105 | vad_latencies.append(vad_latency)
106 | elif isinstance(m, BenchmarkUserMessage): # type: ignore
107 | stt_latency = m.timing.text_start - m.timing.audio_start
108 | stt_latencies.append(stt_latency)
109 |
110 | return LatencyReport(
111 | stt_latencies=stt_latencies,
112 | vad_latencies=vad_latencies,
113 | llm_latencies=llm_latencies,
114 | tts_start_latencies=tts_start_latencies,
115 | tts_realtime_factors=tts_realtime_factors,
116 | )
117 |
--------------------------------------------------------------------------------
/unmute/loadtest/voices/Bear-or-shark-trim.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/unmute/d85fbf05a90d2f019a10bf521c1c029d3208b280/unmute/loadtest/voices/Bear-or-shark-trim.mp3
--------------------------------------------------------------------------------
/unmute/loadtest/voices/dog-or-cat-3-nowait.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/unmute/d85fbf05a90d2f019a10bf521c1c029d3208b280/unmute/loadtest/voices/dog-or-cat-3-nowait.mp3
--------------------------------------------------------------------------------
/unmute/loadtest/voices/seine.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/unmute/d85fbf05a90d2f019a10bf521c1c029d3208b280/unmute/loadtest/voices/seine.mp3
--------------------------------------------------------------------------------
/unmute/loadtest/voices/vaclav_english_news_trim.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/unmute/d85fbf05a90d2f019a10bf521c1c029d3208b280/unmute/loadtest/voices/vaclav_english_news_trim.mp3
--------------------------------------------------------------------------------
/unmute/main_gradio.py:
--------------------------------------------------------------------------------
1 | import pprint
2 | from typing import Any
3 |
4 | import gradio as gr
5 | import pandas as pd
6 | import plotly.express as px
7 | from fastrtc import Stream, get_hf_turn_credentials
8 |
9 | from unmute.unmute_handler import GradioUpdate, UnmuteHandler
10 |
11 | if __name__ == "__main__":
12 | gradio_chatbot = gr.Chatbot(type="messages")
13 | gradio_debug_textbox = gr.Textbox(label="Debug dict")
14 | gradio_debug_plot = gr.Plot(label="Debug plot")
15 |
16 | def update_outputs(
17 | _chatbot_state: Any,
18 | _debug_textbox_state: Any,
19 | _debug_plot_state: Any,
20 | update: GradioUpdate,
21 | ):
22 | # Not sure if this is expected behavior, but it seems necessary to send updates
23 | # to all of the components even if you don't want to change them. Otherwise they
24 | # get overwritten.
25 | chatbot_state = update.chat_history
26 | debug_textbox_state = pprint.pformat(update.debug_dict)
27 |
28 | debug_plot_data_variables = set().union(
29 | *[x.keys() for x in update.debug_plot_data],
30 | ) - {"t"}
31 |
32 | if debug_plot_data_variables:
33 | df = pd.DataFrame(update.debug_plot_data)
34 | df = df.ffill()
35 |
36 | fig = px.line(
37 | df,
38 | x="t",
39 | y=sorted(list(debug_plot_data_variables)),
40 | )
41 | else:
42 | fig = None
43 |
44 | return chatbot_state, debug_textbox_state, fig
45 |
46 | rtc_configuration = get_hf_turn_credentials()
47 | # rtc_configuration = get_cloudflare_rtc_configuration()
48 |
49 | stream = Stream(
50 | handler=UnmuteHandler(),
51 | modality="audio",
52 | mode="send-receive",
53 | # rtc_configuration=rtc_configuration,
54 | rtc_configuration=rtc_configuration,
55 | # additional_inputs=[gradio_chatbot],
56 | additional_outputs=[gradio_chatbot, gradio_debug_textbox, gradio_debug_plot],
57 | additional_outputs_handler=update_outputs,
58 | # TODO: check if clients actually get disconnected
59 | concurrency_limit=1,
60 | )
61 |
62 | # This variable needs to contain the Gradio UI for the autoreload to work:
63 | # https://www.gradio.app/guides/developing-faster-with-reload-mode
64 | demo = stream.ui
65 |
66 | # Not clear what `debug` does. It's not auto-reload.
67 | demo.launch(debug=False)
68 |
--------------------------------------------------------------------------------
/unmute/metrics.py:
--------------------------------------------------------------------------------
1 | from prometheus_client import Counter, Gauge, Histogram, Summary
2 |
3 | SESSION_DURATION_BINS = [1.0, 10.0, 30.0, 60.0, 120.0, 240.0, 480.0, 960.0, 1920.0]
4 | TURN_DURATION_BINS = [0.5, 1.0, 5.0, 10.0, 20.0, 40.0, 60.0]
5 | GENERATION_DURATION_BINS = [0.1, 0.25, 0.5, 1.0, 2.0, 5.0, 10.0, 20.0]
6 |
7 | PING_BINS_MS = [1.0, 5.0, 10.0, 25.0, 50.0, 100.0, 200.0]
8 | PING_BINS = [x / 1000 for x in PING_BINS_MS]
9 |
10 | # Time to first token.
11 | TTFT_BINS_STT_MS = [
12 | 10.0,
13 | 15.0,
14 | 25.0,
15 | 50.0,
16 | 75.0,
17 | 100.0,
18 | ]
19 | TTFT_BINS_STT = [x / 1000 for x in TTFT_BINS_STT_MS]
20 |
21 | TTFT_BINS_TTS_MS = [
22 | 200.0,
23 | 250.0,
24 | 300.0,
25 | 350.0,
26 | 400.0,
27 | 450.0,
28 | 500.0,
29 | 550.0,
30 | ]
31 | TTFT_BINS_TTS = [x / 1000 for x in TTFT_BINS_TTS_MS]
32 |
33 |
34 | TTFT_BINS_VLLM_MS = [
35 | 50.0,
36 | 75.0,
37 | 100.0,
38 | 150.0,
39 | 200.0,
40 | 250.0,
41 | 300.0,
42 | 400.0,
43 | 500.0,
44 | 750.0,
45 | 1000.0,
46 | ]
47 | TTFT_BINS_VLLM = [x / 1000 for x in TTFT_BINS_VLLM_MS]
48 |
49 | NUM_WORDS_REQUEST_BINS = [
50 | 50.0,
51 | 100.0,
52 | 200.0,
53 | 500.0,
54 | 1000.0,
55 | 2000.0,
56 | 4000.0,
57 | 6000.0,
58 | 8000.0,
59 | ]
60 | NUM_WORDS_STT_BINS = [0.0, 50.0, 100.0, 200.0, 500.0, 1000.0, 2000.0, 4000.0]
61 | NUM_WORDS_REPLY_BINS = [5.0, 10.0, 25.0, 50.0, 100.0, 200.0]
62 |
63 | SESSIONS = Counter("worker_sessions", "")
64 | SERVICE_MISSES = Counter("worker_service_misses", "")
65 | HARD_SERVICE_MISSES = Counter("worker_hard_service_misses", "")
66 | FORCE_DISCONNECTS = Counter("worker_force_disconnects", "")
67 | FATAL_SERVICE_MISSES = Counter("worker_fatal_service_misses", "")
68 | HARD_ERRORS = Counter("worker_hard_errors", "")
69 | ACTIVE_SESSIONS = Gauge("worker_active_sessions", "")
70 | SESSION_DURATION = Histogram(
71 | "worker_session_duration", "", buckets=SESSION_DURATION_BINS
72 | )
73 | HEALTH_OK = Summary("worker_health_ok", "")
74 |
75 | STT_SESSIONS = Counter("worker_stt_sessions", "")
76 | STT_ACTIVE_SESSIONS = Gauge("worker_stt_active_sessions", "")
77 | STT_MISSES = Counter("worker_stt_misses", "")
78 | STT_HARD_MISSES = Counter("worker_stt_hard_misses", "")
79 | STT_SENT_FRAMES = Counter("worker_stt_sent_frames", "")
80 | STT_RECV_FRAMES = Counter("worker_stt_recv_frames", "")
81 | STT_RECV_WORDS = Counter("worker_stt_recv_words", "")
82 | STT_PING_TIME = Histogram("worker_stt_ping_time", "", buckets=PING_BINS)
83 | STT_FIND_TIME = Histogram("worker_stt_find_time", "", buckets=PING_BINS)
84 | STT_SESSION_DURATION = Histogram(
85 | "worker_stt_session_duration", "", buckets=SESSION_DURATION_BINS
86 | )
87 | STT_AUDIO_DURATION = Histogram(
88 | "worker_stt_audio_duration", "", buckets=SESSION_DURATION_BINS
89 | )
90 | STT_NUM_WORDS = Histogram("worker_stt_num_words", "", buckets=NUM_WORDS_STT_BINS)
91 | STT_TTFT = Histogram("worker_stt_ttft", "", buckets=TTFT_BINS_STT)
92 |
93 | TTS_SESSIONS = Counter("worker_tts_sessions", "")
94 | TTS_ACTIVE_SESSIONS = Gauge("worker_tts_active_sessions", "")
95 | TTS_MISSES = Counter("worker_tts_misses", "")
96 | TTS_HARD_MISSES = Counter("worker_hard_tts_misses", "")
97 | TTS_INTERRUPT = Counter("worker_tts_interrupt", "")
98 | TTS_SENT_FRAMES = Counter("worker_tts_sent_frames", "")
99 | TTS_RECV_FRAMES = Counter("worker_tts_recv_frames", "")
100 | TTS_RECV_WORDS = Counter("worker_tts_recv_words", "")
101 | TTS_PING_TIME = Histogram("worker_tts_ping_time", "", buckets=PING_BINS)
102 | TTS_FIND_TIME = Histogram("worker_tts_find_time", "", buckets=PING_BINS)
103 | TTS_TTFT = Histogram("worker_tts_ttft", "", buckets=TTFT_BINS_TTS)
104 | TTS_AUDIO_DURATION = Histogram(
105 | "worker_tts_audio_duration", "", buckets=TURN_DURATION_BINS
106 | )
107 | TTS_GEN_DURATION = Histogram(
108 | "worker_tts_gen_duration", "", buckets=GENERATION_DURATION_BINS
109 | )
110 |
111 | VLLM_SESSIONS = Counter("worker_vllm_sessions", "")
112 | VLLM_ACTIVE_SESSIONS = Gauge("worker_vllm_active_sessions", "")
113 | VLLM_INTERRUPTS = Counter("worker_vllm_interrupt", "")
114 | VLLM_HARD_ERRORS = Counter("worker_vllm_hard_errors", "")
115 | VLLM_SENT_WORDS = Counter("worker_vllm_sent_words", "")
116 | VLLM_RECV_WORDS = Counter("worker_vllm_recv_words", "")
117 | VLLM_TTFT = Histogram("worker_vllm_ttft", "", buckets=TTFT_BINS_VLLM)
118 | VLLM_REQUEST_LENGTH = Histogram(
119 | "worker_vllm_request_length", "", buckets=NUM_WORDS_REQUEST_BINS
120 | )
121 | VLLM_REPLY_LENGTH = Histogram(
122 | "worker_vllm_reply_length", "", buckets=NUM_WORDS_REPLY_BINS
123 | )
124 | VLLM_GEN_DURATION = Histogram(
125 | "worker_vllm_gen_duration", "", buckets=GENERATION_DURATION_BINS
126 | )
127 |
128 | VOICE_DONATION_SUBMISSIONS = Counter("worker_voice_donation_submissions", "")
129 |
--------------------------------------------------------------------------------
/unmute/recorder.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import uuid
3 | from datetime import datetime
4 | from pathlib import Path
5 | from typing import Annotated, Literal
6 |
7 | import aiofiles
8 | from pydantic import BaseModel, Field
9 |
10 | import unmute.openai_realtime_api_events as ora
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 | EventSender = Literal["client", "server"]
15 |
16 |
17 | class RecorderEvent(BaseModel):
18 | timestamp_wall: float
19 | event_sender: EventSender
20 | data: Annotated[ora.Event, Field(discriminator="type")]
21 |
22 |
23 | class Recorder:
24 | """Record the events sent between the client and the server to a file.
25 |
26 | Doesn't include the user audio for privacy reasons.
27 | """
28 |
29 | def __init__(self, recordings_dir: Path):
30 | self.path = recordings_dir / (make_filename() + ".jsonl")
31 | recordings_dir.mkdir(exist_ok=True)
32 | # We use aiofiles to avoid blocking the event loop when writing to the file.
33 | self.opened_file = None
34 |
35 | async def add_event(self, event_sender: EventSender, data: ora.Event):
36 | """If the recorder is not actually running, the event will be ignored."""
37 | if self.opened_file is None:
38 | self.opened_file = await aiofiles.open(self.path, "a")
39 |
40 | await self.opened_file.write(
41 | RecorderEvent(
42 | timestamp_wall=datetime.now().timestamp(),
43 | event_sender=event_sender,
44 | data=data,
45 | ).model_dump_json()
46 | + "\n"
47 | )
48 |
49 | async def shutdown(self, keep_recording: bool = True):
50 | """Flush any remaining events to the file and close the recorder.
51 |
52 | If `keep_recording` is False, the file will be deleted if it exists.
53 | This is because we get the user consent after we've already started recording,
54 | so if the user doesn't consent, we delete the file afterwards.
55 | """
56 | if self.opened_file is not None:
57 | await self.opened_file.close()
58 | if keep_recording:
59 | logger.info(f"Recording stored into {self.path}.")
60 | else:
61 | try:
62 | self.path.unlink()
63 | logger.info(
64 | f"Deleted recording {self.path} due to lack of consent."
65 | )
66 | except Exception as e:
67 | logger.error(f"Failed to delete recording file {self.path}: {e}")
68 |
69 |
70 | def make_filename() -> str:
71 | """Create a unique filename based on the current timestamp and a short UUID, without a suffix."""
72 | timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
73 | unique_id = uuid.uuid4().hex[:4]
74 | return f"{timestamp}_{unique_id}"
75 |
--------------------------------------------------------------------------------
/unmute/scripts/check_hugging_face_token_not_write.py:
--------------------------------------------------------------------------------
1 | """Check that a Hugging Face token does not have write access."""
2 |
3 | import argparse
4 |
5 | import requests
6 |
7 |
8 | def abbreviate_token(token: str) -> str:
9 | """Abbreviate the token for display."""
10 | assert len(token) > 10
11 | return f"{token[:4]}...{token[-4:]}"
12 |
13 |
14 | def main(token: str):
15 | response = requests.get(
16 | "https://huggingface.co/api/whoami-v2",
17 | headers={"Authorization": f"Bearer {token}"},
18 | timeout=10,
19 | )
20 | response.raise_for_status()
21 |
22 | # Example response:
23 | # {
24 | # [...]
25 | # "auth": {
26 | # "type": "access_token",
27 | # "accessToken": {
28 | # "displayName": "foo",
29 | # "role": "write",
30 | # "createdAt": "2025-03-18T10:40:56.186Z"
31 | # }
32 | # }
33 | # }
34 |
35 | data = response.json()
36 | if data["auth"]["type"] != "access_token":
37 | raise ValueError(f"Unexpected auth type: {data['auth']['type']}.")
38 |
39 | role = data["auth"]["accessToken"]["role"]
40 | if role == "fineGrained":
41 | # Harder to test. As a heuristic, just look for "write" somewhere in the JSON.
42 | if "write" in str(data["auth"]["accessToken"]["fineGrained"]).lower():
43 | raise ValueError(
44 | "The provided fine-grained Hugging Face token "
45 | f"{abbreviate_token(token)} has write access. "
46 | "Use a read-only token to deploy. "
47 | "It has the following permissions: "
48 | f"{data['auth']['accessToken']['fineGrained']}"
49 | )
50 | elif role == "write":
51 | raise ValueError(
52 | f"The provided Hugging Face token {abbreviate_token(token)} has write "
53 | "access. Use a read-only token to deploy."
54 | )
55 | else:
56 | if role != "read":
57 | raise ValueError(
58 | f"Unknown token role: {role}. Use a read-only token to deploy."
59 | )
60 |
61 | print("Ok, Hugging Face token has no write access.")
62 |
63 |
64 | if __name__ == "__main__":
65 | parser = argparse.ArgumentParser(
66 | description="Check that the Hugging Face token does not have write access. "
67 | "This is because we don't want the deployed containers to have write access "
68 | "in case they are compromised. "
69 | "Exits with non-zero exit code if the token has write access or something else "
70 | "goes wrong."
71 | )
72 | parser.add_argument("token", type=str, help="Hugging Face token to check. ")
73 |
74 | args = parser.parse_args()
75 | main(args.token)
76 |
--------------------------------------------------------------------------------
/unmute/scripts/copy_voice_to_prod.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from unmute.tts.voices import copy_voice_to_prod
4 |
5 | if __name__ == "__main__":
6 | parser = argparse.ArgumentParser(description="Copy voice to production server")
7 | parser.add_argument(
8 | "path_on_server",
9 | type=str,
10 | help="The path by which the voice is referred to by the TTS server "
11 | "(=relative to the voice directory)",
12 | )
13 | args = parser.parse_args()
14 |
15 | copy_voice_to_prod(args.path_on_server)
16 |
--------------------------------------------------------------------------------
/unmute/scripts/example_websocket_client.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import asyncio
3 | import base64
4 | import json
5 | import os
6 | from pathlib import Path
7 |
8 | import numpy as np
9 | import pydub
10 | import pydub.playback
11 | import sphn
12 | import websockets
13 | from fastrtc import audio_to_int16
14 |
15 | from unmute.kyutai_constants import SAMPLE_RATE
16 |
17 | INPUT_FRAME_SIZE = 960
18 | TARGET_SAMPLE_RATE = 24000
19 | TARGET_CHANNELS = 1 # Mono
20 |
21 |
22 | def base64_encode_audio(audio: np.ndarray):
23 | pcm_bytes = audio_to_int16(audio)
24 | encoded = base64.b64encode(pcm_bytes).decode("ascii")
25 | return encoded
26 |
27 |
28 | async def send_messages(websocket: websockets.ClientConnection, audio_path: Path):
29 | data, _sr = sphn.read(audio_path, sample_rate=SAMPLE_RATE)
30 | data = data[0] # Take first channel to make it mono
31 |
32 | try:
33 | while True:
34 | chunk_size = 1920 # Send data in chunks
35 | for i in range(0, len(data), chunk_size):
36 | event = {
37 | "type": "input_audio_buffer.append",
38 | "audio": base64_encode_audio(data[i : i + chunk_size]),
39 | }
40 |
41 | await websocket.send(json.dumps(event))
42 | await asyncio.sleep(0.01) # Simulate real-time streaming
43 |
44 | await websocket.send(json.dumps({"type": "input_audio_buffer.commit"}))
45 | await websocket.send(json.dumps({"type": "response.create"}))
46 |
47 | for _ in range(0, len(data), chunk_size):
48 | event = {
49 | "type": "input_audio_buffer.append",
50 | "audio": base64_encode_audio(
51 | np.zeros(chunk_size, dtype=np.float32)
52 | ),
53 | }
54 |
55 | await websocket.send(json.dumps(event))
56 | await asyncio.sleep(0.01) # Simulate real-time streaming
57 | except websockets.ConnectionClosed:
58 | print("Connection closed while sending messages.")
59 |
60 |
61 | async def receive_messages(websocket: websockets.ClientConnection):
62 | buffer = []
63 | transcript = ""
64 |
65 | try:
66 | async for message in websocket:
67 | message = json.loads(message)
68 | if message["type"] == "response.audio.delta":
69 | base64_audio = message["delta"]
70 | binary_audio_data = base64.b64decode(base64_audio)
71 | buffer.append(binary_audio_data)
72 | elif message["type"] == "response.audio.done":
73 | print("Received `response.audio.done` message.")
74 | break
75 | elif message["type"] == "response.audio_transcript.delta":
76 | transcript += message["delta"]
77 | print(message["delta"], end="", flush=True)
78 | else:
79 | print(f"Received message: {message}")
80 | except websockets.ConnectionClosed:
81 | print("Connection closed while receiving messages.")
82 |
83 | # save and play using pydub
84 | audio = pydub.AudioSegment(
85 | data=b"".join(buffer),
86 | sample_width=2,
87 | frame_rate=TARGET_SAMPLE_RATE,
88 | channels=TARGET_CHANNELS,
89 | )
90 | audio.export("output.wav", format="wav")
91 | pydub.playback.play(audio)
92 |
93 |
94 | async def main(audio_path: Path, server_url: str):
95 | if "openai.com" in server_url:
96 | additional_headers = {
97 | "Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}",
98 | "OpenAI-Beta": "realtime=v1",
99 | }
100 | query_string = "model=gpt-4o-realtime-preview"
101 | else:
102 | additional_headers = {}
103 | query_string = ""
104 |
105 | async with websockets.connect(
106 | f"{server_url}/v1/realtime?{query_string}",
107 | additional_headers=additional_headers,
108 | subprotocols=[websockets.Subprotocol("realtime")],
109 | ) as websocket:
110 | send_task = asyncio.create_task(send_messages(websocket, audio_path))
111 | receive_task = asyncio.create_task(receive_messages(websocket))
112 | await asyncio.gather(send_task, receive_task)
113 |
114 |
115 | if __name__ == "__main__":
116 | parser = argparse.ArgumentParser()
117 | parser.add_argument("--server-url", type=str, default="wss://api.openai.com")
118 | parser.add_argument("audio_path", type=Path)
119 | args = parser.parse_args()
120 |
121 | asyncio.run(main(args.audio_path, server_url=args.server_url))
122 |
--------------------------------------------------------------------------------
/unmute/scripts/mistral_streaming.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from mistralai import Mistral
4 |
5 | if __name__ == "__main__":
6 | mistral_api_key = os.getenv("MISTRAL_API_KEY")
7 | if not mistral_api_key:
8 | raise ValueError("MISTRAL_API_KEY environment variable must be set")
9 |
10 | model = "mistral-small-latest"
11 |
12 | client = Mistral(api_key=mistral_api_key)
13 |
14 | res = client.chat.stream(
15 | model=model,
16 | messages=[
17 | {
18 | "role": "system",
19 | "content": "Keep your responses to at most a few sentences. "
20 | "They will be spoken out loud, so don't worry about formatting. "
21 | "Write as a human would speak.",
22 | },
23 | {
24 | "role": "user",
25 | "content": "What is the best French cheese?",
26 | },
27 | ],
28 | )
29 |
30 | with res as event_stream:
31 | for event in event_stream:
32 | content = event.data.choices[0].delta.content
33 | print(content, flush=True, end="")
34 |
35 | print("")
36 |
--------------------------------------------------------------------------------
/unmute/scripts/output_from_file.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import asyncio
3 | from pathlib import Path
4 |
5 | import numpy as np
6 | import sphn
7 | from fastrtc import AsyncStreamHandler, Stream, wait_for_item
8 |
9 | SAMPLE_RATE = 24000
10 | # 480 works but the default 960 doesn't!!!
11 | OUTPUT_FRAME_SIZE = 480
12 |
13 |
14 | class FilePlaybackHandler(AsyncStreamHandler):
15 | def __init__(self, audio_path: Path) -> None:
16 | super().__init__(
17 | input_sample_rate=SAMPLE_RATE,
18 | output_sample_rate=SAMPLE_RATE,
19 | output_frame_size=OUTPUT_FRAME_SIZE,
20 | )
21 | self.output_queue = asyncio.Queue()
22 | self.audio_path = audio_path
23 |
24 | async def receive(self, frame: tuple[int, np.ndarray]) -> None:
25 | pass
26 |
27 | async def emit(self) -> tuple[int, np.ndarray]:
28 | return await wait_for_item(self.output_queue)
29 |
30 | def copy(self):
31 | return FilePlaybackHandler(self.audio_path)
32 |
33 | async def start_up(self) -> None:
34 | data, _sr = sphn.read(self.audio_path, sample_rate=SAMPLE_RATE)
35 | data = data[0] # Take first channel to make it mono
36 |
37 | simulated_ratio = 1.5
38 |
39 | for i in range(0, len(data), OUTPUT_FRAME_SIZE):
40 | await self.output_queue.put((SAMPLE_RATE, data[i : i + OUTPUT_FRAME_SIZE]))
41 | await asyncio.sleep(OUTPUT_FRAME_SIZE / SAMPLE_RATE / simulated_ratio)
42 |
43 |
44 | if __name__ == "__main__":
45 | parser = argparse.ArgumentParser()
46 | parser.add_argument("file", type=Path)
47 | args = parser.parse_args()
48 |
49 | stream = Stream(
50 | handler=FilePlaybackHandler(args.file),
51 | modality="audio",
52 | mode="send-receive",
53 | )
54 |
55 | stream.ui.launch(debug=True)
56 |
--------------------------------------------------------------------------------
/unmute/scripts/output_sine.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from fastrtc import Stream, StreamHandler, get_hf_turn_credentials
3 |
4 | from unmute.audio_stream_saver import AudioStreamSaver
5 |
6 | SAMPLE_RATE = 24000
7 | OUTPUT_FRAME_SIZE = 1920
8 |
9 | # logging.basicConfig(level=logging.DEBUG)
10 |
11 |
12 | class SineHandler(StreamHandler):
13 | def __init__(self) -> None:
14 | super().__init__(input_sample_rate=SAMPLE_RATE, output_frame_size=960)
15 | self.cur_time_samples = 0
16 | self.saver = AudioStreamSaver()
17 |
18 | def receive(self, frame: tuple[int, np.ndarray]) -> None:
19 | pass
20 |
21 | def emit(self) -> tuple[int, np.ndarray]:
22 | times = np.arange(
23 | self.cur_time_samples,
24 | self.cur_time_samples + OUTPUT_FRAME_SIZE,
25 | )
26 | x = np.sin(2 * np.pi * 440 / SAMPLE_RATE * times) * 0.1
27 | x = x.astype(np.float32)
28 | self.cur_time_samples += OUTPUT_FRAME_SIZE
29 |
30 | self.saver.add(x)
31 |
32 | return (SAMPLE_RATE, x)
33 |
34 | def copy(self):
35 | return SineHandler()
36 |
37 | def shutdown(self):
38 | pass
39 |
40 | def start_up(self) -> None:
41 | pass
42 |
43 |
44 | if __name__ == "__main__":
45 | # rtc_configuration = get_cloudflare_rtc_configuration()
46 | rtc_configuration = get_hf_turn_credentials()
47 | stream = Stream(
48 | handler=SineHandler(),
49 | modality="audio",
50 | mode="send-receive",
51 | rtc_configuration=rtc_configuration,
52 | )
53 |
54 | stream.ui.launch(debug=True)
55 |
--------------------------------------------------------------------------------
/unmute/scripts/output_sine_async.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 |
3 | import numpy as np
4 | from fastrtc import AsyncStreamHandler, Stream
5 |
6 | from unmute.audio_stream_saver import AudioStreamSaver
7 |
8 | SAMPLE_RATE = 24000
9 | OUTPUT_FRAME_SIZE = 1920
10 |
11 |
12 | class SineHandler(AsyncStreamHandler):
13 | def __init__(self) -> None:
14 | super().__init__(input_sample_rate=SAMPLE_RATE)
15 | self.cur_time_samples = 0
16 | self.saver = AudioStreamSaver()
17 |
18 | async def receive(self, frame: tuple[int, np.ndarray]) -> None:
19 | pass
20 |
21 | async def emit(self) -> tuple[int, np.ndarray]:
22 | times = np.arange(
23 | self.cur_time_samples,
24 | self.cur_time_samples + OUTPUT_FRAME_SIZE,
25 | )
26 | x = np.sin(2 * np.pi * 440 / SAMPLE_RATE * times) * 0.3
27 | x = x.astype(np.float32)
28 | self.cur_time_samples += OUTPUT_FRAME_SIZE
29 |
30 | self.saver.add(x)
31 |
32 | await asyncio.sleep(0.01)
33 |
34 | return (SAMPLE_RATE, x)
35 |
36 | def copy(self):
37 | return SineHandler()
38 |
39 | async def start_up(self) -> None:
40 | pass
41 |
42 |
43 | if __name__ == "__main__":
44 | stream = Stream(
45 | handler=SineHandler(),
46 | modality="audio",
47 | mode="send-receive",
48 | )
49 |
50 | stream.ui.launch(debug=True)
51 |
--------------------------------------------------------------------------------
/unmute/scripts/output_tts.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import time
3 |
4 | import numpy as np
5 | import websockets
6 | from fastrtc import AsyncStreamHandler, Stream, wait_for_item
7 |
8 | from unmute.tts.text_to_speech import TextToSpeech, TTSAudioMessage
9 |
10 | SAMPLE_RATE = 24000
11 | OUTPUT_FRAME_SIZE = 480
12 |
13 |
14 | class TTSHandler(AsyncStreamHandler):
15 | def __init__(self) -> None:
16 | super().__init__(
17 | input_sample_rate=SAMPLE_RATE,
18 | output_sample_rate=SAMPLE_RATE,
19 | output_frame_size=OUTPUT_FRAME_SIZE,
20 | )
21 | self.tts = TextToSpeech()
22 | self.output_queue = asyncio.Queue()
23 | self.go = False
24 | self.cur_time_samples = 0
25 |
26 | async def receive(self, frame: tuple[int, np.ndarray]) -> None:
27 | pass
28 |
29 | async def emit(self) -> tuple[int, np.ndarray]:
30 | # if not self.output_queue.empty():
31 | # return await self.output_queue.get()
32 |
33 | # times = np.arange(
34 | # self.cur_time_samples,
35 | # self.cur_time_samples + OUTPUT_FRAME_SIZE,
36 | # )
37 | # x = np.sin(2 * np.pi * 440 / SAMPLE_RATE * times) * 0.3
38 | # x = x.astype(np.float32)
39 | # self.cur_time_samples += OUTPUT_FRAME_SIZE
40 |
41 | # await asyncio.sleep(0.01)
42 |
43 | # return (SAMPLE_RATE, x)
44 |
45 | return await wait_for_item(self.output_queue)
46 | # return await self.output_queue.get()
47 |
48 | def copy(self):
49 | return TTSHandler()
50 |
51 | async def start_up(self) -> None:
52 | asyncio.create_task(self._tts_loop())
53 |
54 | async def _tts_loop(self):
55 | await self.tts.start_up()
56 |
57 | await self.tts.send(" ".join(["Hello, world! "] * 10))
58 |
59 | try:
60 | audio_started = None
61 |
62 | async for message in self.tts:
63 | if audio_started is not None:
64 | time_since_start = time.time() - audio_started
65 | time_received = self.tts.received_samples / self.input_sample_rate
66 | ratio = time_received / time_since_start
67 | assert self.input_sample_rate == SAMPLE_RATE
68 | print(
69 | f"{time_received=:.2f}, {time_since_start=:.2f}, "
70 | f"ratio {ratio:.2f}"
71 | )
72 |
73 | if isinstance(message, TTSAudioMessage):
74 | audio = np.array(message.pcm, dtype=np.float32)
75 | assert self.output_sample_rate == SAMPLE_RATE
76 |
77 | assert len(audio) % OUTPUT_FRAME_SIZE == 0, (
78 | "Audio length must be a multiple of the frame size."
79 | )
80 | for i in range(0, len(audio), OUTPUT_FRAME_SIZE):
81 | await self.output_queue.put(
82 | (SAMPLE_RATE, audio[i : i + OUTPUT_FRAME_SIZE])
83 | )
84 | # await self.output_queue.put((SAMPLE_RATE, audio))
85 |
86 | if audio_started is None:
87 | audio_started = time.time()
88 |
89 | except websockets.ConnectionClosed:
90 | print("TTS connection closed while receiving messages.")
91 |
92 |
93 | if __name__ == "__main__":
94 | stream = Stream(
95 | handler=TTSHandler(),
96 | modality="audio",
97 | mode="send-receive",
98 | )
99 |
100 | stream.ui.launch(debug=True)
101 |
--------------------------------------------------------------------------------
/unmute/scripts/pitch_detection_handler.py:
--------------------------------------------------------------------------------
1 | from collections import deque
2 |
3 | import librosa
4 | import numpy as np
5 | from fastrtc import Stream, StreamHandler
6 |
7 | from unmute.audio_stream_saver import AudioStreamSaver
8 |
9 | SAMPLE_RATE = 24000
10 | OUTPUT_FRAME_SIZE = 1920
11 |
12 |
13 | class PitchDetectionHandler(StreamHandler):
14 | def __init__(self) -> None:
15 | super().__init__(input_sample_rate=SAMPLE_RATE, output_frame_size=480)
16 | self.cur_time_samples = 0
17 | self.saver = AudioStreamSaver()
18 | self.frequency_queue = deque()
19 | self.last_phase = 0
20 | self.last_frequency = 100
21 |
22 | def receive(self, frame: tuple[int, np.ndarray]) -> None:
23 | mono_audio = frame[1][0]
24 | assert mono_audio.dtype == np.int16
25 | mono_audio = mono_audio.astype(np.float32) / np.iinfo(np.int16).max
26 |
27 | freqs = librosa.yin(
28 | mono_audio,
29 | fmin=float(librosa.note_to_hz("E2")),
30 | fmax=float(librosa.note_to_hz("E5")),
31 | sr=SAMPLE_RATE,
32 | frame_length=len(mono_audio),
33 | hop_length=len(mono_audio),
34 | center=False,
35 | )
36 | assert len(freqs) == 1
37 | self.frequency_queue.append(freqs[0])
38 |
39 | def emit(self) -> tuple[int, np.ndarray] | None:
40 | if not self.frequency_queue:
41 | return None
42 | else:
43 | frequency = self.frequency_queue.popleft()
44 |
45 | phase = self.last_phase + np.cumsum(
46 | np.linspace(self.last_frequency, frequency, OUTPUT_FRAME_SIZE) / SAMPLE_RATE
47 | )
48 | self.last_phase = phase[-1] % 1.0
49 | amplitude = 0.1
50 | x = np.sin(2 * np.pi * phase) * amplitude
51 | x = x.astype(np.float32)
52 |
53 | self.cur_time_samples += OUTPUT_FRAME_SIZE
54 | self.saver.add(x)
55 | self.last_frequency = frequency
56 |
57 | return (SAMPLE_RATE, x)
58 |
59 | def copy(self):
60 | return PitchDetectionHandler()
61 |
62 |
63 | if __name__ == "__main__":
64 | stream = Stream(
65 | handler=PitchDetectionHandler(), modality="audio", mode="send-receive"
66 | )
67 |
68 | stream.ui.launch(debug=True)
69 |
--------------------------------------------------------------------------------
/unmute/scripts/stt_from_file_example.py:
--------------------------------------------------------------------------------
1 | """Run speech-to-text on an audio file in a non-streaming way."""
2 |
3 | import asyncio
4 | import logging
5 | from pathlib import Path
6 |
7 | import numpy as np
8 | import sphn
9 | import tqdm
10 |
11 | from unmute.kyutai_constants import SAMPLE_RATE, SAMPLES_PER_FRAME
12 | from unmute.stt.speech_to_text import SpeechToText, STTMarkerMessage, STTWordMessage
13 |
14 | TARGET_SAMPLE_RATE = 24000
15 | TARGET_CHANNELS = 1 # Mono
16 | logging.basicConfig(level=logging.INFO)
17 |
18 |
19 | def load_and_process_audio(audio_path: Path):
20 | data, _sr = sphn.read(audio_path, sample_rate=SAMPLE_RATE)
21 | data = data[0] # Take first channel to make it mono
22 | return data
23 |
24 |
25 | async def main(audio_path: Path):
26 | stt = SpeechToText()
27 | await stt.start_up()
28 |
29 | audio_data = load_and_process_audio(audio_path)
30 |
31 | for i in tqdm.trange(0, len(audio_data), SAMPLES_PER_FRAME, desc="Sending audio"):
32 | chunk = audio_data[i : i + SAMPLES_PER_FRAME]
33 | await stt.send_audio(chunk)
34 | await asyncio.sleep(SAMPLES_PER_FRAME / SAMPLE_RATE)
35 |
36 | # When we get the marker back from the server, we know it has processed the audio
37 | await stt.send_marker(0)
38 |
39 | # Send extra audio to make sure the marker is processed
40 | for _ in range(25):
41 | await stt.send_audio(np.zeros(SAMPLES_PER_FRAME, dtype=np.int16))
42 |
43 | words = []
44 |
45 | with tqdm.tqdm() as pbar:
46 | async for msg in stt:
47 | if isinstance(msg, STTWordMessage):
48 | words.append(msg)
49 | pbar.set_postfix(n_words=len(words))
50 | elif isinstance(msg, STTMarkerMessage): # pyright: ignore[reportUnnecessaryIsInstance]
51 | break
52 |
53 | pbar.update()
54 |
55 | print("\n".join(str(s) for s in words))
56 |
57 |
58 | if __name__ == "__main__":
59 | import argparse
60 |
61 | parser = argparse.ArgumentParser()
62 | parser.add_argument("audio_path", type=Path)
63 | args = parser.parse_args()
64 |
65 | asyncio.run(main(args.audio_path))
66 |
--------------------------------------------------------------------------------
/unmute/scripts/stt_microphone_example.py:
--------------------------------------------------------------------------------
1 | """Transcribe audio from the microphone in real-time."""
2 |
3 | import asyncio
4 | from typing import Any
5 |
6 | import numpy as np
7 |
8 | try:
9 | # We don't need this for anything else so it's not in the dependencies
10 | import sounddevice as sd # type: ignore
11 | except ImportError as e:
12 | raise ImportError(
13 | "Please install sounddevice to run this example: pip install sounddevice "
14 | "(or uv pip install sounddevice if you're using uv)."
15 | ) from e
16 | import tqdm
17 |
18 | from unmute.kyutai_constants import SAMPLES_PER_FRAME
19 | from unmute.stt.speech_to_text import (
20 | SpeechToText,
21 | STTMarkerMessage,
22 | STTWordMessage,
23 | )
24 |
25 |
26 | async def receive_loop(stt: SpeechToText):
27 | delay = None
28 | async for msg in stt:
29 | if isinstance(msg, STTWordMessage):
30 | print(f"Word: {msg.text} ({msg.start_time:.2f}s). Delay: {delay:.2f}s")
31 | elif isinstance(msg, STTMarkerMessage): # type: ignore
32 | marker_time = msg.id / 1000
33 | time = asyncio.get_event_loop().time()
34 | delay = time - marker_time
35 |
36 |
37 | async def main():
38 | stt = SpeechToText()
39 | await stt.start_up()
40 | audio_queue = asyncio.Queue()
41 |
42 | duration_sec = 30
43 |
44 | receive_task = asyncio.create_task(receive_loop(stt))
45 |
46 | def callback(indata: np.ndarray, frames: int, time: Any, status: sd.CallbackFlags):
47 | mono_audio = indata[:, 0]
48 | audio_queue.put_nowait(mono_audio.copy())
49 |
50 | start_time = asyncio.get_event_loop().time()
51 |
52 | audio_buffer = np.zeros((0,), dtype=np.float32)
53 |
54 | with sd.InputStream(callback=callback, blocksize=1024, samplerate=24000):
55 | pbar = tqdm.tqdm(total=duration_sec, desc="Recording", unit="s")
56 | while asyncio.get_event_loop().time() - start_time < duration_sec:
57 | try:
58 | audio_chunk = await asyncio.wait_for(audio_queue.get(), timeout=0.1)
59 | except asyncio.TimeoutError:
60 | continue
61 |
62 | pbar.set_postfix(
63 | volume=np.mean(np.abs(audio_chunk)),
64 | )
65 | # Updating this is a bit annoying
66 | # pbar.update(audio_chunk.shape[0] / 24000)
67 |
68 | audio_buffer = np.concatenate((audio_buffer, audio_chunk), axis=0)
69 | while audio_buffer.shape[0] > SAMPLES_PER_FRAME:
70 | audio_chunk = audio_buffer[:SAMPLES_PER_FRAME]
71 | audio_buffer = audio_buffer[SAMPLES_PER_FRAME:]
72 |
73 | await stt.send_marker(int(asyncio.get_event_loop().time() * 1000))
74 | await stt.send_audio(audio_chunk)
75 |
76 | receive_task.cancel()
77 | print(f"Quit after {duration_sec} seconds.")
78 |
79 |
80 | if __name__ == "__main__":
81 | asyncio.run(main())
82 |
--------------------------------------------------------------------------------
/unmute/scripts/tts_example.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import asyncio
3 | from pathlib import Path
4 |
5 | import numpy as np
6 | import sphn
7 | import tqdm
8 |
9 | from unmute.loadtest.loadtest_client import preview_audio
10 | from unmute.tts.text_to_speech import (
11 | TextToSpeech,
12 | TTSAudioMessage,
13 | TTSClientEosMessage,
14 | TTSTextMessage,
15 | )
16 | from unmute.tts.voice_cloning import clone_voice
17 |
18 |
19 | async def main(
20 | text: str, voice_file: Path | None = None, output_path: Path | None = None
21 | ):
22 | if voice_file:
23 | voice = clone_voice(voice_file.read_bytes())
24 | else:
25 | voice = None
26 |
27 | tts = TextToSpeech(voice=voice)
28 | await tts.start_up()
29 |
30 | for word in text.split(" "):
31 | await tts.send(word)
32 | await asyncio.sleep(0.1)
33 |
34 | await tts.send(TTSClientEosMessage())
35 |
36 | audio_chunks = []
37 | n_words = 0
38 |
39 | with tqdm.tqdm() as pbar:
40 | async for msg in tts:
41 | if isinstance(msg, TTSTextMessage):
42 | pbar.set_postfix(n_words=n_words)
43 | n_words += 1
44 | elif isinstance(msg, TTSAudioMessage):
45 | audio_chunks.append(msg.pcm)
46 | pbar.update(len(msg.pcm) / 24000)
47 |
48 | all_audio = np.concat(audio_chunks).astype(np.float32)
49 | preview_audio(all_audio)
50 |
51 | sphn.write_wav(output_path, all_audio, 24000)
52 | print(f"Saved to {output_path}")
53 |
54 |
55 | if __name__ == "__main__":
56 | parser = argparse.ArgumentParser()
57 | parser.add_argument(
58 | "--voice-file",
59 | type=Path,
60 | )
61 | parser.add_argument(
62 | "--output-path",
63 | type=Path,
64 | default=Path("out.wav"),
65 | help="Path to save the audio to, .wav or .ogg file (default: out.wav)",
66 | )
67 | parser.add_argument(
68 | "text",
69 | type=str,
70 | nargs="?",
71 | default="Did you know that the author of Octavia "
72 | "based one character on a former lover?",
73 | )
74 | args = parser.parse_args()
75 |
76 | asyncio.run(
77 | main(args.text, voice_file=args.voice_file, output_path=args.output_path)
78 | )
79 |
--------------------------------------------------------------------------------
/unmute/scripts/update_voice_list.py:
--------------------------------------------------------------------------------
1 | """Upload local voices to the server based on the voice list."""
2 |
3 | import asyncio
4 | import logging
5 |
6 | from unmute.tts.voices import VoiceList
7 |
8 |
9 | async def main():
10 | logging.basicConfig(level=logging.INFO)
11 |
12 | voice_list = VoiceList()
13 | await voice_list.upload_to_server()
14 | voice_list.save()
15 | print("Voices updated successfully. Voice list path:")
16 | print(voice_list.path)
17 |
18 |
19 | if __name__ == "__main__":
20 | asyncio.run(main())
21 |
--------------------------------------------------------------------------------
/unmute/scripts/vllm_wrapper_example.py:
--------------------------------------------------------------------------------
1 | # https://github.com/gabrielchua/async-stream-openai-st/blob/824eab8f3ab600d3689d8d946526e48e0e0310c2/app.py
2 | # https://qwen.readthedocs.io/en/latest/deployment/vllm.html#openai-compatible-api-service
3 |
4 | import time
5 | from typing import Any, cast
6 |
7 | from unmute.kyutai_constants import LLM_SERVER
8 | from unmute.llm.llm_utils import VLLMStream, get_openai_client, rechunk_to_words
9 |
10 | # Predefined message
11 | PREDEFINED_MESSAGE = "Explain the second law of thermodynamics"
12 |
13 |
14 | async def main(server_url: str):
15 | client = get_openai_client(server_url=server_url)
16 | s = VLLMStream(client)
17 |
18 | messages = [
19 | {"role": "system", "content": "You are a helpful assistant."},
20 | {
21 | "role": "user",
22 | "content": "Write a 200 word essay on 'bear vs shark'. "
23 | "The first line is a 2-3 word title with an emoji and then include "
24 | "2 line breaks. For example 'TITLE \n \n ' ",
25 | },
26 | ]
27 |
28 | start_time = time.time()
29 | first_token_time = None
30 | async for message in rechunk_to_words(s.chat_completion(cast(Any, messages))):
31 | if first_token_time is None:
32 | first_token_time = time.time()
33 | print(
34 | f"\nTime to first token: {first_token_time - start_time:.3f} seconds\n"
35 | )
36 | print(message, end="", flush=True)
37 |
38 | print()
39 |
40 |
41 | if __name__ == "__main__":
42 | import argparse
43 | import asyncio
44 |
45 | parser = argparse.ArgumentParser(description="Run VLLM wrapper example.")
46 | parser.add_argument(
47 | "--server-url",
48 | type=str,
49 | default=LLM_SERVER,
50 | help=f"The URL of the VLLM server (default: {LLM_SERVER}).",
51 | )
52 | args = parser.parse_args()
53 |
54 | asyncio.run(main(args.server_url))
55 |
--------------------------------------------------------------------------------
/unmute/stt/dummy_speech_to_text.py:
--------------------------------------------------------------------------------
1 | """A dummy speech-to-text that never sends any words.
2 |
3 | Useful for testing, like checking if not running the STT on the same GPU can reduce
4 | latency.
5 | """
6 |
7 | import asyncio
8 | from logging import getLogger
9 | from typing import AsyncIterator, Literal
10 |
11 | import numpy as np
12 |
13 | from unmute.kyutai_constants import FRAME_TIME_SEC, STT_DELAY_SEC, STT_SERVER
14 | from unmute.service_discovery import ServiceWithStartup
15 | from unmute.stt.exponential_moving_average import ExponentialMovingAverage
16 | from unmute.stt.speech_to_text import STTMarkerMessage, STTWordMessage
17 | from unmute.websocket_utils import WebsocketState
18 |
19 | logger = getLogger(__name__)
20 |
21 | TranscriptionStatus = Literal[
22 | "should_transcribe", "has_transcribed", "should_not_transcribe"
23 | ]
24 |
25 |
26 | class DummySpeechToText(ServiceWithStartup):
27 | def __init__(
28 | self, stt_instance: str = STT_SERVER, delay_sec: float = STT_DELAY_SEC
29 | ):
30 | self.stt_instance = stt_instance
31 | self.sent_samples = 0
32 | self.received_words = 0
33 | self.delay_sec = delay_sec
34 | self.current_time = -STT_DELAY_SEC
35 |
36 | # We just keep this at 1.0 = user is not speaking
37 | self.pause_prediction = ExponentialMovingAverage(
38 | attack_time=0.01, release_time=0.01, initial_value=1.0
39 | )
40 | self.should_shutdown = asyncio.Event()
41 |
42 | def state(self) -> WebsocketState:
43 | return "connected"
44 |
45 | async def send_audio(self, audio: np.ndarray) -> None:
46 | self.current_time += FRAME_TIME_SEC
47 |
48 | async def send_marker(self, id: int) -> None:
49 | return
50 |
51 | async def start_up(self):
52 | logger.info("Starting dummy STT")
53 |
54 | async def shutdown(self):
55 | logger.info("Shutting down dummy STT")
56 | self.should_shutdown.set()
57 |
58 | async def __aiter__(
59 | self,
60 | ) -> AsyncIterator[STTWordMessage | STTMarkerMessage]:
61 | while self.should_shutdown.is_set() is False:
62 | await asyncio.sleep(1.0)
63 |
64 | # Just to satisfy the type checker
65 | yield STTMarkerMessage(type="Marker", id=0)
66 |
--------------------------------------------------------------------------------
/unmute/stt/exponential_moving_average.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class ExponentialMovingAverage:
5 | def __init__(
6 | self, attack_time: float, release_time: float, initial_value: float = 0.0
7 | ):
8 | """An EMA that can smooth differently for attack (up) and release (down).
9 |
10 | Args:
11 | attack_time: Time in seconds to reach 50% of the target value.
12 | Used when the new value is greater than the current value.
13 | release_time: Time in seconds to decay to 50% of the target value.
14 | Used when the new value is less than the current value.
15 | initial_value: Initial value of the EMA.
16 | """
17 | self.attack_time = attack_time
18 | self.release_time = release_time
19 | self.value = initial_value
20 |
21 | def update(self, *, dt: float, new_value: float) -> float:
22 | assert dt > 0.0, f"dt must be positive, got {dt=}"
23 | assert new_value >= 0.0, f"new_value must be non-negative, got {new_value=}"
24 |
25 | if new_value > self.value:
26 | alpha = 1 - np.exp(-dt / self.attack_time * np.log(2))
27 | else:
28 | alpha = 1 - np.exp(-dt / self.release_time * np.log(2))
29 |
30 | self.value = float((1 - alpha) * self.value + alpha * new_value)
31 | return self.value
32 |
33 | def time_to_decay_to(self, value: float) -> float:
34 | """Return the time in seconds it will take for the estimate to reach `value`
35 | if it started at 1."""
36 | assert 0 < value < 1
37 | return float(-self.release_time * np.log2(value))
38 |
--------------------------------------------------------------------------------
/unmute/timer.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 |
3 |
4 | def get_time() -> float:
5 | return asyncio.get_event_loop().time()
6 |
7 |
8 | class Stopwatch:
9 | def __init__(self, autostart: bool = True):
10 | self.start_time = get_time() if autostart else None
11 | self.end_time = None
12 |
13 | def start_if_not_started(self):
14 | if self.start_time is None:
15 | self.start_time = get_time()
16 |
17 | def stop(self) -> float | None:
18 | if self.start_time is None:
19 | return None
20 |
21 | if self.end_time is not None:
22 | return None # Already stopped
23 | else:
24 | self.end_time = get_time()
25 | return self.end_time - self.start_time
26 |
27 | def time(self) -> float:
28 | if self.start_time is None:
29 | raise RuntimeError("Stopwatch not started")
30 |
31 | return get_time() - self.start_time
32 |
33 | @property
34 | def started(self) -> bool:
35 | return self.start_time is not None
36 |
37 |
38 | class PhasesStopwatch:
39 | def __init__(self, phases: list[str]):
40 | self.phases = phases
41 | self.times: list[float | None] = [None for _ in phases]
42 |
43 | def _check_previous_phases_done(self, to: int):
44 | for i in range(to):
45 | if self.times[i] is None:
46 | raise RuntimeError(
47 | f"Wanted to start phase {self.phases[to]} "
48 | f"but earlier phase {self.phases[i]} hasn't started"
49 | )
50 |
51 | def time_phase_if_not_started(
52 | self, phase: str, t: float | None = None, check_previous: bool = True
53 | ):
54 | """Time a phase, either with the current time or a given time."""
55 | if check_previous:
56 | self._check_previous_phases_done(self.get_phase_index(phase))
57 |
58 | i = self.get_phase_index(phase)
59 |
60 | if self.times[i] is None:
61 | self.times[i] = t or get_time()
62 |
63 | def get_phase_index(self, phase: str) -> int:
64 | """Get the index of a phase."""
65 | try:
66 | i = self.phases.index(phase)
67 | except ValueError as e:
68 | raise ValueError(
69 | f"Phase {phase} not in phases. Valid phases: {self.phases}"
70 | ) from e
71 |
72 | return i
73 |
74 | def get_time_for_phase(self, phase: str) -> float:
75 | try:
76 | i = self.phases.index(phase)
77 | except ValueError as e:
78 | raise ValueError(
79 | f"Phase {phase} not in phases. Valid phases: {self.phases}"
80 | ) from e
81 |
82 | t = self.times[i]
83 | if t is None:
84 | raise RuntimeError(
85 | f"Phase {phase} not started. {self.phase_dict_partial()=}"
86 | )
87 |
88 | return t
89 |
90 | def phase_dict(self) -> dict[str, float]:
91 | return {phase: self.get_time_for_phase(phase) for phase in self.phases}
92 |
93 | def phase_dict_partial(self) -> dict[str, float | None]:
94 | return {phase: self.times[i] for i, phase in enumerate(self.phases)}
95 |
96 | def reset(self):
97 | self.times = [None for _ in self.phases]
98 |
--------------------------------------------------------------------------------
/unmute/tts/copy_approved_voice_donations.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import csv
3 | from pathlib import Path
4 |
5 | from unmute.tts.trim_voice_donation_clip import trim_trailing_silence
6 |
7 |
8 | def main():
9 | parser = argparse.ArgumentParser(
10 | description="Copy approved voice donation .wav files with proper naming."
11 | )
12 | parser.add_argument(
13 | "--table",
14 | required=True,
15 | help="Path to the .tsv or .csv file with metadata.",
16 | type=Path,
17 | )
18 | parser.add_argument(
19 | "--input-dir",
20 | required=True,
21 | help="Directory containing input .wav files named {verification_id}.wav",
22 | type=Path,
23 | )
24 | parser.add_argument(
25 | "--output-dir",
26 | required=True,
27 | help="Directory to copy approved .wav files to.",
28 | type=Path,
29 | )
30 | args = parser.parse_args()
31 |
32 | # Detect delimiter
33 | table_path = Path(args.table)
34 | delimiter = "\t" if table_path.suffix.lower() == ".tsv" else ","
35 |
36 | # Ask for confirmation before clearing the output directory
37 | if args.output_dir.exists() and any(args.output_dir.iterdir()):
38 | confirm = (
39 | input(
40 | f"Output directory {args.output_dir} is not empty. "
41 | "Will clear .wav and .wav.safetensors files before continuing. "
42 | "Ok? (y/N): "
43 | )
44 | .strip()
45 | .lower()
46 | )
47 | if confirm != "y":
48 | print("Exiting.")
49 | exit(1)
50 |
51 | for item in args.output_dir.iterdir():
52 | if item.is_file() and (
53 | item.suffix == ".wav" or item.name.endswith(".wav.safetensors")
54 | ):
55 | item.unlink()
56 |
57 | with table_path.open(newline="", encoding="utf-8") as f:
58 | reader = csv.DictReader(f, delimiter=delimiter)
59 | for row in reader:
60 | approval = row.get("approval", "").strip().upper()
61 | if approval != "TRUE":
62 | continue
63 |
64 | verification_id = row["verification_id"].strip()
65 |
66 | in_path = args.input_dir / f"{verification_id}.wav"
67 |
68 | if not in_path.is_file():
69 | raise FileNotFoundError(f"Input file not found: {in_path}")
70 |
71 | nickname_override = row.get("nickname override", "").strip()
72 | nickname = row.get("nickname", "").strip()
73 | if nickname_override:
74 | out_name = nickname_override
75 | elif nickname:
76 | out_name = nickname
77 | else:
78 | out_name = verification_id[:4]
79 |
80 | # Clean output name
81 | out_name = (
82 | out_name.replace(".", " ")
83 | .replace("/", " ")
84 | .replace("\\", " ")
85 | .strip() # Strip trailing spaces before turning them into underscores
86 | .replace(" ", "_")
87 | )
88 | out_path = args.output_dir / f"{out_name}.wav"
89 |
90 | trim_trailing_silence(in_path, out_path)
91 | print(f"Copied {in_path} -> {out_path}, trimming silence")
92 |
93 |
94 | if __name__ == "__main__":
95 | main()
96 |
--------------------------------------------------------------------------------
/unmute/tts/create_voice_donation_table.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import csv
3 | import os
4 | from pathlib import Path
5 |
6 | from unmute.tts.voice_donation import VoiceDonationMetadata
7 |
8 |
9 | def get_flattened_donation(donation: VoiceDonationMetadata) -> dict:
10 | """Flatten the VoiceDonationMetadata for easier processing."""
11 | return {
12 | "verification_id": str(donation.submission.verification_id),
13 | "timestamp_str": donation.timestamp_str,
14 | "email": donation.submission.email,
15 | "nickname": donation.submission.nickname,
16 | "verification_text": donation.verification.text,
17 | }
18 |
19 |
20 | def main(voice_donation_dir: Path, set_mtime: bool = False):
21 | backups = sorted(list(voice_donation_dir.glob("voice-donation_*/")))
22 | if not backups:
23 | print("No backups found.")
24 | exit(1)
25 |
26 | backup = backups[-1]
27 | print(f"Using backup: {backup}")
28 |
29 | donations: list[VoiceDonationMetadata] = []
30 | for donation_json in backup.glob("*.json"):
31 | with open(donation_json, "r") as f:
32 | metadata = VoiceDonationMetadata.model_validate_json(f.read())
33 | donations.append(metadata)
34 |
35 | if set_mtime:
36 | os.utime(donation_json, (metadata.timestamp, metadata.timestamp))
37 | donation_wav = donation_json.with_suffix(".wav")
38 | if not donation_wav.exists():
39 | print(f"Warning: {donation_wav} does not exist, skipping mtime set.")
40 | else:
41 | os.utime(donation_wav, (metadata.timestamp, metadata.timestamp))
42 |
43 | donations.sort(key=lambda x: x.timestamp)
44 |
45 | seen_nicknames = set()
46 | for donation in donations:
47 | if donation.submission.nickname in seen_nicknames:
48 | raise ValueError(
49 | f"Duplicate nickname found: {donation.submission.nickname}"
50 | )
51 |
52 | assert donation.submission.license == "CC0", "Only CC0 license expected"
53 | assert donation.submission.format_version == "1.0", (
54 | "Only format version 1.0 expected"
55 | )
56 |
57 | flattened_donations = [get_flattened_donation(d) for d in donations]
58 |
59 | output_tsv = voice_donation_dir / "flattened_donations.tsv"
60 | if flattened_donations:
61 | with open(output_tsv, "w", newline="") as tsvfile:
62 | writer = csv.DictWriter(
63 | tsvfile, fieldnames=flattened_donations[0].keys(), delimiter="\t"
64 | )
65 | writer.writeheader()
66 | writer.writerows(flattened_donations)
67 | print(f"Exported {len(flattened_donations)} donations to {output_tsv}")
68 | print(
69 | "You can copy this file and use cmd+shift+v to paste the values into a spreadsheet."
70 | )
71 | else:
72 | print("No donations to export.")
73 |
74 |
75 | if __name__ == "__main__":
76 | parser = argparse.ArgumentParser(description="Process voice donation backups.")
77 | parser.add_argument(
78 | "voice_donation_dir",
79 | type=Path,
80 | help="Directory containing voice donation backups.",
81 | )
82 | parser.add_argument(
83 | "--set-mtime",
84 | action="store_true",
85 | help="Set modification time of each file to match its timestamp. "
86 | "Useful to be able to sort the folder by timestamp for manual verification, "
87 | "so that the file order matches the table.",
88 | )
89 | args = parser.parse_args()
90 |
91 | main(args.voice_donation_dir, set_mtime=args.set_mtime)
92 |
--------------------------------------------------------------------------------
/unmute/tts/realtime_queue.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import heapq
3 | from dataclasses import dataclass, field
4 | from typing import AsyncIterable, Callable, Iterable, TypeVar
5 |
6 | T = TypeVar("T")
7 |
8 |
9 | @dataclass(order=True)
10 | class TimedItem[T]:
11 | time: float
12 | item: T = field(compare=False)
13 |
14 | def as_tuple(self) -> tuple[float, T]:
15 | return self.time, self.item
16 |
17 |
18 | class RealtimeQueue[T]:
19 | """A data structure that accumulates timestamped items and releases them at the given times.
20 |
21 | Implemented as a heap, so it doesn't have to be FIFO.
22 | """
23 |
24 | def __init__(self, get_time: Callable[[], float] | None = None):
25 | self.queue: list[TimedItem] = []
26 | self.start_time: float | None = None
27 |
28 | if get_time is None:
29 | self.get_time = lambda: asyncio.get_event_loop().time()
30 | else:
31 | # Use an external time function to support use cases where "real time"
32 | # means something different
33 | self.get_time = get_time
34 |
35 | def start_if_not_started(self):
36 | if self.start_time is None:
37 | self.start_time = self.get_time()
38 |
39 | def put(self, item: T, time: float):
40 | heapq.heappush(self.queue, TimedItem(time, item))
41 |
42 | async def get(self) -> AsyncIterable[tuple[float, T]]:
43 | """Get all items that are past due. If none is, wait for the next one."""
44 |
45 | if self.start_time is None:
46 | return
47 | if not self.queue:
48 | return
49 |
50 | time_since_start = self.get_time() - self.start_time
51 | while self.queue:
52 | delta = self.queue[0].time - time_since_start
53 |
54 | if delta > 0:
55 | await asyncio.sleep(delta)
56 |
57 | yield heapq.heappop(self.queue).as_tuple()
58 |
59 | def get_nowait(self) -> Iterable[tuple[float, T]]:
60 | if self.start_time is None:
61 | return None
62 |
63 | time_since_start = self.get_time() - self.start_time
64 |
65 | while self.queue and self.queue[0].time <= time_since_start:
66 | yield heapq.heappop(self.queue).as_tuple()
67 |
68 | async def __aiter__(self):
69 | if self.start_time is None or not self.queue:
70 | return
71 |
72 | while self.queue:
73 | time_since_start = self.get_time() - self.start_time
74 | delta = self.queue[0].time - time_since_start
75 |
76 | if delta > 0:
77 | await asyncio.sleep(delta)
78 |
79 | yield heapq.heappop(self.queue).as_tuple()
80 |
81 | def empty(self):
82 | return not self.queue
83 |
--------------------------------------------------------------------------------
/unmute/tts/trim_voice_donation_clip.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 |
4 | import numpy as np
5 | import sphn
6 |
7 | from unmute.kyutai_constants import SAMPLE_RATE
8 |
9 |
10 | def trim_silence_end(
11 | audio: np.ndarray, threshold_db: float = -24.0, min_silence_sec: float = 1.0
12 | ) -> np.ndarray:
13 | """
14 | Trim silence from the end of the audio. Silence is defined as samples below a threshold (in dB relative to peak).
15 | """
16 | # Only operate on mono audio
17 | if audio.ndim != 1:
18 | raise ValueError("trim_silence_end expects mono audio (1D array)")
19 |
20 | peak = np.max(np.abs(audio))
21 | if peak == 0:
22 | return audio # silent audio
23 |
24 | threshold = peak * 10 ** (threshold_db / 20)
25 | window_sec = 0.1
26 | window_size = int(window_sec * SAMPLE_RATE)
27 | if window_size < 1:
28 | window_size = 1
29 |
30 | # Compute moving RMS (root mean square) over the window
31 | def moving_rms(x: np.ndarray, w: int) -> np.ndarray:
32 | # Pad with zeros at the end to keep length
33 | if x.shape[0] < w:
34 | return np.array([])
35 | cumsum = np.cumsum(np.insert(x**2, 0, 0))
36 | rms = np.sqrt((cumsum[w:] - cumsum[:-w]) / w)
37 | # Pad to match input length (pad end)
38 | pad = np.zeros(x.shape[0] - rms.shape[0])
39 | return np.concatenate([rms, pad])
40 |
41 | rms = moving_rms(audio, window_size)
42 | # Find last window above threshold
43 | for i in range(rms.shape[0] - 1, -1, -1):
44 | if rms[i] > threshold:
45 | end = min(
46 | i + window_size + int(min_silence_sec * SAMPLE_RATE), audio.shape[0]
47 | )
48 | if end < audio.shape[0]:
49 | print(
50 | "Trimming silence from end: "
51 | f"{(audio.shape[0] - end) / SAMPLE_RATE:.1f}s removed"
52 | )
53 | return audio[:end]
54 |
55 | raise ValueError("Internal error, no windows above threshold found.")
56 |
57 |
58 | def trim_trailing_silence(in_path: Path, out_path: Path | None = None) -> None:
59 | if out_path is None:
60 | out_path = in_path.with_stem(in_path.stem + "_trimmed")
61 |
62 | data, _sr = sphn.read(in_path, sample_rate=SAMPLE_RATE)
63 |
64 | if data.ndim == 2:
65 | data = np.mean(data, axis=0)
66 | elif data.ndim == 1:
67 | pass
68 | else:
69 | raise ValueError(f"Unexpected audio shape: {data.shape}")
70 |
71 | n_samples = data.shape[0]
72 |
73 | ten_sec_samples = int(SAMPLE_RATE * 10)
74 | if n_samples < ten_sec_samples:
75 | print(
76 | f"{in_path} is shorter than 10 seconds: "
77 | f"{n_samples / SAMPLE_RATE:.2f}s, not trimming"
78 | )
79 | sphn.write_wav(out_path, data, SAMPLE_RATE)
80 | return
81 |
82 | data = trim_silence_end(data)
83 |
84 | data_last10 = data[-ten_sec_samples:]
85 | if data_last10.shape[0] < ten_sec_samples:
86 | raise ValueError(
87 | "Less than 10 seconds remain after trimming silence: "
88 | f"{data_last10.shape[0] / SAMPLE_RATE:.2f}s"
89 | )
90 |
91 | sphn.write_wav(out_path, data_last10, SAMPLE_RATE)
92 | print(f"Wrote {out_path} ({data_last10.shape[0] / SAMPLE_RATE:.2f}s)")
93 |
94 |
95 | def main():
96 | parser = argparse.ArgumentParser(
97 | description="Trim last 10s and trailing silence from wav files."
98 | )
99 | parser.add_argument(
100 | "inputs", nargs="+", help="Input wav files or glob patterns (e.g. *.wav)"
101 | )
102 | args = parser.parse_args()
103 |
104 | for arg in args.inputs:
105 | in_path = Path(arg)
106 |
107 | # if already trimmed, skip
108 | if in_path.suffix == ".wav" and in_path.stem.endswith("_trimmed"):
109 | print(f"Skipping {in_path} (already trimmed)")
110 | continue
111 |
112 | if not in_path.is_file():
113 | print(f"Skipping {in_path} (not a file)")
114 | continue
115 | try:
116 | trim_trailing_silence(in_path)
117 | except ValueError as e:
118 | print(f"Error processing {in_path}: {e}")
119 | continue
120 |
121 |
122 | if __name__ == "__main__":
123 | main()
124 |
--------------------------------------------------------------------------------
/unmute/tts/voice_cloning.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import uuid
3 |
4 | import requests
5 |
6 | from unmute.cache import get_cache
7 | from unmute.kyutai_constants import VOICE_CLONING_SERVER
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | voice_embeddings_cache = get_cache(prefix="voice", ttl_seconds=60 * 60 * 1) # 1 hour
13 |
14 |
15 | def clone_voice(audio_data: bytes) -> str:
16 | # Generate a unique voice name
17 | voice_name = "custom:" + str(uuid.uuid4())
18 |
19 | # Call the voice cloning server
20 | response = requests.post(
21 | f"{VOICE_CLONING_SERVER}/api/voice",
22 | data=audio_data,
23 | headers={"Content-Type": "application/octet-stream"},
24 | )
25 | response.raise_for_status()
26 | msgpack_data = response.content
27 |
28 | logger.info(f"Received voice embedding of size: {len(msgpack_data)} bytes")
29 |
30 | voice_embeddings_cache.set(voice_name, msgpack_data)
31 | voice_embeddings_cache.cleanup()
32 |
33 | return voice_name
34 |
--------------------------------------------------------------------------------
/unmute/tts/voice_donation.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import functools
3 | import logging
4 | import random
5 | import time
6 | import uuid
7 | from pathlib import Path
8 | from typing import Literal
9 |
10 | from pydantic import BaseModel
11 |
12 | from unmute import metrics as mt
13 | from unmute.cache import get_cache
14 | from unmute.kyutai_constants import MAX_VOICE_FILE_SIZE_MB, VOICE_DONATION_DIR
15 |
16 | MINUTES_TO_VERIFY = 5
17 | SECONDS_IN_HOUR = 60 * 60
18 |
19 | voice_donation_verification_cache = get_cache(
20 | prefix="voice_donation_verification", ttl_seconds=SECONDS_IN_HOUR * 1
21 | )
22 |
23 | CONSTANT_PREFIX = "I consent to my voice being used for voice cloning."
24 |
25 | logger = logging.getLogger(__name__)
26 |
27 |
28 | @functools.cache
29 | def get_sentences():
30 | with open(Path(__file__).parent / "voice_donation_sentences.txt", "r") as f:
31 | return [line.strip() for line in f if line.strip()]
32 |
33 |
34 | class VoiceDonationVerification(BaseModel):
35 | id: str
36 | text: str
37 | created_at_timetamp: float # seconds since epoch
38 |
39 |
40 | def generate_verification() -> VoiceDonationVerification:
41 | sentences = get_sentences()
42 | chosen_sentences = random.sample(sentences, 2)
43 | verification_text = f"{CONSTANT_PREFIX} {chosen_sentences[0]} {chosen_sentences[1]}"
44 | verification_id = uuid.uuid4()
45 |
46 | verification = VoiceDonationVerification(
47 | id=str(verification_id),
48 | text=verification_text,
49 | created_at_timetamp=time.time(),
50 | )
51 |
52 | voice_donation_verification_cache.set(
53 | verification.id, verification.model_dump_json()
54 | )
55 | voice_donation_verification_cache.cleanup()
56 |
57 | return verification
58 |
59 |
60 | class VoiceDonationSubmission(BaseModel):
61 | format_version: Literal["1.0"] = "1.0"
62 | # The email is kept so that the person can contact us if they want to withdraw their
63 | # donation, not published.
64 | email: str
65 | nickname: str
66 | verification_id: uuid.UUID
67 | # Only CC0 is allowed for now, but storing in case we decide to change it later
68 | license: Literal["CC0"] = "CC0"
69 |
70 |
71 | class VoiceDonationMetadata(BaseModel):
72 | submission: VoiceDonationSubmission
73 | verification: VoiceDonationVerification
74 | timestamp: float
75 | timestamp_str: str # For human readability
76 |
77 |
78 | def submit_voice_donation(
79 | submission: VoiceDonationSubmission, audio_file: bytes
80 | ) -> None:
81 | file_size_mb = len(audio_file) / (1024 * 1024)
82 |
83 | # No way they would be able to say all verification sentences in this time.
84 | if file_size_mb < 0.1:
85 | raise ValueError("Audio file is too small. Please provide a valid audio file.")
86 |
87 | # Should be checked by middleware already, but let's ensure it here too
88 | if file_size_mb > MAX_VOICE_FILE_SIZE_MB:
89 | raise ValueError(
90 | f"Audio file is too large. Maximum size is {MAX_VOICE_FILE_SIZE_MB} MB."
91 | )
92 |
93 | if len(submission.nickname) > 30:
94 | raise ValueError("Nickname is too long. Maximum length is 30 characters.")
95 |
96 | verification_raw = voice_donation_verification_cache.get(
97 | str(submission.verification_id)
98 | )
99 | if not verification_raw:
100 | raise ValueError(
101 | "Couldn't find verification data for the provided ID. "
102 | "Note you must complete the verification within "
103 | f"{MINUTES_TO_VERIFY:.0f} minutes."
104 | )
105 | verification = VoiceDonationVerification.model_validate_json(verification_raw)
106 |
107 | sec_since_creation = time.time() - verification.created_at_timetamp
108 |
109 | if sec_since_creation > MINUTES_TO_VERIFY * 60:
110 | raise ValueError(
111 | f"Verification expired after {MINUTES_TO_VERIFY} minutes. "
112 | "Please request a new verification."
113 | )
114 |
115 | VOICE_DONATION_DIR.mkdir(parents=True, exist_ok=True)
116 | audio_file_path = VOICE_DONATION_DIR / f"{submission.verification_id}.wav"
117 | audio_file_path.write_bytes(audio_file)
118 |
119 | now = datetime.datetime.now().astimezone()
120 | metadata = VoiceDonationMetadata(
121 | submission=submission,
122 | verification=verification,
123 | timestamp=now.timestamp(),
124 | timestamp_str=now.isoformat(),
125 | )
126 | metadata_path = VOICE_DONATION_DIR / f"{submission.verification_id}.json"
127 | metadata_path.write_text(metadata.model_dump_json(indent=2))
128 |
129 | voice_donation_verification_cache.delete(str(submission.verification_id))
130 | voice_donation_verification_cache.cleanup()
131 |
132 | logger.info(
133 | f"Received voice donation with id {submission.verification_id}, "
134 | f"file size {file_size_mb:.2f} MB. "
135 | f"Saved to {audio_file_path}."
136 | )
137 | mt.VOICE_DONATION_SUBMISSIONS.inc()
138 |
--------------------------------------------------------------------------------
/unmute/webrtc_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import requests
4 |
5 |
6 | def get_cloudflare_rtc_configuration():
7 | # see: https://fastrtc.org/deployment/#cloudflare-calls-api
8 | turn_key_id = os.environ.get("TURN_KEY_ID")
9 | turn_key_api_token = os.environ.get("TURN_KEY_API_TOKEN")
10 | ttl = 86400 # Can modify TTL, here it's set to 24 hours
11 |
12 | response = requests.post(
13 | f"https://rtc.live.cloudflare.com/v1/turn/keys/{turn_key_id}/credentials/generate-ice-servers",
14 | headers={
15 | "Authorization": f"Bearer {turn_key_api_token}",
16 | "Content-Type": "application/json",
17 | },
18 | json={"ttl": ttl},
19 | )
20 | if response.ok:
21 | return response.json()
22 | else:
23 | raise Exception(
24 | f"Failed to get TURN credentials: {response.status_code} {response.text}"
25 | )
26 |
--------------------------------------------------------------------------------
/unmute/websocket_utils.py:
--------------------------------------------------------------------------------
1 | from typing import Literal
2 |
3 | WebsocketState = Literal["not_created", "connecting", "connected", "closing", "closed"]
4 |
5 |
6 | def http_to_ws(url_string: str):
7 | """
8 | Converts an HTTP(S) URL string to a WebSocket (WS/WSS) URL string.
9 |
10 | Args:
11 | url_string: The input URL string starting with http:// or https://.
12 |
13 | Returns:
14 | The corresponding WebSocket URL string starting with ws:// or wss://.
15 | Returns the original string if it doesn't start with http:// or https://.
16 | """
17 | if url_string.startswith("http://"):
18 | return "ws://" + url_string[7:]
19 | elif url_string.startswith("https://"):
20 | return "wss://" + url_string[8:]
21 | else:
22 | return url_string
23 |
24 |
25 | def ws_to_http(url_string: str):
26 | """
27 | Converts a WebSocket (WS/WSS) URL string to an HTTP(S) URL string.
28 |
29 | Args:
30 | url_string: The input URL string starting with ws:// or wss://.
31 |
32 | Returns:
33 | The corresponding HTTP URL string starting with http:// or https://.
34 | Returns the original string if it doesn't start with ws:// or wss://.
35 | """
36 | if url_string.startswith("ws://"):
37 | return "http://" + url_string[5:]
38 | elif url_string.startswith("wss://"):
39 | return "https://" + url_string[6:]
40 | else:
41 | return url_string
42 |
--------------------------------------------------------------------------------
/voices.yaml:
--------------------------------------------------------------------------------
1 | - name: Watercooler
2 | good: true
3 | instructions:
4 | type: smalltalk
5 | source:
6 | source_type: file
7 | path_on_server: unmute-prod-website/p329_022.wav
8 | description: From the Device Recorded VCTK dataset.
9 | description_link: https://datashare.ed.ac.uk/handle/10283/3038
10 | - name: Quiz show
11 | comment: man, UK, skeptical
12 | good: true
13 | instructions:
14 | type: quiz_show
15 | source:
16 | source_type: freesound
17 | url: https://freesound.org/people/InspectorJ/sounds/519189/
18 | sound_instance:
19 | id: 519189
20 | name: "Request #42 - Hmm, I don't know.wav"
21 | username: InspectorJ
22 | license: https://creativecommons.org/licenses/by/4.0/
23 | path_on_server: unmute-prod-website/freesound/519189_request-42---hmm-i-dont-knowwav.mp3
24 | - name: Gertrude
25 | good: true
26 | instructions:
27 | type: constant
28 | text: Offer life advice. Be kind and sympathetic. Your name is Gertrude.
29 | source:
30 | source_type: freesound
31 | url: https://freesound.org/people/tender_buttons/sounds/440565/
32 | sound_instance:
33 | id: 440565
34 | name: Why is there education.wav
35 | username: tender_buttons
36 | license: http://creativecommons.org/licenses/by/3.0/
37 | path_on_server: unmute-prod-website/freesound/440565_why-is-there-educationwav.mp3
38 | - name: Dev (news)
39 | good: true
40 | instructions:
41 | type: news
42 | source:
43 | source_type: file
44 | path_on_server: unmute-prod-website/developer-1.mp3
45 | description: This is the voice of Václav Volhejn from Kyutai.
46 | - name: Explanation
47 | good: true
48 | instructions:
49 | type: unmute_explanation
50 | source:
51 | source_type: file
52 | path_on_server: unmute-prod-website/ex04_narration_longform_00001.wav
53 | description: This voice comes from the Expresso dataset.
54 | description_link: https://speechbot.github.io/expresso/
55 | - name: Charles
56 | good: true
57 | instructions:
58 | type: constant
59 | text: Tu es le général de Gaulle. Pour ton premier tour de parole, tu te présentes en français en 2 phrases. Si on te répond en français, tu parles en français. Si on te répond en anglais, tu parles en anglais, mais tu utilises au moins un mot français par phrase, entre guillemets français. Quand on te pose une question, tu réponds en parlant d'une anecdote historique que tu as vécu, comme une rencontre ou une discussion. Tu fais preuve d'une sensibilité particulière à la souffrance de tous les peuples du monde au cours de l'histoire. Tu utilises un langage grave et solennel.
60 | source:
61 | source_type: file
62 | path_on_server: unmute-prod-website/degaulle-2.wav
63 | description: From a recording of Charles de Gaulle's speech.
64 | description_link: https://www.youtube.com/watch?v=AUS5LHDkwP0
65 | - name: Développeuse
66 | good: true
67 | instructions:
68 | type: smalltalk
69 | language: fr
70 | source:
71 | source_type: file
72 | path_on_server: unmute-prod-website/developpeuse-3.wav
73 | description: This is the voice of one of the developers at Kyutai.
74 | - name: Fabieng
75 | good: true
76 | instructions:
77 | type: constant
78 | text: Ta langue principale est le français mais avec des anglicismes caractéristiques du jeune cadre dynamique. Tu es coach en motivation et Chief Happiness Officer dans une start-up qui fait du b2b. Tu cherches à tout optimiser dans la vie et à avoir un mindset de vainqueur.
79 | language: fr
80 | source:
81 | source_type: file
82 | path_on_server: unmute-prod-website/fabieng-enhanced-v2.wav
83 | description: Fabieng is voice acted by Neil Zeghidour from Kyutai.
84 |
--------------------------------------------------------------------------------