├── .gitignore ├── LICENSE ├── README.md ├── defaults.ini ├── pyproject.toml ├── setup.py ├── stable_codec ├── __init__.py ├── ctc_loss.py ├── data │ ├── Text2Phone │ │ ├── Text2PhoneTokenizer.py │ │ ├── __init__.py │ │ ├── abs_tokenizer.py │ │ ├── alignment_dict │ │ ├── dict_phone.txt │ │ ├── modules │ │ │ ├── __init__.py │ │ │ ├── data_gen_utils.py │ │ │ └── txt_processors │ │ │ │ ├── base_text_processor.py │ │ │ │ ├── en.py │ │ │ │ ├── en_syl.py │ │ │ │ ├── zh.py │ │ │ │ ├── zh_g2pM.py │ │ │ │ ├── zh_g2pM_song_seg.py │ │ │ │ └── zh_song_seg.py │ │ ├── phone_tokenizer.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── audio.py │ │ │ ├── ckpt_utils.py │ │ │ ├── pitch_utils.py │ │ │ ├── plot.py │ │ │ └── text_encoder.py │ └── dataset.py ├── fsq.py ├── model.py ├── residual_fsq.py ├── training_demo.py └── training_module.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | venv/ 3 | 4 | *.wav 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Stability AI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Stable Codec 2 | 3 | This repository contains training and inference scripts for models in the Stable Codec series, starting with `stable-codec-speech-16k` - introduced in the paper titled Scaling Transformers for Low-bitrate High-Quality Speech Coding. 4 | 5 | Paper: https://arxiv.org/abs/2411.19842 6 | 7 | Sound demos: https://stability-ai.github.io/stable-codec-demo/ 8 | 9 | Model weights: https://huggingface.co/stabilityai/stable-codec-speech-16k 10 | 11 | ## Changelog 12 | 13 | ### [v0.1.2] 14-01-25 14 | - __New__ added hooks for `stable-codec-speech-16k-base`. 15 | - __Fix__ fixed major issue with precision in FSQ token calculation, which was degrading results. Fix is currently local, will be upstreamed to `stable-audio-tools` later. 16 | ### [v0.1.1] 10-01-25 17 | - Release 18 | 19 | 20 | ## 21 | 22 | Note that whilst this code is MIT licensed, the model weights are covered by the [Stability AI Community License](https://huggingface.co/stabilityai/stable-codec-speech-16k/blob/main/LICENSE.md) 23 | 24 | ## Variants 25 | The model is currently available in two variants: 26 | - `stable-codec-speech-16k` is an improved finetune, with boosted latent semantics. __It should be used in 99% of use-cases.__ 27 | - `stable-codec-speech-16k-base` is the weights corresponding to the results in our [publication](https://arxiv.org/abs/2411.19842), provided for reproducibility. 28 | 29 | ### Additional Training 30 | 31 | In addition to the training described in the paper, the weights for `stable-codec-speech-16k` have undergone 500k steps of finetuning with force-aligned data from LibriLight and the English portion Multilingual LibriSpeech. This was performed by using a CTC head to regress the force-aligned phoneme tags from pre-bottleneck latents. We found that this additional training significantly boosted the applicability of the codec tokens to downstream tasks like TTS, at a small cost to objective reconstruction metrics. 32 | 33 | ## Install 34 | 35 | The model itself is defined in [stable-audio-tools](https://github.com/Stability-AI/stable-audio-tools) package. 36 | 37 | To install `stable-codec`: 38 | 39 | ```bash 40 | pip install stable-codec 41 | pip install -U flash-attn --no-build-isolation 42 | ``` 43 | 44 | **IMPORTANT NOTE:** This model currently has a hard requirement for FlashAttention due to its use of sliding window attention. Inference without FlashAttention will likely be greatly degraded. This also means that the model currently does not support CPU inference. We will relax the dependency on FlashAttention in the future. 45 | 46 | ## Encoding and decoding 47 | 48 | To encode audio or decode tokens, the `StableCodec` class provides a convenient wrapper for the model. It can be used with a local checkpoint and config as follows: 49 | 50 | ```python 51 | import torch 52 | import torchaudio 53 | from stable_codec import StableCodec 54 | 55 | model = StableCodec( 56 | model_config_path="", 57 | ckpt_path="", # optional, can be `None`, 58 | device = torch.device("cuda") 59 | ) 60 | 61 | audiopath = "audio.wav" 62 | 63 | latents, tokens = model.encode(audiopath) 64 | decoded_audio = model.decode(tokens) 65 | 66 | torchaudio.save("decoded.wav", decoded_audio, model.sample_rate) 67 | ``` 68 | 69 | To download the model weights automatically from HuggingFace, simply provide the model name: 70 | 71 | ```python 72 | model = StableCodec( 73 | pretrained_model = 'stabilityai/stable-codec-speech-16k' 74 | ) 75 | ``` 76 | ### Posthoc bottleneck configuration 77 | 78 | Most usecases will benefit from replacing the training-time FSQ bottleneck with a post-hoc FSQ bottleneck, as described in the paper. This allows token dictionary size to be reduced to a reasonable level for modern language models. This is achieved by calling the `set_posthoc_bottleneck` function, and setting a flag to the encode/decode calls: 79 | 80 | ```python 81 | model.set_posthoc_bottleneck("2x15625_700bps") 82 | latents, tokens = model.encode(audiopath, posthoc_bottleneck = True) 83 | decoded_audio = model.decode(tokens, posthoc_bottleneck = True) 84 | ``` 85 | `set_posthoc_bottleneck` can take a string as argument, which allows selection a number of recommended preset settings for the bottleneck: 86 | 87 | | Bottleneck Preset | Number of Tokens per step | Dictionary Size | Bits Per Second (bps) | 88 | |-------------------|------------------|-----------------|-----------------------| 89 | | `1x46656_400bps` | 1 | 46656 | 400 | 90 | | `2x15625_700bps` | 2 | 15625 | 700 | 91 | | `4x729_1000bps` | 4 | 729 | 1000 | 92 | 93 | Alternatively, the bottleneck stages can be specified directly. The format for specifying this can be seen in the definition of the `StableCodec` class in `model.py`. 94 | 95 | ### Normalization 96 | 97 | The model is trained with utterances normalized to -20 +-5 LUFS. The `encode` function normalizes to -20 LUFS by default, but it can be disabled by setting `normalize = False` when calling the function. 98 | 99 | ## Finetune 100 | 101 | To finetune a model given its config and checkpoint, execute `train.py` file: 102 | 103 | ```bash 104 | python train.py \ 105 | --project "stable-codec" \ 106 | --name "finetune" \ 107 | --config-file "defaults.ini" \ 108 | --save-dir "" \ 109 | --model-config "" \ 110 | --dataset-config "" \ 111 | --val-dataset-config "" \ 112 | --pretrained-ckpt-path "" \ 113 | --ckpt-path "$CKPT_PATH" \ 114 | --num-nodes $SLURM_JOB_NUM_NODES \ 115 | --num-workers 16 --batch-size 10 --precision "16-mixed" \ 116 | --checkpoint-every 10000 \ 117 | --logger "wandb" 118 | ``` 119 | 120 | For dataset configuration, refer to `stable-audio-tools` [dataset docs](https://github.com/Stability-AI/stable-audio-tools/blob/main/docs/datasets.md). 121 | 122 | 123 | ### Using CTC loss 124 | 125 | To use [CTC loss](https://pytorch.org/docs/stable/generated/torch.nn.CTCLoss.html) 126 | during training you have to enable it in the training configuration file 127 | and in the training dataset configuration. 128 | 129 | 1. Modifying training configuration: 130 | - Enable CTC projection head and set its hidden dimension: 131 | ```python 132 | config["model"]["use_proj_head"] = True 133 | config["model"]["proj_head_dim"] = 81 134 | ``` 135 | - Enable CTC in the training part of the config: 136 | ```python 137 | config["training"]["use_ctc"] = True 138 | ``` 139 | - And set its loss config: 140 | ```python 141 | config["training"]["loss_configs"]["ctc"] = { 142 | "blank_idx": 80, 143 | "decay": 1.0, 144 | "weights": {"ctc": 1.0} 145 | } 146 | ``` 147 | - Optionally, you can enable computation of the Phone-Error-Rate (PER) during validation: 148 | ```python 149 | config["training"]["eval_loss_configs"]["per"] = {} 150 | ``` 151 | 152 | 2. Configuring dataset (only WebDataset format is supported for CTC): 153 | - The dataset configuration should have one additional field set to it (see [dataset docs](https://github.com/Stability-AI/stable-audio-tools/blob/main/docs/datasets.md) for other options): 154 | ```python 155 | config["force_align_text"] = True 156 | ``` 157 | - And the JSON metadata file for each sample should contain force aligned transcript under `force_aligned_text` entry in the format specified below (besides other metadata). 158 | Where `transcript` is a list of word-level alignments with `start` and `end` fields specifying range **in seconds** of each word. 159 | ```json 160 | "normalized_text":"and i feel" 161 | "force_aligned_text":{ 162 | "transcript":[ 163 | { 164 | "word":"and", 165 | "start":0.2202, 166 | "end":0.3403 167 | }, 168 | { 169 | "word":"i", 170 | "start":0.4604, 171 | "end":0.4804 172 | }, 173 | { 174 | "word":"feel", 175 | "start":0.5204, 176 | "end":0.7006 177 | } 178 | ] 179 | } 180 | ``` 181 | ## Objective Metrics 182 | 183 | | Model | SI-SDR | Mel Dis | STFT Dis | PESQ | STOI | 184 | |---------------------------|-------:|--------:|---------:|-----:|-----:| 185 | | `stable-codec-speech-16k-base` | 4.73 | 0.86 | 1.26 | 3.09 | 0.92 | 186 | | `stable-codec-speech-16k` | 3.58 | 0.90 | 1.30 | 3.01 | 0.90 | 187 | 188 | -------------------------------------------------------------------------------- /defaults.ini: -------------------------------------------------------------------------------- 1 | [DEFAULTS] 2 | 3 | #name of the run 4 | name = stable_codec 5 | 6 | # name of the project 7 | project = None 8 | 9 | # the batch size 10 | batch_size = 4 11 | 12 | # Save top K model checkpoints during training. 13 | save_top_k = -1 14 | 15 | # number of nodes to use for training 16 | num_nodes = 1 17 | 18 | # Multi-GPU strategy for PyTorch Lightning 19 | strategy = "auto" 20 | 21 | # Precision to use for training 22 | precision = "16-mixed" 23 | 24 | # number of CPU workers for the DataLoader 25 | num_workers = 6 26 | 27 | # the random seed 28 | seed = 42 29 | 30 | # Batches for gradient accumulation 31 | accum_batches = 1 32 | 33 | # Number of steps between checkpoints 34 | checkpoint_every = 10000 35 | 36 | # Number of steps between validation runs 37 | val_every = -1 38 | 39 | # trainer checkpoint file to restart training from 40 | ckpt_path = '' 41 | 42 | # model checkpoint file to start a new training run from 43 | pretrained_ckpt_path = '' 44 | 45 | # Checkpoint path for the pretransform model if needed 46 | pretransform_ckpt_path = '' 47 | 48 | # configuration model specifying model hyperparameters 49 | model_config = '' 50 | 51 | # configuration for datasets 52 | dataset_config = '' 53 | 54 | # configuration for validation datasets 55 | val_dataset_config = '' 56 | 57 | # directory to save the checkpoints in 58 | save_dir = '' 59 | 60 | # gradient_clip_val passed into PyTorch Lightning Trainer 61 | gradient_clip_val = 0.0 62 | 63 | # remove the weight norm from the pretransform model 64 | remove_pretransform_weight_norm = '' 65 | 66 | # Logger type to use 67 | logger = 'wandb' 68 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools","wheel"] 3 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='stable-codec', 5 | version='0.1.2', 6 | author='Stability AI', 7 | author_email='julian.parker@stability.ai', 8 | description='Stable Codec: A series of codec models for speech and audio', 9 | long_description=open('README.md').read(), 10 | long_description_content_type='text/markdown', 11 | url='https://github.com/Stability-AI/stable-codec/', 12 | packages=find_packages(), 13 | python_requires='>=3.9', 14 | install_requires=['packaging', 15 | 'wheel', 16 | 'torch==2.4', 17 | 'torchaudio==2.4', 18 | 'stable-audio-tools==0.0.17', 19 | 'pytorch-lightning==2.1', 20 | 'prefigure==0.0.9'] 21 | ) -------------------------------------------------------------------------------- /stable_codec/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_codec.model import StableCodec -------------------------------------------------------------------------------- /stable_codec/ctc_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch.nn import functional as F 4 | from torch import nn 5 | 6 | from stable_audio_tools.training.losses import LossModule 7 | 8 | # https://pytorch.org/docs/stable/generated/torch.nn.CTCLoss.html 9 | class CTCLossModule(LossModule): 10 | def __init__( 11 | self, 12 | name: str, 13 | input_key: str, 14 | target_key: str, 15 | weight: float = 1.0, 16 | decay: float = 1.0, 17 | blank_idx: int = 0, 18 | padding_idx: int = None, 19 | input_lengths_key: str = None, 20 | ): 21 | super().__init__(name=name, weight=weight, decay=decay) 22 | self.ctc_loss = nn.CTCLoss(blank=blank_idx, reduction='mean', zero_infinity=True) 23 | self.input_key = input_key 24 | self.target_key = target_key 25 | self.input_lengths_key = input_lengths_key 26 | self.blank_idx = blank_idx 27 | self.padding_idx = padding_idx if padding_idx is not None else blank_idx + 1 28 | 29 | def forward(self, info): 30 | """ 31 | Computes the CTC loss. 32 | 33 | Args: 34 | info (dict): Dictionary containing model outputs and other relevant data. 35 | - info[self.input_key]: Model logits of shape (batch_size, sequence_length, num_classes). 36 | - info[self.target_key]: Target data (list of dicts with 'phone' key). 37 | - info[self.input_lengths_key]: (Optional) Actual lengths of the input sequences. 38 | 39 | Returns: 40 | loss (Tensor): The computed CTC loss, scaled by the weight. 41 | """ 42 | # Build targets and target lengths 43 | padded_targets, target_lengths = build_target(info[self.target_key], self.padding_idx) 44 | 45 | # Get logits from the model output 46 | logits = info[self.input_key] # Expected shape: (batch_size, sequence_length, num_classes) 47 | 48 | # Move logits to the device of phonemes 49 | device = padded_targets.device 50 | logits = logits.to(device) 51 | 52 | # Apply log_softmax to obtain log probabilities 53 | log_probs = F.log_softmax(logits, dim=-1) # Shape: (batch_size, seq_length, num_classes) 54 | 55 | # Transpose log_probs to match (seq_length, batch_size, num_classes) 56 | log_probs = log_probs.permute(1, 0, 2) # Now shape is (seq_length, batch_size, num_classes) 57 | 58 | # Determine input lengths 59 | if self.input_lengths_key and self.input_lengths_key in info: 60 | input_lengths = info[self.input_lengths_key].to(device) 61 | else: 62 | # Assume all input sequences have the same length 63 | input_lengths = torch.full( 64 | (log_probs.size(1),), # batch_size 65 | log_probs.size(0), # seq_length 66 | dtype=torch.long, 67 | device=device 68 | ) 69 | 70 | # Compute the CTC loss 71 | loss = self.ctc_loss(log_probs, padded_targets, input_lengths, target_lengths) 72 | 73 | loss = self.weight * loss 74 | 75 | return loss 76 | 77 | class PERModule(nn.Module): 78 | def __init__( 79 | self, 80 | input_key: str, 81 | target_key: str, 82 | blank_idx: int = 0, 83 | padding_idx: int = None, 84 | ): 85 | super().__init__() 86 | self.input_key = input_key 87 | self.target_key = target_key 88 | self.blank_idx = blank_idx 89 | self.padding_idx = padding_idx if padding_idx is not None else blank_idx + 1 90 | 91 | def decode_predictions(self, predicted_ids): 92 | """ 93 | Decodes the model predictions by collapsing repeats and removing blanks. 94 | 95 | Args: 96 | predicted_ids (Tensor): Tensor of shape (seq_length,) containing predicted token IDs. 97 | 98 | Returns: 99 | List[int]: Decoded sequence of token IDs. 100 | """ 101 | predicted_sequence = [] 102 | previous_id = None 103 | for id in predicted_ids: 104 | id = id.item() 105 | if id != self.blank_idx and id != previous_id: 106 | predicted_sequence.append(id) 107 | previous_id = id 108 | return predicted_sequence 109 | 110 | def forward(self, info): 111 | """ 112 | Computes the CTC loss. 113 | 114 | Args: 115 | info (dict): Dictionary containing model outputs and other relevant data. 116 | - info[self.input_key]: Model logits of shape (batch_size, sequence_length, num_classes). 117 | - info[self.target_key]: Target data (list of dicts with 'phone' key). 118 | - info[self.input_lengths_key]: (Optional) Actual lengths of the input sequences. 119 | 120 | Returns: 121 | loss (Tensor): The computed CTC loss, scaled by the weight. 122 | """ 123 | with torch.no_grad(): 124 | # Build targets and target lengths 125 | padded_targets, target_lengths = build_target(info[self.target_key], self.padding_idx) 126 | 127 | # Get logits from the model output 128 | logits = info[self.input_key] # Expected shape: (batch_size, sequence_length, num_classes) 129 | 130 | # Move logits to the device of phonemes 131 | device = padded_targets.device 132 | logits = logits.to(device) 133 | 134 | # Apply log_softmax to obtain log probabilities 135 | log_probs = F.log_softmax(logits, dim=-1) # Shape: (batch_size, seq_length, num_classes) 136 | 137 | # Transpose log_probs to match (seq_length, batch_size, num_classes) 138 | log_probs = log_probs.permute(1, 0, 2) # Now shape is (seq_length, batch_size, num_classes) 139 | 140 | # Get predictions via greedy decoding 141 | predicted_ids = torch.argmax(logits, dim=-1) # Shape: (batch_size, seq_length) 142 | 143 | batch_size = predicted_ids.size(0) 144 | pers = [] 145 | 146 | for i in range(batch_size): 147 | # Decode predictions 148 | pred_ids = predicted_ids[i] # Tensor of shape (seq_length,) 149 | pred_sequence = self.decode_predictions(pred_ids) 150 | 151 | # Get target sequence 152 | target_ids = padded_targets[i] # Tensor of shape (max_target_length,) 153 | target_length = target_lengths[i] 154 | target_sequence = target_ids[:target_length].tolist() 155 | 156 | # Remove padding tokens from target sequence 157 | target_sequence = [id for id in target_sequence if id != self.padding_idx] 158 | 159 | # Compute edit distance using the editdistance package 160 | # distance = editdistance.eval(pred_sequence, target_sequence) 161 | distance = edit_distance(pred_sequence, target_sequence) 162 | 163 | # Compute PER 164 | per = distance / max(len(target_sequence), 1) 165 | pers.append(per) 166 | 167 | # Compute average PER over the batch 168 | average_per = sum(pers) / len(pers) 169 | 170 | return average_per 171 | 172 | def edit_distance(seq1, seq2): 173 | """ 174 | Computes the edit distance between two sequences. 175 | 176 | Args: 177 | seq1 (List[int]): First sequence. 178 | seq2 (List[int]): Second sequence. 179 | 180 | Returns: 181 | int: The edit distance between seq1 and seq2. 182 | """ 183 | m = len(seq1) 184 | n = len(seq2) 185 | # Create a DP table 186 | dp = [[0] * (n + 1) for _ in range(m + 1)] 187 | # Initialize 188 | for i in range(m + 1): 189 | dp[i][0] = i 190 | for j in range(n + 1): 191 | dp[0][j] = j 192 | # Compute dp table 193 | for i in range(1, m + 1): 194 | for j in range(1, n + 1): 195 | if seq1[i - 1] == seq2[j - 1]: 196 | cost = 0 197 | else: 198 | cost = 1 199 | dp[i][j] = min( 200 | dp[i - 1][j] + 1, # deletion 201 | dp[i][j - 1] + 1, # insertion 202 | dp[i - 1][j - 1] + cost # substitution 203 | ) 204 | return dp[m][n] 205 | 206 | def build_target(batch, padding_idx): 207 | """ 208 | Builds padded targets and computes target lengths. 209 | 210 | Args: 211 | batch (list): A list of dictionaries, each containing a 'phone' key with tensor values. 212 | 213 | Returns: 214 | padded_targets (Tensor): Padded target sequences of shape (batch_size, max_target_length). 215 | target_lengths (Tensor): Lengths of each target sequence in the batch. 216 | """ 217 | # Extract phoneme sequences 218 | phoneme_sequences = [item['phone'] for item in batch] 219 | 220 | # Determine device from the phoneme sequences 221 | device = phoneme_sequences[0].device 222 | 223 | # Ensure phoneme sequences are 1D tensors 224 | phoneme_sequences = [seq.view(-1) if seq.ndim > 1 else seq for seq in phoneme_sequences] 225 | 226 | # Compute target lengths 227 | target_lengths = torch.tensor([seq.size(0) for seq in phoneme_sequences], dtype=torch.long, device=device) 228 | 229 | # Pad sequences 230 | padded_targets = nn.utils.rnn.pad_sequence( 231 | phoneme_sequences, 232 | batch_first=True, 233 | padding_value=padding_idx 234 | ).to(device) 235 | 236 | return padded_targets, target_lengths 237 | -------------------------------------------------------------------------------- /stable_codec/data/Text2Phone/Text2PhoneTokenizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | 5 | from .abs_tokenizer import AbsTokenizer 6 | from .modules.txt_processors.en import TxtProcessor 7 | 8 | class Text2PhoneTokenizer(AbsTokenizer): 9 | def __init__(self, duplicate=False): 10 | "Transfer the text input to the phone sequence" 11 | super(Text2PhoneTokenizer, self).__init__() 12 | self.txt_processor = TxtProcessor() # init the text processor 13 | self.phone_dict_path = os.path.join( 14 | os.path.dirname(os.path.abspath(__file__)), 15 | "dict_phone.txt") 16 | self.phone_dict = self.load_dict(self.phone_dict_path) 17 | self.duplicate = duplicate 18 | 19 | def load_dict(self, path): 20 | f = open(path, 'r') 21 | idx = 0 22 | phone_dict = {} 23 | for line in f: 24 | tmp = line.split(' ') 25 | phone = tmp[0] 26 | phone_dict[phone] = idx 27 | idx += 1 28 | return phone_dict 29 | 30 | def get_phone_sequence(self, text): 31 | # input the speech text, such as "I am talking with you". output the phone sequence 32 | phs, txt = self.txt_processor.process(text, {'use_tone': True}) 33 | return phs 34 | 35 | @property 36 | def is_discrete(self): 37 | return True 38 | 39 | def find_length(self, x): 40 | return len(self.tokenize(x)) 41 | 42 | def tokenize(self, x, task=None, cache=None): 43 | if isinstance(x, torch.Tensor): 44 | x = torch.unique_consecutive(x) if not self.duplicate else x 45 | return x 46 | elif isinstance(x, str): 47 | phs = self.get_phone_sequence(x) 48 | idxs = [self.phone_dict[id] for id in phs] 49 | idxs = np.array(idxs) 50 | idxs = torch.from_numpy(idxs).to(torch.int16) 51 | return idxs 52 | else: 53 | raise NotImplementedError 54 | 55 | @property 56 | def codebook_length(self): 57 | return len(self.phone_dict.keys()) 58 | 59 | if __name__ == '__main__': 60 | T2P_tokenizer = Text2PhoneTokenizer() 61 | text = "I am talking with you" 62 | phone = T2P_tokenizer.tokenize(text) 63 | print(phone) # AY1 | AE1 M | T AO1 K IH0 NG | W IH1 DH | Y UW1 64 | -------------------------------------------------------------------------------- /stable_codec/data/Text2Phone/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stable-codec/853554419944dbec11c39a29cf803e492aade836/stable_codec/data/Text2Phone/__init__.py -------------------------------------------------------------------------------- /stable_codec/data/Text2Phone/abs_tokenizer.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | class AbsTokenizer(torch.nn.Module): 5 | """ 6 | This is the virtual tokenizer class. 7 | Other tokenizers should inherit this class. 8 | typicially: 9 | Text -> BPE 10 | Text -> Phone 11 | Audio -> Codec 12 | Image -> Codec 13 | ... 14 | """ 15 | 16 | @property 17 | def is_discrete(self): 18 | """ 19 | Return True if the results are discrete token-ids: e.g., BPE / Phone / Codec 20 | Return False if the results are continuous embeddings: e.g., RoBERTa embeddings 21 | """ 22 | raise NotImplementedError 23 | 24 | @property 25 | def codebook_length(self): 26 | """ 27 | Return 0 if "self.is_discrete is False", 28 | otherwise returns the length of codebook. 29 | e.g., for audio codec that adopts 4 codebooks, each of which is in size of 1024, 30 | this is 4 * 1024 31 | This is used to create the shared vocabulary for softmax 32 | """ 33 | raise NotImplementedError 34 | 35 | def find_length(self, x): 36 | """ 37 | This method quickly returns the length of the output (usually without tokenization) 38 | This method is used in batchfying process: measure the whole length of the example 39 | typically: 40 | number of BPE / Frames / Codec sequence / Embedding lengths 41 | """ 42 | raise NotImplementedError 43 | 44 | def tokenize(self, x): 45 | """ Do tokenization. 46 | typically, x can be any input type, e.g., 47 | text: which is a path of the audio 48 | text: which is the exact text data for BPE / G2P 49 | Tensor: the loaded data. e.g., audio 50 | Returns 1-D LONG tensor when this is discrete 51 | Returns 2-D FLOAT tensor when this is continuous: [length, embedding_size] 52 | """ 53 | raise NotImplementedError 54 | 55 | def tokenize_batch(self, xs, lengths=None): 56 | """ batch version of tokenization 57 | Implementation of this method is optional, as it will only be used offline. 58 | 59 | warning: you should verify that the results of 'tokenize_batch' and 'tokenize' 60 | are actually (or roughly) identical (i.g., padding will not effect the results) 61 | 62 | return: list of 'tokenize' results. do NOT make it as a batched Tensor 63 | """ 64 | raise NotImplementedError 65 | 66 | def detokenize(self, x): 67 | """ This method recovers the original input based on the 'tokenize' result 68 | Implementation of this method is optional, as some tokenization process 69 | is not recoverable. i.g., hubert 70 | """ 71 | raise NotImplementedError 72 | -------------------------------------------------------------------------------- /stable_codec/data/Text2Phone/alignment_dict: -------------------------------------------------------------------------------- 1 | 2 | SIL_s 3 | SIL_m 4 | SIL_l 5 | SIL 6 | CH_B 7 | AE1_I 8 | P_I 9 | T_I 10 | ER0_I 11 | Z_E 12 | W_B 13 | AH1_I 14 | N_E 15 | TH_B 16 | R_I 17 | UW1_E 18 | S_B 19 | EH1_I 20 | V_I 21 | AH0_I 22 | AH1_B 23 | V_E 24 | JH_B 25 | N_I 26 | S_I 27 | S_E 28 | F_B 29 | M_E 30 | DH_B 31 | AH0_E 32 | HH_B 33 | OW1_I 34 | L_I 35 | IY0_E 36 | B_B 37 | AY1_I 38 | B_I 39 | L_E 40 | IH1_B 41 | M_B 42 | AA1_I 43 | D_I 44 | NG_I 45 | IH0_I 46 | SH_E 47 | IH1_I 48 | AH0_S 49 | L_B 50 | K_I 51 | R_B 52 | AO1_I 53 | NG_E 54 | AO1_B 55 | V_B 56 | ER0_S 57 | IH0_B 58 | P_B 59 | K_E 60 | D_B 61 | OW0_I 62 | M_I 63 | EY1_I 64 | R_E 65 | IH2_B 66 | F_I 67 | SH_I 68 | T_B 69 | AA2_I 70 | IY1_I 71 | Z_I 72 | T_E 73 | G_E 74 | AY1_E 75 | OW2_I 76 | D_E 77 | CH 78 | AE1 79 | P 80 | T 81 | ER0 82 | Z 83 | W 84 | AH1 85 | N 86 | TH 87 | R 88 | UW1 89 | S 90 | EH1 91 | V 92 | AH0 93 | JH 94 | F 95 | M 96 | DH 97 | HH 98 | OW1 99 | L 100 | IY0 101 | B 102 | AY1 103 | IH1 104 | AA1 105 | D 106 | NG 107 | IH0 108 | SH 109 | K 110 | AO1 111 | OW0 112 | EY1 113 | IH2 114 | AA2 115 | IY1 116 | G 117 | OW2 118 | G_I 119 | AE0_I 120 | ER0_E 121 | AE0 122 | G_B 123 | JH_E 124 | UW1_I 125 | EH2_I 126 | IY0_I 127 | JH_I 128 | K_B 129 | EY1_E 130 | AE1_B 131 | N_B 132 | AH0_B 133 | OW1_E 134 | EH1_B 135 | ER1_B 136 | TH_E 137 | EH2 138 | ER1 139 | ZH_I 140 | W_I 141 | AO1_E 142 | Y_I 143 | EY1_B 144 | ZH 145 | Y 146 | IY1_E 147 | F_E 148 | ER1_I 149 | P_E 150 | AY0_I 151 | OW0_E 152 | AY0 153 | AY2_E 154 | AY2 155 | AA0_I 156 | AA0 157 | AA1_S 158 | AW1_B 159 | SH_B 160 | OW2_E 161 | AW1 162 | OW1_S 163 | OW1_B 164 | UW2_I 165 | EY2_I 166 | UW2 167 | EY2 168 | CH_E 169 | AY2_I 170 | AY1_S 171 | Y_B 172 | AO2_I 173 | ER0_B 174 | AO2 175 | ER1_E 176 | UH1_I 177 | AO2_B 178 | B_E 179 | IY1_B 180 | UH1 181 | EY1_S 182 | CH_I 183 | AW1_I 184 | DH_I 185 | AA1_B 186 | AH2_I 187 | AH2 188 | AW1_E 189 | AY1_B 190 | OY1_I 191 | OY1 192 | AA0_B 193 | AE0_B 194 | IH2_I 195 | AW2_I 196 | AW2 197 | UW0_I 198 | UW0 199 | HH_I 200 | ZH_E 201 | TH_I 202 | EH0_I 203 | HH_E 204 | Z_B 205 | EH0 206 | AA0_E 207 | IY2_E 208 | IY2 209 | IY2_I 210 | UH0_I 211 | UH0 212 | OW2_B 213 | EY0_I 214 | EY0 215 | UH2_I 216 | UH2 217 | AE2_I 218 | AE2 219 | IH0_E 220 | EH2_B 221 | EY2_E 222 | AH2_B 223 | AE2_B 224 | EH1_S 225 | AA1_E 226 | DH_E 227 | EH0_B 228 | ER2_I 229 | ER2 230 | OY1_E 231 | AO0_I 232 | AO0 233 | AH1_E 234 | AY0_B 235 | UW0_E 236 | IY0_B 237 | AW2_B 238 | IY1_S 239 | AO0_B 240 | UH1_B 241 | OW0_B 242 | AO2_E 243 | UW1_B 244 | OY1_B 245 | AA2_B 246 | AW0_I 247 | EY2_B 248 | AW0 249 | AY0_E 250 | AW0_B 251 | OY2_I 252 | OY2 253 | AH2_E 254 | AY2_B 255 | UW2_E 256 | ZH_B 257 | EY0_E 258 | UW1_S 259 | IY2_B 260 | Y_E 261 | EH1_E 262 | K_S 263 | AW2_E 264 | AO1_S 265 | OY0_E 266 | OY0 267 | EY0_B 268 | OY2_E 269 | AA2_E 270 | ER2_E 271 | OY1_S 272 | M_S 273 | AW0_E 274 | OY0_I 275 | AO0_E 276 | IH2_E 277 | OY0_B 278 | AW1_S 279 | UW0_B 280 | AE1_E 281 | JH_S 282 | SH_S 283 | CH_S 284 | S_S 285 | ER1_S 286 | L_S 287 | UH1_E 288 | EH0_E 289 | IH1_E 290 | EH2_E 291 | W_E 292 | UW2_B 293 | AE0_E 294 | AW2_S 295 | OW0_S 296 | G_S 297 | AH1_S 298 | TH_S 299 | D_S 300 | F_S 301 | ZH_S 302 | IY0_S 303 | R_S 304 | Z_S 305 | N_S 306 | -------------------------------------------------------------------------------- /stable_codec/data/Text2Phone/dict_phone.txt: -------------------------------------------------------------------------------- 1 | | 3074372 2 | AH0 780464 3 | T 733022 4 | N 714038 5 | D 505406 6 | S 485556 7 | R 436603 8 | L 402851 9 | IH1 330354 10 | DH 328795 11 | M 298727 12 | K 289169 13 | Z 280531 14 | EH1 277636 15 | AE1 262226 16 | IH0 259084 17 | AH1 252593 18 | W 238394 19 | HH 218537 20 | , 212347 21 | ER0 206850 22 | P 199148 23 | IY1 198226 24 | V 197271 25 | F 192464 26 | B 182706 27 | UW1 181468 28 | AA1 178815 29 | AY1 170877 30 | AO1 149112 31 | EY1 143429 32 | . 142949 33 | IY0 140939 34 | OW1 113531 35 | NG 108070 36 | G 90454 37 | SH 82222 38 | Y 69774 39 | AW1 60297 40 | CH 57051 41 | ER1 53272 42 | TH 50546 43 | UH1 49529 44 | JH 45787 45 | ' 15058 46 | OW0 14618 47 | ? 13792 48 | EH2 13602 49 | ! 11626 50 | IH2 11207 51 | OY1 10220 52 | EY2 10117 53 | AY2 9117 54 | EH0 8194 55 | AE2 7370 56 | UW0 7245 57 | AA2 7211 58 | - 6158 59 | OW2 6064 60 | AO2 5015 61 | AH2 4861 62 | AE0 4744 63 | ZH 4296 64 | AA0 4081 65 | UW2 3583 66 | IY2 3097 67 | AO0 2789 68 | AY0 2751 69 | AW2 2486 70 | EY0 1571 71 | UH2 1313 72 | ER2 1092 73 | AW0 1017 74 | ... 843 75 | UH0 337 76 | OY2 262 77 | OY0 78 78 | .. 10 79 | ; 11 80 | : 12 -------------------------------------------------------------------------------- /stable_codec/data/Text2Phone/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stability-AI/stable-codec/853554419944dbec11c39a29cf803e492aade836/stable_codec/data/Text2Phone/modules/__init__.py -------------------------------------------------------------------------------- /stable_codec/data/Text2Phone/modules/data_gen_utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.filterwarnings("ignore") 4 | 5 | # import parselmouth 6 | import os 7 | import torch 8 | from skimage.transform import resize 9 | from tools.tokenizer.Text2Phone.utils.text_encoder import TokenTextEncoder 10 | from tools.tokenizer.Text2Phone.utils.pitch_utils import f0_to_coarse 11 | import struct 12 | import webrtcvad 13 | from scipy.ndimage.morphology import binary_dilation 14 | import librosa 15 | import numpy as np 16 | from tools.tokenizer.Text2Phone.utils import audio 17 | import pyloudnorm as pyln 18 | import re 19 | import json 20 | from collections import OrderedDict 21 | 22 | PUNCS = '!,.?;:' 23 | 24 | int16_max = (2 ** 15) - 1 25 | 26 | 27 | def trim_long_silences(path, sr=None, return_raw_wav=False, norm=True, vad_max_silence_length=12): 28 | """ 29 | Ensures that segments without voice in the waveform remain no longer than a 30 | threshold determined by the VAD parameters in params.py. 31 | :param wav: the raw waveform as a numpy array of floats 32 | :param vad_max_silence_length: Maximum number of consecutive silent frames a segment can have. 33 | :return: the same waveform with silences trimmed away (length <= original wav length) 34 | """ 35 | 36 | ## Voice Activation Detection 37 | # Window size of the VAD. Must be either 10, 20 or 30 milliseconds. 38 | # This sets the granularity of the VAD. Should not need to be changed. 39 | sampling_rate = 16000 40 | wav_raw, sr = librosa.core.load(path, sr=sr) 41 | 42 | if norm: 43 | meter = pyln.Meter(sr) # create BS.1770 meter 44 | loudness = meter.integrated_loudness(wav_raw) 45 | wav_raw = pyln.normalize.loudness(wav_raw, loudness, -20.0) 46 | if np.abs(wav_raw).max() > 1.0: 47 | wav_raw = wav_raw / np.abs(wav_raw).max() 48 | 49 | wav = librosa.resample(wav_raw, sr, sampling_rate, res_type='kaiser_best') 50 | 51 | vad_window_length = 30 # In milliseconds 52 | # Number of frames to average together when performing the moving average smoothing. 53 | # The larger this value, the larger the VAD variations must be to not get smoothed out. 54 | vad_moving_average_width = 8 55 | 56 | # Compute the voice detection window size 57 | samples_per_window = (vad_window_length * sampling_rate) // 1000 58 | 59 | # Trim the end of the audio to have a multiple of the window size 60 | wav = wav[:len(wav) - (len(wav) % samples_per_window)] 61 | 62 | # Convert the float waveform to 16-bit mono PCM 63 | pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16)) 64 | 65 | # Perform voice activation detection 66 | voice_flags = [] 67 | vad = webrtcvad.Vad(mode=3) 68 | for window_start in range(0, len(wav), samples_per_window): 69 | window_end = window_start + samples_per_window 70 | voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2], 71 | sample_rate=sampling_rate)) 72 | voice_flags = np.array(voice_flags) 73 | 74 | # Smooth the voice detection with a moving average 75 | def moving_average(array, width): 76 | array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2))) 77 | ret = np.cumsum(array_padded, dtype=float) 78 | ret[width:] = ret[width:] - ret[:-width] 79 | return ret[width - 1:] / width 80 | 81 | audio_mask = moving_average(voice_flags, vad_moving_average_width) 82 | audio_mask = np.round(audio_mask).astype(np.bool) 83 | 84 | # Dilate the voiced regions 85 | audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1)) 86 | audio_mask = np.repeat(audio_mask, samples_per_window) 87 | audio_mask = resize(audio_mask, (len(wav_raw),)) > 0 88 | if return_raw_wav: 89 | return wav_raw, audio_mask, sr 90 | return wav_raw[audio_mask], audio_mask, sr 91 | 92 | 93 | def process_utterance(wav_path, 94 | fft_size=1024, 95 | hop_size=256, 96 | win_length=1024, 97 | window="hann", 98 | num_mels=80, 99 | fmin=80, 100 | fmax=7600, 101 | eps=1e-6, 102 | sample_rate=22050, 103 | loud_norm=False, 104 | min_level_db=-100, 105 | return_linear=False, 106 | trim_long_sil=False, vocoder='pwg'): 107 | if isinstance(wav_path, str): 108 | if trim_long_sil: 109 | wav, _, _ = trim_long_silences(wav_path, sample_rate) 110 | else: 111 | wav, _ = librosa.core.load(wav_path, sr=sample_rate) 112 | else: 113 | wav = wav_path 114 | 115 | if loud_norm: 116 | meter = pyln.Meter(sample_rate) # create BS.1770 meter 117 | loudness = meter.integrated_loudness(wav) 118 | wav = pyln.normalize.loudness(wav, loudness, -22.0) 119 | if np.abs(wav).max() > 1: 120 | wav = wav / np.abs(wav).max() 121 | 122 | # get amplitude spectrogram 123 | x_stft = librosa.stft(wav, n_fft=fft_size, hop_length=hop_size, 124 | win_length=win_length, window=window, pad_mode="constant") 125 | spc = np.abs(x_stft) # (n_bins, T) 126 | 127 | # get mel basis 128 | fmin = 0 if fmin == -1 else fmin 129 | fmax = sample_rate / 2 if fmax == -1 else fmax 130 | mel_basis = librosa.filters.mel(sample_rate, fft_size, num_mels, fmin, fmax) 131 | mel = mel_basis @ spc 132 | 133 | if vocoder == 'pwg': 134 | mel = np.log10(np.maximum(eps, mel)) # (n_mel_bins, T) 135 | else: 136 | assert False, f'"{vocoder}" is not in ["pwg"].' 137 | 138 | l_pad, r_pad = audio.librosa_pad_lr(wav, fft_size, hop_size, 1) 139 | wav = np.pad(wav, (l_pad, r_pad), mode='constant', constant_values=0.0) 140 | wav = wav[:mel.shape[1] * hop_size] 141 | 142 | if not return_linear: 143 | return wav, mel 144 | else: 145 | spc = audio.amp_to_db(spc) 146 | spc = audio.normalize(spc, {'min_level_db': min_level_db}) 147 | return wav, mel, spc 148 | 149 | 150 | def get_pitch(wav_data, mel, hparams): 151 | """ 152 | 153 | :param wav_data: [T] 154 | :param mel: [T, 80] 155 | :param hparams: 156 | :return: 157 | """ 158 | time_step = hparams['hop_size'] / hparams['audio_sample_rate'] * 1000 159 | f0_min = 80 160 | f0_max = 750 161 | 162 | if hparams['pitch_extractor'] == 'harvest': 163 | import pyworld as pw 164 | f0, t = pw.harvest(wav_data.astype(np.double), hparams['audio_sample_rate'], 165 | frame_period=hparams['hop_size'] / hparams['audio_sample_rate'] * 1000) 166 | if hparams['pitch_extractor'] == 'dio': 167 | _f0, t = pw.dio(wav_data.astype(np.double), hparams['audio_sample_rate'], 168 | frame_period=hparams['hop_size'] / hparams['audio_sample_rate'] * 1000) 169 | f0 = pw.stonemask(wav_data.astype(np.double), _f0, t, hparams['audio_sample_rate']) # pitch refinement 170 | elif hparams['pitch_extractor'] == 'parselmouth': 171 | if hparams['hop_size'] == 128: 172 | pad_size = 4 173 | elif hparams['hop_size'] == 256: 174 | pad_size = 2 175 | else: 176 | assert False 177 | f0 = parselmouth.Sound(wav_data, hparams['audio_sample_rate']).to_pitch_ac( 178 | time_step=time_step / 1000, voicing_threshold=0.6, 179 | pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency'] 180 | lpad = pad_size * 2 181 | rpad = len(mel) - len(f0) - lpad 182 | f0 = np.pad(f0, [[lpad, rpad]], mode='constant') 183 | 184 | # mel和f0是2个库抽的 需要保证两者长度一致 185 | delta_l = len(mel) - len(f0) 186 | assert np.abs(delta_l) <= 8 187 | if delta_l > 0: 188 | f0 = np.concatenate([f0, [f0[-1]] * delta_l], 0) 189 | f0 = f0[:len(mel)] 190 | pitch_coarse = f0_to_coarse(f0) 191 | return f0, pitch_coarse 192 | 193 | 194 | def remove_empty_lines(text): 195 | """remove empty lines""" 196 | assert (len(text) > 0) 197 | assert (isinstance(text, list)) 198 | text = [t.strip() for t in text] 199 | if "" in text: 200 | text.remove("") 201 | return text 202 | 203 | 204 | class TextGrid(object): 205 | def __init__(self, text): 206 | text = remove_empty_lines(text) 207 | self.text = text 208 | self.line_count = 0 209 | self._get_type() 210 | self._get_time_intval() 211 | self._get_size() 212 | self.tier_list = [] 213 | self._get_item_list() 214 | 215 | def _extract_pattern(self, pattern, inc): 216 | """ 217 | Parameters 218 | ---------- 219 | pattern : regex to extract pattern 220 | inc : increment of line count after extraction 221 | Returns 222 | ------- 223 | group : extracted info 224 | """ 225 | try: 226 | group = re.match(pattern, self.text[self.line_count]).group(1) 227 | self.line_count += inc 228 | except AttributeError: 229 | raise ValueError("File format error at line %d:%s" % (self.line_count, self.text[self.line_count])) 230 | return group 231 | 232 | def _get_type(self): 233 | self.file_type = self._extract_pattern(r"File type = \"(.*)\"", 2) 234 | 235 | def _get_time_intval(self): 236 | self.xmin = self._extract_pattern(r"xmin = (.*)", 1) 237 | self.xmax = self._extract_pattern(r"xmax = (.*)", 2) 238 | 239 | def _get_size(self): 240 | self.size = int(self._extract_pattern(r"size = (.*)", 2)) 241 | 242 | def _get_item_list(self): 243 | """Only supports IntervalTier currently""" 244 | for itemIdx in range(1, self.size + 1): 245 | tier = OrderedDict() 246 | item_list = [] 247 | tier_idx = self._extract_pattern(r"item \[(.*)\]:", 1) 248 | tier_class = self._extract_pattern(r"class = \"(.*)\"", 1) 249 | if tier_class != "IntervalTier": 250 | raise NotImplementedError("Only IntervalTier class is supported currently") 251 | tier_name = self._extract_pattern(r"name = \"(.*)\"", 1) 252 | tier_xmin = self._extract_pattern(r"xmin = (.*)", 1) 253 | tier_xmax = self._extract_pattern(r"xmax = (.*)", 1) 254 | tier_size = self._extract_pattern(r"intervals: size = (.*)", 1) 255 | for i in range(int(tier_size)): 256 | item = OrderedDict() 257 | item["idx"] = self._extract_pattern(r"intervals \[(.*)\]", 1) 258 | item["xmin"] = self._extract_pattern(r"xmin = (.*)", 1) 259 | item["xmax"] = self._extract_pattern(r"xmax = (.*)", 1) 260 | item["text"] = self._extract_pattern(r"text = \"(.*)\"", 1) 261 | item_list.append(item) 262 | tier["idx"] = tier_idx 263 | tier["class"] = tier_class 264 | tier["name"] = tier_name 265 | tier["xmin"] = tier_xmin 266 | tier["xmax"] = tier_xmax 267 | tier["size"] = tier_size 268 | tier["items"] = item_list 269 | self.tier_list.append(tier) 270 | 271 | def toJson(self): 272 | _json = OrderedDict() 273 | _json["file_type"] = self.file_type 274 | _json["xmin"] = self.xmin 275 | _json["xmax"] = self.xmax 276 | _json["size"] = self.size 277 | _json["tiers"] = self.tier_list 278 | return json.dumps(_json, ensure_ascii=False, indent=2) 279 | 280 | 281 | def get_mel2ph(tg_fn, ph, mel, hparams): 282 | ph_list = ph.split(" ") 283 | with open(tg_fn, "r") as f: 284 | tg = f.readlines() 285 | tg = remove_empty_lines(tg) 286 | tg = TextGrid(tg) 287 | tg = json.loads(tg.toJson()) 288 | split = np.ones(len(ph_list) + 1, np.float) * -1 289 | tg_idx = 0 290 | ph_idx = 0 291 | tg_align = [x for x in tg['tiers'][-1]['items']] 292 | tg_align_ = [] 293 | for x in tg_align: 294 | x['xmin'] = float(x['xmin']) 295 | x['xmax'] = float(x['xmax']) 296 | if x['text'] in ['sil', 'sp', '', 'SIL', 'PUNC']: 297 | x['text'] = '' 298 | if len(tg_align_) > 0 and tg_align_[-1]['text'] == '': 299 | tg_align_[-1]['xmax'] = x['xmax'] 300 | continue 301 | tg_align_.append(x) 302 | tg_align = tg_align_ 303 | tg_len = len([x for x in tg_align if x['text'] != '']) 304 | ph_len = len([x for x in ph_list if not is_sil_phoneme(x)]) 305 | assert tg_len == ph_len, (tg_len, ph_len, tg_align, ph_list, tg_fn) 306 | while tg_idx < len(tg_align) or ph_idx < len(ph_list): 307 | if tg_idx == len(tg_align) and is_sil_phoneme(ph_list[ph_idx]): 308 | split[ph_idx] = 1e8 309 | ph_idx += 1 310 | continue 311 | x = tg_align[tg_idx] 312 | if x['text'] == '' and ph_idx == len(ph_list): 313 | tg_idx += 1 314 | continue 315 | assert ph_idx < len(ph_list), (tg_len, ph_len, tg_align, ph_list, tg_fn) 316 | ph = ph_list[ph_idx] 317 | if x['text'] == '' and not is_sil_phoneme(ph): 318 | assert False, (ph_list, tg_align) 319 | if x['text'] != '' and is_sil_phoneme(ph): 320 | ph_idx += 1 321 | else: 322 | assert (x['text'] == '' and is_sil_phoneme(ph)) \ 323 | or x['text'].lower() == ph.lower() \ 324 | or x['text'].lower() == 'sil', (x['text'], ph) 325 | split[ph_idx] = x['xmin'] 326 | if ph_idx > 0 and split[ph_idx - 1] == -1 and is_sil_phoneme(ph_list[ph_idx - 1]): 327 | split[ph_idx - 1] = split[ph_idx] 328 | ph_idx += 1 329 | tg_idx += 1 330 | assert tg_idx == len(tg_align), (tg_idx, [x['text'] for x in tg_align]) 331 | assert ph_idx >= len(ph_list) - 1, (ph_idx, ph_list, len(ph_list), [x['text'] for x in tg_align], tg_fn) 332 | mel2ph = np.zeros([mel.shape[0]], np.int) 333 | split[0] = 0 334 | split[-1] = 1e8 335 | for i in range(len(split) - 1): 336 | assert split[i] != -1 and split[i] <= split[i + 1], (split[:-1],) 337 | split = [int(s * hparams['audio_sample_rate'] / hparams['hop_size'] + 0.5) for s in split] 338 | for ph_idx in range(len(ph_list)): 339 | mel2ph[split[ph_idx]:split[ph_idx + 1]] = ph_idx + 1 340 | mel2ph_torch = torch.from_numpy(mel2ph) 341 | T_t = len(ph_list) 342 | dur = mel2ph_torch.new_zeros([T_t + 1]).scatter_add(0, mel2ph_torch, torch.ones_like(mel2ph_torch)) 343 | dur = dur[1:].numpy() 344 | return mel2ph, dur 345 | 346 | 347 | def build_phone_encoder(data_dir): 348 | phone_list_file = os.path.join(data_dir, 'phone_set.json') 349 | phone_list = json.load(open(phone_list_file)) 350 | return TokenTextEncoder(None, vocab_list=phone_list, replace_oov=',') 351 | 352 | 353 | def is_sil_phoneme(p): 354 | return not p[0].isalpha() 355 | -------------------------------------------------------------------------------- /stable_codec/data/Text2Phone/modules/txt_processors/base_text_processor.py: -------------------------------------------------------------------------------- 1 | class BaseTxtProcessor: 2 | @staticmethod 3 | def sp_phonemes(): 4 | return ['|'] 5 | 6 | @classmethod 7 | def process(cls, txt, pre_align_args): 8 | raise NotImplementedError 9 | -------------------------------------------------------------------------------- /stable_codec/data/Text2Phone/modules/txt_processors/en.py: -------------------------------------------------------------------------------- 1 | import re 2 | # from tools.tokenizer.Text2Phone.modules.data_gen_utils import PUNCS 3 | PUNCS = '!,.?;:' # code above raise import error 4 | from g2p_en import G2p 5 | import unicodedata 6 | from g2p_en.expand import normalize_numbers 7 | from nltk import pos_tag 8 | from nltk.tokenize import TweetTokenizer 9 | import nltk 10 | # nltk.download('averaged_perceptron_tagger_eng') 11 | from .base_text_processor import BaseTxtProcessor 12 | 13 | class EnG2p(G2p): 14 | word_tokenize = TweetTokenizer().tokenize 15 | 16 | def __call__(self, text): 17 | # preprocessing 18 | words = EnG2p.word_tokenize(text) 19 | tokens = pos_tag(words) # tuples of (word, tag) 20 | 21 | # steps 22 | prons = [] 23 | for word, pos in tokens: 24 | if re.search("[a-z]", word) is None: 25 | pron = [word] 26 | 27 | elif word in self.homograph2features: # Check homograph 28 | pron1, pron2, pos1 = self.homograph2features[word] 29 | if pos.startswith(pos1): 30 | pron = pron1 31 | else: 32 | pron = pron2 33 | elif word in self.cmu: # lookup CMU dict 34 | pron = self.cmu[word][0] 35 | else: # predict for oov 36 | pron = self.predict(word) 37 | 38 | prons.extend(pron) 39 | prons.extend([" "]) 40 | 41 | return prons[:-1] 42 | 43 | 44 | class TxtProcessor(BaseTxtProcessor): 45 | g2p = EnG2p() 46 | 47 | @staticmethod 48 | def preprocess_text(text): 49 | text = normalize_numbers(text) 50 | text = ''.join(char for char in unicodedata.normalize('NFD', text) 51 | if unicodedata.category(char) != 'Mn') # Strip accents 52 | text = text.lower() 53 | text = re.sub("[\'\"()]+", "", text) 54 | text = re.sub("[-]+", " ", text) 55 | text = re.sub(f"[^ a-z{PUNCS}]", "", text) 56 | text = re.sub(f" ?([{PUNCS}]) ?", r"\1", text) # !! -> ! 57 | text = re.sub(f"([{PUNCS}])+", r"\1", text) # !! -> ! 58 | text = text.replace("i.e.", "that is") 59 | text = text.replace("i.e.", "that is") 60 | text = text.replace("etc.", "etc") 61 | text = re.sub(f"([{PUNCS}])", r" \1 ", text) 62 | text = re.sub(rf"\s+", r" ", text) 63 | return text 64 | 65 | @classmethod 66 | def process(cls, txt, pre_align_args): 67 | txt = cls.preprocess_text(txt).strip() 68 | phs = cls.g2p(txt) 69 | phs_ = [] 70 | n_word_sep = 0 71 | for p in phs: 72 | if p.strip() == '': 73 | phs_ += ['|'] 74 | n_word_sep += 1 75 | else: 76 | phs_ += p.split(" ") 77 | phs = phs_ 78 | assert n_word_sep + 1 == len(txt.split(" ")), (phs, f"\"{txt}\"") 79 | return phs, txt 80 | -------------------------------------------------------------------------------- /stable_codec/data/Text2Phone/modules/txt_processors/en_syl.py: -------------------------------------------------------------------------------- 1 | from syllabipy.sonoripy import SonoriPy 2 | from modules.txt_processors import en 3 | 4 | 5 | class TxtProcessor(en.TxtProcessor): 6 | @classmethod 7 | def process(cls, txt, pre_align_args): 8 | txt = cls.preprocess_text(txt) 9 | phs = [] 10 | for p in txt.split(" "): 11 | if len(p) == 0: 12 | continue 13 | syl = SonoriPy(p) 14 | if len(syl) == 0: 15 | phs += list(p) 16 | else: 17 | for x in syl: 18 | phs += list(x) 19 | phs += ['|'] 20 | return phs, txt 21 | -------------------------------------------------------------------------------- /stable_codec/data/Text2Phone/modules/txt_processors/zh.py: -------------------------------------------------------------------------------- 1 | import re 2 | import jieba 3 | from pypinyin import pinyin, Style 4 | from modules.data_gen_utils import PUNCS 5 | from modules.txt_processors.base_text_processor import BaseTxtProcessor 6 | from utils.text_norm import NSWNormalizer 7 | 8 | 9 | ALL_SHENMU = ['zh', 'ch', 'sh', 'b', 'p', 'm', 'f', 'd', 't', 'n', 'l', 'g', 'k', 'h', 'j', 10 | 'q', 'x', 'r', 'z', 'c', 's', 'y', 'w'] 11 | 12 | 13 | class TxtProcessor(BaseTxtProcessor): 14 | table = {ord(f): ord(t) for f, t in zip( 15 | u':,。!?【】()%#@&1234567890', 16 | u':,.!?[]()%#@&1234567890')} 17 | 18 | @staticmethod 19 | def sp_phonemes(): 20 | return ['|', '#'] 21 | 22 | @staticmethod 23 | def preprocess_text(text): 24 | text = text.translate(TxtProcessor.table) 25 | text = NSWNormalizer(text).normalize(remove_punc=False).lower() 26 | text = re.sub("[\'\"()]+", "", text) 27 | text = re.sub("[-]+", " ", text) 28 | text = re.sub(f"[^ A-Za-z\u4e00-\u9fff{PUNCS}]", "", text) 29 | text = re.sub(f"([{PUNCS}])+", r"\1", text) # !! -> ! 30 | text = re.sub(f"([{PUNCS}])", r" \1 ", text) 31 | text = re.sub(rf"\s+", r"", text) 32 | text = re.sub(rf"[A-Za-z]+", r"$", text) 33 | return text 34 | 35 | @classmethod 36 | def pinyin_with_en(cls, txt, style): 37 | x = pinyin(txt, style) 38 | x = [t[0] for t in x] 39 | x_ = [] 40 | for t in x: 41 | if '$' not in t: 42 | x_.append(t) 43 | else: 44 | x_ += list(t) 45 | x_ = [t if t != '$' else 'ENG' for t in x_] 46 | return x_ 47 | 48 | @classmethod 49 | def process(cls, txt, pre_align_args): 50 | txt = cls.preprocess_text(txt) 51 | 52 | # https://blog.csdn.net/zhoulei124/article/details/89055403 53 | shengmu = cls.pinyin_with_en(txt, style=Style.INITIALS) 54 | yunmu = cls.pinyin_with_en(txt, style= 55 | Style.FINALS_TONE3 if pre_align_args['use_tone'] else Style.FINALS) 56 | assert len(shengmu) == len(yunmu) 57 | ph_list = [] 58 | for a, b in zip(shengmu, yunmu): 59 | if a == b: 60 | ph_list += [a] 61 | else: 62 | ph_list += [a + "%" + b] 63 | seg_list = '#'.join(jieba.cut(txt)) 64 | assert len(ph_list) == len([s for s in seg_list if s != '#']), (ph_list, seg_list) 65 | 66 | # 加入词边界'#' 67 | ph_list_ = [] 68 | seg_idx = 0 69 | for p in ph_list: 70 | if seg_list[seg_idx] == '#': 71 | ph_list_.append('#') 72 | seg_idx += 1 73 | elif len(ph_list_) > 0: 74 | ph_list_.append("|") 75 | seg_idx += 1 76 | finished = False 77 | if not finished: 78 | ph_list_ += [x for x in p.split("%") if x != ''] 79 | 80 | ph_list = ph_list_ 81 | 82 | # 去除静音符号周围的词边界标记 [..., '#', ',', '#', ...] 83 | sil_phonemes = list(PUNCS) + TxtProcessor.sp_phonemes() 84 | ph_list_ = [] 85 | for i in range(0, len(ph_list), 1): 86 | if ph_list[i] != '#' or (ph_list[i - 1] not in sil_phonemes and ph_list[i + 1] not in sil_phonemes): 87 | ph_list_.append(ph_list[i]) 88 | ph_list = ph_list_ 89 | return ph_list, txt 90 | 91 | 92 | if __name__ == '__main__': 93 | t = 'simon演唱过后,simon还进行了simon精彩的文艺演出simon.' 94 | phs, txt = TxtProcessor.process(t, {'use_tone': True}) 95 | print(phs, txt) 96 | -------------------------------------------------------------------------------- /stable_codec/data/Text2Phone/modules/txt_processors/zh_g2pM.py: -------------------------------------------------------------------------------- 1 | import re 2 | import jieba 3 | from pypinyin import pinyin, Style 4 | from modules.data_gen_utils import PUNCS 5 | from modules.txt_processors import zh 6 | from g2pM import G2pM 7 | 8 | ALL_SHENMU = ['zh', 'ch', 'sh', 'b', 'p', 'm', 'f', 'd', 't', 'n', 'l', 'g', 'k', 'h', 'j', 9 | 'q', 'x', 'r', 'z', 'c', 's', 'y', 'w'] 10 | 11 | 12 | class TxtProcessor(zh.TxtProcessor): 13 | model = G2pM() 14 | 15 | @staticmethod 16 | def sp_phonemes(): 17 | return ['|', '#'] 18 | 19 | @classmethod 20 | def process(cls, txt, pre_align_args): 21 | txt = cls.preprocess_text(txt) 22 | ph_list = cls.model(txt, tone=pre_align_args['use_tone'], char_split=True) 23 | seg_list = '#'.join(jieba.cut(txt)) 24 | assert len(ph_list) == len([s for s in seg_list if s != '#']), (ph_list, seg_list) 25 | 26 | # 加入词边界'#' 27 | ph_list_ = [] 28 | seg_idx = 0 29 | for p in ph_list: 30 | p = p.replace("u:", "v") 31 | if seg_list[seg_idx] == '#': 32 | ph_list_.append('#') 33 | seg_idx += 1 34 | else: 35 | ph_list_.append("|") 36 | seg_idx += 1 37 | if re.findall('[\u4e00-\u9fff]', p): 38 | if pre_align_args['use_tone']: 39 | p = pinyin(p, style=Style.TONE3, strict=True)[0][0] 40 | if p[-1] not in ['1', '2', '3', '4', '5']: 41 | p = p + '5' 42 | else: 43 | p = pinyin(p, style=Style.NORMAL, strict=True)[0][0] 44 | 45 | finished = False 46 | if len([c.isalpha() for c in p]) > 1: 47 | for shenmu in ALL_SHENMU: 48 | if p.startswith(shenmu) and not p.lstrip(shenmu).isnumeric(): 49 | ph_list_ += [shenmu, p.lstrip(shenmu)] 50 | finished = True 51 | break 52 | if not finished: 53 | ph_list_.append(p) 54 | 55 | ph_list = ph_list_ 56 | 57 | # 去除静音符号周围的词边界标记 [..., '#', ',', '#', ...] 58 | sil_phonemes = list(PUNCS) + TxtProcessor.sp_phonemes() 59 | ph_list_ = [] 60 | for i in range(0, len(ph_list), 1): 61 | if ph_list[i] != '#' or (ph_list[i - 1] not in sil_phonemes and ph_list[i + 1] not in sil_phonemes): 62 | ph_list_.append(ph_list[i]) 63 | ph_list = ph_list_ 64 | return ph_list, txt 65 | 66 | 67 | if __name__ == '__main__': 68 | phs, txt = TxtProcessor.process('他来到了,网易杭研大厦', {'use_tone': True}) 69 | print(phs) 70 | -------------------------------------------------------------------------------- /stable_codec/data/Text2Phone/modules/txt_processors/zh_g2pM_song_seg.py: -------------------------------------------------------------------------------- 1 | import re 2 | from modules.data_gen_utils import PUNCS 3 | from modules.txt_processors import zh_g2pM 4 | 5 | from utils.text_norm import NSWNormalizer 6 | 7 | 8 | class TxtProcessor(zh_g2pM.TxtProcessor): 9 | @staticmethod 10 | def preprocess_text(text): 11 | text = text.translate(TxtProcessor.table) 12 | text = NSWNormalizer(text).normalize(remove_punc=False) 13 | text = re.sub("[\'\"()]+", "", text) 14 | text = re.sub("[-]+", " ", text) 15 | text = re.sub(f"[^ A-Za-z\u4e00-\u9fff{PUNCS}&]", "", text) 16 | text = re.sub(f"([{PUNCS}])+", r"\1", text) # !! -> ! 17 | text = re.sub(f"([{PUNCS}])", r" \1 ", text) 18 | text = re.sub(rf"\s+", r"", text) 19 | return text 20 | 21 | @staticmethod 22 | def sp_phonemes(): 23 | return ['|', '#', '&'] 24 | 25 | @classmethod 26 | def process(cls, txt, pre_align_args): 27 | txt = txt.replace('SEP', '&') 28 | ph_list, txt = super().process(txt, pre_align_args) 29 | txt = txt.replace('&', ' SEP ') 30 | ph_list = [p if p != '&' else 'SEP' for p in ph_list if p not in ['|', '#', '', '']] 31 | return ph_list, txt 32 | -------------------------------------------------------------------------------- /stable_codec/data/Text2Phone/modules/txt_processors/zh_song_seg.py: -------------------------------------------------------------------------------- 1 | import re 2 | from modules.data_gen_utils import PUNCS 3 | from modules.txt_processors import zh 4 | from utils.text_norm import NSWNormalizer 5 | 6 | 7 | class TxtProcessor(zh.TxtProcessor): 8 | @staticmethod 9 | def preprocess_text(text): 10 | text = text.translate(TxtProcessor.table) 11 | text = NSWNormalizer(text).normalize(remove_punc=False) 12 | text = re.sub("[\'\"()]+", "", text) 13 | text = re.sub("[-]+", " ", text) 14 | text = re.sub(f"[^ A-Za-z\u4e00-\u9fff{PUNCS}&]", "", text) 15 | text = re.sub(f"([{PUNCS}])+", r"\1", text) # !! -> ! 16 | text = re.sub(f"([{PUNCS}])", r" \1 ", text) 17 | text = re.sub(rf"\s+", r"", text) 18 | return text 19 | 20 | @staticmethod 21 | def sp_phonemes(): 22 | return ['|', '#', '&'] 23 | 24 | @classmethod 25 | def process(cls, txt, pre_align_args): 26 | txt = txt.replace('SEP', '&') 27 | ph_list, txt = super().process(txt, pre_align_args) 28 | txt = txt.replace('&', ' SEP ') 29 | ph_list = [p if p != '&' else 'SEP' for p in ph_list if p not in ['|', '#', '', '']] 30 | return ph_list, txt 31 | -------------------------------------------------------------------------------- /stable_codec/data/Text2Phone/phone_tokenizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | from valle.tools.tokenizer.abs_tokenizer import AbsTokenizer 4 | 5 | default_phone_dict = "tools/tokenizer/Text2Phone/alignment_dict" 6 | 7 | class PhoneTokenizer(AbsTokenizer): 8 | """ 9 | This is the virtual tokenizer class. 10 | Other tokenizers should inherit this class. 11 | typicially: 12 | Text -> BPE 13 | Text -> Phone 14 | Audio -> Codec 15 | Image -> Codec 16 | ... 17 | """ 18 | 19 | def __init__(self, phone_table=default_phone_dict, duplicate=False, unk_ph=None): 20 | super(PhoneTokenizer, self).__init__() 21 | 22 | phone_dict = open(phone_table, encoding="utf-8").readlines() 23 | phone_dict = [line.strip().split() for line in phone_dict] 24 | phone_dict = {line[0]: None for line in phone_dict} 25 | keys = list(phone_dict.keys()) 26 | for i, k in enumerate(keys): 27 | phone_dict[k] = i 28 | self.phone_dict = phone_dict 29 | 30 | if unk_ph is None: 31 | self.unk_ph = "" 32 | logging.info("No unknown phone provided. Set it as .") 33 | else: 34 | self.unk_ph = unk_ph 35 | 36 | if unk_ph not in self.phone_dict: 37 | logging.info(f"Set unknown phone with number: {len(self.phone_dict)}") 38 | self.phone_dict[unk_ph] = len(self.phone_dict) 39 | self.unk_id = phone_dict[unk_ph] 40 | 41 | self.duplicate = duplicate 42 | 43 | @property 44 | def is_discrete(self): 45 | return True 46 | 47 | @property 48 | def codebook_length(self): 49 | return len(self.phone_dict) 50 | 51 | def find_length(self, x): 52 | return len(self.tokenize(x)) 53 | 54 | def tokenize(self, x, task=None, cache=None): 55 | if isinstance(x, torch.Tensor): 56 | assert x.dim() == 1 57 | x = torch.unique_consequtive(x) if not self.duplicate else x 58 | return x.to(torch.int16) 59 | elif isinstance(x, str): 60 | x = [self.phone_dict.get(ph, self.unk_id) for ph in x.strip().split()] 61 | return torch.Tensor(x).to(torch.int16) 62 | else: 63 | raise NotImplementedError 64 | -------------------------------------------------------------------------------- /stable_codec/data/Text2Phone/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import time 2 | import sys 3 | import types 4 | 5 | import chardet 6 | import numpy as np 7 | import torch 8 | import torch.distributed as dist 9 | from .ckpt_utils import load_ckpt 10 | 11 | 12 | def reduce_tensors(metrics): 13 | new_metrics = {} 14 | for k, v in metrics.items(): 15 | if isinstance(v, torch.Tensor): 16 | dist.all_reduce(v) 17 | v = v / dist.get_world_size() 18 | if type(v) is dict: 19 | v = reduce_tensors(v) 20 | new_metrics[k] = v 21 | return new_metrics 22 | 23 | 24 | def tensors_to_scalars(tensors): 25 | if isinstance(tensors, torch.Tensor): 26 | tensors = tensors.item() 27 | return tensors 28 | elif isinstance(tensors, dict): 29 | new_tensors = {} 30 | for k, v in tensors.items(): 31 | v = tensors_to_scalars(v) 32 | new_tensors[k] = v 33 | return new_tensors 34 | elif isinstance(tensors, list): 35 | return [tensors_to_scalars(v) for v in tensors] 36 | else: 37 | return tensors 38 | 39 | 40 | def tensors_to_np(tensors): 41 | if isinstance(tensors, dict): 42 | new_np = {} 43 | for k, v in tensors.items(): 44 | if isinstance(v, torch.Tensor): 45 | v = v.cpu().numpy() 46 | if type(v) is dict: 47 | v = tensors_to_np(v) 48 | new_np[k] = v 49 | elif isinstance(tensors, list): 50 | new_np = [] 51 | for v in tensors: 52 | if isinstance(v, torch.Tensor): 53 | v = v.cpu().numpy() 54 | if type(v) is dict: 55 | v = tensors_to_np(v) 56 | new_np.append(v) 57 | elif isinstance(tensors, torch.Tensor): 58 | v = tensors 59 | if isinstance(v, torch.Tensor): 60 | v = v.cpu().numpy() 61 | if type(v) is dict: 62 | v = tensors_to_np(v) 63 | new_np = v 64 | else: 65 | raise Exception(f'tensors_to_np does not support type {type(tensors)}.') 66 | return new_np 67 | 68 | 69 | def move_to_cpu(tensors): 70 | ret = {} 71 | for k, v in tensors.items(): 72 | if isinstance(v, torch.Tensor): 73 | v = v.cpu() 74 | if type(v) is dict: 75 | v = move_to_cpu(v) 76 | ret[k] = v 77 | return ret 78 | 79 | 80 | def move_to_cuda(batch, gpu_id=0): 81 | # base case: object can be directly moved using `cuda` or `to` 82 | if callable(getattr(batch, 'cuda', None)): 83 | return batch.cuda(gpu_id, non_blocking=True) 84 | elif callable(getattr(batch, 'to', None)): 85 | return batch.to(torch.device('cuda', gpu_id), non_blocking=True) 86 | elif isinstance(batch, list): 87 | for i, x in enumerate(batch): 88 | batch[i] = move_to_cuda(x, gpu_id) 89 | return batch 90 | elif isinstance(batch, tuple): 91 | batch = list(batch) 92 | for i, x in enumerate(batch): 93 | batch[i] = move_to_cuda(x, gpu_id) 94 | return tuple(batch) 95 | elif isinstance(batch, dict): 96 | for k, v in batch.items(): 97 | batch[k] = move_to_cuda(v, gpu_id) 98 | return batch 99 | return batch 100 | 101 | 102 | class AvgrageMeter(object): 103 | 104 | def __init__(self): 105 | self.reset() 106 | 107 | def reset(self): 108 | self.avg = 0 109 | self.sum = 0 110 | self.cnt = 0 111 | 112 | def update(self, val, n=1): 113 | self.sum += val * n 114 | self.cnt += n 115 | self.avg = self.sum / self.cnt 116 | 117 | 118 | def collate_1d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None, shift_id=1): 119 | """Convert a list of 1d tensors into a padded 2d tensor.""" 120 | size = max(v.size(0) for v in values) if max_len is None else max_len 121 | res = values[0].new(len(values), size).fill_(pad_idx) 122 | 123 | def copy_tensor(src, dst): 124 | assert dst.numel() == src.numel() 125 | if shift_right: 126 | dst[1:] = src[:-1] 127 | dst[0] = shift_id 128 | else: 129 | dst.copy_(src) 130 | 131 | for i, v in enumerate(values): 132 | copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)]) 133 | return res 134 | 135 | 136 | def collate_2d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None): 137 | """Convert a list of 2d tensors into a padded 3d tensor.""" 138 | size = max(v.size(0) for v in values) if max_len is None else max_len 139 | res = values[0].new(len(values), size, values[0].shape[1]).fill_(pad_idx) 140 | 141 | def copy_tensor(src, dst): 142 | assert dst.numel() == src.numel() 143 | if shift_right: 144 | dst[1:] = src[:-1] 145 | else: 146 | dst.copy_(src) 147 | 148 | for i, v in enumerate(values): 149 | copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)]) 150 | return res 151 | 152 | 153 | def _is_batch_full(batch, num_tokens, max_tokens, max_sentences): 154 | if len(batch) == 0: 155 | return 0 156 | if len(batch) == max_sentences: 157 | return 1 158 | if num_tokens > max_tokens: 159 | return 1 160 | return 0 161 | 162 | 163 | def batch_by_size( 164 | indices, num_tokens_fn, max_tokens=None, max_sentences=None, 165 | required_batch_size_multiple=1, distributed=False 166 | ): 167 | """ 168 | Yield mini-batches of indices bucketed by size. Batches may contain 169 | sequences of different lengths. 170 | 171 | Args: 172 | indices (List[int]): ordered list of dataset indices 173 | num_tokens_fn (callable): function that returns the number of tokens at 174 | a given index 175 | max_tokens (int, optional): max number of tokens in each batch 176 | (default: None). 177 | max_sentences (int, optional): max number of sentences in each 178 | batch (default: None). 179 | required_batch_size_multiple (int, optional): require batch size to 180 | be a multiple of N (default: 1). 181 | """ 182 | max_tokens = max_tokens if max_tokens is not None else sys.maxsize 183 | max_sentences = max_sentences if max_sentences is not None else sys.maxsize 184 | bsz_mult = required_batch_size_multiple 185 | 186 | if isinstance(indices, types.GeneratorType): 187 | indices = np.fromiter(indices, dtype=np.int64, count=-1) 188 | 189 | sample_len = 0 190 | sample_lens = [] 191 | batch = [] 192 | batches = [] 193 | for i in range(len(indices)): 194 | idx = indices[i] 195 | num_tokens = num_tokens_fn(idx) 196 | sample_lens.append(num_tokens) 197 | sample_len = max(sample_len, num_tokens) 198 | 199 | assert sample_len <= max_tokens, ( 200 | "sentence at index {} of size {} exceeds max_tokens " 201 | "limit of {}!".format(idx, sample_len, max_tokens) 202 | ) 203 | num_tokens = (len(batch) + 1) * sample_len 204 | 205 | if _is_batch_full(batch, num_tokens, max_tokens, max_sentences): 206 | mod_len = max( 207 | bsz_mult * (len(batch) // bsz_mult), 208 | len(batch) % bsz_mult, 209 | ) 210 | batches.append(batch[:mod_len]) 211 | batch = batch[mod_len:] 212 | sample_lens = sample_lens[mod_len:] 213 | sample_len = max(sample_lens) if len(sample_lens) > 0 else 0 214 | batch.append(idx) 215 | if len(batch) > 0: 216 | batches.append(batch) 217 | return batches 218 | 219 | def unpack_dict_to_list(samples): 220 | samples_ = [] 221 | bsz = samples.get('outputs').size(0) 222 | for i in range(bsz): 223 | res = {} 224 | for k, v in samples.items(): 225 | try: 226 | res[k] = v[i] 227 | except: 228 | pass 229 | samples_.append(res) 230 | return samples_ 231 | 232 | 233 | def remove_padding(x, padding_idx=0): 234 | if x is None: 235 | return None 236 | assert len(x.shape) in [1, 2] 237 | if len(x.shape) == 2: # [T, H] 238 | return x[np.abs(x).sum(-1) != padding_idx] 239 | elif len(x.shape) == 1: # [T] 240 | return x[x != padding_idx] 241 | 242 | 243 | class Timer: 244 | timer_map = {} 245 | 246 | def __init__(self, name, enable=False): 247 | if name not in Timer.timer_map: 248 | Timer.timer_map[name] = 0 249 | self.name = name 250 | self.enable = enable 251 | 252 | def __enter__(self): 253 | if self.enable: 254 | if torch.cuda.is_available(): 255 | torch.cuda.synchronize() 256 | self.t = time.time() 257 | 258 | def __exit__(self, exc_type, exc_val, exc_tb): 259 | if self.enable: 260 | if torch.cuda.is_available(): 261 | torch.cuda.synchronize() 262 | Timer.timer_map[self.name] += time.time() - self.t 263 | if self.enable: 264 | print(f'[Timer] {self.name}: {Timer.timer_map[self.name]}') 265 | 266 | 267 | def print_arch(model, model_name='model'): 268 | print(f"| {model_name} Arch: ", model) 269 | num_params(model, model_name=model_name) 270 | 271 | 272 | def num_params(model, print_out=True, model_name="model"): 273 | parameters = filter(lambda p: p.requires_grad, model.parameters()) 274 | parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 275 | if print_out: 276 | print(f'| {model_name} Trainable Parameters: %.3fM' % parameters) 277 | return parameters 278 | 279 | 280 | def get_encoding(file): 281 | with open(file, 'rb') as f: 282 | encoding = chardet.detect(f.read())['encoding'] 283 | if encoding == 'GB2312': 284 | encoding = 'GB18030' 285 | return encoding 286 | -------------------------------------------------------------------------------- /stable_codec/data/Text2Phone/utils/audio.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import librosa 3 | import librosa.filters 4 | import numpy as np 5 | import torch 6 | from scipy import signal 7 | from scipy.io import wavfile 8 | import torch.nn.functional as F 9 | 10 | 11 | def save_wav(wav, path, sr, norm=False): 12 | if norm: 13 | wav = wav / np.abs(wav).max() 14 | wav *= 32767 15 | # proposed by @dsmiller 16 | wavfile.write(path, sr, wav.astype(np.int16)) 17 | 18 | 19 | def to_mp3(out_path): 20 | subprocess.check_call( 21 | f'ffmpeg -threads 1 -loglevel error -i "{out_path}.wav" -vn -ar 44100 -ac 1 -b:a 192k -y -hide_banner "{out_path}.mp3"', 22 | shell=True, stdin=subprocess.PIPE) 23 | subprocess.check_call(f'rm -f "{out_path}.wav"', shell=True) 24 | 25 | 26 | def get_hop_size(hparams): 27 | hop_size = hparams['hop_size'] 28 | if hop_size is None: 29 | assert hparams['frame_shift_ms'] is not None 30 | hop_size = int(hparams['frame_shift_ms'] / 1000 * hparams['audio_sample_rate']) 31 | return hop_size 32 | 33 | 34 | ########################################################################################### 35 | def griffin_lim(S, hparams, angles=None): 36 | angles = np.exp(2j * np.pi * np.random.rand(*S.shape)) if angles is None else angles 37 | S_complex = np.abs(S).astype(np.complex) 38 | y = _istft(S_complex * angles, hparams) 39 | for i in range(hparams['griffin_lim_iters']): 40 | angles = np.exp(1j * np.angle(_stft(y, hparams))) 41 | y = _istft(S_complex * angles, hparams) 42 | return y 43 | 44 | 45 | def preemphasis(wav, k, preemphasize=True): 46 | if preemphasize: 47 | return signal.lfilter([1, -k], [1], wav) 48 | return wav 49 | 50 | 51 | def inv_preemphasis(wav, k, inv_preemphasize=True): 52 | if inv_preemphasize: 53 | return signal.lfilter([1], [1, -k], wav) 54 | return wav 55 | 56 | 57 | def _stft(y, hparams): 58 | return librosa.stft(y=y, n_fft=hparams['fft_size'], hop_length=get_hop_size(hparams), 59 | win_length=hparams['win_size'], pad_mode='constant') 60 | 61 | 62 | def _istft(y, hparams): 63 | return librosa.istft(y, hop_length=get_hop_size(hparams), win_length=hparams['win_size']) 64 | 65 | 66 | 67 | def librosa_pad_lr(x, fsize, fshift, pad_sides=1): 68 | '''compute right padding (final frame) or both sides padding (first and final frames) 69 | ''' 70 | assert pad_sides in (1, 2) 71 | # return int(fsize // 2) 72 | pad = (x.shape[0] // fshift + 1) * fshift - x.shape[0] 73 | if pad_sides == 1: 74 | return 0, pad 75 | else: 76 | return pad // 2, pad // 2 + pad % 2 77 | 78 | 79 | # Conversions 80 | _mel_basis = None 81 | _inv_mel_basis = None 82 | 83 | 84 | def _linear_to_mel(spectogram, hparams): 85 | global _mel_basis 86 | if _mel_basis is None: 87 | _mel_basis = _build_mel_basis(hparams) 88 | return np.dot(_mel_basis, spectogram) 89 | 90 | 91 | def _mel_to_linear(mel_spectrogram, hparams): 92 | global _inv_mel_basis 93 | if _inv_mel_basis is None: 94 | _inv_mel_basis = np.linalg.pinv(_build_mel_basis(hparams)) 95 | return np.maximum(1e-10, np.dot(_inv_mel_basis, mel_spectrogram)) 96 | 97 | 98 | def _build_mel_basis(hparams): 99 | assert hparams['fmax'] <= hparams['audio_sample_rate'] // 2 100 | return librosa.filters.mel(hparams['audio_sample_rate'], hparams['fft_size'], n_mels=hparams['audio_num_mel_bins'], 101 | fmin=hparams['fmin'], fmax=hparams['fmax']) 102 | 103 | 104 | def amp_to_db(x): 105 | return 20 * np.log10(np.maximum(1e-5, x)) 106 | 107 | 108 | def db_to_amp(x): 109 | return 10.0 ** (x * 0.05) 110 | 111 | 112 | def normalize(S, hparams): 113 | return (S - hparams['min_level_db']) / -hparams['min_level_db'] 114 | 115 | 116 | def denormalize(D, hparams): 117 | return (D * -hparams['min_level_db']) + hparams['min_level_db'] 118 | 119 | 120 | #### torch audio 121 | 122 | 123 | def istft(amp, ang, hparams, pad=False, window=None): 124 | spec = amp * torch.exp(1j * ang) 125 | spec_r = spec.real 126 | spec_i = spec.imag 127 | spec = torch.stack([spec_r, spec_i], -1) 128 | if window is None: 129 | window = torch.hann_window(hparams['win_size']).to(amp.device) 130 | if pad: 131 | spec = F.pad(spec, [0, 0, 0, 1], mode='reflect') 132 | wav = torch.istft(spec, hparams['fft_size'], hparams['hop_size'], hparams['win_size']) 133 | return wav 134 | 135 | 136 | def griffin_lim_torch(amp, ang, hparams, n_iters=30): 137 | """ 138 | 139 | Examples: 140 | >>> x_stft = librosa.stft(wav, n_fft=fft_size, hop_length=hop_size, win_length=win_length, pad_mode="constant") 141 | >>> x_stft = x_stft[None, ...] 142 | >>> amp = np.abs(x_stft) 143 | >>> angle_init = np.exp(2j * np.pi * np.random.rand(*x_stft.shape)) 144 | >>> amp = torch.FloatTensor(amp) 145 | >>> wav = griffin_lim_torch(amp, angle_init, hparams) 146 | 147 | :param amp: [B, n_fft, T] 148 | :param ang: [B, n_fft, T] 149 | :return: [B, T_wav] 150 | """ 151 | window = torch.hann_window(hparams['win_size']).to(amp.device) 152 | y = istft(amp, ang, hparams, window=window) 153 | for i in range(n_iters): 154 | x_stft = torch.stft(y, hparams['fft_size'], hparams['hop_size'], hparams['win_size'], window) 155 | x_stft = x_stft[..., 0] + 1j * x_stft[..., 1] 156 | ang = torch.angle(x_stft) 157 | y = istft(amp, ang, hparams, window=window) 158 | return y 159 | -------------------------------------------------------------------------------- /stable_codec/data/Text2Phone/utils/ckpt_utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import logging 3 | import os 4 | import re 5 | import torch 6 | 7 | 8 | def get_last_checkpoint(work_dir, steps=None): 9 | checkpoint = None 10 | last_ckpt_path = None 11 | ckpt_paths = get_all_ckpts(work_dir, steps) 12 | if len(ckpt_paths) > 0: 13 | last_ckpt_path = ckpt_paths[0] 14 | checkpoint = torch.load(last_ckpt_path, map_location='cpu') 15 | logging.info(f'load module from checkpoint: {last_ckpt_path}') 16 | return checkpoint, last_ckpt_path 17 | 18 | 19 | def get_all_ckpts(work_dir, steps=None): 20 | if steps is None: 21 | ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_*.ckpt' 22 | else: 23 | ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_{steps}.ckpt' 24 | return sorted(glob.glob(ckpt_path_pattern), 25 | key=lambda x: -int(re.findall('.*steps\_(\d+)\.ckpt', x)[0])) 26 | 27 | 28 | def load_ckpt(cur_model, ckpt_base_dir, model_name='model', force=True, strict=True): 29 | if os.path.isfile(ckpt_base_dir): 30 | base_dir = os.path.dirname(ckpt_base_dir) 31 | ckpt_path = ckpt_base_dir 32 | checkpoint = torch.load(ckpt_base_dir, map_location='cpu') 33 | else: 34 | base_dir = ckpt_base_dir 35 | checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir) 36 | if checkpoint is not None: 37 | state_dict = checkpoint["state_dict"] 38 | if len([k for k in state_dict.keys() if '.' in k]) > 0: 39 | state_dict = {k[len(model_name) + 1:]: v for k, v in state_dict.items() 40 | if k.startswith(f'{model_name}.')} 41 | else: 42 | if '.' not in model_name: 43 | state_dict = state_dict[model_name] 44 | else: 45 | base_model_name = model_name.split('.')[0] 46 | rest_model_name = model_name[len(base_model_name) + 1:] 47 | state_dict = { 48 | k[len(rest_model_name) + 1:]: v for k, v in state_dict[base_model_name].items() 49 | if k.startswith(f'{rest_model_name}.')} 50 | if not strict: 51 | cur_model_state_dict = cur_model.state_dict() 52 | unmatched_keys = [] 53 | for key, param in state_dict.items(): 54 | if key in cur_model_state_dict: 55 | new_param = cur_model_state_dict[key] 56 | if new_param.shape != param.shape: 57 | unmatched_keys.append(key) 58 | print("| Unmatched keys: ", key, new_param.shape, param.shape) 59 | for key in unmatched_keys: 60 | del state_dict[key] 61 | cur_model.load_state_dict(state_dict, strict=strict) 62 | print(f"| load '{model_name}' from '{ckpt_path}'.") 63 | else: 64 | e_msg = f"| ckpt not found in {base_dir}." 65 | if force: 66 | assert False, e_msg 67 | else: 68 | print(e_msg) 69 | -------------------------------------------------------------------------------- /stable_codec/data/Text2Phone/utils/pitch_utils.py: -------------------------------------------------------------------------------- 1 | ########## 2 | # world 3 | ########## 4 | import librosa 5 | import numpy as np 6 | import copy 7 | 8 | import torch 9 | 10 | gamma = 0 11 | mcepInput = 3 # 0 for dB, 3 for magnitude 12 | alpha = 0.45 13 | en_floor = 10 ** (-80 / 20) 14 | FFT_SIZE = 2048 15 | 16 | 17 | def code_harmonic(sp, order): 18 | import pysptk 19 | # get mcep 20 | mceps = np.apply_along_axis(pysptk.mcep, 1, sp, order - 1, alpha, itype=mcepInput, threshold=en_floor) 21 | 22 | # do fft and take real 23 | scale_mceps = copy.copy(mceps) 24 | scale_mceps[:, 0] *= 2 25 | scale_mceps[:, -1] *= 2 26 | mirror = np.hstack([scale_mceps[:, :-1], scale_mceps[:, -1:0:-1]]) 27 | mfsc = np.fft.rfft(mirror).real 28 | 29 | return mfsc 30 | 31 | 32 | def decode_harmonic(mfsc, fftlen=FFT_SIZE): 33 | import pysptk 34 | # get mcep back 35 | mceps_mirror = np.fft.irfft(mfsc) 36 | mceps_back = mceps_mirror[:, :60] 37 | mceps_back[:, 0] /= 2 38 | mceps_back[:, -1] /= 2 39 | 40 | # get sp 41 | spSm = np.exp(np.apply_along_axis(pysptk.mgc2sp, 1, mceps_back, alpha, gamma, fftlen=fftlen).real) 42 | 43 | return spSm 44 | 45 | 46 | def to_lf0(f0): 47 | f0[f0 < 1.0e-5] = 1.0e-6 48 | lf0 = f0.log() if isinstance(f0, torch.Tensor) else np.log(f0) 49 | lf0[f0 < 1.0e-5] = - 1.0E+10 50 | return lf0 51 | 52 | 53 | def to_f0(lf0): 54 | f0 = np.where(lf0 <= 0, 0.0, np.exp(lf0)) 55 | return f0.flatten() 56 | 57 | 58 | def formant_enhancement(coded_spectrogram, beta, fs): 59 | alpha_dict = { 60 | 8000: 0.31, 61 | 16000: 0.58, 62 | 22050: 0.65, 63 | 44100: 0.76, 64 | 48000: 0.77 65 | } 66 | alpha = alpha_dict[fs] 67 | datad = np.zeros((coded_spectrogram.shape[1],)) 68 | sp_dim = coded_spectrogram.shape[1] 69 | for i in range(coded_spectrogram.shape[0]): 70 | datad = mc2b(coded_spectrogram[i], datad, sp_dim - 1, alpha) 71 | datad[1] = datad[1] - alpha * beta * datad[2] 72 | for j in range(2, sp_dim): 73 | datad[j] *= 1 + beta 74 | coded_spectrogram[i] = b2mc(datad, coded_spectrogram[i], sp_dim - 1, alpha) 75 | return coded_spectrogram 76 | 77 | 78 | def mc2b(mc, b, m, a): 79 | """ 80 | Transform Mel Cepstrum to MLSA Digital Filter Coefficients 81 | 82 | void mc2b(mc, b, m, a) 83 | 84 | double *mc : mel cepstral coefficients 85 | double *b : MLSA digital filter coefficients 86 | int m : order of mel cepstrum 87 | double a : all-pass constant 88 | 89 | http://www.asel.udel.edu/icslp/cdrom/vol1/725/a725.pdf 90 | CELP coding system based on mel-generalized cepstral analysis 91 | :param mc: 92 | :param b: 93 | :param m: 94 | :param a: 95 | :return: 96 | """ 97 | b[m] = mc[m] 98 | for i in range(1, m + 1): 99 | b[m - i] = mc[m - i] - a * b[m - i + 1] 100 | return b 101 | 102 | 103 | def b2mc(b, mc, m, a): 104 | """ 105 | Transform MLSA Digital Filter Coefficients to Mel Cepstrum 106 | 107 | void b2mc(b, mc, m, a) 108 | 109 | double *b : MLSA digital filter coefficients 110 | double *mc : mel cepstral coefficients 111 | int m : order of mel cepstrum 112 | double a : all-pass constant 113 | 114 | http://www.asel.udel.edu/icslp/cdrom/vol1/725/a725.pdf 115 | CELP coding system based on mel-generalized cepstral analysis 116 | :param b: 117 | :param mc: 118 | :param m: 119 | :param a: 120 | :return: 121 | """ 122 | d = mc[m] = b[m] 123 | for i in range(1, m + 1): 124 | o = b[m - i] + a * d 125 | d = b[m - i] 126 | mc[m - i] = o 127 | return mc 128 | 129 | 130 | f0_bin = 256 131 | f0_max = 1100.0 132 | f0_min = 50.0 133 | f0_mel_min = 1127 * np.log(1 + f0_min / 700) 134 | f0_mel_max = 1127 * np.log(1 + f0_max / 700) 135 | 136 | 137 | def f0_to_coarse(f0): 138 | is_torch = isinstance(f0, torch.Tensor) 139 | f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * np.log(1 + f0 / 700) 140 | f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1 141 | 142 | f0_mel[f0_mel <= 1] = 1 143 | f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1 144 | f0_coarse = (f0_mel + 0.5).long() if is_torch else np.rint(f0_mel).astype(np.int) 145 | assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (f0_coarse.max(), f0_coarse.min(), f0.min(), f0.max()) 146 | return f0_coarse 147 | 148 | 149 | def norm_f0(f0, uv, hparams): 150 | is_torch = isinstance(f0, torch.Tensor) 151 | if hparams['pitch_norm'] == 'standard': 152 | f0 = (f0 - hparams['f0_mean']) / hparams['f0_std'] 153 | if hparams['pitch_norm'] == 'log': 154 | f0 = torch.log2(f0 + 1e-8) if is_torch else np.log2(f0 + 1e-8) 155 | if uv is not None and hparams['use_uv']: 156 | f0[uv > 0] = 0 157 | return f0 158 | 159 | 160 | def norm_interp_f0(f0, hparams): 161 | is_torch = isinstance(f0, torch.Tensor) 162 | if is_torch: 163 | device = f0.device 164 | f0 = f0.data.cpu().numpy() 165 | uv = f0 == 0 166 | f0 = norm_f0(f0, uv, hparams) 167 | if sum(uv) == len(f0): 168 | f0[uv] = 0 169 | elif sum(uv) > 0: 170 | f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv]) 171 | if is_torch: 172 | uv = torch.FloatTensor(uv) 173 | f0 = torch.FloatTensor(f0) 174 | f0 = f0.to(device) 175 | uv = uv.to(device) 176 | return f0, uv 177 | 178 | 179 | def denorm_f0(f0, uv, hparams, pitch_padding=None, min=None, max=None): 180 | is_torch = isinstance(f0, torch.Tensor) 181 | if hparams['pitch_norm'] == 'standard': 182 | f0 = f0 * hparams['f0_std'] + hparams['f0_mean'] 183 | if hparams['pitch_norm'] == 'log': 184 | f0 = 2 ** f0 185 | if min is None: 186 | min = 0 187 | if max is None: 188 | max = f0_max 189 | f0 = f0.clamp(min=min) if is_torch else np.clip(f0, min=min) 190 | f0 = f0.clamp(max=max) if is_torch else np.clip(f0, max=max) 191 | if uv is not None and hparams['use_uv']: 192 | f0[uv > 0] = 0 193 | if pitch_padding is not None: 194 | f0[pitch_padding] = 0 195 | return f0 196 | 197 | 198 | def pitchfeats(wav, sampling_rate, fft_size, hop_size, win_length, fmin, fmax): 199 | pitches, magnitudes = librosa.piptrack(wav, sampling_rate, 200 | n_fft=fft_size, win_length=win_length, hop_length=hop_size, 201 | fmin=fmin, fmax=fmax) 202 | pitches = pitches.T 203 | magnitudes = magnitudes.T 204 | assert pitches.shape == magnitudes.shape 205 | 206 | pitches = [pitches[i][find_f0(magnitudes[i])] for i, _ in enumerate(pitches)] 207 | 208 | return np.asarray(pitches) 209 | 210 | 211 | def find_f0(mags): 212 | tmp = 0 213 | mags = list(mags) 214 | for i, mag in enumerate(mags): 215 | if mag < tmp: 216 | # return i-1 217 | if tmp - mag > 2: 218 | # return i-1 219 | return mags.index(max(mags[0:i])) 220 | else: 221 | return 0 222 | else: 223 | tmp = mag 224 | return 0 225 | -------------------------------------------------------------------------------- /stable_codec/data/Text2Phone/utils/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | 3 | matplotlib.use('Agg') 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import torch 7 | 8 | LINE_COLORS = ['w', 'r', 'y', 'cyan', 'm', 'b', 'lime'] 9 | 10 | 11 | def spec_to_figure(spec, vmin=None, vmax=None, title=''): 12 | if isinstance(spec, torch.Tensor): 13 | spec = spec.cpu().numpy() 14 | fig = plt.figure(figsize=(12, 6)) 15 | plt.title(title) 16 | plt.pcolor(spec.T, vmin=vmin, vmax=vmax) 17 | return fig 18 | 19 | 20 | def spec_f0_to_figure(spec, f0s, figsize=None): 21 | max_y = spec.shape[1] 22 | if isinstance(spec, torch.Tensor): 23 | spec = spec.detach().cpu().numpy() 24 | f0s = {k: f0.detach().cpu().numpy() for k, f0 in f0s.items()} 25 | f0s = {k: f0 / 10 for k, f0 in f0s.items()} 26 | fig = plt.figure(figsize=(12, 6) if figsize is None else figsize) 27 | plt.pcolor(spec.T) 28 | for i, (k, f0) in enumerate(f0s.items()): 29 | plt.plot(f0.clip(0, max_y), label=k, c=LINE_COLORS[i], linewidth=1, alpha=0.8) 30 | plt.legend() 31 | return fig 32 | 33 | 34 | def dur_to_figure(dur_gt, dur_pred, txt, mels=None, vmin=-5.5, vmax=1): 35 | dur_gt = dur_gt.cpu().numpy() 36 | dur_pred = dur_pred.cpu().numpy() 37 | dur_gt = np.cumsum(dur_gt).astype(int) 38 | dur_pred = np.cumsum(dur_pred).astype(int) 39 | fig = plt.figure(figsize=(12, 6)) 40 | for i in range(len(dur_gt)): 41 | shift = (i % 8) + 1 42 | plt.text(dur_gt[i], shift * 4, txt[i]) 43 | plt.text(dur_pred[i], 40 + shift * 4, txt[i]) 44 | plt.vlines(dur_gt[i], 0, 40, colors='b') # blue is gt 45 | plt.vlines(dur_pred[i], 40, 80, colors='r') # red is pred 46 | plt.xlim(0, max(dur_gt[-1], dur_pred[-1])) 47 | if mels is not None: 48 | mels = mels.cpu().numpy() 49 | plt.pcolor(mels.T, vmin=vmin, vmax=vmax) 50 | return fig 51 | 52 | 53 | def f0_to_figure(f0_gt, f0_cwt=None, f0_pred=None): 54 | fig = plt.figure(figsize=(12, 8)) 55 | f0_gt = f0_gt.cpu().numpy() 56 | plt.plot(f0_gt, color='r', label='gt') 57 | if f0_cwt is not None: 58 | f0_cwt = f0_cwt.cpu().numpy() 59 | plt.plot(f0_cwt, color='b', label='cwt') 60 | if f0_pred is not None: 61 | f0_pred = f0_pred.cpu().numpy() 62 | plt.plot(f0_pred, color='green', label='pred') 63 | plt.legend() 64 | return fig 65 | -------------------------------------------------------------------------------- /stable_codec/data/Text2Phone/utils/text_encoder.py: -------------------------------------------------------------------------------- 1 | import re 2 | import six 3 | from six.moves import range # pylint: disable=redefined-builtin 4 | 5 | PAD = "" 6 | EOS = "" 7 | UNK = "" 8 | SEG = "|" 9 | RESERVED_TOKENS = [PAD, EOS, UNK] 10 | NUM_RESERVED_TOKENS = len(RESERVED_TOKENS) 11 | PAD_ID = RESERVED_TOKENS.index(PAD) # Normally 0 12 | EOS_ID = RESERVED_TOKENS.index(EOS) # Normally 1 13 | UNK_ID = RESERVED_TOKENS.index(UNK) # Normally 2 14 | 15 | if six.PY2: 16 | RESERVED_TOKENS_BYTES = RESERVED_TOKENS 17 | else: 18 | RESERVED_TOKENS_BYTES = [bytes(PAD, "ascii"), bytes(EOS, "ascii")] 19 | 20 | # Regular expression for unescaping token strings. 21 | # '\u' is converted to '_' 22 | # '\\' is converted to '\' 23 | # '\213;' is converted to unichr(213) 24 | _UNESCAPE_REGEX = re.compile(r"\\u|\\\\|\\([0-9]+);") 25 | _ESCAPE_CHARS = set(u"\\_u;0123456789") 26 | 27 | 28 | def strip_ids(ids, ids_to_strip): 29 | """Strip ids_to_strip from the end ids.""" 30 | ids = list(ids) 31 | while ids and ids[-1] in ids_to_strip: 32 | ids.pop() 33 | return ids 34 | 35 | 36 | class TextEncoder(object): 37 | """Base class for converting from ints to/from human readable strings.""" 38 | 39 | def __init__(self, num_reserved_ids=NUM_RESERVED_TOKENS): 40 | self._num_reserved_ids = num_reserved_ids 41 | 42 | @property 43 | def num_reserved_ids(self): 44 | return self._num_reserved_ids 45 | 46 | def encode(self, s): 47 | """Transform a human-readable string into a sequence of int ids. 48 | 49 | The ids should be in the range [num_reserved_ids, vocab_size). Ids [0, 50 | num_reserved_ids) are reserved. 51 | 52 | EOS is not appended. 53 | 54 | Args: 55 | s: human-readable string to be converted. 56 | 57 | Returns: 58 | ids: list of integers 59 | """ 60 | return [int(w) + self._num_reserved_ids for w in s.split()] 61 | 62 | def decode(self, ids, strip_extraneous=False): 63 | """Transform a sequence of int ids into a human-readable string. 64 | 65 | EOS is not expected in ids. 66 | 67 | Args: 68 | ids: list of integers to be converted. 69 | strip_extraneous: bool, whether to strip off extraneous tokens 70 | (EOS and PAD). 71 | 72 | Returns: 73 | s: human-readable string. 74 | """ 75 | if strip_extraneous: 76 | ids = strip_ids(ids, list(range(self._num_reserved_ids or 0))) 77 | return " ".join(self.decode_list(ids)) 78 | 79 | def decode_list(self, ids): 80 | """Transform a sequence of int ids into a their string versions. 81 | 82 | This method supports transforming individual input/output ids to their 83 | string versions so that sequence to/from text conversions can be visualized 84 | in a human readable format. 85 | 86 | Args: 87 | ids: list of integers to be converted. 88 | 89 | Returns: 90 | strs: list of human-readable string. 91 | """ 92 | decoded_ids = [] 93 | for id_ in ids: 94 | if 0 <= id_ < self._num_reserved_ids: 95 | decoded_ids.append(RESERVED_TOKENS[int(id_)]) 96 | else: 97 | decoded_ids.append(id_ - self._num_reserved_ids) 98 | return [str(d) for d in decoded_ids] 99 | 100 | @property 101 | def vocab_size(self): 102 | raise NotImplementedError() 103 | 104 | 105 | class ByteTextEncoder(TextEncoder): 106 | """Encodes each byte to an id. For 8-bit strings only.""" 107 | 108 | def encode(self, s): 109 | numres = self._num_reserved_ids 110 | if six.PY2: 111 | if isinstance(s, unicode): 112 | s = s.encode("utf-8") 113 | return [ord(c) + numres for c in s] 114 | # Python3: explicitly convert to UTF-8 115 | return [c + numres for c in s.encode("utf-8")] 116 | 117 | def decode(self, ids, strip_extraneous=False): 118 | if strip_extraneous: 119 | ids = strip_ids(ids, list(range(self._num_reserved_ids or 0))) 120 | numres = self._num_reserved_ids 121 | decoded_ids = [] 122 | int2byte = six.int2byte 123 | for id_ in ids: 124 | if 0 <= id_ < numres: 125 | decoded_ids.append(RESERVED_TOKENS_BYTES[int(id_)]) 126 | else: 127 | decoded_ids.append(int2byte(id_ - numres)) 128 | if six.PY2: 129 | return "".join(decoded_ids) 130 | # Python3: join byte arrays and then decode string 131 | return b"".join(decoded_ids).decode("utf-8", "replace") 132 | 133 | def decode_list(self, ids): 134 | numres = self._num_reserved_ids 135 | decoded_ids = [] 136 | int2byte = six.int2byte 137 | for id_ in ids: 138 | if 0 <= id_ < numres: 139 | decoded_ids.append(RESERVED_TOKENS_BYTES[int(id_)]) 140 | else: 141 | decoded_ids.append(int2byte(id_ - numres)) 142 | # Python3: join byte arrays and then decode string 143 | return decoded_ids 144 | 145 | @property 146 | def vocab_size(self): 147 | return 2**8 + self._num_reserved_ids 148 | 149 | 150 | class ByteTextEncoderWithEos(ByteTextEncoder): 151 | """Encodes each byte to an id and appends the EOS token.""" 152 | 153 | def encode(self, s): 154 | return super(ByteTextEncoderWithEos, self).encode(s) + [EOS_ID] 155 | 156 | 157 | class TokenTextEncoder(TextEncoder): 158 | """Encoder based on a user-supplied vocabulary (file or list).""" 159 | 160 | def __init__(self, 161 | vocab_filename, 162 | reverse=False, 163 | vocab_list=None, 164 | replace_oov=None, 165 | num_reserved_ids=NUM_RESERVED_TOKENS): 166 | """Initialize from a file or list, one token per line. 167 | 168 | Handling of reserved tokens works as follows: 169 | - When initializing from a list, we add reserved tokens to the vocab. 170 | - When initializing from a file, we do not add reserved tokens to the vocab. 171 | - When saving vocab files, we save reserved tokens to the file. 172 | 173 | Args: 174 | vocab_filename: If not None, the full filename to read vocab from. If this 175 | is not None, then vocab_list should be None. 176 | reverse: Boolean indicating if tokens should be reversed during encoding 177 | and decoding. 178 | vocab_list: If not None, a list of elements of the vocabulary. If this is 179 | not None, then vocab_filename should be None. 180 | replace_oov: If not None, every out-of-vocabulary token seen when 181 | encoding will be replaced by this string (which must be in vocab). 182 | num_reserved_ids: Number of IDs to save for reserved tokens like . 183 | """ 184 | super(TokenTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids) 185 | self._reverse = reverse 186 | self._replace_oov = replace_oov 187 | if vocab_filename: 188 | self._init_vocab_from_file(vocab_filename) 189 | else: 190 | assert vocab_list is not None 191 | self._init_vocab_from_list(vocab_list) 192 | self.pad_index = self._token_to_id[PAD] 193 | self.eos_index = self._token_to_id[EOS] 194 | self.unk_index = self._token_to_id[UNK] 195 | self.seg_index = self._token_to_id[SEG] if SEG in self._token_to_id else self.eos_index 196 | 197 | def encode(self, s): 198 | """Converts a space-separated string of tokens to a list of ids.""" 199 | sentence = s 200 | tokens = sentence.strip().split() 201 | if self._replace_oov is not None: 202 | tokens = [t if t in self._token_to_id else self._replace_oov 203 | for t in tokens] 204 | ret = [self._token_to_id[tok] for tok in tokens] 205 | return ret[::-1] if self._reverse else ret 206 | 207 | def decode(self, ids, strip_eos=False, strip_padding=False): 208 | if strip_padding and self.pad() in list(ids): 209 | pad_pos = list(ids).index(self.pad()) 210 | ids = ids[:pad_pos] 211 | if strip_eos and self.eos() in list(ids): 212 | eos_pos = list(ids).index(self.eos()) 213 | ids = ids[:eos_pos] 214 | return " ".join(self.decode_list(ids)) 215 | 216 | def decode_list(self, ids): 217 | seq = reversed(ids) if self._reverse else ids 218 | return [self._safe_id_to_token(i) for i in seq] 219 | 220 | @property 221 | def vocab_size(self): 222 | return len(self._id_to_token) 223 | 224 | def __len__(self): 225 | return self.vocab_size 226 | 227 | def _safe_id_to_token(self, idx): 228 | return self._id_to_token.get(idx, "ID_%d" % idx) 229 | 230 | def _init_vocab_from_file(self, filename): 231 | """Load vocab from a file. 232 | 233 | Args: 234 | filename: The file to load vocabulary from. 235 | """ 236 | with open(filename) as f: 237 | tokens = [token.strip() for token in f.readlines()] 238 | 239 | def token_gen(): 240 | for token in tokens: 241 | yield token 242 | 243 | self._init_vocab(token_gen(), add_reserved_tokens=False) 244 | 245 | def _init_vocab_from_list(self, vocab_list): 246 | """Initialize tokens from a list of tokens. 247 | 248 | It is ok if reserved tokens appear in the vocab list. They will be 249 | removed. The set of tokens in vocab_list should be unique. 250 | 251 | Args: 252 | vocab_list: A list of tokens. 253 | """ 254 | def token_gen(): 255 | for token in vocab_list: 256 | if token not in RESERVED_TOKENS: 257 | yield token 258 | 259 | self._init_vocab(token_gen()) 260 | 261 | def _init_vocab(self, token_generator, add_reserved_tokens=True): 262 | """Initialize vocabulary with tokens from token_generator.""" 263 | 264 | self._id_to_token = {} 265 | non_reserved_start_index = 0 266 | 267 | if add_reserved_tokens: 268 | self._id_to_token.update(enumerate(RESERVED_TOKENS)) 269 | non_reserved_start_index = len(RESERVED_TOKENS) 270 | 271 | self._id_to_token.update( 272 | enumerate(token_generator, start=non_reserved_start_index)) 273 | 274 | # _token_to_id is the reverse of _id_to_token 275 | self._token_to_id = dict((v, k) 276 | for k, v in six.iteritems(self._id_to_token)) 277 | 278 | def pad(self): 279 | return self.pad_index 280 | 281 | def eos(self): 282 | return self.eos_index 283 | 284 | def unk(self): 285 | return self.unk_index 286 | 287 | def seg(self): 288 | return self.seg_index 289 | 290 | def store_to_file(self, filename): 291 | """Write vocab file to disk. 292 | 293 | Vocab files have one token per line. The file ends in a newline. Reserved 294 | tokens are written to the vocab file as well. 295 | 296 | Args: 297 | filename: Full path of the file to store the vocab to. 298 | """ 299 | with open(filename, "w") as f: 300 | for i in range(len(self._id_to_token)): 301 | f.write(self._id_to_token[i] + "\n") 302 | 303 | def sil_phonemes(self): 304 | return [p for p in self._id_to_token.values() if not p[0].isalpha()] 305 | -------------------------------------------------------------------------------- /stable_codec/data/dataset.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import random 3 | import torch 4 | import torchaudio 5 | import webdataset as wds 6 | 7 | from typing import List 8 | from torchaudio import transforms as T 9 | 10 | from stable_audio_tools.data.dataset import ( 11 | S3DatasetConfig, LocalWebDatasetConfig, log_and_continue, audio_decoder, npy_decoder, 12 | is_valid_sample, collation_fn, AUDIO_KEYS, remove_long_silence, 13 | ) 14 | from stable_audio_tools.data.utils import ( 15 | Stereo, Mono, PhaseFlipper, PadCrop_Normalized_T, VolumeNorm, 16 | ) 17 | 18 | from .Text2Phone.Text2PhoneTokenizer import Text2PhoneTokenizer 19 | 20 | class WebDatasetDataLoader(): 21 | def __init__( 22 | self, 23 | datasets: List[S3DatasetConfig], 24 | batch_size, 25 | sample_size, 26 | sample_rate=48000, 27 | num_workers=8, 28 | epoch_steps=1000, 29 | random_crop=True, 30 | force_channels="stereo", 31 | augment_phase=True, 32 | remove_silence=True, 33 | silence_threshold=[0.01, 0.5], 34 | max_silence_duration=0.2, 35 | volume_norm=False, 36 | volume_norm_param=(-16, 2), 37 | pre_encoded=False, 38 | resampled_shards=True, 39 | force_align_text=False, 40 | **data_loader_kwargs 41 | ): 42 | 43 | self.datasets = datasets 44 | 45 | self.sample_size = sample_size 46 | self.sample_rate = sample_rate 47 | self.random_crop = random_crop 48 | self.force_channels = force_channels 49 | self.augment_phase = augment_phase 50 | self.pre_encoded = pre_encoded 51 | self.volume_norm = volume_norm 52 | self.volume_norm_param = volume_norm_param 53 | self.remove_silence = remove_silence 54 | self.silence_threshold = silence_threshold 55 | self.max_silence_duration = max_silence_duration 56 | 57 | self.force_align_text = force_align_text 58 | self.phonemizer = Text2PhoneTokenizer() 59 | 60 | urls = [dataset.load_data_urls() for dataset in datasets] 61 | 62 | # Flatten the list of lists of URLs 63 | urls = [url for dataset_urls in urls for url in dataset_urls] 64 | 65 | # Shuffle the urls 66 | random.shuffle(urls) 67 | 68 | self.dataset = wds.DataPipeline( 69 | wds.ResampledShards(urls) if resampled_shards else wds.SimpleShardList(urls), 70 | wds.tarfile_to_samples(handler=log_and_continue), 71 | wds.decode(audio_decoder, handler=log_and_continue) if not self.pre_encoded else wds.decode(npy_decoder, handler=log_and_continue), 72 | wds.map(self.wds_preprocess, handler=log_and_continue), 73 | wds.select(is_valid_sample), 74 | wds.to_tuple("audio", "json", handler=log_and_continue), 75 | #wds.shuffle(bufsize=1000, initial=5000), 76 | wds.batched(batch_size, partial=False, collation_fn=collation_fn), 77 | ) 78 | 79 | if resampled_shards: 80 | self.dataset = self.dataset.with_epoch(epoch_steps//num_workers if num_workers > 0 else epoch_steps) 81 | 82 | self.data_loader = wds.WebLoader(self.dataset, num_workers=num_workers, **data_loader_kwargs) 83 | 84 | def wds_preprocess(self, sample): 85 | 86 | if self.pre_encoded: 87 | audio = torch.from_numpy(sample["npy"]) 88 | del sample["npy"] 89 | sample["__pre_encoded__"] = True 90 | sample["json"]["padding_mask"] = torch.tensor(sample["json"]["padding_mask"]) 91 | else: 92 | found_key, rewrite_key = '', '' 93 | for k, v in sample.items(): # print the all entries in dict 94 | for akey in AUDIO_KEYS: 95 | if k.endswith(akey): 96 | # to rename long/weird key with its simpler counterpart 97 | found_key, rewrite_key = k, akey 98 | break 99 | if '' != found_key: 100 | break 101 | if '' == found_key: # got no audio! 102 | return None # try returning None to tell WebDataset to skip this one 103 | 104 | audio, in_sr = sample[found_key] 105 | if in_sr != self.sample_rate: 106 | resample_tf = T.Resample(in_sr, self.sample_rate) 107 | audio = resample_tf(audio) 108 | 109 | # Replace the long silence by the short for the mono audios 110 | if audio.shape[0] == 1 and self.remove_silence: 111 | audio = remove_long_silence(audio, self.sample_rate, self.silence_threshold, self.max_silence_duration) 112 | 113 | original_length = audio.shape[-1] 114 | 115 | if self.sample_size is not None: 116 | # Pad/crop and get the relative timestamp 117 | pad_crop = PadCrop_Normalized_T( 118 | self.sample_size, randomize=self.random_crop, sample_rate=self.sample_rate) 119 | audio, t_start, t_end, seconds_start, seconds_total, padding_mask = pad_crop( 120 | audio) 121 | sample["json"]["seconds_start"] = seconds_start 122 | sample["json"]["seconds_total"] = seconds_total 123 | sample["json"]["padding_mask"] = padding_mask 124 | else: 125 | t_start, t_end = 0, 1 126 | 127 | start_time = (original_length * t_start) / self.sample_rate 128 | end_time = (original_length * t_end) / self.sample_rate 129 | 130 | # Check if audio is length zero, initialize to a single zero if so 131 | if audio.shape[-1] == 0: 132 | audio = torch.zeros(1, 1) 133 | 134 | # Make the audio stereo and augment by randomly inverting phase 135 | augs = torch.nn.Sequential( 136 | Stereo() if self.force_channels == "stereo" else torch.nn.Identity(), 137 | Mono() if self.force_channels == "mono" else torch.nn.Identity(), 138 | VolumeNorm(self.volume_norm_param, self.sample_rate) if self.volume_norm else torch.nn.Identity(), 139 | PhaseFlipper() if self.augment_phase else torch.nn.Identity() 140 | ) 141 | 142 | audio = augs(audio) 143 | 144 | sample["json"]["timestamps"] = (t_start, t_end) 145 | 146 | if found_key != rewrite_key: # rename long/weird key with its simpler counterpart 147 | del sample[found_key] 148 | 149 | if "text" in sample["json"]: 150 | sample["json"]["prompt"] = sample["json"]["text"] 151 | 152 | # Check for custom metadata functions 153 | for dataset in self.datasets: 154 | if dataset.custom_metadata_fn is None: 155 | continue 156 | 157 | if dataset.path in sample["__url__"]: 158 | custom_metadata = dataset.custom_metadata_fn(sample["json"], audio) 159 | sample["json"].update(custom_metadata) 160 | 161 | sample["audio"] = audio 162 | # Add audio to the metadata as well for conditioning 163 | sample["json"]["audio"] = audio 164 | 165 | if self.force_align_text and self.sample_size is not None: 166 | # Chunk the original transcriptions according to (start_time, end_time) 167 | chunked_text_list = [] 168 | for entry in sample["json"]['force_aligned_text']['transcript']: 169 | word_start = entry['start'] 170 | word_end = entry['end'] 171 | # Check if the word's start or end time falls within the time range 172 | if (word_start >= start_time and word_start <= end_time) or (word_end >= start_time and word_end <= end_time): 173 | chunked_text_list.append(entry['word']) 174 | 175 | chunked_text = ' '.join(chunked_text_list) 176 | chunked_phone = self.phonemizer.tokenize(chunked_text) 177 | 178 | sample["json"]["phone"] = chunked_phone 179 | sample["json"]["aligned_text"] = chunked_text 180 | 181 | return sample 182 | 183 | def create_dataloader_from_config(dataset_config, batch_size, sample_size, sample_rate, audio_channels=2, num_workers=4, shuffle = True): 184 | dataset_type = dataset_config.get("dataset_type", None) 185 | assert dataset_type is not None, "Dataset type must be specified in dataset config" 186 | assert dataset_type in ("s3", "wds") 187 | 188 | force_channels = "mono" if audio_channels == 1 else "stereo" 189 | 190 | wds_configs = [] 191 | for wds_config in dataset_config["datasets"]: 192 | custom_metadata_fn = None 193 | custom_metadata_module_path = wds_config.get("custom_metadata_module", None) 194 | 195 | if custom_metadata_module_path is not None: 196 | spec = importlib.util.spec_from_file_location( 197 | "metadata_module", custom_metadata_module_path) 198 | metadata_module = importlib.util.module_from_spec(spec) 199 | spec.loader.exec_module(metadata_module) 200 | custom_metadata_fn = metadata_module.get_custom_metadata 201 | 202 | if "s3_path" in wds_config: 203 | wds_configs.append(S3DatasetConfig( 204 | id=wds_config["id"], 205 | s3_path=wds_config["s3_path"], 206 | custom_metadata_fn=custom_metadata_fn, 207 | profile=wds_config.get("profile", None), 208 | )) 209 | elif "path" in wds_config: 210 | wds_configs.append(LocalWebDatasetConfig( 211 | id=wds_config["id"], 212 | path=wds_config["path"], 213 | custom_metadata_fn=custom_metadata_fn 214 | )) 215 | 216 | return WebDatasetDataLoader( 217 | wds_configs, 218 | sample_rate=sample_rate, 219 | sample_size=sample_size, 220 | batch_size=batch_size, 221 | remove_silence=dataset_config.get("remove_silence", False), 222 | silence_threshold=dataset_config.get("silence_threshold", [0.01, 0.5]), 223 | max_silence_duration=dataset_config.get("max_silence_duration", 0.25), 224 | random_crop=dataset_config.get("random_crop", True), 225 | volume_norm=dataset_config.get("volume_norm", False), 226 | volume_norm_param=dataset_config.get("volume_norm_param", [-16, 2]), 227 | num_workers=num_workers, 228 | persistent_workers=True, 229 | pin_memory=True, 230 | force_channels=force_channels, 231 | epoch_steps=dataset_config.get("epoch_steps", 2000), 232 | pre_encoded=dataset_config.get("pre_encoded", False), 233 | resampled_shards=dataset_config.get("resampled_shards", True), 234 | force_align_text=dataset_config.get("force_align_text", False) 235 | ).data_loader 236 | -------------------------------------------------------------------------------- /stable_codec/fsq.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dithered Finite Scalar Quantization 3 | Code adapted from https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/finite_scalar_quantization.py 4 | """ 5 | 6 | from typing import List, Tuple 7 | import random 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch.nn import Module 12 | from torch import Tensor, int32 13 | from torch.amp import autocast 14 | 15 | from einops import rearrange 16 | 17 | 18 | def leaky_hard_clip(x: Tensor, alpha: float = 1e-3) -> Tensor: 19 | return (1-alpha) * torch.clamp(x, -1, 1) + alpha * x 20 | 21 | def round_ste(z: Tensor) -> Tensor: 22 | """Round with straight through gradients.""" 23 | zhat = z.round() 24 | return z + (zhat - z).detach() 25 | 26 | class DitheredFSQ(Module): 27 | def __init__( 28 | self, 29 | levels: List[int], 30 | dither_inference: bool = False, 31 | num_codebooks: int = 1, 32 | noise_dropout: float = 0.5, 33 | scale: float = 1.0, 34 | ): 35 | super().__init__() 36 | self.levels = levels 37 | 38 | _levels = torch.tensor(levels, dtype=torch.int64) 39 | self.register_buffer("_levels", _levels, persistent = False) 40 | 41 | _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=torch.int64) 42 | self.register_buffer("_basis", _basis, persistent = False) 43 | 44 | codebook_dim = len(levels) 45 | self.codebook_dim = codebook_dim 46 | 47 | self.codebook_size = _levels.prod().item() 48 | 49 | self.num_codebooks = num_codebooks 50 | 51 | self.dim = codebook_dim * num_codebooks 52 | 53 | self.dither_inference = dither_inference 54 | 55 | self.scale = scale 56 | 57 | half_l = self.scale * 2 / (self._levels - 1) 58 | self.register_buffer("half_l", half_l, persistent = False) 59 | 60 | self.allowed_dtypes = (torch.float32, torch.float64) 61 | 62 | self.noise_dropout = noise_dropout 63 | 64 | def quantize(self, z, skip_tanh: bool = False): 65 | if not skip_tanh: z = torch.tanh(z) 66 | 67 | if not self.training: 68 | quantized = self._scale_and_shift_inverse(round_ste(self._scale_and_shift(z))) 69 | else: 70 | quantized = z 71 | mask = torch.bernoulli(torch.full([z.shape[0],1,1,1], self.noise_dropout, device = z.device)).bool().expand_as(z) 72 | quantized = torch.where(mask, quantized, self._scale_and_shift_inverse(round_ste(self._scale_and_shift(quantized)))) 73 | mask = torch.bernoulli(torch.full([z.shape[0],1,1,1], self.noise_dropout, device = z.device)).bool().expand_as(z) 74 | quantized = torch.where(mask, quantized, z + (torch.rand_like(z) - 0.5) * self.half_l) 75 | 76 | return quantized 77 | 78 | def _scale_and_shift(self, z): 79 | level_indices = (z + 1 * self.scale) / self.half_l 80 | return level_indices 81 | 82 | def _scale_and_shift_inverse(self, level_indices): 83 | z = level_indices * self.half_l - 1 * self.scale 84 | return z 85 | 86 | def _indices_to_codes(self, indices): 87 | level_indices = self._indices_to_level_indices(indices) 88 | codes = self._scale_and_shift_inverse(level_indices) 89 | return codes 90 | 91 | def _codes_to_indices(self, zhat): 92 | zhat = self._scale_and_shift(zhat) 93 | zhat = zhat.round().to(torch.int64) 94 | out = (zhat * self._basis).sum(dim=-1) 95 | return out 96 | 97 | def _indices_to_level_indices(self, indices): 98 | indices = rearrange(indices, '... -> ... 1') 99 | codes_non_centered = (indices // self._basis) % self._levels 100 | return codes_non_centered 101 | 102 | def indices_to_codes(self, indices): 103 | # Expects input of batch x sequence x num_codebooks 104 | assert indices.shape[-1] == self.num_codebooks, f'expected last dimension of {self.num_codebooks} but found last dimension of {indices.shape[-1]}' 105 | codes = self._indices_to_codes(indices.to(torch.int64)) 106 | codes = rearrange(codes, '... c d -> ... (c d)') 107 | return codes 108 | 109 | @autocast(device_type="cuda", enabled = False) 110 | def forward(self, z, skip_tanh: bool = False): 111 | 112 | orig_dtype = z.dtype 113 | 114 | assert z.shape[-1] == self.dim, f'expected dimension of {self.num_codebooks * self.dim} but found dimension of {z.shape[-1]}' 115 | 116 | z = rearrange(z, 'b n (c d) -> b n c d', c = self.num_codebooks) 117 | 118 | # make sure allowed dtype before quantizing 119 | 120 | if z.dtype not in self.allowed_dtypes: 121 | z = z.to(torch.float64) 122 | 123 | codes = self.quantize(z, skip_tanh=skip_tanh) 124 | indices = self._codes_to_indices(codes) 125 | codes = rearrange(codes, 'b n c d -> b n (c d)') 126 | 127 | # cast codes back to original dtype 128 | 129 | if codes.dtype != orig_dtype: 130 | codes = codes.type(orig_dtype) 131 | 132 | # return quantized output and indices 133 | 134 | return codes, indices 135 | -------------------------------------------------------------------------------- /stable_codec/model.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import torch.nn as nn 4 | import torchaudio 5 | 6 | from typing import Optional, List, Tuple, Union 7 | from einops import rearrange 8 | from stable_audio_tools.models import create_model_from_config 9 | from stable_audio_tools.models.fsq import DitheredFSQ 10 | from stable_audio_tools.models.utils import load_ckpt_state_dict 11 | from stable_audio_tools.training.utils import copy_state_dict 12 | from stable_audio_tools.data.utils import VolumeNorm 13 | 14 | from .residual_fsq import ResidualFSQBottleneck 15 | from stable_audio_tools import get_pretrained_model 16 | 17 | class StableCodec(nn.Module): 18 | def __init__(self, 19 | model_config_path: Optional[str] = None, ckpt_path: Optional[str] = None, pretrained_model: Optional[str] = None, device = torch.device("cpu"), 20 | ): 21 | super().__init__() 22 | self.device = device 23 | 24 | if pretrained_model is not None: 25 | print(f"Loading pretrained model `{pretrained_model}`.\n") 26 | self.model, model_config = get_pretrained_model(pretrained_model) 27 | else: 28 | if model_config_path is None: 29 | raise ValueError("Either `model_config_path` or `pretrained_model` should be provided.") 30 | print(f"Loading config from `{model_config_path}`.\n") 31 | with open(model_config_path) as f: 32 | model_config = json.load(f) 33 | self.model = create_model_from_config(model_config) 34 | if ckpt_path is not None: 35 | print(f"Loading weights from `{ckpt_path}`.\n") 36 | state = load_ckpt_state_dict(ckpt_path) 37 | copy_state_dict(self.model, state) 38 | 39 | self.model = self.model.to(self.device).eval().requires_grad_(False) 40 | 41 | self.residual_fsq: Optional[ResidualFSQBottleneck] = None 42 | 43 | self.sample_rate = model_config["sample_rate"] 44 | self.volume_norm = VolumeNorm([-20, 0], self.sample_rate) 45 | 46 | self.preset_bottleneck_configs = { 47 | "1x46656_400bps": [ 48 | ([6, 6, 6, 6, 6, 6], 1.0) 49 | ], 50 | "2x15625_700bps": [ 51 | ([5, 5, 5, 5, 5, 5], 1.0), 52 | ([5, 5, 5, 5, 5, 5], 0.25), 53 | ], 54 | "4x729_1000bps": [ 55 | ([3, 3, 3, 3, 3, 3], 1.0), 56 | ([3, 3, 3, 3, 3, 3], 0.5), 57 | ([3, 3, 3, 3, 3, 3], 0.25), 58 | ([3, 3, 3, 3, 3, 3], 0.125), 59 | ] 60 | } 61 | 62 | def set_posthoc_bottleneck(self, stages): 63 | if isinstance(stages,str): 64 | if stages in self.preset_bottleneck_configs: 65 | stages = self.preset_bottleneck_configs[stages] 66 | else: 67 | raise ValueError(f"Unsupported preset bottleneck configuration `{stages}`.") 68 | 69 | self.residual_fsq = ResidualFSQBottleneck(stages).to(self.device).eval().requires_grad_(False) 70 | 71 | def encode(self, audio: Union[str, torch.Tensor], posthoc_bottleneck: bool = False, normalize: bool = True,**kwargs): 72 | """ 73 | Encode audio into latents and tokens. 74 | 75 | Args: 76 | 77 | audio : Union[str, torch.Tensor] 78 | Path to an audio file or a `Tensor` of the eaudio itself. 79 | posthoc_bottleneck : bool 80 | Whether to inject a posthoc FSQ instead of the FSQ used during training. 81 | If `True`, its configuration should've been passed in with the `self.set_posthoc_bottleneck` method. 82 | normalize : bool 83 | Whether to normalize the audio to -20 LUFS before encoding (recommended). 84 | Other `kwargs` are the same as in `AudioAutoencoder.encode_audio` method. 85 | 86 | Returns: 87 | 88 | Tuple of `(continuous_latents, tokens)`. 89 | 90 | continuous_latents : torch.Tensor 91 | Pre-bottleneck latents in the `(B, H, S)` shape. 92 | tokens : torch.Tensor 93 | Bottleneck tokens in the `(B, S, 1)` shape. 94 | 95 | Where `B` is the batch size, `H` is the hidden dimension and `S` is the sequence length. 96 | """ 97 | if isinstance(audio, str): 98 | audio, sample_rate = torchaudio.load(audio) 99 | audio = self.model.preprocess_audio_for_encoder(audio.to(self.device), sample_rate) 100 | if normalize: 101 | audio = self.volume_norm(audio.squeeze(0)).unsqueeze(0) 102 | 103 | latents, info = self.model.encode_audio(audio, 104 | return_info=True, skip_bottleneck=posthoc_bottleneck, **kwargs) 105 | if posthoc_bottleneck: 106 | tokens = self.residual_fsq.encode(latents) 107 | else: 108 | tokens = info["quantizer_indices"] 109 | 110 | return info["pre_bottleneck_latents"], tokens 111 | 112 | def decode(self, tokens: torch.Tensor, posthoc_bottleneck: bool = False, **kwargs): 113 | """ 114 | Decode audio from tokens. 115 | 116 | Args: 117 | 118 | tokens : torch.Tensor 119 | Integer tokens produced by `encode` stage in `(B, S, 1)` shape. 120 | posthoc_bottleneck : bool 121 | Whether to inject a posthoc FSQ instead of the FSQ used during training. 122 | If `True`, its configuration should've been passed in with `self.set_posthoc_bottleneck` method. 123 | 124 | Returns: 125 | 126 | Decoded audio in the `(B, C, L)` shape. 127 | Where `B` is the batch size, `C` is the number of channels and `L` is the number of frames. 128 | """ 129 | if posthoc_bottleneck: 130 | latents = self.residual_fsq.decode(tokens) 131 | else: 132 | latents = self.model.bottleneck.decode_tokens(tokens) 133 | latents = rearrange(latents, "b c n -> b n c") 134 | 135 | audio = self.model.decode_audio(latents, **kwargs) 136 | return audio 137 | 138 | def main(): 139 | sc = StableCodec( 140 | pretrained_model="stabilityai/stable-codec-speech-16k", 141 | device = torch.device("cuda") 142 | ) 143 | 144 | sc.set_posthoc_bottleneck("2x15625_700bps") 145 | 146 | wavfile = "test.wav" 147 | 148 | posthoc_bottleneck = False 149 | latents, tokens = sc.encode(wavfile, posthoc_bottleneck=posthoc_bottleneck) 150 | decoded = sc.decode(tokens, posthoc_bottleneck=posthoc_bottleneck) 151 | torchaudio.save("decode.wav", decoded.squeeze(0).cpu(), 16000) 152 | 153 | posthoc_bottleneck = True 154 | latents, tokens = sc.encode(wavfile, posthoc_bottleneck=posthoc_bottleneck) 155 | decoded = sc.decode(tokens, posthoc_bottleneck=posthoc_bottleneck) 156 | torchaudio.save("decode-res.wav", decoded.squeeze(0).cpu(), 16000) 157 | 158 | if __name__ == "__main__": 159 | main() 160 | -------------------------------------------------------------------------------- /stable_codec/residual_fsq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from typing import List, Tuple 5 | from einops import rearrange 6 | from .fsq import DitheredFSQ 7 | 8 | class ResidualFSQBottleneck(nn.Module): 9 | def __init__(self, stages: List[Tuple[List[int], float]]): 10 | super().__init__() 11 | 12 | # 1st for single_tokens, others - residuals. 13 | self.quantizers = nn.ModuleList([ 14 | DitheredFSQ(levels=levels, scale=scale).eval().requires_grad_(False) 15 | for (levels, scale) in stages]) 16 | 17 | self.n_codebooks = len(stages) 18 | self.codebook_size = sum(map(len, stages)) * self.n_codebooks 19 | 20 | def encode(self, x): 21 | input_dtype = x.dtype 22 | z = torch.tanh(x.to(torch.float64)) 23 | z = rearrange(z, "b c n -> b n c") 24 | 25 | r = z 26 | res_ids = [] 27 | for quantizer in self.quantizers: 28 | q, ids = quantizer(r, skip_tanh=True) 29 | r = r - q.to(torch.float64) 30 | res_ids.append(ids) 31 | 32 | return res_ids 33 | 34 | def decode(self, res_ids): 35 | z = sum([ 36 | q.indices_to_codes(res_ids[i]) 37 | for (i, q) in enumerate(self.quantizers) 38 | ]) 39 | return rearrange(z, "b n c -> b c n") 40 | 41 | if __name__ == "__main__": 42 | fsq = DitheredFSQ([17, 17, 17, 17, 17, 17]).eval().requires_grad_(False) 43 | # res_fsq = ResidualFSQBottleneck(stages=[ 44 | # ([5, 5, 5, 5, 5, 5], 1.0), 45 | # ([5, 5, 5, 5, 5, 5], 0.25), 46 | # ]).eval().requires_grad_(False) 47 | res_fsq = ResidualFSQBottleneck(stages=[ 48 | ([3, 3, 3, 3, 3, 3], 1.0), 49 | ([3, 3, 3, 3, 3, 3], 0.5), 50 | ([3, 3, 3, 3, 3, 3], 0.25), 51 | ([3, 3, 3, 3, 3, 3], 0.125), 52 | ]).eval().requires_grad_(False) 53 | 54 | x = torch.rand(1, 6, 1) 55 | 56 | z1 = res_fsq.decode(res_fsq.encode(x)) 57 | 58 | _, y2 = fsq(rearrange(x, "b c n -> b n c")) 59 | z2 = rearrange(fsq.indices_to_codes(y2), "b n c -> b c n") 60 | 61 | print(z1) 62 | print(z2) 63 | assert (z1 == z2).all() 64 | -------------------------------------------------------------------------------- /stable_codec/training_demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchaudio 4 | import pytorch_lightning as pl 5 | 6 | from einops import rearrange 7 | from pytorch_lightning.utilities.rank_zero import rank_zero_only 8 | 9 | from stable_audio_tools.models.autoencoders import ( 10 | fold_channels_into_batch, unfold_channels_from_batch, 11 | ) 12 | from stable_audio_tools.training.utils import ( 13 | log_image, log_point_cloud, logger_project_name, log_audio, 14 | ) 15 | from stable_audio_tools.interface.aeiou import ( 16 | audio_spectrogram_image, tokens_spectrogram_image, 17 | ) 18 | 19 | def trim_to_shortest(a, b): 20 | """Trim the longer of two tensors to the length of the shorter one.""" 21 | if a.shape[-1] > b.shape[-1]: 22 | return a[:,:,:b.shape[-1]], b 23 | elif b.shape[-1] > a.shape[-1]: 24 | return a, b[:,:,:a.shape[-1]] 25 | return a, b 26 | 27 | class AutoencoderDemoCallback(pl.Callback): 28 | def __init__( 29 | self, 30 | demo_dl, 31 | demo_every = 2000, 32 | sample_size = 65536, 33 | sample_rate = 16000, 34 | max_demos = 8, 35 | ): 36 | super().__init__() 37 | self.demo_every = demo_every 38 | self.demo_samples = sample_size 39 | self.demo_dl = demo_dl 40 | self.sample_rate = sample_rate 41 | self.last_demo_step = -1 42 | self.max_demos = max_demos 43 | 44 | @rank_zero_only 45 | def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx): 46 | if ( 47 | (trainer.global_step - 1) % self.demo_every != 0 or 48 | self.last_demo_step == trainer.global_step 49 | ): 50 | return 51 | 52 | self.last_demo_step = trainer.global_step 53 | module.eval() 54 | 55 | try: 56 | demo_iter = iter(self.demo_dl) 57 | demo_reals, _ = next(demo_iter) 58 | 59 | # Remove extra dimension added by WebDataset 60 | if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: 61 | demo_reals = demo_reals[0] 62 | 63 | # Limit the number of demo samples 64 | if demo_reals.shape[0] > self.max_demos: 65 | demo_reals = demo_reals[:self.max_demos,...] 66 | 67 | encoder_input = demo_reals 68 | encoder_input = encoder_input.to(module.device) 69 | 70 | if module.force_input_mono: 71 | encoder_input = encoder_input.mean(dim=1, keepdim=True) 72 | 73 | demo_reals = demo_reals.to(module.device) 74 | 75 | with torch.no_grad(): 76 | if module.use_ema: 77 | latents = module.autoencoder_ema.ema_model.encode(encoder_input) 78 | fakes = module.autoencoder_ema.ema_model.decode(latents) 79 | else: 80 | latents = module.autoencoder.encode(encoder_input) 81 | fakes = module.autoencoder.decode(latents) 82 | 83 | #Trim output to remove post-padding. 84 | fakes, demo_reals = trim_to_shortest(fakes.detach(), demo_reals) 85 | 86 | # Visualize discriminator sensitivity. 87 | if module.discriminator is not None: 88 | window = torch.kaiser_window(512).to(fakes.device) 89 | stft_kwargs = { 90 | "n_fft": 512, 91 | "hop_length": 128, 92 | "win_length": 512, 93 | "window": window, 94 | "center": True, 95 | } 96 | 97 | fakes_stft = torch.stft(fold_channels_into_batch(fakes), 98 | return_complex=True, **stft_kwargs) 99 | fakes_stft.requires_grad = True 100 | fakes_signal = unfold_channels_from_batch( 101 | torch.istft(fakes_stft, **stft_kwargs), fakes.shape[1]) 102 | 103 | real_stft = torch.stft(fold_channels_into_batch(demo_reals), 104 | return_complex=True, **stft_kwargs) 105 | reals_signal = unfold_channels_from_batch( 106 | torch.istft(real_stft, **stft_kwargs), demo_reals.shape[1]) 107 | 108 | _, loss, _ = module.discriminator.loss(reals_signal, fakes_signal) 109 | fakes_stft.retain_grad() 110 | loss.backward() 111 | grads = unfold_channels_from_batch(fakes_stft.grad.detach().abs(), fakes.shape[1]) 112 | 113 | log_image(trainer.logger, 'disciminator_sensitivity', 114 | tokens_spectrogram_image(grads.mean(dim=1).log10(), 115 | title='Discriminator Sensitivity', symmetric=False)) 116 | opts = module.optimizers() 117 | opts[0].zero_grad() 118 | opts[1].zero_grad() 119 | 120 | #Interleave reals and fakes 121 | reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n') 122 | # Put the demos together 123 | reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)') 124 | 125 | data_dir = os.path.join( 126 | trainer.logger.save_dir, logger_project_name(trainer.logger), 127 | trainer.logger.experiment.id, "media") 128 | os.makedirs(data_dir, exist_ok=True) 129 | filename = os.path.join(data_dir, f'recon_{trainer.global_step:08}.wav') 130 | 131 | reals_fakes = reals_fakes.to(torch.float32).clamp(-1, 1).mul(32767).to(torch.int16).cpu() 132 | torchaudio.save(filename, reals_fakes, self.sample_rate) 133 | 134 | log_audio(trainer.logger, 'recon', filename, self.sample_rate) 135 | log_point_cloud(trainer.logger, 'embeddings_3dpca', latents) 136 | log_image(trainer.logger, 'embeddings_spec', tokens_spectrogram_image(latents)) 137 | log_image(trainer.logger, 'recon_melspec_left', audio_spectrogram_image(reals_fakes)) 138 | except Exception as e: 139 | print(f'{type(e).__name__}: {e}') 140 | raise e 141 | finally: 142 | module.train() 143 | 144 | def create_demo_callback_from_config(model_config, **kwargs): 145 | model_type = model_config.get('model_type', None) 146 | assert model_type is not None, 'model_type must be specified in model config' 147 | 148 | training_config = model_config.get('training', None) 149 | assert training_config is not None, 'training config must be specified in model config' 150 | 151 | demo_config = training_config.get("demo", {}) 152 | return AutoencoderDemoCallback( 153 | demo_every=demo_config.get("demo_every", 2000), 154 | sample_size=model_config["sample_size"], 155 | sample_rate=model_config["sample_rate"], 156 | **kwargs 157 | ) 158 | -------------------------------------------------------------------------------- /stable_codec/training_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import pytorch_lightning as pl 4 | 5 | from typing import Optional, Literal 6 | from ema_pytorch import EMA 7 | from torch.nn import Parameter 8 | from einops import rearrange 9 | 10 | from stable_audio_tools.models import create_model_from_config 11 | from stable_audio_tools.models.autoencoders import AudioAutoencoder 12 | from stable_audio_tools.models.discriminators import ( 13 | EncodecDiscriminator, OobleckDiscriminator, DACGANLoss, 14 | ) 15 | from stable_audio_tools.models.bottleneck import ( 16 | VAEBottleneck, RVQBottleneck, DACRVQBottleneck, DACRVQVAEBottleneck, 17 | RVQVAEBottleneck, WassersteinBottleneck, 18 | ) 19 | from stable_audio_tools.training.losses import ( 20 | MelSpectrogramLoss, MultiLoss, AuralossLoss, ValueLoss, L1Loss, 21 | LossWithTarget, MSELoss, HubertLoss, 22 | # PESQMetric, # TODO move PESQ here? 23 | ) 24 | from stable_audio_tools.training.losses import auraloss as auraloss 25 | from stable_audio_tools.training.utils import ( 26 | create_optimizer_from_config, create_scheduler_from_config, log_metric, 27 | ) 28 | 29 | from .ctc_loss import CTCLossModule, PERModule 30 | 31 | def trim_to_shortest(a, b): 32 | """Trim the longer of two tensors to the length of the shorter one.""" 33 | if a.shape[-1] > b.shape[-1]: 34 | return a[:,:,:b.shape[-1]], b 35 | elif b.shape[-1] > a.shape[-1]: 36 | return a, b[:,:,:a.shape[-1]] 37 | return a, b 38 | 39 | class ProjectionHead(nn.Module): 40 | def __init__(self, latent_dim, proj_head_dim, mid_dim=256): 41 | super(ProjectionHead, self).__init__() 42 | self.proj_head = nn.Sequential( 43 | nn.Tanh(), 44 | nn.Linear(latent_dim, mid_dim), 45 | nn.ReLU(), 46 | nn.Linear(mid_dim, mid_dim), 47 | nn.ReLU(), 48 | nn.Linear(mid_dim, proj_head_dim) 49 | ) 50 | 51 | def forward(self, x): 52 | return self.proj_head(x) 53 | 54 | class AutoencoderTrainingWrapper(pl.LightningModule): 55 | def __init__(self, 56 | autoencoder: AudioAutoencoder, 57 | loss_config: dict, 58 | eval_loss_config: dict, 59 | optimizer_configs: dict, 60 | sample_rate: int = 16000, 61 | lr: float = 1e-4, 62 | warmup_steps: int = 0, 63 | warmup_mode: Literal["adv", "full"] = "adv", 64 | encoder_freeze_on_warmup: bool = False, 65 | use_ema: bool = True, 66 | ema_copy = None, 67 | force_input_mono = False, 68 | latent_mask_ratio = 0.0, 69 | teacher_model: Optional[AudioAutoencoder] = None, 70 | clip_grad_norm = 0.0, 71 | encoder_mask_ratio = 0.0, 72 | use_ctc: bool = False, 73 | proj_head_dim: Optional[int] = None, 74 | detach_proj_head: bool = False, 75 | ): 76 | super().__init__() 77 | 78 | self.automatic_optimization = False 79 | self.autoencoder = autoencoder 80 | 81 | self.warmed_up = False 82 | self.warmup_steps = warmup_steps 83 | self.warmup_mode = warmup_mode 84 | self.encoder_freeze_on_warmup = encoder_freeze_on_warmup 85 | self.lr = lr 86 | self.clip_grad_norm = clip_grad_norm 87 | 88 | self.force_input_mono = force_input_mono 89 | self.teacher_model = teacher_model 90 | 91 | self.use_ctc = use_ctc 92 | self.proj_head_dim = proj_head_dim 93 | self.detach_proj_head = detach_proj_head 94 | self.projection_head = ( 95 | ProjectionHead(self.autoencoder.latent_dim, self.proj_head_dim) 96 | if self.use_ctc and self.proj_head_dim is not None else 97 | nn.Identity() 98 | ) 99 | 100 | self.optimizer_configs = optimizer_configs 101 | self.loss_config = loss_config 102 | 103 | # Spectral reconstruction loss 104 | self.sdstft = auraloss.MultiResolutionSTFTLoss( 105 | sample_rate=sample_rate, **loss_config['spectral']['config']) 106 | 107 | # Discriminator 108 | self.use_disc = True if 'discriminator' in loss_config else False 109 | self.discriminator = None 110 | if self.use_disc: 111 | if loss_config['discriminator']['type'] == 'oobleck': 112 | self.discriminator = OobleckDiscriminator(**loss_config['discriminator']['config']) 113 | elif loss_config['discriminator']['type'] == 'encodec': 114 | self.discriminator = EncodecDiscriminator( 115 | in_channels=self.autoencoder.out_channels, 116 | **loss_config['discriminator']['config']) 117 | elif loss_config['discriminator']['type'] == 'dac': 118 | self.discriminator = DACGANLoss( 119 | channels=self.autoencoder.out_channels, 120 | sample_rate=sample_rate, 121 | **loss_config['discriminator']['config']) 122 | 123 | gen_loss_modules = [] 124 | if self.use_disc: 125 | # Discriminator loss. 126 | self.losses_disc = MultiLoss([ 127 | ValueLoss(key='loss_dis', weight=1.0, name='discriminator_loss'), 128 | ]) 129 | 130 | # Adversarial and feature matching losses. 131 | gen_loss_modules += [ 132 | ValueLoss( 133 | key='loss_adv', 134 | weight=self.loss_config['discriminator']['weights']['adversarial'], 135 | name='loss_adv'), 136 | ValueLoss( 137 | key='feature_matching_distance', 138 | weight=self.loss_config['discriminator']['weights']['feature_matching'], 139 | name='feature_matching_loss'), 140 | ] 141 | 142 | # Reconstruction loss 143 | gen_loss_modules += [AuralossLoss(self.sdstft, 144 | target_key='reals', input_key='decoded', name='mrstft_loss', 145 | weight=self.loss_config['spectral']['weights']['mrstft'], 146 | decay=self.loss_config['spectral'].get('decay', 1.0), 147 | )] 148 | 149 | if "mrmel" in loss_config: 150 | mrmel_weight = loss_config["mrmel"]["weights"]["mrmel"] 151 | if mrmel_weight > 0: 152 | mrmel_config = loss_config["mrmel"]["config"] 153 | self.mrmel = MelSpectrogramLoss(sample_rate, 154 | n_mels=mrmel_config["n_mels"], 155 | window_lengths=mrmel_config["window_lengths"], 156 | pow=mrmel_config["pow"], 157 | log_weight=mrmel_config["log_weight"], 158 | mag_weight=mrmel_config["mag_weight"], 159 | ) 160 | gen_loss_modules.append(LossWithTarget( 161 | self.mrmel, "reals", "decoded", 162 | name="mrmel_loss", weight=mrmel_weight, 163 | )) 164 | 165 | if "hubert" in loss_config: 166 | hubert_weight = loss_config["hubert"]["weights"]["hubert"] 167 | if hubert_weight > 0: 168 | hubert_cfg = ( 169 | loss_config["hubert"]["config"] 170 | if "config" in loss_config["hubert"] else 171 | dict() 172 | ) 173 | self.hubert = HubertLoss(weight=1.0, **hubert_cfg) 174 | 175 | gen_loss_modules.append(LossWithTarget( 176 | self.hubert, target_key = "reals", input_key = "decoded", 177 | name="hubert_loss", weight=hubert_weight, 178 | decay = loss_config["hubert"].get("decay", 1.0) 179 | )) 180 | 181 | if "l1" in loss_config["time"]["weights"]: 182 | if self.loss_config['time']['weights']['l1'] > 0.0: 183 | gen_loss_modules.append(L1Loss( 184 | key_a='reals', key_b='decoded', 185 | weight=self.loss_config['time']['weights']['l1'], 186 | name='l1_time_loss', 187 | decay = self.loss_config['time'].get('decay', 1.0), 188 | )) 189 | 190 | if "l2" in loss_config["time"]["weights"]: 191 | if self.loss_config['time']['weights']['l2'] > 0.0: 192 | gen_loss_modules.append(MSELoss( 193 | key_a='reals', key_b='decoded', 194 | weight=self.loss_config['time']['weights']['l2'], 195 | name='l2_time_loss', 196 | decay = self.loss_config['time'].get('decay', 1.0), 197 | )) 198 | 199 | if self.autoencoder.bottleneck is not None: 200 | gen_loss_modules += create_loss_modules_from_bottleneck( 201 | self.autoencoder.bottleneck, self.loss_config) 202 | 203 | self.encoder_mask_ratio = encoder_mask_ratio 204 | if encoder_mask_ratio > 0.0: 205 | gen_loss_modules.append(L1Loss( 206 | key_a='detached_latents', key_b='masked_latents', 207 | weight=1.0, 208 | name='encoder_mask_loss', 209 | decay = 1.0, 210 | )) 211 | 212 | if "ctc" in loss_config: 213 | ctc_weight = loss_config["ctc"]["weights"]["ctc"] 214 | if ctc_weight > 0: 215 | gen_loss_modules.append(CTCLossModule( 216 | name = "ctc_loss", 217 | target_key = "ctc_tgt", 218 | input_key = "log_probs", 219 | weight = ctc_weight, 220 | decay = loss_config["ctc"].get("decay", 1.0), 221 | blank_idx = loss_config["ctc"].get("blank_idx", 80) 222 | )) 223 | 224 | self.losses_gen = MultiLoss(gen_loss_modules) 225 | 226 | # Set up EMA for model weights 227 | self.autoencoder_ema = None 228 | self.use_ema = use_ema 229 | if self.use_ema: 230 | self.autoencoder_ema = EMA( 231 | self.autoencoder, 232 | ema_model=ema_copy, 233 | beta=0.9999, 234 | power=3/4, 235 | update_every=1, 236 | update_after_step=1 237 | ) 238 | 239 | self.latent_mask_ratio = latent_mask_ratio 240 | 241 | # evaluation losses & metrics 242 | self.eval_losses = torch.nn.ModuleDict() 243 | if eval_loss_config is not None: 244 | # if "pesq" in eval_loss_config: 245 | # self.eval_losses["pesq"] = PESQMetric(sample_rate) 246 | if "stft"in eval_loss_config: 247 | self.eval_losses["stft"] = auraloss.STFTLoss(**eval_loss_config["stft"]) 248 | if "sisdr" in eval_loss_config: 249 | self.eval_losses["sisdr"] = auraloss.SISDRLoss(**eval_loss_config["sisdr"]) 250 | if "mel" in eval_loss_config: 251 | self.eval_losses["mel"] = auraloss.MelSTFTLoss( 252 | sample_rate, **eval_loss_config["mel"]) 253 | if "per" in eval_loss_config: 254 | self.eval_losses["per"] = PERModule( 255 | target_key = "ctc_tgt", 256 | input_key = "log_probs", 257 | blank_idx = loss_config["ctc"].get("blank_idx", 80)) 258 | 259 | self.validation_step_outputs = [] 260 | 261 | def configure_optimizers(self): 262 | gen_params = list(self.autoencoder.parameters()) 263 | 264 | if not self.use_disc: 265 | opt_gen = create_optimizer_from_config( 266 | self.optimizer_configs['autoencoder']['optimizer'], gen_params) 267 | if "scheduler" in self.optimizer_configs['autoencoder']: 268 | sched_gen = create_scheduler_from_config( 269 | self.optimizer_configs['autoencoder']['scheduler'], opt_gen) 270 | return [opt_gen], [sched_gen] 271 | return [opt_gen] 272 | 273 | # Using discriminator. 274 | opt_gen = create_optimizer_from_config( 275 | self.optimizer_configs['autoencoder']['optimizer'], gen_params) 276 | opt_disc = create_optimizer_from_config( 277 | self.optimizer_configs['discriminator']['optimizer'], 278 | self.discriminator.parameters()) 279 | 280 | use_scheduler = ( 281 | "scheduler" in self.optimizer_configs['autoencoder'] and 282 | "scheduler" in self.optimizer_configs['discriminator'] 283 | ) 284 | if use_scheduler: 285 | sched_gen = create_scheduler_from_config( 286 | self.optimizer_configs['autoencoder']['scheduler'], opt_gen) 287 | sched_disc = create_scheduler_from_config( 288 | self.optimizer_configs['discriminator']['scheduler'], opt_disc) 289 | return [opt_gen, opt_disc], [sched_gen, sched_disc] 290 | return [opt_gen, opt_disc] 291 | 292 | def forward(self, reals): 293 | latents, encoder_info = self.autoencoder.encode(reals, return_info=True) 294 | decoded = self.autoencoder.decode(latents) 295 | return decoded 296 | 297 | def validation_step(self, batch, batch_idx): 298 | reals, _ = batch 299 | # Remove extra dimension added by WebDataset 300 | if reals.ndim == 4 and reals.shape[0] == 1: 301 | reals = reals[0] 302 | 303 | if len(reals.shape) == 2: 304 | reals = reals.unsqueeze(1) 305 | 306 | loss_info = {} 307 | loss_info["reals"] = reals 308 | 309 | encoder_input = reals 310 | if self.force_input_mono and encoder_input.shape[1] > 1: 311 | encoder_input = encoder_input.mean(dim=1, keepdim=True) 312 | 313 | loss_info["encoder_input"] = encoder_input 314 | 315 | with torch.no_grad(): 316 | if self.use_ctc: 317 | latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True) 318 | continuous_latents = encoder_info["pre_bottleneck_latents"] 319 | proj_features = rearrange(continuous_latents, "b c n -> b n c") 320 | proj_features = self.projection_head( 321 | proj_features.detach() 322 | if self.detach_proj_head else 323 | proj_features 324 | ) 325 | 326 | loss_info['log_probs'] = proj_features 327 | loss_info['ctc_tgt'] = batch[1] 328 | else: 329 | latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True) 330 | 331 | loss_info["latents"] = latents 332 | loss_info.update(encoder_info) 333 | 334 | decoded = self.autoencoder.decode(latents) 335 | #Trim output to remove post-padding. 336 | decoded, reals = trim_to_shortest(decoded, reals) 337 | 338 | # Run evaluation metrics. 339 | val_loss_dict = {} 340 | for eval_key, eval_fn in self.eval_losses.items(): 341 | if eval_key == 'per': 342 | loss_value = eval_fn(loss_info) 343 | else: 344 | loss_value = eval_fn(decoded, reals) 345 | if eval_key == "sisdr": loss_value = -loss_value 346 | 347 | if isinstance(loss_value, torch.Tensor): 348 | loss_value = loss_value.item() 349 | 350 | val_loss_dict[eval_key] = loss_value 351 | 352 | self.validation_step_outputs.append(val_loss_dict) 353 | return val_loss_dict 354 | 355 | def on_validation_epoch_end(self): 356 | sum_loss_dict = {} 357 | for loss_dict in self.validation_step_outputs: 358 | for key, value in loss_dict.items(): 359 | if key not in sum_loss_dict: 360 | sum_loss_dict[key] = value 361 | else: 362 | sum_loss_dict[key] += value 363 | 364 | for key, value in sum_loss_dict.items(): 365 | val_loss = value / len(self.validation_step_outputs) 366 | val_loss = self.all_gather(val_loss).mean().item() 367 | log_metric(self.logger, f"val/{key}", val_loss) 368 | 369 | self.validation_step_outputs.clear() # free memory 370 | 371 | def training_step(self, batch, batch_idx): 372 | reals, _ = batch 373 | 374 | log_dict = {} 375 | # Remove extra dimension added by WebDataset 376 | if reals.ndim == 4 and reals.shape[0] == 1: 377 | reals = reals[0] 378 | 379 | if len(reals.shape) == 2: 380 | reals = reals.unsqueeze(1) 381 | 382 | if self.global_step >= self.warmup_steps: 383 | self.warmed_up = True 384 | 385 | loss_info = {} 386 | loss_info["reals"] = reals 387 | encoder_input = reals 388 | 389 | if self.force_input_mono and encoder_input.shape[1] > 1: 390 | encoder_input = encoder_input.mean(dim=1, keepdim=True) 391 | 392 | loss_info["encoder_input"] = encoder_input 393 | data_std = encoder_input.std() 394 | 395 | if self.warmed_up and self.encoder_freeze_on_warmup: 396 | with torch.no_grad(): 397 | latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True) 398 | else: 399 | if self.use_ctc: 400 | latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True) 401 | continuous_latents = encoder_info["pre_bottleneck_latents"] 402 | proj_features = rearrange(continuous_latents, "b c n -> b n c") 403 | proj_features = self.projection_head( 404 | proj_features.detach() 405 | if self.detach_proj_head else 406 | proj_features 407 | ) 408 | 409 | loss_info['log_probs'] = proj_features 410 | loss_info['ctc_tgt'] = batch[1] 411 | else: 412 | latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True) 413 | 414 | if self.encoder_mask_ratio > 0.0: 415 | masked_latents = self.autoencoder.encode( 416 | encoder_input, return_info=False, encoder_mask_ratio=self.encoder_mask_ratio) 417 | detached_latents = latents.detach() 418 | loss_info["masked_latents"] = masked_latents 419 | loss_info["detached_latents"] = detached_latents 420 | 421 | loss_info["latents"] = latents 422 | loss_info.update(encoder_info) 423 | 424 | # Encode with teacher model for distillation 425 | if self.teacher_model is not None: 426 | with torch.no_grad(): 427 | teacher_latents = self.teacher_model.encode(encoder_input, return_info=False) 428 | loss_info['teacher_latents'] = teacher_latents 429 | 430 | # Optionally mask out some latents for noise resistance 431 | if self.latent_mask_ratio > 0.0: 432 | mask = torch.rand_like(latents) < self.latent_mask_ratio 433 | latents = torch.where(mask, torch.zeros_like(latents), latents) 434 | 435 | decoded = self.autoencoder.decode(latents) 436 | #Trim output to remove post-padding 437 | decoded, reals = trim_to_shortest(decoded, reals) 438 | 439 | loss_info["decoded"] = decoded 440 | loss_info["reals"] = reals 441 | 442 | if self.autoencoder.out_channels == 2: 443 | loss_info["decoded_left"] = decoded[:, 0:1, :] 444 | loss_info["decoded_right"] = decoded[:, 1:2, :] 445 | loss_info["reals_left"] = reals[:, 0:1, :] 446 | loss_info["reals_right"] = reals[:, 1:2, :] 447 | 448 | # Distillation 449 | if self.teacher_model is not None: 450 | with torch.no_grad(): 451 | teacher_decoded = self.teacher_model.decode(teacher_latents) 452 | own_latents_teacher_decoded = self.teacher_model.decode(latents) #Distilled model's latents decoded by teacher 453 | teacher_latents_own_decoded = self.autoencoder.decode(teacher_latents) #Teacher's latents decoded by distilled model 454 | 455 | loss_info['teacher_decoded'] = teacher_decoded 456 | loss_info['own_latents_teacher_decoded'] = own_latents_teacher_decoded 457 | loss_info['teacher_latents_own_decoded'] = teacher_latents_own_decoded 458 | 459 | if self.use_disc: 460 | if self.warmed_up: 461 | loss_dis, loss_adv, feature_matching_distance = self.discriminator.loss(reals=reals, fakes=decoded) 462 | else: 463 | loss_adv = torch.tensor(0.).to(reals) 464 | feature_matching_distance = torch.tensor(0.).to(reals) 465 | 466 | if self.warmup_mode == "adv": 467 | loss_dis, _, _ = self.discriminator.loss(reals=reals, fakes=decoded) 468 | else: 469 | loss_dis = torch.tensor(0.0).to(reals) 470 | 471 | loss_info["loss_dis"] = loss_dis 472 | loss_info["loss_adv"] = loss_adv 473 | loss_info["feature_matching_distance"] = feature_matching_distance 474 | 475 | opt_gen = None 476 | opt_disc = None 477 | if self.use_disc: 478 | opt_gen, opt_disc = self.optimizers() 479 | else: 480 | opt_gen = self.optimizers() 481 | 482 | lr_schedulers = self.lr_schedulers() 483 | sched_gen = None 484 | sched_disc = None 485 | 486 | if lr_schedulers is not None: 487 | if self.use_disc: 488 | sched_gen, sched_disc = lr_schedulers 489 | else: 490 | sched_gen = lr_schedulers 491 | 492 | # Train the discriminator 493 | use_disc = ( 494 | self.use_disc 495 | and self.global_step % 2 496 | # Check warmup mode and if it is time to use discriminator. 497 | and ( 498 | (self.warmup_mode == "full" and self.warmed_up) 499 | or self.warmup_mode == "adv") 500 | ) 501 | if use_disc: 502 | loss, losses = self.losses_disc(loss_info) 503 | log_dict['train/disc_lr'] = opt_disc.param_groups[0]['lr'] 504 | opt_disc.zero_grad() 505 | self.manual_backward(loss) 506 | 507 | if self.clip_grad_norm > 0.0: 508 | torch.nn.utils.clip_grad_norm_( 509 | self.discriminator.parameters(), self.clip_grad_norm) 510 | 511 | opt_disc.step() 512 | if sched_disc is not None: 513 | # sched step every step 514 | sched_disc.step() 515 | 516 | # Train the generator 517 | else: 518 | loss, losses = self.losses_gen(loss_info) 519 | if self.use_ema: 520 | self.autoencoder_ema.update() 521 | 522 | opt_gen.zero_grad() 523 | self.manual_backward(loss) 524 | if self.clip_grad_norm > 0.0: 525 | torch.nn.utils.clip_grad_norm_( 526 | self.autoencoder.parameters(), self.clip_grad_norm) 527 | 528 | opt_gen.step() 529 | if sched_gen is not None: 530 | # scheduler step every step 531 | sched_gen.step() 532 | 533 | log_dict['train/loss'] = loss.detach().item() 534 | log_dict['train/latent_std'] = latents.std().detach().item() 535 | log_dict['train/data_std'] = data_std.detach().item() 536 | log_dict['train/gen_lr'] = opt_gen.param_groups[0]['lr'] 537 | 538 | for loss_name, loss_value in losses.items(): 539 | log_dict[f'train/{loss_name}'] = loss_value.detach().item() 540 | 541 | self.log_dict(log_dict, prog_bar=True, on_step=True) 542 | return loss 543 | 544 | def export_model(self, path, use_safetensors=False): 545 | if self.autoencoder_ema is not None: 546 | model = self.autoencoder_ema.ema_model 547 | else: 548 | model = self.autoencoder 549 | 550 | if use_safetensors: 551 | save_model(model, path) 552 | else: 553 | torch.save({"state_dict": model.state_dict()}, path) 554 | 555 | def create_loss_modules_from_bottleneck(bottleneck, loss_config): 556 | losses = [] 557 | 558 | if ( 559 | isinstance(bottleneck, VAEBottleneck) or 560 | isinstance(bottleneck, DACRVQVAEBottleneck) or 561 | isinstance(bottleneck, RVQVAEBottleneck) 562 | ): 563 | try: 564 | kl_weight = loss_config['bottleneck']['weights']['kl'] 565 | except: 566 | kl_weight = 1e-6 567 | 568 | kl_loss = ValueLoss(key='kl', weight=kl_weight, name='kl_loss') 569 | losses.append(kl_loss) 570 | 571 | if ( 572 | isinstance(bottleneck, RVQBottleneck) or 573 | isinstance(bottleneck, RVQVAEBottleneck) 574 | ): 575 | quantizer_loss = ValueLoss(key='quantizer_loss', weight=1.0, name='quantizer_loss') 576 | losses.append(quantizer_loss) 577 | 578 | if isinstance(bottleneck, DACRVQBottleneck) or isinstance(bottleneck, DACRVQVAEBottleneck): 579 | codebook_loss = ValueLoss(key='vq/codebook_loss', weight=1.0, name='codebook_loss') 580 | commitment_loss = ValueLoss(key='vq/commitment_loss', weight=0.25, name='commitment_loss') 581 | losses.append(codebook_loss) 582 | losses.append(commitment_loss) 583 | 584 | if isinstance(bottleneck, WassersteinBottleneck): 585 | try: 586 | mmd_weight = loss_config['bottleneck']['weights']['mmd'] 587 | except: 588 | mmd_weight = 100 589 | 590 | mmd_loss = ValueLoss(key='mmd', weight=mmd_weight, name='mmd_loss') 591 | losses.append(mmd_loss) 592 | 593 | return losses 594 | 595 | def create_training_wrapper_from_config(model_config, model): 596 | model_type = model_config.get('model_type', None) 597 | assert model_type is not None, 'model_type must be specified in model config' 598 | 599 | training_config = model_config.get('training', None) 600 | assert training_config is not None, 'training config must be specified in model config' 601 | 602 | ema_copy = None 603 | if training_config.get("use_ema", False): 604 | ema_copy = create_model_from_config(model_config) 605 | # Copy each weight to the ema copy 606 | for name, param in model.state_dict().items(): 607 | if isinstance(param, Parameter): 608 | # backwards compatibility for serialized parameters 609 | param = param.data 610 | ema_copy.state_dict()[name].copy_(param) 611 | 612 | use_ema = training_config.get("use_ema", False) 613 | latent_mask_ratio = training_config.get("latent_mask_ratio", 0.0) 614 | 615 | teacher_model = training_config.get("teacher_model", None) 616 | if teacher_model is not None: 617 | teacher_model = create_model_from_config(teacher_model) 618 | teacher_model = teacher_model.eval().requires_grad_(False) 619 | 620 | teacher_model_ckpt = training_config.get("teacher_model_ckpt", None) 621 | if teacher_model_ckpt is not None: 622 | teacher_model.load_state_dict(torch.load(teacher_model_ckpt)["state_dict"]) 623 | else: 624 | raise ValueError("teacher_model_ckpt must be specified if teacher_model is specified") 625 | 626 | return AutoencoderTrainingWrapper( 627 | model, 628 | lr=training_config.get("learning_rate", None), 629 | warmup_steps=training_config.get("warmup_steps", 0), 630 | encoder_freeze_on_warmup=training_config.get("encoder_freeze_on_warmup", False), 631 | sample_rate=model_config["sample_rate"], 632 | loss_config=training_config.get("loss_configs", None), 633 | eval_loss_config=training_config.get("eval_loss_configs", None), 634 | optimizer_configs=training_config.get("optimizer_configs", None), 635 | use_ema=use_ema, 636 | ema_copy=ema_copy if use_ema else None, 637 | force_input_mono=training_config.get("force_input_mono", False), 638 | latent_mask_ratio=latent_mask_ratio, 639 | teacher_model=teacher_model, 640 | encoder_mask_ratio=training_config.get("encoder_mask_ratio", 0.0), 641 | use_ctc=training_config.get("use_ctc", False), 642 | proj_head_dim=model_config["model"].get("proj_head_dim", False), 643 | detach_proj_head=model_config["model"].get("detach_proj_head", None), 644 | ) 645 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import os 4 | import pytorch_lightning as pl 5 | 6 | from typing import Optional 7 | from prefigure.prefigure import get_all_args, push_wandb_config 8 | from stable_audio_tools.models import create_model_from_config 9 | from stable_audio_tools.models.utils import load_ckpt_state_dict, remove_weight_norm_from_model 10 | from stable_audio_tools.training.utils import copy_state_dict 11 | 12 | from stable_codec.training_module import create_training_wrapper_from_config 13 | from stable_codec.training_demo import create_demo_callback_from_config 14 | from stable_codec.data.dataset import create_dataloader_from_config 15 | 16 | class ExceptionCallback(pl.Callback): 17 | def on_exception(self, trainer, module, err): 18 | print(f'{type(err).__name__}: {err}') 19 | 20 | class ModelConfigEmbedderCallback(pl.Callback): 21 | def __init__(self, model_config): 22 | self.model_config = model_config 23 | 24 | def on_save_checkpoint(self, trainer, pl_module, checkpoint): 25 | checkpoint["model_config"] = self.model_config 26 | 27 | def main(): 28 | args = get_all_args() 29 | seed = args.seed 30 | 31 | # Set a different seed for each process if using SLURM 32 | if os.environ.get("SLURM_PROCID") is not None: 33 | seed += int(os.environ.get("SLURM_PROCID")) 34 | 35 | print(f"Setting random seed: `{seed}`.") 36 | pl.seed_everything(seed, workers=True) 37 | 38 | save_dir = args.save_dir 39 | ckpt_path: Optional[str] = None 40 | if args.ckpt_path: 41 | ckpt_path = args.ckpt_path 42 | print(f"Using user-provided checkpoint: `{ckpt_path}`.") 43 | 44 | with open(args.model_config) as f: 45 | model_config = json.load(f) 46 | with open(args.dataset_config) as f: 47 | dataset_config = json.load(f) 48 | 49 | train_dl = create_dataloader_from_config( 50 | dataset_config, 51 | batch_size=args.batch_size, 52 | num_workers=args.num_workers, 53 | sample_rate=model_config["sample_rate"], 54 | sample_size=model_config["sample_size"], 55 | audio_channels=model_config.get("audio_channels", 2), 56 | ) 57 | 58 | val_dl = None 59 | val_dataset_config = None 60 | 61 | if args.val_dataset_config: 62 | with open(args.val_dataset_config) as f: 63 | val_dataset_config = json.load(f) 64 | 65 | val_dl = create_dataloader_from_config( 66 | val_dataset_config, 67 | batch_size=args.batch_size, 68 | num_workers=args.num_workers, 69 | sample_rate=model_config["sample_rate"], 70 | sample_size=model_config["sample_size"], 71 | audio_channels=model_config.get("audio_channels", 2), 72 | shuffle=False, 73 | ) 74 | 75 | model = create_model_from_config(model_config) 76 | if args.pretrained_ckpt_path: 77 | copy_state_dict(model, load_ckpt_state_dict(args.pretrained_ckpt_path)) 78 | 79 | if args.remove_pretransform_weight_norm == "pre_load": 80 | remove_weight_norm_from_model(model.pretransform) 81 | if args.pretransform_ckpt_path: 82 | model.pretransform.load_state_dict(load_ckpt_state_dict(args.pretransform_ckpt_path)) 83 | if args.remove_pretransform_weight_norm == "post_load": 84 | remove_weight_norm_from_model(model.pretransform) 85 | 86 | training_wrapper = create_training_wrapper_from_config(model_config, model) 87 | 88 | if args.project is None: 89 | project_name = args.name 90 | run_name = None 91 | else: 92 | project_name = args.project 93 | run_name = args.name 94 | 95 | exc_callback = ExceptionCallback() 96 | 97 | logger = None 98 | ckpt_dir = save_dir 99 | if args.logger == 'wandb': 100 | logger = pl.loggers.WandbLogger( 101 | name=run_name, project=project_name, 102 | save_dir=save_dir) 103 | logger.watch(training_wrapper, log_freq=1000) 104 | 105 | ckpt_dir = os.path.join( 106 | save_dir, logger.experiment.project, 107 | logger.experiment.id, "checkpoints") 108 | elif args.logger == 'comet': 109 | logger = pl.loggers.CometLogger( 110 | api_key=os.environ.get("COMET_API_KEY"), 111 | experiment_name=run_name, project_name=project_name, 112 | save_dir=save_dir) 113 | 114 | ckpt_dir = os.path.join( 115 | save_dir, project_name, logger.experiment.id, "checkpoints") 116 | 117 | print(f"Checkpoint dir: `{ckpt_dir}`.") 118 | ckpt_callback = pl.callbacks.ModelCheckpoint( 119 | every_n_train_steps=args.checkpoint_every, 120 | dirpath=ckpt_dir, save_top_k=args.save_top_k) 121 | save_model_config_callback = ModelConfigEmbedderCallback(model_config) 122 | 123 | demo_dl = copy.deepcopy(val_dl if args.val_dataset_config else train_dl) 124 | demo_callback = create_demo_callback_from_config(model_config, demo_dl=demo_dl) 125 | 126 | #Combine args and config dicts 127 | args_dict = vars(args) 128 | args_dict.update({"model_config": model_config}) 129 | args_dict.update({"dataset_config": dataset_config}) 130 | args_dict.update({"val_dataset_config": val_dataset_config}) 131 | 132 | if args.logger == 'wandb': 133 | push_wandb_config(logger, args_dict) 134 | elif args.logger == 'comet': 135 | logger.log_hyperparams(args_dict) 136 | 137 | #Set multi-GPU strategy if specified 138 | if args.strategy: 139 | strategy = args.strategy 140 | else: 141 | strategy = 'ddp_find_unused_parameters_true' if args.num_gpus > 1 else "auto" 142 | 143 | trainer = pl.Trainer( 144 | devices="auto", 145 | accelerator="gpu", 146 | num_nodes = args.num_nodes, 147 | strategy=strategy, 148 | precision=args.precision, 149 | accumulate_grad_batches=args.accum_batches, 150 | callbacks=[ckpt_callback, demo_callback, exc_callback, save_model_config_callback], 151 | logger=logger, 152 | log_every_n_steps=1, 153 | max_epochs=10000000, 154 | default_root_dir=save_dir, 155 | gradient_clip_val=args.gradient_clip_val, 156 | reload_dataloaders_every_n_epochs=0, 157 | ) 158 | trainer.fit(training_wrapper, train_dl, val_dl, ckpt_path=ckpt_path) 159 | 160 | if __name__ == '__main__': 161 | main() 162 | --------------------------------------------------------------------------------