├── .github
├── ISSUE_TEMPLATE
│ ├── bug.yml
│ └── question.yml
├── PULL_REQUEST_TEMPLATE.md
├── actions
│ ├── moshi_build
│ │ └── action.yml
│ └── rust_build
│ │ └── action.yml
└── workflows
│ ├── precommit.yml
│ ├── rust-ci.yml
│ └── rustymimi-ci.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CONTRIBUTING.md
├── FAQ.md
├── LICENSE-APACHE
├── LICENSE-MIT
├── README.md
├── client
├── .env.local
├── .eslinrc.json
├── .nvmrc
├── .prettierignore
├── .prettierrc.json
├── Dockerfile
├── LICENSE
├── README.md
├── index.html
├── package-lock.json
├── package.json
├── postcss.config.js
├── public
│ └── assets
│ │ ├── decoderWorker.min.js
│ │ ├── decoderWorker.min.wasm
│ │ ├── favicon-16x16.png
│ │ ├── favicon-32x32.png
│ │ ├── favicon.ico
│ │ ├── images
│ │ └── demo
│ │ │ ├── image1.jpg
│ │ │ ├── image10.jpg
│ │ │ ├── image11.jpg
│ │ │ ├── image12.jpg
│ │ │ ├── image13.jpg
│ │ │ ├── image14.jpg
│ │ │ ├── image15.jpg
│ │ │ ├── image16.jpg
│ │ │ ├── image17.jpg
│ │ │ ├── image18.jpg
│ │ │ ├── image19.jpg
│ │ │ ├── image2.jpg
│ │ │ ├── image20.jpg
│ │ │ ├── image3.jpg
│ │ │ ├── image4.jpg
│ │ │ ├── image5.jpg
│ │ │ ├── image6.jpg
│ │ │ ├── image7.jpg
│ │ │ ├── image8.jpg
│ │ │ └── image9.jpg
│ │ └── logo.svg
├── src
│ ├── app.tsx
│ ├── audio-processor.ts
│ ├── components
│ │ ├── Button
│ │ │ └── Button.tsx
│ │ ├── ImageGallery
│ │ │ └── ImageGallery.tsx
│ │ └── Input
│ │ │ └── Input.tsx
│ ├── decoder
│ │ └── decoderWorker.ts
│ ├── env.ts
│ ├── index.css
│ ├── modules.d.ts
│ ├── pages
│ │ ├── Conversation
│ │ │ ├── Conversation.tsx
│ │ │ ├── MediaContext.ts
│ │ │ ├── SocketContext.ts
│ │ │ ├── canvas-logo-black.png
│ │ │ ├── canvas-logo.png
│ │ │ ├── components
│ │ │ │ ├── AudioVisualizer
│ │ │ │ │ ├── AudioVisualizer.tsx
│ │ │ │ │ ├── ClientVisualizer.tsx
│ │ │ │ │ └── ServerVisualizer.tsx
│ │ │ │ ├── Controls
│ │ │ │ │ └── Controls.tsx
│ │ │ │ ├── ModelParams
│ │ │ │ │ └── ModelParams.tsx
│ │ │ │ ├── ServerAudio
│ │ │ │ │ ├── ServerAudio.tsx
│ │ │ │ │ └── ServerAudioStats.tsx
│ │ │ │ ├── ServerInfo
│ │ │ │ │ └── ServerInfo.tsx
│ │ │ │ ├── TextDisplay
│ │ │ │ │ ├── TextDisplay.tsx
│ │ │ │ │ └── TextDisplayStats.tsx
│ │ │ │ └── UserAudio
│ │ │ │ │ ├── UserAudio.tsx
│ │ │ │ │ └── UserAudioStats.tsx
│ │ │ ├── getMimeType.ts
│ │ │ └── hooks
│ │ │ │ ├── audioUtils.ts
│ │ │ │ ├── useModelParams.ts
│ │ │ │ ├── useServerAudio.ts
│ │ │ │ ├── useServerInfo.ts
│ │ │ │ ├── useServerText.ts
│ │ │ │ ├── useSocket.ts
│ │ │ │ └── useUserAudio.ts
│ │ └── Queue
│ │ │ ├── Queue.tsx
│ │ │ ├── api
│ │ │ ├── client.ts
│ │ │ ├── errors
│ │ │ │ ├── api_error.ts
│ │ │ │ └── response_error.ts
│ │ │ └── validators.ts
│ │ │ └── hooks
│ │ │ └── useUserEmail.ts
│ └── protocol
│ │ ├── encoder.ts
│ │ ├── testMessages.ts
│ │ └── types.ts
├── tailwind.config.js
├── tsconfig.json
└── vite.config.ts
├── configs
├── moshi_7b_202409.json
├── moshi_dev_2b.json
└── moshi_mlx_2b.json
├── data
├── sample_fr_hibiki_crepes.mp3
├── sample_fr_hibiki_intro.mp3
└── sample_fr_hibiki_monologue_otis.mp3
├── docker-compose.yml
├── mimi.png
├── moshi.png
├── moshi
├── Dockerfile
├── LICENSE
├── LICENSE.audiocraft
├── MANIFEST.in
├── README.md
├── demo_moshi.ipynb
├── moshi
│ ├── __init__.py
│ ├── client.py
│ ├── client_gradio.py
│ ├── client_utils.py
│ ├── conditioners
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── tensors.py
│ │ └── text.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── compression.py
│ │ ├── lm.py
│ │ ├── lm_utils.py
│ │ └── loaders.py
│ ├── modules
│ │ ├── __init__.py
│ │ ├── conv.py
│ │ ├── conv_test.py
│ │ ├── gating.py
│ │ ├── lora.py
│ │ ├── resample.py
│ │ ├── rope.py
│ │ ├── seanet.py
│ │ ├── seanet_test.py
│ │ ├── streaming.py
│ │ └── transformer.py
│ ├── quantization
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── core_vq.py
│ │ └── vq.py
│ ├── run_inference.py
│ ├── server.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── autocast.py
│ │ ├── compile.py
│ │ ├── quantize.py
│ │ ├── sampling.py
│ │ └── utils.py
├── pyproject.toml
├── requirements.txt
├── setup.cfg
└── tests
│ ├── assets
│ ├── test_lm_codes.safetensors
│ ├── test_lm_model.safetensors
│ └── test_lm_out.safetensors
│ └── test_lm.py
├── moshi_mlx
├── LICENSE
├── MANIFEST.in
├── README.md
├── moshi_mlx
│ ├── __init__.py
│ ├── client_utils.py
│ ├── local.py
│ ├── local_web.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── generate.py
│ │ ├── lm.py
│ │ └── mimi.py
│ ├── modules
│ │ ├── __init__.py
│ │ ├── conditioner.py
│ │ ├── conv.py
│ │ ├── kv_cache.py
│ │ ├── quantization.py
│ │ ├── seanet.py
│ │ └── transformer.py
│ ├── py.typed
│ ├── run_helium.py
│ ├── run_inference.py
│ └── utils
│ │ ├── __init__.py
│ │ └── sampling.py
├── pyproject.toml
├── requirements.txt
└── setup.cfg
├── requirements-dev.txt
├── rust
├── .github
│ └── workflows
│ │ └── rust-ci.yml
├── Cargo.lock
├── Cargo.toml
├── LICENSE
├── README.md
├── mimi-pyo3
│ ├── Cargo.toml
│ ├── py_src
│ │ └── rustymimi
│ │ │ ├── __init__.py
│ │ │ └── __init__.pyi
│ ├── pyproject.toml
│ ├── src
│ │ └── lib.rs
│ └── stub.py
├── moshi-backend
│ ├── Cargo.toml
│ ├── build.rs
│ ├── config-q8.json
│ ├── config.json
│ └── src
│ │ ├── audio.rs
│ │ ├── benchmark.rs
│ │ ├── build.rs
│ │ ├── main.rs
│ │ ├── standalone.rs
│ │ ├── stream_both.rs
│ │ └── utils.rs
├── moshi-cli
│ ├── Cargo.toml
│ └── src
│ │ ├── audio_io.rs
│ │ ├── gen.rs
│ │ ├── main.rs
│ │ └── multistream.rs
├── moshi-core
│ ├── Cargo.toml
│ └── src
│ │ ├── asr.rs
│ │ ├── batched_transformer.rs
│ │ ├── conditioner.rs
│ │ ├── conv.rs
│ │ ├── kv_cache.rs
│ │ ├── lib.rs
│ │ ├── lm.rs
│ │ ├── lm_generate.rs
│ │ ├── lm_generate_multistream.rs
│ │ ├── mimi.rs
│ │ ├── nn.rs
│ │ ├── quantization.rs
│ │ ├── seanet.rs
│ │ ├── streaming.rs
│ │ ├── transformer.rs
│ │ ├── tts.rs
│ │ ├── tts_streaming.rs
│ │ └── wav.rs
├── protocol.md
├── rustfmt.toml
└── s2st-1b.toml
└── scripts
├── __init__.py
├── export_quantized.py
├── export_torch.py
├── import_helium_mlx.py
├── import_lightformer.py
├── import_mlx.py
├── import_mlx_lora.py
├── import_pytorch.py
├── import_rust.py
├── import_rust_lora.py
├── mimi_mlx.py
├── mimi_streaming_test.py
├── moshi_benchmark.py
├── quantize_mlx.py
├── run_ci_when_installed.sh
├── setup.cfg
├── test_mimi.py
├── test_missing_data.py
├── test_missing_data_lm.py
├── test_mlx.py
└── update_repo.py
/.github/ISSUE_TEMPLATE/bug.yml:
--------------------------------------------------------------------------------
1 | name: Bug Report
2 | description: You found a bug.
3 | labels: ["bug", "triage"]
4 | body:
5 | - type: dropdown
6 | id: backend
7 | attributes:
8 | label: Backend impacted
9 | description: Which backend is concerned with your bug report?
10 | options:
11 | - The PyTorch implementation
12 | - The MLX implementation
13 | - The Rust implementation
14 | - Other / All
15 | default: 0
16 | validations:
17 | required: true
18 | - type: dropdown
19 | id: os
20 | attributes:
21 | label: Operating system
22 | description: What is your operating system?
23 | options:
24 | - Linux
25 | - Mac OS X
26 | - Windows (unsupported)
27 | default: 0
28 | validations:
29 | required: true
30 | - type: dropdown
31 | id: hardware
32 | attributes:
33 | label: Hardware
34 | description: What hardware are you using?
35 | options:
36 | - CPU
37 | - GPU with CUDA
38 | - Metal with MLX
39 | default: 0
40 | validations:
41 | required: true
42 | - type: textarea
43 | id: description
44 | attributes:
45 | label: Description
46 | description: Provide a detailed description of your bug.
47 | placeholder:
48 | value:
49 | validations:
50 | required: true
51 | - type: textarea
52 | id: more_info
53 | attributes:
54 | label: Extra information
55 | description: Please provide any other relevant information, such as log extracts, code etc.
56 | placeholder:
57 | value:
58 | validations:
59 | required: true
60 | - type: textarea
61 | id: env
62 | attributes:
63 | label: Environment
64 | description: Please provide any other relevant information, such as log extracts, code etc.
65 | placeholder:
66 | value: |
67 | Fill in the following information on your system.
68 | - Operating system version:
69 |
70 | If the backend impacted is PyTorch:
71 | - Python version:
72 | - PyTorch version:
73 | - CUDA version (run `python -c 'import torch; print(torch.version.cuda)'`):
74 | - GPU model and memory:
75 |
76 | If the backend is MLX:
77 | - Mac model:
78 | validations:
79 | required: true
80 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/question.yml:
--------------------------------------------------------------------------------
1 | name: Question
2 | description: You have a question about Moshi/Mimi, this codebase.
3 | labels: ["question", "triage"]
4 | body:
5 | - type: markdown
6 | attributes:
7 | value: |
8 | Please first check the [FAQ](https://github.com/kyutai-labs/moshi/blob/main/FAQ.md).
9 | - type: checkboxes
10 | id: terms
11 | attributes:
12 | label: Due diligence
13 | description: Have you searched the existing issues / FAQ / Google / asked ChatGPT?
14 | options:
15 | - label: I have done my due diligence in trying to find the answer myself.
16 | required: true
17 |
18 | - type: dropdown
19 | id: backend
20 | attributes:
21 | label: Topic
22 | description: What is your question about?
23 | options:
24 | - The paper
25 | - The PyTorch implementation
26 | - The MLX implementation
27 | - The Rust implementation
28 | - Other / All
29 | default: 0
30 | validations:
31 | required: true
32 | - type: textarea
33 | id: question
34 | attributes:
35 | label: Question
36 | description: What is your question?
37 | placeholder: Your question. Please make sure this is directly related to our codebase. We will not provide support for installing PyTorch, CUDA, Rust etc.
38 | value:
39 | validations:
40 | required: true
41 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | ## Checklist
2 |
3 | - [ ] Read CONTRIBUTING.md, and accept the CLA by including the provided snippet. We will not accept PR without this.
4 | - [ ] Run pre-commit hook.
5 | - [ ] If you changed Rust code, run `cargo check`, `cargo clippy`, `cargo test`.
6 |
7 | ## PR Description
8 |
9 |
10 |
--------------------------------------------------------------------------------
/.github/actions/moshi_build/action.yml:
--------------------------------------------------------------------------------
1 | name: moshi_build
2 | description: 'Build env.'
3 | runs:
4 | using: "composite"
5 | steps:
6 | - uses: actions/setup-python@v2
7 | with:
8 | python-version: '3.10.14'
9 | - uses: actions/cache@v3
10 | id: cache
11 | with:
12 | path: env
13 | key: env-${{ hashFiles('moshi/pyproject.toml') }}
14 | - name: Install dependencies
15 | if: steps.cache.outputs.cache-hit != 'true'
16 | shell: bash
17 | run: |
18 | python3 -m venv env
19 | . env/bin/activate
20 | python -m pip install --upgrade pip
21 | pip install torch==2.4.0 --index-url https://download.pytorch.org/whl/cpu
22 | - name: Setup env
23 | shell: bash
24 | run: |
25 | source env/bin/activate
26 | pip install -e './moshi[dev]'
27 | pre-commit install
28 |
--------------------------------------------------------------------------------
/.github/actions/rust_build/action.yml:
--------------------------------------------------------------------------------
1 | name: rust_build
2 | description: 'Setup rust env'
3 | inputs:
4 | os:
5 | default: ubuntu-latest
6 | toolchain:
7 | default: stable
8 | target:
9 | default: check
10 | runs:
11 | using: "composite"
12 | steps:
13 | - uses: actions-rs/toolchain@v1
14 | with:
15 | profile: minimal
16 | toolchain: ${{ inputs.toolchain }}
17 | override: true
18 | - name: cargo cache
19 | uses: actions/cache@v3
20 | with:
21 | path: |
22 | ~/.cargo/bin/
23 | ~/.cargo/registry/index/
24 | ~/.cargo/registry/cache/
25 | ~/.cargo/git/db/
26 | rust/target/
27 | key: ${{ inputs.os }}-cargo-${{ inputs.target }}-${{ hashFiles('**/Cargo.toml') }}
28 | restore-keys: ${{ inputs.os }}-cargo-
29 | - name: install deps
30 | shell: bash
31 | run: |
32 | sudo apt-get update
33 | sudo apt-get install libasound2-dev cmake
34 | echo "test"
35 | cmake --version
36 | apt-cache show cmake
37 |
--------------------------------------------------------------------------------
/.github/workflows/precommit.yml:
--------------------------------------------------------------------------------
1 | name: precommit
2 | on:
3 | push:
4 | branches: [ main ]
5 | pull_request:
6 | branches: [ main, refacto ]
7 |
8 | jobs:
9 | run_precommit:
10 | name: Run precommit
11 | runs-on: ubuntu-latest
12 | steps:
13 | - uses: actions/checkout@v2
14 | - uses: ./.github/actions/moshi_build
15 | - run: |
16 | . env/bin/activate
17 | bash .git/hooks/pre-commit
18 |
--------------------------------------------------------------------------------
/.github/workflows/rust-ci.yml:
--------------------------------------------------------------------------------
1 | on:
2 | push:
3 | branches: [ main ]
4 | pull_request:
5 | branches: [ main, refacto ]
6 |
7 | name: Rust CI
8 |
9 | jobs:
10 | check:
11 | name: Check
12 | defaults:
13 | run:
14 | working-directory: ./rust
15 | runs-on: ${{ matrix.os }}
16 | strategy:
17 | matrix:
18 | os: [ubuntu-latest]
19 | rust: [stable]
20 | steps:
21 | - uses: actions/checkout@v2
22 | - uses: ./.github/actions/rust_build
23 | - name: check
24 | shell: bash
25 | run: |
26 | cargo check
27 | - name: clippy
28 | shell: bash
29 | run: |
30 | cargo clippy -- -D warnings
31 | - name: fmt
32 | shell: bash
33 | run: |
34 | cargo fmt --all -- --check
35 | test:
36 | name: Test
37 | defaults:
38 | run:
39 | working-directory: ./rust
40 | runs-on: ${{ matrix.os }}
41 | strategy:
42 | matrix:
43 | os: [ubuntu-latest]
44 | rust: [stable]
45 | steps:
46 | - uses: actions/checkout@v2
47 | - uses: actions/setup-python@v5
48 | with:
49 | python-version: 3.11
50 | - uses: ./.github/actions/rust_build
51 | with:
52 | target: test
53 | - name: test
54 | shell: bash
55 | run: |
56 | cargo test
57 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 | target-trunk/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | # For a library or package, you might want to ignore these files since the code is
88 | # intended to run in multiple environments; otherwise, check them in:
89 | # .python-version
90 |
91 | # pipenv
92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
95 | # install all needed dependencies.
96 | #Pipfile.lock
97 |
98 | # poetry
99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
100 | # This is especially recommended for binary packages to ensure reproducibility, and is more
101 | # commonly ignored for libraries.
102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
103 | #poetry.lock
104 |
105 | # pdm
106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
107 | #pdm.lock
108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
109 | # in version control.
110 | # https://pdm.fming.dev/#use-with-ide
111 | .pdm.toml
112 |
113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
114 | __pypackages__/
115 |
116 | # Celery stuff
117 | celerybeat-schedule
118 | celerybeat.pid
119 |
120 | # SageMath parsed files
121 | *.sage.py
122 |
123 | # Environments
124 | .env
125 | .venv
126 | env/
127 | venv/
128 | ENV/
129 | env.bak/
130 | venv.bak/
131 |
132 | # Spyder project settings
133 | .spyderproject
134 | .spyproject
135 |
136 | # Rope project settings
137 | .ropeproject
138 |
139 | # mkdocs documentation
140 | /site
141 |
142 | # mypy
143 | .mypy_cache/
144 | .dmypy.json
145 | dmypy.json
146 |
147 | # Pyre type checker
148 | .pyre/
149 |
150 | # pytype static type analyzer
151 | .pytype/
152 |
153 | # Cython debug symbols
154 | cython_debug/
155 |
156 | # PyCharm
157 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
159 | # and can be added to the global gitignore or merged into this file. For a more nuclear
160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
161 | #.idea/
162 |
163 | # VsCode
164 | .vscode/
165 |
166 | *~
167 | *.safetensors
168 | *.wav
169 | trace*.json
170 | *.flac
171 | pkg
172 | *.nsys-rep
173 | *.sqlite
174 | *.pem
175 | *.tgz
176 | *.mp3
177 | *.ogg
178 | /moshi-demo/config.sh
179 | log.*
180 | laurent/cuda-test/check
181 | client/node_modules
182 | timings.json
183 | mlx-trace.json
184 | /scripts/token.txt
185 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: local
3 | hooks:
4 | - id: flake8-moshi
5 | name: flake8 on moshi package
6 | language: system
7 | entry: bash -c 'cd moshi && flake8'
8 | pass_filenames: false
9 | always_run: true
10 | - id: pyright-moshi
11 | name: pyright on moshi package
12 | language: system
13 | entry: scripts/run_ci_when_installed.sh moshi 'cd moshi && pyright'
14 | pass_filenames: false
15 | always_run: true
16 | - id: tests-moshi
17 | name: pytests on moshi package
18 | language: system
19 | entry: scripts/run_ci_when_installed.sh moshi 'cd moshi && pytest tests'
20 | pass_filenames: false
21 | always_run: true
22 | - id: flake8-moshi_mlx
23 | name: flake8 on moshi_mlx package
24 | language: system
25 | entry: bash -c 'cd moshi_mlx && flake8'
26 | pass_filenames: false
27 | always_run: true
28 | - id: pyright-moshi_mlx
29 | name: pyright on moshi_mlx package
30 | language: system
31 | entry: scripts/run_ci_when_installed.sh moshi_mlx 'cd moshi_mlx && pyright'
32 | pass_filenames: false
33 | always_run: true
34 |
35 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to Moshi
2 |
3 | ## Pull Requests
4 |
5 | Moshi is the implementation of a research paper.
6 | Therefore, we do not plan on accepting many pull requests for new features.
7 | However, we certainly welcome them for bug fixes.
8 |
9 | 1. Fork the repo and create your branch from `main`.
10 | 2. If you have changed APIs, update the documentation accordingly.
11 | 3. Ensure pre-commit hooks pass properly, in particular the linting and typing.
12 | 4. When changing the Rust code, run `cargo check`, `cargo clippy`, `cargo test`.
13 | 5. Accept the Contributor License Agreement (see after).
14 |
15 | Note that in general, we will not accept refactoring of the code.
16 |
17 |
18 | ## Contributor License Agreement ("CLA")
19 |
20 | In order to accept your pull request, we need you to submit a Contributor License Agreement.
21 |
22 | If you agree with the full CLA provided in the next paragraph, copy the following statement in your PR, changing your Github Handle:
23 |
24 | > I, {your GitHub handle}, confirm that I have read and understood the terms of the CLA of Kyutai-labs, as outlined in the repository's CONTRIBUTING.md, and I agree to be bound by these terms.
25 |
26 | The full CLA is provided as follows:
27 |
28 | > I, {your GitHub handle}, hereby grant to Kyutai-labs a perpetual, worldwide, non-exclusive, royalty-free,
29 | > irrevocable license to use, modify, distribute, and sublicense my Contributions.
30 |
31 | > I understand and accept that Contributions are limited to modifications, improvements, or changes
32 | > to the project’s source code submitted via pull requests. I accept that Kyutai-labs has full discretion to
33 | > review, accept, reject, or request changes to any Contributions I submit, and that submitting
34 | > a pull request does not guarantee its inclusion in the project.
35 |
36 | > By submitting a Contribution, I grant Kyutai-labs a perpetual, worldwide license to use, modify,
37 | > reproduce, distribute, and create derivative works based on my Contributions.
38 | > I also agree to assign all patent rights for any inventions or improvements that arise from my Contributions,
39 | > giving the Kyutai-labs full rights to file for and enforce patents.
40 | > I understand that the Kyutai-labs may commercialize, relicense, or exploit the project and my Contributions without further notice or obligation to me.
41 | > I confirm that my Contributions are original and that I have the legal right to grant this license.
42 | > If my Contributions include third-party materials, I will ensure that I have the necessary permissions
43 | > and will disclose this information. I accept that once my Contributions are integrated, they may be altered or removed at the Kyutai-labs’s discretion.
44 |
45 | > I acknowledge that I am making these Contributions voluntarily and will not receive any compensation.
46 | > Furthermore, I understand that all Contributions, including mine, are provided on an "as-is" basis, with no warranties.
47 | > By submitting a pull request, I agree to be bound by these terms.
48 |
49 | ## Issues
50 |
51 | Please submit issues on our Github repository.
52 |
53 | ## License
54 |
55 | By contributing to Moshi, you agree that your contributions will be licensed
56 | under the LICENSE-* files in the root directory of this source tree.
57 | In particular, the rust code is licensed under APACHE, and the python code under MIT.
58 |
--------------------------------------------------------------------------------
/FAQ.md:
--------------------------------------------------------------------------------
1 | # FAQ
2 |
3 | Here is the answer to a number of frequently asked questions.
4 |
5 | ### Will you release training code?
6 |
7 | Some finetuning code can be found in the [kyutai-labs/moshi-finetune repo](https://github.com/kyutai-labs/moshi-finetune).
8 |
9 | ### Will you release the dataset?
10 |
11 | We will not release the pre-training dataset.
12 |
13 | ### Is Moshi multilingual?
14 |
15 | At the moment no. Moshi only speaks English. It has some basic support for translating some sentences
16 | or words to other languages, but you shouldn't expect to use it fully in any other language than English.
17 |
18 | ### Can I change Moshi's voice / personality?
19 |
20 | This would require fine tuning, which is not currently supported.
21 |
22 | ### Can Moshi run on a M1, or smaller GPUs?
23 |
24 | Sadly we do not think this is currently possible. Quantizing beyond 4 bits lead to dramatic
25 | decrease in quality, see [PR #58](https://github.com/kyutai-labs/moshi/pull/58).
26 | While we keep those limitations in mind for future versions, there is no immediate solution.
27 |
28 | ### Can we run quantized Moshi with PyTorch?
29 |
30 | At the moment no, we might look into adding this feature when we get the time. At the moment
31 | it is however possible to use the Rust backend, which should run in int8 with CUDA.
32 |
33 | ### Moshi stopped talking after 5 min.
34 |
35 | This is expected on the MLX and Rust implementation.
36 | We only use a fixed buffer, and we do not discard past entries.
37 | The PyTorch version should work for unlimited times, although this is mostly untested and we
38 | expect the quality to degrade after a bit (we have no attention sink or other mechanism to improve the streaming
39 | beyond the finite context used at training).
40 |
41 | ### The server seems to be running but nothing happens on connect.
42 |
43 | For diagnosis, look at your browser console if there is any error being
44 | reported.
45 |
46 | If you see issues that look like the following:
47 | ```
48 | Uncaught (in promise) TypeError: Cannot read properties of undefined (reading 'addModule')
49 | ```
50 | this is likely caused by the http server being remote and audio being disabled
51 | for http in such a case.
52 |
53 | To get around this, tunnel the 8998 port from the remote server to the localhost
54 | via ssh and access [localhost:8998](http://localhost:8998) via http normally
55 | after that.
56 |
57 | ### How to get the key.pem and cert.pem files required for serving over https?
58 | ```bash
59 | openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -days 365 -nodes -subj "/CN=localhost"
60 | ```
61 |
62 | ### Can I run on a 12GB / 8 GB GPU ?
63 | For a 12GB GPU, this is possible following instructions in [issue #54](https://github.com/kyutai-labs/moshi/issues/54).
64 | For 8GB GPU, this is not possible at the moment.
65 |
--------------------------------------------------------------------------------
/LICENSE-MIT:
--------------------------------------------------------------------------------
1 | Permission is hereby granted, free of charge, to any
2 | person obtaining a copy of this software and associated
3 | documentation files (the "Software"), to deal in the
4 | Software without restriction, including without
5 | limitation the rights to use, copy, modify, merge,
6 | publish, distribute, sublicense, and/or sell copies of
7 | the Software, and to permit persons to whom the Software
8 | is furnished to do so, subject to the following
9 | conditions:
10 |
11 | The above copyright notice and this permission notice
12 | shall be included in all copies or substantial portions
13 | of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
16 | ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
17 | TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
18 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
19 | SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
20 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
21 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
22 | IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
23 | DEALINGS IN THE SOFTWARE.
24 |
--------------------------------------------------------------------------------
/client/.env.local:
--------------------------------------------------------------------------------
1 | VITE_QUEUE_API_PATH=/api
2 | VITE_QUEUE_API_URL=https://moshi.chat
--------------------------------------------------------------------------------
/client/.eslinrc.json:
--------------------------------------------------------------------------------
1 | {
2 | "env": {
3 | "browser": true,
4 | "es2021": true
5 | },
6 | "extends": [
7 | "plugin:react/recommended",
8 | "standard-with-typescript",
9 | "plugin:import/typescript",
10 | "plugin:prettier/recommended"
11 | ],
12 | "parser": "@typescript-eslint/parser",
13 | "overrides": [],
14 | "parserOptions": {
15 | "ecmaVersion": "latest",
16 | "sourceType": "module",
17 | "project": "./tsconfig.json"
18 | },
19 | "plugins": ["react", "prettier"],
20 | "rules": {
21 | "@typescript-eslint/triple-slash-reference": "off"
22 | }
23 | }
24 |
--------------------------------------------------------------------------------
/client/.nvmrc:
--------------------------------------------------------------------------------
1 | v20.12.2
2 |
--------------------------------------------------------------------------------
/client/.prettierignore:
--------------------------------------------------------------------------------
1 | dist/*
--------------------------------------------------------------------------------
/client/.prettierrc.json:
--------------------------------------------------------------------------------
1 | {
2 | "arrowParens": "avoid",
3 | "singleQuote": false,
4 | "trailingComma": "all",
5 | "tabWidth": 2,
6 | "useTabs": false,
7 | "semi": true,
8 | "printWidth": 80,
9 | "plugins": ["prettier-plugin-tailwindcss"]
10 | }
11 |
--------------------------------------------------------------------------------
/client/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM node:20
2 |
3 | WORKDIR /app
4 |
5 | COPY . /app
6 |
7 | RUN npm install
8 |
9 | RUN npm run build
10 |
11 | # Install OpenSSL
12 | RUN apt-get update && apt-get install -y openssl
13 |
14 | # Generate self-signed SSL certificate
15 | RUN openssl req -x509 -nodes -days 365 -newkey rsa:2048 \
16 | -keyout /app/key.pem -out /app/cert.pem \
17 | -subj "/C=US/ST=State/L=City/O=Organization/CN=localhost"
18 |
19 | EXPOSE 5173
20 |
21 | CMD ["npm", "run", "dev"]
22 |
--------------------------------------------------------------------------------
/client/LICENSE:
--------------------------------------------------------------------------------
1 | Permission is hereby granted, free of charge, to any
2 | person obtaining a copy of this software and associated
3 | documentation files (the "Software"), to deal in the
4 | Software without restriction, including without
5 | limitation the rights to use, copy, modify, merge,
6 | publish, distribute, sublicense, and/or sell copies of
7 | the Software, and to permit persons to whom the Software
8 | is furnished to do so, subject to the following
9 | conditions:
10 |
11 | The above copyright notice and this permission notice
12 | shall be included in all copies or substantial portions
13 | of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
16 | ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
17 | TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
18 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
19 | SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
20 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
21 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
22 | IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
23 | DEALINGS IN THE SOFTWARE.
24 |
--------------------------------------------------------------------------------
/client/README.md:
--------------------------------------------------------------------------------
1 | # moshi-client
2 |
3 | Frontend for the demo.
4 |
5 | ## Run the client
6 |
7 | - Node is required, I recommend using [NVM](https://github.com/nvm-sh/nvm) to help you manage your node version and make sure you're on the recommended version for this project. If you do so run `nvm use`.
8 | - Generate a public/private key pair, `cert.pem` and `key.pem` and copy it to the at the root of this package
9 | - Create an env.local file and add your an entry for `VITE_QUEUE_API_PATH` (default should be `/api`)
10 | - Before running the project for the time or after dependencies update use `npm install`
11 | - To run the project use `npm run dev`
12 | - To build the project use `npm run build`
13 |
14 | ## Skipping the queue
15 | To skip the queue for standalone use, once the project is running go to `/?worker_addr={WORKER_ADDR}` where `WORKER_ADDR` is your worker instance address.
16 | For example : `https://localhost:5173/?worker_addr=0.0.0.0:8088`
17 |
18 | ## License
19 |
20 | The present code is provided under the MIT license.
21 |
--------------------------------------------------------------------------------
/client/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 | moshi.chat
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/client/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "kyutai-client",
3 | "private": true,
4 | "version": "0.0.0",
5 | "type": "module",
6 | "scripts": {
7 | "dev": "vite",
8 | "build": "tsc && vite build",
9 | "lint": "eslint",
10 | "lint:fix": "eslint --fix",
11 | "prettier": "prettier --write .",
12 | "preview": "vite preview"
13 | },
14 | "devDependencies": {
15 | "@eslint/js": "^9.3.0",
16 | "@types/react": "^18.3.1",
17 | "@types/react-dom": "^18.3.0",
18 | "@types/ws": "^8.5.10",
19 | "autoprefixer": "^10.4.19",
20 | "daisyui": "^4.12.2",
21 | "eslint": "^8.57.0",
22 | "eslint-config-prettier": "^9.1.0",
23 | "eslint-plugin-prettier": "^5.1.3",
24 | "eslint-plugin-react": "^7.34.1",
25 | "globals": "^15.2.0",
26 | "postcss": "^8.4.38",
27 | "prettier": "^3.2.5",
28 | "prettier-eslint": "^16.3.0",
29 | "prettier-plugin-tailwindcss": "^0.5.14",
30 | "tailwindcss": "^3.4.3",
31 | "typescript": "^5.2.2",
32 | "typescript-eslint": "^7.9.0",
33 | "vite": "^5.2.14",
34 | "vite-plugin-top-level-await": "^1.4.1"
35 | },
36 | "dependencies": {
37 | "eruda": "^3.0.1",
38 | "opus-recorder": "^8.0.5",
39 | "react": "^18.3.1",
40 | "react-dom": "^18.3.1",
41 | "react-router-dom": "^6.23.1",
42 | "webm-duration-fix": "^1.0.4",
43 | "ws": "^8.16.0",
44 | "zod": "^3.23.8"
45 | }
46 | }
47 |
--------------------------------------------------------------------------------
/client/postcss.config.js:
--------------------------------------------------------------------------------
1 | export default {
2 | plugins: {
3 | tailwindcss: {},
4 | autoprefixer: {},
5 | },
6 | };
7 |
--------------------------------------------------------------------------------
/client/public/assets/decoderWorker.min.wasm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/decoderWorker.min.wasm
--------------------------------------------------------------------------------
/client/public/assets/favicon-16x16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/favicon-16x16.png
--------------------------------------------------------------------------------
/client/public/assets/favicon-32x32.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/favicon-32x32.png
--------------------------------------------------------------------------------
/client/public/assets/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/favicon.ico
--------------------------------------------------------------------------------
/client/public/assets/images/demo/image1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/images/demo/image1.jpg
--------------------------------------------------------------------------------
/client/public/assets/images/demo/image10.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/images/demo/image10.jpg
--------------------------------------------------------------------------------
/client/public/assets/images/demo/image11.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/images/demo/image11.jpg
--------------------------------------------------------------------------------
/client/public/assets/images/demo/image12.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/images/demo/image12.jpg
--------------------------------------------------------------------------------
/client/public/assets/images/demo/image13.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/images/demo/image13.jpg
--------------------------------------------------------------------------------
/client/public/assets/images/demo/image14.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/images/demo/image14.jpg
--------------------------------------------------------------------------------
/client/public/assets/images/demo/image15.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/images/demo/image15.jpg
--------------------------------------------------------------------------------
/client/public/assets/images/demo/image16.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/images/demo/image16.jpg
--------------------------------------------------------------------------------
/client/public/assets/images/demo/image17.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/images/demo/image17.jpg
--------------------------------------------------------------------------------
/client/public/assets/images/demo/image18.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/images/demo/image18.jpg
--------------------------------------------------------------------------------
/client/public/assets/images/demo/image19.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/images/demo/image19.jpg
--------------------------------------------------------------------------------
/client/public/assets/images/demo/image2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/images/demo/image2.jpg
--------------------------------------------------------------------------------
/client/public/assets/images/demo/image20.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/images/demo/image20.jpg
--------------------------------------------------------------------------------
/client/public/assets/images/demo/image3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/images/demo/image3.jpg
--------------------------------------------------------------------------------
/client/public/assets/images/demo/image4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/images/demo/image4.jpg
--------------------------------------------------------------------------------
/client/public/assets/images/demo/image5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/images/demo/image5.jpg
--------------------------------------------------------------------------------
/client/public/assets/images/demo/image6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/images/demo/image6.jpg
--------------------------------------------------------------------------------
/client/public/assets/images/demo/image7.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/images/demo/image7.jpg
--------------------------------------------------------------------------------
/client/public/assets/images/demo/image8.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/images/demo/image8.jpg
--------------------------------------------------------------------------------
/client/public/assets/images/demo/image9.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/public/assets/images/demo/image9.jpg
--------------------------------------------------------------------------------
/client/public/assets/logo.svg:
--------------------------------------------------------------------------------
1 |
10 |
--------------------------------------------------------------------------------
/client/src/app.tsx:
--------------------------------------------------------------------------------
1 | import ReactDOM from "react-dom/client";
2 | import {
3 | createBrowserRouter,
4 | RouterProvider,
5 | } from "react-router-dom";
6 | import "./index.css";
7 | // @ts-expect-error - Worker is not recognized by the TS compiler
8 | import { DecoderWorker } from "./decoder/decoderWorker";
9 | import { Queue } from "./pages/Queue/Queue";
10 |
11 | const router = createBrowserRouter([
12 | {
13 | path: "/",
14 | element: ,
15 | },
16 | ]);
17 |
18 | ReactDOM.createRoot(document.getElementById("root") as HTMLElement).render(
19 |
20 | );
21 |
--------------------------------------------------------------------------------
/client/src/components/Button/Button.tsx:
--------------------------------------------------------------------------------
1 | import { FC } from "react";
2 |
3 | type ButtonProps = React.ButtonHTMLAttributes;
4 | export const Button: FC = ({ children, className, ...props }) => {
5 | return (
6 |
12 | );
13 | };
14 |
15 |
16 | export const SwitchButton: FC = ({ children, className, ...props }) => {
17 | return (
18 |
24 | );
25 | };
26 |
--------------------------------------------------------------------------------
/client/src/components/ImageGallery/ImageGallery.tsx:
--------------------------------------------------------------------------------
1 |
2 | import { useState } from "react";
3 | import { Button } from "../Button/Button";
4 |
5 | // Natural images
6 | import img1 from "/assets/images/demo/image1.jpg";
7 | import img2 from "/assets/images/demo/image2.jpg";
8 | import img3 from "/assets/images/demo/image3.jpg";
9 | import img4 from "/assets/images/demo/image4.jpg";
10 | import img5 from "/assets/images/demo/image5.jpg";
11 | import img6 from "/assets/images/demo/image6.jpg";
12 | import img7 from "/assets/images/demo/image7.jpg";
13 | import img8 from "/assets/images/demo/image8.jpg";
14 | import img9 from "/assets/images/demo/image9.jpg";
15 | import img10 from "/assets/images/demo/image10.jpg";
16 | import img11 from "/assets/images/demo/image11.jpg";
17 | import img12 from "/assets/images/demo/image12.jpg";
18 | import img13 from "/assets/images/demo/image13.jpg";
19 | import img14 from "/assets/images/demo/image14.jpg";
20 | import img15 from "/assets/images/demo/image15.jpg";
21 | import img16 from "/assets/images/demo/image16.jpg";
22 | import img17 from "/assets/images/demo/image17.jpg";
23 | import img18 from "/assets/images/demo/image18.jpg";
24 | import img19 from "/assets/images/demo/image19.jpg";
25 | import img20 from "/assets/images/demo/image20.jpg";
26 |
27 | const images = [
28 | img1,
29 | img2,
30 | img3,
31 | img4,
32 | img5,
33 | img6,
34 | img7,
35 | img8,
36 | img9,
37 | img10,
38 | img11,
39 | img12,
40 | img13,
41 | img14,
42 | img15,
43 | img16,
44 | img17,
45 | img18,
46 | img19,
47 | img20,
48 | ]
49 |
50 | var images_order: number[] = [];
51 | for (let i = 0; i < images.length; i++) {
52 | images_order.push(i)
53 | }
54 |
55 | type ImageGalleryProps = React.InputHTMLAttributes & {
56 | // Properties for the ImageGallery
57 | paramsSetter: Function;
58 | clickAction: Function;
59 | size: number;
60 | numImages: number;
61 | }
62 |
63 |
64 | type ImageItemProps = React.InputHTMLAttributes & {
65 | // Properties for a single item in the ImageGallery
66 | // Two actions:
67 | // paramsSetter sets the chosen image url into the model params
68 | // clickAction then starts the conversation
69 | paramsSetter: Function;
70 | clickAction: Function;
71 | size: number;
72 | imageUrl: string;
73 | }
74 |
75 |
76 | function ImageSelect(props: ImageItemProps) {
77 | // Represents a single image in the gallery
78 | const [isHover, setIsHover] = useState(false);
79 |
80 | const handleMouseEnter = () => {
81 | setIsHover(true);
82 | };
83 | const handleMouseLeave = () => {
84 | setIsHover(false);
85 | };
86 | let bordercolor = isHover ? "#f7a319" : "black";
87 | let bgalpha = isHover ? 0.05 : 0.6;
88 | let textalpha = isHover ? 1.0 : 0.0
89 | let label = isHover ? "Connect" : "X";
90 | let style = {
91 | width: props.size,
92 | height: props.size,
93 | background: `url(${props.imageUrl})`,
94 | backgroundSize: "100% 100%",
95 | border: `3px solid ${bordercolor}`,
96 | margin: "2px",
97 | padding: "0px",
98 | color: `rgba(255, 255, 255, ${textalpha})`,
99 | boxShadow: `inset 0 0 0 1000px rgba(0,0,0,${bgalpha})`,
100 | textShadow: `2px 2px 2px rgba(0, 0, 0, ${textalpha})`
101 | };
102 | return (
103 |
106 | );
107 | }
108 |
109 |
110 | const shuffle = (array: number[]) => {
111 | return array.sort(() => Math.random() - 0.5);
112 | };
113 |
114 | export const ImageGallery = (props: ImageGalleryProps) => {
115 | const [ordering, SetOrdering] = useState(images_order);
116 |
117 | function handleShuffle() {
118 | SetOrdering(shuffle([...ordering]));
119 | }
120 |
121 | // Image Gallery widget (random subset)
122 | const steps = [];
123 | for (let i = 0; i < props.numImages; i++) {
124 | steps.push();
127 | }
128 |
129 | return (
130 |
131 |
132 |
133 |
136 |
137 |
{steps}
138 |
)
139 | ;
140 | };
141 |
--------------------------------------------------------------------------------
/client/src/components/Input/Input.tsx:
--------------------------------------------------------------------------------
1 | type InputProps = React.InputHTMLAttributes & {
2 | error?: string;
3 | }
4 |
5 | export const Input = ({className, error, ...props}:InputProps) => {
6 | return (
7 |
8 |
12 | {error &&
{error}
}
13 |
14 | );
15 | }
--------------------------------------------------------------------------------
/client/src/decoder/decoderWorker.ts:
--------------------------------------------------------------------------------
1 | export const DecoderWorker = new Worker(
2 | new URL("/assets/decoderWorker.min.js", import.meta.url),
3 | );
4 |
--------------------------------------------------------------------------------
/client/src/env.ts:
--------------------------------------------------------------------------------
1 | type ENV = {
2 | VITE_QUEUE_API_PATH: string;
3 | VITE_ENV: 'development' | 'production';
4 | };
5 |
6 | const parseEnv = (): ENV => {
7 | const VITE_QUEUE_API_PATH = import.meta.env.VITE_QUEUE_API_PATH;
8 |
9 | if (!VITE_QUEUE_API_PATH) {
10 | throw new Error("VITE_QUEUE_API_PATH is not defined");
11 | }
12 |
13 | return {
14 | VITE_QUEUE_API_PATH,
15 | VITE_ENV: import.meta.env.DEV ? 'development' : 'production',
16 | };
17 | };
18 |
19 | export const env = parseEnv();
20 |
--------------------------------------------------------------------------------
/client/src/index.css:
--------------------------------------------------------------------------------
1 | @tailwind base;
2 | @tailwind components;
3 | @tailwind utilities;
4 |
5 | @layer utilities {
6 |
7 | /* Hide scrollbar for Chrome, Safari and Opera */
8 | .no-scrollbar::-webkit-scrollbar {
9 | display: none;
10 | }
11 |
12 | /* Hide scrollbar for IE, Edge and Firefox */
13 | .no-scrollbar {
14 | -ms-overflow-style: none;
15 | /* IE and Edge */
16 | scrollbar-width: none;
17 | /* Firefox */
18 | }
19 |
20 | .scrollbar::-webkit-scrollbar {
21 | width: 10px;
22 | }
23 |
24 | .scrollbar::-webkit-scrollbar-track {
25 | background: transparent;
26 | }
27 |
28 | .scrollbar::-webkit-scrollbar-thumb {
29 | background: white;
30 | border: 3px solid #f6f7ed;
31 | }
32 | }
33 |
34 | .main-grid {
35 | display: grid;
36 | grid-template-columns: 1fr;
37 | grid-template-rows: min-content 1fr 1fr;
38 | gap: 30px;
39 | grid-auto-flow: column;
40 | grid-template-areas:
41 | "controls"
42 | "player"
43 | "player-text";
44 |
45 | @media screen and (min-width: 768px) {
46 | grid-template-columns: 2fr 2.5fr;
47 | grid-template-rows: min-content min-content min-content 1fr;
48 | gap: 30px 30px;
49 | grid-auto-flow: column;
50 | align-items: center;
51 | justify-items: center;
52 | grid-template-areas:
53 | "controls controls"
54 | "player player-stats"
55 | "player player-text"
56 | "player player-text";
57 | }
58 | }
59 |
60 | .presentation {
61 | max-width: 450px;
62 | }
63 |
64 | .presentation>p {
65 | padding-top: 10px;
66 | }
67 |
68 |
69 | .gallery {
70 | max-width: 450px;
71 | }
72 |
73 | .cute-words {
74 | color: #54e8b3;
75 | }
76 |
77 |
78 | .download-links {
79 | color: #54e8b3;
80 | }
81 |
82 | .explain-links {
83 | color: #BCFCE5;
84 | }
85 |
86 |
87 | .controls {
88 | grid-area: controls;
89 | }
90 |
91 | .player {
92 | grid-area: player;
93 | grid-template-areas:
94 | "server-audio"
95 | "user-audio";
96 | display: grid;
97 | grid-template-columns: 1fr;
98 | grid-template-rows: 1fr 1fr;
99 | align-items: center;
100 | justify-items: center;
101 | /* margin:auto; */
102 | }
103 |
104 | .server-audio {
105 | grid-area: server-audio;
106 | }
107 |
108 | .user-audio {
109 | grid-area: user-audio;
110 | }
111 |
112 | .player-stats {
113 | grid-area: player-stats;
114 | width: 100%;
115 | height: 100%;
116 | }
117 |
118 | .commands {
119 | grid-area: commands;
120 | width: 100%;
121 | height: 100%;
122 | }
123 |
124 | .player-text {
125 | grid-area: player-text;
126 | width: 100%;
127 | height: 100%;
128 | overflow: scroll;
129 | }
--------------------------------------------------------------------------------
/client/src/modules.d.ts:
--------------------------------------------------------------------------------
1 | declare module "opus-recorder";
2 |
--------------------------------------------------------------------------------
/client/src/pages/Conversation/MediaContext.ts:
--------------------------------------------------------------------------------
1 | import { MutableRefObject, createContext, useContext } from "react";
2 | type MediaContextType = {
3 | startRecording: () => void;
4 | stopRecording: () => void;
5 | audioContext: MutableRefObject;
6 | audioStreamDestination: MutableRefObject;
7 | worklet: MutableRefObject;
8 | micDuration: MutableRefObject;
9 | actualAudioPlayed: MutableRefObject;
10 | };
11 |
12 | export const MediaContext = createContext(null);
13 |
14 | export const useMediaContext = () => {
15 | const context = useContext(MediaContext);
16 | if (!context) {
17 | throw new Error(
18 | "useMediaContext must be used within a MediaContextProvider",
19 | );
20 | }
21 |
22 | return context;
23 | };
--------------------------------------------------------------------------------
/client/src/pages/Conversation/SocketContext.ts:
--------------------------------------------------------------------------------
1 | import { createContext, useContext } from "react";
2 | import { WSMessage } from "../../protocol/types";
3 |
4 | type SocketContextType = {
5 | isConnected: boolean;
6 | socket: WebSocket | null;
7 | sendMessage: (message: WSMessage) => void;
8 | };
9 |
10 | export const SocketContext = createContext({
11 | isConnected: false,
12 | socket: null,
13 | sendMessage: () => {},
14 | });
15 |
16 | export const useSocketContext = () => {
17 | return useContext(SocketContext);
18 | };
19 |
--------------------------------------------------------------------------------
/client/src/pages/Conversation/canvas-logo-black.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/src/pages/Conversation/canvas-logo-black.png
--------------------------------------------------------------------------------
/client/src/pages/Conversation/canvas-logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/client/src/pages/Conversation/canvas-logo.png
--------------------------------------------------------------------------------
/client/src/pages/Conversation/components/AudioVisualizer/AudioVisualizer.tsx:
--------------------------------------------------------------------------------
1 | import { FC, useCallback, useEffect, useRef } from "react";
2 |
3 | type AudioVisualizerProps = {
4 | analyser: AnalyserNode | null;
5 | };
6 |
7 | export const AudioVisualizer: FC = ({ analyser }) => {
8 | const requestRef = useRef(null);
9 | const canvasRef = useRef(null);
10 |
11 | const visualizeData = useCallback(() => {
12 | requestRef.current = window.requestAnimationFrame(() => visualizeData());
13 | if (!canvasRef.current) {
14 | console.log("Canvas not found");
15 | return;
16 | }
17 | const audioData = new Uint8Array(140);
18 | analyser?.getByteFrequencyData(audioData);
19 | const bar_width = 3;
20 | let start = 0;
21 | const ctx = canvasRef.current.getContext("2d");
22 | if (!ctx) {
23 | console.log("Canvas context not found");
24 | return;
25 | }
26 | ctx.clearRect(0, 0, canvasRef.current.width, canvasRef.current.height);
27 | for (let i = 0; i < audioData.length; i++) {
28 | start = i * 4;
29 | let gradient = ctx.createLinearGradient(
30 | 0,
31 | 0,
32 | canvasRef.current.width,
33 | canvasRef.current.height,
34 | );
35 | gradient.addColorStop(0.2, "#2392f5");
36 | gradient.addColorStop(0.5, "#fe0095");
37 | gradient.addColorStop(1.0, "purple");
38 | ctx.fillStyle = gradient;
39 | ctx.fillRect(
40 | start,
41 | canvasRef.current.height,
42 | bar_width,
43 | (-audioData[i] * 100) / 255,
44 | );
45 | }
46 | }, [analyser]);
47 |
48 | const resetCanvas = useCallback(() => {
49 | if (!canvasRef.current) {
50 | return;
51 | }
52 | const ctx = canvasRef.current.getContext("2d");
53 | if (!ctx) {
54 | return;
55 | }
56 | ctx.clearRect(0, 0, canvasRef.current.width, canvasRef.current.height);
57 | }, []);
58 |
59 | useEffect(() => {
60 | if (!analyser) {
61 | return;
62 | }
63 | visualizeData();
64 | return () => {
65 | if (requestRef.current) {
66 | console.log("Canceling animation frame");
67 | cancelAnimationFrame(requestRef.current);
68 | }
69 | };
70 | }, [visualizeData, analyser, resetCanvas]);
71 |
72 | return ;
73 | };
74 |
--------------------------------------------------------------------------------
/client/src/pages/Conversation/components/AudioVisualizer/ClientVisualizer.tsx:
--------------------------------------------------------------------------------
1 | import { FC, RefObject, useCallback, useEffect, useRef, useState } from "react";
2 | import { clamp } from "../../hooks/audioUtils";
3 |
4 | type AudioVisualizerProps = {
5 | analyser: AnalyserNode | null;
6 | parent: RefObject;
7 | copyCanvasRef: RefObject;
8 | };
9 |
10 | const MAX_INTENSITY = 255;
11 |
12 | const COLORS = [
13 | "#197556",
14 | "#299e77",
15 | "#32b89b",
16 | "#31d4b8",
17 | "#14d9d5",
18 | "#41eff2",
19 | "#7ff3f5",
20 | "#789bf5",
21 | "#eb94eb",
22 | "#e63280",
23 | "#c41862",
24 | ];
25 |
26 | export const ClientVisualizer: FC = ({ analyser, parent, copyCanvasRef }) => {
27 | const [canvasWidth, setCanvasWidth] = useState(parent.current ? Math.min(parent.current.clientWidth, parent.current.clientHeight) : 0);
28 | const requestRef = useRef(null);
29 | const canvasRef = useRef(null);
30 |
31 | const drawBars = useCallback(
32 | (
33 | ctx: CanvasRenderingContext2D,
34 | x: number,
35 | y: number,
36 | volume: number,
37 | height: number,
38 | width: number,
39 | gap: number,
40 | ) => {
41 | const barHeight = height / 10 - gap;
42 | for (let i = 1; i <= 10; i++) {
43 | const barY = y + height + gap + Math.min(1, width / 30) - (i * barHeight + i * gap);
44 | ctx.fillStyle = COLORS[i - 1];
45 | ctx.strokeStyle = "white";
46 | ctx.lineWidth = Math.min(1, height / 100);
47 | if (i <= volume) {
48 | ctx.fillRect(x, barY, width, barHeight);
49 | }
50 | ctx.strokeRect(x, barY, width, barHeight);
51 | }
52 | },
53 | [],
54 | );
55 |
56 | const draw = useCallback((ctx: CanvasRenderingContext2D, audioData: Uint8Array, x: number, y: number, width: number, height: number) => {
57 | const stereoGap = Math.floor(width / 30);
58 | const barGap = Math.floor(height / 30);
59 | const padding = Math.floor(width / 30);
60 | const maxBarHeight = Math.floor(height - padding * 2);
61 | const maxBarWidth = Math.floor(
62 | width / 2.5 - stereoGap - padding * 2,
63 | );
64 |
65 | const centerX = x + width / 2;
66 | const averageIntensity = Math.sqrt(
67 | audioData.reduce((acc, curr) => acc + curr * curr, 0) / audioData.length,
68 | );
69 | const intensity = clamp(
70 | averageIntensity * 1.4,
71 | averageIntensity,
72 | MAX_INTENSITY,
73 | );
74 | const volume = Math.floor((intensity * 10) / MAX_INTENSITY);
75 | ctx.fillStyle = "rgba(0, 0, 0, 0)";
76 | ctx.fillRect(x, y, width, height);
77 | drawBars(
78 | ctx,
79 | centerX - maxBarWidth - stereoGap / 2,
80 | y,
81 | volume,
82 | maxBarHeight,
83 | maxBarWidth,
84 | barGap,
85 | );
86 | drawBars(
87 | ctx,
88 | centerX + stereoGap / 2,
89 | y,
90 | volume,
91 | maxBarHeight,
92 | maxBarWidth,
93 | barGap,
94 | );
95 | }, [analyser, drawBars]);
96 |
97 | const visualizeData = useCallback(() => {
98 | const width = parent.current ? Math.min(parent.current.clientWidth, parent.current.clientHeight) : 0
99 | if (width !== canvasWidth) {
100 | console.log("Setting canvas width");
101 | setCanvasWidth(width);
102 | }
103 | requestRef.current = window.requestAnimationFrame(() => visualizeData());
104 | if (!canvasRef.current) {
105 | console.log("Canvas not found");
106 | return;
107 | }
108 | const audioData = new Uint8Array(140);
109 | analyser?.getByteFrequencyData(audioData);
110 |
111 | const ctx = canvasRef.current.getContext("2d");
112 | if (!ctx) {
113 | console.log("Canvas context not found");
114 | return;
115 | }
116 | ctx.clearRect(0, 0, canvasRef.current.width, canvasRef.current.height);
117 | draw(ctx, audioData, 0, 0, width, width);
118 | if (copyCanvasRef?.current) {
119 | const copyCtx = copyCanvasRef.current.getContext("2d");
120 | if (copyCtx) {
121 | copyCtx.clearRect(220, 40, 140, 180);
122 | draw(copyCtx, audioData, 220, 40, 140, 180);
123 | }
124 | }
125 | }, [analyser, canvasWidth, drawBars, parent, copyCanvasRef, draw]);
126 |
127 | useEffect(() => {
128 | visualizeData();
129 | return () => {
130 | if (requestRef.current) {
131 | console.log("Canceling animation frame");
132 | cancelAnimationFrame(requestRef.current);
133 | }
134 | };
135 | }, [visualizeData, analyser]);
136 | return (
137 |
143 | );
144 | };
145 |
--------------------------------------------------------------------------------
/client/src/pages/Conversation/components/Controls/Controls.tsx:
--------------------------------------------------------------------------------
1 | import {
2 | controlBOSMessage,
3 | controlEOSMessage,
4 | } from "../../../../protocol/testMessages";
5 | import { useSocketContext } from "../../SocketContext";
6 | import { Button } from "../../../../components/Button/Button";
7 |
8 | export const Controls = () => {
9 | const { sendMessage } = useSocketContext();
10 |
11 | const sendControlBOS = () => {
12 | sendMessage(controlBOSMessage);
13 | };
14 |
15 | const sendControlEOS = () => {
16 | sendMessage(controlEOSMessage);
17 | };
18 | return (
19 |
20 |
23 |
26 |
27 | );
28 | };
29 |
--------------------------------------------------------------------------------
/client/src/pages/Conversation/components/ModelParams/ModelParams.tsx:
--------------------------------------------------------------------------------
1 | import { FC, RefObject } from "react";
2 | import { useModelParams } from "../../hooks/useModelParams";
3 | import { Button } from "../../../../components/Button/Button";
4 |
5 | type ModelParamsProps = {
6 | isConnected: boolean;
7 | isImageMode: boolean;
8 | modal?: RefObject,
9 | } & ReturnType;
10 | export const ModelParams: FC = ({
11 | textTemperature,
12 | textTopk,
13 | audioTemperature,
14 | audioTopk,
15 | padMult,
16 | repetitionPenalty,
17 | repetitionPenaltyContext,
18 | imageResolution,
19 | setTextTemperature,
20 | setTextTopk,
21 | setAudioTemperature,
22 | setAudioTopk,
23 | setPadMult,
24 | setRepetitionPenalty,
25 | setRepetitionPenaltyContext,
26 | setImageResolution,
27 | resetParams,
28 | isConnected,
29 | isImageMode,
30 | modal,
31 | }) => {
32 | return (
33 |
34 |
80 |
81 | {!isConnected && }
82 | {!isConnected && }
83 |
84 |
85 | )
86 | };
87 |
--------------------------------------------------------------------------------
/client/src/pages/Conversation/components/ServerAudio/ServerAudio.tsx:
--------------------------------------------------------------------------------
1 | import { FC, useRef } from "react";
2 | import { AudioStats, useServerAudio } from "../../hooks/useServerAudio";
3 | import { ServerVisualizer } from "../AudioVisualizer/ServerVisualizer";
4 |
5 | type ServerAudioProps = {
6 | setGetAudioStats: (getAudioStats: () => AudioStats) => void;
7 | imageUrl: string | undefined;
8 | copyCanvasRef?: React.RefObject;
9 | };
10 | export const ServerAudio: FC = ({ setGetAudioStats, imageUrl, copyCanvasRef }) => {
11 | const { analyser, hasCriticalDelay, setHasCriticalDelay } = useServerAudio({
12 | setGetAudioStats,
13 | });
14 | const containerRef = useRef(null);
15 | return (
16 | <>
17 | {hasCriticalDelay && (
18 |
19 |
A connection issue has been detected, you've been reconnected
20 |
28 |
29 | )}
30 |
31 |
32 |
33 | >
34 | );
35 | };
36 |
--------------------------------------------------------------------------------
/client/src/pages/Conversation/components/ServerAudio/ServerAudioStats.tsx:
--------------------------------------------------------------------------------
1 | import { useState, useEffect, useRef } from "react";
2 |
3 | type ServerAudioStatsProps = {
4 | getAudioStats: React.MutableRefObject<
5 | () => {
6 | playedAudioDuration: number;
7 | missedAudioDuration: number;
8 | totalAudioMessages: number;
9 | delay: number;
10 | minPlaybackDelay: number;
11 | maxPlaybackDelay: number;
12 | }
13 | >;
14 | };
15 |
16 | export const ServerAudioStats = ({ getAudioStats }: ServerAudioStatsProps) => {
17 | const [audioStats, setAudioStats] = useState(getAudioStats.current());
18 |
19 | const movingAverageSum = useRef(0.);
20 | const movingAverageCount = useRef(0.);
21 | const movingBeta = 0.85;
22 |
23 | let convertMinSecs = (total_secs: number) => {
24 | // convert secs to the format mm:ss.cc
25 | let mins = (Math.floor(total_secs / 60)).toString();
26 | let secs = (Math.floor(total_secs) % 60).toString();
27 | let cents = (Math.floor(100 * (total_secs - Math.floor(total_secs)))).toString();
28 | if (secs.length < 2) {
29 | secs = "0" + secs;
30 | }
31 | if (cents.length < 2) {
32 | cents = "0" + cents;
33 | }
34 | return mins + ":" + secs + "." + cents;
35 | };
36 |
37 | useEffect(() => {
38 | const interval = setInterval(() => {
39 | const newAudioStats = getAudioStats.current();
40 | setAudioStats(newAudioStats);
41 | movingAverageCount.current *= movingBeta;
42 | movingAverageCount.current += (1 - movingBeta) * 1;
43 | movingAverageSum.current *= movingBeta;
44 | movingAverageSum.current += (1 - movingBeta) * newAudioStats.delay;
45 |
46 | }, 141);
47 | return () => {
48 | clearInterval(interval);
49 | };
50 | }, []);
51 |
52 | return (
53 |
54 |
Server Audio Stats
55 |
56 |
57 |
58 | Audio played: |
59 | {convertMinSecs(audioStats.playedAudioDuration)} |
60 |
61 |
62 | Missed audio: |
63 | {convertMinSecs(audioStats.missedAudioDuration)} |
64 |
65 |
66 | Latency: |
67 | {(movingAverageSum.current / movingAverageCount.current).toFixed(3)} |
68 |
69 |
70 | Min/Max buffer: |
71 | {audioStats.minPlaybackDelay.toFixed(3)} / {audioStats.maxPlaybackDelay.toFixed(3)} |
72 |
73 |
74 |
75 |
76 | );
77 | };
78 |
--------------------------------------------------------------------------------
/client/src/pages/Conversation/components/ServerInfo/ServerInfo.tsx:
--------------------------------------------------------------------------------
1 | import { useServerInfo } from "../../hooks/useServerInfo";
2 |
3 | export const ServerInfo = () => {
4 | const { serverInfo } = useServerInfo();
5 | if (!serverInfo) {
6 | return null;
7 | }
8 | return (
9 |
10 | Our server is running on the following configuration:
11 |
Text temperature: {serverInfo.text_temperature}
12 |
Text topk: {serverInfo.text_topk}
13 |
Audio temperature: {serverInfo.audio_temperature}
14 |
Audio topk: {serverInfo.audio_topk}
15 |
Pad mult: {serverInfo.pad_mult}
16 |
Repeat penalty last N: {serverInfo.repetition_penalty_context}
17 |
Repeat penalty: {serverInfo.repetition_penalty}
18 |
LM model file: {serverInfo.lm_model_file}
19 |
Instance name: {serverInfo.instance_name}
20 |
21 | );
22 | };
23 |
--------------------------------------------------------------------------------
/client/src/pages/Conversation/components/TextDisplay/TextDisplay.tsx:
--------------------------------------------------------------------------------
1 | import { FC, useEffect, useRef } from "react";
2 | import { useServerText } from "../../hooks/useServerText";
3 |
4 | type TextDisplayProps = {
5 | containerRef: React.RefObject;
6 | displayColor: boolean | undefined;
7 | };
8 |
9 | // Palette 2: Purple to Green Moshi
10 | // sns.diverging_palette(288, 145, s=90, l=72, n=11)
11 | const textDisplayColors = [
12 | "#d19bf7", "#d7acf6", "#debdf5", "#e4cef4",
13 | "#ebe0f3", "#eef2f0", "#c8ead9", "#a4e2c4",
14 | "#80d9af", "#5bd09a", "#38c886"]
15 |
16 | function clamp_color(v: number) {
17 | return v <= 0
18 | ? 0
19 | : v >= textDisplayColors.length
20 | ? textDisplayColors.length
21 | : v
22 | }
23 |
24 | export const TextDisplay: FC = ({
25 | containerRef, displayColor
26 | }) => {
27 | const { text, textColor } = useServerText();
28 | const currentIndex = text.length - 1;
29 | const prevScrollTop = useRef(0);
30 |
31 | useEffect(() => {
32 | if (containerRef.current) {
33 | prevScrollTop.current = containerRef.current.scrollTop;
34 | containerRef.current.scroll({
35 | top: containerRef.current.scrollHeight,
36 | behavior: "smooth",
37 | });
38 | }
39 | }, [text]);
40 | if (displayColor && (textColor.length == text.length)) {
41 | return (
42 |
43 | {text.map((t, i) => (
44 |
51 | {t}
52 |
53 | ))
54 | }
55 |
56 | );
57 | }
58 | else {
59 | return (
60 |
61 | {text.map((t, i) => (
62 |
66 | {t}
67 |
68 | ))}
69 |
70 | );
71 | };
72 | };
73 |
--------------------------------------------------------------------------------
/client/src/pages/Conversation/components/TextDisplay/TextDisplayStats.tsx:
--------------------------------------------------------------------------------
1 | import { FC } from "react";
2 |
3 | type TextDisplayStatsProps = {
4 | totalTextMessages: number;
5 | };
6 | export const TextDisplayStats: FC = ({
7 | totalTextMessages,
8 | }) => {
9 | return (
10 |
11 |
Text Display Stats
12 |
13 |
14 |
Total messages:
15 |
{totalTextMessages}
16 |
17 |
18 |
19 | );
20 | };
21 |
--------------------------------------------------------------------------------
/client/src/pages/Conversation/components/UserAudio/UserAudio.tsx:
--------------------------------------------------------------------------------
1 | import { FC, useCallback, useEffect, useRef, useState } from "react";
2 | import { useSocketContext } from "../../SocketContext";
3 | import { useUserAudio } from "../../hooks/useUserAudio";
4 | import { ClientVisualizer } from "../AudioVisualizer/ClientVisualizer";
5 |
6 | type UserAudioProps = {
7 | copyCanvasRef: React.RefObject;
8 | };
9 | export const UserAudio: FC = ({ copyCanvasRef }) => {
10 | const [analyser, setAnalyser] = useState(null);
11 | const { sendMessage, isConnected } = useSocketContext();
12 | const containerRef = useRef(null);
13 | const onRecordingStart = useCallback(() => {
14 | console.log("Recording started");
15 | }, []);
16 |
17 | const onRecordingStop = useCallback(() => {
18 | console.log("Recording stopped");
19 | }, []);
20 |
21 | const onRecordingChunk = useCallback(
22 | (chunk: Uint8Array) => {
23 | if (!isConnected) {
24 | return;
25 | }
26 | sendMessage({
27 | type: "audio",
28 | data: chunk,
29 | });
30 | },
31 | [sendMessage, isConnected],
32 | );
33 |
34 | const { startRecordingUser, stopRecording } = useUserAudio({
35 | constraints: {
36 | audio: {
37 | echoCancellation: true,
38 | noiseSuppression: true,
39 | autoGainControl: true,
40 | channelCount: 1,
41 | },
42 | video: false,
43 | },
44 | onDataChunk: onRecordingChunk,
45 | onRecordingStart,
46 | onRecordingStop,
47 | });
48 |
49 | useEffect(() => {
50 | let res: Awaited>;
51 | if (isConnected) {
52 | startRecordingUser().then(result => {
53 | if (result) {
54 | res = result;
55 | setAnalyser(result.analyser);
56 | }
57 | });
58 | }
59 | return () => {
60 | console.log("Stop recording called from somewhere else.");
61 | stopRecording();
62 | res?.source?.disconnect();
63 | };
64 | }, [startRecordingUser, stopRecording, isConnected]);
65 |
66 | return (
67 |
68 |
69 |
70 | );
71 | };
72 |
--------------------------------------------------------------------------------
/client/src/pages/Conversation/components/UserAudio/UserAudioStats.tsx:
--------------------------------------------------------------------------------
1 | import { FC } from "react";
2 |
3 | type UserAudioStatsProps = {
4 | sentMessagesCount: number;
5 | };
6 |
7 | export const UserAudioStats: FC = ({
8 | sentMessagesCount,
9 | }) => {
10 | return (
11 |
12 |
User Audio Stats
13 |
14 |
15 |
Total messages:
16 |
{sentMessagesCount}
17 |
18 |
19 |
20 | );
21 | };
22 |
--------------------------------------------------------------------------------
/client/src/pages/Conversation/getMimeType.ts:
--------------------------------------------------------------------------------
1 | export const mimeTypeCheck = () => {
2 | const types = [
3 | "audio/ogg",
4 | "audio/wav",
5 | "audio/webm;codecs=opus",
6 | "audio/webm;codecs=pcm",
7 | "audio/webm;codecs=pcm_s16le",
8 | "audio/webm;codecs=pcm_f32le",
9 | "audio/mp3",
10 | "audio/aac",
11 | "audio/mp4",
12 | "audio/webm",
13 | "audio/mpeg",
14 | "video/mp4",
15 | "video/webm;codecs=vp9",
16 | "video/webm;codecs=vp8",
17 | "video/webm",
18 | ];
19 | for (const mime of types) {
20 | console.log(mime, MediaRecorder.isTypeSupported(mime));
21 | }
22 | }
23 |
24 | const getVideoMimeType = () => {
25 | if (!MediaRecorder.isTypeSupported){
26 | return "video/mp4";
27 | }
28 | if (MediaRecorder.isTypeSupported("video/webm")) {
29 | return "video/webm";
30 | }
31 | if (MediaRecorder.isTypeSupported("video/mp4")) {
32 | return "video/mp4";
33 | }
34 | console.log("No supported video mime type found")
35 | return "";
36 | };
37 |
38 | const getAudioMimeType = () => {
39 | if (!MediaRecorder.isTypeSupported){
40 | return "audio/mp4";
41 | }
42 | if (MediaRecorder.isTypeSupported("audio/webm")) {
43 | return "audio/webm";
44 | }
45 | if (MediaRecorder.isTypeSupported("audio/mpeg")) {
46 | return "audio/mpeg";
47 | }``
48 | if (MediaRecorder.isTypeSupported("audio/mp4")) {
49 | return "audio/mp4";
50 | }
51 | console.log("No supported audio mime type found")
52 | return "";
53 | }
54 |
55 | export const getMimeType = (type: "audio" | "video") => {
56 | if(type === "audio") {
57 | return getAudioMimeType();
58 | }
59 | return getVideoMimeType();
60 | }
61 |
62 | export const getExtension = (type: "audio" | "video") => {
63 | if(getMimeType(type).includes("mp4")) {
64 | return "mp4";
65 | }
66 | if(getMimeType(type).includes("mpeg")) {
67 | return "mp3";
68 | }
69 | return "webm";
70 | }
--------------------------------------------------------------------------------
/client/src/pages/Conversation/hooks/audioUtils.ts:
--------------------------------------------------------------------------------
1 | export const clamp = (value: number, min: number, max: number) => {
2 | return Math.min(Math.max(value, min), max);
3 | };
4 |
--------------------------------------------------------------------------------
/client/src/pages/Conversation/hooks/useServerInfo.ts:
--------------------------------------------------------------------------------
1 | import { useCallback, useEffect, useState } from "react";
2 | import { useSocketContext } from "../SocketContext";
3 | import { decodeMessage } from "../../../protocol/encoder";
4 | import { z } from "zod";
5 |
6 | const ServersInfoSchema = z.object({
7 | text_temperature: z.number(),
8 | text_topk: z.number(),
9 | audio_temperature: z.number(),
10 | audio_topk: z.number(),
11 | pad_mult: z.number(),
12 | repetition_penalty_context: z.number(),
13 | repetition_penalty: z.number(),
14 | lm_model_file: z.string(),
15 | instance_name: z.string(),
16 | build_info: z.object({
17 | build_timestamp: z.string(),
18 | build_date: z.string(),
19 | git_branch: z.string(),
20 | git_timestamp: z.string(),
21 | git_date: z.string(),
22 | git_hash: z.string(),
23 | git_describe: z.string(),
24 | rustc_host_triple: z.string(),
25 | rustc_version: z.string(),
26 | cargo_target_triple: z.string(),
27 | }),
28 | });
29 |
30 | const parseInfo = (infos: any) => {
31 | const serverInfo = ServersInfoSchema.safeParse(infos);
32 | if (!serverInfo.success) {
33 | console.error(serverInfo.error);
34 | return null;
35 | }
36 | return serverInfo.data;
37 | };
38 |
39 | type ServerInfo = {
40 | text_temperature: number;
41 | text_topk: number;
42 | audio_temperature: number;
43 | audio_topk: number;
44 | pad_mult: number;
45 | repetition_penalty_context: number;
46 | repetition_penalty: number;
47 | lm_model_file: string;
48 | instance_name: string;
49 | build_info: {
50 | build_timestamp: string;
51 | build_date: string;
52 | git_branch: string;
53 | git_timestamp: string;
54 | git_date: string;
55 | git_hash: string;
56 | git_describe: string;
57 | rustc_host_triple: string;
58 | rustc_version: string;
59 | cargo_target_triple: string;
60 | };
61 | }
62 |
63 | export const useServerInfo = () => {
64 | const [serverInfo, setServerInfo] = useState(null);
65 | const { socket } = useSocketContext();
66 |
67 | const onSocketMessage = useCallback((e: MessageEvent) => {
68 | const dataArray = new Uint8Array(e.data);
69 | const message = decodeMessage(dataArray);
70 | if (message.type === "metadata") {
71 | const infos = parseInfo(message.data);
72 | if (infos) {
73 | setServerInfo(infos);
74 | console.log("received metadata", infos);
75 | }
76 | }
77 | }, [setServerInfo]);
78 |
79 | useEffect(() => {
80 | const currentSocket = socket;
81 | if (!currentSocket) {
82 | return;
83 | }
84 | setServerInfo(null);
85 | currentSocket.addEventListener("message", onSocketMessage);
86 | return () => {
87 | currentSocket.removeEventListener("message", onSocketMessage);
88 | };
89 | }, [socket]);
90 |
91 | return { serverInfo };
92 | };
93 |
--------------------------------------------------------------------------------
/client/src/pages/Conversation/hooks/useServerText.ts:
--------------------------------------------------------------------------------
1 | import { useCallback, useEffect, useState } from "react";
2 | import { useSocketContext } from "../SocketContext";
3 | import { decodeMessage } from "../../../protocol/encoder";
4 |
5 | export const useServerText = () => {
6 | const [text, setText] = useState([]);
7 | const [textColor, setTextColor] = useState([]);
8 | const [totalTextMessages, setTotalTextMessages] = useState(0);
9 | const { socket } = useSocketContext();
10 |
11 | const onSocketMessage = useCallback((e: MessageEvent) => {
12 | const dataArray = new Uint8Array(e.data);
13 | const message = decodeMessage(dataArray);
14 | if (message.type === "text") {
15 | setText(text => [...text, message.data]);
16 | setTotalTextMessages(count => count + 1);
17 | } else if (message.type === "coloredtext") {
18 | setText(text => [...text, message.data]);
19 | setTextColor(textColor => [...textColor, message.color]);
20 | setTotalTextMessages(count => count + 1);
21 | }
22 | }, []);
23 |
24 | useEffect(() => {
25 | const currentSocket = socket;
26 | if (!currentSocket) {
27 | return;
28 | }
29 | setText([]);
30 | currentSocket.addEventListener("message", onSocketMessage);
31 | return () => {
32 | currentSocket.removeEventListener("message", onSocketMessage);
33 | };
34 | }, [socket]);
35 |
36 | return { text, textColor, totalTextMessages };
37 | };
38 |
--------------------------------------------------------------------------------
/client/src/pages/Conversation/hooks/useSocket.ts:
--------------------------------------------------------------------------------
1 | import { useState, useEffect, useCallback, useRef } from "react";
2 | import { WSMessage } from "../../../protocol/types";
3 | import { decodeMessage, encodeMessage } from "../../../protocol/encoder";
4 |
5 | export const useSocket = ({
6 | onMessage,
7 | uri,
8 | onDisconnect: onDisconnectProp,
9 | }: {
10 | onMessage?: (message: WSMessage) => void;
11 | uri: string;
12 | onDisconnect?: () => void;
13 | }) => {
14 | const lastMessageTime = useRef(null);
15 | const [isConnected, setIsConnected] = useState(false);
16 | const [socket, setSocket] = useState(null);
17 |
18 | const sendMessage = useCallback(
19 | (message: WSMessage) => {
20 | if (!socket || !isConnected) {
21 | console.log("socket not connected");
22 | return;
23 | }
24 | socket.send(encodeMessage(message));
25 | },
26 | [isConnected],
27 | );
28 |
29 | const onConnect = useCallback(() => {
30 | console.log("connected, now waiting for handshake.");
31 | // setIsConnected(true);
32 | }, [setIsConnected]);
33 |
34 | const onDisconnect = useCallback(() => {
35 | console.log("disconnected");
36 | if (onDisconnectProp) {
37 | onDisconnectProp();
38 | }
39 | setIsConnected(false);
40 | }, [onDisconnectProp]);
41 |
42 | const onMessageEvent = useCallback(
43 | (eventData: MessageEvent) => {
44 | lastMessageTime.current = Date.now();
45 | const dataArray = new Uint8Array(eventData.data);
46 | const message = decodeMessage(dataArray);
47 | if (message.type == "handshake") {
48 | console.log("Handshake received, let's rocknroll.");
49 | setIsConnected(true);
50 | }
51 | if (!onMessage) {
52 | return;
53 | }
54 | onMessage(message);
55 | },
56 | [onMessage, setIsConnected],
57 | );
58 |
59 | const start = useCallback(() => {
60 | const ws = new WebSocket(uri);
61 | ws.binaryType = "arraybuffer";
62 | ws.addEventListener("open", onConnect);
63 | ws.addEventListener("close", onDisconnect);
64 | ws.addEventListener("message", onMessageEvent);
65 | setSocket(ws);
66 | console.log("Socket created", ws);
67 | lastMessageTime.current = Date.now();
68 | }, [uri, onMessage, onDisconnectProp]);
69 |
70 | const stop = useCallback(() => {
71 | setIsConnected(false);
72 | if (onDisconnectProp) {
73 | onDisconnectProp();
74 | }
75 | socket?.close();
76 | setSocket(null);
77 | }, [socket]);
78 |
79 | useEffect(() => {
80 | if(!isConnected){
81 | return;
82 | }
83 | let intervalId = setInterval(() => {
84 | if (lastMessageTime.current && Date.now() - lastMessageTime.current > 10000) {
85 | console.log("closing socket due to inactivity", socket);
86 | socket?.close();
87 | onDisconnect();
88 | clearInterval(intervalId);
89 | }
90 | }, 500);
91 |
92 | return () => {
93 | lastMessageTime.current = null;
94 | clearInterval(intervalId);
95 | };
96 | }, [isConnected, socket]);
97 |
98 | return {
99 | isConnected,
100 | socket,
101 | sendMessage,
102 | start,
103 | stop,
104 | };
105 | };
106 |
--------------------------------------------------------------------------------
/client/src/pages/Queue/api/client.ts:
--------------------------------------------------------------------------------
1 | import { APIError } from "./errors/api_error";
2 | import { ResponseError } from "./errors/response_error";
3 | import { validateAddUser, validateCheckUser } from "./validators";
4 |
5 | export const getAPIClient = (url:string) => ({
6 | addUser: async (queueId:string) => {
7 | const encodedQueueId = encodeURIComponent(queueId);
8 | const response = await fetch(`${url}/add_user?queue_id=${encodedQueueId}`);
9 | if (!response.ok) {
10 | const errorText = await response.text();
11 | throw new APIError(errorText , response.status);
12 | }
13 | const json = await response.json();
14 | const result = validateAddUser(json);
15 | if(result.success) {
16 | return result.data;
17 | }
18 | console.error(result.error.message);
19 | throw new ResponseError("Failed to validate response");
20 |
21 | },
22 | checkUser: async (sessionId:number, sessionAuthId:string) => {
23 | const encodedSessionAuthId = encodeURIComponent(sessionAuthId);
24 | const encodedSessionId = encodeURIComponent(sessionId);
25 | const response = await fetch(`${url}/check_user?session_id=${encodedSessionId}&session_auth_id=${encodedSessionAuthId}`);
26 | if (!response.ok) {
27 | const errorText = await response.text();
28 | throw new APIError(errorText , response.status);
29 | }
30 | const json = await response.json();
31 | const result = validateCheckUser(json);
32 | if(result.success) {
33 | return result.data;
34 | }
35 | console.error(result.error.message);
36 | throw new ResponseError("Failed to validate response");
37 | },
38 | addFeedback: async ({
39 | workerAuthId,
40 | sessionId,
41 | sessionAuthId,
42 | feedback,
43 | timestamp,
44 | email
45 | }:{
46 | workerAuthId:string;
47 | sessionId:number;
48 | sessionAuthId:string;
49 | feedback:0|1;
50 | timestamp:number;
51 | email:string;
52 |
53 | } ) => {
54 | const encodedWorkerAuthId = encodeURIComponent(workerAuthId);
55 | const encodedSessionAuthId = encodeURIComponent(sessionAuthId);
56 | const encodedSessionId = encodeURIComponent(sessionId);
57 | const encodedFeedback = encodeURIComponent(feedback);
58 | const encodedTimestamp = encodeURIComponent(timestamp);
59 | const encodedEmail = encodeURIComponent(email);
60 | const response = await fetch(`${url}/user_feedback?worker_auth_id=${encodedWorkerAuthId}&session_id=${encodedSessionId}&session_auth_id=${encodedSessionAuthId}&feedback=${encodedFeedback}×tamp=${encodedTimestamp}&email=${encodedEmail}`);
61 | if (!response.ok) {
62 | const errorText = await response.text();
63 | throw new APIError(errorText , response.status);
64 | }
65 | return response.json();
66 | }
67 | });
68 |
--------------------------------------------------------------------------------
/client/src/pages/Queue/api/errors/api_error.ts:
--------------------------------------------------------------------------------
1 | export class APIError extends Error {
2 | status:number;
3 |
4 | constructor(message:string, status:number) {
5 | super(message);
6 | this.status = status;
7 | this.name = "APIError";
8 | }
9 | }
10 |
--------------------------------------------------------------------------------
/client/src/pages/Queue/api/errors/response_error.ts:
--------------------------------------------------------------------------------
1 | export class ResponseError extends Error {
2 | constructor(message:string) {
3 | super(message);
4 | this.name = "ResponseError";
5 | }
6 | }
7 |
--------------------------------------------------------------------------------
/client/src/pages/Queue/api/validators.ts:
--------------------------------------------------------------------------------
1 | import { z } from "zod"
2 |
3 | export const validateAddUser = (response: unknown) => {
4 | const AddUser = z.object({
5 | session_id: z.number(),
6 | session_auth_id: z.string(),
7 | });
8 | return AddUser.safeParse(response);
9 | };
10 |
11 | export const validateCheckUser = (response: unknown) => {
12 | const CheckUser = z.object({
13 | session_id: z.number(),
14 | // TODO: add more statuses
15 | status: z.enum(['wait', 'ready']),
16 | worker_auth_id: z.string().nullable(),
17 | worker_addr: z.string().nullable(),
18 | current_position: z.string(),
19 | });
20 | return CheckUser.safeParse(response);
21 | }
--------------------------------------------------------------------------------
/client/src/pages/Queue/hooks/useUserEmail.ts:
--------------------------------------------------------------------------------
1 | import { useCallback, useState } from "react";
2 | import {z} from "zod";
3 |
4 | const validateEmail = z.string().email();
5 |
6 | export const useUserEmail = (isBypass: boolean) => {
7 | const [userEmail, setUserEmail] = useState('');
8 | const [error, setError] = useState(null);
9 |
10 | const validate = useCallback((email: string) => {
11 | if(isBypass) {
12 | setError(null);
13 | return true;
14 | }
15 | const result = validateEmail.safeParse(email);
16 | if(result.success) {
17 | setError(null);
18 | return true;
19 | }
20 | setError('Invalid email address');
21 | return false;
22 | }, [setError]);
23 | return {userEmail, setUserEmail, error, validate};
24 | }
25 |
--------------------------------------------------------------------------------
/client/src/protocol/encoder.ts:
--------------------------------------------------------------------------------
1 | import {
2 | CONTROL_MESSAGE,
3 | CONTROL_MESSAGES_MAP,
4 | MODELS_MAP,
5 | WSMessage,
6 | VERSIONS_MAP,
7 | } from "./types";
8 |
9 | export const encodeMessage = (message: WSMessage): Uint8Array => {
10 | switch (message.type) {
11 | case "handshake":
12 | return new Uint8Array([
13 | 0x00,
14 | VERSIONS_MAP[message.version],
15 | MODELS_MAP[message.model],
16 | ]);
17 | case "audio":
18 | return new Uint8Array([0x01, ...message.data]);
19 | case "text":
20 | return new Uint8Array([0x02, ...new TextEncoder().encode(message.data)]);
21 | case "coloredtext":
22 | return new Uint8Array([0x02, 0x05, ...new TextEncoder().encode(message.data)]);
23 | case "control":
24 | return new Uint8Array([0x03, CONTROL_MESSAGES_MAP[message.action]]);
25 | case "metadata":
26 | return new Uint8Array([
27 | 0x04,
28 | ...new TextEncoder().encode(JSON.stringify(message.data)),
29 | ]);
30 | case "error":
31 | return new Uint8Array([0x05, ...new TextEncoder().encode(message.data)]);
32 | case "ping":
33 | return new Uint8Array([0x06]);
34 | }
35 | };
36 |
37 | export const decodeMessage = (data: Uint8Array): WSMessage => {
38 | const type = data[0];
39 | const payload = data.slice(1);
40 | switch (type) {
41 | case 0x00: {
42 | return {
43 | type: "handshake",
44 | version: 0,
45 | model: 0,
46 | };
47 | }
48 | case 0x01:
49 | return {
50 | type: "audio",
51 | data: payload,
52 | };
53 | case 0x02:
54 | return {
55 | type: "text",
56 | data: new TextDecoder().decode(payload),
57 | };
58 | case 0x07:
59 | return {
60 | type: "coloredtext",
61 | color: payload[0],
62 | data: new TextDecoder().decode(payload.slice(1)),
63 | };
64 | case 0x03: {
65 | const action = Object.keys(CONTROL_MESSAGES_MAP).find(
66 | key => CONTROL_MESSAGES_MAP[key as CONTROL_MESSAGE] === payload[0],
67 | ) as CONTROL_MESSAGE | undefined;
68 |
69 | //TODO: log this and don't throw
70 | if (!action) {
71 | throw new Error("Unknown control message");
72 | }
73 | return {
74 | type: "control",
75 | action,
76 | };
77 | }
78 | case 0x04:
79 | return {
80 | type: "metadata",
81 | data: JSON.parse(new TextDecoder().decode(payload)),
82 | }
83 | case 0x05:
84 | return {
85 | type: "error",
86 | data: new TextDecoder().decode(payload),
87 | }
88 | case 0x06:
89 | return {
90 | type: "ping",
91 | }
92 | default: {
93 | console.log(type);
94 | throw new Error("Unknown message type");
95 | }
96 | }
97 | };
98 |
--------------------------------------------------------------------------------
/client/src/protocol/testMessages.ts:
--------------------------------------------------------------------------------
1 | import { WSMessage } from "./types";
2 |
3 | export const handshakeMessage: WSMessage = {
4 | type: "handshake",
5 | version: 0,
6 | model: 0,
7 | };
8 |
9 | export const audioMessage: WSMessage = {
10 | type: "audio",
11 | data: new Uint8Array(10),
12 | };
13 |
14 | export const textMessage: WSMessage = {
15 | type: "text",
16 | data: "Hello",
17 | };
18 |
19 | export const controlBOSMessage: WSMessage = {
20 | type: "control",
21 | action: "start",
22 | };
23 |
24 | export const controlEOSMessage: WSMessage = {
25 | type: "control",
26 | action: "endTurn",
27 | };
28 |
29 | export const metadataMessage: WSMessage = {
30 | type: "metadata",
31 | data: { key: "value" },
32 | };
33 |
--------------------------------------------------------------------------------
/client/src/protocol/types.ts:
--------------------------------------------------------------------------------
1 | export type MessageType =
2 | | "handshake"
3 | | "audio"
4 | | "text"
5 | | "coloredtext"
6 | | "control"
7 | | "metadata";
8 |
9 | export const VERSIONS_MAP = {
10 | 0: 0b00000000,
11 | } as const;
12 |
13 | export const MODELS_MAP = {
14 | 0: 0b00000000,
15 | } as const;
16 |
17 | export type VERSION = keyof typeof VERSIONS_MAP;
18 |
19 | export type MODEL = keyof typeof MODELS_MAP;
20 |
21 | export type WSMessage =
22 | | {
23 | type: "handshake";
24 | version: VERSION;
25 | model: MODEL;
26 | }
27 | | {
28 | type: "audio";
29 | data: Uint8Array;
30 | }
31 | | {
32 | type: "text";
33 | data: string;
34 | }
35 | | {
36 | type: "coloredtext";
37 | color: number;
38 | data: string;
39 | }
40 | | {
41 | type: "control";
42 | action: CONTROL_MESSAGE;
43 | }
44 | | {
45 | type: "metadata";
46 | data: unknown;
47 | }
48 | | {
49 | type: "error";
50 | data: string;
51 | }
52 | | {
53 | type: "ping";
54 | }
55 |
56 | export const CONTROL_MESSAGES_MAP = {
57 | start: 0b00000000,
58 | endTurn: 0b00000001,
59 | pause: 0b00000010,
60 | restart: 0b00000011,
61 | } as const;
62 |
63 | export type CONTROL_MESSAGE = keyof typeof CONTROL_MESSAGES_MAP;
64 |
--------------------------------------------------------------------------------
/client/tailwind.config.js:
--------------------------------------------------------------------------------
1 | /** @type {import('tailwindcss').Config} */
2 |
3 | export default {
4 | content: ["./src/**/*.{js,jsx,ts,tsx}", "./index.html"],
5 | theme: {
6 | extend: {},
7 | },
8 | plugins: [require('daisyui')],
9 | };
10 |
--------------------------------------------------------------------------------
/client/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | "target": "ES2020",
4 | "useDefineForClassFields": true,
5 | "module": "ESNext",
6 | "lib": [
7 | "ES2020",
8 | "DOM",
9 | "DOM.Iterable"
10 | ],
11 | "skipLibCheck": true,
12 | "outDir": "dist",
13 | /* Bundler mode */
14 | "moduleResolution": "bundler",
15 | "allowImportingTsExtensions": true,
16 | "resolveJsonModule": true,
17 | "isolatedModules": true,
18 | "noEmit": true,
19 | "jsx": "react-jsx",
20 | /* Linting */
21 | "strict": true,
22 | "noUnusedLocals": true,
23 | "noUnusedParameters": true,
24 | "noFallthroughCasesInSwitch": true,
25 | "types": [
26 | "vite/client"
27 | ]
28 | },
29 | "include": [
30 | "src"
31 | ]
32 | }
--------------------------------------------------------------------------------
/client/vite.config.ts:
--------------------------------------------------------------------------------
1 | import { ProxyOptions, defineConfig, loadEnv } from "vite";
2 | import topLevelAwait from "vite-plugin-top-level-await";
3 |
4 | export default defineConfig(({ mode }) => {
5 | const env = loadEnv(mode, process.cwd());
6 | const proxyConf: Record = env.VITE_QUEUE_API_URL ? {
7 | "/api": {
8 | target: env.VITE_QUEUE_API_URL,
9 | changeOrigin: true,
10 | },
11 | } : {};
12 | return {
13 | server: {
14 | host: "0.0.0.0",
15 | https: {
16 | cert: "./cert.pem",
17 | key: "./key.pem",
18 | },
19 | proxy: {
20 | ...proxyConf,
21 | }
22 | },
23 | plugins: [
24 | topLevelAwait({
25 | // The export name of top-level await promise for each chunk module
26 | promiseExportName: "__tla",
27 | // The function to generate import names of top-level await promise in each chunk module
28 | promiseImportName: i => `__tla_${i}`,
29 | }),
30 | ],
31 | };
32 | });
33 |
--------------------------------------------------------------------------------
/configs/moshi_7b_202409.json:
--------------------------------------------------------------------------------
1 | {
2 | "dim": 4096,
3 | "text_card": 32000,
4 | "existing_text_padding_id": 3,
5 | "n_q": 16,
6 | "dep_q": 8,
7 | "card": 2048,
8 | "num_heads": 32,
9 | "num_layers": 32,
10 | "hidden_scale": 4.125,
11 | "causal": true,
12 | "layer_scale": null,
13 | "context": 3000,
14 | "max_period": 10000,
15 | "gating": "silu",
16 | "norm": "rms_norm_f32",
17 | "positional_embedding": "rope",
18 | "depformer_dim": 1024,
19 | "depformer_dim_feedforward": 4224,
20 | "depformer_num_heads": 16,
21 | "depformer_num_layers": 6,
22 | "depformer_causal": true,
23 | "depformer_layer_scale": null,
24 | "depformer_multi_linear": true,
25 | "depformer_context": 8,
26 | "depformer_max_period": 10000,
27 | "depformer_gating": "silu",
28 | "depformer_pos_emb": "none",
29 | "depformer_weights_per_step": true,
30 | "delays": [0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1]
31 | }
32 |
33 |
34 |
--------------------------------------------------------------------------------
/configs/moshi_dev_2b.json:
--------------------------------------------------------------------------------
1 | {
2 | "dim": 2560,
3 | "text_card": 48000,
4 | "existing_text_padding_id": 3,
5 | "n_q": 32,
6 | "dep_q": 16,
7 | "card": 2048,
8 | "num_heads": 20,
9 | "num_layers": 24,
10 | "hidden_scale": 4.125,
11 | "causal": true,
12 | "layer_scale": null,
13 | "context": 3000,
14 | "max_period": 100000,
15 | "gating": "silu",
16 | "norm": "rms_norm_f32",
17 | "positional_embedding": "rope",
18 | "depformer_dim": 1024,
19 | "depformer_dim_feedforward": 4224,
20 | "depformer_num_heads": 16,
21 | "depformer_num_layers": 6,
22 | "depformer_causal": true,
23 | "depformer_layer_scale": null,
24 | "depformer_multi_linear": true,
25 | "depformer_context": 16,
26 | "depformer_max_period": 10000,
27 | "depformer_gating": "silu",
28 | "depformer_pos_emb": "none",
29 | "depformer_weights_per_step": true,
30 | "delays": [0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
31 | "conditioners": {
32 | "description": {
33 | "type": "lut",
34 | "lut": {
35 | "n_bins": 31,
36 | "dim": 16,
37 | "tokenizer": "noop",
38 | "possible_values": ["very_bad", "bad", "neutral", "good", "very_good"]
39 | }
40 | }
41 | },
42 | "fuser": {
43 | "sum": ["description"]
44 | }
45 | }
46 |
--------------------------------------------------------------------------------
/configs/moshi_mlx_2b.json:
--------------------------------------------------------------------------------
1 | {
2 | "dim": 2560,
3 | "text_card": 48000,
4 | "existing_text_padding_id": 3,
5 | "n_q": 32,
6 | "dep_q": 16,
7 | "card": 2048,
8 | "num_heads": 20,
9 | "num_layers": 24,
10 | "hidden_scale": 4.125,
11 | "causal": true,
12 | "layer_scale": null,
13 | "context": 3000,
14 | "max_period": 100000,
15 | "gating": "silu",
16 | "norm": "rms_norm_f32",
17 | "positional_embedding": "rope",
18 | "depformer_dim": 1024,
19 | "depformer_dim_feedforward": 4224,
20 | "depformer_num_heads": 16,
21 | "depformer_num_layers": 6,
22 | "depformer_causal": true,
23 | "depformer_layer_scale": null,
24 | "depformer_multi_linear": true,
25 | "depformer_context": 16,
26 | "depformer_max_period": 10000,
27 | "depformer_gating": "silu",
28 | "depformer_pos_emb": "none",
29 | "depformer_weights_per_step": true,
30 | "delays": [0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
31 | "conditioners": {
32 | "description": {
33 | "type": "lut",
34 | "lut": {
35 | "n_bins": 31,
36 | "dim": 16,
37 | "tokenizer": "noop",
38 | "possible_values": ["very_bad", "bad", "neutral", "good", "very_good"]
39 | }
40 | }
41 | },
42 | "fuser": {
43 | "sum": ["description"]
44 | }
45 | }
46 |
--------------------------------------------------------------------------------
/data/sample_fr_hibiki_crepes.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/data/sample_fr_hibiki_crepes.mp3
--------------------------------------------------------------------------------
/data/sample_fr_hibiki_intro.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/data/sample_fr_hibiki_intro.mp3
--------------------------------------------------------------------------------
/data/sample_fr_hibiki_monologue_otis.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/data/sample_fr_hibiki_monologue_otis.mp3
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | version: "3.8"
2 |
3 | name: moshi
4 |
5 | services:
6 |
7 | moshi:
8 | build:
9 | context: ./moshi
10 | expose:
11 | - 8998/tcp
12 | restart: unless-stopped
13 | volumes:
14 | - hf-cache:/root/.cache/huggingface
15 | environment:
16 | #- HF_REPO=kyutai/moshika-pytorch-bf16
17 | - HF_REPO=kyutai/moshiko-pytorch-bf16
18 | deploy:
19 | resources:
20 | reservations:
21 | devices:
22 | - driver: nvidia
23 | capabilities: [gpu]
24 | count: all
25 |
26 | tunnel:
27 | image: cloudflare/cloudflared:latest
28 | pull_policy: always
29 | restart: unless-stopped
30 | expose:
31 | - 43337/tcp
32 | environment:
33 | TUNNEL_URL: http://moshi:8998
34 | TUNNEL_METRICS: 0.0.0.0:43337
35 | command: tunnel --no-autoupdate
36 | depends_on:
37 | - moshi
38 |
39 | volumes:
40 | hf-cache:
41 |
--------------------------------------------------------------------------------
/mimi.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/mimi.png
--------------------------------------------------------------------------------
/moshi.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/moshi.png
--------------------------------------------------------------------------------
/moshi/Dockerfile:
--------------------------------------------------------------------------------
1 | # Use an official Python runtime as a parent image
2 | FROM python:3.10
3 |
4 | # Set the working directory in the container
5 | WORKDIR /app
6 |
7 | # Copy the current directory contents into the container at /app
8 | COPY . /app
9 |
10 | # Install any needed packages specified in requirements.txt
11 | # Assuming you have a requirements.txt file in the moshi directory
12 | RUN pip install --no-cache-dir -r requirements.txt
13 |
14 | # Install Moshi and gradio
15 | RUN pip install --no-cache-dir moshi gradio
16 |
17 | # Expose the port used by the server
18 | EXPOSE 8998
19 |
20 | # Set environment variable for the model (with a default value)
21 | ENV HF_REPO=kyutai/moshiko-pytorch-bf16
22 |
23 | # Run the server when the container launches
24 | CMD python -m moshi.server --gradio-tunnel --hf-repo $HF_REPO
25 |
--------------------------------------------------------------------------------
/moshi/LICENSE:
--------------------------------------------------------------------------------
1 | Permission is hereby granted, free of charge, to any
2 | person obtaining a copy of this software and associated
3 | documentation files (the "Software"), to deal in the
4 | Software without restriction, including without
5 | limitation the rights to use, copy, modify, merge,
6 | publish, distribute, sublicense, and/or sell copies of
7 | the Software, and to permit persons to whom the Software
8 | is furnished to do so, subject to the following
9 | conditions:
10 |
11 | The above copyright notice and this permission notice
12 | shall be included in all copies or substantial portions
13 | of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
16 | ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
17 | TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
18 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
19 | SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
20 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
21 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
22 | IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
23 | DEALINGS IN THE SOFTWARE.
24 |
--------------------------------------------------------------------------------
/moshi/LICENSE.audiocraft:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Meta Platforms, Inc. and affiliates.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/moshi/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include LICENSE*
2 | include *.md
3 | include *.cfg
4 | include requirements.txt
5 | include moshi/py.typed
6 | include tests/assets/*.safetensors
7 |
--------------------------------------------------------------------------------
/moshi/demo_moshi.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "provenance": [],
7 | "gpuType": "T4"
8 | },
9 | "kernelspec": {
10 | "name": "python3",
11 | "display_name": "Python 3"
12 | },
13 | "language_info": {
14 | "name": "python"
15 | },
16 | "accelerator": "GPU"
17 | },
18 | "cells": [
19 | {
20 | "cell_type": "markdown",
21 | "source": [
22 | "# Moshi: a speech-text foundation model for real time dialogue\n",
23 | "\n",
24 | " [Moshi][moshi] is a speech-text foundation model and **full-duplex** spoken dialogue framework.\n",
25 | " It uses [Mimi][moshi], a state-of-the-art streaming neural audio codec. Mimi processes 24 kHz audio, down to a 12.5 Hz representation\n",
26 | " with a bandwidth of 1.1 kbps, in a fully streaming manner (latency of 80ms, the frame size),\n",
27 | " yet performs better than existing, non-streaming, codecs like\n",
28 | " [SpeechTokenizer](https://github.com/ZhangXInFD/SpeechTokenizer) (50 Hz, 4kbps), or [SemantiCodec](https://github.com/haoheliu/SemantiCodec-inference) (50 Hz, 1.3kbps).\n",
29 | "\n",
30 | " Moshi models **two streams of audio**: one corresponds to Moshi, and the other one to the user.\n",
31 | " At inference, the stream from the user is taken from the audio input,\n",
32 | "and the one for Moshi is sampled from the model's output. Along these two audio streams, Moshi predicts text tokens corresponding to its own speech, its **inner monologue**,\n",
33 | "which greatly improves the quality of its generation. A small Depth Transformer models inter codebook dependencies for a given time step,\n",
34 | "while a large, 7B parameter Temporal Transformer models the temporal dependencies. Moshi achieves a theoretical latency\n",
35 | "of 160ms (80ms for the frame size of Mimi + 80ms of acoustic delay), with a practical overall latency as low as 200ms on an L4 GPU.\n",
36 | "\n",
37 | "\n",
38 | "For more information, checkout our repo\n",
39 | "[[repo]](https://github.com/kyutai-labs/hibiki),\n",
40 | "[[samples]](https://huggingface.co/spaces/kyutai/hibiki-samples), and [[paper]](https://arxiv.org/abs/2410.00037)."
41 | ],
42 | "metadata": {
43 | "id": "iuzciRNiOznZ"
44 | }
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": null,
49 | "metadata": {
50 | "id": "zYz-oOcaOwwr"
51 | },
52 | "outputs": [],
53 | "source": [
54 | "!pip install \"git+https://git@github.com/kyutai-labs/moshi.git#egg=moshi&subdirectory=moshi\"\n",
55 | "!pip install gradio"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "source": [
61 | "# You can also run the model through our WebUI for live translation.\n",
62 | "# Click on the gradio.live link!\n",
63 | "! python -m moshi.server --gradio-tunnel --hf-repo kyutai/moshiko-pytorch-q8 --half\n",
64 | "# Or with Moshika's voice\n",
65 | "# ! python -m moshi.server --gradio-tunnel --hf-repo kyutai/moshika-1b-pytorch-bf16 --half | grep -v 'frame handled'"
66 | ],
67 | "metadata": {
68 | "id": "jykmkKO0f7dK"
69 | },
70 | "execution_count": null,
71 | "outputs": []
72 | },
73 | {
74 | "cell_type": "code",
75 | "source": [],
76 | "metadata": {
77 | "id": "ouGqgLmWL0x0"
78 | },
79 | "execution_count": null,
80 | "outputs": []
81 | }
82 | ]
83 | }
--------------------------------------------------------------------------------
/moshi/moshi/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Kyutai, all rights reserved.
2 | # This source code is licensed under the license found in the
3 | # LICENSE file in the root directory of this source tree.
4 |
5 | """
6 | moshi is the inference codebase for Kyutai audio generation models.
7 |
8 | The code has been adapted from Audiocraft, see LICENSE.audiocraft
9 | Copyright (c) Meta Platforms, Inc. and affiliates.
10 | """
11 |
12 | # flake8: noqa
13 | from . import conditioners
14 | from . import models
15 | from . import modules
16 | from . import quantization
17 | from . import utils
18 |
19 | __version__ = "0.2.5a7"
20 |
--------------------------------------------------------------------------------
/moshi/moshi/conditioners/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | # Copyright (c) Kyutai, all rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | """
6 | Modules to help doing generations under some fixed conditions.
7 | """
8 |
9 | from .base import (ConditionType, ConditionAttributes, ConditionFuser, ConditionProvider,
10 | BaseConditioner, TensorCondition, ConditionTensors, dropout_all_conditions)
11 |
--------------------------------------------------------------------------------
/moshi/moshi/conditioners/tensors.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Kyutai, all rights reserved.
2 | # This source code is licensed under the license found in the
3 | # LICENSE file in the root directory of this source tree.
4 | from .base import _BaseTensorConditioner, TensorCondition, ConditionType
5 |
6 |
7 | class TensorConditioner(_BaseTensorConditioner[TensorCondition]):
8 | """Does basically nothing.
9 | """
10 |
11 | def prepare(self, tensor: TensorCondition) -> TensorCondition:
12 | device = next(iter(self.parameters())).device
13 | return TensorCondition(tensor.tensor.to(device=device), tensor.mask.to(device=device))
14 |
15 | def _get_condition(self, inputs: TensorCondition) -> ConditionType:
16 | return ConditionType(inputs.tensor, inputs.mask)
17 |
--------------------------------------------------------------------------------
/moshi/moshi/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Kyutai, all rights reserved.
2 | # This source code is licensed under the license found in the
3 | # LICENSE file in the root directory of this source tree.
4 | """
5 | Models for the compression model Moshi,
6 | """
7 |
8 | # flake8: noqa
9 | from .compression import (
10 | CompressionModel,
11 | MimiModel,
12 | )
13 | from .lm import LMModel, LMGen
14 | from .loaders import get_mimi, get_moshi_lm
15 |
--------------------------------------------------------------------------------
/moshi/moshi/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Kyutai, all rights reserved.
2 | # This source code is licensed under the license found in the
3 | # LICENSE file in the root directory of this source tree.
4 |
5 | # Copyright (c) Meta Platforms, Inc. and affiliates.
6 | # All rights reserved.
7 | #
8 | # This source code is licensed under the license found in the
9 | # LICENSE file in the root directory of this source tree.
10 | """Modules used for building the models."""
11 |
12 | # flake8: noqa
13 | from .conv import (
14 | NormConv1d,
15 | NormConvTranspose1d,
16 | StreamingConv1d,
17 | StreamingConvTranspose1d,
18 | pad_for_conv1d,
19 | pad1d,
20 | unpad1d,
21 | )
22 | from .seanet import SEANetEncoder, SEANetDecoder
23 | from .transformer import StreamingTransformer
24 |
--------------------------------------------------------------------------------
/moshi/moshi/modules/gating.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Kyutai, all rights reserved.
2 | # This source code is licensed under the license found in the
3 | # LICENSE file in the root directory of this source tree.
4 |
5 | from contextlib import ExitStack
6 | import torch
7 | from torch import nn
8 | from torch.nn import functional as F
9 |
10 | from ..utils.compile import torch_compile_lazy, no_compile
11 |
12 |
13 | @torch_compile_lazy
14 | def gating_forward_kernel(
15 | weight_in: torch.Tensor, weight_out: torch.Tensor, activation, x: torch.Tensor
16 | ):
17 | x = F.linear(x, weight_in)
18 | B, T, _ = x.shape
19 | x = x.view(B, T, 2, -1)
20 | x = activation(x[..., 0, :]) * x[..., 1, :]
21 | x = F.linear(x, weight_out)
22 | return x
23 |
24 |
25 | def gating_forward_generic(
26 | linear_in: nn.Module,
27 | linear_out: nn.Module,
28 | activation,
29 | x: torch.Tensor
30 | ):
31 | x = linear_in(x)
32 | B, T, _ = x.shape
33 | x = x.view(B, T, 2, -1)
34 | x = activation(x[..., 0, :]) * x[..., 1, :]
35 | x = linear_out(x)
36 | return x
37 |
38 |
39 | class ActivationGating(nn.Module):
40 | """
41 | Gating FFN layer, using the given activation.
42 | Args:
43 | dim (int): dimension of the input and output of the transformer.
44 | activation (any callable Tensor to Tensor): activation function to use.
45 | **factory_kwargs: other kwargs passed to the linear layer, in particular device and dtype.
46 | """
47 |
48 | _fsdp_final = True
49 |
50 | def __init__(self, dim: int, dim_feedforward: int, activation, quantized: bool = False, **factory_kwargs):
51 | super().__init__()
52 | # We should have 8 d^2 param, instead we will have
53 | # 2 * h * d + h * d = 3 h * d = 8 d^2
54 | # so h = 8 d / 3 but following Hervé's advice we use 21 / 8 as an approx.
55 | if dim_feedforward == 4 * dim:
56 | hidden = (21 * dim) // 8
57 | else:
58 | hidden = (2 * dim_feedforward) // 3
59 |
60 | self.linear_in = nn.Linear(dim, 2 * hidden, bias=False, **factory_kwargs)
61 | self.linear_out = nn.Linear(hidden, dim, bias=False, **factory_kwargs)
62 |
63 | # We try to follow the default PyTorch MHA convention, to easily compare results.
64 |
65 | self.activation = activation
66 |
67 | def forward(self, x: torch.Tensor):
68 | if isinstance(self.linear_in, nn.Linear):
69 | assert isinstance(self.linear_out, nn.Linear)
70 | with ExitStack() as stack:
71 | if self.training:
72 | stack.enter_context(no_compile())
73 | return gating_forward_kernel(
74 | self.linear_in.weight, self.linear_out.weight, self.activation, x
75 | )
76 | else:
77 | return gating_forward_generic(
78 | self.linear_in,
79 | self.linear_out,
80 | self.activation,
81 | x
82 | )
83 |
84 |
85 | def _get_activation(name: str):
86 | if name in ["sigmoid", "tanh", "relu"]:
87 | return getattr(torch, name)
88 | elif name in ["leaky_relu", "elu", "gelu", "silu", "mish", "softsign"]:
89 | return getattr(torch.nn.functional, name)
90 | elif name == "identity":
91 | return torch.nn.Identity()
92 | else:
93 | raise ValueError(f"Unknown activation {name}")
94 |
95 |
96 | def _make_gating(
97 | name: str, dim: int, dim_feedforward: int,
98 | **factory_kwargs
99 | ) -> nn.Module:
100 | return ActivationGating(
101 | dim, dim_feedforward, _get_activation(name), **factory_kwargs
102 | )
103 |
104 |
105 | def make_gating(
106 | name: str, dim: int, dim_feedforward: int, **factory_kwargs
107 | ) -> nn.Module:
108 | gating = _make_gating(name, dim, dim_feedforward, **factory_kwargs)
109 | if isinstance(gating.linear_in, nn.Linear):
110 | max_params = 2 * dim * dim_feedforward
111 | params = sum(p.numel() for p in gating.parameters())
112 | assert (
113 | params <= max_params
114 | ), f"{name} gating has {params} params, max is {max_params}"
115 | return gating
116 |
--------------------------------------------------------------------------------
/moshi/moshi/modules/lora.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def replace_all_linear_with_lora(module, rank: int, scaling: float, device=None, dtype=None):
6 | """ Recursively replace all Linear layers with LoRALinear layers."""
7 | for name, child in module.named_children():
8 | if isinstance(child, nn.Linear):
9 | if device is None:
10 | this_device = child.weight.device
11 | else:
12 | this_device = device
13 | if dtype is None:
14 | this_dtype = child.weight.dtype
15 | else:
16 | this_dtype = dtype
17 | lora = LoRALinear(child.in_features, child.out_features,
18 | rank, scaling, device=this_device, dtype=this_dtype)
19 | lora.frozen_W = child
20 | setattr(module, name, lora)
21 | else:
22 | replace_all_linear_with_lora(child, rank, scaling, device=device, dtype=dtype)
23 |
24 |
25 | def replace_lora_with_linear(module):
26 | """Recursively replace all LoRALinear layers with Linear layers."""
27 | for name, child in module.named_children():
28 | if isinstance(child, LoRALinear):
29 | # Compute merged weights: W' = W + scaling * B @ A
30 | merged_weight = child.frozen_W.weight.data + \
31 | child.scaling * (child.lora_B.weight @ child.lora_A.weight)
32 | # Create a standard Linear layer with the same in/out features
33 | new_linear = nn.Linear(child.frozen_W.in_features,
34 | child.frozen_W.out_features, bias=False,
35 | device=torch.device('meta'),
36 | dtype=merged_weight.dtype)
37 | new_linear.weight = nn.Parameter(
38 | merged_weight, requires_grad=merged_weight.requires_grad) # Transfer merged weights
39 | setattr(module, name, new_linear) # Replace the module
40 | else:
41 | replace_lora_with_linear(child) # Recursively process submodules
42 |
43 |
44 | class LoRALinear(nn.Module):
45 | """
46 | Implementation of:
47 | - LoRA: https://arxiv.org/abs/2106.09685
48 |
49 | Notes:
50 | - Freezing is handled at the network level, not the layer level.
51 | - Scaling factor controls relative importance of LoRA skip
52 | connection versus original frozen weight. General guidance is
53 | to keep it to 2.0 and sweep over learning rate when changing
54 | the rank.
55 | """
56 |
57 | def __init__(
58 | self,
59 | in_features: int,
60 | out_features: int,
61 | rank: int,
62 | scaling: float,
63 | bias: bool = False,
64 | device: torch.device | None = None,
65 | dtype: torch.dtype = torch.bfloat16,
66 | ):
67 | super().__init__()
68 |
69 | self.in_features = in_features
70 | self.out_features = out_features
71 | assert not bias
72 | self.bias = bias
73 | self.rank = rank
74 | self.scaling = scaling
75 |
76 | self.lora_A = nn.Linear(
77 | self.in_features,
78 | self.rank,
79 | bias=self.bias,
80 | device=device,
81 | dtype=dtype,
82 | )
83 | self.lora_B = nn.Linear(
84 | self.rank,
85 | self.out_features,
86 | bias=self.bias,
87 | device=device,
88 | dtype=dtype,
89 | )
90 |
91 | self.frozen_W = nn.Linear(self.in_features,
92 | self.out_features,
93 | bias=self.bias,
94 | device=device,
95 | dtype=dtype)
96 |
97 | self._register_load_state_dict_pre_hook(LoRALinear._load_hook, with_module=True)
98 |
99 | def merge_weight(self):
100 | with torch.no_grad():
101 | down_weight = self.lora_A.weight
102 | up_weight = self.lora_B.weight
103 |
104 | weight = up_weight.mm(down_weight) * self.scaling
105 |
106 | weight += self.frozen_W.weight
107 | return weight
108 |
109 | @staticmethod
110 | def _load_hook(module, state_dict, prefix, *_):
111 | key_name = prefix + "weight"
112 | if key_name in state_dict:
113 | w_ref = state_dict.pop(key_name)
114 | state_dict[prefix + 'frozen_W.weight'] = w_ref
115 |
116 | def forward(self, x: torch.Tensor):
117 | lora = self.lora_B(self.lora_A(x))
118 | return self.frozen_W(x) + lora * self.scaling
119 |
120 | def __repr__(self) -> str:
121 | return "{}Linear(in_features={}, out_features={}, r={})".format(
122 | "LoRA", self.in_features, self.out_features, self.rank)
123 |
--------------------------------------------------------------------------------
/moshi/moshi/modules/resample.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Kyutai, all rights reserved.
2 | # This source code is licensed under the license found in the
3 | # LICENSE file in the root directory of this source tree.
4 |
5 | import typing as tp
6 |
7 | from einops import rearrange
8 | import torch
9 | from torch import nn
10 |
11 | from .conv import StreamingConv1d, StreamingConvTranspose1d
12 |
13 |
14 | class ConvDownsample1d(nn.Module):
15 | """
16 | Downsampling by some integer amount `stride` using convolutions
17 | with a kernel size of twice the stride.
18 | If `causal` is True, the output uses a causal convolution.
19 | """
20 |
21 | def __init__(
22 | self,
23 | stride: int,
24 | dimension: tp.Optional[int] = None,
25 | causal: bool = False,
26 | learnt: bool = False,
27 | channel_wise: bool = False,
28 | ):
29 | super().__init__()
30 | self.learnt = learnt
31 | self.channel_wise = channel_wise
32 | groups = 1
33 | if learnt:
34 | assert dimension is not None, "Dimension required for learnt convolutions."
35 | in_channels = dimension
36 | out_channels = dimension
37 | if channel_wise:
38 | groups = dimension
39 | else:
40 | in_channels = 1
41 | out_channels = 1
42 |
43 | self.conv = StreamingConv1d(
44 | in_channels,
45 | out_channels,
46 | kernel_size=2 * stride,
47 | stride=stride,
48 | causal=causal,
49 | groups=groups,
50 | bias=False,
51 | pad_mode="replicate",
52 | )
53 | if not learnt:
54 | actual_conv = self.conv.conv.conv
55 | actual_conv.weight.requires_grad_(False)
56 | actual_conv.weight.data.fill_(1.0 / (2 * stride))
57 |
58 | def forward(self, x: torch.Tensor):
59 | batch_size = len(x)
60 | if not self.learnt:
61 | x = rearrange(x, "b c t -> (b c) () t")
62 | y = self.conv(x)
63 | if not self.learnt:
64 | y = rearrange(y, "(b c) () t -> b c t", b=batch_size)
65 | return y
66 |
67 |
68 | class ConvTrUpsample1d(nn.Module):
69 | """
70 | Upsample by some integer amount `stride` using transposed convolutions.
71 | """
72 |
73 | def __init__(
74 | self,
75 | stride: int,
76 | dimension: tp.Optional[int] = None,
77 | causal: bool = False,
78 | learnt: bool = False,
79 | channel_wise: bool = False,
80 | ):
81 | super().__init__()
82 | self.learnt = learnt
83 | self.channel_wise = channel_wise
84 | groups = 1
85 | if learnt:
86 | assert dimension is not None, "Dimension required for learnt convolutions."
87 | in_channels = dimension
88 | out_channels = dimension
89 | if channel_wise:
90 | groups = dimension
91 | else:
92 | in_channels = 1
93 | out_channels = 1
94 |
95 | self.convtr = StreamingConvTranspose1d(
96 | in_channels,
97 | out_channels,
98 | kernel_size=2 * stride,
99 | stride=stride,
100 | causal=causal,
101 | groups=groups,
102 | bias=False,
103 | )
104 | if not learnt:
105 | actual_convtr = self.convtr.convtr.convtr
106 | actual_convtr.weight.requires_grad_(False)
107 | actual_convtr.weight.data.fill_(1.0)
108 |
109 | def forward(self, x: torch.Tensor):
110 | batch_size = len(x)
111 | if not self.learnt:
112 | x = rearrange(x, "b c t -> (b c) () t")
113 | y = self.convtr(x)
114 | if not self.learnt:
115 | x_for_normalization = torch.ones_like(x[:1])
116 | normalization = self.convtr(x_for_normalization)
117 | y = y / normalization
118 | y = rearrange(y, "(b c) () t -> b c t", b=batch_size)
119 | return y
120 |
--------------------------------------------------------------------------------
/moshi/moshi/modules/rope.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Kyutai, all rights reserved.
2 | # This source code is licensed under the license found in the
3 | # LICENSE file in the root directory of this source tree.
4 |
5 | from torch import nn
6 | import math
7 | import torch
8 | from ..utils.compile import torch_compile_lazy
9 |
10 |
11 | @torch_compile_lazy
12 | def apply_rope(
13 | q: torch.Tensor,
14 | k: torch.Tensor,
15 | offset: torch.Tensor,
16 | max_period: float = 10_000,
17 | time_before_heads: bool = False,
18 | ):
19 | """
20 | Args:
21 | q (torch.Tensor): queries, shape `[B, T, H, D]`.
22 | k (torch.Tensor): keys, shape `[B, T, H, D]`.
23 | offset (int): current offset, e.g. when streaming.
24 | max_period (float): maximum period for the cos and sin.
25 | time_before_heads (bool): if True, expected [B, T, H, D], else [B, H, T ,D]
26 | """
27 |
28 | if time_before_heads:
29 | B, T, H, D = q.shape
30 | else:
31 | B, H, T, D = q.shape
32 | assert k.shape == q.shape
33 | assert D > 0
34 | assert D % 2 == 0
35 | assert max_period > 0
36 |
37 | ds = torch.arange(D // 2, device=q.device, dtype=torch.float32)
38 | freqs = torch.exp(ds * (-math.log(max_period) * 2 / D))
39 | ts = offset.float().view(-1, 1) + torch.arange(T, device=q.device, dtype=torch.float32)
40 | if time_before_heads:
41 | ts = ts.view(B, -1, 1, 1)
42 | else:
43 | ts = ts.view(B, 1, -1, 1)
44 |
45 | dims = q.shape[:-1]
46 | q = q.view(*dims, D // 2, 2)
47 | k = k.view(*dims, D // 2, 2)
48 |
49 | # convention is `r` suffix is real part, `i` is imaginary.
50 | qr = q[..., 0].float()
51 | qi = q[..., 1].float()
52 |
53 | kr = k[..., 0].float()
54 | ki = k[..., 1].float()
55 |
56 | rotr = torch.cos(freqs * ts)
57 | roti = torch.sin(freqs * ts)
58 | qor = qr * rotr - qi * roti
59 | qoi = qr * roti + qi * rotr
60 |
61 | kor = kr * rotr - ki * roti
62 | koi = kr * roti + ki * rotr
63 |
64 | dtype = q.dtype
65 | qo = torch.stack([qor.to(dtype), qoi.to(dtype)], dim=-1)
66 | ko = torch.stack([kor.to(dtype), koi.to(dtype)], dim=-1)
67 |
68 | return qo.view(*dims, D), ko.view(*dims, D)
69 |
70 |
71 | class RotaryEmbedding(nn.Module):
72 | """Rotary positional embedding (RoPE) from [Su et al 2022](https://arxiv.org/abs/2104.09864).
73 |
74 | Args:
75 | max_period (float): Maximum period of the rotation frequencies.
76 | """
77 |
78 | def __init__(self, max_period: float = 10000.0):
79 | super().__init__()
80 | self.max_period = max_period
81 |
82 | def forward(
83 | self,
84 | q: torch.Tensor,
85 | k: torch.Tensor,
86 | offset: torch.Tensor,
87 | time_before_heads: bool = False,
88 | ):
89 | """Apply rope rotation to query or key tensor."""
90 | return apply_rope(q, k, offset, self.max_period, time_before_heads)
91 |
--------------------------------------------------------------------------------
/moshi/moshi/quantization/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Kyutai, all rights reserved.
2 | # This source code is licensed under the license found in the
3 | # LICENSE file in the root directory of this source tree.
4 |
5 | # Copyright (c) Meta Platforms, Inc. and affiliates.
6 | # All rights reserved.
7 | #
8 | # This source code is licensed under the license found in the
9 | # LICENSE file in the root directory of this source tree.
10 | """RVQ."""
11 | # flake8: noqa
12 | from .vq import ResidualVectorQuantizer, SplitResidualVectorQuantizer
13 | from .base import BaseQuantizer, DummyQuantizer, QuantizedResult
14 |
--------------------------------------------------------------------------------
/moshi/moshi/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Kyutai, all rights reserved.
2 | # This source code is licensed under the license found in the
3 | # LICENSE file in the root directory of this source tree.
4 |
5 | # Copyright (c) Meta Platforms, Inc. and affiliates.
6 | # All rights reserved.
7 | #
8 | # This source code is licensed under the license found in the
9 | # LICENSE file in the root directory of this source tree.
10 | """Utilities."""
11 |
--------------------------------------------------------------------------------
/moshi/moshi/utils/autocast.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Kyutai, all rights reserved.
2 | # This source code is licensed under the license found in the
3 | # LICENSE file in the root directory of this source tree.
4 |
5 | # Copyright (c) Meta Platforms, Inc. and affiliates.
6 | # All rights reserved.
7 | #
8 | # This source code is licensed under the license found in the
9 | # LICENSE file in the root directory of this source tree.
10 |
11 | import torch
12 |
13 |
14 | class TorchAutocast:
15 | """TorchAutocast utility class.
16 | Allows you to enable and disable autocast. This is specially useful
17 | when dealing with different architectures and clusters with different
18 | levels of support.
19 |
20 | Args:
21 | enabled (bool): Whether to enable torch.autocast or not.
22 | args: Additional args for torch.autocast.
23 | kwargs: Additional kwargs for torch.autocast
24 | """
25 |
26 | def __init__(self, enabled: bool, *args, **kwargs):
27 | self.autocast = torch.autocast(*args, **kwargs) if enabled else None
28 |
29 | def __enter__(self):
30 | if self.autocast is None:
31 | return
32 | try:
33 | self.autocast.__enter__()
34 | except RuntimeError:
35 | device = self.autocast.device
36 | dtype = self.autocast.fast_dtype
37 | raise RuntimeError(
38 | f"There was an error autocasting with dtype={dtype} device={device}\n"
39 | "If you are on the FAIR Cluster, you might need to use autocast_dtype=float16"
40 | )
41 |
42 | def __exit__(self, *args, **kwargs):
43 | if self.autocast is None:
44 | return
45 | self.autocast.__exit__(*args, **kwargs)
46 |
--------------------------------------------------------------------------------
/moshi/moshi/utils/quantize.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Kyutai, all rights reserved.
2 | # This source code is licensed under the license found in the
3 | # LICENSE file in the root directory of this source tree.
4 |
5 | """Quantization based on bitsandbytes, supporting only 8 bits for now.
6 | We are taking from freedom from the intended use of bnb:
7 | """
8 |
9 | import torch
10 | from torch import nn
11 |
12 |
13 | class QLinear(nn.Module):
14 | def __init__(self, linear: nn.Linear):
15 | super().__init__()
16 | from bitsandbytes import functional as bnbF # type: ignore
17 | weight = linear.weight
18 | assert weight.data.dtype.is_floating_point
19 | assert linear.bias is None
20 | CB, SCB, _ = bnbF.int8_vectorwise_quant(weight.data.to(torch.float16)) # type: ignore
21 | self.weight = nn.Parameter(CB, requires_grad=False)
22 | self.weight_scb = nn.Parameter(SCB, requires_grad=False)
23 |
24 | def forward(self, x):
25 | import bitsandbytes as bnb # type: ignore
26 | state = bnb.MatmulLtState()
27 | state.CB = self.weight # type: ignore
28 | assert isinstance(state.CB, torch.Tensor)
29 | state.SCB = self.weight_scb # type: ignore
30 | assert isinstance(state.SCB, torch.Tensor)
31 | if state.SCB.dtype != torch.float:
32 | raise RuntimeError(
33 | "Expected `weight_scb` to have type float, but got bfloat16. "
34 | "When using quantized models, care should be taken not to change the dtype of "
35 | "the model once initialized.")
36 | assert state.SCB.dtype == torch.float, state.SCB.dtype
37 | state.has_fp16_weights = False
38 | y = bnb.matmul(x.half(), state.CB, state=state)
39 | assert isinstance(y, torch.Tensor)
40 | return y
41 |
42 |
43 | def replace_linear_with_qlinear(module):
44 | """Recursively replace all Linear layers with QLinear layers."""
45 | for name, child in module.named_children():
46 | if isinstance(child, nn.Linear):
47 | setattr(module, name, QLinear(child))
48 | elif isinstance(child, QLinear):
49 | # Slight issue with the way we implement things: the scale param
50 | # might get casted with the rest of the model to bfloat16, altough
51 | # we most likely want to keep it as float. For the LM model we might call this function twice,
52 | # first layer by layer to avoid to big of a memory usage, and second, at the end
53 | # of the LM init, after all other modules are initialized and properly dtyped.
54 | # In any case that should happen before loading the state dict to avoid a loss of precision.
55 | child.float()
56 | else:
57 | replace_linear_with_qlinear(child)
58 |
--------------------------------------------------------------------------------
/moshi/moshi/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .compile import torch_compile_lazy
4 |
5 |
6 | @torch_compile_lazy
7 | def cross_entropy(
8 | logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor, dtype=torch.float32,
9 | logits_soft_clip: float | None = None) -> torch.Tensor:
10 | """Compute cross entropy between multi-codebook targets and model's logits.
11 | The cross entropy is computed per codebook to provide codebook-level cross entropy.
12 | Valid timesteps for each of the codebook are pulled from the mask, where invalid
13 | timesteps are set to 0.
14 |
15 | Args:
16 | logits (torch.Tensor): Model's logits of shape [B, K, T, card].
17 | targets (torch.Tensor): Target codes, of shape [B, K, T].
18 | mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T].
19 | dtype (type): Data type of the output cross entropy.
20 | logits_soft_clip (float): Clipping value for the logits to avoid numerical instability.
21 | Recommended value: 30.0.
22 | Returns:
23 | ce (torch.Tensor): Cross entropy [B, K, T] with type dtype.
24 | """
25 | output_shape = targets.shape
26 | assert logits.shape[:-1] == targets.shape
27 | assert mask.shape == targets.shape
28 | logits = logits.view(-1, logits.shape[-1])
29 | targets = targets.reshape(-1)
30 | mask = mask.reshape(-1)
31 |
32 | safe_targets = torch.where(
33 | mask,
34 | targets,
35 | torch.zeros(1, device=targets.device, dtype=targets.dtype),
36 | )
37 |
38 | # Chunking the conversion to float32 to avoid OOMs.
39 | ce_chunks = []
40 | for logits_chunk, targets_chunk in zip(torch.chunk(logits, 4), torch.chunk(safe_targets, 4)):
41 | logits_chunk = logits_chunk.to(dtype)
42 | if logits_soft_clip is not None:
43 | logits_chunk = logits_soft_clip * torch.tanh(logits_chunk / logits_soft_clip)
44 | log_partition = torch.logsumexp(logits_chunk, dim=-1, keepdim=True)
45 |
46 | # For some reason, the PyTorch cross entropy is super slow with inputs with large cardinality (e.g. 32000)
47 | # so we reimplement the cross entropy ourselves...
48 | ce_chunks.append(log_partition - logits_chunk.gather(-1, targets_chunk[..., None]))
49 | ce = torch.cat(ce_chunks, dim=0)
50 | ce = ce[..., 0]
51 | ce = torch.where(mask, ce, torch.zeros(1, device=ce.device, dtype=ce.dtype))
52 | return ce.view(output_shape)
53 |
--------------------------------------------------------------------------------
/moshi/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "moshi"
3 | requires-python = ">= 3.10"
4 | description = "Moshi is moshi"
5 | dependencies = [
6 | "numpy >= 1.26, < 2.3",
7 | "safetensors >= 0.4.0, < 0.6",
8 | "huggingface-hub >= 0.24, < 0.29",
9 | "bitsandbytes >= 0.45, < 0.46; sys_platform == 'linux'",
10 | "einops >= 0.7, < 0.9",
11 | "sentencepiece == 0.2",
12 | "sounddevice == 0.5",
13 | "sphn >= 0.1.4, < 0.2.0",
14 | "torch >= 2.2.0, < 2.7",
15 | "aiohttp>=3.10.5, <3.12",
16 | "pytest >= 8.3.3",
17 | ]
18 | authors = [{name="Laurent Mazaré", email="laurent@kyutai.org"}]
19 | maintainers = [{name="Laurent Mazaré", email="laurent@kyutai.org"}]
20 | license = {text = "MIT"}
21 | dynamic = ["version"]
22 | readme = "README.md"
23 |
24 | [project.scripts]
25 | moshi-server = "moshi.server:main"
26 | moshi-client = "moshi.client:main"
27 |
28 | [tool.setuptools.dynamic]
29 | version = {attr = "moshi.__version__"}
30 |
31 | [build-system]
32 | requires = ["setuptools"]
33 | build-backend = "setuptools.build_meta"
34 |
35 | [project.optional-dependencies]
36 | dev = [
37 | "pyright",
38 | "pytest",
39 | "flake8",
40 | "pre-commit",
41 | "gradio-webrtc>=0.0.18"
42 | ]
43 |
--------------------------------------------------------------------------------
/moshi/requirements.txt:
--------------------------------------------------------------------------------
1 | einops==0.7.0
2 | safetensors==0.4.4
3 | sentencepiece==0.2.0
4 | sounddevice==0.5.0
5 | soundfile==0.12.1
6 | sphn==0.1.4
7 | torch==2.2.0
8 | numpy==1.26.4
9 | aiohttp>=3.10.5, <3.11
10 | huggingface-hub==0.24.6
11 | pytest==8.3.3
--------------------------------------------------------------------------------
/moshi/setup.cfg:
--------------------------------------------------------------------------------
1 | [pep8]
2 | max-line-length = 120
3 |
4 | [flake8]
5 | max-line-length = 120
6 | ignore = E203,E704
7 | exclude =
8 | dist
9 | build
10 |
11 |
--------------------------------------------------------------------------------
/moshi/tests/assets/test_lm_codes.safetensors:
--------------------------------------------------------------------------------
1 | @ {"codes":{"dtype":"I64","shape":[3,4,7],"data_offsets":[0,672]}}
2 |
3 |
4 |
--------------------------------------------------------------------------------
/moshi/tests/assets/test_lm_model.safetensors:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/moshi/tests/assets/test_lm_model.safetensors
--------------------------------------------------------------------------------
/moshi/tests/assets/test_lm_out.safetensors:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/moshi/tests/assets/test_lm_out.safetensors
--------------------------------------------------------------------------------
/moshi/tests/test_lm.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from safetensors.torch import load_file, load_model
3 | import torch
4 |
5 | from moshi.models import lm
6 | from moshi.utils.utils import cross_entropy
7 |
8 |
9 | def _get_assets() -> Path:
10 | return Path(__file__).parent / 'assets'
11 |
12 |
13 | def _get_lm(device=None, dtype=torch.float32) -> lm.LMModel:
14 | torch.manual_seed(1234)
15 | model = lm.LMModel(
16 | delays=[0, 1, 2, 4],
17 | n_q=3,
18 | dep_q=3,
19 | card=32,
20 | text_card=48,
21 | dim=16,
22 | num_layers=2,
23 | num_heads=1,
24 | hidden_scale=1,
25 | depformer_dim=16,
26 | depformer_multi_linear=True,
27 | depformer_weights_per_step=True,
28 | depformer_weights_per_step_schedule=[0, 1, 1],
29 | depformer_low_rank_embeddings=8,
30 | depformer_num_heads=1,
31 | depformer_gating='silu',
32 | context=4,
33 | device=device,
34 | dtype=dtype,
35 | )
36 | return model
37 |
38 |
39 | def test_init():
40 | _get_lm(dtype=torch.float32)
41 | _get_lm(dtype=torch.bfloat16)
42 | _get_lm(dtype=torch.float16)
43 |
44 |
45 | @torch.no_grad
46 | def test_forward():
47 | model = _get_lm()
48 | load_model(model, _get_assets() / 'test_lm_model.safetensors')
49 | codes = load_file(_get_assets() / 'test_lm_codes.safetensors')['codes']
50 | out = model(codes)
51 | assert out.logits is not None
52 | assert out.text_logits is not None
53 | assert out.mask.shape == codes[:, 1:].shape
54 | assert out.text_mask.shape == codes[:, :1].shape
55 | assert out.logits.shape[:-1] == codes[:, 1:].shape
56 | assert out.logits.shape[-1] == model.card
57 | assert out.text_logits.shape[-1] == model.text_card
58 |
59 | ref_out = load_file(_get_assets() / 'test_lm_out.safetensors')
60 | assert (ref_out['mask'] == out.mask).all()
61 | assert (ref_out['text_mask'] == out.text_mask).all()
62 | ce = cross_entropy(out.logits, codes[:, 1:], out.mask)
63 | ce_ref = cross_entropy(ref_out['logits'], codes[:, 1:], out.mask)
64 | delta = (ce.mean(dim=(0, 2)) - ce_ref.mean(dim=(0, 2))).abs() / ce_ref.mean(dim=(0, 2))
65 | assert delta.amax() <= 1e-6, delta.amax()
66 |
67 | ce = cross_entropy(out.text_logits, codes[:, :1], out.text_mask)
68 | ce_ref = cross_entropy(ref_out['text_logits'], codes[:, :1], out.text_mask)
69 | delta = (ce.mean(dim=(0, 2)) - ce_ref.mean(dim=(0, 2))).abs() / ce_ref.mean(dim=(0, 2))
70 | assert delta.amax() <= 1e-6, delta.amax()
71 |
--------------------------------------------------------------------------------
/moshi_mlx/LICENSE:
--------------------------------------------------------------------------------
1 | Permission is hereby granted, free of charge, to any
2 | person obtaining a copy of this software and associated
3 | documentation files (the "Software"), to deal in the
4 | Software without restriction, including without
5 | limitation the rights to use, copy, modify, merge,
6 | publish, distribute, sublicense, and/or sell copies of
7 | the Software, and to permit persons to whom the Software
8 | is furnished to do so, subject to the following
9 | conditions:
10 |
11 | The above copyright notice and this permission notice
12 | shall be included in all copies or substantial portions
13 | of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
16 | ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
17 | TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
18 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
19 | SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
20 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
21 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
22 | IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
23 | DEALINGS IN THE SOFTWARE.
24 |
--------------------------------------------------------------------------------
/moshi_mlx/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include LICENSE*
2 | include *.md
3 | include *.cfg
4 | include requirements.txt
5 | include moshi_mlx/py.typed
6 |
--------------------------------------------------------------------------------
/moshi_mlx/README.md:
--------------------------------------------------------------------------------
1 | # Moshi - MLX
2 |
3 | See the [top-level README.md][main_repo] for more information on Moshi.
4 |
5 | [Moshi][moshi] is a speech-text foundation model and full-duplex spoken dialogue framework.
6 | It uses [Mimi][moshi], a state-of-the-art streaming neural audio codec. Mimi operates at a framerate of 12.5 Hz, and compresses
7 | 24 kHz audio down to 1.1 kbps, in a fully streaming manner (latency of 80ms, the frame size), yet performs better than existing, non-streaming, codec.
8 |
9 | This is the MLX implementation for Moshi. For Mimi, this uses our Rust based implementation through the Python binding provided in `rustymimi`, available in the [rust/](https://github.com/kyutai-labs/moshi/tree/main/rust) folder of our main repository.
10 |
11 | ## Requirements
12 |
13 | You will need at least Python 3.10, we recommend Python 3.12.
14 |
15 | ```bash
16 | pip install moshi_mlx # moshi MLX, from PyPI, best with Python 3.12.
17 | # Or the bleeding edge versions for Moshi and Moshi-MLX.
18 | pip install -e "git+https://git@github.com/kyutai-labs/moshi#egg=moshi_mlx&subdirectory=moshi_mlx"
19 | ```
20 | We have tested the MLX version with MacBook Pro M3.
21 |
22 | If you are not using Python 3.12, you might get an error when installing
23 | `moshi_mlx` or `rustymimi` (which `moshi_mlx` depends on). Then,you will need to install the [Rust toolchain](https://rustup.rs/), or switch to Python 3.12.
24 |
25 | ## Usage
26 |
27 |
28 | Once you have installed `moshi_mlx`, you can run
29 | ```bash
30 | python -m moshi_mlx.local -q 4 # weights quantized to 4 bits
31 | python -m moshi_mlx.local -q 8 # weights quantized to 8 bits
32 | # And using a different pretrained model:
33 | python -m moshi_mlx.local -q 4 --hf-repo kyutai/moshika-mlx-q4
34 | python -m moshi_mlx.local -q 8 --hf-repo kyutai/moshika-mlx-q8
35 | # be careful to always match the `-q` and `--hf-repo` flag.
36 | ```
37 |
38 | This uses a command line interface, which is barebone. It does not perform any echo cancellation,
39 | nor does it try to compensate for a growing lag by skipping frames.
40 |
41 | You can use `--hf-repo` to select a different pretrained model, by setting the proper Hugging Face repository.
42 | See [the model list](https://github.com/kyutai-labs/moshi?tab=readme-ov-file#models) for a reference of the available models.
43 |
44 | Alternatively you can use `python -m moshi_mlx.local_web` to use
45 | the web UI, the connection is via http, at [localhost:8998](http://localhost:8998).
46 |
47 |
48 | ## License
49 |
50 | The present code is provided under the MIT license.
51 |
52 | ## Citation
53 |
54 | If you use either Mimi or Moshi, please cite the following paper,
55 |
56 | ```
57 | @techreport{kyutai2024moshi,
58 | author = {Alexandre D\'efossez and Laurent Mazar\'e and Manu Orsini and Am\'elie Royer and
59 | Patrick P\'erez and Herv\'e J\'egou and Edouard Grave and Neil Zeghidour},
60 | title = {Moshi: a speech-text foundation model for real-time dialogue},
61 | institution = {Kyutai},
62 | year={2024},
63 | month={September},
64 | url={http://kyutai.org/Moshi.pdf},
65 | }
66 | ```
67 |
68 | [moshi]: https://kyutai.org/Moshi.pdf
69 | [main_repo]: https://github.com/kyutai-labs/moshi
70 |
--------------------------------------------------------------------------------
/moshi_mlx/moshi_mlx/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Kyutai, all rights reserved.
2 | # This source code is licensed under the license found in the
3 | # LICENSE file in the root directory of this source tree.
4 | # flake8: noqa
5 |
6 | """
7 | moshi_mlx is the MLX inference codebase for Kyutai audio generation models.
8 | """
9 |
10 | from . import modules, models, utils
11 |
12 | __version__ = "0.2.4"
13 |
--------------------------------------------------------------------------------
/moshi_mlx/moshi_mlx/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Kyutai, all rights reserved.
2 | # This source code is licensed under the license found in the
3 | # LICENSE file in the root directory of this source tree.
4 | # flake8: noqa
5 | """
6 | Models for EnCodec, AudioGen, MusicGen, as well as the generic LMModel.
7 | """
8 |
9 | from .lm import (
10 | Lm,
11 | LmConfig,
12 | config_v0_1,
13 | config1b_202412,
14 | config1b_202412_16rvq,
15 | config_helium_1_preview_2b,
16 | )
17 | from .generate import LmGen
18 | from .mimi import mimi_202407, MimiConfig
19 |
--------------------------------------------------------------------------------
/moshi_mlx/moshi_mlx/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Kyutai, all rights reserved.
2 | # This source code is licensed under the license found in the
3 | # LICENSE file in the root directory of this source tree.
4 | # flake8: noqa
5 | """Modules used for building the models."""
6 |
7 | from .conv import Conv1d, ConvTranspose1d, StreamableConv1d, StreamableConvTranspose1d, NormConv1d, NormConvTranspose1d, ConvDownsample1d, ConvTrUpsample1d
8 | from .quantization import SplitResidualVectorQuantizer
9 | from .seanet import SeanetConfig, SeanetEncoder, SeanetDecoder
10 | from .kv_cache import KVCache, RotatingKVCache
11 | from .transformer import Transformer, TransformerConfig, ProjectedTransformer
12 |
--------------------------------------------------------------------------------
/moshi_mlx/moshi_mlx/modules/conditioner.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Kyutai, all rights reserved.
2 | # This source code is licensed under the license found in the
3 | # LICENSE file in the root directory of this source tree.
4 | # flake8: noqa
5 | """
6 | Conditioners
7 | """
8 |
9 | from dataclasses import dataclass
10 |
11 | import mlx.core as mx
12 | import mlx.nn as nn
13 |
14 |
15 | @dataclass
16 | class LutConditionerConfig:
17 | n_bins: int
18 | dim: int
19 | tokenizer: str
20 | possible_values: dict[str, int]
21 |
22 |
23 | class LutConditioner(nn.Module):
24 | def __init__(self, output_dim: int, cfg: LutConditionerConfig):
25 | super().__init__()
26 |
27 | if cfg.tokenizer != "noop":
28 | raise ValueError(f"unsupported tokenizer {cfg.tokenizer}")
29 |
30 | self.embed = nn.Embedding(cfg.n_bins + 1, cfg.dim)
31 | self.output_proj = nn.Linear(cfg.dim, output_dim, bias=False)
32 | self.learnt_padding = mx.zeros((1, 1, output_dim))
33 | self.possible_values = { v: i for i, v in enumerate(cfg.possible_values) }
34 |
35 | def condition(self, value: str) -> mx.array:
36 | idx = self.possible_values.get(value, None)
37 | if idx is None:
38 | raise ValueError(f"unknown value {value}, possible-values: {self.possible_values}")
39 | idx = mx.array([idx])
40 | return self.output_proj(self.embed(idx))
41 |
42 | @dataclass
43 | class ConditionTensor:
44 | tensor: mx.array
45 |
46 | class ConditionProvider(nn.Module):
47 | def __init__(self, output_dim: int, cfg: dict[str, LutConditionerConfig]):
48 | self.conditioners = { name: LutConditioner(output_dim, c) for name, c in cfg.items() }
49 |
50 | def condition_tensor(self, name: str, value: str) -> ConditionTensor:
51 | if name not in self.conditioners:
52 | raise ValueError(f"unsupported conditioner {name}")
53 | tensor = self.conditioners[name].condition(value)
54 | return ConditionTensor(tensor)
55 |
--------------------------------------------------------------------------------
/moshi_mlx/moshi_mlx/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/moshi_mlx/moshi_mlx/py.typed
--------------------------------------------------------------------------------
/moshi_mlx/moshi_mlx/run_helium.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Kyutai, all rights reserved.
2 | # This source code is licensed under the license found in the
3 | # LICENSE file in the root directory of this source tree.
4 |
5 | import argparse
6 | import sentencepiece
7 | import huggingface_hub
8 | import mlx.core as mx
9 | import mlx.nn as nn
10 | from moshi_mlx import models, utils
11 |
12 |
13 | def main():
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument("--tokenizer", type=str)
16 | parser.add_argument("--weights", type=str)
17 | parser.add_argument("--nsteps", type=int, default=20)
18 | parser.add_argument("--hf-repo", type=str, default="kyutai/helium-1-preview-2b-mlx")
19 | parser.add_argument("--prompt", type=str, default="Aujourd'hui, il est temps")
20 | parser.add_argument("--verbose", action="store_true")
21 | parser.add_argument("--quantize-bits", type=int)
22 | parser.add_argument("--save-quantized", type=str)
23 | parser.add_argument("--quantize-group-size", type=int, default=64)
24 | args = parser.parse_args()
25 |
26 | weights = args.weights
27 | if weights is None:
28 | weights = huggingface_hub.hf_hub_download(
29 | args.hf_repo, "helium-1-preview-2b-bf16.safetensors"
30 | )
31 | tokenizer = args.tokenizer
32 | if tokenizer is None:
33 | tokenizer = huggingface_hub.hf_hub_download(
34 | args.hf_repo, "tokenizer_spm_48k_multi6_2.model"
35 | )
36 |
37 | mx.random.seed(299792458)
38 | lm_config = models.config_helium_1_preview_2b()
39 | model = models.Lm(lm_config)
40 | model.set_dtype(mx.bfloat16)
41 | model.load_weights(weights, strict=True)
42 | if args.quantize_bits is not None:
43 | nn.quantize(model, bits=args.quantize_bits, group_size=args.quantize_group_size)
44 | if args.save_quantized is not None:
45 | print(f"saving quantized weights in {args.save_quantized}")
46 | model.save_weights(args.save_quantized)
47 | sampler = utils.Sampler()
48 | tokenizer = sentencepiece.SentencePieceProcessor(tokenizer) # type: ignore
49 | if args.verbose:
50 | print("prompt", args.prompt)
51 | else:
52 | print(args.prompt, end="", flush=True)
53 | prompt_tokens = tokenizer.encode(args.prompt) # type: ignore
54 | token = mx.array([[1] + prompt_tokens])
55 | for step_idx in range(args.nsteps):
56 | logits = model(token)
57 | token, _ = sampler(logits[:, -1])
58 | text_token = token.item()
59 | _text = tokenizer.id_to_piece(text_token) # type: ignore
60 | _text = _text.replace("▁", " ")
61 | _text = _text.replace("<0x0A>", "\n")
62 | if args.verbose:
63 | print(step_idx, token, _text)
64 | else:
65 | print(_text, end="", flush=True)
66 | token = token[None]
67 | print()
68 |
69 |
70 | if __name__ == "__main__":
71 | main()
72 |
--------------------------------------------------------------------------------
/moshi_mlx/moshi_mlx/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Kyutai, all rights reserved.
2 | # This source code is licensed under the license found in the
3 | # LICENSE file in the root directory of this source tree.
4 | # flake8: noqa
5 | """Utilities."""
6 |
7 | from .sampling import Sampler
8 |
--------------------------------------------------------------------------------
/moshi_mlx/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "moshi_mlx"
3 | requires-python = ">= 3.10"
4 | description = "Moshi is moshi, but running on macOS"
5 | dependencies = [
6 | "numpy >= 2.1.0, < 2.3",
7 | "safetensors >= 0.4.0, < 0.6",
8 | "huggingface-hub >= 0.24, < 0.29",
9 | "rustymimi == 0.4.1",
10 | "sentencepiece == 0.2",
11 | "sounddevice == 0.5",
12 | "sphn >= 0.1.4, < 0.2.0",
13 | "mlx >= 0.24.0, < 0.25",
14 | "aiohttp>=3.10.5, <3.12",
15 | ]
16 | authors = [{name="Laurent Mazaré", email="laurent@kyutai.org"}]
17 | maintainers = [{name="Laurent Mazaré", email="laurent@kyutai.org"}]
18 | license = {text = "MIT"}
19 | dynamic = ["version"]
20 | readme = "README.md"
21 |
22 | [project.scripts]
23 | moshi-local = "moshi_mlx.local:main"
24 | moshi-local-web = "moshi_mlx.local_web:main"
25 |
26 | [build-system]
27 | requires = ["setuptools"]
28 | build-backend = "setuptools.build_meta"
29 |
30 | [tool.setuptools.dynamic]
31 | version = {attr = "moshi_mlx.__version__"}
32 |
33 | [project.optional-dependencies]
34 | dev = [
35 | "pyright",
36 | "flake8",
37 | "pre-commit",
38 | ]
39 |
--------------------------------------------------------------------------------
/moshi_mlx/requirements.txt:
--------------------------------------------------------------------------------
1 | aiohttp>=3.10.5, <3.11
2 | blessed==1.20.0
3 | cffi==1.17.0
4 | dashing==0.1.0
5 | huggingface-hub==0.24.6
6 | mlx==0.22.0
7 | nodeenv==1.9.1
8 | numpy==2.1.0
9 | psutil==6.0.0
10 | pycparser==2.22
11 | pyright==1.1.378
12 | rustymimi==0.4.1
13 | safetensors==0.4.4
14 | sentencepiece==0.2.0
15 | six==1.16.0
16 | sounddevice==0.5.0
17 | wcwidth==0.2.13
18 |
--------------------------------------------------------------------------------
/moshi_mlx/setup.cfg:
--------------------------------------------------------------------------------
1 | [pep8]
2 | max-line-length = 120
3 |
4 | [flake8]
5 | max-line-length = 120
6 | ignore = E203,E704
7 | exclude =
8 | dist
9 | build
10 |
11 |
--------------------------------------------------------------------------------
/requirements-dev.txt:
--------------------------------------------------------------------------------
1 | pre-commit>=3.8
2 | pyright>=1.1
3 | flake8>=7.1
4 |
--------------------------------------------------------------------------------
/rust/.github/workflows/rust-ci.yml:
--------------------------------------------------------------------------------
1 | on: [push, pull_request]
2 |
3 | name: Continuous integration
4 |
5 | jobs:
6 | check:
7 | name: Check
8 | defaults:
9 | run:
10 | working-directory: ./rust
11 | runs-on: ${{ matrix.os }}
12 | strategy:
13 | matrix:
14 | os: [ubuntu-latest, windows-latest, macOS-latest]
15 | rust: [stable, nightly]
16 | steps:
17 | - uses: actions/checkout@v2
18 | - uses: actions-rs/toolchain@v1
19 | with:
20 | profile: minimal
21 | toolchain: ${{ matrix.rust }}
22 | override: true
23 | - uses: actions-rs/cargo@v1
24 | with:
25 | command: check
26 |
27 | test:
28 | name: Test Suite
29 | defaults:
30 | run:
31 | working-directory: ./rust
32 | runs-on: ${{ matrix.os }}
33 | strategy:
34 | matrix:
35 | os: [ubuntu-latest, windows-latest, macOS-latest]
36 | rust: [stable, nightly]
37 | steps:
38 | - uses: actions/checkout@v2
39 | - uses: actions-rs/toolchain@v1
40 | with:
41 | profile: minimal
42 | toolchain: ${{ matrix.rust }}
43 | override: true
44 | - uses: actions-rs/cargo@v1
45 | with:
46 | command: test
47 |
48 | fmt:
49 | name: Rustfmt
50 | defaults:
51 | run:
52 | working-directory: ./rust
53 | runs-on: ubuntu-latest
54 | steps:
55 | - uses: actions/checkout@v2
56 | - uses: actions-rs/toolchain@v1
57 | with:
58 | profile: minimal
59 | toolchain: stable
60 | override: true
61 | - run: rustup component add rustfmt
62 | - uses: actions-rs/cargo@v1
63 | with:
64 | command: fmt
65 | args: --all -- --check
66 |
67 | clippy:
68 | name: Clippy
69 | defaults:
70 | run:
71 | working-directory: ./rust
72 | runs-on: ubuntu-latest
73 | steps:
74 | - uses: actions/checkout@v2
75 | - uses: actions-rs/toolchain@v1
76 | with:
77 | profile: minimal
78 | toolchain: stable
79 | override: true
80 | - run: rustup component add clippy
81 | - uses: actions-rs/cargo@v1
82 | with:
83 | command: clippy
84 | args: -- -D warnings
85 |
--------------------------------------------------------------------------------
/rust/Cargo.toml:
--------------------------------------------------------------------------------
1 | [workspace]
2 | members = [
3 | "mimi-pyo3",
4 | "moshi-backend",
5 | "moshi-cli",
6 | "moshi-core",
7 | ]
8 | resolver = "2"
9 |
10 | [workspace.package]
11 | version = "0.6.0-alpha.1"
12 | edition = "2021"
13 | license = "MIT/Apache-2.0"
14 | description = "moshi, a real-time voice AI"
15 | repository = "https://github.com/kyutai-labs/moshi"
16 | keywords = ["machine-learning", "audio"]
17 | categories = ["science"]
18 |
19 | [workspace.dependencies]
20 | anyhow = "1"
21 | axum = { version = "0.7.3", features = ["ws"] }
22 | axum-server = { version = "0.6", features = ["tls-rustls"] }
23 | base64ct = { version = "1.6.0", features = ["alloc"] }
24 | bincode = "1.3.3"
25 | byteorder = "1.5.0"
26 | candle = { git = "https://github.com/huggingface/candle.git", rev = "e3db30021fb08efbe4ee71d840f5d1230d050cd3", package = "candle-core" }
27 | candle-flash-attn = { git = "https://github.com/huggingface/candle.git", rev = "e3db30021fb08efbe4ee71d840f5d1230d050cd3" }
28 | candle-nn = { git = "https://github.com/huggingface/candle.git", rev = "e3db30021fb08efbe4ee71d840f5d1230d050cd3" }
29 | candle-transformers = { git = "https://github.com/huggingface/candle.git", rev = "e3db30021fb08efbe4ee71d840f5d1230d050cd3" }
30 | clap = { version = "4.4.12", features = ["derive"] }
31 | color-eyre = "0.6.2"
32 | cpal = "0.15.3"
33 | crossterm = { version = "0.27.0", features = ["event-stream"] }
34 | env_logger = "0.10.1"
35 | futures = "0.3.28"
36 | futures-util = "0.3.30"
37 | hf-hub = { version = "0.3.2", features = ["tokio"] }
38 | http = "1.1.0"
39 | lazy_static = "1.5.0"
40 | log = "0.4.20"
41 | moshi = { path = "./moshi-core", version = "0.6.0-alpha.1" }
42 | native-tls = "0.2.11"
43 | numpy = "0.23.0"
44 | ogg = { version = "0.9.1", features = ["async"] }
45 | opus = "0.3.0"
46 | pyo3 = "0.23.0"
47 | rand = { version = "0.8.5", features = ["getrandom"] }
48 | rand_chacha = "0.3.1"
49 | ratatui = "0.27.0"
50 | rayon = "1.8.1"
51 | rcgen = "0.13.1"
52 | regex = "1.10.3"
53 | rubato = "0.15.0"
54 | rustls = "0.23.5"
55 | sentencepiece = "0.11.2"
56 | serde = { version = "1.0", features = ["derive"] }
57 | serde_json = "1.0.115"
58 | sha3 = "0.10.8"
59 | symphonia = { version = "0.5.3", features = ["all"] }
60 | tokenizers = "0.15.2"
61 | tokio = { version = "1.35.1", features = ["full"] }
62 | tokio-rustls = "0.24.1"
63 | tokio-tungstenite = { version = "0.21.0", features = ["rustls", "native-tls"] }
64 | toml = "0.8.19"
65 | tower = "0.4.13"
66 | tower-http = { version = "0.5", features = ["full"] }
67 | tracing = "0.1.40"
68 | tracing-appender = "0.2.3"
69 | tracing-chrome = "0.7.2"
70 | tracing-subscriber = "0.3.18"
71 | tui-logger = "0.11.2"
72 | vergen = { version = "8.3.1", features = ["build", "cargo", "git", "gitcl", "rustc", "si"] }
73 |
74 | [profile.release]
75 | debug = true
76 |
77 | [profile.release-no-debug]
78 | inherits = "release"
79 | debug = false
80 |
--------------------------------------------------------------------------------
/rust/README.md:
--------------------------------------------------------------------------------
1 | # moshi - rust
2 |
3 | [](https://crates.io/crates/moshi)
4 | [](https://docs.rs/moshi)
5 | 
6 |
7 | See the [top-level README.md](../README.md) for more information.
8 |
9 | This provides the Rust backend (both Mimi and Moshi) and client implementation.
10 | The Mimi implementation is available through Python bindings, through the `rustymimi` package.
11 |
12 | ## Requirements
13 |
14 | You will need a recent version of the [Rust toolchain](https://rustup.rs/).
15 | To compile GPU support, you will also need the [CUDA](https://developer.nvidia.com/cuda-toolkit) properly installed for your GPU, in particular with `nvcc`.
16 |
17 |
18 | ## Rust based Mimi with Python bindings
19 |
20 | First, a standalone rust based implementation of Mimi is provided, along with Python bindings.
21 | This is the one used by `moshi_mlx`. It is automatically installed with `moshi_mlx`, but you
22 | can install it separately as
23 | ```bash
24 | # Install from pip:
25 | pip install rustymimi
26 | # Alternatively, if you want to compile the package run from the root of the repo.
27 | maturin dev -r -m rust/mimi-pyo3/Cargo.toml
28 | ```
29 |
30 | ## Rust server
31 |
32 | If you don't have ssl certificates yet, generate a `key.pem` and `cert.pem` file
33 | using the following command.
34 | ```bash
35 | openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -days 365 -nodes -subj "/CN=localhost"
36 | ```
37 |
38 | In order to run the rust inference server, use the following command from within
39 | the this directory:
40 |
41 | ```bash
42 | cargo run --features cuda --bin moshi-backend -r -- --config moshi-backend/config.json standalone
43 | ```
44 |
45 | When using macOS, you can replace `--features cuda` with `--features metal`.
46 |
47 | Alternatively you can use `config-q8.json` rather than `config.json` to use the
48 | quantified q8 model. You can select a different pretrained model, e.g. Moshika,
49 | by changing the `"hf_repo"` key in either file.
50 |
51 | Once the server has printed 'standalone worker listening', you can use the web
52 | UI. By default the rust version uses https so it will be at
53 | [localhost:8998](https://localhost:8998).
54 |
55 | You will get some warnings about the site being unsafe. When using chrome you
56 | can bypass it by selecting "Details" or "Advanced", then "Visit this unsafe
57 | site" or "Proceed to localhost (unsafe)".
58 |
59 | ## Rust client
60 |
61 | We recommend using the web UI as it provides some echo cancellation that helps
62 | the overall model quality. Alternatively we provide some command line interfaces
63 | for the rust and python versions, the protocol is the same as with the web UI so
64 | there is nothing to change on the server side.
65 |
66 | ### Rust Command Line
67 |
68 | From within the `rust` directory, run the following:
69 | ```bash
70 | cargo run --bin moshi-cli -r -- tui --host localhost
71 | ```
72 |
73 | ## License
74 |
75 | The present code is provided under the Apache license.
76 |
--------------------------------------------------------------------------------
/rust/mimi-pyo3/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "mimi-pyo3"
3 | version.workspace = true
4 | edition.workspace = true
5 | description.workspace = true
6 | repository.workspace = true
7 | keywords.workspace = true
8 | categories.workspace = true
9 | license.workspace = true
10 |
11 | [lib]
12 | name = "rustymimi"
13 | crate-type = ["cdylib"]
14 |
15 | [dependencies]
16 | anyhow = { workspace = true }
17 | numpy = { workspace = true }
18 | pyo3 = { workspace = true }
19 | moshi = { workspace = true }
20 |
--------------------------------------------------------------------------------
/rust/mimi-pyo3/py_src/rustymimi/__init__.py:
--------------------------------------------------------------------------------
1 | from .rustymimi import *
2 |
--------------------------------------------------------------------------------
/rust/mimi-pyo3/py_src/rustymimi/__init__.pyi:
--------------------------------------------------------------------------------
1 | # Generated content DO NOT EDIT
2 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
3 | from os import PathLike
4 |
5 | @staticmethod
6 | def write_wav(filename, data, sample_rate):
7 | """
8 | Writes an audio file using the wav format based on pcm data from a numpy array.
9 |
10 | This only supports a single channel at the moment so the input array data is expected to have a
11 | single dimension.
12 | """
13 | pass
14 |
15 | class StreamTokenizer:
16 | def __init__(path, *, dtype="f32", max_seq_len=None):
17 | pass
18 |
19 | def decode(self, codes):
20 | """ """
21 | pass
22 |
23 | def encode(self, pcm_data):
24 | """ """
25 | pass
26 |
27 | def get_decoded(self):
28 | """ """
29 | pass
30 |
31 | def get_encoded(self):
32 | """ """
33 | pass
34 |
35 | class Tokenizer:
36 | def __init__(path, *, dtype="f32", max_seq_len=None):
37 | pass
38 |
39 | def decode(self, codes):
40 | """ """
41 | pass
42 |
43 | def decode_step(self, codes):
44 | """ """
45 | pass
46 |
47 | def encode(self, pcm_data):
48 | """ """
49 | pass
50 |
51 | def encode_step(self, pcm_data):
52 | """ """
53 | pass
54 |
55 | def reset(self):
56 | """ """
57 | pass
58 |
--------------------------------------------------------------------------------
/rust/mimi-pyo3/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["maturin>=1.4,<2.0"]
3 | build-backend = "maturin"
4 |
5 | [project]
6 | name = "rustymimi"
7 | requires-python = ">=3.8"
8 | classifiers = [
9 | "Programming Language :: Rust",
10 | "Programming Language :: Python :: Implementation :: CPython",
11 | "Programming Language :: Python :: Implementation :: PyPy",
12 | ]
13 | dynamic = ["version"]
14 |
15 | [tool.maturin]
16 | python-source = "py_src"
17 | module-name = "rustymimi.rustymimi"
18 | bindings = 'pyo3'
19 | features = ["pyo3/extension-module"]
20 |
--------------------------------------------------------------------------------
/rust/moshi-backend/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "moshi-backend"
3 | version.workspace = true
4 | edition.workspace = true
5 | description.workspace = true
6 | repository.workspace = true
7 | keywords.workspace = true
8 | categories.workspace = true
9 | license.workspace = true
10 |
11 | [dependencies]
12 | anyhow = { workspace = true }
13 | axum = { workspace = true }
14 | axum-server = { workspace = true }
15 | base64ct = { workspace = true }
16 | bincode = { workspace = true }
17 | byteorder = { workspace = true }
18 | candle = { workspace = true }
19 | candle-nn = { workspace = true }
20 | candle-transformers = { workspace = true }
21 | clap = { workspace = true }
22 | env_logger = { workspace = true }
23 | futures-util = { workspace = true }
24 | hf-hub = { workspace = true }
25 | rcgen = { workspace = true }
26 | http = { workspace = true }
27 | lazy_static = { workspace = true }
28 | log = { workspace = true }
29 | moshi = { workspace = true }
30 | ogg = { workspace = true }
31 | opus = { workspace = true }
32 | rand = { workspace = true }
33 | rand_chacha = { workspace = true }
34 | regex = { workspace = true }
35 | rubato = { workspace = true }
36 | sentencepiece = { workspace = true }
37 | serde = { workspace = true }
38 | serde_json = { workspace = true }
39 | sha3 = { workspace = true }
40 | symphonia = { workspace = true }
41 | tokenizers = { workspace = true }
42 | tokio = { workspace = true }
43 | tokio-rustls = { workspace = true }
44 | tower = { workspace = true }
45 | tower-http = { workspace = true }
46 | tracing = { workspace = true }
47 | tracing-appender = { workspace = true }
48 | tracing-chrome = { workspace = true }
49 | tracing-subscriber = { workspace = true }
50 |
51 | [build-dependencies]
52 | anyhow = { workspace = true }
53 | vergen = { workspace = true }
54 |
55 | [features]
56 | default = []
57 | cuda = ["moshi/cuda", "candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
58 | metal = ["moshi/metal", "candle/metal", "candle-nn/metal", "candle-transformers/metal"]
59 |
60 | [profile.release]
61 | debug = true
62 |
63 | [profile.release-no-debug]
64 | inherits = "release"
65 | debug = false
66 |
67 |
--------------------------------------------------------------------------------
/rust/moshi-backend/build.rs:
--------------------------------------------------------------------------------
1 | // Copyright (c) Kyutai, all rights reserved.
2 | // This source code is licensed under the license found in the
3 | // LICENSE file in the root directory of this source tree.
4 |
5 | use anyhow::Result;
6 | use vergen::EmitBuilder;
7 |
8 | pub fn main() -> Result<()> {
9 | // NOTE: This will output everything, and requires all features enabled.
10 | // NOTE: See the EmitBuilder documentation for configuration options.
11 | EmitBuilder::builder().all_build().all_cargo().all_git().all_rustc().all_sysinfo().emit()?;
12 | Ok(())
13 | }
14 |
--------------------------------------------------------------------------------
/rust/moshi-backend/config-q8.json:
--------------------------------------------------------------------------------
1 | {
2 | "instance_name": "foo",
3 | "hf_repo": "kyutai/moshiko-candle-q8",
4 | "lm_model_file": "$HOME/tmp/moshiko_rs_301e30bf@120/model.q8.gguf",
5 | "text_tokenizer_file": "$HOME/tmp/tokenizer_spm_32k_3.model",
6 | "log_dir": "$HOME/tmp/moshi-logs",
7 | "mimi_model_file": "$HOME/tmp/tokenizer-e351c8d8-checkpoint125.safetensors",
8 | "mimi_num_codebooks": 8,
9 | "static_dir": "../client/dist",
10 | "addr": "0.0.0.0",
11 | "port": 8998,
12 | "cert_dir": "."
13 | }
14 |
15 |
--------------------------------------------------------------------------------
/rust/moshi-backend/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "instance_name": "foo",
3 | "hf_repo": "kyutai/moshiko-candle-bf16",
4 | "lm_model_file": "$HOME/tmp/moshiko_rs_301e30bf@120/model.safetensors",
5 | "text_tokenizer_file": "$HOME/tmp/tokenizer_spm_32k_3.model",
6 | "log_dir": "$HOME/tmp/moshi-logs",
7 | "mimi_model_file": "$HOME/tmp/tokenizer-e351c8d8-checkpoint125.safetensors",
8 | "mimi_num_codebooks": 8,
9 | "static_dir": "../client/dist",
10 | "addr": "0.0.0.0",
11 | "port": 8998,
12 | "cert_dir": "."
13 | }
14 |
--------------------------------------------------------------------------------
/rust/moshi-backend/src/build.rs:
--------------------------------------------------------------------------------
1 | // Copyright (c) Kyutai, all rights reserved.
2 | // This source code is licensed under the license found in the
3 | // LICENSE file in the root directory of this source tree.
4 |
5 | use anyhow::Result;
6 | use vergen::EmitBuilder;
7 |
8 | pub fn main() -> Result<()> {
9 | // NOTE: This will output everything, and requires all features enabled.
10 | // NOTE: See the EmitBuilder documentation for configuration options.
11 | EmitBuilder::builder().all_build().all_cargo().all_git().all_rustc().all_sysinfo().emit()?;
12 | Ok(())
13 | }
14 |
--------------------------------------------------------------------------------
/rust/moshi-backend/src/utils.rs:
--------------------------------------------------------------------------------
1 | // Copyright (c) Kyutai, all rights reserved.
2 | // This source code is licensed under the license found in the
3 | // LICENSE file in the root directory of this source tree.
4 |
5 | #[derive(Debug, PartialEq, Clone, serde::Deserialize, serde::Serialize)]
6 | pub struct BuildInfo {
7 | build_timestamp: String,
8 | build_date: String,
9 | git_branch: String,
10 | git_timestamp: String,
11 | git_date: String,
12 | git_hash: String,
13 | git_describe: String,
14 | rustc_host_triple: String,
15 | rustc_version: String,
16 | cargo_target_triple: String,
17 | }
18 |
19 | impl BuildInfo {
20 | pub fn new() -> BuildInfo {
21 | BuildInfo {
22 | build_timestamp: String::from(env!("VERGEN_BUILD_TIMESTAMP")),
23 | build_date: String::from(env!("VERGEN_BUILD_DATE")),
24 | git_branch: String::from(env!("VERGEN_GIT_BRANCH")),
25 | git_timestamp: String::from(env!("VERGEN_GIT_COMMIT_TIMESTAMP")),
26 | git_date: String::from(env!("VERGEN_GIT_COMMIT_DATE")),
27 | git_hash: String::from(env!("VERGEN_GIT_SHA")),
28 | git_describe: String::from(env!("VERGEN_GIT_DESCRIBE")),
29 | rustc_host_triple: String::from(env!("VERGEN_RUSTC_HOST_TRIPLE")),
30 | rustc_version: String::from(env!("VERGEN_RUSTC_SEMVER")),
31 | cargo_target_triple: String::from(env!("VERGEN_CARGO_TARGET_TRIPLE")),
32 | }
33 | }
34 | }
35 |
36 | pub struct WrapJson(pub anyhow::Result);
37 |
38 | impl axum::response::IntoResponse for WrapJson {
39 | fn into_response(self) -> axum::response::Response {
40 | match self.0 {
41 | Ok(v) => axum::Json(v).into_response(),
42 | Err(err) => {
43 | tracing::error!(?err, "returning internal server error 500");
44 | (axum::http::StatusCode::INTERNAL_SERVER_ERROR, format!("{err}")).into_response()
45 | }
46 | }
47 | }
48 | }
49 |
50 | pub fn replace_env_vars(input: &str) -> String {
51 | let re = regex::Regex::new(r"\$([A-Za-z_][A-Za-z0-9_]*)").unwrap();
52 | re.replace_all(input, |caps: ®ex::Captures| {
53 | let var_name = &caps[1];
54 | std::env::var(var_name).unwrap_or_else(|_| "".to_string())
55 | })
56 | .to_string()
57 | }
58 |
59 | pub struct WrapBincode(pub anyhow::Result);
60 |
61 | impl axum::response::IntoResponse for WrapBincode {
62 | fn into_response(self) -> axum::response::Response {
63 | match self.0.and_then(|v| Ok(bincode::serialize(&v)?)) {
64 | Ok(v) => (axum::http::StatusCode::OK, v).into_response(),
65 | Err(err) => {
66 | tracing::error!(?err, "returning internal server error 500");
67 | (axum::http::StatusCode::INTERNAL_SERVER_ERROR, format!("{err}")).into_response()
68 | }
69 | }
70 | }
71 | }
72 |
--------------------------------------------------------------------------------
/rust/moshi-cli/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "moshi-cli"
3 | version.workspace = true
4 | edition.workspace = true
5 | description.workspace = true
6 | repository.workspace = true
7 | keywords.workspace = true
8 | categories.workspace = true
9 | license.workspace = true
10 |
11 | [dependencies]
12 | anyhow = { workspace = true }
13 | byteorder = { workspace = true }
14 | candle = { workspace = true }
15 | candle-nn = { workspace = true }
16 | candle-transformers = { workspace = true }
17 | clap = { workspace = true }
18 | color-eyre = { workspace = true }
19 | cpal = { workspace = true }
20 | crossterm = { workspace = true }
21 | env_logger = { workspace = true }
22 | futures = { workspace = true }
23 | futures-util = { workspace = true }
24 | log = { workspace = true }
25 | moshi = { workspace = true }
26 | native-tls = { workspace = true }
27 | ogg = { workspace = true }
28 | opus = { workspace = true }
29 | rand = { workspace = true }
30 | ratatui = { workspace = true }
31 | rubato = { workspace = true }
32 | rustls = { workspace = true }
33 | sentencepiece = { workspace = true }
34 | serde_json = { workspace = true }
35 | symphonia = { workspace = true }
36 | tokio = { workspace = true }
37 | tokio-tungstenite = { workspace = true }
38 | toml = { workspace = true }
39 | tracing = { workspace = true }
40 | tracing-chrome = { workspace = true }
41 | tracing-subscriber = { workspace = true }
42 | tui-logger = { workspace = true }
43 |
44 | [features]
45 | default = []
46 | cuda = ["moshi/cuda", "candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
47 | metal = ["moshi/metal", "candle/metal", "candle-nn/metal", "candle-transformers/metal"]
48 |
49 | [profile.release]
50 | debug = true
51 |
52 | [profile.release-no-debug]
53 | inherits = "release"
54 | debug = false
55 |
--------------------------------------------------------------------------------
/rust/moshi-cli/src/main.rs:
--------------------------------------------------------------------------------
1 | // Copyright (c) Kyutai, all rights reserved.
2 | // This source code is licensed under the license found in the
3 | // LICENSE file in the root directory of this source tree.
4 |
5 | use anyhow::Result;
6 | use clap::Parser;
7 |
8 | mod audio_io;
9 | mod gen;
10 | mod multistream;
11 |
12 | use candle::Device;
13 |
14 | #[derive(Debug, Parser)]
15 | struct Args {
16 | #[command(subcommand)]
17 | command: Command,
18 |
19 | /// Enable tracing (generates a trace-timestamp.json file).
20 | #[arg(long)]
21 | tracing: bool,
22 | }
23 |
24 | #[derive(Debug, clap::Subcommand)]
25 | enum Command {
26 | Client {
27 | #[arg(long)]
28 | host: String,
29 |
30 | #[arg(long, default_value_t = 8998)]
31 | port: usize,
32 | },
33 | Tui {
34 | #[arg(long)]
35 | host: String,
36 |
37 | #[arg(long, default_value_t = 8998)]
38 | port: usize,
39 | },
40 | Gen {
41 | #[arg(long)]
42 | lm_model_file: String,
43 |
44 | #[arg(long)]
45 | mimi_model_file: String,
46 |
47 | #[arg(long)]
48 | lm_config_file: String,
49 |
50 | #[arg(long)]
51 | text_tokenizer: String,
52 |
53 | #[arg(long)]
54 | audio_input_file: String,
55 |
56 | #[arg(long)]
57 | audio_output_file: String,
58 |
59 | #[arg(long, default_value_t = 299_792_458)]
60 | seed: u64,
61 |
62 | #[arg(long)]
63 | cfg_alpha: Option,
64 |
65 | /// Run on cpu
66 | #[arg(long)]
67 | cpu: bool,
68 | },
69 | }
70 |
71 | pub fn device(cpu: bool) -> Result {
72 | if cpu {
73 | Ok(Device::Cpu)
74 | } else if candle::utils::cuda_is_available() {
75 | Ok(Device::new_cuda(0)?)
76 | } else if candle::utils::metal_is_available() {
77 | Ok(Device::new_metal(0)?)
78 | } else {
79 | Ok(Device::Cpu)
80 | }
81 | }
82 |
83 | #[tokio::main(flavor = "multi_thread", worker_threads = 10)]
84 | async fn main() -> Result<()> {
85 | use tracing_chrome::ChromeLayerBuilder;
86 | use tracing_subscriber::prelude::*;
87 |
88 | let args = Args::parse();
89 | let _guard = if args.tracing {
90 | let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
91 | tracing_subscriber::registry().with(chrome_layer).init();
92 | Some(guard)
93 | } else {
94 | None
95 | };
96 | match args.command {
97 | Command::Client { host, port } => {
98 | tracing_subscriber::fmt::init();
99 | multistream::client::run(host, port).await?
100 | }
101 | Command::Tui { host, port } => {
102 | tracing_subscriber::fmt::init();
103 | multistream::client_tui::run(host, port).await?
104 | }
105 | Command::Gen {
106 | seed,
107 | text_tokenizer,
108 | lm_model_file,
109 | lm_config_file,
110 | mimi_model_file,
111 | audio_input_file,
112 | audio_output_file,
113 | cfg_alpha,
114 | cpu,
115 | } => {
116 | let dev = device(cpu)?;
117 | tracing_subscriber::fmt::init();
118 | let args = gen::Args {
119 | lm_model_file,
120 | mimi_model_file,
121 | text_tokenizer,
122 | lm_config_file,
123 | audio_input_file,
124 | audio_output_file,
125 | seed,
126 | cfg_alpha,
127 | };
128 | gen::run(&args, &dev)?
129 | }
130 | }
131 | Ok(())
132 | }
133 |
--------------------------------------------------------------------------------
/rust/moshi-core/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "moshi"
3 | version.workspace = true
4 | edition.workspace = true
5 | description.workspace = true
6 | repository.workspace = true
7 | keywords.workspace = true
8 | categories.workspace = true
9 | license.workspace = true
10 | readme = "../README.md"
11 |
12 | [dependencies]
13 | candle = { workspace = true }
14 | candle-nn = { workspace = true }
15 | candle-transformers = { workspace = true }
16 | candle-flash-attn = { workspace = true, optional = true }
17 |
18 | rayon = { workspace = true }
19 | serde = { workspace = true }
20 | tracing = { workspace = true }
21 |
22 | [features]
23 | default = []
24 | cuda = ["candle/cuda", "candle-nn/cuda"]
25 | metal = ["candle/metal", "candle-nn/metal"]
26 | flash-attn = ["cuda", "dep:candle-flash-attn"]
27 |
--------------------------------------------------------------------------------
/rust/moshi-core/src/lib.rs:
--------------------------------------------------------------------------------
1 | // Copyright (c) Kyutai, all rights reserved.
2 | // This source code is licensed under the license found in the
3 | // LICENSE file in the root directory of this source tree.
4 |
5 | pub use candle;
6 | pub use candle_nn;
7 |
8 | pub mod asr;
9 | pub mod batched_transformer;
10 | pub mod conditioner;
11 | pub mod conv;
12 | pub mod kv_cache;
13 | pub mod lm;
14 | pub mod lm_generate;
15 | pub mod lm_generate_multistream;
16 | pub mod mimi;
17 | pub mod nn;
18 | pub mod quantization;
19 | pub mod seanet;
20 | pub mod streaming;
21 | pub mod transformer;
22 | pub mod tts;
23 | pub mod tts_streaming;
24 | pub mod wav;
25 |
26 | #[derive(Debug, Copy, Clone, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
27 | pub enum NormType {
28 | RmsNorm,
29 | LayerNorm,
30 | }
31 |
32 | pub use streaming::{StreamMask, StreamTensor, StreamingModule};
33 |
--------------------------------------------------------------------------------
/rust/moshi-core/src/wav.rs:
--------------------------------------------------------------------------------
1 | // Copyright (c) Kyutai, all rights reserved.
2 | // This source code is licensed under the license found in the
3 | // LICENSE file in the root directory of this source tree.
4 |
5 | use std::io::prelude::*;
6 |
7 | pub trait Sample {
8 | fn to_i16(&self) -> i16;
9 | }
10 |
11 | impl Sample for f32 {
12 | fn to_i16(&self) -> i16 {
13 | (self.clamp(-1.0, 1.0) * 32767.0) as i16
14 | }
15 | }
16 |
17 | impl Sample for f64 {
18 | fn to_i16(&self) -> i16 {
19 | (self.clamp(-1.0, 1.0) * 32767.0) as i16
20 | }
21 | }
22 |
23 | impl Sample for i16 {
24 | fn to_i16(&self) -> i16 {
25 | *self
26 | }
27 | }
28 |
29 | pub fn write_pcm_as_wav(
30 | w: &mut W,
31 | samples: &[S],
32 | sample_rate: u32,
33 | ) -> std::io::Result<()> {
34 | let len = 12u32; // header
35 | let len = len + 24u32; // fmt
36 | let len = len + samples.len() as u32 * 2 + 8; // data
37 | let n_channels = 1u16;
38 | let bytes_per_second = sample_rate * 2 * n_channels as u32;
39 | w.write_all(b"RIFF")?;
40 | w.write_all(&(len - 8).to_le_bytes())?; // total length minus 8 bytes
41 | w.write_all(b"WAVE")?;
42 |
43 | // Format block
44 | w.write_all(b"fmt ")?;
45 | w.write_all(&16u32.to_le_bytes())?; // block len minus 8 bytes
46 | w.write_all(&1u16.to_le_bytes())?; // PCM
47 | w.write_all(&n_channels.to_le_bytes())?; // one channel
48 | w.write_all(&sample_rate.to_le_bytes())?;
49 | w.write_all(&bytes_per_second.to_le_bytes())?;
50 | w.write_all(&2u16.to_le_bytes())?; // 2 bytes of data per sample
51 | w.write_all(&16u16.to_le_bytes())?; // bits per sample
52 |
53 | // Data block
54 | w.write_all(b"data")?;
55 | w.write_all(&(samples.len() as u32 * 2).to_le_bytes())?;
56 | for sample in samples.iter() {
57 | w.write_all(&sample.to_i16().to_le_bytes())?
58 | }
59 | Ok(())
60 | }
61 |
--------------------------------------------------------------------------------
/rust/protocol.md:
--------------------------------------------------------------------------------
1 | # Protocol
2 |
3 | The connection takes place using a websocket. This handles the message lengths
4 | for us. The binary protocol for messages is as follows. The protocol uses little
5 | endian encoding.
6 |
7 | Each message starts by a single byte indicating the message type `MT`.
8 | The format for the rest of the message, aka the payload, depends on `MT`.
9 |
10 | ```
11 | - Handshake MT=0. The payload is made of two fields.
12 | 1. Protocol version (`u32`) - always 0 for now.
13 | 2. Model version (`u32`).
14 | - Audio MT=1. The payload is made of a single field.
15 | - Binary data for the ogg frames containing opus encoded audio (24kHz, mono).
16 | - Text MT=2. The payload is made of a single field.
17 | - UTF8 encoded string.
18 | - Control MT=3. The payload is made of a single field. This is not used in full
19 | streaming mode.
20 | - One byte B describing the control itself.
21 | - Start B=0.
22 | - EndTurn B=1.
23 | - Pause B=2.
24 | - Restart B=3.
25 | - MetaData MT=4. The payload is made of a single field.
26 | - UTF8 encoded string with json data.
27 | - Error MT=5. The payload is made of a single field.
28 | - UTF8 encoded string containing the error description.
29 | - Ping MT=6. No payload, this message type is currently unused.
30 | ```
31 | Messages with an unknow message types should be discarded.
32 |
--------------------------------------------------------------------------------
/rust/rustfmt.toml:
--------------------------------------------------------------------------------
1 | use_small_heuristics = "Max"
2 | edition = "2021"
3 |
4 |
--------------------------------------------------------------------------------
/rust/s2st-1b.toml:
--------------------------------------------------------------------------------
1 | text_in_vocab_size = 48001
2 | text_out_vocab_size = 48000
3 | audio_vocab_size = 2049
4 | audio_codebooks = 16
5 |
6 | [transformer]
7 | d_model = 2048
8 | num_heads = 16
9 | num_layers = 16
10 | dim_feedforward = 8192
11 | causal = true
12 | norm_first = true
13 | bias_ff = false
14 | bias_attn = false
15 | context = 3000
16 | max_period = 100000
17 | use_conv_block = false
18 | use_conv_bias = true
19 | gating = "silu"
20 | norm = "RmsNorm"
21 | positional_embedding = "Rope"
22 | conv_layout = false
23 | conv_kernel_size = 3
24 | kv_repeat = 1
25 | max_seq_len = 4096
26 |
27 | [depformer]
28 | num_slices = 8
29 |
30 | [depformer.transformer]
31 | d_model = 1024
32 | num_heads = 16
33 | num_layers = 6
34 | dim_feedforward = 4096
35 | causal = true
36 | norm_first = true
37 | bias_ff = false
38 | bias_attn = false
39 | context = 32
40 | max_period = 10000
41 | use_conv_block = false
42 | use_conv_bias = true
43 | gating = "silu"
44 | norm = "RmsNorm"
45 | positional_embedding = "None"
46 | conv_layout = false
47 | conv_kernel_size = 3
48 | kv_repeat = 1
49 | max_seq_len = 4096
50 |
51 | [conditioners.description]
52 | type = "Lut"
53 | n_bins = 31
54 | dim = 16
55 | possible_values = ["very_bad", "bad", "neutral", "good", "very_good"]
56 |
--------------------------------------------------------------------------------
/scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyutai-labs/moshi/5d496822e1236211e219f419cc69837335da3a6f/scripts/__init__.py
--------------------------------------------------------------------------------
/scripts/export_quantized.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Kyutai, all rights reserved.
2 | # This source code is licensed under the license found in the
3 | # LICENSE file in the root directory of this source tree.
4 | """Convert a repo into a quantized one (PyTorch only). Need to run from a GPU."""
5 |
6 |
7 | import argparse
8 | import json
9 | from pathlib import Path
10 | import tempfile
11 |
12 | from huggingface_hub import HfApi, hf_hub_download
13 | from safetensors.torch import save_file
14 |
15 | from moshi.models import loaders
16 | from moshi.utils.quantize import replace_linear_with_qlinear
17 |
18 |
19 | def get_api():
20 | token = input("Write token? ").strip()
21 | api = HfApi(token=token)
22 | return api
23 |
24 |
25 | def main():
26 | parser = argparse.ArgumentParser('export_quantized')
27 | parser.add_argument('hf_repo')
28 | parser.add_argument('new_hf_repo', nargs='?', default=None)
29 |
30 | args = parser.parse_args()
31 | api = get_api()
32 |
33 | repo = args.hf_repo
34 |
35 | print("Downloading base model.")
36 | info = loaders.CheckpointInfo.from_hf_repo(args.hf_repo)
37 | print("Creating model.")
38 | model = info.get_moshi(fuse_lora=True, device='cuda')
39 | print("Quantizing model.")
40 | replace_linear_with_qlinear(model)
41 |
42 | if args.new_hf_repo is None:
43 | new_repo = repo.rsplit('-', 1)[0] + '-q8'
44 | else:
45 | new_repo = args.new_hf_repo
46 | if not api.repo_exists(new_repo):
47 | api.create_repo(new_repo, repo_type='model')
48 | print("Repo created.")
49 |
50 | to_copy = ['README.md']
51 | for file in to_copy:
52 | if not api.file_exists(repo, file):
53 | continue
54 | if not api.file_exists(new_repo, file):
55 | print("File", file, "is missing")
56 | old_file = hf_hub_download(repo, file)
57 | api.upload_file(
58 | path_or_fileobj=old_file,
59 | path_in_repo=file,
60 | repo_id=new_repo,
61 | repo_type="model")
62 | with tempfile.NamedTemporaryFile(suffix='.safetensors', delete=True) as file:
63 | save_file(model.state_dict(), file.name)
64 | size = Path(file.name).stat().st_size / 1e9
65 | print(f"Checkpoint size: {size:.1f}GB")
66 | old_name, old_ext = info.moshi_weights.name.rsplit('.', 1)
67 | new_name = old_name + '.q8.' + old_ext
68 | api.upload_file(
69 | path_or_fileobj=file.name,
70 | path_in_repo=new_name,
71 | repo_id=new_repo,
72 | repo_type="model")
73 | config = json.load(open(hf_hub_download(repo, 'config.json')))
74 | config['moshi_name'] = new_name
75 | config['quantize'] = True
76 | if not config['mimi_name'].startswith('hf://'):
77 | config['mimi_name'] = f'hf://{repo}/{config["mimi_name"]}'
78 | if not config['tokenizer_name'].startswith('hf://'):
79 | config['tokenizer_name'] = f'hf://{repo}/{config["tokenizer_name"]}'
80 | with tempfile.NamedTemporaryFile(mode='w') as file:
81 | json.dump(config, file, indent=2)
82 | file.flush()
83 | api.upload_file(
84 | path_or_fileobj=file.name,
85 | path_in_repo='config.json',
86 | repo_id=new_repo,
87 | repo_type="model")
88 |
89 |
90 | if __name__ == "__main__":
91 | main()
92 |
--------------------------------------------------------------------------------
/scripts/export_torch.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Kyutai, all rights reserved.
2 | # This source code is licensed under the license found in the
3 | # LICENSE file in the root directory of this source tree.
4 | """Export a model to HF."""
5 |
6 |
7 | import argparse
8 | import json
9 | import tempfile
10 |
11 | from huggingface_hub import HfApi
12 |
13 | from moshi.models import loaders
14 |
15 |
16 | def get_api():
17 | token = input("Write token? ").strip()
18 | api = HfApi(token=token)
19 | return api
20 |
21 |
22 | def main():
23 | parser = argparse.ArgumentParser('export_quantized')
24 | parser.add_argument("--tokenizer", type=str, help="Path to a local tokenizer file.")
25 | parser.add_argument("--moshi-weight", type=str, help="Path to a local checkpoint file for Moshi.")
26 | parser.add_argument("--mimi-weight", type=str, help="Path to a local checkpoint file for Mimi.")
27 | parser.add_argument("--hf-repo", type=str, default=loaders.DEFAULT_REPO,
28 | help="HF repo to look into, defaults Moshiko. "
29 | "Use this to select a different pre-trained model.")
30 | parser.add_argument("--config", "--lm-config", dest="config", type=str, help="The config as a json file.")
31 | parser.add_argument('new_hf_repo')
32 |
33 | args = parser.parse_args()
34 | api = get_api()
35 |
36 | info = loaders.CheckpointInfo.from_hf_repo(
37 | args.hf_repo, moshi_weights=args.moshi_weight, mimi_weights=args.mimi_weight,
38 | tokenizer=args.tokenizer, config_path=args.config)
39 |
40 | if not api.repo_exists(args.new_hf_repo):
41 | api.create_repo(args.new_hf_repo, repo_type='model', private=True)
42 | print("Repo created.")
43 |
44 | config = info.raw_config
45 | assert config is not None
46 | config['mimi_name'] = info.mimi_weights.name
47 | config['moshi_name'] = info.moshi_weights.name
48 | config['tokenizer_name'] = info.tokenizer.name
49 | for file in [info.mimi_weights, info.moshi_weights, info.tokenizer]:
50 | if not api.file_exists(args.new_hf_repo, file.name):
51 | print("Uploading file", file)
52 | api.upload_file(
53 | path_or_fileobj=file,
54 | path_in_repo=file.name,
55 | repo_id=args.new_hf_repo,
56 | repo_type="model")
57 | with tempfile.NamedTemporaryFile(mode='w') as file:
58 | json.dump(config, file, indent=2)
59 | file.flush()
60 | api.upload_file(
61 | path_or_fileobj=file.name,
62 | path_in_repo='config.json',
63 | repo_id=args.new_hf_repo,
64 | repo_type="model")
65 |
66 |
67 | if __name__ == "__main__":
68 | main()
69 |
--------------------------------------------------------------------------------
/scripts/import_helium_mlx.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Kyutai, all rights reserved.
2 | # This source code is licensed under the license found in the
3 | # LICENSE file in the root directory of this source tree.
4 |
5 | import argparse
6 | import torch
7 | from pathlib import Path
8 | from safetensors import safe_open
9 | from safetensors.torch import save_file
10 | from huggingface_hub import hf_hub_download
11 |
12 |
13 | def import_model(in_path: Path, out_path: Path, silent: bool = False) -> None:
14 | with safe_open(in_path, framework="pt", device="cpu") as f:
15 | tensors = {key: f.get_tensor(key) for key in f.keys()}
16 | model = {
17 | "text_emb.weight": tensors["model.embed_tokens.weight"],
18 | "text_linear.weight": tensors["lm_head.weight"],
19 | "out_norm.weight": tensors["model.norm.weight"],
20 | }
21 | n_layers = -1
22 | for key in tensors.keys():
23 | if key.startswith("model.layers."):
24 | layer_idx = int(key.split(".")[2])
25 | n_layers = max(layer_idx, n_layers)
26 | n_layers += 1
27 | if not silent:
28 | print(f"found {n_layers} layers")
29 | for layer_idx in range(n_layers):
30 | dst_prefix = f"transformer.layers.{layer_idx}."
31 | src_prefix = f"model.layers.{layer_idx}."
32 | _model = {
33 | "norm1.weight": "input_layernorm.weight",
34 | "norm2.weight": "post_attention_layernorm.weight",
35 | "self_attn.out_proj.weight": "self_attn.o_proj.weight",
36 | "gating.linear_out.weight": "mlp.down_proj.weight",
37 | }
38 | for dst, src in _model.items():
39 | model[dst_prefix + dst] = tensors[src_prefix + src]
40 | gate_proj = tensors[src_prefix + "mlp.gate_proj.weight"]
41 | up_proj = tensors[src_prefix + "mlp.up_proj.weight"]
42 | linear_in = torch.cat([gate_proj, up_proj], dim=0)
43 | model[dst_prefix + "gating.linear_in.weight"] = linear_in
44 | q = tensors[src_prefix + "self_attn.q_proj.weight"]
45 | k = tensors[src_prefix + "self_attn.k_proj.weight"]
46 | v = tensors[src_prefix + "self_attn.v_proj.weight"]
47 | in_proj = torch.cat([q, k, v], dim=0)
48 | model[dst_prefix + "self_attn.in_proj.weight"] = in_proj
49 |
50 | save_file(model, out_path)
51 |
52 |
53 | def main():
54 | parser = argparse.ArgumentParser()
55 | parser.add_argument(
56 | "--checkpoint",
57 | type=str,
58 | default="kyutai/helium-1-preview-2b",
59 | help="the transformers checkpoint to import",
60 | )
61 | parser.add_argument("--out", type=str, help="the mlx safetensors file to generate")
62 | parser.add_argument(
63 | "-s", "--silent", action="store_true", help="Only prints the checkpoint name"
64 | )
65 | args = parser.parse_args()
66 |
67 | ckpt_path = Path(args.checkpoint)
68 | if not ckpt_path.exists():
69 | ckpt_path = hf_hub_download(
70 | repo_id=args.checkpoint, filename="model.safetensors"
71 | )
72 | out_path = Path(args.out)
73 | if not out_path.exists():
74 | import_model(ckpt_path, out_path, silent=args.silent)
75 | print(out_path)
76 |
77 |
78 | if __name__ == "__main__":
79 | main()
80 |
--------------------------------------------------------------------------------
/scripts/import_lightformer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Kyutai, all rights reserved.
2 | # This source code is licensed under the license found in the
3 | # LICENSE file in the root directory of this source tree.
4 | """Import Moshi model, in particular with support for a 'light' depth transformer
5 | with low rank embeddings and weight sharing for some codebooks."""
6 |
7 | import argparse
8 | from pathlib import Path
9 | from safetensors.torch import save_file
10 |
11 | import omegaconf
12 | import torch
13 |
14 |
15 | def import_model(
16 | in_path: str,
17 | out_path: str,
18 | ) -> None:
19 | pkg = torch.load(in_path, map_location=torch.device("cpu"))
20 | if 'xp.cfg' in pkg:
21 | cfg = pkg['xp.cfg']
22 | else:
23 | cfg = omegaconf.OmegaConf.load(Path(in_path).parent / '.hydra/config.yaml')
24 |
25 | model = pkg["fsdp_best_state"]["model"]
26 |
27 | # Asumming same size of both streams n_q.
28 | in_n_q = cfg.compression_model_n_q * 2
29 | out_n_q = cfg.compression_model_n_q
30 | print(f"in_n_q: {in_n_q}, out_n_q: {out_n_q}")
31 | schedule = cfg.transformer_lm.get('depformer_weights_per_step_schedule', None)
32 | if schedule is None:
33 | schedule = list(range(in_n_q))
34 |
35 | num_weights = max(schedule) + 1
36 | schedule = schedule[:out_n_q]
37 | kept_weights = max(schedule) + 1
38 | print(f"Number of dep weights: {num_weights}, keeping {kept_weights}")
39 |
40 | for idx in range(cfg.transformer_lm.depformer_num_layers):
41 | in_proj_key = f"depformer.layers.{idx}.self_attn.in_proj_weight"
42 | in_proj = model[in_proj_key]
43 | in_proj = in_proj.view(num_weights, -1, *in_proj.shape[1:])
44 | model[in_proj_key] = in_proj[:kept_weights].view(-1, *in_proj.shape[2:]).contiguous()
45 | out_proj_key = f"depformer.layers.{idx}.self_attn.out_proj.weight"
46 | out_proj = model[out_proj_key]
47 | out_proj = out_proj.view(num_weights, -1, *out_proj.shape[1:])
48 | model[out_proj_key] = out_proj[:kept_weights].view(-1, *out_proj.shape[2:]).contiguous()
49 |
50 | # For mimi inference, we trim the depformer layer that are unused.
51 | for dep_idx in range(out_n_q - 1, in_n_q - 1):
52 | del model[f"depformer_emb.{dep_idx}.weight"]
53 | if cfg.transformer_lm.get('depformer_low_rank_embeddings'):
54 | del model[f"depformer_emb.{dep_idx}.low_rank.weight"]
55 | for dep_idx in range(out_n_q, in_n_q):
56 | del model[f"linears.{dep_idx}.weight"]
57 | for real_idx in range(kept_weights, num_weights):
58 | model.pop(f"depformer_in.{real_idx}.weight")
59 | for idx in range(cfg.transformer_lm.depformer_num_layers):
60 | model.pop(f"depformer.layers.{idx}.gating.{real_idx}.linear_in.weight")
61 | model.pop(f"depformer.layers.{idx}.gating.{real_idx}.linear_out.weight")
62 |
63 | schedule = schedule[:out_n_q]
64 |
65 | save_file(model, out_path)
66 |
67 |
68 | def main():
69 | parser = argparse.ArgumentParser(
70 | prog="moshi_import", description="Imports moshi checkpoints"
71 | )
72 | parser.add_argument("checkpoint", help="The checkpoint to be imported.")
73 | parser.add_argument("out", help="The safetensors out file.")
74 | args = parser.parse_args()
75 |
76 | out_path = Path(args.out)
77 |
78 | if out_path.exists():
79 | print("file already exists")
80 | else:
81 | import_model(args.checkpoint, out_path)
82 | print(out_path)
83 |
84 |
85 | if __name__ == "__main__":
86 | main()
87 |
--------------------------------------------------------------------------------
/scripts/mimi_mlx.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Kyutai, all rights reserved.
2 | # This source code is licensed under the license found in the
3 | # LICENSE file in the root directory of this source tree.
4 |
5 | import argparse
6 | from huggingface_hub import hf_hub_download
7 | import numpy as np
8 | import mlx.core as mx
9 | import sphn
10 | import moshi_mlx
11 |
12 |
13 | def run():
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument("--input", type=str)
16 | parser.add_argument("--model-file", type=str)
17 | parser.add_argument("--hf-repo", type=str, default="kyutai/moshiko-mlx-q4")
18 | parser.add_argument("--streaming", action="store_true")
19 | args = parser.parse_args()
20 |
21 | pcm_in, _ = sphn.read(args.input, sample_rate=24000)
22 | pcm_in = mx.array(pcm_in[0])[None, None]
23 | print(pcm_in.shape)
24 |
25 | if args.model_file is None:
26 | model_file = hf_hub_download(args.hf_repo, "tokenizer-e351c8d8-checkpoint125.safetensors")
27 | else:
28 | model_file = args.model_file
29 | cfg = moshi_mlx.models.mimi.mimi_202407(32)
30 | print("building model", flush=True)
31 | model = moshi_mlx.models.mimi.Mimi(cfg)
32 | print(f"loading weights {model_file}", flush=True)
33 | model.load_pytorch_weights(model_file, strict=True)
34 | print("weights loaded")
35 |
36 | if args.streaming:
37 | chunk_size = 1920
38 | pcm_out = []
39 | len_ = pcm_in.shape[-1]
40 | print("starting streaming conversion")
41 | for start_idx in range(0, len_, chunk_size):
42 | end_idx = start_idx + chunk_size
43 | if end_idx >= len_:
44 | break
45 | _pcm_in = pcm_in[..., start_idx:end_idx]
46 | codes = model.encode_step(_pcm_in)
47 | _pcm_out = model.decode_step(codes)
48 | pcm_out.append(_pcm_out)
49 | pct = int(100 * start_idx / len_)
50 | print(f"{pct}%", end="\r", flush=True)
51 | print()
52 | pcm_out = mx.concat(pcm_out, axis=-1)
53 | else:
54 | codes = model.encode(pcm_in)
55 | print(codes.shape)
56 | pcm_out = model.decode(codes)
57 | print("writing output file with audio shape", pcm_out.shape)
58 | sphn.write_wav("out.wav", np.array(pcm_out[0]), sample_rate=24000)
59 |
60 | if __name__ == "__main__":
61 | run()
62 |
--------------------------------------------------------------------------------
/scripts/mimi_streaming_test.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Kyutai, all rights reserved.
2 | # This source code is licensed under the license found in the
3 | # LICENSE file in the root directory of this source tree.
4 |
5 | import argparse
6 | import random
7 | import time
8 |
9 | from huggingface_hub import hf_hub_download
10 | import numpy as np
11 | import sphn
12 | import torch
13 | from torch.profiler import profile, ProfilerActivity
14 |
15 | from moshi.models import loaders
16 |
17 |
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument("--mimi-weight", type=str)
20 | parser.add_argument("--hf-repo", type=str, default=loaders.DEFAULT_REPO)
21 | parser.add_argument(
22 | "--device", type=str, default="cuda" if torch.cuda.device_count() else "cpu"
23 | )
24 | parser.add_argument("--profile", action="store_true")
25 | args = parser.parse_args()
26 |
27 |
28 | def seed_all(seed):
29 | torch.manual_seed(seed)
30 | if torch.cuda.is_available():
31 | torch.cuda.manual_seed(seed)
32 | torch.cuda.manual_seed_all(seed) # for multi-GPU setups
33 | random.seed(seed)
34 | np.random.seed(seed)
35 | torch.backends.cudnn.deterministic = True
36 | torch.backends.cudnn.benchmark = False
37 |
38 |
39 | seed_all(42424242)
40 |
41 |
42 | print("loading mimi")
43 | if args.mimi_weight is None:
44 | args.mimi_weight = hf_hub_download(args.hf_repo, loaders.MIMI_NAME)
45 | mimi = loaders.get_mimi(args.mimi_weight, args.device)
46 | print("mimi loaded")
47 |
48 |
49 | def mimi_streaming_test(mimi, max_duration_sec=10.0):
50 | pcm_chunk_size = int(mimi.sample_rate / mimi.frame_rate)
51 | # wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3
52 | sample_pcm, sample_sr = sphn.read("bria.mp3")
53 | sample_rate = mimi.sample_rate
54 | print("loaded pcm", sample_pcm.shape, sample_sr)
55 | sample_pcm = sphn.resample(
56 | sample_pcm, src_sample_rate=sample_sr, dst_sample_rate=sample_rate
57 | )
58 | sample_pcm = torch.tensor(sample_pcm, device=args.device)
59 | max_duration_len = int(sample_rate * max_duration_sec)
60 | if sample_pcm.shape[-1] > max_duration_len:
61 | sample_pcm = sample_pcm[..., :max_duration_len]
62 | print("resampled pcm", sample_pcm.shape, sample_sr)
63 | sample_pcm = sample_pcm[None].to(device=args.device)
64 |
65 | print("streaming encoding...")
66 | start_time = time.time()
67 | all_codes = []
68 |
69 | def run_loop():
70 | for start_idx in range(0, sample_pcm.shape[-1], pcm_chunk_size):
71 | end_idx = min(sample_pcm.shape[-1], start_idx + pcm_chunk_size)
72 | chunk = sample_pcm[..., start_idx:end_idx]
73 | codes = mimi.encode(chunk)
74 | if codes.shape[-1]:
75 | print(start_idx, codes.shape, end="\r")
76 | all_codes.append(codes)
77 |
78 | if args.profile:
79 | with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
80 | run_loop()
81 | prof.export_chrome_trace("trace.json")
82 | else:
83 | run_loop()
84 | all_codes_th = torch.cat(all_codes, dim=-1)
85 | print(f"codes {all_codes_th.shape} generated in {time.time() - start_time:.2f}s")
86 | print("streaming decoding...")
87 | all_pcms = []
88 | with mimi.streaming(1):
89 | for i in range(all_codes_th.shape[-1]):
90 | codes = all_codes_th[..., i : i + 1]
91 | pcm = mimi.decode(codes)
92 | print(i, pcm.shape, end="\r")
93 | all_pcms.append(pcm)
94 | all_pcms = torch.cat(all_pcms, dim=-1)
95 | print("pcm", all_pcms.shape, all_pcms.dtype)
96 | sphn.write_wav("streaming_out.wav", all_pcms[0, 0].cpu().numpy(), sample_rate)
97 | pcm = mimi.decode(all_codes_th)
98 | print("pcm", pcm.shape, pcm.dtype)
99 | sphn.write_wav("roundtrip_out.wav", pcm[0, 0].cpu().numpy(), sample_rate)
100 |
101 |
102 | with torch.no_grad():
103 | mimi_streaming_test(mimi)
104 |
--------------------------------------------------------------------------------
/scripts/quantize_mlx.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Kyutai, all rights reserved.
2 | # This source code is licensed under the license found in the
3 | # LICENSE file in the root directory of this source tree.
4 |
5 | import argparse
6 |
7 | import mlx.core as mx
8 | import mlx.nn as nn
9 |
10 | import moshi_mlx
11 |
12 |
13 | def main():
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument("original_weights", type=str)
16 | parser.add_argument("--out", type=str)
17 | parser.add_argument("--config", type=str, default="v0_1")
18 | parser.add_argument("--bits", type=int, default=8)
19 | parser.add_argument("--group-size", type=int, default=64)
20 | args = parser.parse_args()
21 |
22 | model_file = args.original_weights
23 |
24 | if args.config == "v0_1":
25 | lm_config = moshi_mlx.models.config_v0_1()
26 | elif args.config == "1b":
27 | lm_config = moshi_mlx.models.config1b_202412()
28 | elif args.config == "1b-16rvq":
29 | lm_config = moshi_mlx.models.config1b_202412_16rvq()
30 | elif args.config == "helium-2b":
31 | lm_config = moshi_mlx.models.config_helium_1_preview_2b()
32 | else:
33 | raise ValueError(f"unknown config name '{args.config}'")
34 | print(f"model config:\n{lm_config}")
35 |
36 | model = moshi_mlx.models.Lm(lm_config)
37 | model.set_dtype(mx.bfloat16)
38 | print(f"loading weights {model_file}")
39 | model.load_weights(model_file, strict=True)
40 | print("weights loaded")
41 |
42 | nn.quantize(model, bits=args.bits, group_size=args.group_size)
43 | print(f"saving the quantized q{args.bits} weights in {args.out}")
44 | model.save_weights(args.out)
45 |
46 |
47 | if __name__ == "__main__":
48 | main()
49 |
--------------------------------------------------------------------------------
/scripts/run_ci_when_installed.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # This script is used to detect if moshi or moshi_mlx are installed, and run
4 | # their CI only in that case!
5 |
6 | package=$1
7 | if python -c "from $package import models"; then
8 | # package is installed, let's run the command
9 | eval $2
10 | else
11 | echo "Package $package not installed, skipping the CI for it."
12 | fi
13 |
--------------------------------------------------------------------------------
/scripts/setup.cfg:
--------------------------------------------------------------------------------
1 | [pep8]
2 | max-line-length = 120
3 |
4 | [flake8]
5 | max-line-length = 120
6 | ignore = E203,E704
7 |
--------------------------------------------------------------------------------
/scripts/test_mimi.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Kyutai, all rights reserved.
2 | # This source code is licensed under the license found in the
3 | # LICENSE file in the root directory of this source tree.
4 |
5 | import argparse
6 | import numpy as np
7 | import time
8 |
9 | import rustymimi
10 |
11 |
12 | def main():
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument("--model", type=str)
15 | parser.add_argument("--steps", default=100, type=int)
16 | args = parser.parse_args()
17 |
18 | steps = args.steps
19 | model = rustymimi.Tokenizer(str(args.model))
20 | print(model)
21 |
22 | start_time = 0
23 | for i in range(steps + 1):
24 | if i == 1:
25 | start_time = time.time()
26 | pcm_data = np.array([[[0.0] * 1920]]).astype(np.float32)
27 | out = model.encode_step(pcm_data)
28 | print(out.shape)
29 | pcm_data = model.decode_step(out)
30 | print(pcm_data)
31 | token_per_second = steps / (time.time() - start_time)
32 | print(f"steps: {steps}, token per sec: {token_per_second}")
33 |
34 |
35 | if __name__ == "__main__":
36 | main()
37 |
--------------------------------------------------------------------------------
/scripts/test_missing_data.py:
--------------------------------------------------------------------------------
1 | """Testing exec_mask feature of the streaming module, where each batch entry
2 | can advance at its own pace, while retaining full compat with CUDA Graph."""
3 | import sys
4 | import sphn
5 | import torch
6 |
7 | from moshi.models import loaders
8 |
9 | device = 'cuda'
10 | wnp, sr = sphn.read(sys.argv[1],
11 | start_sec=0, duration_sec=8, sample_rate=24000)
12 | wav = torch.from_numpy(wnp)
13 | ci = loaders.CheckpointInfo.from_hf_repo("kyutai/moshiko-pytorch-bf16")
14 | mimi = ci.get_mimi(device=device)
15 | mimi.eval()
16 |
17 | frame = int(mimi.sample_rate / mimi.frame_rate)
18 |
19 | total = wav.shape[-1] // frame
20 | B = 4
21 | wav = wav[..., :total * frame][None]
22 | remaining = [total for _ in range(B)]
23 |
24 | print("Ref computation")
25 | with torch.no_grad():
26 | ref_codes = mimi.encode(wav[:1].to(device=device))
27 | ref_audio = mimi.decode(ref_codes[:1].to(device=device))
28 |
29 | out_audio = torch.zeros_like(ref_audio.expand(B, -1, -1))
30 | out_codes = torch.zeros_like(ref_codes.expand(B, -1, -1))
31 |
32 | print("Going streaming")
33 | with mimi.streaming(B), torch.no_grad():
34 | while any(remaining):
35 | inputs = []
36 | exec_mask = torch.rand(B) < 0.5
37 | offsets = []
38 | for b, this_remaining in enumerate(remaining):
39 | offset = min(total - this_remaining, total - 1)
40 | offsets.append(offset)
41 | inputs.append(wav[0, :, offset * frame: offset * frame + frame])
42 | input_ = torch.stack(inputs)
43 | mimi.set_exec_mask(exec_mask)
44 | codes = mimi.encode(input_.to(device=device))
45 | assert codes.shape[-1] == 1, codes.shape
46 | w = mimi.decode(codes)
47 | assert w.shape[-1] == frame, w.shape
48 | print(remaining)
49 | for b, active in enumerate(exec_mask.tolist()):
50 | if not active or remaining[b] == 0:
51 | continue
52 | remaining[b] = max(0, remaining[b] - 1)
53 | offset = offsets[b]
54 | out_codes[b, :, offset: offset + 1] = codes[b]
55 | out_audio[b, :, offset * frame: offset * frame + frame] = w[b]
56 |
57 | print(ref_codes[0, :, -1])
58 | print(out_codes[0, :, -1])
59 | for b in range(B):
60 | print((out_codes[b] != ref_codes[:1]).any(dim=0).nonzero()[:1])
61 | d = (out_codes[..., :1] == ref_codes[..., :1]).float().mean(dim=(1, 2))
62 | print(d)
63 | d = (out_codes[..., :2] == ref_codes[..., :2]).float().mean(dim=(1, 2))
64 | print(d)
65 | d = (out_codes[..., :10] == ref_codes[..., :10]).float().mean(dim=(1, 2))
66 | print(d)
67 | d = (out_codes == ref_codes).float().mean(dim=(1, 2))
68 | print(d)
69 | d = (out_audio - ref_audio).norm(dim=-1, p=2) / ref_audio.norm(dim=-1, p=2)
70 | print(d.sum(dim=1))
71 |
--------------------------------------------------------------------------------
/scripts/test_missing_data_lm.py:
--------------------------------------------------------------------------------
1 | """Testing exec_mask feature of the streaming module, where each batch entry
2 | can advance at its own pace, while retaining full compat with CUDA Graph."""
3 | import sys
4 | import sphn
5 | import torch
6 |
7 | from moshi.models import loaders
8 | from moshi.models.lm import LMGen
9 | from moshi.conditioners import ConditionAttributes
10 |
11 | device = 'cuda'
12 | wnp, sr = sphn.read(sys.argv[1],
13 | start_sec=0, duration_sec=8, sample_rate=24000)
14 | wav = torch.from_numpy(wnp)
15 | ci = loaders.CheckpointInfo.from_hf_repo("kyutai/hibiki-2b-pytorch-bf16")
16 | mimi = ci.get_mimi(device=device)
17 | mimi.eval()
18 |
19 | lm = ci.get_moshi(device=device)
20 |
21 | B = 4
22 |
23 | with torch.no_grad():
24 | codes = mimi.encode(wav[:1].to(device=device)[None])
25 |
26 | T = codes.shape[-1]
27 | offsets = [0 for _ in range(B)]
28 |
29 | out_codes: list[list[torch.Tensor]] = [[] for _ in range(B)]
30 | assert lm.condition_provider is not None
31 | conditions = [ConditionAttributes(text={"description": "very_good"}, tensor={})] * B
32 | prepared = lm.condition_provider.prepare(conditions)
33 | condition_tensors = lm.condition_provider(prepared)
34 | lm_gen = LMGen(lm, temp=0., temp_text=0., support_out_of_sync=True, condition_tensors=condition_tensors)
35 | print("Going streaming")
36 | with torch.no_grad(), lm_gen.streaming(B):
37 | while any(o < T for o in offsets):
38 | inputs = []
39 | exec_mask = torch.rand(B) < 0.5
40 | exec_mask[0] = True
41 | for offset in offsets:
42 | inputs.append(codes[:, :, min(offset, T - 1)])
43 | input_ = torch.cat(inputs)[..., None]
44 | lm_gen.set_exec_mask(exec_mask)
45 | pred = lm_gen.step(input_.to(device=device))
46 | assert pred is not None
47 | assert pred.shape[-1] == 1, pred.shape
48 | for b, active in enumerate(exec_mask.tolist()):
49 | if not active or offsets[b] >= T:
50 | continue
51 | if offsets[b] >= 2:
52 | assert (pred[b] >= 0).all()
53 | offsets[b] += 1
54 | out_codes[b].append(pred[b])
55 |
56 | alls = []
57 | for frames in out_codes:
58 | out = torch.cat(frames, -1)
59 | alls.append(out)
60 |
61 | ys = torch.stack(alls)
62 | r = ys[:1]
63 | o = ys[1:]
64 |
65 | ma = (r == o).float().mean(dim=(0, 1))
66 | print(ma)
67 | print(r[0, :2, :5])
68 | print(o[0, :2, :5])
69 | print(o[1, :2, :5])
70 | print(o[2, :2, :5])
71 |
--------------------------------------------------------------------------------
/scripts/update_repo.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Kyutai, all rights reserved.
2 | # This source code is licensed under the license found in the
3 | # LICENSE file in the root directory of this source tree.
4 | """Create or update a HF repo."""
5 |
6 |
7 | import argparse
8 | import json
9 | from pathlib import Path
10 | import tempfile
11 |
12 | from huggingface_hub import HfApi, hf_hub_download
13 | from safetensors.torch import save_file
14 |
15 | from moshi.models import loaders
16 | from moshi.modules.transformer import quantize_transformer
17 |
18 |
19 | def get_api():
20 | token = input("Write token? ").strip()
21 | api = HfApi(token=token)
22 | return api
23 |
24 | def main():
25 | parser = argparse.ArgumentParser('update_repo')
26 | parser.add_argument('repo')
27 | parser.add_argument('-c', '--config')
28 | parser.add_argument('-r', '--readme')
29 | parser.add_argument('-m', '--mimi-weight')
30 | parser.add_argument('-M', '--moshi-weight')
31 | parser.add_argument('-t', '--tokenizer')
32 |
33 | args = parser.parse_args()
34 |
35 | api = get_api()
36 | if not api.repo_exists(args.repo):
37 | api.create_repo(args.repo, repo_type='model')
38 | print(f"Repo {args.repo} created.")
39 |
40 | old_config = None
41 | if api.file_exists(args.repo, 'config.json'):
42 | old_config = json.load(open(hf_hub_download(args.repo, 'config.json')))
43 |
44 | changes = False
45 | if args.config:
46 | changes = True
47 | new_config = json.load(open(args.config))
48 | elif old_config:
49 | new_config = old_config
50 | else:
51 | new_config = {}
52 |
53 | names = ['mimi_name', 'moshi_name', 'tokenizer_name']
54 | paths = [args.mimi_weight, args.moshi_weight, args.tokenizer]
55 | for name, path in zip(names, paths):
56 | if path is None:
57 | if old_config is not None and name in old_config:
58 | new_config[name] = old_config[name]
59 | continue
60 | filename = Path(path).name
61 | print(f"Uploading {path}")
62 | api.upload_file(
63 | path_or_fileobj=path,
64 | path_in_repo=filename,
65 | repo_id=args.repo,
66 | repo_type="model")
67 | new_config[name] = filename
68 | changes = True
69 |
70 | if changes:
71 | with tempfile.NamedTemporaryFile(mode='w') as file:
72 | json.dump(new_config, file, indent=2)
73 | file.flush()
74 | api.upload_file(
75 | path_or_fileobj=file.name,
76 | path_in_repo='config.json',
77 | repo_id=args.repo,
78 | repo_type="model")
79 |
80 |
81 | if __name__ == "__main__":
82 | main()
83 |
--------------------------------------------------------------------------------