├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── NOTICE ├── README.md ├── hubconf.py ├── requirements.txt ├── sample.wav ├── setup.cfg └── src ├── codec.py ├── dequantizer.py ├── quantizer.py ├── utils.py └── vocoder.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # IPython 78 | profile_default/ 79 | ipython_config.py 80 | 81 | # pyenv 82 | .python-version 83 | 84 | # celery beat schedule file 85 | celerybeat-schedule 86 | 87 | # SageMath parsed files 88 | *.sage.py 89 | 90 | # Environments 91 | .env 92 | .venv 93 | env/* 94 | venv/ 95 | ENV/ 96 | env.bak/ 97 | venv.bak/ 98 | 99 | # Spyder project settings 100 | .spyderproject 101 | .spyproject 102 | 103 | # Rope project settings 104 | .ropeproject 105 | 106 | # mkdocs documentation 107 | /site 108 | 109 | # mypy 110 | .mypy_cache/ 111 | .dmypy.json 112 | dmypy.json 113 | 114 | # Pyre type checker 115 | .pyre/ 116 | 117 | # PyCharm 118 | .idea/ 119 | 120 | # Misc 121 | reconstruction.wav 122 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v2.5.0 4 | hooks: 5 | - id: trailing-whitespace 6 | types: [file, text] 7 | - id: end-of-file-fixer 8 | types: [python] 9 | - id: mixed-line-ending 10 | types: [python] 11 | args: ["--fix=lf"] 12 | 13 | - repo: https://github.com/psf/black 14 | rev: 24.3.0 15 | hooks: 16 | - id: black 17 | types: [python] 18 | 19 | - repo: https://github.com/pycqa/isort 20 | rev: 5.13.2 21 | hooks: 22 | - id: isort 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | Copyright 2024 Luca Della Libera. 179 | 180 | Licensed under the Apache License, Version 2.0 (the "License"); 181 | you may not use this file except in compliance with the License. 182 | You may obtain a copy of the License at 183 | 184 | https://www.apache.org/licenses/LICENSE-2.0 185 | 186 | Unless required by applicable law or agreed to in writing, software 187 | distributed under the License is distributed on an "AS IS" BASIS, 188 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 189 | See the License for the specific language governing permissions and 190 | limitations under the License. -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | This project incorporates components from the projects listed below. The original copyright notices are set forth below. 2 | 3 | ############################################################################################################################################################# 4 | 5 | 1. Code in src/quantizer.py adapted from: 6 | https://github.com/jokofa/torch_kmeans/tree/be7d2b78664e81a985ddfa6d21d94917a8b49fe6 7 | 8 | The MIT License (MIT) 9 | 10 | Copyright (c) 2022 Jonas K. Falkner 11 | 12 | Permission is hereby granted, free of charge, to any person obtaining a copy 13 | of this software and associated documentation files (the "Software"), to deal 14 | in the Software without restriction, including without limitation the rights 15 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 16 | copies of the Software, and to permit persons to whom the Software is 17 | furnished to do so, subject to the following conditions: 18 | 19 | The above copyright notice and this permission notice shall be included in all 20 | copies or substantial portions of the Software. 21 | 22 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 23 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 24 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 25 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 26 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 27 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 28 | SOFTWARE. 29 | 30 | ############################################################################################################################################################# -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Discrete WavLM Codec 2 | 3 | A speech codec obtained by quantizing WavLM representations via K-means clustering (see https://arxiv.org/abs/2312.09747). 4 | 5 | --------------------------------------------------------------------------------------------------------- 6 | 7 | ## 🛠️️ Installation 8 | 9 | First of all, install [Python 3.8 or later](https://www.python.org). Open a terminal and run: 10 | 11 | ``` 12 | pip install huggingface-hub safetensors speechbrain torch torchaudio transformers 13 | ``` 14 | 15 | --------------------------------------------------------------------------------------------------------- 16 | 17 | ## ▶️ Quickstart 18 | 19 | We use `torch.hub` to make loading the model easy (no need to clone the repository): 20 | 21 | ```python 22 | import torch 23 | import torchaudio 24 | 25 | dwavlm = torch.hub.load("lucadellalib/discrete-wavlm-codec", "discrete_wavlm_large", layer_ids=[6]) 26 | dwavlm.eval().requires_grad_(False) 27 | sig, sample_rate = torchaudio.load("") 28 | sig = torchaudio.functional.resample(sig, sample_rate, dwavlm.sample_rate) 29 | feats = dwavlm.sig_to_feats(sig) 30 | toks = dwavlm.feats_to_toks(feats) 31 | qfeats = dwavlm.toks_to_qfeats(toks) 32 | rec_feats = dwavlm.qfeats_to_feats(qfeats) 33 | rec_sig = dwavlm.feats_to_sig(rec_feats) 34 | torchaudio.save("reconstruction.wav", rec_sig[:, 0], dwavlm.sample_rate) 35 | ``` 36 | 37 | --------------------------------------------------------------------------------------------------------- 38 | 39 | ## 📧 Contact 40 | 41 | [luca.dellalib@gmail.com](mailto:luca.dellalib@gmail.com) 42 | 43 | --------------------------------------------------------------------------------------------------------- 44 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | # ============================================================================== 2 | # Copyright 2024 Luca Della Libera. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """PyTorch Hub entry point.""" 18 | 19 | import huggingface_hub 20 | import torch 21 | from safetensors import safe_open 22 | from speechbrain.lobes.models.huggingface_transformers.wavlm import WavLM 23 | from speechbrain.lobes.models.transformer.TransformerASR import TransformerASR 24 | 25 | from src.codec import Codec 26 | from src.dequantizer import Dequantizer 27 | from src.quantizer import KMeansMultiQuantizer 28 | from src.utils import SBWav2Vec2ForwardWrapper 29 | from src.vocoder import HifiganVocoder 30 | 31 | 32 | dependencies = [ 33 | "huggingface_hub", 34 | "safetensors", 35 | "speechbrain", 36 | "torch", 37 | "transformers", 38 | ] 39 | 40 | 41 | def discrete_wavlm_large( 42 | layer_ids=(6,), 43 | pretrained=True, 44 | cache_dir=huggingface_hub.constants.HUGGINGFACE_HUB_CACHE, 45 | ) -> "Codec": 46 | """Load discrete WavLM codec. 47 | 48 | Arguments 49 | --------- 50 | layer_ids: 51 | The WavLM layer indices. 52 | pretrained: 53 | True to load the pretrained model weights, False otherwise. 54 | cache_dir: 55 | The model cache directory. 56 | 57 | """ 58 | encoder = WavLM( 59 | source="microsoft/wavlm-large", 60 | save_path=cache_dir, 61 | output_all_hiddens=True, 62 | output_norm=False, 63 | ) 64 | encoder = SBWav2Vec2ForwardWrapper(encoder, layer_ids) 65 | 66 | num_features = 1024 67 | num_clusters = [512] * len(layer_ids) 68 | quantizer = KMeansMultiQuantizer(num_features, num_clusters) 69 | 70 | dropout = 0.1 71 | activation = torch.nn.GELU 72 | d_model = 512 73 | nhead = 4 74 | num_layers = 6 75 | d_ffn = 512 76 | max_length = 2000 77 | causal = False 78 | dequantizer = Dequantizer( 79 | frontend=torch.nn.Linear(in_features=len(layer_ids), out_features=1), 80 | backbone=TransformerASR( 81 | input_size=num_features, 82 | tgt_vocab=-1, 83 | d_model=d_model, 84 | nhead=nhead, 85 | num_encoder_layers=num_layers, 86 | num_decoder_layers=0, 87 | d_ffn=d_ffn, 88 | dropout=dropout, 89 | activation=activation, 90 | max_length=max_length, 91 | encoder_module="conformer", 92 | normalize_before=True, 93 | causal=causal, 94 | ), 95 | head=torch.nn.Linear(in_features=d_model, out_features=num_features), 96 | backend=torch.nn.Linear(in_features=1, out_features=len(layer_ids)), 97 | ) 98 | 99 | resblock_type = 1 100 | resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]] 101 | resblock_kernel_sizes = [3, 7, 11] 102 | upsample_kernel_sizes = [20, 16, 4, 4] 103 | upsample_initial_channel = 512 104 | upsample_factors = [10, 8, 2, 2] 105 | vocoder = HifiganVocoder( 106 | embedding_dim=[num_features] * len(layer_ids), 107 | out_channels=1, 108 | resblock_type=str(resblock_type), 109 | resblock_dilation_sizes=resblock_dilation_sizes, 110 | resblock_kernel_sizes=resblock_kernel_sizes, 111 | upsample_kernel_sizes=upsample_kernel_sizes, 112 | upsample_initial_channel=upsample_initial_channel, 113 | upsample_factors=upsample_factors, 114 | ) 115 | 116 | if pretrained: 117 | repo_id = "lucadellalib/discrete-wavlm-codec" 118 | variant = "_" + "-".join([str(x) for x in layer_ids]) + ".safetensors" 119 | for module, ckpt_file in zip( 120 | [quantizer, dequantizer, vocoder], 121 | [f"quantizer{variant}", f"dequantizer{variant}", f"vocoder{variant}"], 122 | ): 123 | local_path = huggingface_hub.hf_hub_download( 124 | repo_id, ckpt_file, cache_dir=cache_dir 125 | ) 126 | with safe_open(local_path, framework="pt", device="cpu") as f: 127 | module.load_state_dict({k: f.get_tensor(k) for k in f.keys()}) 128 | 129 | codec = Codec(encoder, quantizer, dequantizer, vocoder) 130 | codec.sample_rate = 16000 131 | 132 | return codec 133 | 134 | 135 | if __name__ == "__main__": 136 | try: 137 | import torchaudio 138 | except ImportError: 139 | raise ImportError("`pip install torchaudio` to run this script") 140 | 141 | codec = discrete_wavlm_large(pretrained=True, layer_ids=[1, 3, 6]) 142 | print( 143 | f"Total number of parameters: {sum([x.numel() for x in codec.state_dict().values()]) / 1e6} M" 144 | ) 145 | codec.eval().requires_grad_(False) 146 | sig, sample_rate = torchaudio.load("sample.wav") 147 | sig = torchaudio.functional.resample(sig, sample_rate, codec.sample_rate) 148 | feats = codec.sig_to_feats(sig) 149 | toks = codec.feats_to_toks(feats) 150 | qfeats = codec.toks_to_qfeats(toks) 151 | rec_feats = codec.qfeats_to_feats(qfeats) 152 | rec_sig = codec.feats_to_sig(rec_feats) 153 | torchaudio.save("reconstruction.wav", rec_sig[:, 0], codec.sample_rate) 154 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | huggingface-hub 2 | safetensors 3 | speechbrain 4 | torch 5 | transformers 6 | -------------------------------------------------------------------------------- /sample.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucadellalib/discrete-wavlm-codec/349d49446b2ada4f73f11a0b2911fd6193ef5b93/sample.wav -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | force_grid_wrap = 0 3 | include_trailing_comma = True 4 | line_length = 88 5 | lines_after_imports = 2 6 | multi_line_output = 3 7 | skip_gitignore = True 8 | use_parentheses = True -------------------------------------------------------------------------------- /src/codec.py: -------------------------------------------------------------------------------- 1 | # ============================================================================== 2 | # Copyright 2024 Luca Della Libera. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """Neural codec.""" 18 | 19 | from typing import Optional, Sequence, Union 20 | 21 | from torch import Tensor, nn 22 | 23 | 24 | __all__ = ["Codec"] 25 | 26 | 27 | class Codec(nn.Module): 28 | """Neural codec. 29 | 30 | Arguments 31 | --------- 32 | encoder: 33 | The encoder, i.e. a module that receives as an input a waveform and returns 34 | the corresponding continuous hidden representations. 35 | quantizer: 36 | The quantizer, i.e. a module that receives as an input continuous hidden representations 37 | and returns the corresponding tokens and quantized hidden representations. 38 | dequantizer: 39 | The dequantizer, i.e. a module that receives as an input quantized hidden representations 40 | and returns the corresponding continuous hidden representations. 41 | vocoder: 42 | The vocoder, i.e. a module that receives as an input continuous hidden representations 43 | and returns the corresponding waveform. 44 | freeze: 45 | The names of the modules to freeze (e.g. `["encoder", "vocoder"]`). 46 | 47 | """ 48 | 49 | def __init__( 50 | self, 51 | encoder: "Optional[nn.Module]" = None, 52 | quantizer: "Optional[nn.Module]" = None, 53 | dequantizer: "Optional[nn.Module]" = None, 54 | vocoder: "Optional[nn.Module]" = None, 55 | freeze: "Union[Sequence[str], bool]" = False, 56 | ) -> "None": 57 | super().__init__() 58 | self.encoder = encoder 59 | self.quantizer = quantizer 60 | self.dequantizer = dequantizer 61 | self.vocoder = vocoder 62 | self.freeze = freeze 63 | if isinstance(freeze, bool): 64 | if freeze: 65 | self.requires_grad_(False).eval() 66 | return 67 | for key in self.freeze: 68 | self._modules[key].requires_grad_(False).eval() 69 | 70 | def forward(self, sig: "Tensor", length: "Optional[Tensor]" = None) -> "Tensor": 71 | """Forward pass. 72 | 73 | Arguments 74 | --------- 75 | sig: 76 | The input waveform, shape: [B, T]. 77 | length: 78 | The relative length, shape: [B]. 79 | 80 | Returns 81 | ------- 82 | The reconstructed waveform, shape (B, T). 83 | 84 | """ 85 | feats = self.sig_to_feats(sig, length) 86 | qfeats = self.feats_to_qfeats(feats) 87 | rec_feats = self.qfeats_to_feats(qfeats, length) 88 | rec_sig = self.feats_to_sig(rec_feats) 89 | return rec_sig 90 | 91 | def sig_to_feats( 92 | self, sig: "Tensor", length: "Optional[Tensor]" = None 93 | ) -> "Tensor": 94 | if self.encoder is None: 95 | raise NotImplementedError 96 | feats = self.encoder(sig, length) # (K, B, N, H) 97 | return feats.movedim(0, -1) # (B, N, H, K) 98 | 99 | def feats_to_sig(self, feats: "Tensor") -> "Tensor": 100 | if self.vocoder is None: 101 | raise NotImplementedError 102 | # (B, N, H, K) 103 | sig = self.vocoder(feats) 104 | return sig # (B, C, T) 105 | 106 | def feats_to_toks(self, feats: "Tensor") -> "Tensor": 107 | if self.quantizer is None: 108 | raise NotImplementedError 109 | # (B, N, H, K) 110 | toks, _ = self.quantizer(feats) 111 | return toks # (B, N, K) 112 | 113 | def feats_to_qfeats(self, feats: "Tensor") -> "Tensor": 114 | if self.quantizer is None: 115 | raise NotImplementedError 116 | # (B, N, H, K) 117 | _, qfeats = self.quantizer(feats) 118 | return qfeats # (B, N, H, K) 119 | 120 | def qfeats_to_feats( 121 | self, qfeats: "Tensor", length: "Optional[Tensor]" = None 122 | ) -> "Tensor": 123 | if self.dequantizer is None: 124 | raise NotImplementedError 125 | # (B, N, H, K) 126 | feats = self.dequantizer(qfeats, length) 127 | return feats # (B, N, H, K) 128 | 129 | def toks_to_qfeats(self, toks: "Tensor") -> "Tensor": 130 | if self.quantizer is None: 131 | raise NotImplementedError 132 | # (B, N, K) 133 | _, qfeats = self.quantizer(toks) 134 | return qfeats # (B, N, H, K) 135 | 136 | 137 | if __name__ == "__main__": 138 | import torch 139 | from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE 140 | from speechbrain.lobes.models.huggingface_transformers.wavlm import WavLM 141 | from speechbrain.lobes.models.transformer.TransformerASR import TransformerASR 142 | 143 | from dequantizer import Dequantizer 144 | from quantizer import KMeansMultiQuantizer 145 | from utils import SBWav2Vec2ForwardWrapper 146 | from vocoder import HifiganVocoder 147 | 148 | layer_ids = [6, 7] 149 | num_features = 768 150 | num_clusters = [300, 300] 151 | encoder = WavLM( 152 | source="microsoft/wavlm-base", 153 | save_path=HUGGINGFACE_HUB_CACHE, 154 | output_all_hiddens=True, 155 | ) 156 | quantizer = KMeansMultiQuantizer(num_features, num_clusters) 157 | dequantizer = Dequantizer( 158 | frontend=torch.nn.Linear(in_features=len(layer_ids), out_features=1), 159 | backbone=TransformerASR( 160 | input_size=num_features, 161 | tgt_vocab=-1, 162 | d_model=128, 163 | nhead=4, 164 | num_encoder_layers=6, 165 | num_decoder_layers=0, 166 | d_ffn=512, 167 | ), 168 | head=torch.nn.Linear(in_features=128, out_features=num_features), 169 | backend=torch.nn.Linear(in_features=1, out_features=len(layer_ids)), 170 | ) 171 | vocoder = HifiganVocoder( 172 | embedding_dim=[num_features] * len(layer_ids), 173 | out_channels=1, 174 | resblock_type="1", 175 | resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], 176 | resblock_kernel_sizes=[3, 7, 11], 177 | upsample_kernel_sizes=[16, 16, 4, 4], 178 | upsample_initial_channel=512, 179 | upsample_factors=[8, 8, 2, 2], 180 | ) 181 | codec = Codec( 182 | SBWav2Vec2ForwardWrapper(encoder, layer_ids), quantizer, dequantizer, vocoder 183 | ) 184 | sigs = torch.rand([10, 16000]) 185 | rec_sig = codec(sigs) 186 | print(rec_sig.shape) 187 | -------------------------------------------------------------------------------- /src/dequantizer.py: -------------------------------------------------------------------------------- 1 | # ============================================================================== 2 | # Copyright 2024 Luca Della Libera. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """Neural dequantizer.""" 18 | 19 | from typing import Optional 20 | 21 | import torch 22 | from torch import Tensor, nn 23 | 24 | 25 | __all__ = ["Dequantizer"] 26 | 27 | 28 | class Dequantizer(nn.Module): 29 | """Dequantizer. 30 | 31 | Arguments 32 | --------- 33 | backbone: 34 | The transformer backbone. 35 | embedding: 36 | The transformer embedding layer. 37 | frontend: 38 | The transformer frontend. 39 | head: 40 | The transformer head. 41 | 42 | Examples 43 | -------- 44 | >>> from speechbrain.lobes.models.convolution import ConvolutionFrontEnd 45 | >>> from speechbrain.lobes.models.transformer.TransformerASR import TransformerASR 46 | >>> from torch import nn 47 | >>> 48 | >>> input_size = 256 49 | >>> d_model = 128 50 | >>> out_channels = (28, 28, 28) 51 | >>> strides = [1, 2, 2] 52 | >>> frontend = ConvolutionFrontEnd((None, None, input_size), out_channels=out_channels, strides=strides) 53 | >>> backbone = TransformerASR( 54 | ... input_size=input_size // torch.Size(strides).numel() * out_channels[-1], 55 | ... tgt_vocab=-1, 56 | ... num_decoder_layers=0, 57 | ... d_model=d_model, 58 | ... ) 59 | >>> head = nn.Linear(d_model, input_size) 60 | >>> model = Dequantizer(backbone, frontend=frontend, head=head) 61 | >>> 62 | >>> input = torch.rand([10, 200, input_size]) 63 | >>> length = torch.ones(10) 64 | >>> output = model(input, length) 65 | 66 | """ 67 | 68 | def __init__( 69 | self, 70 | backbone: "nn.Module", 71 | embedding: "Optional[nn.Module]" = None, 72 | frontend: "Optional[nn.Module]" = None, 73 | head: "Optional[nn.Module]" = None, 74 | backend: "Optional[nn.Module]" = None, 75 | ) -> "None": 76 | super().__init__() 77 | self.backbone = backbone 78 | self.embedding = embedding 79 | self.frontend = frontend 80 | self.head = head 81 | self.backend = backend 82 | 83 | def forward(self, src: "Tensor", length: "Optional[Tensor]" = None) -> "Tensor": 84 | if self.embedding is not None: 85 | src = self.embedding(src) 86 | 87 | if self.frontend is not None: 88 | src = self.frontend(src) 89 | 90 | src_shape = src.shape 91 | if len(src_shape) > 3: 92 | # assert src_shape[-1] == 1 93 | src = src.squeeze(dim=-1) 94 | 95 | if hasattr(self.backbone, "encode"): 96 | # Transformer ASR 97 | src = self.backbone.encode(src, length) 98 | else: 99 | src = self.backbone(src, length) 100 | if self.head is not None: 101 | src = self.head(src) 102 | 103 | if len(src_shape) > 3: 104 | src = src.unsqueeze(dim=-1) 105 | 106 | if self.backend is not None: 107 | src = self.backend(src) 108 | 109 | return src 110 | 111 | 112 | if __name__ == "__main__": 113 | import torch 114 | from speechbrain.lobes.models.convolution import ConvolutionFrontEnd 115 | from speechbrain.lobes.models.transformer.TransformerASR import TransformerASR 116 | from torch import nn 117 | 118 | input_size = 256 119 | d_model = 128 120 | out_channels = (28, 28, 28) 121 | strides = [1, 2, 2] 122 | 123 | frontend = ConvolutionFrontEnd( 124 | (None, None, input_size), out_channels=out_channels, strides=strides 125 | ) 126 | 127 | backbone = TransformerASR( 128 | input_size=input_size // torch.Size(strides).numel() * out_channels[-1], 129 | tgt_vocab=-1, 130 | num_decoder_layers=0, 131 | d_model=d_model, 132 | ) 133 | 134 | head = nn.Linear(d_model, input_size) 135 | 136 | model = Dequantizer(backbone, frontend=frontend, head=head) 137 | 138 | input = torch.rand([10, 200, input_size]) 139 | length = torch.ones(10) 140 | output = model(input, length) 141 | -------------------------------------------------------------------------------- /src/quantizer.py: -------------------------------------------------------------------------------- 1 | # ============================================================================== 2 | # Copyright 2024 Luca Della Libera. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """K-means quantizer.""" 18 | 19 | # Adapted from: 20 | # https://github.com/jokofa/torch_kmeans/tree/be7d2b78664e81a985ddfa6d21d94917a8b49fe6 21 | 22 | import logging 23 | from typing import List, Optional, Tuple, Union 24 | 25 | import torch 26 | from torch import Tensor, nn 27 | 28 | 29 | __all__ = [ 30 | "KMeansMultiQuantizer", 31 | "KMeansQuantizer", 32 | ] 33 | 34 | 35 | _LOGGER = logging.getLogger(__file__) 36 | 37 | 38 | class KMeansQuantizer(nn.Module): 39 | """K-means quantizer. 40 | 41 | Arguments 42 | --------- 43 | num_features: 44 | The number of features. 45 | num_clusters: 46 | The number of clusters. 47 | init: 48 | Method to initialize cluster centroids. One of ["random"]. 49 | normalize: 50 | Method to use to normalize input. One of [None, "mean", "minmax", "unit"]. 51 | 52 | References 53 | ---------- 54 | .. [1] Stuart P. Lloyd. 55 | "Least squares quantization in PCM". 56 | In: IEEE Trans. Information Theory. 1982, pp. 129-137. 57 | URL: https://doi.org/10.1109/TIT.1982.1056489 58 | 59 | Examples 60 | -------- 61 | >>> import torch 62 | >>> 63 | >>> batch_size = 8 64 | >>> seq_length = 200 65 | >>> num_features = 64 66 | >>> num_clusters = 4 67 | >>> kmeans = KMeansQuantizer(num_features, num_clusters) 68 | >>> input = torch.randn(batch_size, seq_length, num_features) 69 | >>> labels, centroids = kmeans(input) 70 | >>> drift = kmeans.step(input, labels) 71 | 72 | """ 73 | 74 | _INIT_METHODS = ["random"] 75 | _NORM_METHODS = ["mean", "minmax", "unit"] 76 | 77 | def __init__( 78 | self, 79 | num_features: "int", 80 | num_clusters: "int", 81 | init: "str" = "random", 82 | normalize: "Optional[Union[str, bool]]" = None, 83 | ) -> "None": 84 | super().__init__() 85 | self.num_features = num_features 86 | self.num_clusters = num_clusters 87 | self.init = init.lower() 88 | self.normalize = normalize 89 | self._check_params() 90 | 91 | # Register centroids as a buffer, with "inf" indicating uninitialized centroids 92 | self.register_buffer( 93 | "centroids", 94 | torch.full((self.num_clusters, self.num_features), float("inf")), 95 | ) 96 | 97 | def reset_parameters(self, feats: "Tensor") -> "None": 98 | """Reset parameters. 99 | 100 | Arguments 101 | --------- 102 | feats: 103 | The input features, shape: [*batch_shape, num_features]. 104 | 105 | """ 106 | feats = feats.reshape(-1, self.num_features) 107 | if feats.shape[0] >= self.num_clusters: 108 | self.centroids = _init_centroids( 109 | feats, self.num_clusters, self.init 110 | ).float() 111 | return 112 | _LOGGER.warning( 113 | "The first batch contains less samples than centroids, skipping initialization..." 114 | ) 115 | 116 | def forward( 117 | self, 118 | feats: "Tensor", 119 | return_centroids: "bool" = True, 120 | ) -> "Tuple[Tensor, Optional[Tensor]]": 121 | """Forward pass. 122 | 123 | Arguments 124 | --------- 125 | feats: 126 | The input features, shape: [*batch_shape, num_features]. 127 | Alternatively, the cluster assignments, shape: [*batch_shape] 128 | (useful to retrieve the assigned centroids). 129 | return_centroids: 130 | True to additionally return the assigned centroids, False otherwise. 131 | 132 | Returns 133 | ------- 134 | - The cluster assignments, shape: [*batch_shape]. 135 | - If `return_centroids=True`, the assigned centroids. 136 | 137 | """ 138 | if return_centroids and feats.shape[-1] != self.num_features: 139 | # Assume a cluster assignment is given as input 140 | labels = feats 141 | batch_shape = labels.shape 142 | labels = labels.flatten() 143 | assigned_centroids = self.centroids.gather( 144 | 0, labels[:, None].expand(-1, self.num_features) 145 | ).clone() 146 | labels = labels.reshape(batch_shape) 147 | assigned_centroids = assigned_centroids.reshape( 148 | *batch_shape, self.num_features 149 | ) 150 | return labels, assigned_centroids 151 | 152 | batch_shape = feats.shape[:-1] 153 | feats = feats.reshape(-1, self.num_features) 154 | 155 | if self.centroids[0, 0].isinf(): 156 | # Initialize centroids 157 | self.reset_parameters(feats) 158 | 159 | if self.normalize is not None: 160 | feats = _normalize(feats, self.normalize) 161 | 162 | # Handle mixed precision 163 | centroids = self.centroids.to(feats) 164 | 165 | dist = _compute_pairwise_distance(feats, centroids) 166 | 167 | # Get cluster assignments (index of the closest centroid) 168 | labels = dist.argmin(dim=-1) 169 | if return_centroids: 170 | assigned_centroids = centroids.gather( 171 | 0, labels[:, None].expand(-1, self.num_features) 172 | ).clone() 173 | labels = labels.reshape(batch_shape) 174 | assigned_centroids = assigned_centroids.reshape( 175 | *batch_shape, self.num_features 176 | ) 177 | return labels, assigned_centroids 178 | 179 | labels = labels.reshape(batch_shape) 180 | return labels, None 181 | 182 | def step( 183 | self, 184 | feats: "Tensor", 185 | labels: "Optional[Tensor]" = None, 186 | return_drift: "bool" = True, 187 | ) -> "Optional[Tensor]": 188 | """ "Lloyd's K-means update. 189 | 190 | Arguments 191 | --------- 192 | feats: 193 | The input features, shape: [*batch_shape, num_features]. 194 | labels: 195 | The corresponding labels, shape: [*batch_shape]. 196 | return_drift: 197 | True to return the drift between current and previous centroids, False otherwise. 198 | 199 | Returns 200 | ------- 201 | If `return_drift=True`, the drift between current and previous centroids. 202 | 203 | """ 204 | feats = feats.reshape(-1, self.num_features) 205 | 206 | if feats.shape[0] < self.num_clusters: 207 | _LOGGER.warning( 208 | f"Number of samples ({feats.shape[0]}) is less than the number " 209 | f"of clusters ({self.num_clusters}), skipping this batch", 210 | ) 211 | return torch.zeros(1, device=feats.device) if return_drift else None 212 | 213 | if labels is None: 214 | labels, _ = self.forward(feats, return_centroids=False) 215 | labels = labels.flatten() 216 | 217 | # Update cluster centroids 218 | old_centroids = self.centroids.clone() 219 | self.centroids = _group_by_label_mean(feats, labels, self.num_clusters) 220 | 221 | if return_drift: 222 | # Compute centroid drift 223 | drift = _compute_drift(self.centroids, old_centroids) 224 | return drift 225 | 226 | def evaluate(self, feats: "Tensor", labels: "Optional[Tensor]" = None) -> "Tensor": 227 | """Compute inertia for the current batch, i.e. the sum of squared distances 228 | of samples to their closest cluster centroid. 229 | 230 | Arguments 231 | --------- 232 | feats: 233 | The input features, shape: [*batch_shape, num_features]. 234 | labels: 235 | The corresponding labels, shape: [*batch_shape]. 236 | 237 | Returns 238 | ------- 239 | The inertia for the current batch. 240 | 241 | """ 242 | batch_shape = feats.shape[:-1] 243 | feats = feats.reshape(-1, self.num_features) 244 | if labels is None: 245 | labels, _ = self.forward(feats, return_centroids=False) 246 | labels = labels.flatten() 247 | inertia = _compute_inertia(feats, self.centroids, labels) 248 | inertia = inertia.reshape(batch_shape) 249 | return inertia 250 | 251 | def _check_params(self): 252 | """Check initialization parameters.""" 253 | if self.num_features < 1: 254 | raise ValueError(f"`num_features` ({self.num_features}) must be > 0") 255 | if self.num_clusters < 2: 256 | raise ValueError(f"`num_clusters` ({self.num_clusters}) must be > 1") 257 | if self.init not in self._INIT_METHODS: 258 | raise ValueError( 259 | f"`init` ({self.init}) must be one of {self._INIT_METHODS}" 260 | ) 261 | if isinstance(self.normalize, bool): 262 | if self.normalize: 263 | self.normalize = "mean" 264 | else: 265 | self.normalize = None 266 | if self.normalize is not None and self.normalize not in self._NORM_METHODS: 267 | raise ValueError( 268 | f"`normalize` ({self.normalize}) must be one of {self._NORM_METHODS}" 269 | ) 270 | 271 | def __repr__(self): 272 | return ( 273 | f"{type(self).__name__}(" 274 | f"num_features: {self.num_features}, " 275 | f"num_clusters: {self.num_clusters}, " 276 | f"init: {self.init}, " 277 | f"normalize: {self.normalize})" 278 | ) 279 | 280 | 281 | class KMeansMultiQuantizer(nn.Module): 282 | """K-means quantizer with multiple instances.""" 283 | 284 | def __init__(self, *args, **kwargs) -> "None": 285 | super().__init__() 286 | max_length = max( 287 | len(v) 288 | for v in args + tuple(kwargs.values()) 289 | if isinstance(v, (list, tuple)) 290 | ) 291 | args = [v if isinstance(v, (list, tuple)) else [v] * max_length for v in args] 292 | kwargs = { 293 | k: v if isinstance(v, (list, tuple)) else [v] * max_length 294 | for k, v in kwargs.items() 295 | } 296 | all_args = list(zip(*args)) 297 | all_kwargs_values = list(zip(*kwargs.values())) 298 | all_kwargs = [dict(zip(kwargs.keys(), values)) for values in all_kwargs_values] 299 | if not all_args: 300 | all_args = [[] for _ in range(len(all_kwargs))] 301 | if not all_kwargs: 302 | all_kwargs = [{} for _ in range(len(all_args))] 303 | assert len(all_args) == len(all_kwargs) 304 | 305 | kmeanss = [ 306 | KMeansQuantizer(*args, **kwargs) 307 | for args, kwargs in zip(all_args, all_kwargs) 308 | ] 309 | self.kmeanss = nn.ModuleList(kmeanss) 310 | 311 | @property 312 | def num_features(self) -> "List[int]": 313 | return [kmeans.num_features for kmeans in self.kmeanss] 314 | 315 | @property 316 | def num_clusters(self) -> "List[int]": 317 | return [kmeans.num_clusters for kmeans in self.kmeanss] 318 | 319 | @property 320 | def init(self) -> "List[str]": 321 | return [kmeans.init for kmeans in self.kmeanss] 322 | 323 | @property 324 | def normalize(self) -> "List[str]": 325 | return [kmeans.normalize for kmeans in self.kmeanss] 326 | 327 | @property 328 | def centroids(self) -> "Union[Tensor, List[Tensor]]": 329 | if len(self.kmeanss) == 1: 330 | # Fast path 331 | return self.kmeanss[0].centroids[..., None] 332 | centroids_list = [kmeans.centroids for kmeans in self.kmeanss] 333 | try: 334 | centroids = torch.stack(centroids_list).movedim(0, -1) 335 | except RuntimeError: 336 | centroids = centroids_list 337 | return centroids 338 | 339 | def reset_parameters(self, feats: "Tensor") -> "None": 340 | assert feats.shape[-1] == len(self.kmeanss) 341 | for i, kmeans in enumerate(self.kmeanss): 342 | kmeans.reset_parameters(feats[..., i]) 343 | 344 | def forward( 345 | self, feats: "Tensor", return_centroids: "bool" = True 346 | ) -> "Tuple[Tensor, Optional[Tensor]]": 347 | assert feats.shape[-1] == len(self.kmeanss) 348 | 349 | if len(self.kmeanss) == 1: 350 | # Fast path 351 | labels, assigned_centroids = self.kmeanss[0]( 352 | feats[..., 0], return_centroids 353 | ) 354 | labels = labels[..., None] 355 | if return_centroids: 356 | assigned_centroids = assigned_centroids[..., None] 357 | return labels, assigned_centroids 358 | return labels 359 | 360 | labels_list, assigned_centroids_list = [], [] 361 | for i, kmeans in enumerate(self.kmeanss): 362 | labels, assigned_centroids = kmeans(feats[..., i], return_centroids) 363 | labels_list.append(labels) 364 | assigned_centroids_list.append(assigned_centroids) 365 | labels = torch.stack(labels_list).movedim(0, -1) 366 | if return_centroids: 367 | assigned_centroids = torch.stack(assigned_centroids_list).movedim(0, -1) 368 | return labels, assigned_centroids 369 | return labels 370 | 371 | def step( 372 | self, 373 | feats: "Tensor", 374 | labels: "Optional[Tensor]" = None, 375 | return_drift: "bool" = True, 376 | ) -> "Optional[Tensor]": 377 | assert feats.shape[-1] == len(self.kmeanss) 378 | total_drift = 0.0 379 | for i, kmeans in enumerate(self.kmeanss): 380 | drift = kmeans.step( 381 | feats[..., i], 382 | labels[..., i] if labels is not None else None, 383 | return_drift, 384 | ) 385 | if return_drift: 386 | total_drift += drift 387 | if return_drift: 388 | return total_drift / len(self.kmeanss) 389 | 390 | def evaluate(self, feats: "Tensor", labels: "Optional[Tensor]" = None) -> "Tensor": 391 | assert feats.shape[-1] == len(self.kmeanss) 392 | total_inertia = 0.0 393 | for i, kmeans in enumerate(self.kmeanss): 394 | inertia = kmeans.evaluate( 395 | feats[..., i], labels[..., i] if labels is not None else None 396 | ) 397 | total_inertia += inertia 398 | return total_inertia / len(self.kmeanss) 399 | 400 | 401 | @torch.jit.script 402 | def _init_centroids(feats: "Tensor", k: "int", init: "str" = "random") -> "Tensor": 403 | """Initialize centroids according to specified method: 404 | 405 | - "random": random initialization. 406 | 407 | """ 408 | if init == "random": 409 | b = feats.shape[0] 410 | rnd_idx = torch.multinomial( 411 | torch.full((b,), 1 / b, device=feats.device), k, replacement=k > b 412 | ) 413 | return feats[rnd_idx].reshape(k, -1) 414 | else: 415 | raise NotImplementedError 416 | 417 | 418 | @torch.jit.script 419 | def _normalize( 420 | feats: "Tensor", normalize: "str" = "mean", eps: "float" = 1e-8 421 | ) -> "Tensor": 422 | """Normalize input features according to specified method: 423 | 424 | - "mean": subtract sample mean. 425 | - "minmax": min-max normalization subtracting sample min and divide by sample max. 426 | - "unit": normalize features to lie on D-dimensional unit sphere. 427 | 428 | """ 429 | if normalize == "mean": 430 | feats -= feats.mean(dim=0)[None] 431 | return feats 432 | elif normalize == "minmax": 433 | feats -= feats.min(dim=-1).values[:, None] 434 | feats /= feats.max(dim=-1).values[:, None] 435 | return feats 436 | elif normalize == "unit": 437 | z_msk = feats == 0 438 | feats = feats.clone() 439 | feats[z_msk] = eps 440 | feats = (1.0 / (feats.norm(p=2.0, dim=-1))).diag_embed() @ feats 441 | return feats 442 | else: 443 | raise NotImplementedError 444 | 445 | 446 | @torch.jit.script 447 | def _compute_pairwise_distance(feats: "Tensor", centroids: "Tensor") -> "Tensor": 448 | """Compute pairwise distances between features and centroids.""" 449 | # Approximate implementation (time and memory-efficient) 450 | feats_norm = (feats**2).sum(dim=-1)[:, None] 451 | centroids_norm = (centroids**2).sum(dim=-1)[None] 452 | dist = (feats_norm + centroids_norm - 2 * feats @ centroids.T).clamp(min=0.0).sqrt() 453 | return dist 454 | 455 | 456 | # @torch.jit.script 457 | # def _compute_pairwise_distance(feats: "Tensor", centroids: "Tensor") -> "Tensor": 458 | # Exact implementation (time and memory-inefficient) 459 | # b, d = feats.shape 460 | # k, d = centroids.shape 461 | # x = feats[:, None].expand(b, k, d).reshape(-1, d) 462 | # centroids = centroids.expand(b, k, d).reshape(-1, d) 463 | # return nn.functional.pairwise_distance(x, centroids, p=2.0).reshape(b, k) 464 | 465 | 466 | @torch.jit.script 467 | def _group_by_label_mean(feats: "Tensor", labels: "Tensor", k: "int") -> "Tensor": 468 | """Group features by label and compute group mean.""" 469 | M = nn.functional.one_hot(labels, num_classes=k).T.to(feats.dtype) 470 | M = nn.functional.normalize(M, p=1.0, dim=-1) 471 | return M @ feats 472 | 473 | 474 | @torch.jit.script 475 | def _compute_drift(centroids: "Tensor", old_centroids: "Tensor") -> "Tensor": 476 | """Compute drift between current and previous centroids.""" 477 | dist = (centroids - old_centroids).norm(p=2.0, dim=-1) 478 | dist[dist.isinf()] = 0.0 479 | return dist.mean(dim=-1) 480 | 481 | 482 | @torch.jit.script 483 | def _compute_inertia( 484 | feats: "Tensor", 485 | centroids: "Tensor", 486 | labels: "Tensor", 487 | ) -> "Tensor": 488 | """Compute inertia, i.e. the sum of squared distances of samples to their closest cluster centroid.""" 489 | b, d = feats.shape 490 | # Select assigned centroid by label and compute squared distance 491 | assigned_centroids = centroids.gather(0, labels[:, None].expand(-1, d)) 492 | # Squared distance to closest centroid 493 | dist = (feats - assigned_centroids).norm(p=2.0, dim=-1) ** 2 494 | dist[dist.isinf()] = 0 495 | return dist 496 | 497 | 498 | # Test 499 | if __name__ == "__main__": 500 | try: 501 | import matplotlib 502 | except ImportError: 503 | raise ImportError("`pip install matplotlib` to run this script") 504 | 505 | try: 506 | import numpy 507 | except ImportError: 508 | raise ImportError("`pip install numpy` to run this script") 509 | 510 | try: 511 | import sklearn 512 | except ImportError: 513 | raise ImportError("`pip install scikit-learn` to run this script") 514 | 515 | import time 516 | 517 | import matplotlib.pyplot as plt 518 | import numpy as np 519 | from sklearn.cluster import MiniBatchKMeans 520 | from sklearn.datasets import make_blobs 521 | from sklearn.metrics.pairwise import pairwise_distances_argmin 522 | 523 | np.random.seed(0) 524 | torch.manual_seed(0) 525 | 526 | n_samples = 30000 527 | batch_size = 1024 528 | centers = [[1, 1], [-1, -1], [1, -1]] 529 | n_clusters = len(centers) 530 | max_iter = 100 531 | X, labels_true = make_blobs(n_samples=n_samples, centers=centers, cluster_std=0.7) 532 | 533 | # PyTorch 534 | k_means_torch = KMeansQuantizer(2, n_clusters) 535 | X_torch = torch.from_numpy(X) 536 | t0 = time.time() 537 | for epoch in range(max_iter): 538 | for i in range(n_samples // batch_size): 539 | batch = X_torch[i * batch_size : (i + 1) * batch_size] 540 | k_means_torch.step(batch) 541 | t_batch = time.time() - t0 542 | 543 | # Scikit-learn 544 | mbk = MiniBatchKMeans( 545 | init="random", 546 | n_clusters=n_clusters, 547 | batch_size=batch_size, 548 | max_iter=max_iter, 549 | n_init=1, 550 | max_no_improvement=10000, 551 | reassignment_ratio=0.0, 552 | verbose=0, 553 | ) 554 | t0 = time.time() 555 | mbk.fit(X) 556 | t_mini_batch = time.time() - t0 557 | 558 | k_means_torch_cluster_centers = k_means_torch.centroids.numpy() 559 | order = pairwise_distances_argmin( 560 | k_means_torch_cluster_centers, mbk.cluster_centers_ 561 | ) 562 | mbk_means_cluster_centers = mbk.cluster_centers_[order] 563 | 564 | k_means_labels = pairwise_distances_argmin(X, k_means_torch_cluster_centers) 565 | mbk_means_labels = pairwise_distances_argmin(X, mbk_means_cluster_centers) 566 | 567 | fig = plt.figure(figsize=(8, 3)) 568 | fig.subplots_adjust(left=0.02, right=0.98, bottom=0.05, top=0.9) 569 | colors = ["#4EACC5", "#FF9C34", "#4E9A06"] 570 | 571 | # PyTorch 572 | ax = fig.add_subplot(1, 3, 1) 573 | for k, col in zip(range(n_clusters), colors): 574 | my_members = k_means_labels == k 575 | cluster_center = k_means_torch_cluster_centers[k] 576 | ax.plot( 577 | X[my_members, 0], 578 | X[my_members, 1], 579 | "w", 580 | markerfacecolor=col, 581 | marker=".", 582 | ) 583 | ax.plot( 584 | cluster_center[0], 585 | cluster_center[1], 586 | "o", 587 | markerfacecolor=col, 588 | markeredgecolor="k", 589 | markersize=6, 590 | ) 591 | ax.set_title("MiniBatchKMeans PyTorch") 592 | ax.set_xticks(()) 593 | ax.set_yticks(()) 594 | plt.text( 595 | -3.5, 596 | 1.8, 597 | "train time: %.2fs\ninertia: %f" 598 | % (t_batch, k_means_torch.evaluate(X_torch).sum().item()), 599 | ) 600 | 601 | # Scikit-learn 602 | ax = fig.add_subplot(1, 3, 2) 603 | for k, col in zip(range(n_clusters), colors): 604 | my_members = mbk_means_labels == k 605 | cluster_center = mbk_means_cluster_centers[k] 606 | ax.plot( 607 | X[my_members, 0], 608 | X[my_members, 1], 609 | "w", 610 | markerfacecolor=col, 611 | marker=".", 612 | ) 613 | ax.plot( 614 | cluster_center[0], 615 | cluster_center[1], 616 | "o", 617 | markerfacecolor=col, 618 | markeredgecolor="k", 619 | markersize=6, 620 | ) 621 | ax.set_title("MiniBatchKMeans Scikit-learn") 622 | ax.set_xticks(()) 623 | ax.set_yticks(()) 624 | plt.text( 625 | -3.5, 626 | 1.8, 627 | "train time: %.2fs\ninertia: %f" % (t_mini_batch, mbk.inertia_), 628 | ) 629 | 630 | # Initialize the different array to all False 631 | different = mbk_means_labels == 4 632 | ax = fig.add_subplot(1, 3, 3) 633 | 634 | for k in range(n_clusters): 635 | different += (k_means_labels == k) != (mbk_means_labels == k) 636 | 637 | identical = np.logical_not(different) 638 | ax.plot( 639 | X[identical, 0], 640 | X[identical, 1], 641 | "w", 642 | markerfacecolor="#bbbbbb", 643 | marker=".", 644 | ) 645 | ax.plot(X[different, 0], X[different, 1], "w", markerfacecolor="m", marker=".") 646 | ax.set_title("Difference") 647 | ax.set_xticks(()) 648 | ax.set_yticks(()) 649 | 650 | plt.show() 651 | 652 | 653 | if __name__ == "__main__": 654 | quantizer = KMeansMultiQuantizer(num_features=[300, 300], num_clusters=[10, 10]) 655 | input = torch.randn(5, 200, 300, 2) 656 | labels, _ = quantizer(input) 657 | print(labels.shape) 658 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | # ============================================================================== 2 | # Copyright 2024 Luca Della Libera. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """Common utilities.""" 18 | 19 | from typing import Optional, Sequence 20 | 21 | import torch 22 | from torch import Tensor, nn 23 | from transformers.models.hubert.modeling_hubert import HubertEncoderStableLayerNorm 24 | from transformers.models.wav2vec2.modeling_wav2vec2 import ( 25 | Wav2Vec2EncoderStableLayerNorm, 26 | ) 27 | from transformers.models.wavlm.modeling_wavlm import WavLMEncoderStableLayerNorm 28 | 29 | 30 | __all__ = ["SBWav2Vec2ForwardWrapper"] 31 | 32 | 33 | class SBWav2Vec2ForwardWrapper(nn.Module): 34 | """SpeechBrain wav2vec 2.0 wrapper that returns the hidden representations from the specified layer IDs. 35 | 36 | Arguments 37 | --------- 38 | wav2vec2: 39 | The SpeechBrain wav2vec 2.0 module. 40 | layer_ids: 41 | The layer IDs from which the hidden representations are extracted. 42 | 43 | Examples 44 | -------- 45 | >>> import torch 46 | >>> from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE 47 | >>> from speechbrain.lobes.models.huggingface_transformers.wavlm import WavLM 48 | >>> 49 | >>> encoder = WavLM(source="microsoft/wavlm-large", save_path=HUGGINGFACE_HUB_CACHE) 50 | >>> encoder = SBWav2Vec2ForwardWrapper(encoder, layer_ids=[6, 7]) 51 | >>> 52 | >>> input = torch.rand([10, 16000]) 53 | >>> length = torch.ones(10) 54 | >>> output = encoder(input, length) 55 | 56 | """ 57 | 58 | def __init__(self, wav2vec2: "nn.Module", layer_ids: "Sequence[int]") -> "None": 59 | super().__init__() 60 | self.wav2vec2 = wav2vec2 61 | # Workaround to deal with hardcoded class name in discrete SSL 62 | # https://github.com/speechbrain/speechbrain/blob/60062c2536e8122253d6ad0e681208f554528950/speechbrain/lobes/models/huggingface_transformers/discrete_ssl.py#L88 63 | self.__class__.__name__ = self.wav2vec2.__class__.__name__ 64 | self.layer_ids = sorted(layer_ids) 65 | assert hasattr(self.wav2vec2, "model") 66 | assert hasattr(self.wav2vec2.model, "encoder") 67 | assert hasattr(self.wav2vec2.model.encoder, "layers") 68 | # Workaround for early exiting to avoid the computational overhead of forwarding through the whole model 69 | # NOTE: the model is modified in-place 70 | self.wav2vec2.output_all_hiddens = True 71 | self.wav2vec2.model.encoder.layers = self.wav2vec2.model.encoder.layers[ 72 | : max(self.layer_ids) 73 | ] 74 | # NOTE: workaround to account for layer norm applied to the last hidden states when StableLayerNorm variant is used: 75 | # https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/wavlm/modeling_wavlm.py#L816 76 | if isinstance( 77 | self.wav2vec2.model.encoder, 78 | ( 79 | HubertEncoderStableLayerNorm, 80 | Wav2Vec2EncoderStableLayerNorm, 81 | WavLMEncoderStableLayerNorm, 82 | ), 83 | ): 84 | self.wav2vec2.model.encoder.layer_norm = torch.nn.Identity() 85 | 86 | def extract_features( 87 | self, wav: "Tensor", length: "Optional[Tensor]" = None 88 | ) -> "Tensor": 89 | feats = self.wav2vec2(wav, length) # (K, B, N, H) 90 | return feats[self.layer_ids] 91 | 92 | def forward(self, wav: "Tensor", length: "Optional[Tensor]" = None) -> "Tensor": 93 | return self.extract_features(wav, length) 94 | 95 | 96 | if __name__ == "__main__": 97 | from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE 98 | from speechbrain.lobes.models.huggingface_transformers.wav2vec2 import Wav2Vec2 99 | 100 | for source in [ 101 | "facebook/wav2vec2-large-960h-lv60-self", 102 | "facebook/hubert-large-ll60k", 103 | "microsoft/wavlm-large", 104 | ]: 105 | layer_ids = [3, 7] 106 | encoder1 = Wav2Vec2( 107 | source=source, 108 | save_path=HUGGINGFACE_HUB_CACHE, 109 | output_norm=True, 110 | ) 111 | encoder1 = SBWav2Vec2ForwardWrapper(encoder1, layer_ids=layer_ids).eval() 112 | 113 | encoder2 = Wav2Vec2( 114 | source=source, 115 | save_path=HUGGINGFACE_HUB_CACHE, 116 | output_norm=True, 117 | output_all_hiddens=True, 118 | ).eval() 119 | 120 | input = torch.ones([1, 16000]) 121 | with torch.no_grad(): 122 | output1 = encoder1(input) 123 | output2 = encoder2(input)[layer_ids] 124 | 125 | print((output1 == output2).all()) 126 | -------------------------------------------------------------------------------- /src/vocoder.py: -------------------------------------------------------------------------------- 1 | # ============================================================================== 2 | # Copyright 2024 Luca Della Libera. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """HiFi-GAN vocoder.""" 18 | 19 | import torch 20 | from speechbrain.lobes.models.HifiGAN import HifiganGenerator 21 | from torch import nn 22 | 23 | 24 | __all__ = ["HifiganVocoder"] 25 | 26 | 27 | # Use default parameters from: 28 | # https://github.com/bshall/knn-vc/blob/848302a262f7299c738af49d74209790ed442a9f/hifigan/config_v1_wavlm.json 29 | class HifiganVocoder(HifiganGenerator): 30 | def __init__( 31 | self, 32 | embedding_dim, 33 | out_channels=1, 34 | resblock_type=1, 35 | resblock_dilation_sizes=((1, 3, 5), (1, 3, 5), (1, 3, 5)), 36 | resblock_kernel_sizes=(3, 7, 11), 37 | upsample_kernel_sizes=(20, 16, 4, 4), 38 | upsample_initial_channel=512, 39 | upsample_factors=(10, 8, 2, 2), 40 | inference_padding=5, 41 | cond_channels=0, 42 | conv_post_bias=True, 43 | ): 44 | if isinstance(embedding_dim, (list, tuple)): 45 | assert all([x == embedding_dim[0] for x in embedding_dim]) 46 | self.embedding_dim = embedding_dim[0] 47 | self.num_codebooks = len(embedding_dim) 48 | else: 49 | self.embedding_dim = embedding_dim 50 | self.num_codebooks = 1 51 | super().__init__( 52 | in_channels=self.embedding_dim, 53 | out_channels=out_channels, 54 | resblock_type=str(resblock_type), 55 | resblock_dilation_sizes=resblock_dilation_sizes, 56 | resblock_kernel_sizes=resblock_kernel_sizes, 57 | upsample_kernel_sizes=upsample_kernel_sizes, 58 | upsample_initial_channel=upsample_initial_channel, 59 | upsample_factors=upsample_factors, 60 | inference_padding=inference_padding, 61 | cond_channels=cond_channels, 62 | conv_post_bias=conv_post_bias, 63 | ) 64 | self.in_proj = nn.Linear(self.num_codebooks, 1) 65 | 66 | def forward(self, x, g=None): 67 | # (batch, time, embedding_dim, num_codebooks) 68 | x = self.in_proj(x) 69 | # (batch, time, embedding_dim, 1) 70 | x = x.squeeze(dim=-1) 71 | # (batch, time, embedding_dim) 72 | x = x.movedim(-1, -2) 73 | # (batch, embedding_dim, time) 74 | return super().forward(x, g) 75 | 76 | @torch.no_grad() 77 | def inference(self, x, g=None, **kwargs): 78 | return self.forward(x, g) 79 | 80 | 81 | if __name__ == "__main__": 82 | from copy import deepcopy 83 | 84 | embedding_dim = 200 85 | x = torch.randn(2, 49, embedding_dim, 1) 86 | model = HifiganVocoder(embedding_dim) 87 | output = model(x) 88 | print(output.shape) 89 | with torch.no_grad(): 90 | model(x) 91 | deepcopy(model) 92 | --------------------------------------------------------------------------------