├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── MANIFEST.in
├── README.md
├── requirements.txt
├── s3tokenizer
├── __init__.py
├── assets
│ ├── BAC009S0764W0121.wav
│ ├── BAC009S0764W0122.wav
│ └── mel_filters.npz
├── cli.py
├── model.py
├── model_v2.py
└── utils.py
├── setup.py
└── test
└── test_onnx.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | name: Release
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | jobs:
8 | deploy:
9 | runs-on: ubuntu-latest
10 | steps:
11 | - uses: actions/checkout@v3
12 | - uses: actions-ecosystem/action-regex-match@v2
13 | id: regex-match
14 | with:
15 | text: ${{ github.event.head_commit.message }}
16 | regex: '^Release ([^ ]+)'
17 | - name: Set up Python
18 | uses: actions/setup-python@v4
19 | with:
20 | python-version: '3.8'
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install setuptools wheel twine
25 | - name: Release
26 | if: ${{ steps.regex-match.outputs.match != '' }}
27 | uses: softprops/action-gh-release@v1
28 | with:
29 | tag_name: v${{ steps.regex-match.outputs.group1 }}
30 | - name: Build and publish
31 | if: ${{ steps.regex-match.outputs.match != '' }}
32 | env:
33 | TWINE_USERNAME: __token__
34 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
35 | run: |
36 | python setup.py sdist
37 | twine upload dist/*
38 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v4.5.0
4 | hooks:
5 | - id: trailing-whitespace
6 | exclude: 's3tokenizer/assets/.*'
7 | - repo: https://github.com/pre-commit/mirrors-yapf
8 | rev: 'v0.32.0'
9 | hooks:
10 | - id: yapf
11 | - repo: https://github.com/pycqa/flake8
12 | rev: '3.8.2'
13 | hooks:
14 | - id: flake8
15 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include requirements.txt
2 | include README.md
3 | include LICENSE
4 | include s3tokenizer/assets/*
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Reverse Engineering of S3Tokenizer
2 |
3 |
4 |

5 |
Supervised Semantic Speech Tokenizer (S3Tokenizer)
6 |
7 |
8 | S3Tokenizer was initially introduced in CosyVoice [[Paper]](https://arxiv.org/abs/2407.04051v2) [[Repo]](https://github.com/FunAudioLLM/CosyVoice), it is a Supervised Semantic Speech Tokenizer based on the pre-trained SenseVoice-Large model, which enhances the semantic relationship of extracted tokens to textual and paralinguistic information, is robust to data noise, and reduces the reliance on clean data collection, thereby enabling the use of a broader range of data for model training.
9 |
10 | However, as indicated in this [[issue]](https://github.com/FunAudioLLM/CosyVoice/issues/70), the authors have no intention to open-source the PyTorch implementation of the S3Tokenizer, and only plan to release an ONNX file. Additionally, users aiming to fine-tune CosyVoice must extract speech codes offline, with the batch size restricted to 1, a process that is notably time-consuming (refer to [[cosyvoice/tools/extract_speech_token.py]](https://github.com/FunAudioLLM/CosyVoice/blob/main/tools/extract_speech_token.py)).
11 |
12 | This repository undertakes a reverse engineering of the S3Tokenizer, offering:
13 | 1. A pure PyTorch implementation of S3Tokenizer (see [[model.py]](https://github.com/xingchensong/S3Tokenizer/blob/main/s3tokenizer/model.py)), compatible with initializing weights from the released ONNX file (see [[utils.py::onnx2torch()]](https://github.com/xingchensong/S3Tokenizer/blob/main/s3tokenizer/utils.py)).
14 | 2. High-throughput (distributed) batch inference, achieving a ~790x speedup compared to the original inference pipeline in [[cosyvoice/tools/extract_speech_token.py]](https://github.com/FunAudioLLM/CosyVoice/blob/main/tools/extract_speech_token.py).
15 | 3. The capability to perform online speech code extraction during SpeechLLM training.
16 |
17 | ## Supported Models 🔥
18 | - [x] [S3Tokenizer V1 50hz](https://modelscope.cn/models/iic/CosyVoice-300M)
19 | - [x] [S3Tokenizer V1 25hz](https://modelscope.cn/models/iic/CosyVoice-300M-25Hz)
20 | - [x] [S3Tokenizer V2 25hz](https://modelscope.cn/models/iic/CosyVoice2-0.5B)
21 |
22 |
23 | # Setup
24 |
25 | ```sh
26 | pip install s3tokenizer
27 | ```
28 |
29 | # Usage-1: Offline batch inference
30 |
31 | ```py
32 | import s3tokenizer
33 |
34 | tokenizer = s3tokenizer.load_model("speech_tokenizer_v1").cuda() # or "speech_tokenizer_v1_25hz speech_tokenizer_v2_25hz"
35 |
36 | mels = []
37 | wav_paths = ["s3tokenizer/assets/BAC009S0764W0121.wav", "s3tokenizer/assets/BAC009S0764W0122.wav"]
38 | for wav_path in wav_paths:
39 | audio = s3tokenizer.load_audio(wav_path)
40 | mels.append(s3tokenizer.log_mel_spectrogram(audio))
41 | mels, mels_lens = s3tokenizer.padding(mels)
42 | codes, codes_lens = tokenizer.quantize(mels.cuda(), mels_lens.cuda())
43 |
44 | for i in range(len(wav_paths)):
45 | print(codes[i, :codes_lens[i].item()])
46 | ```
47 |
48 | # Usage-2: Distributed offline batch inference via command-line tools
49 |
50 | ## 2.1 CPU batch inference
51 |
52 | ```sh
53 | s3tokenizer --wav_scp xxx.scp \
54 | --device "cpu" \
55 | --output_dir "./" \
56 | --batch_size 32 \
57 | --model "speech_tokenizer_v1" # or "speech_tokenizer_v1_25hz speech_tokenizer_v2_25hz"
58 | ```
59 |
60 |
61 |
62 | https://github.com/user-attachments/assets/d37d10fd-0e13-46a3-86b0-4cbec309086f
63 |
64 |
65 |
66 | ## 2.2 (Multi) GPU batch inference (a.k.a Distributed inference)
67 |
68 | ```sh
69 | torchrun --nproc_per_node=8 --nnodes=1 \
70 | --rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
71 | `which s3tokenizer` --wav_scp xxx.scp \
72 | --device "cuda" \
73 | --output_dir "./" \
74 | --batch_size 32 \
75 | --model "speech_tokenizer_v1" # or "speech_tokenizer_v1_25hz speech_tokenizer_v2_25hz"
76 | ```
77 |
78 |
79 |
80 | https://github.com/user-attachments/assets/79a3fb11-7199-4ee2-8a35-9682a3b4d94a
81 |
82 |
83 |
84 | ## 2.3 Performance Benchmark
85 |
86 | | Method | Time cost on Aishell Test Set | Relative speed up | Miss Rate |
87 | |:------:|:----------:|:--------------:|:-----:|
88 | | [[cosyvoice/tools/extract_speech_token.py]](https://github.com/FunAudioLLM/CosyVoice/blob/main/tools/extract_speech_token.py), cpu | 9 hours | ~ | ~ |
89 | | cpu, batchsize 32 | 1.5h | ~6x | 0.00% |
90 | | 4 gpus (3090), batchsize 32 per gpu | 41s | ~790x | 0.00% |
91 |
92 | The miss rate represents the proportion of tokens that are inconsistent between the batch inference predictions and the ONNX (batch=1) inference predictions.
93 |
94 | # Usage-3: Online speech code extraction
95 |
96 |
97 |
98 | Before (extract code offline) |
99 | After (extract code online) |
100 |
101 |
102 |
103 |
104 |
105 | ```py
106 |
107 | class SpeechLLM(nn.Module):
108 | ...
109 | def __init__(self, ...):
110 | ...
111 |
112 | def forward(self, speech_codes: Tensor, text_ids: Tensor, ...):
113 | ...
114 | ```
115 |
116 |
117 | |
118 |
119 |
120 | ```py
121 | import s3tokenizer
122 |
123 | class SpeechLLM(nn.Module):
124 | ...
125 | def __init__(self, ...):
126 | ...
127 | self.speech_tokenizer = s3tokenizer.load_model("speech_tokenizer_v1") # or "speech_tokenizer_v1_25hz"
128 | self.speech_tokenizer.freeze()
129 |
130 | def forward(self, speech: Tensor, speech_lens: Tensor, text_ids: Tensor, ...):
131 | ...
132 | speech_codes, speech_codes_lens = self.speech_tokenizer.quantize(speech, speech_lens)
133 | speech_codes = speech_codes.clone() # for backward compatbility
134 | speech_codes_lens = speeech_codes_lens.clone() # for backward compatbility
135 | ```
136 |
137 |
138 | |
139 |
140 |
141 |
142 |
143 | # TODO
144 |
145 | - [x] Usage-1: Offline batch inference
146 | - [x] Usage-2: Distributed offline batch inference via command-line tools
147 | - [x] Usage-3: Online speech code extraction
148 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pre-commit
2 | numpy
3 | torch
4 | onnx
5 | tqdm
6 | torchaudio
7 | einops
8 |
--------------------------------------------------------------------------------
/s3tokenizer/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 OpenAI. (authors: Whisper Team)
2 | # 2024 Tsinghua Univ. (authors: Xingchen Song)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Modified from
16 | https://github.com/openai/whisper/blob/main/whisper/__init__.py
17 | """
18 |
19 | import hashlib
20 | import os
21 | import urllib
22 | import warnings
23 | from typing import List, Union
24 |
25 | from tqdm import tqdm
26 |
27 | from s3tokenizer.model_v2 import S3TokenizerV2
28 |
29 | from .model import S3Tokenizer
30 | from .utils import (load_audio, log_mel_spectrogram, make_non_pad_mask,
31 | mask_to_bias, onnx2torch, padding)
32 |
33 | __all__ = [
34 | 'load_audio', 'log_mel_spectrogram', 'make_non_pad_mask', 'mask_to_bias',
35 | 'onnx2torch', 'padding'
36 | ]
37 | _MODELS = {
38 | "speech_tokenizer_v1":
39 | "https://www.modelscope.cn/models/iic/cosyvoice-300m/"
40 | "resolve/master/speech_tokenizer_v1.onnx",
41 | "speech_tokenizer_v1_25hz":
42 | "https://www.modelscope.cn/models/iic/CosyVoice-300M-25Hz/"
43 | "resolve/master/speech_tokenizer_v1.onnx",
44 | "speech_tokenizer_v2_25hz":
45 | "https://www.modelscope.cn/models/iic/CosyVoice2-0.5B/"
46 | "resolve/master/speech_tokenizer_v2.onnx",
47 | }
48 |
49 | _SHA256S = {
50 | "speech_tokenizer_v1":
51 | "23b5a723ed9143aebfd9ffda14ac4c21231f31c35ef837b6a13bb9e5488abb1e",
52 | "speech_tokenizer_v1_25hz":
53 | "56285ddd4a83e883ee0cb9f8d69c1089b53a94b1f78ff7e4a0224a27eb4cb486",
54 | "speech_tokenizer_v2_25hz":
55 | "d43342aa12163a80bf07bffb94c9de2e120a8df2f9917cd2f642e7f4219c6f71",
56 | }
57 |
58 |
59 | def _download(name: str, root: str) -> Union[bytes, str]:
60 | os.makedirs(root, exist_ok=True)
61 |
62 | expected_sha256 = _SHA256S[name]
63 | url = _MODELS[name]
64 | download_target = os.path.join(root, f"{name}.onnx")
65 |
66 | if os.path.exists(download_target) and not os.path.isfile(download_target):
67 | raise RuntimeError(
68 | f"{download_target} exists and is not a regular file")
69 |
70 | if os.path.isfile(download_target):
71 | with open(download_target, "rb") as f:
72 | model_bytes = f.read()
73 | if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
74 | return download_target
75 | else:
76 | warnings.warn(
77 | f"{download_target} exists, but the SHA256 checksum does not"
78 | " match; re-downloading the file")
79 |
80 | with urllib.request.urlopen(url) as source, open(download_target,
81 | "wb") as output:
82 | with tqdm(
83 | total=int(source.info().get("Content-Length")),
84 | ncols=80,
85 | unit="iB",
86 | unit_scale=True,
87 | unit_divisor=1024,
88 | desc="Downloading onnx checkpoint",
89 | ) as loop:
90 | while True:
91 | buffer = source.read(8192)
92 | if not buffer:
93 | break
94 |
95 | output.write(buffer)
96 | loop.update(len(buffer))
97 |
98 | model_bytes = open(download_target, "rb").read()
99 | if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
100 | raise RuntimeError(
101 | "Model has been downloaded but the SHA256 checksum does not not"
102 | " match. Please retry loading the model.")
103 |
104 | return download_target
105 |
106 |
107 | def available_models() -> List[str]:
108 | """Returns the names of available models"""
109 | return list(_MODELS.keys())
110 |
111 |
112 | def load_model(
113 | name: str,
114 | download_root: str = None,
115 | ) -> S3Tokenizer:
116 | """
117 | Load a S3Tokenizer ASR model
118 |
119 | Parameters
120 | ----------
121 | name : str
122 | one of the official model names listed by
123 | `s3tokenizer.available_models()`, or path to a model checkpoint
124 | containing the model dimensions and the model state_dict.
125 | download_root: str
126 | path to download the model files; by default,
127 | it uses "~/.cache/s3tokenizer"
128 |
129 | Returns
130 | -------
131 | model : S3Tokenizer
132 | The S3Tokenizer model instance
133 | """
134 |
135 | if download_root is None:
136 | default = os.path.join(os.path.expanduser("~"), ".cache")
137 | download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default),
138 | "s3tokenizer")
139 |
140 | if name in _MODELS:
141 | checkpoint_file = _download(name, download_root)
142 | elif os.path.isfile(name):
143 | checkpoint_file = name
144 | else:
145 | raise RuntimeError(
146 | f"Model {name} not found; available models = {available_models()}")
147 | if 'v2' in name:
148 | model = S3TokenizerV2(name)
149 | else:
150 | model = S3Tokenizer(name)
151 | model.init_from_onnx(checkpoint_file)
152 |
153 | return model
154 |
--------------------------------------------------------------------------------
/s3tokenizer/assets/BAC009S0764W0121.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xingchensong/S3Tokenizer/dc95bac8bce0dee347c40acca90b0005d8eba711/s3tokenizer/assets/BAC009S0764W0121.wav
--------------------------------------------------------------------------------
/s3tokenizer/assets/BAC009S0764W0122.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xingchensong/S3Tokenizer/dc95bac8bce0dee347c40acca90b0005d8eba711/s3tokenizer/assets/BAC009S0764W0122.wav
--------------------------------------------------------------------------------
/s3tokenizer/assets/mel_filters.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xingchensong/S3Tokenizer/dc95bac8bce0dee347c40acca90b0005d8eba711/s3tokenizer/assets/mel_filters.npz
--------------------------------------------------------------------------------
/s3tokenizer/cli.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """ Example Usage
15 | cpu:
16 |
17 | s3tokenizer --wav_scp xxx.scp \
18 | --device "cpu" \
19 | --output_dir "./" \
20 | --batch_size 32
21 |
22 | gpu:
23 |
24 | torchrun --nproc_per_node=8 --nnodes=1 \
25 | --rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
26 | `which s3tokenizer` --wav_scp xxx.scp \
27 | --device "cuda" \
28 | --output_dir "./" \
29 | --batch_size 32
30 |
31 | """
32 |
33 | import argparse
34 | import json
35 | import os
36 |
37 | import torch
38 | import torch.distributed as dist
39 | from torch.utils.data import DataLoader, Dataset, DistributedSampler
40 | from tqdm import tqdm
41 |
42 | import s3tokenizer
43 |
44 |
45 | class AudioDataset(Dataset):
46 |
47 | def __init__(self, wav_scp):
48 | self.data = []
49 | self.keys = []
50 |
51 | with open(wav_scp, 'r', encoding='utf-8') as f:
52 | for line in f:
53 | key, file_path = line.strip().split()
54 | self.data.append(file_path)
55 | self.keys.append(key)
56 |
57 | def __len__(self):
58 | return len(self.data)
59 |
60 | def __getitem__(self, idx):
61 | file_path = self.data[idx]
62 | key = self.keys[idx]
63 | audio = s3tokenizer.load_audio(file_path)
64 | if audio.shape[0] / 16000 > 30:
65 | print(
66 | f'do not support extract speech token for audio longer than 30s, file_path: {file_path}' # noqa
67 | )
68 | mel = torch.zeros(128, 0)
69 | else:
70 | mel = s3tokenizer.log_mel_spectrogram(audio)
71 | return key, mel
72 |
73 |
74 | def collate_fn(batch):
75 | keys = [item[0] for item in batch]
76 | mels = [item[1] for item in batch]
77 | mels, mels_lens = s3tokenizer.padding(mels)
78 | return keys, mels, mels_lens
79 |
80 |
81 | def init_distributed():
82 | world_size = int(os.environ.get('WORLD_SIZE', 1))
83 | local_rank = int(os.environ.get('LOCAL_RANK', 0))
84 | rank = int(os.environ.get('RANK', 0))
85 | print('Inference on multiple gpus, this gpu {}'.format(local_rank) +
86 | ', rank {}, world_size {}'.format(rank, world_size))
87 | torch.cuda.set_device(local_rank)
88 | dist.init_process_group("nccl")
89 | return world_size, local_rank, rank
90 |
91 |
92 | def get_args():
93 | parser = argparse.ArgumentParser(description='extract speech code')
94 | parser.add_argument('--model',
95 | required=True,
96 | type=str,
97 | choices=[
98 | "speech_tokenizer_v1", "speech_tokenizer_v1_25hz",
99 | "speech_tokenizer_v2_25hz"
100 | ],
101 | help='model version')
102 | parser.add_argument('--wav_scp',
103 | required=True,
104 | type=str,
105 | help='each line contains `wav_name wav_path`')
106 | parser.add_argument('--device',
107 | required=True,
108 | type=str,
109 | choices=["cuda", "cpu"],
110 | help='device for inference')
111 | parser.add_argument('--output_dir',
112 | required=True,
113 | type=str,
114 | help='dir to save result')
115 | parser.add_argument('--batch_size',
116 | required=True,
117 | type=int,
118 | help='batch size (per-device) for inference')
119 | parser.add_argument('--num_workers',
120 | type=int,
121 | default=4,
122 | help='workers for dataloader')
123 | parser.add_argument('--prefetch',
124 | type=int,
125 | default=5,
126 | help='prefetch for dataloader')
127 | args = parser.parse_args()
128 | return args
129 |
130 |
131 | def main():
132 | args = get_args()
133 | os.makedirs(args.output_dir, exist_ok=True)
134 |
135 | if args.device == "cuda":
136 | assert (torch.cuda.is_available())
137 | world_size, local_rank, rank = init_distributed()
138 | else:
139 | world_size, local_rank, rank = 1, 0, 0
140 |
141 | device = torch.device(args.device)
142 | model = s3tokenizer.load_model(args.model).to(device)
143 | dataset = AudioDataset(args.wav_scp)
144 |
145 | if args.device == "cuda":
146 | model = torch.nn.parallel.DistributedDataParallel(
147 | model, device_ids=[local_rank])
148 | sampler = DistributedSampler(dataset,
149 | num_replicas=world_size,
150 | rank=rank)
151 | else:
152 | sampler = None
153 |
154 | dataloader = DataLoader(dataset,
155 | batch_size=args.batch_size,
156 | sampler=sampler,
157 | shuffle=False,
158 | num_workers=args.num_workers,
159 | prefetch_factor=args.prefetch,
160 | collate_fn=collate_fn)
161 |
162 | total_steps = len(dataset)
163 |
164 | if rank == 0:
165 | progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
166 |
167 | writer = open(f"{args.output_dir}/part_{rank + 1}_of_{world_size}", "w")
168 | for keys, mels, mels_lens in dataloader:
169 | codes, codes_lens = model(mels.to(device), mels_lens.to(device))
170 | for i, k in enumerate(keys):
171 | code = codes[i, :codes_lens[i].item()].tolist()
172 | writer.write(
173 | json.dumps({
174 | "key": k,
175 | "code": code
176 | }, ensure_ascii=False) + "\n")
177 | if rank == 0:
178 | progress_bar.update(world_size * len(keys))
179 |
180 | if rank == 0:
181 | progress_bar.close()
182 | writer.close()
183 | if args.device == "cuda":
184 | dist.barrier()
185 | dist.destroy_process_group()
186 |
187 |
188 | if __name__ == "__main__":
189 | main()
190 |
--------------------------------------------------------------------------------
/s3tokenizer/model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 OpenAI. (authors: Whisper Team)
2 | # 2024 Tsinghua Univ. (authors: Xingchen Song)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Modified from https://github.com/openai/whisper/blob/main/whisper/model.py
16 | Add EuclideanCodebook & VectorQuantization
17 | """
18 |
19 | from dataclasses import dataclass
20 | from typing import Iterable, Optional, Tuple
21 |
22 | import numpy as np
23 | import torch
24 | import torch.nn.functional as F
25 | from einops import rearrange
26 | from torch import Tensor, nn
27 |
28 | from .utils import make_non_pad_mask, mask_to_bias, onnx2torch
29 |
30 |
31 | @dataclass
32 | class ModelConfig:
33 | n_mels: int = 128
34 | n_audio_ctx: int = 1500
35 | n_audio_state: int = 1280
36 | n_audio_head: int = 20
37 | n_audio_layer: int = 6
38 | n_codebook_size: int = 4096
39 |
40 | use_sdpa: bool = False
41 |
42 |
43 | class LayerNorm(nn.LayerNorm):
44 |
45 | def forward(self, x: Tensor) -> Tensor:
46 | return super().forward(x.float()).type(x.dtype)
47 |
48 |
49 | class Linear(nn.Linear):
50 |
51 | def forward(self, x: Tensor) -> Tensor:
52 | return F.linear(
53 | x,
54 | self.weight.to(x.dtype),
55 | None if self.bias is None else self.bias.to(x.dtype),
56 | )
57 |
58 |
59 | class Conv1d(nn.Conv1d):
60 |
61 | def _conv_forward(self, x: Tensor, weight: Tensor,
62 | bias: Optional[Tensor]) -> Tensor:
63 | return super()._conv_forward(
64 | x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype))
65 |
66 |
67 | def sinusoids(length, channels, max_timescale=10000):
68 | """Returns sinusoids for positional embedding"""
69 | assert channels % 2 == 0
70 | log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
71 | inv_timescales = torch.exp(-log_timescale_increment *
72 | torch.arange(channels // 2))
73 | scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[
74 | np.newaxis, :]
75 | return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
76 |
77 |
78 | class MultiHeadAttention(nn.Module):
79 |
80 | def __init__(self, n_state: int, n_head: int, use_sdpa: bool = False):
81 | super().__init__()
82 | self.n_head = n_head
83 | self.query = Linear(n_state, n_state)
84 | self.key = Linear(n_state, n_state, bias=False)
85 | self.value = Linear(n_state, n_state)
86 | self.out = Linear(n_state, n_state)
87 |
88 | self.use_sdpa = use_sdpa
89 |
90 | def forward(
91 | self,
92 | x: Tensor,
93 | mask: Optional[Tensor] = None,
94 | ):
95 | q = self.query(x)
96 | k = self.key(x)
97 | v = self.value(x)
98 |
99 | wv, qk = self.qkv_attention(q, k, v, mask)
100 | return self.out(wv), qk
101 |
102 | def qkv_attention(self,
103 | q: Tensor,
104 | k: Tensor,
105 | v: Tensor,
106 | mask: Optional[Tensor] = None):
107 | _, _, D = q.shape
108 | scale = (D // self.n_head)**-0.25
109 | q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
110 | k = k.view(*k.shape[:2], self.n_head, -1)
111 | v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
112 |
113 | if not self.use_sdpa:
114 | k = k.permute(0, 2, 3, 1) * scale
115 | qk = q @ k # (B, n_head, T, T)
116 | if mask is not None:
117 | qk = qk + mask
118 | qk = qk.float()
119 | w = torch.nn.functional.softmax(qk, dim=-1).to(q.dtype)
120 | return (w @ v).permute(0, 2, 1,
121 | 3).flatten(start_dim=2), qk.detach()
122 | else:
123 | k = k.permute(0, 2, 1, 3) * scale
124 | assert mask is not None
125 | output = torch.nn.functional.scaled_dot_product_attention(
126 | q,
127 | k,
128 | v,
129 | attn_mask=mask,
130 | dropout_p=0.,
131 | scale=1.,
132 | )
133 | output = (output.transpose(1,
134 | 2).contiguous().view(q.size(0), -1, D)
135 | ) # (batch, time1, d_model)
136 | return output, None
137 |
138 |
139 | class ResidualAttentionBlock(nn.Module):
140 |
141 | def __init__(self, n_state: int, n_head: int, use_sdpa: bool):
142 | super().__init__()
143 |
144 | self.attn = MultiHeadAttention(n_state, n_head, use_sdpa=use_sdpa)
145 | self.attn_ln = LayerNorm(n_state)
146 |
147 | n_mlp = n_state * 4
148 | self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(),
149 | Linear(n_mlp, n_state))
150 | self.mlp_ln = LayerNorm(n_state)
151 |
152 | def forward(
153 | self,
154 | x: Tensor,
155 | mask: Optional[Tensor] = None,
156 | ):
157 | x = x + self.attn(self.attn_ln(x), mask=mask)[0]
158 | x = x + self.mlp(self.mlp_ln(x))
159 | return x
160 |
161 |
162 | class AudioEncoder(nn.Module):
163 |
164 | def __init__(
165 | self,
166 | n_mels: int,
167 | n_ctx: int,
168 | n_state: int,
169 | n_head: int,
170 | n_layer: int,
171 | stride: int,
172 | use_sdpa: bool,
173 | ):
174 | super().__init__()
175 | self.stride = stride
176 | self.conv1 = Conv1d(n_mels,
177 | n_state,
178 | kernel_size=3,
179 | stride=stride,
180 | padding=1)
181 | self.conv2 = Conv1d(n_state,
182 | n_state,
183 | kernel_size=3,
184 | stride=2,
185 | padding=1)
186 | self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
187 |
188 | self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList([
189 | ResidualAttentionBlock(n_state, n_head, use_sdpa=use_sdpa)
190 | for _ in range(n_layer)
191 | ])
192 |
193 | def forward(self, x: Tensor, x_len: Tensor) -> Tuple[Tensor, Tensor]:
194 | """
195 | x : torch.Tensor, shape = (batch_size, n_mels, T)
196 | the mel spectrogram of the audio
197 | x_len: torch.Tensor, shape = (batch_size,)
198 | length of each audio in x
199 | """
200 | mask = make_non_pad_mask(x_len).unsqueeze(1)
201 | x = F.gelu(self.conv1(x * mask))
202 | x_len = (x_len + 2 - 1 * (3 - 1) - 1) // self.stride + 1
203 | mask = make_non_pad_mask(x_len).unsqueeze(1)
204 | x = F.gelu(self.conv2(x * mask))
205 | x_len = (x_len + 2 - 1 * (3 - 1) - 1) // 2 + 1
206 | mask = make_non_pad_mask(x_len).unsqueeze(1)
207 | x = x.permute(0, 2, 1) # (B, T // 2, n_state)
208 |
209 | mask = mask_to_bias(mask, x.dtype)
210 |
211 | x = (x + self.positional_embedding[:x.shape[1], :]).to(x.dtype)
212 |
213 | for block in self.blocks:
214 | x = block(x, mask.unsqueeze(1))
215 |
216 | return x, x_len
217 |
218 |
219 | class EuclideanCodebook(nn.Module):
220 | """Codebook with Euclidean distance (inference-only).
221 | Args:
222 | dim (int): Dimension.
223 | codebook_size (int): Codebook size.
224 | """
225 |
226 | def __init__(self, dim: int, codebook_size: int):
227 | super().__init__()
228 | embed = torch.zeros(codebook_size, dim)
229 | self.codebook_size = codebook_size
230 | self.register_buffer("embed", embed)
231 |
232 | @torch.inference_mode()
233 | def preprocess(self, x: Tensor) -> Tensor:
234 | x = rearrange(x, "... d -> (...) d")
235 | return x
236 |
237 | @torch.inference_mode()
238 | def quantize(self, x: Tensor) -> Tensor:
239 | embed = self.embed.t()
240 | dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed +
241 | embed.pow(2).sum(0, keepdim=True))
242 | embed_ind = dist.max(dim=-1).indices
243 | return embed_ind
244 |
245 | @torch.inference_mode()
246 | def postprocess_emb(self, embed_ind, shape):
247 | return embed_ind.view(*shape[:-1])
248 |
249 | @torch.inference_mode()
250 | def dequantize(self, embed_ind: Tensor) -> Tensor:
251 | quantize = F.embedding(embed_ind, self.embed)
252 | return quantize
253 |
254 | @torch.inference_mode()
255 | def encode(self, x: Tensor) -> Tensor:
256 | shape = x.shape
257 | # pre-process
258 | x = self.preprocess(x)
259 | # quantize
260 | embed_ind = self.quantize(x)
261 | # post-process
262 | embed_ind = self.postprocess_emb(embed_ind, shape)
263 | return embed_ind
264 |
265 | @torch.inference_mode()
266 | def decode(self, embed_ind: Tensor) -> Tensor:
267 | quantize = self.dequantize(embed_ind)
268 | return quantize
269 |
270 |
271 | class VectorQuantization(nn.Module):
272 | """Vector quantization implementation (inference-only).
273 | Args:
274 | dim (int): Dimension
275 | codebook_size (int): Codebook size
276 | """
277 |
278 | def __init__(self, dim: int, codebook_size: int):
279 | super().__init__()
280 | self._codebook = EuclideanCodebook(dim=dim,
281 | codebook_size=codebook_size)
282 | self.codebook_size = codebook_size
283 |
284 | @property
285 | def codebook(self):
286 | return self._codebook.embed
287 |
288 | @torch.inference_mode()
289 | def encode(self, x: Tensor) -> Tensor:
290 | x = F.normalize(x, p=2, dim=-1)
291 | embed_in = self._codebook.encode(x)
292 | return embed_in
293 |
294 | @torch.inference_mode()
295 | def decode(self, embed_ind: Tensor) -> Tensor:
296 | quantize = self._codebook.decode(embed_ind)
297 | quantize = rearrange(quantize, "b n d -> b d n")
298 | return quantize
299 |
300 |
301 | class S3Tokenizer(nn.Module):
302 | """S3 tokenizer implementation (inference-only).
303 | Args:
304 | config (ModelConfig): Config
305 | """
306 |
307 | def __init__(self, name: str, config: ModelConfig = ModelConfig()):
308 | super().__init__()
309 | self.config = config
310 | self.encoder = AudioEncoder(
311 | self.config.n_mels,
312 | self.config.n_audio_ctx,
313 | self.config.n_audio_state,
314 | self.config.n_audio_head,
315 | self.config.n_audio_layer,
316 | 2 if name == "speech_tokenizer_v1_25hz" else 1,
317 | self.config.use_sdpa,
318 | )
319 | self.quantizer = VectorQuantization(self.config.n_audio_state,
320 | self.config.n_codebook_size)
321 |
322 | def forward(self, mel: Tensor, mel_len: Tensor) -> Tuple[Tensor, Tensor]:
323 | return self.quantize(mel, mel_len)
324 |
325 | @torch.inference_mode()
326 | def quantize(self, mel: Tensor, mel_len: Tensor) -> Tuple[Tensor, Tensor]:
327 | hidden, code_len = self.encoder(mel, mel_len)
328 | code = self.quantizer.encode(hidden)
329 | return code, code_len
330 |
331 | @property
332 | def device(self):
333 | return next(self.parameters()).device
334 |
335 | def init_from_onnx(self, onnx_path: str):
336 | ckpt = onnx2torch(onnx_path, None, False)
337 | self.load_state_dict(ckpt, strict=True)
338 |
339 | def init_from_pt(self, ckpt_path: str):
340 | ckpt = torch.load(ckpt_path, map_location="cpu", mmap=True)
341 | self.load_state_dict(ckpt, strict=True)
342 |
343 | def freeze(self):
344 | for _, param in self.named_parameters():
345 | param.requires_grad = False
346 |
--------------------------------------------------------------------------------
/s3tokenizer/model_v2.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) (Mddct: Dinghao Zhou)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from dataclasses import dataclass
16 | from typing import Optional, Tuple
17 |
18 | import torch
19 | from einops import rearrange
20 |
21 | from s3tokenizer.model import Conv1d, LayerNorm, Linear, MultiHeadAttention
22 | from s3tokenizer.utils import make_non_pad_mask, mask_to_bias, onnx2torch
23 |
24 |
25 | @dataclass
26 | class ModelConfig:
27 | n_mels: int = 128
28 | n_audio_ctx: int = 1500
29 | n_audio_state: int = 1280
30 | n_audio_head: int = 20
31 | n_audio_layer: int = 6
32 | n_codebook_size: int = 3**8
33 |
34 | use_sdpa: bool = False
35 |
36 |
37 | def precompute_freqs_cis(dim: int,
38 | end: int,
39 | theta: float = 10000.0,
40 | scaling=None):
41 | freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
42 | t = torch.arange(end, device=freqs.device) # type: ignore
43 | if scaling is not None:
44 | t = t * scaling
45 | freqs = torch.outer(t, freqs).float() # type: ignore
46 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
47 |
48 | return torch.cat((freqs_cis, freqs_cis), dim=-1)
49 |
50 |
51 | def apply_rotary_emb(
52 | xq: torch.Tensor,
53 | xk: torch.Tensor,
54 | freqs_cis: torch.Tensor,
55 | ) -> Tuple[torch.Tensor, torch.Tensor]:
56 | real = torch.view_as_real(freqs_cis)
57 | cos, sin = real[:, :, 0], real[:, :, 1]
58 | cos = cos.unsqueeze(0).unsqueeze(2)
59 | sin = sin.unsqueeze(0).unsqueeze(2)
60 |
61 | D = xq.shape[-1]
62 | half_l, half_r = xq[:, :, :, :D // 2], xq[:, :, :, D // 2:]
63 | xq_r = torch.cat((-half_r, half_l), dim=-1)
64 |
65 | D = xk.shape[-1]
66 |
67 | half_l, half_r = xk[:, :, :, :D // 2], xk[:, :, :, D // 2:]
68 | xk_r = torch.cat((-half_r, half_l), dim=-1)
69 |
70 | return xq * cos + xq_r * sin, xk * cos + xk_r * sin
71 |
72 |
73 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
74 | ndim = x.ndim
75 | assert 0 <= 1 < ndim
76 | assert freqs_cis.shape == (x.shape[1], x.shape[-1])
77 | shape = [
78 | d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)
79 | ]
80 | return freqs_cis.view(*shape)
81 |
82 |
83 | class FSQCodebook(torch.nn.Module):
84 |
85 | def __init__(self, dim: int, level: int = 3):
86 | super().__init__()
87 | self.project_down = torch.nn.Linear(dim, 8)
88 | self.level = level
89 | self.embed = None
90 |
91 | @torch.inference_mode()
92 | def preprocess(self, x: torch.Tensor) -> torch.Tensor:
93 | x = rearrange(x, "... d -> (...) d")
94 | return x
95 |
96 | @torch.inference_mode()
97 | def encode(self, x: torch.Tensor) -> torch.Tensor:
98 | x_shape = x.shape
99 | # pre-process
100 | x = self.preprocess(x)
101 | # quantize
102 | h = self.project_down(x).float()
103 | h = h.tanh()
104 | h = h * 0.9990000128746033
105 | h = h.round() + 1
106 | # h = ((self.level - 1) * h).round() # range [-k, k]
107 | powers = torch.pow(
108 | self.level,
109 | torch.arange(2**self.level, device=x.device, dtype=h.dtype))
110 | mu = torch.sum(h * powers.unsqueeze(0), dim=-1)
111 | ind = mu.reshape(x_shape[0], x_shape[1]).int()
112 | return ind
113 |
114 | @torch.inference_mode()
115 | def decode(self, embed_ind: torch.Tensor) -> torch.Tensor:
116 | raise NotImplementedError(
117 | 'There is no official up project component provided')
118 |
119 |
120 | class FSQVectorQuantization(torch.nn.Module):
121 | """Vector quantization implementation (inference-only).
122 | Args:
123 | dim (int): Dimension
124 | codebook_size (int): Codebook size
125 | """
126 |
127 | def __init__(
128 | self,
129 | dim: int,
130 | codebook_size: int,
131 | ):
132 | super().__init__()
133 | assert 3**8 == codebook_size
134 | self._codebook = FSQCodebook(dim=dim, level=3)
135 | self.codebook_size = codebook_size
136 |
137 | @property
138 | def codebook(self):
139 | return self._codebook.embed
140 |
141 | @torch.inference_mode()
142 | def encode(self, x: torch.Tensor) -> torch.Tensor:
143 | return self._codebook.encode(x)
144 |
145 | @torch.inference_mode()
146 | def decode(self, embed_ind: torch.Tensor) -> torch.Tensor:
147 | quantize = self._codebook.decode(embed_ind)
148 | quantize = rearrange(quantize, "b n d -> b d n")
149 | return quantize
150 |
151 |
152 | class FSMNMultiHeadAttention(MultiHeadAttention):
153 |
154 | def __init__(
155 | self,
156 | n_state: int,
157 | n_head: int,
158 | kernel_size: int = 31,
159 | use_sdpa: bool = False,
160 | ):
161 | super().__init__(n_state, n_head)
162 |
163 | self.fsmn_block = torch.nn.Conv1d(n_state,
164 | n_state,
165 | kernel_size,
166 | stride=1,
167 | padding=0,
168 | groups=n_state,
169 | bias=False)
170 | self.left_padding = (kernel_size - 1) // 2
171 | self.right_padding = kernel_size - 1 - self.left_padding
172 | self.pad_fn = torch.nn.ConstantPad1d(
173 | (self.left_padding, self.right_padding), 0.0)
174 |
175 | self.use_sdpa = use_sdpa
176 |
177 | def forward_fsmn(self,
178 | inputs: torch.Tensor,
179 | mask: Optional[torch.Tensor] = None):
180 | b, t, _, _ = inputs.size()
181 | inputs = inputs.view(b, t, -1)
182 | if mask is not None and mask.size(2) > 0: # time2 > 0
183 | inputs = inputs * mask
184 | x = inputs.transpose(1, 2)
185 | x = self.pad_fn(x)
186 | x = self.fsmn_block(x)
187 | x = x.transpose(1, 2)
188 | x += inputs
189 | return x * mask
190 |
191 | def qkv_attention(self,
192 | q: torch.Tensor,
193 | k: torch.Tensor,
194 | v: torch.Tensor,
195 | mask: Optional[torch.Tensor] = None,
196 | mask_pad: Optional[torch.Tensor] = None,
197 | freqs_cis: Optional[torch.Tensor] = None):
198 | _, _, D = q.shape
199 | scale = (D // self.n_head)**-0.25
200 | q = q.view(*q.shape[:2], self.n_head, -1)
201 | k = k.view(*k.shape[:2], self.n_head, -1)
202 | v = v.view(*v.shape[:2], self.n_head, -1)
203 |
204 | if freqs_cis is not None:
205 | q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis)
206 |
207 | fsm_memory = self.forward_fsmn(v, mask_pad)
208 |
209 | q = q.permute(0, 2, 1, 3) * scale
210 | v = v.permute(0, 2, 1, 3)
211 |
212 | if not self.use_sdpa:
213 | k = k.permute(0, 2, 3, 1) * scale
214 | qk = q @ k # (B, n_head, T, T)
215 | if mask is not None:
216 | qk = qk + mask
217 | qk = qk.float()
218 | w = torch.nn.functional.softmax(qk, dim=-1).to(q.dtype)
219 | return (w @ v).permute(
220 | 0, 2, 1, 3).flatten(start_dim=2), qk.detach(), fsm_memory
221 | else:
222 | k = k.permute(0, 2, 1, 3) * scale
223 | assert mask is not None
224 | output = torch.nn.functional.scaled_dot_product_attention(
225 | q,
226 | k,
227 | v,
228 | attn_mask=mask,
229 | dropout_p=0.,
230 | scale=1.,
231 | )
232 | output = (output.transpose(1,
233 | 2).contiguous().view(q.size(0), -1, D)
234 | ) # (batch, time1, d_model)
235 | return output, None, fsm_memory
236 |
237 | def forward(self,
238 | x: torch.Tensor,
239 | mask: Optional[torch.Tensor] = None,
240 | mask_pad: Optional[torch.Tensor] = None,
241 | freqs_cis: Optional[torch.Tensor] = None):
242 |
243 | q = self.query(x)
244 | k = self.key(x)
245 | v = self.value(x)
246 |
247 | wv, qk, fsm_memory = self.qkv_attention(q, k, v, mask, mask_pad,
248 | freqs_cis)
249 | return self.out(wv) + fsm_memory, qk
250 |
251 |
252 | class ResidualAttentionBlock(torch.nn.Module):
253 |
254 | def __init__(
255 | self,
256 | n_state: int,
257 | n_head: int,
258 | kernel_size: int = 31,
259 | use_sdpa: bool = False,
260 | ):
261 | super().__init__()
262 |
263 | self.attn = FSMNMultiHeadAttention(n_state,
264 | n_head,
265 | kernel_size,
266 | use_sdpa=use_sdpa)
267 | self.attn_ln = LayerNorm(n_state, eps=1e-6)
268 |
269 | n_mlp = n_state * 4
270 |
271 | self.mlp = torch.nn.Sequential(Linear(n_state, n_mlp), torch.nn.GELU(),
272 | Linear(n_mlp, n_state))
273 | self.mlp_ln = LayerNorm(n_state)
274 |
275 | def forward(
276 | self,
277 | x: torch.Tensor,
278 | mask: Optional[torch.Tensor] = None,
279 | mask_pad: Optional[torch.Tensor] = None,
280 | freqs_cis: Optional[torch.Tensor] = None,
281 | ):
282 | x = x + self.attn(
283 | self.attn_ln(x), mask=mask, mask_pad=mask_pad,
284 | freqs_cis=freqs_cis)[0]
285 |
286 | x = x + self.mlp(self.mlp_ln(x))
287 | return x
288 |
289 |
290 | class AudioEncoderV2(torch.nn.Module):
291 |
292 | def __init__(
293 | self,
294 | n_mels: int,
295 | n_state: int,
296 | n_head: int,
297 | n_layer: int,
298 | stride: int,
299 | use_sdpa: bool,
300 | ):
301 | super().__init__()
302 | self.stride = stride
303 |
304 | self.conv1 = Conv1d(n_mels,
305 | n_state,
306 | kernel_size=3,
307 | stride=stride,
308 | padding=1)
309 | self.conv2 = Conv1d(n_state,
310 | n_state,
311 | kernel_size=3,
312 | stride=2,
313 | padding=1)
314 | self.freqs_cis = precompute_freqs_cis(64, 1024 * 2)
315 | self.blocks = torch.nn.ModuleList([
316 | ResidualAttentionBlock(n_state, n_head, use_sdpa=use_sdpa)
317 | for _ in range(n_layer)
318 | ])
319 |
320 | def forward(self, x: torch.Tensor,
321 | x_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
322 | """
323 | x : torch.Tensor, shape = (batch_size, n_mels, T)
324 | the mel spectrogram of the audio
325 | x_len: torch.Tensor, shape = (batch_size,)
326 | length of each audio in x
327 | """
328 | mask = make_non_pad_mask(x_len).unsqueeze(1)
329 | x = torch.nn.functional.gelu(self.conv1(x * mask))
330 | x_len = (x_len + 2 - 1 * (3 - 1) - 1) // self.stride + 1
331 | mask = make_non_pad_mask(x_len).unsqueeze(1)
332 | x = torch.nn.functional.gelu(self.conv2(x * mask))
333 | x_len = (x_len + 2 - 1 * (3 - 1) - 1) // 2 + 1
334 | mask = make_non_pad_mask(x_len).unsqueeze(1)
335 | x = x.permute(0, 2, 1) # (B, T // 2, n_state)
336 | freqs_cis = self.freqs_cis.to(x.device)
337 | mask_pad = mask.transpose(1, 2)
338 | mask = mask_to_bias(mask, x.dtype)
339 |
340 | tmp = torch.view_as_real(freqs_cis)
341 | cos, sin = tmp[:, :, 0], tmp[:, :, 1]
342 |
343 | cos = torch.cat((cos, cos), dim=-1)
344 | sin = torch.cat((sin, sin), dim=-1)
345 | cos = cos.unsqueeze(0).unsqueeze(2)
346 | sin = sin.unsqueeze(0).unsqueeze(2)
347 |
348 | for block in self.blocks:
349 | x = block(x, mask.unsqueeze(1), mask_pad, freqs_cis[:x.size(1)])
350 |
351 | return x, x_len
352 |
353 |
354 | class S3TokenizerV2(torch.nn.Module):
355 | """S3 tokenizer v2 implementation (inference-only).
356 | Args:
357 | config (ModelConfig): Config
358 | """
359 |
360 | def __init__(self, name: str, config: ModelConfig = ModelConfig()):
361 | super().__init__()
362 | if 'v1' not in name:
363 | assert 'v2' in name
364 | # TODO(Mddct): make it configureable
365 | config.n_codebook_size = 3**8
366 | self.config = config
367 | self.encoder = AudioEncoderV2(
368 | self.config.n_mels,
369 | self.config.n_audio_state,
370 | self.config.n_audio_head,
371 | self.config.n_audio_layer,
372 | 2,
373 | self.config.use_sdpa,
374 | )
375 | self.quantizer = FSQVectorQuantization(
376 | self.config.n_audio_state,
377 | self.config.n_codebook_size,
378 | )
379 |
380 | def forward(self, mel: torch.Tensor,
381 | mel_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
382 | return self.quantize(mel, mel_len)
383 |
384 | @torch.inference_mode()
385 | def quantize(self, mel: torch.Tensor,
386 | mel_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
387 | hidden, code_len = self.encoder(mel, mel_len)
388 | code = self.quantizer.encode(hidden)
389 | return code, code_len
390 |
391 | @property
392 | def device(self):
393 | return next(self.parameters()).device
394 |
395 | def init_from_onnx(self, onnx_path: str):
396 | ckpt = onnx2torch(onnx_path, None, False)
397 | self.load_state_dict(ckpt, strict=True)
398 |
399 | def init_from_pt(self, ckpt_path: str):
400 | ckpt = torch.load(ckpt_path, map_location="cpu", mmap=True)
401 | self.load_state_dict(ckpt, strict=True)
402 |
403 | def freeze(self):
404 | for _, param in self.named_parameters():
405 | param.requires_grad = False
406 |
--------------------------------------------------------------------------------
/s3tokenizer/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 OpenAI. (authors: Whisper Team)
2 | # 2024 Tsinghua Univ. (authors: Xingchen Song)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Modified from https://github.com/openai/whisper/blob/main/whisper/audio.py
16 | Add rename_weights() & onnx2torch() & make_non_pad_mask() & mask_to_bias()
17 | """
18 |
19 | import os
20 | from functools import lru_cache
21 | from typing import List, Optional, Union
22 |
23 | import numpy as np
24 | import onnx
25 | import torch
26 | import torch.nn.functional as F
27 | import torchaudio
28 | from torch.nn.utils.rnn import pad_sequence
29 |
30 |
31 | def _rename_weights(weights_dict: dict):
32 | """
33 | Rename onnx weights to pytorch format.
34 |
35 | Parameters
36 | ----------
37 | weight_dict: dict
38 | The dict containing weights in onnx format
39 |
40 | Returns
41 | -------
42 | A new weight dict containing the weights in pytorch format.
43 | """
44 | new_weight_dict = {}
45 | for k in weights_dict.keys():
46 | if "quantizer" in k: # vq or fsq
47 | if k == "/quantizer/rq/model/layers.0/_codebook/Pow_1":
48 | new_weight_dict["quantizer._codebook.embed"] = weights_dict[k]
49 | elif 'project_down' in k: # v2
50 | new_weight_dict[k] = weights_dict[k]
51 | elif "positional_embedding" in k: # positional emb
52 | new_weight_dict[k] = weights_dict[k]
53 | elif "conv" in k: # 1/2 or 1/4 subsample
54 | new_weight_dict[k] = weights_dict[k]
55 | else: # transformer blocks
56 | assert "blocks" in k
57 | new_k = (k[1:].replace('/', '.').replace(
58 | 'MatMul', 'weight').replace('Add_1', 'bias').replace(
59 | 'Mul', 'weight').replace('Add', 'bias').replace(
60 | 'mlp.mlp', 'mlp')).replace('fsmn_block.Conv',
61 | 'fsmn_block.weight')
62 |
63 | new_weight_dict[f"encoder.{new_k}"] = weights_dict[k]
64 | return new_weight_dict
65 |
66 |
67 | def onnx2torch(onnx_path: str, torch_path: str = None, verbose: bool = False):
68 | """
69 | Open an onnx file and convert to pytorch format.
70 |
71 | Parameters
72 | ----------
73 | onnx_path: str
74 | The onnx file to open, typically `speech_tokenizer_v1.onnx`
75 |
76 | torch_path: str
77 | The path to save the torch-formated checkpoint.
78 |
79 | verbose: bool
80 | Logging info or not.
81 |
82 | Returns
83 | -------
84 | A checkpoint dict containing the weights and their names, if torch_path is
85 | None. Otherwise save checkpoint dict to the desired path.
86 | """
87 | onnx_model = onnx.load(onnx_path)
88 | weights_dict = {}
89 | initializer_map = {
90 | initializer.name: initializer
91 | for initializer in onnx_model.graph.initializer
92 | }
93 | for node in onnx_model.graph.node:
94 | for input_name in node.input:
95 | if input_name in initializer_map:
96 | ln_bias_name, ln_weight_name = None, None # for v2 ln
97 | initializer = initializer_map[input_name]
98 | if input_name in [
99 | "onnx::Conv_1519",
100 | "encoders.conv1.weight",
101 | "onnx::Conv_2216",
102 | ]: # v1_50hz, v1_25hz, v2_25hz
103 | weight_name = "encoder.conv1.weight"
104 | elif input_name in [
105 | "onnx::Conv_1520",
106 | "encoders.conv1.bias",
107 | "onnx::Conv_2217",
108 | ]: # v1_50hz, v1_25hz, v2_25hz
109 | weight_name = "encoder.conv1.bias"
110 | elif input_name in [
111 | "onnx::Conv_1521",
112 | "encoders.conv2.weight",
113 | "onnx::Conv_2218",
114 | ]:
115 | weight_name = "encoder.conv2.weight"
116 | elif input_name in [
117 | "onnx::Conv_1522",
118 | "encoders.conv2.bias",
119 | "onnx::Conv_2219",
120 | ]:
121 | weight_name = "encoder.conv2.bias"
122 | elif input_name == "encoders.positional_embedding":
123 | weight_name = "encoder.positional_embedding"
124 | elif input_name == 'quantizer.project_in.bias':
125 | weight_name = "quantizer._codebook.project_down.bias"
126 | elif input_name == 'onnx::MatMul_2536':
127 | weight_name = "quantizer._codebook.project_down.weight"
128 | else:
129 | if node.op_type == 'LayerNormalization': # in input_name:
130 | ln_name = node.name.replace('/LayerNormalization', '')
131 | ln_weight_name = ln_name + '.weight'
132 | ln_bias_name = ln_name + '.bias'
133 | else:
134 | weight_name = node.name
135 | if ln_weight_name is not None and ln_bias_name is not None:
136 | ln_inputs = node.input
137 | scale_name = ln_inputs[1]
138 | bias_name = ln_inputs[2]
139 | scale = onnx.numpy_helper.to_array(
140 | initializer_map[scale_name]).copy(
141 | ) if scale_name in initializer_map else None
142 | bias = onnx.numpy_helper.to_array(
143 | initializer_map[bias_name]).copy(
144 | ) if bias_name in initializer_map else None
145 | scale.flags.writeable = True
146 | bias.flags.writeable = True
147 | weight_tensor = torch.from_numpy(scale)
148 | bias_tensor = torch.from_numpy(bias)
149 |
150 | weights_dict[ln_bias_name] = bias_tensor
151 | weights_dict[ln_weight_name] = weight_tensor
152 | else:
153 | weight_array = onnx.numpy_helper.to_array(
154 | initializer).copy()
155 | weight_array.flags.writeable = True
156 | weight_tensor = torch.from_numpy(weight_array)
157 | if len(weight_tensor.shape) > 2 or weight_name in [
158 | "encoder.positional_embedding"
159 | ]:
160 | weights_dict[weight_name] = weight_tensor
161 | else:
162 | weights_dict[weight_name] = weight_tensor.t()
163 |
164 | new_weights_dict = _rename_weights(weights_dict)
165 | if verbose:
166 | for k, v in new_weights_dict.items():
167 | print(f"{k} : {v.shape} {v.dtype}")
168 | print(f"PyTorch weights saved to {torch_path}")
169 | del weights_dict, onnx_model
170 | if torch_path:
171 | torch.save(new_weights_dict, torch_path)
172 | else:
173 | return new_weights_dict
174 |
175 |
176 | def load_audio(file: str, sr: int = 16000):
177 | """
178 | Open an audio file and read as mono waveform, resampling as necessary
179 |
180 | Parameters
181 | ----------
182 | file: str
183 | The audio file to open
184 |
185 | sr: int
186 | The sample rate to resample the audio if necessary
187 |
188 | Returns
189 | -------
190 | A torch.Tensor containing the audio waveform, in float32 dtype.
191 | """
192 | audio, sample_rate = torchaudio.load(file)
193 | if sample_rate != sr:
194 | audio = torchaudio.transforms.Resample(sample_rate, sr)(audio)
195 | audio = audio[0] # get the first channel
196 | return audio
197 |
198 |
199 | @lru_cache(maxsize=None)
200 | def _mel_filters(device, n_mels: int) -> torch.Tensor:
201 | """
202 | load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
203 | Allows decoupling librosa dependency; saved using:
204 |
205 | np.savez_compressed(
206 | "mel_filters.npz",
207 | mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
208 | mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
209 | )
210 | """
211 | assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
212 |
213 | filters_path = os.path.join(os.path.dirname(__file__), "assets",
214 | "mel_filters.npz")
215 | with np.load(filters_path, allow_pickle=False) as f:
216 | return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
217 |
218 |
219 | def log_mel_spectrogram(
220 | audio: Union[str, np.ndarray, torch.Tensor],
221 | n_mels: int = 128,
222 | padding: int = 0,
223 | device: Optional[Union[str, torch.device]] = None,
224 | ):
225 | """
226 | Compute the log-Mel spectrogram of
227 |
228 | Parameters
229 | ----------
230 | audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
231 | The path to audio or either a NumPy array or Tensor containing the
232 | audio waveform in 16 kHz
233 |
234 | n_mels: int
235 | The number of Mel-frequency filters, only 80 is supported
236 |
237 | padding: int
238 | Number of zero samples to pad to the right
239 |
240 | device: Optional[Union[str, torch.device]]
241 | If given, the audio tensor is moved to this device before STFT
242 |
243 | Returns
244 | -------
245 | torch.Tensor, shape = (128, n_frames)
246 | A Tensor that contains the Mel spectrogram
247 | """
248 | if not torch.is_tensor(audio):
249 | if isinstance(audio, str):
250 | audio = load_audio(audio)
251 | audio = torch.from_numpy(audio)
252 |
253 | if device is not None:
254 | audio = audio.to(device)
255 | if padding > 0:
256 | audio = F.pad(audio, (0, padding))
257 | window = torch.hann_window(400).to(audio.device)
258 | stft = torch.stft(audio, 400, 160, window=window, return_complex=True)
259 | magnitudes = stft[..., :-1].abs()**2
260 |
261 | filters = _mel_filters(audio.device, n_mels)
262 | mel_spec = filters @ magnitudes
263 |
264 | log_spec = torch.clamp(mel_spec, min=1e-10).log10()
265 | log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
266 | log_spec = (log_spec + 4.0) / 4.0
267 | return log_spec
268 |
269 |
270 | def make_non_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
271 | """Make mask tensor containing indices of non-padded part.
272 |
273 | The sequences in a batch may have different lengths. To enable
274 | batch computing, padding is need to make all sequence in same
275 | size. To avoid the padding part pass value to context dependent
276 | block such as attention or convolution , this padding part is
277 | masked.
278 |
279 | 1 for non-padded part and 0 for padded part.
280 |
281 | Parameters
282 | ----------
283 | lengths (torch.Tensor): Batch of lengths (B,).
284 |
285 | Returns:
286 | -------
287 | torch.Tensor: Mask tensor containing indices of padded part (B, max_T).
288 |
289 | Examples:
290 | >>> import torch
291 | >>> import s3tokenizer
292 | >>> lengths = torch.tensor([5, 3, 2])
293 | >>> masks = s3tokenizer.make_non_pad_mask(lengths)
294 | masks = [[1, 1, 1, 1, 1],
295 | [1, 1, 1, 0, 0],
296 | [1, 1, 0, 0, 0]]
297 | """
298 | batch_size = lengths.size(0)
299 | max_len = max_len if max_len > 0 else lengths.max().item()
300 | seq_range = torch.arange(0,
301 | max_len,
302 | dtype=torch.int64,
303 | device=lengths.device)
304 | seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
305 | seq_length_expand = lengths.unsqueeze(-1)
306 | mask = seq_range_expand >= seq_length_expand
307 | return ~mask
308 |
309 |
310 | def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
311 | """Convert bool-tensor to float-tensor for flash attention.
312 |
313 | Parameters
314 | ----------
315 | lengths (torch.Tensor): Batch of lengths (B, ?).
316 |
317 | Returns:
318 | -------
319 | torch.Tensor: Mask tensor containing indices of padded part (B, ?).
320 |
321 | Examples:
322 | >>> import torch
323 | >>> import s3tokenizer
324 | >>> lengths = torch.tensor([5, 3, 2])
325 | >>> masks = s3tokenizer.make_non_pad_mask(lengths)
326 | masks = [[1, 1, 1, 1, 1],
327 | [1, 1, 1, 0, 0],
328 | [1, 1, 0, 0, 0]]
329 | >>> new_masks = s3tokenizer.mask_to_bias(masks, torch.float32)
330 | new_masks =
331 | [[-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],
332 | [-0.0000e+00, -0.0000e+00, -0.0000e+00, -1.0000e+10, -1.0000e+10],
333 | [-0.0000e+00, -0.0000e+00, -1.0000e+10, -1.0000e+10, -1.0000e+10]]
334 | """
335 | assert mask.dtype == torch.bool
336 | assert dtype in [torch.float32, torch.bfloat16, torch.float16]
337 | mask = mask.to(dtype)
338 |
339 | # attention mask bias
340 | # NOTE(Mddct): torch.finfo jit issues
341 | # chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
342 | mask = (1.0 - mask) * -1.0e+10
343 | return mask
344 |
345 |
346 | def padding(data: List[torch.Tensor]):
347 | """ Padding the data into batch data
348 |
349 | Parameters
350 | ----------
351 | data: List[Tensor], shape of Tensor (128, T)
352 |
353 | Returns:
354 | -------
355 | feats, feats lengths
356 | """
357 | sample = data
358 | assert isinstance(sample, list)
359 | feats_lengths = torch.tensor([s.size(1) for s in sample],
360 | dtype=torch.int32)
361 | feats = [s.t() for s in sample]
362 | padded_feats = pad_sequence(feats, batch_first=True, padding_value=0)
363 |
364 | return padded_feats.transpose(1, 2), feats_lengths
365 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from setuptools import find_packages, setup
4 |
5 |
6 | def parse_requirements(filename):
7 | """Load requirements from a pip requirements file."""
8 | with open(filename, 'r') as file:
9 | lines = (line.strip() for line in file)
10 | return [line for line in lines if line and not line.startswith('#')]
11 |
12 |
13 | setup(
14 | name="s3tokenizer",
15 | version="0.1.7",
16 | description=\
17 | "Reverse Engineering of Supervised Semantic Speech Tokenizer (S3Tokenizer) proposed in CosyVoice", # noqa
18 | long_description=open("README.md", encoding="utf-8").read(),
19 | long_description_content_type="text/markdown",
20 | python_requires=">=3.8",
21 | author="xingchensong",
22 | url="https://github.com/xingchensong/S3Tokenizer",
23 | license="Apache2.0",
24 | packages=find_packages(),
25 | install_requires=parse_requirements(
26 | Path(__file__).with_name("requirements.txt")),
27 | entry_points={
28 | "console_scripts": ["s3tokenizer=s3tokenizer.cli:main"],
29 | },
30 | include_package_data=True,
31 | extras_require={"dev": ["pytest", "scipy", "black", "flake8", "isort"]},
32 | classifiers=[
33 | "Programming Language :: Python :: 3",
34 | "Operating System :: OS Independent",
35 | "Topic :: Scientific/Engineering",
36 | ],
37 | )
38 |
--------------------------------------------------------------------------------
/test/test_onnx.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # Copyright [2024-09-27]
4 |
5 | import os
6 |
7 | import numpy as np
8 | import onnxruntime
9 | import s3tokenizer
10 | import torch
11 |
12 | default = os.path.join(os.path.expanduser("~"), ".cache")
13 | download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default),
14 | "s3tokenizer")
15 | name = "speech_tokenizer_v1"
16 | tokenizer = s3tokenizer.load_model(name)
17 |
18 | mels = []
19 | wav_paths = [
20 | "s3tokenizer/assets/BAC009S0764W0121.wav",
21 | "s3tokenizer/assets/BAC009S0764W0122.wav"
22 | ]
23 | for wav_path in wav_paths:
24 | audio = s3tokenizer.load_audio(wav_path)
25 | mels.append(s3tokenizer.log_mel_spectrogram(audio))
26 | print("=========torch=============")
27 | mels, mels_lens = s3tokenizer.padding(mels)
28 | print(f"mels.size: {mels.size()}, mels_lens: {mels_lens}")
29 | codes, codes_lens = tokenizer.quantize(mels, mels_lens)
30 | print(f"codes.size: {codes.size()}, codes_lens: {codes_lens}")
31 |
32 | for i in range(len(wav_paths)):
33 | print(f"wav[{i}]")
34 | print(codes[i, :codes_lens[i].item()])
35 |
36 | print("=========onnx===============")
37 | option = onnxruntime.SessionOptions()
38 | option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL # noqa
39 | option.intra_op_num_threads = 1
40 | providers = ["CPUExecutionProvider"]
41 | ort_session = onnxruntime.InferenceSession(f"{download_root}/{name}.onnx",
42 | sess_options=option,
43 | providers=providers)
44 |
45 | for i in range(len(wav_paths)):
46 | speech_token = ort_session.run(
47 | None, {
48 | ort_session.get_inputs()[0].name:
49 | mels[i, :, :mels_lens[i].item()].unsqueeze(
50 | 0).detach().cpu().numpy(),
51 | ort_session.get_inputs()[1].name:
52 | np.array([mels_lens[i].item()], dtype=np.int32)
53 | })[0]
54 | if name == 'speech_tokenizer_v2_25hz':
55 | speech_token = np.expand_dims(speech_token, 0)
56 | speech_token = torch.tensor(speech_token[0, 0, :])
57 | print(f"wav[{i}]")
58 | print(speech_token)
59 | print(
60 | f"all equal: {torch.equal(speech_token, codes[i, :codes_lens[i].item()].cpu())}" # noqa
61 | )
62 | miss_num = torch.sum(
63 | ~(speech_token == codes[i, :codes_lens[i].item()].cpu()))
64 | total = speech_token.numel()
65 | print(f"miss rate: {miss_num * 100.0 / total}%")
66 |
--------------------------------------------------------------------------------