├── LICENSE
├── README.md
├── config
└── spt_base_cfg.json
├── example.py
├── images
├── READ.me
├── overview.png
└── speechtokenizer_framework.jpg
├── samples
└── example_input.wav
├── scripts
├── hubert_rep_extract.py
├── hubert_rep_extract.sh
├── train_example.py
└── train_example.sh
├── setup.py
└── speechtokenizer
├── __init__.py
├── discriminators.py
├── model.py
├── modules
├── __init__.py
├── conv.py
├── lstm.py
├── norm.py
└── seanet.py
├── quantization
├── __init__.py
├── ac.py
├── core_vq.py
├── distrib.py
└── vq.py
└── trainer
├── __init__.py
├── dataset.py
├── loss.py
├── optimizer.py
└── trainer.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SpeechTokenizer: Unified Speech Tokenizer for Speech Language Models
2 |
3 |
4 |
5 | ## Introduction
6 | This is the code for the SpeechTokenizer presented in the [SpeechTokenizer: Unified Speech Tokenizer for Speech Language Models](https://arxiv.org/abs/2308.16692). SpeechTokenizer is a unified speech tokenizer for speech language models, which adopts the Encoder-Decoder architecture with residual vector quantization (RVQ). Unifying semantic and acoustic tokens, SpeechTokenizer disentangles different aspects of speech information hierarchically across different RVQ layers. Specifically, the code indices that the first quantizer of RVQ outputs can be considered as semantic tokens and the output of the remaining quantizers mainly contain timbre info, which serve as supplements for the information lost by the first quantizer. We provide our models:
7 | * A model operated at 16khz on monophonic speech trained on Librispeech with average representation across all HuBERT layers as semantic teacher.
8 | * A model with [Snake activation](https://arxiv.org/abs/2306.06546) operated at 16khz on monophonic speech trained on Librispeech and Common Voice with average representation across all HuBERT layers as semantic teacher.
9 |
10 |
11 |
12 |
13 | Overview
14 |
15 |
16 |
17 | The SpeechTokenizer framework.
18 |
19 |
20 |
21 |
22 | Welcome to try our [SLMTokBench](https://github.com/0nutation/SLMTokBench)
23 | and we will also open source our [USLM](https://github.com/0nutation/USLM)!
24 |
25 | ## Qick Link
26 | * [Relase](#release)
27 | * [Samples](#samples)
28 | * [Installation](#installation)
29 | * [Model List](#model-list)
30 | * [Usage](#usage)
31 | * [Train SpeechTokenizer](#train-speechtokenizer)
32 | * [Data Preprocess](#data-preprocess)
33 | * [Train](#train)
34 | * [Quick Start](#quick-start)
35 | * [Citation](#citation)
36 | * [License](#license)
37 |
38 |
39 | ## Release
40 | - [2024/6/9] 🔥 We released the training code of SpeechTokenizer.
41 | - [2024/3] 🔥 We released a checkpoint of SpeechTokenizer with [Snake activation](https://arxiv.org/abs/2306.06546) trained on LibriSpeech and Common Voice.
42 | - [2023/9/11] 🔥 We released code of [soundstorm_speechtokenizer](https://github.com/ZhangXInFD/soundstorm-speechtokenizer).
43 | - [2023/9/10] 🔥 We released code and checkpoints of [USLM](https://github.com/0nutation/USLM).
44 | - [2023/9/1] 🔥 We released code and checkpoints of SpeechTokenizer. Checkout the [paper](https://arxiv.org/abs/2308.16692) and [demo](https://0nutation.github.io/SpeechTokenizer.github.io/).
45 |
46 | ## Samples
47 |
48 | Samples are provided on [our demo page](https://0nutation.github.io/SpeechTokenizer.github.io/).
49 |
50 | ## Installation
51 |
52 | SpeechTokenizer requires Python>=3.8, and a reasonly recent version of PyTorch.
53 | To install SpeechTokenizer, you can run from this repository:
54 | ```bash
55 | pip install -U speechtokenizer
56 |
57 | # or you can clone the repo and install locally
58 | git clone https://github.com/ZhangXInFD/SpeechTokenizer.git
59 | cd SpeechTokenizer
60 | pip install .
61 | ```
62 | ## Model List
63 | | Model| Dataset |Discription|
64 | |:----|:----:|:----|
65 | |[speechtokenizer_hubert_avg](https://huggingface.co/fnlp/SpeechTokenizer/tree/main/speechtokenizer_hubert_avg)|LibriSpeech|Adopt average representation across all HuBERT layers as semantic teacher |
66 | |[speechtokenizer_snake](https://huggingface.co/fnlp/AnyGPT-speech-modules/tree/main/speechtokenizer)|LibriSpeech + Common Voice|Snake activation, average representation across all HuBERT layers |
67 | ## Usage
68 | ### load model
69 | ```python
70 | from speechtokenizer import SpeechTokenizer
71 |
72 | config_path = '/path/config.json'
73 | ckpt_path = '/path/SpeechTokenizer.pt'
74 | model = SpeechTokenizer.load_from_checkpoint(config_path, ckpt_path)
75 | model.eval()
76 | ```
77 | ### Extracting discrete representations
78 | ```python
79 | import torchaudio
80 | import torch
81 |
82 | # Load and pre-process speech waveform
83 | wav, sr = torchaudio.load('')
84 |
85 | # monophonic checking
86 | if wav.shape(0) > 1:
87 | wav = wav[:1,:]
88 |
89 | if sr != model.sample_rate:
90 | wav = torchaudio.functional.resample(wav, sr, model.sample_rate)
91 |
92 | wav = wav.unsqueeze(0)
93 |
94 | # Extract discrete codes from SpeechTokenizer
95 | with torch.no_grad():
96 | codes = model.encode(wav) # codes: (n_q, B, T)
97 |
98 | RVQ_1 = codes[:1, :, :] # Contain content info, can be considered as semantic tokens
99 | RVQ_supplement = codes[1:, :, :] # Contain timbre info, complete info lost by the first quantizer
100 | ```
101 |
102 | ### Decoding discrete representations
103 | ```python
104 | # Concatenating semantic tokens (RVQ_1) and supplementary timbre tokens and then decoding
105 | wav = model.decode(torch.cat([RVQ_1, RVQ_supplement], axis=0))
106 |
107 | # Decoding from RVQ-i:j tokens from the ith quantizers to the jth quantizers
108 | wav = model.decode(codes[i: (j + 1)], st=i)
109 | ```
110 |
111 | ## Train SpeechTokenizer
112 | In the following section, we describe how to train a SpeechTokenizer model by using our trainer.
113 | ### Data Preprocess
114 | To train the SpeechTokenizer, the first step is to extract semantic teacher representations from raw audio waveforms. We provide an example of how to extract HuBERT representations in [scripts/hubert_rep_extract.sh](scripts/hubert_rep_extract.sh). We explain the arguments in the following:
115 | * `--config`: Config file path. An example is provided in [config/spt_base_cfg.json](config/spt_base_cfg.json). You can modify the `semantic_model_path` and `semantic_model_layer` parameters in this file to change the Hubert model and the target layer.
116 | * `--audio_dir`: The path to the folder containing all audio files.
117 | * `--rep_dir`: The path to the folder storing all semantic representation files.
118 | * `--exts`: The file extension of the audio files. Use ',' to separate multiple extensions if they exist.
119 | * `--split_seed`: Random seed for splitting training set and validation set.
120 | * `--valid_set_size`: The size of validation set. When this number is between 0 and 1, it represents the proportion of the total dataset used for the validation set.
121 |
122 | ### Train
123 | You can use SpeechTokenizerTrainer to train a SpeechTokenizer as follows:
124 | ```python
125 | from speechtokenizer import SpeechTokenizer, SpeechTokenizerTrainer
126 | from speechtokenizer.discriminators import MultiPeriodDiscriminator, MultiScaleDiscriminator, MultiScaleSTFTDiscriminator
127 | import json
128 |
129 |
130 | # Load model and trainer config
131 | with open('') as f:
132 | cfg = json.load(f)
133 |
134 | # Initialize SpeechTokenizer
135 | generator = SpeechTokenizer(cfg)
136 |
137 | # Initialize the discriminators. You can add any discriminator that is not yet implemented in this repository, as long as the output format remains consistent with the discriminators in `speechtokenizer.discriminators`.
138 | discriminators = {'mpd':MultiPeriodDiscriminator(), 'msd':MultiScaleDiscriminator(), 'mstftd':MultiScaleSTFTDiscriminator(32)}
139 |
140 | # Initialize Trainer
141 | trainer = SpeechTokenizerTrainer(generator=generator,
142 | discriminators=discriminators,
143 | cfg=cfg)
144 |
145 | # Start training
146 | trainer.train()
147 |
148 | # Continue training from checkpoints
149 | trainer.continue_train()
150 | ```
151 | We provide example training scripts in [scripts/train_example.sh](scripts/train_example.sh). All arguments for SpeechTokenizerTrainer are defined in [config/spt_base_cfg.json](config/spt_base_cfg.json). Below, we explain some of the important arguments:
152 | * `train_files` and `valid_files`: Training file path and validation file path. These files should be text files listing the paths of all audio files and their corresponding semantic representation files in the training/validation set. Each line should follow the format: "\t". If you use [scripts/hubert_rep_extract.sh](scripts/hubert_rep_extract.sh) to extract semantic representations, these two files will be genrated automantically.
153 | * `distill_type`: Use "d_axis" for D-axis distillation loss and "t_axis" for T-axis distillation loss, as mentioned in the paper.
154 |
155 | ### Quick Start
156 | If you want to fully follow our experimental setup, simply set `semantic_model_path` in [config/spt_base_cfg.json](config/spt_base_cfg.json), and `AUDIO_DIR`, `REP_DIR`, `EXTS` in [scripts/hubert_rep_extract.sh](scripts/hubert_rep_extract.sh), and other optional arguments , then execute the following code:
157 | ```shell
158 | cd SpeechTokenizer
159 |
160 | # Extact semantic representation
161 | bash scripts/hubert_rep_extract.sh
162 |
163 | # Train
164 | bash scripts/train_example.sh
165 | ```
166 | ## Citation
167 | If you use this code or result in your paper, please cite our work as:
168 | ```Tex
169 | @misc{zhang2023speechtokenizer,
170 | title={SpeechTokenizer: Unified Speech Tokenizer for Speech Language Models},
171 | author={Xin Zhang and Dong Zhang and Shimin Li and Yaqian Zhou and Xipeng Qiu},
172 | year={2023},
173 | eprint={2308.16692},
174 | archivePrefix={arXiv},
175 | primaryClass={cs.CL}
176 | }
177 | ```
178 | ## License
179 | The code in this repository is released under the Apache 2.0 license as found in the
180 | [LICENSE](LICENSE) file.
181 |
--------------------------------------------------------------------------------
/config/spt_base_cfg.json:
--------------------------------------------------------------------------------
1 | {
2 | "n_filters": 64,
3 | "strides": [8,5,4,2],
4 | "dimension": 1024,
5 | "semantic_dimension": 768,
6 | "bidirectional": true,
7 | "dilation_base": 2,
8 | "residual_kernel_size": 3,
9 | "n_residual_layers": 1,
10 | "lstm_layers": 2,
11 | "activation": "ELU",
12 | "codebook_size": 1024,
13 | "n_q": 8,
14 |
15 | "num_mels": 80,
16 | "num_freq": 1025,
17 | "n_fft": 1024,
18 | "hop_size": 240,
19 | "win_size": 1024,
20 | "fmin": 0,
21 | "fmax": 8000,
22 | "fmax_for_loss": null,
23 | "mel_loss_lambdas": [45, 1, 1, 1],
24 | "recon_loss_lambda": 500,
25 | "commitment_loss_lambda": 10,
26 |
27 | "distill_loss_lambda": 120,
28 | "distill_type": "d_axis",
29 |
30 | "semantic_model_path": "/remote-home/share/models/hubert/ls_960/hubert-base-ls960",
31 | "semantic_model_layer": "avg",
32 |
33 | "train_files":"train_file_list.txt",
34 | "valid_files":"dev_file_list.txt",
35 | "results_folder": "Log/spt_base",
36 | "sample_rate": 16000,
37 | "batch_size": 6,
38 | "epochs":10,
39 | "learning_rate": 1e-4,
40 | "intial_learning_rate":1e-4,
41 | "num_warmup_steps":0,
42 | "betas":[0.9, 0.99],
43 | "adam_b2": 0.9,
44 | "lr_decay": 0.98,
45 | "seed": 1234,
46 | "segment_size": 48000,
47 | "wd": 0,
48 | "num_workers": 8,
49 | "log_steps": 100,
50 | "stdout_steps": 10,
51 | "save_model_steps": 2500,
52 | "num_ckpt_keep": 6,
53 | "showpiece_num": 8
54 | }
--------------------------------------------------------------------------------
/example.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torchaudio
3 | import torch
4 | from speechtokenizer import SpeechTokenizer
5 | from scipy.io.wavfile import write
6 | import numpy as np
7 |
8 | from huggingface_hub import snapshot_download
9 |
10 | snapshot_download(repo_id="fnlp/SpeechTokenizer", local_dir="model_hub")
11 |
12 |
13 | # Set up argument parser
14 | parser = argparse.ArgumentParser(
15 | description="Load SpeechTokenizer model and process audio file."
16 | )
17 | parser.add_argument(
18 | "--config_path",
19 | type=str,
20 | help="Path to the model configuration file.",
21 | default="model_hub/speechtokenizer_hubert_avg/config.json",
22 | )
23 | parser.add_argument(
24 | "--ckpt_path",
25 | type=str,
26 | help="Path to the model checkpoint file.",
27 | default="model_hub/speechtokenizer_hubert_avg/SpeechTokenizer.pt",
28 | )
29 | parser.add_argument(
30 | "--speech_file",
31 | type=str,
32 | required=True,
33 | help="Path to the speech file to be processed.",
34 | )
35 | parser.add_argument(
36 | "--output_file",
37 | type=str,
38 | help="Path to save the output audio file.",
39 | default="example_output.wav",
40 | )
41 |
42 | args = parser.parse_args()
43 |
44 | # Load model from the specified checkpoint
45 | model = SpeechTokenizer.load_from_checkpoint(args.config_path, args.ckpt_path)
46 | model.eval()
47 |
48 | # Determine the model's expected sample rate
49 | model_sample_rate = model.sample_rate
50 |
51 | # Load and preprocess speech waveform with the model's sample rate
52 | wav, sr = torchaudio.load(args.speech_file)
53 |
54 | if sr != model_sample_rate:
55 | resample_transform = torchaudio.transforms.Resample(
56 | orig_freq=sr, new_freq=model_sample_rate
57 | )
58 | wav = resample_transform(wav)
59 |
60 | # Ensure the waveform is monophonic
61 | if wav.shape[0] > 1:
62 | wav = wav[:1, :]
63 |
64 | wav = wav.unsqueeze(0)
65 |
66 |
67 | # Extract discrete codes from SpeechTokenizer
68 | with torch.no_grad():
69 | codes = model.encode(wav) # codes: (n_q, B, T)
70 |
71 | RVQ_1 = codes[:1, :, :] # Contain content info, can be considered as semantic tokens
72 | RVQ_supplement = codes[
73 | 1:, :, :
74 | ] # Contain timbre info, complete info lost by the first quantizer
75 |
76 | # Concatenating semantic tokens (RVQ_1) and supplementary timbre tokens and then decoding
77 | wav_out = model.decode(torch.cat([RVQ_1, RVQ_supplement], axis=0))
78 |
79 | # Decoding from RVQ-i:j tokens from the ith quantizers to the jth quantizers
80 | # Example: decoding from quantizer 0 to quantizer 2
81 | wav_out = wav_out.detach().numpy()
82 | write(args.output_file, model_sample_rate, wav_out.astype(np.float32))
83 |
--------------------------------------------------------------------------------
/images/READ.me:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/images/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhangXInFD/SpeechTokenizer/30c96fb32a9fc06a2258c98119e237def051e46c/images/overview.png
--------------------------------------------------------------------------------
/images/speechtokenizer_framework.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhangXInFD/SpeechTokenizer/30c96fb32a9fc06a2258c98119e237def051e46c/images/speechtokenizer_framework.jpg
--------------------------------------------------------------------------------
/samples/example_input.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhangXInFD/SpeechTokenizer/30c96fb32a9fc06a2258c98119e237def051e46c/samples/example_input.wav
--------------------------------------------------------------------------------
/scripts/hubert_rep_extract.py:
--------------------------------------------------------------------------------
1 | from transformers import HubertModel, Wav2Vec2FeatureExtractor
2 | from pathlib import Path
3 | import torchaudio
4 | import torch
5 | import json
6 | import argparse
7 | from tqdm import tqdm
8 | import random
9 | import numpy as np
10 | import os
11 |
12 | if __name__ == '__main__':
13 |
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('--config', '-c', type=str, help='Config file path')
16 | parser.add_argument('--audio_dir', type=str, help='Audio folder path')
17 | parser.add_argument('--rep_dir', type=str, help='Path to save representation files')
18 | parser.add_argument('--exts', type=str, help="Audio file extensions, splitting with ','", default='flac')
19 | parser.add_argument('--split_seed', type=int, help="Random seed", default=0)
20 | parser.add_argument('--valid_set_size', type=float, default=1000)
21 | args = parser.parse_args()
22 | exts = args.exts.split(',')
23 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
24 | with open(args.config) as f:
25 | cfg = json.load(f)
26 | sample_rate = cfg.get('sample_rate')
27 | feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(cfg.get('semantic_model_path'))
28 | model = HubertModel.from_pretrained(cfg.get('semantic_model_path')).eval().to(device)
29 | target_layer = cfg.get('semantic_model_layer')
30 | path = Path(args.audio_dir)
31 | file_list = [str(file) for ext in exts for file in path.glob(f'**/*.{ext}')]
32 | if args.valid_set_size != 0 and args.valid_set_size < 1:
33 | valid_set_size = int(len(file_list) * args.valid_set_size)
34 | else:
35 | valid_set_size = int(args.valid_set_size)
36 | train_file_list = cfg.get('train_files')
37 | valid_file_list = cfg.get('valid_files')
38 | segment_size = cfg.get('segment_size')
39 | random.seed(args.split_seed)
40 | random.shuffle(file_list)
41 | print(f'A total of {len(file_list)} samples will be processed, and {valid_set_size} of them will be included in the validation set.')
42 | with torch.no_grad():
43 | for i, audio_file in tqdm(enumerate(file_list)):
44 | wav, sr = torchaudio.load(audio_file)
45 | if sr != sample_rate:
46 | wav = torchaudio.functional.resample(wav, sr, sample_rate)
47 | if wav.size(-1) < segment_size:
48 | wav = torch.nn.functional.pad(wav, (0, segment_size - wav.size(-1)), 'constant')
49 | input_values = feature_extractor(wav.squeeze(0), sampling_rate=sample_rate, return_tensors="pt").input_values
50 | ouput = model(input_values.to(model.device), output_hidden_states=True)
51 | if target_layer == 'avg':
52 | rep = torch.mean(torch.stack(ouput.hidden_states), axis=0)
53 | else:
54 | rep = ouput.hidden_states[target_layer]
55 | rep_file = audio_file.replace(args.audio_dir, args.rep_dir).split('.')[0] + '.hubert.npy'
56 | rep_sub_dir = '/'.join(rep_file.split('/')[:-1])
57 | if not os.path.exists(rep_sub_dir):
58 | os.makedirs(rep_sub_dir)
59 | np.save(rep_file, rep.detach().cpu().numpy())
60 | if i < valid_set_size:
61 | with open(valid_file_list, 'a+') as f:
62 | f.write(f'{audio_file}\t{rep_file}\n')
63 | else:
64 | with open(train_file_list, 'a+') as f:
65 | f.write(f'{audio_file}\t{rep_file}\n')
66 |
67 |
68 |
69 |
--------------------------------------------------------------------------------
/scripts/hubert_rep_extract.sh:
--------------------------------------------------------------------------------
1 | CONFIG="config/spt_base_cfg.json"
2 | AUDIO_DIR="/remote-home/share/data/SpeechPretrain/LibriSpeech/LibriSpeech"
3 | REP_DIR="/remote-home/share/data/SpeechPretrain/hubert_rep/LibriSpeech"
4 | EXTS="flac"
5 | SPLIT_SEED=0
6 | VALID_SET_SIZE=1500
7 |
8 |
9 |
10 | CUDA_VISIBLE_DEVICES=0 python scripts/hubert_rep_extract.py\
11 | --config ${CONFIG}\
12 | --audio_dir ${AUDIO_DIR}\
13 | --rep_dir ${REP_DIR}\
14 | --exts ${EXTS}\
15 | --split_seed ${SPLIT_SEED}\
16 | --valid_set_size ${VALID_SET_SIZE}
--------------------------------------------------------------------------------
/scripts/train_example.py:
--------------------------------------------------------------------------------
1 | from speechtokenizer import SpeechTokenizer, SpeechTokenizerTrainer
2 | from speechtokenizer.discriminators import MultiPeriodDiscriminator, MultiScaleDiscriminator, MultiScaleSTFTDiscriminator
3 | import json
4 | import argparse
5 |
6 |
7 | if __name__ == '__main__':
8 |
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--config', '-c', type=str, help='Config file path')
11 | parser.add_argument('--continue_train', action='store_true', help='Continue to train from checkpoints')
12 | args = parser.parse_args()
13 | with open(args.config) as f:
14 | cfg = json.load(f)
15 |
16 | generator = SpeechTokenizer(cfg)
17 | discriminators = {'mpd':MultiPeriodDiscriminator(), 'msd':MultiScaleDiscriminator(), 'mstftd':MultiScaleSTFTDiscriminator(32)}
18 |
19 | trainer = SpeechTokenizerTrainer(generator=generator,
20 | discriminators=discriminators,
21 | cfg=cfg)
22 |
23 | if args.continue_train:
24 | trainer.continue_train()
25 | else:
26 | trainer.train()
--------------------------------------------------------------------------------
/scripts/train_example.sh:
--------------------------------------------------------------------------------
1 |
2 | CONFIG="config/spt_base_cfg.json"
3 |
4 |
5 | # NPROC_PER_NODE=4
6 | # CUDA_VISIBLE_DEVICES=1,2,6,7 torchrun \
7 | # --nnode 1 \
8 | # --nproc_per_node $NPROC_PER_NODE \
9 | # --master_port 50025 \
10 | # train_example.py \
11 | # --config ${CONFIG} \
12 |
13 | CUDA_VISIBLE_DEVICES=1,2,6,7 accelerate launch scripts/train_example.py\
14 | --config ${CONFIG}\
15 | --continue_train
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from setuptools import setup
3 |
4 | NAME = 'speechtokenizer'
5 | DESCRIPTION = 'Unified speech tokenizer for speech language model'
6 | URL = 'https://github.com/ZhangXInFD/SpeechTokenizer'
7 | EMAIL = 'xin_zhang22@m.fudan.edu.cn'
8 | AUTHOR = 'Xin Zhang, Dong Zhang, Shimin Li, Yaqian Zhou, Xipeng Qiu'
9 | REQUIRES_PYTHON = '>=3.8.0'
10 |
11 | for line in open('speechtokenizer/__init__.py'):
12 | line = line.strip()
13 | if '__version__' in line:
14 | context = {}
15 | exec(line, context)
16 | VERSION = context['__version__']
17 |
18 | HERE = Path(__file__).parent
19 |
20 | try:
21 | with open(HERE / "README.md", encoding='utf-8') as f:
22 | long_description = '\n' + f.read()
23 | except FileNotFoundError:
24 | long_description = DESCRIPTION
25 |
26 | setup(
27 | name=NAME,
28 | version=VERSION,
29 | description=DESCRIPTION,
30 | long_description=long_description,
31 | long_description_content_type='text/markdown',
32 | author=AUTHOR,
33 | author_email=EMAIL,
34 | python_requires=REQUIRES_PYTHON,
35 | url=URL,
36 | packages=['speechtokenizer', 'speechtokenizer.quantization', 'speechtokenizer.modules', 'speechtokenizer.trainer'],
37 | install_requires=['numpy', 'torch', 'torchaudio', 'einops','scipy','huggingface-hub','soundfile', 'matplotlib', 'lion_pytorch', 'accelerate'],
38 | include_package_data=True,
39 | license='Apache License 2.0',
40 | classifiers=[
41 | 'Topic :: Multimedia :: Sound/Audio',
42 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
43 | 'License :: OSI Approved :: Apache Software License',
44 | ])
45 |
--------------------------------------------------------------------------------
/speechtokenizer/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import SpeechTokenizer
2 | from .trainer import SpeechTokenizerTrainer
3 |
4 | __version__ = '1.0.0'
--------------------------------------------------------------------------------
/speechtokenizer/discriminators.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 | from torch.nn import Conv1d, AvgPool1d, Conv2d
5 | from torch.nn.utils import weight_norm, spectral_norm
6 | import typing as tp
7 | import torchaudio
8 | from einops import rearrange
9 | from .modules import NormConv2d
10 |
11 | LRELU_SLOPE = 0.1
12 |
13 | def get_padding(kernel_size, dilation=1):
14 | return int((kernel_size*dilation - dilation)/2)
15 |
16 | def init_weights(m, mean=0.0, std=0.01):
17 | classname = m.__class__.__name__
18 | if classname.find("Conv") != -1:
19 | m.weight.data.normal_(mean, std)
20 |
21 |
22 | class DiscriminatorP(torch.nn.Module):
23 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
24 | super(DiscriminatorP, self).__init__()
25 | self.period = period
26 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm
27 | self.convs = nn.ModuleList([
28 | norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
29 | norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
30 | norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
31 | norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
32 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
33 | ])
34 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
35 |
36 | def forward(self, x):
37 | fmap = []
38 |
39 | # 1d to 2d
40 | b, c, t = x.shape
41 | if t % self.period != 0: # pad first
42 | n_pad = self.period - (t % self.period)
43 | x = F.pad(x, (0, n_pad), "reflect")
44 | t = t + n_pad
45 | x = x.view(b, c, t // self.period, self.period)
46 |
47 | for l in self.convs:
48 | x = l(x)
49 | x = F.leaky_relu(x, LRELU_SLOPE)
50 | fmap.append(x)
51 | x = self.conv_post(x)
52 | fmap.append(x)
53 | x = torch.flatten(x, 1, -1)
54 |
55 | return x, fmap
56 |
57 |
58 | class MultiPeriodDiscriminator(torch.nn.Module):
59 | def __init__(self):
60 | super(MultiPeriodDiscriminator, self).__init__()
61 | self.discriminators = nn.ModuleList([
62 | DiscriminatorP(2),
63 | DiscriminatorP(3),
64 | DiscriminatorP(5),
65 | DiscriminatorP(7),
66 | DiscriminatorP(11),
67 | ])
68 |
69 | def forward(self, y, y_hat):
70 | y_d_rs = []
71 | y_d_gs = []
72 | fmap_rs = []
73 | fmap_gs = []
74 | for i, d in enumerate(self.discriminators):
75 | y_d_r, fmap_r = d(y)
76 | y_d_g, fmap_g = d(y_hat)
77 | y_d_rs.append(y_d_r)
78 | fmap_rs.append(fmap_r)
79 | y_d_gs.append(y_d_g)
80 | fmap_gs.append(fmap_g)
81 |
82 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs
83 |
84 |
85 | class DiscriminatorS(torch.nn.Module):
86 | def __init__(self, use_spectral_norm=False):
87 | super(DiscriminatorS, self).__init__()
88 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm
89 | self.convs = nn.ModuleList([
90 | norm_f(Conv1d(1, 128, 15, 1, padding=7)),
91 | norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
92 | norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
93 | norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
94 | norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
95 | norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
96 | norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
97 | ])
98 | self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
99 |
100 | def forward(self, x):
101 | fmap = []
102 | for l in self.convs:
103 | x = l(x)
104 | x = F.leaky_relu(x, LRELU_SLOPE)
105 | fmap.append(x)
106 | x = self.conv_post(x)
107 | fmap.append(x)
108 | x = torch.flatten(x, 1, -1)
109 |
110 | return x, fmap
111 |
112 |
113 | class MultiScaleDiscriminator(torch.nn.Module):
114 | def __init__(self):
115 | super(MultiScaleDiscriminator, self).__init__()
116 | self.discriminators = nn.ModuleList([
117 | DiscriminatorS(use_spectral_norm=True),
118 | DiscriminatorS(),
119 | DiscriminatorS(),
120 | ])
121 | self.meanpools = nn.ModuleList([
122 | AvgPool1d(4, 2, padding=2),
123 | AvgPool1d(4, 2, padding=2)
124 | ])
125 |
126 | def forward(self, y, y_hat):
127 | y_d_rs = []
128 | y_d_gs = []
129 | fmap_rs = []
130 | fmap_gs = []
131 | for i, d in enumerate(self.discriminators):
132 | if i != 0:
133 | y = self.meanpools[i-1](y)
134 | y_hat = self.meanpools[i-1](y_hat)
135 | y_d_r, fmap_r = d(y)
136 | y_d_g, fmap_g = d(y_hat)
137 | y_d_rs.append(y_d_r)
138 | fmap_rs.append(fmap_r)
139 | y_d_gs.append(y_d_g)
140 | fmap_gs.append(fmap_g)
141 |
142 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs
143 |
144 |
145 | FeatureMapType = tp.List[torch.Tensor]
146 | LogitsType = torch.Tensor
147 | DiscriminatorOutput = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]]
148 |
149 |
150 | def get_2d_padding(kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)):
151 | return (((kernel_size[0] - 1) * dilation[0]) // 2, ((kernel_size[1] - 1) * dilation[1]) // 2)
152 |
153 |
154 | class DiscriminatorSTFT(nn.Module):
155 | """STFT sub-discriminator.
156 | Args:
157 | filters (int): Number of filters in convolutions
158 | in_channels (int): Number of input channels. Default: 1
159 | out_channels (int): Number of output channels. Default: 1
160 | n_fft (int): Size of FFT for each scale. Default: 1024
161 | hop_length (int): Length of hop between STFT windows for each scale. Default: 256
162 | kernel_size (tuple of int): Inner Conv2d kernel sizes. Default: ``(3, 9)``
163 | stride (tuple of int): Inner Conv2d strides. Default: ``(1, 2)``
164 | dilations (list of int): Inner Conv2d dilation on the time dimension. Default: ``[1, 2, 4]``
165 | win_length (int): Window size for each scale. Default: 1024
166 | normalized (bool): Whether to normalize by magnitude after stft. Default: True
167 | norm (str): Normalization method. Default: `'weight_norm'`
168 | activation (str): Activation function. Default: `'LeakyReLU'`
169 | activation_params (dict): Parameters to provide to the activation function.
170 | growth (int): Growth factor for the filters. Default: 1
171 | """
172 | def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1,
173 | n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, max_filters: int = 1024,
174 | filters_scale: int = 1, kernel_size: tp.Tuple[int, int] = (3, 9), dilations: tp.List = [1, 2, 4],
175 | stride: tp.Tuple[int, int] = (1, 2), normalized: bool = True, norm: str = 'weight_norm',
176 | activation: str = 'LeakyReLU', activation_params: dict = {'negative_slope': 0.2}):
177 | super().__init__()
178 | assert len(kernel_size) == 2
179 | assert len(stride) == 2
180 | self.filters = filters
181 | self.in_channels = in_channels
182 | self.out_channels = out_channels
183 | self.n_fft = n_fft
184 | self.hop_length = hop_length
185 | self.win_length = win_length
186 | self.normalized = normalized
187 | self.activation = getattr(torch.nn, activation)(**activation_params)
188 | self.spec_transform = torchaudio.transforms.Spectrogram(
189 | n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window_fn=torch.hann_window,
190 | normalized=self.normalized, center=False, pad_mode=None, power=None)
191 | spec_channels = 2 * self.in_channels
192 | self.convs = nn.ModuleList()
193 | self.convs.append(
194 | NormConv2d(spec_channels, self.filters, kernel_size=kernel_size, padding=get_2d_padding(kernel_size))
195 | )
196 | in_chs = min(filters_scale * self.filters, max_filters)
197 | for i, dilation in enumerate(dilations):
198 | out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters)
199 | self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride,
200 | dilation=(dilation, 1), padding=get_2d_padding(kernel_size, (dilation, 1)),
201 | norm=norm))
202 | in_chs = out_chs
203 | out_chs = min((filters_scale ** (len(dilations) + 1)) * self.filters, max_filters)
204 | self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_size[0], kernel_size[0]),
205 | padding=get_2d_padding((kernel_size[0], kernel_size[0])),
206 | norm=norm))
207 | self.conv_post = NormConv2d(out_chs, self.out_channels,
208 | kernel_size=(kernel_size[0], kernel_size[0]),
209 | padding=get_2d_padding((kernel_size[0], kernel_size[0])),
210 | norm=norm)
211 |
212 | def forward(self, x: torch.Tensor):
213 | fmap = []
214 | # print('x ', x.shape)
215 | z = self.spec_transform(x) # [B, 2, Freq, Frames, 2]
216 | # print('z ', z.shape)
217 | z = torch.cat([z.real, z.imag], dim=1)
218 | # print('cat_z ', z.shape)
219 | z = rearrange(z, 'b c w t -> b c t w')
220 | for i, layer in enumerate(self.convs):
221 | z = layer(z)
222 | z = self.activation(z)
223 | # print('z i', i, z.shape)
224 | fmap.append(z)
225 | z = self.conv_post(z)
226 | # print('logit ', z.shape)
227 | return z, fmap
228 |
229 |
230 | class MultiScaleSTFTDiscriminator(nn.Module):
231 | """Multi-Scale STFT (MS-STFT) discriminator.
232 | Args:
233 | filters (int): Number of filters in convolutions
234 | in_channels (int): Number of input channels. Default: 1
235 | out_channels (int): Number of output channels. Default: 1
236 | n_ffts (Sequence[int]): Size of FFT for each scale
237 | hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale
238 | win_lengths (Sequence[int]): Window size for each scale
239 | **kwargs: additional args for STFTDiscriminator
240 | """
241 | def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1,
242 | n_ffts: tp.List[int] = [1024, 2048, 512, 256, 128], hop_lengths: tp.List[int] = [256, 512, 128, 64, 32],
243 | win_lengths: tp.List[int] = [1024, 2048, 512, 256, 128], **kwargs):
244 | super().__init__()
245 | assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
246 | self.discriminators = nn.ModuleList([
247 | DiscriminatorSTFT(filters, in_channels=in_channels, out_channels=out_channels,
248 | n_fft=n_ffts[i], win_length=win_lengths[i], hop_length=hop_lengths[i], **kwargs)
249 | for i in range(len(n_ffts))
250 | ])
251 | self.num_discriminators = len(self.discriminators)
252 |
253 | def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> DiscriminatorOutput:
254 | logits = []
255 | logits_fake = []
256 | fmaps = []
257 | fmaps_fake = []
258 | for disc in self.discriminators:
259 | logit, fmap = disc(y)
260 | logits.append(logit)
261 | fmaps.append(fmap)
262 | logit_fake, fmap_fake = disc(y_hat)
263 | logits_fake.append(logit_fake)
264 | fmaps_fake.append(fmap_fake)
265 | return logits, logits_fake, fmaps, fmaps_fake
--------------------------------------------------------------------------------
/speechtokenizer/model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Wed Aug 30 15:47:55 2023
4 | @author: zhangxin
5 | """
6 |
7 | from .modules.seanet import SEANetEncoder, SEANetDecoder
8 | from .quantization import ResidualVectorQuantizer
9 | import torch.nn as nn
10 | from einops import rearrange
11 | import torch
12 | import numpy as np
13 |
14 | class SpeechTokenizer(nn.Module):
15 | def __init__(self, config):
16 | '''
17 |
18 | Parameters
19 | ----------
20 | config : json
21 | Model Config.
22 |
23 | '''
24 | super().__init__()
25 | self.encoder = SEANetEncoder(n_filters=config.get('n_filters'),
26 | dimension=config.get('dimension'),
27 | ratios=config.get('strides'),
28 | lstm=config.get('lstm_layers'),
29 | bidirectional=config.get('bidirectional'),
30 | dilation_base=config.get('dilation_base'),
31 | residual_kernel_size=config.get('residual_kernel_size'),
32 | n_residual_layers=config.get('n_residual_layers'),
33 | activation=config.get('activation'))
34 | self.sample_rate = config.get('sample_rate')
35 | self.n_q = config.get('n_q')
36 | self.downsample_rate = np.prod(config.get('strides'))
37 | if config.get('dimension') != config.get('semantic_dimension'):
38 | self.transform = nn.Linear(config.get('dimension'), config.get('semantic_dimension'))
39 | else:
40 | self.transform = nn.Identity()
41 | self.quantizer = ResidualVectorQuantizer(dimension=config.get('dimension'), n_q=config.get('n_q'), bins=config.get('codebook_size'))
42 | self.decoder = SEANetDecoder(n_filters=config.get('n_filters'),
43 | dimension=config.get('dimension'),
44 | ratios=config.get('strides'),
45 | lstm=config.get('lstm_layers'),
46 | bidirectional=False,
47 | dilation_base=config.get('dilation_base'),
48 | residual_kernel_size=config.get('residual_kernel_size'),
49 | n_residual_layers=config.get('n_residual_layers'),
50 | activation=config.get('activation'))
51 |
52 | @classmethod
53 | def load_from_checkpoint(cls,
54 | config_path: str,
55 | ckpt_path: str):
56 | '''
57 |
58 | Parameters
59 | ----------
60 | config_path : str
61 | Path of model configuration file.
62 | ckpt_path : str
63 | Path of model checkpoint.
64 |
65 | Returns
66 | -------
67 | model : SpeechTokenizer
68 | SpeechTokenizer model.
69 |
70 | '''
71 | import json
72 | with open(config_path) as f:
73 | cfg = json.load(f)
74 | model = cls(cfg)
75 | params = torch.load(ckpt_path, map_location='cpu')
76 | model.load_state_dict(params)
77 | return model
78 |
79 |
80 | def forward(self,
81 | x: torch.tensor,
82 | n_q: int=None,
83 | layers: list=[0]):
84 | '''
85 |
86 | Parameters
87 | ----------
88 | x : torch.tensor
89 | Input wavs. Shape: (batch, channels, timesteps).
90 | n_q : int, optional
91 | Number of quantizers in RVQ used to encode. The default is all layers.
92 | layers : list[int], optional
93 | Layers of RVQ should return quantized result. The default is the first layer.
94 |
95 | Returns
96 | -------
97 | o : torch.tensor
98 | Output wavs. Shape: (batch, channels, timesteps).
99 | commit_loss : torch.tensor
100 | Commitment loss from residual vector quantizers.
101 | feature : torch.tensor
102 | Output of RVQ's first layer. Shape: (batch, timesteps, dimension)
103 |
104 | '''
105 | n_q = n_q if n_q else self.n_q
106 | e = self.encoder(x)
107 | quantized, codes, commit_loss, quantized_list = self.quantizer(e, n_q=n_q, layers=layers)
108 | feature = rearrange(quantized_list[0], 'b d t -> b t d')
109 | feature = self.transform(feature)
110 | o = self.decoder(quantized)
111 | return o, commit_loss, feature
112 |
113 | def forward_feature(self,
114 | x: torch.tensor,
115 | layers: list=None):
116 | '''
117 |
118 | Parameters
119 | ----------
120 | x : torch.tensor
121 | Input wavs. Shape should be (batch, channels, timesteps).
122 | layers : list[int], optional
123 | Layers of RVQ should return quantized result. The default is all layers.
124 |
125 | Returns
126 | -------
127 | quantized_list : list[torch.tensor]
128 | Quantized of required layers.
129 |
130 | '''
131 | e = self.encoder(x)
132 | layers = layers if layers else list(range(self.n_q))
133 | quantized, codes, commit_loss, quantized_list = self.quantizer(e, layers=layers)
134 | return quantized_list
135 |
136 | def encode(self,
137 | x: torch.tensor,
138 | n_q: int=None,
139 | st: int=None):
140 | '''
141 |
142 | Parameters
143 | ----------
144 | x : torch.tensor
145 | Input wavs. Shape: (batch, channels, timesteps).
146 | n_q : int, optional
147 | Number of quantizers in RVQ used to encode. The default is all layers.
148 | st : int, optional
149 | Start quantizer index in RVQ. The default is 0.
150 |
151 | Returns
152 | -------
153 | codes : torch.tensor
154 | Output indices for each quantizer. Shape: (n_q, batch, timesteps)
155 |
156 | '''
157 | e = self.encoder(x)
158 | if st is None:
159 | st = 0
160 | n_q = n_q if n_q else self.n_q
161 | codes = self.quantizer.encode(e, n_q=n_q, st=st)
162 | return codes
163 |
164 | def decode(self,
165 | codes: torch.tensor,
166 | st: int=0):
167 | '''
168 |
169 | Parameters
170 | ----------
171 | codes : torch.tensor
172 | Indices for each quantizer. Shape: (n_q, batch, timesteps).
173 | st : int, optional
174 | Start quantizer index in RVQ. The default is 0.
175 |
176 | Returns
177 | -------
178 | o : torch.tensor
179 | Reconstruct wavs from codes. Shape: (batch, channels, timesteps)
180 |
181 | '''
182 | quantized = self.quantizer.decode(codes, st=st)
183 | o = self.decoder(quantized)
184 | return o
--------------------------------------------------------------------------------
/speechtokenizer/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """Torch modules."""
8 |
9 | # flake8: noqa
10 | from .conv import (
11 | pad1d,
12 | unpad1d,
13 | NormConv1d,
14 | NormConvTranspose1d,
15 | NormConv2d,
16 | NormConvTranspose2d,
17 | SConv1d,
18 | SConvTranspose1d,
19 | )
20 | from .lstm import SLSTM
21 | from .seanet import SEANetEncoder, SEANetDecoder
--------------------------------------------------------------------------------
/speechtokenizer/modules/conv.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """Convolutional layers wrappers and utilities."""
8 |
9 | import math
10 | import typing as tp
11 | import warnings
12 |
13 | import torch
14 | from torch import nn
15 | from torch.nn import functional as F
16 | from torch.nn.utils import spectral_norm, weight_norm
17 |
18 | from .norm import ConvLayerNorm
19 |
20 |
21 | CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
22 | 'time_layer_norm', 'layer_norm', 'time_group_norm'])
23 |
24 |
25 | def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module:
26 | assert norm in CONV_NORMALIZATIONS
27 | if norm == 'weight_norm':
28 | return weight_norm(module)
29 | elif norm == 'spectral_norm':
30 | return spectral_norm(module)
31 | else:
32 | # We already check was in CONV_NORMALIZATION, so any other choice
33 | # doesn't need reparametrization.
34 | return module
35 |
36 |
37 | def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module:
38 | """Return the proper normalization module. If causal is True, this will ensure the returned
39 | module is causal, or return an error if the normalization doesn't support causal evaluation.
40 | """
41 | assert norm in CONV_NORMALIZATIONS
42 | if norm == 'layer_norm':
43 | assert isinstance(module, nn.modules.conv._ConvNd)
44 | return ConvLayerNorm(module.out_channels, **norm_kwargs)
45 | elif norm == 'time_group_norm':
46 | if causal:
47 | raise ValueError("GroupNorm doesn't support causal evaluation.")
48 | assert isinstance(module, nn.modules.conv._ConvNd)
49 | return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
50 | else:
51 | return nn.Identity()
52 |
53 |
54 | def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
55 | padding_total: int = 0) -> int:
56 | """See `pad_for_conv1d`.
57 | """
58 | length = x.shape[-1]
59 | n_frames = (length - kernel_size + padding_total) / stride + 1
60 | ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
61 | return ideal_length - length
62 |
63 |
64 | def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
65 | """Pad for a convolution to make sure that the last window is full.
66 | Extra padding is added at the end. This is required to ensure that we can rebuild
67 | an output of the same length, as otherwise, even with padding, some time steps
68 | might get removed.
69 | For instance, with total padding = 4, kernel size = 4, stride = 2:
70 | 0 0 1 2 3 4 5 0 0 # (0s are padding)
71 | 1 2 3 # (output frames of a convolution, last 0 is never used)
72 | 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
73 | 1 2 3 4 # once you removed padding, we are missing one time step !
74 | """
75 | extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
76 | return F.pad(x, (0, extra_padding))
77 |
78 |
79 | def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.):
80 | """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
81 | If this is the case, we insert extra 0 padding to the right before the reflection happen.
82 | """
83 | length = x.shape[-1]
84 | padding_left, padding_right = paddings
85 | assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
86 | if mode == 'reflect':
87 | max_pad = max(padding_left, padding_right)
88 | extra_pad = 0
89 | if length <= max_pad:
90 | extra_pad = max_pad - length + 1
91 | x = F.pad(x, (0, extra_pad))
92 | padded = F.pad(x, paddings, mode, value)
93 | end = padded.shape[-1] - extra_pad
94 | return padded[..., :end]
95 | else:
96 | return F.pad(x, paddings, mode, value)
97 |
98 |
99 | def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
100 | """Remove padding from x, handling properly zero padding. Only for 1d!"""
101 | padding_left, padding_right = paddings
102 | assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
103 | assert (padding_left + padding_right) <= x.shape[-1]
104 | end = x.shape[-1] - padding_right
105 | return x[..., padding_left: end]
106 |
107 |
108 | class NormConv1d(nn.Module):
109 | """Wrapper around Conv1d and normalization applied to this conv
110 | to provide a uniform interface across normalization approaches.
111 | """
112 | def __init__(self, *args, causal: bool = False, norm: str = 'none',
113 | norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
114 | super().__init__()
115 | self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
116 | self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
117 | self.norm_type = norm
118 |
119 | def forward(self, x):
120 | x = self.conv(x)
121 | x = self.norm(x)
122 | return x
123 |
124 |
125 | class NormConv2d(nn.Module):
126 | """Wrapper around Conv2d and normalization applied to this conv
127 | to provide a uniform interface across normalization approaches.
128 | """
129 | def __init__(self, *args, norm: str = 'none',
130 | norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
131 | super().__init__()
132 | self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
133 | self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
134 | self.norm_type = norm
135 |
136 | def forward(self, x):
137 | x = self.conv(x)
138 | x = self.norm(x)
139 | return x
140 |
141 |
142 | class NormConvTranspose1d(nn.Module):
143 | """Wrapper around ConvTranspose1d and normalization applied to this conv
144 | to provide a uniform interface across normalization approaches.
145 | """
146 | def __init__(self, *args, causal: bool = False, norm: str = 'none',
147 | norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
148 | super().__init__()
149 | self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
150 | self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
151 | self.norm_type = norm
152 |
153 | def forward(self, x):
154 | x = self.convtr(x)
155 | x = self.norm(x)
156 | return x
157 |
158 |
159 | class NormConvTranspose2d(nn.Module):
160 | """Wrapper around ConvTranspose2d and normalization applied to this conv
161 | to provide a uniform interface across normalization approaches.
162 | """
163 | def __init__(self, *args, norm: str = 'none',
164 | norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
165 | super().__init__()
166 | self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
167 | self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
168 |
169 | def forward(self, x):
170 | x = self.convtr(x)
171 | x = self.norm(x)
172 | return x
173 |
174 |
175 | class SConv1d(nn.Module):
176 | """Conv1d with some builtin handling of asymmetric or causal padding
177 | and normalization.
178 | """
179 | def __init__(self, in_channels: int, out_channels: int,
180 | kernel_size: int, stride: int = 1, dilation: int = 1,
181 | groups: int = 1, bias: bool = True, causal: bool = False,
182 | norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
183 | pad_mode: str = 'reflect'):
184 | super().__init__()
185 | # warn user on unusual setup between dilation and stride
186 | if stride > 1 and dilation > 1:
187 | warnings.warn('SConv1d has been initialized with stride > 1 and dilation > 1'
188 | f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).')
189 | self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
190 | dilation=dilation, groups=groups, bias=bias, causal=causal,
191 | norm=norm, norm_kwargs=norm_kwargs)
192 | self.causal = causal
193 | self.pad_mode = pad_mode
194 |
195 | def forward(self, x):
196 | B, C, T = x.shape
197 | kernel_size = self.conv.conv.kernel_size[0]
198 | stride = self.conv.conv.stride[0]
199 | dilation = self.conv.conv.dilation[0]
200 | padding_total = (kernel_size - 1) * dilation - (stride - 1)
201 | extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
202 | if self.causal:
203 | # Left padding for causal
204 | x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
205 | else:
206 | # Asymmetric padding required for odd strides
207 | padding_right = padding_total // 2
208 | padding_left = padding_total - padding_right
209 | x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
210 | return self.conv(x)
211 |
212 |
213 | class SConvTranspose1d(nn.Module):
214 | """ConvTranspose1d with some builtin handling of asymmetric or causal padding
215 | and normalization.
216 | """
217 | def __init__(self, in_channels: int, out_channels: int,
218 | kernel_size: int, stride: int = 1, causal: bool = False,
219 | norm: str = 'none', trim_right_ratio: float = 1.,
220 | norm_kwargs: tp.Dict[str, tp.Any] = {}):
221 | super().__init__()
222 | self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
223 | causal=causal, norm=norm, norm_kwargs=norm_kwargs)
224 | self.causal = causal
225 | self.trim_right_ratio = trim_right_ratio
226 | assert self.causal or self.trim_right_ratio == 1., \
227 | "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
228 | assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
229 |
230 | def forward(self, x):
231 | kernel_size = self.convtr.convtr.kernel_size[0]
232 | stride = self.convtr.convtr.stride[0]
233 | padding_total = kernel_size - stride
234 |
235 | y = self.convtr(x)
236 |
237 | # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
238 | # removed at the very end, when keeping only the right length for the output,
239 | # as removing it here would require also passing the length at the matching layer
240 | # in the encoder.
241 | if self.causal:
242 | # Trim the padding on the right according to the specified ratio
243 | # if trim_right_ratio = 1.0, trim everything from right
244 | padding_right = math.ceil(padding_total * self.trim_right_ratio)
245 | padding_left = padding_total - padding_right
246 | y = unpad1d(y, (padding_left, padding_right))
247 | else:
248 | # Asymmetric padding required for odd strides
249 | padding_right = padding_total // 2
250 | padding_left = padding_total - padding_right
251 | y = unpad1d(y, (padding_left, padding_right))
252 | return y
253 |
--------------------------------------------------------------------------------
/speechtokenizer/modules/lstm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """LSTM layers module."""
8 |
9 | from torch import nn
10 |
11 |
12 | class SLSTM(nn.Module):
13 | """
14 | LSTM without worrying about the hidden state, nor the layout of the data.
15 | Expects input as convolutional layout.
16 | """
17 | def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True, bidirectional: bool=False):
18 | super().__init__()
19 | self.bidirectional = bidirectional
20 | self.skip = skip
21 | self.lstm = nn.LSTM(dimension, dimension, num_layers, bidirectional=bidirectional)
22 |
23 | def forward(self, x):
24 | x = x.permute(2, 0, 1)
25 | y, _ = self.lstm(x)
26 | if self.bidirectional:
27 | x = x.repeat(1, 1, 2)
28 | if self.skip:
29 | y = y + x
30 | y = y.permute(1, 2, 0)
31 | return y
32 |
--------------------------------------------------------------------------------
/speechtokenizer/modules/norm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """Normalization modules."""
8 |
9 | import typing as tp
10 |
11 | import einops
12 | import torch
13 | from torch import nn
14 |
15 |
16 | class ConvLayerNorm(nn.LayerNorm):
17 | """
18 | Convolution-friendly LayerNorm that moves channels to last dimensions
19 | before running the normalization and moves them back to original position right after.
20 | """
21 | def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs):
22 | super().__init__(normalized_shape, **kwargs)
23 |
24 | def forward(self, x):
25 | x = einops.rearrange(x, 'b ... t -> b t ...')
26 | x = super().forward(x)
27 | x = einops.rearrange(x, 'b t ... -> b ... t')
28 | return
29 |
--------------------------------------------------------------------------------
/speechtokenizer/modules/seanet.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """Encodec SEANet-based encoder and decoder implementation."""
8 |
9 | import typing as tp
10 |
11 | import numpy as np
12 | import torch.nn as nn
13 | import torch
14 |
15 | from . import (
16 | SConv1d,
17 | SConvTranspose1d,
18 | SLSTM
19 | )
20 |
21 |
22 | @torch.jit.script
23 | def snake(x, alpha):
24 | shape = x.shape
25 | x = x.reshape(shape[0], shape[1], -1)
26 | x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
27 | x = x.reshape(shape)
28 | return x
29 |
30 |
31 | class Snake1d(nn.Module):
32 | def __init__(self, channels):
33 | super().__init__()
34 | self.alpha = nn.Parameter(torch.ones(1, channels, 1))
35 |
36 | def forward(self, x):
37 | return snake(x, self.alpha)
38 |
39 | class SEANetResnetBlock(nn.Module):
40 | """Residual block from SEANet model.
41 | Args:
42 | dim (int): Dimension of the input/output
43 | kernel_sizes (list): List of kernel sizes for the convolutions.
44 | dilations (list): List of dilations for the convolutions.
45 | activation (str): Activation function.
46 | activation_params (dict): Parameters to provide to the activation function
47 | norm (str): Normalization method.
48 | norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
49 | causal (bool): Whether to use fully causal convolution.
50 | pad_mode (str): Padding mode for the convolutions.
51 | compress (int): Reduced dimensionality in residual branches (from Demucs v3)
52 | true_skip (bool): Whether to use true skip connection or a simple convolution as the skip connection.
53 | """
54 | def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1],
55 | activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
56 | norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False,
57 | pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True):
58 | super().__init__()
59 | assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations'
60 | act = getattr(nn, activation) if activation != 'Snake' else Snake1d
61 | hidden = dim // compress
62 | block = []
63 | for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
64 | in_chs = dim if i == 0 else hidden
65 | out_chs = dim if i == len(kernel_sizes) - 1 else hidden
66 | block += [
67 | act(**activation_params) if activation != 'Snake' else act(in_chs),
68 | SConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation,
69 | norm=norm, norm_kwargs=norm_params,
70 | causal=causal, pad_mode=pad_mode),
71 | ]
72 | self.block = nn.Sequential(*block)
73 | self.shortcut: nn.Module
74 | if true_skip:
75 | self.shortcut = nn.Identity()
76 | else:
77 | self.shortcut = SConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params,
78 | causal=causal, pad_mode=pad_mode)
79 |
80 | def forward(self, x):
81 | return self.shortcut(x) + self.block(x)
82 |
83 |
84 |
85 | class SEANetEncoder(nn.Module):
86 | """SEANet encoder.
87 | Args:
88 | channels (int): Audio channels.
89 | dimension (int): Intermediate representation dimension.
90 | n_filters (int): Base width for the model.
91 | n_residual_layers (int): nb of residual layers.
92 | ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
93 | upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
94 | that must match the decoder order
95 | activation (str): Activation function.
96 | activation_params (dict): Parameters to provide to the activation function
97 | norm (str): Normalization method.
98 | norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
99 | kernel_size (int): Kernel size for the initial convolution.
100 | last_kernel_size (int): Kernel size for the initial convolution.
101 | residual_kernel_size (int): Kernel size for the residual layers.
102 | dilation_base (int): How much to increase the dilation with each layer.
103 | causal (bool): Whether to use fully causal convolution.
104 | pad_mode (str): Padding mode for the convolutions.
105 | true_skip (bool): Whether to use true skip connection or a simple
106 | (streamable) convolution as the skip connection in the residual network blocks.
107 | compress (int): Reduced dimensionality in residual branches (from Demucs v3).
108 | lstm (int): Number of LSTM layers at the end of the encoder.
109 | """
110 | def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 1,
111 | ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
112 | norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
113 | last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
114 | pad_mode: str = 'reflect', true_skip: bool = False, compress: int = 2, lstm: int = 2, bidirectional:bool = False):
115 | super().__init__()
116 | self.channels = channels
117 | self.dimension = dimension
118 | self.n_filters = n_filters
119 | self.ratios = list(reversed(ratios))
120 | del ratios
121 | self.n_residual_layers = n_residual_layers
122 | self.hop_length = np.prod(self.ratios) # 计算乘积
123 |
124 | act = getattr(nn, activation) if activation != 'Snake' else Snake1d
125 | mult = 1
126 | model: tp.List[nn.Module] = [
127 | SConv1d(channels, mult * n_filters, kernel_size, norm=norm, norm_kwargs=norm_params,
128 | causal=causal, pad_mode=pad_mode)
129 | ]
130 | # Downsample to raw audio scale
131 | for i, ratio in enumerate(self.ratios):
132 | # Add residual layers
133 | for j in range(n_residual_layers):
134 | model += [
135 | SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1],
136 | dilations=[dilation_base ** j, 1],
137 | norm=norm, norm_params=norm_params,
138 | activation=activation, activation_params=activation_params,
139 | causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
140 |
141 | # Add downsampling layers
142 | model += [
143 | act(**activation_params) if activation != 'Snake' else act(mult * n_filters),
144 | SConv1d(mult * n_filters, mult * n_filters * 2,
145 | kernel_size=ratio * 2, stride=ratio,
146 | norm=norm, norm_kwargs=norm_params,
147 | causal=causal, pad_mode=pad_mode),
148 | ]
149 | mult *= 2
150 |
151 | if lstm:
152 | model += [SLSTM(mult * n_filters, num_layers=lstm, bidirectional=bidirectional)]
153 |
154 | mult = mult * 2 if bidirectional else mult
155 | model += [
156 | act(**activation_params) if activation != 'Snake' else act(mult * n_filters),
157 | SConv1d(mult * n_filters, dimension, last_kernel_size, norm=norm, norm_kwargs=norm_params,
158 | causal=causal, pad_mode=pad_mode)
159 | ]
160 |
161 | self.model = nn.Sequential(*model)
162 |
163 | def forward(self, x):
164 | return self.model(x)
165 |
166 |
167 | class SEANetDecoder(nn.Module):
168 | """SEANet decoder.
169 | Args:
170 | channels (int): Audio channels.
171 | dimension (int): Intermediate representation dimension.
172 | n_filters (int): Base width for the model.
173 | n_residual_layers (int): nb of residual layers.
174 | ratios (Sequence[int]): kernel size and stride ratios
175 | activation (str): Activation function.
176 | activation_params (dict): Parameters to provide to the activation function
177 | final_activation (str): Final activation function after all convolutions.
178 | final_activation_params (dict): Parameters to provide to the activation function
179 | norm (str): Normalization method.
180 | norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
181 | kernel_size (int): Kernel size for the initial convolution.
182 | last_kernel_size (int): Kernel size for the initial convolution.
183 | residual_kernel_size (int): Kernel size for the residual layers.
184 | dilation_base (int): How much to increase the dilation with each layer.
185 | causal (bool): Whether to use fully causal convolution.
186 | pad_mode (str): Padding mode for the convolutions.
187 | true_skip (bool): Whether to use true skip connection or a simple
188 | (streamable) convolution as the skip connection in the residual network blocks.
189 | compress (int): Reduced dimensionality in residual branches (from Demucs v3).
190 | lstm (int): Number of LSTM layers at the end of the encoder.
191 | trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
192 | If equal to 1.0, it means that all the trimming is done at the right.
193 | """
194 | def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 1,
195 | ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
196 | final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None,
197 | norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
198 | last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
199 | pad_mode: str = 'reflect', true_skip: bool = False, compress: int = 2, lstm: int = 2,
200 | trim_right_ratio: float = 1.0, bidirectional:bool = False):
201 | super().__init__()
202 | self.dimension = dimension
203 | self.channels = channels
204 | self.n_filters = n_filters
205 | self.ratios = ratios
206 | del ratios
207 | self.n_residual_layers = n_residual_layers
208 | self.hop_length = np.prod(self.ratios)
209 |
210 | act = getattr(nn, activation) if activation != 'Snake' else Snake1d
211 | mult = int(2 ** len(self.ratios))
212 | model: tp.List[nn.Module] = [
213 | SConv1d(dimension, mult * n_filters, kernel_size, norm=norm, norm_kwargs=norm_params,
214 | causal=causal, pad_mode=pad_mode)
215 | ]
216 |
217 | if lstm:
218 | model += [SLSTM(mult * n_filters, num_layers=lstm, bidirectional=bidirectional)]
219 |
220 | # Upsample to raw audio scale
221 | for i, ratio in enumerate(self.ratios):
222 | # Add upsampling layers
223 | model += [
224 | act(**activation_params) if activation != 'Snake' else act(mult * n_filters),
225 | SConvTranspose1d(mult * n_filters, mult * n_filters // 2,
226 | kernel_size=ratio * 2, stride=ratio,
227 | norm=norm, norm_kwargs=norm_params,
228 | causal=causal, trim_right_ratio=trim_right_ratio),
229 | ]
230 | # Add residual layers
231 | for j in range(n_residual_layers):
232 | model += [
233 | SEANetResnetBlock(mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1],
234 | dilations=[dilation_base ** j, 1],
235 | activation=activation, activation_params=activation_params,
236 | norm=norm, norm_params=norm_params, causal=causal,
237 | pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
238 |
239 | mult //= 2
240 |
241 | # Add final layers
242 | model += [
243 | act(**activation_params) if activation != 'Snake' else act(n_filters),
244 | SConv1d(n_filters, channels, last_kernel_size, norm=norm, norm_kwargs=norm_params,
245 | causal=causal, pad_mode=pad_mode)
246 | ]
247 | # Add optional final activation to decoder (eg. tanh)
248 | if final_activation is not None:
249 | final_act = getattr(nn, final_activation)
250 | final_activation_params = final_activation_params or {}
251 | model += [
252 | final_act(**final_activation_params)
253 | ]
254 | self.model = nn.Sequential(*model)
255 |
256 | def forward(self, z):
257 | y = self.model(z)
258 | return y
259 |
260 |
261 | def test():
262 | import torch
263 | encoder = SEANetEncoder()
264 | decoder = SEANetDecoder()
265 | x = torch.randn(1, 1, 24000)
266 | z = encoder(x)
267 | print('z ', z.shape)
268 | assert 1==2
269 | assert list(z.shape) == [1, 128, 75], z.shape
270 | y = decoder(z)
271 | assert y.shape == x.shape, (x.shape, y.shape)
272 |
273 |
274 | if __name__ == '__main__':
275 | test()
276 |
--------------------------------------------------------------------------------
/speechtokenizer/quantization/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # flake8: noqa
8 | from .vq import QuantizedResult, ResidualVectorQuantizer
9 |
--------------------------------------------------------------------------------
/speechtokenizer/quantization/ac.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """Arithmetic coder."""
8 |
9 | import io
10 | import math
11 | import random
12 | import typing as tp
13 | import torch
14 |
15 | from ..binary import BitPacker, BitUnpacker
16 |
17 |
18 | def build_stable_quantized_cdf(pdf: torch.Tensor, total_range_bits: int,
19 | roundoff: float = 1e-8, min_range: int = 2,
20 | check: bool = True) -> torch.Tensor:
21 | """Turn the given PDF into a quantized CDF that splits
22 | [0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional
23 | to the PDF.
24 |
25 | Args:
26 | pdf (torch.Tensor): probability distribution, shape should be `[N]`.
27 | total_range_bits (int): see `ArithmeticCoder`, the typical range we expect
28 | during the coding process is `[0, 2 ** total_range_bits - 1]`.
29 | roundoff (float): will round the pdf up to that level to remove difference coming
30 | from e.g. evaluating the Language Model on different architectures.
31 | min_range (int): minimum range width. Should always be at least 2 for numerical
32 | stability. Use this to avoid pathological behavior is a value
33 | that is expected to be rare actually happens in real life.
34 | check (bool): if True, checks that nothing bad happened, can be deactivated for speed.
35 | """
36 | pdf = pdf.detach()
37 | if roundoff:
38 | pdf = (pdf / roundoff).floor() * roundoff
39 | # interpolate with uniform distribution to achieve desired minimum probability.
40 | total_range = 2 ** total_range_bits
41 | cardinality = len(pdf)
42 | alpha = min_range * cardinality / total_range
43 | assert alpha <= 1, "you must reduce min_range"
44 | ranges = (((1 - alpha) * total_range) * pdf).floor().long()
45 | ranges += min_range
46 | quantized_cdf = torch.cumsum(ranges, dim=-1)
47 | if min_range < 2:
48 | raise ValueError("min_range must be at least 2.")
49 | if check:
50 | assert quantized_cdf[-1] <= 2 ** total_range_bits, quantized_cdf[-1]
51 | if ((quantized_cdf[1:] - quantized_cdf[:-1]) < min_range).any() or quantized_cdf[0] < min_range:
52 | raise ValueError("You must increase your total_range_bits.")
53 | return quantized_cdf
54 |
55 |
56 | class ArithmeticCoder:
57 | """ArithmeticCoder,
58 | Let us take a distribution `p` over `N` symbols, and assume we have a stream
59 | of random variables `s_t` sampled from `p`. Let us assume that we have a budget
60 | of `B` bits that we can afford to write on device. There are `2**B` possible numbers,
61 | corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single
62 | sequence `(s_t)` by doing the following:
63 |
64 | 1) Initialize the current range to` [0 ** 2 B - 1]`.
65 | 2) For each time step t, split the current range into contiguous chunks,
66 | one for each possible outcome, with size roughly proportional to `p`.
67 | For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks
68 | would be `{[0, 2], [3, 3]}`.
69 | 3) Select the chunk corresponding to `s_t`, and replace the current range with this.
70 | 4) When done encoding all the values, just select any value remaining in the range.
71 |
72 | You will notice that this procedure can fail: for instance if at any point in time
73 | the range is smaller than `N`, then we can no longer assign a non-empty chunk to each
74 | possible outcome. Intuitively, the more likely a value is, the less the range width
75 | will reduce, and the longer we can go on encoding values. This makes sense: for any efficient
76 | coding scheme, likely outcomes would take less bits, and more of them can be coded
77 | with a fixed budget.
78 |
79 | In practice, we do not know `B` ahead of time, but we have a way to inject new bits
80 | when the current range decreases below a given limit (given by `total_range_bits`), without
81 | having to redo all the computations. If we encode mostly likely values, we will seldom
82 | need to inject new bits, but a single rare value can deplete our stock of entropy!
83 |
84 | In this explanation, we assumed that the distribution `p` was constant. In fact, the present
85 | code works for any sequence `(p_t)` possibly different for each timestep.
86 | We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller
87 | the KL between the true distribution and `p_t`, the most efficient the coding will be.
88 |
89 | Args:
90 | fo (IO[bytes]): file-like object to which the bytes will be written to.
91 | total_range_bits (int): the range `M` described above is `2 ** total_range_bits.
92 | Any time the current range width fall under this limit, new bits will
93 | be injected to rescale the initial range.
94 | """
95 |
96 | def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
97 | assert total_range_bits <= 30
98 | self.total_range_bits = total_range_bits
99 | self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time.
100 | self.low: int = 0
101 | self.high: int = 0
102 | self.max_bit: int = -1
103 | self._dbg: tp.List[tp.Any] = []
104 | self._dbg2: tp.List[tp.Any] = []
105 |
106 | @property
107 | def delta(self) -> int:
108 | """Return the current range width."""
109 | return self.high - self.low + 1
110 |
111 | def _flush_common_prefix(self):
112 | # If self.low and self.high start with the sames bits,
113 | # those won't change anymore as we always just increase the range
114 | # by powers of 2, and we can flush them out to the bit stream.
115 | assert self.high >= self.low, (self.low, self.high)
116 | assert self.high < 2 ** (self.max_bit + 1)
117 | while self.max_bit >= 0:
118 | b1 = self.low >> self.max_bit
119 | b2 = self.high >> self.max_bit
120 | if b1 == b2:
121 | self.low -= (b1 << self.max_bit)
122 | self.high -= (b1 << self.max_bit)
123 | assert self.high >= self.low, (self.high, self.low, self.max_bit)
124 | assert self.low >= 0
125 | self.max_bit -= 1
126 | self.packer.push(b1)
127 | else:
128 | break
129 |
130 | def push(self, symbol: int, quantized_cdf: torch.Tensor):
131 | """Push the given symbol on the stream, flushing out bits
132 | if possible.
133 |
134 | Args:
135 | symbol (int): symbol to encode with the AC.
136 | quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
137 | to build this from your pdf estimate.
138 | """
139 | while self.delta < 2 ** self.total_range_bits:
140 | self.low *= 2
141 | self.high = self.high * 2 + 1
142 | self.max_bit += 1
143 |
144 | range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item()
145 | range_high = quantized_cdf[symbol].item() - 1
146 | effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits))))
147 | effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits))))
148 | assert self.low <= self.high
149 | self.high = self.low + effective_high
150 | self.low = self.low + effective_low
151 | assert self.low <= self.high, (effective_low, effective_high, range_low, range_high)
152 | self._dbg.append((self.low, self.high))
153 | self._dbg2.append((self.low, self.high))
154 | outs = self._flush_common_prefix()
155 | assert self.low <= self.high
156 | assert self.max_bit >= -1
157 | assert self.max_bit <= 61, self.max_bit
158 | return outs
159 |
160 | def flush(self):
161 | """Flush the remaining information to the stream.
162 | """
163 | while self.max_bit >= 0:
164 | b1 = (self.low >> self.max_bit) & 1
165 | self.packer.push(b1)
166 | self.max_bit -= 1
167 | self.packer.flush()
168 |
169 |
170 | class ArithmeticDecoder:
171 | """ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation.
172 |
173 | Note that this must be called with **exactly** the same parameters and sequence
174 | of quantized cdf as the arithmetic encoder or the wrong values will be decoded.
175 |
176 | If the AC encoder current range is [L, H], with `L` and `H` having the some common
177 | prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream.
178 | For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside
179 | `[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained
180 | for a specific sequence of symbols and a binary-search allows us to decode those symbols.
181 | At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols,
182 | and we will need to read new bits from the stream and repeat the process.
183 |
184 | """
185 | def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
186 | self.total_range_bits = total_range_bits
187 | self.low: int = 0
188 | self.high: int = 0
189 | self.current: int = 0
190 | self.max_bit: int = -1
191 | self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time.
192 | # Following is for debugging
193 | self._dbg: tp.List[tp.Any] = []
194 | self._dbg2: tp.List[tp.Any] = []
195 | self._last: tp.Any = None
196 |
197 | @property
198 | def delta(self) -> int:
199 | return self.high - self.low + 1
200 |
201 | def _flush_common_prefix(self):
202 | # Given the current range [L, H], if both have a common prefix,
203 | # we know we can remove it from our representation to avoid handling large numbers.
204 | while self.max_bit >= 0:
205 | b1 = self.low >> self.max_bit
206 | b2 = self.high >> self.max_bit
207 | if b1 == b2:
208 | self.low -= (b1 << self.max_bit)
209 | self.high -= (b1 << self.max_bit)
210 | self.current -= (b1 << self.max_bit)
211 | assert self.high >= self.low
212 | assert self.low >= 0
213 | self.max_bit -= 1
214 | else:
215 | break
216 |
217 | def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]:
218 | """Pull a symbol, reading as many bits from the stream as required.
219 | This returns `None` when the stream has been exhausted.
220 |
221 | Args:
222 | quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
223 | to build this from your pdf estimate. This must be **exatly**
224 | the same cdf as the one used at encoding time.
225 | """
226 | while self.delta < 2 ** self.total_range_bits:
227 | bit = self.unpacker.pull()
228 | if bit is None:
229 | return None
230 | self.low *= 2
231 | self.high = self.high * 2 + 1
232 | self.current = self.current * 2 + bit
233 | self.max_bit += 1
234 |
235 | def bin_search(low_idx: int, high_idx: int):
236 | # Binary search is not just for coding interviews :)
237 | if high_idx < low_idx:
238 | raise RuntimeError("Binary search failed")
239 | mid = (low_idx + high_idx) // 2
240 | range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0
241 | range_high = quantized_cdf[mid].item() - 1
242 | effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits))))
243 | effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits))))
244 | low = effective_low + self.low
245 | high = effective_high + self.low
246 | if self.current >= low:
247 | if self.current <= high:
248 | return (mid, low, high, self.current)
249 | else:
250 | return bin_search(mid + 1, high_idx)
251 | else:
252 | return bin_search(low_idx, mid - 1)
253 |
254 | self._last = (self.low, self.high, self.current, self.max_bit)
255 | sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1)
256 | self._dbg.append((self.low, self.high, self.current))
257 | self._flush_common_prefix()
258 | self._dbg2.append((self.low, self.high, self.current))
259 |
260 | return sym
261 |
262 |
263 | def test():
264 | torch.manual_seed(1234)
265 | random.seed(1234)
266 | for _ in range(4):
267 | pdfs = []
268 | cardinality = random.randrange(4000)
269 | steps = random.randrange(100, 500)
270 | fo = io.BytesIO()
271 | encoder = ArithmeticCoder(fo)
272 | symbols = []
273 | for step in range(steps):
274 | pdf = torch.softmax(torch.randn(cardinality), dim=0)
275 | pdfs.append(pdf)
276 | q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
277 | symbol = torch.multinomial(pdf, 1).item()
278 | symbols.append(symbol)
279 | encoder.push(symbol, q_cdf)
280 | encoder.flush()
281 |
282 | fo.seek(0)
283 | decoder = ArithmeticDecoder(fo)
284 | for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)):
285 | q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
286 | decoded_symbol = decoder.pull(q_cdf)
287 | assert decoded_symbol == symbol, idx
288 | assert decoder.pull(torch.zeros(1)) is None
289 |
290 |
291 | if __name__ == "__main__":
292 | test()
293 |
--------------------------------------------------------------------------------
/speechtokenizer/quantization/core_vq.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 | # This implementation is inspired from
8 | # https://github.com/lucidrains/vector-quantize-pytorch
9 | # which is released under MIT License. Hereafter, the original license:
10 | # MIT License
11 | #
12 | # Copyright (c) 2020 Phil Wang
13 | #
14 | # Permission is hereby granted, free of charge, to any person obtaining a copy
15 | # of this software and associated documentation files (the "Software"), to deal
16 | # in the Software without restriction, including without limitation the rights
17 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18 | # copies of the Software, and to permit persons to whom the Software is
19 | # furnished to do so, subject to the following conditions:
20 | #
21 | # The above copyright notice and this permission notice shall be included in all
22 | # copies or substantial portions of the Software.
23 | #
24 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
25 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
27 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
28 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
29 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30 | # SOFTWARE.
31 |
32 | """Core vector quantization implementation."""
33 | import typing as tp
34 |
35 | from einops import rearrange, repeat
36 | import torch
37 | from torch import nn
38 | import torch.nn.functional as F
39 |
40 | from .distrib import broadcast_tensors, rank
41 |
42 |
43 | def default(val: tp.Any, d: tp.Any) -> tp.Any:
44 | return val if val is not None else d
45 |
46 |
47 | def ema_inplace(moving_avg, new, decay: float):
48 | moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
49 |
50 |
51 | def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
52 | return (x + epsilon) / (x.sum() + n_categories * epsilon)
53 |
54 |
55 | def uniform_init(*shape: int):
56 | t = torch.empty(shape)
57 | nn.init.kaiming_uniform_(t)
58 | return t
59 |
60 |
61 | def sample_vectors(samples, num: int):
62 | num_samples, device = samples.shape[0], samples.device
63 |
64 | if num_samples >= num:
65 | indices = torch.randperm(num_samples, device=device)[:num]
66 | else:
67 | indices = torch.randint(0, num_samples, (num,), device=device)
68 |
69 | return samples[indices]
70 |
71 |
72 | def kmeans(samples, num_clusters: int, num_iters: int = 10):
73 | dim, dtype = samples.shape[-1], samples.dtype
74 |
75 | means = sample_vectors(samples, num_clusters)
76 |
77 | for _ in range(num_iters):
78 | diffs = rearrange(samples, "n d -> n () d") - rearrange(
79 | means, "c d -> () c d"
80 | )
81 | dists = -(diffs ** 2).sum(dim=-1)
82 |
83 | buckets = dists.max(dim=-1).indices
84 | bins = torch.bincount(buckets, minlength=num_clusters)
85 | zero_mask = bins == 0
86 | bins_min_clamped = bins.masked_fill(zero_mask, 1)
87 |
88 | new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
89 | new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
90 | new_means = new_means / bins_min_clamped[..., None]
91 |
92 | means = torch.where(zero_mask[..., None], means, new_means)
93 |
94 | return means, bins
95 |
96 |
97 | class EuclideanCodebook(nn.Module):
98 | """Codebook with Euclidean distance.
99 | Args:
100 | dim (int): Dimension.
101 | codebook_size (int): Codebook size.
102 | kmeans_init (bool): Whether to use k-means to initialize the codebooks.
103 | If set to true, run the k-means algorithm on the first training batch and use
104 | the learned centroids as initialization.
105 | kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
106 | decay (float): Decay for exponential moving average over the codebooks.
107 | epsilon (float): Epsilon value for numerical stability.
108 | threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
109 | that have an exponential moving average cluster size less than the specified threshold with
110 | randomly selected vector from the current batch.
111 | """
112 | def __init__(
113 | self,
114 | dim: int,
115 | codebook_size: int,
116 | kmeans_init: int = False,
117 | kmeans_iters: int = 10,
118 | decay: float = 0.99,
119 | epsilon: float = 1e-5,
120 | threshold_ema_dead_code: int = 2,
121 | ):
122 | super().__init__()
123 | self.decay = decay
124 | init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
125 | embed = init_fn(codebook_size, dim)
126 |
127 | self.codebook_size = codebook_size
128 |
129 | self.kmeans_iters = kmeans_iters
130 | self.epsilon = epsilon
131 | self.threshold_ema_dead_code = threshold_ema_dead_code
132 |
133 | self.register_buffer("inited", torch.Tensor([not kmeans_init]))
134 | self.register_buffer("cluster_size", torch.zeros(codebook_size))
135 | self.register_buffer("embed", embed)
136 | self.register_buffer("embed_avg", embed.clone())
137 |
138 | @torch.jit.ignore
139 | def init_embed_(self, data):
140 | if self.inited:
141 | return
142 |
143 | embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
144 | self.embed.data.copy_(embed)
145 | self.embed_avg.data.copy_(embed.clone())
146 | self.cluster_size.data.copy_(cluster_size)
147 | self.inited.data.copy_(torch.Tensor([True]))
148 | # Make sure all buffers across workers are in sync after initialization
149 | #broadcast_tensors(self.buffers())
150 |
151 | def replace_(self, samples, mask):
152 | modified_codebook = torch.where(
153 | mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
154 | )
155 | self.embed.data.copy_(modified_codebook)
156 |
157 | def expire_codes_(self, batch_samples):
158 | if self.threshold_ema_dead_code == 0:
159 | return
160 |
161 | expired_codes = self.cluster_size < self.threshold_ema_dead_code
162 | if not torch.any(expired_codes):
163 | return
164 |
165 | batch_samples = rearrange(batch_samples, "... d -> (...) d")
166 | self.replace_(batch_samples, mask=expired_codes)
167 | #broadcast_tensors(self.buffers())
168 |
169 | def preprocess(self, x):
170 | x = rearrange(x, "... d -> (...) d")
171 | return x
172 |
173 | def quantize(self, x):
174 | embed = self.embed.t()
175 | dist = -(
176 | x.pow(2).sum(1, keepdim=True)
177 | - 2 * x @ embed
178 | + embed.pow(2).sum(0, keepdim=True)
179 | )
180 | embed_ind = dist.max(dim=-1).indices
181 | return embed_ind
182 |
183 | def postprocess_emb(self, embed_ind, shape):
184 | return embed_ind.view(*shape[:-1])
185 |
186 | def dequantize(self, embed_ind):
187 | quantize = F.embedding(embed_ind, self.embed)
188 | return quantize
189 |
190 | def encode(self, x):
191 | shape = x.shape
192 | # pre-process
193 | x = self.preprocess(x)
194 | # quantize
195 | embed_ind = self.quantize(x)
196 | # post-process
197 | embed_ind = self.postprocess_emb(embed_ind, shape)
198 | return embed_ind
199 |
200 | def decode(self, embed_ind):
201 | quantize = self.dequantize(embed_ind)
202 | return quantize
203 |
204 | def forward(self, x):
205 | shape, dtype = x.shape, x.dtype
206 | x = self.preprocess(x)
207 |
208 | self.init_embed_(x)
209 |
210 | embed_ind = self.quantize(x)
211 | embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
212 | embed_ind = self.postprocess_emb(embed_ind, shape)
213 | quantize = self.dequantize(embed_ind)
214 |
215 | if self.training:
216 | # We do the expiry of code at that point as buffers are in sync
217 | # and all the workers will take the same decision.
218 | self.expire_codes_(x)
219 | ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
220 | embed_sum = x.t() @ embed_onehot
221 | ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
222 | cluster_size = (
223 | laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
224 | * self.cluster_size.sum()
225 | )
226 | embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
227 | self.embed.data.copy_(embed_normalized)
228 |
229 | return quantize, embed_ind
230 |
231 |
232 | class VectorQuantization(nn.Module):
233 | """Vector quantization implementation.
234 | Currently supports only euclidean distance.
235 | Args:
236 | dim (int): Dimension
237 | codebook_size (int): Codebook size
238 | codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
239 | decay (float): Decay for exponential moving average over the codebooks.
240 | epsilon (float): Epsilon value for numerical stability.
241 | kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
242 | kmeans_iters (int): Number of iterations used for kmeans initialization.
243 | threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
244 | that have an exponential moving average cluster size less than the specified threshold with
245 | randomly selected vector from the current batch.
246 | commitment_weight (float): Weight for commitment loss.
247 | """
248 | def __init__(
249 | self,
250 | dim: int,
251 | codebook_size: int,
252 | codebook_dim: tp.Optional[int] = None,
253 | decay: float = 0.99,
254 | epsilon: float = 1e-5,
255 | kmeans_init: bool = True,
256 | kmeans_iters: int = 50,
257 | threshold_ema_dead_code: int = 2,
258 | commitment_weight: float = 1.,
259 | ):
260 | super().__init__()
261 | _codebook_dim: int = default(codebook_dim, dim)
262 |
263 | requires_projection = _codebook_dim != dim
264 | self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity())
265 | self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity())
266 |
267 | self.epsilon = epsilon
268 | self.commitment_weight = commitment_weight
269 |
270 | self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
271 | kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
272 | decay=decay, epsilon=epsilon,
273 | threshold_ema_dead_code=threshold_ema_dead_code)
274 | self.codebook_size = codebook_size
275 |
276 | @property
277 | def codebook(self):
278 | return self._codebook.embed
279 |
280 | def encode(self, x):
281 | x = rearrange(x, "b d n -> b n d")
282 | x = self.project_in(x)
283 | embed_in = self._codebook.encode(x)
284 | return embed_in
285 |
286 | def decode(self, embed_ind):
287 | quantize = self._codebook.decode(embed_ind)
288 | quantize = self.project_out(quantize)
289 | quantize = rearrange(quantize, "b n d -> b d n")
290 | return quantize
291 |
292 | def forward(self, x):
293 | device = x.device
294 | x = rearrange(x, "b d n -> b n d")
295 | x = self.project_in(x)
296 |
297 | quantize, embed_ind = self._codebook(x)
298 |
299 | if self.training:
300 | quantize = x + (quantize - x).detach()
301 |
302 | loss = torch.tensor([0.0], device=device, requires_grad=self.training)
303 |
304 | if self.training:
305 | if self.commitment_weight > 0:
306 | commit_loss = F.mse_loss(quantize.detach(), x)
307 | loss = loss + commit_loss * self.commitment_weight
308 |
309 | quantize = self.project_out(quantize)
310 | quantize = rearrange(quantize, "b n d -> b d n")
311 | return quantize, embed_ind, loss
312 |
313 |
314 | class ResidualVectorQuantization(nn.Module):
315 | """Residual vector quantization implementation.
316 | Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
317 | """
318 | def __init__(self, *, num_quantizers, **kwargs):
319 | super().__init__()
320 | self.layers = nn.ModuleList(
321 | [VectorQuantization(**kwargs) for _ in range(num_quantizers)]
322 | )
323 |
324 | def forward(self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None):
325 | quantized_out = 0.0
326 | residual = x
327 |
328 | all_losses = []
329 | all_indices = []
330 | out_quantized = []
331 |
332 | n_q = n_q or len(self.layers)
333 |
334 | for i, layer in enumerate(self.layers[:n_q]):
335 | quantized, indices, loss = layer(residual)
336 | residual = residual - quantized
337 | quantized_out = quantized_out + quantized
338 |
339 | all_indices.append(indices)
340 | all_losses.append(loss)
341 | if layers and i in layers:
342 | out_quantized.append(quantized)
343 |
344 | out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
345 | return quantized_out, out_indices, out_losses, out_quantized
346 |
347 | def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int]= None) -> torch.Tensor:
348 | residual = x
349 | all_indices = []
350 | n_q = n_q or len(self.layers)
351 | st = st or 0
352 | for layer in self.layers[st:n_q]:
353 | indices = layer.encode(residual)
354 | quantized = layer.decode(indices)
355 | residual = residual - quantized
356 | all_indices.append(indices)
357 | out_indices = torch.stack(all_indices)
358 | return out_indices
359 |
360 | def decode(self, q_indices: torch.Tensor, st: int=0) -> torch.Tensor:
361 | quantized_out = torch.tensor(0.0, device=q_indices.device)
362 | for i, indices in enumerate(q_indices):
363 | layer = self.layers[st + i]
364 | quantized = layer.decode(indices)
365 | quantized_out = quantized_out + quantized
366 | return quantized_out
367 |
--------------------------------------------------------------------------------
/speechtokenizer/quantization/distrib.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """Torch distributed utilities."""
8 |
9 | import typing as tp
10 |
11 | import torch
12 |
13 |
14 | def rank():
15 | if torch.distributed.is_initialized():
16 | return torch.distributed.get_rank()
17 | else:
18 | return 0
19 |
20 |
21 | def world_size():
22 | if torch.distributed.is_initialized():
23 | return torch.distributed.get_world_size()
24 | else:
25 | return 1
26 |
27 |
28 | def is_distributed():
29 | return world_size() > 1
30 |
31 |
32 | def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
33 | if is_distributed():
34 | return torch.distributed.all_reduce(tensor, op)
35 |
36 |
37 | def _is_complex_or_float(tensor):
38 | return torch.is_floating_point(tensor) or torch.is_complex(tensor)
39 |
40 |
41 | def _check_number_of_params(params: tp.List[torch.Tensor]):
42 | # utility function to check that the number of params in all workers is the same,
43 | # and thus avoid a deadlock with distributed all reduce.
44 | if not is_distributed() or not params:
45 | return
46 | #print('params[0].device ', params[0].device)
47 | tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long)
48 | all_reduce(tensor)
49 | if tensor.item() != len(params) * world_size():
50 | # If not all the workers have the same number, for at least one of them,
51 | # this inequality will be verified.
52 | raise RuntimeError(f"Mismatch in number of params: ours is {len(params)}, "
53 | "at least one worker has a different one.")
54 |
55 |
56 | def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0):
57 | """Broadcast the tensors from the given parameters to all workers.
58 | This can be used to ensure that all workers have the same model to start with.
59 | """
60 | if not is_distributed():
61 | return
62 | tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)]
63 | _check_number_of_params(tensors)
64 | handles = []
65 | for tensor in tensors:
66 | # src = int(rank()) # added code
67 | handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True)
68 | handles.append(handle)
69 | for handle in handles:
70 | handle.wait()
71 |
72 |
73 | def sync_buffer(buffers, average=True):
74 | """
75 | Sync grad for buffers. If average is False, broadcast instead of averaging.
76 | """
77 | if not is_distributed():
78 | return
79 | handles = []
80 | for buffer in buffers:
81 | if torch.is_floating_point(buffer.data):
82 | if average:
83 | handle = torch.distributed.all_reduce(
84 | buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
85 | else:
86 | handle = torch.distributed.broadcast(
87 | buffer.data, src=0, async_op=True)
88 | handles.append((buffer, handle))
89 | for buffer, handle in handles:
90 | handle.wait()
91 | if average:
92 | buffer.data /= world_size
93 |
94 |
95 | def sync_grad(params):
96 | """
97 | Simpler alternative to DistributedDataParallel, that doesn't rely
98 | on any black magic. For simple models it can also be as fast.
99 | Just call this on your model parameters after the call to backward!
100 | """
101 | if not is_distributed():
102 | return
103 | handles = []
104 | for p in params:
105 | if p.grad is not None:
106 | handle = torch.distributed.all_reduce(
107 | p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
108 | handles.append((p, handle))
109 | for p, handle in handles:
110 | handle.wait()
111 | p.grad.data /= world_size()
112 |
113 |
114 | def average_metrics(metrics: tp.Dict[str, float], count=1.):
115 | """Average a dictionary of metrics across all workers, using the optional
116 | `count` as unormalized weight.
117 | """
118 | if not is_distributed():
119 | return metrics
120 | keys, values = zip(*metrics.items())
121 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
122 | tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
123 | tensor *= count
124 | all_reduce(tensor)
125 | averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
126 | return dict(zip(keys, averaged))
127 |
--------------------------------------------------------------------------------
/speechtokenizer/quantization/vq.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """Residual vector quantizer implementation."""
8 |
9 | from dataclasses import dataclass, field
10 | import math
11 | import typing as tp
12 |
13 | import torch
14 | from torch import nn
15 |
16 | from .core_vq import ResidualVectorQuantization
17 |
18 |
19 | @dataclass
20 | class QuantizedResult:
21 | quantized: torch.Tensor
22 | codes: torch.Tensor
23 | bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
24 | penalty: tp.Optional[torch.Tensor] = None
25 | metrics: dict = field(default_factory=dict)
26 |
27 |
28 | class ResidualVectorQuantizer(nn.Module):
29 | """Residual Vector Quantizer.
30 | Args:
31 | dimension (int): Dimension of the codebooks.
32 | n_q (int): Number of residual vector quantizers used.
33 | bins (int): Codebook size.
34 | decay (float): Decay for exponential moving average over the codebooks.
35 | kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
36 | kmeans_iters (int): Number of iterations used for kmeans initialization.
37 | threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
38 | that have an exponential moving average cluster size less than the specified threshold with
39 | randomly selected vector from the current batch.
40 | """
41 | def __init__(
42 | self,
43 | dimension: int = 256,
44 | n_q: int = 8,
45 | bins: int = 1024,
46 | decay: float = 0.99,
47 | kmeans_init: bool = True,
48 | kmeans_iters: int = 50,
49 | threshold_ema_dead_code: int = 2,
50 | ):
51 | super().__init__()
52 | self.n_q = n_q
53 | self.dimension = dimension
54 | self.bins = bins
55 | self.decay = decay
56 | self.kmeans_init = kmeans_init
57 | self.kmeans_iters = kmeans_iters
58 | self.threshold_ema_dead_code = threshold_ema_dead_code
59 | self.vq = ResidualVectorQuantization(
60 | dim=self.dimension,
61 | codebook_size=self.bins,
62 | num_quantizers=self.n_q,
63 | decay=self.decay,
64 | kmeans_init=self.kmeans_init,
65 | kmeans_iters=self.kmeans_iters,
66 | threshold_ema_dead_code=self.threshold_ema_dead_code,
67 | )
68 |
69 | def forward(self, x: torch.Tensor, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None) -> QuantizedResult:
70 | """Residual vector quantization on the given input tensor.
71 | Args:
72 | x (torch.Tensor): Input tensor.
73 | n_q (int): Number of quantizer used to quantize. Default: All quantizers.
74 | layers (list): Layer that need to return quantized. Defalt: None.
75 | Returns:
76 | QuantizedResult:
77 | The quantized (or approximately quantized) representation with
78 | the associated numbert quantizers and layer quantized required to return.
79 | """
80 | n_q = n_q if n_q else self.n_q
81 | if layers and max(layers) >= n_q:
82 | raise ValueError(f'Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B.')
83 | quantized, codes, commit_loss, quantized_list = self.vq(x, n_q=n_q, layers=layers)
84 | return quantized, codes, torch.mean(commit_loss), quantized_list
85 |
86 |
87 | def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None) -> torch.Tensor:
88 | """Encode a given input tensor with the specified sample rate at the given bandwidth.
89 | The RVQ encode method sets the appropriate number of quantizer to use
90 | and returns indices for each quantizer.
91 | Args:
92 | x (torch.Tensor): Input tensor.
93 | n_q (int): Number of quantizer used to quantize. Default: All quantizers.
94 | st (int): Start to encode input from which layers. Default: 0.
95 | """
96 | n_q = n_q if n_q else self.n_q
97 | st = st or 0
98 | codes = self.vq.encode(x, n_q=n_q, st=st)
99 | return codes
100 |
101 | def decode(self, codes: torch.Tensor, st: int = 0) -> torch.Tensor:
102 | """Decode the given codes to the quantized representation.
103 | Args:
104 | codes (torch.Tensor): Input indices for each quantizer.
105 | st (int): Start to decode input codes from which layers. Default: 0.
106 | """
107 | quantized = self.vq.decode(codes, st=st)
108 | return quantized
109 |
--------------------------------------------------------------------------------
/speechtokenizer/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | from .trainer import SpeechTokenizerTrainer
--------------------------------------------------------------------------------
/speechtokenizer/trainer/dataset.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset, DataLoader
2 | from torch.nn.utils.rnn import pad_sequence
3 | import torchaudio
4 | import random
5 | import torch
6 | import numpy as np
7 |
8 | def collate_fn(data):
9 | # return pad_sequence(data, batch_first=True)
10 | # return pad_sequence(*data)
11 | is_one_data = not isinstance(data[0], tuple)
12 | outputs = []
13 | if is_one_data:
14 | for datum in data:
15 | if isinstance(datum, torch.Tensor):
16 | output = datum.unsqueeze(0)
17 | else:
18 | output = torch.tensor([datum])
19 | outputs.append(output)
20 | return tuple(outputs)
21 | for datum in zip(*data):
22 | if isinstance(datum[0], torch.Tensor):
23 | output = pad_sequence(datum, batch_first=True)
24 | else:
25 | output = torch.tensor(list(datum))
26 | outputs.append(output)
27 |
28 | return tuple(outputs)
29 |
30 | def get_dataloader(ds, **kwargs):
31 | return DataLoader(ds, collate_fn=collate_fn, **kwargs)
32 |
33 | class audioDataset(Dataset):
34 |
35 | def __init__(self,
36 | file_list,
37 | segment_size,
38 | sample_rate,
39 | downsample_rate = 320,
40 | valid=False):
41 | super().__init__()
42 | self.file_list = file_list
43 | self.segment_size = segment_size
44 | self.sample_rate = sample_rate
45 | self.valid = valid
46 | self.downsample_rate = downsample_rate
47 |
48 | def __len__(self):
49 | return len(self.file_list)
50 |
51 |
52 | def __getitem__(self, index):
53 | file = self.file_list[index].strip()
54 | audio_file, feature_file = file.split('\t')
55 | audio, sr = torchaudio.load(audio_file)
56 | feature = torch.from_numpy(np.load(feature_file))
57 | audio = audio.mean(axis=0)
58 | if sr != self.sample_rate:
59 | audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
60 | if audio.size(-1) > self.segment_size:
61 | if self.valid:
62 | return audio[:self.segment_size], feature[:self.segment_size // self.downsample_rate]
63 | max_audio_start = audio.size(-1) - self.segment_size
64 | audio_start = random.randint(0, max_audio_start)
65 | audio = audio[audio_start:audio_start+self.segment_size]
66 | feature_start = min(int(audio_start / self.downsample_rate), feature.size(0) - self.segment_size // self.downsample_rate)
67 | feature = feature[feature_start:feature_start + self.segment_size // self.downsample_rate, :]
68 | else:
69 | if not self.valid:
70 | audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(-1)), 'constant')
71 | return audio, feature
--------------------------------------------------------------------------------
/speechtokenizer/trainer/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchaudio
3 | import matplotlib.pylab as plt
4 |
5 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
6 | return torch.log(torch.clamp(x, min=clip_val) * C)
7 |
8 |
9 | def spectral_normalize_torch(magnitudes):
10 | output = dynamic_range_compression_torch(magnitudes)
11 | return output
12 |
13 |
14 | mel_basis = {}
15 | hann_window = {}
16 |
17 |
18 | def mel_spectrogram(y, n_fft, num_mels, sample_rate, hop_size, win_size, fmin, fmax, center=False):
19 |
20 | global mel_basis, hann_window
21 | if fmax not in mel_basis:
22 | mel_transform = torchaudio.transforms.MelScale(n_mels=num_mels, sample_rate=sample_rate, n_stft=n_fft//2+1, f_min=fmin, f_max=fmax, norm='slaney', mel_scale="htk")
23 | mel_basis[str(fmax)+'_'+str(y.device)] = mel_transform.fb.float().T.to(y.device)
24 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
25 |
26 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
27 | y = y.squeeze(1)
28 |
29 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
30 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
31 | spec = torch.abs(spec) + 1e-9
32 | spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
33 | spec = spectral_normalize_torch(spec)
34 |
35 | return spec
36 |
37 | def plot_spectrogram(spectrogram):
38 | fig, ax = plt.subplots(figsize=(10, 2))
39 | im = ax.imshow(spectrogram, aspect="auto", origin="lower",
40 | interpolation='none')
41 | plt.colorbar(im, ax=ax)
42 |
43 | fig.canvas.draw()
44 | plt.close()
45 |
46 | return fig
47 |
48 | def recon_loss(x, x_hat):
49 | length = min(x.size(-1), x_hat.size(-1))
50 | return torch.nn.functional.l1_loss(x[:, :, :length], x_hat[:, :, :length])
51 |
52 | def mel_loss(x, x_hat, **kwargs):
53 | x_mel = mel_spectrogram(x.squeeze(1), **kwargs)
54 | x_hat_mel = mel_spectrogram(x_hat.squeeze(1), **kwargs)
55 | length = min(x_mel.size(2), x_hat_mel.size(2))
56 | return torch.nn.functional.l1_loss(x_mel[:, :, :length], x_hat_mel[:, :, :length])
57 |
58 | def feature_loss(fmap_r, fmap_g):
59 | loss = 0
60 | for dr, dg in zip(fmap_r, fmap_g):
61 | for rl, gl in zip(dr, dg):
62 | loss += torch.mean(torch.abs(rl - gl))
63 |
64 | return loss*2
65 |
66 |
67 | def discriminator_loss(disc_real_outputs, disc_generated_outputs):
68 | loss = 0
69 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
70 | r_loss = torch.mean((1-dr)**2)
71 | g_loss = torch.mean(dg**2)
72 | loss += (r_loss + g_loss)
73 |
74 | return loss
75 |
76 |
77 | def adversarial_loss(disc_outputs):
78 | loss = 0
79 | for dg in disc_outputs:
80 | l = torch.mean((1-dg)**2)
81 | loss += l
82 |
83 | return loss
84 |
85 |
86 | def d_axis_distill_loss(feature, target_feature):
87 | n = min(feature.size(1), target_feature.size(1))
88 | distill_loss = - torch.log(torch.sigmoid(torch.nn.functional.cosine_similarity(feature[:, :n], target_feature[:, :n], axis=1))).mean()
89 | return distill_loss
90 |
91 | def t_axis_distill_loss(feature, target_feature, lambda_sim=1):
92 | n = min(feature.size(1), target_feature.size(1))
93 | l1_loss = torch.functional.l1_loss(feature[:, :n], target_feature[:, :n], reduction='mean')
94 | sim_loss = - torch.log(torch.sigmoid(torch.nn.functional.cosine_similarity(feature[:, :n], target_feature[:, :n], axis=-1))).mean()
95 | distill_loss = l1_loss + lambda_sim * sim_loss
96 | return distill_loss
--------------------------------------------------------------------------------
/speechtokenizer/trainer/optimizer.py:
--------------------------------------------------------------------------------
1 | from lion_pytorch import Lion
2 | from torch.optim import AdamW, Adam
3 |
4 | def separate_weight_decayable_params(params):
5 | wd_params, no_wd_params = [], []
6 | for param in params:
7 | param_list = no_wd_params if param.ndim < 2 else wd_params
8 | param_list.append(param)
9 | return wd_params, no_wd_params
10 |
11 | def get_optimizer(
12 | params,
13 | lr = 1e-4,
14 | wd = 1e-2,
15 | betas = (0.9, 0.99),
16 | eps = 1e-8,
17 | filter_by_requires_grad = False,
18 | group_wd_params = True,
19 | use_lion = False,
20 | **kwargs
21 | ):
22 | has_wd = wd > 0
23 |
24 | if filter_by_requires_grad:
25 | params = list(filter(lambda t: t.requires_grad, params))
26 |
27 | if group_wd_params and has_wd:
28 | wd_params, no_wd_params = separate_weight_decayable_params(params)
29 |
30 | params = [
31 | {'params': wd_params},
32 | {'params': no_wd_params, 'weight_decay': 0},
33 | ]
34 |
35 | if use_lion:
36 | return Lion(params, lr = lr, betas = betas, weight_decay = wd)
37 |
38 | if not has_wd:
39 | return Adam(params, lr = lr, betas = betas, eps = eps)
40 |
41 | return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)
--------------------------------------------------------------------------------
/speechtokenizer/trainer/trainer.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import re
3 | import os
4 | import itertools
5 |
6 | from beartype import beartype
7 |
8 | import torch
9 | from torch import nn
10 | from torch.optim.lr_scheduler import CosineAnnealingLR
11 |
12 | from .dataset import get_dataloader, audioDataset
13 | from .optimizer import get_optimizer
14 | from torch.utils import tensorboard
15 | from .loss import *
16 | import json
17 | from speechtokenizer import SpeechTokenizer
18 | import time
19 | from tqdm import tqdm
20 | from accelerate import Accelerator, DistributedType, DistributedDataParallelKwargs, DataLoaderConfiguration
21 |
22 |
23 | # helpers
24 |
25 | def exists(val):
26 | return val is not None
27 |
28 | def cycle(dl):
29 | while True:
30 | for data in dl:
31 | yield data
32 |
33 | def cast_tuple(t):
34 | return t if isinstance(t, (tuple, list)) else (t,)
35 |
36 |
37 | def accum_log(log, new_logs):
38 | for key, new_value in new_logs.items():
39 | old_value = log.get(key, 0.)
40 | log[key] = old_value + new_value
41 | return log
42 |
43 | def checkpoint_num_steps(checkpoint_path):
44 | """Returns the number of steps trained from a checkpoint based on the filename.
45 |
46 | Filename format assumed to be something like "/path/to/soundstorm.20000.pt" which is
47 | for 20k train steps. Returns 20000 in that case.
48 | """
49 | results = re.findall(r'\d+', str(checkpoint_path))
50 |
51 | if len(results) == 0:
52 | return 0
53 |
54 | return int(results[-1])
55 |
56 |
57 | class SpeechTokenizerTrainer(nn.Module):
58 | @beartype
59 | def __init__(
60 | self,
61 | generator: SpeechTokenizer,
62 | discriminators: dict,
63 | cfg,
64 | accelerate_kwargs: dict = dict(),
65 | ):
66 | super().__init__()
67 | ddp_kwargs = DistributedDataParallelKwargs()
68 | torch.manual_seed(cfg.get('seed'))
69 | split_batches = cfg.get("split_batches", False)
70 | self.log_steps = cfg.get('log_steps')
71 | self.stdout_steps = cfg.get('stdout_steps')
72 | self.save_model_steps = cfg.get('save_model_steps')
73 | results_folder = cfg.get('results_folder')
74 | self.results_folder = Path(results_folder)
75 | self.num_ckpt_keep = cfg.get("num_ckpt_keep")
76 | self.epochs = cfg.get("epochs")
77 | self.num_warmup_steps = cfg.get("num_warmup_steps")
78 | self.batch_size = cfg.get("batch_size")
79 | self.sample_rate = cfg.get('sample_rate')
80 | self.showpiece_num = cfg.get('showpiece_num', 8)
81 | project_name = 'SpeechTokenizer'
82 |
83 | if not self.results_folder.exists():
84 | self.results_folder.mkdir(parents = True, exist_ok = True)
85 |
86 | with open(f'{str(self.results_folder)}/config.json', 'w+') as f:
87 | json.dump(cfg, f, ensure_ascii=False, indent=4)
88 |
89 |
90 | # tracker = AudioTensorBoardTracker(run_name=project_name, logging_dir=results_folder)
91 | dataloader_config = DataLoaderConfiguration(split_batches=split_batches)
92 | self.accelerator = Accelerator(
93 | dataloader_config=dataloader_config,
94 | kwargs_handlers=[ddp_kwargs],
95 | # log_with=tracker,
96 | **accelerate_kwargs
97 | )
98 |
99 | if self.is_main:
100 | self.writer = tensorboard.SummaryWriter(os.path.join(results_folder, 'logs'))
101 |
102 | self.generator = generator
103 | self.discriminators = discriminators
104 |
105 |
106 | self.register_buffer('steps', torch.Tensor([0]))
107 |
108 |
109 |
110 | self.mel_loss_lambdas = cfg.get('mel_loss_lambdas')
111 | self.commitment_loss_lambda = cfg.get('commitment_loss_lambda')
112 | self.recon_loss_lambda = cfg.get('recon_loss_lambda')
113 | self.distill_loss_lambda = cfg.get('distill_loss_lambda')
114 | distill_type = cfg.get('distill_type', 'd_axis')
115 | if distill_type == 't_axis':
116 | from functools import partial
117 | lambda_sim = cfg.get('lambda_sim', 1)
118 | self.distill_loss = partial(t_axis_distill_loss, lambda_sim=lambda_sim)
119 | else:
120 | self.distill_loss = d_axis_distill_loss
121 | self.mel_loss_kwargs_list = []
122 | mult = 1
123 | for i in range(len(self.mel_loss_lambdas)):
124 | self.mel_loss_kwargs_list.append({'n_fft': cfg.get('n_fft') // mult, 'num_mels':cfg.get('num_mels'),'sample_rate':self.sample_rate,
125 | 'hop_size': cfg.get('hop_size') // mult, 'win_size':cfg.get('win_size') // mult, 'fmin':cfg.get('fmin'),
126 | 'fmax':cfg.get('fmax_for_loss')})
127 | mult = mult * 2
128 | self.mel_kwargs = {'n_fft': cfg.get('n_fft'), 'num_mels':cfg.get('num_mels'),'sample_rate':self.sample_rate,
129 | 'hop_size': cfg.get('hop_size'), 'win_size':cfg.get('win_size'), 'fmin':cfg.get('fmin'),
130 | 'fmax':cfg.get('fmax')}
131 |
132 |
133 | # max grad norm
134 |
135 | # self.max_grad_norm = max_grad_norm
136 | segment_size = cfg.get("segment_size")
137 | train_files = cfg.get("train_files")
138 | batch_size = cfg.get("batch_size")
139 | self.batch_size = batch_size
140 | with open(train_files, 'r') as f:
141 | train_file_list = f.readlines()
142 | valid_files = cfg.get("valid_files")
143 | with open(valid_files, 'r') as f:
144 | valid_file_list = f.readlines()
145 |
146 | self.ds = audioDataset(file_list=train_file_list,
147 | segment_size=segment_size,
148 | downsample_rate=generator.downsample_rate,
149 | sample_rate=self.sample_rate)
150 | self.valid_ds = audioDataset(file_list=valid_file_list,
151 | segment_size=self.sample_rate * 30,
152 | downsample_rate=generator.downsample_rate,
153 | sample_rate=self.sample_rate,
154 | valid=True)
155 | if self.is_main:
156 | self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
157 |
158 |
159 |
160 | assert len(self.ds) >= self.batch_size, 'dataset must have sufficient samples for training'
161 | assert len(self.valid_ds) >= self.batch_size, f'validation dataset must have sufficient number of samples (currently {len(self.valid_ds)}) for training'
162 |
163 | # dataloader
164 | drop_last = cfg.get("drop_last", True)
165 | num_workers = cfg.get("num_workers")
166 | self.dl = get_dataloader(self.ds, batch_size = self.batch_size, shuffle = True, drop_last = drop_last, num_workers=num_workers)
167 | self.valid_dl = get_dataloader(self.valid_ds, batch_size = 1, shuffle = False, drop_last = False, num_workers=1)
168 |
169 | # lr
170 | self.lr = cfg.get("learning_rate")
171 | self.initial_lr = cfg.get("intial_learning_rate")
172 |
173 | # optimizer
174 | self.optim_g = get_optimizer(
175 | generator.parameters(),
176 | lr = cfg.get("learning_rate"),
177 | wd = cfg.get("wd"),
178 | betas = cfg.get("betas")
179 | )
180 |
181 | self.optim_d = get_optimizer(
182 | itertools.chain(*[i.parameters() for i in self.discriminators.values()]),
183 | lr = cfg.get("learning_rate"),
184 | wd = cfg.get("wd"),
185 | betas = cfg.get("betas")
186 | )
187 |
188 | # scheduler
189 | # num_train_steps = epochs * self.ds.__len__() // (batch_size * grad_accum_every)
190 | num_train_steps = self.epochs * self.ds.__len__() // batch_size
191 | self.scheduler_g = CosineAnnealingLR(self.optim_g, T_max = num_train_steps)
192 | self.scheduler_d = CosineAnnealingLR(self.optim_d, T_max = num_train_steps)
193 |
194 |
195 |
196 |
197 | # prepare with accelerator
198 |
199 | (
200 | self.generator,
201 | self.optim_g,
202 | self.optim_d,
203 | self.scheduler_g,
204 | self.scheduler_d,
205 | self.dl,
206 | self.valid_dl
207 | ) = self.accelerator.prepare(
208 | self.generator,
209 | self.optim_g,
210 | self.optim_d,
211 | self.scheduler_g,
212 | self.scheduler_d,
213 | self.dl,
214 | self.valid_dl
215 | )
216 | self.discriminators = {k:self.accelerator.prepare(v) for k, v in self.discriminators.items()}
217 |
218 |
219 |
220 | hps = {"num_train_steps": num_train_steps, "num_warmup_steps": self.num_warmup_steps, "learning_rate": self.lr, "initial_learning_rate": self.initial_lr, "epochs": self.epochs}
221 | self.accelerator.init_trackers("SpeechTokenizer", config=hps)
222 | self.best_dev_mel_loss = float('inf')
223 | self.plot_gt_once = False
224 |
225 | def save(self, path, best_dev_mel_loss):
226 | if best_dev_mel_loss < self.best_dev_mel_loss:
227 | self.best_dev_mel_loss = best_dev_mel_loss
228 | torch.save(self.accelerator.get_state_dict(self.generator), f'{self.results_folder}/SpeechTokenizer_best_dev.pt')
229 | ckpts = sorted(Path(path).parent.glob(f'SpeechTokenizerTrainer_*'))
230 | if len(ckpts) > self.num_ckpt_keep:
231 | [os.remove(c) for c in ckpts[:-self.num_ckpt_keep]]
232 | pkg = dict(
233 | generator = self.accelerator.get_state_dict(self.generator),
234 | discriminators = {k:self.accelerator.get_state_dict(v) for k, v in self.discriminators.items()},
235 | optim_g = self.optim_g.state_dict(),
236 | optim_d = self.optim_d.state_dict(),
237 | scheduler_g = self.scheduler_g.state_dict(),
238 | scheduler_d = self.scheduler_d.state_dict(),
239 | best_dev_mel_loss = self.best_dev_mel_loss
240 | )
241 | torch.save(pkg, path)
242 |
243 | def load(self, path = None, restore_optimizer = True):
244 | if not exists(path):
245 | ckpts = sorted(self.results_folder.glob(f'SpeechTokenizerTrainer_*'))
246 | path = str(ckpts[-1])
247 | generator = self.accelerator.unwrap_model(self.generator)
248 | pkg = torch.load(path, map_location='cpu')
249 | generator.load_state_dict(pkg['generator'])
250 | discriminators = {k:self.accelerator.unwrap_model(v) for k, v in self.discriminators.items()}
251 | map(lambda kv: kv[1].load_state_dict(pkg['discriminators'][kv[0]]), discriminators.items())
252 |
253 | if restore_optimizer:
254 | self.optim_d.load_state_dict(pkg['optim_d'])
255 | self.scheduler_d.load_state_dict(pkg['scheduler_d'])
256 | self.optim_g.load_state_dict(pkg['optim_g'])
257 | self.scheduler_g.load_state_dict(pkg['scheduler_g'])
258 | if 'best_dev_mel_loss' in pkg.keys():
259 | self.best_dev_mel_loss = pkg['best_dev_mel_loss']
260 | if self.is_main:
261 | self.print(f'The best dev mel loss before is {self.best_dev_mel_loss}')
262 |
263 | # + 1 to start from the next step and avoid overwriting the last checkpoint
264 | self.steps = torch.tensor([checkpoint_num_steps(path) + 1], device=self.device)
265 |
266 | def print(self, msg):
267 | self.accelerator.print(msg)
268 |
269 | @property
270 | def device(self):
271 | return self.accelerator.device
272 |
273 | @property
274 | def is_distributed(self):
275 | return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)
276 |
277 | @property
278 | def is_main(self):
279 | return self.accelerator.is_main_process
280 |
281 | @property
282 | def is_local_main(self):
283 | return self.accelerator.is_local_main_process
284 |
285 | def warmup(self, step):
286 | if step < self.num_warmup_steps:
287 | return self.initial_lr + (self.lr - self.initial_lr) * step / self.num_warmup_steps
288 | else:
289 | return self.lr
290 |
291 | def log(self, values: dict, step, type=None, **kwargs):
292 | if type == 'figure':
293 | for k, v in values.items():
294 | self.writer.add_figure(k, v, global_step=step)
295 | elif type == 'audio':
296 | for k, v in values.items():
297 | self.writer.add_audio(k, v, global_step=step, **kwargs)
298 | else:
299 | for k, v in values.items():
300 | self.writer.add_scalar(k, v, global_step=step)
301 |
302 | def train(self):
303 |
304 | self.generator.train()
305 | map(lambda disc:disc.train(), self.discriminators.values())
306 | step_time_log = {}
307 |
308 | steps = int(self.steps.item())
309 | if steps < self.num_warmup_steps:
310 | lr = self.warmup(steps)
311 | for param_group in self.optim.param_groups:
312 | param_group['lr'] = lr
313 | else:
314 | self.scheduler_d.step()
315 | self.scheduler_g.step()
316 | lr = self.scheduler_d.get_last_lr()[0]
317 |
318 | for epoch in range(self.epochs):
319 | if self.is_main:
320 | print(f'Epoch:{epoch} start...')
321 |
322 | for batch in self.dl:
323 |
324 | tic = time.time()
325 |
326 | x, semantic_feature = batch
327 | x = x.unsqueeze(1)
328 | x_hat, loss_q, feature = self.generator(x)
329 |
330 | # Discriminators
331 | self.optim_d.zero_grad()
332 | discriminator_outputs = list(map(lambda disc:disc(x, x_hat.detach()), self.discriminators.values()))
333 | loss_disc_all = sum(map(lambda x:discriminator_loss(*x[:2]), discriminator_outputs))
334 |
335 | self.accelerator.backward(loss_disc_all)
336 | self.optim_d.step()
337 |
338 | # Generator
339 | self.optim_g.zero_grad()
340 | discriminator_outputs = list(map(lambda disc:disc(x, x_hat), self.discriminators.values()))
341 | loss_recon = recon_loss(x, x_hat)
342 | loss_mel = sum(map(lambda mel_k:mel_k[0] * mel_loss(x, x_hat, **mel_k[1]), zip(self.mel_loss_lambdas, self.mel_loss_kwargs_list)))
343 | loss_feature = sum(map(lambda x:feature_loss(*x[2:]), discriminator_outputs))
344 | loss_adversarial = sum(map(lambda x:adversarial_loss(x[1]), discriminator_outputs))
345 | loss_distill = self.distill_loss(feature, semantic_feature)
346 | loss_generator_all = loss_feature + loss_adversarial + loss_mel + loss_q * self.commitment_loss_lambda + loss_recon * self.recon_loss_lambda + self.distill_loss_lambda * loss_distill
347 | self.accelerator.backward(loss_generator_all)
348 | # if exists(self.max_grad_norm):
349 | # self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
350 | self.optim_g.step()
351 |
352 | step_time_log = accum_log(step_time_log, {'time_cost': time.time() - tic})
353 | # self.accelerator.wait_for_everyone()
354 |
355 |
356 | # log
357 | if self.is_main and not (steps % self.stdout_steps):
358 | with torch.inference_mode():
359 | mel_error = mel_loss(x, x_hat, **self.mel_loss_kwargs_list[0]).item()
360 | self.print(f"Epoch {epoch} -- Step {steps}: Gen Loss: {loss_generator_all.item():0.3f}; Mel Error:{mel_error:0.3f}; Q Loss: {loss_q.item():0.3f}; Distill Loss: {loss_distill.item():0.3f}; Time cost per step: {step_time_log['time_cost'] / self.stdout_steps:0.3f}s")
361 | step_time_log = {}
362 | if self.is_main and not (steps % self.log_steps):
363 | self.log({"train/discriminators loss": loss_disc_all.item(), "train/generator loss": loss_generator_all.item(), "train/feature loss": loss_feature.item(),
364 | "train/adversarial loss": loss_adversarial.item(), "train/quantizer loss": loss_q.item(), "train/mel loss": loss_mel.item(),
365 | "train/mel error": mel_error, "train/distillation loss": loss_distill.item(), "train/learning_rate": lr}, step=steps)
366 |
367 | self.accelerator.wait_for_everyone()
368 |
369 | # validate and save model
370 | if self.is_main and not(steps % self.save_model_steps) and steps != 0:
371 |
372 | self.print('Validation start ...')
373 | # validate
374 | total_mel_error = 0.0
375 | total_distill_loss = 0.0
376 | num = 0
377 | self.generator.eval()
378 | with torch.inference_mode():
379 | for i, batch in tqdm(enumerate(self.valid_dl)):
380 | x, semantic_feature = batch
381 | x = x.unsqueeze(1)
382 | x_hat, loss_q, feature = self.generator(x)
383 | mel_error = mel_loss(x, x_hat, **self.mel_loss_kwargs_list[0]).item()
384 | total_mel_error += mel_error
385 | loss_distill = self.distill_loss(feature, semantic_feature).item()
386 | total_distill_loss += loss_distill
387 | num += x.size(0)
388 | if i < self.showpiece_num:
389 | if not self.plot_gt_once:
390 | self.log({f'groundtruth/x_{i}': x[0].cpu().detach()}, type='audio', sample_rate=self.sample_rate, step=steps)
391 | x_spec = mel_spectrogram(x.squeeze(1), **self.mel_kwargs)
392 | self.log({f'groundtruth/x_spec_{i}': plot_spectrogram(x_spec[0].cpu().numpy())}, type='figure', step=steps)
393 |
394 | self.log({f'generate/x_hat_{i}': x_hat[0].cpu().detach()}, type='audio', sample_rate=self.sample_rate, step=steps)
395 | x_hat_spec = mel_spectrogram(x_hat.squeeze(1), **self.mel_kwargs)
396 | self.log({f'generate/x_hat_spec_{i}': plot_spectrogram(x_hat_spec[0].cpu().numpy())}, type='figure', step=steps)
397 | if not self.plot_gt_once:
398 | self.plot_gt_once = True
399 | self.print(f'{steps}: dev mel error: {total_mel_error / num:0.3f}\tdev distill loss: {total_distill_loss / num:0.3f}')
400 | self.log({'dev/mel error': total_mel_error / num, 'dev/distillation loss': total_distill_loss / num}, step=steps)
401 |
402 |
403 | # save model
404 | model_path = str(self.results_folder / f'SpeechTokenizerTrainer_{steps:08d}')
405 | self.save(model_path, total_mel_error / num)
406 | self.print(f'{steps}: saving model to {str(self.results_folder)}')
407 | self.generator.train()
408 |
409 | # Update lr
410 | self.steps += 1
411 | steps = int(self.steps.item())
412 | if steps < self.num_warmup_steps:
413 | lr = self.warmup(steps)
414 | for param_group in self.optim_g.param_groups:
415 | param_group['lr'] = lr
416 | for param_group in self.optim_d.param_groups:
417 | param_group['lr'] = lr
418 | else:
419 | self.scheduler_d.step()
420 | self.scheduler_g.step()
421 | lr = self.scheduler_g.get_last_lr()[0]
422 |
423 | self.print('training complete')
424 |
425 | def continue_train(self):
426 | self.load()
427 | self.train()
428 |
429 |
--------------------------------------------------------------------------------