├── .github ├── ISSUE_TEMPLATE │ ├── bug.yml │ └── question.yml ├── PULL_REQUEST_TEMPLATE.md ├── actions │ ├── moshi_build │ │ └── action.yml │ └── rust_build │ │ └── action.yml └── workflows │ ├── precommit.yml │ ├── rust-ci.yml │ └── rustymimi-ci.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── FAQ.md ├── LICENSE-APACHE ├── LICENSE-MIT ├── README.md ├── client ├── .env.local ├── .eslinrc.json ├── .nvmrc ├── .prettierignore ├── .prettierrc.json ├── Dockerfile ├── LICENSE ├── README.md ├── index.html ├── package-lock.json ├── package.json ├── postcss.config.js ├── public │ └── assets │ │ ├── decoderWorker.min.js │ │ ├── decoderWorker.min.wasm │ │ ├── favicon-16x16.png │ │ ├── favicon-32x32.png │ │ ├── favicon.ico │ │ ├── images │ │ └── demo │ │ │ ├── image1.jpg │ │ │ ├── image10.jpg │ │ │ ├── image11.jpg │ │ │ ├── image12.jpg │ │ │ ├── image13.jpg │ │ │ ├── image14.jpg │ │ │ ├── image15.jpg │ │ │ ├── image16.jpg │ │ │ ├── image17.jpg │ │ │ ├── image18.jpg │ │ │ ├── image19.jpg │ │ │ ├── image2.jpg │ │ │ ├── image20.jpg │ │ │ ├── image3.jpg │ │ │ ├── image4.jpg │ │ │ ├── image5.jpg │ │ │ ├── image6.jpg │ │ │ ├── image7.jpg │ │ │ ├── image8.jpg │ │ │ └── image9.jpg │ │ └── logo.svg ├── src │ ├── app.tsx │ ├── audio-processor.ts │ ├── components │ │ ├── Button │ │ │ └── Button.tsx │ │ ├── ImageGallery │ │ │ └── ImageGallery.tsx │ │ └── Input │ │ │ └── Input.tsx │ ├── decoder │ │ └── decoderWorker.ts │ ├── env.ts │ ├── index.css │ ├── modules.d.ts │ ├── pages │ │ ├── Conversation │ │ │ ├── Conversation.tsx │ │ │ ├── MediaContext.ts │ │ │ ├── SocketContext.ts │ │ │ ├── canvas-logo-black.png │ │ │ ├── canvas-logo.png │ │ │ ├── components │ │ │ │ ├── AudioVisualizer │ │ │ │ │ ├── AudioVisualizer.tsx │ │ │ │ │ ├── ClientVisualizer.tsx │ │ │ │ │ └── ServerVisualizer.tsx │ │ │ │ ├── Controls │ │ │ │ │ └── Controls.tsx │ │ │ │ ├── ModelParams │ │ │ │ │ └── ModelParams.tsx │ │ │ │ ├── ServerAudio │ │ │ │ │ ├── ServerAudio.tsx │ │ │ │ │ └── ServerAudioStats.tsx │ │ │ │ ├── ServerInfo │ │ │ │ │ └── ServerInfo.tsx │ │ │ │ ├── TextDisplay │ │ │ │ │ ├── TextDisplay.tsx │ │ │ │ │ └── TextDisplayStats.tsx │ │ │ │ └── UserAudio │ │ │ │ │ ├── UserAudio.tsx │ │ │ │ │ └── UserAudioStats.tsx │ │ │ ├── getMimeType.ts │ │ │ └── hooks │ │ │ │ ├── audioUtils.ts │ │ │ │ ├── useModelParams.ts │ │ │ │ ├── useServerAudio.ts │ │ │ │ ├── useServerInfo.ts │ │ │ │ ├── useServerText.ts │ │ │ │ ├── useSocket.ts │ │ │ │ └── useUserAudio.ts │ │ └── Queue │ │ │ ├── Queue.tsx │ │ │ ├── api │ │ │ ├── client.ts │ │ │ ├── errors │ │ │ │ ├── api_error.ts │ │ │ │ └── response_error.ts │ │ │ └── validators.ts │ │ │ └── hooks │ │ │ └── useUserEmail.ts │ └── protocol │ │ ├── encoder.ts │ │ ├── testMessages.ts │ │ └── types.ts ├── tailwind.config.js ├── tsconfig.json └── vite.config.ts ├── configs ├── moshi_7b_202409.json ├── moshi_dev_2b.json └── moshi_mlx_2b.json ├── data ├── sample_fr_hibiki_crepes.mp3 ├── sample_fr_hibiki_intro.mp3 └── sample_fr_hibiki_monologue_otis.mp3 ├── docker-compose.yml ├── mimi.png ├── moshi.png ├── moshi ├── Dockerfile ├── LICENSE ├── LICENSE.audiocraft ├── MANIFEST.in ├── README.md ├── demo_moshi.ipynb ├── moshi │ ├── __init__.py │ ├── client.py │ ├── client_gradio.py │ ├── client_utils.py │ ├── conditioners │ │ ├── __init__.py │ │ ├── base.py │ │ ├── tensors.py │ │ └── text.py │ ├── models │ │ ├── __init__.py │ │ ├── compression.py │ │ ├── lm.py │ │ ├── lm_utils.py │ │ └── loaders.py │ ├── modules │ │ ├── __init__.py │ │ ├── conv.py │ │ ├── conv_test.py │ │ ├── gating.py │ │ ├── lora.py │ │ ├── resample.py │ │ ├── rope.py │ │ ├── seanet.py │ │ ├── seanet_test.py │ │ ├── streaming.py │ │ └── transformer.py │ ├── quantization │ │ ├── __init__.py │ │ ├── base.py │ │ ├── core_vq.py │ │ └── vq.py │ ├── run_inference.py │ ├── server.py │ └── utils │ │ ├── __init__.py │ │ ├── autocast.py │ │ ├── compile.py │ │ ├── quantize.py │ │ ├── sampling.py │ │ └── utils.py ├── pyproject.toml ├── requirements.txt ├── setup.cfg └── tests │ ├── assets │ ├── test_lm_codes.safetensors │ ├── test_lm_model.safetensors │ └── test_lm_out.safetensors │ └── test_lm.py ├── moshi_mlx ├── LICENSE ├── MANIFEST.in ├── README.md ├── moshi_mlx │ ├── __init__.py │ ├── client_utils.py │ ├── local.py │ ├── local_web.py │ ├── models │ │ ├── __init__.py │ │ ├── generate.py │ │ ├── lm.py │ │ └── mimi.py │ ├── modules │ │ ├── __init__.py │ │ ├── conditioner.py │ │ ├── conv.py │ │ ├── kv_cache.py │ │ ├── quantization.py │ │ ├── seanet.py │ │ └── transformer.py │ ├── py.typed │ ├── run_helium.py │ ├── run_inference.py │ └── utils │ │ ├── __init__.py │ │ └── sampling.py ├── pyproject.toml ├── requirements.txt └── setup.cfg ├── requirements-dev.txt ├── rust ├── .github │ └── workflows │ │ └── rust-ci.yml ├── Cargo.lock ├── Cargo.toml ├── LICENSE ├── README.md ├── mimi-pyo3 │ ├── Cargo.toml │ ├── py_src │ │ └── rustymimi │ │ │ ├── __init__.py │ │ │ └── __init__.pyi │ ├── pyproject.toml │ ├── src │ │ └── lib.rs │ └── stub.py ├── moshi-backend │ ├── Cargo.toml │ ├── build.rs │ ├── config-q8.json │ ├── config.json │ └── src │ │ ├── audio.rs │ │ ├── benchmark.rs │ │ ├── build.rs │ │ ├── main.rs │ │ ├── standalone.rs │ │ ├── stream_both.rs │ │ └── utils.rs ├── moshi-cli │ ├── Cargo.toml │ └── src │ │ ├── audio_io.rs │ │ ├── gen.rs │ │ ├── main.rs │ │ └── multistream.rs ├── moshi-core │ ├── Cargo.toml │ └── src │ │ ├── asr.rs │ │ ├── batched_transformer.rs │ │ ├── conditioner.rs │ │ ├── conv.rs │ │ ├── kv_cache.rs │ │ ├── lib.rs │ │ ├── lm.rs │ │ ├── lm_generate.rs │ │ ├── lm_generate_multistream.rs │ │ ├── mimi.rs │ │ ├── nn.rs │ │ ├── quantization.rs │ │ ├── seanet.rs │ │ ├── streaming.rs │ │ ├── transformer.rs │ │ ├── tts.rs │ │ ├── tts_streaming.rs │ │ └── wav.rs ├── protocol.md ├── rustfmt.toml └── s2st-1b.toml └── scripts ├── __init__.py ├── export_quantized.py ├── export_torch.py ├── import_helium_mlx.py ├── import_lightformer.py ├── import_mlx.py ├── import_mlx_lora.py ├── import_pytorch.py ├── import_rust.py ├── import_rust_lora.py ├── mimi_mlx.py ├── mimi_streaming_test.py ├── moshi_benchmark.py ├── quantize_mlx.py ├── run_ci_when_installed.sh ├── setup.cfg ├── test_mimi.py ├── test_missing_data.py ├── test_missing_data_lm.py ├── test_mlx.py └── update_repo.py /.github/ISSUE_TEMPLATE/bug.yml: -------------------------------------------------------------------------------- 1 | name: Bug Report 2 | description: You found a bug. 3 | labels: ["bug", "triage"] 4 | body: 5 | - type: dropdown 6 | id: backend 7 | attributes: 8 | label: Backend impacted 9 | description: Which backend is concerned with your bug report? 10 | options: 11 | - The PyTorch implementation 12 | - The MLX implementation 13 | - The Rust implementation 14 | - Other / All 15 | default: 0 16 | validations: 17 | required: true 18 | - type: dropdown 19 | id: os 20 | attributes: 21 | label: Operating system 22 | description: What is your operating system? 23 | options: 24 | - Linux 25 | - Mac OS X 26 | - Windows (unsupported) 27 | default: 0 28 | validations: 29 | required: true 30 | - type: dropdown 31 | id: hardware 32 | attributes: 33 | label: Hardware 34 | description: What hardware are you using? 35 | options: 36 | - CPU 37 | - GPU with CUDA 38 | - Metal with MLX 39 | default: 0 40 | validations: 41 | required: true 42 | - type: textarea 43 | id: description 44 | attributes: 45 | label: Description 46 | description: Provide a detailed description of your bug. 47 | placeholder: 48 | value: 49 | validations: 50 | required: true 51 | - type: textarea 52 | id: more_info 53 | attributes: 54 | label: Extra information 55 | description: Please provide any other relevant information, such as log extracts, code etc. 56 | placeholder: 57 | value: 58 | validations: 59 | required: true 60 | - type: textarea 61 | id: env 62 | attributes: 63 | label: Environment 64 | description: Please provide any other relevant information, such as log extracts, code etc. 65 | placeholder: 66 | value: | 67 | Fill in the following information on your system. 68 | - Operating system version: 69 | 70 | If the backend impacted is PyTorch: 71 | - Python version: 72 | - PyTorch version: 73 | - CUDA version (run `python -c 'import torch; print(torch.version.cuda)'`): 74 | - GPU model and memory: 75 | 76 | If the backend is MLX: 77 | - Mac model: 78 | validations: 79 | required: true 80 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/question.yml: -------------------------------------------------------------------------------- 1 | name: Question 2 | description: You have a question about Moshi/Mimi, this codebase. 3 | labels: ["question", "triage"] 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: | 8 | Please first check the [FAQ](https://github.com/kyutai-labs/moshi/blob/main/FAQ.md). 9 | - type: checkboxes 10 | id: terms 11 | attributes: 12 | label: Due diligence 13 | description: Have you searched the existing issues / FAQ / Google / asked ChatGPT? 14 | options: 15 | - label: I have done my due diligence in trying to find the answer myself. 16 | required: true 17 | 18 | - type: dropdown 19 | id: backend 20 | attributes: 21 | label: Topic 22 | description: What is your question about? 23 | options: 24 | - The paper 25 | - The PyTorch implementation 26 | - The MLX implementation 27 | - The Rust implementation 28 | - Other / All 29 | default: 0 30 | validations: 31 | required: true 32 | - type: textarea 33 | id: question 34 | attributes: 35 | label: Question 36 | description: What is your question? 37 | placeholder: Your question. Please make sure this is directly related to our codebase. We will not provide support for installing PyTorch, CUDA, Rust etc. 38 | value: 39 | validations: 40 | required: true 41 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Checklist 2 | 3 | - [ ] Read CONTRIBUTING.md, and accept the CLA by including the provided snippet. We will not accept PR without this. 4 | - [ ] Run pre-commit hook. 5 | - [ ] If you changed Rust code, run `cargo check`, `cargo clippy`, `cargo test`. 6 | 7 | ## PR Description 8 | 9 | 10 | -------------------------------------------------------------------------------- /.github/actions/moshi_build/action.yml: -------------------------------------------------------------------------------- 1 | name: moshi_build 2 | description: 'Build env.' 3 | runs: 4 | using: "composite" 5 | steps: 6 | - uses: actions/setup-python@v2 7 | with: 8 | python-version: '3.10.14' 9 | - uses: actions/cache@v3 10 | id: cache 11 | with: 12 | path: env 13 | key: env-${{ hashFiles('moshi/pyproject.toml') }} 14 | - name: Install dependencies 15 | if: steps.cache.outputs.cache-hit != 'true' 16 | shell: bash 17 | run: | 18 | python3 -m venv env 19 | . env/bin/activate 20 | python -m pip install --upgrade pip 21 | pip install torch==2.4.0 --index-url https://download.pytorch.org/whl/cpu 22 | - name: Setup env 23 | shell: bash 24 | run: | 25 | source env/bin/activate 26 | pip install -e './moshi[dev]' 27 | pre-commit install 28 | -------------------------------------------------------------------------------- /.github/actions/rust_build/action.yml: -------------------------------------------------------------------------------- 1 | name: rust_build 2 | description: 'Setup rust env' 3 | inputs: 4 | os: 5 | default: ubuntu-latest 6 | toolchain: 7 | default: stable 8 | target: 9 | default: check 10 | runs: 11 | using: "composite" 12 | steps: 13 | - uses: actions-rs/toolchain@v1 14 | with: 15 | profile: minimal 16 | toolchain: ${{ inputs.toolchain }} 17 | override: true 18 | - name: cargo cache 19 | uses: actions/cache@v3 20 | with: 21 | path: | 22 | ~/.cargo/bin/ 23 | ~/.cargo/registry/index/ 24 | ~/.cargo/registry/cache/ 25 | ~/.cargo/git/db/ 26 | rust/target/ 27 | key: ${{ inputs.os }}-cargo-${{ inputs.target }}-${{ hashFiles('**/Cargo.toml') }} 28 | restore-keys: ${{ inputs.os }}-cargo- 29 | - name: install deps 30 | shell: bash 31 | run: | 32 | sudo apt-get update 33 | sudo apt-get install libasound2-dev cmake 34 | echo "test" 35 | cmake --version 36 | apt-cache show cmake 37 | -------------------------------------------------------------------------------- /.github/workflows/precommit.yml: -------------------------------------------------------------------------------- 1 | name: precommit 2 | on: 3 | push: 4 | branches: [ main ] 5 | pull_request: 6 | branches: [ main, refacto ] 7 | 8 | jobs: 9 | run_precommit: 10 | name: Run precommit 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v2 14 | - uses: ./.github/actions/moshi_build 15 | - run: | 16 | . env/bin/activate 17 | bash .git/hooks/pre-commit 18 | -------------------------------------------------------------------------------- /.github/workflows/rust-ci.yml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | branches: [ main ] 4 | pull_request: 5 | branches: [ main, refacto ] 6 | 7 | name: Rust CI 8 | 9 | jobs: 10 | check: 11 | name: Check 12 | defaults: 13 | run: 14 | working-directory: ./rust 15 | runs-on: ${{ matrix.os }} 16 | strategy: 17 | matrix: 18 | os: [ubuntu-latest] 19 | rust: [stable] 20 | steps: 21 | - uses: actions/checkout@v2 22 | - uses: ./.github/actions/rust_build 23 | - name: check 24 | shell: bash 25 | run: | 26 | cargo check 27 | - name: clippy 28 | shell: bash 29 | run: | 30 | cargo clippy -- -D warnings 31 | - name: fmt 32 | shell: bash 33 | run: | 34 | cargo fmt --all -- --check 35 | test: 36 | name: Test 37 | defaults: 38 | run: 39 | working-directory: ./rust 40 | runs-on: ${{ matrix.os }} 41 | strategy: 42 | matrix: 43 | os: [ubuntu-latest] 44 | rust: [stable] 45 | steps: 46 | - uses: actions/checkout@v2 47 | - uses: actions/setup-python@v5 48 | with: 49 | python-version: 3.11 50 | - uses: ./.github/actions/rust_build 51 | with: 52 | target: test 53 | - name: test 54 | shell: bash 55 | run: | 56 | cargo test 57 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | target-trunk/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # poetry 99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 103 | #poetry.lock 104 | 105 | # pdm 106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 107 | #pdm.lock 108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 109 | # in version control. 110 | # https://pdm.fming.dev/#use-with-ide 111 | .pdm.toml 112 | 113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 114 | __pypackages__/ 115 | 116 | # Celery stuff 117 | celerybeat-schedule 118 | celerybeat.pid 119 | 120 | # SageMath parsed files 121 | *.sage.py 122 | 123 | # Environments 124 | .env 125 | .venv 126 | env/ 127 | venv/ 128 | ENV/ 129 | env.bak/ 130 | venv.bak/ 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # mkdocs documentation 140 | /site 141 | 142 | # mypy 143 | .mypy_cache/ 144 | .dmypy.json 145 | dmypy.json 146 | 147 | # Pyre type checker 148 | .pyre/ 149 | 150 | # pytype static type analyzer 151 | .pytype/ 152 | 153 | # Cython debug symbols 154 | cython_debug/ 155 | 156 | # PyCharm 157 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 159 | # and can be added to the global gitignore or merged into this file. For a more nuclear 160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 161 | #.idea/ 162 | 163 | # VsCode 164 | .vscode/ 165 | 166 | *~ 167 | *.safetensors 168 | *.wav 169 | trace*.json 170 | *.flac 171 | pkg 172 | *.nsys-rep 173 | *.sqlite 174 | *.pem 175 | *.tgz 176 | *.mp3 177 | *.ogg 178 | /moshi-demo/config.sh 179 | log.* 180 | laurent/cuda-test/check 181 | client/node_modules 182 | timings.json 183 | mlx-trace.json 184 | /scripts/token.txt 185 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: local 3 | hooks: 4 | - id: flake8-moshi 5 | name: flake8 on moshi package 6 | language: system 7 | entry: bash -c 'cd moshi && flake8' 8 | pass_filenames: false 9 | always_run: true 10 | - id: pyright-moshi 11 | name: pyright on moshi package 12 | language: system 13 | entry: scripts/run_ci_when_installed.sh moshi 'cd moshi && pyright' 14 | pass_filenames: false 15 | always_run: true 16 | - id: tests-moshi 17 | name: pytests on moshi package 18 | language: system 19 | entry: scripts/run_ci_when_installed.sh moshi 'cd moshi && pytest tests' 20 | pass_filenames: false 21 | always_run: true 22 | - id: flake8-moshi_mlx 23 | name: flake8 on moshi_mlx package 24 | language: system 25 | entry: bash -c 'cd moshi_mlx && flake8' 26 | pass_filenames: false 27 | always_run: true 28 | - id: pyright-moshi_mlx 29 | name: pyright on moshi_mlx package 30 | language: system 31 | entry: scripts/run_ci_when_installed.sh moshi_mlx 'cd moshi_mlx && pyright' 32 | pass_filenames: false 33 | always_run: true 34 | 35 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Moshi 2 | 3 | ## Pull Requests 4 | 5 | Moshi is the implementation of a research paper. 6 | Therefore, we do not plan on accepting many pull requests for new features. 7 | However, we certainly welcome them for bug fixes. 8 | 9 | 1. Fork the repo and create your branch from `main`. 10 | 2. If you have changed APIs, update the documentation accordingly. 11 | 3. Ensure pre-commit hooks pass properly, in particular the linting and typing. 12 | 4. When changing the Rust code, run `cargo check`, `cargo clippy`, `cargo test`. 13 | 5. Accept the Contributor License Agreement (see after). 14 | 15 | Note that in general, we will not accept refactoring of the code. 16 | 17 | 18 | ## Contributor License Agreement ("CLA") 19 | 20 | In order to accept your pull request, we need you to submit a Contributor License Agreement. 21 | 22 | If you agree with the full CLA provided in the next paragraph, copy the following statement in your PR, changing your Github Handle: 23 | 24 | > I, {your GitHub handle}, confirm that I have read and understood the terms of the CLA of Kyutai-labs, as outlined in the repository's CONTRIBUTING.md, and I agree to be bound by these terms. 25 | 26 | The full CLA is provided as follows: 27 | 28 | > I, {your GitHub handle}, hereby grant to Kyutai-labs a perpetual, worldwide, non-exclusive, royalty-free, 29 | > irrevocable license to use, modify, distribute, and sublicense my Contributions. 30 | 31 | > I understand and accept that Contributions are limited to modifications, improvements, or changes 32 | > to the project’s source code submitted via pull requests. I accept that Kyutai-labs has full discretion to 33 | > review, accept, reject, or request changes to any Contributions I submit, and that submitting 34 | > a pull request does not guarantee its inclusion in the project. 35 | 36 | > By submitting a Contribution, I grant Kyutai-labs a perpetual, worldwide license to use, modify, 37 | > reproduce, distribute, and create derivative works based on my Contributions. 38 | > I also agree to assign all patent rights for any inventions or improvements that arise from my Contributions, 39 | > giving the Kyutai-labs full rights to file for and enforce patents. 40 | > I understand that the Kyutai-labs may commercialize, relicense, or exploit the project and my Contributions without further notice or obligation to me. 41 | > I confirm that my Contributions are original and that I have the legal right to grant this license. 42 | > If my Contributions include third-party materials, I will ensure that I have the necessary permissions 43 | > and will disclose this information. I accept that once my Contributions are integrated, they may be altered or removed at the Kyutai-labs’s discretion. 44 | 45 | > I acknowledge that I am making these Contributions voluntarily and will not receive any compensation. 46 | > Furthermore, I understand that all Contributions, including mine, are provided on an "as-is" basis, with no warranties. 47 | > By submitting a pull request, I agree to be bound by these terms. 48 | 49 | ## Issues 50 | 51 | Please submit issues on our Github repository. 52 | 53 | ## License 54 | 55 | By contributing to Moshi, you agree that your contributions will be licensed 56 | under the LICENSE-* files in the root directory of this source tree. 57 | In particular, the rust code is licensed under APACHE, and the python code under MIT. 58 | -------------------------------------------------------------------------------- /FAQ.md: -------------------------------------------------------------------------------- 1 | # FAQ 2 | 3 | Here is the answer to a number of frequently asked questions. 4 | 5 | ### Will you release training code? 6 | 7 | Some finetuning code can be found in the [kyutai-labs/moshi-finetune repo](https://github.com/kyutai-labs/moshi-finetune). 8 | 9 | ### Will you release the dataset? 10 | 11 | We will not release the pre-training dataset. 12 | 13 | ### Is Moshi multilingual? 14 | 15 | At the moment no. Moshi only speaks English. It has some basic support for translating some sentences 16 | or words to other languages, but you shouldn't expect to use it fully in any other language than English. 17 | 18 | ### Can I change Moshi's voice / personality? 19 | 20 | This would require fine tuning, which is not currently supported. 21 | 22 | ### Can Moshi run on a M1, or smaller GPUs? 23 | 24 | Sadly we do not think this is currently possible. Quantizing beyond 4 bits lead to dramatic 25 | decrease in quality, see [PR #58](https://github.com/kyutai-labs/moshi/pull/58). 26 | While we keep those limitations in mind for future versions, there is no immediate solution. 27 | 28 | ### Can we run quantized Moshi with PyTorch? 29 | 30 | At the moment no, we might look into adding this feature when we get the time. At the moment 31 | it is however possible to use the Rust backend, which should run in int8 with CUDA. 32 | 33 | ### Moshi stopped talking after 5 min. 34 | 35 | This is expected on the MLX and Rust implementation. 36 | We only use a fixed buffer, and we do not discard past entries. 37 | The PyTorch version should work for unlimited times, although this is mostly untested and we 38 | expect the quality to degrade after a bit (we have no attention sink or other mechanism to improve the streaming 39 | beyond the finite context used at training). 40 | 41 | ### The server seems to be running but nothing happens on connect. 42 | 43 | For diagnosis, look at your browser console if there is any error being 44 | reported. 45 | 46 | If you see issues that look like the following: 47 | ``` 48 | Uncaught (in promise) TypeError: Cannot read properties of undefined (reading 'addModule') 49 | ``` 50 | this is likely caused by the http server being remote and audio being disabled 51 | for http in such a case. 52 | 53 | To get around this, tunnel the 8998 port from the remote server to the localhost 54 | via ssh and access [localhost:8998](http://localhost:8998) via http normally 55 | after that. 56 | 57 | ### How to get the key.pem and cert.pem files required for serving over https? 58 | ```bash 59 | openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -days 365 -nodes -subj "/CN=localhost" 60 | ``` 61 | 62 | ### Can I run on a 12GB / 8 GB GPU ? 63 | For a 12GB GPU, this is possible following instructions in [issue #54](https://github.com/kyutai-labs/moshi/issues/54). 64 | For 8GB GPU, this is not possible at the moment. 65 | -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | Permission is hereby granted, free of charge, to any 2 | person obtaining a copy of this software and associated 3 | documentation files (the "Software"), to deal in the 4 | Software without restriction, including without 5 | limitation the rights to use, copy, modify, merge, 6 | publish, distribute, sublicense, and/or sell copies of 7 | the Software, and to permit persons to whom the Software 8 | is furnished to do so, subject to the following 9 | conditions: 10 | 11 | The above copyright notice and this permission notice 12 | shall be included in all copies or substantial portions 13 | of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF 16 | ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED 17 | TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A 18 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT 19 | SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 20 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 21 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR 22 | IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 23 | DEALINGS IN THE SOFTWARE. 24 | -------------------------------------------------------------------------------- /client/.env.local: -------------------------------------------------------------------------------- 1 | VITE_QUEUE_API_PATH=/api 2 | VITE_QUEUE_API_URL=https://moshi.chat -------------------------------------------------------------------------------- /client/.eslinrc.json: -------------------------------------------------------------------------------- 1 | { 2 | "env": { 3 | "browser": true, 4 | "es2021": true 5 | }, 6 | "extends": [ 7 | "plugin:react/recommended", 8 | "standard-with-typescript", 9 | "plugin:import/typescript", 10 | "plugin:prettier/recommended" 11 | ], 12 | "parser": "@typescript-eslint/parser", 13 | "overrides": [], 14 | "parserOptions": { 15 | "ecmaVersion": "latest", 16 | "sourceType": "module", 17 | "project": "./tsconfig.json" 18 | }, 19 | "plugins": ["react", "prettier"], 20 | "rules": { 21 | "@typescript-eslint/triple-slash-reference": "off" 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /client/.nvmrc: -------------------------------------------------------------------------------- 1 | v20.12.2 2 | -------------------------------------------------------------------------------- /client/.prettierignore: -------------------------------------------------------------------------------- 1 | dist/* -------------------------------------------------------------------------------- /client/.prettierrc.json: -------------------------------------------------------------------------------- 1 | { 2 | "arrowParens": "avoid", 3 | "singleQuote": false, 4 | "trailingComma": "all", 5 | "tabWidth": 2, 6 | "useTabs": false, 7 | "semi": true, 8 | "printWidth": 80, 9 | "plugins": ["prettier-plugin-tailwindcss"] 10 | } 11 | -------------------------------------------------------------------------------- /client/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM node:20 2 | 3 | WORKDIR /app 4 | 5 | COPY . /app 6 | 7 | RUN npm install 8 | 9 | RUN npm run build 10 | 11 | # Install OpenSSL 12 | RUN apt-get update && apt-get install -y openssl 13 | 14 | # Generate self-signed SSL certificate 15 | RUN openssl req -x509 -nodes -days 365 -newkey rsa:2048 \ 16 | -keyout /app/key.pem -out /app/cert.pem \ 17 | -subj "/C=US/ST=State/L=City/O=Organization/CN=localhost" 18 | 19 | EXPOSE 5173 20 | 21 | CMD ["npm", "run", "dev"] 22 | -------------------------------------------------------------------------------- /client/LICENSE: -------------------------------------------------------------------------------- 1 | Permission is hereby granted, free of charge, to any 2 | person obtaining a copy of this software and associated 3 | documentation files (the "Software"), to deal in the 4 | Software without restriction, including without 5 | limitation the rights to use, copy, modify, merge, 6 | publish, distribute, sublicense, and/or sell copies of 7 | the Software, and to permit persons to whom the Software 8 | is furnished to do so, subject to the following 9 | conditions: 10 | 11 | The above copyright notice and this permission notice 12 | shall be included in all copies or substantial portions 13 | of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF 16 | ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED 17 | TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A 18 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT 19 | SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 20 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 21 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR 22 | IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 23 | DEALINGS IN THE SOFTWARE. 24 | -------------------------------------------------------------------------------- /client/README.md: -------------------------------------------------------------------------------- 1 | # moshi-client 2 | 3 | Frontend for the demo. 4 | 5 | ## Run the client 6 | 7 | - Node is required, I recommend using [NVM](https://github.com/nvm-sh/nvm) to help you manage your node version and make sure you're on the recommended version for this project. If you do so run `nvm use`. 8 | - Generate a public/private key pair, `cert.pem` and `key.pem` and copy it to the at the root of this package 9 | - Create an env.local file and add your an entry for `VITE_QUEUE_API_PATH` (default should be `/api`) 10 | - Before running the project for the time or after dependencies update use `npm install` 11 | - To run the project use `npm run dev` 12 | - To build the project use `npm run build` 13 | 14 | ## Skipping the queue 15 | To skip the queue for standalone use, once the project is running go to `/?worker_addr={WORKER_ADDR}` where `WORKER_ADDR` is your worker instance address. 16 | For example : `https://localhost:5173/?worker_addr=0.0.0.0:8088` 17 | 18 | ## License 19 | 20 | The present code is provided under the MIT license. 21 | -------------------------------------------------------------------------------- /client/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | moshi.chat 9 | 10 | 11 |
12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /client/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "kyutai-client", 3 | "private": true, 4 | "version": "0.0.0", 5 | "type": "module", 6 | "scripts": { 7 | "dev": "vite", 8 | "build": "tsc && vite build", 9 | "lint": "eslint", 10 | "lint:fix": "eslint --fix", 11 | "prettier": "prettier --write .", 12 | "preview": "vite preview" 13 | }, 14 | "devDependencies": { 15 | "@eslint/js": "^9.3.0", 16 | "@types/react": "^18.3.1", 17 | "@types/react-dom": "^18.3.0", 18 | "@types/ws": "^8.5.10", 19 | "autoprefixer": "^10.4.19", 20 | "daisyui": "^4.12.2", 21 | "eslint": "^8.57.0", 22 | "eslint-config-prettier": "^9.1.0", 23 | "eslint-plugin-prettier": "^5.1.3", 24 | "eslint-plugin-react": "^7.34.1", 25 | "globals": "^15.2.0", 26 | "postcss": "^8.4.38", 27 | "prettier": "^3.2.5", 28 | "prettier-eslint": "^16.3.0", 29 | "prettier-plugin-tailwindcss": "^0.5.14", 30 | "tailwindcss": "^3.4.3", 31 | "typescript": "^5.2.2", 32 | "typescript-eslint": "^7.9.0", 33 | "vite": "^5.2.14", 34 | "vite-plugin-top-level-await": "^1.4.1" 35 | }, 36 | "dependencies": { 37 | "eruda": "^3.0.1", 38 | "opus-recorder": "^8.0.5", 39 | "react": "^18.3.1", 40 | "react-dom": "^18.3.1", 41 | "react-router-dom": "^6.23.1", 42 | "webm-duration-fix": "^1.0.4", 43 | "ws": "^8.16.0", 44 | "zod": "^3.23.8" 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /client/postcss.config.js: -------------------------------------------------------------------------------- 1 | export default { 2 | plugins: { 3 | tailwindcss: {}, 4 | autoprefixer: {}, 5 | }, 6 | }; 7 | -------------------------------------------------------------------------------- /client/public/assets/decoderWorker.min.wasm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/decoderWorker.min.wasm -------------------------------------------------------------------------------- /client/public/assets/favicon-16x16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/favicon-16x16.png -------------------------------------------------------------------------------- /client/public/assets/favicon-32x32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/favicon-32x32.png -------------------------------------------------------------------------------- /client/public/assets/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/favicon.ico -------------------------------------------------------------------------------- /client/public/assets/images/demo/image1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/images/demo/image1.jpg -------------------------------------------------------------------------------- /client/public/assets/images/demo/image10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/images/demo/image10.jpg -------------------------------------------------------------------------------- /client/public/assets/images/demo/image11.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/images/demo/image11.jpg -------------------------------------------------------------------------------- /client/public/assets/images/demo/image12.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/images/demo/image12.jpg -------------------------------------------------------------------------------- /client/public/assets/images/demo/image13.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/images/demo/image13.jpg -------------------------------------------------------------------------------- /client/public/assets/images/demo/image14.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/images/demo/image14.jpg -------------------------------------------------------------------------------- /client/public/assets/images/demo/image15.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/images/demo/image15.jpg -------------------------------------------------------------------------------- /client/public/assets/images/demo/image16.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/images/demo/image16.jpg -------------------------------------------------------------------------------- /client/public/assets/images/demo/image17.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/images/demo/image17.jpg -------------------------------------------------------------------------------- /client/public/assets/images/demo/image18.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/images/demo/image18.jpg -------------------------------------------------------------------------------- /client/public/assets/images/demo/image19.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/images/demo/image19.jpg -------------------------------------------------------------------------------- /client/public/assets/images/demo/image2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/images/demo/image2.jpg -------------------------------------------------------------------------------- /client/public/assets/images/demo/image20.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/images/demo/image20.jpg -------------------------------------------------------------------------------- /client/public/assets/images/demo/image3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/images/demo/image3.jpg -------------------------------------------------------------------------------- /client/public/assets/images/demo/image4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/images/demo/image4.jpg -------------------------------------------------------------------------------- /client/public/assets/images/demo/image5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/images/demo/image5.jpg -------------------------------------------------------------------------------- /client/public/assets/images/demo/image6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/images/demo/image6.jpg -------------------------------------------------------------------------------- /client/public/assets/images/demo/image7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/images/demo/image7.jpg -------------------------------------------------------------------------------- /client/public/assets/images/demo/image8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/images/demo/image8.jpg -------------------------------------------------------------------------------- /client/public/assets/images/demo/image9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/images/demo/image9.jpg -------------------------------------------------------------------------------- /client/public/assets/logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /client/src/app.tsx: -------------------------------------------------------------------------------- 1 | import ReactDOM from "react-dom/client"; 2 | import { 3 | createBrowserRouter, 4 | RouterProvider, 5 | } from "react-router-dom"; 6 | import "./index.css"; 7 | // @ts-expect-error - Worker is not recognized by the TS compiler 8 | import { DecoderWorker } from "./decoder/decoderWorker"; 9 | import { Queue } from "./pages/Queue/Queue"; 10 | 11 | const router = createBrowserRouter([ 12 | { 13 | path: "/", 14 | element: , 15 | }, 16 | ]); 17 | 18 | ReactDOM.createRoot(document.getElementById("root") as HTMLElement).render( 19 | 20 | ); 21 | -------------------------------------------------------------------------------- /client/src/components/Button/Button.tsx: -------------------------------------------------------------------------------- 1 | import { FC } from "react"; 2 | 3 | type ButtonProps = React.ButtonHTMLAttributes; 4 | export const Button: FC = ({ children, className, ...props }) => { 5 | return ( 6 | 12 | ); 13 | }; 14 | 15 | 16 | export const SwitchButton: FC = ({ children, className, ...props }) => { 17 | return ( 18 | 24 | ); 25 | }; 26 | -------------------------------------------------------------------------------- /client/src/components/ImageGallery/ImageGallery.tsx: -------------------------------------------------------------------------------- 1 | 2 | import { useState } from "react"; 3 | import { Button } from "../Button/Button"; 4 | 5 | // Natural images 6 | import img1 from "/assets/images/demo/image1.jpg"; 7 | import img2 from "/assets/images/demo/image2.jpg"; 8 | import img3 from "/assets/images/demo/image3.jpg"; 9 | import img4 from "/assets/images/demo/image4.jpg"; 10 | import img5 from "/assets/images/demo/image5.jpg"; 11 | import img6 from "/assets/images/demo/image6.jpg"; 12 | import img7 from "/assets/images/demo/image7.jpg"; 13 | import img8 from "/assets/images/demo/image8.jpg"; 14 | import img9 from "/assets/images/demo/image9.jpg"; 15 | import img10 from "/assets/images/demo/image10.jpg"; 16 | import img11 from "/assets/images/demo/image11.jpg"; 17 | import img12 from "/assets/images/demo/image12.jpg"; 18 | import img13 from "/assets/images/demo/image13.jpg"; 19 | import img14 from "/assets/images/demo/image14.jpg"; 20 | import img15 from "/assets/images/demo/image15.jpg"; 21 | import img16 from "/assets/images/demo/image16.jpg"; 22 | import img17 from "/assets/images/demo/image17.jpg"; 23 | import img18 from "/assets/images/demo/image18.jpg"; 24 | import img19 from "/assets/images/demo/image19.jpg"; 25 | import img20 from "/assets/images/demo/image20.jpg"; 26 | 27 | const images = [ 28 | img1, 29 | img2, 30 | img3, 31 | img4, 32 | img5, 33 | img6, 34 | img7, 35 | img8, 36 | img9, 37 | img10, 38 | img11, 39 | img12, 40 | img13, 41 | img14, 42 | img15, 43 | img16, 44 | img17, 45 | img18, 46 | img19, 47 | img20, 48 | ] 49 | 50 | var images_order: number[] = []; 51 | for (let i = 0; i < images.length; i++) { 52 | images_order.push(i) 53 | } 54 | 55 | type ImageGalleryProps = React.InputHTMLAttributes & { 56 | // Properties for the ImageGallery 57 | paramsSetter: Function; 58 | clickAction: Function; 59 | size: number; 60 | numImages: number; 61 | } 62 | 63 | 64 | type ImageItemProps = React.InputHTMLAttributes & { 65 | // Properties for a single item in the ImageGallery 66 | // Two actions: 67 | // paramsSetter sets the chosen image url into the model params 68 | // clickAction then starts the conversation 69 | paramsSetter: Function; 70 | clickAction: Function; 71 | size: number; 72 | imageUrl: string; 73 | } 74 | 75 | 76 | function ImageSelect(props: ImageItemProps) { 77 | // Represents a single image in the gallery 78 | const [isHover, setIsHover] = useState(false); 79 | 80 | const handleMouseEnter = () => { 81 | setIsHover(true); 82 | }; 83 | const handleMouseLeave = () => { 84 | setIsHover(false); 85 | }; 86 | let bordercolor = isHover ? "#f7a319" : "black"; 87 | let bgalpha = isHover ? 0.05 : 0.6; 88 | let textalpha = isHover ? 1.0 : 0.0 89 | let label = isHover ? "Connect" : "X"; 90 | let style = { 91 | width: props.size, 92 | height: props.size, 93 | background: `url(${props.imageUrl})`, 94 | backgroundSize: "100% 100%", 95 | border: `3px solid ${bordercolor}`, 96 | margin: "2px", 97 | padding: "0px", 98 | color: `rgba(255, 255, 255, ${textalpha})`, 99 | boxShadow: `inset 0 0 0 1000px rgba(0,0,0,${bgalpha})`, 100 | textShadow: `2px 2px 2px rgba(0, 0, 0, ${textalpha})` 101 | }; 102 | return ( 103 | 106 | ); 107 | } 108 | 109 | 110 | const shuffle = (array: number[]) => { 111 | return array.sort(() => Math.random() - 0.5); 112 | }; 113 | 114 | export const ImageGallery = (props: ImageGalleryProps) => { 115 | const [ordering, SetOrdering] = useState(images_order); 116 | 117 | function handleShuffle() { 118 | SetOrdering(shuffle([...ordering])); 119 | } 120 | 121 | // Image Gallery widget (random subset) 122 | const steps = []; 123 | for (let i = 0; i < props.numImages; i++) { 124 | steps.push(); 127 | } 128 | 129 | return ( 130 |
131 |
132 | 133 | 136 |
137 |
{steps}
138 |
) 139 | ; 140 | }; 141 | -------------------------------------------------------------------------------- /client/src/components/Input/Input.tsx: -------------------------------------------------------------------------------- 1 | type InputProps = React.InputHTMLAttributes & { 2 | error?: string; 3 | } 4 | 5 | export const Input = ({className, error, ...props}:InputProps) => { 6 | return ( 7 |
8 | 12 | {error &&

{error}

} 13 |
14 | ); 15 | } -------------------------------------------------------------------------------- /client/src/decoder/decoderWorker.ts: -------------------------------------------------------------------------------- 1 | export const DecoderWorker = new Worker( 2 | new URL("/assets/decoderWorker.min.js", import.meta.url), 3 | ); 4 | -------------------------------------------------------------------------------- /client/src/env.ts: -------------------------------------------------------------------------------- 1 | type ENV = { 2 | VITE_QUEUE_API_PATH: string; 3 | VITE_ENV: 'development' | 'production'; 4 | }; 5 | 6 | const parseEnv = (): ENV => { 7 | const VITE_QUEUE_API_PATH = import.meta.env.VITE_QUEUE_API_PATH; 8 | 9 | if (!VITE_QUEUE_API_PATH) { 10 | throw new Error("VITE_QUEUE_API_PATH is not defined"); 11 | } 12 | 13 | return { 14 | VITE_QUEUE_API_PATH, 15 | VITE_ENV: import.meta.env.DEV ? 'development' : 'production', 16 | }; 17 | }; 18 | 19 | export const env = parseEnv(); 20 | -------------------------------------------------------------------------------- /client/src/index.css: -------------------------------------------------------------------------------- 1 | @tailwind base; 2 | @tailwind components; 3 | @tailwind utilities; 4 | 5 | @layer utilities { 6 | 7 | /* Hide scrollbar for Chrome, Safari and Opera */ 8 | .no-scrollbar::-webkit-scrollbar { 9 | display: none; 10 | } 11 | 12 | /* Hide scrollbar for IE, Edge and Firefox */ 13 | .no-scrollbar { 14 | -ms-overflow-style: none; 15 | /* IE and Edge */ 16 | scrollbar-width: none; 17 | /* Firefox */ 18 | } 19 | 20 | .scrollbar::-webkit-scrollbar { 21 | width: 10px; 22 | } 23 | 24 | .scrollbar::-webkit-scrollbar-track { 25 | background: transparent; 26 | } 27 | 28 | .scrollbar::-webkit-scrollbar-thumb { 29 | background: white; 30 | border: 3px solid #f6f7ed; 31 | } 32 | } 33 | 34 | .main-grid { 35 | display: grid; 36 | grid-template-columns: 1fr; 37 | grid-template-rows: min-content 1fr 1fr; 38 | gap: 30px; 39 | grid-auto-flow: column; 40 | grid-template-areas: 41 | "controls" 42 | "player" 43 | "player-text"; 44 | 45 | @media screen and (min-width: 768px) { 46 | grid-template-columns: 2fr 2.5fr; 47 | grid-template-rows: min-content min-content min-content 1fr; 48 | gap: 30px 30px; 49 | grid-auto-flow: column; 50 | align-items: center; 51 | justify-items: center; 52 | grid-template-areas: 53 | "controls controls" 54 | "player player-stats" 55 | "player player-text" 56 | "player player-text"; 57 | } 58 | } 59 | 60 | .presentation { 61 | max-width: 450px; 62 | } 63 | 64 | .presentation>p { 65 | padding-top: 10px; 66 | } 67 | 68 | 69 | .gallery { 70 | max-width: 450px; 71 | } 72 | 73 | .cute-words { 74 | color: #54e8b3; 75 | } 76 | 77 | 78 | .download-links { 79 | color: #54e8b3; 80 | } 81 | 82 | .explain-links { 83 | color: #BCFCE5; 84 | } 85 | 86 | 87 | .controls { 88 | grid-area: controls; 89 | } 90 | 91 | .player { 92 | grid-area: player; 93 | grid-template-areas: 94 | "server-audio" 95 | "user-audio"; 96 | display: grid; 97 | grid-template-columns: 1fr; 98 | grid-template-rows: 1fr 1fr; 99 | align-items: center; 100 | justify-items: center; 101 | /* margin:auto; */ 102 | } 103 | 104 | .server-audio { 105 | grid-area: server-audio; 106 | } 107 | 108 | .user-audio { 109 | grid-area: user-audio; 110 | } 111 | 112 | .player-stats { 113 | grid-area: player-stats; 114 | width: 100%; 115 | height: 100%; 116 | } 117 | 118 | .commands { 119 | grid-area: commands; 120 | width: 100%; 121 | height: 100%; 122 | } 123 | 124 | .player-text { 125 | grid-area: player-text; 126 | width: 100%; 127 | height: 100%; 128 | overflow: scroll; 129 | } -------------------------------------------------------------------------------- /client/src/modules.d.ts: -------------------------------------------------------------------------------- 1 | declare module "opus-recorder"; 2 | -------------------------------------------------------------------------------- /client/src/pages/Conversation/MediaContext.ts: -------------------------------------------------------------------------------- 1 | import { MutableRefObject, createContext, useContext } from "react"; 2 | type MediaContextType = { 3 | startRecording: () => void; 4 | stopRecording: () => void; 5 | audioContext: MutableRefObject; 6 | audioStreamDestination: MutableRefObject; 7 | worklet: MutableRefObject; 8 | micDuration: MutableRefObject; 9 | actualAudioPlayed: MutableRefObject; 10 | }; 11 | 12 | export const MediaContext = createContext(null); 13 | 14 | export const useMediaContext = () => { 15 | const context = useContext(MediaContext); 16 | if (!context) { 17 | throw new Error( 18 | "useMediaContext must be used within a MediaContextProvider", 19 | ); 20 | } 21 | 22 | return context; 23 | }; -------------------------------------------------------------------------------- /client/src/pages/Conversation/SocketContext.ts: -------------------------------------------------------------------------------- 1 | import { createContext, useContext } from "react"; 2 | import { WSMessage } from "../../protocol/types"; 3 | 4 | type SocketContextType = { 5 | isConnected: boolean; 6 | socket: WebSocket | null; 7 | sendMessage: (message: WSMessage) => void; 8 | }; 9 | 10 | export const SocketContext = createContext({ 11 | isConnected: false, 12 | socket: null, 13 | sendMessage: () => {}, 14 | }); 15 | 16 | export const useSocketContext = () => { 17 | return useContext(SocketContext); 18 | }; 19 | -------------------------------------------------------------------------------- /client/src/pages/Conversation/canvas-logo-black.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/src/pages/Conversation/canvas-logo-black.png -------------------------------------------------------------------------------- /client/src/pages/Conversation/canvas-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/src/pages/Conversation/canvas-logo.png -------------------------------------------------------------------------------- /client/src/pages/Conversation/components/AudioVisualizer/AudioVisualizer.tsx: -------------------------------------------------------------------------------- 1 | import { FC, useCallback, useEffect, useRef } from "react"; 2 | 3 | type AudioVisualizerProps = { 4 | analyser: AnalyserNode | null; 5 | }; 6 | 7 | export const AudioVisualizer: FC = ({ analyser }) => { 8 | const requestRef = useRef(null); 9 | const canvasRef = useRef(null); 10 | 11 | const visualizeData = useCallback(() => { 12 | requestRef.current = window.requestAnimationFrame(() => visualizeData()); 13 | if (!canvasRef.current) { 14 | console.log("Canvas not found"); 15 | return; 16 | } 17 | const audioData = new Uint8Array(140); 18 | analyser?.getByteFrequencyData(audioData); 19 | const bar_width = 3; 20 | let start = 0; 21 | const ctx = canvasRef.current.getContext("2d"); 22 | if (!ctx) { 23 | console.log("Canvas context not found"); 24 | return; 25 | } 26 | ctx.clearRect(0, 0, canvasRef.current.width, canvasRef.current.height); 27 | for (let i = 0; i < audioData.length; i++) { 28 | start = i * 4; 29 | let gradient = ctx.createLinearGradient( 30 | 0, 31 | 0, 32 | canvasRef.current.width, 33 | canvasRef.current.height, 34 | ); 35 | gradient.addColorStop(0.2, "#2392f5"); 36 | gradient.addColorStop(0.5, "#fe0095"); 37 | gradient.addColorStop(1.0, "purple"); 38 | ctx.fillStyle = gradient; 39 | ctx.fillRect( 40 | start, 41 | canvasRef.current.height, 42 | bar_width, 43 | (-audioData[i] * 100) / 255, 44 | ); 45 | } 46 | }, [analyser]); 47 | 48 | const resetCanvas = useCallback(() => { 49 | if (!canvasRef.current) { 50 | return; 51 | } 52 | const ctx = canvasRef.current.getContext("2d"); 53 | if (!ctx) { 54 | return; 55 | } 56 | ctx.clearRect(0, 0, canvasRef.current.width, canvasRef.current.height); 57 | }, []); 58 | 59 | useEffect(() => { 60 | if (!analyser) { 61 | return; 62 | } 63 | visualizeData(); 64 | return () => { 65 | if (requestRef.current) { 66 | console.log("Canceling animation frame"); 67 | cancelAnimationFrame(requestRef.current); 68 | } 69 | }; 70 | }, [visualizeData, analyser, resetCanvas]); 71 | 72 | return ; 73 | }; 74 | -------------------------------------------------------------------------------- /client/src/pages/Conversation/components/AudioVisualizer/ClientVisualizer.tsx: -------------------------------------------------------------------------------- 1 | import { FC, RefObject, useCallback, useEffect, useRef, useState } from "react"; 2 | import { clamp } from "../../hooks/audioUtils"; 3 | 4 | type AudioVisualizerProps = { 5 | analyser: AnalyserNode | null; 6 | parent: RefObject; 7 | copyCanvasRef: RefObject; 8 | }; 9 | 10 | const MAX_INTENSITY = 255; 11 | 12 | const COLORS = [ 13 | "#197556", 14 | "#299e77", 15 | "#32b89b", 16 | "#31d4b8", 17 | "#14d9d5", 18 | "#41eff2", 19 | "#7ff3f5", 20 | "#789bf5", 21 | "#eb94eb", 22 | "#e63280", 23 | "#c41862", 24 | ]; 25 | 26 | export const ClientVisualizer: FC = ({ analyser, parent, copyCanvasRef }) => { 27 | const [canvasWidth, setCanvasWidth] = useState(parent.current ? Math.min(parent.current.clientWidth, parent.current.clientHeight) : 0); 28 | const requestRef = useRef(null); 29 | const canvasRef = useRef(null); 30 | 31 | const drawBars = useCallback( 32 | ( 33 | ctx: CanvasRenderingContext2D, 34 | x: number, 35 | y: number, 36 | volume: number, 37 | height: number, 38 | width: number, 39 | gap: number, 40 | ) => { 41 | const barHeight = height / 10 - gap; 42 | for (let i = 1; i <= 10; i++) { 43 | const barY = y + height + gap + Math.min(1, width / 30) - (i * barHeight + i * gap); 44 | ctx.fillStyle = COLORS[i - 1]; 45 | ctx.strokeStyle = "white"; 46 | ctx.lineWidth = Math.min(1, height / 100); 47 | if (i <= volume) { 48 | ctx.fillRect(x, barY, width, barHeight); 49 | } 50 | ctx.strokeRect(x, barY, width, barHeight); 51 | } 52 | }, 53 | [], 54 | ); 55 | 56 | const draw = useCallback((ctx: CanvasRenderingContext2D, audioData: Uint8Array, x: number, y: number, width: number, height: number) => { 57 | const stereoGap = Math.floor(width / 30); 58 | const barGap = Math.floor(height / 30); 59 | const padding = Math.floor(width / 30); 60 | const maxBarHeight = Math.floor(height - padding * 2); 61 | const maxBarWidth = Math.floor( 62 | width / 2.5 - stereoGap - padding * 2, 63 | ); 64 | 65 | const centerX = x + width / 2; 66 | const averageIntensity = Math.sqrt( 67 | audioData.reduce((acc, curr) => acc + curr * curr, 0) / audioData.length, 68 | ); 69 | const intensity = clamp( 70 | averageIntensity * 1.4, 71 | averageIntensity, 72 | MAX_INTENSITY, 73 | ); 74 | const volume = Math.floor((intensity * 10) / MAX_INTENSITY); 75 | ctx.fillStyle = "rgba(0, 0, 0, 0)"; 76 | ctx.fillRect(x, y, width, height); 77 | drawBars( 78 | ctx, 79 | centerX - maxBarWidth - stereoGap / 2, 80 | y, 81 | volume, 82 | maxBarHeight, 83 | maxBarWidth, 84 | barGap, 85 | ); 86 | drawBars( 87 | ctx, 88 | centerX + stereoGap / 2, 89 | y, 90 | volume, 91 | maxBarHeight, 92 | maxBarWidth, 93 | barGap, 94 | ); 95 | }, [analyser, drawBars]); 96 | 97 | const visualizeData = useCallback(() => { 98 | const width = parent.current ? Math.min(parent.current.clientWidth, parent.current.clientHeight) : 0 99 | if (width !== canvasWidth) { 100 | console.log("Setting canvas width"); 101 | setCanvasWidth(width); 102 | } 103 | requestRef.current = window.requestAnimationFrame(() => visualizeData()); 104 | if (!canvasRef.current) { 105 | console.log("Canvas not found"); 106 | return; 107 | } 108 | const audioData = new Uint8Array(140); 109 | analyser?.getByteFrequencyData(audioData); 110 | 111 | const ctx = canvasRef.current.getContext("2d"); 112 | if (!ctx) { 113 | console.log("Canvas context not found"); 114 | return; 115 | } 116 | ctx.clearRect(0, 0, canvasRef.current.width, canvasRef.current.height); 117 | draw(ctx, audioData, 0, 0, width, width); 118 | if (copyCanvasRef?.current) { 119 | const copyCtx = copyCanvasRef.current.getContext("2d"); 120 | if (copyCtx) { 121 | copyCtx.clearRect(220, 40, 140, 180); 122 | draw(copyCtx, audioData, 220, 40, 140, 180); 123 | } 124 | } 125 | }, [analyser, canvasWidth, drawBars, parent, copyCanvasRef, draw]); 126 | 127 | useEffect(() => { 128 | visualizeData(); 129 | return () => { 130 | if (requestRef.current) { 131 | console.log("Canceling animation frame"); 132 | cancelAnimationFrame(requestRef.current); 133 | } 134 | }; 135 | }, [visualizeData, analyser]); 136 | return ( 137 | 143 | ); 144 | }; 145 | -------------------------------------------------------------------------------- /client/src/pages/Conversation/components/Controls/Controls.tsx: -------------------------------------------------------------------------------- 1 | import { 2 | controlBOSMessage, 3 | controlEOSMessage, 4 | } from "../../../../protocol/testMessages"; 5 | import { useSocketContext } from "../../SocketContext"; 6 | import { Button } from "../../../../components/Button/Button"; 7 | 8 | export const Controls = () => { 9 | const { sendMessage } = useSocketContext(); 10 | 11 | const sendControlBOS = () => { 12 | sendMessage(controlBOSMessage); 13 | }; 14 | 15 | const sendControlEOS = () => { 16 | sendMessage(controlEOSMessage); 17 | }; 18 | return ( 19 |
20 | 23 | 26 |
27 | ); 28 | }; 29 | -------------------------------------------------------------------------------- /client/src/pages/Conversation/components/ModelParams/ModelParams.tsx: -------------------------------------------------------------------------------- 1 | import { FC, RefObject } from "react"; 2 | import { useModelParams } from "../../hooks/useModelParams"; 3 | import { Button } from "../../../../components/Button/Button"; 4 | 5 | type ModelParamsProps = { 6 | isConnected: boolean; 7 | isImageMode: boolean; 8 | modal?: RefObject, 9 | } & ReturnType; 10 | export const ModelParams: FC = ({ 11 | textTemperature, 12 | textTopk, 13 | audioTemperature, 14 | audioTopk, 15 | padMult, 16 | repetitionPenalty, 17 | repetitionPenaltyContext, 18 | imageResolution, 19 | setTextTemperature, 20 | setTextTopk, 21 | setAudioTemperature, 22 | setAudioTopk, 23 | setPadMult, 24 | setRepetitionPenalty, 25 | setRepetitionPenaltyContext, 26 | setImageResolution, 27 | resetParams, 28 | isConnected, 29 | isImageMode, 30 | modal, 31 | }) => { 32 | return ( 33 |
34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | {isImageMode && 72 | 73 | 74 | 75 | 76 | 77 | } 78 | 79 |
Text temperature:{textTemperature} setTextTemperature(parseFloat(e.target.value))} />
Text topk:{textTopk} setTextTopk(parseInt(e.target.value))} />
Audio temperature:{audioTemperature} setAudioTemperature(parseFloat(e.target.value))} />
Audio topk:{audioTopk} setAudioTopk(parseInt(e.target.value))} />
Padding multiplier:{padMult} setPadMult(parseFloat(e.target.value))} />
Repeat penalty:{repetitionPenalty} setRepetitionPenalty(parseFloat(e.target.value))} />
Repeat penalty last N:{repetitionPenaltyContext} setRepetitionPenaltyContext(parseFloat(e.target.value))} />
Image max-side (px):{imageResolution} setImageResolution(parseFloat(e.target.value))} />
80 |
81 | {!isConnected && } 82 | {!isConnected && } 83 |
84 |
85 | ) 86 | }; 87 | -------------------------------------------------------------------------------- /client/src/pages/Conversation/components/ServerAudio/ServerAudio.tsx: -------------------------------------------------------------------------------- 1 | import { FC, useRef } from "react"; 2 | import { AudioStats, useServerAudio } from "../../hooks/useServerAudio"; 3 | import { ServerVisualizer } from "../AudioVisualizer/ServerVisualizer"; 4 | 5 | type ServerAudioProps = { 6 | setGetAudioStats: (getAudioStats: () => AudioStats) => void; 7 | imageUrl: string | undefined; 8 | copyCanvasRef?: React.RefObject; 9 | }; 10 | export const ServerAudio: FC = ({ setGetAudioStats, imageUrl, copyCanvasRef }) => { 11 | const { analyser, hasCriticalDelay, setHasCriticalDelay } = useServerAudio({ 12 | setGetAudioStats, 13 | }); 14 | const containerRef = useRef(null); 15 | return ( 16 | <> 17 | {hasCriticalDelay && ( 18 |
19 |

A connection issue has been detected, you've been reconnected

20 | 28 |
29 | )} 30 |
31 | 32 |
33 | 34 | ); 35 | }; 36 | -------------------------------------------------------------------------------- /client/src/pages/Conversation/components/ServerAudio/ServerAudioStats.tsx: -------------------------------------------------------------------------------- 1 | import { useState, useEffect, useRef } from "react"; 2 | 3 | type ServerAudioStatsProps = { 4 | getAudioStats: React.MutableRefObject< 5 | () => { 6 | playedAudioDuration: number; 7 | missedAudioDuration: number; 8 | totalAudioMessages: number; 9 | delay: number; 10 | minPlaybackDelay: number; 11 | maxPlaybackDelay: number; 12 | } 13 | >; 14 | }; 15 | 16 | export const ServerAudioStats = ({ getAudioStats }: ServerAudioStatsProps) => { 17 | const [audioStats, setAudioStats] = useState(getAudioStats.current()); 18 | 19 | const movingAverageSum = useRef(0.); 20 | const movingAverageCount = useRef(0.); 21 | const movingBeta = 0.85; 22 | 23 | let convertMinSecs = (total_secs: number) => { 24 | // convert secs to the format mm:ss.cc 25 | let mins = (Math.floor(total_secs / 60)).toString(); 26 | let secs = (Math.floor(total_secs) % 60).toString(); 27 | let cents = (Math.floor(100 * (total_secs - Math.floor(total_secs)))).toString(); 28 | if (secs.length < 2) { 29 | secs = "0" + secs; 30 | } 31 | if (cents.length < 2) { 32 | cents = "0" + cents; 33 | } 34 | return mins + ":" + secs + "." + cents; 35 | }; 36 | 37 | useEffect(() => { 38 | const interval = setInterval(() => { 39 | const newAudioStats = getAudioStats.current(); 40 | setAudioStats(newAudioStats); 41 | movingAverageCount.current *= movingBeta; 42 | movingAverageCount.current += (1 - movingBeta) * 1; 43 | movingAverageSum.current *= movingBeta; 44 | movingAverageSum.current += (1 - movingBeta) * newAudioStats.delay; 45 | 46 | }, 141); 47 | return () => { 48 | clearInterval(interval); 49 | }; 50 | }, []); 51 | 52 | return ( 53 |
54 |

Server Audio Stats

55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 |
Audio played: {convertMinSecs(audioStats.playedAudioDuration)}
Missed audio: {convertMinSecs(audioStats.missedAudioDuration)}
Latency: {(movingAverageSum.current / movingAverageCount.current).toFixed(3)}
Min/Max buffer: {audioStats.minPlaybackDelay.toFixed(3)} / {audioStats.maxPlaybackDelay.toFixed(3)}
75 |
76 | ); 77 | }; 78 | -------------------------------------------------------------------------------- /client/src/pages/Conversation/components/ServerInfo/ServerInfo.tsx: -------------------------------------------------------------------------------- 1 | import { useServerInfo } from "../../hooks/useServerInfo"; 2 | 3 | export const ServerInfo = () => { 4 | const { serverInfo } = useServerInfo(); 5 | if (!serverInfo) { 6 | return null; 7 | } 8 | return ( 9 |
10 | Our server is running on the following configuration: 11 |
Text temperature: {serverInfo.text_temperature}
12 |
Text topk: {serverInfo.text_topk}
13 |
Audio temperature: {serverInfo.audio_temperature}
14 |
Audio topk: {serverInfo.audio_topk}
15 |
Pad mult: {serverInfo.pad_mult}
16 |
Repeat penalty last N: {serverInfo.repetition_penalty_context}
17 |
Repeat penalty: {serverInfo.repetition_penalty}
18 |
LM model file: {serverInfo.lm_model_file}
19 |
Instance name: {serverInfo.instance_name}
20 |
21 | ); 22 | }; 23 | -------------------------------------------------------------------------------- /client/src/pages/Conversation/components/TextDisplay/TextDisplay.tsx: -------------------------------------------------------------------------------- 1 | import { FC, useEffect, useRef } from "react"; 2 | import { useServerText } from "../../hooks/useServerText"; 3 | 4 | type TextDisplayProps = { 5 | containerRef: React.RefObject; 6 | displayColor: boolean | undefined; 7 | }; 8 | 9 | // Palette 2: Purple to Green Moshi 10 | // sns.diverging_palette(288, 145, s=90, l=72, n=11) 11 | const textDisplayColors = [ 12 | "#d19bf7", "#d7acf6", "#debdf5", "#e4cef4", 13 | "#ebe0f3", "#eef2f0", "#c8ead9", "#a4e2c4", 14 | "#80d9af", "#5bd09a", "#38c886"] 15 | 16 | function clamp_color(v: number) { 17 | return v <= 0 18 | ? 0 19 | : v >= textDisplayColors.length 20 | ? textDisplayColors.length 21 | : v 22 | } 23 | 24 | export const TextDisplay: FC = ({ 25 | containerRef, displayColor 26 | }) => { 27 | const { text, textColor } = useServerText(); 28 | const currentIndex = text.length - 1; 29 | const prevScrollTop = useRef(0); 30 | 31 | useEffect(() => { 32 | if (containerRef.current) { 33 | prevScrollTop.current = containerRef.current.scrollTop; 34 | containerRef.current.scroll({ 35 | top: containerRef.current.scrollHeight, 36 | behavior: "smooth", 37 | }); 38 | } 39 | }, [text]); 40 | if (displayColor && (textColor.length == text.length)) { 41 | return ( 42 |
43 | {text.map((t, i) => ( 44 | 51 | {t} 52 | 53 | )) 54 | } 55 |
56 | ); 57 | } 58 | else { 59 | return ( 60 |
61 | {text.map((t, i) => ( 62 | 66 | {t} 67 | 68 | ))} 69 |
70 | ); 71 | }; 72 | }; 73 | -------------------------------------------------------------------------------- /client/src/pages/Conversation/components/TextDisplay/TextDisplayStats.tsx: -------------------------------------------------------------------------------- 1 | import { FC } from "react"; 2 | 3 | type TextDisplayStatsProps = { 4 | totalTextMessages: number; 5 | }; 6 | export const TextDisplayStats: FC = ({ 7 | totalTextMessages, 8 | }) => { 9 | return ( 10 |
11 |

Text Display Stats

12 |
13 |
14 |

Total messages:

15 |

{totalTextMessages}

16 |
17 |
18 |
19 | ); 20 | }; 21 | -------------------------------------------------------------------------------- /client/src/pages/Conversation/components/UserAudio/UserAudio.tsx: -------------------------------------------------------------------------------- 1 | import { FC, useCallback, useEffect, useRef, useState } from "react"; 2 | import { useSocketContext } from "../../SocketContext"; 3 | import { useUserAudio } from "../../hooks/useUserAudio"; 4 | import { ClientVisualizer } from "../AudioVisualizer/ClientVisualizer"; 5 | 6 | type UserAudioProps = { 7 | copyCanvasRef: React.RefObject; 8 | }; 9 | export const UserAudio: FC = ({ copyCanvasRef }) => { 10 | const [analyser, setAnalyser] = useState(null); 11 | const { sendMessage, isConnected } = useSocketContext(); 12 | const containerRef = useRef(null); 13 | const onRecordingStart = useCallback(() => { 14 | console.log("Recording started"); 15 | }, []); 16 | 17 | const onRecordingStop = useCallback(() => { 18 | console.log("Recording stopped"); 19 | }, []); 20 | 21 | const onRecordingChunk = useCallback( 22 | (chunk: Uint8Array) => { 23 | if (!isConnected) { 24 | return; 25 | } 26 | sendMessage({ 27 | type: "audio", 28 | data: chunk, 29 | }); 30 | }, 31 | [sendMessage, isConnected], 32 | ); 33 | 34 | const { startRecordingUser, stopRecording } = useUserAudio({ 35 | constraints: { 36 | audio: { 37 | echoCancellation: true, 38 | noiseSuppression: true, 39 | autoGainControl: true, 40 | channelCount: 1, 41 | }, 42 | video: false, 43 | }, 44 | onDataChunk: onRecordingChunk, 45 | onRecordingStart, 46 | onRecordingStop, 47 | }); 48 | 49 | useEffect(() => { 50 | let res: Awaited>; 51 | if (isConnected) { 52 | startRecordingUser().then(result => { 53 | if (result) { 54 | res = result; 55 | setAnalyser(result.analyser); 56 | } 57 | }); 58 | } 59 | return () => { 60 | console.log("Stop recording called from somewhere else."); 61 | stopRecording(); 62 | res?.source?.disconnect(); 63 | }; 64 | }, [startRecordingUser, stopRecording, isConnected]); 65 | 66 | return ( 67 |
68 | 69 |
70 | ); 71 | }; 72 | -------------------------------------------------------------------------------- /client/src/pages/Conversation/components/UserAudio/UserAudioStats.tsx: -------------------------------------------------------------------------------- 1 | import { FC } from "react"; 2 | 3 | type UserAudioStatsProps = { 4 | sentMessagesCount: number; 5 | }; 6 | 7 | export const UserAudioStats: FC = ({ 8 | sentMessagesCount, 9 | }) => { 10 | return ( 11 |
12 |

User Audio Stats

13 |
14 |
15 |

Total messages:

16 |

{sentMessagesCount}

17 |
18 |
19 |
20 | ); 21 | }; 22 | -------------------------------------------------------------------------------- /client/src/pages/Conversation/getMimeType.ts: -------------------------------------------------------------------------------- 1 | export const mimeTypeCheck = () => { 2 | const types = [ 3 | "audio/ogg", 4 | "audio/wav", 5 | "audio/webm;codecs=opus", 6 | "audio/webm;codecs=pcm", 7 | "audio/webm;codecs=pcm_s16le", 8 | "audio/webm;codecs=pcm_f32le", 9 | "audio/mp3", 10 | "audio/aac", 11 | "audio/mp4", 12 | "audio/webm", 13 | "audio/mpeg", 14 | "video/mp4", 15 | "video/webm;codecs=vp9", 16 | "video/webm;codecs=vp8", 17 | "video/webm", 18 | ]; 19 | for (const mime of types) { 20 | console.log(mime, MediaRecorder.isTypeSupported(mime)); 21 | } 22 | } 23 | 24 | const getVideoMimeType = () => { 25 | if (!MediaRecorder.isTypeSupported){ 26 | return "video/mp4"; 27 | } 28 | if (MediaRecorder.isTypeSupported("video/webm")) { 29 | return "video/webm"; 30 | } 31 | if (MediaRecorder.isTypeSupported("video/mp4")) { 32 | return "video/mp4"; 33 | } 34 | console.log("No supported video mime type found") 35 | return ""; 36 | }; 37 | 38 | const getAudioMimeType = () => { 39 | if (!MediaRecorder.isTypeSupported){ 40 | return "audio/mp4"; 41 | } 42 | if (MediaRecorder.isTypeSupported("audio/webm")) { 43 | return "audio/webm"; 44 | } 45 | if (MediaRecorder.isTypeSupported("audio/mpeg")) { 46 | return "audio/mpeg"; 47 | }`` 48 | if (MediaRecorder.isTypeSupported("audio/mp4")) { 49 | return "audio/mp4"; 50 | } 51 | console.log("No supported audio mime type found") 52 | return ""; 53 | } 54 | 55 | export const getMimeType = (type: "audio" | "video") => { 56 | if(type === "audio") { 57 | return getAudioMimeType(); 58 | } 59 | return getVideoMimeType(); 60 | } 61 | 62 | export const getExtension = (type: "audio" | "video") => { 63 | if(getMimeType(type).includes("mp4")) { 64 | return "mp4"; 65 | } 66 | if(getMimeType(type).includes("mpeg")) { 67 | return "mp3"; 68 | } 69 | return "webm"; 70 | } -------------------------------------------------------------------------------- /client/src/pages/Conversation/hooks/audioUtils.ts: -------------------------------------------------------------------------------- 1 | export const clamp = (value: number, min: number, max: number) => { 2 | return Math.min(Math.max(value, min), max); 3 | }; 4 | -------------------------------------------------------------------------------- /client/src/pages/Conversation/hooks/useServerInfo.ts: -------------------------------------------------------------------------------- 1 | import { useCallback, useEffect, useState } from "react"; 2 | import { useSocketContext } from "../SocketContext"; 3 | import { decodeMessage } from "../../../protocol/encoder"; 4 | import { z } from "zod"; 5 | 6 | const ServersInfoSchema = z.object({ 7 | text_temperature: z.number(), 8 | text_topk: z.number(), 9 | audio_temperature: z.number(), 10 | audio_topk: z.number(), 11 | pad_mult: z.number(), 12 | repetition_penalty_context: z.number(), 13 | repetition_penalty: z.number(), 14 | lm_model_file: z.string(), 15 | instance_name: z.string(), 16 | build_info: z.object({ 17 | build_timestamp: z.string(), 18 | build_date: z.string(), 19 | git_branch: z.string(), 20 | git_timestamp: z.string(), 21 | git_date: z.string(), 22 | git_hash: z.string(), 23 | git_describe: z.string(), 24 | rustc_host_triple: z.string(), 25 | rustc_version: z.string(), 26 | cargo_target_triple: z.string(), 27 | }), 28 | }); 29 | 30 | const parseInfo = (infos: any) => { 31 | const serverInfo = ServersInfoSchema.safeParse(infos); 32 | if (!serverInfo.success) { 33 | console.error(serverInfo.error); 34 | return null; 35 | } 36 | return serverInfo.data; 37 | }; 38 | 39 | type ServerInfo = { 40 | text_temperature: number; 41 | text_topk: number; 42 | audio_temperature: number; 43 | audio_topk: number; 44 | pad_mult: number; 45 | repetition_penalty_context: number; 46 | repetition_penalty: number; 47 | lm_model_file: string; 48 | instance_name: string; 49 | build_info: { 50 | build_timestamp: string; 51 | build_date: string; 52 | git_branch: string; 53 | git_timestamp: string; 54 | git_date: string; 55 | git_hash: string; 56 | git_describe: string; 57 | rustc_host_triple: string; 58 | rustc_version: string; 59 | cargo_target_triple: string; 60 | }; 61 | } 62 | 63 | export const useServerInfo = () => { 64 | const [serverInfo, setServerInfo] = useState(null); 65 | const { socket } = useSocketContext(); 66 | 67 | const onSocketMessage = useCallback((e: MessageEvent) => { 68 | const dataArray = new Uint8Array(e.data); 69 | const message = decodeMessage(dataArray); 70 | if (message.type === "metadata") { 71 | const infos = parseInfo(message.data); 72 | if (infos) { 73 | setServerInfo(infos); 74 | console.log("received metadata", infos); 75 | } 76 | } 77 | }, [setServerInfo]); 78 | 79 | useEffect(() => { 80 | const currentSocket = socket; 81 | if (!currentSocket) { 82 | return; 83 | } 84 | setServerInfo(null); 85 | currentSocket.addEventListener("message", onSocketMessage); 86 | return () => { 87 | currentSocket.removeEventListener("message", onSocketMessage); 88 | }; 89 | }, [socket]); 90 | 91 | return { serverInfo }; 92 | }; 93 | -------------------------------------------------------------------------------- /client/src/pages/Conversation/hooks/useServerText.ts: -------------------------------------------------------------------------------- 1 | import { useCallback, useEffect, useState } from "react"; 2 | import { useSocketContext } from "../SocketContext"; 3 | import { decodeMessage } from "../../../protocol/encoder"; 4 | 5 | export const useServerText = () => { 6 | const [text, setText] = useState([]); 7 | const [textColor, setTextColor] = useState([]); 8 | const [totalTextMessages, setTotalTextMessages] = useState(0); 9 | const { socket } = useSocketContext(); 10 | 11 | const onSocketMessage = useCallback((e: MessageEvent) => { 12 | const dataArray = new Uint8Array(e.data); 13 | const message = decodeMessage(dataArray); 14 | if (message.type === "text") { 15 | setText(text => [...text, message.data]); 16 | setTotalTextMessages(count => count + 1); 17 | } else if (message.type === "coloredtext") { 18 | setText(text => [...text, message.data]); 19 | setTextColor(textColor => [...textColor, message.color]); 20 | setTotalTextMessages(count => count + 1); 21 | } 22 | }, []); 23 | 24 | useEffect(() => { 25 | const currentSocket = socket; 26 | if (!currentSocket) { 27 | return; 28 | } 29 | setText([]); 30 | currentSocket.addEventListener("message", onSocketMessage); 31 | return () => { 32 | currentSocket.removeEventListener("message", onSocketMessage); 33 | }; 34 | }, [socket]); 35 | 36 | return { text, textColor, totalTextMessages }; 37 | }; 38 | -------------------------------------------------------------------------------- /client/src/pages/Conversation/hooks/useSocket.ts: -------------------------------------------------------------------------------- 1 | import { useState, useEffect, useCallback, useRef } from "react"; 2 | import { WSMessage } from "../../../protocol/types"; 3 | import { decodeMessage, encodeMessage } from "../../../protocol/encoder"; 4 | 5 | export const useSocket = ({ 6 | onMessage, 7 | uri, 8 | onDisconnect: onDisconnectProp, 9 | }: { 10 | onMessage?: (message: WSMessage) => void; 11 | uri: string; 12 | onDisconnect?: () => void; 13 | }) => { 14 | const lastMessageTime = useRef(null); 15 | const [isConnected, setIsConnected] = useState(false); 16 | const [socket, setSocket] = useState(null); 17 | 18 | const sendMessage = useCallback( 19 | (message: WSMessage) => { 20 | if (!socket || !isConnected) { 21 | console.log("socket not connected"); 22 | return; 23 | } 24 | socket.send(encodeMessage(message)); 25 | }, 26 | [isConnected], 27 | ); 28 | 29 | const onConnect = useCallback(() => { 30 | console.log("connected, now waiting for handshake."); 31 | // setIsConnected(true); 32 | }, [setIsConnected]); 33 | 34 | const onDisconnect = useCallback(() => { 35 | console.log("disconnected"); 36 | if (onDisconnectProp) { 37 | onDisconnectProp(); 38 | } 39 | setIsConnected(false); 40 | }, [onDisconnectProp]); 41 | 42 | const onMessageEvent = useCallback( 43 | (eventData: MessageEvent) => { 44 | lastMessageTime.current = Date.now(); 45 | const dataArray = new Uint8Array(eventData.data); 46 | const message = decodeMessage(dataArray); 47 | if (message.type == "handshake") { 48 | console.log("Handshake received, let's rocknroll."); 49 | setIsConnected(true); 50 | } 51 | if (!onMessage) { 52 | return; 53 | } 54 | onMessage(message); 55 | }, 56 | [onMessage, setIsConnected], 57 | ); 58 | 59 | const start = useCallback(() => { 60 | const ws = new WebSocket(uri); 61 | ws.binaryType = "arraybuffer"; 62 | ws.addEventListener("open", onConnect); 63 | ws.addEventListener("close", onDisconnect); 64 | ws.addEventListener("message", onMessageEvent); 65 | setSocket(ws); 66 | console.log("Socket created", ws); 67 | lastMessageTime.current = Date.now(); 68 | }, [uri, onMessage, onDisconnectProp]); 69 | 70 | const stop = useCallback(() => { 71 | setIsConnected(false); 72 | if (onDisconnectProp) { 73 | onDisconnectProp(); 74 | } 75 | socket?.close(); 76 | setSocket(null); 77 | }, [socket]); 78 | 79 | useEffect(() => { 80 | if(!isConnected){ 81 | return; 82 | } 83 | let intervalId = setInterval(() => { 84 | if (lastMessageTime.current && Date.now() - lastMessageTime.current > 10000) { 85 | console.log("closing socket due to inactivity", socket); 86 | socket?.close(); 87 | onDisconnect(); 88 | clearInterval(intervalId); 89 | } 90 | }, 500); 91 | 92 | return () => { 93 | lastMessageTime.current = null; 94 | clearInterval(intervalId); 95 | }; 96 | }, [isConnected, socket]); 97 | 98 | return { 99 | isConnected, 100 | socket, 101 | sendMessage, 102 | start, 103 | stop, 104 | }; 105 | }; 106 | -------------------------------------------------------------------------------- /client/src/pages/Queue/api/client.ts: -------------------------------------------------------------------------------- 1 | import { APIError } from "./errors/api_error"; 2 | import { ResponseError } from "./errors/response_error"; 3 | import { validateAddUser, validateCheckUser } from "./validators"; 4 | 5 | export const getAPIClient = (url:string) => ({ 6 | addUser: async (queueId:string) => { 7 | const encodedQueueId = encodeURIComponent(queueId); 8 | const response = await fetch(`${url}/add_user?queue_id=${encodedQueueId}`); 9 | if (!response.ok) { 10 | const errorText = await response.text(); 11 | throw new APIError(errorText , response.status); 12 | } 13 | const json = await response.json(); 14 | const result = validateAddUser(json); 15 | if(result.success) { 16 | return result.data; 17 | } 18 | console.error(result.error.message); 19 | throw new ResponseError("Failed to validate response"); 20 | 21 | }, 22 | checkUser: async (sessionId:number, sessionAuthId:string) => { 23 | const encodedSessionAuthId = encodeURIComponent(sessionAuthId); 24 | const encodedSessionId = encodeURIComponent(sessionId); 25 | const response = await fetch(`${url}/check_user?session_id=${encodedSessionId}&session_auth_id=${encodedSessionAuthId}`); 26 | if (!response.ok) { 27 | const errorText = await response.text(); 28 | throw new APIError(errorText , response.status); 29 | } 30 | const json = await response.json(); 31 | const result = validateCheckUser(json); 32 | if(result.success) { 33 | return result.data; 34 | } 35 | console.error(result.error.message); 36 | throw new ResponseError("Failed to validate response"); 37 | }, 38 | addFeedback: async ({ 39 | workerAuthId, 40 | sessionId, 41 | sessionAuthId, 42 | feedback, 43 | timestamp, 44 | email 45 | }:{ 46 | workerAuthId:string; 47 | sessionId:number; 48 | sessionAuthId:string; 49 | feedback:0|1; 50 | timestamp:number; 51 | email:string; 52 | 53 | } ) => { 54 | const encodedWorkerAuthId = encodeURIComponent(workerAuthId); 55 | const encodedSessionAuthId = encodeURIComponent(sessionAuthId); 56 | const encodedSessionId = encodeURIComponent(sessionId); 57 | const encodedFeedback = encodeURIComponent(feedback); 58 | const encodedTimestamp = encodeURIComponent(timestamp); 59 | const encodedEmail = encodeURIComponent(email); 60 | const response = await fetch(`${url}/user_feedback?worker_auth_id=${encodedWorkerAuthId}&session_id=${encodedSessionId}&session_auth_id=${encodedSessionAuthId}&feedback=${encodedFeedback}×tamp=${encodedTimestamp}&email=${encodedEmail}`); 61 | if (!response.ok) { 62 | const errorText = await response.text(); 63 | throw new APIError(errorText , response.status); 64 | } 65 | return response.json(); 66 | } 67 | }); 68 | -------------------------------------------------------------------------------- /client/src/pages/Queue/api/errors/api_error.ts: -------------------------------------------------------------------------------- 1 | export class APIError extends Error { 2 | status:number; 3 | 4 | constructor(message:string, status:number) { 5 | super(message); 6 | this.status = status; 7 | this.name = "APIError"; 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /client/src/pages/Queue/api/errors/response_error.ts: -------------------------------------------------------------------------------- 1 | export class ResponseError extends Error { 2 | constructor(message:string) { 3 | super(message); 4 | this.name = "ResponseError"; 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /client/src/pages/Queue/api/validators.ts: -------------------------------------------------------------------------------- 1 | import { z } from "zod" 2 | 3 | export const validateAddUser = (response: unknown) => { 4 | const AddUser = z.object({ 5 | session_id: z.number(), 6 | session_auth_id: z.string(), 7 | }); 8 | return AddUser.safeParse(response); 9 | }; 10 | 11 | export const validateCheckUser = (response: unknown) => { 12 | const CheckUser = z.object({ 13 | session_id: z.number(), 14 | // TODO: add more statuses 15 | status: z.enum(['wait', 'ready']), 16 | worker_auth_id: z.string().nullable(), 17 | worker_addr: z.string().nullable(), 18 | current_position: z.string(), 19 | }); 20 | return CheckUser.safeParse(response); 21 | } -------------------------------------------------------------------------------- /client/src/pages/Queue/hooks/useUserEmail.ts: -------------------------------------------------------------------------------- 1 | import { useCallback, useState } from "react"; 2 | import {z} from "zod"; 3 | 4 | const validateEmail = z.string().email(); 5 | 6 | export const useUserEmail = (isBypass: boolean) => { 7 | const [userEmail, setUserEmail] = useState(''); 8 | const [error, setError] = useState(null); 9 | 10 | const validate = useCallback((email: string) => { 11 | if(isBypass) { 12 | setError(null); 13 | return true; 14 | } 15 | const result = validateEmail.safeParse(email); 16 | if(result.success) { 17 | setError(null); 18 | return true; 19 | } 20 | setError('Invalid email address'); 21 | return false; 22 | }, [setError]); 23 | return {userEmail, setUserEmail, error, validate}; 24 | } 25 | -------------------------------------------------------------------------------- /client/src/protocol/encoder.ts: -------------------------------------------------------------------------------- 1 | import { 2 | CONTROL_MESSAGE, 3 | CONTROL_MESSAGES_MAP, 4 | MODELS_MAP, 5 | WSMessage, 6 | VERSIONS_MAP, 7 | } from "./types"; 8 | 9 | export const encodeMessage = (message: WSMessage): Uint8Array => { 10 | switch (message.type) { 11 | case "handshake": 12 | return new Uint8Array([ 13 | 0x00, 14 | VERSIONS_MAP[message.version], 15 | MODELS_MAP[message.model], 16 | ]); 17 | case "audio": 18 | return new Uint8Array([0x01, ...message.data]); 19 | case "text": 20 | return new Uint8Array([0x02, ...new TextEncoder().encode(message.data)]); 21 | case "coloredtext": 22 | return new Uint8Array([0x02, 0x05, ...new TextEncoder().encode(message.data)]); 23 | case "control": 24 | return new Uint8Array([0x03, CONTROL_MESSAGES_MAP[message.action]]); 25 | case "metadata": 26 | return new Uint8Array([ 27 | 0x04, 28 | ...new TextEncoder().encode(JSON.stringify(message.data)), 29 | ]); 30 | case "error": 31 | return new Uint8Array([0x05, ...new TextEncoder().encode(message.data)]); 32 | case "ping": 33 | return new Uint8Array([0x06]); 34 | } 35 | }; 36 | 37 | export const decodeMessage = (data: Uint8Array): WSMessage => { 38 | const type = data[0]; 39 | const payload = data.slice(1); 40 | switch (type) { 41 | case 0x00: { 42 | return { 43 | type: "handshake", 44 | version: 0, 45 | model: 0, 46 | }; 47 | } 48 | case 0x01: 49 | return { 50 | type: "audio", 51 | data: payload, 52 | }; 53 | case 0x02: 54 | return { 55 | type: "text", 56 | data: new TextDecoder().decode(payload), 57 | }; 58 | case 0x07: 59 | return { 60 | type: "coloredtext", 61 | color: payload[0], 62 | data: new TextDecoder().decode(payload.slice(1)), 63 | }; 64 | case 0x03: { 65 | const action = Object.keys(CONTROL_MESSAGES_MAP).find( 66 | key => CONTROL_MESSAGES_MAP[key as CONTROL_MESSAGE] === payload[0], 67 | ) as CONTROL_MESSAGE | undefined; 68 | 69 | //TODO: log this and don't throw 70 | if (!action) { 71 | throw new Error("Unknown control message"); 72 | } 73 | return { 74 | type: "control", 75 | action, 76 | }; 77 | } 78 | case 0x04: 79 | return { 80 | type: "metadata", 81 | data: JSON.parse(new TextDecoder().decode(payload)), 82 | } 83 | case 0x05: 84 | return { 85 | type: "error", 86 | data: new TextDecoder().decode(payload), 87 | } 88 | case 0x06: 89 | return { 90 | type: "ping", 91 | } 92 | default: { 93 | console.log(type); 94 | throw new Error("Unknown message type"); 95 | } 96 | } 97 | }; 98 | -------------------------------------------------------------------------------- /client/src/protocol/testMessages.ts: -------------------------------------------------------------------------------- 1 | import { WSMessage } from "./types"; 2 | 3 | export const handshakeMessage: WSMessage = { 4 | type: "handshake", 5 | version: 0, 6 | model: 0, 7 | }; 8 | 9 | export const audioMessage: WSMessage = { 10 | type: "audio", 11 | data: new Uint8Array(10), 12 | }; 13 | 14 | export const textMessage: WSMessage = { 15 | type: "text", 16 | data: "Hello", 17 | }; 18 | 19 | export const controlBOSMessage: WSMessage = { 20 | type: "control", 21 | action: "start", 22 | }; 23 | 24 | export const controlEOSMessage: WSMessage = { 25 | type: "control", 26 | action: "endTurn", 27 | }; 28 | 29 | export const metadataMessage: WSMessage = { 30 | type: "metadata", 31 | data: { key: "value" }, 32 | }; 33 | -------------------------------------------------------------------------------- /client/src/protocol/types.ts: -------------------------------------------------------------------------------- 1 | export type MessageType = 2 | | "handshake" 3 | | "audio" 4 | | "text" 5 | | "coloredtext" 6 | | "control" 7 | | "metadata"; 8 | 9 | export const VERSIONS_MAP = { 10 | 0: 0b00000000, 11 | } as const; 12 | 13 | export const MODELS_MAP = { 14 | 0: 0b00000000, 15 | } as const; 16 | 17 | export type VERSION = keyof typeof VERSIONS_MAP; 18 | 19 | export type MODEL = keyof typeof MODELS_MAP; 20 | 21 | export type WSMessage = 22 | | { 23 | type: "handshake"; 24 | version: VERSION; 25 | model: MODEL; 26 | } 27 | | { 28 | type: "audio"; 29 | data: Uint8Array; 30 | } 31 | | { 32 | type: "text"; 33 | data: string; 34 | } 35 | | { 36 | type: "coloredtext"; 37 | color: number; 38 | data: string; 39 | } 40 | | { 41 | type: "control"; 42 | action: CONTROL_MESSAGE; 43 | } 44 | | { 45 | type: "metadata"; 46 | data: unknown; 47 | } 48 | | { 49 | type: "error"; 50 | data: string; 51 | } 52 | | { 53 | type: "ping"; 54 | } 55 | 56 | export const CONTROL_MESSAGES_MAP = { 57 | start: 0b00000000, 58 | endTurn: 0b00000001, 59 | pause: 0b00000010, 60 | restart: 0b00000011, 61 | } as const; 62 | 63 | export type CONTROL_MESSAGE = keyof typeof CONTROL_MESSAGES_MAP; 64 | -------------------------------------------------------------------------------- /client/tailwind.config.js: -------------------------------------------------------------------------------- 1 | /** @type {import('tailwindcss').Config} */ 2 | 3 | export default { 4 | content: ["./src/**/*.{js,jsx,ts,tsx}", "./index.html"], 5 | theme: { 6 | extend: {}, 7 | }, 8 | plugins: [require('daisyui')], 9 | }; 10 | -------------------------------------------------------------------------------- /client/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "target": "ES2020", 4 | "useDefineForClassFields": true, 5 | "module": "ESNext", 6 | "lib": [ 7 | "ES2020", 8 | "DOM", 9 | "DOM.Iterable" 10 | ], 11 | "skipLibCheck": true, 12 | "outDir": "dist", 13 | /* Bundler mode */ 14 | "moduleResolution": "bundler", 15 | "allowImportingTsExtensions": true, 16 | "resolveJsonModule": true, 17 | "isolatedModules": true, 18 | "noEmit": true, 19 | "jsx": "react-jsx", 20 | /* Linting */ 21 | "strict": true, 22 | "noUnusedLocals": true, 23 | "noUnusedParameters": true, 24 | "noFallthroughCasesInSwitch": true, 25 | "types": [ 26 | "vite/client" 27 | ] 28 | }, 29 | "include": [ 30 | "src" 31 | ] 32 | } -------------------------------------------------------------------------------- /client/vite.config.ts: -------------------------------------------------------------------------------- 1 | import { ProxyOptions, defineConfig, loadEnv } from "vite"; 2 | import topLevelAwait from "vite-plugin-top-level-await"; 3 | 4 | export default defineConfig(({ mode }) => { 5 | const env = loadEnv(mode, process.cwd()); 6 | const proxyConf: Record = env.VITE_QUEUE_API_URL ? { 7 | "/api": { 8 | target: env.VITE_QUEUE_API_URL, 9 | changeOrigin: true, 10 | }, 11 | } : {}; 12 | return { 13 | server: { 14 | host: "0.0.0.0", 15 | https: { 16 | cert: "./cert.pem", 17 | key: "./key.pem", 18 | }, 19 | proxy: { 20 | ...proxyConf, 21 | } 22 | }, 23 | plugins: [ 24 | topLevelAwait({ 25 | // The export name of top-level await promise for each chunk module 26 | promiseExportName: "__tla", 27 | // The function to generate import names of top-level await promise in each chunk module 28 | promiseImportName: i => `__tla_${i}`, 29 | }), 30 | ], 31 | }; 32 | }); 33 | -------------------------------------------------------------------------------- /configs/moshi_7b_202409.json: -------------------------------------------------------------------------------- 1 | { 2 | "dim": 4096, 3 | "text_card": 32000, 4 | "existing_text_padding_id": 3, 5 | "n_q": 16, 6 | "dep_q": 8, 7 | "card": 2048, 8 | "num_heads": 32, 9 | "num_layers": 32, 10 | "hidden_scale": 4.125, 11 | "causal": true, 12 | "layer_scale": null, 13 | "context": 3000, 14 | "max_period": 10000, 15 | "gating": "silu", 16 | "norm": "rms_norm_f32", 17 | "positional_embedding": "rope", 18 | "depformer_dim": 1024, 19 | "depformer_dim_feedforward": 4224, 20 | "depformer_num_heads": 16, 21 | "depformer_num_layers": 6, 22 | "depformer_causal": true, 23 | "depformer_layer_scale": null, 24 | "depformer_multi_linear": true, 25 | "depformer_context": 8, 26 | "depformer_max_period": 10000, 27 | "depformer_gating": "silu", 28 | "depformer_pos_emb": "none", 29 | "depformer_weights_per_step": true, 30 | "delays": [0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1] 31 | } 32 | 33 | 34 | -------------------------------------------------------------------------------- /configs/moshi_dev_2b.json: -------------------------------------------------------------------------------- 1 | { 2 | "dim": 2560, 3 | "text_card": 48000, 4 | "existing_text_padding_id": 3, 5 | "n_q": 32, 6 | "dep_q": 16, 7 | "card": 2048, 8 | "num_heads": 20, 9 | "num_layers": 24, 10 | "hidden_scale": 4.125, 11 | "causal": true, 12 | "layer_scale": null, 13 | "context": 3000, 14 | "max_period": 100000, 15 | "gating": "silu", 16 | "norm": "rms_norm_f32", 17 | "positional_embedding": "rope", 18 | "depformer_dim": 1024, 19 | "depformer_dim_feedforward": 4224, 20 | "depformer_num_heads": 16, 21 | "depformer_num_layers": 6, 22 | "depformer_causal": true, 23 | "depformer_layer_scale": null, 24 | "depformer_multi_linear": true, 25 | "depformer_context": 16, 26 | "depformer_max_period": 10000, 27 | "depformer_gating": "silu", 28 | "depformer_pos_emb": "none", 29 | "depformer_weights_per_step": true, 30 | "delays": [0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], 31 | "conditioners": { 32 | "description": { 33 | "type": "lut", 34 | "lut": { 35 | "n_bins": 31, 36 | "dim": 16, 37 | "tokenizer": "noop", 38 | "possible_values": ["very_bad", "bad", "neutral", "good", "very_good"] 39 | } 40 | } 41 | }, 42 | "fuser": { 43 | "sum": ["description"] 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /configs/moshi_mlx_2b.json: -------------------------------------------------------------------------------- 1 | { 2 | "dim": 2560, 3 | "text_card": 48000, 4 | "existing_text_padding_id": 3, 5 | "n_q": 32, 6 | "dep_q": 16, 7 | "card": 2048, 8 | "num_heads": 20, 9 | "num_layers": 24, 10 | "hidden_scale": 4.125, 11 | "causal": true, 12 | "layer_scale": null, 13 | "context": 3000, 14 | "max_period": 100000, 15 | "gating": "silu", 16 | "norm": "rms_norm_f32", 17 | "positional_embedding": "rope", 18 | "depformer_dim": 1024, 19 | "depformer_dim_feedforward": 4224, 20 | "depformer_num_heads": 16, 21 | "depformer_num_layers": 6, 22 | "depformer_causal": true, 23 | "depformer_layer_scale": null, 24 | "depformer_multi_linear": true, 25 | "depformer_context": 16, 26 | "depformer_max_period": 10000, 27 | "depformer_gating": "silu", 28 | "depformer_pos_emb": "none", 29 | "depformer_weights_per_step": true, 30 | "delays": [0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], 31 | "conditioners": { 32 | "description": { 33 | "type": "lut", 34 | "lut": { 35 | "n_bins": 31, 36 | "dim": 16, 37 | "tokenizer": "noop", 38 | "possible_values": ["very_bad", "bad", "neutral", "good", "very_good"] 39 | } 40 | } 41 | }, 42 | "fuser": { 43 | "sum": ["description"] 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /data/sample_fr_hibiki_crepes.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/data/sample_fr_hibiki_crepes.mp3 -------------------------------------------------------------------------------- /data/sample_fr_hibiki_intro.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/data/sample_fr_hibiki_intro.mp3 -------------------------------------------------------------------------------- /data/sample_fr_hibiki_monologue_otis.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/data/sample_fr_hibiki_monologue_otis.mp3 -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3.8" 2 | 3 | name: moshi 4 | 5 | services: 6 | 7 | moshi: 8 | build: 9 | context: ./moshi 10 | expose: 11 | - 8998/tcp 12 | restart: unless-stopped 13 | volumes: 14 | - hf-cache:/root/.cache/huggingface 15 | environment: 16 | #- HF_REPO=kyutai/moshika-pytorch-bf16 17 | - HF_REPO=kyutai/moshiko-pytorch-bf16 18 | deploy: 19 | resources: 20 | reservations: 21 | devices: 22 | - driver: nvidia 23 | capabilities: [gpu] 24 | count: all 25 | 26 | tunnel: 27 | image: cloudflare/cloudflared:latest 28 | pull_policy: always 29 | restart: unless-stopped 30 | expose: 31 | - 43337/tcp 32 | environment: 33 | TUNNEL_URL: http://moshi:8998 34 | TUNNEL_METRICS: 0.0.0.0:43337 35 | command: tunnel --no-autoupdate 36 | depends_on: 37 | - moshi 38 | 39 | volumes: 40 | hf-cache: 41 | -------------------------------------------------------------------------------- /mimi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/mimi.png -------------------------------------------------------------------------------- /moshi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/moshi.png -------------------------------------------------------------------------------- /moshi/Dockerfile: -------------------------------------------------------------------------------- 1 | # Use an official Python runtime as a parent image 2 | FROM python:3.10 3 | 4 | # Set the working directory in the container 5 | WORKDIR /app 6 | 7 | # Copy the current directory contents into the container at /app 8 | COPY . /app 9 | 10 | # Install any needed packages specified in requirements.txt 11 | # Assuming you have a requirements.txt file in the moshi directory 12 | RUN pip install --no-cache-dir -r requirements.txt 13 | 14 | # Install Moshi and gradio 15 | RUN pip install --no-cache-dir moshi gradio 16 | 17 | # Expose the port used by the server 18 | EXPOSE 8998 19 | 20 | # Set environment variable for the model (with a default value) 21 | ENV HF_REPO=kyutai/moshiko-pytorch-bf16 22 | 23 | # Run the server when the container launches 24 | CMD python -m moshi.server --gradio-tunnel --hf-repo $HF_REPO 25 | -------------------------------------------------------------------------------- /moshi/LICENSE: -------------------------------------------------------------------------------- 1 | Permission is hereby granted, free of charge, to any 2 | person obtaining a copy of this software and associated 3 | documentation files (the "Software"), to deal in the 4 | Software without restriction, including without 5 | limitation the rights to use, copy, modify, merge, 6 | publish, distribute, sublicense, and/or sell copies of 7 | the Software, and to permit persons to whom the Software 8 | is furnished to do so, subject to the following 9 | conditions: 10 | 11 | The above copyright notice and this permission notice 12 | shall be included in all copies or substantial portions 13 | of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF 16 | ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED 17 | TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A 18 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT 19 | SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 20 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 21 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR 22 | IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 23 | DEALINGS IN THE SOFTWARE. 24 | -------------------------------------------------------------------------------- /moshi/LICENSE.audiocraft: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Meta Platforms, Inc. and affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /moshi/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE* 2 | include *.md 3 | include *.cfg 4 | include requirements.txt 5 | include moshi/py.typed 6 | include tests/assets/*.safetensors 7 | -------------------------------------------------------------------------------- /moshi/demo_moshi.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "gpuType": "T4" 8 | }, 9 | "kernelspec": { 10 | "name": "python3", 11 | "display_name": "Python 3" 12 | }, 13 | "language_info": { 14 | "name": "python" 15 | }, 16 | "accelerator": "GPU" 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "markdown", 21 | "source": [ 22 | "# Moshi: a speech-text foundation model for real time dialogue\n", 23 | "\n", 24 | " [Moshi][moshi] is a speech-text foundation model and **full-duplex** spoken dialogue framework.\n", 25 | " It uses [Mimi][moshi], a state-of-the-art streaming neural audio codec. Mimi processes 24 kHz audio, down to a 12.5 Hz representation\n", 26 | " with a bandwidth of 1.1 kbps, in a fully streaming manner (latency of 80ms, the frame size),\n", 27 | " yet performs better than existing, non-streaming, codecs like\n", 28 | " [SpeechTokenizer](https://github.com/ZhangXInFD/SpeechTokenizer) (50 Hz, 4kbps), or [SemantiCodec](https://github.com/haoheliu/SemantiCodec-inference) (50 Hz, 1.3kbps).\n", 29 | "\n", 30 | " Moshi models **two streams of audio**: one corresponds to Moshi, and the other one to the user.\n", 31 | " At inference, the stream from the user is taken from the audio input,\n", 32 | "and the one for Moshi is sampled from the model's output. Along these two audio streams, Moshi predicts text tokens corresponding to its own speech, its **inner monologue**,\n", 33 | "which greatly improves the quality of its generation. A small Depth Transformer models inter codebook dependencies for a given time step,\n", 34 | "while a large, 7B parameter Temporal Transformer models the temporal dependencies. Moshi achieves a theoretical latency\n", 35 | "of 160ms (80ms for the frame size of Mimi + 80ms of acoustic delay), with a practical overall latency as low as 200ms on an L4 GPU.\n", 36 | "\n", 37 | "\n", 38 | "For more information, checkout our repo\n", 39 | "[[repo]](https://github.com/kyutai-labs/hibiki),\n", 40 | "[[samples]](https://huggingface.co/spaces/kyutai/hibiki-samples), and [[paper]](https://arxiv.org/abs/2410.00037)." 41 | ], 42 | "metadata": { 43 | "id": "iuzciRNiOznZ" 44 | } 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": { 50 | "id": "zYz-oOcaOwwr" 51 | }, 52 | "outputs": [], 53 | "source": [ 54 | "!pip install \"git+https://git@github.com/kyutai-labs/moshi.git#egg=moshi&subdirectory=moshi\"\n", 55 | "!pip install gradio" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "source": [ 61 | "# You can also run the model through our WebUI for live translation.\n", 62 | "# Click on the gradio.live link!\n", 63 | "! python -m moshi.server --gradio-tunnel --hf-repo kyutai/moshiko-pytorch-q8 --half\n", 64 | "# Or with Moshika's voice\n", 65 | "# ! python -m moshi.server --gradio-tunnel --hf-repo kyutai/moshika-1b-pytorch-bf16 --half | grep -v 'frame handled'" 66 | ], 67 | "metadata": { 68 | "id": "jykmkKO0f7dK" 69 | }, 70 | "execution_count": null, 71 | "outputs": [] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "source": [], 76 | "metadata": { 77 | "id": "ouGqgLmWL0x0" 78 | }, 79 | "execution_count": null, 80 | "outputs": [] 81 | } 82 | ] 83 | } -------------------------------------------------------------------------------- /moshi/moshi/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Kyutai, all rights reserved. 2 | # This source code is licensed under the license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | """ 6 | moshi is the inference codebase for Kyutai audio generation models. 7 | 8 | The code has been adapted from Audiocraft, see LICENSE.audiocraft 9 | Copyright (c) Meta Platforms, Inc. and affiliates. 10 | """ 11 | 12 | # flake8: noqa 13 | from . import conditioners 14 | from . import models 15 | from . import modules 16 | from . import quantization 17 | from . import utils 18 | 19 | __version__ = "0.2.5a7" 20 | -------------------------------------------------------------------------------- /moshi/moshi/conditioners/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # Copyright (c) Kyutai, all rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """ 6 | Modules to help doing generations under some fixed conditions. 7 | """ 8 | 9 | from .base import (ConditionType, ConditionAttributes, ConditionFuser, ConditionProvider, 10 | BaseConditioner, TensorCondition, ConditionTensors, dropout_all_conditions) 11 | -------------------------------------------------------------------------------- /moshi/moshi/conditioners/tensors.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Kyutai, all rights reserved. 2 | # This source code is licensed under the license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | from .base import _BaseTensorConditioner, TensorCondition, ConditionType 5 | 6 | 7 | class TensorConditioner(_BaseTensorConditioner[TensorCondition]): 8 | """Does basically nothing. 9 | """ 10 | 11 | def prepare(self, tensor: TensorCondition) -> TensorCondition: 12 | device = next(iter(self.parameters())).device 13 | return TensorCondition(tensor.tensor.to(device=device), tensor.mask.to(device=device)) 14 | 15 | def _get_condition(self, inputs: TensorCondition) -> ConditionType: 16 | return ConditionType(inputs.tensor, inputs.mask) 17 | -------------------------------------------------------------------------------- /moshi/moshi/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Kyutai, all rights reserved. 2 | # This source code is licensed under the license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | """ 5 | Models for the compression model Moshi, 6 | """ 7 | 8 | # flake8: noqa 9 | from .compression import ( 10 | CompressionModel, 11 | MimiModel, 12 | ) 13 | from .lm import LMModel, LMGen 14 | from .loaders import get_mimi, get_moshi_lm 15 | -------------------------------------------------------------------------------- /moshi/moshi/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Kyutai, all rights reserved. 2 | # This source code is licensed under the license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | # Copyright (c) Meta Platforms, Inc. and affiliates. 6 | # All rights reserved. 7 | # 8 | # This source code is licensed under the license found in the 9 | # LICENSE file in the root directory of this source tree. 10 | """Modules used for building the models.""" 11 | 12 | # flake8: noqa 13 | from .conv import ( 14 | NormConv1d, 15 | NormConvTranspose1d, 16 | StreamingConv1d, 17 | StreamingConvTranspose1d, 18 | pad_for_conv1d, 19 | pad1d, 20 | unpad1d, 21 | ) 22 | from .seanet import SEANetEncoder, SEANetDecoder 23 | from .transformer import StreamingTransformer 24 | -------------------------------------------------------------------------------- /moshi/moshi/modules/gating.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Kyutai, all rights reserved. 2 | # This source code is licensed under the license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | from contextlib import ExitStack 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | 10 | from ..utils.compile import torch_compile_lazy, no_compile 11 | 12 | 13 | @torch_compile_lazy 14 | def gating_forward_kernel( 15 | weight_in: torch.Tensor, weight_out: torch.Tensor, activation, x: torch.Tensor 16 | ): 17 | x = F.linear(x, weight_in) 18 | B, T, _ = x.shape 19 | x = x.view(B, T, 2, -1) 20 | x = activation(x[..., 0, :]) * x[..., 1, :] 21 | x = F.linear(x, weight_out) 22 | return x 23 | 24 | 25 | def gating_forward_generic( 26 | linear_in: nn.Module, 27 | linear_out: nn.Module, 28 | activation, 29 | x: torch.Tensor 30 | ): 31 | x = linear_in(x) 32 | B, T, _ = x.shape 33 | x = x.view(B, T, 2, -1) 34 | x = activation(x[..., 0, :]) * x[..., 1, :] 35 | x = linear_out(x) 36 | return x 37 | 38 | 39 | class ActivationGating(nn.Module): 40 | """ 41 | Gating FFN layer, using the given activation. 42 | Args: 43 | dim (int): dimension of the input and output of the transformer. 44 | activation (any callable Tensor to Tensor): activation function to use. 45 | **factory_kwargs: other kwargs passed to the linear layer, in particular device and dtype. 46 | """ 47 | 48 | _fsdp_final = True 49 | 50 | def __init__(self, dim: int, dim_feedforward: int, activation, quantized: bool = False, **factory_kwargs): 51 | super().__init__() 52 | # We should have 8 d^2 param, instead we will have 53 | # 2 * h * d + h * d = 3 h * d = 8 d^2 54 | # so h = 8 d / 3 but following Hervé's advice we use 21 / 8 as an approx. 55 | if dim_feedforward == 4 * dim: 56 | hidden = (21 * dim) // 8 57 | else: 58 | hidden = (2 * dim_feedforward) // 3 59 | 60 | self.linear_in = nn.Linear(dim, 2 * hidden, bias=False, **factory_kwargs) 61 | self.linear_out = nn.Linear(hidden, dim, bias=False, **factory_kwargs) 62 | 63 | # We try to follow the default PyTorch MHA convention, to easily compare results. 64 | 65 | self.activation = activation 66 | 67 | def forward(self, x: torch.Tensor): 68 | if isinstance(self.linear_in, nn.Linear): 69 | assert isinstance(self.linear_out, nn.Linear) 70 | with ExitStack() as stack: 71 | if self.training: 72 | stack.enter_context(no_compile()) 73 | return gating_forward_kernel( 74 | self.linear_in.weight, self.linear_out.weight, self.activation, x 75 | ) 76 | else: 77 | return gating_forward_generic( 78 | self.linear_in, 79 | self.linear_out, 80 | self.activation, 81 | x 82 | ) 83 | 84 | 85 | def _get_activation(name: str): 86 | if name in ["sigmoid", "tanh", "relu"]: 87 | return getattr(torch, name) 88 | elif name in ["leaky_relu", "elu", "gelu", "silu", "mish", "softsign"]: 89 | return getattr(torch.nn.functional, name) 90 | elif name == "identity": 91 | return torch.nn.Identity() 92 | else: 93 | raise ValueError(f"Unknown activation {name}") 94 | 95 | 96 | def _make_gating( 97 | name: str, dim: int, dim_feedforward: int, 98 | **factory_kwargs 99 | ) -> nn.Module: 100 | return ActivationGating( 101 | dim, dim_feedforward, _get_activation(name), **factory_kwargs 102 | ) 103 | 104 | 105 | def make_gating( 106 | name: str, dim: int, dim_feedforward: int, **factory_kwargs 107 | ) -> nn.Module: 108 | gating = _make_gating(name, dim, dim_feedforward, **factory_kwargs) 109 | if isinstance(gating.linear_in, nn.Linear): 110 | max_params = 2 * dim * dim_feedforward 111 | params = sum(p.numel() for p in gating.parameters()) 112 | assert ( 113 | params <= max_params 114 | ), f"{name} gating has {params} params, max is {max_params}" 115 | return gating 116 | -------------------------------------------------------------------------------- /moshi/moshi/modules/lora.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def replace_all_linear_with_lora(module, rank: int, scaling: float, device=None, dtype=None): 6 | """ Recursively replace all Linear layers with LoRALinear layers.""" 7 | for name, child in module.named_children(): 8 | if isinstance(child, nn.Linear): 9 | if device is None: 10 | this_device = child.weight.device 11 | else: 12 | this_device = device 13 | if dtype is None: 14 | this_dtype = child.weight.dtype 15 | else: 16 | this_dtype = dtype 17 | lora = LoRALinear(child.in_features, child.out_features, 18 | rank, scaling, device=this_device, dtype=this_dtype) 19 | lora.frozen_W = child 20 | setattr(module, name, lora) 21 | else: 22 | replace_all_linear_with_lora(child, rank, scaling, device=device, dtype=dtype) 23 | 24 | 25 | def replace_lora_with_linear(module): 26 | """Recursively replace all LoRALinear layers with Linear layers.""" 27 | for name, child in module.named_children(): 28 | if isinstance(child, LoRALinear): 29 | # Compute merged weights: W' = W + scaling * B @ A 30 | merged_weight = child.frozen_W.weight.data + \ 31 | child.scaling * (child.lora_B.weight @ child.lora_A.weight) 32 | # Create a standard Linear layer with the same in/out features 33 | new_linear = nn.Linear(child.frozen_W.in_features, 34 | child.frozen_W.out_features, bias=False, 35 | device=torch.device('meta'), 36 | dtype=merged_weight.dtype) 37 | new_linear.weight = nn.Parameter( 38 | merged_weight, requires_grad=merged_weight.requires_grad) # Transfer merged weights 39 | setattr(module, name, new_linear) # Replace the module 40 | else: 41 | replace_lora_with_linear(child) # Recursively process submodules 42 | 43 | 44 | class LoRALinear(nn.Module): 45 | """ 46 | Implementation of: 47 | - LoRA: https://arxiv.org/abs/2106.09685 48 | 49 | Notes: 50 | - Freezing is handled at the network level, not the layer level. 51 | - Scaling factor controls relative importance of LoRA skip 52 | connection versus original frozen weight. General guidance is 53 | to keep it to 2.0 and sweep over learning rate when changing 54 | the rank. 55 | """ 56 | 57 | def __init__( 58 | self, 59 | in_features: int, 60 | out_features: int, 61 | rank: int, 62 | scaling: float, 63 | bias: bool = False, 64 | device: torch.device | None = None, 65 | dtype: torch.dtype = torch.bfloat16, 66 | ): 67 | super().__init__() 68 | 69 | self.in_features = in_features 70 | self.out_features = out_features 71 | assert not bias 72 | self.bias = bias 73 | self.rank = rank 74 | self.scaling = scaling 75 | 76 | self.lora_A = nn.Linear( 77 | self.in_features, 78 | self.rank, 79 | bias=self.bias, 80 | device=device, 81 | dtype=dtype, 82 | ) 83 | self.lora_B = nn.Linear( 84 | self.rank, 85 | self.out_features, 86 | bias=self.bias, 87 | device=device, 88 | dtype=dtype, 89 | ) 90 | 91 | self.frozen_W = nn.Linear(self.in_features, 92 | self.out_features, 93 | bias=self.bias, 94 | device=device, 95 | dtype=dtype) 96 | 97 | self._register_load_state_dict_pre_hook(LoRALinear._load_hook, with_module=True) 98 | 99 | def merge_weight(self): 100 | with torch.no_grad(): 101 | down_weight = self.lora_A.weight 102 | up_weight = self.lora_B.weight 103 | 104 | weight = up_weight.mm(down_weight) * self.scaling 105 | 106 | weight += self.frozen_W.weight 107 | return weight 108 | 109 | @staticmethod 110 | def _load_hook(module, state_dict, prefix, *_): 111 | key_name = prefix + "weight" 112 | if key_name in state_dict: 113 | w_ref = state_dict.pop(key_name) 114 | state_dict[prefix + 'frozen_W.weight'] = w_ref 115 | 116 | def forward(self, x: torch.Tensor): 117 | lora = self.lora_B(self.lora_A(x)) 118 | return self.frozen_W(x) + lora * self.scaling 119 | 120 | def __repr__(self) -> str: 121 | return "{}Linear(in_features={}, out_features={}, r={})".format( 122 | "LoRA", self.in_features, self.out_features, self.rank) 123 | -------------------------------------------------------------------------------- /moshi/moshi/modules/resample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Kyutai, all rights reserved. 2 | # This source code is licensed under the license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import typing as tp 6 | 7 | from einops import rearrange 8 | import torch 9 | from torch import nn 10 | 11 | from .conv import StreamingConv1d, StreamingConvTranspose1d 12 | 13 | 14 | class ConvDownsample1d(nn.Module): 15 | """ 16 | Downsampling by some integer amount `stride` using convolutions 17 | with a kernel size of twice the stride. 18 | If `causal` is True, the output uses a causal convolution. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | stride: int, 24 | dimension: tp.Optional[int] = None, 25 | causal: bool = False, 26 | learnt: bool = False, 27 | channel_wise: bool = False, 28 | ): 29 | super().__init__() 30 | self.learnt = learnt 31 | self.channel_wise = channel_wise 32 | groups = 1 33 | if learnt: 34 | assert dimension is not None, "Dimension required for learnt convolutions." 35 | in_channels = dimension 36 | out_channels = dimension 37 | if channel_wise: 38 | groups = dimension 39 | else: 40 | in_channels = 1 41 | out_channels = 1 42 | 43 | self.conv = StreamingConv1d( 44 | in_channels, 45 | out_channels, 46 | kernel_size=2 * stride, 47 | stride=stride, 48 | causal=causal, 49 | groups=groups, 50 | bias=False, 51 | pad_mode="replicate", 52 | ) 53 | if not learnt: 54 | actual_conv = self.conv.conv.conv 55 | actual_conv.weight.requires_grad_(False) 56 | actual_conv.weight.data.fill_(1.0 / (2 * stride)) 57 | 58 | def forward(self, x: torch.Tensor): 59 | batch_size = len(x) 60 | if not self.learnt: 61 | x = rearrange(x, "b c t -> (b c) () t") 62 | y = self.conv(x) 63 | if not self.learnt: 64 | y = rearrange(y, "(b c) () t -> b c t", b=batch_size) 65 | return y 66 | 67 | 68 | class ConvTrUpsample1d(nn.Module): 69 | """ 70 | Upsample by some integer amount `stride` using transposed convolutions. 71 | """ 72 | 73 | def __init__( 74 | self, 75 | stride: int, 76 | dimension: tp.Optional[int] = None, 77 | causal: bool = False, 78 | learnt: bool = False, 79 | channel_wise: bool = False, 80 | ): 81 | super().__init__() 82 | self.learnt = learnt 83 | self.channel_wise = channel_wise 84 | groups = 1 85 | if learnt: 86 | assert dimension is not None, "Dimension required for learnt convolutions." 87 | in_channels = dimension 88 | out_channels = dimension 89 | if channel_wise: 90 | groups = dimension 91 | else: 92 | in_channels = 1 93 | out_channels = 1 94 | 95 | self.convtr = StreamingConvTranspose1d( 96 | in_channels, 97 | out_channels, 98 | kernel_size=2 * stride, 99 | stride=stride, 100 | causal=causal, 101 | groups=groups, 102 | bias=False, 103 | ) 104 | if not learnt: 105 | actual_convtr = self.convtr.convtr.convtr 106 | actual_convtr.weight.requires_grad_(False) 107 | actual_convtr.weight.data.fill_(1.0) 108 | 109 | def forward(self, x: torch.Tensor): 110 | batch_size = len(x) 111 | if not self.learnt: 112 | x = rearrange(x, "b c t -> (b c) () t") 113 | y = self.convtr(x) 114 | if not self.learnt: 115 | x_for_normalization = torch.ones_like(x[:1]) 116 | normalization = self.convtr(x_for_normalization) 117 | y = y / normalization 118 | y = rearrange(y, "(b c) () t -> b c t", b=batch_size) 119 | return y 120 | -------------------------------------------------------------------------------- /moshi/moshi/modules/rope.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Kyutai, all rights reserved. 2 | # This source code is licensed under the license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | from torch import nn 6 | import math 7 | import torch 8 | from ..utils.compile import torch_compile_lazy 9 | 10 | 11 | @torch_compile_lazy 12 | def apply_rope( 13 | q: torch.Tensor, 14 | k: torch.Tensor, 15 | offset: torch.Tensor, 16 | max_period: float = 10_000, 17 | time_before_heads: bool = False, 18 | ): 19 | """ 20 | Args: 21 | q (torch.Tensor): queries, shape `[B, T, H, D]`. 22 | k (torch.Tensor): keys, shape `[B, T, H, D]`. 23 | offset (int): current offset, e.g. when streaming. 24 | max_period (float): maximum period for the cos and sin. 25 | time_before_heads (bool): if True, expected [B, T, H, D], else [B, H, T ,D] 26 | """ 27 | 28 | if time_before_heads: 29 | B, T, H, D = q.shape 30 | else: 31 | B, H, T, D = q.shape 32 | assert k.shape == q.shape 33 | assert D > 0 34 | assert D % 2 == 0 35 | assert max_period > 0 36 | 37 | ds = torch.arange(D // 2, device=q.device, dtype=torch.float32) 38 | freqs = torch.exp(ds * (-math.log(max_period) * 2 / D)) 39 | ts = offset.float().view(-1, 1) + torch.arange(T, device=q.device, dtype=torch.float32) 40 | if time_before_heads: 41 | ts = ts.view(B, -1, 1, 1) 42 | else: 43 | ts = ts.view(B, 1, -1, 1) 44 | 45 | dims = q.shape[:-1] 46 | q = q.view(*dims, D // 2, 2) 47 | k = k.view(*dims, D // 2, 2) 48 | 49 | # convention is `r` suffix is real part, `i` is imaginary. 50 | qr = q[..., 0].float() 51 | qi = q[..., 1].float() 52 | 53 | kr = k[..., 0].float() 54 | ki = k[..., 1].float() 55 | 56 | rotr = torch.cos(freqs * ts) 57 | roti = torch.sin(freqs * ts) 58 | qor = qr * rotr - qi * roti 59 | qoi = qr * roti + qi * rotr 60 | 61 | kor = kr * rotr - ki * roti 62 | koi = kr * roti + ki * rotr 63 | 64 | dtype = q.dtype 65 | qo = torch.stack([qor.to(dtype), qoi.to(dtype)], dim=-1) 66 | ko = torch.stack([kor.to(dtype), koi.to(dtype)], dim=-1) 67 | 68 | return qo.view(*dims, D), ko.view(*dims, D) 69 | 70 | 71 | class RotaryEmbedding(nn.Module): 72 | """Rotary positional embedding (RoPE) from [Su et al 2022](https://arxiv.org/abs/2104.09864). 73 | 74 | Args: 75 | max_period (float): Maximum period of the rotation frequencies. 76 | """ 77 | 78 | def __init__(self, max_period: float = 10000.0): 79 | super().__init__() 80 | self.max_period = max_period 81 | 82 | def forward( 83 | self, 84 | q: torch.Tensor, 85 | k: torch.Tensor, 86 | offset: torch.Tensor, 87 | time_before_heads: bool = False, 88 | ): 89 | """Apply rope rotation to query or key tensor.""" 90 | return apply_rope(q, k, offset, self.max_period, time_before_heads) 91 | -------------------------------------------------------------------------------- /moshi/moshi/quantization/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Kyutai, all rights reserved. 2 | # This source code is licensed under the license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | # Copyright (c) Meta Platforms, Inc. and affiliates. 6 | # All rights reserved. 7 | # 8 | # This source code is licensed under the license found in the 9 | # LICENSE file in the root directory of this source tree. 10 | """RVQ.""" 11 | # flake8: noqa 12 | from .vq import ResidualVectorQuantizer, SplitResidualVectorQuantizer 13 | from .base import BaseQuantizer, DummyQuantizer, QuantizedResult 14 | -------------------------------------------------------------------------------- /moshi/moshi/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Kyutai, all rights reserved. 2 | # This source code is licensed under the license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | # Copyright (c) Meta Platforms, Inc. and affiliates. 6 | # All rights reserved. 7 | # 8 | # This source code is licensed under the license found in the 9 | # LICENSE file in the root directory of this source tree. 10 | """Utilities.""" 11 | -------------------------------------------------------------------------------- /moshi/moshi/utils/autocast.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Kyutai, all rights reserved. 2 | # This source code is licensed under the license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | # Copyright (c) Meta Platforms, Inc. and affiliates. 6 | # All rights reserved. 7 | # 8 | # This source code is licensed under the license found in the 9 | # LICENSE file in the root directory of this source tree. 10 | 11 | import torch 12 | 13 | 14 | class TorchAutocast: 15 | """TorchAutocast utility class. 16 | Allows you to enable and disable autocast. This is specially useful 17 | when dealing with different architectures and clusters with different 18 | levels of support. 19 | 20 | Args: 21 | enabled (bool): Whether to enable torch.autocast or not. 22 | args: Additional args for torch.autocast. 23 | kwargs: Additional kwargs for torch.autocast 24 | """ 25 | 26 | def __init__(self, enabled: bool, *args, **kwargs): 27 | self.autocast = torch.autocast(*args, **kwargs) if enabled else None 28 | 29 | def __enter__(self): 30 | if self.autocast is None: 31 | return 32 | try: 33 | self.autocast.__enter__() 34 | except RuntimeError: 35 | device = self.autocast.device 36 | dtype = self.autocast.fast_dtype 37 | raise RuntimeError( 38 | f"There was an error autocasting with dtype={dtype} device={device}\n" 39 | "If you are on the FAIR Cluster, you might need to use autocast_dtype=float16" 40 | ) 41 | 42 | def __exit__(self, *args, **kwargs): 43 | if self.autocast is None: 44 | return 45 | self.autocast.__exit__(*args, **kwargs) 46 | -------------------------------------------------------------------------------- /moshi/moshi/utils/quantize.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Kyutai, all rights reserved. 2 | # This source code is licensed under the license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | """Quantization based on bitsandbytes, supporting only 8 bits for now. 6 | We are taking from freedom from the intended use of bnb: 7 | """ 8 | 9 | import torch 10 | from torch import nn 11 | 12 | 13 | class QLinear(nn.Module): 14 | def __init__(self, linear: nn.Linear): 15 | super().__init__() 16 | from bitsandbytes import functional as bnbF # type: ignore 17 | weight = linear.weight 18 | assert weight.data.dtype.is_floating_point 19 | assert linear.bias is None 20 | CB, SCB, _ = bnbF.int8_vectorwise_quant(weight.data.to(torch.float16)) # type: ignore 21 | self.weight = nn.Parameter(CB, requires_grad=False) 22 | self.weight_scb = nn.Parameter(SCB, requires_grad=False) 23 | 24 | def forward(self, x): 25 | import bitsandbytes as bnb # type: ignore 26 | state = bnb.MatmulLtState() 27 | state.CB = self.weight # type: ignore 28 | assert isinstance(state.CB, torch.Tensor) 29 | state.SCB = self.weight_scb # type: ignore 30 | assert isinstance(state.SCB, torch.Tensor) 31 | if state.SCB.dtype != torch.float: 32 | raise RuntimeError( 33 | "Expected `weight_scb` to have type float, but got bfloat16. " 34 | "When using quantized models, care should be taken not to change the dtype of " 35 | "the model once initialized.") 36 | assert state.SCB.dtype == torch.float, state.SCB.dtype 37 | state.has_fp16_weights = False 38 | y = bnb.matmul(x.half(), state.CB, state=state) 39 | assert isinstance(y, torch.Tensor) 40 | return y 41 | 42 | 43 | def replace_linear_with_qlinear(module): 44 | """Recursively replace all Linear layers with QLinear layers.""" 45 | for name, child in module.named_children(): 46 | if isinstance(child, nn.Linear): 47 | setattr(module, name, QLinear(child)) 48 | elif isinstance(child, QLinear): 49 | # Slight issue with the way we implement things: the scale param 50 | # might get casted with the rest of the model to bfloat16, altough 51 | # we most likely want to keep it as float. For the LM model we might call this function twice, 52 | # first layer by layer to avoid to big of a memory usage, and second, at the end 53 | # of the LM init, after all other modules are initialized and properly dtyped. 54 | # In any case that should happen before loading the state dict to avoid a loss of precision. 55 | child.float() 56 | else: 57 | replace_linear_with_qlinear(child) 58 | -------------------------------------------------------------------------------- /moshi/moshi/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .compile import torch_compile_lazy 4 | 5 | 6 | @torch_compile_lazy 7 | def cross_entropy( 8 | logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor, dtype=torch.float32, 9 | logits_soft_clip: float | None = None) -> torch.Tensor: 10 | """Compute cross entropy between multi-codebook targets and model's logits. 11 | The cross entropy is computed per codebook to provide codebook-level cross entropy. 12 | Valid timesteps for each of the codebook are pulled from the mask, where invalid 13 | timesteps are set to 0. 14 | 15 | Args: 16 | logits (torch.Tensor): Model's logits of shape [B, K, T, card]. 17 | targets (torch.Tensor): Target codes, of shape [B, K, T]. 18 | mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T]. 19 | dtype (type): Data type of the output cross entropy. 20 | logits_soft_clip (float): Clipping value for the logits to avoid numerical instability. 21 | Recommended value: 30.0. 22 | Returns: 23 | ce (torch.Tensor): Cross entropy [B, K, T] with type dtype. 24 | """ 25 | output_shape = targets.shape 26 | assert logits.shape[:-1] == targets.shape 27 | assert mask.shape == targets.shape 28 | logits = logits.view(-1, logits.shape[-1]) 29 | targets = targets.reshape(-1) 30 | mask = mask.reshape(-1) 31 | 32 | safe_targets = torch.where( 33 | mask, 34 | targets, 35 | torch.zeros(1, device=targets.device, dtype=targets.dtype), 36 | ) 37 | 38 | # Chunking the conversion to float32 to avoid OOMs. 39 | ce_chunks = [] 40 | for logits_chunk, targets_chunk in zip(torch.chunk(logits, 4), torch.chunk(safe_targets, 4)): 41 | logits_chunk = logits_chunk.to(dtype) 42 | if logits_soft_clip is not None: 43 | logits_chunk = logits_soft_clip * torch.tanh(logits_chunk / logits_soft_clip) 44 | log_partition = torch.logsumexp(logits_chunk, dim=-1, keepdim=True) 45 | 46 | # For some reason, the PyTorch cross entropy is super slow with inputs with large cardinality (e.g. 32000) 47 | # so we reimplement the cross entropy ourselves... 48 | ce_chunks.append(log_partition - logits_chunk.gather(-1, targets_chunk[..., None])) 49 | ce = torch.cat(ce_chunks, dim=0) 50 | ce = ce[..., 0] 51 | ce = torch.where(mask, ce, torch.zeros(1, device=ce.device, dtype=ce.dtype)) 52 | return ce.view(output_shape) 53 | -------------------------------------------------------------------------------- /moshi/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "moshi" 3 | requires-python = ">= 3.10" 4 | description = "Moshi is moshi" 5 | dependencies = [ 6 | "numpy >= 1.26, < 2.3", 7 | "safetensors >= 0.4.0, < 0.6", 8 | "huggingface-hub >= 0.24, < 0.29", 9 | "bitsandbytes >= 0.45, < 0.46; sys_platform == 'linux'", 10 | "einops >= 0.7, < 0.9", 11 | "sentencepiece == 0.2", 12 | "sounddevice == 0.5", 13 | "sphn >= 0.1.4, < 0.2.0", 14 | "torch >= 2.2.0, < 2.7", 15 | "aiohttp>=3.10.5, <3.12", 16 | "pytest >= 8.3.3", 17 | ] 18 | authors = [{name="Laurent Mazaré", email="laurent@kyutai.org"}] 19 | maintainers = [{name="Laurent Mazaré", email="laurent@kyutai.org"}] 20 | license = {text = "MIT"} 21 | dynamic = ["version"] 22 | readme = "README.md" 23 | 24 | [project.scripts] 25 | moshi-server = "moshi.server:main" 26 | moshi-client = "moshi.client:main" 27 | 28 | [tool.setuptools.dynamic] 29 | version = {attr = "moshi.__version__"} 30 | 31 | [build-system] 32 | requires = ["setuptools"] 33 | build-backend = "setuptools.build_meta" 34 | 35 | [project.optional-dependencies] 36 | dev = [ 37 | "pyright", 38 | "pytest", 39 | "flake8", 40 | "pre-commit", 41 | "gradio-webrtc>=0.0.18" 42 | ] 43 | -------------------------------------------------------------------------------- /moshi/requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.7.0 2 | safetensors==0.4.4 3 | sentencepiece==0.2.0 4 | sounddevice==0.5.0 5 | soundfile==0.12.1 6 | sphn==0.1.4 7 | torch==2.2.0 8 | numpy==1.26.4 9 | aiohttp>=3.10.5, <3.11 10 | huggingface-hub==0.24.6 11 | pytest==8.3.3 -------------------------------------------------------------------------------- /moshi/setup.cfg: -------------------------------------------------------------------------------- 1 | [pep8] 2 | max-line-length = 120 3 | 4 | [flake8] 5 | max-line-length = 120 6 | ignore = E203,E704 7 | exclude = 8 | dist 9 | build 10 | 11 | -------------------------------------------------------------------------------- /moshi/tests/assets/test_lm_codes.safetensors: -------------------------------------------------------------------------------- 1 | @{"codes":{"dtype":"I64","shape":[3,4,7],"data_offsets":[0,672]}}     2 |   3 |     4 |   -------------------------------------------------------------------------------- /moshi/tests/assets/test_lm_model.safetensors: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/moshi/tests/assets/test_lm_model.safetensors -------------------------------------------------------------------------------- /moshi/tests/assets/test_lm_out.safetensors: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/moshi/tests/assets/test_lm_out.safetensors -------------------------------------------------------------------------------- /moshi/tests/test_lm.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from safetensors.torch import load_file, load_model 3 | import torch 4 | 5 | from moshi.models import lm 6 | from moshi.utils.utils import cross_entropy 7 | 8 | 9 | def _get_assets() -> Path: 10 | return Path(__file__).parent / 'assets' 11 | 12 | 13 | def _get_lm(device=None, dtype=torch.float32) -> lm.LMModel: 14 | torch.manual_seed(1234) 15 | model = lm.LMModel( 16 | delays=[0, 1, 2, 4], 17 | n_q=3, 18 | dep_q=3, 19 | card=32, 20 | text_card=48, 21 | dim=16, 22 | num_layers=2, 23 | num_heads=1, 24 | hidden_scale=1, 25 | depformer_dim=16, 26 | depformer_multi_linear=True, 27 | depformer_weights_per_step=True, 28 | depformer_weights_per_step_schedule=[0, 1, 1], 29 | depformer_low_rank_embeddings=8, 30 | depformer_num_heads=1, 31 | depformer_gating='silu', 32 | context=4, 33 | device=device, 34 | dtype=dtype, 35 | ) 36 | return model 37 | 38 | 39 | def test_init(): 40 | _get_lm(dtype=torch.float32) 41 | _get_lm(dtype=torch.bfloat16) 42 | _get_lm(dtype=torch.float16) 43 | 44 | 45 | @torch.no_grad 46 | def test_forward(): 47 | model = _get_lm() 48 | load_model(model, _get_assets() / 'test_lm_model.safetensors') 49 | codes = load_file(_get_assets() / 'test_lm_codes.safetensors')['codes'] 50 | out = model(codes) 51 | assert out.logits is not None 52 | assert out.text_logits is not None 53 | assert out.mask.shape == codes[:, 1:].shape 54 | assert out.text_mask.shape == codes[:, :1].shape 55 | assert out.logits.shape[:-1] == codes[:, 1:].shape 56 | assert out.logits.shape[-1] == model.card 57 | assert out.text_logits.shape[-1] == model.text_card 58 | 59 | ref_out = load_file(_get_assets() / 'test_lm_out.safetensors') 60 | assert (ref_out['mask'] == out.mask).all() 61 | assert (ref_out['text_mask'] == out.text_mask).all() 62 | ce = cross_entropy(out.logits, codes[:, 1:], out.mask) 63 | ce_ref = cross_entropy(ref_out['logits'], codes[:, 1:], out.mask) 64 | delta = (ce.mean(dim=(0, 2)) - ce_ref.mean(dim=(0, 2))).abs() / ce_ref.mean(dim=(0, 2)) 65 | assert delta.amax() <= 1e-6, delta.amax() 66 | 67 | ce = cross_entropy(out.text_logits, codes[:, :1], out.text_mask) 68 | ce_ref = cross_entropy(ref_out['text_logits'], codes[:, :1], out.text_mask) 69 | delta = (ce.mean(dim=(0, 2)) - ce_ref.mean(dim=(0, 2))).abs() / ce_ref.mean(dim=(0, 2)) 70 | assert delta.amax() <= 1e-6, delta.amax() 71 | -------------------------------------------------------------------------------- /moshi_mlx/LICENSE: -------------------------------------------------------------------------------- 1 | Permission is hereby granted, free of charge, to any 2 | person obtaining a copy of this software and associated 3 | documentation files (the "Software"), to deal in the 4 | Software without restriction, including without 5 | limitation the rights to use, copy, modify, merge, 6 | publish, distribute, sublicense, and/or sell copies of 7 | the Software, and to permit persons to whom the Software 8 | is furnished to do so, subject to the following 9 | conditions: 10 | 11 | The above copyright notice and this permission notice 12 | shall be included in all copies or substantial portions 13 | of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF 16 | ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED 17 | TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A 18 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT 19 | SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 20 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 21 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR 22 | IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 23 | DEALINGS IN THE SOFTWARE. 24 | -------------------------------------------------------------------------------- /moshi_mlx/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE* 2 | include *.md 3 | include *.cfg 4 | include requirements.txt 5 | include moshi_mlx/py.typed 6 | -------------------------------------------------------------------------------- /moshi_mlx/README.md: -------------------------------------------------------------------------------- 1 | # Moshi - MLX 2 | 3 | See the [top-level README.md][main_repo] for more information on Moshi. 4 | 5 | [Moshi][moshi] is a speech-text foundation model and full-duplex spoken dialogue framework. 6 | It uses [Mimi][moshi], a state-of-the-art streaming neural audio codec. Mimi operates at a framerate of 12.5 Hz, and compresses 7 | 24 kHz audio down to 1.1 kbps, in a fully streaming manner (latency of 80ms, the frame size), yet performs better than existing, non-streaming, codec. 8 | 9 | This is the MLX implementation for Moshi. For Mimi, this uses our Rust based implementation through the Python binding provided in `rustymimi`, available in the [rust/](https://github.com/kyutai-labs/moshi/tree/main/rust) folder of our main repository. 10 | 11 | ## Requirements 12 | 13 | You will need at least Python 3.10, we recommend Python 3.12. 14 | 15 | ```bash 16 | pip install moshi_mlx # moshi MLX, from PyPI, best with Python 3.12. 17 | # Or the bleeding edge versions for Moshi and Moshi-MLX. 18 | pip install -e "git+https://git@github.com/kyutai-labs/moshi#egg=moshi_mlx&subdirectory=moshi_mlx" 19 | ``` 20 | We have tested the MLX version with MacBook Pro M3. 21 | 22 | If you are not using Python 3.12, you might get an error when installing 23 | `moshi_mlx` or `rustymimi` (which `moshi_mlx` depends on). Then,you will need to install the [Rust toolchain](https://rustup.rs/), or switch to Python 3.12. 24 | 25 | ## Usage 26 | 27 | 28 | Once you have installed `moshi_mlx`, you can run 29 | ```bash 30 | python -m moshi_mlx.local -q 4 # weights quantized to 4 bits 31 | python -m moshi_mlx.local -q 8 # weights quantized to 8 bits 32 | # And using a different pretrained model: 33 | python -m moshi_mlx.local -q 4 --hf-repo kyutai/moshika-mlx-q4 34 | python -m moshi_mlx.local -q 8 --hf-repo kyutai/moshika-mlx-q8 35 | # be careful to always match the `-q` and `--hf-repo` flag. 36 | ``` 37 | 38 | This uses a command line interface, which is barebone. It does not perform any echo cancellation, 39 | nor does it try to compensate for a growing lag by skipping frames. 40 | 41 | You can use `--hf-repo` to select a different pretrained model, by setting the proper Hugging Face repository. 42 | See [the model list](https://github.com/kyutai-labs/moshi?tab=readme-ov-file#models) for a reference of the available models. 43 | 44 | Alternatively you can use `python -m moshi_mlx.local_web` to use 45 | the web UI, the connection is via http, at [localhost:8998](http://localhost:8998). 46 | 47 | 48 | ## License 49 | 50 | The present code is provided under the MIT license. 51 | 52 | ## Citation 53 | 54 | If you use either Mimi or Moshi, please cite the following paper, 55 | 56 | ``` 57 | @techreport{kyutai2024moshi, 58 | author = {Alexandre D\'efossez and Laurent Mazar\'e and Manu Orsini and Am\'elie Royer and 59 | Patrick P\'erez and Herv\'e J\'egou and Edouard Grave and Neil Zeghidour}, 60 | title = {Moshi: a speech-text foundation model for real-time dialogue}, 61 | institution = {Kyutai}, 62 | year={2024}, 63 | month={September}, 64 | url={http://kyutai.org/Moshi.pdf}, 65 | } 66 | ``` 67 | 68 | [moshi]: https://kyutai.org/Moshi.pdf 69 | [main_repo]: https://github.com/kyutai-labs/moshi 70 | -------------------------------------------------------------------------------- /moshi_mlx/moshi_mlx/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Kyutai, all rights reserved. 2 | # This source code is licensed under the license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | # flake8: noqa 5 | 6 | """ 7 | moshi_mlx is the MLX inference codebase for Kyutai audio generation models. 8 | """ 9 | 10 | from . import modules, models, utils 11 | 12 | __version__ = "0.2.4" 13 | -------------------------------------------------------------------------------- /moshi_mlx/moshi_mlx/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Kyutai, all rights reserved. 2 | # This source code is licensed under the license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | # flake8: noqa 5 | """ 6 | Models for EnCodec, AudioGen, MusicGen, as well as the generic LMModel. 7 | """ 8 | 9 | from .lm import ( 10 | Lm, 11 | LmConfig, 12 | config_v0_1, 13 | config1b_202412, 14 | config1b_202412_16rvq, 15 | config_helium_1_preview_2b, 16 | ) 17 | from .generate import LmGen 18 | from .mimi import mimi_202407, MimiConfig 19 | -------------------------------------------------------------------------------- /moshi_mlx/moshi_mlx/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Kyutai, all rights reserved. 2 | # This source code is licensed under the license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | # flake8: noqa 5 | """Modules used for building the models.""" 6 | 7 | from .conv import Conv1d, ConvTranspose1d, StreamableConv1d, StreamableConvTranspose1d, NormConv1d, NormConvTranspose1d, ConvDownsample1d, ConvTrUpsample1d 8 | from .quantization import SplitResidualVectorQuantizer 9 | from .seanet import SeanetConfig, SeanetEncoder, SeanetDecoder 10 | from .kv_cache import KVCache, RotatingKVCache 11 | from .transformer import Transformer, TransformerConfig, ProjectedTransformer 12 | -------------------------------------------------------------------------------- /moshi_mlx/moshi_mlx/modules/conditioner.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Kyutai, all rights reserved. 2 | # This source code is licensed under the license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | # flake8: noqa 5 | """ 6 | Conditioners 7 | """ 8 | 9 | from dataclasses import dataclass 10 | 11 | import mlx.core as mx 12 | import mlx.nn as nn 13 | 14 | 15 | @dataclass 16 | class LutConditionerConfig: 17 | n_bins: int 18 | dim: int 19 | tokenizer: str 20 | possible_values: dict[str, int] 21 | 22 | 23 | class LutConditioner(nn.Module): 24 | def __init__(self, output_dim: int, cfg: LutConditionerConfig): 25 | super().__init__() 26 | 27 | if cfg.tokenizer != "noop": 28 | raise ValueError(f"unsupported tokenizer {cfg.tokenizer}") 29 | 30 | self.embed = nn.Embedding(cfg.n_bins + 1, cfg.dim) 31 | self.output_proj = nn.Linear(cfg.dim, output_dim, bias=False) 32 | self.learnt_padding = mx.zeros((1, 1, output_dim)) 33 | self.possible_values = { v: i for i, v in enumerate(cfg.possible_values) } 34 | 35 | def condition(self, value: str) -> mx.array: 36 | idx = self.possible_values.get(value, None) 37 | if idx is None: 38 | raise ValueError(f"unknown value {value}, possible-values: {self.possible_values}") 39 | idx = mx.array([idx]) 40 | return self.output_proj(self.embed(idx)) 41 | 42 | @dataclass 43 | class ConditionTensor: 44 | tensor: mx.array 45 | 46 | class ConditionProvider(nn.Module): 47 | def __init__(self, output_dim: int, cfg: dict[str, LutConditionerConfig]): 48 | self.conditioners = { name: LutConditioner(output_dim, c) for name, c in cfg.items() } 49 | 50 | def condition_tensor(self, name: str, value: str) -> ConditionTensor: 51 | if name not in self.conditioners: 52 | raise ValueError(f"unsupported conditioner {name}") 53 | tensor = self.conditioners[name].condition(value) 54 | return ConditionTensor(tensor) 55 | -------------------------------------------------------------------------------- /moshi_mlx/moshi_mlx/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/moshi_mlx/moshi_mlx/py.typed -------------------------------------------------------------------------------- /moshi_mlx/moshi_mlx/run_helium.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Kyutai, all rights reserved. 2 | # This source code is licensed under the license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import argparse 6 | import sentencepiece 7 | import huggingface_hub 8 | import mlx.core as mx 9 | import mlx.nn as nn 10 | from moshi_mlx import models, utils 11 | 12 | 13 | def main(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--tokenizer", type=str) 16 | parser.add_argument("--weights", type=str) 17 | parser.add_argument("--nsteps", type=int, default=20) 18 | parser.add_argument("--hf-repo", type=str, default="kyutai/helium-1-preview-2b-mlx") 19 | parser.add_argument("--prompt", type=str, default="Aujourd'hui, il est temps") 20 | parser.add_argument("--verbose", action="store_true") 21 | parser.add_argument("--quantize-bits", type=int) 22 | parser.add_argument("--save-quantized", type=str) 23 | parser.add_argument("--quantize-group-size", type=int, default=64) 24 | args = parser.parse_args() 25 | 26 | weights = args.weights 27 | if weights is None: 28 | weights = huggingface_hub.hf_hub_download( 29 | args.hf_repo, "helium-1-preview-2b-bf16.safetensors" 30 | ) 31 | tokenizer = args.tokenizer 32 | if tokenizer is None: 33 | tokenizer = huggingface_hub.hf_hub_download( 34 | args.hf_repo, "tokenizer_spm_48k_multi6_2.model" 35 | ) 36 | 37 | mx.random.seed(299792458) 38 | lm_config = models.config_helium_1_preview_2b() 39 | model = models.Lm(lm_config) 40 | model.set_dtype(mx.bfloat16) 41 | model.load_weights(weights, strict=True) 42 | if args.quantize_bits is not None: 43 | nn.quantize(model, bits=args.quantize_bits, group_size=args.quantize_group_size) 44 | if args.save_quantized is not None: 45 | print(f"saving quantized weights in {args.save_quantized}") 46 | model.save_weights(args.save_quantized) 47 | sampler = utils.Sampler() 48 | tokenizer = sentencepiece.SentencePieceProcessor(tokenizer) # type: ignore 49 | if args.verbose: 50 | print("prompt", args.prompt) 51 | else: 52 | print(args.prompt, end="", flush=True) 53 | prompt_tokens = tokenizer.encode(args.prompt) # type: ignore 54 | token = mx.array([[1] + prompt_tokens]) 55 | for step_idx in range(args.nsteps): 56 | logits = model(token) 57 | token, _ = sampler(logits[:, -1]) 58 | text_token = token.item() 59 | _text = tokenizer.id_to_piece(text_token) # type: ignore 60 | _text = _text.replace("▁", " ") 61 | _text = _text.replace("<0x0A>", "\n") 62 | if args.verbose: 63 | print(step_idx, token, _text) 64 | else: 65 | print(_text, end="", flush=True) 66 | token = token[None] 67 | print() 68 | 69 | 70 | if __name__ == "__main__": 71 | main() 72 | -------------------------------------------------------------------------------- /moshi_mlx/moshi_mlx/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Kyutai, all rights reserved. 2 | # This source code is licensed under the license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | # flake8: noqa 5 | """Utilities.""" 6 | 7 | from .sampling import Sampler 8 | -------------------------------------------------------------------------------- /moshi_mlx/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "moshi_mlx" 3 | requires-python = ">= 3.10" 4 | description = "Moshi is moshi, but running on macOS" 5 | dependencies = [ 6 | "numpy >= 2.1.0, < 2.3", 7 | "safetensors >= 0.4.0, < 0.6", 8 | "huggingface-hub >= 0.24, < 0.29", 9 | "rustymimi == 0.4.1", 10 | "sentencepiece == 0.2", 11 | "sounddevice == 0.5", 12 | "sphn >= 0.1.4, < 0.2.0", 13 | "mlx >= 0.24.0, < 0.25", 14 | "aiohttp>=3.10.5, <3.12", 15 | ] 16 | authors = [{name="Laurent Mazaré", email="laurent@kyutai.org"}] 17 | maintainers = [{name="Laurent Mazaré", email="laurent@kyutai.org"}] 18 | license = {text = "MIT"} 19 | dynamic = ["version"] 20 | readme = "README.md" 21 | 22 | [project.scripts] 23 | moshi-local = "moshi_mlx.local:main" 24 | moshi-local-web = "moshi_mlx.local_web:main" 25 | 26 | [build-system] 27 | requires = ["setuptools"] 28 | build-backend = "setuptools.build_meta" 29 | 30 | [tool.setuptools.dynamic] 31 | version = {attr = "moshi_mlx.__version__"} 32 | 33 | [project.optional-dependencies] 34 | dev = [ 35 | "pyright", 36 | "flake8", 37 | "pre-commit", 38 | ] 39 | -------------------------------------------------------------------------------- /moshi_mlx/requirements.txt: -------------------------------------------------------------------------------- 1 | aiohttp>=3.10.5, <3.11 2 | blessed==1.20.0 3 | cffi==1.17.0 4 | dashing==0.1.0 5 | huggingface-hub==0.24.6 6 | mlx==0.22.0 7 | nodeenv==1.9.1 8 | numpy==2.1.0 9 | psutil==6.0.0 10 | pycparser==2.22 11 | pyright==1.1.378 12 | rustymimi==0.4.1 13 | safetensors==0.4.4 14 | sentencepiece==0.2.0 15 | six==1.16.0 16 | sounddevice==0.5.0 17 | wcwidth==0.2.13 18 | -------------------------------------------------------------------------------- /moshi_mlx/setup.cfg: -------------------------------------------------------------------------------- 1 | [pep8] 2 | max-line-length = 120 3 | 4 | [flake8] 5 | max-line-length = 120 6 | ignore = E203,E704 7 | exclude = 8 | dist 9 | build 10 | 11 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | pre-commit>=3.8 2 | pyright>=1.1 3 | flake8>=7.1 4 | -------------------------------------------------------------------------------- /rust/.github/workflows/rust-ci.yml: -------------------------------------------------------------------------------- 1 | on: [push, pull_request] 2 | 3 | name: Continuous integration 4 | 5 | jobs: 6 | check: 7 | name: Check 8 | defaults: 9 | run: 10 | working-directory: ./rust 11 | runs-on: ${{ matrix.os }} 12 | strategy: 13 | matrix: 14 | os: [ubuntu-latest, windows-latest, macOS-latest] 15 | rust: [stable, nightly] 16 | steps: 17 | - uses: actions/checkout@v2 18 | - uses: actions-rs/toolchain@v1 19 | with: 20 | profile: minimal 21 | toolchain: ${{ matrix.rust }} 22 | override: true 23 | - uses: actions-rs/cargo@v1 24 | with: 25 | command: check 26 | 27 | test: 28 | name: Test Suite 29 | defaults: 30 | run: 31 | working-directory: ./rust 32 | runs-on: ${{ matrix.os }} 33 | strategy: 34 | matrix: 35 | os: [ubuntu-latest, windows-latest, macOS-latest] 36 | rust: [stable, nightly] 37 | steps: 38 | - uses: actions/checkout@v2 39 | - uses: actions-rs/toolchain@v1 40 | with: 41 | profile: minimal 42 | toolchain: ${{ matrix.rust }} 43 | override: true 44 | - uses: actions-rs/cargo@v1 45 | with: 46 | command: test 47 | 48 | fmt: 49 | name: Rustfmt 50 | defaults: 51 | run: 52 | working-directory: ./rust 53 | runs-on: ubuntu-latest 54 | steps: 55 | - uses: actions/checkout@v2 56 | - uses: actions-rs/toolchain@v1 57 | with: 58 | profile: minimal 59 | toolchain: stable 60 | override: true 61 | - run: rustup component add rustfmt 62 | - uses: actions-rs/cargo@v1 63 | with: 64 | command: fmt 65 | args: --all -- --check 66 | 67 | clippy: 68 | name: Clippy 69 | defaults: 70 | run: 71 | working-directory: ./rust 72 | runs-on: ubuntu-latest 73 | steps: 74 | - uses: actions/checkout@v2 75 | - uses: actions-rs/toolchain@v1 76 | with: 77 | profile: minimal 78 | toolchain: stable 79 | override: true 80 | - run: rustup component add clippy 81 | - uses: actions-rs/cargo@v1 82 | with: 83 | command: clippy 84 | args: -- -D warnings 85 | -------------------------------------------------------------------------------- /rust/Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | members = [ 3 | "mimi-pyo3", 4 | "moshi-backend", 5 | "moshi-cli", 6 | "moshi-core", 7 | ] 8 | resolver = "2" 9 | 10 | [workspace.package] 11 | version = "0.6.0-alpha.1" 12 | edition = "2021" 13 | license = "MIT/Apache-2.0" 14 | description = "moshi, a real-time voice AI" 15 | repository = "https://github.com/kyutai-labs/moshi" 16 | keywords = ["machine-learning", "audio"] 17 | categories = ["science"] 18 | 19 | [workspace.dependencies] 20 | anyhow = "1" 21 | axum = { version = "0.7.3", features = ["ws"] } 22 | axum-server = { version = "0.6", features = ["tls-rustls"] } 23 | base64ct = { version = "1.6.0", features = ["alloc"] } 24 | bincode = "1.3.3" 25 | byteorder = "1.5.0" 26 | candle = { git = "https://github.com/huggingface/candle.git", rev = "e3db30021fb08efbe4ee71d840f5d1230d050cd3", package = "candle-core" } 27 | candle-flash-attn = { git = "https://github.com/huggingface/candle.git", rev = "e3db30021fb08efbe4ee71d840f5d1230d050cd3" } 28 | candle-nn = { git = "https://github.com/huggingface/candle.git", rev = "e3db30021fb08efbe4ee71d840f5d1230d050cd3" } 29 | candle-transformers = { git = "https://github.com/huggingface/candle.git", rev = "e3db30021fb08efbe4ee71d840f5d1230d050cd3" } 30 | clap = { version = "4.4.12", features = ["derive"] } 31 | color-eyre = "0.6.2" 32 | cpal = "0.15.3" 33 | crossterm = { version = "0.27.0", features = ["event-stream"] } 34 | env_logger = "0.10.1" 35 | futures = "0.3.28" 36 | futures-util = "0.3.30" 37 | hf-hub = { version = "0.3.2", features = ["tokio"] } 38 | http = "1.1.0" 39 | lazy_static = "1.5.0" 40 | log = "0.4.20" 41 | moshi = { path = "./moshi-core", version = "0.6.0-alpha.1" } 42 | native-tls = "0.2.11" 43 | numpy = "0.23.0" 44 | ogg = { version = "0.9.1", features = ["async"] } 45 | opus = "0.3.0" 46 | pyo3 = "0.23.0" 47 | rand = { version = "0.8.5", features = ["getrandom"] } 48 | rand_chacha = "0.3.1" 49 | ratatui = "0.27.0" 50 | rayon = "1.8.1" 51 | rcgen = "0.13.1" 52 | regex = "1.10.3" 53 | rubato = "0.15.0" 54 | rustls = "0.23.5" 55 | sentencepiece = "0.11.2" 56 | serde = { version = "1.0", features = ["derive"] } 57 | serde_json = "1.0.115" 58 | sha3 = "0.10.8" 59 | symphonia = { version = "0.5.3", features = ["all"] } 60 | tokenizers = "0.15.2" 61 | tokio = { version = "1.35.1", features = ["full"] } 62 | tokio-rustls = "0.24.1" 63 | tokio-tungstenite = { version = "0.21.0", features = ["rustls", "native-tls"] } 64 | toml = "0.8.19" 65 | tower = "0.4.13" 66 | tower-http = { version = "0.5", features = ["full"] } 67 | tracing = "0.1.40" 68 | tracing-appender = "0.2.3" 69 | tracing-chrome = "0.7.2" 70 | tracing-subscriber = "0.3.18" 71 | tui-logger = "0.11.2" 72 | vergen = { version = "8.3.1", features = ["build", "cargo", "git", "gitcl", "rustc", "si"] } 73 | 74 | [profile.release] 75 | debug = true 76 | 77 | [profile.release-no-debug] 78 | inherits = "release" 79 | debug = false 80 | -------------------------------------------------------------------------------- /rust/README.md: -------------------------------------------------------------------------------- 1 | # moshi - rust 2 | 3 | [![Latest version](https://img.shields.io/crates/v/moshi.svg)](https://crates.io/crates/moshi) 4 | [![Documentation](https://docs.rs/moshi/badge.svg)](https://docs.rs/moshi) 5 | ![License](https://img.shields.io/crates/l/moshi.svg) 6 | 7 | See the [top-level README.md](../README.md) for more information. 8 | 9 | This provides the Rust backend (both Mimi and Moshi) and client implementation. 10 | The Mimi implementation is available through Python bindings, through the `rustymimi` package. 11 | 12 | ## Requirements 13 | 14 | You will need a recent version of the [Rust toolchain](https://rustup.rs/). 15 | To compile GPU support, you will also need the [CUDA](https://developer.nvidia.com/cuda-toolkit) properly installed for your GPU, in particular with `nvcc`. 16 | 17 | 18 | ## Rust based Mimi with Python bindings 19 | 20 | First, a standalone rust based implementation of Mimi is provided, along with Python bindings. 21 | This is the one used by `moshi_mlx`. It is automatically installed with `moshi_mlx`, but you 22 | can install it separately as 23 | ```bash 24 | # Install from pip: 25 | pip install rustymimi 26 | # Alternatively, if you want to compile the package run from the root of the repo. 27 | maturin dev -r -m rust/mimi-pyo3/Cargo.toml 28 | ``` 29 | 30 | ## Rust server 31 | 32 | If you don't have ssl certificates yet, generate a `key.pem` and `cert.pem` file 33 | using the following command. 34 | ```bash 35 | openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -days 365 -nodes -subj "/CN=localhost" 36 | ``` 37 | 38 | In order to run the rust inference server, use the following command from within 39 | the this directory: 40 | 41 | ```bash 42 | cargo run --features cuda --bin moshi-backend -r -- --config moshi-backend/config.json standalone 43 | ``` 44 | 45 | When using macOS, you can replace `--features cuda` with `--features metal`. 46 | 47 | Alternatively you can use `config-q8.json` rather than `config.json` to use the 48 | quantified q8 model. You can select a different pretrained model, e.g. Moshika, 49 | by changing the `"hf_repo"` key in either file. 50 | 51 | Once the server has printed 'standalone worker listening', you can use the web 52 | UI. By default the rust version uses https so it will be at 53 | [localhost:8998](https://localhost:8998). 54 | 55 | You will get some warnings about the site being unsafe. When using chrome you 56 | can bypass it by selecting "Details" or "Advanced", then "Visit this unsafe 57 | site" or "Proceed to localhost (unsafe)". 58 | 59 | ## Rust client 60 | 61 | We recommend using the web UI as it provides some echo cancellation that helps 62 | the overall model quality. Alternatively we provide some command line interfaces 63 | for the rust and python versions, the protocol is the same as with the web UI so 64 | there is nothing to change on the server side. 65 | 66 | ### Rust Command Line 67 | 68 | From within the `rust` directory, run the following: 69 | ```bash 70 | cargo run --bin moshi-cli -r -- tui --host localhost 71 | ``` 72 | 73 | ## License 74 | 75 | The present code is provided under the Apache license. 76 | -------------------------------------------------------------------------------- /rust/mimi-pyo3/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "mimi-pyo3" 3 | version.workspace = true 4 | edition.workspace = true 5 | description.workspace = true 6 | repository.workspace = true 7 | keywords.workspace = true 8 | categories.workspace = true 9 | license.workspace = true 10 | 11 | [lib] 12 | name = "rustymimi" 13 | crate-type = ["cdylib"] 14 | 15 | [dependencies] 16 | anyhow = { workspace = true } 17 | numpy = { workspace = true } 18 | pyo3 = { workspace = true } 19 | moshi = { workspace = true } 20 | -------------------------------------------------------------------------------- /rust/mimi-pyo3/py_src/rustymimi/__init__.py: -------------------------------------------------------------------------------- 1 | from .rustymimi import * 2 | -------------------------------------------------------------------------------- /rust/mimi-pyo3/py_src/rustymimi/__init__.pyi: -------------------------------------------------------------------------------- 1 | # Generated content DO NOT EDIT 2 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence 3 | from os import PathLike 4 | 5 | @staticmethod 6 | def write_wav(filename, data, sample_rate): 7 | """ 8 | Writes an audio file using the wav format based on pcm data from a numpy array. 9 | 10 | This only supports a single channel at the moment so the input array data is expected to have a 11 | single dimension. 12 | """ 13 | pass 14 | 15 | class StreamTokenizer: 16 | def __init__(path, *, dtype="f32", max_seq_len=None): 17 | pass 18 | 19 | def decode(self, codes): 20 | """ """ 21 | pass 22 | 23 | def encode(self, pcm_data): 24 | """ """ 25 | pass 26 | 27 | def get_decoded(self): 28 | """ """ 29 | pass 30 | 31 | def get_encoded(self): 32 | """ """ 33 | pass 34 | 35 | class Tokenizer: 36 | def __init__(path, *, dtype="f32", max_seq_len=None): 37 | pass 38 | 39 | def decode(self, codes): 40 | """ """ 41 | pass 42 | 43 | def decode_step(self, codes): 44 | """ """ 45 | pass 46 | 47 | def encode(self, pcm_data): 48 | """ """ 49 | pass 50 | 51 | def encode_step(self, pcm_data): 52 | """ """ 53 | pass 54 | 55 | def reset(self): 56 | """ """ 57 | pass 58 | -------------------------------------------------------------------------------- /rust/mimi-pyo3/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["maturin>=1.4,<2.0"] 3 | build-backend = "maturin" 4 | 5 | [project] 6 | name = "rustymimi" 7 | requires-python = ">=3.8" 8 | classifiers = [ 9 | "Programming Language :: Rust", 10 | "Programming Language :: Python :: Implementation :: CPython", 11 | "Programming Language :: Python :: Implementation :: PyPy", 12 | ] 13 | dynamic = ["version"] 14 | 15 | [tool.maturin] 16 | python-source = "py_src" 17 | module-name = "rustymimi.rustymimi" 18 | bindings = 'pyo3' 19 | features = ["pyo3/extension-module"] 20 | -------------------------------------------------------------------------------- /rust/moshi-backend/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "moshi-backend" 3 | version.workspace = true 4 | edition.workspace = true 5 | description.workspace = true 6 | repository.workspace = true 7 | keywords.workspace = true 8 | categories.workspace = true 9 | license.workspace = true 10 | 11 | [dependencies] 12 | anyhow = { workspace = true } 13 | axum = { workspace = true } 14 | axum-server = { workspace = true } 15 | base64ct = { workspace = true } 16 | bincode = { workspace = true } 17 | byteorder = { workspace = true } 18 | candle = { workspace = true } 19 | candle-nn = { workspace = true } 20 | candle-transformers = { workspace = true } 21 | clap = { workspace = true } 22 | env_logger = { workspace = true } 23 | futures-util = { workspace = true } 24 | hf-hub = { workspace = true } 25 | rcgen = { workspace = true } 26 | http = { workspace = true } 27 | lazy_static = { workspace = true } 28 | log = { workspace = true } 29 | moshi = { workspace = true } 30 | ogg = { workspace = true } 31 | opus = { workspace = true } 32 | rand = { workspace = true } 33 | rand_chacha = { workspace = true } 34 | regex = { workspace = true } 35 | rubato = { workspace = true } 36 | sentencepiece = { workspace = true } 37 | serde = { workspace = true } 38 | serde_json = { workspace = true } 39 | sha3 = { workspace = true } 40 | symphonia = { workspace = true } 41 | tokenizers = { workspace = true } 42 | tokio = { workspace = true } 43 | tokio-rustls = { workspace = true } 44 | tower = { workspace = true } 45 | tower-http = { workspace = true } 46 | tracing = { workspace = true } 47 | tracing-appender = { workspace = true } 48 | tracing-chrome = { workspace = true } 49 | tracing-subscriber = { workspace = true } 50 | 51 | [build-dependencies] 52 | anyhow = { workspace = true } 53 | vergen = { workspace = true } 54 | 55 | [features] 56 | default = [] 57 | cuda = ["moshi/cuda", "candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"] 58 | metal = ["moshi/metal", "candle/metal", "candle-nn/metal", "candle-transformers/metal"] 59 | 60 | [profile.release] 61 | debug = true 62 | 63 | [profile.release-no-debug] 64 | inherits = "release" 65 | debug = false 66 | 67 | -------------------------------------------------------------------------------- /rust/moshi-backend/build.rs: -------------------------------------------------------------------------------- 1 | // Copyright (c) Kyutai, all rights reserved. 2 | // This source code is licensed under the license found in the 3 | // LICENSE file in the root directory of this source tree. 4 | 5 | use anyhow::Result; 6 | use vergen::EmitBuilder; 7 | 8 | pub fn main() -> Result<()> { 9 | // NOTE: This will output everything, and requires all features enabled. 10 | // NOTE: See the EmitBuilder documentation for configuration options. 11 | EmitBuilder::builder().all_build().all_cargo().all_git().all_rustc().all_sysinfo().emit()?; 12 | Ok(()) 13 | } 14 | -------------------------------------------------------------------------------- /rust/moshi-backend/config-q8.json: -------------------------------------------------------------------------------- 1 | { 2 | "instance_name": "foo", 3 | "hf_repo": "kyutai/moshiko-candle-q8", 4 | "lm_model_file": "$HOME/tmp/moshiko_rs_301e30bf@120/model.q8.gguf", 5 | "text_tokenizer_file": "$HOME/tmp/tokenizer_spm_32k_3.model", 6 | "log_dir": "$HOME/tmp/moshi-logs", 7 | "mimi_model_file": "$HOME/tmp/tokenizer-e351c8d8-checkpoint125.safetensors", 8 | "mimi_num_codebooks": 8, 9 | "static_dir": "../client/dist", 10 | "addr": "0.0.0.0", 11 | "port": 8998, 12 | "cert_dir": "." 13 | } 14 | 15 | -------------------------------------------------------------------------------- /rust/moshi-backend/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "instance_name": "foo", 3 | "hf_repo": "kyutai/moshiko-candle-bf16", 4 | "lm_model_file": "$HOME/tmp/moshiko_rs_301e30bf@120/model.safetensors", 5 | "text_tokenizer_file": "$HOME/tmp/tokenizer_spm_32k_3.model", 6 | "log_dir": "$HOME/tmp/moshi-logs", 7 | "mimi_model_file": "$HOME/tmp/tokenizer-e351c8d8-checkpoint125.safetensors", 8 | "mimi_num_codebooks": 8, 9 | "static_dir": "../client/dist", 10 | "addr": "0.0.0.0", 11 | "port": 8998, 12 | "cert_dir": "." 13 | } 14 | -------------------------------------------------------------------------------- /rust/moshi-backend/src/build.rs: -------------------------------------------------------------------------------- 1 | // Copyright (c) Kyutai, all rights reserved. 2 | // This source code is licensed under the license found in the 3 | // LICENSE file in the root directory of this source tree. 4 | 5 | use anyhow::Result; 6 | use vergen::EmitBuilder; 7 | 8 | pub fn main() -> Result<()> { 9 | // NOTE: This will output everything, and requires all features enabled. 10 | // NOTE: See the EmitBuilder documentation for configuration options. 11 | EmitBuilder::builder().all_build().all_cargo().all_git().all_rustc().all_sysinfo().emit()?; 12 | Ok(()) 13 | } 14 | -------------------------------------------------------------------------------- /rust/moshi-backend/src/utils.rs: -------------------------------------------------------------------------------- 1 | // Copyright (c) Kyutai, all rights reserved. 2 | // This source code is licensed under the license found in the 3 | // LICENSE file in the root directory of this source tree. 4 | 5 | #[derive(Debug, PartialEq, Clone, serde::Deserialize, serde::Serialize)] 6 | pub struct BuildInfo { 7 | build_timestamp: String, 8 | build_date: String, 9 | git_branch: String, 10 | git_timestamp: String, 11 | git_date: String, 12 | git_hash: String, 13 | git_describe: String, 14 | rustc_host_triple: String, 15 | rustc_version: String, 16 | cargo_target_triple: String, 17 | } 18 | 19 | impl BuildInfo { 20 | pub fn new() -> BuildInfo { 21 | BuildInfo { 22 | build_timestamp: String::from(env!("VERGEN_BUILD_TIMESTAMP")), 23 | build_date: String::from(env!("VERGEN_BUILD_DATE")), 24 | git_branch: String::from(env!("VERGEN_GIT_BRANCH")), 25 | git_timestamp: String::from(env!("VERGEN_GIT_COMMIT_TIMESTAMP")), 26 | git_date: String::from(env!("VERGEN_GIT_COMMIT_DATE")), 27 | git_hash: String::from(env!("VERGEN_GIT_SHA")), 28 | git_describe: String::from(env!("VERGEN_GIT_DESCRIBE")), 29 | rustc_host_triple: String::from(env!("VERGEN_RUSTC_HOST_TRIPLE")), 30 | rustc_version: String::from(env!("VERGEN_RUSTC_SEMVER")), 31 | cargo_target_triple: String::from(env!("VERGEN_CARGO_TARGET_TRIPLE")), 32 | } 33 | } 34 | } 35 | 36 | pub struct WrapJson(pub anyhow::Result); 37 | 38 | impl axum::response::IntoResponse for WrapJson { 39 | fn into_response(self) -> axum::response::Response { 40 | match self.0 { 41 | Ok(v) => axum::Json(v).into_response(), 42 | Err(err) => { 43 | tracing::error!(?err, "returning internal server error 500"); 44 | (axum::http::StatusCode::INTERNAL_SERVER_ERROR, format!("{err}")).into_response() 45 | } 46 | } 47 | } 48 | } 49 | 50 | pub fn replace_env_vars(input: &str) -> String { 51 | let re = regex::Regex::new(r"\$([A-Za-z_][A-Za-z0-9_]*)").unwrap(); 52 | re.replace_all(input, |caps: ®ex::Captures| { 53 | let var_name = &caps[1]; 54 | std::env::var(var_name).unwrap_or_else(|_| "".to_string()) 55 | }) 56 | .to_string() 57 | } 58 | 59 | pub struct WrapBincode(pub anyhow::Result); 60 | 61 | impl axum::response::IntoResponse for WrapBincode { 62 | fn into_response(self) -> axum::response::Response { 63 | match self.0.and_then(|v| Ok(bincode::serialize(&v)?)) { 64 | Ok(v) => (axum::http::StatusCode::OK, v).into_response(), 65 | Err(err) => { 66 | tracing::error!(?err, "returning internal server error 500"); 67 | (axum::http::StatusCode::INTERNAL_SERVER_ERROR, format!("{err}")).into_response() 68 | } 69 | } 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /rust/moshi-cli/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "moshi-cli" 3 | version.workspace = true 4 | edition.workspace = true 5 | description.workspace = true 6 | repository.workspace = true 7 | keywords.workspace = true 8 | categories.workspace = true 9 | license.workspace = true 10 | 11 | [dependencies] 12 | anyhow = { workspace = true } 13 | byteorder = { workspace = true } 14 | candle = { workspace = true } 15 | candle-nn = { workspace = true } 16 | candle-transformers = { workspace = true } 17 | clap = { workspace = true } 18 | color-eyre = { workspace = true } 19 | cpal = { workspace = true } 20 | crossterm = { workspace = true } 21 | env_logger = { workspace = true } 22 | futures = { workspace = true } 23 | futures-util = { workspace = true } 24 | log = { workspace = true } 25 | moshi = { workspace = true } 26 | native-tls = { workspace = true } 27 | ogg = { workspace = true } 28 | opus = { workspace = true } 29 | rand = { workspace = true } 30 | ratatui = { workspace = true } 31 | rubato = { workspace = true } 32 | rustls = { workspace = true } 33 | sentencepiece = { workspace = true } 34 | serde_json = { workspace = true } 35 | symphonia = { workspace = true } 36 | tokio = { workspace = true } 37 | tokio-tungstenite = { workspace = true } 38 | toml = { workspace = true } 39 | tracing = { workspace = true } 40 | tracing-chrome = { workspace = true } 41 | tracing-subscriber = { workspace = true } 42 | tui-logger = { workspace = true } 43 | 44 | [features] 45 | default = [] 46 | cuda = ["moshi/cuda", "candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"] 47 | metal = ["moshi/metal", "candle/metal", "candle-nn/metal", "candle-transformers/metal"] 48 | 49 | [profile.release] 50 | debug = true 51 | 52 | [profile.release-no-debug] 53 | inherits = "release" 54 | debug = false 55 | -------------------------------------------------------------------------------- /rust/moshi-cli/src/main.rs: -------------------------------------------------------------------------------- 1 | // Copyright (c) Kyutai, all rights reserved. 2 | // This source code is licensed under the license found in the 3 | // LICENSE file in the root directory of this source tree. 4 | 5 | use anyhow::Result; 6 | use clap::Parser; 7 | 8 | mod audio_io; 9 | mod gen; 10 | mod multistream; 11 | 12 | use candle::Device; 13 | 14 | #[derive(Debug, Parser)] 15 | struct Args { 16 | #[command(subcommand)] 17 | command: Command, 18 | 19 | /// Enable tracing (generates a trace-timestamp.json file). 20 | #[arg(long)] 21 | tracing: bool, 22 | } 23 | 24 | #[derive(Debug, clap::Subcommand)] 25 | enum Command { 26 | Client { 27 | #[arg(long)] 28 | host: String, 29 | 30 | #[arg(long, default_value_t = 8998)] 31 | port: usize, 32 | }, 33 | Tui { 34 | #[arg(long)] 35 | host: String, 36 | 37 | #[arg(long, default_value_t = 8998)] 38 | port: usize, 39 | }, 40 | Gen { 41 | #[arg(long)] 42 | lm_model_file: String, 43 | 44 | #[arg(long)] 45 | mimi_model_file: String, 46 | 47 | #[arg(long)] 48 | lm_config_file: String, 49 | 50 | #[arg(long)] 51 | text_tokenizer: String, 52 | 53 | #[arg(long)] 54 | audio_input_file: String, 55 | 56 | #[arg(long)] 57 | audio_output_file: String, 58 | 59 | #[arg(long, default_value_t = 299_792_458)] 60 | seed: u64, 61 | 62 | #[arg(long)] 63 | cfg_alpha: Option, 64 | 65 | /// Run on cpu 66 | #[arg(long)] 67 | cpu: bool, 68 | }, 69 | } 70 | 71 | pub fn device(cpu: bool) -> Result { 72 | if cpu { 73 | Ok(Device::Cpu) 74 | } else if candle::utils::cuda_is_available() { 75 | Ok(Device::new_cuda(0)?) 76 | } else if candle::utils::metal_is_available() { 77 | Ok(Device::new_metal(0)?) 78 | } else { 79 | Ok(Device::Cpu) 80 | } 81 | } 82 | 83 | #[tokio::main(flavor = "multi_thread", worker_threads = 10)] 84 | async fn main() -> Result<()> { 85 | use tracing_chrome::ChromeLayerBuilder; 86 | use tracing_subscriber::prelude::*; 87 | 88 | let args = Args::parse(); 89 | let _guard = if args.tracing { 90 | let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); 91 | tracing_subscriber::registry().with(chrome_layer).init(); 92 | Some(guard) 93 | } else { 94 | None 95 | }; 96 | match args.command { 97 | Command::Client { host, port } => { 98 | tracing_subscriber::fmt::init(); 99 | multistream::client::run(host, port).await? 100 | } 101 | Command::Tui { host, port } => { 102 | tracing_subscriber::fmt::init(); 103 | multistream::client_tui::run(host, port).await? 104 | } 105 | Command::Gen { 106 | seed, 107 | text_tokenizer, 108 | lm_model_file, 109 | lm_config_file, 110 | mimi_model_file, 111 | audio_input_file, 112 | audio_output_file, 113 | cfg_alpha, 114 | cpu, 115 | } => { 116 | let dev = device(cpu)?; 117 | tracing_subscriber::fmt::init(); 118 | let args = gen::Args { 119 | lm_model_file, 120 | mimi_model_file, 121 | text_tokenizer, 122 | lm_config_file, 123 | audio_input_file, 124 | audio_output_file, 125 | seed, 126 | cfg_alpha, 127 | }; 128 | gen::run(&args, &dev)? 129 | } 130 | } 131 | Ok(()) 132 | } 133 | -------------------------------------------------------------------------------- /rust/moshi-core/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "moshi" 3 | version.workspace = true 4 | edition.workspace = true 5 | description.workspace = true 6 | repository.workspace = true 7 | keywords.workspace = true 8 | categories.workspace = true 9 | license.workspace = true 10 | readme = "../README.md" 11 | 12 | [dependencies] 13 | candle = { workspace = true } 14 | candle-nn = { workspace = true } 15 | candle-transformers = { workspace = true } 16 | candle-flash-attn = { workspace = true, optional = true } 17 | 18 | rayon = { workspace = true } 19 | serde = { workspace = true } 20 | tracing = { workspace = true } 21 | 22 | [features] 23 | default = [] 24 | cuda = ["candle/cuda", "candle-nn/cuda"] 25 | metal = ["candle/metal", "candle-nn/metal"] 26 | flash-attn = ["cuda", "dep:candle-flash-attn"] 27 | -------------------------------------------------------------------------------- /rust/moshi-core/src/lib.rs: -------------------------------------------------------------------------------- 1 | // Copyright (c) Kyutai, all rights reserved. 2 | // This source code is licensed under the license found in the 3 | // LICENSE file in the root directory of this source tree. 4 | 5 | pub use candle; 6 | pub use candle_nn; 7 | 8 | pub mod asr; 9 | pub mod batched_transformer; 10 | pub mod conditioner; 11 | pub mod conv; 12 | pub mod kv_cache; 13 | pub mod lm; 14 | pub mod lm_generate; 15 | pub mod lm_generate_multistream; 16 | pub mod mimi; 17 | pub mod nn; 18 | pub mod quantization; 19 | pub mod seanet; 20 | pub mod streaming; 21 | pub mod transformer; 22 | pub mod tts; 23 | pub mod tts_streaming; 24 | pub mod wav; 25 | 26 | #[derive(Debug, Copy, Clone, PartialEq, Eq, serde::Deserialize, serde::Serialize)] 27 | pub enum NormType { 28 | RmsNorm, 29 | LayerNorm, 30 | } 31 | 32 | pub use streaming::{StreamMask, StreamTensor, StreamingModule}; 33 | -------------------------------------------------------------------------------- /rust/moshi-core/src/wav.rs: -------------------------------------------------------------------------------- 1 | // Copyright (c) Kyutai, all rights reserved. 2 | // This source code is licensed under the license found in the 3 | // LICENSE file in the root directory of this source tree. 4 | 5 | use std::io::prelude::*; 6 | 7 | pub trait Sample { 8 | fn to_i16(&self) -> i16; 9 | } 10 | 11 | impl Sample for f32 { 12 | fn to_i16(&self) -> i16 { 13 | (self.clamp(-1.0, 1.0) * 32767.0) as i16 14 | } 15 | } 16 | 17 | impl Sample for f64 { 18 | fn to_i16(&self) -> i16 { 19 | (self.clamp(-1.0, 1.0) * 32767.0) as i16 20 | } 21 | } 22 | 23 | impl Sample for i16 { 24 | fn to_i16(&self) -> i16 { 25 | *self 26 | } 27 | } 28 | 29 | pub fn write_pcm_as_wav( 30 | w: &mut W, 31 | samples: &[S], 32 | sample_rate: u32, 33 | ) -> std::io::Result<()> { 34 | let len = 12u32; // header 35 | let len = len + 24u32; // fmt 36 | let len = len + samples.len() as u32 * 2 + 8; // data 37 | let n_channels = 1u16; 38 | let bytes_per_second = sample_rate * 2 * n_channels as u32; 39 | w.write_all(b"RIFF")?; 40 | w.write_all(&(len - 8).to_le_bytes())?; // total length minus 8 bytes 41 | w.write_all(b"WAVE")?; 42 | 43 | // Format block 44 | w.write_all(b"fmt ")?; 45 | w.write_all(&16u32.to_le_bytes())?; // block len minus 8 bytes 46 | w.write_all(&1u16.to_le_bytes())?; // PCM 47 | w.write_all(&n_channels.to_le_bytes())?; // one channel 48 | w.write_all(&sample_rate.to_le_bytes())?; 49 | w.write_all(&bytes_per_second.to_le_bytes())?; 50 | w.write_all(&2u16.to_le_bytes())?; // 2 bytes of data per sample 51 | w.write_all(&16u16.to_le_bytes())?; // bits per sample 52 | 53 | // Data block 54 | w.write_all(b"data")?; 55 | w.write_all(&(samples.len() as u32 * 2).to_le_bytes())?; 56 | for sample in samples.iter() { 57 | w.write_all(&sample.to_i16().to_le_bytes())? 58 | } 59 | Ok(()) 60 | } 61 | -------------------------------------------------------------------------------- /rust/protocol.md: -------------------------------------------------------------------------------- 1 | # Protocol 2 | 3 | The connection takes place using a websocket. This handles the message lengths 4 | for us. The binary protocol for messages is as follows. The protocol uses little 5 | endian encoding. 6 | 7 | Each message starts by a single byte indicating the message type `MT`. 8 | The format for the rest of the message, aka the payload, depends on `MT`. 9 | 10 | ``` 11 | - Handshake MT=0. The payload is made of two fields. 12 | 1. Protocol version (`u32`) - always 0 for now. 13 | 2. Model version (`u32`). 14 | - Audio MT=1. The payload is made of a single field. 15 | - Binary data for the ogg frames containing opus encoded audio (24kHz, mono). 16 | - Text MT=2. The payload is made of a single field. 17 | - UTF8 encoded string. 18 | - Control MT=3. The payload is made of a single field. This is not used in full 19 | streaming mode. 20 | - One byte B describing the control itself. 21 | - Start B=0. 22 | - EndTurn B=1. 23 | - Pause B=2. 24 | - Restart B=3. 25 | - MetaData MT=4. The payload is made of a single field. 26 | - UTF8 encoded string with json data. 27 | - Error MT=5. The payload is made of a single field. 28 | - UTF8 encoded string containing the error description. 29 | - Ping MT=6. No payload, this message type is currently unused. 30 | ``` 31 | Messages with an unknow message types should be discarded. 32 | -------------------------------------------------------------------------------- /rust/rustfmt.toml: -------------------------------------------------------------------------------- 1 | use_small_heuristics = "Max" 2 | edition = "2021" 3 | 4 | -------------------------------------------------------------------------------- /rust/s2st-1b.toml: -------------------------------------------------------------------------------- 1 | text_in_vocab_size = 48001 2 | text_out_vocab_size = 48000 3 | audio_vocab_size = 2049 4 | audio_codebooks = 16 5 | 6 | [transformer] 7 | d_model = 2048 8 | num_heads = 16 9 | num_layers = 16 10 | dim_feedforward = 8192 11 | causal = true 12 | norm_first = true 13 | bias_ff = false 14 | bias_attn = false 15 | context = 3000 16 | max_period = 100000 17 | use_conv_block = false 18 | use_conv_bias = true 19 | gating = "silu" 20 | norm = "RmsNorm" 21 | positional_embedding = "Rope" 22 | conv_layout = false 23 | conv_kernel_size = 3 24 | kv_repeat = 1 25 | max_seq_len = 4096 26 | 27 | [depformer] 28 | num_slices = 8 29 | 30 | [depformer.transformer] 31 | d_model = 1024 32 | num_heads = 16 33 | num_layers = 6 34 | dim_feedforward = 4096 35 | causal = true 36 | norm_first = true 37 | bias_ff = false 38 | bias_attn = false 39 | context = 32 40 | max_period = 10000 41 | use_conv_block = false 42 | use_conv_bias = true 43 | gating = "silu" 44 | norm = "RmsNorm" 45 | positional_embedding = "None" 46 | conv_layout = false 47 | conv_kernel_size = 3 48 | kv_repeat = 1 49 | max_seq_len = 4096 50 | 51 | [conditioners.description] 52 | type = "Lut" 53 | n_bins = 31 54 | dim = 16 55 | possible_values = ["very_bad", "bad", "neutral", "good", "very_good"] 56 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/export_quantized.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Kyutai, all rights reserved. 2 | # This source code is licensed under the license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | """Convert a repo into a quantized one (PyTorch only). Need to run from a GPU.""" 5 | 6 | 7 | import argparse 8 | import json 9 | from pathlib import Path 10 | import tempfile 11 | 12 | from huggingface_hub import HfApi, hf_hub_download 13 | from safetensors.torch import save_file 14 | 15 | from moshi.models import loaders 16 | from moshi.utils.quantize import replace_linear_with_qlinear 17 | 18 | 19 | def get_api(): 20 | token = input("Write token? ").strip() 21 | api = HfApi(token=token) 22 | return api 23 | 24 | 25 | def main(): 26 | parser = argparse.ArgumentParser('export_quantized') 27 | parser.add_argument('hf_repo') 28 | parser.add_argument('new_hf_repo', nargs='?', default=None) 29 | 30 | args = parser.parse_args() 31 | api = get_api() 32 | 33 | repo = args.hf_repo 34 | 35 | print("Downloading base model.") 36 | info = loaders.CheckpointInfo.from_hf_repo(args.hf_repo) 37 | print("Creating model.") 38 | model = info.get_moshi(fuse_lora=True, device='cuda') 39 | print("Quantizing model.") 40 | replace_linear_with_qlinear(model) 41 | 42 | if args.new_hf_repo is None: 43 | new_repo = repo.rsplit('-', 1)[0] + '-q8' 44 | else: 45 | new_repo = args.new_hf_repo 46 | if not api.repo_exists(new_repo): 47 | api.create_repo(new_repo, repo_type='model') 48 | print("Repo created.") 49 | 50 | to_copy = ['README.md'] 51 | for file in to_copy: 52 | if not api.file_exists(repo, file): 53 | continue 54 | if not api.file_exists(new_repo, file): 55 | print("File", file, "is missing") 56 | old_file = hf_hub_download(repo, file) 57 | api.upload_file( 58 | path_or_fileobj=old_file, 59 | path_in_repo=file, 60 | repo_id=new_repo, 61 | repo_type="model") 62 | with tempfile.NamedTemporaryFile(suffix='.safetensors', delete=True) as file: 63 | save_file(model.state_dict(), file.name) 64 | size = Path(file.name).stat().st_size / 1e9 65 | print(f"Checkpoint size: {size:.1f}GB") 66 | old_name, old_ext = info.moshi_weights.name.rsplit('.', 1) 67 | new_name = old_name + '.q8.' + old_ext 68 | api.upload_file( 69 | path_or_fileobj=file.name, 70 | path_in_repo=new_name, 71 | repo_id=new_repo, 72 | repo_type="model") 73 | config = json.load(open(hf_hub_download(repo, 'config.json'))) 74 | config['moshi_name'] = new_name 75 | config['quantize'] = True 76 | if not config['mimi_name'].startswith('hf://'): 77 | config['mimi_name'] = f'hf://{repo}/{config["mimi_name"]}' 78 | if not config['tokenizer_name'].startswith('hf://'): 79 | config['tokenizer_name'] = f'hf://{repo}/{config["tokenizer_name"]}' 80 | with tempfile.NamedTemporaryFile(mode='w') as file: 81 | json.dump(config, file, indent=2) 82 | file.flush() 83 | api.upload_file( 84 | path_or_fileobj=file.name, 85 | path_in_repo='config.json', 86 | repo_id=new_repo, 87 | repo_type="model") 88 | 89 | 90 | if __name__ == "__main__": 91 | main() 92 | -------------------------------------------------------------------------------- /scripts/export_torch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Kyutai, all rights reserved. 2 | # This source code is licensed under the license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | """Export a model to HF.""" 5 | 6 | 7 | import argparse 8 | import json 9 | import tempfile 10 | 11 | from huggingface_hub import HfApi 12 | 13 | from moshi.models import loaders 14 | 15 | 16 | def get_api(): 17 | token = input("Write token? ").strip() 18 | api = HfApi(token=token) 19 | return api 20 | 21 | 22 | def main(): 23 | parser = argparse.ArgumentParser('export_quantized') 24 | parser.add_argument("--tokenizer", type=str, help="Path to a local tokenizer file.") 25 | parser.add_argument("--moshi-weight", type=str, help="Path to a local checkpoint file for Moshi.") 26 | parser.add_argument("--mimi-weight", type=str, help="Path to a local checkpoint file for Mimi.") 27 | parser.add_argument("--hf-repo", type=str, default=loaders.DEFAULT_REPO, 28 | help="HF repo to look into, defaults Moshiko. " 29 | "Use this to select a different pre-trained model.") 30 | parser.add_argument("--config", "--lm-config", dest="config", type=str, help="The config as a json file.") 31 | parser.add_argument('new_hf_repo') 32 | 33 | args = parser.parse_args() 34 | api = get_api() 35 | 36 | info = loaders.CheckpointInfo.from_hf_repo( 37 | args.hf_repo, moshi_weights=args.moshi_weight, mimi_weights=args.mimi_weight, 38 | tokenizer=args.tokenizer, config_path=args.config) 39 | 40 | if not api.repo_exists(args.new_hf_repo): 41 | api.create_repo(args.new_hf_repo, repo_type='model', private=True) 42 | print("Repo created.") 43 | 44 | config = info.raw_config 45 | assert config is not None 46 | config['mimi_name'] = info.mimi_weights.name 47 | config['moshi_name'] = info.moshi_weights.name 48 | config['tokenizer_name'] = info.tokenizer.name 49 | for file in [info.mimi_weights, info.moshi_weights, info.tokenizer]: 50 | if not api.file_exists(args.new_hf_repo, file.name): 51 | print("Uploading file", file) 52 | api.upload_file( 53 | path_or_fileobj=file, 54 | path_in_repo=file.name, 55 | repo_id=args.new_hf_repo, 56 | repo_type="model") 57 | with tempfile.NamedTemporaryFile(mode='w') as file: 58 | json.dump(config, file, indent=2) 59 | file.flush() 60 | api.upload_file( 61 | path_or_fileobj=file.name, 62 | path_in_repo='config.json', 63 | repo_id=args.new_hf_repo, 64 | repo_type="model") 65 | 66 | 67 | if __name__ == "__main__": 68 | main() 69 | -------------------------------------------------------------------------------- /scripts/import_helium_mlx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Kyutai, all rights reserved. 2 | # This source code is licensed under the license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import argparse 6 | import torch 7 | from pathlib import Path 8 | from safetensors import safe_open 9 | from safetensors.torch import save_file 10 | from huggingface_hub import hf_hub_download 11 | 12 | 13 | def import_model(in_path: Path, out_path: Path, silent: bool = False) -> None: 14 | with safe_open(in_path, framework="pt", device="cpu") as f: 15 | tensors = {key: f.get_tensor(key) for key in f.keys()} 16 | model = { 17 | "text_emb.weight": tensors["model.embed_tokens.weight"], 18 | "text_linear.weight": tensors["lm_head.weight"], 19 | "out_norm.weight": tensors["model.norm.weight"], 20 | } 21 | n_layers = -1 22 | for key in tensors.keys(): 23 | if key.startswith("model.layers."): 24 | layer_idx = int(key.split(".")[2]) 25 | n_layers = max(layer_idx, n_layers) 26 | n_layers += 1 27 | if not silent: 28 | print(f"found {n_layers} layers") 29 | for layer_idx in range(n_layers): 30 | dst_prefix = f"transformer.layers.{layer_idx}." 31 | src_prefix = f"model.layers.{layer_idx}." 32 | _model = { 33 | "norm1.weight": "input_layernorm.weight", 34 | "norm2.weight": "post_attention_layernorm.weight", 35 | "self_attn.out_proj.weight": "self_attn.o_proj.weight", 36 | "gating.linear_out.weight": "mlp.down_proj.weight", 37 | } 38 | for dst, src in _model.items(): 39 | model[dst_prefix + dst] = tensors[src_prefix + src] 40 | gate_proj = tensors[src_prefix + "mlp.gate_proj.weight"] 41 | up_proj = tensors[src_prefix + "mlp.up_proj.weight"] 42 | linear_in = torch.cat([gate_proj, up_proj], dim=0) 43 | model[dst_prefix + "gating.linear_in.weight"] = linear_in 44 | q = tensors[src_prefix + "self_attn.q_proj.weight"] 45 | k = tensors[src_prefix + "self_attn.k_proj.weight"] 46 | v = tensors[src_prefix + "self_attn.v_proj.weight"] 47 | in_proj = torch.cat([q, k, v], dim=0) 48 | model[dst_prefix + "self_attn.in_proj.weight"] = in_proj 49 | 50 | save_file(model, out_path) 51 | 52 | 53 | def main(): 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument( 56 | "--checkpoint", 57 | type=str, 58 | default="kyutai/helium-1-preview-2b", 59 | help="the transformers checkpoint to import", 60 | ) 61 | parser.add_argument("--out", type=str, help="the mlx safetensors file to generate") 62 | parser.add_argument( 63 | "-s", "--silent", action="store_true", help="Only prints the checkpoint name" 64 | ) 65 | args = parser.parse_args() 66 | 67 | ckpt_path = Path(args.checkpoint) 68 | if not ckpt_path.exists(): 69 | ckpt_path = hf_hub_download( 70 | repo_id=args.checkpoint, filename="model.safetensors" 71 | ) 72 | out_path = Path(args.out) 73 | if not out_path.exists(): 74 | import_model(ckpt_path, out_path, silent=args.silent) 75 | print(out_path) 76 | 77 | 78 | if __name__ == "__main__": 79 | main() 80 | -------------------------------------------------------------------------------- /scripts/import_lightformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Kyutai, all rights reserved. 2 | # This source code is licensed under the license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | """Import Moshi model, in particular with support for a 'light' depth transformer 5 | with low rank embeddings and weight sharing for some codebooks.""" 6 | 7 | import argparse 8 | from pathlib import Path 9 | from safetensors.torch import save_file 10 | 11 | import omegaconf 12 | import torch 13 | 14 | 15 | def import_model( 16 | in_path: str, 17 | out_path: str, 18 | ) -> None: 19 | pkg = torch.load(in_path, map_location=torch.device("cpu")) 20 | if 'xp.cfg' in pkg: 21 | cfg = pkg['xp.cfg'] 22 | else: 23 | cfg = omegaconf.OmegaConf.load(Path(in_path).parent / '.hydra/config.yaml') 24 | 25 | model = pkg["fsdp_best_state"]["model"] 26 | 27 | # Asumming same size of both streams n_q. 28 | in_n_q = cfg.compression_model_n_q * 2 29 | out_n_q = cfg.compression_model_n_q 30 | print(f"in_n_q: {in_n_q}, out_n_q: {out_n_q}") 31 | schedule = cfg.transformer_lm.get('depformer_weights_per_step_schedule', None) 32 | if schedule is None: 33 | schedule = list(range(in_n_q)) 34 | 35 | num_weights = max(schedule) + 1 36 | schedule = schedule[:out_n_q] 37 | kept_weights = max(schedule) + 1 38 | print(f"Number of dep weights: {num_weights}, keeping {kept_weights}") 39 | 40 | for idx in range(cfg.transformer_lm.depformer_num_layers): 41 | in_proj_key = f"depformer.layers.{idx}.self_attn.in_proj_weight" 42 | in_proj = model[in_proj_key] 43 | in_proj = in_proj.view(num_weights, -1, *in_proj.shape[1:]) 44 | model[in_proj_key] = in_proj[:kept_weights].view(-1, *in_proj.shape[2:]).contiguous() 45 | out_proj_key = f"depformer.layers.{idx}.self_attn.out_proj.weight" 46 | out_proj = model[out_proj_key] 47 | out_proj = out_proj.view(num_weights, -1, *out_proj.shape[1:]) 48 | model[out_proj_key] = out_proj[:kept_weights].view(-1, *out_proj.shape[2:]).contiguous() 49 | 50 | # For mimi inference, we trim the depformer layer that are unused. 51 | for dep_idx in range(out_n_q - 1, in_n_q - 1): 52 | del model[f"depformer_emb.{dep_idx}.weight"] 53 | if cfg.transformer_lm.get('depformer_low_rank_embeddings'): 54 | del model[f"depformer_emb.{dep_idx}.low_rank.weight"] 55 | for dep_idx in range(out_n_q, in_n_q): 56 | del model[f"linears.{dep_idx}.weight"] 57 | for real_idx in range(kept_weights, num_weights): 58 | model.pop(f"depformer_in.{real_idx}.weight") 59 | for idx in range(cfg.transformer_lm.depformer_num_layers): 60 | model.pop(f"depformer.layers.{idx}.gating.{real_idx}.linear_in.weight") 61 | model.pop(f"depformer.layers.{idx}.gating.{real_idx}.linear_out.weight") 62 | 63 | schedule = schedule[:out_n_q] 64 | 65 | save_file(model, out_path) 66 | 67 | 68 | def main(): 69 | parser = argparse.ArgumentParser( 70 | prog="moshi_import", description="Imports moshi checkpoints" 71 | ) 72 | parser.add_argument("checkpoint", help="The checkpoint to be imported.") 73 | parser.add_argument("out", help="The safetensors out file.") 74 | args = parser.parse_args() 75 | 76 | out_path = Path(args.out) 77 | 78 | if out_path.exists(): 79 | print("file already exists") 80 | else: 81 | import_model(args.checkpoint, out_path) 82 | print(out_path) 83 | 84 | 85 | if __name__ == "__main__": 86 | main() 87 | -------------------------------------------------------------------------------- /scripts/mimi_mlx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Kyutai, all rights reserved. 2 | # This source code is licensed under the license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import argparse 6 | from huggingface_hub import hf_hub_download 7 | import numpy as np 8 | import mlx.core as mx 9 | import sphn 10 | import moshi_mlx 11 | 12 | 13 | def run(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--input", type=str) 16 | parser.add_argument("--model-file", type=str) 17 | parser.add_argument("--hf-repo", type=str, default="kyutai/moshiko-mlx-q4") 18 | parser.add_argument("--streaming", action="store_true") 19 | args = parser.parse_args() 20 | 21 | pcm_in, _ = sphn.read(args.input, sample_rate=24000) 22 | pcm_in = mx.array(pcm_in[0])[None, None] 23 | print(pcm_in.shape) 24 | 25 | if args.model_file is None: 26 | model_file = hf_hub_download(args.hf_repo, "tokenizer-e351c8d8-checkpoint125.safetensors") 27 | else: 28 | model_file = args.model_file 29 | cfg = moshi_mlx.models.mimi.mimi_202407(32) 30 | print("building model", flush=True) 31 | model = moshi_mlx.models.mimi.Mimi(cfg) 32 | print(f"loading weights {model_file}", flush=True) 33 | model.load_pytorch_weights(model_file, strict=True) 34 | print("weights loaded") 35 | 36 | if args.streaming: 37 | chunk_size = 1920 38 | pcm_out = [] 39 | len_ = pcm_in.shape[-1] 40 | print("starting streaming conversion") 41 | for start_idx in range(0, len_, chunk_size): 42 | end_idx = start_idx + chunk_size 43 | if end_idx >= len_: 44 | break 45 | _pcm_in = pcm_in[..., start_idx:end_idx] 46 | codes = model.encode_step(_pcm_in) 47 | _pcm_out = model.decode_step(codes) 48 | pcm_out.append(_pcm_out) 49 | pct = int(100 * start_idx / len_) 50 | print(f"{pct}%", end="\r", flush=True) 51 | print() 52 | pcm_out = mx.concat(pcm_out, axis=-1) 53 | else: 54 | codes = model.encode(pcm_in) 55 | print(codes.shape) 56 | pcm_out = model.decode(codes) 57 | print("writing output file with audio shape", pcm_out.shape) 58 | sphn.write_wav("out.wav", np.array(pcm_out[0]), sample_rate=24000) 59 | 60 | if __name__ == "__main__": 61 | run() 62 | -------------------------------------------------------------------------------- /scripts/mimi_streaming_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Kyutai, all rights reserved. 2 | # This source code is licensed under the license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import argparse 6 | import random 7 | import time 8 | 9 | from huggingface_hub import hf_hub_download 10 | import numpy as np 11 | import sphn 12 | import torch 13 | from torch.profiler import profile, ProfilerActivity 14 | 15 | from moshi.models import loaders 16 | 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--mimi-weight", type=str) 20 | parser.add_argument("--hf-repo", type=str, default=loaders.DEFAULT_REPO) 21 | parser.add_argument( 22 | "--device", type=str, default="cuda" if torch.cuda.device_count() else "cpu" 23 | ) 24 | parser.add_argument("--profile", action="store_true") 25 | args = parser.parse_args() 26 | 27 | 28 | def seed_all(seed): 29 | torch.manual_seed(seed) 30 | if torch.cuda.is_available(): 31 | torch.cuda.manual_seed(seed) 32 | torch.cuda.manual_seed_all(seed) # for multi-GPU setups 33 | random.seed(seed) 34 | np.random.seed(seed) 35 | torch.backends.cudnn.deterministic = True 36 | torch.backends.cudnn.benchmark = False 37 | 38 | 39 | seed_all(42424242) 40 | 41 | 42 | print("loading mimi") 43 | if args.mimi_weight is None: 44 | args.mimi_weight = hf_hub_download(args.hf_repo, loaders.MIMI_NAME) 45 | mimi = loaders.get_mimi(args.mimi_weight, args.device) 46 | print("mimi loaded") 47 | 48 | 49 | def mimi_streaming_test(mimi, max_duration_sec=10.0): 50 | pcm_chunk_size = int(mimi.sample_rate / mimi.frame_rate) 51 | # wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3 52 | sample_pcm, sample_sr = sphn.read("bria.mp3") 53 | sample_rate = mimi.sample_rate 54 | print("loaded pcm", sample_pcm.shape, sample_sr) 55 | sample_pcm = sphn.resample( 56 | sample_pcm, src_sample_rate=sample_sr, dst_sample_rate=sample_rate 57 | ) 58 | sample_pcm = torch.tensor(sample_pcm, device=args.device) 59 | max_duration_len = int(sample_rate * max_duration_sec) 60 | if sample_pcm.shape[-1] > max_duration_len: 61 | sample_pcm = sample_pcm[..., :max_duration_len] 62 | print("resampled pcm", sample_pcm.shape, sample_sr) 63 | sample_pcm = sample_pcm[None].to(device=args.device) 64 | 65 | print("streaming encoding...") 66 | start_time = time.time() 67 | all_codes = [] 68 | 69 | def run_loop(): 70 | for start_idx in range(0, sample_pcm.shape[-1], pcm_chunk_size): 71 | end_idx = min(sample_pcm.shape[-1], start_idx + pcm_chunk_size) 72 | chunk = sample_pcm[..., start_idx:end_idx] 73 | codes = mimi.encode(chunk) 74 | if codes.shape[-1]: 75 | print(start_idx, codes.shape, end="\r") 76 | all_codes.append(codes) 77 | 78 | if args.profile: 79 | with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: 80 | run_loop() 81 | prof.export_chrome_trace("trace.json") 82 | else: 83 | run_loop() 84 | all_codes_th = torch.cat(all_codes, dim=-1) 85 | print(f"codes {all_codes_th.shape} generated in {time.time() - start_time:.2f}s") 86 | print("streaming decoding...") 87 | all_pcms = [] 88 | with mimi.streaming(1): 89 | for i in range(all_codes_th.shape[-1]): 90 | codes = all_codes_th[..., i : i + 1] 91 | pcm = mimi.decode(codes) 92 | print(i, pcm.shape, end="\r") 93 | all_pcms.append(pcm) 94 | all_pcms = torch.cat(all_pcms, dim=-1) 95 | print("pcm", all_pcms.shape, all_pcms.dtype) 96 | sphn.write_wav("streaming_out.wav", all_pcms[0, 0].cpu().numpy(), sample_rate) 97 | pcm = mimi.decode(all_codes_th) 98 | print("pcm", pcm.shape, pcm.dtype) 99 | sphn.write_wav("roundtrip_out.wav", pcm[0, 0].cpu().numpy(), sample_rate) 100 | 101 | 102 | with torch.no_grad(): 103 | mimi_streaming_test(mimi) 104 | -------------------------------------------------------------------------------- /scripts/quantize_mlx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Kyutai, all rights reserved. 2 | # This source code is licensed under the license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import argparse 6 | 7 | import mlx.core as mx 8 | import mlx.nn as nn 9 | 10 | import moshi_mlx 11 | 12 | 13 | def main(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("original_weights", type=str) 16 | parser.add_argument("--out", type=str) 17 | parser.add_argument("--config", type=str, default="v0_1") 18 | parser.add_argument("--bits", type=int, default=8) 19 | parser.add_argument("--group-size", type=int, default=64) 20 | args = parser.parse_args() 21 | 22 | model_file = args.original_weights 23 | 24 | if args.config == "v0_1": 25 | lm_config = moshi_mlx.models.config_v0_1() 26 | elif args.config == "1b": 27 | lm_config = moshi_mlx.models.config1b_202412() 28 | elif args.config == "1b-16rvq": 29 | lm_config = moshi_mlx.models.config1b_202412_16rvq() 30 | elif args.config == "helium-2b": 31 | lm_config = moshi_mlx.models.config_helium_1_preview_2b() 32 | else: 33 | raise ValueError(f"unknown config name '{args.config}'") 34 | print(f"model config:\n{lm_config}") 35 | 36 | model = moshi_mlx.models.Lm(lm_config) 37 | model.set_dtype(mx.bfloat16) 38 | print(f"loading weights {model_file}") 39 | model.load_weights(model_file, strict=True) 40 | print("weights loaded") 41 | 42 | nn.quantize(model, bits=args.bits, group_size=args.group_size) 43 | print(f"saving the quantized q{args.bits} weights in {args.out}") 44 | model.save_weights(args.out) 45 | 46 | 47 | if __name__ == "__main__": 48 | main() 49 | -------------------------------------------------------------------------------- /scripts/run_ci_when_installed.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This script is used to detect if moshi or moshi_mlx are installed, and run 4 | # their CI only in that case! 5 | 6 | package=$1 7 | if python -c "from $package import models"; then 8 | # package is installed, let's run the command 9 | eval $2 10 | else 11 | echo "Package $package not installed, skipping the CI for it." 12 | fi 13 | -------------------------------------------------------------------------------- /scripts/setup.cfg: -------------------------------------------------------------------------------- 1 | [pep8] 2 | max-line-length = 120 3 | 4 | [flake8] 5 | max-line-length = 120 6 | ignore = E203,E704 7 | -------------------------------------------------------------------------------- /scripts/test_mimi.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Kyutai, all rights reserved. 2 | # This source code is licensed under the license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import argparse 6 | import numpy as np 7 | import time 8 | 9 | import rustymimi 10 | 11 | 12 | def main(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--model", type=str) 15 | parser.add_argument("--steps", default=100, type=int) 16 | args = parser.parse_args() 17 | 18 | steps = args.steps 19 | model = rustymimi.Tokenizer(str(args.model)) 20 | print(model) 21 | 22 | start_time = 0 23 | for i in range(steps + 1): 24 | if i == 1: 25 | start_time = time.time() 26 | pcm_data = np.array([[[0.0] * 1920]]).astype(np.float32) 27 | out = model.encode_step(pcm_data) 28 | print(out.shape) 29 | pcm_data = model.decode_step(out) 30 | print(pcm_data) 31 | token_per_second = steps / (time.time() - start_time) 32 | print(f"steps: {steps}, token per sec: {token_per_second}") 33 | 34 | 35 | if __name__ == "__main__": 36 | main() 37 | -------------------------------------------------------------------------------- /scripts/test_missing_data.py: -------------------------------------------------------------------------------- 1 | """Testing exec_mask feature of the streaming module, where each batch entry 2 | can advance at its own pace, while retaining full compat with CUDA Graph.""" 3 | import sys 4 | import sphn 5 | import torch 6 | 7 | from moshi.models import loaders 8 | 9 | device = 'cuda' 10 | wnp, sr = sphn.read(sys.argv[1], 11 | start_sec=0, duration_sec=8, sample_rate=24000) 12 | wav = torch.from_numpy(wnp) 13 | ci = loaders.CheckpointInfo.from_hf_repo("kyutai/moshiko-pytorch-bf16") 14 | mimi = ci.get_mimi(device=device) 15 | mimi.eval() 16 | 17 | frame = int(mimi.sample_rate / mimi.frame_rate) 18 | 19 | total = wav.shape[-1] // frame 20 | B = 4 21 | wav = wav[..., :total * frame][None] 22 | remaining = [total for _ in range(B)] 23 | 24 | print("Ref computation") 25 | with torch.no_grad(): 26 | ref_codes = mimi.encode(wav[:1].to(device=device)) 27 | ref_audio = mimi.decode(ref_codes[:1].to(device=device)) 28 | 29 | out_audio = torch.zeros_like(ref_audio.expand(B, -1, -1)) 30 | out_codes = torch.zeros_like(ref_codes.expand(B, -1, -1)) 31 | 32 | print("Going streaming") 33 | with mimi.streaming(B), torch.no_grad(): 34 | while any(remaining): 35 | inputs = [] 36 | exec_mask = torch.rand(B) < 0.5 37 | offsets = [] 38 | for b, this_remaining in enumerate(remaining): 39 | offset = min(total - this_remaining, total - 1) 40 | offsets.append(offset) 41 | inputs.append(wav[0, :, offset * frame: offset * frame + frame]) 42 | input_ = torch.stack(inputs) 43 | mimi.set_exec_mask(exec_mask) 44 | codes = mimi.encode(input_.to(device=device)) 45 | assert codes.shape[-1] == 1, codes.shape 46 | w = mimi.decode(codes) 47 | assert w.shape[-1] == frame, w.shape 48 | print(remaining) 49 | for b, active in enumerate(exec_mask.tolist()): 50 | if not active or remaining[b] == 0: 51 | continue 52 | remaining[b] = max(0, remaining[b] - 1) 53 | offset = offsets[b] 54 | out_codes[b, :, offset: offset + 1] = codes[b] 55 | out_audio[b, :, offset * frame: offset * frame + frame] = w[b] 56 | 57 | print(ref_codes[0, :, -1]) 58 | print(out_codes[0, :, -1]) 59 | for b in range(B): 60 | print((out_codes[b] != ref_codes[:1]).any(dim=0).nonzero()[:1]) 61 | d = (out_codes[..., :1] == ref_codes[..., :1]).float().mean(dim=(1, 2)) 62 | print(d) 63 | d = (out_codes[..., :2] == ref_codes[..., :2]).float().mean(dim=(1, 2)) 64 | print(d) 65 | d = (out_codes[..., :10] == ref_codes[..., :10]).float().mean(dim=(1, 2)) 66 | print(d) 67 | d = (out_codes == ref_codes).float().mean(dim=(1, 2)) 68 | print(d) 69 | d = (out_audio - ref_audio).norm(dim=-1, p=2) / ref_audio.norm(dim=-1, p=2) 70 | print(d.sum(dim=1)) 71 | -------------------------------------------------------------------------------- /scripts/test_missing_data_lm.py: -------------------------------------------------------------------------------- 1 | """Testing exec_mask feature of the streaming module, where each batch entry 2 | can advance at its own pace, while retaining full compat with CUDA Graph.""" 3 | import sys 4 | import sphn 5 | import torch 6 | 7 | from moshi.models import loaders 8 | from moshi.models.lm import LMGen 9 | from moshi.conditioners import ConditionAttributes 10 | 11 | device = 'cuda' 12 | wnp, sr = sphn.read(sys.argv[1], 13 | start_sec=0, duration_sec=8, sample_rate=24000) 14 | wav = torch.from_numpy(wnp) 15 | ci = loaders.CheckpointInfo.from_hf_repo("kyutai/hibiki-2b-pytorch-bf16") 16 | mimi = ci.get_mimi(device=device) 17 | mimi.eval() 18 | 19 | lm = ci.get_moshi(device=device) 20 | 21 | B = 4 22 | 23 | with torch.no_grad(): 24 | codes = mimi.encode(wav[:1].to(device=device)[None]) 25 | 26 | T = codes.shape[-1] 27 | offsets = [0 for _ in range(B)] 28 | 29 | out_codes: list[list[torch.Tensor]] = [[] for _ in range(B)] 30 | assert lm.condition_provider is not None 31 | conditions = [ConditionAttributes(text={"description": "very_good"}, tensor={})] * B 32 | prepared = lm.condition_provider.prepare(conditions) 33 | condition_tensors = lm.condition_provider(prepared) 34 | lm_gen = LMGen(lm, temp=0., temp_text=0., support_out_of_sync=True, condition_tensors=condition_tensors) 35 | print("Going streaming") 36 | with torch.no_grad(), lm_gen.streaming(B): 37 | while any(o < T for o in offsets): 38 | inputs = [] 39 | exec_mask = torch.rand(B) < 0.5 40 | exec_mask[0] = True 41 | for offset in offsets: 42 | inputs.append(codes[:, :, min(offset, T - 1)]) 43 | input_ = torch.cat(inputs)[..., None] 44 | lm_gen.set_exec_mask(exec_mask) 45 | pred = lm_gen.step(input_.to(device=device)) 46 | assert pred is not None 47 | assert pred.shape[-1] == 1, pred.shape 48 | for b, active in enumerate(exec_mask.tolist()): 49 | if not active or offsets[b] >= T: 50 | continue 51 | if offsets[b] >= 2: 52 | assert (pred[b] >= 0).all() 53 | offsets[b] += 1 54 | out_codes[b].append(pred[b]) 55 | 56 | alls = [] 57 | for frames in out_codes: 58 | out = torch.cat(frames, -1) 59 | alls.append(out) 60 | 61 | ys = torch.stack(alls) 62 | r = ys[:1] 63 | o = ys[1:] 64 | 65 | ma = (r == o).float().mean(dim=(0, 1)) 66 | print(ma) 67 | print(r[0, :2, :5]) 68 | print(o[0, :2, :5]) 69 | print(o[1, :2, :5]) 70 | print(o[2, :2, :5]) 71 | -------------------------------------------------------------------------------- /scripts/update_repo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Kyutai, all rights reserved. 2 | # This source code is licensed under the license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | """Create or update a HF repo.""" 5 | 6 | 7 | import argparse 8 | import json 9 | from pathlib import Path 10 | import tempfile 11 | 12 | from huggingface_hub import HfApi, hf_hub_download 13 | from safetensors.torch import save_file 14 | 15 | from moshi.models import loaders 16 | from moshi.modules.transformer import quantize_transformer 17 | 18 | 19 | def get_api(): 20 | token = input("Write token? ").strip() 21 | api = HfApi(token=token) 22 | return api 23 | 24 | def main(): 25 | parser = argparse.ArgumentParser('update_repo') 26 | parser.add_argument('repo') 27 | parser.add_argument('-c', '--config') 28 | parser.add_argument('-r', '--readme') 29 | parser.add_argument('-m', '--mimi-weight') 30 | parser.add_argument('-M', '--moshi-weight') 31 | parser.add_argument('-t', '--tokenizer') 32 | 33 | args = parser.parse_args() 34 | 35 | api = get_api() 36 | if not api.repo_exists(args.repo): 37 | api.create_repo(args.repo, repo_type='model') 38 | print(f"Repo {args.repo} created.") 39 | 40 | old_config = None 41 | if api.file_exists(args.repo, 'config.json'): 42 | old_config = json.load(open(hf_hub_download(args.repo, 'config.json'))) 43 | 44 | changes = False 45 | if args.config: 46 | changes = True 47 | new_config = json.load(open(args.config)) 48 | elif old_config: 49 | new_config = old_config 50 | else: 51 | new_config = {} 52 | 53 | names = ['mimi_name', 'moshi_name', 'tokenizer_name'] 54 | paths = [args.mimi_weight, args.moshi_weight, args.tokenizer] 55 | for name, path in zip(names, paths): 56 | if path is None: 57 | if old_config is not None and name in old_config: 58 | new_config[name] = old_config[name] 59 | continue 60 | filename = Path(path).name 61 | print(f"Uploading {path}") 62 | api.upload_file( 63 | path_or_fileobj=path, 64 | path_in_repo=filename, 65 | repo_id=args.repo, 66 | repo_type="model") 67 | new_config[name] = filename 68 | changes = True 69 | 70 | if changes: 71 | with tempfile.NamedTemporaryFile(mode='w') as file: 72 | json.dump(new_config, file, indent=2) 73 | file.flush() 74 | api.upload_file( 75 | path_or_fileobj=file.name, 76 | path_in_repo='config.json', 77 | repo_id=args.repo, 78 | repo_type="model") 79 | 80 | 81 | if __name__ == "__main__": 82 | main() 83 | --------------------------------------------------------------------------------