├── LICENSE ├── README.md ├── config └── spt_base_cfg.json ├── example.py ├── images ├── READ.me ├── overview.png └── speechtokenizer_framework.jpg ├── samples └── example_input.wav ├── scripts ├── hubert_rep_extract.py ├── hubert_rep_extract.sh ├── train_example.py └── train_example.sh ├── setup.py └── speechtokenizer ├── __init__.py ├── discriminators.py ├── model.py ├── modules ├── __init__.py ├── conv.py ├── lstm.py ├── norm.py └── seanet.py ├── quantization ├── __init__.py ├── ac.py ├── core_vq.py ├── distrib.py └── vq.py └── trainer ├── __init__.py ├── dataset.py ├── loss.py ├── optimizer.py └── trainer.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SpeechTokenizer: Unified Speech Tokenizer for Speech Language Models 2 | 3 | 4 | 5 | ## Introduction 6 | This is the code for the SpeechTokenizer presented in the [SpeechTokenizer: Unified Speech Tokenizer for Speech Language Models](https://arxiv.org/abs/2308.16692). SpeechTokenizer is a unified speech tokenizer for speech language models, which adopts the Encoder-Decoder architecture with residual vector quantization (RVQ). Unifying semantic and acoustic tokens, SpeechTokenizer disentangles different aspects of speech information hierarchically across different RVQ layers. Specifically, the code indices that the first quantizer of RVQ outputs can be considered as semantic tokens and the output of the remaining quantizers mainly contain timbre info, which serve as supplements for the information lost by the first quantizer. We provide our models: 7 | * A model operated at 16khz on monophonic speech trained on Librispeech with average representation across all HuBERT layers as semantic teacher. 8 | * A model with [Snake activation](https://arxiv.org/abs/2306.06546) operated at 16khz on monophonic speech trained on Librispeech and Common Voice with average representation across all HuBERT layers as semantic teacher. 9 | 10 |
11 |

12 |
13 | Overview 14 |

15 |

16 |
17 | The SpeechTokenizer framework. 18 |

19 |
20 | 21 | 22 | Welcome to try our [SLMTokBench](https://github.com/0nutation/SLMTokBench) 23 | and we will also open source our [USLM](https://github.com/0nutation/USLM)! 24 | 25 | ## Qick Link 26 | * [Relase](#release) 27 | * [Samples](#samples) 28 | * [Installation](#installation) 29 | * [Model List](#model-list) 30 | * [Usage](#usage) 31 | * [Train SpeechTokenizer](#train-speechtokenizer) 32 | * [Data Preprocess](#data-preprocess) 33 | * [Train](#train) 34 | * [Quick Start](#quick-start) 35 | * [Citation](#citation) 36 | * [License](#license) 37 | 38 | 39 | ## Release 40 | - [2024/6/9] 🔥 We released the training code of SpeechTokenizer. 41 | - [2024/3] 🔥 We released a checkpoint of SpeechTokenizer with [Snake activation](https://arxiv.org/abs/2306.06546) trained on LibriSpeech and Common Voice. 42 | - [2023/9/11] 🔥 We released code of [soundstorm_speechtokenizer](https://github.com/ZhangXInFD/soundstorm-speechtokenizer). 43 | - [2023/9/10] 🔥 We released code and checkpoints of [USLM](https://github.com/0nutation/USLM). 44 | - [2023/9/1] 🔥 We released code and checkpoints of SpeechTokenizer. Checkout the [paper](https://arxiv.org/abs/2308.16692) and [demo](https://0nutation.github.io/SpeechTokenizer.github.io/). 45 | 46 | ## Samples 47 | 48 | Samples are provided on [our demo page](https://0nutation.github.io/SpeechTokenizer.github.io/). 49 | 50 | ## Installation 51 | 52 | SpeechTokenizer requires Python>=3.8, and a reasonly recent version of PyTorch. 53 | To install SpeechTokenizer, you can run from this repository: 54 | ```bash 55 | pip install -U speechtokenizer 56 | 57 | # or you can clone the repo and install locally 58 | git clone https://github.com/ZhangXInFD/SpeechTokenizer.git 59 | cd SpeechTokenizer 60 | pip install . 61 | ``` 62 | ## Model List 63 | | Model| Dataset |Discription| 64 | |:----|:----:|:----| 65 | |[speechtokenizer_hubert_avg](https://huggingface.co/fnlp/SpeechTokenizer/tree/main/speechtokenizer_hubert_avg)|LibriSpeech|Adopt average representation across all HuBERT layers as semantic teacher | 66 | |[speechtokenizer_snake](https://huggingface.co/fnlp/AnyGPT-speech-modules/tree/main/speechtokenizer)|LibriSpeech + Common Voice|Snake activation, average representation across all HuBERT layers | 67 | ## Usage 68 | ### load model 69 | ```python 70 | from speechtokenizer import SpeechTokenizer 71 | 72 | config_path = '/path/config.json' 73 | ckpt_path = '/path/SpeechTokenizer.pt' 74 | model = SpeechTokenizer.load_from_checkpoint(config_path, ckpt_path) 75 | model.eval() 76 | ``` 77 | ### Extracting discrete representations 78 | ```python 79 | import torchaudio 80 | import torch 81 | 82 | # Load and pre-process speech waveform 83 | wav, sr = torchaudio.load('') 84 | 85 | # monophonic checking 86 | if wav.shape(0) > 1: 87 | wav = wav[:1,:] 88 | 89 | if sr != model.sample_rate: 90 | wav = torchaudio.functional.resample(wav, sr, model.sample_rate) 91 | 92 | wav = wav.unsqueeze(0) 93 | 94 | # Extract discrete codes from SpeechTokenizer 95 | with torch.no_grad(): 96 | codes = model.encode(wav) # codes: (n_q, B, T) 97 | 98 | RVQ_1 = codes[:1, :, :] # Contain content info, can be considered as semantic tokens 99 | RVQ_supplement = codes[1:, :, :] # Contain timbre info, complete info lost by the first quantizer 100 | ``` 101 | 102 | ### Decoding discrete representations 103 | ```python 104 | # Concatenating semantic tokens (RVQ_1) and supplementary timbre tokens and then decoding 105 | wav = model.decode(torch.cat([RVQ_1, RVQ_supplement], axis=0)) 106 | 107 | # Decoding from RVQ-i:j tokens from the ith quantizers to the jth quantizers 108 | wav = model.decode(codes[i: (j + 1)], st=i) 109 | ``` 110 | 111 | ## Train SpeechTokenizer 112 | In the following section, we describe how to train a SpeechTokenizer model by using our trainer. 113 | ### Data Preprocess 114 | To train the SpeechTokenizer, the first step is to extract semantic teacher representations from raw audio waveforms. We provide an example of how to extract HuBERT representations in [scripts/hubert_rep_extract.sh](scripts/hubert_rep_extract.sh). We explain the arguments in the following: 115 | * `--config`: Config file path. An example is provided in [config/spt_base_cfg.json](config/spt_base_cfg.json). You can modify the `semantic_model_path` and `semantic_model_layer` parameters in this file to change the Hubert model and the target layer. 116 | * `--audio_dir`: The path to the folder containing all audio files. 117 | * `--rep_dir`: The path to the folder storing all semantic representation files. 118 | * `--exts`: The file extension of the audio files. Use ',' to separate multiple extensions if they exist. 119 | * `--split_seed`: Random seed for splitting training set and validation set. 120 | * `--valid_set_size`: The size of validation set. When this number is between 0 and 1, it represents the proportion of the total dataset used for the validation set. 121 | 122 | ### Train 123 | You can use SpeechTokenizerTrainer to train a SpeechTokenizer as follows: 124 | ```python 125 | from speechtokenizer import SpeechTokenizer, SpeechTokenizerTrainer 126 | from speechtokenizer.discriminators import MultiPeriodDiscriminator, MultiScaleDiscriminator, MultiScaleSTFTDiscriminator 127 | import json 128 | 129 | 130 | # Load model and trainer config 131 | with open('') as f: 132 | cfg = json.load(f) 133 | 134 | # Initialize SpeechTokenizer 135 | generator = SpeechTokenizer(cfg) 136 | 137 | # Initialize the discriminators. You can add any discriminator that is not yet implemented in this repository, as long as the output format remains consistent with the discriminators in `speechtokenizer.discriminators`. 138 | discriminators = {'mpd':MultiPeriodDiscriminator(), 'msd':MultiScaleDiscriminator(), 'mstftd':MultiScaleSTFTDiscriminator(32)} 139 | 140 | # Initialize Trainer 141 | trainer = SpeechTokenizerTrainer(generator=generator, 142 | discriminators=discriminators, 143 | cfg=cfg) 144 | 145 | # Start training 146 | trainer.train() 147 | 148 | # Continue training from checkpoints 149 | trainer.continue_train() 150 | ``` 151 | We provide example training scripts in [scripts/train_example.sh](scripts/train_example.sh). All arguments for SpeechTokenizerTrainer are defined in [config/spt_base_cfg.json](config/spt_base_cfg.json). Below, we explain some of the important arguments: 152 | * `train_files` and `valid_files`: Training file path and validation file path. These files should be text files listing the paths of all audio files and their corresponding semantic representation files in the training/validation set. Each line should follow the format: "\t". If you use [scripts/hubert_rep_extract.sh](scripts/hubert_rep_extract.sh) to extract semantic representations, these two files will be genrated automantically. 153 | * `distill_type`: Use "d_axis" for D-axis distillation loss and "t_axis" for T-axis distillation loss, as mentioned in the paper. 154 | 155 | ### Quick Start 156 | If you want to fully follow our experimental setup, simply set `semantic_model_path` in [config/spt_base_cfg.json](config/spt_base_cfg.json), and `AUDIO_DIR`, `REP_DIR`, `EXTS` in [scripts/hubert_rep_extract.sh](scripts/hubert_rep_extract.sh), and other optional arguments , then execute the following code: 157 | ```shell 158 | cd SpeechTokenizer 159 | 160 | # Extact semantic representation 161 | bash scripts/hubert_rep_extract.sh 162 | 163 | # Train 164 | bash scripts/train_example.sh 165 | ``` 166 | ## Citation 167 | If you use this code or result in your paper, please cite our work as: 168 | ```Tex 169 | @misc{zhang2023speechtokenizer, 170 | title={SpeechTokenizer: Unified Speech Tokenizer for Speech Language Models}, 171 | author={Xin Zhang and Dong Zhang and Shimin Li and Yaqian Zhou and Xipeng Qiu}, 172 | year={2023}, 173 | eprint={2308.16692}, 174 | archivePrefix={arXiv}, 175 | primaryClass={cs.CL} 176 | } 177 | ``` 178 | ## License 179 | The code in this repository is released under the Apache 2.0 license as found in the 180 | [LICENSE](LICENSE) file. 181 | -------------------------------------------------------------------------------- /config/spt_base_cfg.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_filters": 64, 3 | "strides": [8,5,4,2], 4 | "dimension": 1024, 5 | "semantic_dimension": 768, 6 | "bidirectional": true, 7 | "dilation_base": 2, 8 | "residual_kernel_size": 3, 9 | "n_residual_layers": 1, 10 | "lstm_layers": 2, 11 | "activation": "ELU", 12 | "codebook_size": 1024, 13 | "n_q": 8, 14 | 15 | "num_mels": 80, 16 | "num_freq": 1025, 17 | "n_fft": 1024, 18 | "hop_size": 240, 19 | "win_size": 1024, 20 | "fmin": 0, 21 | "fmax": 8000, 22 | "fmax_for_loss": null, 23 | "mel_loss_lambdas": [45, 1, 1, 1], 24 | "recon_loss_lambda": 500, 25 | "commitment_loss_lambda": 10, 26 | 27 | "distill_loss_lambda": 120, 28 | "distill_type": "d_axis", 29 | 30 | "semantic_model_path": "/remote-home/share/models/hubert/ls_960/hubert-base-ls960", 31 | "semantic_model_layer": "avg", 32 | 33 | "train_files":"train_file_list.txt", 34 | "valid_files":"dev_file_list.txt", 35 | "results_folder": "Log/spt_base", 36 | "sample_rate": 16000, 37 | "batch_size": 6, 38 | "epochs":10, 39 | "learning_rate": 1e-4, 40 | "intial_learning_rate":1e-4, 41 | "num_warmup_steps":0, 42 | "betas":[0.9, 0.99], 43 | "adam_b2": 0.9, 44 | "lr_decay": 0.98, 45 | "seed": 1234, 46 | "segment_size": 48000, 47 | "wd": 0, 48 | "num_workers": 8, 49 | "log_steps": 100, 50 | "stdout_steps": 10, 51 | "save_model_steps": 2500, 52 | "num_ckpt_keep": 6, 53 | "showpiece_num": 8 54 | } -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torchaudio 3 | import torch 4 | from speechtokenizer import SpeechTokenizer 5 | from scipy.io.wavfile import write 6 | import numpy as np 7 | 8 | from huggingface_hub import snapshot_download 9 | 10 | snapshot_download(repo_id="fnlp/SpeechTokenizer", local_dir="model_hub") 11 | 12 | 13 | # Set up argument parser 14 | parser = argparse.ArgumentParser( 15 | description="Load SpeechTokenizer model and process audio file." 16 | ) 17 | parser.add_argument( 18 | "--config_path", 19 | type=str, 20 | help="Path to the model configuration file.", 21 | default="model_hub/speechtokenizer_hubert_avg/config.json", 22 | ) 23 | parser.add_argument( 24 | "--ckpt_path", 25 | type=str, 26 | help="Path to the model checkpoint file.", 27 | default="model_hub/speechtokenizer_hubert_avg/SpeechTokenizer.pt", 28 | ) 29 | parser.add_argument( 30 | "--speech_file", 31 | type=str, 32 | required=True, 33 | help="Path to the speech file to be processed.", 34 | ) 35 | parser.add_argument( 36 | "--output_file", 37 | type=str, 38 | help="Path to save the output audio file.", 39 | default="example_output.wav", 40 | ) 41 | 42 | args = parser.parse_args() 43 | 44 | # Load model from the specified checkpoint 45 | model = SpeechTokenizer.load_from_checkpoint(args.config_path, args.ckpt_path) 46 | model.eval() 47 | 48 | # Determine the model's expected sample rate 49 | model_sample_rate = model.sample_rate 50 | 51 | # Load and preprocess speech waveform with the model's sample rate 52 | wav, sr = torchaudio.load(args.speech_file) 53 | 54 | if sr != model_sample_rate: 55 | resample_transform = torchaudio.transforms.Resample( 56 | orig_freq=sr, new_freq=model_sample_rate 57 | ) 58 | wav = resample_transform(wav) 59 | 60 | # Ensure the waveform is monophonic 61 | if wav.shape[0] > 1: 62 | wav = wav[:1, :] 63 | 64 | wav = wav.unsqueeze(0) 65 | 66 | 67 | # Extract discrete codes from SpeechTokenizer 68 | with torch.no_grad(): 69 | codes = model.encode(wav) # codes: (n_q, B, T) 70 | 71 | RVQ_1 = codes[:1, :, :] # Contain content info, can be considered as semantic tokens 72 | RVQ_supplement = codes[ 73 | 1:, :, : 74 | ] # Contain timbre info, complete info lost by the first quantizer 75 | 76 | # Concatenating semantic tokens (RVQ_1) and supplementary timbre tokens and then decoding 77 | wav_out = model.decode(torch.cat([RVQ_1, RVQ_supplement], axis=0)) 78 | 79 | # Decoding from RVQ-i:j tokens from the ith quantizers to the jth quantizers 80 | # Example: decoding from quantizer 0 to quantizer 2 81 | wav_out = wav_out.detach().numpy() 82 | write(args.output_file, model_sample_rate, wav_out.astype(np.float32)) 83 | -------------------------------------------------------------------------------- /images/READ.me: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /images/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhangXInFD/SpeechTokenizer/30c96fb32a9fc06a2258c98119e237def051e46c/images/overview.png -------------------------------------------------------------------------------- /images/speechtokenizer_framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhangXInFD/SpeechTokenizer/30c96fb32a9fc06a2258c98119e237def051e46c/images/speechtokenizer_framework.jpg -------------------------------------------------------------------------------- /samples/example_input.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhangXInFD/SpeechTokenizer/30c96fb32a9fc06a2258c98119e237def051e46c/samples/example_input.wav -------------------------------------------------------------------------------- /scripts/hubert_rep_extract.py: -------------------------------------------------------------------------------- 1 | from transformers import HubertModel, Wav2Vec2FeatureExtractor 2 | from pathlib import Path 3 | import torchaudio 4 | import torch 5 | import json 6 | import argparse 7 | from tqdm import tqdm 8 | import random 9 | import numpy as np 10 | import os 11 | 12 | if __name__ == '__main__': 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--config', '-c', type=str, help='Config file path') 16 | parser.add_argument('--audio_dir', type=str, help='Audio folder path') 17 | parser.add_argument('--rep_dir', type=str, help='Path to save representation files') 18 | parser.add_argument('--exts', type=str, help="Audio file extensions, splitting with ','", default='flac') 19 | parser.add_argument('--split_seed', type=int, help="Random seed", default=0) 20 | parser.add_argument('--valid_set_size', type=float, default=1000) 21 | args = parser.parse_args() 22 | exts = args.exts.split(',') 23 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 24 | with open(args.config) as f: 25 | cfg = json.load(f) 26 | sample_rate = cfg.get('sample_rate') 27 | feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(cfg.get('semantic_model_path')) 28 | model = HubertModel.from_pretrained(cfg.get('semantic_model_path')).eval().to(device) 29 | target_layer = cfg.get('semantic_model_layer') 30 | path = Path(args.audio_dir) 31 | file_list = [str(file) for ext in exts for file in path.glob(f'**/*.{ext}')] 32 | if args.valid_set_size != 0 and args.valid_set_size < 1: 33 | valid_set_size = int(len(file_list) * args.valid_set_size) 34 | else: 35 | valid_set_size = int(args.valid_set_size) 36 | train_file_list = cfg.get('train_files') 37 | valid_file_list = cfg.get('valid_files') 38 | segment_size = cfg.get('segment_size') 39 | random.seed(args.split_seed) 40 | random.shuffle(file_list) 41 | print(f'A total of {len(file_list)} samples will be processed, and {valid_set_size} of them will be included in the validation set.') 42 | with torch.no_grad(): 43 | for i, audio_file in tqdm(enumerate(file_list)): 44 | wav, sr = torchaudio.load(audio_file) 45 | if sr != sample_rate: 46 | wav = torchaudio.functional.resample(wav, sr, sample_rate) 47 | if wav.size(-1) < segment_size: 48 | wav = torch.nn.functional.pad(wav, (0, segment_size - wav.size(-1)), 'constant') 49 | input_values = feature_extractor(wav.squeeze(0), sampling_rate=sample_rate, return_tensors="pt").input_values 50 | ouput = model(input_values.to(model.device), output_hidden_states=True) 51 | if target_layer == 'avg': 52 | rep = torch.mean(torch.stack(ouput.hidden_states), axis=0) 53 | else: 54 | rep = ouput.hidden_states[target_layer] 55 | rep_file = audio_file.replace(args.audio_dir, args.rep_dir).split('.')[0] + '.hubert.npy' 56 | rep_sub_dir = '/'.join(rep_file.split('/')[:-1]) 57 | if not os.path.exists(rep_sub_dir): 58 | os.makedirs(rep_sub_dir) 59 | np.save(rep_file, rep.detach().cpu().numpy()) 60 | if i < valid_set_size: 61 | with open(valid_file_list, 'a+') as f: 62 | f.write(f'{audio_file}\t{rep_file}\n') 63 | else: 64 | with open(train_file_list, 'a+') as f: 65 | f.write(f'{audio_file}\t{rep_file}\n') 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /scripts/hubert_rep_extract.sh: -------------------------------------------------------------------------------- 1 | CONFIG="config/spt_base_cfg.json" 2 | AUDIO_DIR="/remote-home/share/data/SpeechPretrain/LibriSpeech/LibriSpeech" 3 | REP_DIR="/remote-home/share/data/SpeechPretrain/hubert_rep/LibriSpeech" 4 | EXTS="flac" 5 | SPLIT_SEED=0 6 | VALID_SET_SIZE=1500 7 | 8 | 9 | 10 | CUDA_VISIBLE_DEVICES=0 python scripts/hubert_rep_extract.py\ 11 | --config ${CONFIG}\ 12 | --audio_dir ${AUDIO_DIR}\ 13 | --rep_dir ${REP_DIR}\ 14 | --exts ${EXTS}\ 15 | --split_seed ${SPLIT_SEED}\ 16 | --valid_set_size ${VALID_SET_SIZE} -------------------------------------------------------------------------------- /scripts/train_example.py: -------------------------------------------------------------------------------- 1 | from speechtokenizer import SpeechTokenizer, SpeechTokenizerTrainer 2 | from speechtokenizer.discriminators import MultiPeriodDiscriminator, MultiScaleDiscriminator, MultiScaleSTFTDiscriminator 3 | import json 4 | import argparse 5 | 6 | 7 | if __name__ == '__main__': 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--config', '-c', type=str, help='Config file path') 11 | parser.add_argument('--continue_train', action='store_true', help='Continue to train from checkpoints') 12 | args = parser.parse_args() 13 | with open(args.config) as f: 14 | cfg = json.load(f) 15 | 16 | generator = SpeechTokenizer(cfg) 17 | discriminators = {'mpd':MultiPeriodDiscriminator(), 'msd':MultiScaleDiscriminator(), 'mstftd':MultiScaleSTFTDiscriminator(32)} 18 | 19 | trainer = SpeechTokenizerTrainer(generator=generator, 20 | discriminators=discriminators, 21 | cfg=cfg) 22 | 23 | if args.continue_train: 24 | trainer.continue_train() 25 | else: 26 | trainer.train() -------------------------------------------------------------------------------- /scripts/train_example.sh: -------------------------------------------------------------------------------- 1 | 2 | CONFIG="config/spt_base_cfg.json" 3 | 4 | 5 | # NPROC_PER_NODE=4 6 | # CUDA_VISIBLE_DEVICES=1,2,6,7 torchrun \ 7 | # --nnode 1 \ 8 | # --nproc_per_node $NPROC_PER_NODE \ 9 | # --master_port 50025 \ 10 | # train_example.py \ 11 | # --config ${CONFIG} \ 12 | 13 | CUDA_VISIBLE_DEVICES=1,2,6,7 accelerate launch scripts/train_example.py\ 14 | --config ${CONFIG}\ 15 | --continue_train -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from setuptools import setup 3 | 4 | NAME = 'speechtokenizer' 5 | DESCRIPTION = 'Unified speech tokenizer for speech language model' 6 | URL = 'https://github.com/ZhangXInFD/SpeechTokenizer' 7 | EMAIL = 'xin_zhang22@m.fudan.edu.cn' 8 | AUTHOR = 'Xin Zhang, Dong Zhang, Shimin Li, Yaqian Zhou, Xipeng Qiu' 9 | REQUIRES_PYTHON = '>=3.8.0' 10 | 11 | for line in open('speechtokenizer/__init__.py'): 12 | line = line.strip() 13 | if '__version__' in line: 14 | context = {} 15 | exec(line, context) 16 | VERSION = context['__version__'] 17 | 18 | HERE = Path(__file__).parent 19 | 20 | try: 21 | with open(HERE / "README.md", encoding='utf-8') as f: 22 | long_description = '\n' + f.read() 23 | except FileNotFoundError: 24 | long_description = DESCRIPTION 25 | 26 | setup( 27 | name=NAME, 28 | version=VERSION, 29 | description=DESCRIPTION, 30 | long_description=long_description, 31 | long_description_content_type='text/markdown', 32 | author=AUTHOR, 33 | author_email=EMAIL, 34 | python_requires=REQUIRES_PYTHON, 35 | url=URL, 36 | packages=['speechtokenizer', 'speechtokenizer.quantization', 'speechtokenizer.modules', 'speechtokenizer.trainer'], 37 | install_requires=['numpy', 'torch', 'torchaudio', 'einops','scipy','huggingface-hub','soundfile', 'matplotlib', 'lion_pytorch', 'accelerate'], 38 | include_package_data=True, 39 | license='Apache License 2.0', 40 | classifiers=[ 41 | 'Topic :: Multimedia :: Sound/Audio', 42 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 43 | 'License :: OSI Approved :: Apache Software License', 44 | ]) 45 | -------------------------------------------------------------------------------- /speechtokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import SpeechTokenizer 2 | from .trainer import SpeechTokenizerTrainer 3 | 4 | __version__ = '1.0.0' -------------------------------------------------------------------------------- /speechtokenizer/discriminators.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.nn import Conv1d, AvgPool1d, Conv2d 5 | from torch.nn.utils import weight_norm, spectral_norm 6 | import typing as tp 7 | import torchaudio 8 | from einops import rearrange 9 | from .modules import NormConv2d 10 | 11 | LRELU_SLOPE = 0.1 12 | 13 | def get_padding(kernel_size, dilation=1): 14 | return int((kernel_size*dilation - dilation)/2) 15 | 16 | def init_weights(m, mean=0.0, std=0.01): 17 | classname = m.__class__.__name__ 18 | if classname.find("Conv") != -1: 19 | m.weight.data.normal_(mean, std) 20 | 21 | 22 | class DiscriminatorP(torch.nn.Module): 23 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 24 | super(DiscriminatorP, self).__init__() 25 | self.period = period 26 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 27 | self.convs = nn.ModuleList([ 28 | norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 29 | norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 30 | norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 31 | norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 32 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), 33 | ]) 34 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 35 | 36 | def forward(self, x): 37 | fmap = [] 38 | 39 | # 1d to 2d 40 | b, c, t = x.shape 41 | if t % self.period != 0: # pad first 42 | n_pad = self.period - (t % self.period) 43 | x = F.pad(x, (0, n_pad), "reflect") 44 | t = t + n_pad 45 | x = x.view(b, c, t // self.period, self.period) 46 | 47 | for l in self.convs: 48 | x = l(x) 49 | x = F.leaky_relu(x, LRELU_SLOPE) 50 | fmap.append(x) 51 | x = self.conv_post(x) 52 | fmap.append(x) 53 | x = torch.flatten(x, 1, -1) 54 | 55 | return x, fmap 56 | 57 | 58 | class MultiPeriodDiscriminator(torch.nn.Module): 59 | def __init__(self): 60 | super(MultiPeriodDiscriminator, self).__init__() 61 | self.discriminators = nn.ModuleList([ 62 | DiscriminatorP(2), 63 | DiscriminatorP(3), 64 | DiscriminatorP(5), 65 | DiscriminatorP(7), 66 | DiscriminatorP(11), 67 | ]) 68 | 69 | def forward(self, y, y_hat): 70 | y_d_rs = [] 71 | y_d_gs = [] 72 | fmap_rs = [] 73 | fmap_gs = [] 74 | for i, d in enumerate(self.discriminators): 75 | y_d_r, fmap_r = d(y) 76 | y_d_g, fmap_g = d(y_hat) 77 | y_d_rs.append(y_d_r) 78 | fmap_rs.append(fmap_r) 79 | y_d_gs.append(y_d_g) 80 | fmap_gs.append(fmap_g) 81 | 82 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 83 | 84 | 85 | class DiscriminatorS(torch.nn.Module): 86 | def __init__(self, use_spectral_norm=False): 87 | super(DiscriminatorS, self).__init__() 88 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 89 | self.convs = nn.ModuleList([ 90 | norm_f(Conv1d(1, 128, 15, 1, padding=7)), 91 | norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), 92 | norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), 93 | norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), 94 | norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), 95 | norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), 96 | norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), 97 | ]) 98 | self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) 99 | 100 | def forward(self, x): 101 | fmap = [] 102 | for l in self.convs: 103 | x = l(x) 104 | x = F.leaky_relu(x, LRELU_SLOPE) 105 | fmap.append(x) 106 | x = self.conv_post(x) 107 | fmap.append(x) 108 | x = torch.flatten(x, 1, -1) 109 | 110 | return x, fmap 111 | 112 | 113 | class MultiScaleDiscriminator(torch.nn.Module): 114 | def __init__(self): 115 | super(MultiScaleDiscriminator, self).__init__() 116 | self.discriminators = nn.ModuleList([ 117 | DiscriminatorS(use_spectral_norm=True), 118 | DiscriminatorS(), 119 | DiscriminatorS(), 120 | ]) 121 | self.meanpools = nn.ModuleList([ 122 | AvgPool1d(4, 2, padding=2), 123 | AvgPool1d(4, 2, padding=2) 124 | ]) 125 | 126 | def forward(self, y, y_hat): 127 | y_d_rs = [] 128 | y_d_gs = [] 129 | fmap_rs = [] 130 | fmap_gs = [] 131 | for i, d in enumerate(self.discriminators): 132 | if i != 0: 133 | y = self.meanpools[i-1](y) 134 | y_hat = self.meanpools[i-1](y_hat) 135 | y_d_r, fmap_r = d(y) 136 | y_d_g, fmap_g = d(y_hat) 137 | y_d_rs.append(y_d_r) 138 | fmap_rs.append(fmap_r) 139 | y_d_gs.append(y_d_g) 140 | fmap_gs.append(fmap_g) 141 | 142 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 143 | 144 | 145 | FeatureMapType = tp.List[torch.Tensor] 146 | LogitsType = torch.Tensor 147 | DiscriminatorOutput = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]] 148 | 149 | 150 | def get_2d_padding(kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)): 151 | return (((kernel_size[0] - 1) * dilation[0]) // 2, ((kernel_size[1] - 1) * dilation[1]) // 2) 152 | 153 | 154 | class DiscriminatorSTFT(nn.Module): 155 | """STFT sub-discriminator. 156 | Args: 157 | filters (int): Number of filters in convolutions 158 | in_channels (int): Number of input channels. Default: 1 159 | out_channels (int): Number of output channels. Default: 1 160 | n_fft (int): Size of FFT for each scale. Default: 1024 161 | hop_length (int): Length of hop between STFT windows for each scale. Default: 256 162 | kernel_size (tuple of int): Inner Conv2d kernel sizes. Default: ``(3, 9)`` 163 | stride (tuple of int): Inner Conv2d strides. Default: ``(1, 2)`` 164 | dilations (list of int): Inner Conv2d dilation on the time dimension. Default: ``[1, 2, 4]`` 165 | win_length (int): Window size for each scale. Default: 1024 166 | normalized (bool): Whether to normalize by magnitude after stft. Default: True 167 | norm (str): Normalization method. Default: `'weight_norm'` 168 | activation (str): Activation function. Default: `'LeakyReLU'` 169 | activation_params (dict): Parameters to provide to the activation function. 170 | growth (int): Growth factor for the filters. Default: 1 171 | """ 172 | def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1, 173 | n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, max_filters: int = 1024, 174 | filters_scale: int = 1, kernel_size: tp.Tuple[int, int] = (3, 9), dilations: tp.List = [1, 2, 4], 175 | stride: tp.Tuple[int, int] = (1, 2), normalized: bool = True, norm: str = 'weight_norm', 176 | activation: str = 'LeakyReLU', activation_params: dict = {'negative_slope': 0.2}): 177 | super().__init__() 178 | assert len(kernel_size) == 2 179 | assert len(stride) == 2 180 | self.filters = filters 181 | self.in_channels = in_channels 182 | self.out_channels = out_channels 183 | self.n_fft = n_fft 184 | self.hop_length = hop_length 185 | self.win_length = win_length 186 | self.normalized = normalized 187 | self.activation = getattr(torch.nn, activation)(**activation_params) 188 | self.spec_transform = torchaudio.transforms.Spectrogram( 189 | n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window_fn=torch.hann_window, 190 | normalized=self.normalized, center=False, pad_mode=None, power=None) 191 | spec_channels = 2 * self.in_channels 192 | self.convs = nn.ModuleList() 193 | self.convs.append( 194 | NormConv2d(spec_channels, self.filters, kernel_size=kernel_size, padding=get_2d_padding(kernel_size)) 195 | ) 196 | in_chs = min(filters_scale * self.filters, max_filters) 197 | for i, dilation in enumerate(dilations): 198 | out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters) 199 | self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride, 200 | dilation=(dilation, 1), padding=get_2d_padding(kernel_size, (dilation, 1)), 201 | norm=norm)) 202 | in_chs = out_chs 203 | out_chs = min((filters_scale ** (len(dilations) + 1)) * self.filters, max_filters) 204 | self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_size[0], kernel_size[0]), 205 | padding=get_2d_padding((kernel_size[0], kernel_size[0])), 206 | norm=norm)) 207 | self.conv_post = NormConv2d(out_chs, self.out_channels, 208 | kernel_size=(kernel_size[0], kernel_size[0]), 209 | padding=get_2d_padding((kernel_size[0], kernel_size[0])), 210 | norm=norm) 211 | 212 | def forward(self, x: torch.Tensor): 213 | fmap = [] 214 | # print('x ', x.shape) 215 | z = self.spec_transform(x) # [B, 2, Freq, Frames, 2] 216 | # print('z ', z.shape) 217 | z = torch.cat([z.real, z.imag], dim=1) 218 | # print('cat_z ', z.shape) 219 | z = rearrange(z, 'b c w t -> b c t w') 220 | for i, layer in enumerate(self.convs): 221 | z = layer(z) 222 | z = self.activation(z) 223 | # print('z i', i, z.shape) 224 | fmap.append(z) 225 | z = self.conv_post(z) 226 | # print('logit ', z.shape) 227 | return z, fmap 228 | 229 | 230 | class MultiScaleSTFTDiscriminator(nn.Module): 231 | """Multi-Scale STFT (MS-STFT) discriminator. 232 | Args: 233 | filters (int): Number of filters in convolutions 234 | in_channels (int): Number of input channels. Default: 1 235 | out_channels (int): Number of output channels. Default: 1 236 | n_ffts (Sequence[int]): Size of FFT for each scale 237 | hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale 238 | win_lengths (Sequence[int]): Window size for each scale 239 | **kwargs: additional args for STFTDiscriminator 240 | """ 241 | def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1, 242 | n_ffts: tp.List[int] = [1024, 2048, 512, 256, 128], hop_lengths: tp.List[int] = [256, 512, 128, 64, 32], 243 | win_lengths: tp.List[int] = [1024, 2048, 512, 256, 128], **kwargs): 244 | super().__init__() 245 | assert len(n_ffts) == len(hop_lengths) == len(win_lengths) 246 | self.discriminators = nn.ModuleList([ 247 | DiscriminatorSTFT(filters, in_channels=in_channels, out_channels=out_channels, 248 | n_fft=n_ffts[i], win_length=win_lengths[i], hop_length=hop_lengths[i], **kwargs) 249 | for i in range(len(n_ffts)) 250 | ]) 251 | self.num_discriminators = len(self.discriminators) 252 | 253 | def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> DiscriminatorOutput: 254 | logits = [] 255 | logits_fake = [] 256 | fmaps = [] 257 | fmaps_fake = [] 258 | for disc in self.discriminators: 259 | logit, fmap = disc(y) 260 | logits.append(logit) 261 | fmaps.append(fmap) 262 | logit_fake, fmap_fake = disc(y_hat) 263 | logits_fake.append(logit_fake) 264 | fmaps_fake.append(fmap_fake) 265 | return logits, logits_fake, fmaps, fmaps_fake -------------------------------------------------------------------------------- /speechtokenizer/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Aug 30 15:47:55 2023 4 | @author: zhangxin 5 | """ 6 | 7 | from .modules.seanet import SEANetEncoder, SEANetDecoder 8 | from .quantization import ResidualVectorQuantizer 9 | import torch.nn as nn 10 | from einops import rearrange 11 | import torch 12 | import numpy as np 13 | 14 | class SpeechTokenizer(nn.Module): 15 | def __init__(self, config): 16 | ''' 17 | 18 | Parameters 19 | ---------- 20 | config : json 21 | Model Config. 22 | 23 | ''' 24 | super().__init__() 25 | self.encoder = SEANetEncoder(n_filters=config.get('n_filters'), 26 | dimension=config.get('dimension'), 27 | ratios=config.get('strides'), 28 | lstm=config.get('lstm_layers'), 29 | bidirectional=config.get('bidirectional'), 30 | dilation_base=config.get('dilation_base'), 31 | residual_kernel_size=config.get('residual_kernel_size'), 32 | n_residual_layers=config.get('n_residual_layers'), 33 | activation=config.get('activation')) 34 | self.sample_rate = config.get('sample_rate') 35 | self.n_q = config.get('n_q') 36 | self.downsample_rate = np.prod(config.get('strides')) 37 | if config.get('dimension') != config.get('semantic_dimension'): 38 | self.transform = nn.Linear(config.get('dimension'), config.get('semantic_dimension')) 39 | else: 40 | self.transform = nn.Identity() 41 | self.quantizer = ResidualVectorQuantizer(dimension=config.get('dimension'), n_q=config.get('n_q'), bins=config.get('codebook_size')) 42 | self.decoder = SEANetDecoder(n_filters=config.get('n_filters'), 43 | dimension=config.get('dimension'), 44 | ratios=config.get('strides'), 45 | lstm=config.get('lstm_layers'), 46 | bidirectional=False, 47 | dilation_base=config.get('dilation_base'), 48 | residual_kernel_size=config.get('residual_kernel_size'), 49 | n_residual_layers=config.get('n_residual_layers'), 50 | activation=config.get('activation')) 51 | 52 | @classmethod 53 | def load_from_checkpoint(cls, 54 | config_path: str, 55 | ckpt_path: str): 56 | ''' 57 | 58 | Parameters 59 | ---------- 60 | config_path : str 61 | Path of model configuration file. 62 | ckpt_path : str 63 | Path of model checkpoint. 64 | 65 | Returns 66 | ------- 67 | model : SpeechTokenizer 68 | SpeechTokenizer model. 69 | 70 | ''' 71 | import json 72 | with open(config_path) as f: 73 | cfg = json.load(f) 74 | model = cls(cfg) 75 | params = torch.load(ckpt_path, map_location='cpu') 76 | model.load_state_dict(params) 77 | return model 78 | 79 | 80 | def forward(self, 81 | x: torch.tensor, 82 | n_q: int=None, 83 | layers: list=[0]): 84 | ''' 85 | 86 | Parameters 87 | ---------- 88 | x : torch.tensor 89 | Input wavs. Shape: (batch, channels, timesteps). 90 | n_q : int, optional 91 | Number of quantizers in RVQ used to encode. The default is all layers. 92 | layers : list[int], optional 93 | Layers of RVQ should return quantized result. The default is the first layer. 94 | 95 | Returns 96 | ------- 97 | o : torch.tensor 98 | Output wavs. Shape: (batch, channels, timesteps). 99 | commit_loss : torch.tensor 100 | Commitment loss from residual vector quantizers. 101 | feature : torch.tensor 102 | Output of RVQ's first layer. Shape: (batch, timesteps, dimension) 103 | 104 | ''' 105 | n_q = n_q if n_q else self.n_q 106 | e = self.encoder(x) 107 | quantized, codes, commit_loss, quantized_list = self.quantizer(e, n_q=n_q, layers=layers) 108 | feature = rearrange(quantized_list[0], 'b d t -> b t d') 109 | feature = self.transform(feature) 110 | o = self.decoder(quantized) 111 | return o, commit_loss, feature 112 | 113 | def forward_feature(self, 114 | x: torch.tensor, 115 | layers: list=None): 116 | ''' 117 | 118 | Parameters 119 | ---------- 120 | x : torch.tensor 121 | Input wavs. Shape should be (batch, channels, timesteps). 122 | layers : list[int], optional 123 | Layers of RVQ should return quantized result. The default is all layers. 124 | 125 | Returns 126 | ------- 127 | quantized_list : list[torch.tensor] 128 | Quantized of required layers. 129 | 130 | ''' 131 | e = self.encoder(x) 132 | layers = layers if layers else list(range(self.n_q)) 133 | quantized, codes, commit_loss, quantized_list = self.quantizer(e, layers=layers) 134 | return quantized_list 135 | 136 | def encode(self, 137 | x: torch.tensor, 138 | n_q: int=None, 139 | st: int=None): 140 | ''' 141 | 142 | Parameters 143 | ---------- 144 | x : torch.tensor 145 | Input wavs. Shape: (batch, channels, timesteps). 146 | n_q : int, optional 147 | Number of quantizers in RVQ used to encode. The default is all layers. 148 | st : int, optional 149 | Start quantizer index in RVQ. The default is 0. 150 | 151 | Returns 152 | ------- 153 | codes : torch.tensor 154 | Output indices for each quantizer. Shape: (n_q, batch, timesteps) 155 | 156 | ''' 157 | e = self.encoder(x) 158 | if st is None: 159 | st = 0 160 | n_q = n_q if n_q else self.n_q 161 | codes = self.quantizer.encode(e, n_q=n_q, st=st) 162 | return codes 163 | 164 | def decode(self, 165 | codes: torch.tensor, 166 | st: int=0): 167 | ''' 168 | 169 | Parameters 170 | ---------- 171 | codes : torch.tensor 172 | Indices for each quantizer. Shape: (n_q, batch, timesteps). 173 | st : int, optional 174 | Start quantizer index in RVQ. The default is 0. 175 | 176 | Returns 177 | ------- 178 | o : torch.tensor 179 | Reconstruct wavs from codes. Shape: (batch, channels, timesteps) 180 | 181 | ''' 182 | quantized = self.quantizer.decode(codes, st=st) 183 | o = self.decoder(quantized) 184 | return o -------------------------------------------------------------------------------- /speechtokenizer/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Torch modules.""" 8 | 9 | # flake8: noqa 10 | from .conv import ( 11 | pad1d, 12 | unpad1d, 13 | NormConv1d, 14 | NormConvTranspose1d, 15 | NormConv2d, 16 | NormConvTranspose2d, 17 | SConv1d, 18 | SConvTranspose1d, 19 | ) 20 | from .lstm import SLSTM 21 | from .seanet import SEANetEncoder, SEANetDecoder -------------------------------------------------------------------------------- /speechtokenizer/modules/conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Convolutional layers wrappers and utilities.""" 8 | 9 | import math 10 | import typing as tp 11 | import warnings 12 | 13 | import torch 14 | from torch import nn 15 | from torch.nn import functional as F 16 | from torch.nn.utils import spectral_norm, weight_norm 17 | 18 | from .norm import ConvLayerNorm 19 | 20 | 21 | CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm', 22 | 'time_layer_norm', 'layer_norm', 'time_group_norm']) 23 | 24 | 25 | def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module: 26 | assert norm in CONV_NORMALIZATIONS 27 | if norm == 'weight_norm': 28 | return weight_norm(module) 29 | elif norm == 'spectral_norm': 30 | return spectral_norm(module) 31 | else: 32 | # We already check was in CONV_NORMALIZATION, so any other choice 33 | # doesn't need reparametrization. 34 | return module 35 | 36 | 37 | def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module: 38 | """Return the proper normalization module. If causal is True, this will ensure the returned 39 | module is causal, or return an error if the normalization doesn't support causal evaluation. 40 | """ 41 | assert norm in CONV_NORMALIZATIONS 42 | if norm == 'layer_norm': 43 | assert isinstance(module, nn.modules.conv._ConvNd) 44 | return ConvLayerNorm(module.out_channels, **norm_kwargs) 45 | elif norm == 'time_group_norm': 46 | if causal: 47 | raise ValueError("GroupNorm doesn't support causal evaluation.") 48 | assert isinstance(module, nn.modules.conv._ConvNd) 49 | return nn.GroupNorm(1, module.out_channels, **norm_kwargs) 50 | else: 51 | return nn.Identity() 52 | 53 | 54 | def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, 55 | padding_total: int = 0) -> int: 56 | """See `pad_for_conv1d`. 57 | """ 58 | length = x.shape[-1] 59 | n_frames = (length - kernel_size + padding_total) / stride + 1 60 | ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) 61 | return ideal_length - length 62 | 63 | 64 | def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0): 65 | """Pad for a convolution to make sure that the last window is full. 66 | Extra padding is added at the end. This is required to ensure that we can rebuild 67 | an output of the same length, as otherwise, even with padding, some time steps 68 | might get removed. 69 | For instance, with total padding = 4, kernel size = 4, stride = 2: 70 | 0 0 1 2 3 4 5 0 0 # (0s are padding) 71 | 1 2 3 # (output frames of a convolution, last 0 is never used) 72 | 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding) 73 | 1 2 3 4 # once you removed padding, we are missing one time step ! 74 | """ 75 | extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) 76 | return F.pad(x, (0, extra_padding)) 77 | 78 | 79 | def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.): 80 | """Tiny wrapper around F.pad, just to allow for reflect padding on small input. 81 | If this is the case, we insert extra 0 padding to the right before the reflection happen. 82 | """ 83 | length = x.shape[-1] 84 | padding_left, padding_right = paddings 85 | assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) 86 | if mode == 'reflect': 87 | max_pad = max(padding_left, padding_right) 88 | extra_pad = 0 89 | if length <= max_pad: 90 | extra_pad = max_pad - length + 1 91 | x = F.pad(x, (0, extra_pad)) 92 | padded = F.pad(x, paddings, mode, value) 93 | end = padded.shape[-1] - extra_pad 94 | return padded[..., :end] 95 | else: 96 | return F.pad(x, paddings, mode, value) 97 | 98 | 99 | def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): 100 | """Remove padding from x, handling properly zero padding. Only for 1d!""" 101 | padding_left, padding_right = paddings 102 | assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) 103 | assert (padding_left + padding_right) <= x.shape[-1] 104 | end = x.shape[-1] - padding_right 105 | return x[..., padding_left: end] 106 | 107 | 108 | class NormConv1d(nn.Module): 109 | """Wrapper around Conv1d and normalization applied to this conv 110 | to provide a uniform interface across normalization approaches. 111 | """ 112 | def __init__(self, *args, causal: bool = False, norm: str = 'none', 113 | norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): 114 | super().__init__() 115 | self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm) 116 | self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs) 117 | self.norm_type = norm 118 | 119 | def forward(self, x): 120 | x = self.conv(x) 121 | x = self.norm(x) 122 | return x 123 | 124 | 125 | class NormConv2d(nn.Module): 126 | """Wrapper around Conv2d and normalization applied to this conv 127 | to provide a uniform interface across normalization approaches. 128 | """ 129 | def __init__(self, *args, norm: str = 'none', 130 | norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): 131 | super().__init__() 132 | self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm) 133 | self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs) 134 | self.norm_type = norm 135 | 136 | def forward(self, x): 137 | x = self.conv(x) 138 | x = self.norm(x) 139 | return x 140 | 141 | 142 | class NormConvTranspose1d(nn.Module): 143 | """Wrapper around ConvTranspose1d and normalization applied to this conv 144 | to provide a uniform interface across normalization approaches. 145 | """ 146 | def __init__(self, *args, causal: bool = False, norm: str = 'none', 147 | norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): 148 | super().__init__() 149 | self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm) 150 | self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs) 151 | self.norm_type = norm 152 | 153 | def forward(self, x): 154 | x = self.convtr(x) 155 | x = self.norm(x) 156 | return x 157 | 158 | 159 | class NormConvTranspose2d(nn.Module): 160 | """Wrapper around ConvTranspose2d and normalization applied to this conv 161 | to provide a uniform interface across normalization approaches. 162 | """ 163 | def __init__(self, *args, norm: str = 'none', 164 | norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): 165 | super().__init__() 166 | self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm) 167 | self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs) 168 | 169 | def forward(self, x): 170 | x = self.convtr(x) 171 | x = self.norm(x) 172 | return x 173 | 174 | 175 | class SConv1d(nn.Module): 176 | """Conv1d with some builtin handling of asymmetric or causal padding 177 | and normalization. 178 | """ 179 | def __init__(self, in_channels: int, out_channels: int, 180 | kernel_size: int, stride: int = 1, dilation: int = 1, 181 | groups: int = 1, bias: bool = True, causal: bool = False, 182 | norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, 183 | pad_mode: str = 'reflect'): 184 | super().__init__() 185 | # warn user on unusual setup between dilation and stride 186 | if stride > 1 and dilation > 1: 187 | warnings.warn('SConv1d has been initialized with stride > 1 and dilation > 1' 188 | f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).') 189 | self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride, 190 | dilation=dilation, groups=groups, bias=bias, causal=causal, 191 | norm=norm, norm_kwargs=norm_kwargs) 192 | self.causal = causal 193 | self.pad_mode = pad_mode 194 | 195 | def forward(self, x): 196 | B, C, T = x.shape 197 | kernel_size = self.conv.conv.kernel_size[0] 198 | stride = self.conv.conv.stride[0] 199 | dilation = self.conv.conv.dilation[0] 200 | padding_total = (kernel_size - 1) * dilation - (stride - 1) 201 | extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) 202 | if self.causal: 203 | # Left padding for causal 204 | x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode) 205 | else: 206 | # Asymmetric padding required for odd strides 207 | padding_right = padding_total // 2 208 | padding_left = padding_total - padding_right 209 | x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode) 210 | return self.conv(x) 211 | 212 | 213 | class SConvTranspose1d(nn.Module): 214 | """ConvTranspose1d with some builtin handling of asymmetric or causal padding 215 | and normalization. 216 | """ 217 | def __init__(self, in_channels: int, out_channels: int, 218 | kernel_size: int, stride: int = 1, causal: bool = False, 219 | norm: str = 'none', trim_right_ratio: float = 1., 220 | norm_kwargs: tp.Dict[str, tp.Any] = {}): 221 | super().__init__() 222 | self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride, 223 | causal=causal, norm=norm, norm_kwargs=norm_kwargs) 224 | self.causal = causal 225 | self.trim_right_ratio = trim_right_ratio 226 | assert self.causal or self.trim_right_ratio == 1., \ 227 | "`trim_right_ratio` != 1.0 only makes sense for causal convolutions" 228 | assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1. 229 | 230 | def forward(self, x): 231 | kernel_size = self.convtr.convtr.kernel_size[0] 232 | stride = self.convtr.convtr.stride[0] 233 | padding_total = kernel_size - stride 234 | 235 | y = self.convtr(x) 236 | 237 | # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be 238 | # removed at the very end, when keeping only the right length for the output, 239 | # as removing it here would require also passing the length at the matching layer 240 | # in the encoder. 241 | if self.causal: 242 | # Trim the padding on the right according to the specified ratio 243 | # if trim_right_ratio = 1.0, trim everything from right 244 | padding_right = math.ceil(padding_total * self.trim_right_ratio) 245 | padding_left = padding_total - padding_right 246 | y = unpad1d(y, (padding_left, padding_right)) 247 | else: 248 | # Asymmetric padding required for odd strides 249 | padding_right = padding_total // 2 250 | padding_left = padding_total - padding_right 251 | y = unpad1d(y, (padding_left, padding_right)) 252 | return y 253 | -------------------------------------------------------------------------------- /speechtokenizer/modules/lstm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """LSTM layers module.""" 8 | 9 | from torch import nn 10 | 11 | 12 | class SLSTM(nn.Module): 13 | """ 14 | LSTM without worrying about the hidden state, nor the layout of the data. 15 | Expects input as convolutional layout. 16 | """ 17 | def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True, bidirectional: bool=False): 18 | super().__init__() 19 | self.bidirectional = bidirectional 20 | self.skip = skip 21 | self.lstm = nn.LSTM(dimension, dimension, num_layers, bidirectional=bidirectional) 22 | 23 | def forward(self, x): 24 | x = x.permute(2, 0, 1) 25 | y, _ = self.lstm(x) 26 | if self.bidirectional: 27 | x = x.repeat(1, 1, 2) 28 | if self.skip: 29 | y = y + x 30 | y = y.permute(1, 2, 0) 31 | return y 32 | -------------------------------------------------------------------------------- /speechtokenizer/modules/norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Normalization modules.""" 8 | 9 | import typing as tp 10 | 11 | import einops 12 | import torch 13 | from torch import nn 14 | 15 | 16 | class ConvLayerNorm(nn.LayerNorm): 17 | """ 18 | Convolution-friendly LayerNorm that moves channels to last dimensions 19 | before running the normalization and moves them back to original position right after. 20 | """ 21 | def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs): 22 | super().__init__(normalized_shape, **kwargs) 23 | 24 | def forward(self, x): 25 | x = einops.rearrange(x, 'b ... t -> b t ...') 26 | x = super().forward(x) 27 | x = einops.rearrange(x, 'b t ... -> b ... t') 28 | return 29 | -------------------------------------------------------------------------------- /speechtokenizer/modules/seanet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Encodec SEANet-based encoder and decoder implementation.""" 8 | 9 | import typing as tp 10 | 11 | import numpy as np 12 | import torch.nn as nn 13 | import torch 14 | 15 | from . import ( 16 | SConv1d, 17 | SConvTranspose1d, 18 | SLSTM 19 | ) 20 | 21 | 22 | @torch.jit.script 23 | def snake(x, alpha): 24 | shape = x.shape 25 | x = x.reshape(shape[0], shape[1], -1) 26 | x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) 27 | x = x.reshape(shape) 28 | return x 29 | 30 | 31 | class Snake1d(nn.Module): 32 | def __init__(self, channels): 33 | super().__init__() 34 | self.alpha = nn.Parameter(torch.ones(1, channels, 1)) 35 | 36 | def forward(self, x): 37 | return snake(x, self.alpha) 38 | 39 | class SEANetResnetBlock(nn.Module): 40 | """Residual block from SEANet model. 41 | Args: 42 | dim (int): Dimension of the input/output 43 | kernel_sizes (list): List of kernel sizes for the convolutions. 44 | dilations (list): List of dilations for the convolutions. 45 | activation (str): Activation function. 46 | activation_params (dict): Parameters to provide to the activation function 47 | norm (str): Normalization method. 48 | norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. 49 | causal (bool): Whether to use fully causal convolution. 50 | pad_mode (str): Padding mode for the convolutions. 51 | compress (int): Reduced dimensionality in residual branches (from Demucs v3) 52 | true_skip (bool): Whether to use true skip connection or a simple convolution as the skip connection. 53 | """ 54 | def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1], 55 | activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, 56 | norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False, 57 | pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True): 58 | super().__init__() 59 | assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations' 60 | act = getattr(nn, activation) if activation != 'Snake' else Snake1d 61 | hidden = dim // compress 62 | block = [] 63 | for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)): 64 | in_chs = dim if i == 0 else hidden 65 | out_chs = dim if i == len(kernel_sizes) - 1 else hidden 66 | block += [ 67 | act(**activation_params) if activation != 'Snake' else act(in_chs), 68 | SConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation, 69 | norm=norm, norm_kwargs=norm_params, 70 | causal=causal, pad_mode=pad_mode), 71 | ] 72 | self.block = nn.Sequential(*block) 73 | self.shortcut: nn.Module 74 | if true_skip: 75 | self.shortcut = nn.Identity() 76 | else: 77 | self.shortcut = SConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params, 78 | causal=causal, pad_mode=pad_mode) 79 | 80 | def forward(self, x): 81 | return self.shortcut(x) + self.block(x) 82 | 83 | 84 | 85 | class SEANetEncoder(nn.Module): 86 | """SEANet encoder. 87 | Args: 88 | channels (int): Audio channels. 89 | dimension (int): Intermediate representation dimension. 90 | n_filters (int): Base width for the model. 91 | n_residual_layers (int): nb of residual layers. 92 | ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of 93 | upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here 94 | that must match the decoder order 95 | activation (str): Activation function. 96 | activation_params (dict): Parameters to provide to the activation function 97 | norm (str): Normalization method. 98 | norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. 99 | kernel_size (int): Kernel size for the initial convolution. 100 | last_kernel_size (int): Kernel size for the initial convolution. 101 | residual_kernel_size (int): Kernel size for the residual layers. 102 | dilation_base (int): How much to increase the dilation with each layer. 103 | causal (bool): Whether to use fully causal convolution. 104 | pad_mode (str): Padding mode for the convolutions. 105 | true_skip (bool): Whether to use true skip connection or a simple 106 | (streamable) convolution as the skip connection in the residual network blocks. 107 | compress (int): Reduced dimensionality in residual branches (from Demucs v3). 108 | lstm (int): Number of LSTM layers at the end of the encoder. 109 | """ 110 | def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 1, 111 | ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, 112 | norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7, 113 | last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False, 114 | pad_mode: str = 'reflect', true_skip: bool = False, compress: int = 2, lstm: int = 2, bidirectional:bool = False): 115 | super().__init__() 116 | self.channels = channels 117 | self.dimension = dimension 118 | self.n_filters = n_filters 119 | self.ratios = list(reversed(ratios)) 120 | del ratios 121 | self.n_residual_layers = n_residual_layers 122 | self.hop_length = np.prod(self.ratios) # 计算乘积 123 | 124 | act = getattr(nn, activation) if activation != 'Snake' else Snake1d 125 | mult = 1 126 | model: tp.List[nn.Module] = [ 127 | SConv1d(channels, mult * n_filters, kernel_size, norm=norm, norm_kwargs=norm_params, 128 | causal=causal, pad_mode=pad_mode) 129 | ] 130 | # Downsample to raw audio scale 131 | for i, ratio in enumerate(self.ratios): 132 | # Add residual layers 133 | for j in range(n_residual_layers): 134 | model += [ 135 | SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1], 136 | dilations=[dilation_base ** j, 1], 137 | norm=norm, norm_params=norm_params, 138 | activation=activation, activation_params=activation_params, 139 | causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)] 140 | 141 | # Add downsampling layers 142 | model += [ 143 | act(**activation_params) if activation != 'Snake' else act(mult * n_filters), 144 | SConv1d(mult * n_filters, mult * n_filters * 2, 145 | kernel_size=ratio * 2, stride=ratio, 146 | norm=norm, norm_kwargs=norm_params, 147 | causal=causal, pad_mode=pad_mode), 148 | ] 149 | mult *= 2 150 | 151 | if lstm: 152 | model += [SLSTM(mult * n_filters, num_layers=lstm, bidirectional=bidirectional)] 153 | 154 | mult = mult * 2 if bidirectional else mult 155 | model += [ 156 | act(**activation_params) if activation != 'Snake' else act(mult * n_filters), 157 | SConv1d(mult * n_filters, dimension, last_kernel_size, norm=norm, norm_kwargs=norm_params, 158 | causal=causal, pad_mode=pad_mode) 159 | ] 160 | 161 | self.model = nn.Sequential(*model) 162 | 163 | def forward(self, x): 164 | return self.model(x) 165 | 166 | 167 | class SEANetDecoder(nn.Module): 168 | """SEANet decoder. 169 | Args: 170 | channels (int): Audio channels. 171 | dimension (int): Intermediate representation dimension. 172 | n_filters (int): Base width for the model. 173 | n_residual_layers (int): nb of residual layers. 174 | ratios (Sequence[int]): kernel size and stride ratios 175 | activation (str): Activation function. 176 | activation_params (dict): Parameters to provide to the activation function 177 | final_activation (str): Final activation function after all convolutions. 178 | final_activation_params (dict): Parameters to provide to the activation function 179 | norm (str): Normalization method. 180 | norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. 181 | kernel_size (int): Kernel size for the initial convolution. 182 | last_kernel_size (int): Kernel size for the initial convolution. 183 | residual_kernel_size (int): Kernel size for the residual layers. 184 | dilation_base (int): How much to increase the dilation with each layer. 185 | causal (bool): Whether to use fully causal convolution. 186 | pad_mode (str): Padding mode for the convolutions. 187 | true_skip (bool): Whether to use true skip connection or a simple 188 | (streamable) convolution as the skip connection in the residual network blocks. 189 | compress (int): Reduced dimensionality in residual branches (from Demucs v3). 190 | lstm (int): Number of LSTM layers at the end of the encoder. 191 | trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup. 192 | If equal to 1.0, it means that all the trimming is done at the right. 193 | """ 194 | def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 1, 195 | ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, 196 | final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None, 197 | norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7, 198 | last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False, 199 | pad_mode: str = 'reflect', true_skip: bool = False, compress: int = 2, lstm: int = 2, 200 | trim_right_ratio: float = 1.0, bidirectional:bool = False): 201 | super().__init__() 202 | self.dimension = dimension 203 | self.channels = channels 204 | self.n_filters = n_filters 205 | self.ratios = ratios 206 | del ratios 207 | self.n_residual_layers = n_residual_layers 208 | self.hop_length = np.prod(self.ratios) 209 | 210 | act = getattr(nn, activation) if activation != 'Snake' else Snake1d 211 | mult = int(2 ** len(self.ratios)) 212 | model: tp.List[nn.Module] = [ 213 | SConv1d(dimension, mult * n_filters, kernel_size, norm=norm, norm_kwargs=norm_params, 214 | causal=causal, pad_mode=pad_mode) 215 | ] 216 | 217 | if lstm: 218 | model += [SLSTM(mult * n_filters, num_layers=lstm, bidirectional=bidirectional)] 219 | 220 | # Upsample to raw audio scale 221 | for i, ratio in enumerate(self.ratios): 222 | # Add upsampling layers 223 | model += [ 224 | act(**activation_params) if activation != 'Snake' else act(mult * n_filters), 225 | SConvTranspose1d(mult * n_filters, mult * n_filters // 2, 226 | kernel_size=ratio * 2, stride=ratio, 227 | norm=norm, norm_kwargs=norm_params, 228 | causal=causal, trim_right_ratio=trim_right_ratio), 229 | ] 230 | # Add residual layers 231 | for j in range(n_residual_layers): 232 | model += [ 233 | SEANetResnetBlock(mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1], 234 | dilations=[dilation_base ** j, 1], 235 | activation=activation, activation_params=activation_params, 236 | norm=norm, norm_params=norm_params, causal=causal, 237 | pad_mode=pad_mode, compress=compress, true_skip=true_skip)] 238 | 239 | mult //= 2 240 | 241 | # Add final layers 242 | model += [ 243 | act(**activation_params) if activation != 'Snake' else act(n_filters), 244 | SConv1d(n_filters, channels, last_kernel_size, norm=norm, norm_kwargs=norm_params, 245 | causal=causal, pad_mode=pad_mode) 246 | ] 247 | # Add optional final activation to decoder (eg. tanh) 248 | if final_activation is not None: 249 | final_act = getattr(nn, final_activation) 250 | final_activation_params = final_activation_params or {} 251 | model += [ 252 | final_act(**final_activation_params) 253 | ] 254 | self.model = nn.Sequential(*model) 255 | 256 | def forward(self, z): 257 | y = self.model(z) 258 | return y 259 | 260 | 261 | def test(): 262 | import torch 263 | encoder = SEANetEncoder() 264 | decoder = SEANetDecoder() 265 | x = torch.randn(1, 1, 24000) 266 | z = encoder(x) 267 | print('z ', z.shape) 268 | assert 1==2 269 | assert list(z.shape) == [1, 128, 75], z.shape 270 | y = decoder(z) 271 | assert y.shape == x.shape, (x.shape, y.shape) 272 | 273 | 274 | if __name__ == '__main__': 275 | test() 276 | -------------------------------------------------------------------------------- /speechtokenizer/quantization/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # flake8: noqa 8 | from .vq import QuantizedResult, ResidualVectorQuantizer 9 | -------------------------------------------------------------------------------- /speechtokenizer/quantization/ac.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Arithmetic coder.""" 8 | 9 | import io 10 | import math 11 | import random 12 | import typing as tp 13 | import torch 14 | 15 | from ..binary import BitPacker, BitUnpacker 16 | 17 | 18 | def build_stable_quantized_cdf(pdf: torch.Tensor, total_range_bits: int, 19 | roundoff: float = 1e-8, min_range: int = 2, 20 | check: bool = True) -> torch.Tensor: 21 | """Turn the given PDF into a quantized CDF that splits 22 | [0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional 23 | to the PDF. 24 | 25 | Args: 26 | pdf (torch.Tensor): probability distribution, shape should be `[N]`. 27 | total_range_bits (int): see `ArithmeticCoder`, the typical range we expect 28 | during the coding process is `[0, 2 ** total_range_bits - 1]`. 29 | roundoff (float): will round the pdf up to that level to remove difference coming 30 | from e.g. evaluating the Language Model on different architectures. 31 | min_range (int): minimum range width. Should always be at least 2 for numerical 32 | stability. Use this to avoid pathological behavior is a value 33 | that is expected to be rare actually happens in real life. 34 | check (bool): if True, checks that nothing bad happened, can be deactivated for speed. 35 | """ 36 | pdf = pdf.detach() 37 | if roundoff: 38 | pdf = (pdf / roundoff).floor() * roundoff 39 | # interpolate with uniform distribution to achieve desired minimum probability. 40 | total_range = 2 ** total_range_bits 41 | cardinality = len(pdf) 42 | alpha = min_range * cardinality / total_range 43 | assert alpha <= 1, "you must reduce min_range" 44 | ranges = (((1 - alpha) * total_range) * pdf).floor().long() 45 | ranges += min_range 46 | quantized_cdf = torch.cumsum(ranges, dim=-1) 47 | if min_range < 2: 48 | raise ValueError("min_range must be at least 2.") 49 | if check: 50 | assert quantized_cdf[-1] <= 2 ** total_range_bits, quantized_cdf[-1] 51 | if ((quantized_cdf[1:] - quantized_cdf[:-1]) < min_range).any() or quantized_cdf[0] < min_range: 52 | raise ValueError("You must increase your total_range_bits.") 53 | return quantized_cdf 54 | 55 | 56 | class ArithmeticCoder: 57 | """ArithmeticCoder, 58 | Let us take a distribution `p` over `N` symbols, and assume we have a stream 59 | of random variables `s_t` sampled from `p`. Let us assume that we have a budget 60 | of `B` bits that we can afford to write on device. There are `2**B` possible numbers, 61 | corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single 62 | sequence `(s_t)` by doing the following: 63 | 64 | 1) Initialize the current range to` [0 ** 2 B - 1]`. 65 | 2) For each time step t, split the current range into contiguous chunks, 66 | one for each possible outcome, with size roughly proportional to `p`. 67 | For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks 68 | would be `{[0, 2], [3, 3]}`. 69 | 3) Select the chunk corresponding to `s_t`, and replace the current range with this. 70 | 4) When done encoding all the values, just select any value remaining in the range. 71 | 72 | You will notice that this procedure can fail: for instance if at any point in time 73 | the range is smaller than `N`, then we can no longer assign a non-empty chunk to each 74 | possible outcome. Intuitively, the more likely a value is, the less the range width 75 | will reduce, and the longer we can go on encoding values. This makes sense: for any efficient 76 | coding scheme, likely outcomes would take less bits, and more of them can be coded 77 | with a fixed budget. 78 | 79 | In practice, we do not know `B` ahead of time, but we have a way to inject new bits 80 | when the current range decreases below a given limit (given by `total_range_bits`), without 81 | having to redo all the computations. If we encode mostly likely values, we will seldom 82 | need to inject new bits, but a single rare value can deplete our stock of entropy! 83 | 84 | In this explanation, we assumed that the distribution `p` was constant. In fact, the present 85 | code works for any sequence `(p_t)` possibly different for each timestep. 86 | We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller 87 | the KL between the true distribution and `p_t`, the most efficient the coding will be. 88 | 89 | Args: 90 | fo (IO[bytes]): file-like object to which the bytes will be written to. 91 | total_range_bits (int): the range `M` described above is `2 ** total_range_bits. 92 | Any time the current range width fall under this limit, new bits will 93 | be injected to rescale the initial range. 94 | """ 95 | 96 | def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24): 97 | assert total_range_bits <= 30 98 | self.total_range_bits = total_range_bits 99 | self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time. 100 | self.low: int = 0 101 | self.high: int = 0 102 | self.max_bit: int = -1 103 | self._dbg: tp.List[tp.Any] = [] 104 | self._dbg2: tp.List[tp.Any] = [] 105 | 106 | @property 107 | def delta(self) -> int: 108 | """Return the current range width.""" 109 | return self.high - self.low + 1 110 | 111 | def _flush_common_prefix(self): 112 | # If self.low and self.high start with the sames bits, 113 | # those won't change anymore as we always just increase the range 114 | # by powers of 2, and we can flush them out to the bit stream. 115 | assert self.high >= self.low, (self.low, self.high) 116 | assert self.high < 2 ** (self.max_bit + 1) 117 | while self.max_bit >= 0: 118 | b1 = self.low >> self.max_bit 119 | b2 = self.high >> self.max_bit 120 | if b1 == b2: 121 | self.low -= (b1 << self.max_bit) 122 | self.high -= (b1 << self.max_bit) 123 | assert self.high >= self.low, (self.high, self.low, self.max_bit) 124 | assert self.low >= 0 125 | self.max_bit -= 1 126 | self.packer.push(b1) 127 | else: 128 | break 129 | 130 | def push(self, symbol: int, quantized_cdf: torch.Tensor): 131 | """Push the given symbol on the stream, flushing out bits 132 | if possible. 133 | 134 | Args: 135 | symbol (int): symbol to encode with the AC. 136 | quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf` 137 | to build this from your pdf estimate. 138 | """ 139 | while self.delta < 2 ** self.total_range_bits: 140 | self.low *= 2 141 | self.high = self.high * 2 + 1 142 | self.max_bit += 1 143 | 144 | range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item() 145 | range_high = quantized_cdf[symbol].item() - 1 146 | effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits)))) 147 | effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits)))) 148 | assert self.low <= self.high 149 | self.high = self.low + effective_high 150 | self.low = self.low + effective_low 151 | assert self.low <= self.high, (effective_low, effective_high, range_low, range_high) 152 | self._dbg.append((self.low, self.high)) 153 | self._dbg2.append((self.low, self.high)) 154 | outs = self._flush_common_prefix() 155 | assert self.low <= self.high 156 | assert self.max_bit >= -1 157 | assert self.max_bit <= 61, self.max_bit 158 | return outs 159 | 160 | def flush(self): 161 | """Flush the remaining information to the stream. 162 | """ 163 | while self.max_bit >= 0: 164 | b1 = (self.low >> self.max_bit) & 1 165 | self.packer.push(b1) 166 | self.max_bit -= 1 167 | self.packer.flush() 168 | 169 | 170 | class ArithmeticDecoder: 171 | """ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation. 172 | 173 | Note that this must be called with **exactly** the same parameters and sequence 174 | of quantized cdf as the arithmetic encoder or the wrong values will be decoded. 175 | 176 | If the AC encoder current range is [L, H], with `L` and `H` having the some common 177 | prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream. 178 | For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside 179 | `[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained 180 | for a specific sequence of symbols and a binary-search allows us to decode those symbols. 181 | At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols, 182 | and we will need to read new bits from the stream and repeat the process. 183 | 184 | """ 185 | def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24): 186 | self.total_range_bits = total_range_bits 187 | self.low: int = 0 188 | self.high: int = 0 189 | self.current: int = 0 190 | self.max_bit: int = -1 191 | self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time. 192 | # Following is for debugging 193 | self._dbg: tp.List[tp.Any] = [] 194 | self._dbg2: tp.List[tp.Any] = [] 195 | self._last: tp.Any = None 196 | 197 | @property 198 | def delta(self) -> int: 199 | return self.high - self.low + 1 200 | 201 | def _flush_common_prefix(self): 202 | # Given the current range [L, H], if both have a common prefix, 203 | # we know we can remove it from our representation to avoid handling large numbers. 204 | while self.max_bit >= 0: 205 | b1 = self.low >> self.max_bit 206 | b2 = self.high >> self.max_bit 207 | if b1 == b2: 208 | self.low -= (b1 << self.max_bit) 209 | self.high -= (b1 << self.max_bit) 210 | self.current -= (b1 << self.max_bit) 211 | assert self.high >= self.low 212 | assert self.low >= 0 213 | self.max_bit -= 1 214 | else: 215 | break 216 | 217 | def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]: 218 | """Pull a symbol, reading as many bits from the stream as required. 219 | This returns `None` when the stream has been exhausted. 220 | 221 | Args: 222 | quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf` 223 | to build this from your pdf estimate. This must be **exatly** 224 | the same cdf as the one used at encoding time. 225 | """ 226 | while self.delta < 2 ** self.total_range_bits: 227 | bit = self.unpacker.pull() 228 | if bit is None: 229 | return None 230 | self.low *= 2 231 | self.high = self.high * 2 + 1 232 | self.current = self.current * 2 + bit 233 | self.max_bit += 1 234 | 235 | def bin_search(low_idx: int, high_idx: int): 236 | # Binary search is not just for coding interviews :) 237 | if high_idx < low_idx: 238 | raise RuntimeError("Binary search failed") 239 | mid = (low_idx + high_idx) // 2 240 | range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0 241 | range_high = quantized_cdf[mid].item() - 1 242 | effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits)))) 243 | effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits)))) 244 | low = effective_low + self.low 245 | high = effective_high + self.low 246 | if self.current >= low: 247 | if self.current <= high: 248 | return (mid, low, high, self.current) 249 | else: 250 | return bin_search(mid + 1, high_idx) 251 | else: 252 | return bin_search(low_idx, mid - 1) 253 | 254 | self._last = (self.low, self.high, self.current, self.max_bit) 255 | sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1) 256 | self._dbg.append((self.low, self.high, self.current)) 257 | self._flush_common_prefix() 258 | self._dbg2.append((self.low, self.high, self.current)) 259 | 260 | return sym 261 | 262 | 263 | def test(): 264 | torch.manual_seed(1234) 265 | random.seed(1234) 266 | for _ in range(4): 267 | pdfs = [] 268 | cardinality = random.randrange(4000) 269 | steps = random.randrange(100, 500) 270 | fo = io.BytesIO() 271 | encoder = ArithmeticCoder(fo) 272 | symbols = [] 273 | for step in range(steps): 274 | pdf = torch.softmax(torch.randn(cardinality), dim=0) 275 | pdfs.append(pdf) 276 | q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits) 277 | symbol = torch.multinomial(pdf, 1).item() 278 | symbols.append(symbol) 279 | encoder.push(symbol, q_cdf) 280 | encoder.flush() 281 | 282 | fo.seek(0) 283 | decoder = ArithmeticDecoder(fo) 284 | for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)): 285 | q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits) 286 | decoded_symbol = decoder.pull(q_cdf) 287 | assert decoded_symbol == symbol, idx 288 | assert decoder.pull(torch.zeros(1)) is None 289 | 290 | 291 | if __name__ == "__main__": 292 | test() 293 | -------------------------------------------------------------------------------- /speechtokenizer/quantization/core_vq.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | # This implementation is inspired from 8 | # https://github.com/lucidrains/vector-quantize-pytorch 9 | # which is released under MIT License. Hereafter, the original license: 10 | # MIT License 11 | # 12 | # Copyright (c) 2020 Phil Wang 13 | # 14 | # Permission is hereby granted, free of charge, to any person obtaining a copy 15 | # of this software and associated documentation files (the "Software"), to deal 16 | # in the Software without restriction, including without limitation the rights 17 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 18 | # copies of the Software, and to permit persons to whom the Software is 19 | # furnished to do so, subject to the following conditions: 20 | # 21 | # The above copyright notice and this permission notice shall be included in all 22 | # copies or substantial portions of the Software. 23 | # 24 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 25 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 26 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 27 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 28 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 29 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 30 | # SOFTWARE. 31 | 32 | """Core vector quantization implementation.""" 33 | import typing as tp 34 | 35 | from einops import rearrange, repeat 36 | import torch 37 | from torch import nn 38 | import torch.nn.functional as F 39 | 40 | from .distrib import broadcast_tensors, rank 41 | 42 | 43 | def default(val: tp.Any, d: tp.Any) -> tp.Any: 44 | return val if val is not None else d 45 | 46 | 47 | def ema_inplace(moving_avg, new, decay: float): 48 | moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) 49 | 50 | 51 | def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5): 52 | return (x + epsilon) / (x.sum() + n_categories * epsilon) 53 | 54 | 55 | def uniform_init(*shape: int): 56 | t = torch.empty(shape) 57 | nn.init.kaiming_uniform_(t) 58 | return t 59 | 60 | 61 | def sample_vectors(samples, num: int): 62 | num_samples, device = samples.shape[0], samples.device 63 | 64 | if num_samples >= num: 65 | indices = torch.randperm(num_samples, device=device)[:num] 66 | else: 67 | indices = torch.randint(0, num_samples, (num,), device=device) 68 | 69 | return samples[indices] 70 | 71 | 72 | def kmeans(samples, num_clusters: int, num_iters: int = 10): 73 | dim, dtype = samples.shape[-1], samples.dtype 74 | 75 | means = sample_vectors(samples, num_clusters) 76 | 77 | for _ in range(num_iters): 78 | diffs = rearrange(samples, "n d -> n () d") - rearrange( 79 | means, "c d -> () c d" 80 | ) 81 | dists = -(diffs ** 2).sum(dim=-1) 82 | 83 | buckets = dists.max(dim=-1).indices 84 | bins = torch.bincount(buckets, minlength=num_clusters) 85 | zero_mask = bins == 0 86 | bins_min_clamped = bins.masked_fill(zero_mask, 1) 87 | 88 | new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) 89 | new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) 90 | new_means = new_means / bins_min_clamped[..., None] 91 | 92 | means = torch.where(zero_mask[..., None], means, new_means) 93 | 94 | return means, bins 95 | 96 | 97 | class EuclideanCodebook(nn.Module): 98 | """Codebook with Euclidean distance. 99 | Args: 100 | dim (int): Dimension. 101 | codebook_size (int): Codebook size. 102 | kmeans_init (bool): Whether to use k-means to initialize the codebooks. 103 | If set to true, run the k-means algorithm on the first training batch and use 104 | the learned centroids as initialization. 105 | kmeans_iters (int): Number of iterations used for k-means algorithm at initialization. 106 | decay (float): Decay for exponential moving average over the codebooks. 107 | epsilon (float): Epsilon value for numerical stability. 108 | threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes 109 | that have an exponential moving average cluster size less than the specified threshold with 110 | randomly selected vector from the current batch. 111 | """ 112 | def __init__( 113 | self, 114 | dim: int, 115 | codebook_size: int, 116 | kmeans_init: int = False, 117 | kmeans_iters: int = 10, 118 | decay: float = 0.99, 119 | epsilon: float = 1e-5, 120 | threshold_ema_dead_code: int = 2, 121 | ): 122 | super().__init__() 123 | self.decay = decay 124 | init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros 125 | embed = init_fn(codebook_size, dim) 126 | 127 | self.codebook_size = codebook_size 128 | 129 | self.kmeans_iters = kmeans_iters 130 | self.epsilon = epsilon 131 | self.threshold_ema_dead_code = threshold_ema_dead_code 132 | 133 | self.register_buffer("inited", torch.Tensor([not kmeans_init])) 134 | self.register_buffer("cluster_size", torch.zeros(codebook_size)) 135 | self.register_buffer("embed", embed) 136 | self.register_buffer("embed_avg", embed.clone()) 137 | 138 | @torch.jit.ignore 139 | def init_embed_(self, data): 140 | if self.inited: 141 | return 142 | 143 | embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) 144 | self.embed.data.copy_(embed) 145 | self.embed_avg.data.copy_(embed.clone()) 146 | self.cluster_size.data.copy_(cluster_size) 147 | self.inited.data.copy_(torch.Tensor([True])) 148 | # Make sure all buffers across workers are in sync after initialization 149 | #broadcast_tensors(self.buffers()) 150 | 151 | def replace_(self, samples, mask): 152 | modified_codebook = torch.where( 153 | mask[..., None], sample_vectors(samples, self.codebook_size), self.embed 154 | ) 155 | self.embed.data.copy_(modified_codebook) 156 | 157 | def expire_codes_(self, batch_samples): 158 | if self.threshold_ema_dead_code == 0: 159 | return 160 | 161 | expired_codes = self.cluster_size < self.threshold_ema_dead_code 162 | if not torch.any(expired_codes): 163 | return 164 | 165 | batch_samples = rearrange(batch_samples, "... d -> (...) d") 166 | self.replace_(batch_samples, mask=expired_codes) 167 | #broadcast_tensors(self.buffers()) 168 | 169 | def preprocess(self, x): 170 | x = rearrange(x, "... d -> (...) d") 171 | return x 172 | 173 | def quantize(self, x): 174 | embed = self.embed.t() 175 | dist = -( 176 | x.pow(2).sum(1, keepdim=True) 177 | - 2 * x @ embed 178 | + embed.pow(2).sum(0, keepdim=True) 179 | ) 180 | embed_ind = dist.max(dim=-1).indices 181 | return embed_ind 182 | 183 | def postprocess_emb(self, embed_ind, shape): 184 | return embed_ind.view(*shape[:-1]) 185 | 186 | def dequantize(self, embed_ind): 187 | quantize = F.embedding(embed_ind, self.embed) 188 | return quantize 189 | 190 | def encode(self, x): 191 | shape = x.shape 192 | # pre-process 193 | x = self.preprocess(x) 194 | # quantize 195 | embed_ind = self.quantize(x) 196 | # post-process 197 | embed_ind = self.postprocess_emb(embed_ind, shape) 198 | return embed_ind 199 | 200 | def decode(self, embed_ind): 201 | quantize = self.dequantize(embed_ind) 202 | return quantize 203 | 204 | def forward(self, x): 205 | shape, dtype = x.shape, x.dtype 206 | x = self.preprocess(x) 207 | 208 | self.init_embed_(x) 209 | 210 | embed_ind = self.quantize(x) 211 | embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) 212 | embed_ind = self.postprocess_emb(embed_ind, shape) 213 | quantize = self.dequantize(embed_ind) 214 | 215 | if self.training: 216 | # We do the expiry of code at that point as buffers are in sync 217 | # and all the workers will take the same decision. 218 | self.expire_codes_(x) 219 | ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) 220 | embed_sum = x.t() @ embed_onehot 221 | ema_inplace(self.embed_avg, embed_sum.t(), self.decay) 222 | cluster_size = ( 223 | laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) 224 | * self.cluster_size.sum() 225 | ) 226 | embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) 227 | self.embed.data.copy_(embed_normalized) 228 | 229 | return quantize, embed_ind 230 | 231 | 232 | class VectorQuantization(nn.Module): 233 | """Vector quantization implementation. 234 | Currently supports only euclidean distance. 235 | Args: 236 | dim (int): Dimension 237 | codebook_size (int): Codebook size 238 | codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim. 239 | decay (float): Decay for exponential moving average over the codebooks. 240 | epsilon (float): Epsilon value for numerical stability. 241 | kmeans_init (bool): Whether to use kmeans to initialize the codebooks. 242 | kmeans_iters (int): Number of iterations used for kmeans initialization. 243 | threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes 244 | that have an exponential moving average cluster size less than the specified threshold with 245 | randomly selected vector from the current batch. 246 | commitment_weight (float): Weight for commitment loss. 247 | """ 248 | def __init__( 249 | self, 250 | dim: int, 251 | codebook_size: int, 252 | codebook_dim: tp.Optional[int] = None, 253 | decay: float = 0.99, 254 | epsilon: float = 1e-5, 255 | kmeans_init: bool = True, 256 | kmeans_iters: int = 50, 257 | threshold_ema_dead_code: int = 2, 258 | commitment_weight: float = 1., 259 | ): 260 | super().__init__() 261 | _codebook_dim: int = default(codebook_dim, dim) 262 | 263 | requires_projection = _codebook_dim != dim 264 | self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()) 265 | self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()) 266 | 267 | self.epsilon = epsilon 268 | self.commitment_weight = commitment_weight 269 | 270 | self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size, 271 | kmeans_init=kmeans_init, kmeans_iters=kmeans_iters, 272 | decay=decay, epsilon=epsilon, 273 | threshold_ema_dead_code=threshold_ema_dead_code) 274 | self.codebook_size = codebook_size 275 | 276 | @property 277 | def codebook(self): 278 | return self._codebook.embed 279 | 280 | def encode(self, x): 281 | x = rearrange(x, "b d n -> b n d") 282 | x = self.project_in(x) 283 | embed_in = self._codebook.encode(x) 284 | return embed_in 285 | 286 | def decode(self, embed_ind): 287 | quantize = self._codebook.decode(embed_ind) 288 | quantize = self.project_out(quantize) 289 | quantize = rearrange(quantize, "b n d -> b d n") 290 | return quantize 291 | 292 | def forward(self, x): 293 | device = x.device 294 | x = rearrange(x, "b d n -> b n d") 295 | x = self.project_in(x) 296 | 297 | quantize, embed_ind = self._codebook(x) 298 | 299 | if self.training: 300 | quantize = x + (quantize - x).detach() 301 | 302 | loss = torch.tensor([0.0], device=device, requires_grad=self.training) 303 | 304 | if self.training: 305 | if self.commitment_weight > 0: 306 | commit_loss = F.mse_loss(quantize.detach(), x) 307 | loss = loss + commit_loss * self.commitment_weight 308 | 309 | quantize = self.project_out(quantize) 310 | quantize = rearrange(quantize, "b n d -> b d n") 311 | return quantize, embed_ind, loss 312 | 313 | 314 | class ResidualVectorQuantization(nn.Module): 315 | """Residual vector quantization implementation. 316 | Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf 317 | """ 318 | def __init__(self, *, num_quantizers, **kwargs): 319 | super().__init__() 320 | self.layers = nn.ModuleList( 321 | [VectorQuantization(**kwargs) for _ in range(num_quantizers)] 322 | ) 323 | 324 | def forward(self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None): 325 | quantized_out = 0.0 326 | residual = x 327 | 328 | all_losses = [] 329 | all_indices = [] 330 | out_quantized = [] 331 | 332 | n_q = n_q or len(self.layers) 333 | 334 | for i, layer in enumerate(self.layers[:n_q]): 335 | quantized, indices, loss = layer(residual) 336 | residual = residual - quantized 337 | quantized_out = quantized_out + quantized 338 | 339 | all_indices.append(indices) 340 | all_losses.append(loss) 341 | if layers and i in layers: 342 | out_quantized.append(quantized) 343 | 344 | out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) 345 | return quantized_out, out_indices, out_losses, out_quantized 346 | 347 | def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int]= None) -> torch.Tensor: 348 | residual = x 349 | all_indices = [] 350 | n_q = n_q or len(self.layers) 351 | st = st or 0 352 | for layer in self.layers[st:n_q]: 353 | indices = layer.encode(residual) 354 | quantized = layer.decode(indices) 355 | residual = residual - quantized 356 | all_indices.append(indices) 357 | out_indices = torch.stack(all_indices) 358 | return out_indices 359 | 360 | def decode(self, q_indices: torch.Tensor, st: int=0) -> torch.Tensor: 361 | quantized_out = torch.tensor(0.0, device=q_indices.device) 362 | for i, indices in enumerate(q_indices): 363 | layer = self.layers[st + i] 364 | quantized = layer.decode(indices) 365 | quantized_out = quantized_out + quantized 366 | return quantized_out 367 | -------------------------------------------------------------------------------- /speechtokenizer/quantization/distrib.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Torch distributed utilities.""" 8 | 9 | import typing as tp 10 | 11 | import torch 12 | 13 | 14 | def rank(): 15 | if torch.distributed.is_initialized(): 16 | return torch.distributed.get_rank() 17 | else: 18 | return 0 19 | 20 | 21 | def world_size(): 22 | if torch.distributed.is_initialized(): 23 | return torch.distributed.get_world_size() 24 | else: 25 | return 1 26 | 27 | 28 | def is_distributed(): 29 | return world_size() > 1 30 | 31 | 32 | def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM): 33 | if is_distributed(): 34 | return torch.distributed.all_reduce(tensor, op) 35 | 36 | 37 | def _is_complex_or_float(tensor): 38 | return torch.is_floating_point(tensor) or torch.is_complex(tensor) 39 | 40 | 41 | def _check_number_of_params(params: tp.List[torch.Tensor]): 42 | # utility function to check that the number of params in all workers is the same, 43 | # and thus avoid a deadlock with distributed all reduce. 44 | if not is_distributed() or not params: 45 | return 46 | #print('params[0].device ', params[0].device) 47 | tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long) 48 | all_reduce(tensor) 49 | if tensor.item() != len(params) * world_size(): 50 | # If not all the workers have the same number, for at least one of them, 51 | # this inequality will be verified. 52 | raise RuntimeError(f"Mismatch in number of params: ours is {len(params)}, " 53 | "at least one worker has a different one.") 54 | 55 | 56 | def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0): 57 | """Broadcast the tensors from the given parameters to all workers. 58 | This can be used to ensure that all workers have the same model to start with. 59 | """ 60 | if not is_distributed(): 61 | return 62 | tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)] 63 | _check_number_of_params(tensors) 64 | handles = [] 65 | for tensor in tensors: 66 | # src = int(rank()) # added code 67 | handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True) 68 | handles.append(handle) 69 | for handle in handles: 70 | handle.wait() 71 | 72 | 73 | def sync_buffer(buffers, average=True): 74 | """ 75 | Sync grad for buffers. If average is False, broadcast instead of averaging. 76 | """ 77 | if not is_distributed(): 78 | return 79 | handles = [] 80 | for buffer in buffers: 81 | if torch.is_floating_point(buffer.data): 82 | if average: 83 | handle = torch.distributed.all_reduce( 84 | buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True) 85 | else: 86 | handle = torch.distributed.broadcast( 87 | buffer.data, src=0, async_op=True) 88 | handles.append((buffer, handle)) 89 | for buffer, handle in handles: 90 | handle.wait() 91 | if average: 92 | buffer.data /= world_size 93 | 94 | 95 | def sync_grad(params): 96 | """ 97 | Simpler alternative to DistributedDataParallel, that doesn't rely 98 | on any black magic. For simple models it can also be as fast. 99 | Just call this on your model parameters after the call to backward! 100 | """ 101 | if not is_distributed(): 102 | return 103 | handles = [] 104 | for p in params: 105 | if p.grad is not None: 106 | handle = torch.distributed.all_reduce( 107 | p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True) 108 | handles.append((p, handle)) 109 | for p, handle in handles: 110 | handle.wait() 111 | p.grad.data /= world_size() 112 | 113 | 114 | def average_metrics(metrics: tp.Dict[str, float], count=1.): 115 | """Average a dictionary of metrics across all workers, using the optional 116 | `count` as unormalized weight. 117 | """ 118 | if not is_distributed(): 119 | return metrics 120 | keys, values = zip(*metrics.items()) 121 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 122 | tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32) 123 | tensor *= count 124 | all_reduce(tensor) 125 | averaged = (tensor[:-1] / tensor[-1]).cpu().tolist() 126 | return dict(zip(keys, averaged)) 127 | -------------------------------------------------------------------------------- /speechtokenizer/quantization/vq.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Residual vector quantizer implementation.""" 8 | 9 | from dataclasses import dataclass, field 10 | import math 11 | import typing as tp 12 | 13 | import torch 14 | from torch import nn 15 | 16 | from .core_vq import ResidualVectorQuantization 17 | 18 | 19 | @dataclass 20 | class QuantizedResult: 21 | quantized: torch.Tensor 22 | codes: torch.Tensor 23 | bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item. 24 | penalty: tp.Optional[torch.Tensor] = None 25 | metrics: dict = field(default_factory=dict) 26 | 27 | 28 | class ResidualVectorQuantizer(nn.Module): 29 | """Residual Vector Quantizer. 30 | Args: 31 | dimension (int): Dimension of the codebooks. 32 | n_q (int): Number of residual vector quantizers used. 33 | bins (int): Codebook size. 34 | decay (float): Decay for exponential moving average over the codebooks. 35 | kmeans_init (bool): Whether to use kmeans to initialize the codebooks. 36 | kmeans_iters (int): Number of iterations used for kmeans initialization. 37 | threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes 38 | that have an exponential moving average cluster size less than the specified threshold with 39 | randomly selected vector from the current batch. 40 | """ 41 | def __init__( 42 | self, 43 | dimension: int = 256, 44 | n_q: int = 8, 45 | bins: int = 1024, 46 | decay: float = 0.99, 47 | kmeans_init: bool = True, 48 | kmeans_iters: int = 50, 49 | threshold_ema_dead_code: int = 2, 50 | ): 51 | super().__init__() 52 | self.n_q = n_q 53 | self.dimension = dimension 54 | self.bins = bins 55 | self.decay = decay 56 | self.kmeans_init = kmeans_init 57 | self.kmeans_iters = kmeans_iters 58 | self.threshold_ema_dead_code = threshold_ema_dead_code 59 | self.vq = ResidualVectorQuantization( 60 | dim=self.dimension, 61 | codebook_size=self.bins, 62 | num_quantizers=self.n_q, 63 | decay=self.decay, 64 | kmeans_init=self.kmeans_init, 65 | kmeans_iters=self.kmeans_iters, 66 | threshold_ema_dead_code=self.threshold_ema_dead_code, 67 | ) 68 | 69 | def forward(self, x: torch.Tensor, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None) -> QuantizedResult: 70 | """Residual vector quantization on the given input tensor. 71 | Args: 72 | x (torch.Tensor): Input tensor. 73 | n_q (int): Number of quantizer used to quantize. Default: All quantizers. 74 | layers (list): Layer that need to return quantized. Defalt: None. 75 | Returns: 76 | QuantizedResult: 77 | The quantized (or approximately quantized) representation with 78 | the associated numbert quantizers and layer quantized required to return. 79 | """ 80 | n_q = n_q if n_q else self.n_q 81 | if layers and max(layers) >= n_q: 82 | raise ValueError(f'Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B.') 83 | quantized, codes, commit_loss, quantized_list = self.vq(x, n_q=n_q, layers=layers) 84 | return quantized, codes, torch.mean(commit_loss), quantized_list 85 | 86 | 87 | def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None) -> torch.Tensor: 88 | """Encode a given input tensor with the specified sample rate at the given bandwidth. 89 | The RVQ encode method sets the appropriate number of quantizer to use 90 | and returns indices for each quantizer. 91 | Args: 92 | x (torch.Tensor): Input tensor. 93 | n_q (int): Number of quantizer used to quantize. Default: All quantizers. 94 | st (int): Start to encode input from which layers. Default: 0. 95 | """ 96 | n_q = n_q if n_q else self.n_q 97 | st = st or 0 98 | codes = self.vq.encode(x, n_q=n_q, st=st) 99 | return codes 100 | 101 | def decode(self, codes: torch.Tensor, st: int = 0) -> torch.Tensor: 102 | """Decode the given codes to the quantized representation. 103 | Args: 104 | codes (torch.Tensor): Input indices for each quantizer. 105 | st (int): Start to decode input codes from which layers. Default: 0. 106 | """ 107 | quantized = self.vq.decode(codes, st=st) 108 | return quantized 109 | -------------------------------------------------------------------------------- /speechtokenizer/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import SpeechTokenizerTrainer -------------------------------------------------------------------------------- /speechtokenizer/trainer/dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | from torch.nn.utils.rnn import pad_sequence 3 | import torchaudio 4 | import random 5 | import torch 6 | import numpy as np 7 | 8 | def collate_fn(data): 9 | # return pad_sequence(data, batch_first=True) 10 | # return pad_sequence(*data) 11 | is_one_data = not isinstance(data[0], tuple) 12 | outputs = [] 13 | if is_one_data: 14 | for datum in data: 15 | if isinstance(datum, torch.Tensor): 16 | output = datum.unsqueeze(0) 17 | else: 18 | output = torch.tensor([datum]) 19 | outputs.append(output) 20 | return tuple(outputs) 21 | for datum in zip(*data): 22 | if isinstance(datum[0], torch.Tensor): 23 | output = pad_sequence(datum, batch_first=True) 24 | else: 25 | output = torch.tensor(list(datum)) 26 | outputs.append(output) 27 | 28 | return tuple(outputs) 29 | 30 | def get_dataloader(ds, **kwargs): 31 | return DataLoader(ds, collate_fn=collate_fn, **kwargs) 32 | 33 | class audioDataset(Dataset): 34 | 35 | def __init__(self, 36 | file_list, 37 | segment_size, 38 | sample_rate, 39 | downsample_rate = 320, 40 | valid=False): 41 | super().__init__() 42 | self.file_list = file_list 43 | self.segment_size = segment_size 44 | self.sample_rate = sample_rate 45 | self.valid = valid 46 | self.downsample_rate = downsample_rate 47 | 48 | def __len__(self): 49 | return len(self.file_list) 50 | 51 | 52 | def __getitem__(self, index): 53 | file = self.file_list[index].strip() 54 | audio_file, feature_file = file.split('\t') 55 | audio, sr = torchaudio.load(audio_file) 56 | feature = torch.from_numpy(np.load(feature_file)) 57 | audio = audio.mean(axis=0) 58 | if sr != self.sample_rate: 59 | audio = torchaudio.functional.resample(audio, sr, self.sample_rate) 60 | if audio.size(-1) > self.segment_size: 61 | if self.valid: 62 | return audio[:self.segment_size], feature[:self.segment_size // self.downsample_rate] 63 | max_audio_start = audio.size(-1) - self.segment_size 64 | audio_start = random.randint(0, max_audio_start) 65 | audio = audio[audio_start:audio_start+self.segment_size] 66 | feature_start = min(int(audio_start / self.downsample_rate), feature.size(0) - self.segment_size // self.downsample_rate) 67 | feature = feature[feature_start:feature_start + self.segment_size // self.downsample_rate, :] 68 | else: 69 | if not self.valid: 70 | audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(-1)), 'constant') 71 | return audio, feature -------------------------------------------------------------------------------- /speechtokenizer/trainer/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | import matplotlib.pylab as plt 4 | 5 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 6 | return torch.log(torch.clamp(x, min=clip_val) * C) 7 | 8 | 9 | def spectral_normalize_torch(magnitudes): 10 | output = dynamic_range_compression_torch(magnitudes) 11 | return output 12 | 13 | 14 | mel_basis = {} 15 | hann_window = {} 16 | 17 | 18 | def mel_spectrogram(y, n_fft, num_mels, sample_rate, hop_size, win_size, fmin, fmax, center=False): 19 | 20 | global mel_basis, hann_window 21 | if fmax not in mel_basis: 22 | mel_transform = torchaudio.transforms.MelScale(n_mels=num_mels, sample_rate=sample_rate, n_stft=n_fft//2+1, f_min=fmin, f_max=fmax, norm='slaney', mel_scale="htk") 23 | mel_basis[str(fmax)+'_'+str(y.device)] = mel_transform.fb.float().T.to(y.device) 24 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) 25 | 26 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') 27 | y = y.squeeze(1) 28 | 29 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], 30 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True) 31 | spec = torch.abs(spec) + 1e-9 32 | spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec) 33 | spec = spectral_normalize_torch(spec) 34 | 35 | return spec 36 | 37 | def plot_spectrogram(spectrogram): 38 | fig, ax = plt.subplots(figsize=(10, 2)) 39 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 40 | interpolation='none') 41 | plt.colorbar(im, ax=ax) 42 | 43 | fig.canvas.draw() 44 | plt.close() 45 | 46 | return fig 47 | 48 | def recon_loss(x, x_hat): 49 | length = min(x.size(-1), x_hat.size(-1)) 50 | return torch.nn.functional.l1_loss(x[:, :, :length], x_hat[:, :, :length]) 51 | 52 | def mel_loss(x, x_hat, **kwargs): 53 | x_mel = mel_spectrogram(x.squeeze(1), **kwargs) 54 | x_hat_mel = mel_spectrogram(x_hat.squeeze(1), **kwargs) 55 | length = min(x_mel.size(2), x_hat_mel.size(2)) 56 | return torch.nn.functional.l1_loss(x_mel[:, :, :length], x_hat_mel[:, :, :length]) 57 | 58 | def feature_loss(fmap_r, fmap_g): 59 | loss = 0 60 | for dr, dg in zip(fmap_r, fmap_g): 61 | for rl, gl in zip(dr, dg): 62 | loss += torch.mean(torch.abs(rl - gl)) 63 | 64 | return loss*2 65 | 66 | 67 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 68 | loss = 0 69 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 70 | r_loss = torch.mean((1-dr)**2) 71 | g_loss = torch.mean(dg**2) 72 | loss += (r_loss + g_loss) 73 | 74 | return loss 75 | 76 | 77 | def adversarial_loss(disc_outputs): 78 | loss = 0 79 | for dg in disc_outputs: 80 | l = torch.mean((1-dg)**2) 81 | loss += l 82 | 83 | return loss 84 | 85 | 86 | def d_axis_distill_loss(feature, target_feature): 87 | n = min(feature.size(1), target_feature.size(1)) 88 | distill_loss = - torch.log(torch.sigmoid(torch.nn.functional.cosine_similarity(feature[:, :n], target_feature[:, :n], axis=1))).mean() 89 | return distill_loss 90 | 91 | def t_axis_distill_loss(feature, target_feature, lambda_sim=1): 92 | n = min(feature.size(1), target_feature.size(1)) 93 | l1_loss = torch.functional.l1_loss(feature[:, :n], target_feature[:, :n], reduction='mean') 94 | sim_loss = - torch.log(torch.sigmoid(torch.nn.functional.cosine_similarity(feature[:, :n], target_feature[:, :n], axis=-1))).mean() 95 | distill_loss = l1_loss + lambda_sim * sim_loss 96 | return distill_loss -------------------------------------------------------------------------------- /speechtokenizer/trainer/optimizer.py: -------------------------------------------------------------------------------- 1 | from lion_pytorch import Lion 2 | from torch.optim import AdamW, Adam 3 | 4 | def separate_weight_decayable_params(params): 5 | wd_params, no_wd_params = [], [] 6 | for param in params: 7 | param_list = no_wd_params if param.ndim < 2 else wd_params 8 | param_list.append(param) 9 | return wd_params, no_wd_params 10 | 11 | def get_optimizer( 12 | params, 13 | lr = 1e-4, 14 | wd = 1e-2, 15 | betas = (0.9, 0.99), 16 | eps = 1e-8, 17 | filter_by_requires_grad = False, 18 | group_wd_params = True, 19 | use_lion = False, 20 | **kwargs 21 | ): 22 | has_wd = wd > 0 23 | 24 | if filter_by_requires_grad: 25 | params = list(filter(lambda t: t.requires_grad, params)) 26 | 27 | if group_wd_params and has_wd: 28 | wd_params, no_wd_params = separate_weight_decayable_params(params) 29 | 30 | params = [ 31 | {'params': wd_params}, 32 | {'params': no_wd_params, 'weight_decay': 0}, 33 | ] 34 | 35 | if use_lion: 36 | return Lion(params, lr = lr, betas = betas, weight_decay = wd) 37 | 38 | if not has_wd: 39 | return Adam(params, lr = lr, betas = betas, eps = eps) 40 | 41 | return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps) -------------------------------------------------------------------------------- /speechtokenizer/trainer/trainer.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import re 3 | import os 4 | import itertools 5 | 6 | from beartype import beartype 7 | 8 | import torch 9 | from torch import nn 10 | from torch.optim.lr_scheduler import CosineAnnealingLR 11 | 12 | from .dataset import get_dataloader, audioDataset 13 | from .optimizer import get_optimizer 14 | from torch.utils import tensorboard 15 | from .loss import * 16 | import json 17 | from speechtokenizer import SpeechTokenizer 18 | import time 19 | from tqdm import tqdm 20 | from accelerate import Accelerator, DistributedType, DistributedDataParallelKwargs, DataLoaderConfiguration 21 | 22 | 23 | # helpers 24 | 25 | def exists(val): 26 | return val is not None 27 | 28 | def cycle(dl): 29 | while True: 30 | for data in dl: 31 | yield data 32 | 33 | def cast_tuple(t): 34 | return t if isinstance(t, (tuple, list)) else (t,) 35 | 36 | 37 | def accum_log(log, new_logs): 38 | for key, new_value in new_logs.items(): 39 | old_value = log.get(key, 0.) 40 | log[key] = old_value + new_value 41 | return log 42 | 43 | def checkpoint_num_steps(checkpoint_path): 44 | """Returns the number of steps trained from a checkpoint based on the filename. 45 | 46 | Filename format assumed to be something like "/path/to/soundstorm.20000.pt" which is 47 | for 20k train steps. Returns 20000 in that case. 48 | """ 49 | results = re.findall(r'\d+', str(checkpoint_path)) 50 | 51 | if len(results) == 0: 52 | return 0 53 | 54 | return int(results[-1]) 55 | 56 | 57 | class SpeechTokenizerTrainer(nn.Module): 58 | @beartype 59 | def __init__( 60 | self, 61 | generator: SpeechTokenizer, 62 | discriminators: dict, 63 | cfg, 64 | accelerate_kwargs: dict = dict(), 65 | ): 66 | super().__init__() 67 | ddp_kwargs = DistributedDataParallelKwargs() 68 | torch.manual_seed(cfg.get('seed')) 69 | split_batches = cfg.get("split_batches", False) 70 | self.log_steps = cfg.get('log_steps') 71 | self.stdout_steps = cfg.get('stdout_steps') 72 | self.save_model_steps = cfg.get('save_model_steps') 73 | results_folder = cfg.get('results_folder') 74 | self.results_folder = Path(results_folder) 75 | self.num_ckpt_keep = cfg.get("num_ckpt_keep") 76 | self.epochs = cfg.get("epochs") 77 | self.num_warmup_steps = cfg.get("num_warmup_steps") 78 | self.batch_size = cfg.get("batch_size") 79 | self.sample_rate = cfg.get('sample_rate') 80 | self.showpiece_num = cfg.get('showpiece_num', 8) 81 | project_name = 'SpeechTokenizer' 82 | 83 | if not self.results_folder.exists(): 84 | self.results_folder.mkdir(parents = True, exist_ok = True) 85 | 86 | with open(f'{str(self.results_folder)}/config.json', 'w+') as f: 87 | json.dump(cfg, f, ensure_ascii=False, indent=4) 88 | 89 | 90 | # tracker = AudioTensorBoardTracker(run_name=project_name, logging_dir=results_folder) 91 | dataloader_config = DataLoaderConfiguration(split_batches=split_batches) 92 | self.accelerator = Accelerator( 93 | dataloader_config=dataloader_config, 94 | kwargs_handlers=[ddp_kwargs], 95 | # log_with=tracker, 96 | **accelerate_kwargs 97 | ) 98 | 99 | if self.is_main: 100 | self.writer = tensorboard.SummaryWriter(os.path.join(results_folder, 'logs')) 101 | 102 | self.generator = generator 103 | self.discriminators = discriminators 104 | 105 | 106 | self.register_buffer('steps', torch.Tensor([0])) 107 | 108 | 109 | 110 | self.mel_loss_lambdas = cfg.get('mel_loss_lambdas') 111 | self.commitment_loss_lambda = cfg.get('commitment_loss_lambda') 112 | self.recon_loss_lambda = cfg.get('recon_loss_lambda') 113 | self.distill_loss_lambda = cfg.get('distill_loss_lambda') 114 | distill_type = cfg.get('distill_type', 'd_axis') 115 | if distill_type == 't_axis': 116 | from functools import partial 117 | lambda_sim = cfg.get('lambda_sim', 1) 118 | self.distill_loss = partial(t_axis_distill_loss, lambda_sim=lambda_sim) 119 | else: 120 | self.distill_loss = d_axis_distill_loss 121 | self.mel_loss_kwargs_list = [] 122 | mult = 1 123 | for i in range(len(self.mel_loss_lambdas)): 124 | self.mel_loss_kwargs_list.append({'n_fft': cfg.get('n_fft') // mult, 'num_mels':cfg.get('num_mels'),'sample_rate':self.sample_rate, 125 | 'hop_size': cfg.get('hop_size') // mult, 'win_size':cfg.get('win_size') // mult, 'fmin':cfg.get('fmin'), 126 | 'fmax':cfg.get('fmax_for_loss')}) 127 | mult = mult * 2 128 | self.mel_kwargs = {'n_fft': cfg.get('n_fft'), 'num_mels':cfg.get('num_mels'),'sample_rate':self.sample_rate, 129 | 'hop_size': cfg.get('hop_size'), 'win_size':cfg.get('win_size'), 'fmin':cfg.get('fmin'), 130 | 'fmax':cfg.get('fmax')} 131 | 132 | 133 | # max grad norm 134 | 135 | # self.max_grad_norm = max_grad_norm 136 | segment_size = cfg.get("segment_size") 137 | train_files = cfg.get("train_files") 138 | batch_size = cfg.get("batch_size") 139 | self.batch_size = batch_size 140 | with open(train_files, 'r') as f: 141 | train_file_list = f.readlines() 142 | valid_files = cfg.get("valid_files") 143 | with open(valid_files, 'r') as f: 144 | valid_file_list = f.readlines() 145 | 146 | self.ds = audioDataset(file_list=train_file_list, 147 | segment_size=segment_size, 148 | downsample_rate=generator.downsample_rate, 149 | sample_rate=self.sample_rate) 150 | self.valid_ds = audioDataset(file_list=valid_file_list, 151 | segment_size=self.sample_rate * 30, 152 | downsample_rate=generator.downsample_rate, 153 | sample_rate=self.sample_rate, 154 | valid=True) 155 | if self.is_main: 156 | self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples') 157 | 158 | 159 | 160 | assert len(self.ds) >= self.batch_size, 'dataset must have sufficient samples for training' 161 | assert len(self.valid_ds) >= self.batch_size, f'validation dataset must have sufficient number of samples (currently {len(self.valid_ds)}) for training' 162 | 163 | # dataloader 164 | drop_last = cfg.get("drop_last", True) 165 | num_workers = cfg.get("num_workers") 166 | self.dl = get_dataloader(self.ds, batch_size = self.batch_size, shuffle = True, drop_last = drop_last, num_workers=num_workers) 167 | self.valid_dl = get_dataloader(self.valid_ds, batch_size = 1, shuffle = False, drop_last = False, num_workers=1) 168 | 169 | # lr 170 | self.lr = cfg.get("learning_rate") 171 | self.initial_lr = cfg.get("intial_learning_rate") 172 | 173 | # optimizer 174 | self.optim_g = get_optimizer( 175 | generator.parameters(), 176 | lr = cfg.get("learning_rate"), 177 | wd = cfg.get("wd"), 178 | betas = cfg.get("betas") 179 | ) 180 | 181 | self.optim_d = get_optimizer( 182 | itertools.chain(*[i.parameters() for i in self.discriminators.values()]), 183 | lr = cfg.get("learning_rate"), 184 | wd = cfg.get("wd"), 185 | betas = cfg.get("betas") 186 | ) 187 | 188 | # scheduler 189 | # num_train_steps = epochs * self.ds.__len__() // (batch_size * grad_accum_every) 190 | num_train_steps = self.epochs * self.ds.__len__() // batch_size 191 | self.scheduler_g = CosineAnnealingLR(self.optim_g, T_max = num_train_steps) 192 | self.scheduler_d = CosineAnnealingLR(self.optim_d, T_max = num_train_steps) 193 | 194 | 195 | 196 | 197 | # prepare with accelerator 198 | 199 | ( 200 | self.generator, 201 | self.optim_g, 202 | self.optim_d, 203 | self.scheduler_g, 204 | self.scheduler_d, 205 | self.dl, 206 | self.valid_dl 207 | ) = self.accelerator.prepare( 208 | self.generator, 209 | self.optim_g, 210 | self.optim_d, 211 | self.scheduler_g, 212 | self.scheduler_d, 213 | self.dl, 214 | self.valid_dl 215 | ) 216 | self.discriminators = {k:self.accelerator.prepare(v) for k, v in self.discriminators.items()} 217 | 218 | 219 | 220 | hps = {"num_train_steps": num_train_steps, "num_warmup_steps": self.num_warmup_steps, "learning_rate": self.lr, "initial_learning_rate": self.initial_lr, "epochs": self.epochs} 221 | self.accelerator.init_trackers("SpeechTokenizer", config=hps) 222 | self.best_dev_mel_loss = float('inf') 223 | self.plot_gt_once = False 224 | 225 | def save(self, path, best_dev_mel_loss): 226 | if best_dev_mel_loss < self.best_dev_mel_loss: 227 | self.best_dev_mel_loss = best_dev_mel_loss 228 | torch.save(self.accelerator.get_state_dict(self.generator), f'{self.results_folder}/SpeechTokenizer_best_dev.pt') 229 | ckpts = sorted(Path(path).parent.glob(f'SpeechTokenizerTrainer_*')) 230 | if len(ckpts) > self.num_ckpt_keep: 231 | [os.remove(c) for c in ckpts[:-self.num_ckpt_keep]] 232 | pkg = dict( 233 | generator = self.accelerator.get_state_dict(self.generator), 234 | discriminators = {k:self.accelerator.get_state_dict(v) for k, v in self.discriminators.items()}, 235 | optim_g = self.optim_g.state_dict(), 236 | optim_d = self.optim_d.state_dict(), 237 | scheduler_g = self.scheduler_g.state_dict(), 238 | scheduler_d = self.scheduler_d.state_dict(), 239 | best_dev_mel_loss = self.best_dev_mel_loss 240 | ) 241 | torch.save(pkg, path) 242 | 243 | def load(self, path = None, restore_optimizer = True): 244 | if not exists(path): 245 | ckpts = sorted(self.results_folder.glob(f'SpeechTokenizerTrainer_*')) 246 | path = str(ckpts[-1]) 247 | generator = self.accelerator.unwrap_model(self.generator) 248 | pkg = torch.load(path, map_location='cpu') 249 | generator.load_state_dict(pkg['generator']) 250 | discriminators = {k:self.accelerator.unwrap_model(v) for k, v in self.discriminators.items()} 251 | map(lambda kv: kv[1].load_state_dict(pkg['discriminators'][kv[0]]), discriminators.items()) 252 | 253 | if restore_optimizer: 254 | self.optim_d.load_state_dict(pkg['optim_d']) 255 | self.scheduler_d.load_state_dict(pkg['scheduler_d']) 256 | self.optim_g.load_state_dict(pkg['optim_g']) 257 | self.scheduler_g.load_state_dict(pkg['scheduler_g']) 258 | if 'best_dev_mel_loss' in pkg.keys(): 259 | self.best_dev_mel_loss = pkg['best_dev_mel_loss'] 260 | if self.is_main: 261 | self.print(f'The best dev mel loss before is {self.best_dev_mel_loss}') 262 | 263 | # + 1 to start from the next step and avoid overwriting the last checkpoint 264 | self.steps = torch.tensor([checkpoint_num_steps(path) + 1], device=self.device) 265 | 266 | def print(self, msg): 267 | self.accelerator.print(msg) 268 | 269 | @property 270 | def device(self): 271 | return self.accelerator.device 272 | 273 | @property 274 | def is_distributed(self): 275 | return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1) 276 | 277 | @property 278 | def is_main(self): 279 | return self.accelerator.is_main_process 280 | 281 | @property 282 | def is_local_main(self): 283 | return self.accelerator.is_local_main_process 284 | 285 | def warmup(self, step): 286 | if step < self.num_warmup_steps: 287 | return self.initial_lr + (self.lr - self.initial_lr) * step / self.num_warmup_steps 288 | else: 289 | return self.lr 290 | 291 | def log(self, values: dict, step, type=None, **kwargs): 292 | if type == 'figure': 293 | for k, v in values.items(): 294 | self.writer.add_figure(k, v, global_step=step) 295 | elif type == 'audio': 296 | for k, v in values.items(): 297 | self.writer.add_audio(k, v, global_step=step, **kwargs) 298 | else: 299 | for k, v in values.items(): 300 | self.writer.add_scalar(k, v, global_step=step) 301 | 302 | def train(self): 303 | 304 | self.generator.train() 305 | map(lambda disc:disc.train(), self.discriminators.values()) 306 | step_time_log = {} 307 | 308 | steps = int(self.steps.item()) 309 | if steps < self.num_warmup_steps: 310 | lr = self.warmup(steps) 311 | for param_group in self.optim.param_groups: 312 | param_group['lr'] = lr 313 | else: 314 | self.scheduler_d.step() 315 | self.scheduler_g.step() 316 | lr = self.scheduler_d.get_last_lr()[0] 317 | 318 | for epoch in range(self.epochs): 319 | if self.is_main: 320 | print(f'Epoch:{epoch} start...') 321 | 322 | for batch in self.dl: 323 | 324 | tic = time.time() 325 | 326 | x, semantic_feature = batch 327 | x = x.unsqueeze(1) 328 | x_hat, loss_q, feature = self.generator(x) 329 | 330 | # Discriminators 331 | self.optim_d.zero_grad() 332 | discriminator_outputs = list(map(lambda disc:disc(x, x_hat.detach()), self.discriminators.values())) 333 | loss_disc_all = sum(map(lambda x:discriminator_loss(*x[:2]), discriminator_outputs)) 334 | 335 | self.accelerator.backward(loss_disc_all) 336 | self.optim_d.step() 337 | 338 | # Generator 339 | self.optim_g.zero_grad() 340 | discriminator_outputs = list(map(lambda disc:disc(x, x_hat), self.discriminators.values())) 341 | loss_recon = recon_loss(x, x_hat) 342 | loss_mel = sum(map(lambda mel_k:mel_k[0] * mel_loss(x, x_hat, **mel_k[1]), zip(self.mel_loss_lambdas, self.mel_loss_kwargs_list))) 343 | loss_feature = sum(map(lambda x:feature_loss(*x[2:]), discriminator_outputs)) 344 | loss_adversarial = sum(map(lambda x:adversarial_loss(x[1]), discriminator_outputs)) 345 | loss_distill = self.distill_loss(feature, semantic_feature) 346 | loss_generator_all = loss_feature + loss_adversarial + loss_mel + loss_q * self.commitment_loss_lambda + loss_recon * self.recon_loss_lambda + self.distill_loss_lambda * loss_distill 347 | self.accelerator.backward(loss_generator_all) 348 | # if exists(self.max_grad_norm): 349 | # self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) 350 | self.optim_g.step() 351 | 352 | step_time_log = accum_log(step_time_log, {'time_cost': time.time() - tic}) 353 | # self.accelerator.wait_for_everyone() 354 | 355 | 356 | # log 357 | if self.is_main and not (steps % self.stdout_steps): 358 | with torch.inference_mode(): 359 | mel_error = mel_loss(x, x_hat, **self.mel_loss_kwargs_list[0]).item() 360 | self.print(f"Epoch {epoch} -- Step {steps}: Gen Loss: {loss_generator_all.item():0.3f}; Mel Error:{mel_error:0.3f}; Q Loss: {loss_q.item():0.3f}; Distill Loss: {loss_distill.item():0.3f}; Time cost per step: {step_time_log['time_cost'] / self.stdout_steps:0.3f}s") 361 | step_time_log = {} 362 | if self.is_main and not (steps % self.log_steps): 363 | self.log({"train/discriminators loss": loss_disc_all.item(), "train/generator loss": loss_generator_all.item(), "train/feature loss": loss_feature.item(), 364 | "train/adversarial loss": loss_adversarial.item(), "train/quantizer loss": loss_q.item(), "train/mel loss": loss_mel.item(), 365 | "train/mel error": mel_error, "train/distillation loss": loss_distill.item(), "train/learning_rate": lr}, step=steps) 366 | 367 | self.accelerator.wait_for_everyone() 368 | 369 | # validate and save model 370 | if self.is_main and not(steps % self.save_model_steps) and steps != 0: 371 | 372 | self.print('Validation start ...') 373 | # validate 374 | total_mel_error = 0.0 375 | total_distill_loss = 0.0 376 | num = 0 377 | self.generator.eval() 378 | with torch.inference_mode(): 379 | for i, batch in tqdm(enumerate(self.valid_dl)): 380 | x, semantic_feature = batch 381 | x = x.unsqueeze(1) 382 | x_hat, loss_q, feature = self.generator(x) 383 | mel_error = mel_loss(x, x_hat, **self.mel_loss_kwargs_list[0]).item() 384 | total_mel_error += mel_error 385 | loss_distill = self.distill_loss(feature, semantic_feature).item() 386 | total_distill_loss += loss_distill 387 | num += x.size(0) 388 | if i < self.showpiece_num: 389 | if not self.plot_gt_once: 390 | self.log({f'groundtruth/x_{i}': x[0].cpu().detach()}, type='audio', sample_rate=self.sample_rate, step=steps) 391 | x_spec = mel_spectrogram(x.squeeze(1), **self.mel_kwargs) 392 | self.log({f'groundtruth/x_spec_{i}': plot_spectrogram(x_spec[0].cpu().numpy())}, type='figure', step=steps) 393 | 394 | self.log({f'generate/x_hat_{i}': x_hat[0].cpu().detach()}, type='audio', sample_rate=self.sample_rate, step=steps) 395 | x_hat_spec = mel_spectrogram(x_hat.squeeze(1), **self.mel_kwargs) 396 | self.log({f'generate/x_hat_spec_{i}': plot_spectrogram(x_hat_spec[0].cpu().numpy())}, type='figure', step=steps) 397 | if not self.plot_gt_once: 398 | self.plot_gt_once = True 399 | self.print(f'{steps}: dev mel error: {total_mel_error / num:0.3f}\tdev distill loss: {total_distill_loss / num:0.3f}') 400 | self.log({'dev/mel error': total_mel_error / num, 'dev/distillation loss': total_distill_loss / num}, step=steps) 401 | 402 | 403 | # save model 404 | model_path = str(self.results_folder / f'SpeechTokenizerTrainer_{steps:08d}') 405 | self.save(model_path, total_mel_error / num) 406 | self.print(f'{steps}: saving model to {str(self.results_folder)}') 407 | self.generator.train() 408 | 409 | # Update lr 410 | self.steps += 1 411 | steps = int(self.steps.item()) 412 | if steps < self.num_warmup_steps: 413 | lr = self.warmup(steps) 414 | for param_group in self.optim_g.param_groups: 415 | param_group['lr'] = lr 416 | for param_group in self.optim_d.param_groups: 417 | param_group['lr'] = lr 418 | else: 419 | self.scheduler_d.step() 420 | self.scheduler_g.step() 421 | lr = self.scheduler_g.get_last_lr()[0] 422 | 423 | self.print('training complete') 424 | 425 | def continue_train(self): 426 | self.load() 427 | self.train() 428 | 429 | --------------------------------------------------------------------------------