├── .gitignore ├── LICENSE ├── README.md ├── config ├── ablations │ ├── ablate-all.py │ ├── ablate-augment.py │ ├── ablate-loudness.py │ ├── ablate-multiloud.py │ ├── ablate-sppg.py │ ├── ablate-variable-pitch.py │ └── ablate-viterbi.py ├── baselines │ ├── mels.py │ ├── vocos.py │ └── world.py ├── fargan-advlr1e6-warmup.py ├── fargan-fdisc.py ├── fargan-long-noadv.py ├── fargan-zeroshot-shuffle.py ├── fargan-zeroshot.py ├── fargan.py ├── hparams │ ├── bands │ │ ├── 16band.py │ │ ├── 2band.py │ │ ├── 32band.py │ │ ├── 4band.py │ │ └── 8band.py │ └── sppg │ │ ├── sppg-constant-0025.py │ │ ├── sppg-constant-005.py │ │ ├── sppg-constant-0075.py │ │ ├── sppg-constant-010.py │ │ ├── sppg-percentile-080.py │ │ ├── sppg-percentile-085.py │ │ ├── sppg-percentile-090.py │ │ ├── sppg-percentile-095.py │ │ ├── sppg-top-3.py │ │ ├── sppg-top-4.py │ │ ├── sppg-top-5.py │ │ └── sppg-top-6.py ├── promonet-fdisc.py ├── promonet-zeroshot-shuffle.py ├── promonet-zeroshot.py └── promonet.py ├── data ├── cache │ └── .gitkeep └── datasets │ └── .gitkeep ├── notebooks ├── parse_results.ipynb ├── ppgs │ ├── ppg_interpolation.ipynb │ ├── ppgs_objective_eval.ipynb │ └── website_examples.ipynb ├── select-speakers.ipynb └── website_examples.ipynb ├── promonet ├── __init__.py ├── adapt │ ├── __init__.py │ ├── __main__.py │ └── core.py ├── assets │ ├── augmentations │ │ ├── daps-loudness.json │ │ ├── daps-pitch.json │ │ ├── libritts-loudness.json │ │ ├── libritts-pitch.json │ │ ├── vctk-loudness.json │ │ └── vctk-pitch.json │ ├── configs │ │ ├── reconstruction-quality-mushra.yaml │ │ ├── reconstruction-similarity-abx.yaml │ │ ├── shifting-quality-mushra.yaml │ │ └── stretching-quality-mushra.yaml │ ├── partitions │ │ ├── adaptation │ │ │ ├── daps.json │ │ │ ├── libritts.json │ │ │ └── vctk.json │ │ └── multispeaker │ │ │ ├── daps.json │ │ │ └── vctk.json │ └── stats │ │ ├── .gitkeep │ │ ├── vctk-256-loudness-pitch-viterbi.pt │ │ ├── vctk-256-loudness-pitch.pt │ │ ├── vctk-256-viterbi.pt │ │ └── vctk-train-speaker-averages-viterbi.json ├── baseline │ ├── __init__.py │ ├── mels.py │ └── world.py ├── config │ ├── __init__.py │ ├── defaults.py │ └── static.py ├── convert.py ├── data │ ├── __init__.py │ ├── augment │ │ ├── __init__.py │ │ ├── __main__.py │ │ ├── core.py │ │ ├── loudness.py │ │ └── pitch.py │ ├── collate.py │ ├── dataset.py │ ├── download │ │ ├── __init__.py │ │ ├── __main__.py │ │ └── core.py │ ├── loader.py │ ├── pack │ │ ├── __init__.py │ │ ├── __main__.py │ │ └── core.py │ ├── preprocess │ │ ├── __init__.py │ │ ├── __main__.py │ │ └── core.py │ └── sampler.py ├── edit │ ├── __init__.py │ ├── __main__.py │ ├── core.py │ └── grid.py ├── evaluate │ ├── __init__.py │ ├── __main__.py │ ├── core.py │ └── metrics.py ├── load.py ├── model │ ├── __init__.py │ ├── cargan.py │ ├── core.py │ ├── discriminator.py │ ├── export │ │ ├── __init__.py │ │ ├── __main__.py │ │ └── core.py │ ├── fargan.py │ ├── generator.py │ ├── hifigan.py │ └── vocos.py ├── partition │ ├── __init__.py │ ├── __main__.py │ └── core.py ├── plot │ ├── __init__.py │ ├── __main__.py │ ├── core.py │ └── speaker │ │ ├── __init__.py │ │ └── core.py ├── preprocess │ ├── __init__.py │ ├── __main__.py │ ├── core.py │ ├── harmonics.py │ ├── loudness.py │ ├── speaker.py │ ├── spectrogram.py │ └── text.py ├── synthesize │ ├── __init__.py │ ├── __main__.py │ └── core.py └── train │ ├── __init__.py │ ├── __main__.py │ ├── core.py │ └── loss.py ├── results └── .gitkeep ├── run.sh └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info 2 | __pycache__/ 3 | .ipynb_checkpoints 4 | build/ 5 | dist/ 6 | 7 | data/cache/* 8 | !data/cache/.gitkeep 9 | data/datasets/* 10 | !data/datasets/.gitkeep 11 | eval/* 12 | !eval/objective/.gitkeep 13 | results/* 14 | !results/.gitkeep 15 | runs/* 16 | !runs/.gitkeep 17 | notebooks/* 18 | !notebooks/*.ipynb 19 | 20 | .vscode/ 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Max Morrison 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /config/ablations/ablate-all.py: -------------------------------------------------------------------------------- 1 | MODULE = 'promonet' 2 | 3 | # Configuration name 4 | CONFIG = 'ablate-all' 5 | 6 | # Whether to use loudness augmentation 7 | AUGMENT_LOUDNESS = False 8 | 9 | # Whether to use pitch augmentation 10 | AUGMENT_PITCH = False 11 | 12 | # Number of bands of A-weighted loudness 13 | LOUDNESS_BANDS = 1 14 | 15 | # Type of sparsification used for ppgs 16 | # One of ['constant', 'percentile', 'topk', None] 17 | SPARSE_PPG_METHOD = None 18 | 19 | # Whether to use variable-width pitch bins 20 | VARIABLE_PITCH_BINS = False 21 | 22 | # Whether to perform Viterbi decoding on pitch features 23 | VITERBI_DECODE_PITCH = False 24 | -------------------------------------------------------------------------------- /config/ablations/ablate-augment.py: -------------------------------------------------------------------------------- 1 | MODULE = 'promonet' 2 | 3 | # Configuration name 4 | CONFIG = 'ablate-augment' 5 | 6 | # Whether to use loudness augmentation 7 | AUGMENT_LOUDNESS = False 8 | 9 | # Whether to use pitch augmentation 10 | AUGMENT_PITCH = False 11 | -------------------------------------------------------------------------------- /config/ablations/ablate-loudness.py: -------------------------------------------------------------------------------- 1 | MODULE = 'promonet' 2 | 3 | # Configuration name 4 | CONFIG = 'ablate-loudness' 5 | 6 | # Whether to use loudness augmentation 7 | AUGMENT_LOUDNESS = False 8 | -------------------------------------------------------------------------------- /config/ablations/ablate-multiloud.py: -------------------------------------------------------------------------------- 1 | MODULE = 'promonet' 2 | 3 | # Configuration name 4 | CONFIG = 'ablate-multiloud' 5 | 6 | # Number of bands of A-weighted loudness 7 | LOUDNESS_BANDS = 1 8 | -------------------------------------------------------------------------------- /config/ablations/ablate-sppg.py: -------------------------------------------------------------------------------- 1 | MODULE = 'promonet' 2 | 3 | # Configuration name 4 | CONFIG = 'ablate-sppg' 5 | 6 | # Type of sparsification used for ppgs 7 | # One of ['constant', 'percentile', 'topk', None] 8 | SPARSE_PPG_METHOD = None 9 | -------------------------------------------------------------------------------- /config/ablations/ablate-variable-pitch.py: -------------------------------------------------------------------------------- 1 | MODULE = 'promonet' 2 | 3 | # Configuration name 4 | CONFIG = 'ablate-variable-pitch' 5 | 6 | # Whether to use variable-width pitch bins 7 | VARIABLE_PITCH_BINS = False 8 | -------------------------------------------------------------------------------- /config/ablations/ablate-viterbi.py: -------------------------------------------------------------------------------- 1 | MODULE = 'promonet' 2 | 3 | # Configuration name 4 | CONFIG = 'ablate-viterbi' 5 | 6 | # Whether to perform Viterbi decoding on pitch features 7 | VITERBI_DECODE_PITCH = False 8 | -------------------------------------------------------------------------------- /config/baselines/mels.py: -------------------------------------------------------------------------------- 1 | MODULE = 'promonet' 2 | 3 | # Configuration name 4 | CONFIG = 'mels' 5 | 6 | # Batch size 7 | BATCH_SIZE = 64 8 | 9 | # Input features 10 | INPUT_FEATURES = ['spectrogram'] 11 | 12 | # Type of sparsification used for ppgs 13 | # One of ['constant', 'percentile', 'topk', None] 14 | SPARSE_PPG_METHOD = None 15 | 16 | # Only use spectral features 17 | SPECTROGRAM_ONLY = True 18 | -------------------------------------------------------------------------------- /config/baselines/vocos.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch 4 | 5 | MODULE = 'promonet' 6 | 7 | # Configuration name 8 | CONFIG = 'vocos' 9 | 10 | # Whether to use hinge loss instead of L2 11 | ADVERSARIAL_HINGE_LOSS = True 12 | 13 | # Whether to use loudness augmentation 14 | AUGMENT_LOUDNESS = False 15 | 16 | # Whether to use pitch augmentation 17 | AUGMENT_PITCH = False 18 | 19 | # Batch size 20 | BATCH_SIZE = 16 21 | 22 | # Whether to use the complex multi-band discriminator from RVQGAN 23 | COMPLEX_MULTIBAND_DISCRIMINATOR = False 24 | 25 | # Input features 26 | INPUT_FEATURES = ['spectrogram'] 27 | 28 | # The model to use. One of ['hifigan', 'vocos', 'world']. 29 | MODEL = 'vocos' 30 | 31 | # Whether to use the multi-resolution spectrogram discriminator from UnivNet 32 | MULTI_RESOLUTION_DISCRIMINATOR = True 33 | 34 | # Training optimizer 35 | OPTIMIZER = functools.partial( 36 | torch.optim.AdamW, 37 | lr=2e-4, 38 | betas=(.9, .999), 39 | eps=1e-9) 40 | 41 | # Type of sparsification used for ppgs 42 | # One of ['constant', 'percentile', 'topk', None] 43 | SPARSE_PPG_METHOD = None 44 | 45 | # Only use spectral features 46 | SPECTROGRAM_ONLY = True 47 | 48 | # Number of training steps 49 | STEPS = 400000 50 | 51 | # Number of neural network layers in Vocos 52 | VOCOS_LAYERS = 8 53 | -------------------------------------------------------------------------------- /config/baselines/world.py: -------------------------------------------------------------------------------- 1 | MODULE = 'promonet' 2 | 3 | # Configuration name 4 | CONFIG = 'world' 5 | 6 | # The model to use. 7 | # One of ['fargan', 'hifigan', 'vocos', 'world']. 8 | MODEL = 'world' 9 | -------------------------------------------------------------------------------- /config/fargan-advlr1e6-warmup.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch 4 | 5 | MODULE = 'promonet' 6 | 7 | # Configuration name 8 | CONFIG = 'fargan-advlr1e6-warmup' 9 | 10 | # The model to use. 11 | # One of ['fargan', 'hifigan', 'vocos', 'world']. 12 | MODEL = 'fargan' 13 | 14 | # Step to start using adversarial loss 15 | ADVERSARIAL_LOSS_START_STEP = 270000 16 | 17 | # Step to start training discriminator 18 | DISCRIMINATOR_START_STEP = 240000 19 | 20 | # Training batch size 21 | BATCH_SIZE = 128 22 | 23 | # Training sequence length 24 | CHUNK_SIZE = 16384 # samples 25 | 26 | # Whether to use mel spectrogram loss 27 | MEL_LOSS = False 28 | 29 | # Training optimizer 30 | OPTIMIZER = functools.partial( 31 | torch.optim.AdamW, 32 | lr=2e-6, 33 | betas=(.9, .999), 34 | eps=1e-9) 35 | 36 | # Whether to use multi-resolution spectral convergence loss 37 | SPECTRAL_CONVERGENCE_LOSS = True 38 | -------------------------------------------------------------------------------- /config/fargan-fdisc.py: -------------------------------------------------------------------------------- 1 | MODULE = 'promonet' 2 | 3 | # Configuration name 4 | CONFIG = 'fargan-fdisc' 5 | 6 | # The model to use. 7 | # One of ['fargan', 'hifigan', 'vocos', 'world']. 8 | MODEL = 'fargan' 9 | 10 | # Step to start using adversarial loss 11 | ADVERSARIAL_LOSS_START_STEP = 300000 12 | 13 | # Whether to use the complex multi-band discriminator from RVQGAN 14 | COMPLEX_MULTIBAND_DISCRIMINATOR = False 15 | 16 | # Step to start training discriminator 17 | DISCRIMINATOR_START_STEP = 300000 18 | 19 | # Training batch size 20 | BATCH_SIZE = 256 21 | 22 | # Training sequence length 23 | CHUNK_SIZE = 4096 # samples 24 | 25 | # Whether to use the same discriminator as FARGAN 26 | FARGAN_DISCRIMINATOR = True 27 | 28 | # Whether to use mel spectrogram loss 29 | MEL_LOSS = False 30 | 31 | # Whether to use the multi-period waveform discriminator from HiFi-GAN 32 | MULTI_PERIOD_DISCRIMINATOR = False 33 | 34 | # Whether to use multi-resolution spectral convergence loss 35 | SPECTRAL_CONVERGENCE_LOSS = True 36 | -------------------------------------------------------------------------------- /config/fargan-long-noadv.py: -------------------------------------------------------------------------------- 1 | MODULE = 'promonet' 2 | 3 | # Configuration name 4 | CONFIG = 'fargan-long-noadv' 5 | 6 | # The model to use. 7 | # One of ['fargan', 'hifigan', 'vocos', 'world']. 8 | MODEL = 'fargan' 9 | 10 | # Step to start using adversarial loss 11 | ADVERSARIAL_LOSS_START_STEP = 1000000 12 | 13 | # Training batch size 14 | BATCH_SIZE = 1024 15 | 16 | # Training sequence length 17 | CHUNK_SIZE = 4096 # samples 18 | 19 | # Whether to use mel spectrogram loss 20 | MEL_LOSS = False 21 | 22 | # Whether to use multi-resolution spectral convergence loss 23 | SPECTRAL_CONVERGENCE_LOSS = True 24 | -------------------------------------------------------------------------------- /config/fargan-zeroshot-shuffle.py: -------------------------------------------------------------------------------- 1 | MODULE = 'promonet' 2 | 3 | # Configuration name 4 | CONFIG = 'fargan-zeroshot-shuffle' 5 | 6 | # The model to use. 7 | # One of ['fargan', 'hifigan', 'vocos', 'world']. 8 | MODEL = 'fargan' 9 | 10 | # Step to start using adversarial loss 11 | ADVERSARIAL_LOSS_START_STEP = 250000 12 | 13 | # Training batch size 14 | BATCH_SIZE = 256 15 | 16 | # Training sequence length 17 | CHUNK_SIZE = 4096 # samples 18 | 19 | # Whether to use mel spectrogram loss 20 | MEL_LOSS = False 21 | 22 | # Whether to use multi-resolution spectral convergence loss 23 | SPECTRAL_CONVERGENCE_LOSS = True 24 | 25 | # Whether to use WavLM x-vectors for zero-shot speaker conditioning 26 | ZERO_SHOT = True 27 | 28 | # Whether to shuffle speaker embeddings during training 29 | ZERO_SHOT_SHUFFLE = True 30 | -------------------------------------------------------------------------------- /config/fargan-zeroshot.py: -------------------------------------------------------------------------------- 1 | MODULE = 'promonet' 2 | 3 | # Configuration name 4 | CONFIG = 'fargan-zeroshot' 5 | 6 | # The model to use. 7 | # One of ['fargan', 'hifigan', 'vocos', 'world']. 8 | MODEL = 'fargan' 9 | 10 | # Step to start using adversarial loss 11 | ADVERSARIAL_LOSS_START_STEP = 250000 12 | 13 | # Training batch size 14 | BATCH_SIZE = 256 15 | 16 | # Training sequence length 17 | CHUNK_SIZE = 4096 # samples 18 | 19 | # Whether to use mel spectrogram loss 20 | MEL_LOSS = False 21 | 22 | # Whether to use multi-resolution spectral convergence loss 23 | SPECTRAL_CONVERGENCE_LOSS = True 24 | 25 | # Whether to use WavLM x-vectors for zero-shot speaker conditioning 26 | ZERO_SHOT = True 27 | -------------------------------------------------------------------------------- /config/fargan.py: -------------------------------------------------------------------------------- 1 | MODULE = 'promonet' 2 | 3 | # Configuration name 4 | CONFIG = 'fargan' 5 | 6 | # The model to use. 7 | # One of ['fargan', 'hifigan', 'vocos', 'world']. 8 | MODEL = 'fargan' 9 | 10 | # Step to start using adversarial loss 11 | ADVERSARIAL_LOSS_START_STEP = 250000 12 | 13 | # Training batch size 14 | BATCH_SIZE = 256 15 | 16 | # Training sequence length 17 | CHUNK_SIZE = 4096 # samples 18 | 19 | # Whether to use mel spectrogram loss 20 | MEL_LOSS = False 21 | 22 | # Whether to use multi-resolution spectral convergence loss 23 | SPECTRAL_CONVERGENCE_LOSS = True 24 | -------------------------------------------------------------------------------- /config/hparams/bands/16band.py: -------------------------------------------------------------------------------- 1 | MODULE = 'promonet' 2 | 3 | # Configuration name 4 | CONFIG = '16band' 5 | 6 | # Number of bands of A-weighted loudness 7 | LOUDNESS_BANDS = 16 8 | -------------------------------------------------------------------------------- /config/hparams/bands/2band.py: -------------------------------------------------------------------------------- 1 | MODULE = 'promonet' 2 | 3 | # Configuration name 4 | CONFIG = '2band' 5 | 6 | # Number of bands of A-weighted loudness 7 | LOUDNESS_BANDS = 2 8 | -------------------------------------------------------------------------------- /config/hparams/bands/32band.py: -------------------------------------------------------------------------------- 1 | MODULE = 'promonet' 2 | 3 | # Configuration name 4 | CONFIG = '32band' 5 | 6 | # Number of bands of A-weighted loudness 7 | LOUDNESS_BANDS = 32 8 | -------------------------------------------------------------------------------- /config/hparams/bands/4band.py: -------------------------------------------------------------------------------- 1 | MODULE = 'promonet' 2 | 3 | # Configuration name 4 | CONFIG = '4band' 5 | 6 | # Number of bands of A-weighted loudness 7 | LOUDNESS_BANDS = 4 8 | -------------------------------------------------------------------------------- /config/hparams/bands/8band.py: -------------------------------------------------------------------------------- 1 | MODULE = 'promonet' 2 | 3 | # Configuration name 4 | CONFIG = '8band' 5 | 6 | # Number of bands of A-weighted loudness 7 | LOUDNESS_BANDS = 8 8 | -------------------------------------------------------------------------------- /config/hparams/sppg/sppg-constant-0025.py: -------------------------------------------------------------------------------- 1 | MODULE = 'promonet' 2 | 3 | # Configuration name 4 | CONFIG = 'sppg-constant-0025' 5 | 6 | # Type of sparsification used for ppgs 7 | # One of ['constant', 'percentile', 'topk', None] 8 | SPARSE_PPG_METHOD = 'constant' 9 | 10 | # Threshold for ppg sparsification. 11 | # In [0, 1] for 'contant' and 'percentile'; integer > 0 for 'topk'. 12 | SPARSE_PPG_THRESHOLD = 0.025 13 | -------------------------------------------------------------------------------- /config/hparams/sppg/sppg-constant-005.py: -------------------------------------------------------------------------------- 1 | MODULE = 'promonet' 2 | 3 | # Configuration name 4 | CONFIG = 'sppg-constant-005' 5 | 6 | # Type of sparsification used for ppgs 7 | # One of ['constant', 'percentile', 'topk', None] 8 | SPARSE_PPG_METHOD = 'constant' 9 | 10 | # Threshold for ppg sparsification. 11 | # In [0, 1] for 'contant' and 'percentile'; integer > 0 for 'topk'. 12 | SPARSE_PPG_THRESHOLD = 0.05 13 | -------------------------------------------------------------------------------- /config/hparams/sppg/sppg-constant-0075.py: -------------------------------------------------------------------------------- 1 | MODULE = 'promonet' 2 | 3 | # Configuration name 4 | CONFIG = 'sppg-constant-0075' 5 | 6 | # Type of sparsification used for ppgs 7 | # One of ['constant', 'percentile', 'topk', None] 8 | SPARSE_PPG_METHOD = 'constant' 9 | 10 | # Threshold for ppg sparsification. 11 | # In [0, 1] for 'contant' and 'percentile'; integer > 0 for 'topk'. 12 | SPARSE_PPG_THRESHOLD = 0.075 13 | -------------------------------------------------------------------------------- /config/hparams/sppg/sppg-constant-010.py: -------------------------------------------------------------------------------- 1 | MODULE = 'promonet' 2 | 3 | # Configuration name 4 | CONFIG = 'sppg-constant-010' 5 | 6 | # Type of sparsification used for ppgs 7 | # One of ['constant', 'percentile', 'topk', None] 8 | SPARSE_PPG_METHOD = 'constant' 9 | 10 | # Threshold for ppg sparsification. 11 | # In [0, 1] for 'contant' and 'percentile'; integer > 0 for 'topk'. 12 | SPARSE_PPG_THRESHOLD = 0.10 13 | -------------------------------------------------------------------------------- /config/hparams/sppg/sppg-percentile-080.py: -------------------------------------------------------------------------------- 1 | MODULE = 'promonet' 2 | 3 | # Configuration name 4 | CONFIG = 'sppg-percentile-080' 5 | 6 | # Type of sparsification used for ppgs 7 | # One of ['constant', 'percentile', 'topk', None] 8 | SPARSE_PPG_METHOD = 'percentile' 9 | 10 | # Threshold for ppg sparsification. 11 | # In [0, 1] for 'contant' and 'percentile'; integer > 0 for 'topk'. 12 | SPARSE_PPG_THRESHOLD = 0.80 13 | -------------------------------------------------------------------------------- /config/hparams/sppg/sppg-percentile-085.py: -------------------------------------------------------------------------------- 1 | MODULE = 'promonet' 2 | 3 | # Configuration name 4 | CONFIG = 'sppg-percentile-085' 5 | 6 | # Type of sparsification used for ppgs 7 | # One of ['constant', 'percentile', 'topk', None] 8 | SPARSE_PPG_METHOD = 'percentile' 9 | 10 | # Threshold for ppg sparsification. 11 | # In [0, 1] for 'contant' and 'percentile'; integer > 0 for 'topk'. 12 | SPARSE_PPG_THRESHOLD = 0.85 13 | -------------------------------------------------------------------------------- /config/hparams/sppg/sppg-percentile-090.py: -------------------------------------------------------------------------------- 1 | MODULE = 'promonet' 2 | 3 | # Configuration name 4 | CONFIG = 'sppg-percentile-090' 5 | 6 | # Type of sparsification used for ppgs 7 | # One of ['constant', 'percentile', 'topk', None] 8 | SPARSE_PPG_METHOD = 'percentile' 9 | 10 | # Threshold for ppg sparsification. 11 | # In [0, 1] for 'contant' and 'percentile'; integer > 0 for 'topk'. 12 | SPARSE_PPG_THRESHOLD = 0.90 13 | -------------------------------------------------------------------------------- /config/hparams/sppg/sppg-percentile-095.py: -------------------------------------------------------------------------------- 1 | MODULE = 'promonet' 2 | 3 | # Configuration name 4 | CONFIG = 'sppg-percentile-095' 5 | 6 | # Type of sparsification used for ppgs 7 | # One of ['constant', 'percentile', 'topk', None] 8 | SPARSE_PPG_METHOD = 'percentile' 9 | 10 | # Threshold for ppg sparsification. 11 | # In [0, 1] for 'contant' and 'percentile'; integer > 0 for 'topk'. 12 | SPARSE_PPG_THRESHOLD = 0.95 13 | -------------------------------------------------------------------------------- /config/hparams/sppg/sppg-top-3.py: -------------------------------------------------------------------------------- 1 | MODULE = 'promonet' 2 | 3 | # Configuration name 4 | CONFIG = 'sppg-top-3' 5 | 6 | # Type of sparsification used for ppgs 7 | # One of ['constant', 'percentile', 'topk', None] 8 | SPARSE_PPG_METHOD = 'topk' 9 | 10 | # Threshold for ppg sparsification. 11 | # In [0, 1] for 'contant' and 'percentile'; integer > 0 for 'topk'. 12 | SPARSE_PPG_THRESHOLD = 3 13 | -------------------------------------------------------------------------------- /config/hparams/sppg/sppg-top-4.py: -------------------------------------------------------------------------------- 1 | MODULE = 'promonet' 2 | 3 | # Configuration name 4 | CONFIG = 'sppg-top-4' 5 | 6 | # Type of sparsification used for ppgs 7 | # One of ['constant', 'percentile', 'topk', None] 8 | SPARSE_PPG_METHOD = 'topk' 9 | 10 | # Threshold for ppg sparsification. 11 | # In [0, 1] for 'contant' and 'percentile'; integer > 0 for 'topk'. 12 | SPARSE_PPG_THRESHOLD = 4 13 | -------------------------------------------------------------------------------- /config/hparams/sppg/sppg-top-5.py: -------------------------------------------------------------------------------- 1 | MODULE = 'promonet' 2 | 3 | # Configuration name 4 | CONFIG = 'sppg-top-5' 5 | 6 | # Type of sparsification used for ppgs 7 | # One of ['constant', 'percentile', 'topk', None] 8 | SPARSE_PPG_METHOD = 'topk' 9 | 10 | # Threshold for ppg sparsification. 11 | # In [0, 1] for 'contant' and 'percentile'; integer > 0 for 'topk'. 12 | SPARSE_PPG_THRESHOLD = 5 13 | -------------------------------------------------------------------------------- /config/hparams/sppg/sppg-top-6.py: -------------------------------------------------------------------------------- 1 | MODULE = 'promonet' 2 | 3 | # Configuration name 4 | CONFIG = 'sppg-top-6' 5 | 6 | # Type of sparsification used for ppgs 7 | # One of ['constant', 'percentile', 'topk', None] 8 | SPARSE_PPG_METHOD = 'topk' 9 | 10 | # Threshold for ppg sparsification. 11 | # In [0, 1] for 'contant' and 'percentile'; integer > 0 for 'topk'. 12 | SPARSE_PPG_THRESHOLD = 6 13 | -------------------------------------------------------------------------------- /config/promonet-fdisc.py: -------------------------------------------------------------------------------- 1 | MODULE = 'promonet' 2 | 3 | # Configuration name 4 | CONFIG = 'promonet-fdisc' 5 | 6 | # Whether to use the complex multi-band discriminator from RVQGAN 7 | COMPLEX_MULTIBAND_DISCRIMINATOR = False 8 | 9 | # Whether to use the multi-period waveform discriminator from HiFi-GAN 10 | MULTI_PERIOD_DISCRIMINATOR = False 11 | 12 | # Whether to use the same discriminator as FARGAN 13 | FARGAN_DISCRIMINATOR = True 14 | -------------------------------------------------------------------------------- /config/promonet-zeroshot-shuffle.py: -------------------------------------------------------------------------------- 1 | MODULE = 'promonet' 2 | 3 | # Configuration name 4 | CONFIG = 'promonet-zeroshot-shuffle' 5 | 6 | # Whether to use WavLM x-vectors for zero-shot speaker conditioning 7 | ZERO_SHOT = True 8 | 9 | # Whether to shuffle speaker embeddings during training 10 | ZERO_SHOT_SHUFFLE = True 11 | -------------------------------------------------------------------------------- /config/promonet-zeroshot.py: -------------------------------------------------------------------------------- 1 | MODULE = 'promonet' 2 | 3 | # Configuration name 4 | CONFIG = 'promonet-zeroshot' 5 | 6 | # Whether to use WavLM x-vectors for zero-shot speaker conditioning 7 | ZERO_SHOT = True 8 | -------------------------------------------------------------------------------- /config/promonet.py: -------------------------------------------------------------------------------- 1 | MODULE = 'promonet' 2 | 3 | # Configuration name 4 | CONFIG = 'promonet' 5 | -------------------------------------------------------------------------------- /data/cache/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxrmorrison/promonet/86b815e8511d526da42cd8333a9f5a2872b11981/data/cache/.gitkeep -------------------------------------------------------------------------------- /data/datasets/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxrmorrison/promonet/86b815e8511d526da42cd8333a9f5a2872b11981/data/datasets/.gitkeep -------------------------------------------------------------------------------- /notebooks/parse_results.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "37044920-d03d-48fd-ba8e-bfbc648b912a", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%load_ext autoreload\n", 11 | "%autoreload 2" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "id": "b9499e03-9923-4c5f-8122-0c078281bcce", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "import json\n", 22 | "from pathlib import Path\n", 23 | "\n", 24 | "import IPython.display as ipd\n", 25 | "import torch\n", 26 | "\n", 27 | "import promonet" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "id": "79abb39c-7907-4dcf-a308-6f72fd807f28", 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "# Conditions to consider\n", 38 | "conditions = [\n", 39 | " 'promonet',\n", 40 | " # 'ablate-augment',\n", 41 | " # 'ablate-multiloud',\n", 42 | " # 'ablate-sppg',\n", 43 | " # 'ablate-variable-pitch',\n", 44 | " # 'ablate-viterbi',\n", 45 | " # 'ablate-all',\n", 46 | " # 'mels',\n", 47 | " # 'mels-ours',\n", 48 | " # 'world'\n", 49 | "]\n", 50 | "edits = [\n", 51 | " 'reconstructed-100',\n", 52 | " 'scaled-050',\n", 53 | " 'scaled-200',\n", 54 | " 'shifted-071',\n", 55 | " 'shifted-141',\n", 56 | " 'stretched-071',\n", 57 | " 'stretched-141'\n", 58 | "]\n", 59 | "metrics = [\n", 60 | " 'pitch',\n", 61 | " 'periodicity',\n", 62 | " 'loudness-loud',\n", 63 | " 'ppg',\n", 64 | " # 'wer',\n", 65 | " # 'speaker_similarity',\n", 66 | " # 'formant-average',\n", 67 | "]" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "id": "db775fd1-bdb6-463a-87e5-ceec528d1b5d", 73 | "metadata": {}, 74 | "source": [ 75 | "## Parse objective results on a set of conditions" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "id": "241788e2-3336-4982-aadb-ede2c45617f0", 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "def parse_results(conditions, edits, metric, dataset):\n", 86 | " results = {condition: {} for condition in conditions}\n", 87 | " for condition in conditions:\n", 88 | " with open(f'/repos/promonet/results/{condition}/{dataset}/results.json') as file:\n", 89 | " for edit, metrics in json.load(file).items():\n", 90 | " if edit not in edits:\n", 91 | " continue\n", 92 | " # print(edit, json.dumps(metrics, indent=4, sort_keys=True))\n", 93 | " try:\n", 94 | " results[condition][edit] = metrics[metric]\n", 95 | " except KeyError:\n", 96 | " pass\n", 97 | " for condition in conditions:\n", 98 | " values = list(results[condition].values())\n", 99 | " results[condition]['average'] = sum(values) / len(values)\n", 100 | " print(\n", 101 | " json.dumps(\n", 102 | " {condition: results[condition]['average'] for condition in conditions},\n", 103 | " indent=4,\n", 104 | " sort_keys=True))\n" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "id": "6646534a-1961-42b8-9640-847093b82e3b", 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "for metric in metrics:\n", 115 | " print(metric)\n", 116 | " parse_results(conditions, edits, metric, 'vctk')" 117 | ] 118 | }, 119 | { 120 | "cell_type": "markdown", 121 | "id": "5c0fad86-7b4d-43ee-aaee-97e0093d3c54", 122 | "metadata": {}, 123 | "source": [ 124 | "## File-level inspection of objective results" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "id": "3413a0d5-019a-4947-aa74-6957f8311d80", 131 | "metadata": {}, 132 | "outputs": [], 133 | "source": [ 134 | "# Load fine-grained objective results\n", 135 | "condition = 'ablate-all'\n", 136 | "results = {}\n", 137 | "for file in Path(f'/repos/promonet/results/{condition}/vctk').glob('0*.json'):\n", 138 | " with open(file) as file:\n", 139 | " results |= json.load(file)['objective']['raw']" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": null, 145 | "id": "2cb8012b-146a-457a-802a-6d5c5ecafe01", 146 | "metadata": {}, 147 | "outputs": [], 148 | "source": [ 149 | "# Sort files by a specific metric\n", 150 | "metric = 'ppg'\n", 151 | "metric_results = {}\n", 152 | "for key, edit_metrics in results.items():\n", 153 | " edit = list(edit_metrics.keys())[0]\n", 154 | " if 'shifted-' not in key and 'scaled-' not in key and 'stretched-' not in key and 'original-' not in key:\n", 155 | " continue\n", 156 | " metric_results[key] = edit_metrics[edit][metric]\n", 157 | "metric_results = dict(sorted(metric_results.items(), key=lambda item: item[1], reverse=True))" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": null, 163 | "id": "6b398f67-c283-4856-949d-2c9e97c6998e", 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "import numpy as np\n", 168 | "import scipy.stats\n", 169 | "\n", 170 | "def mean_confidence_interval(data, confidence=0.95):\n", 171 | " a = 1.0 * np.array(data)\n", 172 | " n = len(a)\n", 173 | " m, se = np.mean(a), scipy.stats.sem(a)\n", 174 | " return m, se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": null, 180 | "id": "179c9209-be9e-4a39-9aa8-5c7d27c2f159", 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [ 184 | "mean_confidence_interval(list(metric_results.values()))" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": null, 190 | "id": "0bf3c59f-76f4-45ca-840f-631ea84b2c03", 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [ 194 | "subjective_directory = Path('/repos/promonet/eval/subjective')\n", 195 | "objective_directory = Path('/repos/promonet/eval/objective')\n", 196 | "for i, stem in enumerate(metric_results):\n", 197 | "\n", 198 | " if i > 10:\n", 199 | " break\n", 200 | " print(stem, metric_results[stem])\n", 201 | " predicted = promonet.load.audio(subjective_directory / condition / f'{stem}.wav')\n", 202 | " ipd.display(ipd.Audio(predicted, rate=promonet.SAMPLE_RATE))\n", 203 | " parts = stem.split('-')\n", 204 | " file = subjective_directory / 'original' / f'{\"-\".join(parts[:3])}-original-100.wav'\n", 205 | " print(file)\n", 206 | " ipd.display(ipd.Audio(file))\n", 207 | " print(promonet.load.text(objective_directory / condition / f'{stem}.txt'))\n", 208 | " print(promonet.load.text(objective_directory / 'original' / f'{stem}.txt'))\n", 209 | " frames = promonet.convert.samples_to_frames(predicted.shape[-1])\n", 210 | " figure = promonet.plot.from_features(\n", 211 | " predicted,\n", 212 | " torch.load(objective_directory / condition / f'{stem}-viterbi-pitch.pt'),\n", 213 | " torch.load(objective_directory / condition / f'{stem}-viterbi-periodicity.pt'),\n", 214 | " promonet.preprocess.loudness.band_average(torch.load(objective_directory / condition / f'{stem}-loudness.pt'), 1),\n", 215 | " promonet.load.ppg(objective_directory / condition / f'{stem}-ppg.pt', frames),\n", 216 | " torch.load(objective_directory / 'original' / f'{stem}-viterbi-pitch.pt'),\n", 217 | " torch.load(objective_directory / 'original' / f'{stem}-viterbi-periodicity.pt'),\n", 218 | " promonet.preprocess.loudness.band_average(torch.load(objective_directory / 'original' / f'{stem}-loudness.pt'), 1),\n", 219 | " promonet.load.ppg(objective_directory / 'original' / f'{stem}-ppg.pt', frames))\n", 220 | " figure.show()" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": null, 226 | "id": "6345683d-fe8b-4716-84f8-5a6da5ca5db2", 227 | "metadata": {}, 228 | "outputs": [], 229 | "source": [] 230 | } 231 | ], 232 | "metadata": { 233 | "kernelspec": { 234 | "display_name": "promonet", 235 | "language": "python", 236 | "name": "promonet" 237 | }, 238 | "language_info": { 239 | "codemirror_mode": { 240 | "name": "ipython", 241 | "version": 3 242 | }, 243 | "file_extension": ".py", 244 | "mimetype": "text/x-python", 245 | "name": "python", 246 | "nbconvert_exporter": "python", 247 | "pygments_lexer": "ipython3", 248 | "version": "3.10.13" 249 | } 250 | }, 251 | "nbformat": 4, 252 | "nbformat_minor": 5 253 | } 254 | -------------------------------------------------------------------------------- /notebooks/ppgs/ppgs_objective_eval.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "e2a709c9-6dcc-491f-9cb1-02893138ce90", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%load_ext autoreload\n", 11 | "%autoreload 2" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "id": "30a60a2d-2952-4aa1-8133-9519047bb0c7", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "import json\n", 22 | "\n", 23 | "import scipy\n", 24 | "import torch\n", 25 | "import ppgs\n", 26 | "\n", 27 | "import promonet" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "id": "6885cd2b-97a8-4858-b6ec-f180c28a00f4", 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "gpu = 0\n", 38 | "device = 'cpu' if gpu is None else f'cuda:{gpu}'" 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "id": "dfd9a9ab-886a-40eb-84b5-dfae87269f8b", 44 | "metadata": {}, 45 | "source": [ 46 | "### Pitch and WER evaluation" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "id": "3b97fb0a-64ef-4dd6-bf3a-e126949da23b", 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "configs = [\n", 57 | " 'bottleneck',\n", 58 | " 'encodec',\n", 59 | " 'mel',\n", 60 | " 'w2v2fb',\n", 61 | " 'w2v2fc',\n", 62 | " 'bottleneck-latent',\n", 63 | " 'encodec-latent',\n", 64 | " 'mel-latent',\n", 65 | " 'w2v2fb-latent',\n", 66 | " 'w2v2fc-latent']" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "id": "f143c50e-d031-4d54-bce9-30f740c6c98d", 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "pitch_results, wer_results, jsd_results = {}, {}, {}\n", 77 | "for config in configs:\n", 78 | " with open(promonet.RESULTS_DIR / config / 'results.json') as file:\n", 79 | " result = json.load(file)\n", 80 | " pitch_results[config] = .5 * (result['shifted-089']['pitch'] + result['shifted-112']['pitch'])\n", 81 | " wer_results[config] = .5 * (result['shifted-089']['wer'] + result['shifted-112']['wer'])\n", 82 | " jsd_results[config] = .5 * (result['shifted-089']['ppg'] + result['shifted-112']['ppg'])\n", 83 | "print('Pitch', json.dumps(pitch_results, indent=4, sort_keys=True))\n", 84 | "print('WER', json.dumps(wer_results, indent=4, sort_keys=True))\n", 85 | "print('JSD', json.dumps(jsd_results, indent=4, sort_keys=True))" 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "id": "6062ba22-f36e-44ce-a990-a75d80f25963", 91 | "metadata": {}, 92 | "source": [ 93 | "## PPG JSD evaluation" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "id": "06e93e7b-5caf-4096-8ca5-b45257c16a69", 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "class JSDs:\n", 104 | " \"\"\"PPG distances at multiple exponents\"\"\"\n", 105 | "\n", 106 | " def __init__(self):\n", 107 | " self.jsds = [\n", 108 | " promonet.evaluate.metrics.PPG(exponent)\n", 109 | " for exponent in torch.round(torch.arange(0.0, 2.0, 0.05), decimals=2)]\n", 110 | "\n", 111 | " def __call__(self):\n", 112 | " return {f'{jsd.exponent:02f}': jsd() for jsd in self.jsds}\n", 113 | "\n", 114 | " def update(self, predicted, target):\n", 115 | " # Compute PPG\n", 116 | " gpu = (\n", 117 | " None if predicted.device.type == 'cpu'\n", 118 | " else predicted.device.index)\n", 119 | " predicted = ppgs.from_audio(\n", 120 | " predicted,\n", 121 | " promonet.SAMPLE_RATE,\n", 122 | " ppgs.RUNS_DIR / 'mel' / '00200000.pt',\n", 123 | " gpu)\n", 124 | " target = ppgs.from_audio(\n", 125 | " target,\n", 126 | " promonet.SAMPLE_RATE,\n", 127 | " ppgs.RUNS_DIR / 'mel' / '00200000.pt',\n", 128 | " gpu)\n", 129 | " \n", 130 | " # Update metrics\n", 131 | " for jsd in self.jsds:\n", 132 | " jsd.update(predicted, target)\n", 133 | "\n", 134 | " def reset(self):\n", 135 | " for jsd in self.jsds:\n", 136 | " jsd.reset()" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "id": "e2d7f39e-9b64-4bd4-b9d6-c2b699fe0d14", 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "jsd_results = {}\n", 147 | "jsd_file_results = {}\n", 148 | "jsds = JSDs()\n", 149 | "file_jsds = JSDs()\n", 150 | "original_files = sorted(list(\n", 151 | " (promonet.EVAL_DIR / 'subjective' / 'original').glob('vctk*.wav')))\n", 152 | "for config in configs:\n", 153 | " jsds.reset()\n", 154 | " jsd_file_results[config] = {}\n", 155 | " eval_directory = promonet.EVAL_DIR / 'subjective' / config\n", 156 | " shift089_files = sorted(list(eval_directory.glob('*shifted-089.wav')))\n", 157 | " shift112_files = sorted(list(eval_directory.glob('*shifted-112.wav')))\n", 158 | " for original, shift089, shift112 in zip(\n", 159 | " original_files,\n", 160 | " shift089_files,\n", 161 | " shift112_files\n", 162 | " ):\n", 163 | " jsds.update(\n", 164 | " promonet.load.audio(shift089).to(device),\n", 165 | " promonet.load.audio(original).to(device))\n", 166 | " jsds.update(\n", 167 | " promonet.load.audio(shift112).to(device),\n", 168 | " promonet.load.audio(original).to(device))\n", 169 | " file_jsds.reset()\n", 170 | " file_jsds.update(\n", 171 | " promonet.load.audio(shift089).to(device),\n", 172 | " promonet.load.audio(original).to(device))\n", 173 | " jsd_file_results[config][shift089.stem] = file_jsds()\n", 174 | " file_jsds.reset()\n", 175 | " file_jsds.update(\n", 176 | " promonet.load.audio(shift112).to(device),\n", 177 | " promonet.load.audio(original).to(device))\n", 178 | " jsd_file_results[config][shift112.stem] = file_jsds()\n", 179 | " jsd_results[config] = jsds()\n", 180 | "print('JSD', json.dumps(jsd_results, indent=4, sort_keys=True))" 181 | ] 182 | }, 183 | { 184 | "cell_type": "markdown", 185 | "id": "232fed95-1ae8-420d-9420-5b650209d07a", 186 | "metadata": {}, 187 | "source": [ 188 | "## Select exponent with highest correlation with WER" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": null, 194 | "id": "27e7fcf6-59c5-4298-b0da-b84a6813cb5d", 195 | "metadata": {}, 196 | "outputs": [], 197 | "source": [ 198 | "wer_file_results = {}\n", 199 | "for config in configs:\n", 200 | " wer_file_results[config] = {}\n", 201 | " results_dir = promonet.RESULTS_DIR / config / 'vctk'\n", 202 | " for file in results_dir.glob('*.json'):\n", 203 | " if file.stem == 'results':\n", 204 | " continue\n", 205 | " with open(file) as file:\n", 206 | " result = json.load(file)\n", 207 | " for stem, scores in result['objective']['raw'].items():\n", 208 | " if 'shifted' in stem:\n", 209 | " wer_file_results[config][stem] = scores['-'.join(stem.split('-')[-2:])]['wer']\n", 210 | "\n", 211 | "exponents = jsd_results['mel'].keys()\n", 212 | "stems = wer_file_results['mel'].keys()\n", 213 | "\n", 214 | "correlations = {}\n", 215 | "for exponent in exponents:\n", 216 | " jsd_values, wer_values = [], []\n", 217 | " for config in configs:\n", 218 | " for stem in stems:\n", 219 | " jsd_values.append(jsd_file_results[config][stem][exponent])\n", 220 | " wer_values.append(wer_file_results[config][stem])\n", 221 | " correlations[exponent] = scipy.stats.pearsonr(jsd_values, wer_values)\n", 222 | "print('Correlations', json.dumps(correlations, indent=4, sort_keys=True))" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": null, 228 | "id": "0d5d9976-deb1-405c-8374-a3ec3047206b", 229 | "metadata": { 230 | "scrolled": true 231 | }, 232 | "outputs": [], 233 | "source": [ 234 | "optimal = '1.200000'\n", 235 | "jsd_results_optim = {config: value[optimal] for config, value in jsd_results.items()}\n", 236 | "print('JSD', json.dumps(jsd_results_optim, indent=4, sort_keys=True))" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": null, 242 | "id": "71a6ad5f-511e-4b7f-babd-cc7213c5caf0", 243 | "metadata": {}, 244 | "outputs": [], 245 | "source": [] 246 | } 247 | ], 248 | "metadata": { 249 | "kernelspec": { 250 | "display_name": "Python 3 (ipykernel)", 251 | "language": "python", 252 | "name": "python3" 253 | }, 254 | "language_info": { 255 | "codemirror_mode": { 256 | "name": "ipython", 257 | "version": 3 258 | }, 259 | "file_extension": ".py", 260 | "mimetype": "text/x-python", 261 | "name": "python", 262 | "nbconvert_exporter": "python", 263 | "pygments_lexer": "ipython3", 264 | "version": "3.10.13" 265 | } 266 | }, 267 | "nbformat": 4, 268 | "nbformat_minor": 5 269 | } 270 | -------------------------------------------------------------------------------- /notebooks/select-speakers.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "fdd04fe0-72aa-4db3-a501-e95e07410b34", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%load_ext autoreload\n", 11 | "%autoreload 2" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "id": "ca6e9a33-b2e5-417b-a5ff-026320813e58", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "import IPython.display as ipd\n", 22 | "import torchaudio\n", 23 | "\n", 24 | "import promonet" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "id": "855a1782-bcae-468e-a603-41f34e941fe4", 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "# Choose a dataset. One of promonet.DATASETS.\n", 35 | "dataset = 'daps'" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "id": "478fc78b-2f64-46fd-9db8-84a5855ccd32", 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "# Get total duration of each speaker in seconds\n", 46 | "directory = promonet.CACHE_DIR / dataset\n", 47 | "files = list(directory.glob('*.wav'))\n", 48 | "speakers = sorted(list(set(file.stem.split('-')[0] for file in files)))\n", 49 | "speaker_sizes = {speaker: 0. for speaker in speakers}\n", 50 | "for file in files:\n", 51 | " info = torchaudio.info(file)\n", 52 | " size = info.num_frames / info.sample_rate\n", 53 | " speaker_sizes[file.stem.split('-')[0]] += size\n", 54 | "\n", 55 | "# Sort speakers and total durations in descending order of duration\n", 56 | "candidates = sorted(\n", 57 | " speaker_sizes.items(),\n", 58 | " key=lambda item: item[1],\n", 59 | " reverse=True)" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "id": "bf113fac-89ad-4bd6-8985-8a4fd308bce5", 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "# Print speakers in descending order of duration for manual selection\n", 70 | "candidates" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "id": "b05ce6a6-a8cc-42bb-97c3-ec40f2244548", 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "# Listen to a sample audio file to check fidelity and gender\n", 81 | "for index, _ in sorted(candidates):\n", 82 | " file = directory / f'{index}-000023.wav'\n", 83 | " print(index)\n", 84 | " ipd.display(ipd.Audio(file))" 85 | ] 86 | } 87 | ], 88 | "metadata": { 89 | "kernelspec": { 90 | "display_name": "Python 3 (ipykernel)", 91 | "language": "python", 92 | "name": "python3" 93 | }, 94 | "language_info": { 95 | "codemirror_mode": { 96 | "name": "ipython", 97 | "version": 3 98 | }, 99 | "file_extension": ".py", 100 | "mimetype": "text/x-python", 101 | "name": "python", 102 | "nbconvert_exporter": "python", 103 | "pygments_lexer": "ipython3", 104 | "version": "3.10.13" 105 | }, 106 | "vscode": { 107 | "interpreter": { 108 | "hash": "e661ca2f247ba03d88bed293db733ca5edb23c05adbd6829a2eef4272a9ed78d" 109 | } 110 | } 111 | }, 112 | "nbformat": 4, 113 | "nbformat_minor": 5 114 | } 115 | -------------------------------------------------------------------------------- /promonet/__init__.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Configuration 3 | ############################################################################### 4 | 5 | 6 | # Default configuration parameters to be modified 7 | from .config import defaults 8 | 9 | # Modify configuration 10 | import yapecs 11 | yapecs.configure('promonet', defaults) 12 | 13 | # Import configuration parameters 14 | from .config.defaults import * 15 | from .config.static import * 16 | 17 | 18 | ############################################################################### 19 | # Module imports 20 | ############################################################################### 21 | 22 | 23 | from .train import loss, train 24 | from . import adapt 25 | from . import baseline 26 | from . import convert 27 | from . import data 28 | from . import edit 29 | from . import evaluate 30 | from . import load 31 | from . import model 32 | from . import partition 33 | from . import plot 34 | from . import preprocess 35 | from . import synthesize 36 | -------------------------------------------------------------------------------- /promonet/adapt/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import * -------------------------------------------------------------------------------- /promonet/adapt/__main__.py: -------------------------------------------------------------------------------- 1 | import yapecs 2 | from pathlib import Path 3 | 4 | import promonet 5 | 6 | 7 | ############################################################################### 8 | # Speaker adaptation 9 | ############################################################################### 10 | 11 | 12 | def parse_args(): 13 | """Parse command-line arguments""" 14 | parser = yapecs.ArgumentParser(description='Perform speaker adaptation') 15 | parser.add_argument( 16 | '--name', 17 | required=True, 18 | help='The name of the speaker') 19 | parser.add_argument( 20 | '--files', 21 | type=Path, 22 | nargs='+', 23 | required=True, 24 | help='The audio files to use for adaptation') 25 | parser.add_argument( 26 | '--checkpoint', 27 | type=Path, 28 | help='The model checkpoint directory') 29 | parser.add_argument( 30 | '--gpu', 31 | type=int, 32 | help='The gpu to run adaptation on') 33 | return parser.parse_args() 34 | 35 | 36 | promonet.adapt.speaker(**vars(parse_args())) 37 | -------------------------------------------------------------------------------- /promonet/adapt/core.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import List, Optional 3 | 4 | import huggingface_hub 5 | import torch 6 | import torchaudio 7 | import torchutil 8 | 9 | import promonet 10 | 11 | 12 | ############################################################################### 13 | # Speaker adaptation API 14 | ############################################################################### 15 | 16 | 17 | def speaker( 18 | name: str, 19 | files: List[Path], 20 | checkpoint: Optional[Path] = None, 21 | gpu: Optional[int] = None 22 | ) -> Path: 23 | """Perform speaker adaptation 24 | 25 | Args: 26 | name: The name of the speaker 27 | files: The audio files to use for adaptation 28 | checkpoint: The model checkpoint directory 29 | gpu: The gpu to run adaptation on 30 | 31 | Returns: 32 | checkpoint: The file containing the trained generator checkpoint 33 | """ 34 | # Make a new cache directory 35 | cache = promonet.CACHE_DIR / 'adapt' / name 36 | cache.mkdir(exist_ok=True, parents=True) 37 | 38 | # Preprocess audio 39 | for i, file in enumerate(files): 40 | 41 | # Convert to 22.05k 42 | audio = promonet.load.audio(file) 43 | 44 | # If audio is too quiet, increase the volume 45 | maximum = torch.abs(audio).max() 46 | if maximum < .35: 47 | audio *= .35 / maximum 48 | 49 | # Save to cache 50 | torchaudio.save( 51 | cache / f'{i:06d}-100.wav', 52 | audio, 53 | promonet.SAMPLE_RATE) 54 | 55 | if promonet.AUGMENT_PITCH or promonet.AUGMENT_LOUDNESS: 56 | 57 | # Augment and get augmentation ratios 58 | promonet.data.augment.from_files_to_files(files, name) 59 | 60 | # Preprocess features 61 | promonet.data.preprocess.from_files_to_files( 62 | cache, 63 | cache.rglob('*.wav'), 64 | gpu=gpu) 65 | 66 | # Partition (all files are used for training) 67 | promonet.partition.dataset(name) 68 | 69 | # Directory to save configuration, checkpoints, and logs 70 | directory = promonet.RUNS_DIR / promonet.CONFIG / 'adapt' / name 71 | directory.mkdir(exist_ok=True, parents=True) 72 | 73 | # Maybe resume adaptation 74 | generator_path = torchutil.checkpoint.latest_path( 75 | directory, 76 | 'generator-*.pt') 77 | discriminator_path = torchutil.checkpoint.latest_path( 78 | directory, 79 | 'discriminator-*.pt') 80 | if generator_path and discriminator_path: 81 | checkpoint = directory 82 | 83 | # Maybe download checkpoint 84 | if checkpoint is None: 85 | generator_checkpoint = huggingface_hub.hf_hub_download( 86 | 'maxrmorrison/promonet', 87 | f'generator-00{promonet.STEPS}.pt') 88 | huggingface_hub.hf_hub_download( 89 | 'maxrmorrison/promonet', 90 | f'discriminator-00{promonet.STEPS}.pt') 91 | checkpoint = Path(generator_checkpoint).parent 92 | 93 | # Perform adaptation and return generator checkpoint 94 | return promonet.train( 95 | directory, 96 | name, 97 | adapt_from=checkpoint, 98 | gpu=gpu) 99 | -------------------------------------------------------------------------------- /promonet/assets/partitions/adaptation/daps.json: -------------------------------------------------------------------------------- 1 | { 2 | "train-adapt-00": [ 3 | "0002/000001", 4 | "0002/000002", 5 | "0002/000003", 6 | "0002/000004", 7 | "0002/000005", 8 | "0002/000006", 9 | "0002/000008", 10 | "0002/000010", 11 | "0002/000012", 12 | "0002/000014", 13 | "0002/000015", 14 | "0002/000016", 15 | "0002/000017", 16 | "0002/000019", 17 | "0002/000020", 18 | "0002/000021", 19 | "0002/000022", 20 | "0002/000023", 21 | "0002/000024", 22 | "0002/000026", 23 | "0002/000029", 24 | "0002/000030", 25 | "0002/000031", 26 | "0002/000033", 27 | "0002/000034" 28 | ], 29 | "test-adapt-00": [ 30 | "0002/000007", 31 | "0002/000009", 32 | "0002/000011", 33 | "0002/000013", 34 | "0002/000018", 35 | "0002/000025", 36 | "0002/000027", 37 | "0002/000028", 38 | "0002/000032", 39 | "0002/000035" 40 | ], 41 | "train-adapt-01": [ 42 | "0007/000001", 43 | "0007/000002", 44 | "0007/000003", 45 | "0007/000005", 46 | "0007/000007", 47 | "0007/000008", 48 | "0007/000009", 49 | "0007/000010", 50 | "0007/000011", 51 | "0007/000012", 52 | "0007/000014", 53 | "0007/000015", 54 | "0007/000016", 55 | "0007/000017", 56 | "0007/000018", 57 | "0007/000019", 58 | "0007/000020", 59 | "0007/000021", 60 | "0007/000023", 61 | "0007/000024", 62 | "0007/000028", 63 | "0007/000029", 64 | "0007/000031", 65 | "0007/000033", 66 | "0007/000034" 67 | ], 68 | "test-adapt-01": [ 69 | "0007/000004", 70 | "0007/000006", 71 | "0007/000013", 72 | "0007/000022", 73 | "0007/000025", 74 | "0007/000026", 75 | "0007/000027", 76 | "0007/000030", 77 | "0007/000032", 78 | "0007/000035" 79 | ], 80 | "train-adapt-02": [ 81 | "0010/000001", 82 | "0010/000002", 83 | "0010/000003", 84 | "0010/000005", 85 | "0010/000006", 86 | "0010/000008", 87 | "0010/000009", 88 | "0010/000010", 89 | "0010/000014", 90 | "0010/000015", 91 | "0010/000017", 92 | "0010/000018", 93 | "0010/000019", 94 | "0010/000022", 95 | "0010/000023", 96 | "0010/000024", 97 | "0010/000025", 98 | "0010/000026", 99 | "0010/000027", 100 | "0010/000028", 101 | "0010/000029", 102 | "0010/000031", 103 | "0010/000033", 104 | "0010/000034", 105 | "0010/000035" 106 | ], 107 | "test-adapt-02": [ 108 | "0010/000004", 109 | "0010/000007", 110 | "0010/000011", 111 | "0010/000012", 112 | "0010/000013", 113 | "0010/000016", 114 | "0010/000020", 115 | "0010/000021", 116 | "0010/000030", 117 | "0010/000032" 118 | ], 119 | "train-adapt-03": [ 120 | "0013/000001", 121 | "0013/000002", 122 | "0013/000003", 123 | "0013/000004", 124 | "0013/000006", 125 | "0013/000007", 126 | "0013/000009", 127 | "0013/000010", 128 | "0013/000011", 129 | "0013/000012", 130 | "0013/000013", 131 | "0013/000014", 132 | "0013/000016", 133 | "0013/000017", 134 | "0013/000019", 135 | "0013/000022", 136 | "0013/000023", 137 | "0013/000024", 138 | "0013/000025", 139 | "0013/000027", 140 | "0013/000029", 141 | "0013/000030", 142 | "0013/000031", 143 | "0013/000033", 144 | "0013/000034" 145 | ], 146 | "test-adapt-03": [ 147 | "0013/000005", 148 | "0013/000008", 149 | "0013/000015", 150 | "0013/000018", 151 | "0013/000020", 152 | "0013/000021", 153 | "0013/000026", 154 | "0013/000028", 155 | "0013/000032", 156 | "0013/000035" 157 | ], 158 | "train-adapt-04": [ 159 | "0019/000002", 160 | "0019/000003", 161 | "0019/000004", 162 | "0019/000005", 163 | "0019/000006", 164 | "0019/000007", 165 | "0019/000010", 166 | "0019/000011", 167 | "0019/000013", 168 | "0019/000014", 169 | "0019/000015", 170 | "0019/000016", 171 | "0019/000017", 172 | "0019/000019", 173 | "0019/000020", 174 | "0019/000021", 175 | "0019/000023", 176 | "0019/000024", 177 | "0019/000025", 178 | "0019/000026", 179 | "0019/000027", 180 | "0019/000030", 181 | "0019/000031", 182 | "0019/000032", 183 | "0019/000033" 184 | ], 185 | "test-adapt-04": [ 186 | "0019/000001", 187 | "0019/000008", 188 | "0019/000009", 189 | "0019/000012", 190 | "0019/000018", 191 | "0019/000022", 192 | "0019/000028", 193 | "0019/000029", 194 | "0019/000034", 195 | "0019/000035" 196 | ], 197 | "train-adapt-05": [ 198 | "0003/000001", 199 | "0003/000002", 200 | "0003/000003", 201 | "0003/000005", 202 | "0003/000006", 203 | "0003/000007", 204 | "0003/000010", 205 | "0003/000011", 206 | "0003/000014", 207 | "0003/000017", 208 | "0003/000018", 209 | "0003/000019", 210 | "0003/000020", 211 | "0003/000023", 212 | "0003/000024", 213 | "0003/000025", 214 | "0003/000026", 215 | "0003/000027", 216 | "0003/000028", 217 | "0003/000029", 218 | "0003/000030", 219 | "0003/000031", 220 | "0003/000032", 221 | "0003/000033", 222 | "0003/000035" 223 | ], 224 | "test-adapt-05": [ 225 | "0003/000004", 226 | "0003/000008", 227 | "0003/000009", 228 | "0003/000012", 229 | "0003/000013", 230 | "0003/000015", 231 | "0003/000016", 232 | "0003/000021", 233 | "0003/000022", 234 | "0003/000034" 235 | ], 236 | "train-adapt-06": [ 237 | "0005/000001", 238 | "0005/000002", 239 | "0005/000003", 240 | "0005/000004", 241 | "0005/000005", 242 | "0005/000006", 243 | "0005/000007", 244 | "0005/000009", 245 | "0005/000010", 246 | "0005/000013", 247 | "0005/000014", 248 | "0005/000017", 249 | "0005/000018", 250 | "0005/000019", 251 | "0005/000020", 252 | "0005/000021", 253 | "0005/000023", 254 | "0005/000024", 255 | "0005/000026", 256 | "0005/000027", 257 | "0005/000028", 258 | "0005/000030", 259 | "0005/000031", 260 | "0005/000033", 261 | "0005/000035" 262 | ], 263 | "test-adapt-06": [ 264 | "0005/000008", 265 | "0005/000011", 266 | "0005/000012", 267 | "0005/000015", 268 | "0005/000016", 269 | "0005/000022", 270 | "0005/000025", 271 | "0005/000029", 272 | "0005/000032", 273 | "0005/000034" 274 | ], 275 | "train-adapt-07": [ 276 | "0014/000001", 277 | "0014/000002", 278 | "0014/000003", 279 | "0014/000005", 280 | "0014/000006", 281 | "0014/000007", 282 | "0014/000010", 283 | "0014/000012", 284 | "0014/000014", 285 | "0014/000015", 286 | "0014/000016", 287 | "0014/000017", 288 | "0014/000019", 289 | "0014/000020", 290 | "0014/000021", 291 | "0014/000023", 292 | "0014/000024", 293 | "0014/000026", 294 | "0014/000028", 295 | "0014/000029", 296 | "0014/000031", 297 | "0014/000032", 298 | "0014/000033", 299 | "0014/000034", 300 | "0014/000035" 301 | ], 302 | "test-adapt-07": [ 303 | "0014/000004", 304 | "0014/000008", 305 | "0014/000009", 306 | "0014/000011", 307 | "0014/000013", 308 | "0014/000018", 309 | "0014/000022", 310 | "0014/000025", 311 | "0014/000027", 312 | "0014/000030" 313 | ], 314 | "train-adapt-08": [ 315 | "0015/000001", 316 | "0015/000002", 317 | "0015/000003", 318 | "0015/000004", 319 | "0015/000005", 320 | "0015/000007", 321 | "0015/000009", 322 | "0015/000010", 323 | "0015/000011", 324 | "0015/000013", 325 | "0015/000014", 326 | "0015/000015", 327 | "0015/000017", 328 | "0015/000018", 329 | "0015/000019", 330 | "0015/000020", 331 | "0015/000023", 332 | "0015/000024", 333 | "0015/000025", 334 | "0015/000026", 335 | "0015/000027", 336 | "0015/000029", 337 | "0015/000031", 338 | "0015/000033", 339 | "0015/000034" 340 | ], 341 | "test-adapt-08": [ 342 | "0015/000006", 343 | "0015/000008", 344 | "0015/000012", 345 | "0015/000016", 346 | "0015/000021", 347 | "0015/000022", 348 | "0015/000028", 349 | "0015/000030", 350 | "0015/000032", 351 | "0015/000035" 352 | ], 353 | "train-adapt-09": [ 354 | "0017/000001", 355 | "0017/000002", 356 | "0017/000003", 357 | "0017/000005", 358 | "0017/000006", 359 | "0017/000007", 360 | "0017/000008", 361 | "0017/000010", 362 | "0017/000011", 363 | "0017/000012", 364 | "0017/000014", 365 | "0017/000015", 366 | "0017/000017", 367 | "0017/000019", 368 | "0017/000020", 369 | "0017/000021", 370 | "0017/000022", 371 | "0017/000023", 372 | "0017/000024", 373 | "0017/000026", 374 | "0017/000027", 375 | "0017/000031", 376 | "0017/000032", 377 | "0017/000033", 378 | "0017/000034" 379 | ], 380 | "test-adapt-09": [ 381 | "0017/000004", 382 | "0017/000009", 383 | "0017/000013", 384 | "0017/000016", 385 | "0017/000018", 386 | "0017/000025", 387 | "0017/000028", 388 | "0017/000029", 389 | "0017/000030", 390 | "0017/000035" 391 | ] 392 | } -------------------------------------------------------------------------------- /promonet/assets/stats/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxrmorrison/promonet/86b815e8511d526da42cd8333a9f5a2872b11981/promonet/assets/stats/.gitkeep -------------------------------------------------------------------------------- /promonet/assets/stats/vctk-256-loudness-pitch-viterbi.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxrmorrison/promonet/86b815e8511d526da42cd8333a9f5a2872b11981/promonet/assets/stats/vctk-256-loudness-pitch-viterbi.pt -------------------------------------------------------------------------------- /promonet/assets/stats/vctk-256-loudness-pitch.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxrmorrison/promonet/86b815e8511d526da42cd8333a9f5a2872b11981/promonet/assets/stats/vctk-256-loudness-pitch.pt -------------------------------------------------------------------------------- /promonet/assets/stats/vctk-256-viterbi.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxrmorrison/promonet/86b815e8511d526da42cd8333a9f5a2872b11981/promonet/assets/stats/vctk-256-viterbi.pt -------------------------------------------------------------------------------- /promonet/assets/stats/vctk-train-speaker-averages-viterbi.json: -------------------------------------------------------------------------------- 1 | { 2 | "0000": 170.50909350156448, 3 | "0001": 108.05634884311132, 4 | "0002": 112.98379369785437, 5 | "0003": 187.73156422115358, 6 | "0004": 168.43632728325326, 7 | "0005": 181.36740336509175, 8 | "0006": 165.02675924771938, 9 | "0007": 111.3088675313178, 10 | "0008": 194.6425612323579, 11 | "0009": 175.3831922073543, 12 | "0010": 204.52543989076057, 13 | "0011": 82.0053420343216, 14 | "0012": 197.5725004522476, 15 | "0013": 177.04341861701505, 16 | "0014": 218.63597987686623, 17 | "0015": 108.65834632950205, 18 | "0016": 114.13538748708585, 19 | "0017": 194.39792827624015, 20 | "0018": 91.61959776183905, 21 | "0019": 100.22989821194251, 22 | "0020": 130.54939119402817, 23 | "0021": 223.90044434421634, 24 | "0022": 161.67108642049354, 25 | "0023": 195.3145179638834, 26 | "0024": 118.59526430280677, 27 | "0025": 113.14717625149157, 28 | "0026": 212.0594390070102, 29 | "0027": 79.12518609689728, 30 | "0028": 139.0370568079839, 31 | "0029": 91.073702491899, 32 | "0030": 196.59118101585938, 33 | "0031": 112.53526078104967, 34 | "0032": 112.94078250308847, 35 | "0033": 105.74920742087461, 36 | "0034": 229.4376787060624, 37 | "0035": 156.06591725579463, 38 | "0036": 103.81631165328098, 39 | "0037": 184.25529723803535, 40 | "0038": 182.63751402716167, 41 | "0039": 174.59498950858432, 42 | "0040": 173.76185020247883, 43 | "0041": 195.99980701534733, 44 | "0042": 178.40119822878384, 45 | "0043": 100.15386559317466, 46 | "0044": 111.44868224962579, 47 | "0045": 116.4088663992755, 48 | "0046": 144.6717632643517, 49 | "0047": 99.11680231904514, 50 | "0048": 102.6019883145031, 51 | "0049": 209.8892238390337, 52 | "0050": 186.65081995768662, 53 | "0051": 108.10661097719151, 54 | "0052": 120.01537300771115, 55 | "0053": 171.78893597724323, 56 | "0054": 90.89059859906492, 57 | "0055": 178.63414347192787, 58 | "0056": 197.967118168292, 59 | "0057": 96.78919528319207, 60 | "0058": 116.12752950531446, 61 | "0059": 131.5469035383589, 62 | "0060": 97.13428382467303, 63 | "0061": 181.83119231094136, 64 | "0062": 103.65266389399328, 65 | "0063": 173.28662402725752, 66 | "0064": 162.58037683371643, 67 | "0065": 176.61933456011465, 68 | "0066": 182.5478335932092, 69 | "0067": 116.26102297072265, 70 | "0068": 162.25613638558548, 71 | "0069": 186.44352826824436, 72 | "0070": 162.8076229715075, 73 | "0071": 115.23258185918912, 74 | "0072": 207.30382039282253, 75 | "0073": 103.06589660674328, 76 | "0074": 218.93745281165096, 77 | "0075": 178.59087103122354, 78 | "0076": 235.93989623417627, 79 | "0077": 173.591190283374, 80 | "0078": 227.9173501515729, 81 | "0079": 100.20651270046613, 82 | "0080": 195.69958045674724, 83 | "0081": 168.34149010724437, 84 | "0082": 174.05137348674236, 85 | "0083": 95.14694175575815, 86 | "0084": 232.97430279693407, 87 | "0085": 183.53710907157944, 88 | "0086": 207.47875136137878, 89 | "0087": 80.07636902751864, 90 | "0088": 193.10333813393453, 91 | "0089": 174.68076294186403, 92 | "0090": 179.1302513413154, 93 | "0091": 93.78072085098682, 94 | "0092": 201.85447017892298, 95 | "0093": 194.84634590710175, 96 | "0094": 192.94588425296746, 97 | "0095": 195.47371095017337, 98 | "0096": 179.30700869416546, 99 | "0097": 150.17523198645787, 100 | "0098": 96.24032711974795, 101 | "0099": 90.7838955352854, 102 | "0100": 202.98198241841865, 103 | "0101": 99.84826532541551, 104 | "0102": 180.24154380708757, 105 | "0103": 193.80520876093127, 106 | "0104": 107.75444725166328, 107 | "0105": 98.4518217502216, 108 | "0106": 112.51205351510836, 109 | "0107": 98.93230657244858, 110 | "0108": 202.57462316542095 111 | } -------------------------------------------------------------------------------- /promonet/baseline/__init__.py: -------------------------------------------------------------------------------- 1 | from . import mels 2 | from . import world 3 | -------------------------------------------------------------------------------- /promonet/baseline/mels.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | import torchaudio 5 | import torchutil 6 | 7 | import promonet 8 | 9 | 10 | ############################################################################### 11 | # Mel spectrogram reconstruction 12 | ############################################################################### 13 | 14 | 15 | def from_audio( 16 | audio, 17 | sample_rate=promonet.SAMPLE_RATE, 18 | speaker=0, 19 | spectral_balance_ratio: float = 1., 20 | loudness_ratio: float = 1., 21 | checkpoint=None, 22 | gpu=None 23 | ): 24 | """Perform Mel spectrogram reconstruction""" 25 | device = 'cpu' if gpu is None else f'cuda:{gpu}' 26 | 27 | # Resample 28 | audio = resample(audio.to(device), sample_rate) 29 | 30 | # Preprocess 31 | spectrogram = promonet.preprocess.spectrogram.from_audio(audio) 32 | 33 | # Reconstruct 34 | return from_features( 35 | spectrogram, 36 | speaker, 37 | spectral_balance_ratio, 38 | loudness_ratio, 39 | checkpoint) 40 | 41 | 42 | def from_features( 43 | spectrogram, 44 | speaker=0, 45 | spectral_balance_ratio: float = 1., 46 | loudness_ratio: float = 1., 47 | checkpoint=None 48 | ): 49 | """Perform Mel spectrogram reconstruction""" 50 | device = spectrogram.device 51 | 52 | with torchutil.time.context('load'): 53 | 54 | # Cache model 55 | if ( 56 | not hasattr(from_features, 'model') or 57 | from_features.checkpoint != checkpoint or 58 | from_features.device != device 59 | ): 60 | model = promonet.model.MelGenerator().to(device) 61 | if type(checkpoint) is str: 62 | checkpoint = Path(checkpoint) 63 | if checkpoint.is_dir(): 64 | checkpoint = torchutil.checkpoint.latest_path( 65 | checkpoint, 66 | 'generator-*.pt') 67 | model, *_ = torchutil.checkpoint.load(checkpoint, model) 68 | from_features.model = model 69 | from_features.checkpoint = checkpoint 70 | from_features.device = device 71 | 72 | with torchutil.time.context('generate'): 73 | 74 | # Default length is the entire sequence 75 | lengths = torch.tensor( 76 | (spectrogram.shape[-1],), 77 | dtype=torch.long, 78 | device=device) 79 | 80 | # Specify speaker 81 | speakers = torch.full((1,), speaker, dtype=torch.long, device=device) 82 | 83 | # Format ratio 84 | spectral_balance_ratio = torch.tensor( 85 | [spectral_balance_ratio], 86 | dtype=torch.float, 87 | device=device) 88 | 89 | # Loudness ratio 90 | loudness_ratio = torch.tensor( 91 | [loudness_ratio], 92 | dtype=torch.float, 93 | device=device) 94 | 95 | # Reconstruct 96 | with torchutil.inference.context(from_features.model): 97 | return from_features.model( 98 | spectrogram[None], 99 | speakers, 100 | spectral_balance_ratio, 101 | loudness_ratio 102 | )[0].to(torch.float32) 103 | 104 | 105 | def from_file( 106 | audio_file, 107 | speaker=0, 108 | spectral_balance_ratio: float = 1., 109 | loudness_ratio: float = 1., 110 | checkpoint=None, 111 | gpu=None 112 | ): 113 | """Perform Mel reconstruction from audio file""" 114 | return from_audio( 115 | promonet.load.audio(audio_file), 116 | speaker=speaker, 117 | spectral_balance_ratio=spectral_balance_ratio, 118 | loudness_ratio=loudness_ratio, 119 | checkpoint=checkpoint, 120 | gpu=gpu) 121 | 122 | 123 | def from_file_to_file( 124 | audio_file, 125 | output_file, 126 | speaker=0, 127 | spectral_balance_ratio: float = 1., 128 | loudness_ratio: float = 1., 129 | checkpoint=None, 130 | gpu=None 131 | ): 132 | """Perform Mel reconstruction from audio file and save""" 133 | # Reconstruct 134 | reconstructed = from_file( 135 | audio_file, 136 | speaker, 137 | spectral_balance_ratio, 138 | loudness_ratio, 139 | checkpoint, 140 | gpu) 141 | 142 | # Save 143 | torchaudio.save(output_file, reconstructed.cpu(), promonet.SAMPLE_RATE) 144 | 145 | 146 | def from_files_to_files( 147 | audio_files, 148 | output_files, 149 | speakers=None, 150 | spectral_balance_ratio: float = 1., 151 | loudness_ratio: float = 1., 152 | checkpoint=None, 153 | gpu=None 154 | ): 155 | """Perform Mel reconstruction from audio files and save""" 156 | if speakers is None: 157 | speakers = [0] * len(audio_files) 158 | 159 | # Generate 160 | for item in zip(audio_files, output_files, speakers): 161 | from_file_to_file( 162 | *item, 163 | spectral_balance_ratio=spectral_balance_ratio, 164 | loudness_ratio=loudness_ratio, 165 | checkpoint=checkpoint, 166 | gpu=gpu) 167 | 168 | 169 | ############################################################################### 170 | # Utilities 171 | ############################################################################### 172 | 173 | 174 | def resample(audio, sample_rate): 175 | """Resample audio to ProMoNet sample rate""" 176 | # Cache resampling filter 177 | key = str(sample_rate) 178 | if not hasattr(resample, key): 179 | setattr( 180 | resample, 181 | key, 182 | torchaudio.transforms.Resample(sample_rate, promonet.SAMPLE_RATE)) 183 | 184 | # Resample 185 | return getattr(resample, key)(audio) 186 | -------------------------------------------------------------------------------- /promonet/baseline/world.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pyworld 3 | import scipy 4 | import torch 5 | import torchaudio 6 | import torchutil 7 | 8 | import promonet 9 | 10 | 11 | ############################################################################### 12 | # Constants 13 | ############################################################################### 14 | 15 | 16 | ALLOWED_RANGE = .8 17 | 18 | 19 | ############################################################################### 20 | # World speech editing 21 | ############################################################################### 22 | 23 | 24 | def from_audio( 25 | audio, 26 | sample_rate=promonet.SAMPLE_RATE, 27 | grid=None, 28 | loudness=None, 29 | pitch=None, 30 | periodicity=None 31 | ): 32 | """Perform World vocoding""" 33 | # Maybe resample 34 | if sample_rate != promonet.SAMPLE_RATE: 35 | resampler = torch.transforms.Resample( 36 | sample_rate, 37 | promonet.SAMPLE_RATE) 38 | audio = resampler(audio) 39 | 40 | # Get target number of frames 41 | if grid is not None: 42 | frames = grid.shape[-1] 43 | elif loudness is not None: 44 | frames = loudness.shape[-1] 45 | elif pitch is not None: 46 | frames = pitch.shape[-1] 47 | else: 48 | frames = promonet.convert.samples_to_frames(audio.shape[-1]) 49 | 50 | # World parameterization 51 | target_pitch, spectrogram, aperiodicity = analyze( 52 | audio.squeeze().numpy(), frames) 53 | 54 | # Maybe time-stretch 55 | if grid is not None: 56 | ( 57 | target_pitch, 58 | spectrogram, 59 | aperiodicity 60 | ) = linear_time_stretch( 61 | target_pitch, 62 | spectrogram, 63 | aperiodicity, 64 | grid.numpy() 65 | ) 66 | 67 | # Maybe pitch-shift 68 | if pitch is not None: 69 | pitch = pitch.squeeze().numpy().astype(np.float64) 70 | 71 | # In WORLD, unvoiced frames are masked with zeros 72 | if periodicity is not None: 73 | unvoiced = \ 74 | periodicity.squeeze().numpy() < promonet.VOICING_THRESHOLD 75 | pitch[unvoiced] = 0. 76 | else: 77 | pitch = target_pitch 78 | 79 | # Synthesize using modified parameters 80 | vocoded = pyworld.synthesize( 81 | pitch, 82 | spectrogram, 83 | aperiodicity, 84 | promonet.SAMPLE_RATE, 85 | promonet.HOPSIZE / promonet.SAMPLE_RATE * 1000.) 86 | 87 | # Convert to torch 88 | vocoded = torch.from_numpy(vocoded)[None] 89 | 90 | # Ensure correct length 91 | length = promonet.convert.frames_to_samples(len(pitch)) 92 | if vocoded.shape[1] != length: 93 | temp = torch.zeros((1, length)) 94 | crop_point = min(length, vocoded.shape[1]) 95 | temp[:, :crop_point] = vocoded[:, :crop_point] 96 | vocoded = temp 97 | 98 | # Maybe scale loudness 99 | if loudness is not None: 100 | vocoded = promonet.preprocess.loudness.scale( 101 | vocoded, 102 | promonet.preprocess.loudness.band_average(loudness, 1)) 103 | 104 | return vocoded.to(torch.float32) 105 | 106 | 107 | def from_file( 108 | audio_file, 109 | grid_file=None, 110 | loudness_file=None, 111 | pitch_file=None, 112 | periodicity_file=None 113 | ): 114 | """Perform World vocoding on an audio file""" 115 | return from_audio( 116 | promonet.load.audio(audio_file), 117 | promonet.SAMPLE_RATE, 118 | None if grid_file is None else torch.load(grid_file), 119 | None if loudness_file is None else torch.load(loudness_file), 120 | None if pitch_file is None else torch.load(pitch_file), 121 | None if periodicity_file is None else torch.load(periodicity_file)) 122 | 123 | 124 | def from_file_to_file( 125 | audio_file, 126 | output_file, 127 | grid_file=None, 128 | loudness_file=None, 129 | pitch_file=None, 130 | periodicity_file=None 131 | ): 132 | """Perform World vocoding on an audio file and save""" 133 | vocoded = from_file( 134 | audio_file, 135 | grid_file, 136 | loudness_file, 137 | pitch_file, 138 | periodicity_file) 139 | torchaudio.save(output_file, vocoded, promonet.SAMPLE_RATE) 140 | 141 | 142 | def from_files_to_files( 143 | audio_files, 144 | output_files, 145 | grid_files=None, 146 | loudness_files=None, 147 | pitch_files=None, 148 | periodicity_files=None 149 | ): 150 | """Perform World vocoding on multiple files and save""" 151 | if grid_files is None: 152 | grid_files = [None] * len(audio_files) 153 | if loudness_files is None: 154 | loudness_files = [None] * len(audio_files) 155 | if pitch_files is None: 156 | pitch_files = [None] * len(audio_files) 157 | if periodicity_files is None: 158 | periodicity_files = [None] * len(audio_files) 159 | iterator = zip( 160 | audio_files, 161 | output_files, 162 | grid_files, 163 | loudness_files, 164 | pitch_files, 165 | periodicity_files) 166 | for item in torchutil.iterator(iterator, 'world', total=len(audio_files)): 167 | from_file_to_file(*item) 168 | 169 | 170 | ############################################################################### 171 | # Utilities 172 | ############################################################################### 173 | 174 | 175 | def analyze(audio, frames): 176 | """Convert an audio signal to WORLD parameter representation 177 | Arguments 178 | audio : np.array(shape=(samples,)) 179 | The audio being analyzed 180 | Returns 181 | pitch : np.array(shape=(frames,)) 182 | The pitch contour 183 | spectrogram : np.array(shape=(frames, channels)) 184 | The audio spectrogram 185 | aperiodicity : np.array(shape=(frames,)) 186 | The voiced/unvoiced confidence 187 | """ 188 | # Cast to double 189 | audio = audio.astype(np.float64) 190 | 191 | # Hopsize in milliseconds 192 | frame_period = promonet.HOPSIZE / promonet.SAMPLE_RATE * 1000. 193 | 194 | # Extract pitch 195 | samples = promonet.convert.frames_to_samples(frames) 196 | pitch, time = pyworld.dio( 197 | audio, 198 | promonet.SAMPLE_RATE, 199 | frame_period=frame_period, 200 | f0_floor=promonet.FMIN, 201 | f0_ceil=promonet.FMAX, 202 | allowed_range=ALLOWED_RANGE) 203 | pitch = pitch[:frames] 204 | time = time[:frames] 205 | 206 | # Postprocess pitch 207 | pitch = pyworld.stonemask(audio, pitch, time, promonet.SAMPLE_RATE) 208 | 209 | # Extract spectrogram 210 | spectrogram = pyworld.cheaptrick(audio, pitch, time, promonet.SAMPLE_RATE) 211 | 212 | # Extract aperiodicity 213 | aperiodicity = pyworld.d4c(audio, pitch, time, promonet.SAMPLE_RATE) 214 | 215 | return pitch, spectrogram, aperiodicity 216 | 217 | 218 | def linear_time_stretch( 219 | prev_pitch, 220 | prev_spectrogram, 221 | prev_aperiodicity, 222 | grid 223 | ): 224 | """Apply time stretch in WORLD parameter space""" 225 | grid = grid[0] if grid.ndim == 2 else grid 226 | 227 | # Number of frames before and after 228 | prev_frames = len(prev_pitch) 229 | next_frames = len(grid) 230 | 231 | # Time-aligned grid before and after 232 | prev_grid = np.linspace(0, prev_frames - 1, prev_frames) 233 | 234 | # Apply time stretch to pitch 235 | pitch = linear_time_stretch_pitch( 236 | prev_pitch, prev_grid, grid, next_frames) 237 | 238 | # Allocate spectrogram and aperiodicity buffers 239 | frequencies = prev_spectrogram.shape[1] 240 | spectrogram = np.zeros((next_frames, frequencies)) 241 | aperiodicity = np.zeros((next_frames, frequencies)) 242 | 243 | # Apply time stretch to all channels of spectrogram and aperiodicity 244 | for i in range(frequencies): 245 | spectrogram[:, i] = np.interp( 246 | grid, prev_grid, prev_spectrogram[:, i]) 247 | aperiodicity[:, i] = np.interp( 248 | grid, prev_grid, prev_aperiodicity[:, i]) 249 | 250 | return pitch, spectrogram, aperiodicity 251 | 252 | 253 | def linear_time_stretch_pitch(pitch, prev_grid, grid, next_frames): 254 | """Perform time-stretching on pitch features""" 255 | if (pitch == 0.).all(): 256 | return np.zeros(next_frames) 257 | 258 | # Get unvoiced tokens 259 | unvoiced = pitch == 0. 260 | 261 | # Linearly interpolate unvoiced regions 262 | pitch[unvoiced] = np.interp( 263 | np.where(unvoiced)[0], np.where(~unvoiced)[0], pitch[~unvoiced]) 264 | 265 | # Apply time stretch to pitch in base-2 log-space 266 | pitch = 2 ** np.interp(grid, prev_grid, np.log2(pitch)) 267 | 268 | # Apply time stretch to unvoiced sequence 269 | unvoiced = np.interp(grid, prev_grid, unvoiced) 270 | 271 | # Reapply unvoiced tokens 272 | pitch[unvoiced > .5] = 0. 273 | 274 | return pitch 275 | -------------------------------------------------------------------------------- /promonet/config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxrmorrison/promonet/86b815e8511d526da42cd8333a9f5a2872b11981/promonet/config/__init__.py -------------------------------------------------------------------------------- /promonet/config/static.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import promonet 4 | 5 | 6 | ############################################################################### 7 | # Audio parameters 8 | ############################################################################### 9 | 10 | 11 | # Threshold to sparsify Mel spectrograms 12 | LOG_DYNAMIC_RANGE_COMPRESSION_THRESHOLD = ( 13 | None if promonet.DYNAMIC_RANGE_COMPRESSION_THRESHOLD is None else 14 | math.log(promonet.DYNAMIC_RANGE_COMPRESSION_THRESHOLD)) 15 | 16 | # Base-2 log of pitch range boundaries 17 | LOG_FMIN = math.log2(promonet.FMIN) 18 | LOG_FMAX = math.log2(promonet.FMAX) 19 | 20 | 21 | ############################################################################### 22 | # Directories 23 | ############################################################################### 24 | 25 | 26 | # Location to save data augmentation information 27 | AUGMENT_DIR = promonet.ASSETS_DIR / 'augmentations' 28 | 29 | # Location to save dataset partitions 30 | PARTITION_DIR = ( 31 | promonet.ASSETS_DIR / 32 | 'partitions' / 33 | ('adaptation' if promonet.ADAPTATION else 'multispeaker')) 34 | 35 | 36 | ############################################################################### 37 | # Model parameters 38 | ############################################################################### 39 | 40 | 41 | # Global input channels 42 | GLOBAL_CHANNELS = ( 43 | promonet.SPEAKER_CHANNELS + 44 | promonet.AUGMENT_PITCH + 45 | promonet.AUGMENT_LOUDNESS) 46 | 47 | # Number of input features to the generator 48 | NUM_FEATURES = promonet.NUM_MELS if promonet.SPECTROGRAM_ONLY else ( 49 | promonet.PPG_CHANNELS + 50 | ('loudness' in promonet.INPUT_FEATURES) * promonet.LOUDNESS_BANDS + 51 | ('periodicity' in promonet.INPUT_FEATURES) + 52 | ('pitch' in promonet.INPUT_FEATURES) * ( 53 | promonet.PITCH_EMBEDDING_SIZE if promonet.PITCH_EMBEDDING else 1)) 54 | 55 | # Number of input features to the discriminator 56 | NUM_FEATURES_DISCRIM = 1 57 | 58 | # Number of speakers 59 | if promonet.TRAINING_DATASET == 'daps': 60 | NUM_SPEAKERS = 20 61 | elif promonet.TRAINING_DATASET == 'libritts': 62 | NUM_SPEAKERS = 1230 63 | elif promonet.TRAINING_DATASET == 'vctk': 64 | NUM_SPEAKERS = 109 65 | else: 66 | raise ValueError(f'Dataset {promonet.TRAINING_DATASET} is not defined') 67 | 68 | # Number of previous samples 69 | if promonet.MODEL == 'cargan': 70 | NUM_PREVIOUS_SAMPLES = promonet.CARGAN_INPUT_SIZE 71 | elif promonet.MODEL == 'fargan': 72 | NUM_PREVIOUS_SAMPLES = promonet.HOPSIZE * promonet.FARGAN_PREVIOUS_FRAMES 73 | else: 74 | NUM_PREVIOUS_SAMPLES = 1 75 | -------------------------------------------------------------------------------- /promonet/convert.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | 4 | import torch 5 | 6 | import promonet 7 | 8 | 9 | ############################################################################### 10 | # Loudness conversions 11 | ############################################################################### 12 | 13 | 14 | def db_to_ratio(db): 15 | """Convert decibels to perceptual loudness ratio""" 16 | return 2 ** (db / 10) 17 | 18 | 19 | def ratio_to_db(ratio): 20 | """Convert perceptual loudness ratio to decibels""" 21 | if isinstance(ratio, torch.Tensor): 22 | return 10 * torch.log2(ratio) 23 | else: 24 | return 10 * math.log2(ratio) 25 | 26 | 27 | ############################################################################### 28 | # Pitch conversions 29 | ############################################################################### 30 | 31 | 32 | def bins_to_hz( 33 | bins, 34 | num_bins=promonet.PITCH_BINS, 35 | fmin=promonet.FMIN, 36 | fmax=promonet.FMAX): 37 | """Convert pitch in bin indices to hz""" 38 | if promonet.VARIABLE_PITCH_BINS: 39 | # Get bin boundaries 40 | distribution = torch.cat([ 41 | promonet.load.pitch_distribution(), 42 | torch.tensor([promonet.FMAX])]) 43 | 44 | # Compute offset in Hz 45 | offset = 2 ** ( 46 | ( 47 | torch.log2(distribution[bins + 1]) - 48 | torch.log2(distribution[bins]) 49 | ) / 2) 50 | return distribution[bins] + offset 51 | 52 | # Normalize to [0, 1] 53 | logfmin = torch.log2(torch.tensor(fmin)) 54 | logfmax = torch.log2(torch.tensor(fmax)) 55 | normalized = bins.to(torch.float) / (num_bins - 1) 56 | 57 | # Convert to hz 58 | hz = 2 ** ((normalized * (logfmax - logfmin)) + logfmin) 59 | 60 | # Clip to bounds 61 | return torch.clip(hz, fmin, fmax) 62 | 63 | 64 | def cents_to_ratio(cents): 65 | """Convert pitch ratio in cents to linear ratio""" 66 | return 2 ** (cents / 1200) 67 | 68 | 69 | def hz_to_bins( 70 | hz, 71 | num_bins=promonet.PITCH_BINS, 72 | fmin=promonet.FMIN, 73 | fmax=promonet.FMAX): 74 | """Convert pitch in hz to bins""" 75 | # Clip to bounds 76 | hz = torch.clip(hz, fmin, fmax) 77 | 78 | # Maybe size bins according to count 79 | if promonet.VARIABLE_PITCH_BINS: 80 | distribution = promonet.load.pitch_distribution().to(hz.device) 81 | bins = torch.searchsorted(distribution, hz) 82 | return torch.clip(bins, 0, num_bins.item() - 1) 83 | 84 | # Normalize to [0, 1] 85 | logfmin = torch.log2(fmin) 86 | logfmax = torch.log2(fmax) 87 | centered = torch.log2(hz) - logfmin 88 | normalized = centered / (logfmax - logfmin) 89 | 90 | # Convert to integer bin 91 | return ((num_bins - 1) * normalized).to(torch.long) 92 | 93 | 94 | def ratio_to_cents(ratio): 95 | """Convert linear pitch ratio to cents""" 96 | return 1200 * math.log2(ratio) 97 | 98 | 99 | ############################################################################### 100 | # Time conversions 101 | ############################################################################### 102 | 103 | 104 | def seconds_to_frames(seconds): 105 | """Convert seconds to frames""" 106 | return int(seconds * promonet.SAMPLE_RATE / promonet.HOPSIZE) 107 | 108 | 109 | def frames_to_samples(frames): 110 | """Convert number of frames to samples""" 111 | return frames * promonet.HOPSIZE 112 | 113 | 114 | def frames_to_seconds(frames): 115 | """Convert number of frames to seconds""" 116 | return frames * samples_to_seconds(promonet.HOPSIZE) 117 | 118 | 119 | def samples_to_seconds(samples, sample_rate=promonet.SAMPLE_RATE): 120 | """Convert time in samples to seconds""" 121 | return samples / sample_rate 122 | 123 | 124 | def samples_to_frames(samples): 125 | """Convert time in samples to frames""" 126 | with warnings.catch_warnings(): 127 | warnings.simplefilter('ignore') 128 | return samples // promonet.HOPSIZE 129 | -------------------------------------------------------------------------------- /promonet/data/__init__.py: -------------------------------------------------------------------------------- 1 | from . import augment 2 | from . import dataset 3 | from . import download 4 | from . import preprocess 5 | from . import sampler 6 | from .collate import collate 7 | from .dataset import Dataset 8 | from .loader import loader 9 | from .sampler import sampler 10 | -------------------------------------------------------------------------------- /promonet/data/augment/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import * 2 | from . import loudness 3 | from . import pitch 4 | -------------------------------------------------------------------------------- /promonet/data/augment/__main__.py: -------------------------------------------------------------------------------- 1 | import yapecs 2 | 3 | import promonet 4 | 5 | 6 | ############################################################################### 7 | # Data augmentation 8 | ############################################################################### 9 | 10 | 11 | def parse_args(): 12 | """Parse command-line arguments""" 13 | parser = yapecs.ArgumentParser(description='Perform data augmentation') 14 | parser.add_argument( 15 | '--datasets', 16 | nargs='+', 17 | default=promonet.DATASETS, 18 | help='The name of the datasets to augment') 19 | return parser.parse_args() 20 | 21 | 22 | promonet.data.augment.datasets(**vars(parse_args())) 23 | -------------------------------------------------------------------------------- /promonet/data/augment/core.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import torch 4 | import torchutil 5 | 6 | import promonet 7 | 8 | 9 | ############################################################################### 10 | # Data augmentation 11 | ############################################################################### 12 | 13 | 14 | @torchutil.notify('augment') 15 | def datasets(datasets): 16 | """Perform data augmentation on cached datasets""" 17 | for dataset in datasets: 18 | 19 | # Remove cached metadata that may become stale 20 | for stats_file in (promonet.ASSETS_DIR / 'stats').glob('*.pt'): 21 | stats_file.unlink() 22 | 23 | # Get cache directory 24 | directory = promonet.CACHE_DIR / dataset 25 | 26 | # Get files 27 | audio_files = sorted(directory.rglob('*-100.wav')) 28 | 29 | # Augment 30 | from_files_to_files(audio_files, dataset) 31 | 32 | 33 | def from_files_to_files(audio_files, name): 34 | """Perform data augmentation on audio files""" 35 | torch.manual_seed(promonet.RANDOM_SEED) 36 | 37 | # Get augmentation ratios 38 | ratios = sample(len(audio_files)) 39 | 40 | # Get locations to save output 41 | output_files = [ 42 | file.parent / 43 | f'{file.stem.split("-")[0]}-p{int(ratio * 100):03d}.wav' 44 | for file, ratio in zip(audio_files, ratios)] 45 | 46 | # Augment 47 | promonet.data.augment.pitch.from_files_to_files( 48 | audio_files, 49 | output_files, 50 | ratios) 51 | 52 | # Save augmentation ratios 53 | save(promonet.AUGMENT_DIR / f'{name}-pitch.json', audio_files, ratios) 54 | 55 | # Get augmentation ratios 56 | ratios = sample(len(audio_files)) 57 | 58 | # Get locations to save output 59 | output_files = [ 60 | file.parent / 61 | f'{file.stem.split("-")[0]}-l{int(ratio * 100):03d}.wav' 62 | for file, ratio in zip(audio_files, ratios)] 63 | 64 | # Augment 65 | # N.B. Ratios that cause clipping will be resampled 66 | ratios = promonet.data.augment.loudness.from_files_to_files( 67 | audio_files, 68 | output_files, 69 | ratios) 70 | 71 | # Save augmentation ratios 72 | save( 73 | promonet.AUGMENT_DIR / f'{name}-loudness.json', 74 | audio_files, 75 | ratios) 76 | 77 | 78 | ############################################################################### 79 | # Data augmentation 80 | ############################################################################### 81 | 82 | 83 | def sample(n): 84 | """Sample data augmentation ratios""" 85 | distribution = torch.distributions.uniform.Uniform( 86 | torch.log2(torch.tensor(promonet.AUGMENTATION_RATIO_MIN)), 87 | torch.log2(torch.tensor(promonet.AUGMENTATION_RATIO_MAX))) 88 | ratios = 2 ** distribution.sample([n]) 89 | 90 | # Prevent duplicates 91 | ratios[(ratios * 100).to(torch.int) == 100] += 1 92 | 93 | return ratios 94 | 95 | 96 | def save(json_file, audio_files, ratios): 97 | """Cache augmentation ratios""" 98 | ratio_dict = {} 99 | for audio_file, ratio in zip(audio_files, ratios): 100 | key = f'{audio_file.parent.name}/{audio_file.stem.split("-")[0]}' 101 | ratio_dict[key] = f'{int(ratio * 100):03d}' 102 | with open(json_file, 'w') as file: 103 | json.dump(ratio_dict, file, indent=4) 104 | -------------------------------------------------------------------------------- /promonet/data/augment/loudness.py: -------------------------------------------------------------------------------- 1 | import resampy 2 | import soundfile 3 | import torchutil 4 | 5 | import promonet 6 | 7 | 8 | ############################################################################### 9 | # Loudness data augmentation 10 | ############################################################################### 11 | 12 | 13 | def from_audio(audio, sample_rate, ratio): 14 | """Perform volume data augmentation on audio""" 15 | # Augment audio 16 | augmented = promonet.preprocess.loudness.shift( 17 | audio, 18 | promonet.convert.ratio_to_db(ratio)) 19 | 20 | # Resample ratio if the audio clips 21 | while ((augmented <= -1.) | (augmented >= 1.)).any(): 22 | ratio = promonet.data.augment.sample(1)[0] 23 | augmented = promonet.preprocess.loudness.shift( 24 | audio, 25 | promonet.convert.ratio_to_db(ratio)) 26 | 27 | # Resample to promonet sample rate 28 | augmented = resampy.resample(augmented, sample_rate, promonet.SAMPLE_RATE) 29 | 30 | return augmented, ratio 31 | 32 | 33 | def from_file(audio_file, ratio): 34 | """Perform volume data augmentation on audio file""" 35 | return from_audio(*soundfile.read(str(audio_file)), ratio) 36 | 37 | 38 | def from_file_to_file(audio_file, output_file, ratio): 39 | """Perform volume data augmentation on audio file and save""" 40 | augmented, new_ratio = from_file(audio_file, ratio) 41 | if new_ratio != ratio: 42 | output_file = ( 43 | output_file.parent / output_file.name.replace( 44 | f'{int(ratio * 100):03d}', 45 | f'{int(new_ratio * 100):03d}')) 46 | ratio = new_ratio 47 | soundfile.write(str(output_file), augmented, promonet.SAMPLE_RATE) 48 | return ratio 49 | 50 | 51 | def from_files_to_files(audio_files, output_files, ratios): 52 | """Perform volume data augmentation on audio files and save""" 53 | return torchutil.multiprocess_iterator( 54 | wrapper, 55 | zip(audio_files, output_files, ratios), 56 | 'Augmenting loudness', 57 | total=len(audio_files), 58 | num_workers=promonet.NUM_WORKERS) 59 | 60 | 61 | ############################################################################### 62 | # Loudness data augmentation 63 | ############################################################################### 64 | 65 | 66 | def wrapper(item): 67 | return from_file_to_file(*item) 68 | -------------------------------------------------------------------------------- /promonet/data/augment/pitch.py: -------------------------------------------------------------------------------- 1 | import resampy 2 | import soundfile 3 | import torchutil 4 | 5 | import promonet 6 | 7 | 8 | ############################################################################### 9 | # Pitch-shifting data augmentation 10 | ############################################################################### 11 | 12 | 13 | def from_audio(audio, sample_rate, ratio): 14 | """Perform pitch-shifting data augmentation on audio""" 15 | # Augment audio 16 | augmented = resampy.resample(audio, int(ratio * sample_rate), sample_rate) 17 | 18 | # Resample to promonet sample rate 19 | return resampy.resample(augmented, sample_rate, promonet.SAMPLE_RATE) 20 | 21 | 22 | def from_file(audio_file, ratio): 23 | """Perform pitch-shifting data augmentation on audio file""" 24 | return from_audio(*soundfile.read(str(audio_file)), ratio) 25 | 26 | 27 | def from_file_to_file(audio_file, output_file, ratio): 28 | """Perform pitch-shifting data augmentation on audio file and save""" 29 | augmented = from_file(audio_file, ratio) 30 | soundfile.write(str(output_file), augmented, promonet.SAMPLE_RATE) 31 | 32 | 33 | def from_files_to_files(audio_files, output_files, ratios): 34 | """Perform pitch-shifting data augmentation on audio files and save""" 35 | torchutil.multiprocess_iterator( 36 | wrapper, 37 | zip(audio_files, output_files, ratios), 38 | 'Augmenting pitch', 39 | total=len(audio_files), 40 | num_workers=promonet.NUM_WORKERS) 41 | 42 | 43 | ############################################################################### 44 | # Utilities 45 | ############################################################################### 46 | 47 | 48 | def wrapper(item): 49 | from_file_to_file(*item) 50 | -------------------------------------------------------------------------------- /promonet/data/collate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import promonet 4 | 5 | 6 | ############################################################################### 7 | # Batch collation 8 | ############################################################################### 9 | 10 | 11 | def collate(batch): 12 | """Collate from features, spectrograms, audio, and speaker identities""" 13 | # Unpack 14 | ( 15 | text, 16 | loudness, 17 | pitch, 18 | periodicity, 19 | phonemes, 20 | spectrograms, 21 | audio, 22 | speakers, 23 | spectral_balance_ratios, 24 | loudness_ratios, 25 | stems 26 | ) = zip(*batch) 27 | 28 | # Get lengths in samples 29 | lengths = torch.tensor([a.shape[1] for a in audio], dtype=torch.long) 30 | 31 | # Get batch indices sorted by length 32 | _, sorted_indices = torch.sort(lengths, dim=0, descending=True) 33 | 34 | # Get tensor size in frames and samples 35 | max_length_phonemes = max([p.shape[-1] for p in phonemes]) 36 | max_length_samples = lengths.max().item() 37 | max_length_frames = promonet.convert.samples_to_frames(max_length_samples) 38 | 39 | # We store original lengths for, e.g., loss evaluation 40 | feature_lengths = torch.empty((len(batch),), dtype=torch.long) 41 | 42 | # Initialize padded tensors 43 | padded_phonemes = torch.zeros( 44 | (len(batch), promonet.PPG_CHANNELS, max_length_phonemes), 45 | dtype=torch.float) 46 | padded_pitch = torch.zeros( 47 | (len(batch), max_length_frames), 48 | dtype=torch.float) 49 | padded_periodicity = torch.zeros( 50 | (len(batch), max_length_frames), 51 | dtype=torch.float) 52 | padded_loudness = torch.zeros( 53 | (len(batch), promonet.NUM_FFT // 2 + 1, max_length_frames), 54 | dtype=torch.float) 55 | padded_spectrograms = torch.zeros( 56 | (len(batch), promonet.NUM_FFT // 2 + 1, max_length_frames), 57 | dtype=torch.float) 58 | padded_audio = torch.zeros( 59 | (len(batch), 1, max_length_samples), 60 | dtype=torch.float) 61 | for i, index in enumerate(sorted_indices): 62 | 63 | # Get lengths 64 | feature_lengths[i] = phonemes[index].shape[-1] 65 | 66 | # Prepare phoneme features 67 | padded_phonemes[i, :, :feature_lengths[i]] = phonemes[index] 68 | 69 | # Prepare prosody features 70 | padded_pitch[i, :feature_lengths[i]] = pitch[index] 71 | padded_periodicity[i, :feature_lengths[i]] = periodicity[index] 72 | padded_loudness[i, :, :feature_lengths[i]] = loudness[index] 73 | 74 | # Prepare spectrogram 75 | padded_spectrograms[i, :, :feature_lengths[i]] = \ 76 | spectrograms[index] 77 | 78 | # Prepare audio 79 | padded_audio[i, :, :lengths[index]] = audio[index] 80 | 81 | # Collate speaker IDs or embeddings 82 | if promonet.ZERO_SHOT: 83 | speakers = torch.stack(speakers) 84 | else: 85 | speakers = torch.tensor(speakers, dtype=torch.long) 86 | 87 | # Sort stuff 88 | text = [text[i] for i in sorted_indices] 89 | stems = [stems[i] for i in sorted_indices] 90 | speakers = speakers[sorted_indices] 91 | spectral_balance_ratios = torch.tensor( 92 | spectral_balance_ratios, dtype=torch.float)[sorted_indices] 93 | loudness_ratios = torch.tensor( 94 | loudness_ratios, dtype=torch.float)[sorted_indices] 95 | 96 | return ( 97 | text, 98 | padded_loudness, 99 | padded_pitch, 100 | padded_periodicity, 101 | padded_phonemes, 102 | speakers, 103 | spectral_balance_ratios, 104 | loudness_ratios, 105 | padded_spectrograms, 106 | padded_audio, 107 | stems) 108 | -------------------------------------------------------------------------------- /promonet/data/dataset.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import json 3 | import random 4 | 5 | import torch 6 | 7 | import promonet 8 | import ppgs 9 | 10 | 11 | ############################################################################### 12 | # Dataset 13 | ############################################################################### 14 | 15 | 16 | class Dataset(torch.utils.data.Dataset): 17 | 18 | def __init__(self, dataset, partition, adapt=promonet.ADAPTATION): 19 | super().__init__() 20 | self.cache = promonet.CACHE_DIR / dataset 21 | self.partition = partition 22 | self.viterbi = '-viterbi' if promonet.VITERBI_DECODE_PITCH else '' 23 | 24 | # Get stems corresponding to partition 25 | partition_dict = promonet.load.partition(dataset, adapt) 26 | if partition is not None: 27 | stems = partition_dict[partition] 28 | else: 29 | stems = sum(partition_dict.values(), start=[]) 30 | self.stems = [f'{stem}-100' for stem in stems] 31 | 32 | # For training, maybe add augmented data 33 | # This also applies to adaptation partitions: train-adapt-xx 34 | if 'train' in partition: 35 | if promonet.AUGMENT_PITCH: 36 | with open( 37 | promonet.AUGMENT_DIR / f'{dataset}-pitch.json' 38 | ) as file: 39 | ratios = json.load(file) 40 | self.stems.extend([f'{stem}-p{ratios[stem]}' for stem in stems]) 41 | if promonet.AUGMENT_LOUDNESS: 42 | with open( 43 | promonet.AUGMENT_DIR / f'{dataset}-loudness.json' 44 | ) as file: 45 | ratios = json.load(file) 46 | self.stems.extend([ 47 | f'{stem}-l{ratios[stem]}' for stem in stems 48 | if (self.cache / f'{stem}-l{ratios[stem]}.wav').exists()]) 49 | 50 | # Omit files where the 50 Hz hum dominates the pitch estimation 51 | self.stems = [ 52 | stem for stem in self.stems 53 | if ( 54 | 2 ** torch.log2( 55 | torch.load(self.cache / f'{stem}{self.viterbi}-pitch.pt') 56 | ).mean() 57 | ) > 60.] 58 | self.speaker_stems = {} 59 | for stem in self.stems: 60 | speaker = stem.split('/')[0] 61 | if speaker not in self.speaker_stems: 62 | self.speaker_stems[speaker] = [stem] 63 | else: 64 | self.speaker_stems[speaker].append(stem) 65 | 66 | def __getitem__(self, index): 67 | stem = self.stems[index] 68 | text = promonet.load.text(self.cache / f'{stem.split("-")[0]}.txt') 69 | audio = promonet.load.audio( 70 | self.cache / f'{stem}.wav').to(torch.float32) 71 | pitch = torch.load( 72 | self.cache / f'{stem}{self.viterbi}-pitch.pt').to(torch.float32) 73 | periodicity = torch.load( 74 | self.cache / f'{stem}{self.viterbi}-periodicity.pt' 75 | ).to(torch.float32) 76 | spectrogram = torch.load( 77 | self.cache / f'{stem}-spectrogram.pt').to(torch.float32) 78 | phonemes = promonet.load.ppg( 79 | self.cache / f'{stem}{ppgs.representation_file_extension()}', 80 | resample_length=spectrogram.shape[-1] 81 | ).to(torch.float32) 82 | 83 | # For loudness augmentation, use original loudness to disentangle 84 | if stem.split('-')[-1].startswith('l'): 85 | loudness_file = self.cache / f'{stem[:-4]}100-loudness.pt' 86 | else: 87 | loudness_file = self.cache / f'{stem}-loudness.pt' 88 | loudness = torch.load(loudness_file).to(torch.float32) 89 | 90 | # Chunk during training 91 | if self.partition.startswith('train'): 92 | frames = promonet.CHUNK_SIZE // promonet.HOPSIZE 93 | if audio.shape[1] < promonet.CHUNK_SIZE: 94 | audio = torch.nn.functional.pad( 95 | audio, 96 | (0, promonet.CHUNK_SIZE - audio.shape[1]), 97 | mode='reflect') 98 | pad_frames = frames - pitch.shape[1] 99 | pad_fn = functools.partial( 100 | torch.nn.functional.pad, 101 | pad=(0, pad_frames), 102 | mode='reflect') 103 | pitch = pad_fn(pitch) 104 | periodicity = pad_fn(periodicity) 105 | loudness = pad_fn(loudness) 106 | spectrogram = pad_fn(spectrogram) 107 | phonemes = pad_fn(phonemes) 108 | else: 109 | start_frame = torch.randint(pitch.shape[-1] - frames + 1, (1,)).item() 110 | start_sample = start_frame * promonet.HOPSIZE 111 | audio = audio[ 112 | :, start_sample:start_sample + promonet.CHUNK_SIZE] 113 | pitch = pitch[:, start_frame:start_frame + frames] 114 | periodicity = periodicity[:, start_frame:start_frame + frames] 115 | loudness = loudness[:, start_frame:start_frame + frames] 116 | spectrogram = spectrogram[:, start_frame:start_frame + frames] 117 | phonemes = phonemes[:, start_frame:start_frame + frames] 118 | 119 | if promonet.ZERO_SHOT: 120 | 121 | # Load speaker embedding 122 | if promonet.ZERO_SHOT_SHUFFLE and 'train' in self.partition: 123 | random_speaker_stem = stem 124 | while random_speaker_stem == stem: 125 | random_speaker_stem = random.choice(self.speaker_stems[stem.split('/')[0]]) 126 | speaker = torch.load(self.cache / f'{random_speaker_stem}-speaker.pt') 127 | else: 128 | speaker = torch.load(self.cache / f'{stem}-speaker.pt') 129 | 130 | else: 131 | 132 | # Get speaker index. Non-integer speaker names are assumed to be 133 | # for speaker adaptation and therefore default to index zero. 134 | if 'adapt' not in self.partition: 135 | speaker = int(stem.split('/')[0]) 136 | else: 137 | speaker = 0 138 | speaker = torch.tensor(speaker, dtype=torch.long) 139 | 140 | # Data augmentation ratios 141 | augmentation = stem[-4:] 142 | if augmentation.startswith('-'): 143 | spectral_balance_ratios, loudness_ratio = 1., 1. 144 | elif augmentation.startswith('p'): 145 | spectral_balance_ratios = int(stem[-3:]) / 100. 146 | loudness_ratio = 1. 147 | elif augmentation.startswith('l'): 148 | spectral_balance_ratios = 1. 149 | loudness_ratio = int(stem[-3:]) / 100. 150 | else: 151 | raise ValueError( 152 | f'Unrecognized augmentation string {augmentation}') 153 | 154 | return ( 155 | text, 156 | loudness, 157 | pitch, 158 | periodicity, 159 | phonemes, 160 | spectrogram, 161 | audio, 162 | speaker, 163 | spectral_balance_ratios, 164 | loudness_ratio, 165 | stem) 166 | 167 | def __len__(self): 168 | return len(self.stems) 169 | -------------------------------------------------------------------------------- /promonet/data/download/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import * 2 | -------------------------------------------------------------------------------- /promonet/data/download/__main__.py: -------------------------------------------------------------------------------- 1 | import yapecs 2 | 3 | import promonet 4 | 5 | 6 | ############################################################################### 7 | # Download datasets 8 | ############################################################################### 9 | 10 | 11 | def parse_args(): 12 | """Parse command-line arguments""" 13 | parser = yapecs.ArgumentParser(description='Download datasets') 14 | parser.add_argument( 15 | '--datasets', 16 | nargs='+', 17 | choices=promonet.DATASETS, 18 | default=promonet.DATASETS, 19 | help='The datasets to download') 20 | return parser.parse_args() 21 | 22 | 23 | promonet.data.download.datasets(**vars(parse_args())) 24 | -------------------------------------------------------------------------------- /promonet/data/loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import promonet 4 | 5 | 6 | ############################################################################### 7 | # Setup data loaders 8 | ############################################################################### 9 | 10 | 11 | def loader(dataset, partition, adapt=promonet.ADAPTATION, gpu=None): 12 | """Setup data loader""" 13 | # Get dataset 14 | dataset = promonet.data.Dataset(dataset, partition, adapt) 15 | 16 | # Create loader 17 | return torch.utils.data.DataLoader( 18 | dataset, 19 | num_workers=promonet.NUM_WORKERS, 20 | pin_memory=gpu is not None, 21 | collate_fn=promonet.data.collate, 22 | batch_sampler=promonet.data.sampler(dataset, partition)) 23 | -------------------------------------------------------------------------------- /promonet/data/pack/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import * 2 | -------------------------------------------------------------------------------- /promonet/data/pack/__main__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import yapecs 4 | 5 | import promonet 6 | 7 | 8 | ############################################################################### 9 | # Pack features 10 | ############################################################################### 11 | 12 | 13 | def parse_args(): 14 | """Parse command-line arguments""" 15 | parser = yapecs.ArgumentParser( 16 | description='Pack features in a single tensor') 17 | parser.add_argument( 18 | '--audio_file', 19 | type=Path, 20 | help='The audio file to convert to a packed feature tensor') 21 | parser.add_argument( 22 | '--output_file', 23 | type=Path, 24 | help=( 25 | 'File to save packed tensor. ' 26 | 'Default is audio_file with .pt extension')) 27 | parser.add_argument( 28 | '--speaker', 29 | type=int, 30 | default=0, 31 | help='The speaker index') 32 | parser.add_argument( 33 | '--spectral_balance_ratio', 34 | type=float, 35 | default=1., 36 | help='> 1 for Alvin and the Chipmunks; < 1 for Patrick Star') 37 | parser.add_argument( 38 | '--gpu', 39 | type=int, 40 | help='The GPU index') 41 | return parser.parse_args() 42 | 43 | 44 | promonet.data.pack.from_file_to_file(**vars(parse_args())) 45 | -------------------------------------------------------------------------------- /promonet/data/pack/core.py: -------------------------------------------------------------------------------- 1 | import ppgs 2 | import torch 3 | 4 | import csv 5 | import numpy as np 6 | 7 | import promonet 8 | 9 | 10 | ############################################################################### 11 | # Pack features 12 | ############################################################################### 13 | 14 | 15 | def from_audio(audio, speaker=0, spectral_balance_ratio=1., gpu=None): 16 | """Convert audio to packed features""" 17 | # Preprocess audio 18 | loudness, pitch, periodicity, ppg = promonet.preprocess.from_audio( 19 | audio, 20 | gpu=gpu) 21 | 22 | # Pack features 23 | return from_features( 24 | loudness[None].cpu(), 25 | pitch[None].cpu(), 26 | periodicity[None].cpu(), 27 | ppg.cpu(), 28 | speaker, 29 | spectral_balance_ratio, 30 | 1.) 31 | 32 | 33 | def from_features( 34 | loudness, 35 | pitch, 36 | periodicity, 37 | ppg, 38 | speaker=0, 39 | spectral_balance_ratio=1., 40 | loudness_ratio=1.): 41 | """Pack features into a single tensor""" 42 | features = torch.zeros((loudness.shape[0], 0, loudness.shape[2])) 43 | 44 | # Loudness 45 | averaged = promonet.preprocess.loudness.band_average(loudness) 46 | features = torch.cat((features, averaged), dim=1) 47 | 48 | # Pitch 49 | features = torch.cat((features, pitch), dim=1) 50 | 51 | # Periodicity 52 | features = torch.cat((features, periodicity), dim=1) 53 | 54 | # PPG 55 | if ( 56 | promonet.SPARSE_PPG_METHOD is not None and 57 | ppgs.REPRESENTATION_KIND == 'ppg' 58 | ): 59 | threshold = torch.tensor( 60 | promonet.SPARSE_PPG_THRESHOLD, 61 | dtype=torch.float) 62 | ppg = ppgs.sparsify(ppg, promonet.SPARSE_PPG_METHOD, threshold) 63 | features = torch.cat((features, ppg), dim=1) 64 | 65 | # Speaker 66 | speaker = torch.tensor([speaker])[:, None, None] 67 | speaker = speaker.repeat(1, 1, features.shape[-1]).to(torch.float) 68 | features = torch.cat((features, speaker), dim=1) 69 | 70 | # Spectral balance 71 | spectral_balance_ratio = \ 72 | torch.tensor([spectral_balance_ratio])[:, None, None].repeat( 73 | 1, 1, features.shape[-1]) 74 | features = torch.cat((features, spectral_balance_ratio), dim=1) 75 | 76 | # Loudness ratio 77 | loudness_ratio = \ 78 | torch.tensor([loudness_ratio])[:, None, None].repeat( 79 | 1, 1, features.shape[-1]) 80 | features = torch.cat((features, loudness_ratio), dim=1) 81 | 82 | return features 83 | 84 | 85 | def from_file_to_file( 86 | audio_file, 87 | output_file = None, 88 | speaker = 0, 89 | spectral_balance_ratio=1., 90 | gpu=None): 91 | """Convert audio file to packed features and save""" 92 | # Default to audio_file with .csv extension 93 | if output_file is None: 94 | output_format = 'csv' 95 | else: 96 | output_format = output_file.suffix[1:] 97 | if output_format not in ['csv', 'pt']: 98 | raise ValueError(f'Output Format "{output_format}" is not supported') 99 | output_file = audio_file.with_suffix(f'.{output_format}') 100 | 101 | # Pack features 102 | audio = promonet.load.audio(audio_file) 103 | features = from_audio( 104 | audio, 105 | speaker, 106 | spectral_balance_ratio, 107 | gpu) 108 | 109 | # Save 110 | if output_format == 'pt': 111 | torch.save(features, output_file) 112 | elif output_format == 'csv': 113 | features = features.cpu().numpy()[0] 114 | 115 | # Representation labels for header 116 | labels = [ 117 | *[f'loudness-{i}' for i in range(promonet.LOUDNESS_BANDS)], # Loudness (8) 118 | 'pitch', # Pitch 119 | 'periodicity', # Periodicity 120 | *[f'ppg-{i} ({ppgs.PHONEMES[i]})' for i in range(promonet.PPG_CHANNELS)], # PPG (40) 121 | 'speaker', # Speaker id 122 | 'spectral balance', # Spectral Balance 123 | 'loudness ratio' # Loudness Ratio 124 | ] 125 | labels = ['timecode', *labels] # Start of frame time (seconds) 126 | 127 | # Generate timecode information (frame beginning) 128 | timecodes = np.arange(0.0, audio.shape[-1] / promonet.SAMPLE_RATE, promonet.HOPSIZE / promonet.SAMPLE_RATE) 129 | timecodes = timecodes[:features.shape[-1]] 130 | 131 | # Save to CSV 132 | with open(output_file, 'w') as csv_file: 133 | csv_writer = csv.writer(csv_file) 134 | csv_writer.writerow(labels) 135 | for i in range(features.shape[-1]): 136 | row = [timecodes[i], *features[:,i].tolist()] 137 | row = [(f"{int(r)}" if i == 51 else f"{r:.8f}") for i,r in enumerate(row)] 138 | csv_writer.writerow(row) -------------------------------------------------------------------------------- /promonet/data/preprocess/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import * 2 | -------------------------------------------------------------------------------- /promonet/data/preprocess/__main__.py: -------------------------------------------------------------------------------- 1 | import yapecs 2 | 3 | import promonet 4 | 5 | 6 | ############################################################################### 7 | # Preprocess datasets 8 | ############################################################################### 9 | 10 | 11 | def parse_args(): 12 | parser = yapecs.ArgumentParser(description='Preprocess datasets') 13 | parser.add_argument( 14 | '--datasets', 15 | nargs='+', 16 | default=promonet.DATASETS, 17 | choices=promonet.DATASETS, 18 | help='The datasets to preprocess') 19 | parser.add_argument( 20 | '--features', 21 | default=['loudness', 'pitch', 'periodicity', 'ppg'], 22 | choices=promonet.ALL_FEATURES, 23 | nargs='+', 24 | help='The features to preprocess') 25 | parser.add_argument( 26 | '--gpu', 27 | type=int, 28 | help='The index of the gpu to use') 29 | return parser.parse_args() 30 | 31 | 32 | promonet.data.preprocess.datasets(**vars(parse_args())) 33 | -------------------------------------------------------------------------------- /promonet/data/preprocess/core.py: -------------------------------------------------------------------------------- 1 | import torchutil 2 | 3 | import promonet 4 | 5 | 6 | ############################################################################### 7 | # Preprocess datasets 8 | ############################################################################### 9 | 10 | 11 | @torchutil.notify('preprocess') 12 | def datasets(datasets, features=promonet.ALL_FEATURES, gpu=None): 13 | """Preprocess a dataset""" 14 | for dataset in datasets: 15 | 16 | # Get cache directory 17 | directory = promonet.CACHE_DIR / dataset 18 | 19 | # Get text and audio files for this speaker 20 | audio_files = sorted(list(directory.rglob('*.wav'))) 21 | audio_files = [file for file in audio_files if '-' in file.stem] 22 | 23 | # Preprocess input features 24 | if any(feature in features for feature in [ 25 | 'loudness', 26 | 'pitch', 27 | 'periodicity', 28 | 'ppg', 29 | 'text', 30 | 'harmonics', 31 | 'speaker' 32 | ]): 33 | promonet.preprocess.from_files_to_files( 34 | audio_files, 35 | gpu=gpu, 36 | features=[f for f in features if f != 'spectrogram'], 37 | loudness_bands=None) 38 | 39 | # Preprocess spectrograms 40 | if 'spectrogram' in features: 41 | spectrogram_files = [ 42 | file.parent / f'{file.stem}-spectrogram.pt' 43 | for file in audio_files] 44 | promonet.preprocess.spectrogram.from_files_to_files( 45 | audio_files, 46 | spectrogram_files) 47 | -------------------------------------------------------------------------------- /promonet/data/sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import promonet 4 | 5 | 6 | ############################################################################### 7 | # Sampler selection 8 | ############################################################################### 9 | 10 | 11 | def sampler(dataset, partition): 12 | """Create batch sampler""" 13 | # Deterministic random sampler for training 14 | if partition.startswith('train'): 15 | return Sampler(dataset) 16 | 17 | # Sample validation and test data sequentially 18 | elif partition.startswith('test') or partition.startswith('valid'): 19 | return torch.utils.data.BatchSampler( 20 | torch.utils.data.SequentialSampler(dataset), 21 | 1, 22 | False) 23 | 24 | else: 25 | raise ValueError(f'Partition {partition} is not defined') 26 | 27 | 28 | ############################################################################### 29 | # Samplers 30 | ############################################################################### 31 | 32 | 33 | class Sampler: 34 | 35 | def __init__(self, dataset): 36 | self.epoch = 0 37 | self.length = len(dataset) 38 | 39 | def __iter__(self): 40 | return iter(self.batch()) 41 | 42 | def __len__(self): 43 | return len(self.batch()) 44 | 45 | def batch(self): 46 | """Produces batch indices for one epoch""" 47 | # Deterministic shuffling based on epoch 48 | generator = torch.Generator() 49 | generator.manual_seed(promonet.RANDOM_SEED + self.epoch) 50 | 51 | # Shuffle 52 | indices = torch.randperm(self.length, generator=generator).tolist() 53 | 54 | # Make batches 55 | return [ 56 | indices[i:i + promonet.BATCH_SIZE] 57 | for i in range(0, self.length, promonet.BATCH_SIZE)] 58 | 59 | def set_epoch(self, epoch): 60 | self.epoch = epoch 61 | -------------------------------------------------------------------------------- /promonet/edit/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import * 2 | from . import grid 3 | -------------------------------------------------------------------------------- /promonet/edit/__main__.py: -------------------------------------------------------------------------------- 1 | import yapecs 2 | from pathlib import Path 3 | 4 | import promonet 5 | 6 | 7 | ############################################################################### 8 | # Edit 9 | ############################################################################### 10 | 11 | 12 | def parse_args(): 13 | """Parse command-line arguments""" 14 | parser = yapecs.ArgumentParser(description='Edit speech representation') 15 | parser.add_argument( 16 | '--loudness_files', 17 | type=Path, 18 | nargs='+', 19 | required=True, 20 | help='The loudness files to edit') 21 | parser.add_argument( 22 | '--pitch_files', 23 | type=Path, 24 | nargs='+', 25 | required=True, 26 | help='The pitch files to edit') 27 | parser.add_argument( 28 | '--periodicity_files', 29 | type=Path, 30 | nargs='+', 31 | required=True, 32 | help='The periodicity files to edit') 33 | parser.add_argument( 34 | '--ppg_files', 35 | type=Path, 36 | nargs='+', 37 | required=True, 38 | help='The ppg files to edit') 39 | parser.add_argument( 40 | '--output_prefixes', 41 | required=True, 42 | type=Path, 43 | nargs='+', 44 | help='The locations to save output files, minus extension') 45 | parser.add_argument( 46 | '--pitch_shift_cents', 47 | type=float, 48 | help='Amount of pitch-shifting in cents') 49 | parser.add_argument( 50 | '--time_stretch_ratio', 51 | type=float, 52 | help='Amount of time-stretching. Faster when above one.') 53 | parser.add_argument( 54 | '--loudness_scale_db', 55 | type=float, 56 | help='Amount of loudness scaling in dB') 57 | parser.add_argument( 58 | '--stretch_unvoiced', 59 | action='store_true', 60 | help='If provided, applies time-stretching to unvoiced frames') 61 | parser.add_argument( 62 | '--stretch_silence', 63 | action='store_true', 64 | help='If provided, applies time-stretching to silence frames') 65 | parser.add_argument( 66 | '--save_grid', 67 | action='store_true', 68 | help='If provided, also saves the time-stretch grid') 69 | return parser.parse_args() 70 | 71 | 72 | promonet.edit.from_files_to_files(**vars(parse_args())) 73 | -------------------------------------------------------------------------------- /promonet/edit/core.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from typing import List, Optional, Tuple, Union 4 | 5 | import ppgs 6 | import pypar 7 | import torch 8 | 9 | import promonet 10 | 11 | 12 | ############################################################################### 13 | # Edit speech features 14 | ############################################################################### 15 | 16 | 17 | def from_features( 18 | loudness: torch.Tensor, 19 | pitch: torch.Tensor, 20 | periodicity: torch.Tensor, 21 | ppg: torch.Tensor, 22 | pitch_shift_cents: Optional[float] = None, 23 | time_stretch_ratio: Optional[float] = None, 24 | loudness_scale_db: Optional[float] = None, 25 | stretch_unvoiced: bool = True, 26 | stretch_silence: bool = True, 27 | return_grid: bool = False 28 | ) -> Union[ 29 | Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], 30 | Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] 31 | ]: 32 | """Edit speech representation 33 | 34 | Arguments 35 | loudness: Loudness contour to edit 36 | pitch: Pitch contour to edit 37 | periodicity: Periodicity contour to edit 38 | ppg: PPG to edit 39 | pitch_shift_cents: Amount of pitch-shifting in cents 40 | time_stretch_ratio: Amount of time-stretching. Faster when above one. 41 | loudness_scale_db: Loudness ratio editing in dB (not recommended; use loudness) 42 | stretch_unvoiced: If true, applies time-stretching to unvoiced frames 43 | stretch_silence: If true, applies time-stretching to silent frames 44 | return_grid: If true, also returns the time-stretch grid 45 | 46 | Returns 47 | edited_loudness, edited_pitch, edited_periodicity, edited_ppg 48 | """ 49 | # Maybe time-stretch 50 | if time_stretch_ratio is not None: 51 | 52 | # Create time-stretch grid 53 | if stretch_unvoiced and stretch_silence: 54 | grid = promonet.edit.grid.constant( 55 | ppg, 56 | time_stretch_ratio) 57 | else: 58 | 59 | # Get voiced phoneme indices 60 | indices = [ 61 | ppgs.PHONEME_TO_INDEX_MAPPING[phoneme] 62 | for phoneme in ppgs.VOICED] 63 | 64 | # Maybe add silence 65 | if stretch_silence: 66 | indices.append(ppgs.PHONEME_TO_INDEX_MAPPING[pypar.SILENCE]) 67 | 68 | # Maybe add unvoiced 69 | if stretch_unvoiced: 70 | indices.extend( 71 | list( 72 | set(ppgs.PHONEMES) - 73 | set(ppgs.VOICED) - 74 | set([pypar.SILENCE]) 75 | ) 76 | ) 77 | 78 | # Get selection probabilities 79 | selected = ppg[torch.tensor(indices)].sum(dim=0) 80 | 81 | # Get number of output frames 82 | target_frames = round(ppg.shape[-1] / time_stretch_ratio) 83 | 84 | # Adjust ratio based on selection probabilities 85 | total_selected = selected.sum() 86 | total_unselected = ppg.shape[-1] - total_selected 87 | effective_ratio = (target_frames - total_unselected) / total_selected 88 | 89 | # Create time-stretch grid 90 | grid = torch.zeros(round(target_frames)) 91 | i = 0. 92 | for j in range(1, target_frames): 93 | 94 | # Get time-varying interpolation weight 95 | left = math.floor(i) 96 | if left + 1 < len(selected): 97 | offset = i - left 98 | probability = ( 99 | offset * selected[left + 1] + 100 | (1 - offset) * selected[left]) 101 | else: 102 | probability = selected[left] 103 | 104 | # Get time-varying step size 105 | ratio = probability * effective_ratio + (1 - probability) 106 | step = 1. / ratio 107 | 108 | # Take a step 109 | grid[j] = grid[j - 1] + step 110 | i += step 111 | 112 | # Time-stretch 113 | pitch = 2 ** promonet.edit.grid.sample(torch.log2(pitch), grid) 114 | periodicity = promonet.edit.grid.sample(periodicity, grid) 115 | loudness = promonet.edit.grid.sample(loudness, grid) 116 | ppg = promonet.edit.grid.sample(ppg, grid, promonet.PPG_INTERP_METHOD) 117 | elif return_grid: 118 | grid = None 119 | 120 | # Maybe pitch-shift 121 | if pitch_shift_cents is not None: 122 | pitch = pitch.clone() * promonet.convert.cents_to_ratio( 123 | pitch_shift_cents) 124 | pitch = torch.clip(pitch, promonet.FMIN, promonet.FMAX) 125 | 126 | # Maybe loudness-scale 127 | if loudness_scale_db is not None: 128 | loudness += loudness_scale_db 129 | 130 | if return_grid: 131 | return loudness, pitch, periodicity, ppg, grid 132 | return loudness, pitch, periodicity, ppg 133 | 134 | 135 | def from_file( 136 | loudness_file: Union[str, bytes, os.PathLike], 137 | pitch_file: Union[str, bytes, os.PathLike], 138 | periodicity_file: Union[str, bytes, os.PathLike], 139 | ppg_file: Union[str, bytes, os.PathLike], 140 | pitch_shift_cents: Optional[float] = None, 141 | time_stretch_ratio: Optional[float] = None, 142 | loudness_scale_db: Optional[float] = None, 143 | stretch_unvoiced: bool = True, 144 | stretch_silence: bool = True, 145 | return_grid: bool = False 146 | ) -> Union[ 147 | Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], 148 | Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] 149 | ]: 150 | """Edit speech representation on disk 151 | 152 | Arguments 153 | loudness_file: Loudness file to edit 154 | pitch_file: Pitch file to edit 155 | periodicity_file: Periodicity file to edit 156 | ppg_file: PPG file to edit 157 | pitch_shift_cents: Amount of pitch-shifting in cents 158 | time_stretch_ratio: Amount of time-stretching. Faster when above one. 159 | loudness_scale_db: Loudness ratio editing in dB (not recommended; use loudness) 160 | stretch_unvoiced: If true, applies time-stretching to unvoiced frames 161 | stretch_silence: If true, applies time-stretching to silent frames 162 | return_grid: If true, also returns the time-stretch grid 163 | 164 | Returns 165 | edited_loudness, edited_pitch, edited_periodicity, edited_ppg 166 | """ 167 | pitch = torch.load(pitch_file) 168 | return from_features( 169 | torch.load(loudness_file), 170 | pitch, 171 | torch.load(periodicity_file), 172 | promonet.load.ppg(ppg_file, pitch.shape[-1]), 173 | pitch_shift_cents, 174 | time_stretch_ratio, 175 | loudness_scale_db, 176 | stretch_unvoiced, 177 | stretch_silence, 178 | return_grid) 179 | 180 | 181 | def from_file_to_file( 182 | loudness_file: Union[str, bytes, os.PathLike], 183 | pitch_file: Union[str, bytes, os.PathLike], 184 | periodicity_file: Union[str, bytes, os.PathLike], 185 | ppg_file: Union[str, bytes, os.PathLike], 186 | output_prefix: Union[str, bytes, os.PathLike], 187 | pitch_shift_cents: Optional[float] = None, 188 | time_stretch_ratio: Optional[float] = None, 189 | loudness_scale_db: Optional[float] = None, 190 | stretch_unvoiced: bool = True, 191 | stretch_silence: bool = True, 192 | save_grid: bool = False 193 | ) -> None: 194 | """Edit speech representation on disk and save to disk 195 | 196 | Arguments 197 | loudness_file: Loudness file to edit 198 | pitch_file: Pitch file to edit 199 | periodicity_file: Periodicity file to edit 200 | ppg_file: PPG file to edit 201 | output_prefix: File to save output, minus extension 202 | pitch_shift_cents: Amount of pitch-shifting in cents 203 | time_stretch_ratio: Amount of time-stretching. Faster when above one. 204 | loudness_scale_db: Loudness ratio editing in dB (not recommended; use loudness) 205 | stretch_unvoiced: If true, applies time-stretching to unvoiced frames 206 | stretch_silence: If true, applies time-stretching to silent frames 207 | save_grid: If true, also saves the time-stretch grid 208 | """ 209 | # Edit 210 | results = from_file( 211 | loudness_file, 212 | pitch_file, 213 | periodicity_file, 214 | ppg_file, 215 | pitch_shift_cents, 216 | time_stretch_ratio, 217 | loudness_scale_db, 218 | stretch_unvoiced, 219 | stretch_silence, 220 | save_grid) 221 | 222 | # Save 223 | viterbi = '-viterbi' if promonet.VITERBI_DECODE_PITCH else '' 224 | torch.save(results[0], f'{output_prefix}-loudness.pt') 225 | torch.save(results[1], f'{output_prefix}{viterbi}-pitch.pt') 226 | torch.save(results[2], f'{output_prefix}{viterbi}-periodicity.pt') 227 | torch.save(results[3], f'{output_prefix}{ppgs.representation_file_extension()}') 228 | if save_grid: 229 | torch.save(results[4], f'{output_prefix}-grid.pt') 230 | 231 | 232 | def from_files_to_files( 233 | loudness_files: List[Union[str, bytes, os.PathLike]], 234 | pitch_files: List[Union[str, bytes, os.PathLike]], 235 | periodicity_files: List[Union[str, bytes, os.PathLike]], 236 | ppg_files: List[Union[str, bytes, os.PathLike]], 237 | output_prefixes: List[Union[str, bytes, os.PathLike]], 238 | pitch_shift_cents: Optional[float] = None, 239 | time_stretch_ratio: Optional[float] = None, 240 | loudness_scale_db: Optional[float] = None, 241 | stretch_unvoiced: bool = True, 242 | stretch_silence: bool = True, 243 | save_grid: bool = False 244 | ) -> None: 245 | """Edit speech representations on disk and save to disk 246 | 247 | Arguments 248 | loudness_files: Loudness files to edit 249 | pitch_files: Pitch files to edit 250 | periodicity_files: Periodicity files to edit 251 | ppg_files: Phonetic posteriorgram files to edit 252 | output_prefixes: Files to save output, minus extension 253 | pitch_shift_cents: Amount of pitch-shifting in cents 254 | time_stretch_ratio: Amount of time-stretching. Faster when above one. 255 | loudness_scale_db: Loudness ratio editing in dB (not recommended; use loudness) 256 | stretch_unvoiced: If true, applies time-stretching to unvoiced frames 257 | stretch_silence: If true, applies time-stretching to silent frames 258 | save_grid: If true, also saves the time-stretch grid 259 | """ 260 | for loudness_file, pitch_file, periodicity_file, ppg_file, prefix in zip( 261 | loudness_files, 262 | pitch_files, 263 | periodicity_files, 264 | ppg_files, 265 | output_prefixes 266 | ): 267 | from_file_to_file( 268 | loudness_file, 269 | pitch_file, 270 | periodicity_file, 271 | ppg_file, 272 | prefix, 273 | pitch_shift_cents, 274 | time_stretch_ratio, 275 | loudness_scale_db, 276 | stretch_unvoiced, 277 | stretch_silence, 278 | save_grid) 279 | -------------------------------------------------------------------------------- /promonet/edit/grid.py: -------------------------------------------------------------------------------- 1 | import ppgs 2 | import torch 3 | 4 | import promonet 5 | 6 | 7 | ############################################################################### 8 | # Grid sampling 9 | ############################################################################### 10 | 11 | 12 | def sample(sequence, grid, method='linear'): 13 | """Perform 1D grid-based sampling""" 14 | # Linear grid interpolation 15 | if method == 'linear': 16 | x = grid 17 | fp = sequence 18 | 19 | # Input indices 20 | xp = torch.arange(fp.shape[-1], device=fp.device) 21 | 22 | # Output indices 23 | i = torch.searchsorted(xp, x, side='right') 24 | 25 | # Replicate final frame 26 | # "replication_pad1d_cpu" not implemented for 'Half' 27 | if fp.dtype == torch.float16: 28 | fp = torch.nn.functional.pad( 29 | fp.to(torch.float32), 30 | (0, 1), 31 | mode='replicate' 32 | ).to(torch.float16) 33 | else: 34 | fp = torch.nn.functional.pad(fp, (0, 1), mode='replicate') 35 | xp = torch.cat((xp, xp[-1:] + 1)) 36 | 37 | # Interpolate 38 | return fp[..., i - 1] * (xp[i] - x) + fp[..., i] * (x - xp[i - 1]) 39 | 40 | # Nearest neighbors grid interpolation 41 | elif method == 'nearest': 42 | return sequence[..., torch.round(grid).to(torch.long)] 43 | 44 | else: 45 | raise ValueError(f'Grid sampling method {method} is not defined') 46 | 47 | 48 | ############################################################################### 49 | # Interpolation grids 50 | ############################################################################### 51 | 52 | 53 | def constant(tensor, ratio): 54 | """Create a grid for constant-ratio time-stretching""" 55 | return ppgs.edit.grid.constant(tensor, ratio) 56 | 57 | 58 | def from_alignments(source, target): 59 | """Create time-stretch grid to convert source alignment to target""" 60 | return ppgs.edit.grid.from_alignments( 61 | source, 62 | target, 63 | sample_rate=promonet.SAMPLE_RATE, 64 | hopsize=promonet.HOPSIZE) 65 | 66 | 67 | def of_length(tensor, length): 68 | """Create time-stretch grid of a specified length""" 69 | return ppgs.edit.grid.of_length(tensor, length) 70 | -------------------------------------------------------------------------------- /promonet/evaluate/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import * 2 | from .metrics import Metrics, SpectralBalance, spectral_centroid 3 | -------------------------------------------------------------------------------- /promonet/evaluate/__main__.py: -------------------------------------------------------------------------------- 1 | import yapecs 2 | from pathlib import Path 3 | 4 | import promonet 5 | 6 | 7 | ############################################################################### 8 | # Evaluate 9 | ############################################################################### 10 | 11 | 12 | def parse_args(): 13 | """Parse command-line arguments""" 14 | parser = yapecs.ArgumentParser(description='Perform evaluation') 15 | parser.add_argument( 16 | '--datasets', 17 | nargs='+', 18 | default=promonet.DATASETS, 19 | help='The datasets to evaluate') 20 | parser.add_argument( 21 | '--adapt', 22 | action='store_true', 23 | help='Whether to perform speaker adaptation') 24 | parser.add_argument( 25 | '--gpu', 26 | type=int, 27 | help='The index of the gpu to use for evaluation') 28 | return parser.parse_args() 29 | 30 | 31 | promonet.evaluate.datasets(**vars(parse_args())) 32 | -------------------------------------------------------------------------------- /promonet/load.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import ppgs 4 | import torch 5 | import torchaudio 6 | import torchutil 7 | 8 | import promonet 9 | 10 | 11 | ############################################################################### 12 | # Loading utilities 13 | ############################################################################### 14 | 15 | 16 | def audio(file): 17 | """Load audio from disk""" 18 | # Load 19 | audio, sample_rate = torchaudio.load(file) 20 | 21 | # Resample 22 | audio = torchaudio.functional.resample( 23 | audio, 24 | sample_rate, 25 | promonet.SAMPLE_RATE) 26 | 27 | # Ensure mono 28 | return audio.mean(dim=0, keepdims=True) 29 | 30 | 31 | def features(prefix): 32 | """Load input features from file prefix""" 33 | if promonet.VITERBI_DECODE_PITCH: 34 | pitch_prefix = f'{prefix}-viterbi' 35 | else: 36 | pitch_prefix = prefix 37 | return ( 38 | torch.load(f'{prefix}-loudness.pt'), 39 | torch.load(f'{pitch_prefix}-pitch.pt'), 40 | torch.load(f'{pitch_prefix}-periodicity.pt'), 41 | torch.load(f'{prefix}-ppg.pt')) 42 | 43 | 44 | def partition(dataset, adapt=promonet.ADAPTATION): 45 | """Load partitions for dataset""" 46 | partition_dir = ( 47 | promonet.ASSETS_DIR / 48 | 'partitions' / 49 | ('adaptation' if adapt else 'multispeaker')) 50 | with open(partition_dir / f'{dataset}.json') as file: 51 | return json.load(file) 52 | 53 | 54 | def pitch_distribution(dataset=promonet.TRAINING_DATASET, partition='train'): 55 | """Load pitch distribution""" 56 | if not hasattr(pitch_distribution, 'distribution'): 57 | 58 | # Location on disk 59 | key = '' 60 | if promonet.AUGMENT_LOUDNESS: 61 | key += '-loudness' 62 | if promonet.AUGMENT_PITCH: 63 | key += '-pitch' 64 | if promonet.VITERBI_DECODE_PITCH: 65 | key += '-viterbi' 66 | file = ( 67 | promonet.ASSETS_DIR / 68 | 'stats' / 69 | f'{dataset}-{promonet.PITCH_BINS}{key}.pt') 70 | 71 | if file.exists(): 72 | 73 | # Load and cache distribution 74 | pitch_distribution.distribution = torch.load(file) 75 | 76 | else: 77 | 78 | # Get all voiced pitch frames 79 | allpitch = [] 80 | dataset = promonet.data.Dataset(dataset, partition) 81 | viterbi = '-viterbi' if promonet.VITERBI_DECODE_PITCH else '' 82 | for stem in torchutil.iterator( 83 | dataset.stems, 84 | 'promonet.load.pitch_distribution' 85 | ): 86 | pitch = torch.load( 87 | dataset.cache / f'{stem}{viterbi}-pitch.pt') 88 | periodicity = torch.load( 89 | dataset.cache / f'{stem}{viterbi}-periodicity.pt') 90 | allpitch.append( 91 | pitch[ 92 | torch.logical_and( 93 | ~torch.isnan(pitch), 94 | periodicity > promonet.VOICING_THRESHOLD)]) 95 | 96 | # Sort 97 | pitch, _ = torch.sort(torch.cat(allpitch)) 98 | 99 | # Bucket 100 | indices = torch.linspace( 101 | len(pitch) / promonet.PITCH_BINS, 102 | len(pitch) - 1, 103 | promonet.PITCH_BINS, 104 | dtype=torch.float64 105 | ).to(torch.long) 106 | pitch_distribution.distribution = pitch[indices] 107 | 108 | # Save 109 | torch.save(pitch_distribution.distribution, file) 110 | 111 | return pitch_distribution.distribution 112 | 113 | 114 | def per_speaker_averages(dataset=promonet.TRAINING_DATASET, partition='train'): 115 | """Load the average pitch in voiced regions for each speaker""" 116 | if not hasattr(per_speaker_averages, 'averages'): 117 | 118 | # Location on disk 119 | key = '' 120 | if promonet.VITERBI_DECODE_PITCH: 121 | key += '-viterbi' 122 | file = ( 123 | promonet.ASSETS_DIR / 124 | 'stats' / 125 | f'{dataset}-{partition}-speaker-averages{key}.json') 126 | 127 | try: 128 | 129 | # Load and cache averages 130 | with open(file) as json_file: 131 | per_speaker_averages.averages = json.load(json_file) 132 | 133 | except FileNotFoundError: 134 | 135 | # Get all voiced pitch frames 136 | allpitch = {} 137 | dataset = promonet.data.Dataset(dataset, partition) 138 | viterbi = '-viterbi' if promonet.VITERBI_DECODE_PITCH else '' 139 | for stem in torchutil.iterator( 140 | dataset.stems, 141 | 'promonet.load.pitch_distribution' 142 | ): 143 | pitch = torch.load( 144 | dataset.cache / f'{stem}{viterbi}-pitch.pt') 145 | periodicity = torch.load( 146 | dataset.cache / f'{stem}{viterbi}-periodicity.pt') 147 | speaker = stem.split('/')[0] 148 | if speaker not in allpitch: 149 | allpitch[speaker] = [] 150 | allpitch[speaker].append( 151 | pitch[ 152 | torch.logical_and( 153 | ~torch.isnan(pitch), 154 | periodicity > promonet.VOICING_THRESHOLD)]) 155 | 156 | # Cache 157 | per_speaker_averages.averages = { 158 | speaker: 2 ** torch.log2(torch.cat(values)).mean().item() 159 | for speaker, values in allpitch.items()} 160 | 161 | # Save 162 | with open(file, 'w') as json_file: 163 | json.dump( 164 | per_speaker_averages.averages, 165 | json_file, 166 | indent=4, 167 | sort_keys=True) 168 | 169 | return per_speaker_averages.averages 170 | 171 | 172 | def ppg(file, resample_length=None): 173 | """Load a PPG file and maybe resample""" 174 | # Load 175 | result = torch.load(file) 176 | 177 | # Maybe resample 178 | if resample_length is not None and result.shape[-1] != resample_length: 179 | result = promonet.edit.grid.sample( 180 | result, 181 | promonet.edit.grid.of_length(result, resample_length), 182 | promonet.PPG_INTERP_METHOD) 183 | 184 | # Preserve distribution 185 | if ppgs.REPRESENTATION_KIND == 'ppgs': 186 | return torch.softmax(torch.log(result + 1e-8), -2) 187 | 188 | return result 189 | 190 | 191 | def text(file): 192 | """Load text file""" 193 | with open(file, encoding='utf-8') as file: 194 | return file.read() 195 | -------------------------------------------------------------------------------- /promonet/model/__init__.py: -------------------------------------------------------------------------------- 1 | from . import export 2 | from .core import * 3 | from .discriminator import Discriminator 4 | from .fargan import FARGAN 5 | from .generator import Generator, MelGenerator 6 | from .hifigan import HiFiGAN 7 | from .vocos import Vocos 8 | -------------------------------------------------------------------------------- /promonet/model/cargan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import promonet 4 | 5 | 6 | ############################################################################### 7 | # Chunked autoregressive GAN 8 | ############################################################################### 9 | 10 | 11 | class CARGAN(torch.nn.Module): 12 | 13 | def __init__(self, initial_channel, gin_channels): 14 | super().__init__() 15 | self.model = promonet.model.HiFiGAN( 16 | initial_channel + promonet.CARGAN_OUTPUT_SIZE, 17 | gin_channels) 18 | self.ar = Autoregressive() 19 | 20 | # Inference buffer 21 | self.buffer = torch.zeros((1, 1, promonet.CARGAN_INPUT_SIZE)) 22 | 23 | def forward(self, x, g=None, ar=None): 24 | if not self.training and ar == None: 25 | ar = self.buffer 26 | ar = self.ar(ar) 27 | ar = ar.unsqueeze(2).repeat(1, 1, x.shape[2]) 28 | y = self.model(torch.cat((x, ar), dim=1)) 29 | if not self.training: 30 | self.buffer = y[..., -promonet.CARGAN_INPUT_SIZE:] 31 | return y 32 | 33 | 34 | class Autoregressive(torch.nn.Module): 35 | 36 | def __init__(self): 37 | super().__init__() 38 | model = [ 39 | torch.nn.Linear( 40 | promonet.CARGAN_INPUT_SIZE, 41 | promonet.CARGAN_HIDDEN_SIZE), 42 | torch.nn.LeakyReLU(.1)] 43 | for _ in range(3): 44 | model.extend([ 45 | torch.nn.Linear( 46 | promonet.CARGAN_HIDDEN_SIZE, 47 | promonet.CARGAN_HIDDEN_SIZE), 48 | torch.nn.LeakyReLU(.1)]) 49 | model.append( 50 | torch.nn.Linear( 51 | promonet.CARGAN_HIDDEN_SIZE, 52 | promonet.CARGAN_OUTPUT_SIZE)) 53 | self.model = torch.nn.Sequential(*model) 54 | 55 | def forward(self, x): 56 | return self.model(x.squeeze(1)) 57 | -------------------------------------------------------------------------------- /promonet/model/core.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | ############################################################################### 5 | # Shared model utilities 6 | ############################################################################### 7 | 8 | 9 | def get_padding(kernel_size, dilation=1, stride=1): 10 | """Compute the padding needed to perform same-size convolution""" 11 | return int((kernel_size * dilation - dilation - stride + 1) / 2) 12 | 13 | 14 | def random_slice_segments(segments, lengths, segment_size): 15 | """Randomly slice segments along last dimension""" 16 | max_start_indices = lengths - segment_size + 1 17 | start_indices = torch.rand((len(segments),), device=segments.device) 18 | start_indices = (start_indices * max_start_indices).to(dtype=torch.long) 19 | segments = slice_segments(segments, start_indices, segment_size) 20 | return segments, start_indices 21 | 22 | 23 | def slice_segments(segments, start_indices, segment_size, fill_value=0.): 24 | """Slice segments along last dimension""" 25 | slices = torch.full_like(segments[..., :segment_size], fill_value) 26 | iterator = enumerate(zip(segments, start_indices)) 27 | for i, (segment, start_index) in iterator: 28 | end_index = start_index + segment_size 29 | 30 | # Pad negative indices 31 | if start_index <= -segment_size: 32 | continue 33 | elif start_index < 0: 34 | start_index = 0 35 | 36 | # Slice 37 | slices[i, ..., -(end_index - start_index):] = \ 38 | segment[..., start_index:end_index] 39 | 40 | return slices 41 | 42 | 43 | def weight_norm_conv1d(*args, **kwargs): 44 | """Construct Conv1d layer with weight normalization""" 45 | return torch.nn.utils.weight_norm(torch.nn.Conv1d(*args, **kwargs)) 46 | -------------------------------------------------------------------------------- /promonet/model/export/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import * 2 | -------------------------------------------------------------------------------- /promonet/model/export/__main__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import yapecs 4 | 5 | import promonet 6 | 7 | 8 | ############################################################################### 9 | # Model export CLI 10 | ############################################################################### 11 | 12 | 13 | def parse_args(): 14 | """Parse command-line arguments""" 15 | parser = yapecs.ArgumentParser(description='Export torchscript model') 16 | parser.add_argument( 17 | '--checkpoint', 18 | type=Path, 19 | help='The generator checkpoint') 20 | parser.add_argument( 21 | '--output_file', 22 | type=Path, 23 | default='promonet-export.ts', 24 | help='The torch file to write the exported model') 25 | 26 | return parser.parse_args() 27 | 28 | 29 | if __name__ == '__main__': 30 | promonet.model.export.from_file_to_file(**vars(parse_args())) 31 | -------------------------------------------------------------------------------- /promonet/model/export/core.py: -------------------------------------------------------------------------------- 1 | import torchutil 2 | 3 | import promonet 4 | 5 | 6 | ############################################################################### 7 | # Model exporting 8 | ############################################################################### 9 | 10 | 11 | def from_file_to_file(checkpoint=None, output_file='promonet-export.ts'): 12 | """Load model from checkpoint and export to torchscript""" 13 | # Load model 14 | model = promonet.model.Generator() 15 | if checkpoint is not None: 16 | model, *_ = torchutil.checkpoint.load(checkpoint, model) 17 | 18 | # Switch to evaluation mode 19 | with torchutil.inference.context(model): 20 | 21 | # Export 22 | model.export(output_file) 23 | -------------------------------------------------------------------------------- /promonet/model/hifigan.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch 4 | 5 | import promonet 6 | 7 | 8 | ############################################################################### 9 | # HiFi-GAN vocoder 10 | ############################################################################### 11 | 12 | 13 | class HiFiGAN(torch.nn.Module): 14 | 15 | def __init__(self, initial_channel, gin_channels): 16 | super().__init__() 17 | 18 | # Input layer 19 | self.input_feature_conv = torch.nn.Conv1d( 20 | initial_channel, 21 | promonet.HIFIGAN_UPSAMPLE_INITIAL_SIZE, 22 | 7, 23 | 1, 24 | padding=3) 25 | 26 | # Speaker conditioning 27 | self.input_speaker_conv = torch.nn.Conv1d( 28 | gin_channels, 29 | promonet.HIFIGAN_UPSAMPLE_INITIAL_SIZE, 30 | 1) 31 | 32 | # Rest of the model 33 | output_channels = ( 34 | promonet.HIFIGAN_UPSAMPLE_INITIAL_SIZE // 35 | (2 ** len(promonet.HIFIGAN_UPSAMPLE_RATES))) 36 | self.model = torch.nn.Sequential( 37 | 38 | # MRF blocks 39 | *[ 40 | MultiReceptiveFieldFusion( 41 | promonet.HIFIGAN_UPSAMPLE_INITIAL_SIZE // (2 ** i), 42 | promonet.HIFIGAN_UPSAMPLE_INITIAL_SIZE // (2 ** (i + 1)), 43 | upsample_kernel_size, 44 | upsample_rate 45 | ) 46 | for i, ( 47 | upsample_kernel_size, 48 | upsample_rate 49 | ) in enumerate(zip( 50 | promonet.HIFIGAN_UPSAMPLE_KERNEL_SIZES, 51 | promonet.HIFIGAN_UPSAMPLE_RATES 52 | )) 53 | ], 54 | 55 | # Last layer 56 | torch.nn.LeakyReLU(promonet.LRELU_SLOPE), 57 | torch.nn.Conv1d(output_channels, 1, 7, 1, 3, bias=False), 58 | 59 | # Output activation 60 | torch.nn.Tanh() 61 | ) 62 | 63 | def forward(self, x, g, p): 64 | # Input layer 65 | x = self.input_feature_conv(x) 66 | 67 | # Speaker conditioning 68 | x = x + self.input_speaker_conv(g) 69 | 70 | return self.model(x) 71 | 72 | def remove_weight_norm(self): 73 | """Remove weight norm for scriptable inference""" 74 | for layer in self.model: 75 | if isinstance(layer, MultiReceptiveFieldFusion): 76 | layer.remove_weight_norm() 77 | 78 | 79 | ############################################################################### 80 | # HiFi-GAN outermost block 81 | ############################################################################### 82 | 83 | 84 | class MultiReceptiveFieldFusion(torch.nn.Module): 85 | 86 | def __init__( 87 | self, 88 | input_channels, 89 | output_channels, 90 | upsample_kernel_size, 91 | upsample_rate 92 | ): 93 | super().__init__() 94 | self.model = torch.nn.Sequential( 95 | 96 | # Input activation 97 | torch.nn.LeakyReLU(promonet.LRELU_SLOPE), 98 | 99 | # Upsampling layer 100 | torch.nn.utils.weight_norm( 101 | torch.nn.ConvTranspose1d( 102 | input_channels, 103 | output_channels, 104 | upsample_kernel_size, 105 | upsample_rate, 106 | padding=(upsample_kernel_size - upsample_rate) // 2)), 107 | 108 | # Residual block 109 | ResidualBlock(output_channels)) 110 | 111 | # Weight initialization 112 | self.model[1].apply(init_weights) 113 | 114 | def forward(self, x): 115 | return self.model(x) 116 | 117 | def remove_weight_norm(self): 118 | """Remove weight norm for scriptable inference""" 119 | torch.nn.utils.remove_weight_norm(self.model[1]) 120 | self.model[2].remove_weight_norm() 121 | 122 | 123 | ############################################################################### 124 | # HiFi-GAN residual block 125 | ############################################################################### 126 | 127 | 128 | class ResidualBlock(torch.nn.Module): 129 | 130 | def __init__(self, channels): 131 | super().__init__() 132 | self.num_kernels = len(promonet.HIFIGAN_RESBLOCK_KERNEL_SIZES) 133 | self.model = torch.nn.ModuleList([ 134 | Block(channels, kernel_size, dilation_rate) 135 | for kernel_size, dilation_rate in zip( 136 | promonet.HIFIGAN_RESBLOCK_KERNEL_SIZES, 137 | promonet.HIFIGAN_RESBLOCK_DILATION_SIZES 138 | ) 139 | ]) 140 | 141 | def forward(self, x): 142 | xs = None 143 | for layer in self.model: 144 | xs = layer(x) if xs is None else xs + layer(x) 145 | return xs / self.num_kernels 146 | 147 | def remove_weight_norm(self): 148 | for layer in self.model: 149 | layer.remove_weight_norm() 150 | 151 | 152 | ############################################################################### 153 | # HiFi-GAN inner block 154 | ############################################################################### 155 | 156 | 157 | class Block(torch.nn.Module): 158 | 159 | def __init__( 160 | self, 161 | channels, 162 | kernel_size=3, 163 | dilation=(1, 3, 5)): 164 | super().__init__() 165 | 166 | # Convolutions 167 | conv_fn = functools.partial( 168 | promonet.model.weight_norm_conv1d, 169 | channels, 170 | channels, 171 | kernel_size, 172 | 1) 173 | pad_fn = functools.partial(promonet.model.get_padding, kernel_size) 174 | self.convs1 = torch.nn.ModuleList([ 175 | conv_fn(pad_fn(dilation[0]), dilation[0]), 176 | conv_fn(pad_fn(dilation[1]), dilation[1]), 177 | conv_fn(pad_fn(dilation[2]), dilation[2])]) 178 | self.convs1.apply(init_weights) 179 | self.convs2 = torch.nn.ModuleList([ 180 | conv_fn(pad_fn()), 181 | conv_fn(pad_fn()), 182 | conv_fn(pad_fn())]) 183 | self.convs2.apply(init_weights) 184 | 185 | # Activations 186 | activation_fn = functools.partial( 187 | torch.nn.LeakyReLU, 188 | negative_slope=promonet.LRELU_SLOPE) 189 | self.activations1 = torch.nn.ModuleList([ 190 | activation_fn(), 191 | activation_fn(), 192 | activation_fn()]) 193 | self.activations2 = torch.nn.ModuleList([ 194 | activation_fn(), 195 | activation_fn(), 196 | activation_fn()]) 197 | 198 | def forward(self, x): 199 | iterator = zip( 200 | self.convs1, 201 | self.convs2, 202 | self.activations1, 203 | self.activations2) 204 | for c1, c2, a1, a2 in iterator: 205 | xt = a1(x) 206 | xt = c1(xt) 207 | xt = a2(xt) 208 | xt = c2(xt) 209 | x = xt + x 210 | return x 211 | 212 | def remove_weight_norm(self): 213 | """Remove weight norm for scriptable inference""" 214 | for layer in self.convs1: 215 | torch.nn.utils.remove_weight_norm(layer) 216 | for layer in self.convs2: 217 | torch.nn.utils.remove_weight_norm(layer) 218 | 219 | 220 | def init_weights(m, mean=0.0, std=0.01): 221 | classname = m.__class__.__name__ 222 | if classname.find("Conv") != -1: 223 | m.weight.data.normal_(mean, std) 224 | -------------------------------------------------------------------------------- /promonet/model/vocos.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import promonet 4 | 5 | 6 | ############################################################################### 7 | # Vocos vocoder 8 | ############################################################################### 9 | 10 | 11 | class Vocos(torch.nn.Module): 12 | 13 | def __init__(self, initial_channel, gin_channels): 14 | super().__init__() 15 | 16 | # Input feature projection 17 | self.conv_pre = torch.nn.Conv1d( 18 | initial_channel, 19 | promonet.VOCOS_CHANNELS, 20 | 7, 21 | 1, 22 | padding='same') 23 | 24 | # Model architecture 25 | self.backbone = VocosBackbone( 26 | promonet.VOCOS_CHANNELS, 27 | promonet.VOCOS_CHANNELS, 28 | promonet.VOCOS_LAYERS) 29 | 30 | # Differentiable iSTFT 31 | self.head = ISTFTHead( 32 | promonet.VOCOS_CHANNELS, 33 | promonet.NUM_FFT, 34 | promonet.HOPSIZE) 35 | 36 | # Speaker conditioning 37 | self.cond = torch.nn.Conv1d( 38 | gin_channels, 39 | promonet.VOCOS_CHANNELS, 40 | 1) 41 | 42 | def forward(self, x, g=None): 43 | # Initial conv 44 | x = self.conv_pre(x) 45 | 46 | # Speaker conditioning 47 | if g is not None: 48 | g = self.cond(g) 49 | x += g 50 | 51 | # Infer complex STFT 52 | x = self.backbone(x, g) 53 | 54 | # Perform iSTFT to get waveform 55 | return self.head(x) 56 | 57 | 58 | ############################################################################### 59 | # Vocos architecture 60 | ############################################################################### 61 | 62 | 63 | class VocosBackbone(torch.nn.Module): 64 | 65 | def __init__( 66 | self, 67 | input_channels: int, 68 | dim: int, 69 | num_layers: int, 70 | ): 71 | super().__init__() 72 | self.input_channels = input_channels 73 | self.embed = torch.nn.Conv1d( 74 | input_channels, 75 | dim, 76 | kernel_size=7, 77 | padding=3) 78 | self.norm = torch.nn.LayerNorm(dim, eps=1e-6) 79 | self.convnext = torch.nn.ModuleList( 80 | [ 81 | ConvNeXtBlock( 82 | dim=dim, 83 | layer_scale_init_value=1 / num_layers) 84 | for _ in range(num_layers) 85 | ] 86 | ) 87 | self.final_layer_norm = torch.nn.LayerNorm(dim, eps=1e-6) 88 | self.apply(self._init_weights) 89 | 90 | def _init_weights(self, m): 91 | if isinstance(m, (torch.nn.Conv1d, torch.nn.Linear)): 92 | torch.nn.init.trunc_normal_(m.weight, std=0.02) 93 | torch.nn.init.constant_(m.bias, 0) 94 | 95 | def forward(self, x, g): 96 | x = self.embed(x) 97 | x = self.norm(x.transpose(1, 2)) 98 | x = x.transpose(1, 2) 99 | for conv_block in self.convnext: 100 | x = conv_block(x, g) 101 | x = self.final_layer_norm(x.transpose(1, 2)) 102 | return x.transpose(1, 2) 103 | 104 | 105 | ############################################################################### 106 | # ConvNeXt block 107 | ############################################################################### 108 | 109 | 110 | class ConvNeXtBlock(torch.nn.Module): 111 | 112 | def __init__(self, dim, layer_scale_init_value): 113 | super().__init__() 114 | self.dwconv = torch.nn.Conv1d( 115 | dim, 116 | dim, 117 | kernel_size=7, 118 | padding=3, 119 | groups=dim) 120 | self.norm = torch.nn.LayerNorm(dim, eps=1e-6) 121 | self.pwconv1 = torch.nn.Linear(dim, promonet.VOCOS_POINTWISE_CHANNELS) 122 | self.act = torch.nn.GELU() 123 | self.pwconv2 = torch.nn.Linear(promonet.VOCOS_POINTWISE_CHANNELS, dim) 124 | self.gamma = ( 125 | torch.nn.Parameter( 126 | layer_scale_init_value * torch.ones(dim), 127 | requires_grad=True)) 128 | 129 | def forward(self, x, g): 130 | residual = x 131 | x = self.dwconv(x) 132 | x = x.transpose(1, 2) 133 | x = self.norm(x) 134 | x = self.pwconv1(x) 135 | x = self.act(x) 136 | x = self.pwconv2(x) 137 | if self.gamma is not None: 138 | x = self.gamma * x 139 | x = x.transpose(1, 2) 140 | return residual + x 141 | 142 | 143 | ############################################################################### 144 | # Vocos ISTFT head 145 | ############################################################################### 146 | 147 | 148 | class ISTFTHead(torch.nn.Module): 149 | 150 | def __init__(self, dim, n_fft, hop_length): 151 | super().__init__() 152 | self.out = torch.nn.Linear(dim, n_fft + 2) 153 | self.istft = ISTFT( 154 | n_fft=n_fft, 155 | hop_length=hop_length, 156 | win_length=n_fft) 157 | 158 | def forward(self, x): 159 | x = self.out(x.transpose(1, 2)).transpose(1, 2) 160 | mag, p = x.chunk(2, dim=1) 161 | mag = torch.exp(mag) 162 | mag = torch.clip(mag, max=1e2) 163 | x = torch.cos(p) 164 | y = torch.sin(p) 165 | S = mag * (x + 1j * y) 166 | return self.istft(S).unsqueeze(1) 167 | 168 | 169 | class ISTFT(torch.nn.Module): 170 | 171 | def __init__(self, n_fft, hop_length, win_length): 172 | super().__init__() 173 | self.n_fft = n_fft 174 | self.hop_length = hop_length 175 | self.win_length = win_length 176 | window = torch.hann_window(win_length) 177 | self.register_buffer('window', window) 178 | 179 | def forward(self, spec: torch.Tensor) -> torch.Tensor: 180 | B, N, T = spec.shape 181 | pad = (self.win_length - self.hop_length) // 2 182 | 183 | # Inverse FFT 184 | ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm='backward') 185 | ifft = ifft * self.window[None, :, None] 186 | 187 | # Overlap and Add 188 | output_size = (T - 1) * self.hop_length + self.win_length 189 | y = torch.nn.functional.fold( 190 | ifft, 191 | output_size=(1, output_size), 192 | kernel_size=(1, self.win_length), 193 | stride=(1, self.hop_length), 194 | )[:, 0, 0, pad:-pad] 195 | 196 | # Window envelope 197 | window_sq = self.window.square().expand(1, T, -1).transpose(1, 2) 198 | window_envelope = torch.nn.functional.fold( 199 | window_sq, 200 | output_size=(1, output_size), 201 | kernel_size=(1, self.win_length), 202 | stride=(1, self.hop_length), 203 | ).squeeze()[pad:-pad] 204 | 205 | # Normalize 206 | return y / window_envelope 207 | -------------------------------------------------------------------------------- /promonet/partition/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import * 2 | -------------------------------------------------------------------------------- /promonet/partition/__main__.py: -------------------------------------------------------------------------------- 1 | import yapecs 2 | 3 | import promonet 4 | 5 | 6 | ############################################################################### 7 | # Partition datasets 8 | ############################################################################### 9 | 10 | 11 | def parse_args(): 12 | """Parse command-line arguments""" 13 | parser = yapecs.ArgumentParser(description='Partition datasets') 14 | parser.add_argument( 15 | '--datasets', 16 | default=promonet.DATASETS, 17 | nargs='+', 18 | help='The datasets to partition') 19 | return parser.parse_args() 20 | 21 | 22 | promonet.partition.datasets(**vars(parse_args())) 23 | -------------------------------------------------------------------------------- /promonet/partition/core.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data partitions 3 | 4 | DAPS 5 | ==== 6 | * train_adapt_{:02d} - Training dataset for speaker adaptation (10 speakers) 7 | * test_adapt_{:02d} - Test dataset for speaker adaptation 8 | (10 speakers; 10 examples per speaker; 4-10 seconds) 9 | 10 | LibriTTS 11 | ======== 12 | * train - Training data 13 | * valid - Validation set of seen speakers for debugging and tensorboard 14 | (64 examples) 15 | * train_adapt_{:02d} - Training dataset for speaker adaptation (10 speakers) 16 | * test_adapt_{:02d} - Test dataset for speaker adaptation 17 | (10 speakers; 10 examples per speaker; 4-10 seconds) 18 | 19 | VCTK 20 | ==== 21 | * train - Training data 22 | * valid - Validation set of seen speakers for debugging and tensorboard 23 | (64 examples; 4-10 seconds) 24 | * train_adapt_{:02d} - Training dataset for speaker adaptation (10 speakers) 25 | * test_adapt_{:02d} - Test dataset for speaker adaptation 26 | (10 speakers; 10 examples per speaker; 4-10 seconds) 27 | """ 28 | import functools 29 | import itertools 30 | import json 31 | import random 32 | 33 | import torchaudio 34 | 35 | import promonet 36 | 37 | 38 | ############################################################################### 39 | # Constants 40 | ############################################################################### 41 | 42 | 43 | # Range of allowable test sample lengths in seconds 44 | MAX_TEST_SAMPLE_LENGTH = 10. 45 | MIN_TEST_SAMPLE_LENGTH = 4. 46 | 47 | 48 | ############################################################################### 49 | # Adaptation speaker IDs 50 | ############################################################################### 51 | 52 | 53 | # We manually select test speakers to ensure gender balance 54 | DAPS_ADAPTATION_SPEAKERS = [ 55 | # Female 56 | '0002', 57 | '0007', 58 | '0010', 59 | '0013', 60 | '0019', 61 | 62 | # Male 63 | '0003', 64 | '0005', 65 | '0014', 66 | '0015', 67 | '0017'] 68 | 69 | # Speakers selected by sorting the train-clean-100 speakers by longest total 70 | # recording duration and manually selecting speakers with more natural, 71 | # conversational (as opposed to read) prosody 72 | LIBRITTS_ADAPTATION_SPEAKERS = [ 73 | # Female 74 | '40', 75 | '669', 76 | '4362', 77 | '5022', 78 | '8123', 79 | 80 | # Male 81 | '196', 82 | '460', 83 | '1355', 84 | '3664', 85 | '7067'] 86 | 87 | # Gender-balanced VCTK speakers 88 | VCTK_ADAPTATION_SPEAKERS = [ 89 | # Female 90 | '0013', 91 | '0037', 92 | '0070', 93 | '0082', 94 | '0108', 95 | 96 | # Male 97 | '0016', 98 | '0032', 99 | '0047', 100 | '0073', 101 | '0083'] 102 | 103 | 104 | ############################################################################### 105 | # Partition 106 | ############################################################################### 107 | 108 | 109 | def adaptation(name): 110 | """Partition dataset for speaker adaptation""" 111 | directory = promonet.CACHE_DIR / name 112 | train = [ 113 | f'{file.parent.name}/{file.stem}' 114 | for file in directory.rglob('*.wav')] 115 | return {'train': train, 'valid': []} 116 | 117 | 118 | def datasets(datasets): 119 | """Partition datasets and save to disk""" 120 | for name in datasets: 121 | 122 | # Remove cached training statistics that may become stale 123 | for stats_file in (promonet.ASSETS_DIR / 'stats').glob('*.pt'): 124 | stats_file.unlink() 125 | 126 | # Partition 127 | if name == 'vctk': 128 | partition = vctk() 129 | elif name == 'daps': 130 | partition = daps() 131 | elif name == 'libritts': 132 | partition = libritts() 133 | 134 | # All other datasets are assumed to be for speaker adaptation 135 | else: 136 | partition = adaptation(name) 137 | 138 | # Sort partitions 139 | partition = {key: sorted(value) for key, value in partition.items()} 140 | 141 | # Save to disk 142 | file = promonet.PARTITION_DIR / f'{name}.json' 143 | file.parent.mkdir(exist_ok=True, parents=True) 144 | with open(file, 'w') as file: 145 | json.dump(partition, file, indent=4) 146 | 147 | 148 | def daps(): 149 | """Partition the DAPS dataset""" 150 | # Get stems 151 | directory = promonet.CACHE_DIR / 'daps' 152 | stems = [ 153 | f'{file.parent.name}/{file.stem[:6]}' 154 | for file in directory.rglob('*.txt')] 155 | 156 | # Create speaker adaptation partitions 157 | return adaptation_partitions( 158 | directory, 159 | stems, 160 | DAPS_ADAPTATION_SPEAKERS) 161 | 162 | 163 | def libritts(): 164 | """Partition libritts dataset""" 165 | # Get list of speakers 166 | directory = promonet.CACHE_DIR / 'libritts' 167 | stems = { 168 | f'{file.parent.name}/{file.stem[:6]}' 169 | for file in directory.rglob('*.txt')} 170 | 171 | # Get speaker map 172 | with open(directory / 'speakers.json') as file: 173 | speaker_map = json.load(file) 174 | 175 | # Get adaptation speakers 176 | speakers = [ 177 | f'{speaker_map[speaker][0]:04d}' 178 | for speaker in LIBRITTS_ADAPTATION_SPEAKERS] 179 | 180 | # Create speaker adaptation partitions 181 | adapt_partitions = adaptation_partitions( 182 | directory, 183 | stems, 184 | speakers) 185 | 186 | # Get test partition indices 187 | test_stems = list( 188 | itertools.chain.from_iterable(adapt_partitions.values())) 189 | 190 | # Get residual indices 191 | residual = [stem for stem in stems if stem not in test_stems] 192 | random.shuffle(residual) 193 | 194 | # Get validation stems 195 | filter_fn = functools.partial(meets_length_criteria, directory) 196 | valid_stems = list(filter(filter_fn, residual))[:64] 197 | 198 | # Get training stems 199 | train_stems = [stem for stem in residual if stem not in valid_stems] 200 | 201 | # Merge training and adaptation partitions 202 | partition = {'train': sorted(train_stems), 'valid': sorted(valid_stems)} 203 | return {**partition, **adapt_partitions} 204 | 205 | 206 | def vctk(): 207 | """Partition the vctk dataset""" 208 | # Get list of speakers 209 | directory = promonet.CACHE_DIR / 'vctk' 210 | stems = { 211 | f'{file.parent.name}/{file.stem[:6]}' 212 | for file in directory.rglob('*.txt')} 213 | 214 | # Get file stem correspondence 215 | with open(directory / 'correspondence.json') as file: 216 | correspondence = json.load(file) 217 | 218 | # Create speaker adaptation partitions 219 | if promonet.ADAPTATION: 220 | adapt_partitions = adaptation_partitions( 221 | directory, 222 | stems, 223 | VCTK_ADAPTATION_SPEAKERS) 224 | 225 | # Get test partition indices 226 | test_stems = list( 227 | itertools.chain.from_iterable(adapt_partitions.values())) 228 | test_correspondence = [correspondence[stem][:-1] for stem in test_stems] 229 | 230 | # Get residual indices 231 | residual = [ 232 | stem for stem in stems 233 | if stem not in test_stems and 234 | correspondence[stem][:-1] not in test_correspondence] 235 | random.shuffle(residual) 236 | 237 | # Get validation stems 238 | filter_fn = functools.partial(meets_length_criteria, directory) 239 | valid_stems = list(filter(filter_fn, residual))[:64] 240 | 241 | # Get training stems 242 | train_stems = [stem for stem in residual if stem not in valid_stems] 243 | 244 | # Merge training and adaptation partitions 245 | partition = {'train': train_stems, 'valid': valid_stems} 246 | return {**partition, **adapt_partitions} 247 | else: 248 | test_speaker_stems = { 249 | speaker: [stem for stem in stems if stem.split('/')[0] == speaker] 250 | for speaker in VCTK_ADAPTATION_SPEAKERS} 251 | filter_fn = functools.partial(meets_length_criteria, directory) 252 | test_stems = [] 253 | for speaker, speaker_stems in test_speaker_stems.items(): 254 | random.shuffle(speaker_stems) 255 | test_stems += list(filter(filter_fn, speaker_stems))[:10] 256 | test_correspondence = [correspondence[stem][:-1] for stem in test_stems] 257 | 258 | residual = [ 259 | stem for stem in stems 260 | if stem not in test_stems and 261 | correspondence[stem][:-1] not in test_correspondence] 262 | random.shuffle(residual) 263 | 264 | # Get validation stems 265 | filter_fn = functools.partial(meets_length_criteria, directory) 266 | valid_stems = list(filter(filter_fn, residual))[:64] 267 | 268 | # Get training stems 269 | train_stems = [stem for stem in residual if stem not in valid_stems] 270 | 271 | return {'train': train_stems, 'valid': valid_stems, 'test': test_stems} 272 | 273 | 274 | ############################################################################### 275 | # Utilities 276 | ############################################################################### 277 | 278 | 279 | def adaptation_partitions(directory, stems, speakers): 280 | """Create the speaker adaptation partitions""" 281 | # Get adaptation data 282 | adaptation_stems = { 283 | speaker: [stem for stem in stems if stem.split('/')[0] == speaker] 284 | for speaker in speakers} 285 | 286 | # Get length filter 287 | filter_fn = functools.partial(meets_length_criteria, directory) 288 | 289 | # Partition adaptation data 290 | adaptation_partition = {} 291 | random.seed(promonet.RANDOM_SEED) 292 | for i, speaker in enumerate(speakers): 293 | random.shuffle(adaptation_stems[speaker]) 294 | 295 | # Partition speaker data 296 | test_adapt_stems = list( 297 | filter(filter_fn, adaptation_stems[speaker]))[:10] 298 | train_adapt_stems = [ 299 | stem for stem in adaptation_stems[speaker] 300 | if stem not in test_adapt_stems] 301 | 302 | # Save partition 303 | adaptation_partition[f'train-adapt-{i:02d}'] = train_adapt_stems 304 | adaptation_partition[f'test-adapt-{i:02d}'] = test_adapt_stems 305 | 306 | return adaptation_partition 307 | 308 | 309 | def meets_length_criteria(directory, stem): 310 | """Returns True if the audio file duration is within the length criteria""" 311 | info = torchaudio.info(directory / f'{stem}.wav') 312 | duration = info.num_frames / info.sample_rate 313 | return MIN_TEST_SAMPLE_LENGTH <= duration <= MAX_TEST_SAMPLE_LENGTH 314 | -------------------------------------------------------------------------------- /promonet/plot/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import * 2 | from . import speaker 3 | -------------------------------------------------------------------------------- /promonet/plot/__main__.py: -------------------------------------------------------------------------------- 1 | import yapecs 2 | from pathlib import Path 3 | 4 | import promonet 5 | 6 | 7 | ############################################################################### 8 | # Plot speech representation 9 | ############################################################################### 10 | 11 | 12 | def parse_args(): 13 | """Parse command-line arguments""" 14 | parser = yapecs.ArgumentParser(description='Plot speech representation') 15 | parser.add_argument( 16 | '--audio_file', 17 | type=Path, 18 | required=True, 19 | help='The speech audio') 20 | parser.add_argument( 21 | '--output_file', 22 | type=Path, 23 | required=True, 24 | help='The file to save the output figure') 25 | parser.add_argument( 26 | '--target_file', 27 | type=Path, 28 | help='Optional corresponding ground truth to compare to') 29 | parser.add_argument( 30 | '--features', 31 | nargs='+', 32 | choices=promonet.DEFAULT_PLOT_FEATURES, 33 | default=promonet.DEFAULT_PLOT_FEATURES, 34 | help='The features to plot' ) 35 | parser.add_argument( 36 | '--gpu', 37 | type=int, 38 | help='The GPU index') 39 | return parser.parse_args() 40 | 41 | 42 | promonet.plot.from_file_to_file(**vars(parse_args())) 43 | -------------------------------------------------------------------------------- /promonet/plot/speaker/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import * 2 | -------------------------------------------------------------------------------- /promonet/plot/speaker/core.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | from umap import UMAP 6 | 7 | 8 | ############################################################################### 9 | # Constants 10 | ############################################################################### 11 | 12 | 13 | COLORS = np.array([ 14 | [0, 127, 70], 15 | [255, 0, 0], 16 | [255, 217, 38], 17 | [0, 135, 255], 18 | [165, 0, 165], 19 | [255, 167, 255], 20 | [97, 142, 151], 21 | [0, 255, 255], 22 | [255, 96, 38], 23 | [142, 76, 0], 24 | [33, 0, 127], 25 | [0, 0, 0], 26 | [183, 183, 183], 27 | [76, 255, 0], 28 | ], dtype=float) / 255 29 | 30 | 31 | ############################################################################### 32 | # Plot speaker embeddings 33 | ############################################################################### 34 | 35 | 36 | def from_embeddings( 37 | centers, 38 | embeddings, 39 | ax=None, 40 | markers=None, 41 | legend=True, 42 | title='', 43 | file=None): 44 | # Maybe create figure 45 | if ax is None: 46 | fig, ax = plt.subplots(figsize=(6, 6)) 47 | 48 | # Add title 49 | ax.set_title(title) 50 | 51 | # Format 52 | center_speakers, center_embeddings = zip(*centers.items()) 53 | center_embeddings = np.array([item.numpy() for item in center_embeddings]) 54 | speakers, embeddings = zip(*embeddings.items()) 55 | speakers = itertools.chain.from_iterable([ 56 | [index] * len(embed) for index, embed in zip(speakers, embeddings)]) 57 | embeddings = np.array([ 58 | item.numpy() for item in itertools.chain.from_iterable( 59 | embeddings)]) 60 | 61 | # Compute 2D projections 62 | projections = UMAP().fit_transform( 63 | np.append(center_embeddings, embeddings, axis=0)) 64 | center_projections = projections[:center_embeddings.shape[0]] 65 | projections = projections[center_embeddings.shape[0]:] 66 | 67 | # Iterate over speakers 68 | for i, speaker in enumerate(center_speakers): 69 | 70 | # Get projections 71 | center_projection = center_projections[i] 72 | speaker_projections = projections[ 73 | np.array([index == speaker for index in speakers])] 74 | 75 | # Style 76 | marker = 'o' if markers is None else markers[i] 77 | label = speaker if legend else None 78 | 79 | # Plot 80 | ax.scatter( 81 | *center_projection.T, 82 | c=[COLORS[i]], 83 | marker=marker, 84 | label=label + ' GT') 85 | ax.scatter( 86 | *speaker_projections.T, 87 | c=[COLORS[i] * 0.5], 88 | marker=marker, 89 | label=label + ' reconstruct') 90 | 91 | # Add legend 92 | if legend: 93 | ax.legend(title='Speakers', ncol=2) 94 | 95 | # Equal aspect ratio 96 | ax.set_aspect('equal') 97 | 98 | # Save to disk 99 | if file: 100 | plt.savefig(file, bbox_inches='tight', pad_inches=0) 101 | 102 | return fig 103 | -------------------------------------------------------------------------------- /promonet/preprocess/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import * 2 | from . import harmonics 3 | from . import loudness 4 | from . import speaker 5 | from . import spectrogram 6 | from . import text 7 | -------------------------------------------------------------------------------- /promonet/preprocess/__main__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import yapecs 4 | 5 | import promonet 6 | 7 | 8 | ############################################################################### 9 | # Preprocess 10 | ############################################################################### 11 | 12 | 13 | def parse_args(): 14 | parser = yapecs.ArgumentParser(description='Preprocess') 15 | parser.add_argument( 16 | '--files', 17 | nargs='+', 18 | type=Path, 19 | required=True, 20 | help='Audio files to preprocess') 21 | parser.add_argument( 22 | '--output_prefixes', 23 | nargs='+', 24 | type=Path, 25 | help='Files to save features, minus extension') 26 | parser.add_argument( 27 | '--features', 28 | default=promonet.INPUT_FEATURES, 29 | choices=promonet.INPUT_FEATURES, 30 | nargs='+', 31 | help='The features to preprocess') 32 | parser.add_argument( 33 | '--gpu', 34 | type=int, 35 | help='The index of the gpu to use') 36 | return parser.parse_args() 37 | 38 | 39 | promonet.preprocess.from_files_to_files(**vars(parse_args())) 40 | -------------------------------------------------------------------------------- /promonet/preprocess/loudness.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import multiprocessing as mp 3 | import warnings 4 | 5 | import librosa 6 | import numpy as np 7 | import torch 8 | 9 | import promonet 10 | 11 | 12 | ############################################################################### 13 | # Loudness feature 14 | ############################################################################### 15 | 16 | 17 | def from_audio(audio, bands=1): 18 | """Compute A-weighted loudness""" 19 | # Pad 20 | padding = (promonet.WINDOW_SIZE - promonet.HOPSIZE) // 2 21 | audio = torch.nn.functional.pad( 22 | audio[None], 23 | (padding, padding), 24 | mode='reflect' 25 | ).squeeze(0) 26 | 27 | # Save device 28 | device = audio.device 29 | 30 | # Convert to numpy 31 | audio = audio.detach().cpu().numpy().squeeze(0) 32 | 33 | # Cache weights 34 | if not hasattr(from_audio, 'weights'): 35 | from_audio.weights = perceptual_weights() 36 | 37 | # Take stft 38 | stft = librosa.stft( 39 | audio, 40 | n_fft=promonet.WINDOW_SIZE, 41 | hop_length=promonet.HOPSIZE, 42 | win_length=promonet.WINDOW_SIZE, 43 | center=False) 44 | 45 | # Apply A-weighting in units of dB 46 | weighted = librosa.amplitude_to_db(np.abs(stft)) + from_audio.weights 47 | 48 | # Threshold 49 | weighted[weighted < promonet.MIN_DB] = promonet.MIN_DB 50 | 51 | # Multiband loudness 52 | loudness = torch.from_numpy(weighted).float().to(device) 53 | 54 | # Maybe average 55 | return band_average(loudness, bands) if bands is not None else loudness 56 | 57 | 58 | def from_file(audio_file, bands=promonet.LOUDNESS_BANDS): 59 | """Compute A-weighted loudness from audio file""" 60 | return from_audio(promonet.load.audio(audio_file), bands) 61 | 62 | 63 | def from_file_to_file(audio_file, output_file, bands=promonet.LOUDNESS_BANDS): 64 | """Compute A-weighted loudness from audio file and save""" 65 | torch.save(from_file(audio_file, bands), output_file) 66 | 67 | 68 | def from_files_to_files( 69 | audio_files, 70 | output_files, 71 | bands=promonet.LOUDNESS_BANDS 72 | ): 73 | """Compute A-weighted loudness from audio files and save""" 74 | loudness_fn = functools.partial(from_file_to_file, bands=bands) 75 | with mp.get_context('spawn').Pool(promonet.NUM_WORKERS) as pool: 76 | pool.starmap(loudness_fn, zip(audio_files, output_files)) 77 | 78 | 79 | ############################################################################### 80 | # Loudness utilities 81 | ############################################################################### 82 | 83 | 84 | def band_average(loudness, bands=promonet.LOUDNESS_BANDS): 85 | """Average over frequency bands""" 86 | if bands is not None: 87 | 88 | if bands == 1: 89 | 90 | # Average over all weighted frequencies 91 | loudness = loudness.mean(dim=-2, keepdim=True) 92 | 93 | else: 94 | 95 | # Average over loudness frequency bands 96 | step = loudness.shape[-2] / bands 97 | if loudness.ndim == 2: 98 | loudness = torch.stack( 99 | [ 100 | loudness[int(band * step):int((band + 1) * step)].mean(dim=-2) 101 | for band in range(int(bands)) 102 | ]) 103 | else: 104 | loudness = torch.stack( 105 | [ 106 | loudness[:, int(band * step):int((band + 1) * step)].mean(dim=-2) 107 | for band in range(bands) 108 | ], 109 | dim=1) 110 | 111 | return loudness 112 | 113 | 114 | def limit(audio, delay=40, attack_coef=.9, release_coef=.9995, threshold=.99): 115 | """Apply a limiter to prevent clipping""" 116 | # Delay compensation 117 | audio = torch.nn.functional.pad(audio, (0, delay - 1)) 118 | 119 | current_gain = 1. 120 | delay_index = 0 121 | delay_line = torch.zeros(delay) 122 | envelope = 0 123 | 124 | for idx, sample in enumerate(audio[0]): 125 | 126 | # Update signal history 127 | delay_line[delay_index] = sample 128 | delay_index = (delay_index + 1) % delay 129 | 130 | # Calculate envelope 131 | envelope = max(abs(sample), envelope * release_coef) 132 | 133 | # Calcuate gain 134 | target_gain = threshold / envelope if envelope > threshold else 1. 135 | current_gain = \ 136 | current_gain * attack_coef + target_gain * (1 - attack_coef) 137 | 138 | # Apply gain 139 | audio[:, idx] = delay_line[delay_index] * current_gain 140 | 141 | return audio[:, delay - 1:] 142 | 143 | 144 | def normalize(loudness): 145 | """Normalize loudness to [-1., 1.]""" 146 | return (loudness - promonet.MIN_DB) / (promonet.REF_DB - promonet.MIN_DB) 147 | 148 | 149 | def perceptual_weights(): 150 | """A-weighted frequency-dependent perceptual loudness weights""" 151 | frequencies = librosa.fft_frequencies( 152 | sr=promonet.SAMPLE_RATE, 153 | n_fft=promonet.WINDOW_SIZE) 154 | 155 | # A warning is raised for nearly inaudible frequencies, but it ends up 156 | # defaulting to -100 db. That default is fine for our purposes. 157 | with warnings.catch_warnings(): 158 | warnings.simplefilter('ignore', RuntimeWarning) 159 | return ( 160 | librosa.A_weighting(frequencies)[:, None] - float(promonet.REF_DB)) 161 | 162 | 163 | def scale(audio, target_loudness): 164 | """Scale the audio to the target loudness""" 165 | # Maybe average to get scalar loudness 166 | if target_loudness.shape[-2] > 1: 167 | target_loudness = target_loudness.mean(dim=-2, keepdim=True) 168 | 169 | # Get current loudness 170 | loudness = from_audio(audio.to(torch.float64)) 171 | 172 | # Take difference and convert from dB to ratio 173 | gain = promonet.convert.db_to_ratio(target_loudness - loudness) 174 | 175 | # Apply gain and prevent clipping 176 | return limit(shift(audio, gain)) 177 | 178 | 179 | def shift(audio, value): 180 | """Shift loudness by target value in decibels""" 181 | # Convert from dB to ratio 182 | gain = promonet.convert.db_to_ratio(value) 183 | 184 | # Linearly interpolate to the audio resolution 185 | if isinstance(gain, torch.Tensor) and gain.numel() > 1: 186 | gain = torch.nn.functional.interpolate( 187 | gain[None], 188 | size=audio.shape[1], 189 | mode='linear', 190 | align_corners=False)[0] 191 | 192 | # Scale 193 | return gain * audio 194 | -------------------------------------------------------------------------------- /promonet/preprocess/speaker.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | import torchutil 4 | from transformers import Wav2Vec2FeatureExtractor, WavLMForXVector 5 | 6 | import promonet 7 | 8 | 9 | ############################################################################### 10 | # Constants 11 | ############################################################################### 12 | 13 | 14 | # Maximum batch size for batched WavLM inference 15 | WAVLM_MAX_BATCH_SIZE = 16 16 | 17 | # Sample rate of the WavLM model audio input 18 | WAVLM_SAMPLE_RATE = 16000 19 | 20 | 21 | ############################################################################### 22 | # WavLM x-vector speaker embedding 23 | ############################################################################### 24 | 25 | 26 | def from_audio(audio, sample_rate=promonet.SAMPLE_RATE, gpu=None): 27 | """Compute speaker embedding from audio""" 28 | # Resample 29 | torchaudio.functional.resample(audio, sample_rate, WAVLM_SAMPLE_RATE) 30 | 31 | # Embed 32 | return infer(audio[0], gpu) 33 | 34 | 35 | def from_file(file, gpu=None): 36 | """Compute speaker embedding from file""" 37 | return from_audio(promonet.load.audio(file), gpu=gpu) 38 | 39 | 40 | def from_file_to_file(file, output_file, gpu=None): 41 | """Compute speaker embedding from file and save""" 42 | # Embed 43 | embedding = from_file(file, gpu).cpu() 44 | 45 | # Save 46 | torch.save(embedding, output_file) 47 | 48 | 49 | def from_files_to_files(files, output_files, gpu=None): 50 | """Compute speaker embedding from files and save""" 51 | for file, output_file in torchutil.iterator( 52 | zip(files, output_files), 53 | 'WavLM x-vectors', 54 | total=len(files) 55 | ): 56 | from_file_to_file(file, output_file, gpu) 57 | 58 | 59 | ############################################################################### 60 | # Utilities 61 | ############################################################################### 62 | 63 | 64 | def infer(audio, gpu=None): 65 | """Infer speaker embedding from audio""" 66 | # Cache networks 67 | if not hasattr(infer, 'feature_extractor'): 68 | infer.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( 69 | 'microsoft/wavlm-base-plus-sv') 70 | if not hasattr(infer, 'model'): 71 | infer.model = WavLMForXVector.from_pretrained( 72 | 'microsoft/wavlm-base-plus-sv') 73 | 74 | # Place on device (no-op if devices match) 75 | device = torch.device('cpu' if gpu is None else f'cuda:{gpu}') 76 | infer.model.to(device) 77 | 78 | # Preprocess 79 | features = infer.feature_extractor( 80 | audio, 81 | padding=True, 82 | return_tensors="pt") 83 | 84 | # Embed 85 | embeddings = infer.model( 86 | features['input_values'].to(device), 87 | features['attention_mask'].to(device) 88 | ).embeddings.detach() 89 | 90 | # Normalize 91 | return torch.nn.functional.normalize(embeddings, dim=-1) 92 | -------------------------------------------------------------------------------- /promonet/preprocess/spectrogram.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import multiprocessing as mp 3 | 4 | import torch 5 | import librosa 6 | 7 | import promonet 8 | 9 | 10 | ############################################################################### 11 | # Spectrogram computation 12 | ############################################################################### 13 | 14 | 15 | def from_audio( 16 | audio, 17 | mels=False, 18 | log_dynamic_range_compression_threshold=\ 19 | promonet.LOG_DYNAMIC_RANGE_COMPRESSION_THRESHOLD 20 | ): 21 | """Compute spectrogram from audio""" 22 | # Cache hann window 23 | if ( 24 | not hasattr(from_audio, 'window') or 25 | from_audio.dtype != audio.dtype or 26 | from_audio.device != audio.device 27 | ): 28 | from_audio.window = torch.hann_window( 29 | promonet.WINDOW_SIZE, 30 | dtype=audio.dtype, 31 | device=audio.device) 32 | from_audio.dtype = audio.dtype 33 | from_audio.device = audio.device 34 | 35 | # Pad audio 36 | size = (promonet.NUM_FFT - promonet.HOPSIZE) // 2 37 | audio = torch.nn.functional.pad(audio, (size, size), mode='reflect') 38 | 39 | # Compute stft 40 | stft = torch.stft( 41 | audio.squeeze(1), 42 | promonet.NUM_FFT, 43 | hop_length=promonet.HOPSIZE, 44 | window=from_audio.window, 45 | center=False, 46 | normalized=False, 47 | onesided=True, 48 | return_complex=True) 49 | stft = torch.view_as_real(stft) 50 | 51 | # Compute magnitude 52 | spectrogram = torch.sqrt(stft.pow(2).sum(-1) + 1e-6) 53 | 54 | # Maybe convert to mels 55 | if mels: 56 | spectrogram = linear_to_mel( 57 | spectrogram, 58 | log_dynamic_range_compression_threshold) 59 | 60 | return spectrogram.squeeze(0) 61 | 62 | 63 | def from_file( 64 | audio_file, 65 | mels=False, 66 | log_dynamic_range_compression_threshold=\ 67 | promonet.LOG_DYNAMIC_RANGE_COMPRESSION_THRESHOLD 68 | ): 69 | """Compute spectrogram from audio file""" 70 | audio = promonet.load.audio(audio_file) 71 | return from_audio(audio, mels, log_dynamic_range_compression_threshold) 72 | 73 | 74 | def from_file_to_file( 75 | audio_file, 76 | output_file, 77 | mels=False, 78 | log_dynamic_range_compression_threshold=\ 79 | promonet.LOG_DYNAMIC_RANGE_COMPRESSION_THRESHOLD 80 | ): 81 | """Compute spectrogram from audio file and save to disk""" 82 | output = from_file( 83 | audio_file, 84 | mels, 85 | log_dynamic_range_compression_threshold) 86 | torch.save(output, output_file) 87 | 88 | 89 | def from_files_to_files( 90 | audio_files, 91 | output_files, 92 | mels=False, 93 | log_dynamic_range_compression_threshold=\ 94 | promonet.LOG_DYNAMIC_RANGE_COMPRESSION_THRESHOLD 95 | ): 96 | """Compute spectrogram from audio files and save to disk""" 97 | preprocess_fn = functools.partial( 98 | from_file_to_file, 99 | mels=mels, 100 | log_dynamic_range_compression_threshold=\ 101 | log_dynamic_range_compression_threshold) 102 | with mp.get_context('spawn').Pool(promonet.NUM_WORKERS) as pool: 103 | pool.starmap(preprocess_fn, zip(audio_files, output_files)) 104 | 105 | 106 | ############################################################################### 107 | # Utilities 108 | ############################################################################### 109 | 110 | 111 | def linear_to_mel( 112 | spectrogram, 113 | log_dynamic_range_compression_threshold=\ 114 | promonet.LOG_DYNAMIC_RANGE_COMPRESSION_THRESHOLD 115 | ): 116 | # Create mel basis 117 | if not hasattr(linear_to_mel, 'mel_basis'): 118 | basis = librosa.filters.mel( 119 | sr=promonet.SAMPLE_RATE, 120 | n_fft=promonet.NUM_FFT, 121 | n_mels=promonet.NUM_MELS) 122 | basis = torch.from_numpy(basis) 123 | basis = basis.to(spectrogram.dtype).to(spectrogram.device) 124 | linear_to_mel.basis = basis 125 | 126 | # Convert to log-mels 127 | melspectrogram = torch.log(torch.matmul(linear_to_mel.basis, spectrogram)) 128 | 129 | # Maybe apply dynamic range compression 130 | if log_dynamic_range_compression_threshold is not None: 131 | return torch.clamp( 132 | melspectrogram, 133 | min=log_dynamic_range_compression_threshold) 134 | 135 | return melspectrogram 136 | -------------------------------------------------------------------------------- /promonet/preprocess/text.py: -------------------------------------------------------------------------------- 1 | import string 2 | 3 | import torch 4 | from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline 5 | from whisper.normalizers import EnglishTextNormalizer 6 | 7 | import promonet 8 | 9 | 10 | ############################################################################### 11 | # Constants 12 | ############################################################################### 13 | 14 | 15 | # Whisper model identifier 16 | MODEL_ID = "openai/whisper-large-v3" 17 | 18 | 19 | ############################################################################### 20 | # Whisper ASR 21 | ############################################################################### 22 | 23 | 24 | def from_audio(audio, sample_rate=promonet.SAMPLE_RATE, gpu=None): 25 | """Perform ASR from audio""" 26 | device = f'cuda:{gpu}' if gpu is not None else 'cpu' 27 | 28 | # Infer text 29 | results = infer( 30 | { 31 | 'sampling_rate': sample_rate, 32 | 'raw': audio.to(torch.float32).squeeze(dim=0).cpu().numpy() 33 | }, 34 | gpu) 35 | 36 | # Lint 37 | return lint(results['text']) 38 | 39 | 40 | def from_file(audio_file, gpu=None): 41 | """Perform Whisper ASR on an audio file""" 42 | # Infer text 43 | results = infer([str(audio_file)], gpu) 44 | 45 | # Lint 46 | return lint(results[0]['text']) 47 | 48 | 49 | def from_file_to_file(audio_file, output_file, gpu=None): 50 | """Perform Whisper ASR and save""" 51 | from_files_to_files([audio_file], [output_file], gpu) 52 | 53 | 54 | def from_files_to_files(audio_files, output_files, gpu=None): 55 | """Perform batched Whisper ASR from files and save""" 56 | # Infer text 57 | results = infer([str(audio_file) for audio_file in audio_files], gpu) 58 | 59 | # Lint 60 | results = [lint(result['text']) for result in results] 61 | 62 | # Save 63 | for result, output_file in zip(results, output_files): 64 | with open(output_file, 'w', encoding='utf-8') as file: 65 | file.write(result) 66 | 67 | 68 | ############################################################################### 69 | # Utilities 70 | ############################################################################### 71 | 72 | 73 | def infer(audio, gpu=None): 74 | """Batched Whisper ASR""" 75 | device = f'cuda:{gpu}' if gpu is not None else 'cpu' 76 | 77 | # Cache model 78 | if not hasattr(infer, 'pipe') or infer.device != device: 79 | model = AutoModelForSpeechSeq2Seq.from_pretrained( 80 | MODEL_ID, 81 | torch_dtype=torch.float16, 82 | low_cpu_mem_usage=True, 83 | use_safetensors=True, 84 | ).to(device) 85 | processor = AutoProcessor.from_pretrained(MODEL_ID) 86 | infer.pipe = pipeline( 87 | "automatic-speech-recognition", 88 | model=model, 89 | tokenizer=processor.tokenizer, 90 | feature_extractor=processor.feature_extractor, 91 | max_new_tokens=128, 92 | chunk_length_s=30, 93 | batch_size=64, 94 | return_timestamps=False, 95 | torch_dtype=torch.float16, 96 | device=device) 97 | infer.device = device 98 | 99 | return infer.pipe(audio) 100 | 101 | 102 | def lint(text): 103 | """Formats text to only words for use in WER""" 104 | if not hasattr(lint, 'normalizer'): 105 | lint.normalizer = EnglishTextNormalizer() 106 | return lint.normalizer(text).lower() 107 | -------------------------------------------------------------------------------- /promonet/synthesize/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import * 2 | -------------------------------------------------------------------------------- /promonet/synthesize/__main__.py: -------------------------------------------------------------------------------- 1 | import yapecs 2 | from pathlib import Path 3 | 4 | import promonet 5 | 6 | 7 | ############################################################################### 8 | # Entry point 9 | ############################################################################### 10 | 11 | 12 | def parse_args(): 13 | """Parse command-line arguments""" 14 | parser = yapecs.ArgumentParser( 15 | description='Synthesize speech from features') 16 | parser.add_argument( 17 | '--loudness_files', 18 | type=Path, 19 | nargs='+', 20 | required=True, 21 | help='The loudness files') 22 | parser.add_argument( 23 | '--pitch_files', 24 | type=Path, 25 | nargs='+', 26 | required=True, 27 | help='The pitch files') 28 | parser.add_argument( 29 | '--periodicity_files', 30 | type=Path, 31 | nargs='+', 32 | required=True, 33 | help='The periodicity files') 34 | parser.add_argument( 35 | '--ppg_files', 36 | type=Path, 37 | nargs='+', 38 | required=True, 39 | help='The phonetic posteriorgram files') 40 | parser.add_argument( 41 | '--output_files', 42 | type=Path, 43 | nargs='+', 44 | required=True, 45 | help='The files to save the edited audio') 46 | parser.add_argument( 47 | '--speakers', 48 | type=int, 49 | nargs='+', 50 | help='The IDs of the speakers for voice conversion') 51 | parser.add_argument( 52 | '--spectral_balance_ratio', 53 | type=float, 54 | default=1., 55 | help='> 1 for Alvin and the Chipmunks; < 1 for Patrick Star') 56 | parser.add_argument( 57 | '--checkpoint', 58 | type=Path, 59 | help='The generator checkpoint') 60 | parser.add_argument( 61 | '--gpu', 62 | type=int, 63 | help='The GPU index') 64 | return parser.parse_args() 65 | 66 | 67 | promonet.synthesize.from_files_to_files(**vars(parse_args())) 68 | -------------------------------------------------------------------------------- /promonet/synthesize/core.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Optional, Union 3 | from pathlib import Path 4 | 5 | import huggingface_hub 6 | import torch 7 | import torchaudio 8 | import torchutil 9 | 10 | import promonet 11 | 12 | 13 | ############################################################################### 14 | # Editing API 15 | ############################################################################### 16 | 17 | 18 | def from_features( 19 | loudness: torch.Tensor, 20 | pitch: torch.Tensor, 21 | periodicity: torch.Tensor, 22 | ppg: torch.Tensor, 23 | speaker: Union[int, torch.Tensor] = 0, 24 | spectral_balance_ratio: float = 1., 25 | loudness_ratio: float = 1., 26 | checkpoint: Optional[Union[str, os.PathLike]] = None, 27 | gpu: Optional[int] = None 28 | ) -> torch.Tensor: 29 | """Perform speech synthesis 30 | 31 | Args: 32 | loudness: The loudness contour 33 | pitch: The pitch contour 34 | periodicity: The periodicity contour 35 | ppg: The phonetic posteriorgram 36 | speaker: The speaker index or embedding 37 | spectral_balance_ratio: > 1 for Alvin and the Chipmunks; < 1 for Patrick Star 38 | loudness_ratio: > 1 for louder; < 1 for quieter 39 | checkpoint: The generator checkpoint 40 | gpu: The GPU index 41 | 42 | Returns 43 | generated: The generated speech 44 | """ 45 | device = torch.device('cpu' if gpu is None else f'cuda:{gpu}') 46 | 47 | if loudness.ndim == 2: 48 | loudness = loudness[None] 49 | 50 | return generate( 51 | loudness.to(device), 52 | pitch.to(device), 53 | periodicity.to(device), 54 | ppg.to(device), 55 | speaker, 56 | spectral_balance_ratio, 57 | loudness_ratio, 58 | checkpoint 59 | ).to(torch.float32) 60 | 61 | 62 | def from_file( 63 | loudness_file: Union[str, os.PathLike], 64 | pitch_file: Union[str, os.PathLike], 65 | periodicity_file: Union[str, os.PathLike], 66 | ppg_file: Union[str, os.PathLike], 67 | speaker: Union[int, torch.Tensor, Path, str] = 0, 68 | spectral_balance_ratio: float = 1., 69 | loudness_ratio: float = 1., 70 | checkpoint: Optional[Union[str, os.PathLike]] = None, 71 | gpu: Optional[int] = None 72 | ) -> torch.Tensor: 73 | """Perform speech synthesis from features on disk 74 | 75 | Args: 76 | loudness_file: The loudness file 77 | pitch_file: The pitch file 78 | periodicity_file: The periodicity file 79 | ppg_file: The phonetic posteriorgram file 80 | speaker: The speaker index or embedding 81 | spectral_balance_ratio: > 1 for Alvin and the Chipmunks; < 1 for Patrick Star 82 | loudness_ratio: > 1 for louder; < 1 for quieter 83 | checkpoint: The generator checkpoint 84 | gpu: The GPU index 85 | 86 | Returns 87 | generated: The generated speech 88 | """ 89 | device = torch.device('cpu' if gpu is None else f'cuda:{gpu}') 90 | 91 | # Load features 92 | loudness = torch.load(loudness_file) 93 | pitch = torch.load(pitch_file) 94 | periodicity = torch.load(periodicity_file) 95 | ppg = promonet.load.ppg(ppg_file, resample_length=pitch.shape[-1])[None] 96 | 97 | # Maybe load speaker embedding 98 | if promonet.ZERO_SHOT: 99 | speaker = torch.load(speaker).to(device) 100 | 101 | # Generate 102 | return from_features( 103 | loudness.to(device), 104 | pitch.to(device), 105 | periodicity.to(device), 106 | ppg.to(device), 107 | speaker, 108 | spectral_balance_ratio, 109 | loudness_ratio, 110 | checkpoint, 111 | gpu) 112 | 113 | 114 | def from_file_to_file( 115 | loudness_file: Union[str, os.PathLike], 116 | pitch_file: Union[str, os.PathLike], 117 | periodicity_file: Union[str, os.PathLike], 118 | ppg_file: Union[str, os.PathLike], 119 | output_file: Union[str, os.PathLike], 120 | speaker: Union[int, torch.Tensor, Path, str] = 0, 121 | spectral_balance_ratio: float = 1., 122 | loudness_ratio: float = 1., 123 | checkpoint: Optional[Union[str, os.PathLike]] = None, 124 | gpu: Optional[int] = None 125 | ) -> None: 126 | """Perform speech synthesis from features on disk and save 127 | 128 | Args: 129 | loudness_file: The loudness file 130 | pitch_file: The pitch file 131 | periodicity_file: The periodicity file 132 | ppg_file: The phonetic posteriorgram file 133 | output_file: The file to save generated speech audio 134 | speaker: The speaker index or embedding 135 | spectral_balance_ratio: > 1 for Alvin and the Chipmunks; < 1 for Patrick Star 136 | loudness_ratio: > 1 for louder; < 1 for quieter 137 | checkpoint: The generator checkpoint 138 | gpu: The GPU index 139 | """ 140 | # Generate 141 | generated = from_file( 142 | loudness_file, 143 | pitch_file, 144 | periodicity_file, 145 | ppg_file, 146 | speaker, 147 | spectral_balance_ratio, 148 | loudness_ratio, 149 | checkpoint, 150 | gpu 151 | ).to('cpu') 152 | 153 | # Save 154 | output_file.parent.mkdir(exist_ok=True, parents=True) 155 | torchaudio.save(output_file, generated, promonet.SAMPLE_RATE) 156 | 157 | 158 | def from_files_to_files( 159 | loudness_files: List[Union[str, os.PathLike]], 160 | pitch_files: List[Union[str, os.PathLike]], 161 | periodicity_files: List[Union[str, os.PathLike]], 162 | ppg_files: List[Union[str, os.PathLike]], 163 | output_files: List[Union[str, os.PathLike]], 164 | speakers: Optional[Union[List[int], torch.Tensor, Path, str]] = None, 165 | spectral_balance_ratio: float = 1., 166 | loudness_ratio: float = 1., 167 | checkpoint: Optional[Union[str, os.PathLike]] = None, 168 | gpu: Optional[int] = None 169 | ) -> None: 170 | """Perform batched speech synthesis from features on disk and save 171 | 172 | Args: 173 | loudness_files: The loudness files 174 | pitch_files: The pitch files 175 | periodicity_files: The periodicity files 176 | ppg_files: The phonetic posteriorgram files 177 | output_files: The files to save generated speech audio 178 | speakers: The speaker indices or embeddings 179 | spectral_balance_ratio: > 1 for Alvin and the Chipmunks; < 1 for Patrick Star 180 | loudness_ratio: > 1 for louder; < 1 for quieter 181 | checkpoint: The generator checkpoint 182 | gpu: The GPU index 183 | """ 184 | if speakers is None: 185 | speakers = [0] * len(pitch_files) 186 | 187 | # Generate 188 | iterator = zip( 189 | loudness_files, 190 | pitch_files, 191 | periodicity_files, 192 | ppg_files, 193 | output_files, 194 | speakers) 195 | for item in iterator: 196 | from_file_to_file( 197 | *item, 198 | spectral_balance_ratio=spectral_balance_ratio, 199 | loudness_ratio=loudness_ratio, 200 | checkpoint=checkpoint, 201 | gpu=gpu) 202 | 203 | 204 | ############################################################################### 205 | # Pipeline 206 | ############################################################################### 207 | 208 | 209 | def generate( 210 | loudness, 211 | pitch, 212 | periodicity, 213 | ppg, 214 | speaker=0, 215 | spectral_balance_ratio: float = 1., 216 | loudness_ratio: float = 1., 217 | checkpoint=None 218 | ) -> torch.Tensor: 219 | """Generate speech from phoneme and prosody features""" 220 | device = pitch.device 221 | 222 | with torchutil.time.context('load'): 223 | 224 | # Cache model 225 | if ( 226 | not hasattr(generate, 'model') or 227 | generate.checkpoint != checkpoint or 228 | generate.device != device 229 | ): 230 | if promonet.SPECTROGRAM_ONLY: 231 | model = promonet.model.MelGenerator().to(device) 232 | else: 233 | model = promonet.model.Generator().to(device) 234 | if checkpoint is None: 235 | checkpoint = huggingface_hub.hf_hub_download( 236 | 'maxrmorrison/promonet', 237 | f'generator-00{promonet.STEPS}.pt') 238 | else: 239 | if type(checkpoint) is str: 240 | checkpoint = Path(checkpoint) 241 | if checkpoint.is_dir(): 242 | checkpoint = torchutil.checkpoint.latest_path( 243 | checkpoint, 244 | 'generator-*.pt') 245 | model, *_ = torchutil.checkpoint.load(checkpoint, model) 246 | generate.model = model 247 | generate.checkpoint = checkpoint 248 | generate.device = device 249 | 250 | with torchutil.time.context('generate'): 251 | 252 | # Specify speaker 253 | if promonet.ZERO_SHOT: 254 | speakers = speaker.to(device) 255 | else: 256 | speakers = torch.full((1,), speaker, dtype=torch.long, device=device) 257 | 258 | # Format ratio 259 | spectral_balance_ratio = torch.tensor( 260 | [spectral_balance_ratio], 261 | dtype=torch.float, 262 | device=device) 263 | 264 | # Loudness ratio 265 | loudness_ratio = torch.tensor( 266 | [loudness_ratio], 267 | dtype=torch.float, 268 | device=device) 269 | 270 | # Generate 271 | with torchutil.inference.context(generate.model): 272 | return generate.model( 273 | loudness, 274 | pitch, 275 | periodicity, 276 | ppg, 277 | speakers, 278 | spectral_balance_ratio, 279 | loudness_ratio, 280 | generate.model.default_previous_samples 281 | )[0] 282 | -------------------------------------------------------------------------------- /promonet/train/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import * 2 | from . import loss 3 | -------------------------------------------------------------------------------- /promonet/train/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import shutil 3 | from pathlib import Path 4 | 5 | import promonet 6 | 7 | 8 | ############################################################################### 9 | # Training 10 | ############################################################################### 11 | 12 | 13 | def main( 14 | config, 15 | dataset=promonet.TRAINING_DATASET, 16 | train_partition='train', 17 | valid_partition='valid', 18 | adapt_from=False, 19 | gpu=None 20 | ): 21 | # Create output directory 22 | directory = promonet.RUNS_DIR / promonet.CONFIG 23 | directory.mkdir(parents=True, exist_ok=True) 24 | 25 | # Save configuration 26 | if config is not None: 27 | shutil.copyfile(config, directory / config.name) 28 | 29 | # Train 30 | promonet.train( 31 | directory, 32 | dataset, 33 | train_partition, 34 | valid_partition, 35 | adapt_from, 36 | gpu) 37 | 38 | 39 | def parse_args(): 40 | """Parse command-line arguments""" 41 | parser = argparse.ArgumentParser(description='Train a model') 42 | parser.add_argument( 43 | '--config', 44 | type=Path, 45 | nargs='+', 46 | help='The configuration file') 47 | parser.add_argument( 48 | '--dataset', 49 | default=promonet.TRAINING_DATASET, 50 | help='The dataset to train on') 51 | parser.add_argument( 52 | '--train_partition', 53 | default='train', 54 | help='The data partition to train on') 55 | parser.add_argument( 56 | '--valid_partition', 57 | default='valid', 58 | help='The data partition to perform validation on') 59 | parser.add_argument( 60 | '--adapt_from', 61 | type=Path, 62 | help='A checkpoint to perform adaptation from') 63 | parser.add_argument( 64 | '--gpu', 65 | type=int, 66 | help='The gpu to run training on') 67 | 68 | # Delete config files 69 | args = parser.parse_args() 70 | if args.config is not None: 71 | args.config = args.config[0] 72 | return args 73 | 74 | 75 | main(**vars(parse_args())) 76 | -------------------------------------------------------------------------------- /promonet/train/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import promonet 4 | 5 | 6 | ############################################################################### 7 | # Adversarial loss functions 8 | ############################################################################### 9 | 10 | 11 | def feature_matching(real_feature_maps, fake_feature_maps): 12 | """Feature matching loss""" 13 | loss = 0. 14 | iterator = zip(real_feature_maps, fake_feature_maps) 15 | for real_feature_map, fake_feature_map in iterator: 16 | 17 | # Maybe omit first activation layers from feature matching loss 18 | if promonet.FEATURE_MATCHING_OMIT_FIRST: 19 | real_feature_map = real_feature_map[1:] 20 | fake_feature_map = fake_feature_map[1:] 21 | 22 | # Aggregate 23 | for real, fake in zip(real_feature_map, fake_feature_map): 24 | loss += torch.mean(torch.abs(real.float().detach() - fake.float())) 25 | 26 | return loss 27 | 28 | 29 | def discriminator(real_outputs, fake_outputs): 30 | """Discriminator loss""" 31 | real_losses = [] 32 | fake_losses = [] 33 | for real_output, fake_output in zip(real_outputs, fake_outputs): 34 | if promonet.ADVERSARIAL_HINGE_LOSS: 35 | real_losses.append(torch.mean(torch.clamp(1. - real_output, min=0.))) 36 | fake_losses.append(torch.mean(torch.clamp(1 + fake_output, min=0.))) 37 | else: 38 | real_losses.append(torch.mean((1. - real_output) ** 2.)) 39 | fake_losses.append(torch.mean(fake_output ** 2.)) 40 | return sum(real_losses) + sum(fake_losses), real_losses, fake_losses 41 | 42 | 43 | def generator(discriminator_outputs): 44 | """Generator adversarial loss""" 45 | if promonet.ADVERSARIAL_HINGE_LOSS: 46 | losses = [ 47 | torch.mean(torch.clamp(1. - output, min=0.)) 48 | for output in discriminator_outputs] 49 | else: 50 | losses = [ 51 | torch.mean((1. - output) ** 2.) 52 | for output in discriminator_outputs] 53 | return sum(losses), losses 54 | 55 | 56 | ############################################################################### 57 | # Spectral loss functions 58 | ############################################################################### 59 | 60 | 61 | def stft(x, fft_size, hop_size, win_length, window): 62 | """Perform STFT and convert to magnitude spectrogram. 63 | Args: 64 | x (Tensor): Input signal tensor (B, T). 65 | fft_size (int): FFT size. 66 | hop_size (int): Hop size. 67 | win_length (int): Window length. 68 | window (str): Window function type. 69 | Returns: 70 | Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1). 71 | """ 72 | magnitude = torch.abs( 73 | torch.stft( 74 | x, 75 | fft_size, 76 | hop_size, 77 | win_length, 78 | window, 79 | return_complex=True)) 80 | return torch.sqrt(torch.clamp(magnitude, min=1e-7)) 81 | 82 | 83 | class SpectralConvergence(torch.nn.Module): 84 | """STFT loss module.""" 85 | 86 | def __init__( 87 | self, 88 | device, 89 | fft_size=1024, 90 | shift_size=120, 91 | win_length=600, 92 | window='hann_window' 93 | ): 94 | super().__init__() 95 | self.fft_size = fft_size 96 | self.shift_size = shift_size 97 | self.win_length = win_length 98 | self.window = getattr(torch, window)(win_length).to(device) 99 | 100 | def forward(self, x, y): 101 | """Calculate forward propagation. 102 | Args: 103 | x (Tensor): Predicted signal (B, 1, T). 104 | y (Tensor): Groundtruth signal (B, 1, T). 105 | Returns: 106 | Tensor: Spectral convergence loss value. 107 | Tensor: Log STFT magnitude loss value. 108 | """ 109 | x_mag = stft( 110 | x.squeeze(1), 111 | self.fft_size, 112 | self.shift_size, 113 | self.win_length, 114 | self.window) 115 | y_mag = stft( 116 | y.squeeze(1), 117 | self.fft_size, 118 | self.shift_size, 119 | self.win_length, 120 | self.window) 121 | return torch.norm(y_mag - x_mag, p=1) / torch.norm(y_mag, p=1) 122 | 123 | 124 | class MultiResolutionSpectralConvergence(torch.nn.Module): 125 | 126 | def __init__( 127 | self, 128 | device, 129 | fft_sizes=[2560, 1280, 640, 320, 160, 80], 130 | hop_sizes=[640, 320, 160, 80, 40, 20], 131 | win_lengths=[2560, 1280, 640, 320, 160, 80], 132 | window='hann_window' 133 | ): 134 | super().__init__() 135 | self.stft_losses = torch.nn.ModuleList() 136 | for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths): 137 | self.stft_losses += [SpectralConvergence(device, fs, ss, wl, window)] 138 | 139 | def forward(self, x, y): 140 | """Calculate forward propagation. 141 | Args: 142 | x (Tensor): Predicted signal (B, 1, T). 143 | y (Tensor): Groundtruth signal (B, 1, T). 144 | Returns: 145 | Tensor: Multi resolution spectral convergence loss value 146 | """ 147 | sc_loss = 0.0 148 | for stft_loss in self.stft_losses: 149 | sc_loss += stft_loss(x, y) 150 | return sc_loss / len(self.stft_losses) 151 | 152 | 153 | ############################################################################### 154 | # Time-domain loss functions 155 | ############################################################################### 156 | 157 | 158 | def signal(y_true, y_pred): 159 | """Waveform loss function""" 160 | t = y_true / (1e-15 + torch.norm(y_true, dim=-1, p=2, keepdim=True)) 161 | p = y_pred / (1e-15 + torch.norm(y_pred, dim=-1, p=2, keepdim=True)) 162 | return torch.mean(1. - torch.sum(p * t, dim=-1)) 163 | -------------------------------------------------------------------------------- /results/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxrmorrison/promonet/86b815e8511d526da42cd8333a9f5a2872b11981/results/.gitkeep -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | # Runs experiments from the paper 2 | # "Fine-Grained and Interpretable Neural Speech Editing" 3 | 4 | # Args 5 | # $1 - index of GPU to use 6 | 7 | 8 | ############################################################################### 9 | # Best model 10 | ############################################################################### 11 | 12 | 13 | # Data pipeline 14 | python -m promonet.data.download --datasets vctk 15 | python -m promonet.data.augment --datasets vctk 16 | python -m promonet.data.preprocess --datasets vctk --gpu $1 17 | python -m promonet.partition --datasets vctk 18 | 19 | # Train 20 | python -m promonet.train --gpu $1 21 | 22 | # Evaluate 23 | python -m promonet.evaluate --datasets vctk --gpu $1 24 | 25 | 26 | ############################################################################### 27 | # Ablations 28 | ############################################################################### 29 | 30 | 31 | # Data pipeline 32 | python -m promonet.data.preprocess \ 33 | --config config/ablations/ablate-viterbi.py \ 34 | --features pitch \ 35 | --datasets vctk \ 36 | --gpu $1 37 | 38 | # Train 39 | python -m promonet.train --config config/ablations/ablate-all.py --gpu $1 40 | python -m promonet.train --config config/ablations/ablate-augment.py --gpu $1 41 | python -m promonet.train --config config/ablations/ablate-multiloud.py --gpu $1 42 | python -m promonet.train --config config/ablations/ablate-sppg.py --gpu $1 43 | python -m promonet.train --config config/ablations/ablate-variable-pitch.py --gpu $1 44 | python -m promonet.train --config config/ablations/ablate-viterbi.py --gpu $1 45 | 46 | # Evaluate 47 | python -m promonet.evaluate \ 48 | --config config/ablations/ablate-all.py \ 49 | --datasets vctk \ 50 | --gpu $1 51 | python -m promonet.evaluate \ 52 | --config config/ablations/ablate-augment.py \ 53 | --datasets vctk \ 54 | --gpu $1 55 | python -m promonet.evaluate \ 56 | --config config/ablations/ablate-multiloud.py \ 57 | --datasets vctk \ 58 | --gpu $1 59 | python -m promonet.evaluate \ 60 | --config config/ablations/ablate-sppg.py \ 61 | --datasets vctk \ 62 | --gpu $1 63 | python -m promonet.evaluate \ 64 | --config config/ablations/ablate-variable-pitch.py \ 65 | --datasets vctk \ 66 | --gpu $1 67 | python -m promonet.evaluate \ 68 | --config config/ablations/ablate-viterbi.py \ 69 | --datasets vctk \ 70 | --gpu $1 71 | 72 | 73 | ############################################################################### 74 | # Baselines 75 | ############################################################################### 76 | 77 | 78 | # Train 79 | python -m promonet.train --config config/baselines/mels.py --gpu $1 80 | 81 | # Evaluate 82 | python -m promonet.evaluate \ 83 | --config config/baselines/mels.py \ 84 | --datasets vctk \ 85 | --gpu $1 86 | python -m promonet.evaluate \ 87 | --config config/baselines/world.py \ 88 | --datasets vctk \ 89 | --gpu $1 90 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | 4 | with open('README.md', encoding='utf8') as file: 5 | long_description = file.read() 6 | 7 | 8 | setup( 9 | name='promonet', 10 | description='Prosody Modification Network', 11 | version='0.0.1', 12 | author='Interactive Audio Lab', 13 | author_email='interactiveaudiolab@gmail.com', 14 | url='https://github.com/maxrmorrison/promonet', 15 | install_requires=[ 16 | 'GPUtil', 17 | 'huggingface-hub', 18 | 'jiwer', 19 | 'librosa', 20 | 'matplotlib', 21 | 'numpy', 22 | 'openai-whisper', 23 | 'penn', 24 | 'ppgs', 25 | 'pypar', 26 | 'pyworld', 27 | 'resampy', 28 | 'scipy', 29 | 'soundfile', 30 | 'transformers', 31 | 'torch', 32 | 'torchaudio', 33 | 'torchutil', 34 | 'umap-learn', 35 | 'vocos[train]', 36 | 'yapecs', 37 | ], 38 | packages=find_packages(), 39 | package_data={'promonet': ['assets/*', 'assets/*/*', 'assets/*/*/*']}, 40 | long_description=long_description, 41 | long_description_content_type='text/markdown', 42 | keywords=['speech', 'prosody', 'editing', 'synthesis', 'pronunciation'], 43 | license='MIT') 44 | --------------------------------------------------------------------------------