├── .gitattributes ├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── Samples ├── p257_016.wav ├── p260_131.wav └── p266_015.wav ├── config └── config_v3.json ├── environment.yml ├── images └── LPC-NFS.png ├── inference_hifigan.py ├── inference_hifiglot.py ├── pyproject.toml ├── requirements.txt ├── src └── neural_formant_synthesis │ ├── __init__.py │ ├── dataset.py │ ├── feature_extraction.py │ ├── functions.py │ ├── glotnet │ ├── __init__.py │ ├── model │ │ └── feedforward │ │ │ ├── __init__.py │ │ │ ├── activations.py │ │ │ ├── convolution.py │ │ │ ├── convolution_layer.py │ │ │ ├── convolution_stack.py │ │ │ └── wavenet.py │ └── sigproc │ │ ├── __init__.py │ │ ├── allpole.py │ │ ├── biquad.py │ │ ├── emphasis.py │ │ ├── levinson.py │ │ ├── lfilter.py │ │ ├── lpc.py │ │ ├── lsf.py │ │ ├── melspec.py │ │ └── oscillator.py │ ├── models.py │ ├── sigproc │ ├── __init__.py │ ├── levinson.py │ └── lpc.py │ ├── third_party │ ├── __init__.py │ └── hifi_gan │ │ ├── LICENSE │ │ ├── README.md │ │ ├── __init__.py │ │ ├── config_v1.json │ │ ├── config_v2.json │ │ ├── config_v3.json │ │ ├── env.py │ │ ├── inference.py │ │ ├── inference_e2e.py │ │ ├── meldataset.py │ │ ├── models.py │ │ ├── requirements.txt │ │ ├── train.py │ │ └── utils.py │ └── vctk_preprocessing.py ├── tests ├── test_dataset.py ├── test_train_imports.py └── test_wavenet.py ├── train_e2e_hifigan.py └── train_e2e_hifiglot.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.7z filter=lfs diff=lfs merge=lfs -text 2 | *.arrow filter=lfs diff=lfs merge=lfs -text 3 | *.bin filter=lfs diff=lfs merge=lfs -text 4 | *.bz2 filter=lfs diff=lfs merge=lfs -text 5 | *.ckpt filter=lfs diff=lfs merge=lfs -text 6 | *.ftz filter=lfs diff=lfs merge=lfs -text 7 | *.gz filter=lfs diff=lfs merge=lfs -text 8 | *.h5 filter=lfs diff=lfs merge=lfs -text 9 | *.joblib filter=lfs diff=lfs merge=lfs -text 10 | *.lfs.* filter=lfs diff=lfs merge=lfs -text 11 | *.mlmodel filter=lfs diff=lfs merge=lfs -text 12 | *.model filter=lfs diff=lfs merge=lfs -text 13 | *.msgpack filter=lfs diff=lfs merge=lfs -text 14 | *.npy filter=lfs diff=lfs merge=lfs -text 15 | *.npz filter=lfs diff=lfs merge=lfs -text 16 | *.onnx filter=lfs diff=lfs merge=lfs -text 17 | *.ot filter=lfs diff=lfs merge=lfs -text 18 | *.parquet filter=lfs diff=lfs merge=lfs -text 19 | *.pb filter=lfs diff=lfs merge=lfs -text 20 | *.pickle filter=lfs diff=lfs merge=lfs -text 21 | *.pkl filter=lfs diff=lfs merge=lfs -text 22 | *.pt filter=lfs diff=lfs merge=lfs -text 23 | *.pth filter=lfs diff=lfs merge=lfs -text 24 | *.rar filter=lfs diff=lfs merge=lfs -text 25 | *.safetensors filter=lfs diff=lfs merge=lfs -text 26 | saved_model/**/* filter=lfs diff=lfs merge=lfs -text 27 | *.tar.* filter=lfs diff=lfs merge=lfs -text 28 | *.tar filter=lfs diff=lfs merge=lfs -text 29 | *.tflite filter=lfs diff=lfs merge=lfs -text 30 | *.tgz filter=lfs diff=lfs merge=lfs -text 31 | *.wasm filter=lfs diff=lfs merge=lfs -text 32 | *.xz filter=lfs diff=lfs merge=lfs -text 33 | *.png filter=lfs diff=lfs merge=lfs -text 34 | *.zip filter=lfs diff=lfs merge=lfs -text 35 | *.zst filter=lfs diff=lfs merge=lfs -text 36 | *tfevents* filter=lfs diff=lfs merge=lfs -text 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | output/* 10 | 11 | experiments/* 12 | ckpt/* 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # poetry 103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 104 | # This is especially recommended for binary packages to ensure reproducibility, and is more 105 | # commonly ignored for libraries. 106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 107 | #poetry.lock 108 | 109 | # pdm 110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 111 | #pdm.lock 112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 113 | # in version control. 114 | # https://pdm.fming.dev/#use-with-ide 115 | .pdm.toml 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ 159 | 160 | # PyCharm 161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 163 | # and can be added to the global gitignore or merged into this file. For a more nuclear 164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 165 | #.idea/ 166 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "checkpoints/HiFi-Glot"] 2 | path = checkpoints/HiFi-Glot 3 | url = https://huggingface.co/ljuvela/SourceFilterNeuralFormantSynthesisE2E 4 | [submodule "checkpoints/NFS"] 5 | path = checkpoints/NFS 6 | url = https://huggingface.co/ljuvela/NeuralFormantSynthesis 7 | [submodule "checkpoints/NFS-E2E"] 8 | path = checkpoints/NFS-E2E 9 | url = https://huggingface.co/ljuvela/NeuralFormantSynthesisE2E -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Lauri Juvela 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Formant Synthesis with Differentiable Resonant Filters 2 | 3 | Neural formant synthesis using differtiable resonant filters and source-filter model structure. 4 | 5 | Authors: [Lauri Juvela][lauri_profile], [Pablo Pérez Zarazaga][pablo_profile], [Gustav Eje Henter][gustav_profile], [Zofia Malisz][zofia_profile] 6 | 7 | [HiFi_link]: https://github.com/jik876/hifi-gan 8 | [GlotNet_link]: https://github.com/ljuvela/GlotNet 9 | [arxiv_link]: http://arxiv.org/abs/placeholder_link 10 | [demopage_link]: https://perezpoz.github.io/SFNeuralFormants 11 | [gustav_profile]: https://people.kth.se/~ghe/ 12 | [pablo_profile]: https://www.kth.se/profile/pablopz 13 | [zofia_profile]: https://www.kth.se/profile/malisz 14 | [lauri_profile]: https://research.aalto.fi/en/persons/lauri-juvela 15 | 16 | [lfs_link]:https://git-lfs.com 17 | 18 | ## Table of contents 19 | 1. [Model overview](#model_struct) 20 | 1. [Sound samples](#sound_samples) 21 | 3. [Repository installation](#install) 22 | 1. [Conda environment](#conda) 23 | 2. [GlotNet](#glotnet) 24 | 3. [HiFi-GAN](#hifi) 25 | 4. [Additional libraries](#additional) 26 | 4. [Pre-trained models](#pretrained) 27 | 5. [Inference](#inference) 28 | 6. [Training](#training) 29 | 7. [Citation information](#citation) 30 | 31 | ## Model overview 32 | 33 | We present a model that performs neural speech synthesis using the structure of the source-filter model, allowing to independently inspect and manipulate the spectral envelope and glottal excitation: 34 | 35 | ![Neural formant pipeline follwing the source-filter model architectrue](./images/LPC-NFS.png "Neural formant pipeline follwing the source-filter model architectrue.") 36 | 37 | ### Sound samples 38 | 39 | A description of the presented model and sound samples compared to other synthesis/manipulation systems can be found in the [project's demo webpage][demopage_link] 40 | 41 | ## Repository installation 42 | 43 | #### Conda environment 44 | 45 | First, we need to create a conda environment to install our dependencies. Use mamba to speed up the process if possible. 46 | ```sh 47 | mamba env create -n neuralformants -f environment.yml 48 | conda activate neuralformants 49 | ``` 50 | 51 | Pre-trained models are available in HuggingFace, and can be downloaded using git-lfs. If you don't have git-lfs installed (it's included in `environment.yml`), you can find it [here][lfs_link]. Use the following command to download the pre-trained models: 52 | ```sh 53 | git submodule update --init --recursive 54 | ``` 55 | 56 | Install the package in development mode: 57 | ```sh 58 | pip install -e . 59 | ``` 60 | 61 | 62 | #### GlotNet 63 | GlotNet is included partially for WaveNet models and DSP functions. Full repository is available [here][GlotNet_link] 64 | 65 | 66 | #### HiFi-GAN 67 | HiFi-GAN is included in the `hifi_gan` subdirectory. Original source code is available [here][HiFi_link] 68 | 69 | ## Inference 70 | 71 | We provide a script to run inference on the end-to-end architecture, such that an audio file can be provided as input and a wav file with the manipulated features is stored as output. 72 | 73 | Change the feature scaling to modify pitch (with F0) or formants. The scales are provided as a list of 5 elements with the following order: 74 | ```python 75 | [F0, F1, F2, F3, F4] 76 | ``` 77 | An example with the provided audio samples from the VCTK dataset can be run using: 78 | 79 | HiFi-Glot 80 | ```sh 81 | python inference_hifiglot.py \ 82 | --input_path "./Samples" \ 83 | --output_path "./output/hifi-glot" \ 84 | --config "./checkpoints/HiFi-Glot/config_hifigan.json" \ 85 | --fm_config "./checkpoints/HiFi-Glot/config_feature_map.json" \ 86 | --checkpoint_path "./checkpoints/HiFi-Glot" \ 87 | --feature_scale "[1.0, 1.0, 1.0, 1.0, 1.0]" 88 | ``` 89 | 90 | NFS 91 | ```sh 92 | python inference_hifigan.py \ 93 | --input_path "./Samples" \ 94 | --output_path "./output/nfs" \ 95 | --config "./checkpoints/NFS/config_hifigan.json" \ 96 | --fm_config "./checkpoints/NFS/config_feature_map.json" \ 97 | --checkpoint_path "./checkpoints/NFS" \ 98 | --feature_scale "[1.0, 1.0, 1.0, 1.0, 1.0]" 99 | ``` 100 | 101 | NFS-E2E 102 | ```sh 103 | python inference_hifigan.py \ 104 | --input_path "./Samples" \ 105 | --output_path "./output/nfs-e2e" \ 106 | --config "./checkpoints/NFS-E2E/config_hifigan.json" \ 107 | --fm_config "./checkpoints/NFS-E2E/config_feature_map.json" \ 108 | --checkpoint_path "./checkpoints/NFS-E2E" \ 109 | --feature_scale "[1.0, 1.0, 1.0, 1.0, 1.0]" 110 | ``` 111 | 112 | 113 | ## Model training 114 | 115 | Training of the HiFi-GAN and HiFi-Glot models is possible with the end-to-end architecture by using the the scripts `train_e2e_hifigan.py` and `train_e2e_hifiglot.py`. 116 | 117 | 118 | ## Citation information 119 | 120 | Citation information will be added when a pre-print is available. 121 | -------------------------------------------------------------------------------- /Samples/p257_016.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ljuvela/SourceFilterNeuralFormants/d0894c2aa153510e6967c3b62c47f73ce4cb3879/Samples/p257_016.wav -------------------------------------------------------------------------------- /Samples/p260_131.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ljuvela/SourceFilterNeuralFormants/d0894c2aa153510e6967c3b62c47f73ce4cb3879/Samples/p260_131.wav -------------------------------------------------------------------------------- /Samples/p266_015.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ljuvela/SourceFilterNeuralFormants/d0894c2aa153510e6967c3b62c47f73ce4cb3879/Samples/p266_015.wav -------------------------------------------------------------------------------- /config/config_v3.json: -------------------------------------------------------------------------------- 1 | { 2 | "resblock": "2", 3 | "num_gpus": 1, 4 | "batch_size": 16, 5 | "learning_rate": 0.0002, 6 | "adam_b1": 0.8, 7 | "adam_b2": 0.99, 8 | "lr_decay": 0.999, 9 | "seed": 1234, 10 | 11 | "upsample_rates": [8,8,4], 12 | "upsample_kernel_sizes": [16,16,8], 13 | "upsample_initial_channel": 256, 14 | "resblock_kernel_sizes": [3,5,7], 15 | "resblock_dilation_sizes": [[1,2], [2,6], [3,12]], 16 | 17 | "segment_size": 8192, 18 | "num_mels": 80, 19 | "num_freq": 1025, 20 | "n_fft": 1024, 21 | "hop_size": 256, 22 | "win_size": 1024, 23 | "allpole_order": 30, 24 | "pre_emph_coeff": 0.97, 25 | 26 | "sampling_rate": 22050, 27 | 28 | "fmin": 0, 29 | "fmax": 8000, 30 | "fmax_for_loss": null, 31 | 32 | "num_workers": 4, 33 | 34 | "dist_config": { 35 | "dist_backend": "nccl", 36 | "dist_url": "tcp://localhost:54321", 37 | "world_size": 1 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - conda-forge 4 | dependencies: 5 | - python=3.12 6 | - pytorch=2.4 7 | - torchaudio 8 | - tensorboard 9 | - matplotlib 10 | - pandas 11 | - pytest 12 | - git-lfs 13 | - pip: 14 | - diffsptk 15 | - pyworld 16 | - tqdm 17 | - ipdb 18 | -------------------------------------------------------------------------------- /images/LPC-NFS.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a4ec577e3dd476e0972ef4053ab13fe534345973e634d72c25653a9c3af89bca 3 | size 456332 4 | -------------------------------------------------------------------------------- /inference_hifigan.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.simplefilter(action='ignore', category=FutureWarning) 3 | import os 4 | import argparse 5 | import json 6 | import torch 7 | from neural_formant_synthesis.third_party.hifi_gan.env import AttrDict, build_env 8 | from neural_formant_synthesis.third_party.hifi_gan.utils import scan_checkpoint 9 | 10 | 11 | from neural_formant_synthesis.glotnet.sigproc.lpc import LinearPredictor 12 | from neural_formant_synthesis.glotnet.sigproc.emphasis import Emphasis 13 | 14 | from neural_formant_synthesis.models import FM_Hifi_Generator, fm_config_obj, Envelope_wavenet, Envelope_conformer 15 | from neural_formant_synthesis.feature_extraction import feature_extractor, Normaliser, MedianPool1d 16 | from neural_formant_synthesis.models import NeuralFormantSynthesisGenerator 17 | 18 | 19 | from neural_formant_synthesis.glotnet.sigproc.levinson import forward_levinson 20 | 21 | import torchaudio as ta 22 | import pandas as pd 23 | from tqdm import tqdm 24 | from glob import glob 25 | 26 | 27 | torch.backends.cudnn.benchmark = True 28 | 29 | 30 | def generate_wave_list(file_list, scale_list, a, h, fm_h): 31 | 32 | torch.cuda.manual_seed(h.seed) 33 | if torch.cuda.is_available(): 34 | device = torch.device('cuda:0') 35 | else: 36 | device = torch.device('cpu') 37 | 38 | target_sr = h.sampling_rate 39 | win_size = h.win_size 40 | hop_size = h.hop_size 41 | 42 | feat_extractor = feature_extractor(sr = target_sr,window_samples = win_size, step_samples = hop_size, formant_ceiling = 10000, max_formants = 4) 43 | median_filter = MedianPool1d(kernel_size = 3, stride = 1, padding = 0, same = True) 44 | pre_emphasis_cpu = Emphasis(alpha = h.pre_emph_coeff) 45 | 46 | normalise_features = Normaliser(target_sr) 47 | 48 | generator = NeuralFormantSynthesisGenerator(fm_config=fm_h, g_config=h, 49 | pretrained_fm=None, 50 | freeze_fm=False, 51 | device=device) 52 | 53 | 54 | print("checkpoints directory : ", a.checkpoint_path) 55 | 56 | if os.path.isdir(a.checkpoint_path): 57 | cp_g = scan_checkpoint(a.checkpoint_path, 'g_') 58 | 59 | 60 | generator.load_generator_e2e_checkpoint(cp_g) 61 | 62 | generator = generator.to(device) 63 | 64 | generator.eval() 65 | 66 | 67 | # Read files from list 68 | for file in tqdm(file_list, total = len(file_list)): 69 | # Read audio and resample if necessary 70 | x, sample_rate = ta.load(file) 71 | x = x[0:1].type(torch.DoubleTensor) 72 | 73 | x = ta.functional.resample(x, sample_rate, target_sr) 74 | 75 | # Get features using feature extractor 76 | 77 | x_preemph = pre_emphasis_cpu(x.unsqueeze(0)) 78 | x_preemph = x_preemph.squeeze(0).squeeze(0) 79 | formants, energy, centroid, tilt, pitch, voicing_flag,_, _,_ = feat_extractor(x_preemph) 80 | 81 | # Parameter smoothing and length matching 82 | 83 | formants = median_filter(formants.T.unsqueeze(1)).squeeze(1).T 84 | 85 | pitch = pitch.squeeze(0) 86 | voicing_flag = voicing_flag.squeeze(0) 87 | 88 | # If pitch length is smaller than formants, pad pitch and voicing flag with last value 89 | if pitch.size(0) < formants.size(0): 90 | pitch = torch.nn.functional.pad(pitch, (0, formants.size(0) - pitch.size(0)), mode = 'constant', value = pitch[-1]) 91 | voicing_flag = torch.nn.functional.pad(voicing_flag, (0, formants.size(0) - voicing_flag.size(0)), mode = 'constant', value = voicing_flag[-1]) 92 | # If pitch length is larger than formants, truncate pitch and voicing flag 93 | elif pitch.size(0) > formants.size(0): 94 | pitch = pitch[:formants.size(0)] 95 | voicing_flag = voicing_flag[:formants.size(0)] 96 | 97 | # We can apply manipulation HERE 98 | 99 | log_pitch = torch.log(pitch) 100 | 101 | #pitch = pitch * scale_list[0] 102 | for i in range(voicing_flag.size(0)): 103 | if voicing_flag[i] == 1: 104 | log_pitch[i] = log_pitch[i] + torch.log(torch.tensor(scale_list[0])) 105 | formants[i,0] = formants[i,0] * scale_list[1] 106 | formants[i,1] = formants[i,1] * scale_list[2] 107 | formants[i,2] = formants[i,2] * scale_list[3] 108 | formants[i,3] = formants[i,3] * scale_list[4] 109 | 110 | # Normalise data 111 | log_pitch, formants, tilt, centroid, energy = normalise_features(log_pitch, formants, tilt, centroid, energy) 112 | 113 | #Create input data 114 | #size --> (Batch, features, sequence) 115 | norm_feat = torch.transpose(torch.cat((log_pitch.unsqueeze(1), formants, tilt.unsqueeze(1), centroid.unsqueeze(1), energy.unsqueeze(1), voicing_flag.unsqueeze(1)),dim = -1), 0, 1) 116 | 117 | norm_feat = norm_feat.type(torch.FloatTensor).unsqueeze(0).to(device) 118 | 119 | y_g_hat, _ = generator(norm_feat) 120 | 121 | output_file = os.path.splitext(os.path.basename(file))[0] + '_wave_' + str(scale_list[0]) + '_' + str(scale_list[1]) + '_' + str(scale_list[2]) + '_' + str(scale_list[3]) + '_' + str(scale_list[4]) + '.wav' 122 | output_orig = os.path.splitext(os.path.basename(file))[0] + '_orig.wav' 123 | out_path = os.path.join(a.output_path, output_file) 124 | out_orig_path = os.path.join(a.output_path, output_orig) 125 | 126 | ta.save(out_path, y_g_hat.detach().cpu().squeeze(0), target_sr) 127 | if not os.path.exists(out_orig_path): 128 | ta.save(out_orig_path, x.type(torch.FloatTensor), target_sr) 129 | 130 | def parse_file_list(list_file): 131 | """ 132 | Read text file with paths to the files to process. 133 | """ 134 | file1 = open(list_file, 'r') 135 | lines = file1.read().splitlines() 136 | return lines 137 | 138 | def str_to_list(in_str): 139 | return list(map(float, in_str.strip('[]').split(','))) 140 | 141 | def main(): 142 | 143 | parser = argparse.ArgumentParser() 144 | 145 | parser.add_argument('--input_path', default = None, help="Path to directory containing files to process.") 146 | parser.add_argument('--list_file', default = None, help="Text file containing list of files to process. Optional argument to use instead of input_path.") 147 | parser.add_argument('--output_path', default='test_output', help="Path to directory to save processed files") 148 | parser.add_argument('--config', default='', help="Path to HiFi-GAN config json file") 149 | parser.add_argument('--fm_config', default='', help="Path to feature mapping model config json file") 150 | parser.add_argument('--env_config', default='', help="Path to envelope estimation model config json file") 151 | parser.add_argument('--audio_ext', default = '.wav', help="Extension of the audio files to process") 152 | parser.add_argument('--checkpoint_path', help="Path to pre-trained HiFi-GAN model") 153 | parser.add_argument('--feature_scale', help="List of scales for pitch and formant frequencies -- [F0, F1, F2, F3, F4]") 154 | 155 | 156 | a = parser.parse_args() 157 | 158 | with open(a.config) as f: 159 | data = f.read() 160 | 161 | json_config = json.loads(data) 162 | h = AttrDict(json_config) 163 | 164 | with open(a.fm_config) as f: 165 | data = f.read() 166 | json_fm_config = json.loads(data) 167 | fm_h = AttrDict(json_fm_config) 168 | # fm_h = fm_config_obj(json_fm_config) 169 | 170 | # build_env(a.config, 'config.json', a.checkpoint_path) 171 | if a.input_path is not None: 172 | file_list = glob(os.path.join(a.input_path,'*' + a.audio_ext)) 173 | elif a.list_file is not None: 174 | file_list = parse_file_list(a.list_file) 175 | else: 176 | raise ValueError('Input arguments should include either input_path or file_list') 177 | 178 | if not os.path.exists(a.output_path): 179 | os.makedirs(a.output_path, exist_ok=True) 180 | 181 | scale_list = str_to_list(a.feature_scale) 182 | if len(scale_list) != 5: 183 | raise ValueError('The scaling vector must contain 5 features: [F0, F1, F2, F3, F4]') 184 | 185 | torch.manual_seed(h.seed) 186 | 187 | generate_wave_list(file_list, scale_list, a, h, fm_h) 188 | 189 | 190 | if __name__ == '__main__': 191 | main() 192 | -------------------------------------------------------------------------------- /inference_hifiglot.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.simplefilter(action='ignore', category=FutureWarning) 3 | import os 4 | import argparse 5 | import json 6 | import torch 7 | from neural_formant_synthesis.third_party.hifi_gan.env import AttrDict, build_env 8 | #from neural_formant_synthesis.third_party.hifi_gan.models import discriminator_metrics 9 | from neural_formant_synthesis.third_party.hifi_gan.utils import scan_checkpoint 10 | 11 | 12 | from neural_formant_synthesis.glotnet.sigproc.lpc import LinearPredictor 13 | from neural_formant_synthesis.glotnet.sigproc.emphasis import Emphasis 14 | 15 | from neural_formant_synthesis.models import FM_Hifi_Generator, fm_config_obj, Envelope_wavenet, Envelope_conformer 16 | from neural_formant_synthesis.feature_extraction import feature_extractor, Normaliser, MedianPool1d 17 | from neural_formant_synthesis.models import SourceFilterFormantSynthesisGenerator 18 | 19 | 20 | from neural_formant_synthesis.glotnet.sigproc.levinson import forward_levinson 21 | 22 | import torchaudio as ta 23 | import pandas as pd 24 | from tqdm import tqdm 25 | from glob import glob 26 | 27 | 28 | torch.backends.cudnn.benchmark = True 29 | 30 | 31 | def generate_wave_list(file_list, scale_list, a, h, fm_h): 32 | 33 | torch.cuda.manual_seed(h.seed) 34 | if torch.cuda.is_available(): 35 | device = torch.device('cuda:0') 36 | else: 37 | device = torch.device('cpu') 38 | 39 | target_sr = h.sampling_rate 40 | win_size = h.win_size 41 | hop_size = h.hop_size 42 | 43 | feat_extractor = feature_extractor(sr = target_sr,window_samples = win_size, step_samples = hop_size, formant_ceiling = 10000, max_formants = 4) 44 | median_filter = MedianPool1d(kernel_size = 3, stride = 1, padding = 0, same = True) 45 | pre_emphasis_cpu = Emphasis(alpha = h.pre_emph_coeff) 46 | 47 | normalise_features = Normaliser(target_sr) 48 | 49 | generator = SourceFilterFormantSynthesisGenerator( 50 | fm_config=fm_h, 51 | g_config=h, 52 | pretrained_fm=None, 53 | freeze_fm=False, 54 | device=device) 55 | 56 | 57 | print("checkpoints directory : ", a.checkpoint_path) 58 | 59 | if os.path.isdir(a.checkpoint_path): 60 | cp_g = scan_checkpoint(a.checkpoint_path, 'g_') 61 | 62 | 63 | generator.load_generator_e2e_checkpoint(cp_g) 64 | 65 | generator = generator.to(device) 66 | 67 | generator.eval() 68 | 69 | 70 | # Read files from list 71 | for file in tqdm(file_list, total = len(file_list)): 72 | # Read audio and resample if necessary 73 | x, sample_rate = ta.load(file) 74 | x = x[0:1].type(torch.DoubleTensor) 75 | 76 | x = ta.functional.resample(x, sample_rate, target_sr) 77 | 78 | # Get features using feature extractor 79 | 80 | x_preemph = pre_emphasis_cpu(x.unsqueeze(0)) 81 | x_preemph = x_preemph.squeeze(0).squeeze(0) 82 | formants, energy, centroid, tilt, pitch, voicing_flag,_, _,_ = feat_extractor(x_preemph) 83 | 84 | # Parameter smoothing and length matching 85 | 86 | formants = median_filter(formants.T.unsqueeze(1)).squeeze(1).T 87 | 88 | pitch = pitch.squeeze(0) 89 | voicing_flag = voicing_flag.squeeze(0) 90 | 91 | # If pitch length is smaller than formants, pad pitch and voicing flag with last value 92 | if pitch.size(0) < formants.size(0): 93 | pitch = torch.nn.functional.pad(pitch, (0, formants.size(0) - pitch.size(0)), mode = 'constant', value = pitch[-1]) 94 | voicing_flag = torch.nn.functional.pad(voicing_flag, (0, formants.size(0) - voicing_flag.size(0)), mode = 'constant', value = voicing_flag[-1]) 95 | # If pitch length is larger than formants, truncate pitch and voicing flag 96 | elif pitch.size(0) > formants.size(0): 97 | pitch = pitch[:formants.size(0)] 98 | voicing_flag = voicing_flag[:formants.size(0)] 99 | 100 | # We can apply manipulation HERE 101 | 102 | log_pitch = torch.log(pitch) 103 | 104 | #pitch = pitch * scale_list[0] 105 | for i in range(voicing_flag.size(0)): 106 | if voicing_flag[i] == 1: 107 | log_pitch[i] = log_pitch[i] + torch.log(torch.tensor(scale_list[0])) 108 | formants[i,0] = formants[i,0] * scale_list[1] 109 | formants[i,1] = formants[i,1] * scale_list[2] 110 | formants[i,2] = formants[i,2] * scale_list[3] 111 | formants[i,3] = formants[i,3] * scale_list[4] 112 | 113 | # Normalise data 114 | log_pitch, formants, tilt, centroid, energy = normalise_features(log_pitch, formants, tilt, centroid, energy) 115 | 116 | #Create input data 117 | #size --> (Batch, features, sequence) 118 | norm_feat = torch.transpose(torch.cat((log_pitch.unsqueeze(1), formants, tilt.unsqueeze(1), centroid.unsqueeze(1), energy.unsqueeze(1), voicing_flag.unsqueeze(1)),dim = -1), 0, 1) 119 | 120 | norm_feat = norm_feat.type(torch.FloatTensor).unsqueeze(0).to(device) 121 | 122 | y_g_hat, _, _ = generator(norm_feat) 123 | 124 | output_file = os.path.splitext(os.path.basename(file))[0] + '_wave_' + str(scale_list[0]) + '_' + str(scale_list[1]) + '_' + str(scale_list[2]) + '_' + str(scale_list[3]) + '_' + str(scale_list[4]) + '.wav' 125 | output_orig = os.path.splitext(os.path.basename(file))[0] + '_orig.wav' 126 | out_path = os.path.join(a.output_path, output_file) 127 | out_orig_path = os.path.join(a.output_path, output_orig) 128 | 129 | ta.save(out_path, y_g_hat.detach().cpu().squeeze(0), target_sr) 130 | if not os.path.exists(out_orig_path): 131 | ta.save(out_orig_path, x.type(torch.FloatTensor), target_sr) 132 | 133 | def parse_file_list(list_file): 134 | """ 135 | Read text file with paths to the files to process. 136 | """ 137 | file1 = open(list_file, 'r') 138 | lines = file1.read().splitlines() 139 | return lines 140 | 141 | def str_to_list(in_str): 142 | return list(map(float, in_str.strip('[]').split(','))) 143 | 144 | def main(): 145 | 146 | parser = argparse.ArgumentParser() 147 | 148 | parser.add_argument('--input_path', default = None, help="Path to directory containing files to process.") 149 | parser.add_argument('--list_file', default = None, help="Text file containing list of files to process. Optional argument to use instead of input_path.") 150 | parser.add_argument('--output_path', default='test_output', help="Path to directory to save processed files") 151 | parser.add_argument('--config', default='', help="Path to HiFi-GAN config json file") 152 | parser.add_argument('--fm_config', default='', help="Path to feature mapping model config json file") 153 | parser.add_argument('--env_config', default='', help="Path to envelope estimation model config json file") 154 | parser.add_argument('--audio_ext', default = '.wav', help="Extension of the audio files to process") 155 | parser.add_argument('--checkpoint_path', help="Path to pre-trained HiFi-GAN model") 156 | parser.add_argument('--feature_scale', help="List of scales for pitch and formant frequencies -- [F0, F1, F2, F3, F4]") 157 | 158 | 159 | a = parser.parse_args() 160 | 161 | with open(a.config) as f: 162 | data = f.read() 163 | 164 | json_config = json.loads(data) 165 | h = AttrDict(json_config) 166 | 167 | with open(a.fm_config) as f: 168 | data = f.read() 169 | json_fm_config = json.loads(data) 170 | fm_h = AttrDict(json_fm_config) 171 | 172 | # build_env(a.config, 'config.json', a.checkpoint_path) 173 | if a.input_path is not None: 174 | file_list = glob(os.path.join(a.input_path,'*' + a.audio_ext)) 175 | elif a.list_file is not None: 176 | file_list = parse_file_list(a.list_file) 177 | else: 178 | raise ValueError('Input arguments should include either input_path or file_list') 179 | 180 | if not os.path.exists(a.output_path): 181 | os.makedirs(a.output_path, exist_ok=True) 182 | 183 | scale_list = str_to_list(a.feature_scale) 184 | if len(scale_list) != 5: 185 | raise ValueError('The scaling vector must contain 5 features: [F0, F1, F2, F3, F4]') 186 | 187 | torch.manual_seed(h.seed) 188 | 189 | generate_wave_list(file_list, scale_list, a, h, fm_h) 190 | 191 | 192 | if __name__ == '__main__': 193 | main() 194 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "NeuralFormantSynthesis" # Required 3 | version = "1.0.0" # Required 4 | description = "" # Optional 5 | readme = "README.md" # Optional 6 | requires-python = ">=3.7" 7 | license = {file = "LICENSE.txt"} 8 | authors = [ 9 | {name = "Lauri Juvela", email = "lauri.juvela@aalto.fi" } 10 | ] 11 | 12 | [tool.setuptools] 13 | include-package-data = true 14 | 15 | [tool.setuptools.packages.find] 16 | namespaces = true 17 | where = ["src"] 18 | 19 | [build-system] 20 | # These are the assumed default build requirements from pip: 21 | # https://pip.pypa.io/en/stable/reference/pip/#pep-517-and-518-support 22 | requires = ["setuptools>=43.0.0", "wheel"] 23 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy == 1.24.4 2 | torch 3 | torchaudio 4 | diffsptk 5 | pyworld 6 | tqdm 7 | ipdb 8 | tensorboard -------------------------------------------------------------------------------- /src/neural_formant_synthesis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ljuvela/SourceFilterNeuralFormants/d0894c2aa153510e6967c3b62c47f73ce4cb3879/src/neural_formant_synthesis/__init__.py -------------------------------------------------------------------------------- /src/neural_formant_synthesis/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from logging import warning 4 | import random 5 | import torch 6 | import torchaudio as ta 7 | from torch.utils.data import Dataset 8 | from typing import List, Tuple 9 | from .feature_extraction import Normaliser 10 | 11 | from neural_formant_synthesis.glotnet.sigproc.levinson import forward_levinson 12 | 13 | from neural_formant_synthesis.third_party.hifi_gan.meldataset import mel_spectrogram 14 | from tqdm import tqdm 15 | 16 | class FeatureDataset(Dataset): 17 | """ Dataset for audio files """ 18 | 19 | def __init__(self, 20 | dataset_dir: str, 21 | segment_len: int, 22 | sampling_rate: int, 23 | feature_ext: str = '.pt', 24 | audio_ext: str = '.wav', 25 | causal: bool = True, 26 | non_causal_segment: int = 0, 27 | normalise: bool = True, 28 | dtype: torch.dtype = torch.float32, 29 | device: str = 'cpu'): 30 | """ 31 | Args: 32 | config: Config object 33 | audio_dir: directory containing audio files 34 | audio_ext: file extension of audio files 35 | file_list: list of audio files 36 | transforms: transforms to apply to audio, output as auxiliary feature for conditioning 37 | dtype: data type of output 38 | """ 39 | 40 | self.dataset_dir = dataset_dir 41 | self.segment_len = segment_len 42 | self.sampling_rate = sampling_rate 43 | self.feature_ext = feature_ext 44 | self.audio_ext = audio_ext 45 | self.normalise = normalise 46 | self.dtype = dtype 47 | self.device = device 48 | 49 | self.causal = causal 50 | self.non_causal_segment = non_causal_segment 51 | 52 | self.total_segment_len = segment_len + non_causal_segment 53 | 54 | self.normaliser = Normaliser(self.sampling_rate) 55 | 56 | self.audio_files = glob(os.path.join( 57 | self.dataset_dir, f"*{self.feature_ext}")) 58 | 59 | # elements are (filename, start, stop) 60 | self.segment_index: List[Tuple(str, int, int)] = [] 61 | 62 | for f in tqdm(self.audio_files, total=len(self.audio_files)): 63 | self._check_feature_file(f) 64 | 65 | def _check_feature_file(self, f): 66 | 67 | 68 | features = torch.load(f) 69 | input_features = features['Energy'] 70 | num_samples = input_features.size(0) 71 | 72 | # read sements from the end (to minimize zero padding) 73 | stop = num_samples 74 | start = stop - self.segment_len # Only consider causal length to set start index. 75 | start = max(start, 0) 76 | while stop > 0: 77 | #if self.causal: 78 | self.segment_index.append( 79 | (os.path.realpath(f), start, stop)) 80 | #else: 81 | # stop_noncausal = min(stop + self.non_causal_segment, num_samples) 82 | # self.segment_index.append( 83 | # (os.path.realpath(f), start, stop_noncausal)) 84 | 85 | stop = stop - self.segment_len 86 | start = stop - self.segment_len 87 | start = max(start, 0) 88 | 89 | def __len__(self): 90 | return len(self.segment_index) 91 | 92 | def __getitem__(self, i): 93 | f, start, stop = self.segment_index[i] 94 | data = torch.load(f) 95 | features = torch.cat((data["Pitch"].unsqueeze(1), data["Formants"], data["Tilt"].unsqueeze(1), data["Centroid"].unsqueeze(1), data["Energy"].unsqueeze(1), data["Voicing"].unsqueeze(1)), dim = -1) 96 | out_features = data["R_Coeff"] 97 | 98 | x = features[start:stop,:] 99 | y = out_features[start:stop,:] 100 | 101 | if torch.any(torch.isnan(data["Pitch"])): 102 | raise ValueError("Pitch features are NaN before normalisation.") 103 | if torch.any(torch.isnan(data["Formants"])): 104 | raise ValueError("Formants features are NaN before normalisation.") 105 | if torch.any(torch.isnan(data["Tilt"])): 106 | raise ValueError("Tilt features are NaN before normalisation.") 107 | if torch.any(torch.isnan(data["Centroid"])): 108 | raise ValueError("Centroid features are NaN before normalisation.") 109 | if torch.any(torch.isnan(data["Energy"])): 110 | raise ValueError("Energy features are NaN before normalisation.") 111 | 112 | num_samples = features.size(0) 113 | pad_left = 0 114 | pad_right = 0 115 | 116 | if start == 0: 117 | pad_left = self.segment_len - x.size(0) 118 | 119 | # zero pad to segment_len + padding 120 | x = torch.transpose(torch.nn.functional.pad(torch.transpose(x, 0, 1), ( pad_left, 0), mode='replicate'), 0, 1) # seq_len, n_channels 121 | y = torch.transpose(torch.nn.functional.pad(torch.transpose(y, 0, 1), ( pad_left, 0), mode='replicate'), 0, 1) 122 | 123 | if not self.causal: 124 | remaining_samples = min(num_samples - stop, self.non_causal_segment) 125 | x = torch.cat((x, features[stop:stop + remaining_samples,:]), dim=0) 126 | y = torch.cat((y, out_features[stop:stop + remaining_samples,:]), dim=0) 127 | 128 | pad_right = self.non_causal_segment - remaining_samples 129 | if pad_right > 0: 130 | x = torch.transpose(torch.nn.functional.pad(torch.transpose(x, 0, 1), ( 0, pad_right), mode='replicate'), 0, 1) # seq_len, n_channels 131 | y = torch.transpose(torch.nn.functional.pad(torch.transpose(y, 0, 1), ( 0, pad_right), mode='replicate'), 0, 1) 132 | 133 | if not(x.size(0) == self.total_segment_len): 134 | raise ValueError('Padding in the wrong dimension') 135 | 136 | x = torch.transpose(torch.cat((torch.cat(self.normaliser(x[...,0:1], x[...,1:5], x[...,5:6], x[...,6:7], x[...,7:8]),dim = -1), x[...,8:9]),dim = -1), 0, 1) 137 | y = torch.transpose(y, 0, 1) 138 | 139 | if torch.any(torch.isnan(x)): 140 | raise ValueError("Output x features are NaN.") 141 | if torch.any(torch.isnan(y)): 142 | raise ValueError("Output y features are NaN.") 143 | # Set sequence len as last dimension. 144 | return x.type(torch.FloatTensor).to(self.device), y.type(torch.FloatTensor).to(self.device) 145 | 146 | class FeatureDataset_with_Mel(Dataset): 147 | def __init__(self, 148 | dataset_dir: str, 149 | segment_len: int, 150 | sampling_rate: int, 151 | frame_size: int, 152 | hop_size:int, 153 | feature_ext: str = '.pt', 154 | audio_ext: str = '.wav', 155 | causal: bool = True, 156 | non_causal_segment: int = 0, 157 | normalise: bool = True, 158 | dtype: torch.dtype = torch.float32, 159 | device: str = 'cpu'): 160 | """ 161 | Args: 162 | config: Config object 163 | audio_dir: directory containing audio files 164 | audio_ext: file extension of audio files 165 | file_list: list of audio files 166 | transforms: transforms to apply to audio, output as auxiliary feature for conditioning 167 | dtype: data type of output 168 | """ 169 | 170 | self.dataset_dir = dataset_dir 171 | self.segment_len = segment_len 172 | self.sampling_rate = sampling_rate 173 | self.frame_size = frame_size 174 | self.hop_size = hop_size 175 | self.feature_ext = feature_ext 176 | self.audio_ext = audio_ext 177 | self.normalise = normalise 178 | self.dtype = dtype 179 | self.device = device 180 | 181 | self.causal = causal 182 | self.non_causal_segment = non_causal_segment 183 | 184 | self.total_segment_len = segment_len + non_causal_segment 185 | 186 | self.melspec = ta.transforms.MelSpectrogram(sample_rate = self.sampling_rate, n_fft = self.frame_size, win_length = self.frame_size, hop_length = self.hop_size, n_mels = 80) 187 | 188 | self.normaliser = Normaliser(self.sampling_rate) 189 | 190 | self.audio_files = glob(os.path.join( 191 | self.dataset_dir, f"*{self.feature_ext}")) 192 | 193 | # elements are (filename, start, stop) 194 | self.segment_index: List[Tuple(str, int, int)] = [] 195 | 196 | for f in tqdm(self.audio_files, total=len(self.audio_files)): 197 | self._check_feature_file(f) 198 | 199 | def _check_feature_file(self, f): 200 | 201 | 202 | features = torch.load(f) 203 | input_features = features['Energy'] 204 | num_samples = input_features.size(0) 205 | 206 | # read sements from the end (to minimize zero padding) 207 | stop = num_samples 208 | start = stop - self.segment_len # Only consider causal length to set start index. 209 | start = max(start, 0) 210 | while stop > 0: 211 | #if self.causal: 212 | self.segment_index.append( 213 | (os.path.realpath(f), start, stop)) 214 | #else: 215 | # stop_noncausal = min(stop + self.non_causal_segment, num_samples) 216 | # self.segment_index.append( 217 | # (os.path.realpath(f), start, stop_noncausal)) 218 | 219 | stop = stop - self.segment_len 220 | start = stop - self.segment_len 221 | start = max(start, 0) 222 | 223 | def __len__(self): 224 | return len(self.segment_index) 225 | 226 | def __getitem__(self, i): 227 | 228 | f, start, stop = self.segment_index[i] 229 | 230 | audio_file = os.path.splitext(f)[0] + self.audio_ext 231 | start_samples = start * self.hop_size 232 | stop_samples = (stop - 1) * self.hop_size 233 | 234 | data = torch.load(f) 235 | features = torch.cat((torch.log(torch.clamp(data["Pitch"],50, 2000)).unsqueeze(1), data["Formants"], data["Tilt"].unsqueeze(1), data["Centroid"].unsqueeze(1), data["Energy"].unsqueeze(1), data["Voicing"].unsqueeze(1)), dim = -1) 236 | out_features = data["R_Coeff"] 237 | 238 | x = features[start:stop,:] 239 | y = out_features[start:stop,:] 240 | 241 | audio, sample_rate = ta.load(audio_file) 242 | 243 | if sample_rate != self.sampling_rate: 244 | audio = ta.functional.resample(audio, sample_rate, self.sampling_rate) 245 | audio_segment = audio[...,start_samples:stop_samples] 246 | 247 | num_samples = features.size(0) 248 | pad_left = 0 249 | pad_right = 0 250 | 251 | if start == 0: 252 | pad_left = self.segment_len - x.size(0) 253 | pad_left_audio = pad_left * self.hop_size 254 | 255 | # zero pad to segment_len + padding 256 | x = torch.transpose(torch.nn.functional.pad(torch.transpose(x, 0, 1), ( pad_left, 0), mode='replicate'), 0, 1) # seq_len, n_channels 257 | y = torch.transpose(torch.nn.functional.pad(torch.transpose(y, 0, 1), ( pad_left, 0), mode='replicate'), 0, 1) 258 | audio_segment = torch.nn.functional.pad(audio_segment, ( pad_left_audio, 0), mode='constant', value=0) 259 | 260 | if not self.causal: 261 | remaining_samples = min(num_samples - stop, self.non_causal_segment) 262 | remaining_samples_audio = remaining_samples * self.hop_size 263 | 264 | x = torch.cat((x, features[stop:stop + remaining_samples,:]), dim=0) 265 | y = torch.cat((y, out_features[stop:stop + remaining_samples,:]), dim=0) 266 | audio_segment = torch.cat((audio_segment, audio[...,stop_samples:stop_samples + remaining_samples_audio]), dim=-1) 267 | 268 | pad_right = self.non_causal_segment - remaining_samples 269 | if pad_right > 0: 270 | x = torch.transpose(torch.nn.functional.pad(torch.transpose(x, 0, 1), ( 0, pad_right), mode='replicate'), 0, 1) # seq_len, n_channels 271 | y = torch.transpose(torch.nn.functional.pad(torch.transpose(y, 0, 1), ( 0, pad_right), mode='replicate'), 0, 1) 272 | pad_right_audio = pad_right * self.hop_size 273 | audio_segment = torch.nn.functional.pad(audio_segment, ( pad_right_audio, 0), mode='constant', value=0) 274 | 275 | if not(x.size(0) == self.total_segment_len): 276 | raise ValueError('Padding in the wrong dimension') 277 | 278 | x = torch.transpose(torch.cat((torch.cat(self.normaliser(x[...,0:1], x[...,1:5], x[...,5:6], x[...,6:7], x[...,7:8]),dim = -1), x[...,8:9]),dim = -1), 0, 1) 279 | y = torch.transpose(y, 0, 1) 280 | 281 | mels = self.melspec(audio_segment).squeeze(0) 282 | 283 | if torch.any(torch.isnan(x)): 284 | raise ValueError("Output x features are NaN.") 285 | if torch.any(torch.isnan(y)): 286 | raise ValueError("Output y features are NaN.") 287 | # Set sequence len as last dimension. 288 | return x.type(torch.FloatTensor).to(self.device), y.type(torch.FloatTensor).to(self.device), audio_segment.squeeze(0).type(torch.FloatTensor).to(self.device), mels.type(torch.FloatTensor).to(self.device) 289 | 290 | class FeatureDataset_List(Dataset): 291 | def __init__(self, 292 | dataset_dir:str, 293 | config, 294 | sampling_rate: int, 295 | frame_size: int, 296 | hop_size:int, 297 | feature_ext: str = '.pt', 298 | audio_ext: str = '.wav', 299 | segment_length: int = None, 300 | normalise: bool = True, 301 | shuffle: bool = False, 302 | dtype: torch.dtype = torch.float32, 303 | device: str = 'cpu'): 304 | 305 | self.config = config 306 | 307 | self.dataset_dir = dataset_dir 308 | self.sampling_rate = sampling_rate 309 | self.frame_size = frame_size 310 | self.hop_size = hop_size 311 | self.feature_ext = feature_ext 312 | self.audio_ext = audio_ext 313 | self.normalise = normalise 314 | self.dtype = dtype 315 | self.device = device 316 | 317 | self.shuffle = shuffle 318 | 319 | self.segment_length = segment_length 320 | 321 | self.mel_spectrogram = ta.transforms.MelSpectrogram(sample_rate = self.sampling_rate, n_fft=self.frame_size, win_length = self.frame_size, hop_length = self.hop_size, f_min = 0.0, f_max = 8000, n_mels = 80) 322 | self.normaliser = Normaliser(self.sampling_rate) 323 | 324 | self.get_file_list() 325 | 326 | def get_file_list(self): 327 | 328 | self.file_list = glob(os.path.join(self.dataset_dir, '*' + self.feature_ext)) 329 | if self.shuffle: 330 | #indices = torch.randperm(len(file_list)) 331 | #file_list = file_list[indices] 332 | random.shuffle(self.file_list) 333 | 334 | def __len__(self): 335 | return len(self.file_list) 336 | 337 | def __getitem__(self, index): 338 | 339 | feature_file = self.file_list[index] 340 | audio_file = os.path.splitext(feature_file)[0] + self.audio_ext 341 | 342 | data = torch.load(feature_file) 343 | x = torch.cat((data["Pitch"].unsqueeze(1), data["Formants"], data["Tilt"].unsqueeze(1), data["Centroid"].unsqueeze(1), data["Energy"].unsqueeze(1), data["Voicing"].unsqueeze(1)), dim = -1) 344 | y = data["R_Coeff"] 345 | 346 | audio, sample_rate = ta.load(audio_file) 347 | 348 | if sample_rate != self.sampling_rate: 349 | audio = ta.functional.resample(audio, sample_rate, self.sampling_rate) 350 | 351 | if self.segment_length is None: 352 | self.segment_length = x.size(0) 353 | 354 | audio_total_len = int(x.size(0) * self.hop_size) 355 | 356 | if audio.size(1) < audio_total_len: 357 | audio = torch.unsqueeze(torch.nn.functional.pad(audio.squeeze(0), (0,int(audio_total_len - audio.size(1))),'constant'),0) 358 | else: 359 | audio = audio[:,:audio_total_len] 360 | 361 | if self.segment_length <= x.size(0): 362 | 363 | max_segment_start = x.size(0) - self.segment_length 364 | segment_start = random.randint(0, max_segment_start) 365 | 366 | x = x[segment_start:segment_start + self.segment_length,:] 367 | y = y[segment_start:segment_start + self.segment_length,:] 368 | 369 | audio_start = int(segment_start * self.hop_size) 370 | audio_segment_len = int(self.segment_length * self.hop_size) 371 | audio = audio[:,audio_start:audio_start + audio_segment_len] 372 | elif self.segment_length > x.size(0): 373 | diff = self.segment_length - x.size(0) 374 | x = torch.transpose(torch.nn.functional.pad(torch.transpose(x,0,1),(0,diff),'replicate'),0,1) 375 | y = torch.transpose(torch.nn.functional.pad(torch.transpose(y,0,1),(0,diff),'replicate'),0,1) 376 | 377 | audio_segment_diff = int(self.hop_size * diff) 378 | audio = torch.unsqueeze(torch.nn.functional.pad(audio.squeeze(0), (0,audio_segment_diff),'constant'),0) 379 | 380 | x = torch.transpose(torch.cat((torch.cat(self.normaliser(x[...,0:1], x[...,1:5], x[...,5:6], x[...,6:7], x[...,7:8]),dim = -1), x[...,8:9], y),dim = -1), 0, 1) 381 | #x = torch.transpose(torch.cat((torch.cat(self.normaliser(x[...,0:1], x[...,1:5], x[...,5:6], x[...,6:7], x[...,7:8]),dim = -1), x[...,8:9]),dim = -1), 0, 1) 382 | 383 | y = forward_levinson(y) 384 | 385 | y = torch.transpose(y, 0, 1) 386 | 387 | y_mel = mel_spectrogram(audio,sampling_rate = self.sampling_rate, n_fft=self.frame_size, win_size = self.frame_size, hop_size = self.hop_size, fmin = 0.0, fmax = self.config.fmax_for_loss, num_mels = 80)#self.mel_spectrogram(audio) 388 | 389 | return x.type(torch.FloatTensor).to(self.device), y.type(torch.FloatTensor).to(self.device), audio.squeeze(0).type(torch.FloatTensor).to(self.device), y_mel.squeeze().type(torch.FloatTensor).to(self.device) -------------------------------------------------------------------------------- /src/neural_formant_synthesis/feature_extraction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import diffsptk 3 | from neural_formant_synthesis.functions import frame_energy, spectral_centroid, tilt_levinson, root_to_formant, levinson, pitch_extraction 4 | 5 | 6 | class feature_extractor(torch.nn.Module): 7 | """ 8 | Class to extract the features for neural formant synthesis. 9 | Params: 10 | sr: sample rate 11 | window_samples: window size in samples 12 | step_samples: step size in samples 13 | formant_ceiling: formant ceiling in Hz 14 | max_formants: maximum number of formants to estimate 15 | allpole_order: allpole order for the LPC analysis 16 | """ 17 | def __init__(self, sr, window_samples, step_samples, formant_ceiling, max_formants, allpole_order = 10, r_coeff_order = 30): 18 | super().__init__() 19 | self.sr = sr 20 | self.window_samples = window_samples 21 | self.step_samples = step_samples 22 | self.formant_ceiling = formant_ceiling 23 | self.max_formants = max_formants 24 | self.allpole_order = allpole_order 25 | self.r_coeff_order = r_coeff_order 26 | 27 | self.framer = diffsptk.Frame( 28 | frame_length = self.window_samples, 29 | frame_period = self.step_samples 30 | ) 31 | self.windower = diffsptk.Window( 32 | in_length = self.window_samples 33 | ) 34 | 35 | # self.root_finder = diffsptk.DurandKernerMethod(self.allpole_order) 36 | self.root_finder = diffsptk.PolynomialToRoots(self.allpole_order) 37 | 38 | def forward(self, x): 39 | # Signal windowing 40 | x_frame = self.framer(x) 41 | x_window = self.windower(x_frame) 42 | 43 | # STFT 44 | 45 | x_spec = torch.fft.rfft(x_window, dim = -1) 46 | 47 | ds_samples = int(self.formant_ceiling/self.sr * x_spec.size(-1)) 48 | x_ds = x_spec[...,:ds_samples] 49 | 50 | x_ds_acorr = torch.fft.irfft(x_ds * torch.conj(x_ds), dim = -1) 51 | 52 | x_us_acorr = torch.fft.irfft(x_spec * torch.conj(x_spec), dim = -1) 53 | 54 | # Calculate formants 55 | 56 | ap_env, _ = levinson(x_ds_acorr, self.allpole_order) 57 | _, r_coeff_ref = levinson(x_us_acorr, self.r_coeff_order) 58 | 59 | roots_env = self.root_finder(ap_env) 60 | formants = root_to_formant(roots_env, self.formant_ceiling, self.max_formants) 61 | # Calculate other features 62 | energy = 10 * torch.log10(frame_energy(x_frame)) 63 | centroid = spectral_centroid(x_frame, self.sr) 64 | tilt = tilt_levinson(x_ds_acorr) 65 | 66 | # Extract pitch from audio signal 67 | pitch, voicing, ignored = pitch_extraction(x = x, sr = self.sr, window_samples = self.window_samples, step_samples = self.step_samples, fmin = 50, fmax = 500)#penn.dsp.dio.from_audio(audio = x, sample_rate = self.sr, hopsize = hopsize, fmin = 50, fmax = 500) 68 | pitch = torch.log(pitch) 69 | return formants, energy, centroid, tilt, pitch, voicing, r_coeff_ref, x_ds, ignored 70 | 71 | class MedianPool1d(torch.nn.Module): 72 | """ Median pool (usable as median filter when stride=1) module. 73 | 74 | Args: 75 | kernel_size: size of pooling kernel 76 | stride: pool stride 77 | padding: pool padding 78 | same: override padding and enforce same padding, boolean 79 | """ 80 | def __init__(self, kernel_size=3, stride=1, padding=0, same=False): 81 | super(MedianPool1d, self).__init__() 82 | self.k = kernel_size 83 | self.stride = stride 84 | self.padding = padding 85 | self.same = same 86 | 87 | def _padding(self, x): 88 | if self.same: 89 | iw = x.size()[-1] 90 | if iw % self.stride == 0: 91 | pw = max(self.k - self.stride, 0) 92 | else: 93 | pw = max(self.k - (iw % self.stride), 0) 94 | pl = pw // 2 95 | pr = pw - pl 96 | 97 | padding = (pl, pr) 98 | else: 99 | padding = self.padding 100 | return padding 101 | 102 | def forward(self, x): 103 | x = torch.nn.functional.pad(x, self._padding(x), mode='reflect') 104 | x = x.unfold(2, self.k, self.stride) 105 | x = x.contiguous().view(x.size()[:-1] + (-1,)).median(dim=-1)[0] 106 | return x 107 | 108 | class Normaliser(torch.nn.Module): 109 | """ 110 | Normalisation of features to a set of hard limits for training with online feature extraction. 111 | Params: 112 | sample_rate: sample rate 113 | pitch_lims: tuple with pitch limits in log(Hz) (lower, upper) 114 | formant_lims: tuple with formant limits in Hz (f1_lower, f1_upper, f2_lower, f2_upper, f3_lower, f3_upper, f4_lower, f4_upper) 115 | tilt_lims: tuple with spectral tilt limits (lower, upper) 116 | centroid_lims: tuple with spectral centroid limits in Hz (lower, upper) 117 | energy_lims: tuple with energy limits in dB (lower, upper) 118 | """ 119 | def __init__(self, sample_rate, pitch_lims = None, formant_lims = None, tilt_lims = None, centroid_lims = None, energy_lims = None): 120 | super().__init__() 121 | self.sample_rate = sample_rate 122 | if pitch_lims is not None: 123 | self.pitch_lower = pitch_lims[0] 124 | self.pitch_upper = pitch_lims[1] 125 | else: 126 | self.pitch_lower = 3.9 # 50 Hz 127 | self.pitch_upper = 7.3 # 1500 Hz 128 | 129 | if formant_lims is not None: 130 | self.f1_lower = formant_lims[0] 131 | self.f1_upper = formant_lims[1] 132 | self.f2_lower = formant_lims[2] 133 | self.f2_upper = formant_lims[3] 134 | self.f3_lower = formant_lims[4] 135 | self.f3_upper = formant_lims[5] 136 | self.f4_lower = formant_lims[6] 137 | self.f4_upper = formant_lims[7] 138 | else: 139 | self.f1_lower = 200 140 | self.f1_upper = 900 141 | self.f2_lower = 550 142 | self.f2_upper = 2450 143 | self.f3_lower = 2200 144 | self.f3_upper = 2950 145 | self.f4_lower = 3000 146 | self.f4_upper = 4000 147 | 148 | if tilt_lims is not None: 149 | self.tilt_lower = tilt_lims[0] 150 | self.tilt_upper = tilt_lims[1] 151 | else: 152 | self.tilt_lower = -1 153 | self.tilt_upper = -0.9 154 | 155 | if centroid_lims is not None: 156 | self.centroid_lower = centroid_lims[0] 157 | self.centroid_upper = centroid_lims[1] 158 | else: 159 | self.centroid_lower = 0 160 | self.centroid_upper = self.sample_rate/2 161 | 162 | if energy_lims is not None: 163 | self.energy_lower = energy_lims[0] 164 | self.energy_upper = energy_lims[1] 165 | else: 166 | self.energy_lower = -60 167 | self.energy_upper = 30 168 | 169 | def forward(self, pitch, formants, tilt, centroid, energy): 170 | pitch = self._scale(pitch, self.pitch_upper, self.pitch_lower) 171 | formants[...,0] = self._scale(formants[...,0],self.f1_upper, self.f1_lower) 172 | formants[...,1] = self._scale(formants[...,1],self.f2_upper, self.f2_lower) 173 | formants[...,2] = self._scale(formants[...,2],self.f3_upper, self.f3_lower) 174 | formants[...,3] = self._scale(formants[...,3],self.f4_upper, self.f4_lower) 175 | tilt = self._scale(torch.clamp(tilt, -1, -0.85), self.tilt_upper, self.tilt_lower) 176 | centroid = self._scale(centroid, self.centroid_upper, self.centroid_lower) 177 | energy = self._scale(energy, self.energy_upper, self.energy_lower) 178 | return pitch, formants, tilt, centroid, energy 179 | 180 | def _scale(self,feature, upper, lower): 181 | max_denorm = upper - lower 182 | feature = (2 * (feature - lower) / max_denorm) - 1 183 | return feature -------------------------------------------------------------------------------- /src/neural_formant_synthesis/functions.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import pyworld 5 | import warnings 6 | 7 | def tilt_levinson(acorr): 8 | """ 9 | Calculate spectral tilt from predictor coefficient of first order allpole calculated from signal autocorrelation. 10 | Args: 11 | acorr: Autocorrelation of the signal. 12 | Returns: 13 | Spectral tilt value. 14 | """ 15 | 16 | # Calculate allpole coefficients 17 | a,_ = levinson(acorr, 1) 18 | 19 | # Calculate spectral tilt 20 | tilt = a[:,1] 21 | 22 | return tilt 23 | 24 | def spectral_centroid(x, sr): 25 | """ 26 | Extract spectral centroid as weighted ratio of spectral bands. 27 | Args: 28 | x: Input frames with shape (..., n_frames, frame_length) 29 | sr: Sampling rate 30 | Returns: 31 | 1D tensor with the values of estimated spectral centroid for each frame 32 | """ 33 | 34 | mag_spec = torch.abs(torch.fft.rfft(x, dim=-1)) 35 | length = x.size(-1) 36 | freqs = torch.abs(torch.fft.fftfreq(length, 1.0/sr)[:length//2+1]) 37 | centroid = torch.sum(mag_spec*freqs, dim = -1) / torch.sum(mag_spec, dim = -1) 38 | return centroid / ( sr / 2) 39 | 40 | def frame_energy(x): 41 | """ 42 | Calculate frame energy. 43 | Args: 44 | x: Input frames. size (..., n_frames, frame_length) 45 | Returns: 46 | 1D tensor with energies for each frame. Size(n_frames,) 47 | """ 48 | 49 | energy = torch.sum(torch.square(torch.abs(x)), dim = -1) 50 | 51 | return energy 52 | 53 | def levinson(R, M): 54 | """ Levinson-Durbin method for converting autocorrelation to predictor polynomial 55 | Args: 56 | R: autocorrelation tensor, shape=(..., M) 57 | M: filter polynomial order 58 | Returns: 59 | A: filter predictor polynomial tensor, shape=(..., M) 60 | Note: 61 | R can contain more lags than M, but R[..., 0:M] are required 62 | """ 63 | 64 | # Normalize autocorrelation and add white noise correction 65 | R = R / R[..., 0:1] 66 | R[..., 0:1] = R[..., 0:1] + 0.001 67 | 68 | E = R[..., 0:1] 69 | L = torch.cat([torch.ones_like(R[..., 0:1]), 70 | torch.zeros_like(R[..., 0:M])], dim=-1) 71 | L_prev = L 72 | rcoeff = torch.zeros_like(L) 73 | for p in torch.arange(0, M): 74 | K = torch.sum(L_prev[..., 0:p+1] * R[..., 1:p+2], dim=-1, keepdim=True) / E 75 | rcoeff[...,p:p+1] = K 76 | if (torch.any(torch.abs(K) > 1)): 77 | print(torch.argmax(torch.abs(K))) 78 | print(R[torch.argmax(torch.abs(K)), 1:M]) 79 | raise ValueError('Reflection coeff bigger than 1') 80 | 81 | pad = torch.clamp(M-p-1, min=0) 82 | if p == 0: 83 | L = torch.cat([-1.0*K, 84 | torch.ones_like(R[..., 0:1]), 85 | torch.zeros_like(R[..., 0:pad])], dim=-1) 86 | else: 87 | L = torch.cat([-1.0*K, 88 | L_prev[..., 0:p] - 1.0*K * 89 | torch.flip(L_prev[..., 0:p], dims=[-1]), 90 | torch.ones_like(R[..., 0:1]), 91 | torch.zeros_like(R[..., 0:pad])], dim=-1) 92 | L_prev = L 93 | E = E * (1.0 - K ** 2) # % order-p mean-square error 94 | L = torch.flip(L, dims=[-1]) # flip zero delay to zero:th index 95 | return L, rcoeff 96 | 97 | def root_to_formant(roots, sr, max_formants = 5): 98 | """ 99 | Extract formant frequencies from allpole roots. 100 | Args: 101 | roots: Tensor containing roots of the polynomial. 102 | sr: Sampling rate. 103 | max_formants: Maximum number of formants to search. 104 | Returns: 105 | Tensor containing formant frequencies. 106 | """ 107 | 108 | freq_tolerance = 10.0 / (sr / (2.0 * np.pi)) 109 | 110 | phases = torch.angle(roots) 111 | phases = torch.where(phases < 0, phases + 2 * np.pi, phases) 112 | 113 | phases_sort,_ = torch.sort(phases, dim = -1, descending = False) 114 | 115 | phases_slice = phases_sort[...,1:max_formants+1] 116 | phases_sort = phases_sort[...,:max_formants] 117 | 118 | condition = (phases_sort[...,0:1] > freq_tolerance).expand(-1,max_formants) #Use expand instead of repeat 119 | 120 | phase_select = torch.where(condition, phases_sort, phases_slice) 121 | formants = phase_select * (sr / (2 * np.pi)) 122 | 123 | return formants 124 | 125 | def interpolate_unvoiced(pitch, voicing_flag = None): 126 | """ 127 | Fill unvoiced regions via linear interpolation 128 | @Pablo: Function copied from PENN's repository to allow input of voicing flag array. 129 | I haven't found any robust implementations for np.interp in pytorch. 130 | With a differentiable version, we could add this functionality to the formant extractor. 131 | """ 132 | if voicing_flag is None: 133 | unvoiced = pitch < 10 134 | else: 135 | unvoiced = ~voicing_flag 136 | 137 | # Ignore warning of log setting unvoiced regions (zeros) to nan 138 | with warnings.catch_warnings(): 139 | warnings.simplefilter('ignore') 140 | 141 | # Pitch is linear in base-2 log-space 142 | pitch = np.log2(pitch) 143 | 144 | ignored = False 145 | 146 | try: 147 | 148 | # Interpolate 149 | pitch[unvoiced] = np.interp( 150 | np.where(unvoiced)[0], 151 | np.where(~unvoiced)[0], 152 | pitch[~unvoiced]) 153 | 154 | except ValueError: 155 | ignored = True 156 | # Allow all unvoiced 157 | print('Allowing unvoiced') 158 | pass 159 | 160 | return 2 ** pitch, ~unvoiced, ignored 161 | 162 | def pitch_extraction(x, sr, window_samples, step_samples, fmin = 50, fmax = 500): 163 | """ 164 | Extract pitch using pyworld. 165 | Params: 166 | x: audio signal 167 | sr: sample rate 168 | window_samples: window size in samples 169 | step_samples: step size in samples 170 | fmin: minimum pitch frequency 171 | fmax: maximum pitch frequency 172 | Returns: 173 | pitch: tensor with pitch values 174 | voicing_flag: tensor with voicing flags 175 | """ 176 | # Convert to numpy 177 | audio = x.numpy().squeeze().astype(np.float64) 178 | 179 | hopsize = float(step_samples) / sr 180 | 181 | # Get pitch 182 | pitch, times = pyworld.dio( 183 | audio[window_samples // 2:-window_samples // 2], 184 | sr, 185 | fmin, 186 | fmax, 187 | frame_period=1000 * hopsize) 188 | 189 | # Refine pitch 190 | pitch = pyworld.stonemask( 191 | audio, 192 | pitch, 193 | times, 194 | sr) 195 | 196 | # Interpolate unvoiced tokens 197 | pitch, voicing_flag, ignored = interpolate_unvoiced(pitch) 198 | 199 | # Convert to torch 200 | return torch.from_numpy(pitch)[None], torch.tensor(voicing_flag, dtype = torch.int), ignored -------------------------------------------------------------------------------- /src/neural_formant_synthesis/glotnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ljuvela/SourceFilterNeuralFormants/d0894c2aa153510e6967c3b62c47f73ce4cb3879/src/neural_formant_synthesis/glotnet/__init__.py -------------------------------------------------------------------------------- /src/neural_formant_synthesis/glotnet/model/feedforward/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ljuvela/SourceFilterNeuralFormants/d0894c2aa153510e6967c3b62c47f73ce4cb3879/src/neural_formant_synthesis/glotnet/model/feedforward/__init__.py -------------------------------------------------------------------------------- /src/neural_formant_synthesis/glotnet/model/feedforward/activations.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2022 Lauri Juvela 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import torch 18 | import glotnet.cpp_extensions as ext 19 | 20 | def _gated_activation(x: torch.Tensor) -> torch.Tensor: 21 | 22 | assert x.size(1) % 2 == 0, f"Input must have an even number of channels, shape was {x.shape}" 23 | half = x.size(1) // 2 24 | return torch.tanh(x[:, :half, :]) * torch.sigmoid(x[:, half:, :]) 25 | 26 | class Activation(torch.nn.Module): 27 | """ Activation class """ 28 | 29 | def __init__(self, activation="gated"): 30 | super().__init__() 31 | self.activation_str = activation 32 | if activation == "gated": 33 | self.activation_func = _gated_activation 34 | elif activation == "tanh": 35 | self.activation_func = torch.tanh 36 | elif activation == "linear": 37 | self.activation_func = torch.nn.Identity() 38 | 39 | def forward(self, input, use_extension=True): 40 | 41 | if use_extension: 42 | return ActivationFunction.apply(input, self.activation_str) 43 | else: 44 | return self.activation_func(input) 45 | 46 | 47 | class ActivationFunction(torch.autograd.Function): 48 | 49 | @staticmethod 50 | def forward(ctx, input: torch.Tensor, 51 | activation: str, time_major: bool = True 52 | ) -> torch.Tensor: 53 | 54 | ctx.time_major = time_major 55 | if ctx.time_major: 56 | input = input.permute(0, 2, 1) # (B, C, T) -> (B, T, C) 57 | 58 | input = input.contiguous() 59 | ctx.save_for_backward(input) 60 | ctx.activation_type = activation 61 | 62 | output, = ext.activations_forward(input, activation) 63 | 64 | if ctx.time_major: 65 | output = output.permute(0, 2, 1) # (B, T, C) -> (B, C, T) 66 | 67 | return output 68 | 69 | @staticmethod 70 | def backward(ctx, d_output): 71 | raise NotImplementedError("Backward function not implemented for sequential processing") 72 | 73 | 74 | 75 | class Tanh(torch.nn.Module): 76 | 77 | def __init__(self): 78 | super().__init__() 79 | 80 | def forward(self, input: torch.Tensor) -> torch.Tensor: 81 | return TanhFunction.apply(input) 82 | 83 | 84 | class TanhFunction(torch.autograd.Function): 85 | 86 | @staticmethod 87 | def forward(ctx, input: torch.Tensor): 88 | 89 | output = torch.tanh(input) 90 | ctx.save_for_backward(input, output) 91 | return output 92 | 93 | @staticmethod 94 | def backward(ctx, d_output: torch.Tensor): 95 | 96 | input, output = ctx.saved_tensors 97 | # d tanh(x) / dx = 1 - tanh(x) ** 2 98 | d_input = (1 - output ** 2) * d_output 99 | return d_input -------------------------------------------------------------------------------- /src/neural_formant_synthesis/glotnet/model/feedforward/convolution.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class Convolution(torch.nn.Conv1d): 4 | """ Causal convolution with optional FILM conditioning """ 5 | 6 | def __init__(self, 7 | in_channels: int, 8 | out_channels: int, 9 | kernel_size: int, 10 | stride: int = 1, 11 | padding: int = 0, 12 | dilation: int = 1, 13 | groups: int = 1, 14 | bias: bool = True, 15 | padding_mode: str = 'zeros', 16 | device=None, 17 | dtype=None, 18 | causal: bool = True, 19 | use_film: bool = False 20 | ): 21 | super().__init__(in_channels, out_channels, 22 | kernel_size, stride, 23 | padding, dilation, 24 | groups, bias, padding_mode, 25 | device, dtype) 26 | self.causal = causal 27 | self.use_film = use_film 28 | 29 | @property 30 | def receptive_field(self) -> int: 31 | """ Receptive field length """ 32 | return (self.kernel_size[0] - 1) * self.dilation[0] + 1 33 | 34 | def forward(self, 35 | input: torch.Tensor, 36 | cond_input: torch.Tensor = None, 37 | sequential: bool = False 38 | ) -> torch.Tensor: 39 | """ 40 | Args: 41 | input shape is (batch, in_channels, time) 42 | cond_input: conditioning input 43 | 44 | shape = (batch, 2*out_channels, time) if self.use_film else (batch, out_channels, time) 45 | Returns: 46 | output shape is (batch, out_channels, time) 47 | """ 48 | 49 | if cond_input is not None: 50 | if self.use_film and cond_input.size(1) != 2 * self.out_channels: 51 | raise ValueError(f"Cond input number of channels mismatch." 52 | f"Expected {2*self.out_channels}, got {cond_input.size(1)}") 53 | if not self.use_film and cond_input.size(1) != self.out_channels: 54 | raise ValueError(f"Cond input number of channels mismatch." 55 | f"Expected {self.out_channels}, got {cond_input.size(1)}") 56 | if cond_input.size(2) != input.size(2): 57 | raise ValueError(f"Mismatching timesteps, " 58 | f"input has {input.size(2)}, cond_input has {cond_input.size(2)}") 59 | 60 | return self._forward_native(input=input, cond_input=cond_input, causal=self.causal) 61 | 62 | def _forward_native(self, input: torch.Tensor, 63 | cond_input: torch.Tensor, 64 | causal:bool=True) -> torch.Tensor: 65 | """ Native torch conv1d with causal padding 66 | 67 | Args: 68 | input shape is (batch, in_channels, time) 69 | cond_input: conditioning input 70 | shape = (batch, 2*out_channels, time) if self.use_film else (batch, out_channels, time) 71 | Returns: 72 | output shape is (batch, out_channels, time) 73 | 74 | """ 75 | 76 | if causal: 77 | padding = self.dilation[0] * self.stride[0] * (self.kernel_size[0]-1) 78 | if padding > 0: 79 | input = torch.nn.functional.pad(input, (padding, 0)) 80 | output = torch.nn.functional.conv1d( 81 | input, self.weight, bias=self.bias, 82 | stride=self.stride, padding=0, 83 | dilation=self.dilation, groups=self.groups) 84 | else: 85 | output = torch.nn.functional.conv1d( 86 | input, self.weight, bias=self.bias, 87 | stride=self.stride, padding='same', 88 | dilation=self.dilation, groups=self.groups) 89 | 90 | if cond_input is not None: 91 | if self.use_film: 92 | b, a = torch.chunk(cond_input, 2, dim=1) 93 | output = a * output + b 94 | else: 95 | output = output + cond_input 96 | return output 97 | -------------------------------------------------------------------------------- /src/neural_formant_synthesis/glotnet/model/feedforward/convolution_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Tuple 3 | from .convolution import Convolution 4 | 5 | 6 | class ConvolutionLayer(torch.nn.Module): 7 | """ 8 | Wavenet Convolution Layer (also known as Residual Block) 9 | 10 | Uses a gated activation and a 1x1 output transformation by default 11 | 12 | """ 13 | 14 | def __init__(self, in_channels, out_channels, kernel_size, dilation=1, 15 | bias=True, device=None, dtype=None, 16 | causal=True, 17 | activation="gated", 18 | use_output_transform=True, 19 | cond_channels=None, 20 | skip_channels=None, 21 | ): 22 | super().__init__() 23 | 24 | residual_channels = out_channels 25 | self.activation = activation 26 | self.activation_fun, self.channel_mul = self._parse_activation(activation) 27 | self.use_output_transform = use_output_transform 28 | self.use_conditioning = cond_channels is not None 29 | self.cond_channels = cond_channels 30 | self.residual_channels = residual_channels 31 | self.skip_channels = skip_channels 32 | self.dilation = dilation 33 | self.conv = Convolution( 34 | in_channels=in_channels, 35 | out_channels=self.channel_mul * residual_channels, 36 | kernel_size=kernel_size, dilation=dilation, bias=bias, 37 | device=device, dtype=dtype, causal=causal) 38 | # TODO: make parameter alloc conditional on use_output_transform 39 | self.out = Convolution( 40 | in_channels=residual_channels, 41 | out_channels=residual_channels, 42 | kernel_size=1, dilation=1, bias=bias, device=device, dtype=dtype) 43 | if self.skip_channels is not None: 44 | self.skip = Convolution( 45 | in_channels=residual_channels, 46 | out_channels=skip_channels, 47 | kernel_size=1, dilation=1, bias=bias, device=device, dtype=dtype) 48 | if self.use_conditioning: 49 | self.cond_1x1 = torch.nn.Conv1d( 50 | cond_channels, self.channel_mul * residual_channels, 51 | kernel_size=1, bias=True, device=device, dtype=dtype) 52 | 53 | @property 54 | def receptive_field(self): 55 | return self.conv.receptive_field + self.out.receptive_field 56 | 57 | activations = { 58 | "gated": ((torch.tanh, torch.sigmoid), 2), 59 | "tanh": (torch.tanh, 1), 60 | "linear": (torch.nn.Identity(), 1) 61 | } 62 | def _parse_activation(self, activation): 63 | activation_fun, channel_mul = ConvolutionLayer.activations.get(activation, (None, None)) 64 | if channel_mul is None: 65 | raise NotImplementedError 66 | return activation_fun, channel_mul 67 | 68 | 69 | def forward(self, input, cond_input=None, sequential=False): 70 | """ 71 | Args: 72 | input, torch.Tensor of shape (batch_size, in_channels, timesteps) 73 | sequential (optional), 74 | if True, use CUDA compatible parallel implementation 75 | if False, use custom C++ sequential implementation 76 | 77 | Returns: 78 | output, torch.Tensor of shape (batch_size, out_channels, timesteps) 79 | skip, torch.Tensor of shape (batch_size, out_channels, timesteps) 80 | 81 | """ 82 | 83 | if cond_input is not None and not self.use_conditioning: 84 | raise RuntimeError("Module has not been initialized to use conditioning, \ 85 | but conditioning input was provided at forward pass") 86 | 87 | if sequential: 88 | raise NotImplementedError 89 | else: 90 | return self._forward_native(input=input, cond_input=cond_input) 91 | 92 | 93 | def _forward_native(self, input, cond_input): 94 | c = self.cond_1x1(cond_input) if self.use_conditioning else None 95 | x = self.conv(input, cond_input=c) 96 | if self.channel_mul == 2: 97 | R = self.residual_channels 98 | x = self.activation_fun[0](x[:, :R, :]) * self.activation_fun[1](x[:, R:, :]) 99 | else: 100 | x = self.activation_fun(x) 101 | 102 | if self.skip_channels is not None: 103 | skip = self.skip(x) 104 | 105 | if self.use_output_transform: 106 | output = self.out(x) 107 | else: 108 | output = x 109 | 110 | if self.skip_channels is None: 111 | return output 112 | else: 113 | return output, skip 114 | 115 | -------------------------------------------------------------------------------- /src/neural_formant_synthesis/glotnet/model/feedforward/convolution_stack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List 3 | from .convolution_layer import ConvolutionLayer 4 | 5 | 6 | class ConvolutionStack(torch.nn.Module): 7 | """ 8 | Wavenet Convolution Stack 9 | 10 | Uses a gated activation and residual connections by default 11 | """ 12 | 13 | def __init__(self, channels, skip_channels, kernel_size, dilations=[1], bias=True, device=None, dtype=None, 14 | causal=True, 15 | activation="gated", 16 | use_residual=True, 17 | use_1x1_block_out=True, 18 | cond_channels=None, 19 | ): 20 | super().__init__() 21 | 22 | self.channels = channels 23 | self.skip_channels = skip_channels 24 | self.activation = activation 25 | self.dilations = dilations 26 | self.use_residual = use_residual 27 | self.use_1x1_block_out = use_1x1_block_out 28 | self.use_conditioning = cond_channels is not None 29 | self.cond_channels = cond_channels 30 | self.causal = causal 31 | self.num_layers = len(dilations) 32 | 33 | self.layers = torch.nn.ModuleList() 34 | for i, d in enumerate(dilations): 35 | use_output_transform = self.use_1x1_block_out 36 | # Always disable output 1x1 for last layer 37 | if i == self.num_layers - 1: 38 | use_output_transform = False 39 | # Add ConvolutionLayer to Stack 40 | self.layers.append( 41 | ConvolutionLayer( 42 | in_channels=channels, out_channels=channels, 43 | kernel_size=kernel_size, dilation=d, bias=bias, device=device, dtype=dtype, 44 | causal=causal, 45 | activation=activation, 46 | use_output_transform=use_output_transform, 47 | cond_channels=self.cond_channels, 48 | skip_channels=self.skip_channels, 49 | ) 50 | ) 51 | 52 | @property 53 | def weights_conv(self): 54 | return [layer.conv.weight for layer in self.layers] 55 | 56 | @property 57 | def biases_conv(self): 58 | return [layer.conv.bias for layer in self.layers] 59 | 60 | @property 61 | def weights_out(self): 62 | return [layer.out.weight for layer in self.layers] 63 | 64 | @property 65 | def biases_out(self): 66 | return [layer.out.bias for layer in self.layers] 67 | 68 | @property 69 | def weights_skip(self): 70 | return [layer.skip.weight for layer in self.layers] 71 | 72 | @property 73 | def biases_skip(self): 74 | return [layer.skip.bias for layer in self.layers] 75 | 76 | @property 77 | def weights_cond(self): 78 | if self.use_conditioning: 79 | return [layer.cond_1x1.weight for layer in self.layers] 80 | else: 81 | return None 82 | 83 | @property 84 | def biases_cond(self): 85 | if self.use_conditioning: 86 | return [layer.cond_1x1.bias for layer in self.layers] 87 | else: 88 | return None 89 | 90 | @property 91 | def receptive_field(self): 92 | return sum([l.receptive_field for l in self.layers]) 93 | 94 | def forward(self, input, cond_input=None, sequential=False): 95 | """ 96 | Args: 97 | input, torch.Tensor of shape (batch_size, channels, timesteps) 98 | cond_input (optional), 99 | torch.Tensor of shape (batch_size, cond_channels, timesteps) 100 | sequential (optional), 101 | if True, use CUDA compatible parallel implementation 102 | if False, use custom C++ sequential implementation 103 | 104 | Returns: 105 | output, torch.Tensor of shape (batch_size, channels, timesteps) 106 | skips, list of torch.Tensor of shape (batch_size, out_channels, timesteps) 107 | 108 | """ 109 | 110 | if cond_input is not None and not self.use_conditioning: 111 | raise RuntimeError("Module has not been initialized to use conditioning, but conditioning input was provided at forward pass") 112 | 113 | if sequential: 114 | raise NotImplementedError("Sequential mode not implemented") 115 | else: 116 | return self._forward_native(input=input, cond_input=cond_input) 117 | 118 | def _forward_native(self, input, cond_input): 119 | x = input 120 | skips = [] 121 | for layer in self.layers: 122 | h = x 123 | x, s = layer(x, cond_input, sequential=False) 124 | x = x + h # residual connection 125 | skips.append(s) 126 | return x, skips 127 | -------------------------------------------------------------------------------- /src/neural_formant_synthesis/glotnet/model/feedforward/wavenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List 3 | from .convolution_layer import ConvolutionLayer 4 | from .convolution_stack import ConvolutionStack 5 | 6 | class WaveNet(torch.nn.Module): 7 | """ Feedforward WaveNet """ 8 | 9 | def __init__(self, 10 | input_channels: int, 11 | output_channels: int, 12 | residual_channels: int, 13 | skip_channels: int, 14 | kernel_size: int, 15 | dilations: List[int] = [1, 2, 4, 8, 16, 32, 64, 128, 256], 16 | causal: bool = True, 17 | activation: str = "gated", 18 | use_residual: bool = True, 19 | cond_channels: int = None, 20 | cond_net: torch.nn.Module = None, 21 | ): 22 | super().__init__() 23 | 24 | self.input_channels = input_channels 25 | self.output_channels = output_channels 26 | self.residual_channels = residual_channels 27 | self.skip_channels = skip_channels 28 | self.cond_channels = cond_channels 29 | self.use_conditioning = cond_channels is not None 30 | self.activation = activation 31 | self.dilations = dilations 32 | self.kernel_size = kernel_size 33 | self.use_residual = use_residual 34 | self.num_layers = len(dilations) 35 | self.causal = causal 36 | 37 | # Layers 38 | self.input = ConvolutionLayer( 39 | in_channels=self.input_channels, 40 | out_channels=self.residual_channels, 41 | kernel_size=1, activation="tanh", use_output_transform=False) 42 | self.stack = ConvolutionStack( 43 | channels=self.residual_channels, 44 | skip_channels=self.skip_channels, 45 | kernel_size=self.kernel_size, 46 | dilations=dilations, 47 | activation=self.activation, 48 | use_residual=True, 49 | causal=self.causal, 50 | cond_channels=cond_channels) 51 | # TODO: output layers should be just convolution, These hanve conv and Out 52 | self.output1 = ConvolutionLayer( 53 | in_channels=self.skip_channels, 54 | out_channels=self.residual_channels, 55 | kernel_size=1, activation="tanh", use_output_transform=False) 56 | self.output2 = ConvolutionLayer( 57 | in_channels=self.residual_channels, 58 | out_channels=self.output_channels, 59 | kernel_size=1, activation="linear", use_output_transform=False) 60 | 61 | self.cond_net = cond_net 62 | 63 | @property 64 | def output_weights(self): 65 | return [self.output1.conv.weight, self.output2.conv.weight] 66 | 67 | @property 68 | def output_biases(self): 69 | return [self.output1.conv.bias, self.output2.conv.bias] 70 | 71 | @property 72 | def receptive_field(self): 73 | return self.input.receptive_field + self.stack.receptive_field \ 74 | + self.output1.receptive_field + self.output2.receptive_field 75 | 76 | 77 | def forward(self, input, cond_input=None): 78 | """ 79 | Args: 80 | input, torch.Tensor of shape (batch_size, input_channels, timesteps) 81 | cond_input (optional), 82 | torch.Tensor of shape (batch_size, cond_channels, timesteps) 83 | 84 | Returns: 85 | output, torch.Tensor of shape (batch_size, output_channels, timesteps) 86 | 87 | """ 88 | 89 | if cond_input is not None and not self.use_conditioning: 90 | raise RuntimeError("Module has not been initialized to use conditioning, but conditioning input was provided at forward pass") 91 | 92 | x = input 93 | x = self.input(x) 94 | _, skips = self.stack(x, cond_input) # TODO self.stack must be called something different, torch.stack is different 95 | x = torch.stack(skips, dim=0).sum(dim=0) 96 | x = self.output1(x) 97 | x = self.output2(x) 98 | return x 99 | 100 | -------------------------------------------------------------------------------- /src/neural_formant_synthesis/glotnet/sigproc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ljuvela/SourceFilterNeuralFormants/d0894c2aa153510e6967c3b62c47f73ce4cb3879/src/neural_formant_synthesis/glotnet/sigproc/__init__.py -------------------------------------------------------------------------------- /src/neural_formant_synthesis/glotnet/sigproc/allpole.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def allpole(x: torch.Tensor, a: torch.Tensor) -> torch.Tensor: 5 | """ 6 | All-pole filter 7 | :param x: input signal, 8 | shape (B, C, T) (batch, channels, timesteps) 9 | :param a: filter coefficients (denominator), 10 | shape (C, N, T) (channels, num taps, timesteps) 11 | :return: filtered signal 12 | shape (B, C, T) (batch, channels, timesteps) 13 | """ 14 | y = torch.zeros_like(x) 15 | 16 | a_normalized = a / a[:, 0:1, :] 17 | 18 | # filter order 19 | p = a.shape[1] - 1 20 | 21 | # filter coefficients 22 | a1 = a_normalized[:, 1:, :] 23 | 24 | # flip coefficients 25 | a1 = torch.flip(a1, [1]) 26 | 27 | # zero pad y by filter order 28 | y = torch.nn.functional.pad(y, (p, 0)) 29 | 30 | # filter 31 | for i in range(p, y.shape[-1]): 32 | y[..., i] = x[..., i - p] - \ 33 | torch.sum(a1[..., i - p] * y[..., i - p:i], dim=-1) 34 | 35 | return y[..., p:] 36 | 37 | class AllPoleFunction(torch.autograd.Function): 38 | 39 | @staticmethod 40 | def forward(ctx, x, a): 41 | y = allpole(x, a) 42 | ctx.save_for_backward(y, x, a) 43 | return y 44 | 45 | @staticmethod 46 | def backward(ctx, dy): 47 | y, x, a = ctx.saved_tensors 48 | dx = da = None 49 | 50 | n_batch = x.size(0) 51 | n_channels = x.size(1) 52 | p = a.size(1) - 1 53 | T = dy.size(-1) 54 | 55 | # filter or 56 | dyda = allpole(-y, a) 57 | dyda = torch.nn.functional.pad(dyda, (p, 0)) 58 | 59 | # da = torch.zeros_like(a) 60 | # for i in range(0, T): 61 | # for j in range(0, p): 62 | # da[:, p, i] = dyda[..., i:i+T] * dy 63 | # da = da.flip([1]) 64 | 65 | 66 | da = F.conv1d( 67 | dyda.view(1, n_batch * n_channels, -1), 68 | dy.view(n_batch * n_channels, 1, -1), 69 | groups=n_batch * n_channels).view(n_batch, n_channels, -1).sum(0).flip(1) 70 | 71 | dx = allpole(dy.flip(-1), a.flip(-1)).flip(-1) 72 | 73 | return dx, da 74 | 75 | class AllPole(torch.nn.Module): 76 | 77 | def __init__(self): 78 | super().__init__() 79 | 80 | def forward(self, x, a): 81 | return AllPoleFunction.apply(x, a) 82 | 83 | -------------------------------------------------------------------------------- /src/neural_formant_synthesis/glotnet/sigproc/biquad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .lfilter import LFilter 4 | from typing import Tuple, Union 5 | 6 | 7 | class BiquadBaseFunctional(torch.nn.Module): 8 | 9 | def __init__(self): 10 | """ Initialize Biquad""" 11 | 12 | super().__init__() 13 | 14 | # TODO: pass STFT arguments to LFilter 15 | self.n_fft = 2048 16 | self.hop_length = 512 17 | self.win_length = 1024 18 | self.lfilter = LFilter(n_fft=self.n_fft, 19 | hop_length=self.hop_length, 20 | win_length=self.win_length) 21 | 22 | def _check_param_sizes(self, freq, gain, Q): 23 | """ Check that parameter sizes are compatible 24 | 25 | Args: 26 | freq: center frequencies 27 | shape = (batch, out_channels, in_channels, n_frames) 28 | gain: gains in decibels, 29 | shape = (batch, out_channels, in_channels, n_frames) 30 | Q: filter resonance (quality factor) 31 | shape = (batch, out_channels, in_channels, n_frames) 32 | 33 | Returns: 34 | batch, out_channels, in_channels, n_frames 35 | 36 | """ 37 | 38 | # dimensions must be flat 39 | if freq.ndim != 4: 40 | raise ValueError("freq must be 4D") 41 | if gain.ndim != 4: 42 | raise ValueError("gain must be 4D") 43 | if Q.ndim != 4: 44 | raise ValueError("Q must be 4D") 45 | 46 | if freq.shape != gain.shape != Q.shape: 47 | raise ValueError("freq, gain, and Q must have the same shape") 48 | 49 | return freq.shape 50 | 51 | def _params_to_direct_form(self, 52 | freq: torch.Tensor, 53 | gain: torch.Tensor, 54 | Q: torch.Tensor 55 | ) -> Tuple[torch.Tensor, torch.Tensor]: 56 | """ 57 | Args: 58 | freq, center frequency, 59 | shape is (..., n_frames) 60 | gain, gain in decibels 61 | shape is (..., n_frames) 62 | Q, resonance sharpness 63 | shape is (..., n_frames) 64 | 65 | Returns: 66 | b, filter numerator coefficients 67 | shape is (..., n_taps==3, n_frames) 68 | a, filter denominator coefficients 69 | shape is (..., n_taps==3, n_frames) 70 | 71 | """ 72 | raise NotImplementedError("Subclasses must implement this method") 73 | 74 | 75 | def forward(self, 76 | x: torch.Tensor, 77 | freq: torch.Tensor, 78 | gain: torch.Tensor, 79 | Q: torch.Tensor, 80 | ) -> torch.Tensor: 81 | """ 82 | Args: 83 | x: input signal 84 | shape = (batch, in_channels, time) 85 | freq: center frequencies 86 | shape = (batch, out_channels, in_channels, n_frames) 87 | n_frames is expected to be (time // hop_size) 88 | gain: gains in decibels, 89 | shape = (batch, out_channels, in_channels, n_frames) 90 | Q: filter resonance (quality factor) 91 | shape = (batch, out_channels, in_channels, n_frames) 92 | 93 | Returns: 94 | y: output signal 95 | shape = (batch, channels, n_filters, time) 96 | """ 97 | 98 | batch, out_channels, in_channels, n_frames = self._check_param_sizes(freq=freq, gain=gain, Q=Q) 99 | timesteps = x.size(-1) 100 | 101 | freq = freq.reshape(batch * out_channels * in_channels, n_frames) 102 | gain = gain.reshape(batch * out_channels * in_channels, n_frames) 103 | Q = Q.reshape(batch * out_channels * in_channels, n_frames) 104 | 105 | b, a = self._params_to_direct_form(freq=freq, gain=gain, Q=Q) 106 | 107 | # expand x: (batch, in_channels, time) -> (batch, out_channels, in_channels, time) 108 | x_exp = x.unsqueeze(1).expand(-1, out_channels, -1, -1) 109 | # apply filtering 110 | x_exp = x_exp.reshape(batch * out_channels * in_channels, 1, timesteps) 111 | y = self.lfilter.forward(x_exp, b=b, a=a) 112 | # reshape 113 | y = y.reshape(batch, out_channels, in_channels, timesteps) 114 | # sum over input channels 115 | y = y.sum(dim=2) 116 | 117 | return y 118 | 119 | 120 | class BiquadPeakFunctional(BiquadBaseFunctional): 121 | 122 | def __init__(self): 123 | """ Initialize Biquad""" 124 | super().__init__() 125 | 126 | def _params_to_direct_form(self, 127 | freq: torch.Tensor, 128 | gain: torch.Tensor, 129 | Q: torch.Tensor 130 | ) -> Tuple[torch.Tensor, torch.Tensor]: 131 | """ 132 | Args: 133 | freq, center frequency, 134 | shape is (..., n_frames) 135 | gain, gain in decibels 136 | shape is (..., n_frames) 137 | Q, resonance sharpness 138 | shape is (..., n_frames) 139 | 140 | Returns: 141 | b, filter numerator coefficients 142 | shape is (..., n_taps==3, n_frames) 143 | a, filter denominator coefficients 144 | shape is (..., n_taps==3, n_frames) 145 | 146 | """ 147 | if torch.any(freq > 1.0): 148 | raise ValueError(f"Normalized frequency must be below 1.0, max was {freq.max()}") 149 | if torch.any(freq < 0.0): 150 | raise ValueError(f"Normalized frequency must be above 0.0, min was {freq.min()}") 151 | 152 | freq = freq.unsqueeze(-2) 153 | gain = gain.unsqueeze(-2) 154 | Q = Q.unsqueeze(-2) 155 | 156 | omega = torch.pi * freq 157 | A = torch.pow(10.0, 0.025 * gain) 158 | alpha = 0.5 * torch.sin(omega) / Q 159 | 160 | b0 = 1.0 + alpha * A 161 | b1 = -2.0 * torch.cos(omega) 162 | b2 = 1.0 - alpha * A 163 | a0 = 1 + alpha / A 164 | a1 = -2.0 * torch.cos(omega) 165 | a2 = 1 - alpha / A 166 | 167 | a = torch.cat([a0, a1, a2], dim=-2) 168 | b = torch.cat([b0, b1, b2], dim=-2) 169 | 170 | b = b / a0 171 | a = a / a0 172 | return b, a 173 | 174 | 175 | class BiquadModule(torch.nn.Module): 176 | 177 | 178 | @property 179 | def freq(self): 180 | return self._freq 181 | 182 | 183 | def set_freq(self, freq): 184 | if type(freq) != torch.Tensor: 185 | freq = torch.tensor([freq], dtype=torch.float32) 186 | 187 | # convert to normalized frequency 188 | freq = 2.0 * freq / self.fs 189 | 190 | if freq.max() > 1.0: 191 | raise ValueError( 192 | "Maximum normalized frequency is larger than 1.0. " 193 | "Please provide a sample rate or input normalized frequencies") 194 | if freq.min() < 0.0: 195 | raise ValueError( 196 | "Maximum normalized frequency is smaller than 0.0.") 197 | 198 | 199 | self._freq.data = freq.broadcast_to(self._freq.shape) 200 | 201 | @property 202 | def gain_dB(self): 203 | return self._gain_dB 204 | 205 | def set_gain_dB(self, gain): 206 | if type(gain) != torch.Tensor: 207 | gain = torch.tensor([gain], dtype=torch.float32) 208 | self._gain_dB.data = gain.broadcast_to(self._gain_dB.shape) 209 | 210 | @property 211 | def Q(self): 212 | return self._Q 213 | 214 | def set_Q(self, Q): 215 | if type(Q) != torch.Tensor: 216 | Q = torch.tensor([Q], dtype=torch.float32) 217 | self._Q.data = Q.broadcast_to(self._Q.shape) 218 | 219 | def _init_freq(self): 220 | freq = torch.rand(self.out_channels, self.in_channels) 221 | self._freq = torch.nn.Parameter(freq) 222 | 223 | def _init_gain_dB(self): 224 | gain_dB = torch.zeros(self.out_channels, self.in_channels) 225 | self._gain_dB = torch.nn.Parameter(gain_dB) 226 | 227 | def _init_Q(self): 228 | Q = torch.ones(self.out_channels, self.in_channels) 229 | self._Q = torch.nn.Parameter(0.7071 * Q) 230 | 231 | 232 | 233 | def __init__(self, 234 | in_channels: int=1, 235 | out_channels: int=1, 236 | fs: float = None, 237 | func: BiquadBaseFunctional = BiquadPeakFunctional() 238 | ): 239 | """ 240 | Args: 241 | func: BiquadBaseFunctional subclass 242 | freq: center frequency 243 | gain: gain in dB 244 | Q: quality factor determining filter resonance bandwidth 245 | fs: sample rate, if not provided freq is assumed as normalized from 0 to 1 (Nyquist) 246 | 247 | """ 248 | super().__init__() 249 | 250 | self.func = func 251 | 252 | # if no sample rate provided, assume normalized frequency 253 | if fs is None: 254 | fs = 2.0 255 | 256 | self.fs = fs 257 | 258 | self.in_channels = in_channels 259 | self.out_channels = out_channels 260 | 261 | self._init_freq() 262 | self._init_gain_dB() 263 | self._init_Q() 264 | 265 | 266 | def get_impulse_response(self, n_timesteps: int = 2048) -> torch.Tensor: 267 | """ Get impulse response of filter 268 | 269 | Args: 270 | n_timesteps: number of timesteps to evaluate 271 | 272 | Returns: 273 | h, shape is (batch, channels, n_timesteps) 274 | """ 275 | x = torch.zeros(1, 1, n_timesteps) 276 | x[:, :, 0] = 1.0 277 | h = self.forward(x) 278 | return h 279 | 280 | def get_frequency_response(self, n_timesteps: int = 2048, n_fft: int = 2048) -> torch.Tensor: 281 | """ Get frequency response of filter 282 | 283 | Args: 284 | n_timesteps: number of timesteps to evaluate 285 | 286 | Returns: 287 | H, shape is (batch, channels, n_timesteps) 288 | """ 289 | h = self.get_impulse_response(n_timesteps=n_timesteps) 290 | H = torch.fft.rfft(h, n=n_fft, dim=-1) 291 | H = torch.abs(H) 292 | return H 293 | 294 | 295 | def forward(self, x: torch.Tensor) -> torch.Tensor: 296 | """ 297 | Args: 298 | x, shape is (batch, in_channels, timesteps) 299 | 300 | Returns: 301 | y, shape is (batch, out_channels, timesteps) 302 | """ 303 | 304 | batch, in_channels, timesteps = x.size() 305 | 306 | num_frames = timesteps // self.func.hop_length 307 | # map size: (out_channels, in_channels) -> (batch, out_channels, in_channels, n_frames) 308 | freq = self.freq.unsqueeze(0).unsqueeze(-1).expand(batch, -1, -1, num_frames) 309 | gain = self.gain_dB.unsqueeze(0).unsqueeze(-1).expand(batch, -1, -1, num_frames) 310 | q_factor = self.Q.unsqueeze(0).unsqueeze(-1).expand(batch, -1, -1, num_frames) 311 | 312 | y = self.func.forward(x, freq, gain, q_factor) 313 | 314 | return y 315 | 316 | class BiquadParallelBankModule(torch.nn.Module): 317 | 318 | def __init__(self, 319 | num_filters:int=10, 320 | fs=None, 321 | func: BiquadBaseFunctional = BiquadPeakFunctional() 322 | ): 323 | """ 324 | Args: 325 | num_filters: number of filters in bank 326 | func: BiquadBaseFunctional subclass 327 | 328 | """ 329 | super().__init__() 330 | 331 | self.num_filters = num_filters 332 | 333 | self.fs = fs 334 | self.filter_bank = BiquadModule(in_channels=num_filters, out_channels=1, fs=fs, func=func) 335 | 336 | # flat initialization 337 | freq = torch.linspace(0.0, 1.0, num_filters+2)[1:-1] 338 | gain = torch.zeros_like(freq) 339 | Q = 0.7071 * torch.ones_like(freq) 340 | 341 | self.filter_bank.set_freq(freq) 342 | self.filter_bank.set_gain_dB(gain) 343 | self.filter_bank.set_Q(Q) 344 | 345 | def set_freq(self, freq: torch.Tensor): 346 | """ Set center frequency of each filter 347 | Args: 348 | freq, shape is (num_filters,) 349 | """ 350 | self.filter_bank.set_freq(freq) 351 | 352 | def set_gain_dB(self, gain_dB: torch.Tensor): 353 | self.filter_bank.set_gain_dB(gain_dB) 354 | 355 | def set_Q(self, Q: torch.Tensor): 356 | self.filter_bank.set_Q(Q) 357 | 358 | def forward(self, x: torch.Tensor) -> torch.Tensor: 359 | """ 360 | Args: 361 | x, shape is (batch, channels=1, timesteps) 362 | 363 | Returns: 364 | y, shape is (batch, channels=1, timesteps) 365 | """ 366 | 367 | if x.size(1) != 1: 368 | raise ValueError(f"Input must have 1 channel, got {x.size(1)}") 369 | 370 | # expand channels to match filter bank 371 | x = x.expand(-1, self.num_filters, -1) 372 | 373 | # output shape is (batch, channels, , timesteps) 374 | y = self.filter_bank(x) 375 | 376 | # normalize output by number of filters 377 | # y = y / self.num_filters 378 | 379 | return y 380 | 381 | def get_impulse_response(self, n_timesteps: int = 2048) -> torch.Tensor: 382 | """ Get impulse response of filter 383 | 384 | Args: 385 | n_timesteps: number of timesteps to evaluate 386 | 387 | Returns: 388 | h, shape is (batch, channels, n_timesteps) 389 | """ 390 | x = torch.zeros(1, 1, n_timesteps) 391 | x[:, :, 0] = 1.0 392 | h = self.forward(x) 393 | return h 394 | 395 | def get_frequency_response(self, n_timesteps: int = 2048, n_fft: int = 2048) -> torch.Tensor: 396 | """ Get frequency response of filter 397 | 398 | Args: 399 | n_timesteps: number of timesteps to evaluate 400 | 401 | Returns: 402 | H, shape is (batch, channels, n_timesteps) 403 | """ 404 | h = self.get_impulse_response(n_timesteps=n_timesteps) 405 | H = torch.fft.rfft(h, n=n_fft, dim=-1) 406 | H = torch.abs(H) 407 | return H 408 | 409 | 410 | 411 | class BiquadResonatorFunctional(BiquadBaseFunctional): 412 | 413 | def __init__(self): 414 | """ Initialize Biquad""" 415 | super().__init__() 416 | 417 | def _params_to_direct_form(self, 418 | freq: torch.Tensor, 419 | gain: torch.Tensor, 420 | Q: torch.Tensor 421 | ) -> Tuple[torch.Tensor, torch.Tensor]: 422 | """ 423 | Args: 424 | freq, center frequency, 425 | shape is (..., n_frames) 426 | gain, gain in decibels 427 | shape is (..., n_frames) 428 | Q, resonance sharpness 429 | shape is (..., n_frames) 430 | 431 | Returns: 432 | b, filter numerator coefficients 433 | shape is (..., n_taps==3, n_frames) 434 | a, filter denominator coefficients 435 | shape is (..., n_taps==3, n_frames) 436 | 437 | """ 438 | if torch.any(freq > 1.0): 439 | raise ValueError(f"Normalized frequency must be below 1.0, max was {freq.max()}") 440 | if torch.any(freq < 0.0): 441 | raise ValueError(f"Normalized frequency must be above 0.0, min was {freq.min()}") 442 | 443 | freq = freq.unsqueeze(-2) 444 | gain = gain.unsqueeze(-2) 445 | Q = Q.unsqueeze(-2) 446 | 447 | omega = torch.pi * freq 448 | A = torch.pow(10.0, 0.025 * gain) 449 | alpha = 0.5 * torch.sin(omega) / Q 450 | 451 | b0 = torch.ones_like(freq) 452 | b1 = torch.zeros_like(freq) 453 | b2 = torch.zeros_like(freq) 454 | a0 = 1 + alpha / A 455 | a1 = -2.0 * torch.cos(omega) 456 | a2 = 1 - alpha / A 457 | 458 | a = torch.cat([a0, a1, a2], dim=1) 459 | b = torch.cat([b0, b1, b2], dim=1) 460 | 461 | b = b / a0 462 | a = a / a0 463 | return b, a 464 | 465 | 466 | class BiquadBandpassFunctional(BiquadBaseFunctional): 467 | 468 | def __init__(self): 469 | """ Initialize Biquad""" 470 | super().__init__() 471 | 472 | def _params_to_direct_form(self, 473 | freq: torch.Tensor, 474 | gain: torch.Tensor, 475 | Q: torch.Tensor 476 | ) -> Tuple[torch.Tensor, torch.Tensor]: 477 | """ 478 | Args: 479 | freq, center frequency, 480 | shape is (..., n_frames) 481 | gain, gain in decibels 482 | shape is (..., n_frames) 483 | Q, resonance sharpness 484 | shape is (..., n_frames) 485 | 486 | Returns: 487 | b, filter numerator coefficients 488 | shape is (..., n_taps==3, n_frames) 489 | a, filter denominator coefficients 490 | shape is (..., n_taps==3, n_frames) 491 | 492 | """ 493 | if torch.any(freq > 1.0): 494 | raise ValueError(f"Normalized frequency must be below 1.0, max was {freq.max()}") 495 | if torch.any(freq < 0.0): 496 | raise ValueError(f"Normalized frequency must be above 0.0, min was {freq.min()}") 497 | 498 | freq = freq.unsqueeze(-2) 499 | gain = gain.unsqueeze(-2) 500 | Q = Q.unsqueeze(-2) 501 | 502 | omega = torch.pi * freq 503 | A = torch.pow(10.0, 0.025 * gain) 504 | alpha = 0.5 * torch.sin(omega) / Q 505 | 506 | b0 = alpha 507 | b1 = torch.zeros_like(freq) 508 | b2 = -1.0 * alpha 509 | a0 = 1 + alpha / A 510 | a1 = -2.0 * torch.cos(omega) 511 | a2 = 1 - alpha / A 512 | 513 | a = torch.cat([a0, a1, a2], dim=1) 514 | b = torch.cat([b0, b1, b2], dim=1) 515 | 516 | b = b / a0 517 | a = a / a0 518 | return b, a -------------------------------------------------------------------------------- /src/neural_formant_synthesis/glotnet/sigproc/emphasis.py: -------------------------------------------------------------------------------- 1 | from .lfilter import LFilter 2 | import torch 3 | from .lfilter import ceil_division 4 | 5 | class Emphasis(torch.nn.Module): 6 | """ Pre-emphasis and de-emphasis filter""" 7 | 8 | def __init__(self, alpha=0.85) -> None: 9 | """ 10 | Args: 11 | alpha : pre-emphasis coefficient 12 | """ 13 | super().__init__() 14 | if alpha < 0 or alpha >= 1: 15 | raise ValueError(f"alpha must be in [0, 1), got {alpha}") 16 | self.alpha = alpha 17 | self.lfilter = LFilter(n_fft=512, hop_length=256, win_length=512) 18 | self.register_buffer('coeffs', torch.tensor([1, -alpha], dtype=torch.float32)) 19 | 20 | def forward(self, x: torch.Tensor) -> torch.Tensor: 21 | """" Apply pre-emphasis to signal 22 | Args: 23 | x : input signal, shape (batch, channels, timesteps) 24 | 25 | Returns: 26 | y : output signal, shape (batch, channels, timesteps) 27 | 28 | """ 29 | if not (self.alpha > 0): 30 | return x 31 | # expand coefficients to batch size and number of frames 32 | b = self.coeffs.unsqueeze(0).unsqueeze(-1).expand(x.size(0), -1, ceil_division(x.size(-1), self.lfilter.hop_length)) 33 | # filter 34 | return self.lfilter(x, b=b, a=None) 35 | 36 | def emphasis(self, x: torch.Tensor) -> torch.Tensor: 37 | """" Apply pre-emphasis to signal 38 | Args: 39 | x : input signal, shape (batch, channels, timesteps) 40 | 41 | Returns: 42 | y : output signal, shape (batch, channels, timesteps) 43 | 44 | """ 45 | return self.forward(x) 46 | 47 | def deemphasis(self, x: torch.Tensor) -> torch.Tensor: 48 | """ Remove emphasis from signal 49 | Args: 50 | x : input signal, shape (batch, channels, timesteps) 51 | Returns: 52 | y : output signal, shape (batch, channels, timesteps) 53 | """ 54 | if not (self.alpha > 0): 55 | return x 56 | # expand coefficients to batch size and number of frames 57 | a = self.coeffs.unsqueeze(0).unsqueeze(-1).expand(x.size(0), -1, ceil_division(x.size(-1), self.lfilter.hop_length)) 58 | # filter 59 | return self.lfilter(x, b=None, a=a) -------------------------------------------------------------------------------- /src/neural_formant_synthesis/glotnet/sigproc/levinson.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def toeplitz(r: torch.Tensor): 5 | """" Construct Toeplitz matrix """ 6 | p = r.size(-1) 7 | rr = torch.cat([r, r[..., 1:].flip(dims=(-1,))], dim=-1) 8 | T = [torch.roll(rr, i, dims=(-1,))[...,:p] for i in range(p)] 9 | return torch.stack(T, dim=-1) 10 | 11 | 12 | def levinson(R:torch.Tensor, M:int, eps:float=1e-3) -> torch.Tensor: 13 | """ Levinson-Durbin method for converting autocorrelation to all-pole polynomial 14 | Args: 15 | R: autocorrelation, shape=(..., M) 16 | M: filter polynomial order 17 | Returns: 18 | A: all-pole polynomial, shape=(..., M) 19 | """ 20 | # normalize R 21 | R = R / R[..., 0:1] 22 | # white noise correction 23 | R[..., 0] = R[..., 0] + eps 24 | 25 | # zero lag 26 | K = torch.sum(R[..., 1:2], dim=-1, keepdim=True) 27 | A = torch.cat([-1.0*K, torch.ones_like(R[..., 0:1])], dim=-1) 28 | E = 1.0 - K ** 2 29 | # higher lags 30 | for p in torch.arange(1, M): 31 | K = torch.sum(A[..., 0:p+1] * R[..., 1:p+2], dim=-1, keepdim=True) / E 32 | if K.abs().max() > 1.0: 33 | raise ValueError(f"Unstable filter, |K| was {K.abs().max()}") 34 | A = torch.cat([-1.0*K, 35 | A[..., 0:p] - 1.0*K * 36 | torch.flip(A[..., 0:p], dims=[-1]), 37 | torch.ones_like(R[..., 0:1])], dim=-1) 38 | E = E * (1.0 - K ** 2) 39 | A = torch.flip(A, dims=[-1]) 40 | return A 41 | 42 | 43 | def forward_levinson(K: torch.Tensor, M: int = None) -> torch.Tensor: 44 | """ Forward Levinson converts reflection coefficients to all-pole polynomial 45 | 46 | Args: 47 | K: reflection coefficients, shape=(..., M) 48 | M: filter polynomial order (optional, defaults to K.size(-1)) 49 | Returns: 50 | A: all-pole polynomial, shape=(..., M+1) 51 | 52 | """ 53 | if M is None: 54 | M = K.size(-1) 55 | 56 | A = -1.0*K[..., 0:1] 57 | for p in torch.arange(1, M): 58 | A = torch.cat([-1.0*K[..., p:p+1], 59 | A[..., 0:p] - 1.0*K[..., p:p+1] * torch.flip(A[..., 0:p], dims=[-1])], dim=-1) 60 | 61 | A = torch.cat([A, torch.ones_like(A[..., 0:1])], dim=-1) 62 | A = torch.flip(A, dims=[-1]) # flip zero delay to zero:th index 63 | return A 64 | 65 | 66 | def spectrum_to_allpole(spectrum:torch.Tensor, order:int, root_scale:float=1.0): 67 | """ Convert spectrum to all-pole filter coefficients 68 | 69 | Args: 70 | spectrum: power spectrum (squared magnitude), shape=(..., K) 71 | order: filter polynomial order 72 | 73 | Returns: 74 | a: filter predictor polynomial tensor, shape=(..., order+1) 75 | g: filter gain 76 | """ 77 | r = torch.fft.irfft(spectrum, dim=-1) 78 | # add small value to diagonal to avoid singular matrix 79 | r[..., 0] = r[..., 0] + 1e-6 80 | # all pole from autocorr 81 | a = levinson(r, order) 82 | 83 | # filter gain 84 | # g = torch.sqrt(torch.dot(r[:(order+1)], a)) 85 | g = torch.sqrt(torch.sum(r[..., :(order+1)] * a, dim=-1, keepdim=True)) 86 | 87 | # scale filter roots 88 | if root_scale < 1.0: 89 | a = a * root_scale ** torch.arange(order+1, dtype=torch.float32, device=a.device) 90 | 91 | return a, g -------------------------------------------------------------------------------- /src/neural_formant_synthesis/glotnet/sigproc/lfilter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import math 4 | 5 | # TODO: wrap STFT into class 6 | # TODO: support grouped convolution (multichannel, no mixing) 7 | 8 | def ceil_division(n: int, d: int) -> int: 9 | """ Ceiling integer division """ 10 | return -(n // -d) 11 | 12 | def pad_frames(frames: torch.Tensor, target_len: int) -> torch.Tensor: 13 | n_pad = target_len - frames.size(-1) 14 | l_pad = n_pad // 2 15 | r_pad = ceil_division(n_pad, 2) 16 | return F.pad(frames, pad=(l_pad, r_pad), mode='replicate') 17 | 18 | class LFilter(torch.nn.Module): 19 | """ Linear filtering with STFT """ 20 | 21 | def __init__(self, n_fft: int, hop_length: int, win_length: int): 22 | super().__init__() 23 | self.hop_length = hop_length 24 | self.n_fft = n_fft 25 | self.win_length = win_length 26 | window = torch.hann_window(window_length=self.win_length) 27 | self.register_buffer('window', window.reshape(1, -1, 1), persistent=False) 28 | self.scale = 2.0 * self.hop_length / self.win_length 29 | 30 | self.fold_kwargs = { 31 | 'kernel_size':(self.win_length, 1), 32 | 'stride' : (self.hop_length, 1), 33 | 'padding' : 0 34 | } 35 | 36 | def forward(self, x: torch.Tensor, b: torch.Tensor = None, a: torch.Tensor = None) -> torch.Tensor: 37 | """ 38 | Args: 39 | x : input signal 40 | (batch, channels==1, timesteps) 41 | b : filter numerator coefficients 42 | (batch, b_len, n_frames) 43 | a : filter denominator coefficients 44 | (batch, a_len, n_frames) 45 | """ 46 | num_frames = ceil_division(x.size(-1), self.hop_length) 47 | 48 | left_pad = self.win_length 49 | last_frame_center = num_frames * self.hop_length 50 | right_pad = last_frame_center - x.size(-1) + self.win_length 51 | x_padded = F.pad(x, pad=(left_pad, right_pad)) 52 | 53 | if x.size(1) != 1: 54 | raise RuntimeError(f"channels must be 1, got shape {x.shape}") 55 | 56 | # frame 57 | x_padded = x_padded.unsqueeze(-1) 58 | fold_size = x_padded.shape[2:] 59 | x_framed = F.unfold(x_padded, # (B, C, T, 1) 60 | **self.fold_kwargs) 61 | 62 | # window 63 | x_windowed = x_framed * self.window 64 | 65 | # FFT 66 | X = torch.fft.rfft(x_windowed, n=self.n_fft, dim=1) 67 | 68 | if a is None: 69 | A = torch.ones_like(X) 70 | else: 71 | a_pad = pad_frames(a, target_len=X.size(-1)) 72 | A = torch.fft.rfft(a_pad, n=self.n_fft, dim=1) 73 | 74 | if b is None: 75 | B = torch.ones_like(X) 76 | else: 77 | b_pad = pad_frames(b, target_len=X.size(-1)) 78 | B = torch.fft.rfft(b_pad, n=self.n_fft, dim=1) 79 | 80 | # multiply 81 | Y = X * B / A 82 | 83 | # IFFT 84 | y_windowed = torch.fft.irfft(Y, n=self.n_fft, dim=1) 85 | y_windowed = y_windowed[:, :self.win_length, :] 86 | # TODO: window again for OLA 87 | # y_windowed = y_windowed[:, :self.win_length, :] * self.window 88 | 89 | # Overlap-add 90 | # TODO: change OLA fold args kernel size to fft_len 91 | y = F.fold(y_windowed, output_size=fold_size, 92 | **self.fold_kwargs) 93 | return y[:, :, left_pad:-right_pad, 0] * self.scale 94 | 95 | 96 | 97 | 98 | 99 | 100 | -------------------------------------------------------------------------------- /src/neural_formant_synthesis/glotnet/sigproc/lpc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .lfilter import LFilter 3 | from .levinson import spectrum_to_allpole 4 | 5 | 6 | class LinearPredictor(torch.nn.Module): 7 | 8 | def __init__(self, 9 | n_fft: int = 512, 10 | hop_length: int = 256, 11 | win_length=512, 12 | order=10): 13 | """ 14 | Args: 15 | n_fft (int): FFT size 16 | hop_length (int): Hop length 17 | win_length (int): Window length 18 | order (int): Allpole order 19 | """ 20 | super().__init__() 21 | self.n_fft = n_fft 22 | self.hop_length = hop_length 23 | self.win_length = win_length 24 | self.order = order 25 | self.lfilter = LFilter(n_fft=n_fft, hop_length=hop_length, win_length=win_length) 26 | 27 | def estimate(self, x: torch.Tensor) -> torch.Tensor: 28 | """ 29 | Args: 30 | x (torch.Tensor): Input audio (batch, time) 31 | 32 | Returns: 33 | torch.Tensor: Allpole coefficients (batch, order+1, time) 34 | """ 35 | X = torch.stft(x, n_fft=self.n_fft, 36 | hop_length=self.hop_length, win_length=self.win_length, 37 | return_complex=True) 38 | # power spectrum 39 | X = torch.abs(X)**2 40 | 41 | # transpose to (batch, time, freq) 42 | X = X.transpose(1, 2) 43 | 44 | # allpole coefficients 45 | a, _ = spectrum_to_allpole(X, order=self.order) 46 | 47 | # transpose to (batch, order, num_frames) 48 | a = a.transpose(1, 2) 49 | 50 | return a 51 | 52 | def inverse_filter(self, 53 | x: torch.Tensor, 54 | a: torch.Tensor) -> torch.Tensor: 55 | """ 56 | Args: 57 | x: Input audio (batch, time) 58 | a: Allpole coefficients (batch, order+1, time) 59 | 60 | Returns: 61 | Inverse filtered audio (batch, time) 62 | """ 63 | # inverse filter 64 | e = self.lfilter.forward(x=x, b=a, a=None) 65 | 66 | return e 67 | 68 | def synthesis_filter(self, 69 | e: torch.Tensor, 70 | a: torch.Tensor) -> torch.Tensor: 71 | """ 72 | Args: 73 | e: Excitation signal (batch, channels, time) 74 | a: Allpole coefficients (batch, order+1, time) 75 | 76 | Returns: 77 | torch.Tensor: Synthesis filtered audio (batch, time) 78 | """ 79 | # inverse filter 80 | x = self.lfilter.forward(x=e, b=None, a=a) 81 | return x 82 | 83 | 84 | def prediction(self, 85 | x: torch.Tensor, 86 | a: torch.Tensor) -> torch.Tensor: 87 | """ 88 | Args: 89 | x: Input audio (batch, channels, time) 90 | a: Allpole coefficients (batch, order+1, time) 91 | 92 | Returns: 93 | p: Linear prediction signal (batch, time) 94 | """ 95 | a_pred = -a 96 | a_pred[:, 0, :] = 0.0 97 | 98 | # predictor filter 99 | p = self.lfilter.forward(x=x, b=a_pred, a=None) 100 | 101 | return p -------------------------------------------------------------------------------- /src/neural_formant_synthesis/glotnet/sigproc/lsf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def conv(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 5 | """ Convolution of two signals 6 | 7 | Args: 8 | x: First signal (batch, channels, time_x) 9 | y: Second signal (batch, channels, time_y) 10 | 11 | Returns: 12 | torch.Tensor: Convolution (batch, channels, time_x + time_y - 1) 13 | 14 | """ 15 | 16 | # if x.shape[0] != y.shape[0]: 17 | # raise ValueError("x and y must have same batch size") 18 | # if x.shape[1] != y.shape[1]: 19 | # raise ValueError("x and y must have same number of channels") 20 | 21 | length = x.shape[-1] + y.shape[-1] - 1 22 | 23 | X = torch.fft.rfft(x, n=length, dim=-1) 24 | Y = torch.fft.rfft(y, n=length, dim=-1) 25 | Z = X * Y 26 | z = torch.fft.irfft(Z, dim=-1, n=length) 27 | 28 | return z 29 | 30 | def lsf2poly(lsf: torch.Tensor) -> torch.Tensor: 31 | """ Line Spectral Frequencies to Polynomial Coefficients 32 | 33 | Args: 34 | lsf (torch.Tensor): Line spectral frequencies (batch, time, order) 35 | 36 | Returns: 37 | poly (torch.Tensor): Polynomial coefficients (batch, time, order+1) 38 | 39 | 40 | References: 41 | https://github.com/cokelaer/spectrum/blob/master/src/spectrum/linear_prediction.py 42 | """ 43 | 44 | if lsf.min() < 0: 45 | raise ValueError("lsf must be non-negative") 46 | if lsf.max() > torch.pi: 47 | raise ValueError("lsf must be less than pi") 48 | 49 | order = lsf.shape[-1] 50 | 51 | lsf, _ = torch.sort(lsf, dim=-1) 52 | 53 | # split to P and Q 54 | wP = lsf[:, :, ::2].unsqueeze(-1) 55 | wQ = lsf[:, :, 1::2].unsqueeze(-1) 56 | 57 | P_len = wP.shape[-2] 58 | Q_len = wQ.shape[-2] 59 | 60 | # compute conjugate pair polynomials 61 | pad = (1,1,0,0) 62 | Pi = F.pad(-2.0*torch.cos(wP), pad, mode='constant', value=1.0) 63 | Qi = F.pad(-2.0*torch.cos(wQ), pad, mode='constant', value=1.0) 64 | 65 | # Pi = torch.cat(1.0, -2 * torch.cos(wP), 1.0) 66 | # Qi = torch.cat(1.0, -2 * torch.cos(wQ), 1.0) 67 | 68 | # construct polynomials 69 | P = Pi[:,:, 0, :] 70 | for i in range(1, P_len): 71 | P = conv(P, Pi[:, :, i, :]) 72 | 73 | Q = Qi[:, :, 0, :] 74 | for i in range(1, Q_len): 75 | Q = conv(Q, Qi[:, :, i, :]) 76 | 77 | # add trivial zeros 78 | if order % 2 == 0: 79 | # P = conv(P, torch.tensor([1.0, -1.0]).reshape(1, 1, -1)) 80 | # Q = conv(Q, torch.tensor([1.0, 1.0]).reshape(1, 1, -1)) 81 | P = conv(P, torch.tensor([1.0, 1.0]).reshape(1, 1, -1)) 82 | Q = conv(Q, torch.tensor([-1.0, 1.0]).reshape(1, 1, -1)) 83 | else: 84 | # Q = conv(Q, torch.tensor([1.0, 0.0, -1.0]).reshape(1, 1, -1)) 85 | Q = conv(Q, torch.tensor([-1.0, 0.0, 1.0]).reshape(1, 1, -1)) 86 | 87 | # compute polynomial coefficients 88 | A = 0.5 * (P + Q) 89 | 90 | return A[:, :, 1:].flip(-1) 91 | 92 | 93 | -------------------------------------------------------------------------------- /src/neural_formant_synthesis/glotnet/sigproc/melspec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | 4 | from .levinson import levinson, spectrum_to_allpole 5 | 6 | class SpectralNormalization(torch.nn.Module): 7 | def forward(self, input): 8 | return torch.log(torch.clamp(input, min=1e-5)) 9 | 10 | class InverseSpectralNormalization(torch.nn.Module): 11 | def forward(self, input): 12 | return torch.exp(input) 13 | 14 | # Tacotron 2 reference configuration 15 | #https://github.com/pytorch/audio/blob/6b2b6c79ca029b4aa9bdb72d12ad061b144c2410/examples/pipeline_tacotron2/train.py#L284 16 | class LogMelSpectrogram(torch.nn.Module): 17 | """ Log Mel Spectrogram """ 18 | 19 | def __init__(self, 20 | sample_rate: int = 16000, 21 | n_fft: int = 1024, 22 | win_length: int = None, 23 | hop_length: int = None, 24 | n_mels: int = 80, 25 | f_min: float = 0.0, 26 | f_max: float = None, 27 | mel_scale: str = "slaney", 28 | normalized: bool = False, 29 | power: float = 1.0, 30 | norm: str = "slaney" 31 | ): 32 | super().__init__() 33 | 34 | self.mel_spectrogram = torchaudio.transforms.MelSpectrogram( 35 | sample_rate=sample_rate, 36 | n_fft=n_fft, 37 | win_length=win_length, 38 | hop_length=hop_length, 39 | f_min=f_min, 40 | f_max=f_max, 41 | n_mels=n_mels, 42 | mel_scale=mel_scale, 43 | normalized=normalized, 44 | power=power, 45 | norm=norm, 46 | ) 47 | self.log = SpectralNormalization() 48 | self.exp = InverseSpectralNormalization() 49 | 50 | fb_pinv = torch.linalg.pinv(self.mel_spectrogram.mel_scale.fb) 51 | self.register_buffer('fb_pinv', fb_pinv) 52 | 53 | 54 | def forward(self, x: torch.Tensor) -> torch.Tensor: 55 | """ 56 | Args: 57 | x : input signal 58 | (batch, channels, timesteps) 59 | """ 60 | X = self.mel_spectrogram(x) 61 | 62 | return self.log(X) 63 | 64 | def allpole(self, X: torch.Tensor, order:int = 20) -> torch.Tensor: 65 | """ 66 | Args: 67 | X : input mel spectrogram 68 | order : all pole model order 69 | 70 | Returns: 71 | a : allpole filter coefficients 72 | 73 | """ 74 | 75 | # invert normalization 76 | X = self.exp(X) 77 | 78 | # (..., F, T) -> (..., T, F) 79 | X = X.transpose(-1, -2) 80 | # pseudoinvert mel filterbank 81 | X = torch.matmul(X, self.fb_pinv).clamp(min=1e-9) 82 | 83 | # power spectrum (squared magnitude) spectrum 84 | X = torch.pow(X, 2.0 / self.mel_spectrogram.power) 85 | X = X.clamp(min=1e-9) 86 | 87 | g, a = spectrum_to_allpole(X, order=order) 88 | # (..., T, order) -> (..., order, T) 89 | a = a.transpose(-1, -2) 90 | return a 91 | 92 | -------------------------------------------------------------------------------- /src/neural_formant_synthesis/glotnet/sigproc/oscillator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class Oscillator(torch.nn.Module): 4 | """ Sinusoidal oscillator """ 5 | def __init__(self, 6 | audio_rate:int=48000, 7 | control_rate:int=200, 8 | shape:str='sin'): 9 | """ 10 | Args: 11 | audio_rate: audio sample rate in samples per second 12 | control_rate: control sample rate in samples per second 13 | typically equal to 1 / frame_length 14 | """ 15 | super().__init__() 16 | 17 | self.audio_rate = audio_rate 18 | self.control_rate = control_rate 19 | self.nyquist_rate = audio_rate // 2 20 | 21 | upsample_factor = self.audio_rate // self.control_rate 22 | self.upsampler = torch.nn.modules.Upsample( 23 | mode='linear', 24 | scale_factor=upsample_factor, 25 | align_corners=False) 26 | 27 | self.audio_step = 1.0 / audio_rate 28 | self.control_step = 1.0 / control_rate 29 | self.shape = shape 30 | 31 | 32 | def forward(self, f0, init_phase=None): 33 | """ 34 | Args: 35 | f0 : fundamental frequency, shape (batch_size, channels, num_frames) 36 | Returns: 37 | x : sinusoid, shape (batch_size, channels, num_samples) 38 | """ 39 | 40 | f0 = torch.clamp(f0, min=0.0, max=self.nyquist_rate) 41 | inst_freq = self.upsampler(f0) 42 | if_shape = inst_freq.shape 43 | if init_phase is None: 44 | # random initial phase in range [-pi, pi] 45 | init_phase = 2 * torch.pi * (torch.rand(if_shape[0], if_shape[1], 1) - 0.5) 46 | # integrate instantaneous frequency for phase 47 | phase = torch.cumsum(2 * torch.pi * inst_freq * self.audio_step, dim=-1) 48 | 49 | if self.shape == 'sin': 50 | return torch.sin(phase + init_phase) 51 | elif self.shape == 'saw': 52 | return (torch.fmod(phase + init_phase, 2 * torch.pi) - torch.pi) / torch.pi 53 | -------------------------------------------------------------------------------- /src/neural_formant_synthesis/sigproc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ljuvela/SourceFilterNeuralFormants/d0894c2aa153510e6967c3b62c47f73ce4cb3879/src/neural_formant_synthesis/sigproc/__init__.py -------------------------------------------------------------------------------- /src/neural_formant_synthesis/sigproc/levinson.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def toeplitz(r: torch.Tensor): 5 | """" Construct Toeplitz matrix """ 6 | p = r.size(-1) 7 | rr = torch.cat([r, r[..., 1:].flip(dims=(-1,))], dim=-1) 8 | T = [torch.roll(rr, i, dims=(-1,))[...,:p] for i in range(p)] 9 | return torch.stack(T, dim=-1) 10 | 11 | 12 | def levinson(R:torch.Tensor, M:int, eps:float=1e-3) -> torch.Tensor: 13 | """ Levinson-Durbin method for converting autocorrelation to all-pole polynomial 14 | Args: 15 | R: autocorrelation, shape=(..., M) 16 | M: filter polynomial order 17 | Returns: 18 | A: all-pole polynomial, shape=(..., M) 19 | """ 20 | # normalize R 21 | R = R / R[..., 0:1] 22 | # white noise correction 23 | R[..., 0] = R[..., 0] + eps 24 | 25 | # zero lag 26 | K = torch.sum(R[..., 1:2], dim=-1, keepdim=True) 27 | A = torch.cat([-1.0*K, torch.ones_like(R[..., 0:1])], dim=-1) 28 | E = 1.0 - K ** 2 29 | # higher lags 30 | for p in torch.arange(1, M): 31 | K = torch.sum(A[..., 0:p+1] * R[..., 1:p+2], dim=-1, keepdim=True) / E 32 | if K.abs().max() > 1.0: 33 | raise ValueError(f"Unstable filter, |K| was {K.abs().max()}") 34 | A = torch.cat([-1.0*K, 35 | A[..., 0:p] - 1.0*K * 36 | torch.flip(A[..., 0:p], dims=[-1]), 37 | torch.ones_like(R[..., 0:1])], dim=-1) 38 | E = E * (1.0 - K ** 2) 39 | A = torch.flip(A, dims=[-1]) 40 | return A 41 | 42 | 43 | def forward_levinson(K: torch.Tensor, M: int = None) -> torch.Tensor: 44 | """ Forward Levinson converts reflection coefficients to all-pole polynomial 45 | 46 | Args: 47 | K: reflection coefficients, shape=(..., M) 48 | M: filter polynomial order (optional, defaults to K.size(-1)) 49 | Returns: 50 | A: all-pole polynomial, shape=(..., M+1) 51 | 52 | """ 53 | if M is None: 54 | M = K.size(-1) 55 | 56 | A = -1.0*K[..., 0:1] 57 | for p in torch.arange(1, M): 58 | A = torch.cat([-1.0*K[..., p:p+1], 59 | A[..., 0:p] - 1.0*K[..., p:p+1] * torch.flip(A[..., 0:p], dims=[-1])], dim=-1) 60 | 61 | A = torch.cat([A, torch.ones_like(A[..., 0:1])], dim=-1) 62 | A = torch.flip(A, dims=[-1]) # flip zero delay to zero:th index 63 | return A 64 | 65 | 66 | def spectrum_to_allpole(spectrum:torch.Tensor, order:int, root_scale:float=1.0): 67 | """ Convert spectrum to all-pole filter coefficients 68 | 69 | Args: 70 | spectrum: power spectrum (squared magnitude), shape=(..., K) 71 | order: filter polynomial order 72 | 73 | Returns: 74 | a: filter predictor polynomial tensor, shape=(..., order+1) 75 | g: filter gain 76 | 77 | """ 78 | r = torch.fft.irfft(spectrum, dim=-1) 79 | # add small value to diagonal to avoid singular matrix 80 | r[..., 0] = r[..., 0] + 1e-5 81 | # all pole from autocorr 82 | a = levinson(r, order) 83 | 84 | # filter gain 85 | g = torch.sqrt(torch.sum(r[..., :(order+1)] * a, dim=-1, keepdim=True)) 86 | 87 | # scale filter roots 88 | if root_scale < 1.0: 89 | a = a * root_scale ** torch.arange(order+1, dtype=torch.float32, device=a.device) 90 | 91 | return a, g -------------------------------------------------------------------------------- /src/neural_formant_synthesis/sigproc/lpc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from neural_formant_synthesis.glotnet.sigproc.lfilter import LFilter 3 | from .levinson import spectrum_to_allpole 4 | 5 | 6 | class LinearPredictor(torch.nn.Module): 7 | 8 | def __init__(self, 9 | n_fft: int = 512, 10 | hop_length: int = 256, 11 | win_length=512, 12 | order=10): 13 | """ 14 | Args: 15 | n_fft (int): FFT size 16 | hop_length (int): Hop length 17 | win_length (int): Window length 18 | order (int): Allpole order 19 | """ 20 | super().__init__() 21 | self.n_fft = n_fft 22 | self.hop_length = hop_length 23 | self.win_length = win_length 24 | self.order = order 25 | self.lfilter = LFilter(n_fft=n_fft, hop_length=hop_length, win_length=win_length) 26 | self.register_buffer('window', torch.hann_window(win_length)) 27 | 28 | def estimate(self, x: torch.Tensor, root_scale:float=1.0) -> torch.Tensor: 29 | """ 30 | Args: 31 | x (torch.Tensor): Input audio (batch, time) 32 | 33 | Returns: 34 | torch.Tensor: Allpole coefficients (batch, order+1, time) 35 | torch.Tensor: Gain (batch, 1, time) 36 | """ 37 | X = torch.stft(x, n_fft=self.n_fft, 38 | hop_length=self.hop_length, win_length=self.win_length, 39 | window=self.window, 40 | return_complex=True) 41 | 42 | # power spectrum 43 | X = torch.abs(X)**2 44 | 45 | # transpose to (batch, time, freq) 46 | X = X.transpose(1, 2) 47 | 48 | # allpole coefficients 49 | a, g = spectrum_to_allpole(X, order=self.order, root_scale=root_scale) 50 | 51 | # transpose to (batch, order, num_frames) 52 | a = a.transpose(1, 2) 53 | # transpose to (batch, 1, num_frames) 54 | g = g.transpose(1, 2) 55 | 56 | A = torch.fft.rfft(a, n=512, dim=1).abs() 57 | H = g / (A + 1e-6) 58 | # H = 1 / (A + 1e-6) 59 | 60 | return a, g, H 61 | 62 | def inverse_filter(self, 63 | x: torch.Tensor, 64 | a: torch.Tensor) -> torch.Tensor: 65 | """ 66 | Args: 67 | x: Input audio (batch, time) 68 | a: Allpole coefficients (batch, order+1, time) 69 | 70 | Returns: 71 | Inverse filtered audio (batch, time) 72 | """ 73 | # inverse filter 74 | e = self.lfilter.forward(x=x, b=a, a=None) 75 | 76 | return e 77 | 78 | def synthesis_filter(self, 79 | e: torch.Tensor, 80 | a: torch.Tensor, 81 | g: torch.Tensor) -> torch.Tensor: 82 | """ 83 | Args: 84 | e: Excitation signal (batch, channels, time) 85 | a: Allpole coefficients (batch, order+1, time) 86 | g: Gain for filters (batch, 1, time) 87 | 88 | Returns: 89 | torch.Tensor: Synthesis filtered audio (batch, time) 90 | """ 91 | # inverse filter 92 | x = self.lfilter.forward(x=e, b=g, a=a) 93 | return x 94 | 95 | 96 | def prediction(self, 97 | x: torch.Tensor, 98 | a: torch.Tensor) -> torch.Tensor: 99 | """ 100 | Args: 101 | x: Input audio (batch, channels, time) 102 | a: Allpole coefficients (batch, order+1, time) 103 | 104 | Returns: 105 | p: Linear prediction signal (batch, time) 106 | """ 107 | a_pred = -a 108 | a_pred[:, 0, :] = 0.0 109 | 110 | # predictor filter 111 | p = self.lfilter.forward(x=x, b=a_pred, a=None) 112 | 113 | return p -------------------------------------------------------------------------------- /src/neural_formant_synthesis/third_party/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ljuvela/SourceFilterNeuralFormants/d0894c2aa153510e6967c3b62c47f73ce4cb3879/src/neural_formant_synthesis/third_party/__init__.py -------------------------------------------------------------------------------- /src/neural_formant_synthesis/third_party/hifi_gan/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Jungil Kong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /src/neural_formant_synthesis/third_party/hifi_gan/README.md: -------------------------------------------------------------------------------- 1 | # HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis 2 | 3 | ### Jungil Kong, Jaehyeon Kim, Jaekyoung Bae 4 | 5 | In our [paper](https://arxiv.org/abs/2010.05646), 6 | we proposed HiFi-GAN: a GAN-based model capable of generating high fidelity speech efficiently.
7 | We provide our implementation and pretrained models as open source in this repository. 8 | 9 | **Abstract :** 10 | Several recent work on speech synthesis have employed generative adversarial networks (GANs) to produce raw waveforms. 11 | Although such methods improve the sampling efficiency and memory usage, 12 | their sample quality has not yet reached that of autoregressive and flow-based generative models. 13 | In this work, we propose HiFi-GAN, which achieves both efficient and high-fidelity speech synthesis. 14 | As speech audio consists of sinusoidal signals with various periods, 15 | we demonstrate that modeling periodic patterns of an audio is crucial for enhancing sample quality. 16 | A subjective human evaluation (mean opinion score, MOS) of a single speaker dataset indicates that our proposed method 17 | demonstrates similarity to human quality while generating 22.05 kHz high-fidelity audio 167.9 times faster than 18 | real-time on a single V100 GPU. We further show the generality of HiFi-GAN to the mel-spectrogram inversion of unseen 19 | speakers and end-to-end speech synthesis. Finally, a small footprint version of HiFi-GAN generates samples 13.4 times 20 | faster than real-time on CPU with comparable quality to an autoregressive counterpart. 21 | 22 | Visit our [demo website](https://jik876.github.io/hifi-gan-demo/) for audio samples. 23 | 24 | 25 | ## Pre-requisites 26 | 1. Python >= 3.6 27 | 2. Clone this repository. 28 | 3. Install python requirements. Please refer [requirements.txt](requirements.txt) 29 | 4. Download and extract the [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/). 30 | And move all wav files to `LJSpeech-1.1/wavs` 31 | 32 | 33 | ## Training 34 | ``` 35 | python train.py --config config_v1.json 36 | ``` 37 | To train V2 or V3 Generator, replace `config_v1.json` with `config_v2.json` or `config_v3.json`.
38 | Checkpoints and copy of the configuration file are saved in `cp_hifigan` directory by default.
39 | You can change the path by adding `--checkpoint_path` option. 40 | 41 | 42 | ## Pretrained Model 43 | You can also use pretrained models we provide.
44 | [Download pretrained models](https://drive.google.com/drive/folders/1-eEYTB5Av9jNql0WGBlRoi-WH2J7bp5Y?usp=sharing)
45 | Details of each folder are as in follows: 46 | 47 | |Folder Name|Generator|Dataset|Fine-Tuned| 48 | |------|---|---|---| 49 | |LJ_V1|V1|LJSpeech|No| 50 | |LJ_V2|V2|LJSpeech|No| 51 | |LJ_V3|V3|LJSpeech|No| 52 | |LJ_FT_T2_V1|V1|LJSpeech|Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2))| 53 | |LJ_FT_T2_V2|V2|LJSpeech|Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2))| 54 | |LJ_FT_T2_V3|V3|LJSpeech|Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2))| 55 | |VCTK_V1|V1|VCTK|No| 56 | |VCTK_V2|V2|VCTK|No| 57 | |VCTK_V3|V3|VCTK|No| 58 | |UNIVERSAL_V1|V1|Universal|No| 59 | 60 | We provide the universal model with discriminator weights that can be used as a base for transfer learning to other datasets. 61 | 62 | ## Fine-Tuning 63 | 1. Generate mel-spectrograms in numpy format using [Tacotron2](https://github.com/NVIDIA/tacotron2) with teacher-forcing.
64 | The file name of the generated mel-spectrogram should match the audio file and the extension should be `.npy`.
65 | Example: 66 | ``` 67 | Audio File : LJ001-0001.wav 68 | Mel-Spectrogram File : LJ001-0001.npy 69 | ``` 70 | 2. Create `ft_dataset` folder and copy the generated mel-spectrogram files into it.
71 | 3. Run the following command. 72 | ``` 73 | python train.py --fine_tuning True --config config_v1.json 74 | ``` 75 | For other command line options, please refer to the training section. 76 | 77 | 78 | ## Inference from wav file 79 | 1. Make `test_files` directory and copy wav files into the directory. 80 | 2. Run the following command. 81 | ``` 82 | python inference.py --checkpoint_file [generator checkpoint file path] 83 | ``` 84 | Generated wav files are saved in `generated_files` by default.
85 | You can change the path by adding `--output_dir` option. 86 | 87 | 88 | ## Inference for end-to-end speech synthesis 89 | 1. Make `test_mel_files` directory and copy generated mel-spectrogram files into the directory.
90 | You can generate mel-spectrograms using [Tacotron2](https://github.com/NVIDIA/tacotron2), 91 | [Glow-TTS](https://github.com/jaywalnut310/glow-tts) and so forth. 92 | 2. Run the following command. 93 | ``` 94 | python inference_e2e.py --checkpoint_file [generator checkpoint file path] 95 | ``` 96 | Generated wav files are saved in `generated_files_from_mel` by default.
97 | You can change the path by adding `--output_dir` option. 98 | 99 | 100 | ## Acknowledgements 101 | We referred to [WaveGlow](https://github.com/NVIDIA/waveglow), [MelGAN](https://github.com/descriptinc/melgan-neurips) 102 | and [Tacotron2](https://github.com/NVIDIA/tacotron2) to implement this. 103 | 104 | -------------------------------------------------------------------------------- /src/neural_formant_synthesis/third_party/hifi_gan/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ljuvela/SourceFilterNeuralFormants/d0894c2aa153510e6967c3b62c47f73ce4cb3879/src/neural_formant_synthesis/third_party/hifi_gan/__init__.py -------------------------------------------------------------------------------- /src/neural_formant_synthesis/third_party/hifi_gan/config_v1.json: -------------------------------------------------------------------------------- 1 | { 2 | "resblock": "1", 3 | "num_gpus": 0, 4 | "batch_size": 16, 5 | "learning_rate": 0.0002, 6 | "adam_b1": 0.8, 7 | "adam_b2": 0.99, 8 | "lr_decay": 0.999, 9 | "seed": 1234, 10 | 11 | "upsample_rates": [8,8,2,2], 12 | "upsample_kernel_sizes": [16,16,4,4], 13 | "upsample_initial_channel": 512, 14 | "resblock_kernel_sizes": [3,7,11], 15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 16 | 17 | "segment_size": 8192, 18 | "num_mels": 80, 19 | "num_freq": 1025, 20 | "n_fft": 1024, 21 | "hop_size": 256, 22 | "win_size": 1024, 23 | 24 | "sampling_rate": 22050, 25 | 26 | "fmin": 0, 27 | "fmax": 8000, 28 | "fmax_for_loss": null, 29 | 30 | "num_workers": 4, 31 | 32 | "dist_config": { 33 | "dist_backend": "nccl", 34 | "dist_url": "tcp://localhost:54321", 35 | "world_size": 1 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /src/neural_formant_synthesis/third_party/hifi_gan/config_v2.json: -------------------------------------------------------------------------------- 1 | { 2 | "resblock": "1", 3 | "num_gpus": 0, 4 | "batch_size": 16, 5 | "learning_rate": 0.0002, 6 | "adam_b1": 0.8, 7 | "adam_b2": 0.99, 8 | "lr_decay": 0.999, 9 | "seed": 1234, 10 | 11 | "upsample_rates": [8,8,2,2], 12 | "upsample_kernel_sizes": [16,16,4,4], 13 | "upsample_initial_channel": 128, 14 | "resblock_kernel_sizes": [3,7,11], 15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 16 | 17 | "segment_size": 8192, 18 | "num_mels": 80, 19 | "num_freq": 1025, 20 | "n_fft": 1024, 21 | "hop_size": 256, 22 | "win_size": 1024, 23 | 24 | "sampling_rate": 22050, 25 | 26 | "fmin": 0, 27 | "fmax": 8000, 28 | "fmax_for_loss": null, 29 | 30 | "num_workers": 4, 31 | 32 | "dist_config": { 33 | "dist_backend": "nccl", 34 | "dist_url": "tcp://localhost:54321", 35 | "world_size": 1 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /src/neural_formant_synthesis/third_party/hifi_gan/config_v3.json: -------------------------------------------------------------------------------- 1 | { 2 | "resblock": "2", 3 | "num_gpus": 0, 4 | "batch_size": 16, 5 | "learning_rate": 0.0002, 6 | "adam_b1": 0.8, 7 | "adam_b2": 0.99, 8 | "lr_decay": 0.999, 9 | "seed": 1234, 10 | 11 | "upsample_rates": [8,8,4], 12 | "upsample_kernel_sizes": [16,16,8], 13 | "upsample_initial_channel": 256, 14 | "resblock_kernel_sizes": [3,5,7], 15 | "resblock_dilation_sizes": [[1,2], [2,6], [3,12]], 16 | 17 | "segment_size": 8192, 18 | "num_mels": 80, 19 | "num_freq": 1025, 20 | "n_fft": 1024, 21 | "hop_size": 256, 22 | "win_size": 1024, 23 | 24 | "sampling_rate": 22050, 25 | 26 | "fmin": 0, 27 | "fmax": 8000, 28 | "fmax_for_loss": null, 29 | 30 | "num_workers": 4, 31 | 32 | "dist_config": { 33 | "dist_backend": "nccl", 34 | "dist_url": "tcp://localhost:54321", 35 | "world_size": 1 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /src/neural_formant_synthesis/third_party/hifi_gan/env.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | 5 | class AttrDict(dict): 6 | def __init__(self, *args, **kwargs): 7 | super(AttrDict, self).__init__(*args, **kwargs) 8 | self.__dict__ = self 9 | 10 | 11 | def build_env(config, config_name, path): 12 | t_path = os.path.join(path, config_name) 13 | if config != t_path: 14 | os.makedirs(path, exist_ok=True) 15 | shutil.copyfile(config, os.path.join(path, config_name)) 16 | -------------------------------------------------------------------------------- /src/neural_formant_synthesis/third_party/hifi_gan/inference.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | import glob 4 | import os 5 | import argparse 6 | import json 7 | import torch 8 | from env import AttrDict 9 | from meldataset import mel_spectrogram, MAX_WAV_VALUE, load_wav 10 | from models import Generator 11 | import soundfile as sf 12 | 13 | h = None 14 | device = None 15 | 16 | 17 | def load_checkpoint(filepath, device): 18 | assert os.path.isfile(filepath) 19 | print("Loading '{}'".format(filepath)) 20 | checkpoint_dict = torch.load(filepath, map_location=device) 21 | print("Complete.") 22 | return checkpoint_dict 23 | 24 | 25 | def get_mel(x): 26 | return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax) 27 | 28 | 29 | def scan_checkpoint(cp_dir, prefix): 30 | pattern = os.path.join(cp_dir, prefix + '*') 31 | cp_list = glob.glob(pattern) 32 | if len(cp_list) == 0: 33 | return '' 34 | return sorted(cp_list)[-1] 35 | 36 | 37 | def inference(a): 38 | generator = Generator(h).to(device) 39 | 40 | state_dict_g = load_checkpoint(a.checkpoint_file, device) 41 | generator.load_state_dict(state_dict_g['generator']) 42 | 43 | filelist = os.listdir(a.input_wavs_dir) 44 | 45 | os.makedirs(a.output_dir, exist_ok=True) 46 | 47 | generator.eval() 48 | generator.remove_weight_norm() 49 | with torch.no_grad(): 50 | for i, filname in enumerate(filelist): 51 | wav, sr = load_wav(os.path.join(a.input_wavs_dir, filname)) 52 | wav = torch.FloatTensor(wav).to(device) 53 | x = get_mel(wav.unsqueeze(0)) 54 | y_g_hat = generator(x) 55 | audio = y_g_hat.squeeze() 56 | audio = audio.cpu().numpy() 57 | 58 | output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + '_generated.wav') 59 | sf.write(output_file, audio, sr) 60 | print(output_file) 61 | 62 | 63 | def main(): 64 | print('Initializing Inference Process..') 65 | 66 | parser = argparse.ArgumentParser() 67 | parser.add_argument('--input_wavs_dir', default='test_files') 68 | parser.add_argument('--output_dir', default='generated_files') 69 | parser.add_argument('--checkpoint_file', required=True) 70 | a = parser.parse_args() 71 | 72 | config_file = os.path.join(os.path.split(a.checkpoint_file)[0], 'config.json') 73 | with open(config_file) as f: 74 | data = f.read() 75 | 76 | global h 77 | json_config = json.loads(data) 78 | h = AttrDict(json_config) 79 | 80 | torch.manual_seed(h.seed) 81 | global device 82 | if torch.cuda.is_available(): 83 | torch.cuda.manual_seed(h.seed) 84 | device = torch.device('cuda') 85 | else: 86 | device = torch.device('cpu') 87 | 88 | inference(a) 89 | 90 | 91 | if __name__ == '__main__': 92 | main() 93 | 94 | -------------------------------------------------------------------------------- /src/neural_formant_synthesis/third_party/hifi_gan/inference_e2e.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | import glob 4 | import os 5 | import numpy as np 6 | import argparse 7 | import json 8 | import torch 9 | from env import AttrDict 10 | from meldataset import MAX_WAV_VALUE 11 | from models import Generator 12 | 13 | h = None 14 | device = None 15 | 16 | 17 | def load_checkpoint(filepath, device): 18 | assert os.path.isfile(filepath) 19 | print("Loading '{}'".format(filepath)) 20 | checkpoint_dict = torch.load(filepath, map_location=device) 21 | print("Complete.") 22 | return checkpoint_dict 23 | 24 | 25 | def scan_checkpoint(cp_dir, prefix): 26 | pattern = os.path.join(cp_dir, prefix + '*') 27 | cp_list = glob.glob(pattern) 28 | if len(cp_list) == 0: 29 | return '' 30 | return sorted(cp_list)[-1] 31 | 32 | 33 | def inference(a): 34 | generator = Generator(h).to(device) 35 | 36 | state_dict_g = load_checkpoint(a.checkpoint_file, device) 37 | generator.load_state_dict(state_dict_g['generator']) 38 | 39 | filelist = os.listdir(a.input_mels_dir) 40 | 41 | os.makedirs(a.output_dir, exist_ok=True) 42 | 43 | generator.eval() 44 | generator.remove_weight_norm() 45 | with torch.no_grad(): 46 | for i, filname in enumerate(filelist): 47 | x = np.load(os.path.join(a.input_mels_dir, filname)) 48 | x = torch.FloatTensor(x).to(device) 49 | y_g_hat = generator(x) 50 | audio = y_g_hat.squeeze() 51 | audio = audio * MAX_WAV_VALUE 52 | audio = audio.cpu().numpy().astype('int16') 53 | 54 | output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + '_generated_e2e.wav') 55 | write(output_file, h.sampling_rate, audio) 56 | print(output_file) 57 | 58 | 59 | def main(): 60 | print('Initializing Inference Process..') 61 | 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument('--input_mels_dir', default='test_mel_files') 64 | parser.add_argument('--output_dir', default='generated_files_from_mel') 65 | parser.add_argument('--checkpoint_file', required=True) 66 | a = parser.parse_args() 67 | 68 | config_file = os.path.join(os.path.split(a.checkpoint_file)[0], 'config.json') 69 | with open(config_file) as f: 70 | data = f.read() 71 | 72 | global h 73 | json_config = json.loads(data) 74 | h = AttrDict(json_config) 75 | 76 | torch.manual_seed(h.seed) 77 | global device 78 | if torch.cuda.is_available(): 79 | torch.cuda.manual_seed(h.seed) 80 | device = torch.device('cuda') 81 | else: 82 | device = torch.device('cpu') 83 | 84 | inference(a) 85 | 86 | 87 | if __name__ == '__main__': 88 | main() 89 | 90 | -------------------------------------------------------------------------------- /src/neural_formant_synthesis/third_party/hifi_gan/meldataset.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | import torch 5 | import torch.utils.data 6 | import numpy as np 7 | import soundfile as sf 8 | 9 | import torchaudio 10 | 11 | MAX_WAV_VALUE = 32768.0 12 | 13 | 14 | def load_wav(full_path): 15 | data, sampling_rate = sf.read(full_path) 16 | # sampling_rate, data = read(full_path) 17 | return data, sampling_rate 18 | 19 | 20 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 21 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) 22 | 23 | 24 | def dynamic_range_decompression(x, C=1): 25 | return np.exp(x) / C 26 | 27 | 28 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 29 | return torch.log(torch.clamp(x, min=clip_val) * C) 30 | 31 | 32 | def dynamic_range_decompression_torch(x, C=1): 33 | return torch.exp(x) / C 34 | 35 | 36 | def spectral_normalize_torch(magnitudes): 37 | output = dynamic_range_compression_torch(magnitudes) 38 | return output 39 | 40 | 41 | def spectral_de_normalize_torch(magnitudes): 42 | output = dynamic_range_decompression_torch(magnitudes) 43 | return output 44 | 45 | 46 | mel_basis = {} 47 | hann_window = {} 48 | 49 | 50 | def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): 51 | if torch.min(y) < -1.: 52 | print('min value is ', torch.min(y)) 53 | if torch.max(y) > 1.: 54 | print('max value is ', torch.max(y)) 55 | 56 | global mel_basis, hann_window 57 | if fmax not in mel_basis: 58 | mspec = torchaudio.transforms.MelSpectrogram( 59 | sample_rate=sampling_rate, 60 | n_fft=n_fft, 61 | f_min=fmin, 62 | f_max=fmax, 63 | n_mels=num_mels, 64 | norm='slaney', 65 | mel_scale='slaney', 66 | ) 67 | mel = mspec.mel_scale.fb.T.numpy() 68 | mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device) 69 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) 70 | 71 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') 72 | y = y.squeeze(1) 73 | 74 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], 75 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True) 76 | 77 | spec = spec.abs() + 1e-6 78 | 79 | spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec) 80 | spec = spectral_normalize_torch(spec) 81 | 82 | return spec 83 | 84 | 85 | def get_dataset_filelist(a, ext='.wav'): 86 | if hasattr(a, 'wavefile_ext'): 87 | ext = a.wavefile_ext 88 | with open(a.input_training_file, 'r', encoding='utf-8') as fi: 89 | training_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + ext) 90 | for x in fi.read().split('\n') if len(x) > 0] 91 | 92 | with open(a.input_validation_file, 'r', encoding='utf-8') as fi: 93 | validation_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + ext) 94 | for x in fi.read().split('\n') if len(x) > 0] 95 | return training_files, validation_files 96 | 97 | 98 | class MelDataset(torch.utils.data.Dataset): 99 | def __init__(self, training_files, segment_size, n_fft, num_mels, 100 | hop_size, win_size, sampling_rate, fmin, fmax, split=True, shuffle=True, n_cache_reuse=1, 101 | device=None, fmax_loss=None, fine_tuning=False, base_mels_path=None): 102 | self.audio_files = training_files 103 | random.seed(1234) 104 | if shuffle: 105 | random.shuffle(self.audio_files) 106 | self.segment_size = segment_size 107 | self.sampling_rate = sampling_rate 108 | self.split = split 109 | self.n_fft = n_fft 110 | self.num_mels = num_mels 111 | self.hop_size = hop_size 112 | self.win_size = win_size 113 | self.fmin = fmin 114 | self.fmax = fmax 115 | self.fmax_loss = fmax_loss 116 | self.cached_wav = None 117 | self.n_cache_reuse = n_cache_reuse 118 | self._cache_ref_count = 0 119 | self.device = device 120 | self.fine_tuning = fine_tuning 121 | self.base_mels_path = base_mels_path 122 | 123 | def __getitem__(self, index): 124 | filename = self.audio_files[index] 125 | if self._cache_ref_count == 0: 126 | audio, sampling_rate = load_wav(filename) 127 | self.cached_wav = audio 128 | # TODO: resample if needed 129 | if sampling_rate != self.sampling_rate: 130 | raise ValueError("{} SR doesn't match target {} SR".format( 131 | sampling_rate, self.sampling_rate)) 132 | self._cache_ref_count = self.n_cache_reuse 133 | else: 134 | audio = self.cached_wav 135 | self._cache_ref_count -= 1 136 | 137 | audio = torch.FloatTensor(audio) 138 | audio = audio.unsqueeze(0) 139 | 140 | if not self.fine_tuning: 141 | if self.split: 142 | if audio.size(1) >= self.segment_size: 143 | max_audio_start = audio.size(1) - self.segment_size 144 | audio_start = random.randint(0, max_audio_start) 145 | audio = audio[:, audio_start:audio_start+self.segment_size] 146 | else: 147 | audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant') 148 | 149 | mel = mel_spectrogram(audio, self.n_fft, self.num_mels, 150 | self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax, 151 | center=False) 152 | else: 153 | mel = np.load( 154 | os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + '.npy')) 155 | mel = torch.from_numpy(mel) 156 | 157 | if len(mel.shape) < 3: 158 | mel = mel.unsqueeze(0) 159 | 160 | if self.split: 161 | frames_per_seg = math.ceil(self.segment_size / self.hop_size) 162 | 163 | if audio.size(1) >= self.segment_size: 164 | mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1) 165 | mel = mel[:, :, mel_start:mel_start + frames_per_seg] 166 | audio = audio[:, mel_start * self.hop_size:(mel_start + frames_per_seg) * self.hop_size] 167 | else: 168 | mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), 'constant') 169 | audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant') 170 | 171 | mel_loss = mel_spectrogram(audio, self.n_fft, self.num_mels, 172 | self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax_loss, 173 | center=False) 174 | 175 | return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze()) 176 | 177 | def __len__(self): 178 | return len(self.audio_files) 179 | -------------------------------------------------------------------------------- /src/neural_formant_synthesis/third_party/hifi_gan/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d 5 | from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm 6 | from neural_formant_synthesis.third_party.hifi_gan.utils import init_weights, get_padding 7 | 8 | LRELU_SLOPE = 0.1 9 | 10 | 11 | class ResBlock1(torch.nn.Module): 12 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): 13 | super(ResBlock1, self).__init__() 14 | self.h = h 15 | self.convs1 = nn.ModuleList([ 16 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 17 | padding=get_padding(kernel_size, dilation[0]))), 18 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 19 | padding=get_padding(kernel_size, dilation[1]))), 20 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], 21 | padding=get_padding(kernel_size, dilation[2]))) 22 | ]) 23 | self.convs1.apply(init_weights) 24 | 25 | self.convs2 = nn.ModuleList([ 26 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 27 | padding=get_padding(kernel_size, 1))), 28 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 29 | padding=get_padding(kernel_size, 1))), 30 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 31 | padding=get_padding(kernel_size, 1))) 32 | ]) 33 | self.convs2.apply(init_weights) 34 | 35 | def forward(self, x): 36 | for c1, c2 in zip(self.convs1, self.convs2): 37 | xt = F.leaky_relu(x, LRELU_SLOPE) 38 | xt = c1(xt) 39 | xt = F.leaky_relu(xt, LRELU_SLOPE) 40 | xt = c2(xt) 41 | x = xt + x 42 | return x 43 | 44 | def remove_weight_norm(self): 45 | for l in self.convs1: 46 | remove_weight_norm(l) 47 | for l in self.convs2: 48 | remove_weight_norm(l) 49 | 50 | 51 | class ResBlock2(torch.nn.Module): 52 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): 53 | super(ResBlock2, self).__init__() 54 | self.h = h 55 | self.convs = nn.ModuleList([ 56 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 57 | padding=get_padding(kernel_size, dilation[0]))), 58 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 59 | padding=get_padding(kernel_size, dilation[1]))) 60 | ]) 61 | self.convs.apply(init_weights) 62 | 63 | def forward(self, x): 64 | for c in self.convs: 65 | xt = F.leaky_relu(x, LRELU_SLOPE) 66 | xt = c(xt) 67 | x = xt + x 68 | return x 69 | 70 | def remove_weight_norm(self): 71 | for l in self.convs: 72 | remove_weight_norm(l) 73 | 74 | 75 | class Generator(torch.nn.Module): 76 | def __init__(self, h, input_channels=80): 77 | super(Generator, self).__init__() 78 | self.h = h 79 | self.num_kernels = len(h.resblock_kernel_sizes) 80 | self.num_upsamples = len(h.upsample_rates) 81 | self.conv_pre = weight_norm(Conv1d(input_channels, h.upsample_initial_channel, 7, 1, padding=3)) 82 | resblock = ResBlock1 if h.resblock == '1' else ResBlock2 83 | 84 | self.ups = nn.ModuleList() 85 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): 86 | self.ups.append(weight_norm( 87 | ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)), 88 | k, u, padding=(k-u)//2))) 89 | 90 | self.resblocks = nn.ModuleList() 91 | for i in range(len(self.ups)): 92 | ch = h.upsample_initial_channel//(2**(i+1)) 93 | for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): 94 | self.resblocks.append(resblock(h, ch, k, d)) 95 | 96 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) 97 | self.ups.apply(init_weights) 98 | self.conv_post.apply(init_weights) 99 | 100 | def forward(self, x): 101 | x = self.conv_pre(x) 102 | for i in range(self.num_upsamples): 103 | x = F.leaky_relu(x, LRELU_SLOPE) 104 | x = self.ups[i](x) 105 | xs = None 106 | for j in range(self.num_kernels): 107 | if xs is None: 108 | xs = self.resblocks[i*self.num_kernels+j](x) 109 | else: 110 | xs += self.resblocks[i*self.num_kernels+j](x) 111 | x = xs / self.num_kernels 112 | x = F.leaky_relu(x) 113 | x = self.conv_post(x) 114 | x = torch.tanh(x) 115 | 116 | return x 117 | 118 | def remove_weight_norm(self): 119 | print('Removing weight norm...') 120 | for l in self.ups: 121 | remove_weight_norm(l) 122 | for l in self.resblocks: 123 | l.remove_weight_norm() 124 | remove_weight_norm(self.conv_pre) 125 | remove_weight_norm(self.conv_post) 126 | 127 | 128 | class DiscriminatorP(torch.nn.Module): 129 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 130 | super(DiscriminatorP, self).__init__() 131 | self.period = period 132 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 133 | self.convs = nn.ModuleList([ 134 | norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 135 | norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 136 | norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 137 | norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 138 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), 139 | ]) 140 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 141 | 142 | def forward(self, x): 143 | fmap = [] 144 | 145 | # 1d to 2d 146 | b, c, t = x.shape 147 | if t % self.period != 0: # pad first 148 | n_pad = self.period - (t % self.period) 149 | x = F.pad(x, (0, n_pad), "reflect") 150 | t = t + n_pad 151 | x = x.view(b, c, t // self.period, self.period) 152 | 153 | for l in self.convs: 154 | x = l(x) 155 | x = F.leaky_relu(x, LRELU_SLOPE) 156 | fmap.append(x) 157 | x = self.conv_post(x) 158 | fmap.append(x) 159 | x = torch.flatten(x, 1, -1) 160 | 161 | return x, fmap 162 | 163 | 164 | class MultiPeriodDiscriminator(torch.nn.Module): 165 | def __init__(self): 166 | super(MultiPeriodDiscriminator, self).__init__() 167 | self.discriminators = nn.ModuleList([ 168 | DiscriminatorP(2), 169 | DiscriminatorP(3), 170 | DiscriminatorP(5), 171 | DiscriminatorP(7), 172 | DiscriminatorP(11), 173 | ]) 174 | 175 | def forward(self, y, y_hat): 176 | y_d_rs = [] 177 | y_d_gs = [] 178 | fmap_rs = [] 179 | fmap_gs = [] 180 | for i, d in enumerate(self.discriminators): 181 | y_d_r, fmap_r = d(y) 182 | y_d_g, fmap_g = d(y_hat) 183 | y_d_rs.append(y_d_r) 184 | fmap_rs.append(fmap_r) 185 | y_d_gs.append(y_d_g) 186 | fmap_gs.append(fmap_g) 187 | 188 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 189 | 190 | 191 | class DiscriminatorS(torch.nn.Module): 192 | def __init__(self, use_spectral_norm=False): 193 | super(DiscriminatorS, self).__init__() 194 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 195 | self.convs = nn.ModuleList([ 196 | norm_f(Conv1d(1, 128, 15, 1, padding=7)), 197 | norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), 198 | norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), 199 | norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), 200 | norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), 201 | norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), 202 | norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), 203 | ]) 204 | self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) 205 | 206 | def forward(self, x): 207 | fmap = [] 208 | for l in self.convs: 209 | x = l(x) 210 | x = F.leaky_relu(x, LRELU_SLOPE) 211 | fmap.append(x) 212 | x = self.conv_post(x) 213 | fmap.append(x) 214 | x = torch.flatten(x, 1, -1) 215 | 216 | return x, fmap 217 | 218 | 219 | class MultiScaleDiscriminator(torch.nn.Module): 220 | def __init__(self): 221 | super(MultiScaleDiscriminator, self).__init__() 222 | self.discriminators = nn.ModuleList([ 223 | DiscriminatorS(use_spectral_norm=True), 224 | DiscriminatorS(), 225 | DiscriminatorS(), 226 | ]) 227 | self.meanpools = nn.ModuleList([ 228 | AvgPool1d(4, 2, padding=2), 229 | AvgPool1d(4, 2, padding=2) 230 | ]) 231 | 232 | def forward(self, y, y_hat): 233 | y_d_rs = [] 234 | y_d_gs = [] 235 | fmap_rs = [] 236 | fmap_gs = [] 237 | for i, d in enumerate(self.discriminators): 238 | if i != 0: 239 | y = self.meanpools[i-1](y) 240 | y_hat = self.meanpools[i-1](y_hat) 241 | y_d_r, fmap_r = d(y) 242 | y_d_g, fmap_g = d(y_hat) 243 | y_d_rs.append(y_d_r) 244 | fmap_rs.append(fmap_r) 245 | y_d_gs.append(y_d_g) 246 | fmap_gs.append(fmap_g) 247 | 248 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 249 | 250 | 251 | def feature_loss(fmap_r, fmap_g): 252 | loss = 0 253 | for dr, dg in zip(fmap_r, fmap_g): 254 | for rl, gl in zip(dr, dg): 255 | loss += torch.mean(torch.abs(rl - gl)) 256 | 257 | return loss*2 258 | 259 | 260 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 261 | loss = 0 262 | r_losses = [] 263 | g_losses = [] 264 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 265 | r_loss = torch.mean((1-dr)**2) 266 | g_loss = torch.mean(dg**2) 267 | loss += (r_loss + g_loss) 268 | r_losses.append(r_loss.item()) 269 | g_losses.append(g_loss.item()) 270 | 271 | return loss, r_losses, g_losses 272 | 273 | class DiscriminatorMetrics(): 274 | 275 | def __init__(self): 276 | 277 | self.true_accept_total = 0 278 | self.true_reject_total = 0 279 | self.false_accept_total = 0 280 | self.false_reject_total = 0 281 | self.num_samples_total = 0 282 | 283 | @property 284 | def accuracy(self): 285 | TA = self.true_accept_total 286 | TR = self.true_reject_total 287 | N = self.num_samples_total 288 | return 1.0 * (TA + TR) / N 289 | 290 | @property 291 | def false_accept_rate(self): 292 | return 1.0 * self.false_accept_total / self.num_samples_total 293 | 294 | @property 295 | def false_reject_rate(self): 296 | return 1.0 * self.false_reject_total / self.num_samples_total 297 | 298 | @property 299 | def equal_error_rate(self): 300 | return 0.5 * (self.false_rec) 301 | 302 | def accumulate(self, disc_real_outputs, disc_generated_outputs): 303 | """ 304 | Args: 305 | disc_real_outputs: 306 | shape is (batch, channels, timesteps) 307 | disc_generated_outputs 308 | shape is (batch, channels, timesteps) 309 | """ 310 | pred_real = [] 311 | pred_gen = [] 312 | # classifications for each discriminator 313 | for d_real, d_gen in zip(disc_real_outputs, disc_generated_outputs): 314 | # mean prediction over time and channels 315 | pred_real.append(torch.mean(d_real, dim=(-1,)) > 0.5) 316 | pred_gen.append(torch.mean(d_gen, dim=(-1,)) < 0.5) 317 | 318 | # Stack classifications from different discriminators 319 | pred_real = torch.stack(pred_real, dim=0) 320 | pred_gen = torch.stack(pred_gen, dim=0) 321 | 322 | # Majority vote (probabilites not available) 323 | pred_real_voted, _ = torch.median(pred_real * 1.0, dim=-1) 324 | pred_gen_voted, _ = torch.median(pred_gen * 1.0, dim=-1) 325 | 326 | if pred_real_voted.shape != pred_gen_voted.shape: 327 | raise ValueError("Real and generated batch sizes must match") 328 | 329 | N = pred_real_voted.shape[0] + pred_gen_voted.shape[0] 330 | 331 | # True Accept 332 | TA = pred_real_voted.sum() 333 | 334 | # True Reject 335 | TR = pred_gen_voted.sum() 336 | 337 | # False Accept 338 | FA = N - TR 339 | 340 | # False Reject 341 | FR = N - TA 342 | 343 | self.true_accept_total += TA 344 | self.true_reject_total += TR 345 | self.false_accept_total += FA 346 | self.false_reject_total += FR 347 | self.num_samples_total += N 348 | 349 | 350 | 351 | 352 | def generator_adversarial_loss(disc_outputs): 353 | loss = 0 354 | gen_losses = [] 355 | for dg in disc_outputs: 356 | l = torch.mean((1-dg)**2) 357 | gen_losses.append(l) 358 | loss += l 359 | 360 | return loss, gen_losses 361 | 362 | def generator_collaborative_loss(disc_outputs): 363 | loss = 0 364 | gen_losses = [] 365 | for dg in disc_outputs: 366 | l = torch.mean((0-dg)**2) 367 | gen_losses.append(l) 368 | loss += l 369 | 370 | return loss, gen_losses 371 | 372 | 373 | -------------------------------------------------------------------------------- /src/neural_formant_synthesis/third_party/hifi_gan/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.4.0 2 | numpy==1.17.4 3 | librosa==0.7.2 4 | scipy==1.4.1 5 | tensorboard==2.0 6 | soundfile==0.10.3.post1 7 | matplotlib==3.1.3 -------------------------------------------------------------------------------- /src/neural_formant_synthesis/third_party/hifi_gan/train.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.simplefilter(action='ignore', category=FutureWarning) 3 | import itertools 4 | import os 5 | import time 6 | import argparse 7 | import json 8 | import torch 9 | import torch.nn.functional as F 10 | from torch.utils.tensorboard import SummaryWriter 11 | from torch.utils.data import DistributedSampler, DataLoader 12 | import torch.multiprocessing as mp 13 | from torch.distributed import init_process_group 14 | from torch.nn.parallel import DistributedDataParallel 15 | from env import AttrDict, build_env 16 | from meldataset import MelDataset, mel_spectrogram, get_dataset_filelist 17 | from models import Generator, MultiPeriodDiscriminator, MultiScaleDiscriminator, feature_loss, generator_adversarial_loss,\ 18 | discriminator_loss 19 | from utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint 20 | 21 | torch.backends.cudnn.benchmark = True 22 | 23 | 24 | def train(rank, a, h): 25 | if h.num_gpus > 1: 26 | init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'], 27 | world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank) 28 | 29 | torch.cuda.manual_seed(h.seed) 30 | if torch.cuda.is_available(): 31 | device = torch.device('cuda:{:d}'.format(rank)) 32 | else: 33 | device = torch.device('cpu') 34 | 35 | generator = Generator(h).to(device) 36 | mpd = MultiPeriodDiscriminator().to(device) 37 | msd = MultiScaleDiscriminator().to(device) 38 | 39 | if rank == 0: 40 | print(generator) 41 | os.makedirs(a.checkpoint_path, exist_ok=True) 42 | print("checkpoints directory : ", a.checkpoint_path) 43 | 44 | if os.path.isdir(a.checkpoint_path): 45 | cp_g = scan_checkpoint(a.checkpoint_path, 'g_') 46 | cp_do = scan_checkpoint(a.checkpoint_path, 'do_') 47 | 48 | steps = 0 49 | if cp_g is None or cp_do is None: 50 | state_dict_do = None 51 | last_epoch = -1 52 | else: 53 | state_dict_g = load_checkpoint(cp_g, device) 54 | state_dict_do = load_checkpoint(cp_do, device) 55 | generator.load_state_dict(state_dict_g['generator']) 56 | mpd.load_state_dict(state_dict_do['mpd']) 57 | msd.load_state_dict(state_dict_do['msd']) 58 | steps = state_dict_do['steps'] + 1 59 | last_epoch = state_dict_do['epoch'] 60 | 61 | if h.num_gpus > 1: 62 | generator = DistributedDataParallel(generator, device_ids=[rank]).to(device) 63 | mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device) 64 | msd = DistributedDataParallel(msd, device_ids=[rank]).to(device) 65 | 66 | optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2]) 67 | optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()), 68 | h.learning_rate, betas=[h.adam_b1, h.adam_b2]) 69 | 70 | if state_dict_do is not None: 71 | optim_g.load_state_dict(state_dict_do['optim_g']) 72 | optim_d.load_state_dict(state_dict_do['optim_d']) 73 | 74 | scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch) 75 | scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch) 76 | 77 | training_filelist, validation_filelist = get_dataset_filelist(a) 78 | 79 | trainset = MelDataset(training_filelist, h.segment_size, h.n_fft, h.num_mels, 80 | h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, n_cache_reuse=0, 81 | shuffle=False if h.num_gpus > 1 else True, fmax_loss=h.fmax_for_loss, device=device, 82 | fine_tuning=a.fine_tuning, base_mels_path=a.input_mels_dir) 83 | 84 | train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None 85 | 86 | train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False, 87 | sampler=train_sampler, 88 | batch_size=h.batch_size, 89 | pin_memory=True, 90 | drop_last=True) 91 | 92 | if rank == 0: 93 | validset = MelDataset(validation_filelist, h.segment_size, h.n_fft, h.num_mels, 94 | h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, False, False, n_cache_reuse=0, 95 | fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning, 96 | base_mels_path=a.input_mels_dir) 97 | validation_loader = DataLoader(validset, num_workers=1, shuffle=False, 98 | sampler=None, 99 | batch_size=1, 100 | pin_memory=True, 101 | drop_last=True) 102 | 103 | sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs')) 104 | 105 | generator.train() 106 | mpd.train() 107 | msd.train() 108 | for epoch in range(max(0, last_epoch), a.training_epochs): 109 | if rank == 0: 110 | start = time.time() 111 | print("Epoch: {}".format(epoch+1)) 112 | 113 | if h.num_gpus > 1: 114 | train_sampler.set_epoch(epoch) 115 | 116 | for i, batch in enumerate(train_loader): 117 | if rank == 0: 118 | start_b = time.time() 119 | x, y, _, y_mel = batch # mel, audio, filename, mel_loss 120 | x = torch.autograd.Variable(x.to(device, non_blocking=True)) 121 | y = torch.autograd.Variable(y.to(device, non_blocking=True)) 122 | y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True)) 123 | y = y.unsqueeze(1) 124 | 125 | # TODO: add LPC parameter estimation 126 | 127 | y_g_hat = generator(x) 128 | # TODO: add glotnet synthesis filter 129 | y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, 130 | h.fmin, h.fmax_for_loss) 131 | 132 | optim_d.zero_grad() 133 | 134 | # MPD 135 | y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach()) 136 | loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g) 137 | 138 | # MSD 139 | y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach()) 140 | loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g) 141 | 142 | loss_disc_all = loss_disc_s + loss_disc_f 143 | 144 | loss_disc_all.backward() 145 | optim_d.step() 146 | 147 | # Generator 148 | optim_g.zero_grad() 149 | 150 | # L1 Mel-Spectrogram Loss 151 | loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45 152 | 153 | y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat) 154 | y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat) 155 | loss_fm_f = feature_loss(fmap_f_r, fmap_f_g) 156 | loss_fm_s = feature_loss(fmap_s_r, fmap_s_g) 157 | loss_gen_f, losses_gen_f = generator_adversarial_loss(y_df_hat_g) 158 | loss_gen_s, losses_gen_s = generator_adversarial_loss(y_ds_hat_g) 159 | loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel 160 | 161 | loss_gen_all.backward() 162 | optim_g.step() 163 | 164 | if rank == 0: 165 | # STDOUT logging 166 | if steps % a.stdout_interval == 0: 167 | with torch.no_grad(): 168 | mel_error = F.l1_loss(y_mel, y_g_hat_mel).item() 169 | 170 | print('Steps : {:d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'. 171 | format(steps, loss_gen_all, mel_error, time.time() - start_b)) 172 | 173 | # checkpointing 174 | if steps % a.checkpoint_interval == 0 and steps != 0: 175 | checkpoint_path = "{}/g_{:08d}".format(a.checkpoint_path, steps) 176 | save_checkpoint(checkpoint_path, 177 | {'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()}) 178 | checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps) 179 | save_checkpoint(checkpoint_path, 180 | {'mpd': (mpd.module if h.num_gpus > 1 181 | else mpd).state_dict(), 182 | 'msd': (msd.module if h.num_gpus > 1 183 | else msd).state_dict(), 184 | 'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps, 185 | 'epoch': epoch}) 186 | 187 | # Tensorboard summary logging 188 | if steps % a.summary_interval == 0: 189 | sw.add_scalar("training/gen_loss_total", loss_gen_all, steps) 190 | sw.add_scalar("training/mel_spec_error", mel_error, steps) 191 | 192 | # Validation 193 | if steps % a.validation_interval == 0: # and steps != 0: 194 | generator.eval() 195 | torch.cuda.empty_cache() 196 | val_err_tot = 0 197 | with torch.no_grad(): 198 | for j, batch in enumerate(validation_loader): 199 | x, y, _, y_mel = batch 200 | y_g_hat = generator(x.to(device)) 201 | y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True)) 202 | y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, 203 | h.hop_size, h.win_size, 204 | h.fmin, h.fmax_for_loss) 205 | val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item() 206 | 207 | if j <= 4: 208 | if steps == 0: 209 | sw.add_audio('gt/y_{}'.format(j), y[0], steps, h.sampling_rate) 210 | sw.add_figure('gt/y_spec_{}'.format(j), plot_spectrogram(x[0]), steps) 211 | 212 | sw.add_audio('generated/y_hat_{}'.format(j), y_g_hat[0], steps, h.sampling_rate) 213 | y_hat_spec = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, 214 | h.sampling_rate, h.hop_size, h.win_size, 215 | h.fmin, h.fmax) 216 | sw.add_figure('generated/y_hat_spec_{}'.format(j), 217 | plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()), steps) 218 | 219 | val_err = val_err_tot / (j+1) 220 | sw.add_scalar("validation/mel_spec_error", val_err, steps) 221 | 222 | generator.train() 223 | 224 | steps += 1 225 | 226 | scheduler_g.step() 227 | scheduler_d.step() 228 | 229 | if rank == 0: 230 | print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start))) 231 | 232 | 233 | def main(): 234 | print('Initializing Training Process..') 235 | 236 | parser = argparse.ArgumentParser() 237 | 238 | parser.add_argument('--group_name', default=None) 239 | parser.add_argument('--input_wavs_dir', default='LJSpeech-1.1/wavs') 240 | parser.add_argument('--input_mels_dir', default='ft_dataset') 241 | parser.add_argument('--input_training_file', default='LJSpeech-1.1/training.txt') 242 | parser.add_argument('--input_validation_file', default='LJSpeech-1.1/validation.txt') 243 | parser.add_argument('--checkpoint_path', default='cp_hifigan') 244 | parser.add_argument('--config', default='') 245 | parser.add_argument('--training_epochs', default=3100, type=int) 246 | parser.add_argument('--stdout_interval', default=5, type=int) 247 | parser.add_argument('--checkpoint_interval', default=5000, type=int) 248 | parser.add_argument('--summary_interval', default=100, type=int) 249 | parser.add_argument('--validation_interval', default=1000, type=int) 250 | parser.add_argument('--fine_tuning', default=False, type=bool) 251 | 252 | a = parser.parse_args() 253 | 254 | with open(a.config) as f: 255 | data = f.read() 256 | 257 | json_config = json.loads(data) 258 | h = AttrDict(json_config) 259 | build_env(a.config, 'config.json', a.checkpoint_path) 260 | 261 | torch.manual_seed(h.seed) 262 | if torch.cuda.is_available(): 263 | torch.cuda.manual_seed(h.seed) 264 | h.num_gpus = torch.cuda.device_count() 265 | h.batch_size = int(h.batch_size / h.num_gpus) 266 | print('Batch size per GPU :', h.batch_size) 267 | else: 268 | pass 269 | 270 | if h.num_gpus > 1: 271 | mp.spawn(train, nprocs=h.num_gpus, args=(a, h,)) 272 | else: 273 | train(0, a, h) 274 | 275 | 276 | if __name__ == '__main__': 277 | main() 278 | -------------------------------------------------------------------------------- /src/neural_formant_synthesis/third_party/hifi_gan/utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import matplotlib 4 | import torch 5 | from torch.nn.utils import weight_norm 6 | matplotlib.use("Agg") 7 | import matplotlib.pylab as plt 8 | 9 | 10 | def plot_spectrogram(spectrogram): 11 | fig, ax = plt.subplots(figsize=(10, 2)) 12 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 13 | interpolation='none') 14 | plt.colorbar(im, ax=ax) 15 | 16 | fig.canvas.draw() 17 | plt.close() 18 | 19 | return fig 20 | 21 | 22 | def init_weights(m, mean=0.0, std=0.01): 23 | classname = m.__class__.__name__ 24 | if classname.find("Conv") != -1: 25 | m.weight.data.normal_(mean, std) 26 | 27 | 28 | def apply_weight_norm(m): 29 | classname = m.__class__.__name__ 30 | if classname.find("Conv") != -1: 31 | weight_norm(m) 32 | 33 | 34 | def get_padding(kernel_size, dilation=1): 35 | return int((kernel_size*dilation - dilation)/2) 36 | 37 | 38 | def load_checkpoint(filepath, device): 39 | assert os.path.isfile(filepath) 40 | print("Loading '{}'".format(filepath)) 41 | checkpoint_dict = torch.load(filepath, map_location=device) 42 | print("Complete.") 43 | return checkpoint_dict 44 | 45 | 46 | def save_checkpoint(filepath, obj): 47 | print("Saving checkpoint to {}".format(filepath)) 48 | torch.save(obj, filepath) 49 | print("Complete.") 50 | 51 | 52 | def scan_checkpoint(cp_dir, prefix): 53 | pattern = os.path.join(cp_dir, prefix + '????????') 54 | cp_list = glob.glob(pattern) 55 | if len(cp_list) == 0: 56 | return None 57 | return sorted(cp_list)[-1] 58 | 59 | -------------------------------------------------------------------------------- /src/neural_formant_synthesis/vctk_preprocessing.py: -------------------------------------------------------------------------------- 1 | import os, glob 2 | import numpy as np 3 | import torch 4 | import torchaudio as ta 5 | from tqdm import tqdm 6 | 7 | import argparse 8 | 9 | from feature_extraction import feature_extractor, MedianPool1d 10 | from neural_formant_synthesis.glotnet.sigproc.emphasis import Emphasis 11 | 12 | ta.set_audio_backend("sox_io") 13 | 14 | def main(vctk_path, target_dir): 15 | # Set random seed to 0 16 | np.random.seed(0) 17 | torch.manual_seed(0) 18 | 19 | # Define initial parameters 20 | window_length = 1024 21 | step_length = 256 22 | 23 | target_sr = 22050 24 | 25 | file_ext = '.flac' 26 | 27 | # Declare feature extractor 28 | feat_extractor = feature_extractor(sr = target_sr,window_samples = window_length, step_samples = step_length, formant_ceiling = 10000, max_formants = 4) 29 | median_filter = MedianPool1d(kernel_size = 3, stride = 1, padding = 0, same = True) 30 | pre_emphasis = Emphasis(alpha=0.97) 31 | 32 | # Divide vctk dataset into train, validation and test sets 33 | divide_vctk(vctk_path, target_dir, file_ext = file_ext) 34 | 35 | # Process train, validation and test sets 36 | print("Processing separated sets") 37 | train_dir = os.path.join(target_dir, 'train') 38 | process_directory(train_dir, target_sr, feat_extractor, median_filter, pre_emphasis = pre_emphasis, file_ext = file_ext) 39 | val_dir = os.path.join(target_dir, 'val') 40 | process_directory(val_dir, target_sr, feat_extractor, median_filter, pre_emphasis = pre_emphasis, file_ext = file_ext) 41 | test_dir = os.path.join(target_dir, 'test') 42 | process_directory(test_dir, target_sr, feat_extractor, median_filter, pre_emphasis = pre_emphasis, file_ext = file_ext) 43 | 44 | def divide_vctk(vctk_path, target_dir, file_ext = '.wav'): 45 | """ 46 | Divide original vctk dataset into train, validation and test sets with a ratio of 80:10:10 with different speakers. 47 | """ 48 | #Empty target directory if it exists 49 | if os.path.exists(target_dir): 50 | os.system('rm -rf ' + target_dir) 51 | # Create target directory if it doesn't exist 52 | if not os.path.exists(target_dir): 53 | os.mkdir(target_dir) 54 | # Create train, validation and test directories if they don't exist 55 | train_dir = os.path.join(target_dir, 'train') 56 | if not os.path.exists(train_dir): 57 | os.mkdir(train_dir) 58 | val_dir = os.path.join(target_dir, 'val') 59 | if not os.path.exists(val_dir): 60 | os.mkdir(val_dir) 61 | test_dir = os.path.join(target_dir, 'test') 62 | if not os.path.exists(test_dir): 63 | os.mkdir(test_dir) 64 | 65 | # Get speakers as directory names in vctk_path 66 | speakers = os.listdir(vctk_path) 67 | 68 | #Shuffle speakers list 69 | speakers = np.random.permutation(speakers) 70 | # Divide speakers list into train, validation and test sets 71 | train_speakers = speakers[:int(len(speakers) * 0.8)] 72 | val_speakers = speakers[int(len(speakers) * 0.8):int(len(speakers) * 0.9)] 73 | test_speakers = speakers[int(len(speakers) * 0.9):] 74 | 75 | # Save train, validation and test speakers lists in one file in target_dir 76 | speakers_dict = {"train": train_speakers, "val": val_speakers, "test": test_speakers} 77 | torch.save(speakers_dict, os.path.join(target_dir, 'speakers.pt')) 78 | 79 | print("Dividing Speakers") 80 | 81 | # Copy audio files in each speaker directory to train, validation and test directories 82 | 83 | print("Processing Test Set") 84 | testfile = open(os.path.join(target_dir, "test_files.txt"), 'w') 85 | for speaker in tqdm(test_speakers,total = len(test_speakers)): 86 | speaker_dir = os.path.join(vctk_path, speaker) 87 | speaker_files = glob.glob(os.path.join(speaker_dir, '*' + file_ext)) 88 | for file in speaker_files: 89 | os.system('cp ' + file + ' ' + test_dir) 90 | testfile.writelines([str(i)+'\n' for i in speaker_files]) 91 | testfile.close() 92 | 93 | print("Processing Training Set") 94 | trainfile = open(os.path.join(target_dir, "train_files.txt"), 'w') 95 | for speaker in tqdm(train_speakers, total = len(train_speakers)): 96 | speaker_dir = os.path.join(vctk_path, speaker) 97 | speaker_files = glob.glob(os.path.join(speaker_dir, '*' + file_ext)) 98 | for file in speaker_files: 99 | os.system('cp ' + file + ' ' + train_dir) 100 | trainfile.writelines([str(i)+'\n' for i in speaker_files]) 101 | trainfile.close() 102 | 103 | print("Processing Validation Set") 104 | valfile = open(os.path.join(target_dir, "val_files.txt"), 'w') 105 | for speaker in tqdm(val_speakers, total = len(val_speakers)): 106 | speaker_dir = os.path.join(vctk_path, speaker) 107 | speaker_files = glob.glob(os.path.join(speaker_dir, '*' + file_ext)) 108 | for file in speaker_files: 109 | os.system('cp ' + file + ' ' + val_dir) 110 | valfile.writelines([str(i)+'\n' for i in speaker_files]) 111 | valfile.close() 112 | 113 | 114 | 115 | 116 | def process_directory(path, target_sr, feature_extractor, median_filter, pre_emphasis = None, file_ext = '.wav'): 117 | file_list = glob.glob(os.path.join(path, '*' + file_ext)) 118 | print(f"Processing directory {path}") 119 | for file in tqdm(file_list, total=len(file_list)): 120 | basename = os.path.basename(file) 121 | no_ext = os.path.splitext(basename)[0] 122 | 123 | formants, energy, centroid, tilt, log_pitch, voicing_flag, r_coeff, ignored = process_file(file, target_sr, feature_extractor, median_filter, pre_emphasis = pre_emphasis) 124 | 125 | if formants.size(0) < log_pitch.size(0): 126 | raise ValueError("Formants size is different than pitch size for file: " + file) 127 | 128 | feature_dict = {"Formants": formants, "Energy": energy, "Centroid": centroid, "Tilt": tilt, "Pitch": log_pitch, "Voicing": voicing_flag, "R_Coeff": r_coeff} 129 | if not ignored: 130 | torch.save(feature_dict, os.path.join(path, no_ext + '.pt')) 131 | else: 132 | print("File: " + basename + " ignored.") 133 | 134 | def process_file(file, target_sr, feature_extractor, median_filter, pre_emphasis = None): 135 | """ 136 | Extract features for a single audio file and return feature arrays with the same length. 137 | Params: 138 | file: path to file 139 | target_sr: target sample rate 140 | Returns: 141 | formants: formants 142 | energy: energy in log scale 143 | centroid: spectral centroid 144 | tilt: spectral tilt 145 | tilt_ref: spectral tilt from reference 146 | pitch: pitch in log scale 147 | voicing_flag: voicing flag 148 | r_coeff: reflection coefficients from reference 149 | """ 150 | # Read audio file 151 | x, sample_rate = ta.load(file) 152 | x = x[0:1].type(torch.DoubleTensor) 153 | # Resample to target sr 154 | x = ta.functional.resample(x, sample_rate, target_sr) 155 | 156 | if pre_emphasis is not None: 157 | x = pre_emphasis(x.unsqueeze(0)) 158 | x = x.squeeze(0).squeeze(0) 159 | formants, energy, centroid, tilt, pitch, voicing_flag,r_coeff, _, ignored = feature_extractor(x) 160 | 161 | formants = median_filter(formants.T.unsqueeze(1)).squeeze(1).T 162 | 163 | pitch = pitch.squeeze(0) 164 | voicing_flag = voicing_flag.squeeze(0) 165 | 166 | # If pitch length is smaller than formants, pad pitch and voicing flag with last value 167 | if pitch.size(0) < formants.size(0): 168 | pitch = torch.nn.functional.pad(pitch, (0, formants.size(0) - pitch.size(0)), mode = 'constant', value = pitch[-1]) 169 | voicing_flag = torch.nn.functional.pad(voicing_flag, (0, formants.size(0) - voicing_flag.size(0)), mode = 'constant', value = voicing_flag[-1]) 170 | # If pitch length is larger than formants, truncate pitch and voicing flag 171 | elif pitch.size(0) > formants.size(0): 172 | pitch = pitch[:formants.size(0)] 173 | voicing_flag = voicing_flag[:formants.size(0)] 174 | 175 | log_pitch = torch.log(pitch) 176 | 177 | return formants, energy, centroid, tilt, log_pitch, voicing_flag, r_coeff, ignored 178 | 179 | if __name__ == "__main__": 180 | 181 | parser = argparse.ArgumentParser() 182 | 183 | parser.add_argument('--input_dir', default='/workspace/Dataset/wav48_silence_trimmed') 184 | parser.add_argument('--output_dir', default='/workspace/Dataset/vctk_features') 185 | 186 | 187 | a = parser.parse_args() 188 | 189 | vctk_path = os.path.normpath(a.input_dir) 190 | target_dir = os.path.normpath(a.output_dir) 191 | main(vctk_path, target_dir) -------------------------------------------------------------------------------- /tests/test_dataset.py: -------------------------------------------------------------------------------- 1 | from neural_formant_synthesis.dataset import FeatureDataset -------------------------------------------------------------------------------- /tests/test_train_imports.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.simplefilter(action='ignore', category=FutureWarning) 3 | import itertools 4 | import os 5 | import time 6 | import argparse 7 | 8 | import json 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.utils.tensorboard import SummaryWriter 12 | from torch.utils.data import DistributedSampler, DataLoader 13 | import torch.multiprocessing as mp 14 | from torch.distributed import init_process_group 15 | from torch.nn.parallel import DistributedDataParallel 16 | from neural_formant_synthesis.third_party.hifi_gan.env import AttrDict, build_env 17 | from neural_formant_synthesis.third_party.hifi_gan.meldataset import mel_spectrogram 18 | from neural_formant_synthesis.third_party.hifi_gan.models import MultiPeriodDiscriminator, MultiScaleDiscriminator, feature_loss, generator_adversarial_loss,\ 19 | discriminator_loss 20 | from neural_formant_synthesis.third_party.hifi_gan.utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint 21 | 22 | 23 | from neural_formant_synthesis.glotnet.sigproc.lpc import LinearPredictor 24 | from neural_formant_synthesis.glotnet.sigproc.emphasis import Emphasis 25 | 26 | from neural_formant_synthesis.dataset import FeatureDataset_List 27 | from neural_formant_synthesis.models import FM_Hifi_Generator, fm_config_obj, Envelope_wavenet, Envelope_conformer 28 | from neural_formant_synthesis.models import SourceFilterFormantSynthesisGenerator 29 | 30 | from neural_formant_synthesis.glotnet.sigproc.levinson import forward_levinson 31 | 32 | import torchaudio as ta 33 | -------------------------------------------------------------------------------- /tests/test_wavenet.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ljuvela/SourceFilterNeuralFormants/d0894c2aa153510e6967c3b62c47f73ce4cb3879/tests/test_wavenet.py -------------------------------------------------------------------------------- /train_e2e_hifigan.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.simplefilter(action='ignore', category=FutureWarning) 3 | import itertools 4 | import os 5 | import time 6 | import argparse 7 | 8 | import json 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.utils.tensorboard import SummaryWriter 12 | from torch.utils.data import DistributedSampler, DataLoader 13 | import torch.multiprocessing as mp 14 | from torch.distributed import init_process_group 15 | from torch.nn.parallel import DistributedDataParallel 16 | from neural_formant_synthesis.third_party.hifi_gan.env import AttrDict, build_env 17 | from neural_formant_synthesis.third_party.hifi_gan.meldataset import mel_spectrogram 18 | from neural_formant_synthesis.third_party.hifi_gan.models import MultiPeriodDiscriminator, MultiScaleDiscriminator, feature_loss, generator_adversarial_loss,\ 19 | discriminator_loss 20 | #from neural_formant_synthesis.third_party.hifi_gan.models import discriminator_metrics 21 | from neural_formant_synthesis.third_party.hifi_gan.utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint 22 | 23 | 24 | from neural_formant_synthesis.glotnet.sigproc.lpc import LinearPredictor 25 | from neural_formant_synthesis.glotnet.sigproc.emphasis import Emphasis 26 | 27 | from neural_formant_synthesis.dataset import FeatureDataset_List 28 | from neural_formant_synthesis.models import FM_Hifi_Generator, fm_config_obj, Envelope_wavenet, Envelope_conformer 29 | from neural_formant_synthesis.models import NeuralFormantSynthesisGenerator 30 | 31 | 32 | from neural_formant_synthesis.glotnet.sigproc.levinson import forward_levinson 33 | 34 | import torchaudio as ta 35 | 36 | 37 | torch.backends.cudnn.benchmark = True 38 | 39 | 40 | def train(rank, a, h, fm_h): 41 | 42 | if h.num_gpus > 1: 43 | init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'], 44 | world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank) 45 | 46 | torch.cuda.manual_seed(h.seed) 47 | if torch.cuda.is_available(): 48 | device = torch.device('cuda:{:d}'.format(rank)) 49 | else: 50 | device = torch.device('cpu') 51 | 52 | # HiFi generator included in feature mapping class. 53 | pretrained_fm = getattr(fm_h, 'model_path', None) 54 | generator = NeuralFormantSynthesisGenerator(fm_config = fm_h, g_config = h, 55 | pretrained_fm = pretrained_fm, 56 | freeze_fm = pretrained_fm is not None, 57 | device = device) 58 | 59 | def count_parameters(model): 60 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 61 | 62 | print(f"Num parameters in HiFi Generator: {count_parameters(generator.hifi_generator)}") 63 | print(f"Num parameters in Feature Mapping model: {count_parameters(generator.feature_mapping)}") 64 | 65 | generator = generator.to(device) 66 | mpd = MultiPeriodDiscriminator().to(device) 67 | msd = MultiScaleDiscriminator().to(device) 68 | 69 | if rank == 0: 70 | print(generator) 71 | os.makedirs(a.checkpoint_path, exist_ok=True) 72 | print("checkpoints directory : ", a.checkpoint_path) 73 | 74 | if os.path.isdir(a.checkpoint_path): 75 | cp_g = scan_checkpoint(a.checkpoint_path, 'g_') 76 | cp_do = scan_checkpoint(a.checkpoint_path, 'do_') 77 | 78 | pretrain_steps = 0 # always reset, never load 79 | steps = 0 80 | if cp_g is None or cp_do is None: 81 | state_dict_do = None 82 | last_epoch = -1 83 | else: 84 | generator.load_generator_e2e_checkpoint(cp_g) 85 | state_dict_do = load_checkpoint(cp_do, device) 86 | mpd.load_state_dict(state_dict_do['mpd']) 87 | msd.load_state_dict(state_dict_do['msd']) 88 | steps = state_dict_do['steps'] + 1 89 | last_epoch = state_dict_do['epoch'] 90 | 91 | if h.num_gpus > 1: 92 | generator = DistributedDataParallel(generator, device_ids=[rank]).to(device) 93 | mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device) 94 | msd = DistributedDataParallel(msd, device_ids=[rank]).to(device) 95 | 96 | optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2]) 97 | optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()), 98 | h.learning_rate, betas=[h.adam_b1, h.adam_b2]) 99 | 100 | if state_dict_do is not None: 101 | optim_g.load_state_dict(state_dict_do['optim_g']) 102 | optim_d.load_state_dict(state_dict_do['optim_d']) 103 | 104 | scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch) 105 | scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch) 106 | 107 | training_path = os.path.join(a.input_wavs_dir, 'train') 108 | trainset = FeatureDataset_List(training_path, h, sampling_rate = h.sampling_rate, 109 | frame_size = h.win_size, hop_size = h.hop_size, shuffle = True, audio_ext = '.flac', 110 | segment_length = 32) 111 | 112 | train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None 113 | 114 | train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False, 115 | sampler=train_sampler, 116 | batch_size=h.batch_size, 117 | pin_memory=True, 118 | drop_last=True) 119 | 120 | if rank == 0: 121 | 122 | valid_path = os.path.join(a.input_wavs_dir, 'val') 123 | validset = FeatureDataset_List(valid_path, h, sampling_rate = h.sampling_rate, 124 | frame_size = h.win_size, hop_size = h.hop_size, audio_ext = '.flac') 125 | validation_loader = DataLoader(validset, num_workers=1, shuffle=False, 126 | sampler=None, 127 | batch_size=1, 128 | pin_memory=True, 129 | drop_last=True) 130 | 131 | sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs')) 132 | 133 | 134 | generator.train() 135 | mpd.train() 136 | msd.train() 137 | 138 | # generator = torch.compile(generator) 139 | # mpd = torch.compile(mpd) 140 | # msd = torch.compile(msd) 141 | for epoch in range(max(0, last_epoch), a.training_epochs): 142 | 143 | if rank == 0: 144 | start = time.time() 145 | print("Epoch: {}".format(epoch+1)) 146 | 147 | if h.num_gpus > 1: 148 | train_sampler.set_epoch(epoch) 149 | 150 | for i, batch in enumerate(train_loader): 151 | 152 | if rank == 0: 153 | start_b = time.time() 154 | 155 | #size --> (Batch, features, sequence) 156 | x, _, y, y_mel = batch 157 | x = torch.autograd.Variable(x.to(device, non_blocking=True)) 158 | y = torch.autograd.Variable(y.to(device, non_blocking=True)) 159 | y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True)) 160 | y = y.unsqueeze(1) 161 | 162 | x_feat = x[:,0:9,:] 163 | 164 | # feature_map_only = pretrain_steps < fm_h.get('pre_train_steps', 0) 165 | feature_map_only = steps < fm_h.get('pre_train_steps', 0) 166 | detach = fm_h.get('detach_feature_map_from_gan', False) 167 | if feature_map_only: 168 | mel_cond = generator.forward(x_feat, feature_map_only, detach) 169 | else: 170 | y_g_hat, mel_cond = generator.forward(x_feat, feature_map_only, detach) 171 | 172 | 173 | 174 | if not feature_map_only: 175 | y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, 176 | h.num_mels, h.sampling_rate, 177 | h.hop_size, h.win_size, 178 | h.fmin, h.fmax_for_loss) 179 | 180 | optim_d.zero_grad() 181 | 182 | # MPD 183 | y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach()) 184 | 185 | loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g) 186 | 187 | # MSD 188 | y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach()) 189 | 190 | loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g) 191 | 192 | loss_disc_all = loss_disc_s + loss_disc_f 193 | 194 | loss_disc_all.backward() 195 | optim_d.step() 196 | 197 | # Generator 198 | optim_g.zero_grad() 199 | loss_gen_all = 0.0 200 | 201 | if not feature_map_only: 202 | # L1 Mel-Spectrogram Loss 203 | loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45 204 | 205 | # Adversarial losses 206 | y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat) 207 | y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat) 208 | loss_fm_f = feature_loss(fmap_f_r, fmap_f_g) 209 | loss_fm_s = feature_loss(fmap_s_r, fmap_s_g) 210 | loss_gen_f, losses_gen_f = generator_adversarial_loss(y_df_hat_g) 211 | loss_gen_s, losses_gen_s = generator_adversarial_loss(y_ds_hat_g) 212 | loss_gen_all = loss_gen_all + loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel 213 | 214 | loss_mel_fm = F.l1_loss(y_mel, mel_cond) 215 | 216 | loss_gen_all = loss_gen_all + loss_mel_fm 217 | 218 | loss_gen_all.backward() 219 | optim_g.step() 220 | 221 | if not torch.isfinite(loss_gen_all): 222 | raise ValueError(f"Loss value is not finite, was {loss_gen_all}") 223 | 224 | if rank == 0: 225 | # STDOUT logging 226 | if steps % a.stdout_interval == 0: 227 | 228 | if feature_map_only: 229 | print('Steps : {:d}, Gen Loss Total : {:4.3f}, s/b : {:4.3f}'. 230 | format(steps, loss_gen_all, time.time() - start_b)) 231 | else: 232 | with torch.no_grad(): 233 | mel_error = F.l1_loss(y_mel, y_g_hat_mel).item() 234 | print('Steps : {:d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'. 235 | format(steps, loss_gen_all, mel_error, time.time() - start_b)) 236 | 237 | # checkpointing 238 | if steps % a.checkpoint_interval == 0 and steps != 0: 239 | checkpoint_path = "{}/g_{:08d}".format(a.checkpoint_path, steps) 240 | save_checkpoint(checkpoint_path, 241 | {'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()}) 242 | checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps) 243 | save_checkpoint(checkpoint_path, 244 | {'mpd': (mpd.module if h.num_gpus > 1 245 | else mpd).state_dict(), 246 | 'msd': (msd.module if h.num_gpus > 1 247 | else msd).state_dict(), 248 | 'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps, 249 | 'epoch': epoch}) 250 | 251 | # Tensorboard summary logging 252 | if steps % a.summary_interval == 0: 253 | sw.add_scalar("training/gen_loss_total", loss_gen_all, steps) 254 | 255 | sw.add_scalar("training/feature_mapping_mel_loss", loss_mel_fm, steps) 256 | 257 | if not feature_map_only: 258 | sw.add_scalar("training/mel_spec_error", mel_error, steps) 259 | 260 | # Framed Discriminator losses 261 | sw.add_scalar("training_gan/disc_f_r", sum(losses_disc_f_r), steps) 262 | sw.add_scalar("training_gan/disc_f_g", sum(losses_disc_f_g), steps) 263 | # Multiscale Discriminator losses 264 | sw.add_scalar("training_gan/disc_s_r", sum(losses_disc_s_r), steps) 265 | sw.add_scalar("training_gan/disc_s_g", sum(losses_disc_s_g), steps) 266 | # Framed Generator losses 267 | sw.add_scalar("training_gan/gen_f", sum(losses_gen_f), steps) 268 | # Multiscale Generator losses 269 | sw.add_scalar("training_gan/gen_s", sum(losses_gen_s), steps) 270 | # Feature Matching losses 271 | sw.add_scalar("training_gan/loss_fm_f", loss_fm_f, steps) 272 | sw.add_scalar("training_gan/loss_fm_s", loss_fm_s, steps) 273 | 274 | 275 | # Validation 276 | if steps % a.validation_interval == 0: #and steps != 0: 277 | 278 | print(f"Validation at step {steps}") 279 | generator.eval() 280 | torch.cuda.empty_cache() 281 | val_err_tot = 0 282 | max_valid_batches = 100 283 | with torch.no_grad(): 284 | for j, batch in enumerate(validation_loader): 285 | 286 | if j > max_valid_batches: 287 | break 288 | 289 | x, _, y, y_mel = batch 290 | 291 | x = x.to(device) 292 | 293 | x_feat = x[:,0:9,:] 294 | 295 | y_g_hat, mel_cond = generator(x_feat) 296 | 297 | y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True)) 298 | y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, 299 | h.hop_size, h.win_size, 300 | h.fmin, h.fmax_for_loss) 301 | 302 | val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item() 303 | 304 | # TODO: calculate discriminator EER 305 | 306 | if j <= 4: 307 | if steps == 0: 308 | sw.add_audio('gt/y_{}'.format(j), y[0], steps, h.sampling_rate) 309 | sw.add_figure('gt/y_spec_{}'.format(j), plot_spectrogram(y_mel[0].cpu()), steps) 310 | 311 | sw.add_audio('generated/y_hat_{}'.format(j), y_g_hat[0], steps, h.sampling_rate) 312 | y_hat_spec = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, 313 | h.sampling_rate, h.hop_size, h.win_size, 314 | h.fmin, h.fmax) 315 | sw.add_figure('generated/y_hat_spec_{}'.format(j), 316 | plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()), steps) 317 | 318 | val_err = val_err_tot / (j+1) 319 | sw.add_scalar("validation/mel_spec_error", val_err, steps) 320 | 321 | generator.train() 322 | 323 | steps += 1 324 | pretrain_steps += 1 325 | 326 | scheduler_g.step() 327 | scheduler_d.step() 328 | 329 | if rank == 0: 330 | print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start))) 331 | 332 | 333 | def main(): 334 | print('Initializing Training Process..') 335 | 336 | parser = argparse.ArgumentParser() 337 | 338 | parser.add_argument('--group_name', default=None) 339 | parser.add_argument('--input_wavs_dir', default='LJSpeech-1.1/wavs') 340 | parser.add_argument('--input_mels_dir', default='ft_dataset') 341 | parser.add_argument('--input_training_file', default='LJSpeech-1.1/training.txt') 342 | parser.add_argument('--input_validation_file', default='LJSpeech-1.1/validation.txt') 343 | parser.add_argument('--checkpoint_path', default='cp_hifigan') 344 | parser.add_argument('--config', default='') 345 | parser.add_argument('--fm_config', default='') 346 | parser.add_argument('--env_config', default='') 347 | parser.add_argument('--training_epochs', default=3100, type=int) 348 | parser.add_argument('--stdout_interval', default=5, type=int) 349 | parser.add_argument('--checkpoint_interval', default=5000, type=int) 350 | parser.add_argument('--summary_interval', default=100, type=int) 351 | parser.add_argument('--validation_interval', default=1000, type=int) 352 | parser.add_argument('--fine_tuning', default=False, type=bool) 353 | parser.add_argument('--wavefile_ext', default='.wav', type=str) 354 | parser.add_argument('--envelope_model_pretrained', default=False, type=bool) 355 | parser.add_argument('--envelope_model_freeze', default=False, type=bool) 356 | 357 | a = parser.parse_args() 358 | 359 | with open(a.config) as f: 360 | data = f.read() 361 | 362 | json_config = json.loads(data) 363 | h = AttrDict(json_config) 364 | 365 | with open(a.fm_config) as f: 366 | data = f.read() 367 | json_fm_config = json.loads(data) 368 | fm_h = AttrDict(json_fm_config) 369 | 370 | 371 | build_env(a.config, 'config.json', a.checkpoint_path) 372 | # TODO: copy configs for feature mapping and envelope models! 373 | 374 | torch.manual_seed(h.seed) 375 | if torch.cuda.is_available(): 376 | torch.cuda.manual_seed(h.seed) 377 | # Skip multi-gpu implementation until supported by all the models. 378 | h.num_gpus = 1 #torch.cuda.device_count() 379 | h.batch_size = int(h.batch_size / h.num_gpus) 380 | print('Batch size per GPU :', h.batch_size) 381 | else: 382 | pass 383 | 384 | if h.num_gpus > 1: 385 | mp.spawn(train, nprocs=h.num_gpus, args=(a, h, fm_h)) 386 | else: 387 | train(0, a, h, fm_h) 388 | 389 | 390 | if __name__ == '__main__': 391 | main() 392 | --------------------------------------------------------------------------------