├── .gitignore ├── README.md ├── conf ├── infer_transfer.yaml ├── ref_gst_codec_vits.yaml ├── ref_mixer_codec_vits.yaml └── ref_vits.yaml ├── infer.py ├── model ├── base.py ├── helper.py ├── splines.py ├── utils.py └── vits.py ├── module ├── monotonic_align.py ├── ref_gst.py ├── ref_mixer_codec_only.py ├── ref_vits_module.py ├── vits_losses.py └── vits_modules.py ├── preprocess └── make_manifest.py ├── requirements.txt ├── torchdata ├── data.py ├── data_total.py ├── data_type.py └── text_preprocess.py └── vits_train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # User-specific stuff 2 | .idea -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # dc-comix-tts 2 | Implementation of [DCComix TTS: An End-to-End Expressive TTS with Discrete Code Collaborated with Mixer](https://arxiv.org/abs/2305.19567) 3 | Accepted to Interspech 2023. Audio samples/demo for this system is [here](https://lakahaga.github.io/dc-comix-tts/) 4 | 5 | Abstract: Despite the huge successes made in neutral TTS, content-leakage remains a challenge. In this paper, we propose a new input representation and simple architecture to achieve improved prosody modeling. Inspired by the recent success in the use of discrete code in TTS, we introduce discrete code to the input of the reference encoder. Specifically, we leverage the vector quantizer from the audio compression model to exploit the diverse acoustic information it has already been trained on. In addition, we apply the modified MLP-Mixer to the reference encoder, making the architecture lighter. As a result, we train the prosody transfer TTS in an end-to-end manner. We prove the effectiveness of our method through both subjective and objective evaluations. We demonstrate that the reference encoder learns better speaker-independent prosody when discrete code is utilized as input in the experiments. In addition, we obtain comparable results even when fewer parameters are inputted. 6 | 7 | * This repository leverages [Nemo](https://github.com/NVIDIA/NeMo) for [VITS](https://arxiv.org/pdf/2106.06103.pdf) and [MixerTTS](https://arxiv.org/abs/2110.03584) implementation. 8 | * We use [Encodec](https://github.com/facebookresearch/encodec) for discrete code 9 | 10 | ## Installation 11 | * python ≥ 3.8 12 | * pytorch 1.11.0+cu113 13 | * nemo_toolkit 1.18.0 14 | 15 | See `requirements.txt` for other libraries 16 | ## Traininig 17 | * prepare data ([VCTK](https://datashare.ed.ac.uk/handle/10283/2651)) 18 | ``` 19 | python preprocess/make_manifest.py 20 | ``` 21 | * Note that we resample VCTK audios to 24kHz to match resolution with Encodec 22 | * preprocessing 23 | * text normalization 24 | ``` 25 | python torchdata/text_preprocess.py 26 | ``` 27 | * run `train.py` 28 | * for `dc-comix-tts` : use `ref_mixer_codec_vits.yaml` 29 | 30 | ## References 31 | 32 | ```text 33 | @software{Harper_NeMo_a_toolkit, 34 | author = {Harper, Eric and Majumdar, Somshubra and Kuchaiev, Oleksii and Jason, Li and Zhang, Yang and Bakhturina, Evelina and Noroozi, Vahid and Subramanian, Sandeep and Nithin, Koluguri and Jocelyn, Huang and Jia, Fei and Balam, Jagadeesh and Yang, Xuesong and Livne, Micha and Dong, Yi and Naren, Sean and Ginsburg, Boris}, 35 | title = {{NeMo: a toolkit for Conversational AI and Large Language Models}}, 36 | url = {https://github.com/NVIDIA/NeMo} 37 | } 38 | ``` 39 | ```text 40 | @article{defossez2022highfi, 41 | title={High Fidelity Neural Audio Compression}, 42 | author={Défossez, Alexandre and Copet, Jade and Synnaeve, Gabriel and Adi, Yossi}, 43 | journal={arXiv preprint arXiv:2210.13438}, 44 | year={2022} 45 | } 46 | ``` 47 | -------------------------------------------------------------------------------- /conf/infer_transfer.yaml: -------------------------------------------------------------------------------- 1 | name: ?? 2 | device: 'cuda' 3 | 4 | checkpoint_path: ?? 5 | manifest_path: "data/test_manifest.json" 6 | sup_data_path: "sup_data" 7 | sup_data_types: [ 'audio_codec' ] 8 | codec_model: "encodec" 9 | 10 | whitelist_path: "sup_data/text/whitelist/lj_speech.tsv" 11 | phoneme_dict_path: "sup_data/text/cmudict-0.7b_nv22.10" 12 | heteronyms_path: "sup_data/text/heteronyms-052722" 13 | 14 | sample_rate: 24000 15 | n_mel_channels: 80 16 | n_window_size: 1024 17 | n_window_stride: 256 18 | n_fft: 1024 19 | lowfreq: 0 20 | highfreq: 8000 21 | window: hann 22 | 23 | text_normalizer: 24 | _target_: nemo_text_processing.text_normalization.normalize.Normalizer 25 | lang: en 26 | input_case: cased 27 | whitelist: ${whitelist_path} 28 | 29 | text_normalizer_call_kwargs: 30 | verbose: false 31 | punct_pre_process: true 32 | punct_post_process: true 33 | 34 | text_tokenizer: 35 | _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.EnglishPhonemesTokenizer 36 | punct: true 37 | stresses: true 38 | chars: true 39 | apostrophe: true 40 | pad_with_space: true 41 | g2p: 42 | _target_: nemo_text_processing.g2p.modules.EnglishG2p 43 | phoneme_dict: ${phoneme_dict_path} 44 | heteronyms: ${heteronyms_path} 45 | 46 | 47 | dataset: 48 | _target_: data_total.ExtensiveTTSDataset 49 | manifest_filepath: ${manifest_path} 50 | text_tokenizer: ${text_tokenizer} 51 | text_normalizer: ${text_normalizer} 52 | text_normalizer_call_kwargs: ${text_normalizer_call_kwargs} 53 | sample_rate: ${sample_rate} 54 | sup_data_path: ${sup_data_path} 55 | sup_data_types: ${sup_data_types} 56 | n_fft: ${n_fft} 57 | win_length: ${n_window_size} 58 | hop_length: ${n_window_stride} 59 | window: ${window} 60 | n_mels: ${n_mel_channels} 61 | lowfreq: ${lowfreq} 62 | highfreq: ${highfreq} 63 | max_duration: null 64 | min_duration: 1.0 65 | ignore_file: null 66 | trim: true 67 | top_db: 35 68 | lm_model: ${lm_model} 69 | audio_model: ${audio_model} 70 | codec_model: ${codec_model} 71 | -------------------------------------------------------------------------------- /conf/ref_gst_codec_vits.yaml: -------------------------------------------------------------------------------- 1 | # This config contains the default values for training VITS model on LJSpeech dataset. 2 | # If you want to train model on other dataset, you can change config values according to your dataset. 3 | # Most dataset-specific arguments are in the head of the config file, see below. 4 | 5 | name: VITS_GST 6 | 7 | batch_size: ?? 8 | num_workers: ?? 9 | ngpu: ?? 10 | 11 | train_dataset: "data/train_manifest.json" 12 | validation_datasets: "data/valid_manifest.json" 13 | sup_data_path: "sup_data/" 14 | sup_data_types: [ "speaker_id", "audio_codec"] 15 | 16 | whitelist_path: "sup_data/text/whitelist/lj_speech.tsv" 17 | phoneme_dict_path: "sup_data/text/cmudict-0.7b_nv22.10" 18 | heteronyms_path: "sup_data/text/heteronyms-052722" 19 | 20 | # Default values from librosa.pyin 21 | pitch_fmin: 65.40639132514966 22 | pitch_fmax: 2093.004522404789 23 | 24 | sample_rate: 24000 25 | n_mel_channels: 80 26 | n_window_size: 1024 27 | n_window_stride: 256 28 | n_fft: 1024 29 | lowfreq: 0 30 | highfreq: null 31 | window: hann 32 | 33 | 34 | lm_model: "facebook/data2vec-text-base" 35 | codec_model: "encodec" 36 | 37 | 38 | model: 39 | pitch_fmin: ${pitch_fmin} 40 | pitch_fmax: ${pitch_fmax} 41 | 42 | sample_rate: ${sample_rate} 43 | n_mel_channels: ${n_mel_channels} 44 | n_window_size: ${n_window_size} 45 | n_window_stride: ${n_window_stride} 46 | n_fft: ${n_fft} 47 | lowfreq: ${lowfreq} 48 | highfreq: ${highfreq} 49 | window: ${window} 50 | mel_fmin: 0.0 51 | mel_fmax: null 52 | 53 | n_speakers: 1048 54 | segment_size: 8192 55 | c_mel: 45 56 | c_kl: 1. 57 | use_spectral_norm: false 58 | 59 | text_normalizer: 60 | _target_: nemo_text_processing.text_normalization.normalize.Normalizer 61 | lang: en 62 | input_case: cased 63 | whitelist: ${whitelist_path} 64 | 65 | text_normalizer_call_kwargs: 66 | verbose: false 67 | punct_pre_process: true 68 | punct_post_process: true 69 | 70 | text_tokenizer: 71 | _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer 72 | punct: true 73 | apostrophe: true 74 | pad_with_space: false 75 | g2p: 76 | _target_: nemo_text_processing.g2p.modules.IPAG2P 77 | phoneme_dict: ${phoneme_dict_path} 78 | heteronyms: ${heteronyms_path} 79 | phoneme_probability: 0.8 80 | # Relies on the heteronyms list for anything that needs to be disambiguated 81 | ignore_ambiguous_words: false 82 | use_chars: true 83 | use_stresses: true 84 | 85 | train_ds: 86 | dataset: 87 | _target_: torchdata.data_total.ExtensiveTTSDataset 88 | manifest_filepath: ${train_dataset} 89 | sample_rate: ${model.sample_rate} 90 | sup_data_path: ${sup_data_path} 91 | sup_data_types: ${sup_data_types} 92 | n_fft: ${model.n_fft} 93 | win_length: ${model.n_window_size} 94 | hop_length: ${model.n_window_stride} 95 | window: ${model.window} 96 | n_mels: ${model.n_mel_channels} 97 | lowfreq: ${model.lowfreq} 98 | highfreq: ${model.highfreq} 99 | max_duration: null 100 | min_duration: 0.7 101 | ignore_file: null 102 | trim: False 103 | pitch_fmin: ${model.pitch_fmin} 104 | pitch_fmax: ${model.pitch_fmax} 105 | lm_model: ${lm_model} 106 | codec_model: ${codec_model} 107 | codec_sum: false 108 | 109 | dataloader_params: 110 | # drop_last: false 111 | # shuffle: true 112 | # batch_size: ${batch_size} 113 | num_workers: ${num_workers} 114 | pin_memory: true 115 | batch_sampler: 116 | batch_size: ${batch_size} 117 | boundaries: [32,300,400,500,600,700,800,900,1000] 118 | num_replicas: ${trainer.devices} 119 | shuffle: true 120 | 121 | 122 | validation_ds: 123 | dataset: 124 | _target_: torchdata.data_total.ExtensiveTTSDataset 125 | manifest_filepath: ${validation_datasets} 126 | sample_rate: ${model.sample_rate} 127 | sup_data_path: ${sup_data_path} 128 | sup_data_types: ${sup_data_types} 129 | n_fft: ${model.n_fft} 130 | win_length: ${model.n_window_size} 131 | hop_length: ${model.n_window_stride} 132 | window: ${model.window} 133 | n_mels: ${model.n_mel_channels} 134 | lowfreq: ${model.lowfreq} 135 | highfreq: ${model.highfreq} 136 | max_duration: null 137 | min_duration: 0.7 138 | ignore_file: null 139 | trim: False 140 | pitch_fmin: ${model.pitch_fmin} 141 | pitch_fmax: ${model.pitch_fmax} 142 | lm_model: ${lm_model} 143 | codec_model: ${codec_model} 144 | codec_sum: false 145 | 146 | dataloader_params: 147 | drop_last: false 148 | shuffle: false 149 | batch_size: ${batch_size} 150 | num_workers: ${num_workers} 151 | pin_memory: false 152 | 153 | 154 | preprocessor: 155 | _target_: module.preprocessor.FilterbankFeatures 156 | nfilt: ${model.n_mel_channels} 157 | highfreq: ${model.highfreq} 158 | log: true 159 | log_zero_guard_type: clamp 160 | log_zero_guard_value: 1e-05 161 | lowfreq: ${model.lowfreq} 162 | n_fft: ${model.n_fft} 163 | n_window_size: ${model.n_window_size} 164 | n_window_stride: ${model.n_window_stride} 165 | pad_to: 1 166 | pad_value: 0 167 | sample_rate: ${model.sample_rate} 168 | window: ${model.window} 169 | normalize: null 170 | preemph: null 171 | dither: 0.0 172 | frame_splicing: 1 173 | stft_conv: false 174 | nb_augmentation_prob : 0 175 | mag_power: 1.0 176 | exact_pad: true 177 | use_grads: true 178 | 179 | synthesizer: 180 | _target_: module.ref_vits_module.RefSynthesizerTrn 181 | inter_channels: 192 182 | hidden_channels: 192 183 | filter_channels: 768 184 | n_heads: 2 185 | n_layers: 6 186 | kernel_size: 3 187 | p_dropout: 0.1 188 | resblock: "1" 189 | resblock_kernel_sizes: [3,7,11] 190 | resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]] 191 | upsample_rates: [8,8,2,2] 192 | upsample_initial_channel: 512 193 | upsample_kernel_sizes: [16,16,4,4] 194 | n_speakers: ${model.n_speakers} 195 | gin_channels: 256 # for multi-speaker 196 | ref_encoder: 197 | _target_: module.ref_gst.GlobalStyleTokenForCodec 198 | cnn_filters: [ 32, 32, 64, 64, 128, 128 ] 199 | dropout: 0.2 200 | gru_hidden: ${model.synthesizer.gin_channels} 201 | gst_size: ${model.synthesizer.gin_channels} 202 | initial_dim: 8 203 | n_style_token: 10 204 | n_style_attn_head: 4 205 | 206 | optim: 207 | _target_: torch.optim.AdamW 208 | lr: 2e-4 209 | betas: [0.9, 0.99] 210 | eps: 1e-9 211 | 212 | sched: 213 | name: ExponentialLR 214 | lr_decay: 0.999875 215 | 216 | trainer: 217 | num_nodes: 1 218 | devices: ${ngpu} 219 | accelerator: gpu 220 | strategy: ddp 221 | precision: 16 222 | # amp_backend: 'apex' 223 | # amp_level: 'O2' 224 | # benchmark: true 225 | max_epochs: 1000 226 | accumulate_grad_batches: 1 227 | enable_checkpointing: false # Provided by exp_manager 228 | logger: false # Provided by exp_manager 229 | log_every_n_steps: 50 230 | check_val_every_n_epoch: 1 231 | 232 | exp_manager: 233 | exp_dir: exp_VITS 234 | name: ${name} 235 | create_tensorboard_logger: false 236 | create_checkpoint_callback: true 237 | checkpoint_callback_params: 238 | monitor: loss_gen_all 239 | mode: min 240 | create_wandb_logger: true 241 | wandb_logger_kwargs: 242 | name: ${name} 243 | project: RefVits 244 | entity: null 245 | resume_if_exists: true 246 | resume_ignore_no_checkpoint: true -------------------------------------------------------------------------------- /conf/ref_mixer_codec_vits.yaml: -------------------------------------------------------------------------------- 1 | # This config contains the default values for training VITS model on LJSpeech dataset. 2 | # If you want to train model on other dataset, you can change config values according to your dataset. 3 | # Most dataset-specific arguments are in the head of the config file, see below. 4 | 5 | 6 | name: VITS_Mixer_Codec 7 | 8 | batch_size: ?? 9 | num_workers: ?? 10 | ngpu: ?? 11 | 12 | train_dataset: "data/train_manifest.json" 13 | validation_datasets: "data/valid_manifest.json" 14 | sup_data_path: "sup_data/" 15 | 16 | sup_data_types: [ "speaker_id", "audio_codec"] 17 | 18 | whitelist_path: "sup_data/text/whitelist/lj_speech.tsv" 19 | phoneme_dict_path: "sup_data/text/cmudict-0.7b_nv22.10" 20 | heteronyms_path: "sup_data/text/heteronyms-052722" 21 | 22 | # Default values from librosa.pyin 23 | pitch_fmin: 65.40639132514966 24 | pitch_fmax: 2093.004522404789 25 | 26 | sample_rate: 24000 27 | n_mel_channels: 80 28 | n_window_size: 1024 29 | n_window_stride: 256 30 | n_fft: 1024 31 | lowfreq: 0 32 | highfreq: null 33 | window: hann 34 | 35 | codec_model: "encodec" 36 | 37 | 38 | model: 39 | pitch_fmin: ${pitch_fmin} 40 | pitch_fmax: ${pitch_fmax} 41 | 42 | sample_rate: ${sample_rate} 43 | n_mel_channels: ${n_mel_channels} 44 | n_window_size: ${n_window_size} 45 | n_window_stride: ${n_window_stride} 46 | n_fft: ${n_fft} 47 | lowfreq: ${lowfreq} 48 | highfreq: ${highfreq} 49 | window: ${window} 50 | mel_fmin: 0.0 51 | mel_fmax: null 52 | 53 | n_speakers: 819 54 | segment_size: 8192 55 | c_mel: 45 56 | c_kl: 1. 57 | use_spectral_norm: false 58 | 59 | text_normalizer: 60 | _target_: nemo_text_processing.text_normalization.normalize.Normalizer 61 | lang: en 62 | input_case: cased 63 | whitelist: ${whitelist_path} 64 | 65 | text_normalizer_call_kwargs: 66 | verbose: false 67 | punct_pre_process: true 68 | punct_post_process: true 69 | 70 | text_tokenizer: 71 | _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer 72 | punct: true 73 | apostrophe: true 74 | pad_with_space: false 75 | g2p: 76 | _target_: nemo_text_processing.g2p.modules.IPAG2P 77 | phoneme_dict: ${phoneme_dict_path} 78 | heteronyms: ${heteronyms_path} 79 | phoneme_probability: 0.8 80 | # Relies on the heteronyms list for anything that needs to be disambiguated 81 | ignore_ambiguous_words: false 82 | use_chars: true 83 | use_stresses: true 84 | 85 | train_ds: 86 | dataset: 87 | _target_: torchdata.data_total.ExtensiveTTSDataset 88 | manifest_filepath: ${train_dataset} 89 | sample_rate: ${model.sample_rate} 90 | sup_data_path: ${sup_data_path} 91 | sup_data_types: ${sup_data_types} 92 | n_fft: ${model.n_fft} 93 | win_length: ${model.n_window_size} 94 | hop_length: ${model.n_window_stride} 95 | window: ${model.window} 96 | n_mels: ${model.n_mel_channels} 97 | lowfreq: ${model.lowfreq} 98 | highfreq: ${model.highfreq} 99 | max_duration: null 100 | min_duration: 0.7 101 | ignore_file: null 102 | trim: False 103 | pitch_fmin: ${model.pitch_fmin} 104 | pitch_fmax: ${model.pitch_fmax} 105 | codec_model: ${codec_model} 106 | codec_sum: false 107 | 108 | dataloader_params: 109 | # drop_last: false 110 | # shuffle: true 111 | # batch_size: ${batch_size} 112 | num_workers: ${num_workers} 113 | pin_memory: true 114 | batch_sampler: 115 | batch_size: ${batch_size} 116 | boundaries: [32,300,400,500,600,700,800,900,1000] 117 | num_replicas: ${trainer.devices} 118 | shuffle: true 119 | 120 | validation_ds: 121 | dataset: 122 | _target_: torchdata.data_total.ExtensiveTTSDataset 123 | manifest_filepath: ${validation_datasets} 124 | sample_rate: ${model.sample_rate} 125 | sup_data_path: ${sup_data_path} 126 | sup_data_types: ${sup_data_types} 127 | n_fft: ${model.n_fft} 128 | win_length: ${model.n_window_size} 129 | hop_length: ${model.n_window_stride} 130 | window: ${model.window} 131 | n_mels: ${model.n_mel_channels} 132 | lowfreq: ${model.lowfreq} 133 | highfreq: ${model.highfreq} 134 | max_duration: null 135 | min_duration: 0.7 136 | ignore_file: null 137 | trim: False 138 | pitch_fmin: ${model.pitch_fmin} 139 | pitch_fmax: ${model.pitch_fmax} 140 | codec_model: ${codec_model} 141 | codec_sum: false 142 | 143 | dataloader_params: 144 | drop_last: false 145 | shuffle: false 146 | batch_size: ${batch_size} 147 | num_workers: ${num_workers} 148 | pin_memory: false 149 | 150 | preprocessor: 151 | _target_: module.preprocessor.FilterbankFeatures 152 | nfilt: ${model.n_mel_channels} 153 | highfreq: ${model.highfreq} 154 | log: true 155 | log_zero_guard_type: clamp 156 | log_zero_guard_value: 1e-05 157 | lowfreq: ${model.lowfreq} 158 | n_fft: ${model.n_fft} 159 | n_window_size: ${model.n_window_size} 160 | n_window_stride: ${model.n_window_stride} 161 | pad_to: 1 162 | pad_value: 0 163 | sample_rate: ${model.sample_rate} 164 | window: ${model.window} 165 | normalize: null 166 | preemph: null 167 | dither: 0.0 168 | frame_splicing: 1 169 | stft_conv: false 170 | nb_augmentation_prob : 0 171 | mag_power: 1.0 172 | exact_pad: true 173 | use_grads: true 174 | 175 | synthesizer: 176 | _target_: module.ref_vits_module.RefSynthesizerTrn 177 | inter_channels: 192 178 | hidden_channels: 192 179 | filter_channels: 768 180 | n_heads: 2 181 | n_layers: 6 182 | kernel_size: 3 183 | p_dropout: 0.1 184 | resblock: "1" 185 | resblock_kernel_sizes: [3,7,11] 186 | resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]] 187 | upsample_rates: [8,8,2,2] 188 | upsample_initial_channel: 512 189 | upsample_kernel_sizes: [16,16,4,4] 190 | n_speakers: ${model.n_speakers} 191 | gin_channels: 256 # for multi-speaker 192 | ref_encoder: 193 | _target_: module.ref_mixer_codec_only.RefMixer 194 | initial_dim: 8 195 | gru_out: ${model.synthesizer.gin_channels} 196 | expansion_factor: 4 197 | num_layers: 6 198 | kernel_sizes: [11, 13, 15, 17, 19, 21] 199 | conv_type: "depth-wise" 200 | 201 | optim: 202 | _target_: torch.optim.AdamW 203 | lr: 2e-4 204 | betas: [0.9, 0.99] 205 | eps: 1e-9 206 | 207 | sched: 208 | name: ExponentialLR 209 | lr_decay: 0.999875 210 | 211 | trainer: 212 | num_nodes: 1 213 | devices: ${ngpu} 214 | accelerator: gpu 215 | strategy: ddp 216 | precision: 16 217 | # amp_backend: 'apex' 218 | # amp_level: 'O2' 219 | # benchmark: true 220 | max_epochs: 1000 221 | accumulate_grad_batches: 1 222 | enable_checkpointing: false # Provided by exp_manager 223 | logger: false # Provided by exp_manager 224 | log_every_n_steps: 50 225 | check_val_every_n_epoch: 1 226 | 227 | exp_manager: 228 | exp_dir: exp_VITS 229 | name: ${name} 230 | create_tensorboard_logger: false 231 | create_checkpoint_callback: true 232 | checkpoint_callback_params: 233 | monitor: loss_gen_all 234 | mode: min 235 | create_wandb_logger: true 236 | wandb_logger_kwargs: 237 | name: ${name} 238 | project: RefVits 239 | entity: null 240 | resume_if_exists: true 241 | resume_ignore_no_checkpoint: true -------------------------------------------------------------------------------- /conf/ref_vits.yaml: -------------------------------------------------------------------------------- 1 | # This config contains the default values for training VITS model on LJSpeech dataset. 2 | # If you want to train model on other dataset, you can change config values according to your dataset. 3 | # Most dataset-specific arguments are in the head of the config file, see below. 4 | 5 | name: VITS_GST 6 | 7 | batch_size: ?? 8 | num_workers: ?? 9 | ngpu: ?? 10 | 11 | train_dataset: "data/train_manifest.json" 12 | validation_datasets: "data/valid_manifest.json" 13 | sup_data_path: "sup_data/" 14 | sup_data_types: [ "speaker_id"] 15 | 16 | whitelist_path: "sup_data/text/whitelist/lj_speech.tsv" 17 | phoneme_dict_path: "sup_data/text/cmudict-0.7b_nv22.10" 18 | heteronyms_path: "sup_data/text/heteronyms-052722" 19 | 20 | # Default values from librosa.pyin 21 | pitch_fmin: 65.40639132514966 22 | pitch_fmax: 2093.004522404789 23 | 24 | sample_rate: 24000 25 | n_mel_channels: 80 26 | n_window_size: 1024 27 | n_window_stride: 256 28 | n_fft: 1024 29 | lowfreq: 0 30 | highfreq: null 31 | window: hann 32 | 33 | codec_model: "encodec" 34 | 35 | 36 | model: 37 | pitch_fmin: ${pitch_fmin} 38 | pitch_fmax: ${pitch_fmax} 39 | 40 | sample_rate: ${sample_rate} 41 | n_mel_channels: ${n_mel_channels} 42 | n_window_size: ${n_window_size} 43 | n_window_stride: ${n_window_stride} 44 | n_fft: ${n_fft} 45 | lowfreq: ${lowfreq} 46 | highfreq: ${highfreq} 47 | window: ${window} 48 | mel_fmin: 0.0 49 | mel_fmax: null 50 | 51 | n_speakers: ?? 52 | segment_size: 8192 53 | c_mel: 45 54 | c_kl: 1. 55 | use_spectral_norm: false 56 | 57 | text_normalizer: 58 | _target_: nemo_text_processing.text_normalization.normalize.Normalizer 59 | lang: en 60 | input_case: cased 61 | whitelist: ${whitelist_path} 62 | 63 | text_normalizer_call_kwargs: 64 | verbose: false 65 | punct_pre_process: true 66 | punct_post_process: true 67 | 68 | text_tokenizer: 69 | _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer 70 | punct: true 71 | apostrophe: true 72 | pad_with_space: false 73 | g2p: 74 | _target_: nemo_text_processing.g2p.modules.IPAG2P 75 | phoneme_dict: ${phoneme_dict_path} 76 | heteronyms: ${heteronyms_path} 77 | phoneme_probability: 0.8 78 | # Relies on the heteronyms list for anything that needs to be disambiguated 79 | ignore_ambiguous_words: false 80 | use_chars: true 81 | use_stresses: true 82 | 83 | train_ds: 84 | dataset: 85 | _target_: torchdata.data_total.ExtensiveTTSDataset 86 | manifest_filepath: ${train_dataset} 87 | sample_rate: ${model.sample_rate} 88 | sup_data_path: ${sup_data_path} 89 | sup_data_types: ${sup_data_types} 90 | n_fft: ${model.n_fft} 91 | win_length: ${model.n_window_size} 92 | hop_length: ${model.n_window_stride} 93 | window: ${model.window} 94 | n_mels: ${model.n_mel_channels} 95 | lowfreq: ${model.lowfreq} 96 | highfreq: ${model.highfreq} 97 | max_duration: null 98 | min_duration: 0.7 99 | ignore_file: null 100 | trim: False 101 | pitch_fmin: ${model.pitch_fmin} 102 | pitch_fmax: ${model.pitch_fmax} 103 | codec_model: ${codec_model} 104 | 105 | dataloader_params: 106 | drop_last: false 107 | shuffle: true 108 | batch_size: ${batch_size} 109 | num_workers: ${num_workers} 110 | pin_memory: true 111 | 112 | validation_ds: 113 | dataset: 114 | _target_: torchdata.data_total.ExtensiveTTSDataset 115 | manifest_filepath: ${validation_datasets} 116 | sample_rate: ${model.sample_rate} 117 | sup_data_path: ${sup_data_path} 118 | sup_data_types: ${sup_data_types} 119 | n_fft: ${model.n_fft} 120 | win_length: ${model.n_window_size} 121 | hop_length: ${model.n_window_stride} 122 | window: ${model.window} 123 | n_mels: ${model.n_mel_channels} 124 | lowfreq: ${model.lowfreq} 125 | highfreq: ${model.highfreq} 126 | max_duration: null 127 | min_duration: 0.7 128 | ignore_file: null 129 | trim: False 130 | pitch_fmin: ${model.pitch_fmin} 131 | pitch_fmax: ${model.pitch_fmax} 132 | codec_model: ${codec_model} 133 | 134 | dataloader_params: 135 | drop_last: false 136 | shuffle: false 137 | batch_size: ${batch_size} 138 | num_workers: ${num_workers} 139 | pin_memory: false 140 | 141 | preprocessor: 142 | _target_: module.preprocessor.FilterbankFeatures 143 | nfilt: ${model.n_mel_channels} 144 | highfreq: ${model.highfreq} 145 | log: true 146 | log_zero_guard_type: clamp 147 | log_zero_guard_value: 1e-05 148 | lowfreq: ${model.lowfreq} 149 | n_fft: ${model.n_fft} 150 | n_window_size: ${model.n_window_size} 151 | n_window_stride: ${model.n_window_stride} 152 | pad_to: 1 153 | pad_value: 0 154 | sample_rate: ${model.sample_rate} 155 | window: ${model.window} 156 | normalize: null 157 | preemph: null 158 | dither: 0.0 159 | frame_splicing: 1 160 | stft_conv: false 161 | nb_augmentation_prob : 0 162 | mag_power: 1.0 163 | exact_pad: true 164 | use_grads: true 165 | 166 | synthesizer: 167 | _target_: module.ref_vits_module.RefSynthesizerTrn 168 | inter_channels: 192 169 | hidden_channels: 192 170 | filter_channels: 768 171 | n_heads: 2 172 | n_layers: 6 173 | kernel_size: 3 174 | p_dropout: 0.1 175 | resblock: "1" 176 | resblock_kernel_sizes: [3,7,11] 177 | resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]] 178 | upsample_rates: [8,8,2,2] 179 | upsample_initial_channel: 512 180 | upsample_kernel_sizes: [16,16,4,4] 181 | n_speakers: ${model.n_speakers} 182 | gin_channels: 256 # for multi-speaker 183 | ref_encoder: 184 | _target_: module.ref_gst.GlobalStyleToken 185 | cnn_filters: [ 32, 32, 64, 64, 128, 128 ] 186 | dropout: 0.2 187 | gru_hidden: ${model.synthesizer.gin_channels} 188 | gst_size: ${model.synthesizer.gin_channels} 189 | initial_dim: 513 190 | n_style_token: 10 191 | n_style_attn_head: 4 192 | 193 | optim: 194 | _target_: torch.optim.AdamW 195 | lr: 2e-4 196 | betas: [0.9, 0.99] 197 | eps: 1e-9 198 | 199 | sched: 200 | name: ExponentialLR 201 | lr_decay: 0.999875 202 | 203 | trainer: 204 | num_nodes: 1 205 | devices: ${ngpu} 206 | accelerator: gpu 207 | strategy: ddp 208 | precision: 16 209 | # amp_backend: 'apex' 210 | # amp_level: 'O2' 211 | # benchmark: true 212 | max_epochs: 1000 213 | accumulate_grad_batches: 1 214 | enable_checkpointing: false # Provided by exp_manager 215 | logger: false # Provided by exp_manager 216 | log_every_n_steps: 50 217 | check_val_every_n_epoch: 1 218 | 219 | exp_manager: 220 | exp_dir: exp_VITS 221 | name: ${name} 222 | create_tensorboard_logger: false 223 | create_checkpoint_callback: true 224 | checkpoint_callback_params: 225 | monitor: loss_gen_all 226 | mode: min 227 | create_wandb_logger: true 228 | wandb_logger_kwargs: 229 | name: ${name} 230 | project: RefVits 231 | entity: null 232 | resume_if_exists: true 233 | resume_ignore_no_checkpoint: true -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from pathlib import Path 3 | import json 4 | import os 5 | from copy import deepcopy 6 | from tqdm import tqdm 7 | import time 8 | from statistics import mean 9 | 10 | import soundfile as sf 11 | import torch 12 | import torchaudio 13 | from torch.utils.data import DataLoader 14 | from transformers import AutoModel, AutoTokenizer 15 | 16 | from model.vits import VitsModel 17 | from module.ref_gst import GlobalStyleTokenForCodec 18 | 19 | from nemo.core.config import hydra_runner 20 | from hydra.utils import instantiate 21 | 22 | def match_file_for_transfer(manifests): 23 | # 108 speakers 24 | file_ids = [x['audio_filepath'] for x in manifests] 25 | file_ids = [x.split("/")[-1].replace(".wav", "") for x in file_ids] 26 | p287 = [x for x in file_ids if x.split("_")[0]=='p287'] 27 | reversed_ids = deepcopy(file_ids) 28 | reversed_ids.reverse() 29 | matched = {} 30 | idx = 0 31 | for i in range(len(file_ids)): 32 | if file_ids[i].split("_")[0]=='p362': 33 | matched[file_ids[i]] = p287[idx] 34 | idx += 1 35 | else: 36 | matched[file_ids[i]] = reversed_ids[i] 37 | 38 | return matched 39 | 40 | def unseen_files(manifests): 41 | file_ids = [x['audio_filepath'] for x in manifests] 42 | file_ids = [x.split("/")[-1].replace(".wav", "") for x in file_ids] 43 | 44 | matched = {} 45 | 46 | unseen = "path-to-unseen-files" 47 | unseen = [json.loads(x) for x in open(unseen).readlines()] 48 | 49 | random_choose = torch.randint(len(unseen), (1, len(file_ids)))[0] 50 | assert len(file_ids)==random_choose.size(0) 51 | for i, idx in enumerate(random_choose): 52 | idx = idx.data.item() 53 | x = unseen[idx] 54 | matched[file_ids[i]] = x['audio_filepath'].split("/")[-1].replace(".wav", "") 55 | 56 | return matched 57 | 58 | @hydra_runner(config_path="conf", config_name='infer_transfer') 59 | def main(cfg): 60 | device = cfg.device 61 | 62 | manifests = [json.loads(x) for x in open(cfg.manifest_path).readlines()] 63 | # matching files for transfer 64 | matched_file = match_file_for_transfer(manifests) 65 | # matched_file = unseen_files(manifests) 66 | 67 | model = VitsModel.load_from_checkpoint(cfg.checkpoint_path).to(device) 68 | 69 | result_dir = Path("result") 70 | # result_dir = result_dir / cfg.checkpoint_path.split("/")[1] / cfg.checkpoint_path.split("/")[-1] 71 | result_dir = result_dir / "unseen" / cfg.checkpoint_path.split("/")[1] / cfg.checkpoint_path.split("/")[-1] 72 | 73 | if not result_dir.exists(): 74 | result_dir.mkdir(parents=True, exist_ok=True) 75 | 76 | times = [] 77 | rtfs = [] 78 | 79 | for x in tqdm(manifests): 80 | file_id = x['audio_filepath'].split("/")[-1].replace(".wav", "") 81 | tokenized = model.tokenizer(x['normalized_text']) 82 | tokens = torch.tensor(tokenized).long().unsqueeze(0).to(device) 83 | tokens_length = torch.tensor(len(tokenized)).long().unsqueeze(0).to(device) 84 | 85 | codec, codec_len, spec, spec_len = None, None, None, None 86 | if 'codec' in cfg.checkpoint_path.lower(): 87 | codec_path = matched_file[file_id] + ".pt" 88 | codec_path = Path(cfg.sup_data_path) / "encodec" / codec_path 89 | codec = torch.load(codec_path).long().unsqueeze(0).to(device) 90 | codec_len = torch.tensor(codec.size(2)).long().unsqueeze(0).to(device) 91 | else: 92 | processor = model.audio_to_melspec_processor 93 | ref_path = f"{matched_file[file_id].split('_')[0]}/{matched_file[file_id]}.wav" 94 | ref_path = Path(cfg.sup_data_path) / ref_path 95 | audio, sr = torchaudio.load(ref_path) 96 | spec, spec_len = processor(audio, torch.tensor([audio.size(1)]),linear_spec=True) 97 | spec = spec.to(device) 98 | spec_len = spec_len.to(device) 99 | 100 | spk_id = torch.tensor(x['speaker']).long().unsqueeze(0).to(device) 101 | 102 | start_time = time.perf_counter() 103 | if codec is not None: 104 | if isinstance(model.net_g.ref_encoder, GlobalStyleTokenForCodec): 105 | generator = model.net_g.float() 106 | wav, _, _, _ = generator.infer(tokens, tokens_length, speakers=spk_id, ref_spec=codec.float(), 107 | ref_spec_lens=codec_len) 108 | else: 109 | wav, _, _, _ = model.net_g.infer(tokens, tokens_length, speakers=spk_id, ref_spec=codec, ref_spec_lens=codec_len) 110 | elif spec is not None: 111 | wav, _, _, _ = model.net_g.infer(tokens, tokens_length, speakers=spk_id, ref_spec=spec, 112 | ref_spec_lens=spec_len) 113 | # wav_len_sec = int(wav.size(-1)) / 24000 114 | points_per_second = int(wav.size(-1)) / (time.perf_counter() - start_time) 115 | # rtf = (time.perf_counter() - start_time) / int(wav.size(-1)) 116 | # wav = wav[0][0].detach().cpu().numpy() 117 | # print(f"inference speed ={per_frame_speed:1.2f} / sec") 118 | times.append(points_per_second) 119 | # rtfs.append(rtf) 120 | # result_filename = file_id + "_" + matched_file[file_id] + ".wav" 121 | # result_filename = result_dir / result_filename 122 | # sf.write(result_filename, wav, samplerate=cfg.sample_rate, format="WAV", subtype="PCM_16") 123 | print(cfg.checkpoint_path.split("/")[1]) 124 | print(f"average inference speed: {mean(times)}") 125 | # print(f"average RTFs: {mean(rtfs)}") 126 | 127 | 128 | if __name__ == "__main__": 129 | main() -------------------------------------------------------------------------------- /model/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from contextlib import ExitStack, contextmanager 3 | from typing import List 4 | 5 | import torch 6 | 7 | from nemo.collections.tts.helpers.helpers import OperationMode 8 | from nemo.core.classes import ModelPT 9 | from nemo.core.classes.common import PretrainedModelInfo, typecheck 10 | from nemo.core.neural_types.elements import AudioSignal 11 | from nemo.core.neural_types.neural_type import NeuralType 12 | 13 | class TextToWaveform(ModelPT, ABC): 14 | """ Base class for all end-to-end TTS models that generate a waveform from text """ 15 | 16 | @abstractmethod 17 | def parse(self, str_input: str, **kwargs) -> 'torch.tensor': 18 | """ 19 | A helper function that accepts a raw python string and turns it into a tensor. The tensor should have 2 20 | dimensions. The first is the batch, which should be of size 1. The second should represent time. The tensor 21 | should represent either tokenized or embedded text, depending on the model. 22 | """ 23 | 24 | @abstractmethod 25 | def convert_text_to_waveform(self, *, tokens: 'torch.tensor', **kwargs) -> 'List[torch.tensor]': 26 | """ 27 | Accepts a batch of text and returns a list containing a batch of audio 28 | Args: 29 | tokens: A torch tensor representing the text to be converted to speech 30 | Returns: 31 | audio: A list of length batch_size containing torch tensors representing the waveform output 32 | """ 33 | 34 | @classmethod 35 | def list_available_models(cls) -> 'List[PretrainedModelInfo]': 36 | """ 37 | This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. 38 | Returns: 39 | List of available pre-trained models. 40 | """ 41 | list_of_models = [] 42 | for subclass in cls.__subclasses__(): 43 | subclass_models = subclass.list_available_models() 44 | if subclass_models is not None and len(subclass_models) > 0: 45 | list_of_models.extend(subclass_models) 46 | return list_of_models -------------------------------------------------------------------------------- /model/helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional, Tuple 3 | 4 | def slice_segments(x, ids_str, segment_size=4): 5 | """ 6 | Time-wise slicing (patching) of bathches for audio/spectrogram 7 | [B x C x T] -> [B x C x segment_size] 8 | """ 9 | ret = torch.zeros_like(x[:, :, :segment_size]) 10 | for i in range(x.size(0)): 11 | idx_str = ids_str[i] 12 | idx_end = idx_str + segment_size 13 | x_i = x[i] 14 | if idx_end >= x.size(2): 15 | # pad the sample if it is shorter than the segment size 16 | x_i = torch.nn.functional.pad(x_i, (0, (idx_end + 1) - x.size(2))) 17 | ret[i] = x_i[:, idx_str:idx_end] 18 | return ret 19 | 20 | def rand_slice_segments(x, x_lengths=None, segment_size=4): 21 | """ 22 | Chooses random indices and slices segments from batch 23 | [B x C x T] -> [B x C x segment_size] 24 | """ 25 | b, d, t = x.size() 26 | if x_lengths is None: 27 | x_lengths = t 28 | ids_str_max = x_lengths - segment_size + 1 29 | ids_str_max = ids_str_max.to(device=x.device) 30 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) 31 | 32 | ret = slice_segments(x, ids_str, segment_size) 33 | 34 | return ret, ids_str 35 | 36 | def clip_grad_value_(parameters, clip_value, norm_type=2): 37 | if isinstance(parameters, torch.Tensor): 38 | parameters = [parameters] 39 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 40 | norm_type = float(norm_type) 41 | if clip_value is not None: 42 | clip_value = float(clip_value) 43 | 44 | total_norm = 0 45 | for p in parameters: 46 | param_norm = p.grad.data.norm(norm_type) 47 | total_norm += param_norm.item() ** norm_type 48 | if clip_value is not None: 49 | p.grad.data.clamp_(min=-clip_value, max=clip_value) 50 | total_norm = total_norm ** (1.0 / norm_type) 51 | return total_norm 52 | 53 | 54 | def convert_pad_shape(pad_shape): 55 | l = pad_shape[::-1] 56 | pad_shape = [item for sublist in l for item in sublist] 57 | return pad_shape 58 | 59 | 60 | def generate_path(duration, mask): 61 | """ 62 | duration: [b, 1, t_x] 63 | mask: [b, 1, t_y, t_x] 64 | """ 65 | b, _, t_y, t_x = mask.shape 66 | cum_duration = torch.cumsum(duration, -1) 67 | 68 | cum_duration_flat = cum_duration.view(b * t_x) 69 | path = get_mask_from_lengths(cum_duration_flat, torch.Tensor(t_y).reshape(1, 1, -1)).to(mask.dtype) 70 | path = path.view(b, t_x, t_y) 71 | path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 72 | path = path.unsqueeze(1).transpose(2, 3) * mask 73 | return path 74 | 75 | def get_mask_from_lengths(lengths: Optional[torch.Tensor] = None, x: Optional[torch.Tensor] = None,) -> torch.Tensor: 76 | """Constructs binary mask from a 1D torch tensor of input lengths 77 | Args: 78 | lengths: Optional[torch.tensor] (torch.tensor): 1D tensor with lengths 79 | x: Optional[torch.tensor] = tensor to be used on, last dimension is for mask 80 | Returns: 81 | mask (torch.tensor): num_sequences x max_length x 1 binary tensor 82 | """ 83 | if lengths is None: 84 | assert x is not None 85 | return torch.ones(x.shape[-1], dtype=torch.bool, device=x.device) 86 | else: 87 | if x is None: 88 | max_len = torch.max(lengths) 89 | else: 90 | max_len = x.shape[-1] 91 | ids = torch.arange(0, max_len, device=lengths.device, dtype=lengths.dtype) 92 | mask = ids < lengths.unsqueeze(1) 93 | return mask 94 | 95 | -------------------------------------------------------------------------------- /model/splines.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # The MIT License (MIT) 16 | # Copyright (c) 2020, nicolas deutschmann 17 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 18 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 19 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 20 | 21 | import numpy as np 22 | import torch 23 | import torch.nn.functional as F 24 | 25 | 26 | def piecewise_linear_transform(x, q_tilde, compute_jacobian=True, outlier_passthru=True): 27 | """Apply an element-wise piecewise-linear transformation to some variables 28 | 29 | Args: 30 | x : torch.Tensor 31 | a tensor with shape (N,k) where N is the batch dimension while k is the dimension of the variable space. This variable span the k-dimensional unit 32 | hypercube 33 | 34 | q_tilde: torch.Tensor 35 | is a tensor with shape (N,k,b) where b is the number of bins. 36 | This contains the un-normalized heights of the bins of the piecewise-constant PDF for dimension k, 37 | i.e. q_tilde lives in all of R and we don't impose a constraint on their sum yet. 38 | Normalization is imposed in this function using softmax. 39 | 40 | compute_jacobian : bool, optional 41 | determines whether the jacobian should be compute or None is returned 42 | 43 | Returns: 44 | tuple of torch.Tensor 45 | pair `(y,h)`. 46 | - `y` is a tensor with shape (N,k) living in the k-dimensional unit hypercube 47 | - `j` is the jacobian of the transformation with shape (N,) if compute_jacobian==True, else None. 48 | """ 49 | logj = None 50 | 51 | third_dimension_softmax = torch.nn.Softmax(dim=2) 52 | 53 | # Compute the bin width w 54 | N, k, b = q_tilde.shape 55 | Nx, kx = x.shape 56 | assert N == Nx and k == kx, "Shape mismatch" 57 | 58 | w = 1.0 / b 59 | 60 | # Compute normalized bin heights with softmax function on bin dimension 61 | q = 1.0 / w * third_dimension_softmax(q_tilde) 62 | # x is in the mx-th bin: x \in [0,1], 63 | # mx \in [[0,b-1]], so we clamp away the case x == 1 64 | mx = torch.clamp(torch.floor(b * x), 0, b - 1).to(torch.long) 65 | # Need special error handling because trying to index with mx 66 | # if it contains nans will lock the GPU. (device-side assert triggered) 67 | if torch.any(torch.isnan(mx)).item() or torch.any(mx < 0) or torch.any(mx >= b): 68 | raise AvertedCUDARuntimeError("NaN detected in PWLinear bin indexing") 69 | 70 | # We compute the output variable in-place 71 | out = x - mx * w # alpha (element of [0.,w], the position of x in its bin 72 | 73 | # Multiply by the slope 74 | # q has shape (N,k,b), mxu = mx.unsqueeze(-1) has shape (N,k) with entries that are a b-index 75 | # gather defines slope[i, j, k] = q[i, j, mxu[i, j, k]] with k taking only 0 as a value 76 | # i.e. we say slope[i, j] = q[i, j, mx [i, j]] 77 | slopes = torch.gather(q, 2, mx.unsqueeze(-1)).squeeze(-1) 78 | out = out * slopes 79 | # The jacobian is the product of the slopes in all dimensions 80 | 81 | # Compute the integral over the left-bins. 82 | # 1. Compute all integrals: cumulative sum of bin height * bin weight. 83 | # We want that index i contains the cumsum *strictly to the left* so we shift by 1 84 | # leaving the first entry null, which is achieved with a roll and assignment 85 | q_left_integrals = torch.roll(torch.cumsum(q, 2) * w, 1, 2) 86 | q_left_integrals[:, :, 0] = 0 87 | 88 | # 2. Access the correct index to get the left integral of each point and add it to our transformation 89 | out = out + torch.gather(q_left_integrals, 2, mx.unsqueeze(-1)).squeeze(-1) 90 | 91 | # Regularization: points must be strictly within the unit hypercube 92 | # Use the dtype information from pytorch 93 | eps = torch.finfo(out.dtype).eps 94 | out = out.clamp(min=eps, max=1.0 - eps) 95 | oob_mask = torch.logical_or(x < 0.0, x > 1.0).detach().float() 96 | if outlier_passthru: 97 | out = out * (1 - oob_mask) + x * oob_mask 98 | slopes = slopes * (1 - oob_mask) + oob_mask 99 | 100 | if compute_jacobian: 101 | # logj = torch.log(torch.prod(slopes.float(), 1)) 102 | logj = torch.sum(torch.log(slopes), 1) 103 | del slopes 104 | 105 | return out, logj 106 | 107 | 108 | def piecewise_linear_inverse_transform(y, q_tilde, compute_jacobian=True, outlier_passthru=True): 109 | """ 110 | Apply inverse of an element-wise piecewise-linear transformation to some 111 | variables 112 | 113 | Args: 114 | y : torch.Tensor 115 | a tensor with shape (N,k) where N is the batch dimension while k is the dimension of the variable space. This variable span the k-dimensional unit hypercube 116 | 117 | q_tilde: torch.Tensor 118 | is a tensor with shape (N,k,b) where b is the number of bins. This contains the un-normalized heights of the bins of the piecewise-constant PDF for dimension k, i.e. q_tilde lives in all of R and we don't impose a constraint on their sum yet. Normalization is imposed in this function using softmax. 119 | 120 | compute_jacobian : bool, optional 121 | determines whether the jacobian should be compute or None is returned 122 | 123 | Returns: 124 | tuple of torch.Tensor 125 | pair `(x,h)`. 126 | - `x` is a tensor with shape (N,k) living in the k-dimensional unit hypercube 127 | - `j` is the jacobian of the transformation with shape (N,) if compute_jacobian==True, else None. 128 | """ 129 | 130 | third_dimension_softmax = torch.nn.Softmax(dim=2) 131 | # Compute the bin width w 132 | N, k, b = q_tilde.shape 133 | Ny, ky = y.shape 134 | assert N == Ny and k == ky, "Shape mismatch" 135 | 136 | w = 1.0 / b 137 | 138 | # Compute normalized bin heights with softmax function on the bin dimension 139 | q = 1.0 / w * third_dimension_softmax(q_tilde) 140 | 141 | # Compute the integral over the left-bins in the forward transform. 142 | # 1. Compute all integrals: cumulative sum of bin height * bin weight. 143 | # We want that index i contains the cumsum *strictly to the left*, 144 | # so we shift by 1 leaving the first entry null, 145 | # which is achieved with a roll and assignment 146 | q_left_integrals = torch.roll(torch.cumsum(q.float(), 2) * w, 1, 2) 147 | q_left_integrals[:, :, 0] = 0 148 | 149 | # Find which bin each y belongs to by finding the smallest bin such that 150 | # y - q_left_integral is positive 151 | 152 | edges = (y.unsqueeze(-1) - q_left_integrals).detach() 153 | # y and q_left_integrals are between 0 and 1, 154 | # so that their difference is at most 1. 155 | # By setting the negative values to 2., we know that the 156 | # smallest value left is the smallest positive 157 | edges[edges < 0] = 2.0 158 | edges = torch.clamp(torch.argmin(edges, dim=2), 0, b - 1).to(torch.long) 159 | 160 | # Need special error handling because trying to index with mx 161 | # if it contains nans will lock the GPU. (device-side assert triggered) 162 | if torch.any(torch.isnan(edges)).item() or torch.any(edges < 0) or torch.any(edges >= b): 163 | raise AvertedCUDARuntimeError("NaN detected in PWLinear bin indexing") 164 | 165 | # Gather the left integrals at each edge. See comment about gathering in q_left_integrals 166 | # for the unsqueeze 167 | q_left_integrals = q_left_integrals.gather(2, edges.unsqueeze(-1)).squeeze(-1) 168 | 169 | # Gather the slope at each edge. 170 | q = q.gather(2, edges.unsqueeze(-1)).squeeze(-1) 171 | 172 | # Build the output 173 | x = (y - q_left_integrals) / q + edges * w 174 | 175 | # Regularization: points must be strictly within the unit hypercube 176 | # Use the dtype information from pytorch 177 | eps = torch.finfo(x.dtype).eps 178 | x = x.clamp(min=eps, max=1.0 - eps) 179 | oob_mask = torch.logical_or(y < 0.0, y > 1.0).detach().float() 180 | if outlier_passthru: 181 | x = x * (1 - oob_mask) + y * oob_mask 182 | q = q * (1 - oob_mask) + oob_mask 183 | 184 | # Prepare the jacobian 185 | logj = None 186 | if compute_jacobian: 187 | # logj = - torch.log(torch.prod(q, 1)) 188 | logj = -torch.sum(torch.log(q.float()), 1) 189 | return x.detach(), logj 190 | 191 | 192 | def unbounded_piecewise_quadratic_transform(x, w_tilde, v_tilde, upper=1, lower=0, inverse=False): 193 | assert upper > lower 194 | _range = upper - lower 195 | inside_interval_mask = (x >= lower) & (x < upper) 196 | outside_interval_mask = ~inside_interval_mask 197 | 198 | outputs = torch.zeros_like(x) 199 | log_j = torch.zeros_like(x) 200 | 201 | outputs[outside_interval_mask] = x[outside_interval_mask] 202 | log_j[outside_interval_mask] = 0 203 | 204 | output, _log_j = piecewise_quadratic_transform( 205 | (x[inside_interval_mask] - lower) / _range, 206 | w_tilde[inside_interval_mask, :], 207 | v_tilde[inside_interval_mask, :], 208 | inverse=inverse, 209 | ) 210 | outputs[inside_interval_mask] = output * _range + lower 211 | if not inverse: 212 | # the before and after transformation cancel out, so the log_j would be just as it is. 213 | log_j[inside_interval_mask] = _log_j 214 | else: 215 | log_j = None 216 | return outputs, log_j 217 | 218 | 219 | def weighted_softmax(v, w): 220 | # to avoid NaN... 221 | v = v - torch.max(v, dim=-1, keepdim=True)[0] 222 | v = torch.exp(v) + 1e-8 # to avoid NaN... 223 | v_sum = torch.sum((v[..., :-1] + v[..., 1:]) / 2 * w, dim=-1, keepdim=True) 224 | return v / v_sum 225 | 226 | 227 | def piecewise_quadratic_transform(x, w_tilde, v_tilde, inverse=False): 228 | """Element-wise piecewise-quadratic transformation 229 | Args: 230 | x : torch.Tensor 231 | *, The variable spans the D-dim unit hypercube ([0,1)) 232 | w_tilde : torch.Tensor 233 | * x K defined in the paper 234 | v_tilde : torch.Tensor 235 | * x (K+1) defined in the paper 236 | inverse : bool 237 | forward or inverse 238 | Returns: 239 | c : torch.Tensor 240 | *, transformed value 241 | log_j : torch.Tensor 242 | *, log determinant of the Jacobian matrix 243 | """ 244 | w = torch.softmax(w_tilde, dim=-1) 245 | v = weighted_softmax(v_tilde, w) 246 | w_cumsum = torch.cumsum(w, dim=-1) 247 | # force sum = 1 248 | w_cumsum[..., -1] = 1.0 249 | w_cumsum_shift = F.pad(w_cumsum, (1, 0), 'constant', 0) 250 | cdf = torch.cumsum((v[..., 1:] + v[..., :-1]) / 2 * w, dim=-1) 251 | # force sum = 1 252 | cdf[..., -1] = 1.0 253 | cdf_shift = F.pad(cdf, (1, 0), 'constant', 0) 254 | 255 | if not inverse: 256 | # * x D x 1, (w_cumsum[idx-1] < x <= w_cumsum[idx]) 257 | bin_index = torch.searchsorted(w_cumsum, x.unsqueeze(-1)) 258 | else: 259 | # * x D x 1, (cdf[idx-1] < x <= cdf[idx]) 260 | bin_index = torch.searchsorted(cdf, x.unsqueeze(-1)) 261 | 262 | w_b = torch.gather(w, -1, bin_index).squeeze(-1) 263 | w_bn1 = torch.gather(w_cumsum_shift, -1, bin_index).squeeze(-1) 264 | v_b = torch.gather(v, -1, bin_index).squeeze(-1) 265 | v_bp1 = torch.gather(v, -1, bin_index + 1).squeeze(-1) 266 | cdf_bn1 = torch.gather(cdf_shift, -1, bin_index).squeeze(-1) 267 | 268 | if not inverse: 269 | alpha = (x - w_bn1) / w_b.clamp(min=torch.finfo(w_b.dtype).eps) 270 | c = (alpha ** 2) / 2 * (v_bp1 - v_b) * w_b + alpha * v_b * w_b + cdf_bn1 271 | 272 | # just sum of log pdfs 273 | log_j = torch.lerp(v_b, v_bp1, alpha).clamp(min=torch.finfo(c.dtype).eps).log() 274 | 275 | # make sure it falls into [0,1) 276 | c = c.clamp(min=torch.finfo(c.dtype).eps, max=1.0 - torch.finfo(c.dtype).eps) 277 | return c, log_j 278 | else: 279 | # quadratic equation for alpha 280 | # alpha should fall into (0, 1]. Since a, b > 0, the symmetry axis -b/2a < 0 and we should pick the larger root 281 | # skip calculating the log_j in inverse since we don't need it 282 | a = (v_bp1 - v_b) * w_b / 2 283 | b = v_b * w_b 284 | c = cdf_bn1 - x 285 | alpha = (-b + torch.sqrt((b ** 2) - 4 * a * c)) / (2 * a) 286 | inv = alpha * w_b + w_bn1 287 | 288 | # make sure it falls into [0,1) 289 | inv = inv.clamp(min=torch.finfo(c.dtype).eps, max=1.0 - torch.finfo(inv.dtype).eps) 290 | return inv, None 291 | 292 | 293 | def piecewise_rational_quadratic_transform( 294 | inputs, 295 | unnormalized_widths, 296 | unnormalized_heights, 297 | unnormalized_derivatives, 298 | inverse=False, 299 | tails=None, 300 | tail_bound=1.0, 301 | min_bin_width=1e-3, 302 | min_bin_height=1e-3, 303 | min_derivative=1e-3, 304 | ): 305 | 306 | if tails is None: 307 | spline_fn = rational_quadratic_spline 308 | spline_kwargs = {} 309 | else: 310 | spline_fn = unconstrained_rational_quadratic_spline 311 | spline_kwargs = {'tails': tails, 'tail_bound': tail_bound} 312 | 313 | outputs, logabsdet = spline_fn( 314 | inputs=inputs, 315 | unnormalized_widths=unnormalized_widths, 316 | unnormalized_heights=unnormalized_heights, 317 | unnormalized_derivatives=unnormalized_derivatives, 318 | inverse=inverse, 319 | min_bin_width=min_bin_width, 320 | min_bin_height=min_bin_height, 321 | min_derivative=min_derivative, 322 | **spline_kwargs 323 | ) 324 | return outputs, logabsdet 325 | 326 | 327 | def searchsorted(bin_locations, inputs, eps=1e-6): 328 | bin_locations[..., -1] += eps 329 | return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 330 | 331 | 332 | def unconstrained_rational_quadratic_spline( 333 | inputs, 334 | unnormalized_widths, 335 | unnormalized_heights, 336 | unnormalized_derivatives, 337 | inverse=False, 338 | tails='linear', 339 | tail_bound=1.0, 340 | min_bin_width=1e-3, 341 | min_bin_height=1e-3, 342 | min_derivative=1e-3, 343 | ): 344 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) 345 | outside_interval_mask = ~inside_interval_mask 346 | 347 | outputs = torch.zeros_like(inputs) 348 | logabsdet = torch.zeros_like(inputs) 349 | 350 | if tails == 'linear': 351 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) 352 | constant = np.log(np.exp(1 - min_derivative) - 1) 353 | unnormalized_derivatives[..., 0] = constant 354 | unnormalized_derivatives[..., -1] = constant 355 | 356 | outputs[outside_interval_mask] = inputs[outside_interval_mask] 357 | logabsdet[outside_interval_mask] = 0 358 | else: 359 | raise RuntimeError('{} tails are not implemented.'.format(tails)) 360 | 361 | outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline( 362 | inputs=inputs[inside_interval_mask], 363 | unnormalized_widths=unnormalized_widths[inside_interval_mask, :], 364 | unnormalized_heights=unnormalized_heights[inside_interval_mask, :], 365 | unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], 366 | inverse=inverse, 367 | left=-tail_bound, 368 | right=tail_bound, 369 | bottom=-tail_bound, 370 | top=tail_bound, 371 | min_bin_width=min_bin_width, 372 | min_bin_height=min_bin_height, 373 | min_derivative=min_derivative, 374 | ) 375 | 376 | return outputs, logabsdet 377 | 378 | 379 | def rational_quadratic_spline( 380 | inputs, 381 | unnormalized_widths, 382 | unnormalized_heights, 383 | unnormalized_derivatives, 384 | inverse=False, 385 | left=0.0, 386 | right=1.0, 387 | bottom=0.0, 388 | top=1.0, 389 | min_bin_width=1e-3, 390 | min_bin_height=1e-3, 391 | min_derivative=1e-3, 392 | ): 393 | 394 | if torch.min(inputs) < left or torch.max(inputs) > right: 395 | raise ValueError('Input to a transform is not within its domain') 396 | 397 | num_bins = unnormalized_widths.shape[-1] 398 | 399 | if min_bin_width * num_bins > 1.0: 400 | raise ValueError('Minimal bin width too large for the number of bins') 401 | if min_bin_height * num_bins > 1.0: 402 | raise ValueError('Minimal bin height too large for the number of bins') 403 | 404 | widths = F.softmax(unnormalized_widths, dim=-1) 405 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths 406 | cumwidths = torch.cumsum(widths, dim=-1) 407 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0) 408 | cumwidths = (right - left) * cumwidths + left 409 | cumwidths[..., 0] = left 410 | cumwidths[..., -1] = right 411 | widths = cumwidths[..., 1:] - cumwidths[..., :-1] 412 | 413 | derivatives = min_derivative + F.softplus(unnormalized_derivatives) 414 | 415 | heights = F.softmax(unnormalized_heights, dim=-1) 416 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights 417 | cumheights = torch.cumsum(heights, dim=-1) 418 | cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0) 419 | cumheights = (top - bottom) * cumheights + bottom 420 | cumheights[..., 0] = bottom 421 | cumheights[..., -1] = top 422 | heights = cumheights[..., 1:] - cumheights[..., :-1] 423 | 424 | if inverse: 425 | bin_idx = searchsorted(cumheights, inputs)[..., None] 426 | else: 427 | bin_idx = searchsorted(cumwidths, inputs)[..., None] 428 | 429 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] 430 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0] 431 | 432 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] 433 | delta = heights / widths 434 | input_delta = delta.gather(-1, bin_idx)[..., 0] 435 | 436 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] 437 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] 438 | 439 | input_heights = heights.gather(-1, bin_idx)[..., 0] 440 | 441 | if inverse: 442 | a = (inputs - input_cumheights) * ( 443 | input_derivatives + input_derivatives_plus_one - 2 * input_delta 444 | ) + input_heights * (input_delta - input_derivatives) 445 | b = input_heights * input_derivatives - (inputs - input_cumheights) * ( 446 | input_derivatives + input_derivatives_plus_one - 2 * input_delta 447 | ) 448 | c = -input_delta * (inputs - input_cumheights) 449 | 450 | discriminant = b.pow(2) - 4 * a * c 451 | assert (discriminant >= 0).all() 452 | 453 | root = (2 * c) / (-b - torch.sqrt(discriminant)) 454 | outputs = root * input_bin_widths + input_cumwidths 455 | 456 | theta_one_minus_theta = root * (1 - root) 457 | denominator = input_delta + ( 458 | (input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta 459 | ) 460 | derivative_numerator = input_delta.pow(2) * ( 461 | input_derivatives_plus_one * root.pow(2) 462 | + 2 * input_delta * theta_one_minus_theta 463 | + input_derivatives * (1 - root).pow(2) 464 | ) 465 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 466 | 467 | return outputs, -logabsdet 468 | else: 469 | theta = (inputs - input_cumwidths) / input_bin_widths 470 | theta_one_minus_theta = theta * (1 - theta) 471 | 472 | numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta) 473 | denominator = input_delta + ( 474 | (input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta 475 | ) 476 | outputs = input_cumheights + numerator / denominator 477 | 478 | derivative_numerator = input_delta.pow(2) * ( 479 | input_derivatives_plus_one * theta.pow(2) 480 | + 2 * input_delta * theta_one_minus_theta 481 | + input_derivatives * (1 - theta).pow(2) 482 | ) 483 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 484 | 485 | return outputs, logabsdet 486 | -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | from nemo.collections.tts.torch.tts_data_types import MAIN_DATA_TYPES, WithLens 2 | from torchdata.data_type import DATA_STR2DATA_CLASS 3 | def average_features(pitch, durs): 4 | durs_cums_ends = torch.cumsum(durs, dim=1).long() 5 | durs_cums_starts = torch.nn.functional.pad(durs_cums_ends[:, :-1], (1, 0)) 6 | pitch_nonzero_cums = torch.nn.functional.pad(torch.cumsum(pitch != 0.0, dim=2), (1, 0)) 7 | pitch_cums = torch.nn.functional.pad(torch.cumsum(pitch, dim=2), (1, 0)) 8 | bs, l = durs_cums_ends.size() 9 | n_formants = pitch.size(1) 10 | dcs = durs_cums_starts[:, None, :].expand(bs, n_formants, l) 11 | dce = durs_cums_ends[:, None, :].expand(bs, n_formants, l) 12 | 13 | pitch_sums = (torch.gather(pitch_cums, 2, dce) - torch.gather(pitch_cums, 2, dcs)).float() 14 | pitch_nelems = (torch.gather(pitch_nonzero_cums, 2, dce) - torch.gather(pitch_nonzero_cums, 2, dcs)).float() 15 | 16 | pitch_avg = torch.where(pitch_nelems == 0.0, pitch_nelems, pitch_sums / pitch_nelems) 17 | return pitch_avg 18 | 19 | def extensive_process_batch(batch_data, sup_data_types_set): 20 | batch_dict = {} 21 | batch_index = 0 22 | 23 | for name, datatype in DATA_STR2DATA_CLASS.items(): 24 | if datatype in MAIN_DATA_TYPES or datatype in sup_data_types_set: 25 | batch_dict[name] = batch_data[batch_index] 26 | batch_index = batch_index + 1 27 | if issubclass(datatype, WithLens): 28 | batch_dict[name + "_lens"] = batch_data[batch_index] 29 | batch_index = batch_index + 1 30 | return batch_dict 31 | 32 | def process_batch(batch_data, sup_data_types_set): 33 | batch_dict = {} 34 | batch_index = 0 35 | 36 | for name, datatype in DATA_STR2DATA_CLASS.items(): 37 | if datatype in MAIN_DATA_TYPES or datatype in sup_data_types_set: 38 | batch_dict[name] = batch_data[batch_index] 39 | batch_index = batch_index + 1 40 | if issubclass(datatype, WithLens): 41 | batch_dict[name + "_lens"] = batch_data[batch_index] 42 | batch_index = batch_index + 1 43 | 44 | if len(batch_data)==12 or len(batch_data)==11 or (len(batch_data)==13 and 'energy' in batch_dict.keys()): # audio, lm embedding, without xvector or with speaker id 45 | batch_dict['audio_embedding'] = batch_data[-4] 46 | batch_dict['audio_embedding_lens'] = batch_data[-3] 47 | batch_dict['lm_embedding'] = batch_data[-2] 48 | batch_dict['lm_embedding_lens'] = batch_data[-1] 49 | else: # audio, lm embedding, with xvector 50 | batch_dict['audio_embedding'] = batch_data[-5] 51 | batch_dict['audio_embedding_lens'] = batch_data[-4] 52 | batch_dict['lm_embedding'] = batch_data[-3] 53 | batch_dict['lm_embedding_lens'] = batch_data[-2] 54 | batch_dict['pretrained_spk_embedding'] = batch_data[-1] 55 | 56 | return batch_dict 57 | 58 | import torch 59 | import torch.nn.functional as F 60 | 61 | from nemo.collections.tts.modules.transformer import mask_from_lens 62 | from nemo.core.classes import Loss, typecheck 63 | from nemo.core.neural_types.elements import ( 64 | LengthsType, 65 | LossType, 66 | MelSpectrogramType, 67 | RegressionValuesType, 68 | TokenDurationType, 69 | TokenLogDurationType, 70 | ) 71 | from nemo.core.neural_types.neural_type import NeuralType 72 | 73 | class EnergyLoss(Loss): 74 | def __init__(self, loss_scale=0.1): 75 | super().__init__() 76 | self.loss_scale = loss_scale 77 | 78 | @property 79 | def input_types(self): 80 | return { 81 | "energy_predicted": NeuralType(('B', 'T'), RegressionValuesType()), 82 | "energy_tgt": NeuralType(('B', 'T'), RegressionValuesType()), 83 | "length": NeuralType(('B'), LengthsType()), 84 | } 85 | 86 | @property 87 | def output_types(self): 88 | return { 89 | "loss": NeuralType(elements_type=LossType()), 90 | } 91 | 92 | @typecheck() 93 | def forward(self, energy_predicted, energy_tgt, length): 94 | if energy_tgt is None: 95 | return 0.0 96 | dur_mask = mask_from_lens(length, max_len=energy_tgt.size(1)) 97 | energy_loss = F.mse_loss(energy_tgt, energy_predicted, reduction='none') 98 | energy_loss = (energy_loss * dur_mask).sum() / dur_mask.sum() 99 | energy_loss *= self.loss_scale 100 | 101 | return energy_loss 102 | 103 | if __name__=="__main__": 104 | pitch = torch.load("../test_pitch_1.pt") 105 | durs = torch.load("../test_durs_1.pt") 106 | average_features(pitch,durs) -------------------------------------------------------------------------------- /model/vits.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import contextlib 17 | 18 | import omegaconf 19 | import torch 20 | import wandb 21 | from hydra.utils import instantiate 22 | from omegaconf import DictConfig, OmegaConf 23 | from pytorch_lightning import Trainer 24 | from pytorch_lightning.loggers import WandbLogger 25 | from torch.cuda.amp import autocast 26 | from torch.nn import functional as F 27 | 28 | from nemo.collections.tts.helpers.helpers import plot_spectrogram_to_numpy 29 | from nemo.collections.tts.torch.tts_data_types import SpeakerID 30 | from nemo.core.classes.common import PretrainedModelInfo, typecheck 31 | from nemo.core.neural_types.elements import AudioSignal, FloatType, Index, IntType, TokenIndex 32 | from nemo.core.neural_types.neural_type import NeuralType 33 | from nemo.core.optim.lr_scheduler import CosineAnnealing 34 | from nemo.utils import logging, model_utils 35 | from nemo.utils.decorators.experimental import experimental 36 | 37 | from model.base import TextToWaveform 38 | from model.helper import clip_grad_value_, slice_segments 39 | from module.vits_losses import DiscriminatorLoss, FeatureMatchingLoss, GeneratorLoss, KlLoss 40 | from module.vits_modules import MultiPeriodDiscriminator 41 | from torchdata.data import DistributedBucketSampler 42 | from torchdata.data_type import AudioCodec, PretrainedLM 43 | from model.utils import extensive_process_batch as process_batch 44 | 45 | HAVE_WANDB = True 46 | try: 47 | import wandb 48 | except ModuleNotFoundError: 49 | HAVE_WANDB = False 50 | 51 | 52 | @experimental 53 | class VitsModel(TextToWaveform): 54 | def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): 55 | # Convert to Hydra 1.0 compatible DictConfig 56 | 57 | cfg = model_utils.convert_model_config_to_dict_config(cfg) 58 | cfg = model_utils.maybe_update_config_version(cfg) 59 | 60 | # setup normalizer 61 | self.normalizer = None 62 | self.text_normalizer_call = None 63 | self.text_normalizer_call_kwargs = {} 64 | self._setup_normalizer(cfg) 65 | 66 | # setup tokenizer 67 | self.tokenizer = None 68 | self._setup_tokenizer(cfg) 69 | assert self.tokenizer is not None 70 | 71 | num_tokens = len(self.tokenizer.tokens) 72 | self.tokenizer_pad = self.tokenizer.pad 73 | 74 | super().__init__(cfg=cfg, trainer=trainer) 75 | 76 | self.audio_to_melspec_processor = instantiate(cfg.preprocessor, highfreq=cfg.train_ds.dataset.highfreq) 77 | 78 | self.feat_matching_loss = FeatureMatchingLoss() 79 | self.disc_loss = DiscriminatorLoss() 80 | self.gen_loss = GeneratorLoss() 81 | self.kl_loss = KlLoss() 82 | 83 | self.net_g = instantiate( 84 | cfg.synthesizer, 85 | n_vocab=num_tokens, 86 | spec_channels=cfg.n_fft // 2 + 1, 87 | segment_size=cfg.segment_size // cfg.n_window_stride, 88 | padding_idx=self.tokenizer_pad, 89 | ) 90 | 91 | self.net_d = MultiPeriodDiscriminator(cfg.use_spectral_norm) 92 | 93 | self.automatic_optimization = False 94 | 95 | def _setup_normalizer(self, cfg): 96 | if "text_normalizer" in cfg: 97 | normalizer_kwargs = {} 98 | 99 | if "whitelist" in cfg.text_normalizer: 100 | normalizer_kwargs["whitelist"] = self.register_artifact( 101 | 'text_normalizer.whitelist', cfg.text_normalizer.whitelist 102 | ) 103 | 104 | self.normalizer = instantiate(cfg.text_normalizer, **normalizer_kwargs) 105 | self.text_normalizer_call = self.normalizer.normalize 106 | if "text_normalizer_call_kwargs" in cfg: 107 | self.text_normalizer_call_kwargs = cfg.text_normalizer_call_kwargs 108 | 109 | def _setup_tokenizer(self, cfg): 110 | text_tokenizer_kwargs = {} 111 | if "g2p" in cfg.text_tokenizer and cfg.text_tokenizer.g2p is not None: 112 | g2p_kwargs = {} 113 | 114 | if "phoneme_dict" in cfg.text_tokenizer.g2p: 115 | g2p_kwargs["phoneme_dict"] = self.register_artifact( 116 | 'text_tokenizer.g2p.phoneme_dict', cfg.text_tokenizer.g2p.phoneme_dict, 117 | ) 118 | 119 | if "heteronyms" in cfg.text_tokenizer.g2p: 120 | g2p_kwargs["heteronyms"] = self.register_artifact( 121 | 'text_tokenizer.g2p.heteronyms', cfg.text_tokenizer.g2p.heteronyms, 122 | ) 123 | 124 | text_tokenizer_kwargs["g2p"] = instantiate(cfg.text_tokenizer.g2p, **g2p_kwargs) 125 | 126 | self.tokenizer = instantiate(cfg.text_tokenizer, **text_tokenizer_kwargs) 127 | 128 | def parse(self, text: str, normalize=True) -> torch.tensor: 129 | if self.training: 130 | logging.warning("parse() is meant to be called in eval mode.") 131 | if normalize and self.text_normalizer_call is not None: 132 | text = self.text_normalizer_call(text, **self.text_normalizer_call_kwargs) 133 | 134 | eval_phon_mode = contextlib.nullcontext() 135 | if hasattr(self.tokenizer, "set_phone_prob"): 136 | eval_phon_mode = self.tokenizer.set_phone_prob(prob=1.0) 137 | 138 | with eval_phon_mode: 139 | tokens = self.tokenizer.encode(text) 140 | 141 | return torch.tensor(tokens).long().unsqueeze(0).to(self.device) 142 | 143 | def configure_optimizers(self): 144 | optim_config = self._cfg.optim.copy() 145 | OmegaConf.set_struct(optim_config, False) 146 | sched_config = optim_config.pop("sched", None) 147 | OmegaConf.set_struct(optim_config, True) 148 | 149 | optim_g = instantiate(optim_config, params=self.net_g.parameters(),) 150 | optim_d = instantiate(optim_config, params=self.net_d.parameters(),) 151 | 152 | if sched_config is not None: 153 | if sched_config.name == 'ExponentialLR': 154 | scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=sched_config.lr_decay) 155 | scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=sched_config.lr_decay) 156 | elif sched_config.name == 'CosineAnnealing': 157 | scheduler_g = CosineAnnealing( 158 | optimizer=optim_g, max_steps=sched_config.max_steps, min_lr=sched_config.min_lr, 159 | ) 160 | scheduler_d = CosineAnnealing( 161 | optimizer=optim_d, max_steps=sched_config.max_steps, min_lr=sched_config.min_lr, 162 | ) 163 | else: 164 | raise ValueError("Unknown optimizer.") 165 | 166 | scheduler_g_dict = {'scheduler': scheduler_g, 'interval': 'step'} 167 | scheduler_d_dict = {'scheduler': scheduler_d, 'interval': 'step'} 168 | return [optim_g, optim_d], [scheduler_g_dict, scheduler_d_dict] 169 | else: 170 | return [optim_g, optim_d] 171 | 172 | # for inference 173 | @typecheck( 174 | input_types={ 175 | "tokens": NeuralType(('B', 'T_text'), TokenIndex()), 176 | "speakers": NeuralType(('B',), Index(), optional=True), 177 | "noise_scale": NeuralType(('B',), FloatType(), optional=True), 178 | "length_scale": NeuralType(('B',), FloatType(), optional=True), 179 | "noise_scale_w": NeuralType(('B',), FloatType(), optional=True), 180 | "max_len": NeuralType(('B',), IntType(), optional=True), 181 | } 182 | ) 183 | def forward(self, tokens, speakers=None, noise_scale=1, length_scale=1, noise_scale_w=1.0, max_len=1000, 184 | ref_spec = None, ref_spec_lens = None, ref_codec = None, ref_codec_lens = None): 185 | text_len = torch.tensor([tokens.size(-1)]).to(int).to(tokens.device) 186 | audio_pred, attn, y_mask, (z, z_p, m_p, logs_p) = self.net_g.infer( 187 | tokens, 188 | text_len, 189 | speakers=speakers, 190 | noise_scale=noise_scale, 191 | length_scale=length_scale, 192 | noise_scale_w=noise_scale_w, 193 | max_len=max_len, 194 | ref_spec=ref_spec, 195 | ref_spec_lens=ref_spec_lens, 196 | ref_codec=ref_codec, 197 | ref_codec_lens=ref_codec_lens 198 | ) 199 | return audio_pred, attn, y_mask, (z, z_p, m_p, logs_p) 200 | 201 | def training_step(self, batch, batch_idx): 202 | speakers, audio_codec, audio_codec_lens = None, None, None 203 | if AudioCodec in self._train_dl.dataset.sup_data_types_set or PretrainedLM in self._train_dl.dataset.sup_data_types_set: 204 | batch_dict = process_batch(batch, self._train_dl.dataset.sup_data_types_set) 205 | audio = batch_dict.get("audio") 206 | audio_len = batch_dict.get("audio_lens") 207 | text = batch_dict.get("text") 208 | text_len = batch_dict.get("text_lens") 209 | speakers = batch_dict.get("speaker_id", None) 210 | audio_codec = batch_dict.get("audio_codec", None) 211 | audio_codec_lens = batch_dict.get("audio_codec_lens", None) 212 | elif SpeakerID in self._train_dl.dataset.sup_data_types_set: 213 | (audio, audio_len, text, text_len, speakers) = batch 214 | else: 215 | (audio, audio_len, text, text_len) = batch 216 | 217 | spec, spec_lengths = self.audio_to_melspec_processor(audio, audio_len, linear_spec=True) 218 | 219 | ref_spec, ref_spec_lens = None, None 220 | ref_codec, ref_codec_lens = None, None 221 | if audio_codec is not None: 222 | if audio_codec.size(1) == 8: # codec-only 223 | ref_spec, ref_spec_lens = audio_codec, audio_codec_lens 224 | else: # multi modal ref input 225 | ref_codec, ref_codec_lens = audio_codec, audio_codec_lens 226 | else: # mel-only 227 | ref_spec, ref_spec_lens = spec, spec_lengths 228 | 229 | with autocast(enabled=True): 230 | audio_pred, l_length, attn, ids_slice, text_mask, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q) = self.net_g( 231 | text, text_len, spec, spec_lengths, speakers, ref_spec, ref_spec_lens, ref_codec, ref_codec_lens, 232 | ) 233 | 234 | audio_pred = audio_pred.float() 235 | 236 | audio_pred_mel, _ = self.audio_to_melspec_processor(audio_pred.squeeze(1), audio_len, linear_spec=False) 237 | 238 | audio = slice_segments(audio.unsqueeze(1), ids_slice * self.cfg.n_window_stride, self._cfg.segment_size) 239 | audio_mel, _ = self.audio_to_melspec_processor(audio.squeeze(1), audio_len, linear_spec=False) 240 | 241 | with autocast(enabled=True): 242 | y_d_hat_r, y_d_hat_g, _, _ = self.net_d(audio, audio_pred.detach()) 243 | 244 | with autocast(enabled=False): 245 | loss_disc, losses_disc_r, losses_disc_g = self.disc_loss( 246 | disc_real_outputs=y_d_hat_r, disc_generated_outputs=y_d_hat_g 247 | ) 248 | loss_disc_all = loss_disc 249 | 250 | # get optimizers 251 | optim_g, optim_d = self.optimizers() 252 | 253 | # train discriminator 254 | optim_d.zero_grad() 255 | self.manual_backward(loss_disc_all) 256 | norm_d = clip_grad_value_(self.net_d.parameters(), None) 257 | optim_d.step() 258 | 259 | with autocast(enabled=True): 260 | y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.net_d(audio, audio_pred) 261 | # Generator 262 | with autocast(enabled=False): 263 | loss_dur = torch.sum(l_length.float()) 264 | loss_mel = F.l1_loss(audio_mel, audio_pred_mel) * self._cfg.c_mel 265 | loss_kl = self.kl_loss(z_p=z_p, logs_q=logs_q, m_p=m_p, logs_p=logs_p, z_mask=z_mask) * self._cfg.c_kl 266 | loss_fm = self.feat_matching_loss(fmap_r=fmap_r, fmap_g=fmap_g) 267 | loss_gen, losses_gen = self.gen_loss(disc_outputs=y_d_hat_g) 268 | loss_gen_all = loss_gen + loss_fm + loss_mel + loss_dur + loss_kl 269 | 270 | # train generator 271 | optim_g.zero_grad() 272 | self.manual_backward(loss_gen_all) 273 | norm_g = clip_grad_value_(self.net_g.parameters(), None) 274 | optim_g.step() 275 | 276 | schedulers = self.lr_schedulers() 277 | if schedulers is not None: 278 | sch1, sch2 = schedulers 279 | if ( 280 | self.trainer.is_last_batch 281 | and isinstance(sch1, torch.optim.lr_scheduler.ExponentialLR) 282 | or isinstance(sch1, CosineAnnealing) 283 | ): 284 | sch1.step() 285 | sch2.step() 286 | 287 | metrics = { 288 | "loss_gen": loss_gen, 289 | "loss_fm": loss_fm, 290 | "loss_mel": loss_mel, 291 | "loss_dur": loss_dur, 292 | "loss_kl": loss_kl, 293 | "loss_gen_all": loss_gen_all, 294 | "loss_disc_all": loss_disc_all, 295 | "grad_gen": norm_g, 296 | "grad_disc": norm_d, 297 | } 298 | 299 | for i, v in enumerate(losses_gen): 300 | metrics[f"loss_gen_i_{i}"] = v 301 | 302 | for i, v in enumerate(losses_disc_r): 303 | metrics[f"loss_disc_r_{i}"] = v 304 | 305 | for i, v in enumerate(losses_disc_g): 306 | metrics[f"loss_disc_g_{i}"] = v 307 | 308 | self.log_dict(metrics, on_step=True, sync_dist=True) 309 | 310 | def validation_step(self, batch, batch_idx): 311 | speakers, audio_codec, audio_codec_lens = None, None, None 312 | if AudioCodec in self._validation_dl.dataset.sup_data_types_set: 313 | batch_dict = process_batch(batch, self._train_dl.dataset.sup_data_types_set) 314 | audio = batch_dict.get("audio") 315 | audio_len = batch_dict.get("audio_lens") 316 | text = batch_dict.get("text") 317 | text_len = batch_dict.get("text_lens") 318 | speakers = batch_dict.get("speaker_id", None) 319 | audio_codec = batch_dict.get("audio_codec", None) 320 | audio_codec_lens = batch_dict.get("audio_codec_lens", None) 321 | elif self.cfg.n_speakers > 1: 322 | (audio, audio_len, text, text_len, speakers) = batch 323 | else: 324 | (audio, audio_len, text, text_len) = batch 325 | 326 | ref_spec, ref_spec_lens = None, None 327 | ref_codec, ref_codec_lens = None, None 328 | if audio_codec is not None: 329 | if audio_codec.size(1) == 8: # codec-only 330 | ref_spec, ref_spec_lens = audio_codec, audio_codec_lens 331 | else: # multi modal ref input 332 | ref_codec, ref_codec_lens = audio_codec, audio_codec_lens 333 | else: # mel-only 334 | ref_spec, ref_spec_lens = self.audio_to_melspec_processor(audio, audio_len, linear_spec=True) 335 | 336 | audio_pred, _, mask, *_ = self.net_g.infer(text, text_len, speakers, max_len=1000, 337 | ref_spec=ref_spec, ref_spec_lens=ref_spec_lens, 338 | ref_codec=ref_codec, ref_codec_lens=ref_codec_lens) 339 | 340 | audio_pred = audio_pred.squeeze() 341 | audio_pred_len = mask.sum([1, 2]).long() * self._cfg.validation_ds.dataset.hop_length 342 | 343 | mel, mel_lengths = self.audio_to_melspec_processor(audio, audio_len) 344 | audio_pred_mel, audio_pred_mel_len = self.audio_to_melspec_processor(audio_pred, audio_pred_len) 345 | 346 | 347 | # valid_mel_loss = F.l1_loss(mel.detach(), audio_pred_mel.detach()).detach() 348 | # metrics = { 349 | # 'val_mel_loss': valid_mel_loss 350 | # } 351 | # self.log_dict(metrics, on_step=True, sync_dist=True) 352 | # plot audio once per epoch 353 | if batch_idx == 0 and isinstance(self.logger, WandbLogger) and HAVE_WANDB: 354 | logger = self.logger.experiment 355 | specs = [] 356 | audios = [] 357 | specs += [ 358 | wandb.Image( 359 | plot_spectrogram_to_numpy(mel[0, :, : mel_lengths[0]].data.cpu().numpy()), 360 | caption=f"val_mel_target", 361 | ), 362 | wandb.Image( 363 | plot_spectrogram_to_numpy(audio_pred_mel[0, :, : audio_pred_mel_len[0]].data.cpu().numpy()), 364 | caption=f"val_mel_predicted", 365 | ), 366 | ] 367 | 368 | audios += [ 369 | wandb.Audio( 370 | audio[0, : audio_len[0]].data.cpu().to(torch.float).numpy(), 371 | caption=f"val_wav_target_1", 372 | sample_rate=self._cfg.sample_rate, 373 | ), 374 | wandb.Audio( 375 | audio_pred[0, : audio_pred_len[0]].data.cpu().to(torch.float).numpy(), 376 | caption=f"val_wav_predicted_1", 377 | sample_rate=self._cfg.sample_rate, 378 | ), 379 | wandb.Audio( 380 | audio[3, : audio_len[3]].data.cpu().to(torch.float).numpy(), 381 | caption=f"val_wav_target_2", 382 | sample_rate=self._cfg.sample_rate, 383 | ), 384 | wandb.Audio( 385 | audio_pred[3, : audio_pred_len[3]].data.cpu().to(torch.float).numpy(), 386 | caption=f"val_wav_predicted_2", 387 | sample_rate=self._cfg.sample_rate, 388 | ), 389 | wandb.Audio( 390 | audio[4, : audio_len[4]].data.cpu().to(torch.float).numpy(), 391 | caption=f"val_wav_target_3", 392 | sample_rate=self._cfg.sample_rate, 393 | ), 394 | wandb.Audio( 395 | audio_pred[4, : audio_pred_len[4]].data.cpu().to(torch.float).numpy(), 396 | caption=f"val_wav_predicted_3", 397 | sample_rate=self._cfg.sample_rate, 398 | ), 399 | ] 400 | 401 | logger.log({"specs": specs, "audios": audios}) 402 | 403 | def _loader(self, cfg, use_sampler): 404 | try: 405 | _ = cfg['dataset']['manifest_filepath'] 406 | except omegaconf.errors.MissingMandatoryValue: 407 | logging.warning("manifest_filepath was skipped. No dataset for this model.") 408 | return None 409 | 410 | dataset = instantiate( 411 | cfg.dataset, 412 | text_normalizer=self.normalizer, 413 | text_normalizer_call_kwargs=self.text_normalizer_call_kwargs, 414 | text_tokenizer=self.tokenizer, 415 | ) 416 | if use_sampler: # use sampler only in training 417 | sampler = DistributedBucketSampler(dataset, **self.cfg.train_ds.batch_sampler) 418 | 419 | dataloader = torch.utils.data.DataLoader( # noqa 420 | dataset=dataset, collate_fn=dataset.collate_fn, batch_sampler=sampler, **cfg.dataloader_params, 421 | ) 422 | else: 423 | dataloader = torch.utils.data.DataLoader( # noqa 424 | dataset=dataset, collate_fn=dataset.collate_fn, **cfg.dataloader_params, 425 | ) 426 | return dataloader 427 | 428 | def train_dataloader(self): 429 | # default used by the Trainer 430 | dataset = instantiate( 431 | self.cfg.train_ds.dataset, 432 | text_normalizer=self.normalizer, 433 | text_normalizer_call_kwargs=self.text_normalizer_call_kwargs, 434 | text_tokenizer=self.tokenizer, 435 | ) 436 | # 437 | train_sampler = DistributedBucketSampler(dataset, **self.cfg.train_ds.batch_sampler) 438 | # 439 | dataloader = torch.utils.data.DataLoader( 440 | dataset, collate_fn=dataset.collate_fn, batch_sampler=train_sampler, **self.cfg.train_ds.dataloader_params, 441 | ) 442 | return dataloader 443 | 444 | def setup_training_data(self, cfg): 445 | self._train_dl = self._loader(cfg, use_sampler=False) 446 | 447 | def setup_validation_data(self, cfg): 448 | self._validation_dl = self._loader(cfg, use_sampler=False) 449 | 450 | def setup_test_data(self, cfg): 451 | """Omitted.""" 452 | pass 453 | 454 | @classmethod 455 | def list_available_models(cls) -> 'List[PretrainedModelInfo]': 456 | list_of_models = [] 457 | model = PretrainedModelInfo( 458 | pretrained_model_name="tts_en_lj_vits", 459 | location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_en_lj_vits/versions/1.13.0/files/vits_ljspeech_fp16_full.nemo", 460 | description="This model is trained on LJSpeech audio sampled at 22050Hz. This model has been tested on generating female English " 461 | "voices with an American accent.", 462 | class_=cls, 463 | ) 464 | list_of_models.append(model) 465 | return list_of_models 466 | 467 | @typecheck( 468 | input_types={"tokens": NeuralType(('B', 'T_text'), TokenIndex(), optional=True),}, 469 | output_types={"audio": NeuralType(('B', 'T_audio'), AudioSignal())}, 470 | ) 471 | def convert_text_to_waveform(self, *, tokens, speakers=None): 472 | audio = self(tokens=tokens, speakers=speakers)[0].squeeze(1) 473 | return audio 474 | -------------------------------------------------------------------------------- /module/monotonic_align.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import numba 17 | import numpy as np 18 | import torch 19 | 20 | 21 | def maximum_path(neg_cent, mask): 22 | """ Numba version. 23 | neg_cent: [b, t_t, t_s] 24 | mask: [b, t_t, t_s] 25 | """ 26 | device = neg_cent.device 27 | dtype = neg_cent.dtype 28 | neg_cent = neg_cent.data.cpu().numpy().astype(np.float32) 29 | path = np.zeros(neg_cent.shape, dtype=np.int32) 30 | 31 | t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32) 32 | t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32) 33 | maximum_path_c(path, neg_cent, t_t_max, t_s_max) 34 | return torch.from_numpy(path).to(device=device, dtype=dtype) 35 | 36 | 37 | @numba.jit(nopython=True, boundscheck=False, parallel=True) 38 | def maximum_path_each(path, value, t_y: int, t_x: int, max_neg_val=-1e9): 39 | """ 40 | Args: 41 | path: int32[:, :] 42 | value: float32[:, :] 43 | t_y: int 44 | t_x: int 45 | max_neg_val: float 46 | """ 47 | index: int = t_x - 1 48 | 49 | for y in range(t_y): 50 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): 51 | if x == y: 52 | v_cur = max_neg_val 53 | else: 54 | v_cur = value[y - 1, x] 55 | if x == 0: 56 | if y == 0: 57 | v_prev = 0.0 58 | else: 59 | v_prev = max_neg_val 60 | else: 61 | v_prev = value[y - 1, x - 1] 62 | value[y, x] += max(v_prev, v_cur) 63 | 64 | for y in range(t_y - 1, -1, -1): 65 | path[y, index] = 1 66 | if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]): 67 | index = index - 1 68 | 69 | 70 | @numba.jit(nopython=True, boundscheck=False, parallel=True) 71 | def maximum_path_c(paths, values, t_ys, t_xs): 72 | """ 73 | Args: 74 | paths: int32[:, :, :] 75 | values: float32[:, :, :] 76 | t_ys: int[:] 77 | t_xs: int[:] 78 | """ 79 | b: int = paths.shape[0] 80 | for i in numba.prange(b): 81 | maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i]) 82 | 83 | 84 | if __name__ == '__main__': 85 | pass 86 | -------------------------------------------------------------------------------- /module/ref_gst.py: -------------------------------------------------------------------------------- 1 | # from Nemo branch fastpitch_gst 2 | # Nemo/nemo/collections/tts/modules/speaker_modules.py 3 | 4 | import torch 5 | import torch.nn as nn 6 | from collections import OrderedDict 7 | from nemo.core.classes import NeuralModule, typecheck 8 | from nemo.core.neural_types.neural_type import NeuralType 9 | from nemo.core.neural_types.elements import ( 10 | EncodedRepresentation, 11 | MelSpectrogramType, 12 | Index, 13 | TokenDurationType, 14 | TokenIndex, 15 | LengthsType 16 | ) 17 | from module.codec_embedding import AudioCodecForEmbedding 18 | 19 | """ 20 | Weighted Sum of Pre-trained Speaker Embedding 21 | """ 22 | 23 | 24 | class Weighted_SpeakerEmbedding(torch.nn.Module): 25 | def __init__(self, pretrained_embedding): 26 | super(Weighted_SpeakerEmbedding, self).__init__() 27 | self.pretrained_embedding = torch.nn.Parameter(pretrained_embedding.weight.detach().clone()) 28 | self.pretrained_embedding.requires_grad = False 29 | self.num_embeddings = pretrained_embedding.num_embeddings 30 | self.embedding_weight = torch.nn.Parameter(torch.ones(1, self.num_embeddings)) 31 | 32 | def forward(self, speaker): 33 | weight = self.embedding_weight.repeat(len(speaker), 1) 34 | weight = torch.nn.functional.softmax(weight, dim=-1) 35 | speaker_emb = weight @ self.pretrained_embedding 36 | return speaker_emb 37 | 38 | class GlobalStyleTokenForCodec(NeuralModule): 39 | def __init__(self, 40 | cnn_filters=[32, 32, 64, 64, 128, 128], 41 | dropout=0.2, 42 | gru_hidden=128, 43 | gst_size=128, 44 | initial_dim=8, 45 | n_style_token=10, 46 | n_style_attn_head=4): 47 | super(GlobalStyleTokenForCodec, self).__init__() 48 | self.gru_hidden = gru_hidden 49 | self.reference_encoder = ReferenceEncoder_UtteranceLevel(initial_dim=initial_dim, cnn_filters=list(cnn_filters), dropout=dropout, 50 | gru_hidden=gru_hidden) 51 | self.style_attention = StyleAttention(gru_hidden=gru_hidden, gst_size=gst_size, n_style_token=n_style_token, 52 | n_style_attn_head=n_style_attn_head) 53 | 54 | def _get_mask_from_length(self, lengths, max_length=None): 55 | if max_length is None: 56 | max_length = lengths.max() 57 | ids = torch.arange(0, max_length.item(), device=lengths.device, dtype=torch.long) 58 | mask = (ids < lengths.unsqueeze(1)).bool() 59 | mask = mask.unsqueeze(-1) 60 | return mask 61 | 62 | @property 63 | def input_types(self): 64 | return { 65 | "inp": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType()), 66 | "inp_mask": NeuralType(('B', 'T_spec', 1), TokenDurationType()), 67 | } 68 | 69 | @property 70 | def output_types(self): 71 | return { 72 | "gst": NeuralType(('B', 'D'), EncodedRepresentation()), 73 | } 74 | 75 | def forward(self, codec_embeddings, input_mask): 76 | style_embedding = self.reference_encoder(codec_embeddings, input_mask) # inp : (B, D, T_spec) / inp_mask : (B,T_spec, 1) 77 | gst = self.style_attention(style_embedding) 78 | return gst.unsqueeze(1) 79 | 80 | 81 | class GlobalStyleTokenForMulti(NeuralModule): 82 | def __init__(self, 83 | cnn_filters=[32, 32, 64, 64, 128, 128], 84 | dropout=0.2, 85 | gru_hidden=128, 86 | gst_size=128, 87 | lm_in = 768, 88 | codec_vocab_size=7768, 89 | codec_pad_id=1, 90 | n_style_token=10, 91 | n_style_attn_head=4): 92 | super(GlobalStyleTokenForMulti, self).__init__() 93 | self.gru_hidden = gru_hidden 94 | self.reference_encoder = ReferenceEncoder_UtteranceLevel(initial_dim=gru_hidden, cnn_filters=list(cnn_filters), dropout=dropout, 95 | gru_hidden=gru_hidden) 96 | self.style_attention = StyleAttention(gru_hidden=gru_hidden, gst_size=gst_size, n_style_token=n_style_token, 97 | n_style_attn_head=n_style_attn_head) 98 | self.codec_embedding = AudioCodecForEmbedding(vocab_size=codec_vocab_size, hidden_size=gru_hidden, 99 | pad_token_id=codec_pad_id) 100 | self.lm_proj = nn.Linear(lm_in, gru_hidden) 101 | 102 | def _get_mask_from_length(self, lengths, max_length): 103 | ids = torch.arange(0, max_length.item(), device=lengths.device, dtype=torch.long) 104 | mask = (ids < lengths.unsqueeze(1)).bool() 105 | mask = mask.unsqueeze(-1) 106 | return mask 107 | 108 | @property 109 | def input_types(self): 110 | return { 111 | "codec_token_ids": NeuralType(('B', 'T_spec'), TokenIndex()), 112 | "lm_embedding": NeuralType(('B', 'T_text','D'), EncodedRepresentation()), 113 | "codec_lens": NeuralType(('B'), LengthsType()), 114 | "lm_embed_lens":NeuralType(('B'), LengthsType()), 115 | } 116 | 117 | @property 118 | def output_types(self): 119 | return { 120 | "gst": NeuralType(('B', 'D'), EncodedRepresentation()), 121 | } 122 | 123 | def forward(self, codec_token_ids, lm_embeds, codec_lens, lm_embed_lens): 124 | if codec_token_ids.dtype != torch.long: 125 | codec_token_ids = codec_token_ids.long() 126 | codec_embeds = self.codec_embedding(codec_token_ids) 127 | 128 | # with torch.autograd.set_detect_anomaly(True): 129 | batch_size = lm_embeds.size(0) 130 | max_length = torch.tensor([codec_embeds.size(1) + lm_embeds.size(1)]).long() 131 | 132 | multi_embeds, multi_masks = [], [] 133 | for i in range(batch_size): 134 | codec = codec_embeds[i][:codec_lens[i]] # get rid of padded part 135 | 136 | lm = lm_embeds[i][:lm_embed_lens[i]] 137 | lm = self.lm_proj(lm) 138 | 139 | multi = torch.cat([codec, lm], dim=0) 140 | multi_length = torch.tensor([multi.size(0)]).long().to(lm_embeds.device) 141 | multi_masks.append(self._get_mask_from_length(multi_length, max_length)) 142 | 143 | multi = torch.cat([multi, torch.zeros(max_length - multi.size(0), self.gru_hidden, device=multi.device)], 144 | dim=0) 145 | multi_embeds.append(multi) 146 | 147 | multi_embeds = torch.stack(multi_embeds).transpose(1,2) 148 | multi_masks = torch.cat(multi_masks) 149 | 150 | style_embedding = self.reference_encoder(multi_embeds, multi_masks) # inp : (B, D, T_spec) / inp_mask : (B,T_spec, 1) 151 | gst = self.style_attention(style_embedding) 152 | return gst.unsqueeze(1) 153 | 154 | """ 155 | Global Style Token based Speaker Embedding 156 | """ 157 | 158 | 159 | 160 | class GlobalStyleToken(NeuralModule): 161 | def __init__(self, 162 | cnn_filters=[32, 32, 64, 64, 128, 128], 163 | dropout=0.2, 164 | gru_hidden=128, 165 | initial_dim=80, 166 | gst_size=128, 167 | n_style_token=10, 168 | n_style_attn_head=4): 169 | super(GlobalStyleToken, self).__init__() 170 | self.reference_encoder = ReferenceEncoder_UtteranceLevel(initial_dim=initial_dim, cnn_filters=list(cnn_filters), dropout=dropout, 171 | gru_hidden=gru_hidden) 172 | self.style_attention = StyleAttention(gru_hidden=gru_hidden, gst_size=gst_size, n_style_token=n_style_token, 173 | n_style_attn_head=n_style_attn_head) 174 | 175 | @property 176 | def input_types(self): 177 | return { 178 | "inp": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType()), 179 | "inp_mask": NeuralType(('B', 'T_spec', 1), TokenDurationType()), 180 | } 181 | 182 | @property 183 | def output_types(self): 184 | return { 185 | "gst": NeuralType(('B', 'D'), EncodedRepresentation()), 186 | } 187 | 188 | def forward(self, inp, inp_mask): 189 | style_embedding = self.reference_encoder(inp, inp_mask) 190 | gst = self.style_attention(style_embedding) 191 | return gst.unsqueeze(1) 192 | 193 | 194 | class ReferenceEncoder_UtteranceLevel(NeuralModule): 195 | def __init__(self, initial_dim=80, cnn_filters=[32, 32, 64, 64, 128, 128], dropout=0.2, gru_hidden=128): 196 | super(ReferenceEncoder_UtteranceLevel, self).__init__() 197 | self.filter_size = [1] + cnn_filters 198 | self.dropout = dropout 199 | self.conv = nn.Sequential( 200 | OrderedDict( 201 | [ 202 | module 203 | for i in range(len(cnn_filters)) 204 | for module in ( 205 | ( 206 | "conv2d_{}".format(i + 1), 207 | Conv2d( 208 | in_channels=int(self.filter_size[i]), 209 | out_channels=int(self.filter_size[i + 1]), 210 | kernel_size=(3, 3), 211 | stride=(2, 2), 212 | padding=(1, 1), 213 | ), 214 | ), 215 | ("relu_{}".format(i + 1), nn.ReLU()), 216 | ( 217 | "layer_norm_{}".format(i + 1), 218 | nn.LayerNorm(self.filter_size[i + 1]), 219 | ), 220 | ("dropout_{}".format(i + 1), nn.Dropout(self.dropout)), 221 | ) 222 | ] 223 | ) 224 | ) 225 | 226 | gru_input_size = initial_dim 227 | for i in range(len(cnn_filters)): 228 | gru_input_size = (gru_input_size - 3 + 2 * 1) // 2 + 1 229 | gru_input_size *= cnn_filters[-1] 230 | # from espnet2 gst 231 | # for i in range(conv_layers): 232 | # gru_in_units = ( 233 | # gru_in_units - conv_kernel_size + 2 * padding 234 | # ) // conv_stride + 1 235 | self.gru = nn.GRU( 236 | input_size=gru_input_size, 237 | hidden_size=gru_hidden, 238 | batch_first=True, 239 | ) 240 | 241 | @property 242 | def input_types(self): 243 | return { 244 | "inputs": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType()), 245 | "inputs_masks": NeuralType(('B', 'T_spec', 1), TokenDurationType()), 246 | } 247 | 248 | @property 249 | def output_types(self): 250 | return { 251 | "out": NeuralType(('B', 'D'), EncodedRepresentation()), 252 | } 253 | 254 | def forward(self, inputs, inputs_masks): 255 | inputs = inputs.transpose(1, 2) 256 | 257 | inputs = inputs * inputs_masks 258 | out = inputs.unsqueeze(3) 259 | out = self.conv(out) 260 | out = out.view(out.shape[0], out.shape[1], -1).contiguous() 261 | self.gru.flatten_parameters() 262 | memory, out = self.gru(out) 263 | 264 | return out.squeeze(0) 265 | 266 | 267 | class StyleAttention(NeuralModule): 268 | def __init__(self, gru_hidden=128, gst_size=128, n_style_token=10, n_style_attn_head=4): 269 | super(StyleAttention, self).__init__() 270 | self.input_size = gru_hidden 271 | self.output_size = gst_size 272 | self.n_token = n_style_token 273 | self.n_head = n_style_attn_head 274 | self.token_size = self.output_size // self.n_head 275 | 276 | self.tokens = nn.Parameter(torch.FloatTensor(self.n_token, self.token_size)) 277 | 278 | self.q_linear = nn.Linear(self.input_size, self.output_size) 279 | self.k_linear = nn.Linear(self.token_size, self.output_size) 280 | self.v_linear = nn.Linear(self.token_size, self.output_size) 281 | 282 | self.tanh = nn.Tanh() 283 | self.softmax = nn.Softmax(dim=2) 284 | self.temperature = (self.output_size // self.n_head) ** 0.5 285 | nn.init.normal_(self.tokens) 286 | 287 | @property 288 | def input_types(self): 289 | return { 290 | "inputs": NeuralType(('B', 'D'), EncodedRepresentation()), 291 | "token_id": NeuralType(('B'), Index(), optional=True), 292 | } 293 | 294 | @property 295 | def output_types(self): 296 | return { 297 | "style_emb": NeuralType(('B', 'D'), EncodedRepresentation()), 298 | } 299 | 300 | def forward(self, inputs, token_id=None): 301 | bs = inputs.size(0) 302 | q = self.q_linear(inputs.unsqueeze(1)) 303 | k = self.k_linear(self.tanh(self.tokens).unsqueeze(0).expand(bs, -1, -1)) 304 | v = self.v_linear(self.tanh(self.tokens).unsqueeze(0).expand(bs, -1, -1)) 305 | 306 | q = q.view(bs, q.shape[1], self.n_head, self.token_size) 307 | k = k.view(bs, k.shape[1], self.n_head, self.token_size) 308 | v = v.view(bs, v.shape[1], self.n_head, self.token_size) 309 | 310 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, q.shape[1], q.shape[3]) 311 | k = k.permute(2, 0, 3, 1).contiguous().view(-1, k.shape[3], k.shape[1]) 312 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, v.shape[1], v.shape[3]) 313 | 314 | scores = torch.bmm(q, k) / self.temperature 315 | scores = self.softmax(scores) 316 | if token_id is not None: 317 | scores = torch.zeros_like(scores) 318 | scores[:, :, token_id] = 1 319 | 320 | style_emb = torch.bmm(scores, v).squeeze(1) 321 | style_emb = style_emb.contiguous().view(self.n_head, bs, self.token_size) 322 | style_emb = style_emb.permute(1, 0, 2).contiguous().view(bs, -1) 323 | 324 | return style_emb 325 | 326 | 327 | class Conv2d(nn.Module): 328 | """ 329 | Convolution 2D Module 330 | """ 331 | 332 | def __init__( 333 | self, 334 | in_channels, 335 | out_channels, 336 | kernel_size=(1, 1), 337 | stride=(1, 1), 338 | padding=(0, 0), 339 | dilation=(1, 1), 340 | bias=True, 341 | w_init="linear", 342 | ): 343 | """ 344 | :param in_channels: dimension of input 345 | :param out_channels: dimension of output 346 | :param kernel_size: size of kernel 347 | :param stride: size of stride 348 | :param padding: size of padding 349 | :param dilation: dilation rate 350 | :param bias: boolean. if True, bias is included. 351 | :param w_init: str. weight inits with xavier initialization. 352 | """ 353 | super(Conv2d, self).__init__() 354 | 355 | self.conv = nn.Conv2d( 356 | in_channels, 357 | out_channels, 358 | kernel_size=kernel_size, 359 | stride=stride, 360 | padding=padding, 361 | dilation=dilation, 362 | bias=bias, 363 | ) 364 | 365 | def forward(self, x): 366 | x = x.contiguous().transpose(1, 3) 367 | x = x.contiguous().transpose(2, 3) 368 | x = self.conv(x) 369 | x = x.contiguous().transpose(2, 3) 370 | x = x.contiguous().transpose(1, 3) 371 | return x 372 | 373 | 374 | if __name__=="__main__": 375 | # c1 = torch.load("../encodec_sample.pt") 376 | # c1 = torch.sum(c1, dim=0) 377 | # c2 = torch.load("../encodec_sample2.pt") 378 | # c2 = torch.sum(c2, dim=0) 379 | # codec_length = [c1.size(0), c2.size(0)] 380 | # max_length = max(c1.size(0), c2.size(0)) 381 | # c1 = torch.cat([c1, torch.zeros(max_length - c1.size(0))]) 382 | # c2 = torch.cat([c2, torch.zeros(max_length - c2.size(0))]) 383 | # codec_codes = torch.stack([c1.long(), c2.long()]) 384 | # 385 | # l1 = torch.load( 386 | # "/home/lakahaga/nemo_project/ref_mixer/sup_data/libritts/data2vec_data2vec/lm_embedding/19_198_000007_000000.pt") 387 | # l2 = torch.load( 388 | # "/home/lakahaga/nemo_project/ref_mixer/sup_data/libritts/data2vec_data2vec/lm_embedding/19_198_000010_000000.pt") 389 | # max_length = max(l1.size(0), l2.size(0)) 390 | # lm_length = [l1.size(0), l2.size(0)] 391 | # l1 = torch.cat([l1, torch.zeros(max_length - l1.size(0), l1.size(1))]) 392 | # l2 = torch.cat([l2, torch.zeros(max_length - l2.size(0), l2.size(1))]) 393 | # lm_embed = torch.stack([l1, l2]) 394 | # 395 | # codec_length = torch.as_tensor(codec_length) 396 | # lm_length = torch.as_tensor(lm_length) 397 | 398 | # ref_enc = GlobalStyleTokenForMulti( 399 | # cnn_filters=[32, 32, 64, 64, 128, 128], 400 | # dropout=0.2, 401 | # gru_hidden=384, 402 | # gst_size=384, 403 | # lm_in=768, 404 | # codec_vocab_size=7724, 405 | # codec_pad_token_id=1, 406 | # n_style_token=10, 407 | # n_style_attn_head=4 408 | # ) 409 | # 410 | # ref_enc = ref_enc.to('cuda') 411 | # output = ref_enc(codec_codes.to('cuda'), lm_embed.to('cuda'), codec_length.to('cuda'), lm_length.to('cuda')) 412 | 413 | c1 = torch.load("../encodec_sample.pt") 414 | c2 = torch.load("../encodec_sample2.pt") 415 | 416 | codec_length = torch.as_tensor([c1.size(1), c2.size(1)]).to('cuda') 417 | max_length = codec_length.max().item() 418 | c1 = torch.cat([c1, torch.zeros(c1.size(0), max_length-c1.size(1))], dim=1) 419 | c2 = torch.cat([c2, torch.zeros(c2.size(0), max_length-c2.size(1))], dim=1) 420 | codec_embedding = torch.stack([c1, c2]).to('cuda') 421 | 422 | ref_enc = GlobalStyleTokenForCodec( 423 | cnn_filters=[32, 32, 64, 64, 128, 128], 424 | dropout=0.2, 425 | gru_hidden=384, 426 | gst_size=384, 427 | n_style_token=10, 428 | n_style_attn_head=4 429 | ) 430 | ref_enc = ref_enc.to('cuda') 431 | output = ref_enc(codec_embedding, codec_length) 432 | -------------------------------------------------------------------------------- /module/ref_mixer_codec_only.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import torch 3 | import torch.nn as nn 4 | from module.codec_embedding import AudioCodecForEmbedding 5 | 6 | from nemo.collections.tts.modules.mixer_tts import MixerTTSBlock, \ 7 | create_time_mix_layer, create_channel_mix_layer,Mix 8 | 9 | class RefMixer(nn.Module): 10 | def __init__( 11 | self, 12 | initial_dim, 13 | gru_out, 14 | expansion_factor, 15 | conv_type, 16 | num_layers, 17 | kernel_sizes, 18 | dropout=0.1, 19 | ): 20 | super().__init__() 21 | if num_layers != len(kernel_sizes): 22 | raise ValueError 23 | self.feature_dim = initial_dim 24 | 25 | self.mixer_blocks = nn.Sequential( 26 | *[ 27 | MixerTTSBlock(initial_dim, expansion_factor, kernel_size, conv_type, dropout) 28 | for kernel_size in kernel_sizes 29 | ], 30 | ) 31 | self.norm = nn.LayerNorm(initial_dim, eps=1e-03) 32 | 33 | self.gru = nn.GRU(initial_dim, gru_out, 1, batch_first=True) 34 | 35 | 36 | def _get_mask_from_length(self, lengths): 37 | mask = ( 38 | torch.arange(lengths.max()).to(lengths.device).expand(lengths.shape[0], 39 | lengths.max()) < lengths.unsqueeze( 40 | 1)).unsqueeze(2) 41 | return mask 42 | 43 | def forward(self, codec_embedding, masks): 44 | # x = multi_embeds * multi_masks 45 | x = codec_embedding.transpose(1, 2).float() 46 | for block in self.mixer_blocks: 47 | x, lens = block(x, masks) 48 | # x += codec_embedding.transpose(1, 2).float() 49 | y = self.norm(x) 50 | 51 | self.gru.flatten_parameters() 52 | _, y = self.gru(y) # whole output , last hidden state (1, batch, featue dim) 53 | y = y.transpose(0,1) # (batch, 1, feature dim) 54 | 55 | return y 56 | 57 | if __name__=="__main__": 58 | c1 = torch.load("../encodec_sample.pt") 59 | # c1 = torch.sum(c1, dim=0) 60 | c2 = torch.load("../encodec_sample2.pt") 61 | # c2 = torch.sum(c2, dim=0) 62 | codec_length = [c1.size(1), c2.size(1)] 63 | max_length = max(c1.size(1), c2.size(1)) 64 | 65 | c1 = torch.cat([c1, torch.zeros(8, max_length - c1.size(1))], dim=1) 66 | c2 = torch.cat([c2, torch.zeros(8, max_length - c2.size(1))], dim=1) 67 | codec_codes = torch.stack([c1.long(),c2.long()]) 68 | 69 | codec_length = torch.as_tensor(codec_length) 70 | 71 | ref_enc = RefMixer( 72 | initial_dim=8, 73 | gru_out=384, 74 | expansion_factor=4, 75 | num_layers=6, 76 | kernel_sizes=[11, 13, 15, 17, 19, 21], 77 | conv_type="depth-wise", 78 | ) 79 | # output = ref_enc(audio_embed, lm_embed, audio_length, lm_length) 80 | ref_enc = ref_enc.to('cuda') 81 | output = ref_enc(codec_codes.to('cuda'), codec_length.to('cuda')) 82 | -------------------------------------------------------------------------------- /module/ref_vits_module.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from module.monotonic_align import maximum_path 7 | from model.helper import rand_slice_segments, generate_path, get_mask_from_lengths 8 | from module.vits_modules import TextEncoder, PosteriorEncoder, DurationPredictor, ResidualCouplingBlock, Generator, StochasticDurationPredictor 9 | 10 | # ref encoders 11 | from module.ref_gst import GlobalStyleTokenForCodec, GlobalStyleTokenForMulti, GlobalStyleToken 12 | from module.ref_mixer_v2 import RefMixer as MultiMixer 13 | from module.ref_mixer_codec_rnn import RefMixer as MultiCodecMixer 14 | from module.ref_mixer_mel import RefMixer as MelMixer 15 | from module.ref_mixer_codec_only import RefMixer as CodecMixer 16 | 17 | class RefSynthesizerTrn(nn.Module): 18 | """ 19 | Synthesizer for Training 20 | """ 21 | 22 | def __init__( 23 | self, 24 | n_vocab, 25 | spec_channels, 26 | segment_size, 27 | inter_channels, 28 | hidden_channels, 29 | filter_channels, 30 | n_heads, 31 | n_layers, 32 | kernel_size, 33 | p_dropout, 34 | padding_idx, 35 | resblock, 36 | resblock_kernel_sizes, 37 | resblock_dilation_sizes, 38 | upsample_rates, 39 | upsample_initial_channel, 40 | upsample_kernel_sizes, 41 | ref_encoder=None, 42 | n_speakers=0, 43 | gin_channels=0, 44 | use_sdp=True, 45 | **kwargs 46 | ): 47 | 48 | super().__init__() 49 | self.n_vocab = n_vocab 50 | self.spec_channels = spec_channels 51 | self.inter_channels = inter_channels 52 | self.hidden_channels = hidden_channels 53 | self.filter_channels = filter_channels 54 | self.n_heads = n_heads 55 | self.n_layers = n_layers 56 | self.kernel_size = kernel_size 57 | self.p_dropout = p_dropout 58 | self.padding_idx = padding_idx 59 | self.resblock = resblock 60 | self.resblock_kernel_sizes = resblock_kernel_sizes 61 | self.resblock_dilation_sizes = resblock_dilation_sizes 62 | self.upsample_rates = upsample_rates 63 | self.upsample_initial_channel = upsample_initial_channel 64 | self.upsample_kernel_sizes = upsample_kernel_sizes 65 | self.segment_size = segment_size 66 | self.n_speakers = n_speakers 67 | self.gin_channels = gin_channels 68 | 69 | self.use_sdp = use_sdp 70 | 71 | self.enc_p = TextEncoder( 72 | n_vocab, 73 | inter_channels, 74 | hidden_channels, 75 | filter_channels, 76 | n_heads, 77 | n_layers, 78 | kernel_size, 79 | p_dropout, 80 | padding_idx, 81 | ) 82 | self.dec = Generator( 83 | inter_channels, 84 | resblock, 85 | resblock_kernel_sizes, 86 | resblock_dilation_sizes, 87 | upsample_rates, 88 | upsample_initial_channel, 89 | upsample_kernel_sizes, 90 | gin_channels=gin_channels, 91 | ) 92 | self.enc_q = PosteriorEncoder( 93 | spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels 94 | ) 95 | self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) 96 | 97 | if use_sdp: 98 | self.dp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels) 99 | else: 100 | self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels) 101 | 102 | if n_speakers > 1: 103 | self.emb_g = nn.Embedding(n_speakers, gin_channels) 104 | 105 | self.ref_encoder = ref_encoder 106 | 107 | def forward(self, text, text_len, spec, spec_len, speakers=None, 108 | ref_spec=None, ref_spec_lens=None, ref_codec=None, ref_codec_lens=None, lm_embedding=None, lm_embedding_lens=None): 109 | x, mean_prior, logscale_prior, text_mask = self.enc_p(text, text_len) 110 | if self.n_speakers > 1: 111 | g = self.emb_g(speakers).unsqueeze(-1) # [b, h, 1] 112 | else: 113 | g = None 114 | 115 | if self.ref_encoder is not None: 116 | style_embedding = None 117 | if isinstance(self.ref_encoder, GlobalStyleToken) or isinstance(self.ref_encoder, 118 | GlobalStyleTokenForCodec) or isinstance( 119 | self.ref_encoder, MelMixer) or isinstance(self.ref_encoder, CodecMixer): 120 | # mel or codec only 121 | if ref_spec is not None and ref_spec_lens is not None: 122 | ref_spec_mask = ( 123 | torch.arange(ref_spec_lens.max()).to(ref_spec.device).expand(ref_spec_lens.shape[0], 124 | ref_spec_lens.max()) < ref_spec_lens.unsqueeze( 125 | 1)).unsqueeze(2) 126 | style_embedding = self.ref_encoder(ref_spec, ref_spec_mask) 127 | else: 128 | raise ValueError("GST needs reference spectrogram") 129 | if isinstance(self.ref_encoder, GlobalStyleTokenForMulti) or isinstance(self.ref_encoder, 130 | MultiMixer) or isinstance( 131 | self.ref_encoder, MultiCodecMixer): 132 | # codec + lm 133 | if ref_codec is not None and ref_codec_lens is not None and lm_embedding is not None and lm_embedding_lens is not None: 134 | style_embedding = self.ref_encoder(ref_codec, lm_embedding, ref_codec_lens, lm_embedding_lens) 135 | else: 136 | raise ValueError( 137 | "Multi Reference encoder needs codec, codec length, lm embedding, lm embedding lens") 138 | g += style_embedding.transpose(1, 2) # g is [b,h,1] ans style embedding is [b,1,h] => needs transpose 139 | 140 | z, mean_posterior, logscale_posterior, spec_mask = self.enc_q(spec, spec_len, g=g) 141 | z_p = self.flow(z, spec_mask, g=g) 142 | 143 | with torch.no_grad(): 144 | # negative cross-entropy 145 | s_p_sq_r = torch.exp(-2 * logscale_prior) # [b, d, t] 146 | neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi) - logscale_prior, [1], keepdim=True) # [b, 1, t_s] 147 | neg_cent2 = torch.matmul( 148 | -0.5 * (z_p ** 2).transpose(1, 2), s_p_sq_r 149 | ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] 150 | neg_cent3 = torch.matmul( 151 | z_p.transpose(1, 2), (mean_prior * s_p_sq_r) 152 | ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] 153 | neg_cent4 = torch.sum(-0.5 * (mean_prior ** 2) * s_p_sq_r, [1], keepdim=True) # [b, 1, t_s] 154 | neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4 155 | 156 | attn_mask = torch.unsqueeze(text_mask, 2) * torch.unsqueeze(spec_mask, -1) 157 | attn = maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach() 158 | 159 | w = attn.sum(2) 160 | if self.use_sdp: 161 | l_length = self.dp(x, text_mask, w, g=g) 162 | l_length = l_length / torch.sum(text_mask) 163 | else: 164 | logw_ = torch.log(w + 1e-6) * text_mask 165 | logw = self.dp(x, text_mask, g=g) 166 | l_length = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(text_mask) # for averaging 167 | 168 | # expand prior 169 | mean_prior = torch.matmul(attn.squeeze(1), mean_prior.transpose(1, 2)).transpose( 170 | 1, 2 171 | ) # [b, t', t], [b, t, d] -> [b, d, t'] 172 | logscale_prior = torch.matmul(attn.squeeze(1), logscale_prior.transpose(1, 2)).transpose( 173 | 1, 2 174 | ) # [b, t', t], [b, t, d] -> [b, d, t'] 175 | 176 | z_slice, ids_slice = rand_slice_segments(z, spec_len, self.segment_size) 177 | audio = self.dec(z_slice, g=g) 178 | return ( 179 | audio, 180 | l_length, 181 | attn, 182 | ids_slice, 183 | text_mask, 184 | spec_mask, 185 | (z, z_p, mean_prior, logscale_prior, mean_posterior, logscale_posterior), 186 | ) 187 | 188 | def infer(self, text, text_len, speakers=None, noise_scale=1, length_scale=1, noise_scale_w=1.0, max_len=None, 189 | ref_spec=None, ref_spec_lens=None, ref_codec=None, ref_codec_lens=None, lm_embedding=None, lm_embedding_lens=None): 190 | x, mean_prior, logscale_prior, text_mask = self.enc_p(text, text_len) 191 | if self.n_speakers > 1 and speakers is not None: 192 | g = self.emb_g(speakers).unsqueeze(-1) # [b, h, 1] 193 | else: 194 | g = None 195 | 196 | if self.ref_encoder is not None: 197 | style_embedding = None 198 | if isinstance(self.ref_encoder, GlobalStyleToken) or isinstance(self.ref_encoder, 199 | GlobalStyleTokenForCodec) or isinstance( 200 | self.ref_encoder, MelMixer) or isinstance(self.ref_encoder, CodecMixer): 201 | # mel or codec only 202 | if ref_spec is not None and ref_spec_lens is not None: 203 | ref_spec_mask = ( 204 | torch.arange(ref_spec_lens.max()).to(ref_spec.device).expand(ref_spec_lens.shape[0], 205 | ref_spec_lens.max()) < ref_spec_lens.unsqueeze( 206 | 1)).unsqueeze(2) 207 | style_embedding = self.ref_encoder(ref_spec, ref_spec_mask) 208 | else: 209 | raise ValueError("GST needs reference spectrogram") 210 | if isinstance(self.ref_encoder, GlobalStyleTokenForMulti) or isinstance(self.ref_encoder, 211 | MultiMixer) or isinstance( 212 | self.ref_encoder, MultiCodecMixer): 213 | # codec + lm 214 | if ref_codec is not None and ref_codec_lens is not None and lm_embedding is not None and lm_embedding_lens is not None: 215 | style_embedding = self.ref_encoder(ref_codec, lm_embedding, ref_codec_lens, lm_embedding_lens) 216 | else: 217 | raise ValueError( 218 | "Multi Reference encoder needs codec, codec length, lm embedding, lm embedding lens") 219 | g += style_embedding.transpose(1, 2) # g is [b,h,1] ans style embedding is [b,1,h] => needs transpose 220 | 221 | if self.use_sdp: 222 | logw = self.dp(x, text_mask, g=g, reverse=True, noise_scale=noise_scale_w) 223 | else: 224 | logw = self.dp(x, text_mask, g=g) 225 | w = torch.exp(logw) * text_mask * length_scale 226 | w_ceil = torch.ceil(w) 227 | audio_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() 228 | audio_mask = torch.unsqueeze(get_mask_from_lengths(audio_lengths, None), 1).to(text_mask.dtype) 229 | attn_mask = torch.unsqueeze(text_mask, 2) * torch.unsqueeze(audio_mask, -1) 230 | attn = generate_path(w_ceil, attn_mask) 231 | 232 | mean_prior = torch.matmul(attn.squeeze(1), mean_prior.transpose(1, 2)).transpose( 233 | 1, 2 234 | ) # [b, t', t], [b, t, d] -> [b, d, t'] 235 | logscale_prior = torch.matmul(attn.squeeze(1), logscale_prior.transpose(1, 2)).transpose( 236 | 1, 2 237 | ) # [b, t', t], [b, t, d] -> [b, d, t'] 238 | 239 | z_p = mean_prior + torch.randn_like(mean_prior) * torch.exp(logscale_prior) * noise_scale 240 | z = self.flow(z_p, audio_mask, g=g, reverse=True) 241 | audio = self.dec((z * audio_mask)[:, :, :max_len], g=g) 242 | return audio, attn, audio_mask, (z, z_p, mean_prior, logscale_prior) 243 | 244 | # Can be used for emotions 245 | def voice_conversion(self, y, y_lengths, speaker_src, speaker_tgt): 246 | assert self.n_speakers > 1, "n_speakers have to be larger than 1." 247 | g_src = self.emb_g(speaker_src).unsqueeze(-1) 248 | g_tgt = self.emb_g(speaker_tgt).unsqueeze(-1) 249 | z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src) 250 | z_p = self.flow(z, y_mask, g=g_src) 251 | z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) 252 | o_hat = self.dec(z_hat * y_mask, g=g_tgt) 253 | return o_hat, y_mask, (z, z_p, z_hat) 254 | -------------------------------------------------------------------------------- /module/vits_losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # MIT License 16 | # 17 | # Copyright (c) 2021 Jaehyeon Kim 18 | # 19 | # Permission is hereby granted, free of charge, to any person obtaining a copy 20 | # of this software and associated documentation files (the "Software"), to deal 21 | # in the Software without restriction, including without limitation the rights 22 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 23 | # copies of the Software, and to permit persons to whom the Software is 24 | # furnished to do so, subject to the following conditions: 25 | # 26 | # The above copyright notice and this permission notice shall be included in all 27 | # copies or substantial portions of the Software. 28 | # 29 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 30 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 31 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 32 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 33 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 34 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 35 | # SOFTWARE. 36 | 37 | # The forward functions of the following classes are based on code from https://github.com/jaywalnut310/vits: 38 | # KlLoss 39 | 40 | import torch 41 | 42 | from nemo.core.classes import Loss, typecheck 43 | from nemo.core.neural_types.elements import LossType, VoidType 44 | from nemo.core.neural_types.neural_type import NeuralType 45 | 46 | 47 | class KlLoss(Loss): 48 | @property 49 | def input_types(self): 50 | return { 51 | "z_p": [NeuralType(('B', 'D', 'T'), VoidType())], 52 | "logs_q": [NeuralType(('B', 'D', 'T'), VoidType())], 53 | "m_p": [NeuralType(('B', 'D', 'T'), VoidType())], 54 | "logs_p": [NeuralType(('B', 'D', 'T'), VoidType())], 55 | "z_mask": [NeuralType(('B', 'D', 'T'), VoidType())], 56 | } 57 | 58 | @property 59 | def output_types(self): 60 | return { 61 | "loss": NeuralType(elements_type=LossType()), 62 | } 63 | 64 | @typecheck() 65 | def forward(self, z_p, logs_q, m_p, logs_p, z_mask): 66 | """ 67 | z_p: Input distribution 68 | logs_q: LogVariance of target distrubution 69 | m_p: Mean of input distrubution 70 | logs_p: LogVariance of input distrubution 71 | """ 72 | z_p = z_p.float() 73 | logs_q = logs_q.float() 74 | m_p = m_p.float() 75 | logs_p = logs_p.float() 76 | z_mask = z_mask.float() 77 | 78 | kl = logs_p - logs_q - 0.5 79 | kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) 80 | kl = torch.sum(kl * z_mask) 81 | l = kl / torch.sum(z_mask) 82 | return l 83 | 84 | 85 | class FeatureMatchingLoss(Loss): 86 | """VITS Feature Matching Loss module""" 87 | 88 | @property 89 | def input_types(self): 90 | return { 91 | "fmap_r": [[NeuralType(elements_type=VoidType())]], 92 | "fmap_g": [[NeuralType(elements_type=VoidType())]], 93 | } 94 | 95 | @property 96 | def output_types(self): 97 | return { 98 | "loss": NeuralType(elements_type=LossType()), 99 | } 100 | 101 | @typecheck() 102 | def forward(self, fmap_r, fmap_g): 103 | """ 104 | fmap_r, fmap_g: List[List[Tensor]] 105 | """ 106 | loss = 0 107 | for dr, dg in zip(fmap_r, fmap_g): 108 | for rl, gl in zip(dr, dg): 109 | rl = rl.float().detach() 110 | gl = gl.float() 111 | loss += torch.mean(torch.abs(rl - gl)) 112 | 113 | return loss * 2 114 | 115 | 116 | class DiscriminatorLoss(Loss): 117 | """Discriminator Loss module""" 118 | 119 | @property 120 | def input_types(self): 121 | return { 122 | "disc_real_outputs": [NeuralType(('B', 'T'), VoidType())], 123 | "disc_generated_outputs": [NeuralType(('B', 'T'), VoidType())], 124 | } 125 | 126 | @property 127 | def output_types(self): 128 | return { 129 | "loss": NeuralType(elements_type=LossType()), 130 | "real_losses": [NeuralType(elements_type=LossType())], 131 | "fake_losses": [NeuralType(elements_type=LossType())], 132 | } 133 | 134 | @typecheck() 135 | def forward(self, disc_real_outputs, disc_generated_outputs): 136 | r_losses = [] 137 | g_losses = [] 138 | loss = 0 139 | for i, (dr, dg) in enumerate(zip(disc_real_outputs, disc_generated_outputs)): 140 | dr = dr.float() 141 | dg = dg.float() 142 | r_loss = torch.mean((1 - dr) ** 2) 143 | g_loss = torch.mean(dg ** 2) 144 | loss += r_loss + g_loss 145 | r_losses.append(r_loss.item()) 146 | g_losses.append(g_loss.item()) 147 | 148 | return loss, r_losses, g_losses 149 | 150 | 151 | class GeneratorLoss(Loss): 152 | """Generator Loss module""" 153 | 154 | @property 155 | def input_types(self): 156 | return { 157 | "disc_outputs": [NeuralType(('B', 'T'), VoidType())], 158 | } 159 | 160 | @property 161 | def output_types(self): 162 | return { 163 | "loss": NeuralType(elements_type=LossType()), 164 | "fake_losses": [NeuralType(elements_type=LossType())], 165 | } 166 | 167 | @typecheck() 168 | def forward(self, disc_outputs): 169 | loss = 0 170 | gen_losses = [] 171 | for dg in disc_outputs: 172 | dg = dg.float() 173 | l = torch.mean((1 - dg) ** 2) 174 | gen_losses.append(l) 175 | loss += l 176 | 177 | return loss, gen_losses 178 | -------------------------------------------------------------------------------- /preprocess/make_manifest.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import subprocess 4 | from pathlib import Path 5 | import fnmatch 6 | import multiprocessing 7 | import functools 8 | from tqdm import tqdm 9 | 10 | 11 | def speaker_mapping(): 12 | VCTK = Path("VCTK-Corpus/txt") 13 | 14 | speakers = sorted(os.listdir(VCTK)) 15 | 16 | speaker_map = {} 17 | for i, x in enumerate(speakers): 18 | speaker_map[x] = i 19 | # speaker_map = [{'id': x, 'num_id': i} for i, x in enumerate(speakers)] 20 | 21 | if not os.path.exists("data/vctk"): 22 | os.mkdir("data/vctk") 23 | 24 | json.dump(speaker_map, open("data/vctk/speaker_map.json",'w'), ensure_ascii=False, indent=1) 25 | 26 | 27 | 28 | 29 | 30 | def __process_transcript(file_path: str): 31 | entries = [] 32 | with open(file_path, encoding="utf-8") as fin: 33 | text = fin.readlines()[0].strip() 34 | 35 | wav_file = file_path.replace(".txt", ".wav") 36 | wav_file = wav_file.replace("/txt/", "/wav22/") 37 | speaker_id = file_path.split('/')[-2] 38 | assert os.path.exists(wav_file), f"{wav_file} not found!" 39 | duration = subprocess.check_output(f"soxi -D {wav_file}", shell=True) 40 | entry = { 41 | 'audio_filepath': os.path.abspath(wav_file), 42 | 'duration': float(duration), 43 | 'text': text, 44 | 'speaker': speaker_map[speaker_id], 45 | } 46 | 47 | entries.append(entry) 48 | 49 | return entries 50 | 51 | 52 | def make_manifest(): 53 | num_workers = 4 54 | 55 | VCTK = Path("VCTK-Corpus") 56 | 57 | txt_dir = VCTK / "txt" 58 | manifests = [] 59 | files = [] 60 | 61 | for root, dirnames, filenames in os.walk(txt_dir): 62 | # we will use normalized text provided by the original dataset 63 | for filename in fnmatch.filter(filenames, '*.txt'): 64 | files.append(os.path.join(root, filename)) 65 | 66 | with multiprocessing.Pool(num_workers) as p: 67 | processing_func = functools.partial(__process_transcript) 68 | results = p.imap(processing_func, files) 69 | for result in tqdm(results, total=len(files)): 70 | manifests.extend(result) 71 | 72 | if not os.path.exists("data/vctk"): 73 | os.mkdir("data/vctk") 74 | manifest_file = "data/vctk/all_manifest.json" 75 | with open(manifest_file, 'w') as fout: 76 | for m in manifests: 77 | fout.write(json.dumps(m) + '\n') 78 | 79 | 80 | def split_manifest(): 81 | all = [json.loads(x) for x in open('data/vctk/all_manifest.json').readlines()] 82 | 83 | speaker_split = {} 84 | for x in all: 85 | if x['speaker'] not in speaker_split.keys(): 86 | speaker_split[x['speaker']] = [] 87 | speaker_split[x['speaker']].append(x) 88 | print(len(speaker_split.keys())) 89 | total_train, total_dev, total_eval = [], [], [] 90 | for spk, all in tqdm(speaker_split.items()): 91 | train = int(len(all) * 0.8) 92 | deveval = len(all) - train 93 | if deveval % 2 != 0: 94 | dev = int(deveval / 2) + 1 95 | else: 96 | dev = int(deveval / 2) 97 | eval = deveval - dev 98 | assert train + dev + eval == len(all) 99 | 100 | train_manifest = all[:train] 101 | dev_manifest = all[train: train + dev] 102 | eval_manifest = all[train + dev:] 103 | assert len(train_manifest) + len(dev_manifest) + len(eval_manifest) == len(all) 104 | 105 | total_train.extend(train_manifest) 106 | total_dev.extend(dev_manifest) 107 | total_eval.extend(eval_manifest) 108 | 109 | if not os.path.exists("data/vctk"): 110 | os.mkdir("data/vctk") 111 | with open("data/vctk/train_manifest.json", 'w') as f: 112 | for x in total_train: 113 | f.write(json.dumps(x) + "\n") 114 | with open("data/vctk/valid_manifest.json", 'w') as f: 115 | for x in total_dev: 116 | f.write(json.dumps(x) + "\n") 117 | with open("data/vctk/test_manifest.json", 'w') as f: 118 | for x in total_eval: 119 | f.write(json.dumps(x) + "\n") 120 | print(f"Train dataset: {len(total_train)}") 121 | print(f"Valid dataset: {len(total_dev)}") 122 | print(f"Test dataset: {len(total_eval)}") 123 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | nemo_toolkit['tts']==1.18.0 2 | numpy=1.23.5 3 | wandb==0.13.5 4 | pynini==2.1.5 5 | --extra-index-url https://download.pytorch.org/whl/cu113 6 | torch==1.11.0+cu113 7 | torchaudio==0.11.0+cu113 -------------------------------------------------------------------------------- /torchdata/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from tqdm import tqdm 4 | import json 5 | from typing import Callable, Dict, List, Optional, Union 6 | from hydra.utils import instantiate 7 | from nemo.core.config import hydra_runner 8 | 9 | 10 | import torch 11 | import torchaudio 12 | from transformers import AutoTokenizer, AutoModel, AutoModelForAudioXVector 13 | 14 | from nemo.collections.tts.torch.data import TTSDataset 15 | 16 | from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer 17 | from nemo.collections.asr.parts.preprocessing.segment import AudioSegment 18 | from nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers import ( 19 | BaseTokenizer, 20 | EnglishCharsTokenizer, 21 | EnglishPhonemesTokenizer, 22 | ) 23 | from nemo.collections.tts.torch.helpers import ( 24 | BetaBinomialInterpolator, 25 | beta_binomial_prior_distribution, 26 | general_padding, 27 | get_base_dir, 28 | ) 29 | from nemo.collections.tts.torch.tts_data_types import ( 30 | DATA_STR2DATA_CLASS, 31 | MAIN_DATA_TYPES, 32 | AlignPriorMatrix, 33 | Durations, 34 | Energy, 35 | LMTokens, 36 | LogMel, 37 | P_voiced, 38 | Pitch, 39 | SpeakerID, 40 | TTSDataType, 41 | Voiced_mask, 42 | WithLens, 43 | ) 44 | from nemo.core.classes import Dataset 45 | from nemo.utils import logging 46 | 47 | try: 48 | from nemo_text_processing.text_normalization.normalize import Normalizer 49 | 50 | PYNINI_AVAILABLE = True 51 | except (ImportError, ModuleNotFoundError): 52 | Normalizer = None 53 | PYNINI_AVAILABLE = False 54 | 55 | EPSILON = 1e-9 56 | WINDOW_FN_SUPPORTED = { 57 | 'hann': torch.hann_window, 58 | 'hamming': torch.hamming_window, 59 | 'blackman': torch.blackman_window, 60 | 'bartlett': torch.bartlett_window, 61 | 'none': None, 62 | } 63 | 64 | 65 | class MultiTTSDataset(TTSDataset): 66 | def __init__( 67 | self, 68 | manifest_filepath: Union[str, Path, List[str], List[Path]], 69 | sample_rate: int, 70 | text_tokenizer: Union[BaseTokenizer, Callable[[str], List[int]]], 71 | tokens: Optional[List[str]] = None, 72 | text_normalizer: Optional[Union[Normalizer, Callable[[str], str]]] = None, 73 | text_normalizer_call_kwargs: Optional[Dict] = None, 74 | text_tokenizer_pad_id: Optional[int] = None, 75 | sup_data_types: Optional[List[str]] = None, 76 | sup_data_path: Optional[Union[Path, str]] = None, 77 | max_duration: Optional[float] = None, 78 | min_duration: Optional[float] = None, 79 | ignore_file: Optional[Union[str, Path]] = None, 80 | trim: bool = False, 81 | trim_ref: Optional[float] = None, 82 | trim_top_db: Optional[int] = None, 83 | trim_frame_length: Optional[int] = None, 84 | trim_hop_length: Optional[int] = None, 85 | n_fft: int = 1024, 86 | win_length: Optional[int] = None, 87 | hop_length: Optional[int] = None, 88 | window: str = "hann", 89 | n_mels: int = 80, 90 | lowfreq: int = 0, 91 | highfreq: Optional[int] = None, 92 | use_spk_id: bool = False, 93 | use_xvector: bool = False, 94 | **kwargs, 95 | ): 96 | if use_spk_id==True and use_xvector==True: 97 | raise ValueError("You either use speaker id or xvector for speaker embedding, you cannot use both") 98 | 99 | super().__init__(manifest_filepath=manifest_filepath, 100 | sample_rate=sample_rate, 101 | text_tokenizer=text_tokenizer, 102 | text_normalizer=text_normalizer, 103 | text_normalizer_call_kwargs=text_normalizer_call_kwargs, 104 | sup_data_path=sup_data_path, 105 | sup_data_types=sup_data_types, 106 | n_fft=n_fft, 107 | win_length=win_length, 108 | hop_length=hop_length, 109 | window=window, 110 | n_mels=n_mels, 111 | lowfreq=lowfreq, 112 | highfreq=highfreq, 113 | max_duration=max_duration, 114 | min_duration=min_duration, 115 | ignore_file=ignore_file, 116 | trim=trim, 117 | trim_top_db=trim_top_db, 118 | **kwargs) 119 | 120 | lm_model = kwargs.pop("lm_model") 121 | audio_model = kwargs.pop("audio_model") 122 | self.use_xvector = use_xvector 123 | self.use_spk_id = use_spk_id 124 | 125 | self.pretrained_model_name = lm_model.split("/")[-1].split("-")[0] + "_" + \ 126 | audio_model.split("/")[-1].split('-')[0] 127 | self.pretrained_model_name = Path(self.sup_data_path) / self.pretrained_model_name 128 | if not os.path.exists(self.pretrained_model_name): 129 | os.makedirs(self.pretrained_model_name, exist_ok=True) 130 | 131 | self.lm_model_tokenizer = AutoTokenizer.from_pretrained(lm_model) 132 | self.lm_model = AutoModel.from_pretrained(lm_model).eval() 133 | self.audio_model = AutoModel.from_pretrained(audio_model).eval() 134 | if self.use_xvector: 135 | xvector_model = kwargs.pop("xvector_model") 136 | self.xvector_model = AutoModelForAudioXVector.from_pretrained(xvector_model).eval() 137 | 138 | lm_folder = Path(self.pretrained_model_name) / "lm_embedding" 139 | if not os.path.exists(lm_folder): 140 | os.makedirs(lm_folder) 141 | 142 | audio_folder = Path(self.pretrained_model_name) / "audio_embedding" 143 | if not os.path.exists(audio_folder): 144 | os.makedirs(audio_folder) 145 | 146 | if self.use_xvector: 147 | speaker_folder = Path(self.pretrained_model_name) / "speaker_embedding" 148 | if not os.path.exists(speaker_folder): 149 | os.makedirs(speaker_folder) 150 | 151 | def __getitem__(self, index): 152 | try: 153 | ( 154 | audio, 155 | audio_length, 156 | text, 157 | text_length, 158 | log_mel, 159 | log_mel_length, 160 | durations, 161 | align_prior_matrix, 162 | pitch, 163 | pitch_length, 164 | energy, 165 | energy_length, 166 | speaker_id, 167 | voiced_mask, 168 | p_voiced, 169 | ) = super().__getitem__(index) 170 | except AttributeError: 171 | print(self.data[index]['audio_filepath']) 172 | sample = self.data[index] 173 | # find lm, audio embedding from sup_data path 174 | # if it exists, load if 175 | # else, run model and save it 176 | 177 | 178 | lm_path = self.pretrained_model_name / "lm_embedding" 179 | audio_path = self.pretrained_model_name / "audio_embedding" 180 | speaker_path = self.pretrained_model_name / "speaker_embedding" 181 | file_name = sample['audio_filepath'].split("/")[-1] 182 | file_name = file_name.replace("wav", "pt") 183 | 184 | lm_file = os.path.join(lm_path, file_name) 185 | if os.path.exists(lm_file): 186 | try: 187 | lm_embedding = torch.load(lm_file) 188 | except RuntimeError: 189 | lm_inputs = self.lm_model_tokenizer(sample['original_text'], return_tensors='pt') 190 | lm_embedding = self.lm_model(**lm_inputs).last_hidden_state 191 | lm_embedding = lm_embedding.squeeze(0) 192 | torch.save(lm_embedding, lm_file) 193 | else: 194 | lm_inputs = self.lm_model_tokenizer(sample['original_text'], return_tensors='pt') 195 | lm_embedding = self.lm_model(**lm_inputs).last_hidden_state 196 | lm_embedding = lm_embedding.squeeze(0) 197 | torch.save(lm_embedding, lm_file) 198 | 199 | audio_file = os.path.join(audio_path, file_name) 200 | if os.path.exists(audio_file): 201 | try: 202 | audio_embedding = torch.load(audio_file) 203 | except RuntimeError: 204 | audio_input, sr = torchaudio.load(sample['audio_filepath']) 205 | resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000) 206 | audio_input = resampler(audio_input) 207 | audio_embedding = self.audio_model(input_values=audio_input).last_hidden_state 208 | audio_embedding = audio_embedding.squeeze(0) 209 | torch.save(audio_embedding, audio_file) 210 | else: 211 | audio_input, sr = torchaudio.load(sample['audio_filepath']) 212 | resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000) 213 | audio_input = resampler(audio_input) 214 | audio_embedding = self.audio_model(input_values=audio_input).last_hidden_state 215 | audio_embedding = audio_embedding.squeeze(0) 216 | torch.save(audio_embedding, audio_file) 217 | 218 | speaker_embedding = None 219 | if self.use_xvector: 220 | speaker_file = os.path.join(speaker_path, file_name) 221 | if os.path.exists(speaker_file): 222 | speaker_embedding = torch.load(speaker_file) 223 | else: 224 | audio_input, sr = torchaudio.load(sample['audio_filepath']) 225 | resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000) 226 | audio_input = resampler(audio_input) 227 | speaker_embedding = self.xvector_model(input_values=audio_input).embeddings 228 | speaker_embedding = speaker_embedding.squeeze(0) 229 | torch.save(speaker_embedding, speaker_file) 230 | if speaker_embedding != None: 231 | speaker_embedding = speaker_embedding.detach() 232 | 233 | audio_embedding_length = torch.tensor(audio_embedding.size(0)).long() 234 | lm_embedding_length = torch.tensor(lm_embedding.size(0)).long() 235 | 236 | return ( 237 | audio, 238 | audio_length, 239 | text, 240 | text_length, 241 | log_mel, 242 | log_mel_length, 243 | durations, 244 | align_prior_matrix, 245 | pitch, 246 | pitch_length, 247 | energy, 248 | energy_length, 249 | speaker_id, 250 | voiced_mask, 251 | p_voiced, 252 | audio_embedding, 253 | audio_embedding_length, 254 | lm_embedding, 255 | lm_embedding_length, 256 | speaker_embedding, 257 | ) 258 | 259 | def general_collate_fn(self, batch): 260 | ( 261 | _, 262 | audio_lengths, 263 | _, 264 | tokens_lengths, 265 | _, 266 | log_mel_lengths, 267 | durations_list, 268 | align_prior_matrices_list, 269 | pitches, 270 | pitches_lengths, 271 | energies, 272 | energies_lengths, 273 | _, 274 | voiced_masks, 275 | p_voiceds, 276 | audio_embeddings, 277 | audio_embedding_lengths, 278 | lm_embeddings, 279 | lm_embedding_lengths, 280 | speaker_embeddings, 281 | ) = zip(*batch) 282 | 283 | max_audio_len = max(audio_lengths).item() 284 | max_tokens_len = max(tokens_lengths).item() 285 | max_log_mel_len = max(log_mel_lengths) if LogMel in self.sup_data_types_set else None 286 | max_durations_len = max([len(i) for i in durations_list]) if Durations in self.sup_data_types_set else None 287 | max_pitches_len = max(pitches_lengths).item() if Pitch in self.sup_data_types_set else None 288 | max_energies_len = max(energies_lengths).item() if Energy in self.sup_data_types_set else None 289 | max_audio_embedding_len = max(audio_embedding_lengths).item() 290 | max_lm_embedding_len = max(lm_embedding_lengths).item() 291 | 292 | if LogMel in self.sup_data_types_set: 293 | log_mel_pad = torch.finfo(batch[0][4].dtype).tiny 294 | 295 | align_prior_matrices = ( 296 | torch.zeros( 297 | len(align_prior_matrices_list), 298 | max([prior_i.shape[0] for prior_i in align_prior_matrices_list]), 299 | max([prior_i.shape[1] for prior_i in align_prior_matrices_list]), 300 | ) 301 | if AlignPriorMatrix in self.sup_data_types_set 302 | else [] 303 | ) 304 | audios, tokens, log_mels, durations_list, pitches, energies, speaker_ids, voiced_masks, p_voiceds, audio_embeddings, lm_embeddings = ( 305 | [], 306 | [], 307 | [], 308 | [], 309 | [], 310 | [], 311 | [], 312 | [], 313 | [], 314 | [], 315 | [] 316 | ) 317 | 318 | for i, sample_tuple in enumerate(batch): 319 | ( 320 | audio, 321 | audio_len, 322 | token, 323 | token_len, 324 | log_mel, 325 | log_mel_len, 326 | durations, 327 | align_prior_matrix, 328 | pitch, 329 | pitch_length, 330 | energy, 331 | energy_length, 332 | speaker_id, 333 | voiced_mask, 334 | p_voiced, 335 | audio_embedding, 336 | audio_embedding_length, 337 | lm_embedding, 338 | lm_embedding_length, 339 | speaker_embedding, 340 | ) = sample_tuple 341 | 342 | audio = general_padding(audio, audio_len.item(), max_audio_len) 343 | audios.append(audio) 344 | 345 | token = general_padding(token, token_len.item(), max_tokens_len, pad_value=self.text_tokenizer_pad_id) 346 | tokens.append(token) 347 | 348 | if LogMel in self.sup_data_types_set: 349 | log_mels.append(general_padding(log_mel, log_mel_len, max_log_mel_len, pad_value=log_mel_pad)) 350 | 351 | if Durations in self.sup_data_types_set: 352 | durations_list.append(general_padding(durations, len(durations), max_durations_len)) 353 | 354 | if AlignPriorMatrix in self.sup_data_types_set: 355 | align_prior_matrices[ 356 | i, : align_prior_matrix.shape[0], : align_prior_matrix.shape[1] 357 | ] = align_prior_matrix 358 | 359 | if Pitch in self.sup_data_types_set: 360 | pitches.append(general_padding(pitch, pitch_length.item(), max_pitches_len)) 361 | 362 | if Voiced_mask in self.sup_data_types_set: 363 | voiced_masks.append(general_padding(voiced_mask, pitch_length.item(), max_pitches_len)) 364 | 365 | if P_voiced in self.sup_data_types_set: 366 | p_voiceds.append(general_padding(p_voiced, pitch_length.item(), max_pitches_len)) 367 | 368 | if Energy in self.sup_data_types_set: 369 | energies.append(general_padding(energy, energy_length.item(), max_energies_len)) 370 | 371 | if SpeakerID in self.sup_data_types_set: 372 | speaker_ids.append(speaker_id) 373 | 374 | pad = torch.zeros(max_audio_embedding_len - audio_embedding_length.item(), audio_embedding.size(-1)) 375 | audio_embedding = torch.cat([audio_embedding, pad], dim=0) 376 | audio_embeddings.append(audio_embedding.detach()) 377 | 378 | pad = torch.zeros(max_lm_embedding_len - lm_embedding_length.item(), lm_embedding.size(-1)) 379 | lm_embedding = torch.cat([lm_embedding, pad], dim=0) 380 | lm_embeddings.append(lm_embedding.detach()) 381 | 382 | data_dict = { 383 | "audio": torch.stack(audios), 384 | "audio_lens": torch.stack(audio_lengths), 385 | "text": torch.stack(tokens), 386 | "text_lens": torch.stack(tokens_lengths), 387 | "log_mel": torch.stack(log_mels) if LogMel in self.sup_data_types_set else None, 388 | "log_mel_lens": torch.stack(log_mel_lengths) if LogMel in self.sup_data_types_set else None, 389 | "durations": torch.stack(durations_list) if Durations in self.sup_data_types_set else None, 390 | "align_prior_matrix": align_prior_matrices if AlignPriorMatrix in self.sup_data_types_set else None, 391 | "pitch": torch.stack(pitches) if Pitch in self.sup_data_types_set else None, 392 | "pitch_lens": torch.stack(pitches_lengths) if Pitch in self.sup_data_types_set else None, 393 | "energy": torch.stack(energies) if Energy in self.sup_data_types_set else None, 394 | "energy_lens": torch.stack(energies_lengths) if Energy in self.sup_data_types_set else None, 395 | "speaker_id": torch.stack(speaker_ids) if SpeakerID in self.sup_data_types_set else None, 396 | "voiced_mask": torch.stack(voiced_masks) if Voiced_mask in self.sup_data_types_set else None, 397 | "p_voiced": torch.stack(p_voiceds) if P_voiced in self.sup_data_types_set else None, 398 | "audio_embedding": torch.stack(audio_embeddings), 399 | "audio_embedding_lens": torch.stack(audio_embedding_lengths), 400 | "lm_embedding": torch.stack(lm_embeddings), 401 | "lm_embedding_lens": torch.stack(lm_embedding_lengths), 402 | "speaker_embedding": torch.stack(speaker_embeddings) if self.use_xvector else None 403 | } 404 | return data_dict 405 | 406 | def join_data(self, data_dict): 407 | result = [] 408 | for data_type in MAIN_DATA_TYPES + self.sup_data_types: 409 | result.append(data_dict[data_type.name]) 410 | 411 | if issubclass(data_type, TTSDataType) and issubclass(data_type, WithLens): 412 | result.append(data_dict[f"{data_type.name}_lens"]) 413 | result.append(data_dict['audio_embedding']) 414 | result.append(data_dict['audio_embedding_lens']) 415 | result.append(data_dict['lm_embedding']) 416 | result.append(data_dict['lm_embedding_lens']) 417 | if data_dict['speaker_embedding'] != None: 418 | result.append(data_dict['speaker_embedding']) 419 | return tuple(result) 420 | 421 | from torch.utils.data.distributed import DistributedSampler 422 | 423 | 424 | class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler): 425 | """ 426 | Maintain similar input lengths in a batch. 427 | Length groups are specified by boundaries. 428 | Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}. 429 | 430 | It removes samples which are not included in the boundaries. 431 | Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded. 432 | """ 433 | 434 | def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True): 435 | super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) 436 | self.lengths = dataset.lengths 437 | self.batch_size = batch_size 438 | self.boundaries = boundaries 439 | 440 | self.buckets, self.num_samples_per_bucket = self._create_buckets() 441 | self.total_size = sum(self.num_samples_per_bucket) 442 | self.num_samples = self.total_size // self.num_replicas 443 | 444 | def _create_buckets(self): 445 | buckets = [[] for _ in range(len(self.boundaries) - 1)] 446 | for i in range(len(self.lengths)): 447 | length = self.lengths[i] 448 | idx_bucket = self._bisect(length) 449 | if idx_bucket != -1: 450 | buckets[idx_bucket].append(i) 451 | 452 | for i in range(len(buckets) - 1, 0, -1): 453 | if len(buckets[i]) == 0: 454 | buckets.pop(i) 455 | self.boundaries.pop(i + 1) 456 | 457 | num_samples_per_bucket = [] 458 | total_batch_size = self.num_replicas * self.batch_size 459 | for i in range(len(buckets)): 460 | len_bucket = len(buckets[i]) 461 | rem = (total_batch_size - (len_bucket % total_batch_size)) % total_batch_size 462 | num_samples_per_bucket.append(len_bucket + rem) 463 | return buckets, num_samples_per_bucket 464 | 465 | def __iter__(self): 466 | # deterministically shuffle based on epoch 467 | g = torch.Generator() 468 | g.manual_seed(self.epoch) 469 | indices = [] 470 | if self.shuffle: 471 | for bucket in self.buckets: 472 | indices.append(torch.randperm(len(bucket), generator=g).tolist()) 473 | else: 474 | for bucket in self.buckets: 475 | indices.append(list(range(len(bucket)))) 476 | 477 | batches = [] 478 | for i in range(len(self.buckets)): 479 | bucket = self.buckets[i] 480 | len_bucket = len(bucket) 481 | ids_bucket = indices[i] 482 | num_samples_bucket = self.num_samples_per_bucket[i] 483 | 484 | # add extra samples to make it evenly divisible 485 | rem = num_samples_bucket - len_bucket 486 | ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[: (rem % len_bucket)] 487 | 488 | # subsample 489 | ids_bucket = ids_bucket[self.rank:: self.num_replicas] 490 | 491 | # batching 492 | for j in range(len(ids_bucket) // self.batch_size): 493 | batch = [bucket[idx] for idx in ids_bucket[j * self.batch_size: (j + 1) * self.batch_size]] 494 | batches.append(batch) 495 | 496 | if self.shuffle: 497 | batch_ids = torch.randperm(len(batches), generator=g).tolist() 498 | batches = [batches[i] for i in batch_ids] 499 | self.batches = batches 500 | 501 | assert len(self.batches) * self.batch_size == self.num_samples 502 | return iter(self.batches) 503 | 504 | def _bisect(self, x, lo=0, hi=None): 505 | if hi is None: 506 | hi = len(self.boundaries) - 1 507 | 508 | if hi > lo: 509 | mid = (hi + lo) // 2 510 | if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]: 511 | return mid 512 | elif x <= self.boundaries[mid]: 513 | return self._bisect(x, lo, mid) 514 | else: 515 | return self._bisect(x, mid + 1, hi) 516 | else: 517 | return -1 518 | 519 | def __len__(self): 520 | return self.num_samples // self.batch_size 521 | 522 | def set_epoch(self, epoch: int) -> None: 523 | """ 524 | Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas 525 | use a different random ordering for each epoch. Otherwise, the next iteration of this 526 | sampler will yield the same ordering. 527 | Args: 528 | epoch (int): Epoch number. 529 | """ 530 | self.epoch = epoch 531 | 532 | 533 | @hydra_runner(config_path="../conf/english", config_name='ds_for_data2vec_vctk.yaml') 534 | def main(cfg): 535 | torch.multiprocessing.set_sharing_strategy('file_system') 536 | dataset = instantiate(cfg.dataset) 537 | dataloader = torch.utils.data.DataLoader( 538 | dataset= dataset, batch_size=1, collate_fn=dataset._collate_fn, num_workers=0 539 | ) 540 | 541 | pitch_list = [] 542 | for batch in tqdm(dataloader, total=len(dataloader)): 543 | audios, audio_lengths, tokens, token_lengths, attn_prior, pitches, pitches_lengths, energy, energy_lens, spk_id, audio_embeddings, audio_emb_length, lm_embeddings, lm_emb_length = batch 544 | pitch = pitches.squeeze(0) 545 | pitch_list.append(pitch[pitch != 0]) 546 | 547 | pitch_tensor = torch.cat(pitch_list) 548 | pitch_mean, pitch_std = pitch_tensor.mean().item(), pitch_tensor.std().item() 549 | pitch_min, pitch_max = pitch_tensor.min().item(), pitch_tensor.max().item() 550 | print(f"PITCH_MEAN={pitch_mean}, PITCH_STD={pitch_std}") 551 | print(f"PITCH_MIN={pitch_min}, PITCH_MAX={pitch_max}") 552 | f = open(os.path.join(cfg.sup_data_path, "pitch_stats.txt"), 'w') 553 | f.write(f"PITCH MEAN : {pitch_mean}\n") 554 | f.write(f"PITCH STD : {pitch_std}\n") 555 | f.write(f"PITCH MIN : {pitch_min}\n") 556 | f.write(f"PITCH MAX : {pitch_max}\n") 557 | f.close() 558 | 559 | if __name__ == "__main__": 560 | main() -------------------------------------------------------------------------------- /torchdata/data_total.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import json 17 | import math 18 | import pickle 19 | import random 20 | from pathlib import Path 21 | from typing import Callable, Dict, List, Optional, Union 22 | 23 | import librosa 24 | import numpy as np 25 | import torch 26 | from tqdm import tqdm 27 | 28 | import torchaudio 29 | from transformers import AutoModel, AutoTokenizer, AutoModelForAudioXVector 30 | from encodec import EncodecModel 31 | 32 | from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer 33 | from nemo.collections.asr.parts.preprocessing.segment import AudioSegment 34 | from nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers import ( 35 | BaseTokenizer, 36 | EnglishCharsTokenizer, 37 | EnglishPhonemesTokenizer, 38 | ) 39 | from nemo.collections.tts.torch.helpers import ( 40 | BetaBinomialInterpolator, 41 | beta_binomial_prior_distribution, 42 | general_padding, 43 | get_base_dir, 44 | ) 45 | from nemo.collections.tts.torch.tts_data_types import ( 46 | MAIN_DATA_TYPES, 47 | AlignPriorMatrix, 48 | Durations, 49 | Energy, 50 | LMTokens, 51 | LogMel, 52 | P_voiced, 53 | Pitch, 54 | SpeakerID, 55 | TTSDataType, 56 | Voiced_mask, 57 | WithLens, 58 | ) 59 | from nemo.core.classes import Dataset 60 | from nemo.utils import logging 61 | 62 | try: 63 | from nemo_text_processing.text_normalization.normalize import Normalizer 64 | 65 | PYNINI_AVAILABLE = True 66 | except (ImportError, ModuleNotFoundError): 67 | Normalizer = None 68 | PYNINI_AVAILABLE = False 69 | 70 | 71 | from data_type import AudioCodec, Xvector, DATA_STR2DATA_CLASS 72 | 73 | EPSILON = 1e-9 74 | WINDOW_FN_SUPPORTED = { 75 | 'hann': torch.hann_window, 76 | 'hamming': torch.hamming_window, 77 | 'blackman': torch.blackman_window, 78 | 'bartlett': torch.bartlett_window, 79 | 'none': None, 80 | } 81 | 82 | class ExtensiveTTSDataset(Dataset): 83 | def __init__( 84 | self, 85 | manifest_filepath: Union[str, Path, List[str], List[Path]], 86 | sample_rate: int, 87 | text_tokenizer: Union[BaseTokenizer, Callable[[str], List[int]]], 88 | tokens: Optional[List[str]] = None, 89 | text_normalizer: Optional[Union[Normalizer, Callable[[str], str]]] = None, 90 | text_normalizer_call_kwargs: Optional[Dict] = None, 91 | text_tokenizer_pad_id: Optional[int] = None, 92 | sup_data_types: Optional[List[str]] = None, 93 | sup_data_path: Optional[Union[Path, str]] = None, 94 | max_duration: Optional[float] = None, 95 | min_duration: Optional[float] = None, 96 | ignore_file: Optional[Union[str, Path]] = None, 97 | trim: bool = False, 98 | trim_ref: Optional[float] = None, 99 | trim_top_db: Optional[int] = None, 100 | trim_frame_length: Optional[int] = None, 101 | trim_hop_length: Optional[int] = None, 102 | n_fft: int = 1024, 103 | win_length: Optional[int] = None, 104 | hop_length: Optional[int] = None, 105 | window: str = "hann", 106 | n_mels: int = 80, 107 | lowfreq: int = 0, 108 | highfreq: Optional[int] = None, 109 | **kwargs, 110 | ): 111 | """Dataset which can be used for training spectrogram generators and end-to-end TTS models. 112 | It loads main data types (audio, text) and specified supplementary data types (log mel, durations, align prior matrix, pitch, energy, speaker id). 113 | Some supplementary data types will be computed on the fly and saved in the sup_data_path if they did not exist before. 114 | Saved folder can be changed for some supplementary data types (see keyword args section). 115 | Arguments for supplementary data should be also specified in this class, and they will be used from kwargs (see keyword args section). 116 | Args: 117 | manifest_filepath (Union[str, Path, List[str], List[Path]]): Path(s) to the .json manifests containing information on the 118 | dataset. Each line in the .json file should be valid json. Note: the .json file itself is not valid 119 | json. Each line should contain the following: 120 | "audio_filepath": , 121 | "text": , 122 | "normalized_text": (Optional), 123 | "mel_filepath": (Optional), 124 | "duration": (Optional), 125 | sample_rate (int): The sample rate of the audio. Or the sample rate that we will resample all files to. 126 | text_tokenizer (Optional[Union[BaseTokenizer, Callable[[str], List[int]]]]): BaseTokenizer or callable which represents text tokenizer. 127 | tokens (Optional[List[str]]): Tokens from text_tokenizer. Should be specified if text_tokenizer is not BaseTokenizer. 128 | text_normalizer (Optional[Union[Normalizer, Callable[[str], str]]]): Normalizer or callable which represents text normalizer. 129 | text_normalizer_call_kwargs (Optional[Dict]): Additional arguments for text_normalizer function. 130 | text_tokenizer_pad_id (Optional[int]): Index of padding. Should be specified if text_tokenizer is not BaseTokenizer. 131 | sup_data_types (Optional[List[str]]): List of supplementary data types. 132 | sup_data_path (Optional[Union[Path, str]]): A folder that contains or will contain supplementary data (e.g. pitch). 133 | max_duration (Optional[float]): Max duration of audio clips in seconds. All samples exceeding this will be 134 | pruned prior to training. Note: Requires "duration" to be set in the manifest file. It does not load 135 | audio to compute duration. Defaults to None which does not prune. 136 | min_duration (Optional[float]): Min duration of audio clips in seconds. All samples lower than this will be 137 | pruned prior to training. Note: Requires "duration" to be set in the manifest file. It does not load 138 | audio to compute duration. Defaults to None which does not prune. 139 | ignore_file (Optional[Union[str, Path]]): The location of a pickle-saved list of audio paths 140 | that will be pruned prior to training. Defaults to None which does not prune. 141 | trim (bool): Whether to apply `librosa.effects.trim` to trim leading and trailing silence from an audio 142 | signal. Defaults to False. 143 | trim_ref (Optional[float]): the reference amplitude. By default, it uses `np.max` and compares to the peak 144 | amplitude in the signal. 145 | trim_top_db (Optional[int]): the threshold (in decibels) below reference to consider as silence. 146 | Defaults to 60. 147 | trim_frame_length (Optional[int]): the number of samples per analysis frame. Defaults to 2048. 148 | trim_hop_length (Optional[int]): the number of samples between analysis frames. Defaults to 512. 149 | n_fft (int): The number of fft samples. Defaults to 1024 150 | win_length (Optional[int]): The length of the stft windows. Defaults to None which uses n_fft. 151 | hop_length (Optional[int]): The hope length between fft computations. Defaults to None which uses n_fft//4. 152 | window (str): One of 'hann', 'hamming', 'blackman','bartlett', 'none'. Which corresponds to the 153 | equivalent torch window function. 154 | n_mels (int): The number of mel filters. Defaults to 80. 155 | lowfreq (int): The lowfreq input to the mel filter calculation. Defaults to 0. 156 | highfreq (Optional[int]): The highfreq input to the mel filter calculation. Defaults to None. 157 | Keyword Args: 158 | log_mel_folder (Optional[Union[Path, str]]): The folder that contains or will contain log mel spectrograms. 159 | pitch_folder (Optional[Union[Path, str]]): The folder that contains or will contain pitch. 160 | voiced_mask_folder (Optional[Union[Path, str]]): The folder that contains or will contain voiced mask of the pitch 161 | p_voiced_folder (Optional[Union[Path, str]]): The folder that contains or will contain p_voiced(probability) of the pitch 162 | energy_folder (Optional[Union[Path, str]]): The folder that contains or will contain energy. 163 | durs_file (Optional[str]): String path to pickled durations location. 164 | durs_type (Optional[str]): Type of durations. Currently, supported only "aligner-based". 165 | use_beta_binomial_interpolator (Optional[bool]): Whether to use beta-binomial interpolator for calculating alignment prior matrix. Defaults to False. 166 | pitch_fmin (Optional[float]): The fmin input to librosa.pyin. Defaults to librosa.note_to_hz('C2'). 167 | pitch_fmax (Optional[float]): The fmax input to librosa.pyin. Defaults to librosa.note_to_hz('C7'). 168 | pitch_mean (Optional[float]): The mean that we use to normalize the pitch. 169 | pitch_std (Optional[float]): The std that we use to normalize the pitch. 170 | pitch_norm (Optional[bool]): Whether to normalize pitch or not. If True, requires providing either 171 | pitch_stats_path or (pitch_mean and pitch_std). 172 | pitch_stats_path (Optional[Path, str]): Path to file containing speaker level pitch statistics. 173 | audio_model (Optional[str]): Name of Pretrained Audio Model 174 | lm_model (Optional[str]): Name of Pretraeind LM 175 | xvector_model (Optional[str]): Name of Xvector Model 176 | """ 177 | super().__init__() 178 | 179 | # Initialize text tokenizer 180 | self.text_tokenizer = text_tokenizer 181 | 182 | self.phoneme_probability = None 183 | if isinstance(self.text_tokenizer, BaseTokenizer): 184 | self.text_tokenizer_pad_id = text_tokenizer.pad 185 | self.tokens = text_tokenizer.tokens 186 | self.phoneme_probability = getattr(self.text_tokenizer, "phoneme_probability", None) 187 | else: 188 | if text_tokenizer_pad_id is None: 189 | raise ValueError(f"text_tokenizer_pad_id must be specified if text_tokenizer is not BaseTokenizer") 190 | 191 | if tokens is None: 192 | raise ValueError(f"tokens must be specified if text_tokenizer is not BaseTokenizer") 193 | 194 | self.text_tokenizer_pad_id = text_tokenizer_pad_id 195 | self.tokens = tokens 196 | self.cache_text = True if self.phoneme_probability is None else False 197 | 198 | # Initialize text normalizer if specified 199 | self.text_normalizer = text_normalizer 200 | if self.text_normalizer is None: 201 | self.text_normalizer_call = None 202 | elif not PYNINI_AVAILABLE: 203 | raise ImportError("pynini is not installed, please install via nemo_text_processing/install_pynini.sh") 204 | else: 205 | self.text_normalizer_call = ( 206 | self.text_normalizer.normalize 207 | if isinstance(self.text_normalizer, Normalizer) 208 | else self.text_normalizer 209 | ) 210 | self.text_normalizer_call_kwargs = ( 211 | text_normalizer_call_kwargs if text_normalizer_call_kwargs is not None else {} 212 | ) 213 | 214 | # Initialize and read manifest file(s), filter out data by duration and ignore_file, compute base dir 215 | if isinstance(manifest_filepath, str): 216 | manifest_filepath = [manifest_filepath] 217 | self.manifest_filepath = manifest_filepath 218 | self.lengths = [] # Needed for BucketSampling 219 | 220 | data = [] 221 | total_duration = 0 222 | for manifest_file in self.manifest_filepath: 223 | with open(Path(manifest_file).expanduser(), 'r') as f: 224 | logging.info(f"Loading dataset from {manifest_file}.") 225 | for line in tqdm(f): 226 | item = json.loads(line) 227 | 228 | file_info = { 229 | "audio_filepath": item["audio_filepath"], 230 | "original_text": item["text"], 231 | "mel_filepath": item["mel_filepath"] if "mel_filepath" in item else None, 232 | "duration": item["duration"] if "duration" in item else None, 233 | "speaker_id": item["speaker"] if "speaker" in item else None, 234 | } 235 | 236 | if "normalized_text" in item: 237 | file_info["normalized_text"] = item["normalized_text"] 238 | elif "text_normalized" in item: 239 | file_info["normalized_text"] = item["text_normalized"] 240 | else: 241 | text = item["text"] 242 | if self.text_normalizer is not None: 243 | text = self.text_normalizer_call(text, **self.text_normalizer_call_kwargs) 244 | file_info["normalized_text"] = text 245 | 246 | if self.cache_text: 247 | file_info["text_tokens"] = self.text_tokenizer(file_info["normalized_text"]) 248 | 249 | data.append(file_info) 250 | # Calculating length of spectrogram from input audio for batch sampling 251 | 252 | 253 | if file_info["duration"] is None: 254 | logging.info( 255 | "Not all audio files have duration information. Duration logging will be disabled." 256 | ) 257 | total_duration = None 258 | 259 | if total_duration is not None: 260 | total_duration += item["duration"] 261 | 262 | logging.info(f"Loaded dataset with {len(data)} files.") 263 | if total_duration is not None: 264 | logging.info(f"Dataset contains {total_duration / 3600:.2f} hours.") 265 | 266 | self.data = ExtensiveTTSDataset.filter_files(data, ignore_file, min_duration, max_duration, total_duration) 267 | for x in self.data: 268 | self.lengths.append(os.path.getsize(x["audio_filepath"]) // (n_fft // 2)) 269 | self.base_data_dir = get_base_dir([item["audio_filepath"] for item in self.data]) 270 | 271 | # Initialize audio and mel related parameters 272 | self.sample_rate = sample_rate 273 | self.featurizer = WaveformFeaturizer(sample_rate=self.sample_rate) 274 | self.trim = trim 275 | self.trim_ref = trim_ref if trim_ref is not None else np.max 276 | self.trim_top_db = trim_top_db if trim_top_db is not None else 60 277 | self.trim_frame_length = trim_frame_length if trim_frame_length is not None else 2048 278 | self.trim_hop_length = trim_hop_length if trim_hop_length is not None else 512 279 | 280 | self.n_fft = n_fft 281 | self.n_mels = n_mels 282 | self.lowfreq = lowfreq 283 | self.highfreq = highfreq 284 | self.window = window 285 | self.win_length = win_length or self.n_fft 286 | self.hop_length = hop_length 287 | self.hop_len = self.hop_length or self.n_fft // 4 288 | self.fb = torch.tensor( 289 | librosa.filters.mel( 290 | sr=self.sample_rate, n_fft=self.n_fft, n_mels=self.n_mels, fmin=self.lowfreq, fmax=self.highfreq 291 | ), 292 | dtype=torch.float, 293 | ).unsqueeze(0) 294 | 295 | try: 296 | window_fn = WINDOW_FN_SUPPORTED[self.window] 297 | except KeyError: 298 | raise NotImplementedError( 299 | f"Current implementation doesn't support {self.window} window. " 300 | f"Please choose one from {list(WINDOW_FN_SUPPORTED.keys())}." 301 | ) 302 | 303 | self.stft = lambda x: torch.stft( 304 | input=x, 305 | n_fft=self.n_fft, 306 | hop_length=self.hop_len, 307 | win_length=self.win_length, 308 | window=window_fn(self.win_length, periodic=False).to(torch.float) if window_fn else None, 309 | return_complex=True, 310 | ) 311 | 312 | self.sup_data_extraction = kwargs.get("sup_data_extraction", False) 313 | # Initialize sup_data_path, sup_data_types and run preprocessing methods for every supplementary data type 314 | if sup_data_path is not None: 315 | Path(sup_data_path).mkdir(parents=True, exist_ok=True) 316 | self.sup_data_path = sup_data_path 317 | 318 | self.pretrained_model_name = None 319 | 320 | if 'audio_codec' in sup_data_types: 321 | self.codec_token_sum = kwargs.get("codec_sum", True) 322 | 323 | lm_model = kwargs.get("lm_model", None) 324 | if 'pretrained_lm' in sup_data_types: 325 | if lm_model is None: 326 | raise ValueError( 327 | "Name of Pretrained LM should be specified if you want to use pretraiend lm embedding as sup_data") 328 | if self.pretrained_model_name is None: 329 | self.pretrained_model_name = lm_model.split("/")[-1].split("-")[0] 330 | else: 331 | self.pretrained_model_name += "_" + lm_model.split("/")[-1].split("-")[0] 332 | 333 | audio_model = kwargs.get("audio_model", None) 334 | if 'pretrained_audio' in sup_data_types: 335 | if audio_model is None: 336 | raise ValueError( 337 | "Name of Pretrained Audio model should be specified if you want to use pretraiend audio embedding as sup_data") 338 | 339 | 340 | if audio_model is not None and lm_model is not None: 341 | self.pretrained_model_name = lm_model.split("/")[-1].split("-")[0] + "_" + \ 342 | audio_model.split("/")[-1].split('-')[0] 343 | elif audio_model is not None and lm_model is None: 344 | self.pretrained_model_name = audio_model.split("/")[-1].split('-')[0] 345 | elif audio_model is None and lm_model is not None: 346 | self.pretrained_model_name = lm_model.split("/")[-1].split("-")[0] 347 | else: 348 | self.pretrained_model_name = None 349 | 350 | if self.pretrained_model_name is not None: 351 | self.pretrained_model_name = Path(self.sup_data_path) / self.pretrained_model_name 352 | if not os.path.exists(self.pretrained_model_name): 353 | os.makedirs(self.pretrained_model_name, exist_ok=True) 354 | 355 | self.sup_data_types = [] 356 | if sup_data_types is not None: 357 | for d_as_str in sup_data_types: 358 | try: 359 | sup_data_type = DATA_STR2DATA_CLASS[d_as_str] 360 | except KeyError: 361 | raise NotImplementedError(f"Current implementation doesn't support {d_as_str} type.") 362 | 363 | self.sup_data_types.append(sup_data_type) 364 | 365 | if ("voiced_mask" in sup_data_types or "p_voiced" in sup_data_types) and ("pitch" not in sup_data_types): 366 | raise ValueError( 367 | "Please add 'pitch' to sup_data_types in YAML because 'pitch' is required when using either " 368 | "'voiced_mask' or 'p_voiced' or both." 369 | ) 370 | 371 | self.sup_data_types_set = set(self.sup_data_types) 372 | 373 | for data_type in self.sup_data_types: 374 | getattr(self, f"add_{data_type.name}")(**kwargs) 375 | 376 | @staticmethod 377 | def filter_files(data, ignore_file, min_duration, max_duration, total_duration): 378 | if ignore_file: 379 | logging.info(f"Using {ignore_file} to prune dataset.") 380 | with open(Path(ignore_file).expanduser(), "rb") as f: 381 | wavs_to_ignore = set(pickle.load(f)) 382 | 383 | filtered_data: List[Dict] = [] 384 | pruned_duration = 0 if total_duration is not None else None 385 | pruned_items = 0 386 | for item in data: 387 | audio_path = item['audio_filepath'] 388 | 389 | # Prune data according to min/max_duration & the ignore file 390 | if total_duration is not None: 391 | if (min_duration and item["duration"] < min_duration) or ( 392 | max_duration and item["duration"] > max_duration 393 | ): 394 | pruned_duration += item["duration"] 395 | pruned_items += 1 396 | continue 397 | 398 | if ignore_file and (audio_path in wavs_to_ignore): 399 | pruned_items += 1 400 | pruned_duration += item["duration"] 401 | wavs_to_ignore.remove(audio_path) 402 | continue 403 | 404 | filtered_data.append(item) 405 | 406 | logging.info(f"Pruned {pruned_items} files. Final dataset contains {len(filtered_data)} files") 407 | if pruned_duration is not None: 408 | logging.info( 409 | f"Pruned {pruned_duration / 3600:.2f} hours. Final dataset contains " 410 | f"{(total_duration - pruned_duration) / 3600:.2f} hours." 411 | ) 412 | 413 | return filtered_data 414 | 415 | def add_log_mel(self, **kwargs): 416 | self.log_mel_folder = kwargs.pop('log_mel_folder', None) 417 | 418 | if self.log_mel_folder is None: 419 | self.log_mel_folder = Path(self.sup_data_path) / LogMel.name 420 | elif isinstance(self.log_mel_folder, str): 421 | self.log_mel_folder = Path(self.log_mel_folder) 422 | 423 | self.log_mel_folder.mkdir(exist_ok=True, parents=True) 424 | 425 | def add_durations(self, **kwargs): 426 | durs_file = kwargs.pop('durs_file') 427 | durs_type = kwargs.pop('durs_type') 428 | 429 | audio_stem2durs = torch.load(durs_file) 430 | self.durs = [] 431 | 432 | for tag in [Path(d["audio_filepath"]).stem for d in self.data]: 433 | durs = audio_stem2durs[tag] 434 | if durs_type == "aligner-based": 435 | self.durs.append(durs) 436 | else: 437 | raise NotImplementedError( 438 | f"{durs_type} duration type is not supported. Only aligner-based is supported at this moment." 439 | ) 440 | 441 | def add_align_prior_matrix(self, **kwargs): 442 | self.use_beta_binomial_interpolator = kwargs.pop('use_beta_binomial_interpolator', False) 443 | if not self.cache_text: 444 | if 'use_beta_binomial_interpolator' in kwargs and not self.use_beta_binomial_interpolator: 445 | logging.warning( 446 | "phoneme_probability is not None, but use_beta_binomial_interpolator=False, we" 447 | " set use_beta_binomial_interpolator=True manually to use phoneme_probability." 448 | ) 449 | self.use_beta_binomial_interpolator = True 450 | 451 | if self.use_beta_binomial_interpolator: 452 | self.beta_binomial_interpolator = BetaBinomialInterpolator() 453 | 454 | def add_pitch(self, **kwargs): 455 | self.pitch_folder = kwargs.pop('pitch_folder', None) 456 | 457 | if self.pitch_folder is None: 458 | self.pitch_folder = Path(self.sup_data_path) / Pitch.name 459 | elif isinstance(self.pitch_folder, str): 460 | self.pitch_folder = Path(self.pitch_folder) 461 | 462 | self.pitch_folder.mkdir(exist_ok=True, parents=True) 463 | 464 | self.pitch_fmin = kwargs.pop("pitch_fmin", librosa.note_to_hz('C2')) 465 | self.pitch_fmax = kwargs.pop("pitch_fmax", librosa.note_to_hz('C7')) 466 | self.pitch_mean = kwargs.pop("pitch_mean", None) 467 | self.pitch_std = kwargs.pop("pitch_std", None) 468 | self.pitch_norm = kwargs.pop("pitch_norm", False) 469 | pitch_stats_path = kwargs.pop("pitch_stats_path", None) 470 | 471 | if self.pitch_norm: 472 | # XOR to validate that both or neither pitch mean and std are provided 473 | assert (self.pitch_mean is None) == ( 474 | self.pitch_std is None 475 | ), f"Found only 1 of (pitch_mean, pitch_std): ({self.pitch_mean}, {self.pitch_std})" 476 | 477 | # XOR to validate that exactly 1 of (pitch_mean, pitch_std) or pitch_stats_path is provided. 478 | assert (self.pitch_mean is None) != (pitch_stats_path is None), ( 479 | f"pitch_norm requires exactly 1 of (pitch_mean, pitch_std) or pitch_stats_path. " 480 | f"Provided: ({self.pitch_mean}, {self.pitch_std}) and {pitch_stats_path}" 481 | ) 482 | 483 | if pitch_stats_path is not None: 484 | with open(Path(pitch_stats_path), 'r', encoding="utf-8") as pitch_f: 485 | self.pitch_stats = json.load(pitch_f) 486 | 487 | # saving voiced_mask and p_voiced with pitch 488 | def add_voiced_mask(self, **kwargs): 489 | self.voiced_mask_folder = kwargs.pop('voiced_mask_folder', None) 490 | 491 | if self.voiced_mask_folder is None: 492 | self.voiced_mask_folder = Path(self.sup_data_path) / Voiced_mask.name 493 | 494 | self.voiced_mask_folder.mkdir(exist_ok=True, parents=True) 495 | 496 | def add_p_voiced(self, **kwargs): 497 | self.p_voiced_folder = kwargs.pop('p_voiced_folder', None) 498 | 499 | if self.p_voiced_folder is None: 500 | self.p_voiced_folder = Path(self.sup_data_path) / P_voiced.name 501 | 502 | self.p_voiced_folder.mkdir(exist_ok=True, parents=True) 503 | 504 | def add_energy(self, **kwargs): 505 | self.energy_folder = kwargs.pop('energy_folder', None) 506 | 507 | if self.energy_folder is None: 508 | self.energy_folder = Path(self.sup_data_path) / Energy.name 509 | elif isinstance(self.energy_folder, str): 510 | self.energy_folder = Path(self.energy_folder) 511 | 512 | self.energy_folder.mkdir(exist_ok=True, parents=True) 513 | 514 | def add_speaker_id(self, **kwargs): 515 | pass 516 | 517 | def get_spec(self, audio): 518 | with torch.cuda.amp.autocast(enabled=False): 519 | spec = self.stft(audio) 520 | if spec.dtype in [torch.cfloat, torch.cdouble]: 521 | spec = torch.view_as_real(spec) 522 | spec = torch.sqrt(spec.pow(2).sum(-1) + EPSILON) 523 | return spec 524 | 525 | def get_log_mel(self, audio): 526 | with torch.cuda.amp.autocast(enabled=False): 527 | spec = self.get_spec(audio) 528 | mel = torch.matmul(self.fb.to(spec.dtype), spec) 529 | log_mel = torch.log(torch.clamp(mel, min=torch.finfo(mel.dtype).tiny)) 530 | return log_mel 531 | 532 | 533 | def add_xvector(self, **kwargs): 534 | self.speaker_path = self.pretrained_model_name / "speaker_embedding" 535 | self.xvector_model_name = kwargs.get('xvector_model', None) 536 | if self.xvector_model_name is None: 537 | raise ValueError("You should specify xvector model if you want to xvector as sup data, key name : xvector_model") 538 | self.xvector_model = None 539 | if self.sup_data_extraction: 540 | self.xvector_model = AutoModelForAudioXVector.from_pretrained(self.xvector_model_name).eval() 541 | self.speaker_folder = Path(self.pretrained_model_name) / "speaker_embedding" 542 | self.speaker_folder.mkdir(exist_ok=True, parents=True) 543 | 544 | def add_audio_codec(self, **kwargs): 545 | codec_model = kwargs.get("codec_model", None) 546 | if codec_model is None: 547 | raise ValueError("Audio Codec Model Name Needed, key name : codec_model") 548 | self.codec_model_name = codec_model 549 | self.codec_model = None 550 | if self.sup_data_extraction: 551 | if codec_model == 'encodec': 552 | if self.sample_rate==24000: 553 | self.codec_model = EncodecModel.encodec_model_24khz() 554 | elif self.sample_rate==48000: 555 | self.codec_model = EncodecModel.encodec_model_48khz() 556 | self.codec_model.set_target_bandwidth(6.0) 557 | else: 558 | raise ValueError("Current version only supports EnCodec Model") 559 | 560 | self.codec_path = Path(self.sup_data_path) / self.codec_model_name 561 | self.codec_path.mkdir(exist_ok=True, parents=True) 562 | 563 | def _load_pretrained_models(self, sup_datas): 564 | if AudioCodec in sup_datas: 565 | if self.codec_model_name == 'encodec': 566 | self.codec_model = EncodecModel.encodec_model_24khz() 567 | self.codec_model.set_target_bandwidth(6.0) 568 | else: 569 | raise ValueError("Current version only supports EnCodec Model") 570 | if Xvector in sup_datas: 571 | self.xvector_model = AutoModelForAudioXVector.from_pretrained(self.xvector_model_name).eval() 572 | 573 | def _load_tensor_audio(self, file_name, target_sr): 574 | audio_input, sr = torchaudio.load(file_name) 575 | if sr != target_sr: 576 | resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr) 577 | audio_input = resampler(audio_input) 578 | return audio_input 579 | 580 | def _load_audio_codec(self, audio): 581 | # input audio : [1, T], encodec model needs [1, 1, T] 582 | audio = audio.unsqueeze(0) 583 | 584 | if self.codec_model is None: 585 | self._load_pretrained_models([AudioCodec]) 586 | 587 | with torch.no_grad(): 588 | encoded_frames = self.codec_model.encode(audio) 589 | codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # [B, n_q, T] 590 | # codes = torch.sum(codes, dim=1) # [B=1, T] 591 | return codes.squeeze(0) # [T] 592 | 593 | def __getitem__(self, index): 594 | sample = self.data[index] 595 | 596 | # Let's keep audio name and all internal directories in rel_audio_path_as_text_id to avoid any collisions 597 | rel_audio_path = Path(sample["audio_filepath"]).relative_to(self.base_data_dir).with_suffix("") 598 | rel_audio_path_as_text_id = str(rel_audio_path).replace("/", "_") 599 | 600 | # Load audio 601 | features = self.featurizer.process( 602 | sample["audio_filepath"], 603 | trim=self.trim, 604 | trim_ref=self.trim_ref, 605 | trim_top_db=self.trim_top_db, 606 | trim_frame_length=self.trim_frame_length, 607 | trim_hop_length=self.trim_hop_length, 608 | ) 609 | audio, audio_length = features, torch.tensor(features.shape[0]).long() 610 | 611 | if "text_tokens" in sample: 612 | text = torch.tensor(sample["text_tokens"]).long() 613 | text_length = torch.tensor(len(sample["text_tokens"])).long() 614 | else: 615 | tokenized = self.text_tokenizer(sample["normalized_text"]) 616 | text = torch.tensor(tokenized).long() 617 | text_length = torch.tensor(len(tokenized)).long() 618 | 619 | # Load mel if needed 620 | log_mel, log_mel_length = None, None 621 | if LogMel in self.sup_data_types_set: 622 | mel_path = sample["mel_filepath"] 623 | 624 | if mel_path is not None and Path(mel_path).exists(): 625 | log_mel = torch.load(mel_path) 626 | else: 627 | mel_path = self.log_mel_folder / f"{rel_audio_path_as_text_id}.pt" 628 | 629 | if mel_path.exists(): 630 | log_mel = torch.load(mel_path) 631 | else: 632 | log_mel = self.get_log_mel(audio) 633 | torch.save(log_mel, mel_path) 634 | 635 | log_mel = log_mel.squeeze(0) 636 | log_mel_length = torch.tensor(log_mel.shape[1]).long() 637 | 638 | # Load durations if needed 639 | durations = None 640 | if Durations in self.sup_data_types_set: 641 | durations = self.durs[index] 642 | 643 | # Load alignment prior matrix if needed 644 | align_prior_matrix = None 645 | if AlignPriorMatrix in self.sup_data_types_set: 646 | mel_len = self.get_log_mel(audio).shape[2] 647 | if self.use_beta_binomial_interpolator: 648 | align_prior_matrix = torch.from_numpy(self.beta_binomial_interpolator(mel_len, text_length.item())) 649 | else: 650 | align_prior_matrix = torch.from_numpy(beta_binomial_prior_distribution(text_length, mel_len)) 651 | 652 | non_exist_voiced_index = [] 653 | my_var = locals() 654 | for i, voiced_item in enumerate([Pitch, Voiced_mask, P_voiced]): 655 | if voiced_item in self.sup_data_types_set: 656 | voiced_folder = getattr(self, f"{voiced_item.name}_folder") 657 | voiced_filepath = voiced_folder / f"{rel_audio_path_as_text_id}.pt" 658 | if voiced_filepath.exists(): 659 | my_var.__setitem__(voiced_item.name, torch.load(voiced_filepath).float()) 660 | else: 661 | non_exist_voiced_index.append((i, voiced_item.name, voiced_filepath)) 662 | 663 | if len(non_exist_voiced_index) != 0: 664 | voiced_tuple = librosa.pyin( 665 | audio.numpy(), 666 | fmin=self.pitch_fmin, 667 | fmax=self.pitch_fmax, 668 | frame_length=self.win_length, 669 | sr=self.sample_rate, 670 | fill_na=0.0, 671 | ) 672 | for (i, voiced_name, voiced_filepath) in non_exist_voiced_index: 673 | my_var.__setitem__(voiced_name, torch.from_numpy(voiced_tuple[i]).float()) 674 | torch.save(my_var.get(voiced_name), voiced_filepath) 675 | 676 | pitch = my_var.get('pitch', None) 677 | pitch_length = my_var.get('pitch_length', None) 678 | voiced_mask = my_var.get('voiced_mask', None) 679 | p_voiced = my_var.get('p_voiced', None) 680 | 681 | # normalize pitch if requested. 682 | if pitch is not None: 683 | pitch_length = torch.tensor(len(pitch)).long() 684 | if self.pitch_norm: 685 | if self.pitch_mean is not None and self.pitch_std is not None: 686 | sample_pitch_mean = self.pitch_mean 687 | sample_pitch_std = self.pitch_std 688 | elif self.pitch_stats: 689 | if "speaker_id" in sample and str(sample["speaker_id"]) in self.pitch_stats: 690 | pitch_stats = self.pitch_stats[str(sample["speaker_id"])] 691 | elif "default" in self.pitch_stats: 692 | pitch_stats = self.pitch_stats["default"] 693 | else: 694 | raise ValueError(f"Could not find pitch stats for {sample}.") 695 | sample_pitch_mean = pitch_stats["pitch_mean"] 696 | sample_pitch_std = pitch_stats["pitch_std"] 697 | else: 698 | raise ValueError(f"Missing statistics for pitch normalization.") 699 | 700 | pitch -= sample_pitch_mean 701 | pitch[pitch == -sample_pitch_mean] = 0.0 # Zero out values that were previously zero 702 | pitch /= sample_pitch_std 703 | 704 | # Load energy if needed 705 | energy, energy_length = None, None 706 | if Energy in self.sup_data_types_set: 707 | energy_path = self.energy_folder / f"{rel_audio_path_as_text_id}.pt" 708 | 709 | if energy_path.exists(): 710 | energy = torch.load(energy_path).float() 711 | else: 712 | spec = self.get_spec(audio) 713 | energy = torch.linalg.norm(spec.squeeze(0), axis=0).float() 714 | torch.save(energy, energy_path) 715 | 716 | energy_length = torch.tensor(len(energy)).long() 717 | 718 | # Load speaker id if needed 719 | speaker_id = None 720 | if SpeakerID in self.sup_data_types_set: 721 | speaker_id = torch.tensor(sample["speaker_id"]).long() 722 | 723 | file_name = sample['audio_filepath'].split("/")[-1] 724 | file_name = file_name.replace("wav", "pt") 725 | 726 | speaker_embedding = None 727 | if Xvector in self.sup_data_types_set: 728 | speaker_file = os.path.join(self.speaker_path, file_name) 729 | if os.path.exists(speaker_file): 730 | speaker_embedding = torch.load(speaker_file) 731 | else: 732 | audio_input = self._load_tensor_audio(sample['audio_filepath'], 16000) 733 | if self.xvector_model is None: 734 | self._load_pretrained_models([Xvector]) 735 | with torch.no_grad(): 736 | speaker_embedding = self.xvector_model(input_values=audio_input).embeddings 737 | speaker_embedding = speaker_embedding.squeeze(0) 738 | torch.save(speaker_embedding, speaker_file) 739 | speaker_embedding = speaker_embedding.detach() 740 | 741 | codec, codec_len = None, None 742 | if AudioCodec in self.sup_data_types_set: 743 | codec_file = os.path.join(self.codec_path, file_name) 744 | if os.path.exists(codec_file): 745 | codec = torch.load(codec_file) 746 | else: 747 | audio_input = self._load_tensor_audio(sample['audio_filepath'], 24000) 748 | codec = self._load_audio_codec(audio_input) 749 | torch.save(codec, codec_file) 750 | if self.codec_token_sum: 751 | codec = torch.sum(codec, dim=0) 752 | if isinstance(self.codec_tokenizer, CodecTokenizer): 753 | codec = self.codec_tokenizer.token_to_id(codec) 754 | codec_len = torch.tensor(codec.size(0)).long() 755 | else: 756 | codec_len = torch.tensor(codec.size(1)).long() 757 | 758 | return ( 759 | audio, 760 | audio_length, 761 | text, 762 | text_length, 763 | log_mel, 764 | log_mel_length, 765 | durations, 766 | align_prior_matrix, 767 | pitch, 768 | pitch_length, 769 | energy, 770 | energy_length, 771 | speaker_id, 772 | voiced_mask, 773 | p_voiced, 774 | speaker_embedding, 775 | codec, 776 | codec_len, 777 | ) 778 | 779 | def __len__(self): 780 | 781 | return len(self.data) 782 | 783 | def general_collate_fn(self, batch): 784 | ( 785 | _, 786 | audio_lengths, 787 | _, 788 | tokens_lengths, 789 | _, 790 | log_mel_lengths, 791 | durations_list, 792 | align_prior_matrices_list, 793 | pitches, 794 | pitches_lengths, 795 | energies, 796 | energies_lengths, 797 | _, 798 | voiced_masks, 799 | p_voiceds, 800 | speaker_embeddings, 801 | codecs, 802 | codec_lengths, 803 | ) = zip(*batch) 804 | 805 | max_audio_len = max(audio_lengths).item() 806 | max_tokens_len = max(tokens_lengths).item() 807 | max_log_mel_len = max(log_mel_lengths) if LogMel in self.sup_data_types_set else None 808 | max_durations_len = max([len(i) for i in durations_list]) if Durations in self.sup_data_types_set else None 809 | max_pitches_len = max(pitches_lengths).item() if Pitch in self.sup_data_types_set else None 810 | max_energies_len = max(energies_lengths).item() if Energy in self.sup_data_types_set else None 811 | max_codec_len = max(codec_lengths).item() if AudioCodec in self.sup_data_types_set else None 812 | 813 | if LogMel in self.sup_data_types_set: 814 | log_mel_pad = torch.finfo(batch[0][4].dtype).tiny 815 | 816 | align_prior_matrices = ( 817 | torch.zeros( 818 | len(align_prior_matrices_list), 819 | max([prior_i.shape[0] for prior_i in align_prior_matrices_list]), 820 | max([prior_i.shape[1] for prior_i in align_prior_matrices_list]), 821 | ) 822 | if AlignPriorMatrix in self.sup_data_types_set 823 | else [] 824 | ) 825 | audios, tokens, log_mels, durations_list, pitches, energies, speaker_ids, voiced_masks, p_voiceds, codecs = ( 826 | [], 827 | [], 828 | [], 829 | [], 830 | [], 831 | [], 832 | [], 833 | [], 834 | [], 835 | [], 836 | ) 837 | 838 | for i, sample_tuple in enumerate(batch): 839 | ( 840 | audio, 841 | audio_len, 842 | token, 843 | token_len, 844 | log_mel, 845 | log_mel_len, 846 | durations, 847 | align_prior_matrix, 848 | pitch, 849 | pitch_length, 850 | energy, 851 | energy_length, 852 | speaker_id, 853 | voiced_mask, 854 | p_voiced, 855 | codec, 856 | codec_length, 857 | ) = sample_tuple 858 | 859 | audio = general_padding(audio, audio_len.item(), max_audio_len) 860 | audios.append(audio) 861 | 862 | token = general_padding(token, token_len.item(), max_tokens_len, pad_value=self.text_tokenizer_pad_id) 863 | tokens.append(token) 864 | 865 | if LogMel in self.sup_data_types_set: 866 | log_mels.append(general_padding(log_mel, log_mel_len, max_log_mel_len, pad_value=log_mel_pad)) 867 | 868 | if Durations in self.sup_data_types_set: 869 | durations_list.append(general_padding(durations, len(durations), max_durations_len)) 870 | 871 | if AlignPriorMatrix in self.sup_data_types_set: 872 | align_prior_matrices[ 873 | i, : align_prior_matrix.shape[0], : align_prior_matrix.shape[1] 874 | ] = align_prior_matrix 875 | 876 | if Pitch in self.sup_data_types_set: 877 | pitches.append(general_padding(pitch, pitch_length.item(), max_pitches_len)) 878 | 879 | if Voiced_mask in self.sup_data_types_set: 880 | voiced_masks.append(general_padding(voiced_mask, pitch_length.item(), max_pitches_len)) 881 | 882 | if P_voiced in self.sup_data_types_set: 883 | p_voiceds.append(general_padding(p_voiced, pitch_length.item(), max_pitches_len)) 884 | 885 | if Energy in self.sup_data_types_set: 886 | energies.append(general_padding(energy, energy_length.item(), max_energies_len)) 887 | 888 | if SpeakerID in self.sup_data_types_set: 889 | speaker_ids.append(speaker_id) 890 | 891 | if AudioCodec in self.sup_data_types_set: 892 | if self.codec_token_sum: 893 | codec = self.codec_tokenizer.padding(codec, max_length=max_codec_len) 894 | else: 895 | pad = torch.zeros(8, max_codec_len - codec_length.item()) 896 | codec = torch.cat([codec, pad], dim=1) 897 | codecs.append(codec.detach()) 898 | 899 | data_dict = { 900 | "audio": torch.stack(audios), 901 | "audio_lens": torch.stack(audio_lengths), 902 | "text": torch.stack(tokens), 903 | "text_lens": torch.stack(tokens_lengths), 904 | "log_mel": torch.stack(log_mels) if LogMel in self.sup_data_types_set else None, 905 | "log_mel_lens": torch.stack(log_mel_lengths) if LogMel in self.sup_data_types_set else None, 906 | "durations": torch.stack(durations_list) if Durations in self.sup_data_types_set else None, 907 | "align_prior_matrix": align_prior_matrices if AlignPriorMatrix in self.sup_data_types_set else None, 908 | "pitch": torch.stack(pitches) if Pitch in self.sup_data_types_set else None, 909 | "pitch_lens": torch.stack(pitches_lengths) if Pitch in self.sup_data_types_set else None, 910 | "energy": torch.stack(energies) if Energy in self.sup_data_types_set else None, 911 | "energy_lens": torch.stack(energies_lengths) if Energy in self.sup_data_types_set else None, 912 | "speaker_id": torch.stack(speaker_ids) if SpeakerID in self.sup_data_types_set else None, 913 | "voiced_mask": torch.stack(voiced_masks) if Voiced_mask in self.sup_data_types_set else None, 914 | "p_voiced": torch.stack(p_voiceds) if P_voiced in self.sup_data_types_set else None, 915 | "xvector": torch.stack(speaker_embeddings) if Xvector in self.sup_data_types_set else None, 916 | "audio_codec": torch.stack(codecs) if AudioCodec in self.sup_data_types_set else None, 917 | "audio_codec_lens": torch.stack(codec_lengths) if AudioCodec in self.sup_data_types_set else None, 918 | } 919 | 920 | return data_dict 921 | 922 | 923 | def join_data(self, data_dict): 924 | result = [] 925 | for data_type in MAIN_DATA_TYPES + self.sup_data_types: 926 | result.append(data_dict[data_type.name]) 927 | 928 | if issubclass(data_type, TTSDataType) and issubclass(data_type, WithLens): 929 | result.append(data_dict[f"{data_type.name}_lens"]) 930 | 931 | return tuple(result) 932 | 933 | def _collate_fn(self, batch): 934 | data_dict = self.general_collate_fn(batch) 935 | joined_data = self.join_data(data_dict) 936 | return joined_data -------------------------------------------------------------------------------- /torchdata/data_type.py: -------------------------------------------------------------------------------- 1 | from nemo.collections.tts.torch.tts_data_types import TTSDataType, WithLens, VALID_SUPPLEMENTARY_DATA_TYPES, MAIN_DATA_TYPES 2 | 3 | 4 | class AudioCodec(TTSDataType, WithLens): 5 | name = 'audio_codec' 6 | 7 | class Xvector(TTSDataType): 8 | name = 'xvector' 9 | 10 | EXTENSIVE_DATA_TYPES = [ 11 | Xvector, 12 | AudioCodec, 13 | ] 14 | 15 | DATA_STR2DATA_CLASS = {d.name: d for d in MAIN_DATA_TYPES + VALID_SUPPLEMENTARY_DATA_TYPES + EXTENSIVE_DATA_TYPES} -------------------------------------------------------------------------------- /torchdata/text_preprocess.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm import tqdm 3 | from pathlib import Path 4 | 5 | from nemo_text_processing.text_normalization.normalize import Normalizer 6 | 7 | manifest_path = Path("../data/vctk/24k") 8 | print(manifest_path) 9 | text_normalizer = Normalizer(lang='en', input_case='cased', whitelist="../sup_data/text/whitelist/lj_speech.tsv") 10 | text_norm_kwargs = {"verbose": False, "punct_pre_process": True, "punct_post_process": True} 11 | # for split in ['train', 'valid']: 12 | for split in ['test']: 13 | manifest_file = manifest_path / f"{split}_manifest.json" 14 | normalized_manifest = [] 15 | for line in tqdm(open(manifest_file).readlines()): 16 | item = json.loads(line) 17 | 18 | if 'normalized_text' in item: 19 | continue 20 | elif 'text' in item: 21 | item['normalized_text'] = text_normalizer.normalize(item['text'], **text_norm_kwargs) 22 | normalized_manifest.append(item) 23 | normalized_path = manifest_path / "text_normalized" 24 | if not normalized_path.exists(): 25 | normalized_path.mkdir(exist_ok=True, parents=True) 26 | normalized_path = manifest_path / "text_normalized" / f"{split}_manifest.json" 27 | 28 | with open(normalized_path, 'w') as fp: 29 | fp.writelines([json.dumps(x, ensure_ascii=False)+'\n' for x in normalized_manifest]) -------------------------------------------------------------------------------- /vits_train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pytorch_lightning as pl 16 | import torch 17 | 18 | from nemo.collections.common.callbacks import LogEpochTimeCallback 19 | from nemo.core.config import hydra_runner 20 | from nemo.utils.exp_manager import exp_manager 21 | 22 | from model.vits import VitsModel 23 | 24 | 25 | @hydra_runner(config_path="conf", config_name="ref_vits") 26 | def main(cfg): 27 | torch.multiprocessing.set_sharing_strategy('file_system') 28 | trainer = pl.Trainer(replace_sampler_ddp=False, **cfg.trainer) 29 | exp_manager(trainer, cfg.get("exp_manager", None)) 30 | model = VitsModel(cfg=cfg.model, trainer=trainer) 31 | 32 | trainer.callbacks.extend([pl.callbacks.LearningRateMonitor(), LogEpochTimeCallback()]) 33 | trainer.fit(model) 34 | 35 | 36 | if __name__ == '__main__': 37 | main() # noqa pylint: disable=no-value-for-parameter --------------------------------------------------------------------------------