├── .gitattributes
├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── Samples
├── p257_016.wav
├── p260_131.wav
└── p266_015.wav
├── config
└── config_v3.json
├── environment.yml
├── images
└── LPC-NFS.png
├── inference_hifigan.py
├── inference_hifiglot.py
├── pyproject.toml
├── requirements.txt
├── src
└── neural_formant_synthesis
│ ├── __init__.py
│ ├── dataset.py
│ ├── feature_extraction.py
│ ├── functions.py
│ ├── glotnet
│ ├── __init__.py
│ ├── model
│ │ └── feedforward
│ │ │ ├── __init__.py
│ │ │ ├── activations.py
│ │ │ ├── convolution.py
│ │ │ ├── convolution_layer.py
│ │ │ ├── convolution_stack.py
│ │ │ └── wavenet.py
│ └── sigproc
│ │ ├── __init__.py
│ │ ├── allpole.py
│ │ ├── biquad.py
│ │ ├── emphasis.py
│ │ ├── levinson.py
│ │ ├── lfilter.py
│ │ ├── lpc.py
│ │ ├── lsf.py
│ │ ├── melspec.py
│ │ └── oscillator.py
│ ├── models.py
│ ├── sigproc
│ ├── __init__.py
│ ├── levinson.py
│ └── lpc.py
│ ├── third_party
│ ├── __init__.py
│ └── hifi_gan
│ │ ├── LICENSE
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── config_v1.json
│ │ ├── config_v2.json
│ │ ├── config_v3.json
│ │ ├── env.py
│ │ ├── inference.py
│ │ ├── inference_e2e.py
│ │ ├── meldataset.py
│ │ ├── models.py
│ │ ├── requirements.txt
│ │ ├── train.py
│ │ └── utils.py
│ └── vctk_preprocessing.py
├── tests
├── test_dataset.py
├── test_train_imports.py
└── test_wavenet.py
├── train_e2e_hifigan.py
└── train_e2e_hifiglot.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.7z filter=lfs diff=lfs merge=lfs -text
2 | *.arrow filter=lfs diff=lfs merge=lfs -text
3 | *.bin filter=lfs diff=lfs merge=lfs -text
4 | *.bz2 filter=lfs diff=lfs merge=lfs -text
5 | *.ckpt filter=lfs diff=lfs merge=lfs -text
6 | *.ftz filter=lfs diff=lfs merge=lfs -text
7 | *.gz filter=lfs diff=lfs merge=lfs -text
8 | *.h5 filter=lfs diff=lfs merge=lfs -text
9 | *.joblib filter=lfs diff=lfs merge=lfs -text
10 | *.lfs.* filter=lfs diff=lfs merge=lfs -text
11 | *.mlmodel filter=lfs diff=lfs merge=lfs -text
12 | *.model filter=lfs diff=lfs merge=lfs -text
13 | *.msgpack filter=lfs diff=lfs merge=lfs -text
14 | *.npy filter=lfs diff=lfs merge=lfs -text
15 | *.npz filter=lfs diff=lfs merge=lfs -text
16 | *.onnx filter=lfs diff=lfs merge=lfs -text
17 | *.ot filter=lfs diff=lfs merge=lfs -text
18 | *.parquet filter=lfs diff=lfs merge=lfs -text
19 | *.pb filter=lfs diff=lfs merge=lfs -text
20 | *.pickle filter=lfs diff=lfs merge=lfs -text
21 | *.pkl filter=lfs diff=lfs merge=lfs -text
22 | *.pt filter=lfs diff=lfs merge=lfs -text
23 | *.pth filter=lfs diff=lfs merge=lfs -text
24 | *.rar filter=lfs diff=lfs merge=lfs -text
25 | *.safetensors filter=lfs diff=lfs merge=lfs -text
26 | saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27 | *.tar.* filter=lfs diff=lfs merge=lfs -text
28 | *.tar filter=lfs diff=lfs merge=lfs -text
29 | *.tflite filter=lfs diff=lfs merge=lfs -text
30 | *.tgz filter=lfs diff=lfs merge=lfs -text
31 | *.wasm filter=lfs diff=lfs merge=lfs -text
32 | *.xz filter=lfs diff=lfs merge=lfs -text
33 | *.png filter=lfs diff=lfs merge=lfs -text
34 | *.zip filter=lfs diff=lfs merge=lfs -text
35 | *.zst filter=lfs diff=lfs merge=lfs -text
36 | *tfevents* filter=lfs diff=lfs merge=lfs -text
37 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | output/*
10 |
11 | experiments/*
12 | ckpt/*
13 |
14 | # Distribution / packaging
15 | .Python
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | share/python-wheels/
29 | *.egg-info/
30 | .installed.cfg
31 | *.egg
32 | MANIFEST
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .nox/
48 | .coverage
49 | .coverage.*
50 | .cache
51 | nosetests.xml
52 | coverage.xml
53 | *.cover
54 | *.py,cover
55 | .hypothesis/
56 | .pytest_cache/
57 | cover/
58 |
59 | # Translations
60 | *.mo
61 | *.pot
62 |
63 | # Django stuff:
64 | *.log
65 | local_settings.py
66 | db.sqlite3
67 | db.sqlite3-journal
68 |
69 | # Flask stuff:
70 | instance/
71 | .webassets-cache
72 |
73 | # Scrapy stuff:
74 | .scrapy
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 |
79 | # PyBuilder
80 | .pybuilder/
81 | target/
82 |
83 | # Jupyter Notebook
84 | .ipynb_checkpoints
85 |
86 | # IPython
87 | profile_default/
88 | ipython_config.py
89 |
90 | # pyenv
91 | # For a library or package, you might want to ignore these files since the code is
92 | # intended to run in multiple environments; otherwise, check them in:
93 | # .python-version
94 |
95 | # pipenv
96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
99 | # install all needed dependencies.
100 | #Pipfile.lock
101 |
102 | # poetry
103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
104 | # This is especially recommended for binary packages to ensure reproducibility, and is more
105 | # commonly ignored for libraries.
106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
107 | #poetry.lock
108 |
109 | # pdm
110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
111 | #pdm.lock
112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
113 | # in version control.
114 | # https://pdm.fming.dev/#use-with-ide
115 | .pdm.toml
116 |
117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
118 | __pypackages__/
119 |
120 | # Celery stuff
121 | celerybeat-schedule
122 | celerybeat.pid
123 |
124 | # SageMath parsed files
125 | *.sage.py
126 |
127 | # Environments
128 | .env
129 | .venv
130 | env/
131 | venv/
132 | ENV/
133 | env.bak/
134 | venv.bak/
135 |
136 | # Spyder project settings
137 | .spyderproject
138 | .spyproject
139 |
140 | # Rope project settings
141 | .ropeproject
142 |
143 | # mkdocs documentation
144 | /site
145 |
146 | # mypy
147 | .mypy_cache/
148 | .dmypy.json
149 | dmypy.json
150 |
151 | # Pyre type checker
152 | .pyre/
153 |
154 | # pytype static type analyzer
155 | .pytype/
156 |
157 | # Cython debug symbols
158 | cython_debug/
159 |
160 | # PyCharm
161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
163 | # and can be added to the global gitignore or merged into this file. For a more nuclear
164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
165 | #.idea/
166 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "checkpoints/HiFi-Glot"]
2 | path = checkpoints/HiFi-Glot
3 | url = https://huggingface.co/ljuvela/SourceFilterNeuralFormantSynthesisE2E
4 | [submodule "checkpoints/NFS"]
5 | path = checkpoints/NFS
6 | url = https://huggingface.co/ljuvela/NeuralFormantSynthesis
7 | [submodule "checkpoints/NFS-E2E"]
8 | path = checkpoints/NFS-E2E
9 | url = https://huggingface.co/ljuvela/NeuralFormantSynthesisE2E
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Lauri Juvela
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Neural Formant Synthesis with Differentiable Resonant Filters
2 |
3 | Neural formant synthesis using differtiable resonant filters and source-filter model structure.
4 |
5 | Authors: [Lauri Juvela][lauri_profile], [Pablo Pérez Zarazaga][pablo_profile], [Gustav Eje Henter][gustav_profile], [Zofia Malisz][zofia_profile]
6 |
7 | [HiFi_link]: https://github.com/jik876/hifi-gan
8 | [GlotNet_link]: https://github.com/ljuvela/GlotNet
9 | [arxiv_link]: http://arxiv.org/abs/placeholder_link
10 | [demopage_link]: https://perezpoz.github.io/SFNeuralFormants
11 | [gustav_profile]: https://people.kth.se/~ghe/
12 | [pablo_profile]: https://www.kth.se/profile/pablopz
13 | [zofia_profile]: https://www.kth.se/profile/malisz
14 | [lauri_profile]: https://research.aalto.fi/en/persons/lauri-juvela
15 |
16 | [lfs_link]:https://git-lfs.com
17 |
18 | ## Table of contents
19 | 1. [Model overview](#model_struct)
20 | 1. [Sound samples](#sound_samples)
21 | 3. [Repository installation](#install)
22 | 1. [Conda environment](#conda)
23 | 2. [GlotNet](#glotnet)
24 | 3. [HiFi-GAN](#hifi)
25 | 4. [Additional libraries](#additional)
26 | 4. [Pre-trained models](#pretrained)
27 | 5. [Inference](#inference)
28 | 6. [Training](#training)
29 | 7. [Citation information](#citation)
30 |
31 | ## Model overview
32 |
33 | We present a model that performs neural speech synthesis using the structure of the source-filter model, allowing to independently inspect and manipulate the spectral envelope and glottal excitation:
34 |
35 | 
36 |
37 | ### Sound samples
38 |
39 | A description of the presented model and sound samples compared to other synthesis/manipulation systems can be found in the [project's demo webpage][demopage_link]
40 |
41 | ## Repository installation
42 |
43 | #### Conda environment
44 |
45 | First, we need to create a conda environment to install our dependencies. Use mamba to speed up the process if possible.
46 | ```sh
47 | mamba env create -n neuralformants -f environment.yml
48 | conda activate neuralformants
49 | ```
50 |
51 | Pre-trained models are available in HuggingFace, and can be downloaded using git-lfs. If you don't have git-lfs installed (it's included in `environment.yml`), you can find it [here][lfs_link]. Use the following command to download the pre-trained models:
52 | ```sh
53 | git submodule update --init --recursive
54 | ```
55 |
56 | Install the package in development mode:
57 | ```sh
58 | pip install -e .
59 | ```
60 |
61 |
62 | #### GlotNet
63 | GlotNet is included partially for WaveNet models and DSP functions. Full repository is available [here][GlotNet_link]
64 |
65 |
66 | #### HiFi-GAN
67 | HiFi-GAN is included in the `hifi_gan` subdirectory. Original source code is available [here][HiFi_link]
68 |
69 | ## Inference
70 |
71 | We provide a script to run inference on the end-to-end architecture, such that an audio file can be provided as input and a wav file with the manipulated features is stored as output.
72 |
73 | Change the feature scaling to modify pitch (with F0) or formants. The scales are provided as a list of 5 elements with the following order:
74 | ```python
75 | [F0, F1, F2, F3, F4]
76 | ```
77 | An example with the provided audio samples from the VCTK dataset can be run using:
78 |
79 | HiFi-Glot
80 | ```sh
81 | python inference_hifiglot.py \
82 | --input_path "./Samples" \
83 | --output_path "./output/hifi-glot" \
84 | --config "./checkpoints/HiFi-Glot/config_hifigan.json" \
85 | --fm_config "./checkpoints/HiFi-Glot/config_feature_map.json" \
86 | --checkpoint_path "./checkpoints/HiFi-Glot" \
87 | --feature_scale "[1.0, 1.0, 1.0, 1.0, 1.0]"
88 | ```
89 |
90 | NFS
91 | ```sh
92 | python inference_hifigan.py \
93 | --input_path "./Samples" \
94 | --output_path "./output/nfs" \
95 | --config "./checkpoints/NFS/config_hifigan.json" \
96 | --fm_config "./checkpoints/NFS/config_feature_map.json" \
97 | --checkpoint_path "./checkpoints/NFS" \
98 | --feature_scale "[1.0, 1.0, 1.0, 1.0, 1.0]"
99 | ```
100 |
101 | NFS-E2E
102 | ```sh
103 | python inference_hifigan.py \
104 | --input_path "./Samples" \
105 | --output_path "./output/nfs-e2e" \
106 | --config "./checkpoints/NFS-E2E/config_hifigan.json" \
107 | --fm_config "./checkpoints/NFS-E2E/config_feature_map.json" \
108 | --checkpoint_path "./checkpoints/NFS-E2E" \
109 | --feature_scale "[1.0, 1.0, 1.0, 1.0, 1.0]"
110 | ```
111 |
112 |
113 | ## Model training
114 |
115 | Training of the HiFi-GAN and HiFi-Glot models is possible with the end-to-end architecture by using the the scripts `train_e2e_hifigan.py` and `train_e2e_hifiglot.py`.
116 |
117 |
118 | ## Citation information
119 |
120 | Citation information will be added when a pre-print is available.
121 |
--------------------------------------------------------------------------------
/Samples/p257_016.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ljuvela/SourceFilterNeuralFormants/d0894c2aa153510e6967c3b62c47f73ce4cb3879/Samples/p257_016.wav
--------------------------------------------------------------------------------
/Samples/p260_131.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ljuvela/SourceFilterNeuralFormants/d0894c2aa153510e6967c3b62c47f73ce4cb3879/Samples/p260_131.wav
--------------------------------------------------------------------------------
/Samples/p266_015.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ljuvela/SourceFilterNeuralFormants/d0894c2aa153510e6967c3b62c47f73ce4cb3879/Samples/p266_015.wav
--------------------------------------------------------------------------------
/config/config_v3.json:
--------------------------------------------------------------------------------
1 | {
2 | "resblock": "2",
3 | "num_gpus": 1,
4 | "batch_size": 16,
5 | "learning_rate": 0.0002,
6 | "adam_b1": 0.8,
7 | "adam_b2": 0.99,
8 | "lr_decay": 0.999,
9 | "seed": 1234,
10 |
11 | "upsample_rates": [8,8,4],
12 | "upsample_kernel_sizes": [16,16,8],
13 | "upsample_initial_channel": 256,
14 | "resblock_kernel_sizes": [3,5,7],
15 | "resblock_dilation_sizes": [[1,2], [2,6], [3,12]],
16 |
17 | "segment_size": 8192,
18 | "num_mels": 80,
19 | "num_freq": 1025,
20 | "n_fft": 1024,
21 | "hop_size": 256,
22 | "win_size": 1024,
23 | "allpole_order": 30,
24 | "pre_emph_coeff": 0.97,
25 |
26 | "sampling_rate": 22050,
27 |
28 | "fmin": 0,
29 | "fmax": 8000,
30 | "fmax_for_loss": null,
31 |
32 | "num_workers": 4,
33 |
34 | "dist_config": {
35 | "dist_backend": "nccl",
36 | "dist_url": "tcp://localhost:54321",
37 | "world_size": 1
38 | }
39 | }
40 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | channels:
2 | - pytorch
3 | - conda-forge
4 | dependencies:
5 | - python=3.12
6 | - pytorch=2.4
7 | - torchaudio
8 | - tensorboard
9 | - matplotlib
10 | - pandas
11 | - pytest
12 | - git-lfs
13 | - pip:
14 | - diffsptk
15 | - pyworld
16 | - tqdm
17 | - ipdb
18 |
--------------------------------------------------------------------------------
/images/LPC-NFS.png:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:a4ec577e3dd476e0972ef4053ab13fe534345973e634d72c25653a9c3af89bca
3 | size 456332
4 |
--------------------------------------------------------------------------------
/inference_hifigan.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | warnings.simplefilter(action='ignore', category=FutureWarning)
3 | import os
4 | import argparse
5 | import json
6 | import torch
7 | from neural_formant_synthesis.third_party.hifi_gan.env import AttrDict, build_env
8 | from neural_formant_synthesis.third_party.hifi_gan.utils import scan_checkpoint
9 |
10 |
11 | from neural_formant_synthesis.glotnet.sigproc.lpc import LinearPredictor
12 | from neural_formant_synthesis.glotnet.sigproc.emphasis import Emphasis
13 |
14 | from neural_formant_synthesis.models import FM_Hifi_Generator, fm_config_obj, Envelope_wavenet, Envelope_conformer
15 | from neural_formant_synthesis.feature_extraction import feature_extractor, Normaliser, MedianPool1d
16 | from neural_formant_synthesis.models import NeuralFormantSynthesisGenerator
17 |
18 |
19 | from neural_formant_synthesis.glotnet.sigproc.levinson import forward_levinson
20 |
21 | import torchaudio as ta
22 | import pandas as pd
23 | from tqdm import tqdm
24 | from glob import glob
25 |
26 |
27 | torch.backends.cudnn.benchmark = True
28 |
29 |
30 | def generate_wave_list(file_list, scale_list, a, h, fm_h):
31 |
32 | torch.cuda.manual_seed(h.seed)
33 | if torch.cuda.is_available():
34 | device = torch.device('cuda:0')
35 | else:
36 | device = torch.device('cpu')
37 |
38 | target_sr = h.sampling_rate
39 | win_size = h.win_size
40 | hop_size = h.hop_size
41 |
42 | feat_extractor = feature_extractor(sr = target_sr,window_samples = win_size, step_samples = hop_size, formant_ceiling = 10000, max_formants = 4)
43 | median_filter = MedianPool1d(kernel_size = 3, stride = 1, padding = 0, same = True)
44 | pre_emphasis_cpu = Emphasis(alpha = h.pre_emph_coeff)
45 |
46 | normalise_features = Normaliser(target_sr)
47 |
48 | generator = NeuralFormantSynthesisGenerator(fm_config=fm_h, g_config=h,
49 | pretrained_fm=None,
50 | freeze_fm=False,
51 | device=device)
52 |
53 |
54 | print("checkpoints directory : ", a.checkpoint_path)
55 |
56 | if os.path.isdir(a.checkpoint_path):
57 | cp_g = scan_checkpoint(a.checkpoint_path, 'g_')
58 |
59 |
60 | generator.load_generator_e2e_checkpoint(cp_g)
61 |
62 | generator = generator.to(device)
63 |
64 | generator.eval()
65 |
66 |
67 | # Read files from list
68 | for file in tqdm(file_list, total = len(file_list)):
69 | # Read audio and resample if necessary
70 | x, sample_rate = ta.load(file)
71 | x = x[0:1].type(torch.DoubleTensor)
72 |
73 | x = ta.functional.resample(x, sample_rate, target_sr)
74 |
75 | # Get features using feature extractor
76 |
77 | x_preemph = pre_emphasis_cpu(x.unsqueeze(0))
78 | x_preemph = x_preemph.squeeze(0).squeeze(0)
79 | formants, energy, centroid, tilt, pitch, voicing_flag,_, _,_ = feat_extractor(x_preemph)
80 |
81 | # Parameter smoothing and length matching
82 |
83 | formants = median_filter(formants.T.unsqueeze(1)).squeeze(1).T
84 |
85 | pitch = pitch.squeeze(0)
86 | voicing_flag = voicing_flag.squeeze(0)
87 |
88 | # If pitch length is smaller than formants, pad pitch and voicing flag with last value
89 | if pitch.size(0) < formants.size(0):
90 | pitch = torch.nn.functional.pad(pitch, (0, formants.size(0) - pitch.size(0)), mode = 'constant', value = pitch[-1])
91 | voicing_flag = torch.nn.functional.pad(voicing_flag, (0, formants.size(0) - voicing_flag.size(0)), mode = 'constant', value = voicing_flag[-1])
92 | # If pitch length is larger than formants, truncate pitch and voicing flag
93 | elif pitch.size(0) > formants.size(0):
94 | pitch = pitch[:formants.size(0)]
95 | voicing_flag = voicing_flag[:formants.size(0)]
96 |
97 | # We can apply manipulation HERE
98 |
99 | log_pitch = torch.log(pitch)
100 |
101 | #pitch = pitch * scale_list[0]
102 | for i in range(voicing_flag.size(0)):
103 | if voicing_flag[i] == 1:
104 | log_pitch[i] = log_pitch[i] + torch.log(torch.tensor(scale_list[0]))
105 | formants[i,0] = formants[i,0] * scale_list[1]
106 | formants[i,1] = formants[i,1] * scale_list[2]
107 | formants[i,2] = formants[i,2] * scale_list[3]
108 | formants[i,3] = formants[i,3] * scale_list[4]
109 |
110 | # Normalise data
111 | log_pitch, formants, tilt, centroid, energy = normalise_features(log_pitch, formants, tilt, centroid, energy)
112 |
113 | #Create input data
114 | #size --> (Batch, features, sequence)
115 | norm_feat = torch.transpose(torch.cat((log_pitch.unsqueeze(1), formants, tilt.unsqueeze(1), centroid.unsqueeze(1), energy.unsqueeze(1), voicing_flag.unsqueeze(1)),dim = -1), 0, 1)
116 |
117 | norm_feat = norm_feat.type(torch.FloatTensor).unsqueeze(0).to(device)
118 |
119 | y_g_hat, _ = generator(norm_feat)
120 |
121 | output_file = os.path.splitext(os.path.basename(file))[0] + '_wave_' + str(scale_list[0]) + '_' + str(scale_list[1]) + '_' + str(scale_list[2]) + '_' + str(scale_list[3]) + '_' + str(scale_list[4]) + '.wav'
122 | output_orig = os.path.splitext(os.path.basename(file))[0] + '_orig.wav'
123 | out_path = os.path.join(a.output_path, output_file)
124 | out_orig_path = os.path.join(a.output_path, output_orig)
125 |
126 | ta.save(out_path, y_g_hat.detach().cpu().squeeze(0), target_sr)
127 | if not os.path.exists(out_orig_path):
128 | ta.save(out_orig_path, x.type(torch.FloatTensor), target_sr)
129 |
130 | def parse_file_list(list_file):
131 | """
132 | Read text file with paths to the files to process.
133 | """
134 | file1 = open(list_file, 'r')
135 | lines = file1.read().splitlines()
136 | return lines
137 |
138 | def str_to_list(in_str):
139 | return list(map(float, in_str.strip('[]').split(',')))
140 |
141 | def main():
142 |
143 | parser = argparse.ArgumentParser()
144 |
145 | parser.add_argument('--input_path', default = None, help="Path to directory containing files to process.")
146 | parser.add_argument('--list_file', default = None, help="Text file containing list of files to process. Optional argument to use instead of input_path.")
147 | parser.add_argument('--output_path', default='test_output', help="Path to directory to save processed files")
148 | parser.add_argument('--config', default='', help="Path to HiFi-GAN config json file")
149 | parser.add_argument('--fm_config', default='', help="Path to feature mapping model config json file")
150 | parser.add_argument('--env_config', default='', help="Path to envelope estimation model config json file")
151 | parser.add_argument('--audio_ext', default = '.wav', help="Extension of the audio files to process")
152 | parser.add_argument('--checkpoint_path', help="Path to pre-trained HiFi-GAN model")
153 | parser.add_argument('--feature_scale', help="List of scales for pitch and formant frequencies -- [F0, F1, F2, F3, F4]")
154 |
155 |
156 | a = parser.parse_args()
157 |
158 | with open(a.config) as f:
159 | data = f.read()
160 |
161 | json_config = json.loads(data)
162 | h = AttrDict(json_config)
163 |
164 | with open(a.fm_config) as f:
165 | data = f.read()
166 | json_fm_config = json.loads(data)
167 | fm_h = AttrDict(json_fm_config)
168 | # fm_h = fm_config_obj(json_fm_config)
169 |
170 | # build_env(a.config, 'config.json', a.checkpoint_path)
171 | if a.input_path is not None:
172 | file_list = glob(os.path.join(a.input_path,'*' + a.audio_ext))
173 | elif a.list_file is not None:
174 | file_list = parse_file_list(a.list_file)
175 | else:
176 | raise ValueError('Input arguments should include either input_path or file_list')
177 |
178 | if not os.path.exists(a.output_path):
179 | os.makedirs(a.output_path, exist_ok=True)
180 |
181 | scale_list = str_to_list(a.feature_scale)
182 | if len(scale_list) != 5:
183 | raise ValueError('The scaling vector must contain 5 features: [F0, F1, F2, F3, F4]')
184 |
185 | torch.manual_seed(h.seed)
186 |
187 | generate_wave_list(file_list, scale_list, a, h, fm_h)
188 |
189 |
190 | if __name__ == '__main__':
191 | main()
192 |
--------------------------------------------------------------------------------
/inference_hifiglot.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | warnings.simplefilter(action='ignore', category=FutureWarning)
3 | import os
4 | import argparse
5 | import json
6 | import torch
7 | from neural_formant_synthesis.third_party.hifi_gan.env import AttrDict, build_env
8 | #from neural_formant_synthesis.third_party.hifi_gan.models import discriminator_metrics
9 | from neural_formant_synthesis.third_party.hifi_gan.utils import scan_checkpoint
10 |
11 |
12 | from neural_formant_synthesis.glotnet.sigproc.lpc import LinearPredictor
13 | from neural_formant_synthesis.glotnet.sigproc.emphasis import Emphasis
14 |
15 | from neural_formant_synthesis.models import FM_Hifi_Generator, fm_config_obj, Envelope_wavenet, Envelope_conformer
16 | from neural_formant_synthesis.feature_extraction import feature_extractor, Normaliser, MedianPool1d
17 | from neural_formant_synthesis.models import SourceFilterFormantSynthesisGenerator
18 |
19 |
20 | from neural_formant_synthesis.glotnet.sigproc.levinson import forward_levinson
21 |
22 | import torchaudio as ta
23 | import pandas as pd
24 | from tqdm import tqdm
25 | from glob import glob
26 |
27 |
28 | torch.backends.cudnn.benchmark = True
29 |
30 |
31 | def generate_wave_list(file_list, scale_list, a, h, fm_h):
32 |
33 | torch.cuda.manual_seed(h.seed)
34 | if torch.cuda.is_available():
35 | device = torch.device('cuda:0')
36 | else:
37 | device = torch.device('cpu')
38 |
39 | target_sr = h.sampling_rate
40 | win_size = h.win_size
41 | hop_size = h.hop_size
42 |
43 | feat_extractor = feature_extractor(sr = target_sr,window_samples = win_size, step_samples = hop_size, formant_ceiling = 10000, max_formants = 4)
44 | median_filter = MedianPool1d(kernel_size = 3, stride = 1, padding = 0, same = True)
45 | pre_emphasis_cpu = Emphasis(alpha = h.pre_emph_coeff)
46 |
47 | normalise_features = Normaliser(target_sr)
48 |
49 | generator = SourceFilterFormantSynthesisGenerator(
50 | fm_config=fm_h,
51 | g_config=h,
52 | pretrained_fm=None,
53 | freeze_fm=False,
54 | device=device)
55 |
56 |
57 | print("checkpoints directory : ", a.checkpoint_path)
58 |
59 | if os.path.isdir(a.checkpoint_path):
60 | cp_g = scan_checkpoint(a.checkpoint_path, 'g_')
61 |
62 |
63 | generator.load_generator_e2e_checkpoint(cp_g)
64 |
65 | generator = generator.to(device)
66 |
67 | generator.eval()
68 |
69 |
70 | # Read files from list
71 | for file in tqdm(file_list, total = len(file_list)):
72 | # Read audio and resample if necessary
73 | x, sample_rate = ta.load(file)
74 | x = x[0:1].type(torch.DoubleTensor)
75 |
76 | x = ta.functional.resample(x, sample_rate, target_sr)
77 |
78 | # Get features using feature extractor
79 |
80 | x_preemph = pre_emphasis_cpu(x.unsqueeze(0))
81 | x_preemph = x_preemph.squeeze(0).squeeze(0)
82 | formants, energy, centroid, tilt, pitch, voicing_flag,_, _,_ = feat_extractor(x_preemph)
83 |
84 | # Parameter smoothing and length matching
85 |
86 | formants = median_filter(formants.T.unsqueeze(1)).squeeze(1).T
87 |
88 | pitch = pitch.squeeze(0)
89 | voicing_flag = voicing_flag.squeeze(0)
90 |
91 | # If pitch length is smaller than formants, pad pitch and voicing flag with last value
92 | if pitch.size(0) < formants.size(0):
93 | pitch = torch.nn.functional.pad(pitch, (0, formants.size(0) - pitch.size(0)), mode = 'constant', value = pitch[-1])
94 | voicing_flag = torch.nn.functional.pad(voicing_flag, (0, formants.size(0) - voicing_flag.size(0)), mode = 'constant', value = voicing_flag[-1])
95 | # If pitch length is larger than formants, truncate pitch and voicing flag
96 | elif pitch.size(0) > formants.size(0):
97 | pitch = pitch[:formants.size(0)]
98 | voicing_flag = voicing_flag[:formants.size(0)]
99 |
100 | # We can apply manipulation HERE
101 |
102 | log_pitch = torch.log(pitch)
103 |
104 | #pitch = pitch * scale_list[0]
105 | for i in range(voicing_flag.size(0)):
106 | if voicing_flag[i] == 1:
107 | log_pitch[i] = log_pitch[i] + torch.log(torch.tensor(scale_list[0]))
108 | formants[i,0] = formants[i,0] * scale_list[1]
109 | formants[i,1] = formants[i,1] * scale_list[2]
110 | formants[i,2] = formants[i,2] * scale_list[3]
111 | formants[i,3] = formants[i,3] * scale_list[4]
112 |
113 | # Normalise data
114 | log_pitch, formants, tilt, centroid, energy = normalise_features(log_pitch, formants, tilt, centroid, energy)
115 |
116 | #Create input data
117 | #size --> (Batch, features, sequence)
118 | norm_feat = torch.transpose(torch.cat((log_pitch.unsqueeze(1), formants, tilt.unsqueeze(1), centroid.unsqueeze(1), energy.unsqueeze(1), voicing_flag.unsqueeze(1)),dim = -1), 0, 1)
119 |
120 | norm_feat = norm_feat.type(torch.FloatTensor).unsqueeze(0).to(device)
121 |
122 | y_g_hat, _, _ = generator(norm_feat)
123 |
124 | output_file = os.path.splitext(os.path.basename(file))[0] + '_wave_' + str(scale_list[0]) + '_' + str(scale_list[1]) + '_' + str(scale_list[2]) + '_' + str(scale_list[3]) + '_' + str(scale_list[4]) + '.wav'
125 | output_orig = os.path.splitext(os.path.basename(file))[0] + '_orig.wav'
126 | out_path = os.path.join(a.output_path, output_file)
127 | out_orig_path = os.path.join(a.output_path, output_orig)
128 |
129 | ta.save(out_path, y_g_hat.detach().cpu().squeeze(0), target_sr)
130 | if not os.path.exists(out_orig_path):
131 | ta.save(out_orig_path, x.type(torch.FloatTensor), target_sr)
132 |
133 | def parse_file_list(list_file):
134 | """
135 | Read text file with paths to the files to process.
136 | """
137 | file1 = open(list_file, 'r')
138 | lines = file1.read().splitlines()
139 | return lines
140 |
141 | def str_to_list(in_str):
142 | return list(map(float, in_str.strip('[]').split(',')))
143 |
144 | def main():
145 |
146 | parser = argparse.ArgumentParser()
147 |
148 | parser.add_argument('--input_path', default = None, help="Path to directory containing files to process.")
149 | parser.add_argument('--list_file', default = None, help="Text file containing list of files to process. Optional argument to use instead of input_path.")
150 | parser.add_argument('--output_path', default='test_output', help="Path to directory to save processed files")
151 | parser.add_argument('--config', default='', help="Path to HiFi-GAN config json file")
152 | parser.add_argument('--fm_config', default='', help="Path to feature mapping model config json file")
153 | parser.add_argument('--env_config', default='', help="Path to envelope estimation model config json file")
154 | parser.add_argument('--audio_ext', default = '.wav', help="Extension of the audio files to process")
155 | parser.add_argument('--checkpoint_path', help="Path to pre-trained HiFi-GAN model")
156 | parser.add_argument('--feature_scale', help="List of scales for pitch and formant frequencies -- [F0, F1, F2, F3, F4]")
157 |
158 |
159 | a = parser.parse_args()
160 |
161 | with open(a.config) as f:
162 | data = f.read()
163 |
164 | json_config = json.loads(data)
165 | h = AttrDict(json_config)
166 |
167 | with open(a.fm_config) as f:
168 | data = f.read()
169 | json_fm_config = json.loads(data)
170 | fm_h = AttrDict(json_fm_config)
171 |
172 | # build_env(a.config, 'config.json', a.checkpoint_path)
173 | if a.input_path is not None:
174 | file_list = glob(os.path.join(a.input_path,'*' + a.audio_ext))
175 | elif a.list_file is not None:
176 | file_list = parse_file_list(a.list_file)
177 | else:
178 | raise ValueError('Input arguments should include either input_path or file_list')
179 |
180 | if not os.path.exists(a.output_path):
181 | os.makedirs(a.output_path, exist_ok=True)
182 |
183 | scale_list = str_to_list(a.feature_scale)
184 | if len(scale_list) != 5:
185 | raise ValueError('The scaling vector must contain 5 features: [F0, F1, F2, F3, F4]')
186 |
187 | torch.manual_seed(h.seed)
188 |
189 | generate_wave_list(file_list, scale_list, a, h, fm_h)
190 |
191 |
192 | if __name__ == '__main__':
193 | main()
194 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "NeuralFormantSynthesis" # Required
3 | version = "1.0.0" # Required
4 | description = "" # Optional
5 | readme = "README.md" # Optional
6 | requires-python = ">=3.7"
7 | license = {file = "LICENSE.txt"}
8 | authors = [
9 | {name = "Lauri Juvela", email = "lauri.juvela@aalto.fi" }
10 | ]
11 |
12 | [tool.setuptools]
13 | include-package-data = true
14 |
15 | [tool.setuptools.packages.find]
16 | namespaces = true
17 | where = ["src"]
18 |
19 | [build-system]
20 | # These are the assumed default build requirements from pip:
21 | # https://pip.pypa.io/en/stable/reference/pip/#pep-517-and-518-support
22 | requires = ["setuptools>=43.0.0", "wheel"]
23 | build-backend = "setuptools.build_meta"
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy == 1.24.4
2 | torch
3 | torchaudio
4 | diffsptk
5 | pyworld
6 | tqdm
7 | ipdb
8 | tensorboard
--------------------------------------------------------------------------------
/src/neural_formant_synthesis/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ljuvela/SourceFilterNeuralFormants/d0894c2aa153510e6967c3b62c47f73ce4cb3879/src/neural_formant_synthesis/__init__.py
--------------------------------------------------------------------------------
/src/neural_formant_synthesis/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | from logging import warning
4 | import random
5 | import torch
6 | import torchaudio as ta
7 | from torch.utils.data import Dataset
8 | from typing import List, Tuple
9 | from .feature_extraction import Normaliser
10 |
11 | from neural_formant_synthesis.glotnet.sigproc.levinson import forward_levinson
12 |
13 | from neural_formant_synthesis.third_party.hifi_gan.meldataset import mel_spectrogram
14 | from tqdm import tqdm
15 |
16 | class FeatureDataset(Dataset):
17 | """ Dataset for audio files """
18 |
19 | def __init__(self,
20 | dataset_dir: str,
21 | segment_len: int,
22 | sampling_rate: int,
23 | feature_ext: str = '.pt',
24 | audio_ext: str = '.wav',
25 | causal: bool = True,
26 | non_causal_segment: int = 0,
27 | normalise: bool = True,
28 | dtype: torch.dtype = torch.float32,
29 | device: str = 'cpu'):
30 | """
31 | Args:
32 | config: Config object
33 | audio_dir: directory containing audio files
34 | audio_ext: file extension of audio files
35 | file_list: list of audio files
36 | transforms: transforms to apply to audio, output as auxiliary feature for conditioning
37 | dtype: data type of output
38 | """
39 |
40 | self.dataset_dir = dataset_dir
41 | self.segment_len = segment_len
42 | self.sampling_rate = sampling_rate
43 | self.feature_ext = feature_ext
44 | self.audio_ext = audio_ext
45 | self.normalise = normalise
46 | self.dtype = dtype
47 | self.device = device
48 |
49 | self.causal = causal
50 | self.non_causal_segment = non_causal_segment
51 |
52 | self.total_segment_len = segment_len + non_causal_segment
53 |
54 | self.normaliser = Normaliser(self.sampling_rate)
55 |
56 | self.audio_files = glob(os.path.join(
57 | self.dataset_dir, f"*{self.feature_ext}"))
58 |
59 | # elements are (filename, start, stop)
60 | self.segment_index: List[Tuple(str, int, int)] = []
61 |
62 | for f in tqdm(self.audio_files, total=len(self.audio_files)):
63 | self._check_feature_file(f)
64 |
65 | def _check_feature_file(self, f):
66 |
67 |
68 | features = torch.load(f)
69 | input_features = features['Energy']
70 | num_samples = input_features.size(0)
71 |
72 | # read sements from the end (to minimize zero padding)
73 | stop = num_samples
74 | start = stop - self.segment_len # Only consider causal length to set start index.
75 | start = max(start, 0)
76 | while stop > 0:
77 | #if self.causal:
78 | self.segment_index.append(
79 | (os.path.realpath(f), start, stop))
80 | #else:
81 | # stop_noncausal = min(stop + self.non_causal_segment, num_samples)
82 | # self.segment_index.append(
83 | # (os.path.realpath(f), start, stop_noncausal))
84 |
85 | stop = stop - self.segment_len
86 | start = stop - self.segment_len
87 | start = max(start, 0)
88 |
89 | def __len__(self):
90 | return len(self.segment_index)
91 |
92 | def __getitem__(self, i):
93 | f, start, stop = self.segment_index[i]
94 | data = torch.load(f)
95 | features = torch.cat((data["Pitch"].unsqueeze(1), data["Formants"], data["Tilt"].unsqueeze(1), data["Centroid"].unsqueeze(1), data["Energy"].unsqueeze(1), data["Voicing"].unsqueeze(1)), dim = -1)
96 | out_features = data["R_Coeff"]
97 |
98 | x = features[start:stop,:]
99 | y = out_features[start:stop,:]
100 |
101 | if torch.any(torch.isnan(data["Pitch"])):
102 | raise ValueError("Pitch features are NaN before normalisation.")
103 | if torch.any(torch.isnan(data["Formants"])):
104 | raise ValueError("Formants features are NaN before normalisation.")
105 | if torch.any(torch.isnan(data["Tilt"])):
106 | raise ValueError("Tilt features are NaN before normalisation.")
107 | if torch.any(torch.isnan(data["Centroid"])):
108 | raise ValueError("Centroid features are NaN before normalisation.")
109 | if torch.any(torch.isnan(data["Energy"])):
110 | raise ValueError("Energy features are NaN before normalisation.")
111 |
112 | num_samples = features.size(0)
113 | pad_left = 0
114 | pad_right = 0
115 |
116 | if start == 0:
117 | pad_left = self.segment_len - x.size(0)
118 |
119 | # zero pad to segment_len + padding
120 | x = torch.transpose(torch.nn.functional.pad(torch.transpose(x, 0, 1), ( pad_left, 0), mode='replicate'), 0, 1) # seq_len, n_channels
121 | y = torch.transpose(torch.nn.functional.pad(torch.transpose(y, 0, 1), ( pad_left, 0), mode='replicate'), 0, 1)
122 |
123 | if not self.causal:
124 | remaining_samples = min(num_samples - stop, self.non_causal_segment)
125 | x = torch.cat((x, features[stop:stop + remaining_samples,:]), dim=0)
126 | y = torch.cat((y, out_features[stop:stop + remaining_samples,:]), dim=0)
127 |
128 | pad_right = self.non_causal_segment - remaining_samples
129 | if pad_right > 0:
130 | x = torch.transpose(torch.nn.functional.pad(torch.transpose(x, 0, 1), ( 0, pad_right), mode='replicate'), 0, 1) # seq_len, n_channels
131 | y = torch.transpose(torch.nn.functional.pad(torch.transpose(y, 0, 1), ( 0, pad_right), mode='replicate'), 0, 1)
132 |
133 | if not(x.size(0) == self.total_segment_len):
134 | raise ValueError('Padding in the wrong dimension')
135 |
136 | x = torch.transpose(torch.cat((torch.cat(self.normaliser(x[...,0:1], x[...,1:5], x[...,5:6], x[...,6:7], x[...,7:8]),dim = -1), x[...,8:9]),dim = -1), 0, 1)
137 | y = torch.transpose(y, 0, 1)
138 |
139 | if torch.any(torch.isnan(x)):
140 | raise ValueError("Output x features are NaN.")
141 | if torch.any(torch.isnan(y)):
142 | raise ValueError("Output y features are NaN.")
143 | # Set sequence len as last dimension.
144 | return x.type(torch.FloatTensor).to(self.device), y.type(torch.FloatTensor).to(self.device)
145 |
146 | class FeatureDataset_with_Mel(Dataset):
147 | def __init__(self,
148 | dataset_dir: str,
149 | segment_len: int,
150 | sampling_rate: int,
151 | frame_size: int,
152 | hop_size:int,
153 | feature_ext: str = '.pt',
154 | audio_ext: str = '.wav',
155 | causal: bool = True,
156 | non_causal_segment: int = 0,
157 | normalise: bool = True,
158 | dtype: torch.dtype = torch.float32,
159 | device: str = 'cpu'):
160 | """
161 | Args:
162 | config: Config object
163 | audio_dir: directory containing audio files
164 | audio_ext: file extension of audio files
165 | file_list: list of audio files
166 | transforms: transforms to apply to audio, output as auxiliary feature for conditioning
167 | dtype: data type of output
168 | """
169 |
170 | self.dataset_dir = dataset_dir
171 | self.segment_len = segment_len
172 | self.sampling_rate = sampling_rate
173 | self.frame_size = frame_size
174 | self.hop_size = hop_size
175 | self.feature_ext = feature_ext
176 | self.audio_ext = audio_ext
177 | self.normalise = normalise
178 | self.dtype = dtype
179 | self.device = device
180 |
181 | self.causal = causal
182 | self.non_causal_segment = non_causal_segment
183 |
184 | self.total_segment_len = segment_len + non_causal_segment
185 |
186 | self.melspec = ta.transforms.MelSpectrogram(sample_rate = self.sampling_rate, n_fft = self.frame_size, win_length = self.frame_size, hop_length = self.hop_size, n_mels = 80)
187 |
188 | self.normaliser = Normaliser(self.sampling_rate)
189 |
190 | self.audio_files = glob(os.path.join(
191 | self.dataset_dir, f"*{self.feature_ext}"))
192 |
193 | # elements are (filename, start, stop)
194 | self.segment_index: List[Tuple(str, int, int)] = []
195 |
196 | for f in tqdm(self.audio_files, total=len(self.audio_files)):
197 | self._check_feature_file(f)
198 |
199 | def _check_feature_file(self, f):
200 |
201 |
202 | features = torch.load(f)
203 | input_features = features['Energy']
204 | num_samples = input_features.size(0)
205 |
206 | # read sements from the end (to minimize zero padding)
207 | stop = num_samples
208 | start = stop - self.segment_len # Only consider causal length to set start index.
209 | start = max(start, 0)
210 | while stop > 0:
211 | #if self.causal:
212 | self.segment_index.append(
213 | (os.path.realpath(f), start, stop))
214 | #else:
215 | # stop_noncausal = min(stop + self.non_causal_segment, num_samples)
216 | # self.segment_index.append(
217 | # (os.path.realpath(f), start, stop_noncausal))
218 |
219 | stop = stop - self.segment_len
220 | start = stop - self.segment_len
221 | start = max(start, 0)
222 |
223 | def __len__(self):
224 | return len(self.segment_index)
225 |
226 | def __getitem__(self, i):
227 |
228 | f, start, stop = self.segment_index[i]
229 |
230 | audio_file = os.path.splitext(f)[0] + self.audio_ext
231 | start_samples = start * self.hop_size
232 | stop_samples = (stop - 1) * self.hop_size
233 |
234 | data = torch.load(f)
235 | features = torch.cat((torch.log(torch.clamp(data["Pitch"],50, 2000)).unsqueeze(1), data["Formants"], data["Tilt"].unsqueeze(1), data["Centroid"].unsqueeze(1), data["Energy"].unsqueeze(1), data["Voicing"].unsqueeze(1)), dim = -1)
236 | out_features = data["R_Coeff"]
237 |
238 | x = features[start:stop,:]
239 | y = out_features[start:stop,:]
240 |
241 | audio, sample_rate = ta.load(audio_file)
242 |
243 | if sample_rate != self.sampling_rate:
244 | audio = ta.functional.resample(audio, sample_rate, self.sampling_rate)
245 | audio_segment = audio[...,start_samples:stop_samples]
246 |
247 | num_samples = features.size(0)
248 | pad_left = 0
249 | pad_right = 0
250 |
251 | if start == 0:
252 | pad_left = self.segment_len - x.size(0)
253 | pad_left_audio = pad_left * self.hop_size
254 |
255 | # zero pad to segment_len + padding
256 | x = torch.transpose(torch.nn.functional.pad(torch.transpose(x, 0, 1), ( pad_left, 0), mode='replicate'), 0, 1) # seq_len, n_channels
257 | y = torch.transpose(torch.nn.functional.pad(torch.transpose(y, 0, 1), ( pad_left, 0), mode='replicate'), 0, 1)
258 | audio_segment = torch.nn.functional.pad(audio_segment, ( pad_left_audio, 0), mode='constant', value=0)
259 |
260 | if not self.causal:
261 | remaining_samples = min(num_samples - stop, self.non_causal_segment)
262 | remaining_samples_audio = remaining_samples * self.hop_size
263 |
264 | x = torch.cat((x, features[stop:stop + remaining_samples,:]), dim=0)
265 | y = torch.cat((y, out_features[stop:stop + remaining_samples,:]), dim=0)
266 | audio_segment = torch.cat((audio_segment, audio[...,stop_samples:stop_samples + remaining_samples_audio]), dim=-1)
267 |
268 | pad_right = self.non_causal_segment - remaining_samples
269 | if pad_right > 0:
270 | x = torch.transpose(torch.nn.functional.pad(torch.transpose(x, 0, 1), ( 0, pad_right), mode='replicate'), 0, 1) # seq_len, n_channels
271 | y = torch.transpose(torch.nn.functional.pad(torch.transpose(y, 0, 1), ( 0, pad_right), mode='replicate'), 0, 1)
272 | pad_right_audio = pad_right * self.hop_size
273 | audio_segment = torch.nn.functional.pad(audio_segment, ( pad_right_audio, 0), mode='constant', value=0)
274 |
275 | if not(x.size(0) == self.total_segment_len):
276 | raise ValueError('Padding in the wrong dimension')
277 |
278 | x = torch.transpose(torch.cat((torch.cat(self.normaliser(x[...,0:1], x[...,1:5], x[...,5:6], x[...,6:7], x[...,7:8]),dim = -1), x[...,8:9]),dim = -1), 0, 1)
279 | y = torch.transpose(y, 0, 1)
280 |
281 | mels = self.melspec(audio_segment).squeeze(0)
282 |
283 | if torch.any(torch.isnan(x)):
284 | raise ValueError("Output x features are NaN.")
285 | if torch.any(torch.isnan(y)):
286 | raise ValueError("Output y features are NaN.")
287 | # Set sequence len as last dimension.
288 | return x.type(torch.FloatTensor).to(self.device), y.type(torch.FloatTensor).to(self.device), audio_segment.squeeze(0).type(torch.FloatTensor).to(self.device), mels.type(torch.FloatTensor).to(self.device)
289 |
290 | class FeatureDataset_List(Dataset):
291 | def __init__(self,
292 | dataset_dir:str,
293 | config,
294 | sampling_rate: int,
295 | frame_size: int,
296 | hop_size:int,
297 | feature_ext: str = '.pt',
298 | audio_ext: str = '.wav',
299 | segment_length: int = None,
300 | normalise: bool = True,
301 | shuffle: bool = False,
302 | dtype: torch.dtype = torch.float32,
303 | device: str = 'cpu'):
304 |
305 | self.config = config
306 |
307 | self.dataset_dir = dataset_dir
308 | self.sampling_rate = sampling_rate
309 | self.frame_size = frame_size
310 | self.hop_size = hop_size
311 | self.feature_ext = feature_ext
312 | self.audio_ext = audio_ext
313 | self.normalise = normalise
314 | self.dtype = dtype
315 | self.device = device
316 |
317 | self.shuffle = shuffle
318 |
319 | self.segment_length = segment_length
320 |
321 | self.mel_spectrogram = ta.transforms.MelSpectrogram(sample_rate = self.sampling_rate, n_fft=self.frame_size, win_length = self.frame_size, hop_length = self.hop_size, f_min = 0.0, f_max = 8000, n_mels = 80)
322 | self.normaliser = Normaliser(self.sampling_rate)
323 |
324 | self.get_file_list()
325 |
326 | def get_file_list(self):
327 |
328 | self.file_list = glob(os.path.join(self.dataset_dir, '*' + self.feature_ext))
329 | if self.shuffle:
330 | #indices = torch.randperm(len(file_list))
331 | #file_list = file_list[indices]
332 | random.shuffle(self.file_list)
333 |
334 | def __len__(self):
335 | return len(self.file_list)
336 |
337 | def __getitem__(self, index):
338 |
339 | feature_file = self.file_list[index]
340 | audio_file = os.path.splitext(feature_file)[0] + self.audio_ext
341 |
342 | data = torch.load(feature_file)
343 | x = torch.cat((data["Pitch"].unsqueeze(1), data["Formants"], data["Tilt"].unsqueeze(1), data["Centroid"].unsqueeze(1), data["Energy"].unsqueeze(1), data["Voicing"].unsqueeze(1)), dim = -1)
344 | y = data["R_Coeff"]
345 |
346 | audio, sample_rate = ta.load(audio_file)
347 |
348 | if sample_rate != self.sampling_rate:
349 | audio = ta.functional.resample(audio, sample_rate, self.sampling_rate)
350 |
351 | if self.segment_length is None:
352 | self.segment_length = x.size(0)
353 |
354 | audio_total_len = int(x.size(0) * self.hop_size)
355 |
356 | if audio.size(1) < audio_total_len:
357 | audio = torch.unsqueeze(torch.nn.functional.pad(audio.squeeze(0), (0,int(audio_total_len - audio.size(1))),'constant'),0)
358 | else:
359 | audio = audio[:,:audio_total_len]
360 |
361 | if self.segment_length <= x.size(0):
362 |
363 | max_segment_start = x.size(0) - self.segment_length
364 | segment_start = random.randint(0, max_segment_start)
365 |
366 | x = x[segment_start:segment_start + self.segment_length,:]
367 | y = y[segment_start:segment_start + self.segment_length,:]
368 |
369 | audio_start = int(segment_start * self.hop_size)
370 | audio_segment_len = int(self.segment_length * self.hop_size)
371 | audio = audio[:,audio_start:audio_start + audio_segment_len]
372 | elif self.segment_length > x.size(0):
373 | diff = self.segment_length - x.size(0)
374 | x = torch.transpose(torch.nn.functional.pad(torch.transpose(x,0,1),(0,diff),'replicate'),0,1)
375 | y = torch.transpose(torch.nn.functional.pad(torch.transpose(y,0,1),(0,diff),'replicate'),0,1)
376 |
377 | audio_segment_diff = int(self.hop_size * diff)
378 | audio = torch.unsqueeze(torch.nn.functional.pad(audio.squeeze(0), (0,audio_segment_diff),'constant'),0)
379 |
380 | x = torch.transpose(torch.cat((torch.cat(self.normaliser(x[...,0:1], x[...,1:5], x[...,5:6], x[...,6:7], x[...,7:8]),dim = -1), x[...,8:9], y),dim = -1), 0, 1)
381 | #x = torch.transpose(torch.cat((torch.cat(self.normaliser(x[...,0:1], x[...,1:5], x[...,5:6], x[...,6:7], x[...,7:8]),dim = -1), x[...,8:9]),dim = -1), 0, 1)
382 |
383 | y = forward_levinson(y)
384 |
385 | y = torch.transpose(y, 0, 1)
386 |
387 | y_mel = mel_spectrogram(audio,sampling_rate = self.sampling_rate, n_fft=self.frame_size, win_size = self.frame_size, hop_size = self.hop_size, fmin = 0.0, fmax = self.config.fmax_for_loss, num_mels = 80)#self.mel_spectrogram(audio)
388 |
389 | return x.type(torch.FloatTensor).to(self.device), y.type(torch.FloatTensor).to(self.device), audio.squeeze(0).type(torch.FloatTensor).to(self.device), y_mel.squeeze().type(torch.FloatTensor).to(self.device)
--------------------------------------------------------------------------------
/src/neural_formant_synthesis/feature_extraction.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import diffsptk
3 | from neural_formant_synthesis.functions import frame_energy, spectral_centroid, tilt_levinson, root_to_formant, levinson, pitch_extraction
4 |
5 |
6 | class feature_extractor(torch.nn.Module):
7 | """
8 | Class to extract the features for neural formant synthesis.
9 | Params:
10 | sr: sample rate
11 | window_samples: window size in samples
12 | step_samples: step size in samples
13 | formant_ceiling: formant ceiling in Hz
14 | max_formants: maximum number of formants to estimate
15 | allpole_order: allpole order for the LPC analysis
16 | """
17 | def __init__(self, sr, window_samples, step_samples, formant_ceiling, max_formants, allpole_order = 10, r_coeff_order = 30):
18 | super().__init__()
19 | self.sr = sr
20 | self.window_samples = window_samples
21 | self.step_samples = step_samples
22 | self.formant_ceiling = formant_ceiling
23 | self.max_formants = max_formants
24 | self.allpole_order = allpole_order
25 | self.r_coeff_order = r_coeff_order
26 |
27 | self.framer = diffsptk.Frame(
28 | frame_length = self.window_samples,
29 | frame_period = self.step_samples
30 | )
31 | self.windower = diffsptk.Window(
32 | in_length = self.window_samples
33 | )
34 |
35 | # self.root_finder = diffsptk.DurandKernerMethod(self.allpole_order)
36 | self.root_finder = diffsptk.PolynomialToRoots(self.allpole_order)
37 |
38 | def forward(self, x):
39 | # Signal windowing
40 | x_frame = self.framer(x)
41 | x_window = self.windower(x_frame)
42 |
43 | # STFT
44 |
45 | x_spec = torch.fft.rfft(x_window, dim = -1)
46 |
47 | ds_samples = int(self.formant_ceiling/self.sr * x_spec.size(-1))
48 | x_ds = x_spec[...,:ds_samples]
49 |
50 | x_ds_acorr = torch.fft.irfft(x_ds * torch.conj(x_ds), dim = -1)
51 |
52 | x_us_acorr = torch.fft.irfft(x_spec * torch.conj(x_spec), dim = -1)
53 |
54 | # Calculate formants
55 |
56 | ap_env, _ = levinson(x_ds_acorr, self.allpole_order)
57 | _, r_coeff_ref = levinson(x_us_acorr, self.r_coeff_order)
58 |
59 | roots_env = self.root_finder(ap_env)
60 | formants = root_to_formant(roots_env, self.formant_ceiling, self.max_formants)
61 | # Calculate other features
62 | energy = 10 * torch.log10(frame_energy(x_frame))
63 | centroid = spectral_centroid(x_frame, self.sr)
64 | tilt = tilt_levinson(x_ds_acorr)
65 |
66 | # Extract pitch from audio signal
67 | pitch, voicing, ignored = pitch_extraction(x = x, sr = self.sr, window_samples = self.window_samples, step_samples = self.step_samples, fmin = 50, fmax = 500)#penn.dsp.dio.from_audio(audio = x, sample_rate = self.sr, hopsize = hopsize, fmin = 50, fmax = 500)
68 | pitch = torch.log(pitch)
69 | return formants, energy, centroid, tilt, pitch, voicing, r_coeff_ref, x_ds, ignored
70 |
71 | class MedianPool1d(torch.nn.Module):
72 | """ Median pool (usable as median filter when stride=1) module.
73 |
74 | Args:
75 | kernel_size: size of pooling kernel
76 | stride: pool stride
77 | padding: pool padding
78 | same: override padding and enforce same padding, boolean
79 | """
80 | def __init__(self, kernel_size=3, stride=1, padding=0, same=False):
81 | super(MedianPool1d, self).__init__()
82 | self.k = kernel_size
83 | self.stride = stride
84 | self.padding = padding
85 | self.same = same
86 |
87 | def _padding(self, x):
88 | if self.same:
89 | iw = x.size()[-1]
90 | if iw % self.stride == 0:
91 | pw = max(self.k - self.stride, 0)
92 | else:
93 | pw = max(self.k - (iw % self.stride), 0)
94 | pl = pw // 2
95 | pr = pw - pl
96 |
97 | padding = (pl, pr)
98 | else:
99 | padding = self.padding
100 | return padding
101 |
102 | def forward(self, x):
103 | x = torch.nn.functional.pad(x, self._padding(x), mode='reflect')
104 | x = x.unfold(2, self.k, self.stride)
105 | x = x.contiguous().view(x.size()[:-1] + (-1,)).median(dim=-1)[0]
106 | return x
107 |
108 | class Normaliser(torch.nn.Module):
109 | """
110 | Normalisation of features to a set of hard limits for training with online feature extraction.
111 | Params:
112 | sample_rate: sample rate
113 | pitch_lims: tuple with pitch limits in log(Hz) (lower, upper)
114 | formant_lims: tuple with formant limits in Hz (f1_lower, f1_upper, f2_lower, f2_upper, f3_lower, f3_upper, f4_lower, f4_upper)
115 | tilt_lims: tuple with spectral tilt limits (lower, upper)
116 | centroid_lims: tuple with spectral centroid limits in Hz (lower, upper)
117 | energy_lims: tuple with energy limits in dB (lower, upper)
118 | """
119 | def __init__(self, sample_rate, pitch_lims = None, formant_lims = None, tilt_lims = None, centroid_lims = None, energy_lims = None):
120 | super().__init__()
121 | self.sample_rate = sample_rate
122 | if pitch_lims is not None:
123 | self.pitch_lower = pitch_lims[0]
124 | self.pitch_upper = pitch_lims[1]
125 | else:
126 | self.pitch_lower = 3.9 # 50 Hz
127 | self.pitch_upper = 7.3 # 1500 Hz
128 |
129 | if formant_lims is not None:
130 | self.f1_lower = formant_lims[0]
131 | self.f1_upper = formant_lims[1]
132 | self.f2_lower = formant_lims[2]
133 | self.f2_upper = formant_lims[3]
134 | self.f3_lower = formant_lims[4]
135 | self.f3_upper = formant_lims[5]
136 | self.f4_lower = formant_lims[6]
137 | self.f4_upper = formant_lims[7]
138 | else:
139 | self.f1_lower = 200
140 | self.f1_upper = 900
141 | self.f2_lower = 550
142 | self.f2_upper = 2450
143 | self.f3_lower = 2200
144 | self.f3_upper = 2950
145 | self.f4_lower = 3000
146 | self.f4_upper = 4000
147 |
148 | if tilt_lims is not None:
149 | self.tilt_lower = tilt_lims[0]
150 | self.tilt_upper = tilt_lims[1]
151 | else:
152 | self.tilt_lower = -1
153 | self.tilt_upper = -0.9
154 |
155 | if centroid_lims is not None:
156 | self.centroid_lower = centroid_lims[0]
157 | self.centroid_upper = centroid_lims[1]
158 | else:
159 | self.centroid_lower = 0
160 | self.centroid_upper = self.sample_rate/2
161 |
162 | if energy_lims is not None:
163 | self.energy_lower = energy_lims[0]
164 | self.energy_upper = energy_lims[1]
165 | else:
166 | self.energy_lower = -60
167 | self.energy_upper = 30
168 |
169 | def forward(self, pitch, formants, tilt, centroid, energy):
170 | pitch = self._scale(pitch, self.pitch_upper, self.pitch_lower)
171 | formants[...,0] = self._scale(formants[...,0],self.f1_upper, self.f1_lower)
172 | formants[...,1] = self._scale(formants[...,1],self.f2_upper, self.f2_lower)
173 | formants[...,2] = self._scale(formants[...,2],self.f3_upper, self.f3_lower)
174 | formants[...,3] = self._scale(formants[...,3],self.f4_upper, self.f4_lower)
175 | tilt = self._scale(torch.clamp(tilt, -1, -0.85), self.tilt_upper, self.tilt_lower)
176 | centroid = self._scale(centroid, self.centroid_upper, self.centroid_lower)
177 | energy = self._scale(energy, self.energy_upper, self.energy_lower)
178 | return pitch, formants, tilt, centroid, energy
179 |
180 | def _scale(self,feature, upper, lower):
181 | max_denorm = upper - lower
182 | feature = (2 * (feature - lower) / max_denorm) - 1
183 | return feature
--------------------------------------------------------------------------------
/src/neural_formant_synthesis/functions.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | import pyworld
5 | import warnings
6 |
7 | def tilt_levinson(acorr):
8 | """
9 | Calculate spectral tilt from predictor coefficient of first order allpole calculated from signal autocorrelation.
10 | Args:
11 | acorr: Autocorrelation of the signal.
12 | Returns:
13 | Spectral tilt value.
14 | """
15 |
16 | # Calculate allpole coefficients
17 | a,_ = levinson(acorr, 1)
18 |
19 | # Calculate spectral tilt
20 | tilt = a[:,1]
21 |
22 | return tilt
23 |
24 | def spectral_centroid(x, sr):
25 | """
26 | Extract spectral centroid as weighted ratio of spectral bands.
27 | Args:
28 | x: Input frames with shape (..., n_frames, frame_length)
29 | sr: Sampling rate
30 | Returns:
31 | 1D tensor with the values of estimated spectral centroid for each frame
32 | """
33 |
34 | mag_spec = torch.abs(torch.fft.rfft(x, dim=-1))
35 | length = x.size(-1)
36 | freqs = torch.abs(torch.fft.fftfreq(length, 1.0/sr)[:length//2+1])
37 | centroid = torch.sum(mag_spec*freqs, dim = -1) / torch.sum(mag_spec, dim = -1)
38 | return centroid / ( sr / 2)
39 |
40 | def frame_energy(x):
41 | """
42 | Calculate frame energy.
43 | Args:
44 | x: Input frames. size (..., n_frames, frame_length)
45 | Returns:
46 | 1D tensor with energies for each frame. Size(n_frames,)
47 | """
48 |
49 | energy = torch.sum(torch.square(torch.abs(x)), dim = -1)
50 |
51 | return energy
52 |
53 | def levinson(R, M):
54 | """ Levinson-Durbin method for converting autocorrelation to predictor polynomial
55 | Args:
56 | R: autocorrelation tensor, shape=(..., M)
57 | M: filter polynomial order
58 | Returns:
59 | A: filter predictor polynomial tensor, shape=(..., M)
60 | Note:
61 | R can contain more lags than M, but R[..., 0:M] are required
62 | """
63 |
64 | # Normalize autocorrelation and add white noise correction
65 | R = R / R[..., 0:1]
66 | R[..., 0:1] = R[..., 0:1] + 0.001
67 |
68 | E = R[..., 0:1]
69 | L = torch.cat([torch.ones_like(R[..., 0:1]),
70 | torch.zeros_like(R[..., 0:M])], dim=-1)
71 | L_prev = L
72 | rcoeff = torch.zeros_like(L)
73 | for p in torch.arange(0, M):
74 | K = torch.sum(L_prev[..., 0:p+1] * R[..., 1:p+2], dim=-1, keepdim=True) / E
75 | rcoeff[...,p:p+1] = K
76 | if (torch.any(torch.abs(K) > 1)):
77 | print(torch.argmax(torch.abs(K)))
78 | print(R[torch.argmax(torch.abs(K)), 1:M])
79 | raise ValueError('Reflection coeff bigger than 1')
80 |
81 | pad = torch.clamp(M-p-1, min=0)
82 | if p == 0:
83 | L = torch.cat([-1.0*K,
84 | torch.ones_like(R[..., 0:1]),
85 | torch.zeros_like(R[..., 0:pad])], dim=-1)
86 | else:
87 | L = torch.cat([-1.0*K,
88 | L_prev[..., 0:p] - 1.0*K *
89 | torch.flip(L_prev[..., 0:p], dims=[-1]),
90 | torch.ones_like(R[..., 0:1]),
91 | torch.zeros_like(R[..., 0:pad])], dim=-1)
92 | L_prev = L
93 | E = E * (1.0 - K ** 2) # % order-p mean-square error
94 | L = torch.flip(L, dims=[-1]) # flip zero delay to zero:th index
95 | return L, rcoeff
96 |
97 | def root_to_formant(roots, sr, max_formants = 5):
98 | """
99 | Extract formant frequencies from allpole roots.
100 | Args:
101 | roots: Tensor containing roots of the polynomial.
102 | sr: Sampling rate.
103 | max_formants: Maximum number of formants to search.
104 | Returns:
105 | Tensor containing formant frequencies.
106 | """
107 |
108 | freq_tolerance = 10.0 / (sr / (2.0 * np.pi))
109 |
110 | phases = torch.angle(roots)
111 | phases = torch.where(phases < 0, phases + 2 * np.pi, phases)
112 |
113 | phases_sort,_ = torch.sort(phases, dim = -1, descending = False)
114 |
115 | phases_slice = phases_sort[...,1:max_formants+1]
116 | phases_sort = phases_sort[...,:max_formants]
117 |
118 | condition = (phases_sort[...,0:1] > freq_tolerance).expand(-1,max_formants) #Use expand instead of repeat
119 |
120 | phase_select = torch.where(condition, phases_sort, phases_slice)
121 | formants = phase_select * (sr / (2 * np.pi))
122 |
123 | return formants
124 |
125 | def interpolate_unvoiced(pitch, voicing_flag = None):
126 | """
127 | Fill unvoiced regions via linear interpolation
128 | @Pablo: Function copied from PENN's repository to allow input of voicing flag array.
129 | I haven't found any robust implementations for np.interp in pytorch.
130 | With a differentiable version, we could add this functionality to the formant extractor.
131 | """
132 | if voicing_flag is None:
133 | unvoiced = pitch < 10
134 | else:
135 | unvoiced = ~voicing_flag
136 |
137 | # Ignore warning of log setting unvoiced regions (zeros) to nan
138 | with warnings.catch_warnings():
139 | warnings.simplefilter('ignore')
140 |
141 | # Pitch is linear in base-2 log-space
142 | pitch = np.log2(pitch)
143 |
144 | ignored = False
145 |
146 | try:
147 |
148 | # Interpolate
149 | pitch[unvoiced] = np.interp(
150 | np.where(unvoiced)[0],
151 | np.where(~unvoiced)[0],
152 | pitch[~unvoiced])
153 |
154 | except ValueError:
155 | ignored = True
156 | # Allow all unvoiced
157 | print('Allowing unvoiced')
158 | pass
159 |
160 | return 2 ** pitch, ~unvoiced, ignored
161 |
162 | def pitch_extraction(x, sr, window_samples, step_samples, fmin = 50, fmax = 500):
163 | """
164 | Extract pitch using pyworld.
165 | Params:
166 | x: audio signal
167 | sr: sample rate
168 | window_samples: window size in samples
169 | step_samples: step size in samples
170 | fmin: minimum pitch frequency
171 | fmax: maximum pitch frequency
172 | Returns:
173 | pitch: tensor with pitch values
174 | voicing_flag: tensor with voicing flags
175 | """
176 | # Convert to numpy
177 | audio = x.numpy().squeeze().astype(np.float64)
178 |
179 | hopsize = float(step_samples) / sr
180 |
181 | # Get pitch
182 | pitch, times = pyworld.dio(
183 | audio[window_samples // 2:-window_samples // 2],
184 | sr,
185 | fmin,
186 | fmax,
187 | frame_period=1000 * hopsize)
188 |
189 | # Refine pitch
190 | pitch = pyworld.stonemask(
191 | audio,
192 | pitch,
193 | times,
194 | sr)
195 |
196 | # Interpolate unvoiced tokens
197 | pitch, voicing_flag, ignored = interpolate_unvoiced(pitch)
198 |
199 | # Convert to torch
200 | return torch.from_numpy(pitch)[None], torch.tensor(voicing_flag, dtype = torch.int), ignored
--------------------------------------------------------------------------------
/src/neural_formant_synthesis/glotnet/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ljuvela/SourceFilterNeuralFormants/d0894c2aa153510e6967c3b62c47f73ce4cb3879/src/neural_formant_synthesis/glotnet/__init__.py
--------------------------------------------------------------------------------
/src/neural_formant_synthesis/glotnet/model/feedforward/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ljuvela/SourceFilterNeuralFormants/d0894c2aa153510e6967c3b62c47f73ce4cb3879/src/neural_formant_synthesis/glotnet/model/feedforward/__init__.py
--------------------------------------------------------------------------------
/src/neural_formant_synthesis/glotnet/model/feedforward/activations.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2022 Lauri Juvela
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import torch
18 | import glotnet.cpp_extensions as ext
19 |
20 | def _gated_activation(x: torch.Tensor) -> torch.Tensor:
21 |
22 | assert x.size(1) % 2 == 0, f"Input must have an even number of channels, shape was {x.shape}"
23 | half = x.size(1) // 2
24 | return torch.tanh(x[:, :half, :]) * torch.sigmoid(x[:, half:, :])
25 |
26 | class Activation(torch.nn.Module):
27 | """ Activation class """
28 |
29 | def __init__(self, activation="gated"):
30 | super().__init__()
31 | self.activation_str = activation
32 | if activation == "gated":
33 | self.activation_func = _gated_activation
34 | elif activation == "tanh":
35 | self.activation_func = torch.tanh
36 | elif activation == "linear":
37 | self.activation_func = torch.nn.Identity()
38 |
39 | def forward(self, input, use_extension=True):
40 |
41 | if use_extension:
42 | return ActivationFunction.apply(input, self.activation_str)
43 | else:
44 | return self.activation_func(input)
45 |
46 |
47 | class ActivationFunction(torch.autograd.Function):
48 |
49 | @staticmethod
50 | def forward(ctx, input: torch.Tensor,
51 | activation: str, time_major: bool = True
52 | ) -> torch.Tensor:
53 |
54 | ctx.time_major = time_major
55 | if ctx.time_major:
56 | input = input.permute(0, 2, 1) # (B, C, T) -> (B, T, C)
57 |
58 | input = input.contiguous()
59 | ctx.save_for_backward(input)
60 | ctx.activation_type = activation
61 |
62 | output, = ext.activations_forward(input, activation)
63 |
64 | if ctx.time_major:
65 | output = output.permute(0, 2, 1) # (B, T, C) -> (B, C, T)
66 |
67 | return output
68 |
69 | @staticmethod
70 | def backward(ctx, d_output):
71 | raise NotImplementedError("Backward function not implemented for sequential processing")
72 |
73 |
74 |
75 | class Tanh(torch.nn.Module):
76 |
77 | def __init__(self):
78 | super().__init__()
79 |
80 | def forward(self, input: torch.Tensor) -> torch.Tensor:
81 | return TanhFunction.apply(input)
82 |
83 |
84 | class TanhFunction(torch.autograd.Function):
85 |
86 | @staticmethod
87 | def forward(ctx, input: torch.Tensor):
88 |
89 | output = torch.tanh(input)
90 | ctx.save_for_backward(input, output)
91 | return output
92 |
93 | @staticmethod
94 | def backward(ctx, d_output: torch.Tensor):
95 |
96 | input, output = ctx.saved_tensors
97 | # d tanh(x) / dx = 1 - tanh(x) ** 2
98 | d_input = (1 - output ** 2) * d_output
99 | return d_input
--------------------------------------------------------------------------------
/src/neural_formant_synthesis/glotnet/model/feedforward/convolution.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class Convolution(torch.nn.Conv1d):
4 | """ Causal convolution with optional FILM conditioning """
5 |
6 | def __init__(self,
7 | in_channels: int,
8 | out_channels: int,
9 | kernel_size: int,
10 | stride: int = 1,
11 | padding: int = 0,
12 | dilation: int = 1,
13 | groups: int = 1,
14 | bias: bool = True,
15 | padding_mode: str = 'zeros',
16 | device=None,
17 | dtype=None,
18 | causal: bool = True,
19 | use_film: bool = False
20 | ):
21 | super().__init__(in_channels, out_channels,
22 | kernel_size, stride,
23 | padding, dilation,
24 | groups, bias, padding_mode,
25 | device, dtype)
26 | self.causal = causal
27 | self.use_film = use_film
28 |
29 | @property
30 | def receptive_field(self) -> int:
31 | """ Receptive field length """
32 | return (self.kernel_size[0] - 1) * self.dilation[0] + 1
33 |
34 | def forward(self,
35 | input: torch.Tensor,
36 | cond_input: torch.Tensor = None,
37 | sequential: bool = False
38 | ) -> torch.Tensor:
39 | """
40 | Args:
41 | input shape is (batch, in_channels, time)
42 | cond_input: conditioning input
43 |
44 | shape = (batch, 2*out_channels, time) if self.use_film else (batch, out_channels, time)
45 | Returns:
46 | output shape is (batch, out_channels, time)
47 | """
48 |
49 | if cond_input is not None:
50 | if self.use_film and cond_input.size(1) != 2 * self.out_channels:
51 | raise ValueError(f"Cond input number of channels mismatch."
52 | f"Expected {2*self.out_channels}, got {cond_input.size(1)}")
53 | if not self.use_film and cond_input.size(1) != self.out_channels:
54 | raise ValueError(f"Cond input number of channels mismatch."
55 | f"Expected {self.out_channels}, got {cond_input.size(1)}")
56 | if cond_input.size(2) != input.size(2):
57 | raise ValueError(f"Mismatching timesteps, "
58 | f"input has {input.size(2)}, cond_input has {cond_input.size(2)}")
59 |
60 | return self._forward_native(input=input, cond_input=cond_input, causal=self.causal)
61 |
62 | def _forward_native(self, input: torch.Tensor,
63 | cond_input: torch.Tensor,
64 | causal:bool=True) -> torch.Tensor:
65 | """ Native torch conv1d with causal padding
66 |
67 | Args:
68 | input shape is (batch, in_channels, time)
69 | cond_input: conditioning input
70 | shape = (batch, 2*out_channels, time) if self.use_film else (batch, out_channels, time)
71 | Returns:
72 | output shape is (batch, out_channels, time)
73 |
74 | """
75 |
76 | if causal:
77 | padding = self.dilation[0] * self.stride[0] * (self.kernel_size[0]-1)
78 | if padding > 0:
79 | input = torch.nn.functional.pad(input, (padding, 0))
80 | output = torch.nn.functional.conv1d(
81 | input, self.weight, bias=self.bias,
82 | stride=self.stride, padding=0,
83 | dilation=self.dilation, groups=self.groups)
84 | else:
85 | output = torch.nn.functional.conv1d(
86 | input, self.weight, bias=self.bias,
87 | stride=self.stride, padding='same',
88 | dilation=self.dilation, groups=self.groups)
89 |
90 | if cond_input is not None:
91 | if self.use_film:
92 | b, a = torch.chunk(cond_input, 2, dim=1)
93 | output = a * output + b
94 | else:
95 | output = output + cond_input
96 | return output
97 |
--------------------------------------------------------------------------------
/src/neural_formant_synthesis/glotnet/model/feedforward/convolution_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Tuple
3 | from .convolution import Convolution
4 |
5 |
6 | class ConvolutionLayer(torch.nn.Module):
7 | """
8 | Wavenet Convolution Layer (also known as Residual Block)
9 |
10 | Uses a gated activation and a 1x1 output transformation by default
11 |
12 | """
13 |
14 | def __init__(self, in_channels, out_channels, kernel_size, dilation=1,
15 | bias=True, device=None, dtype=None,
16 | causal=True,
17 | activation="gated",
18 | use_output_transform=True,
19 | cond_channels=None,
20 | skip_channels=None,
21 | ):
22 | super().__init__()
23 |
24 | residual_channels = out_channels
25 | self.activation = activation
26 | self.activation_fun, self.channel_mul = self._parse_activation(activation)
27 | self.use_output_transform = use_output_transform
28 | self.use_conditioning = cond_channels is not None
29 | self.cond_channels = cond_channels
30 | self.residual_channels = residual_channels
31 | self.skip_channels = skip_channels
32 | self.dilation = dilation
33 | self.conv = Convolution(
34 | in_channels=in_channels,
35 | out_channels=self.channel_mul * residual_channels,
36 | kernel_size=kernel_size, dilation=dilation, bias=bias,
37 | device=device, dtype=dtype, causal=causal)
38 | # TODO: make parameter alloc conditional on use_output_transform
39 | self.out = Convolution(
40 | in_channels=residual_channels,
41 | out_channels=residual_channels,
42 | kernel_size=1, dilation=1, bias=bias, device=device, dtype=dtype)
43 | if self.skip_channels is not None:
44 | self.skip = Convolution(
45 | in_channels=residual_channels,
46 | out_channels=skip_channels,
47 | kernel_size=1, dilation=1, bias=bias, device=device, dtype=dtype)
48 | if self.use_conditioning:
49 | self.cond_1x1 = torch.nn.Conv1d(
50 | cond_channels, self.channel_mul * residual_channels,
51 | kernel_size=1, bias=True, device=device, dtype=dtype)
52 |
53 | @property
54 | def receptive_field(self):
55 | return self.conv.receptive_field + self.out.receptive_field
56 |
57 | activations = {
58 | "gated": ((torch.tanh, torch.sigmoid), 2),
59 | "tanh": (torch.tanh, 1),
60 | "linear": (torch.nn.Identity(), 1)
61 | }
62 | def _parse_activation(self, activation):
63 | activation_fun, channel_mul = ConvolutionLayer.activations.get(activation, (None, None))
64 | if channel_mul is None:
65 | raise NotImplementedError
66 | return activation_fun, channel_mul
67 |
68 |
69 | def forward(self, input, cond_input=None, sequential=False):
70 | """
71 | Args:
72 | input, torch.Tensor of shape (batch_size, in_channels, timesteps)
73 | sequential (optional),
74 | if True, use CUDA compatible parallel implementation
75 | if False, use custom C++ sequential implementation
76 |
77 | Returns:
78 | output, torch.Tensor of shape (batch_size, out_channels, timesteps)
79 | skip, torch.Tensor of shape (batch_size, out_channels, timesteps)
80 |
81 | """
82 |
83 | if cond_input is not None and not self.use_conditioning:
84 | raise RuntimeError("Module has not been initialized to use conditioning, \
85 | but conditioning input was provided at forward pass")
86 |
87 | if sequential:
88 | raise NotImplementedError
89 | else:
90 | return self._forward_native(input=input, cond_input=cond_input)
91 |
92 |
93 | def _forward_native(self, input, cond_input):
94 | c = self.cond_1x1(cond_input) if self.use_conditioning else None
95 | x = self.conv(input, cond_input=c)
96 | if self.channel_mul == 2:
97 | R = self.residual_channels
98 | x = self.activation_fun[0](x[:, :R, :]) * self.activation_fun[1](x[:, R:, :])
99 | else:
100 | x = self.activation_fun(x)
101 |
102 | if self.skip_channels is not None:
103 | skip = self.skip(x)
104 |
105 | if self.use_output_transform:
106 | output = self.out(x)
107 | else:
108 | output = x
109 |
110 | if self.skip_channels is None:
111 | return output
112 | else:
113 | return output, skip
114 |
115 |
--------------------------------------------------------------------------------
/src/neural_formant_synthesis/glotnet/model/feedforward/convolution_stack.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import List
3 | from .convolution_layer import ConvolutionLayer
4 |
5 |
6 | class ConvolutionStack(torch.nn.Module):
7 | """
8 | Wavenet Convolution Stack
9 |
10 | Uses a gated activation and residual connections by default
11 | """
12 |
13 | def __init__(self, channels, skip_channels, kernel_size, dilations=[1], bias=True, device=None, dtype=None,
14 | causal=True,
15 | activation="gated",
16 | use_residual=True,
17 | use_1x1_block_out=True,
18 | cond_channels=None,
19 | ):
20 | super().__init__()
21 |
22 | self.channels = channels
23 | self.skip_channels = skip_channels
24 | self.activation = activation
25 | self.dilations = dilations
26 | self.use_residual = use_residual
27 | self.use_1x1_block_out = use_1x1_block_out
28 | self.use_conditioning = cond_channels is not None
29 | self.cond_channels = cond_channels
30 | self.causal = causal
31 | self.num_layers = len(dilations)
32 |
33 | self.layers = torch.nn.ModuleList()
34 | for i, d in enumerate(dilations):
35 | use_output_transform = self.use_1x1_block_out
36 | # Always disable output 1x1 for last layer
37 | if i == self.num_layers - 1:
38 | use_output_transform = False
39 | # Add ConvolutionLayer to Stack
40 | self.layers.append(
41 | ConvolutionLayer(
42 | in_channels=channels, out_channels=channels,
43 | kernel_size=kernel_size, dilation=d, bias=bias, device=device, dtype=dtype,
44 | causal=causal,
45 | activation=activation,
46 | use_output_transform=use_output_transform,
47 | cond_channels=self.cond_channels,
48 | skip_channels=self.skip_channels,
49 | )
50 | )
51 |
52 | @property
53 | def weights_conv(self):
54 | return [layer.conv.weight for layer in self.layers]
55 |
56 | @property
57 | def biases_conv(self):
58 | return [layer.conv.bias for layer in self.layers]
59 |
60 | @property
61 | def weights_out(self):
62 | return [layer.out.weight for layer in self.layers]
63 |
64 | @property
65 | def biases_out(self):
66 | return [layer.out.bias for layer in self.layers]
67 |
68 | @property
69 | def weights_skip(self):
70 | return [layer.skip.weight for layer in self.layers]
71 |
72 | @property
73 | def biases_skip(self):
74 | return [layer.skip.bias for layer in self.layers]
75 |
76 | @property
77 | def weights_cond(self):
78 | if self.use_conditioning:
79 | return [layer.cond_1x1.weight for layer in self.layers]
80 | else:
81 | return None
82 |
83 | @property
84 | def biases_cond(self):
85 | if self.use_conditioning:
86 | return [layer.cond_1x1.bias for layer in self.layers]
87 | else:
88 | return None
89 |
90 | @property
91 | def receptive_field(self):
92 | return sum([l.receptive_field for l in self.layers])
93 |
94 | def forward(self, input, cond_input=None, sequential=False):
95 | """
96 | Args:
97 | input, torch.Tensor of shape (batch_size, channels, timesteps)
98 | cond_input (optional),
99 | torch.Tensor of shape (batch_size, cond_channels, timesteps)
100 | sequential (optional),
101 | if True, use CUDA compatible parallel implementation
102 | if False, use custom C++ sequential implementation
103 |
104 | Returns:
105 | output, torch.Tensor of shape (batch_size, channels, timesteps)
106 | skips, list of torch.Tensor of shape (batch_size, out_channels, timesteps)
107 |
108 | """
109 |
110 | if cond_input is not None and not self.use_conditioning:
111 | raise RuntimeError("Module has not been initialized to use conditioning, but conditioning input was provided at forward pass")
112 |
113 | if sequential:
114 | raise NotImplementedError("Sequential mode not implemented")
115 | else:
116 | return self._forward_native(input=input, cond_input=cond_input)
117 |
118 | def _forward_native(self, input, cond_input):
119 | x = input
120 | skips = []
121 | for layer in self.layers:
122 | h = x
123 | x, s = layer(x, cond_input, sequential=False)
124 | x = x + h # residual connection
125 | skips.append(s)
126 | return x, skips
127 |
--------------------------------------------------------------------------------
/src/neural_formant_synthesis/glotnet/model/feedforward/wavenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import List
3 | from .convolution_layer import ConvolutionLayer
4 | from .convolution_stack import ConvolutionStack
5 |
6 | class WaveNet(torch.nn.Module):
7 | """ Feedforward WaveNet """
8 |
9 | def __init__(self,
10 | input_channels: int,
11 | output_channels: int,
12 | residual_channels: int,
13 | skip_channels: int,
14 | kernel_size: int,
15 | dilations: List[int] = [1, 2, 4, 8, 16, 32, 64, 128, 256],
16 | causal: bool = True,
17 | activation: str = "gated",
18 | use_residual: bool = True,
19 | cond_channels: int = None,
20 | cond_net: torch.nn.Module = None,
21 | ):
22 | super().__init__()
23 |
24 | self.input_channels = input_channels
25 | self.output_channels = output_channels
26 | self.residual_channels = residual_channels
27 | self.skip_channels = skip_channels
28 | self.cond_channels = cond_channels
29 | self.use_conditioning = cond_channels is not None
30 | self.activation = activation
31 | self.dilations = dilations
32 | self.kernel_size = kernel_size
33 | self.use_residual = use_residual
34 | self.num_layers = len(dilations)
35 | self.causal = causal
36 |
37 | # Layers
38 | self.input = ConvolutionLayer(
39 | in_channels=self.input_channels,
40 | out_channels=self.residual_channels,
41 | kernel_size=1, activation="tanh", use_output_transform=False)
42 | self.stack = ConvolutionStack(
43 | channels=self.residual_channels,
44 | skip_channels=self.skip_channels,
45 | kernel_size=self.kernel_size,
46 | dilations=dilations,
47 | activation=self.activation,
48 | use_residual=True,
49 | causal=self.causal,
50 | cond_channels=cond_channels)
51 | # TODO: output layers should be just convolution, These hanve conv and Out
52 | self.output1 = ConvolutionLayer(
53 | in_channels=self.skip_channels,
54 | out_channels=self.residual_channels,
55 | kernel_size=1, activation="tanh", use_output_transform=False)
56 | self.output2 = ConvolutionLayer(
57 | in_channels=self.residual_channels,
58 | out_channels=self.output_channels,
59 | kernel_size=1, activation="linear", use_output_transform=False)
60 |
61 | self.cond_net = cond_net
62 |
63 | @property
64 | def output_weights(self):
65 | return [self.output1.conv.weight, self.output2.conv.weight]
66 |
67 | @property
68 | def output_biases(self):
69 | return [self.output1.conv.bias, self.output2.conv.bias]
70 |
71 | @property
72 | def receptive_field(self):
73 | return self.input.receptive_field + self.stack.receptive_field \
74 | + self.output1.receptive_field + self.output2.receptive_field
75 |
76 |
77 | def forward(self, input, cond_input=None):
78 | """
79 | Args:
80 | input, torch.Tensor of shape (batch_size, input_channels, timesteps)
81 | cond_input (optional),
82 | torch.Tensor of shape (batch_size, cond_channels, timesteps)
83 |
84 | Returns:
85 | output, torch.Tensor of shape (batch_size, output_channels, timesteps)
86 |
87 | """
88 |
89 | if cond_input is not None and not self.use_conditioning:
90 | raise RuntimeError("Module has not been initialized to use conditioning, but conditioning input was provided at forward pass")
91 |
92 | x = input
93 | x = self.input(x)
94 | _, skips = self.stack(x, cond_input) # TODO self.stack must be called something different, torch.stack is different
95 | x = torch.stack(skips, dim=0).sum(dim=0)
96 | x = self.output1(x)
97 | x = self.output2(x)
98 | return x
99 |
100 |
--------------------------------------------------------------------------------
/src/neural_formant_synthesis/glotnet/sigproc/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ljuvela/SourceFilterNeuralFormants/d0894c2aa153510e6967c3b62c47f73ce4cb3879/src/neural_formant_synthesis/glotnet/sigproc/__init__.py
--------------------------------------------------------------------------------
/src/neural_formant_synthesis/glotnet/sigproc/allpole.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | def allpole(x: torch.Tensor, a: torch.Tensor) -> torch.Tensor:
5 | """
6 | All-pole filter
7 | :param x: input signal,
8 | shape (B, C, T) (batch, channels, timesteps)
9 | :param a: filter coefficients (denominator),
10 | shape (C, N, T) (channels, num taps, timesteps)
11 | :return: filtered signal
12 | shape (B, C, T) (batch, channels, timesteps)
13 | """
14 | y = torch.zeros_like(x)
15 |
16 | a_normalized = a / a[:, 0:1, :]
17 |
18 | # filter order
19 | p = a.shape[1] - 1
20 |
21 | # filter coefficients
22 | a1 = a_normalized[:, 1:, :]
23 |
24 | # flip coefficients
25 | a1 = torch.flip(a1, [1])
26 |
27 | # zero pad y by filter order
28 | y = torch.nn.functional.pad(y, (p, 0))
29 |
30 | # filter
31 | for i in range(p, y.shape[-1]):
32 | y[..., i] = x[..., i - p] - \
33 | torch.sum(a1[..., i - p] * y[..., i - p:i], dim=-1)
34 |
35 | return y[..., p:]
36 |
37 | class AllPoleFunction(torch.autograd.Function):
38 |
39 | @staticmethod
40 | def forward(ctx, x, a):
41 | y = allpole(x, a)
42 | ctx.save_for_backward(y, x, a)
43 | return y
44 |
45 | @staticmethod
46 | def backward(ctx, dy):
47 | y, x, a = ctx.saved_tensors
48 | dx = da = None
49 |
50 | n_batch = x.size(0)
51 | n_channels = x.size(1)
52 | p = a.size(1) - 1
53 | T = dy.size(-1)
54 |
55 | # filter or
56 | dyda = allpole(-y, a)
57 | dyda = torch.nn.functional.pad(dyda, (p, 0))
58 |
59 | # da = torch.zeros_like(a)
60 | # for i in range(0, T):
61 | # for j in range(0, p):
62 | # da[:, p, i] = dyda[..., i:i+T] * dy
63 | # da = da.flip([1])
64 |
65 |
66 | da = F.conv1d(
67 | dyda.view(1, n_batch * n_channels, -1),
68 | dy.view(n_batch * n_channels, 1, -1),
69 | groups=n_batch * n_channels).view(n_batch, n_channels, -1).sum(0).flip(1)
70 |
71 | dx = allpole(dy.flip(-1), a.flip(-1)).flip(-1)
72 |
73 | return dx, da
74 |
75 | class AllPole(torch.nn.Module):
76 |
77 | def __init__(self):
78 | super().__init__()
79 |
80 | def forward(self, x, a):
81 | return AllPoleFunction.apply(x, a)
82 |
83 |
--------------------------------------------------------------------------------
/src/neural_formant_synthesis/glotnet/sigproc/biquad.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .lfilter import LFilter
4 | from typing import Tuple, Union
5 |
6 |
7 | class BiquadBaseFunctional(torch.nn.Module):
8 |
9 | def __init__(self):
10 | """ Initialize Biquad"""
11 |
12 | super().__init__()
13 |
14 | # TODO: pass STFT arguments to LFilter
15 | self.n_fft = 2048
16 | self.hop_length = 512
17 | self.win_length = 1024
18 | self.lfilter = LFilter(n_fft=self.n_fft,
19 | hop_length=self.hop_length,
20 | win_length=self.win_length)
21 |
22 | def _check_param_sizes(self, freq, gain, Q):
23 | """ Check that parameter sizes are compatible
24 |
25 | Args:
26 | freq: center frequencies
27 | shape = (batch, out_channels, in_channels, n_frames)
28 | gain: gains in decibels,
29 | shape = (batch, out_channels, in_channels, n_frames)
30 | Q: filter resonance (quality factor)
31 | shape = (batch, out_channels, in_channels, n_frames)
32 |
33 | Returns:
34 | batch, out_channels, in_channels, n_frames
35 |
36 | """
37 |
38 | # dimensions must be flat
39 | if freq.ndim != 4:
40 | raise ValueError("freq must be 4D")
41 | if gain.ndim != 4:
42 | raise ValueError("gain must be 4D")
43 | if Q.ndim != 4:
44 | raise ValueError("Q must be 4D")
45 |
46 | if freq.shape != gain.shape != Q.shape:
47 | raise ValueError("freq, gain, and Q must have the same shape")
48 |
49 | return freq.shape
50 |
51 | def _params_to_direct_form(self,
52 | freq: torch.Tensor,
53 | gain: torch.Tensor,
54 | Q: torch.Tensor
55 | ) -> Tuple[torch.Tensor, torch.Tensor]:
56 | """
57 | Args:
58 | freq, center frequency,
59 | shape is (..., n_frames)
60 | gain, gain in decibels
61 | shape is (..., n_frames)
62 | Q, resonance sharpness
63 | shape is (..., n_frames)
64 |
65 | Returns:
66 | b, filter numerator coefficients
67 | shape is (..., n_taps==3, n_frames)
68 | a, filter denominator coefficients
69 | shape is (..., n_taps==3, n_frames)
70 |
71 | """
72 | raise NotImplementedError("Subclasses must implement this method")
73 |
74 |
75 | def forward(self,
76 | x: torch.Tensor,
77 | freq: torch.Tensor,
78 | gain: torch.Tensor,
79 | Q: torch.Tensor,
80 | ) -> torch.Tensor:
81 | """
82 | Args:
83 | x: input signal
84 | shape = (batch, in_channels, time)
85 | freq: center frequencies
86 | shape = (batch, out_channels, in_channels, n_frames)
87 | n_frames is expected to be (time // hop_size)
88 | gain: gains in decibels,
89 | shape = (batch, out_channels, in_channels, n_frames)
90 | Q: filter resonance (quality factor)
91 | shape = (batch, out_channels, in_channels, n_frames)
92 |
93 | Returns:
94 | y: output signal
95 | shape = (batch, channels, n_filters, time)
96 | """
97 |
98 | batch, out_channels, in_channels, n_frames = self._check_param_sizes(freq=freq, gain=gain, Q=Q)
99 | timesteps = x.size(-1)
100 |
101 | freq = freq.reshape(batch * out_channels * in_channels, n_frames)
102 | gain = gain.reshape(batch * out_channels * in_channels, n_frames)
103 | Q = Q.reshape(batch * out_channels * in_channels, n_frames)
104 |
105 | b, a = self._params_to_direct_form(freq=freq, gain=gain, Q=Q)
106 |
107 | # expand x: (batch, in_channels, time) -> (batch, out_channels, in_channels, time)
108 | x_exp = x.unsqueeze(1).expand(-1, out_channels, -1, -1)
109 | # apply filtering
110 | x_exp = x_exp.reshape(batch * out_channels * in_channels, 1, timesteps)
111 | y = self.lfilter.forward(x_exp, b=b, a=a)
112 | # reshape
113 | y = y.reshape(batch, out_channels, in_channels, timesteps)
114 | # sum over input channels
115 | y = y.sum(dim=2)
116 |
117 | return y
118 |
119 |
120 | class BiquadPeakFunctional(BiquadBaseFunctional):
121 |
122 | def __init__(self):
123 | """ Initialize Biquad"""
124 | super().__init__()
125 |
126 | def _params_to_direct_form(self,
127 | freq: torch.Tensor,
128 | gain: torch.Tensor,
129 | Q: torch.Tensor
130 | ) -> Tuple[torch.Tensor, torch.Tensor]:
131 | """
132 | Args:
133 | freq, center frequency,
134 | shape is (..., n_frames)
135 | gain, gain in decibels
136 | shape is (..., n_frames)
137 | Q, resonance sharpness
138 | shape is (..., n_frames)
139 |
140 | Returns:
141 | b, filter numerator coefficients
142 | shape is (..., n_taps==3, n_frames)
143 | a, filter denominator coefficients
144 | shape is (..., n_taps==3, n_frames)
145 |
146 | """
147 | if torch.any(freq > 1.0):
148 | raise ValueError(f"Normalized frequency must be below 1.0, max was {freq.max()}")
149 | if torch.any(freq < 0.0):
150 | raise ValueError(f"Normalized frequency must be above 0.0, min was {freq.min()}")
151 |
152 | freq = freq.unsqueeze(-2)
153 | gain = gain.unsqueeze(-2)
154 | Q = Q.unsqueeze(-2)
155 |
156 | omega = torch.pi * freq
157 | A = torch.pow(10.0, 0.025 * gain)
158 | alpha = 0.5 * torch.sin(omega) / Q
159 |
160 | b0 = 1.0 + alpha * A
161 | b1 = -2.0 * torch.cos(omega)
162 | b2 = 1.0 - alpha * A
163 | a0 = 1 + alpha / A
164 | a1 = -2.0 * torch.cos(omega)
165 | a2 = 1 - alpha / A
166 |
167 | a = torch.cat([a0, a1, a2], dim=-2)
168 | b = torch.cat([b0, b1, b2], dim=-2)
169 |
170 | b = b / a0
171 | a = a / a0
172 | return b, a
173 |
174 |
175 | class BiquadModule(torch.nn.Module):
176 |
177 |
178 | @property
179 | def freq(self):
180 | return self._freq
181 |
182 |
183 | def set_freq(self, freq):
184 | if type(freq) != torch.Tensor:
185 | freq = torch.tensor([freq], dtype=torch.float32)
186 |
187 | # convert to normalized frequency
188 | freq = 2.0 * freq / self.fs
189 |
190 | if freq.max() > 1.0:
191 | raise ValueError(
192 | "Maximum normalized frequency is larger than 1.0. "
193 | "Please provide a sample rate or input normalized frequencies")
194 | if freq.min() < 0.0:
195 | raise ValueError(
196 | "Maximum normalized frequency is smaller than 0.0.")
197 |
198 |
199 | self._freq.data = freq.broadcast_to(self._freq.shape)
200 |
201 | @property
202 | def gain_dB(self):
203 | return self._gain_dB
204 |
205 | def set_gain_dB(self, gain):
206 | if type(gain) != torch.Tensor:
207 | gain = torch.tensor([gain], dtype=torch.float32)
208 | self._gain_dB.data = gain.broadcast_to(self._gain_dB.shape)
209 |
210 | @property
211 | def Q(self):
212 | return self._Q
213 |
214 | def set_Q(self, Q):
215 | if type(Q) != torch.Tensor:
216 | Q = torch.tensor([Q], dtype=torch.float32)
217 | self._Q.data = Q.broadcast_to(self._Q.shape)
218 |
219 | def _init_freq(self):
220 | freq = torch.rand(self.out_channels, self.in_channels)
221 | self._freq = torch.nn.Parameter(freq)
222 |
223 | def _init_gain_dB(self):
224 | gain_dB = torch.zeros(self.out_channels, self.in_channels)
225 | self._gain_dB = torch.nn.Parameter(gain_dB)
226 |
227 | def _init_Q(self):
228 | Q = torch.ones(self.out_channels, self.in_channels)
229 | self._Q = torch.nn.Parameter(0.7071 * Q)
230 |
231 |
232 |
233 | def __init__(self,
234 | in_channels: int=1,
235 | out_channels: int=1,
236 | fs: float = None,
237 | func: BiquadBaseFunctional = BiquadPeakFunctional()
238 | ):
239 | """
240 | Args:
241 | func: BiquadBaseFunctional subclass
242 | freq: center frequency
243 | gain: gain in dB
244 | Q: quality factor determining filter resonance bandwidth
245 | fs: sample rate, if not provided freq is assumed as normalized from 0 to 1 (Nyquist)
246 |
247 | """
248 | super().__init__()
249 |
250 | self.func = func
251 |
252 | # if no sample rate provided, assume normalized frequency
253 | if fs is None:
254 | fs = 2.0
255 |
256 | self.fs = fs
257 |
258 | self.in_channels = in_channels
259 | self.out_channels = out_channels
260 |
261 | self._init_freq()
262 | self._init_gain_dB()
263 | self._init_Q()
264 |
265 |
266 | def get_impulse_response(self, n_timesteps: int = 2048) -> torch.Tensor:
267 | """ Get impulse response of filter
268 |
269 | Args:
270 | n_timesteps: number of timesteps to evaluate
271 |
272 | Returns:
273 | h, shape is (batch, channels, n_timesteps)
274 | """
275 | x = torch.zeros(1, 1, n_timesteps)
276 | x[:, :, 0] = 1.0
277 | h = self.forward(x)
278 | return h
279 |
280 | def get_frequency_response(self, n_timesteps: int = 2048, n_fft: int = 2048) -> torch.Tensor:
281 | """ Get frequency response of filter
282 |
283 | Args:
284 | n_timesteps: number of timesteps to evaluate
285 |
286 | Returns:
287 | H, shape is (batch, channels, n_timesteps)
288 | """
289 | h = self.get_impulse_response(n_timesteps=n_timesteps)
290 | H = torch.fft.rfft(h, n=n_fft, dim=-1)
291 | H = torch.abs(H)
292 | return H
293 |
294 |
295 | def forward(self, x: torch.Tensor) -> torch.Tensor:
296 | """
297 | Args:
298 | x, shape is (batch, in_channels, timesteps)
299 |
300 | Returns:
301 | y, shape is (batch, out_channels, timesteps)
302 | """
303 |
304 | batch, in_channels, timesteps = x.size()
305 |
306 | num_frames = timesteps // self.func.hop_length
307 | # map size: (out_channels, in_channels) -> (batch, out_channels, in_channels, n_frames)
308 | freq = self.freq.unsqueeze(0).unsqueeze(-1).expand(batch, -1, -1, num_frames)
309 | gain = self.gain_dB.unsqueeze(0).unsqueeze(-1).expand(batch, -1, -1, num_frames)
310 | q_factor = self.Q.unsqueeze(0).unsqueeze(-1).expand(batch, -1, -1, num_frames)
311 |
312 | y = self.func.forward(x, freq, gain, q_factor)
313 |
314 | return y
315 |
316 | class BiquadParallelBankModule(torch.nn.Module):
317 |
318 | def __init__(self,
319 | num_filters:int=10,
320 | fs=None,
321 | func: BiquadBaseFunctional = BiquadPeakFunctional()
322 | ):
323 | """
324 | Args:
325 | num_filters: number of filters in bank
326 | func: BiquadBaseFunctional subclass
327 |
328 | """
329 | super().__init__()
330 |
331 | self.num_filters = num_filters
332 |
333 | self.fs = fs
334 | self.filter_bank = BiquadModule(in_channels=num_filters, out_channels=1, fs=fs, func=func)
335 |
336 | # flat initialization
337 | freq = torch.linspace(0.0, 1.0, num_filters+2)[1:-1]
338 | gain = torch.zeros_like(freq)
339 | Q = 0.7071 * torch.ones_like(freq)
340 |
341 | self.filter_bank.set_freq(freq)
342 | self.filter_bank.set_gain_dB(gain)
343 | self.filter_bank.set_Q(Q)
344 |
345 | def set_freq(self, freq: torch.Tensor):
346 | """ Set center frequency of each filter
347 | Args:
348 | freq, shape is (num_filters,)
349 | """
350 | self.filter_bank.set_freq(freq)
351 |
352 | def set_gain_dB(self, gain_dB: torch.Tensor):
353 | self.filter_bank.set_gain_dB(gain_dB)
354 |
355 | def set_Q(self, Q: torch.Tensor):
356 | self.filter_bank.set_Q(Q)
357 |
358 | def forward(self, x: torch.Tensor) -> torch.Tensor:
359 | """
360 | Args:
361 | x, shape is (batch, channels=1, timesteps)
362 |
363 | Returns:
364 | y, shape is (batch, channels=1, timesteps)
365 | """
366 |
367 | if x.size(1) != 1:
368 | raise ValueError(f"Input must have 1 channel, got {x.size(1)}")
369 |
370 | # expand channels to match filter bank
371 | x = x.expand(-1, self.num_filters, -1)
372 |
373 | # output shape is (batch, channels, , timesteps)
374 | y = self.filter_bank(x)
375 |
376 | # normalize output by number of filters
377 | # y = y / self.num_filters
378 |
379 | return y
380 |
381 | def get_impulse_response(self, n_timesteps: int = 2048) -> torch.Tensor:
382 | """ Get impulse response of filter
383 |
384 | Args:
385 | n_timesteps: number of timesteps to evaluate
386 |
387 | Returns:
388 | h, shape is (batch, channels, n_timesteps)
389 | """
390 | x = torch.zeros(1, 1, n_timesteps)
391 | x[:, :, 0] = 1.0
392 | h = self.forward(x)
393 | return h
394 |
395 | def get_frequency_response(self, n_timesteps: int = 2048, n_fft: int = 2048) -> torch.Tensor:
396 | """ Get frequency response of filter
397 |
398 | Args:
399 | n_timesteps: number of timesteps to evaluate
400 |
401 | Returns:
402 | H, shape is (batch, channels, n_timesteps)
403 | """
404 | h = self.get_impulse_response(n_timesteps=n_timesteps)
405 | H = torch.fft.rfft(h, n=n_fft, dim=-1)
406 | H = torch.abs(H)
407 | return H
408 |
409 |
410 |
411 | class BiquadResonatorFunctional(BiquadBaseFunctional):
412 |
413 | def __init__(self):
414 | """ Initialize Biquad"""
415 | super().__init__()
416 |
417 | def _params_to_direct_form(self,
418 | freq: torch.Tensor,
419 | gain: torch.Tensor,
420 | Q: torch.Tensor
421 | ) -> Tuple[torch.Tensor, torch.Tensor]:
422 | """
423 | Args:
424 | freq, center frequency,
425 | shape is (..., n_frames)
426 | gain, gain in decibels
427 | shape is (..., n_frames)
428 | Q, resonance sharpness
429 | shape is (..., n_frames)
430 |
431 | Returns:
432 | b, filter numerator coefficients
433 | shape is (..., n_taps==3, n_frames)
434 | a, filter denominator coefficients
435 | shape is (..., n_taps==3, n_frames)
436 |
437 | """
438 | if torch.any(freq > 1.0):
439 | raise ValueError(f"Normalized frequency must be below 1.0, max was {freq.max()}")
440 | if torch.any(freq < 0.0):
441 | raise ValueError(f"Normalized frequency must be above 0.0, min was {freq.min()}")
442 |
443 | freq = freq.unsqueeze(-2)
444 | gain = gain.unsqueeze(-2)
445 | Q = Q.unsqueeze(-2)
446 |
447 | omega = torch.pi * freq
448 | A = torch.pow(10.0, 0.025 * gain)
449 | alpha = 0.5 * torch.sin(omega) / Q
450 |
451 | b0 = torch.ones_like(freq)
452 | b1 = torch.zeros_like(freq)
453 | b2 = torch.zeros_like(freq)
454 | a0 = 1 + alpha / A
455 | a1 = -2.0 * torch.cos(omega)
456 | a2 = 1 - alpha / A
457 |
458 | a = torch.cat([a0, a1, a2], dim=1)
459 | b = torch.cat([b0, b1, b2], dim=1)
460 |
461 | b = b / a0
462 | a = a / a0
463 | return b, a
464 |
465 |
466 | class BiquadBandpassFunctional(BiquadBaseFunctional):
467 |
468 | def __init__(self):
469 | """ Initialize Biquad"""
470 | super().__init__()
471 |
472 | def _params_to_direct_form(self,
473 | freq: torch.Tensor,
474 | gain: torch.Tensor,
475 | Q: torch.Tensor
476 | ) -> Tuple[torch.Tensor, torch.Tensor]:
477 | """
478 | Args:
479 | freq, center frequency,
480 | shape is (..., n_frames)
481 | gain, gain in decibels
482 | shape is (..., n_frames)
483 | Q, resonance sharpness
484 | shape is (..., n_frames)
485 |
486 | Returns:
487 | b, filter numerator coefficients
488 | shape is (..., n_taps==3, n_frames)
489 | a, filter denominator coefficients
490 | shape is (..., n_taps==3, n_frames)
491 |
492 | """
493 | if torch.any(freq > 1.0):
494 | raise ValueError(f"Normalized frequency must be below 1.0, max was {freq.max()}")
495 | if torch.any(freq < 0.0):
496 | raise ValueError(f"Normalized frequency must be above 0.0, min was {freq.min()}")
497 |
498 | freq = freq.unsqueeze(-2)
499 | gain = gain.unsqueeze(-2)
500 | Q = Q.unsqueeze(-2)
501 |
502 | omega = torch.pi * freq
503 | A = torch.pow(10.0, 0.025 * gain)
504 | alpha = 0.5 * torch.sin(omega) / Q
505 |
506 | b0 = alpha
507 | b1 = torch.zeros_like(freq)
508 | b2 = -1.0 * alpha
509 | a0 = 1 + alpha / A
510 | a1 = -2.0 * torch.cos(omega)
511 | a2 = 1 - alpha / A
512 |
513 | a = torch.cat([a0, a1, a2], dim=1)
514 | b = torch.cat([b0, b1, b2], dim=1)
515 |
516 | b = b / a0
517 | a = a / a0
518 | return b, a
--------------------------------------------------------------------------------
/src/neural_formant_synthesis/glotnet/sigproc/emphasis.py:
--------------------------------------------------------------------------------
1 | from .lfilter import LFilter
2 | import torch
3 | from .lfilter import ceil_division
4 |
5 | class Emphasis(torch.nn.Module):
6 | """ Pre-emphasis and de-emphasis filter"""
7 |
8 | def __init__(self, alpha=0.85) -> None:
9 | """
10 | Args:
11 | alpha : pre-emphasis coefficient
12 | """
13 | super().__init__()
14 | if alpha < 0 or alpha >= 1:
15 | raise ValueError(f"alpha must be in [0, 1), got {alpha}")
16 | self.alpha = alpha
17 | self.lfilter = LFilter(n_fft=512, hop_length=256, win_length=512)
18 | self.register_buffer('coeffs', torch.tensor([1, -alpha], dtype=torch.float32))
19 |
20 | def forward(self, x: torch.Tensor) -> torch.Tensor:
21 | """" Apply pre-emphasis to signal
22 | Args:
23 | x : input signal, shape (batch, channels, timesteps)
24 |
25 | Returns:
26 | y : output signal, shape (batch, channels, timesteps)
27 |
28 | """
29 | if not (self.alpha > 0):
30 | return x
31 | # expand coefficients to batch size and number of frames
32 | b = self.coeffs.unsqueeze(0).unsqueeze(-1).expand(x.size(0), -1, ceil_division(x.size(-1), self.lfilter.hop_length))
33 | # filter
34 | return self.lfilter(x, b=b, a=None)
35 |
36 | def emphasis(self, x: torch.Tensor) -> torch.Tensor:
37 | """" Apply pre-emphasis to signal
38 | Args:
39 | x : input signal, shape (batch, channels, timesteps)
40 |
41 | Returns:
42 | y : output signal, shape (batch, channels, timesteps)
43 |
44 | """
45 | return self.forward(x)
46 |
47 | def deemphasis(self, x: torch.Tensor) -> torch.Tensor:
48 | """ Remove emphasis from signal
49 | Args:
50 | x : input signal, shape (batch, channels, timesteps)
51 | Returns:
52 | y : output signal, shape (batch, channels, timesteps)
53 | """
54 | if not (self.alpha > 0):
55 | return x
56 | # expand coefficients to batch size and number of frames
57 | a = self.coeffs.unsqueeze(0).unsqueeze(-1).expand(x.size(0), -1, ceil_division(x.size(-1), self.lfilter.hop_length))
58 | # filter
59 | return self.lfilter(x, b=None, a=a)
--------------------------------------------------------------------------------
/src/neural_formant_synthesis/glotnet/sigproc/levinson.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | def toeplitz(r: torch.Tensor):
5 | """" Construct Toeplitz matrix """
6 | p = r.size(-1)
7 | rr = torch.cat([r, r[..., 1:].flip(dims=(-1,))], dim=-1)
8 | T = [torch.roll(rr, i, dims=(-1,))[...,:p] for i in range(p)]
9 | return torch.stack(T, dim=-1)
10 |
11 |
12 | def levinson(R:torch.Tensor, M:int, eps:float=1e-3) -> torch.Tensor:
13 | """ Levinson-Durbin method for converting autocorrelation to all-pole polynomial
14 | Args:
15 | R: autocorrelation, shape=(..., M)
16 | M: filter polynomial order
17 | Returns:
18 | A: all-pole polynomial, shape=(..., M)
19 | """
20 | # normalize R
21 | R = R / R[..., 0:1]
22 | # white noise correction
23 | R[..., 0] = R[..., 0] + eps
24 |
25 | # zero lag
26 | K = torch.sum(R[..., 1:2], dim=-1, keepdim=True)
27 | A = torch.cat([-1.0*K, torch.ones_like(R[..., 0:1])], dim=-1)
28 | E = 1.0 - K ** 2
29 | # higher lags
30 | for p in torch.arange(1, M):
31 | K = torch.sum(A[..., 0:p+1] * R[..., 1:p+2], dim=-1, keepdim=True) / E
32 | if K.abs().max() > 1.0:
33 | raise ValueError(f"Unstable filter, |K| was {K.abs().max()}")
34 | A = torch.cat([-1.0*K,
35 | A[..., 0:p] - 1.0*K *
36 | torch.flip(A[..., 0:p], dims=[-1]),
37 | torch.ones_like(R[..., 0:1])], dim=-1)
38 | E = E * (1.0 - K ** 2)
39 | A = torch.flip(A, dims=[-1])
40 | return A
41 |
42 |
43 | def forward_levinson(K: torch.Tensor, M: int = None) -> torch.Tensor:
44 | """ Forward Levinson converts reflection coefficients to all-pole polynomial
45 |
46 | Args:
47 | K: reflection coefficients, shape=(..., M)
48 | M: filter polynomial order (optional, defaults to K.size(-1))
49 | Returns:
50 | A: all-pole polynomial, shape=(..., M+1)
51 |
52 | """
53 | if M is None:
54 | M = K.size(-1)
55 |
56 | A = -1.0*K[..., 0:1]
57 | for p in torch.arange(1, M):
58 | A = torch.cat([-1.0*K[..., p:p+1],
59 | A[..., 0:p] - 1.0*K[..., p:p+1] * torch.flip(A[..., 0:p], dims=[-1])], dim=-1)
60 |
61 | A = torch.cat([A, torch.ones_like(A[..., 0:1])], dim=-1)
62 | A = torch.flip(A, dims=[-1]) # flip zero delay to zero:th index
63 | return A
64 |
65 |
66 | def spectrum_to_allpole(spectrum:torch.Tensor, order:int, root_scale:float=1.0):
67 | """ Convert spectrum to all-pole filter coefficients
68 |
69 | Args:
70 | spectrum: power spectrum (squared magnitude), shape=(..., K)
71 | order: filter polynomial order
72 |
73 | Returns:
74 | a: filter predictor polynomial tensor, shape=(..., order+1)
75 | g: filter gain
76 | """
77 | r = torch.fft.irfft(spectrum, dim=-1)
78 | # add small value to diagonal to avoid singular matrix
79 | r[..., 0] = r[..., 0] + 1e-6
80 | # all pole from autocorr
81 | a = levinson(r, order)
82 |
83 | # filter gain
84 | # g = torch.sqrt(torch.dot(r[:(order+1)], a))
85 | g = torch.sqrt(torch.sum(r[..., :(order+1)] * a, dim=-1, keepdim=True))
86 |
87 | # scale filter roots
88 | if root_scale < 1.0:
89 | a = a * root_scale ** torch.arange(order+1, dtype=torch.float32, device=a.device)
90 |
91 | return a, g
--------------------------------------------------------------------------------
/src/neural_formant_synthesis/glotnet/sigproc/lfilter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import math
4 |
5 | # TODO: wrap STFT into class
6 | # TODO: support grouped convolution (multichannel, no mixing)
7 |
8 | def ceil_division(n: int, d: int) -> int:
9 | """ Ceiling integer division """
10 | return -(n // -d)
11 |
12 | def pad_frames(frames: torch.Tensor, target_len: int) -> torch.Tensor:
13 | n_pad = target_len - frames.size(-1)
14 | l_pad = n_pad // 2
15 | r_pad = ceil_division(n_pad, 2)
16 | return F.pad(frames, pad=(l_pad, r_pad), mode='replicate')
17 |
18 | class LFilter(torch.nn.Module):
19 | """ Linear filtering with STFT """
20 |
21 | def __init__(self, n_fft: int, hop_length: int, win_length: int):
22 | super().__init__()
23 | self.hop_length = hop_length
24 | self.n_fft = n_fft
25 | self.win_length = win_length
26 | window = torch.hann_window(window_length=self.win_length)
27 | self.register_buffer('window', window.reshape(1, -1, 1), persistent=False)
28 | self.scale = 2.0 * self.hop_length / self.win_length
29 |
30 | self.fold_kwargs = {
31 | 'kernel_size':(self.win_length, 1),
32 | 'stride' : (self.hop_length, 1),
33 | 'padding' : 0
34 | }
35 |
36 | def forward(self, x: torch.Tensor, b: torch.Tensor = None, a: torch.Tensor = None) -> torch.Tensor:
37 | """
38 | Args:
39 | x : input signal
40 | (batch, channels==1, timesteps)
41 | b : filter numerator coefficients
42 | (batch, b_len, n_frames)
43 | a : filter denominator coefficients
44 | (batch, a_len, n_frames)
45 | """
46 | num_frames = ceil_division(x.size(-1), self.hop_length)
47 |
48 | left_pad = self.win_length
49 | last_frame_center = num_frames * self.hop_length
50 | right_pad = last_frame_center - x.size(-1) + self.win_length
51 | x_padded = F.pad(x, pad=(left_pad, right_pad))
52 |
53 | if x.size(1) != 1:
54 | raise RuntimeError(f"channels must be 1, got shape {x.shape}")
55 |
56 | # frame
57 | x_padded = x_padded.unsqueeze(-1)
58 | fold_size = x_padded.shape[2:]
59 | x_framed = F.unfold(x_padded, # (B, C, T, 1)
60 | **self.fold_kwargs)
61 |
62 | # window
63 | x_windowed = x_framed * self.window
64 |
65 | # FFT
66 | X = torch.fft.rfft(x_windowed, n=self.n_fft, dim=1)
67 |
68 | if a is None:
69 | A = torch.ones_like(X)
70 | else:
71 | a_pad = pad_frames(a, target_len=X.size(-1))
72 | A = torch.fft.rfft(a_pad, n=self.n_fft, dim=1)
73 |
74 | if b is None:
75 | B = torch.ones_like(X)
76 | else:
77 | b_pad = pad_frames(b, target_len=X.size(-1))
78 | B = torch.fft.rfft(b_pad, n=self.n_fft, dim=1)
79 |
80 | # multiply
81 | Y = X * B / A
82 |
83 | # IFFT
84 | y_windowed = torch.fft.irfft(Y, n=self.n_fft, dim=1)
85 | y_windowed = y_windowed[:, :self.win_length, :]
86 | # TODO: window again for OLA
87 | # y_windowed = y_windowed[:, :self.win_length, :] * self.window
88 |
89 | # Overlap-add
90 | # TODO: change OLA fold args kernel size to fft_len
91 | y = F.fold(y_windowed, output_size=fold_size,
92 | **self.fold_kwargs)
93 | return y[:, :, left_pad:-right_pad, 0] * self.scale
94 |
95 |
96 |
97 |
98 |
99 |
100 |
--------------------------------------------------------------------------------
/src/neural_formant_synthesis/glotnet/sigproc/lpc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .lfilter import LFilter
3 | from .levinson import spectrum_to_allpole
4 |
5 |
6 | class LinearPredictor(torch.nn.Module):
7 |
8 | def __init__(self,
9 | n_fft: int = 512,
10 | hop_length: int = 256,
11 | win_length=512,
12 | order=10):
13 | """
14 | Args:
15 | n_fft (int): FFT size
16 | hop_length (int): Hop length
17 | win_length (int): Window length
18 | order (int): Allpole order
19 | """
20 | super().__init__()
21 | self.n_fft = n_fft
22 | self.hop_length = hop_length
23 | self.win_length = win_length
24 | self.order = order
25 | self.lfilter = LFilter(n_fft=n_fft, hop_length=hop_length, win_length=win_length)
26 |
27 | def estimate(self, x: torch.Tensor) -> torch.Tensor:
28 | """
29 | Args:
30 | x (torch.Tensor): Input audio (batch, time)
31 |
32 | Returns:
33 | torch.Tensor: Allpole coefficients (batch, order+1, time)
34 | """
35 | X = torch.stft(x, n_fft=self.n_fft,
36 | hop_length=self.hop_length, win_length=self.win_length,
37 | return_complex=True)
38 | # power spectrum
39 | X = torch.abs(X)**2
40 |
41 | # transpose to (batch, time, freq)
42 | X = X.transpose(1, 2)
43 |
44 | # allpole coefficients
45 | a, _ = spectrum_to_allpole(X, order=self.order)
46 |
47 | # transpose to (batch, order, num_frames)
48 | a = a.transpose(1, 2)
49 |
50 | return a
51 |
52 | def inverse_filter(self,
53 | x: torch.Tensor,
54 | a: torch.Tensor) -> torch.Tensor:
55 | """
56 | Args:
57 | x: Input audio (batch, time)
58 | a: Allpole coefficients (batch, order+1, time)
59 |
60 | Returns:
61 | Inverse filtered audio (batch, time)
62 | """
63 | # inverse filter
64 | e = self.lfilter.forward(x=x, b=a, a=None)
65 |
66 | return e
67 |
68 | def synthesis_filter(self,
69 | e: torch.Tensor,
70 | a: torch.Tensor) -> torch.Tensor:
71 | """
72 | Args:
73 | e: Excitation signal (batch, channels, time)
74 | a: Allpole coefficients (batch, order+1, time)
75 |
76 | Returns:
77 | torch.Tensor: Synthesis filtered audio (batch, time)
78 | """
79 | # inverse filter
80 | x = self.lfilter.forward(x=e, b=None, a=a)
81 | return x
82 |
83 |
84 | def prediction(self,
85 | x: torch.Tensor,
86 | a: torch.Tensor) -> torch.Tensor:
87 | """
88 | Args:
89 | x: Input audio (batch, channels, time)
90 | a: Allpole coefficients (batch, order+1, time)
91 |
92 | Returns:
93 | p: Linear prediction signal (batch, time)
94 | """
95 | a_pred = -a
96 | a_pred[:, 0, :] = 0.0
97 |
98 | # predictor filter
99 | p = self.lfilter.forward(x=x, b=a_pred, a=None)
100 |
101 | return p
--------------------------------------------------------------------------------
/src/neural_formant_synthesis/glotnet/sigproc/lsf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | def conv(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
5 | """ Convolution of two signals
6 |
7 | Args:
8 | x: First signal (batch, channels, time_x)
9 | y: Second signal (batch, channels, time_y)
10 |
11 | Returns:
12 | torch.Tensor: Convolution (batch, channels, time_x + time_y - 1)
13 |
14 | """
15 |
16 | # if x.shape[0] != y.shape[0]:
17 | # raise ValueError("x and y must have same batch size")
18 | # if x.shape[1] != y.shape[1]:
19 | # raise ValueError("x and y must have same number of channels")
20 |
21 | length = x.shape[-1] + y.shape[-1] - 1
22 |
23 | X = torch.fft.rfft(x, n=length, dim=-1)
24 | Y = torch.fft.rfft(y, n=length, dim=-1)
25 | Z = X * Y
26 | z = torch.fft.irfft(Z, dim=-1, n=length)
27 |
28 | return z
29 |
30 | def lsf2poly(lsf: torch.Tensor) -> torch.Tensor:
31 | """ Line Spectral Frequencies to Polynomial Coefficients
32 |
33 | Args:
34 | lsf (torch.Tensor): Line spectral frequencies (batch, time, order)
35 |
36 | Returns:
37 | poly (torch.Tensor): Polynomial coefficients (batch, time, order+1)
38 |
39 |
40 | References:
41 | https://github.com/cokelaer/spectrum/blob/master/src/spectrum/linear_prediction.py
42 | """
43 |
44 | if lsf.min() < 0:
45 | raise ValueError("lsf must be non-negative")
46 | if lsf.max() > torch.pi:
47 | raise ValueError("lsf must be less than pi")
48 |
49 | order = lsf.shape[-1]
50 |
51 | lsf, _ = torch.sort(lsf, dim=-1)
52 |
53 | # split to P and Q
54 | wP = lsf[:, :, ::2].unsqueeze(-1)
55 | wQ = lsf[:, :, 1::2].unsqueeze(-1)
56 |
57 | P_len = wP.shape[-2]
58 | Q_len = wQ.shape[-2]
59 |
60 | # compute conjugate pair polynomials
61 | pad = (1,1,0,0)
62 | Pi = F.pad(-2.0*torch.cos(wP), pad, mode='constant', value=1.0)
63 | Qi = F.pad(-2.0*torch.cos(wQ), pad, mode='constant', value=1.0)
64 |
65 | # Pi = torch.cat(1.0, -2 * torch.cos(wP), 1.0)
66 | # Qi = torch.cat(1.0, -2 * torch.cos(wQ), 1.0)
67 |
68 | # construct polynomials
69 | P = Pi[:,:, 0, :]
70 | for i in range(1, P_len):
71 | P = conv(P, Pi[:, :, i, :])
72 |
73 | Q = Qi[:, :, 0, :]
74 | for i in range(1, Q_len):
75 | Q = conv(Q, Qi[:, :, i, :])
76 |
77 | # add trivial zeros
78 | if order % 2 == 0:
79 | # P = conv(P, torch.tensor([1.0, -1.0]).reshape(1, 1, -1))
80 | # Q = conv(Q, torch.tensor([1.0, 1.0]).reshape(1, 1, -1))
81 | P = conv(P, torch.tensor([1.0, 1.0]).reshape(1, 1, -1))
82 | Q = conv(Q, torch.tensor([-1.0, 1.0]).reshape(1, 1, -1))
83 | else:
84 | # Q = conv(Q, torch.tensor([1.0, 0.0, -1.0]).reshape(1, 1, -1))
85 | Q = conv(Q, torch.tensor([-1.0, 0.0, 1.0]).reshape(1, 1, -1))
86 |
87 | # compute polynomial coefficients
88 | A = 0.5 * (P + Q)
89 |
90 | return A[:, :, 1:].flip(-1)
91 |
92 |
93 |
--------------------------------------------------------------------------------
/src/neural_formant_synthesis/glotnet/sigproc/melspec.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchaudio
3 |
4 | from .levinson import levinson, spectrum_to_allpole
5 |
6 | class SpectralNormalization(torch.nn.Module):
7 | def forward(self, input):
8 | return torch.log(torch.clamp(input, min=1e-5))
9 |
10 | class InverseSpectralNormalization(torch.nn.Module):
11 | def forward(self, input):
12 | return torch.exp(input)
13 |
14 | # Tacotron 2 reference configuration
15 | #https://github.com/pytorch/audio/blob/6b2b6c79ca029b4aa9bdb72d12ad061b144c2410/examples/pipeline_tacotron2/train.py#L284
16 | class LogMelSpectrogram(torch.nn.Module):
17 | """ Log Mel Spectrogram """
18 |
19 | def __init__(self,
20 | sample_rate: int = 16000,
21 | n_fft: int = 1024,
22 | win_length: int = None,
23 | hop_length: int = None,
24 | n_mels: int = 80,
25 | f_min: float = 0.0,
26 | f_max: float = None,
27 | mel_scale: str = "slaney",
28 | normalized: bool = False,
29 | power: float = 1.0,
30 | norm: str = "slaney"
31 | ):
32 | super().__init__()
33 |
34 | self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(
35 | sample_rate=sample_rate,
36 | n_fft=n_fft,
37 | win_length=win_length,
38 | hop_length=hop_length,
39 | f_min=f_min,
40 | f_max=f_max,
41 | n_mels=n_mels,
42 | mel_scale=mel_scale,
43 | normalized=normalized,
44 | power=power,
45 | norm=norm,
46 | )
47 | self.log = SpectralNormalization()
48 | self.exp = InverseSpectralNormalization()
49 |
50 | fb_pinv = torch.linalg.pinv(self.mel_spectrogram.mel_scale.fb)
51 | self.register_buffer('fb_pinv', fb_pinv)
52 |
53 |
54 | def forward(self, x: torch.Tensor) -> torch.Tensor:
55 | """
56 | Args:
57 | x : input signal
58 | (batch, channels, timesteps)
59 | """
60 | X = self.mel_spectrogram(x)
61 |
62 | return self.log(X)
63 |
64 | def allpole(self, X: torch.Tensor, order:int = 20) -> torch.Tensor:
65 | """
66 | Args:
67 | X : input mel spectrogram
68 | order : all pole model order
69 |
70 | Returns:
71 | a : allpole filter coefficients
72 |
73 | """
74 |
75 | # invert normalization
76 | X = self.exp(X)
77 |
78 | # (..., F, T) -> (..., T, F)
79 | X = X.transpose(-1, -2)
80 | # pseudoinvert mel filterbank
81 | X = torch.matmul(X, self.fb_pinv).clamp(min=1e-9)
82 |
83 | # power spectrum (squared magnitude) spectrum
84 | X = torch.pow(X, 2.0 / self.mel_spectrogram.power)
85 | X = X.clamp(min=1e-9)
86 |
87 | g, a = spectrum_to_allpole(X, order=order)
88 | # (..., T, order) -> (..., order, T)
89 | a = a.transpose(-1, -2)
90 | return a
91 |
92 |
--------------------------------------------------------------------------------
/src/neural_formant_synthesis/glotnet/sigproc/oscillator.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class Oscillator(torch.nn.Module):
4 | """ Sinusoidal oscillator """
5 | def __init__(self,
6 | audio_rate:int=48000,
7 | control_rate:int=200,
8 | shape:str='sin'):
9 | """
10 | Args:
11 | audio_rate: audio sample rate in samples per second
12 | control_rate: control sample rate in samples per second
13 | typically equal to 1 / frame_length
14 | """
15 | super().__init__()
16 |
17 | self.audio_rate = audio_rate
18 | self.control_rate = control_rate
19 | self.nyquist_rate = audio_rate // 2
20 |
21 | upsample_factor = self.audio_rate // self.control_rate
22 | self.upsampler = torch.nn.modules.Upsample(
23 | mode='linear',
24 | scale_factor=upsample_factor,
25 | align_corners=False)
26 |
27 | self.audio_step = 1.0 / audio_rate
28 | self.control_step = 1.0 / control_rate
29 | self.shape = shape
30 |
31 |
32 | def forward(self, f0, init_phase=None):
33 | """
34 | Args:
35 | f0 : fundamental frequency, shape (batch_size, channels, num_frames)
36 | Returns:
37 | x : sinusoid, shape (batch_size, channels, num_samples)
38 | """
39 |
40 | f0 = torch.clamp(f0, min=0.0, max=self.nyquist_rate)
41 | inst_freq = self.upsampler(f0)
42 | if_shape = inst_freq.shape
43 | if init_phase is None:
44 | # random initial phase in range [-pi, pi]
45 | init_phase = 2 * torch.pi * (torch.rand(if_shape[0], if_shape[1], 1) - 0.5)
46 | # integrate instantaneous frequency for phase
47 | phase = torch.cumsum(2 * torch.pi * inst_freq * self.audio_step, dim=-1)
48 |
49 | if self.shape == 'sin':
50 | return torch.sin(phase + init_phase)
51 | elif self.shape == 'saw':
52 | return (torch.fmod(phase + init_phase, 2 * torch.pi) - torch.pi) / torch.pi
53 |
--------------------------------------------------------------------------------
/src/neural_formant_synthesis/sigproc/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ljuvela/SourceFilterNeuralFormants/d0894c2aa153510e6967c3b62c47f73ce4cb3879/src/neural_formant_synthesis/sigproc/__init__.py
--------------------------------------------------------------------------------
/src/neural_formant_synthesis/sigproc/levinson.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | def toeplitz(r: torch.Tensor):
5 | """" Construct Toeplitz matrix """
6 | p = r.size(-1)
7 | rr = torch.cat([r, r[..., 1:].flip(dims=(-1,))], dim=-1)
8 | T = [torch.roll(rr, i, dims=(-1,))[...,:p] for i in range(p)]
9 | return torch.stack(T, dim=-1)
10 |
11 |
12 | def levinson(R:torch.Tensor, M:int, eps:float=1e-3) -> torch.Tensor:
13 | """ Levinson-Durbin method for converting autocorrelation to all-pole polynomial
14 | Args:
15 | R: autocorrelation, shape=(..., M)
16 | M: filter polynomial order
17 | Returns:
18 | A: all-pole polynomial, shape=(..., M)
19 | """
20 | # normalize R
21 | R = R / R[..., 0:1]
22 | # white noise correction
23 | R[..., 0] = R[..., 0] + eps
24 |
25 | # zero lag
26 | K = torch.sum(R[..., 1:2], dim=-1, keepdim=True)
27 | A = torch.cat([-1.0*K, torch.ones_like(R[..., 0:1])], dim=-1)
28 | E = 1.0 - K ** 2
29 | # higher lags
30 | for p in torch.arange(1, M):
31 | K = torch.sum(A[..., 0:p+1] * R[..., 1:p+2], dim=-1, keepdim=True) / E
32 | if K.abs().max() > 1.0:
33 | raise ValueError(f"Unstable filter, |K| was {K.abs().max()}")
34 | A = torch.cat([-1.0*K,
35 | A[..., 0:p] - 1.0*K *
36 | torch.flip(A[..., 0:p], dims=[-1]),
37 | torch.ones_like(R[..., 0:1])], dim=-1)
38 | E = E * (1.0 - K ** 2)
39 | A = torch.flip(A, dims=[-1])
40 | return A
41 |
42 |
43 | def forward_levinson(K: torch.Tensor, M: int = None) -> torch.Tensor:
44 | """ Forward Levinson converts reflection coefficients to all-pole polynomial
45 |
46 | Args:
47 | K: reflection coefficients, shape=(..., M)
48 | M: filter polynomial order (optional, defaults to K.size(-1))
49 | Returns:
50 | A: all-pole polynomial, shape=(..., M+1)
51 |
52 | """
53 | if M is None:
54 | M = K.size(-1)
55 |
56 | A = -1.0*K[..., 0:1]
57 | for p in torch.arange(1, M):
58 | A = torch.cat([-1.0*K[..., p:p+1],
59 | A[..., 0:p] - 1.0*K[..., p:p+1] * torch.flip(A[..., 0:p], dims=[-1])], dim=-1)
60 |
61 | A = torch.cat([A, torch.ones_like(A[..., 0:1])], dim=-1)
62 | A = torch.flip(A, dims=[-1]) # flip zero delay to zero:th index
63 | return A
64 |
65 |
66 | def spectrum_to_allpole(spectrum:torch.Tensor, order:int, root_scale:float=1.0):
67 | """ Convert spectrum to all-pole filter coefficients
68 |
69 | Args:
70 | spectrum: power spectrum (squared magnitude), shape=(..., K)
71 | order: filter polynomial order
72 |
73 | Returns:
74 | a: filter predictor polynomial tensor, shape=(..., order+1)
75 | g: filter gain
76 |
77 | """
78 | r = torch.fft.irfft(spectrum, dim=-1)
79 | # add small value to diagonal to avoid singular matrix
80 | r[..., 0] = r[..., 0] + 1e-5
81 | # all pole from autocorr
82 | a = levinson(r, order)
83 |
84 | # filter gain
85 | g = torch.sqrt(torch.sum(r[..., :(order+1)] * a, dim=-1, keepdim=True))
86 |
87 | # scale filter roots
88 | if root_scale < 1.0:
89 | a = a * root_scale ** torch.arange(order+1, dtype=torch.float32, device=a.device)
90 |
91 | return a, g
--------------------------------------------------------------------------------
/src/neural_formant_synthesis/sigproc/lpc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from neural_formant_synthesis.glotnet.sigproc.lfilter import LFilter
3 | from .levinson import spectrum_to_allpole
4 |
5 |
6 | class LinearPredictor(torch.nn.Module):
7 |
8 | def __init__(self,
9 | n_fft: int = 512,
10 | hop_length: int = 256,
11 | win_length=512,
12 | order=10):
13 | """
14 | Args:
15 | n_fft (int): FFT size
16 | hop_length (int): Hop length
17 | win_length (int): Window length
18 | order (int): Allpole order
19 | """
20 | super().__init__()
21 | self.n_fft = n_fft
22 | self.hop_length = hop_length
23 | self.win_length = win_length
24 | self.order = order
25 | self.lfilter = LFilter(n_fft=n_fft, hop_length=hop_length, win_length=win_length)
26 | self.register_buffer('window', torch.hann_window(win_length))
27 |
28 | def estimate(self, x: torch.Tensor, root_scale:float=1.0) -> torch.Tensor:
29 | """
30 | Args:
31 | x (torch.Tensor): Input audio (batch, time)
32 |
33 | Returns:
34 | torch.Tensor: Allpole coefficients (batch, order+1, time)
35 | torch.Tensor: Gain (batch, 1, time)
36 | """
37 | X = torch.stft(x, n_fft=self.n_fft,
38 | hop_length=self.hop_length, win_length=self.win_length,
39 | window=self.window,
40 | return_complex=True)
41 |
42 | # power spectrum
43 | X = torch.abs(X)**2
44 |
45 | # transpose to (batch, time, freq)
46 | X = X.transpose(1, 2)
47 |
48 | # allpole coefficients
49 | a, g = spectrum_to_allpole(X, order=self.order, root_scale=root_scale)
50 |
51 | # transpose to (batch, order, num_frames)
52 | a = a.transpose(1, 2)
53 | # transpose to (batch, 1, num_frames)
54 | g = g.transpose(1, 2)
55 |
56 | A = torch.fft.rfft(a, n=512, dim=1).abs()
57 | H = g / (A + 1e-6)
58 | # H = 1 / (A + 1e-6)
59 |
60 | return a, g, H
61 |
62 | def inverse_filter(self,
63 | x: torch.Tensor,
64 | a: torch.Tensor) -> torch.Tensor:
65 | """
66 | Args:
67 | x: Input audio (batch, time)
68 | a: Allpole coefficients (batch, order+1, time)
69 |
70 | Returns:
71 | Inverse filtered audio (batch, time)
72 | """
73 | # inverse filter
74 | e = self.lfilter.forward(x=x, b=a, a=None)
75 |
76 | return e
77 |
78 | def synthesis_filter(self,
79 | e: torch.Tensor,
80 | a: torch.Tensor,
81 | g: torch.Tensor) -> torch.Tensor:
82 | """
83 | Args:
84 | e: Excitation signal (batch, channels, time)
85 | a: Allpole coefficients (batch, order+1, time)
86 | g: Gain for filters (batch, 1, time)
87 |
88 | Returns:
89 | torch.Tensor: Synthesis filtered audio (batch, time)
90 | """
91 | # inverse filter
92 | x = self.lfilter.forward(x=e, b=g, a=a)
93 | return x
94 |
95 |
96 | def prediction(self,
97 | x: torch.Tensor,
98 | a: torch.Tensor) -> torch.Tensor:
99 | """
100 | Args:
101 | x: Input audio (batch, channels, time)
102 | a: Allpole coefficients (batch, order+1, time)
103 |
104 | Returns:
105 | p: Linear prediction signal (batch, time)
106 | """
107 | a_pred = -a
108 | a_pred[:, 0, :] = 0.0
109 |
110 | # predictor filter
111 | p = self.lfilter.forward(x=x, b=a_pred, a=None)
112 |
113 | return p
--------------------------------------------------------------------------------
/src/neural_formant_synthesis/third_party/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ljuvela/SourceFilterNeuralFormants/d0894c2aa153510e6967c3b62c47f73ce4cb3879/src/neural_formant_synthesis/third_party/__init__.py
--------------------------------------------------------------------------------
/src/neural_formant_synthesis/third_party/hifi_gan/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Jungil Kong
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/src/neural_formant_synthesis/third_party/hifi_gan/README.md:
--------------------------------------------------------------------------------
1 | # HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis
2 |
3 | ### Jungil Kong, Jaehyeon Kim, Jaekyoung Bae
4 |
5 | In our [paper](https://arxiv.org/abs/2010.05646),
6 | we proposed HiFi-GAN: a GAN-based model capable of generating high fidelity speech efficiently.
7 | We provide our implementation and pretrained models as open source in this repository.
8 |
9 | **Abstract :**
10 | Several recent work on speech synthesis have employed generative adversarial networks (GANs) to produce raw waveforms.
11 | Although such methods improve the sampling efficiency and memory usage,
12 | their sample quality has not yet reached that of autoregressive and flow-based generative models.
13 | In this work, we propose HiFi-GAN, which achieves both efficient and high-fidelity speech synthesis.
14 | As speech audio consists of sinusoidal signals with various periods,
15 | we demonstrate that modeling periodic patterns of an audio is crucial for enhancing sample quality.
16 | A subjective human evaluation (mean opinion score, MOS) of a single speaker dataset indicates that our proposed method
17 | demonstrates similarity to human quality while generating 22.05 kHz high-fidelity audio 167.9 times faster than
18 | real-time on a single V100 GPU. We further show the generality of HiFi-GAN to the mel-spectrogram inversion of unseen
19 | speakers and end-to-end speech synthesis. Finally, a small footprint version of HiFi-GAN generates samples 13.4 times
20 | faster than real-time on CPU with comparable quality to an autoregressive counterpart.
21 |
22 | Visit our [demo website](https://jik876.github.io/hifi-gan-demo/) for audio samples.
23 |
24 |
25 | ## Pre-requisites
26 | 1. Python >= 3.6
27 | 2. Clone this repository.
28 | 3. Install python requirements. Please refer [requirements.txt](requirements.txt)
29 | 4. Download and extract the [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/).
30 | And move all wav files to `LJSpeech-1.1/wavs`
31 |
32 |
33 | ## Training
34 | ```
35 | python train.py --config config_v1.json
36 | ```
37 | To train V2 or V3 Generator, replace `config_v1.json` with `config_v2.json` or `config_v3.json`.
38 | Checkpoints and copy of the configuration file are saved in `cp_hifigan` directory by default.
39 | You can change the path by adding `--checkpoint_path` option.
40 |
41 |
42 | ## Pretrained Model
43 | You can also use pretrained models we provide.
44 | [Download pretrained models](https://drive.google.com/drive/folders/1-eEYTB5Av9jNql0WGBlRoi-WH2J7bp5Y?usp=sharing)
45 | Details of each folder are as in follows:
46 |
47 | |Folder Name|Generator|Dataset|Fine-Tuned|
48 | |------|---|---|---|
49 | |LJ_V1|V1|LJSpeech|No|
50 | |LJ_V2|V2|LJSpeech|No|
51 | |LJ_V3|V3|LJSpeech|No|
52 | |LJ_FT_T2_V1|V1|LJSpeech|Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2))|
53 | |LJ_FT_T2_V2|V2|LJSpeech|Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2))|
54 | |LJ_FT_T2_V3|V3|LJSpeech|Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2))|
55 | |VCTK_V1|V1|VCTK|No|
56 | |VCTK_V2|V2|VCTK|No|
57 | |VCTK_V3|V3|VCTK|No|
58 | |UNIVERSAL_V1|V1|Universal|No|
59 |
60 | We provide the universal model with discriminator weights that can be used as a base for transfer learning to other datasets.
61 |
62 | ## Fine-Tuning
63 | 1. Generate mel-spectrograms in numpy format using [Tacotron2](https://github.com/NVIDIA/tacotron2) with teacher-forcing.
64 | The file name of the generated mel-spectrogram should match the audio file and the extension should be `.npy`.
65 | Example:
66 | ```
67 | Audio File : LJ001-0001.wav
68 | Mel-Spectrogram File : LJ001-0001.npy
69 | ```
70 | 2. Create `ft_dataset` folder and copy the generated mel-spectrogram files into it.
71 | 3. Run the following command.
72 | ```
73 | python train.py --fine_tuning True --config config_v1.json
74 | ```
75 | For other command line options, please refer to the training section.
76 |
77 |
78 | ## Inference from wav file
79 | 1. Make `test_files` directory and copy wav files into the directory.
80 | 2. Run the following command.
81 | ```
82 | python inference.py --checkpoint_file [generator checkpoint file path]
83 | ```
84 | Generated wav files are saved in `generated_files` by default.
85 | You can change the path by adding `--output_dir` option.
86 |
87 |
88 | ## Inference for end-to-end speech synthesis
89 | 1. Make `test_mel_files` directory and copy generated mel-spectrogram files into the directory.
90 | You can generate mel-spectrograms using [Tacotron2](https://github.com/NVIDIA/tacotron2),
91 | [Glow-TTS](https://github.com/jaywalnut310/glow-tts) and so forth.
92 | 2. Run the following command.
93 | ```
94 | python inference_e2e.py --checkpoint_file [generator checkpoint file path]
95 | ```
96 | Generated wav files are saved in `generated_files_from_mel` by default.
97 | You can change the path by adding `--output_dir` option.
98 |
99 |
100 | ## Acknowledgements
101 | We referred to [WaveGlow](https://github.com/NVIDIA/waveglow), [MelGAN](https://github.com/descriptinc/melgan-neurips)
102 | and [Tacotron2](https://github.com/NVIDIA/tacotron2) to implement this.
103 |
104 |
--------------------------------------------------------------------------------
/src/neural_formant_synthesis/third_party/hifi_gan/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ljuvela/SourceFilterNeuralFormants/d0894c2aa153510e6967c3b62c47f73ce4cb3879/src/neural_formant_synthesis/third_party/hifi_gan/__init__.py
--------------------------------------------------------------------------------
/src/neural_formant_synthesis/third_party/hifi_gan/config_v1.json:
--------------------------------------------------------------------------------
1 | {
2 | "resblock": "1",
3 | "num_gpus": 0,
4 | "batch_size": 16,
5 | "learning_rate": 0.0002,
6 | "adam_b1": 0.8,
7 | "adam_b2": 0.99,
8 | "lr_decay": 0.999,
9 | "seed": 1234,
10 |
11 | "upsample_rates": [8,8,2,2],
12 | "upsample_kernel_sizes": [16,16,4,4],
13 | "upsample_initial_channel": 512,
14 | "resblock_kernel_sizes": [3,7,11],
15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16 |
17 | "segment_size": 8192,
18 | "num_mels": 80,
19 | "num_freq": 1025,
20 | "n_fft": 1024,
21 | "hop_size": 256,
22 | "win_size": 1024,
23 |
24 | "sampling_rate": 22050,
25 |
26 | "fmin": 0,
27 | "fmax": 8000,
28 | "fmax_for_loss": null,
29 |
30 | "num_workers": 4,
31 |
32 | "dist_config": {
33 | "dist_backend": "nccl",
34 | "dist_url": "tcp://localhost:54321",
35 | "world_size": 1
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/src/neural_formant_synthesis/third_party/hifi_gan/config_v2.json:
--------------------------------------------------------------------------------
1 | {
2 | "resblock": "1",
3 | "num_gpus": 0,
4 | "batch_size": 16,
5 | "learning_rate": 0.0002,
6 | "adam_b1": 0.8,
7 | "adam_b2": 0.99,
8 | "lr_decay": 0.999,
9 | "seed": 1234,
10 |
11 | "upsample_rates": [8,8,2,2],
12 | "upsample_kernel_sizes": [16,16,4,4],
13 | "upsample_initial_channel": 128,
14 | "resblock_kernel_sizes": [3,7,11],
15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16 |
17 | "segment_size": 8192,
18 | "num_mels": 80,
19 | "num_freq": 1025,
20 | "n_fft": 1024,
21 | "hop_size": 256,
22 | "win_size": 1024,
23 |
24 | "sampling_rate": 22050,
25 |
26 | "fmin": 0,
27 | "fmax": 8000,
28 | "fmax_for_loss": null,
29 |
30 | "num_workers": 4,
31 |
32 | "dist_config": {
33 | "dist_backend": "nccl",
34 | "dist_url": "tcp://localhost:54321",
35 | "world_size": 1
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/src/neural_formant_synthesis/third_party/hifi_gan/config_v3.json:
--------------------------------------------------------------------------------
1 | {
2 | "resblock": "2",
3 | "num_gpus": 0,
4 | "batch_size": 16,
5 | "learning_rate": 0.0002,
6 | "adam_b1": 0.8,
7 | "adam_b2": 0.99,
8 | "lr_decay": 0.999,
9 | "seed": 1234,
10 |
11 | "upsample_rates": [8,8,4],
12 | "upsample_kernel_sizes": [16,16,8],
13 | "upsample_initial_channel": 256,
14 | "resblock_kernel_sizes": [3,5,7],
15 | "resblock_dilation_sizes": [[1,2], [2,6], [3,12]],
16 |
17 | "segment_size": 8192,
18 | "num_mels": 80,
19 | "num_freq": 1025,
20 | "n_fft": 1024,
21 | "hop_size": 256,
22 | "win_size": 1024,
23 |
24 | "sampling_rate": 22050,
25 |
26 | "fmin": 0,
27 | "fmax": 8000,
28 | "fmax_for_loss": null,
29 |
30 | "num_workers": 4,
31 |
32 | "dist_config": {
33 | "dist_backend": "nccl",
34 | "dist_url": "tcp://localhost:54321",
35 | "world_size": 1
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/src/neural_formant_synthesis/third_party/hifi_gan/env.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 |
5 | class AttrDict(dict):
6 | def __init__(self, *args, **kwargs):
7 | super(AttrDict, self).__init__(*args, **kwargs)
8 | self.__dict__ = self
9 |
10 |
11 | def build_env(config, config_name, path):
12 | t_path = os.path.join(path, config_name)
13 | if config != t_path:
14 | os.makedirs(path, exist_ok=True)
15 | shutil.copyfile(config, os.path.join(path, config_name))
16 |
--------------------------------------------------------------------------------
/src/neural_formant_synthesis/third_party/hifi_gan/inference.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function, unicode_literals
2 |
3 | import glob
4 | import os
5 | import argparse
6 | import json
7 | import torch
8 | from env import AttrDict
9 | from meldataset import mel_spectrogram, MAX_WAV_VALUE, load_wav
10 | from models import Generator
11 | import soundfile as sf
12 |
13 | h = None
14 | device = None
15 |
16 |
17 | def load_checkpoint(filepath, device):
18 | assert os.path.isfile(filepath)
19 | print("Loading '{}'".format(filepath))
20 | checkpoint_dict = torch.load(filepath, map_location=device)
21 | print("Complete.")
22 | return checkpoint_dict
23 |
24 |
25 | def get_mel(x):
26 | return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax)
27 |
28 |
29 | def scan_checkpoint(cp_dir, prefix):
30 | pattern = os.path.join(cp_dir, prefix + '*')
31 | cp_list = glob.glob(pattern)
32 | if len(cp_list) == 0:
33 | return ''
34 | return sorted(cp_list)[-1]
35 |
36 |
37 | def inference(a):
38 | generator = Generator(h).to(device)
39 |
40 | state_dict_g = load_checkpoint(a.checkpoint_file, device)
41 | generator.load_state_dict(state_dict_g['generator'])
42 |
43 | filelist = os.listdir(a.input_wavs_dir)
44 |
45 | os.makedirs(a.output_dir, exist_ok=True)
46 |
47 | generator.eval()
48 | generator.remove_weight_norm()
49 | with torch.no_grad():
50 | for i, filname in enumerate(filelist):
51 | wav, sr = load_wav(os.path.join(a.input_wavs_dir, filname))
52 | wav = torch.FloatTensor(wav).to(device)
53 | x = get_mel(wav.unsqueeze(0))
54 | y_g_hat = generator(x)
55 | audio = y_g_hat.squeeze()
56 | audio = audio.cpu().numpy()
57 |
58 | output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + '_generated.wav')
59 | sf.write(output_file, audio, sr)
60 | print(output_file)
61 |
62 |
63 | def main():
64 | print('Initializing Inference Process..')
65 |
66 | parser = argparse.ArgumentParser()
67 | parser.add_argument('--input_wavs_dir', default='test_files')
68 | parser.add_argument('--output_dir', default='generated_files')
69 | parser.add_argument('--checkpoint_file', required=True)
70 | a = parser.parse_args()
71 |
72 | config_file = os.path.join(os.path.split(a.checkpoint_file)[0], 'config.json')
73 | with open(config_file) as f:
74 | data = f.read()
75 |
76 | global h
77 | json_config = json.loads(data)
78 | h = AttrDict(json_config)
79 |
80 | torch.manual_seed(h.seed)
81 | global device
82 | if torch.cuda.is_available():
83 | torch.cuda.manual_seed(h.seed)
84 | device = torch.device('cuda')
85 | else:
86 | device = torch.device('cpu')
87 |
88 | inference(a)
89 |
90 |
91 | if __name__ == '__main__':
92 | main()
93 |
94 |
--------------------------------------------------------------------------------
/src/neural_formant_synthesis/third_party/hifi_gan/inference_e2e.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function, unicode_literals
2 |
3 | import glob
4 | import os
5 | import numpy as np
6 | import argparse
7 | import json
8 | import torch
9 | from env import AttrDict
10 | from meldataset import MAX_WAV_VALUE
11 | from models import Generator
12 |
13 | h = None
14 | device = None
15 |
16 |
17 | def load_checkpoint(filepath, device):
18 | assert os.path.isfile(filepath)
19 | print("Loading '{}'".format(filepath))
20 | checkpoint_dict = torch.load(filepath, map_location=device)
21 | print("Complete.")
22 | return checkpoint_dict
23 |
24 |
25 | def scan_checkpoint(cp_dir, prefix):
26 | pattern = os.path.join(cp_dir, prefix + '*')
27 | cp_list = glob.glob(pattern)
28 | if len(cp_list) == 0:
29 | return ''
30 | return sorted(cp_list)[-1]
31 |
32 |
33 | def inference(a):
34 | generator = Generator(h).to(device)
35 |
36 | state_dict_g = load_checkpoint(a.checkpoint_file, device)
37 | generator.load_state_dict(state_dict_g['generator'])
38 |
39 | filelist = os.listdir(a.input_mels_dir)
40 |
41 | os.makedirs(a.output_dir, exist_ok=True)
42 |
43 | generator.eval()
44 | generator.remove_weight_norm()
45 | with torch.no_grad():
46 | for i, filname in enumerate(filelist):
47 | x = np.load(os.path.join(a.input_mels_dir, filname))
48 | x = torch.FloatTensor(x).to(device)
49 | y_g_hat = generator(x)
50 | audio = y_g_hat.squeeze()
51 | audio = audio * MAX_WAV_VALUE
52 | audio = audio.cpu().numpy().astype('int16')
53 |
54 | output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + '_generated_e2e.wav')
55 | write(output_file, h.sampling_rate, audio)
56 | print(output_file)
57 |
58 |
59 | def main():
60 | print('Initializing Inference Process..')
61 |
62 | parser = argparse.ArgumentParser()
63 | parser.add_argument('--input_mels_dir', default='test_mel_files')
64 | parser.add_argument('--output_dir', default='generated_files_from_mel')
65 | parser.add_argument('--checkpoint_file', required=True)
66 | a = parser.parse_args()
67 |
68 | config_file = os.path.join(os.path.split(a.checkpoint_file)[0], 'config.json')
69 | with open(config_file) as f:
70 | data = f.read()
71 |
72 | global h
73 | json_config = json.loads(data)
74 | h = AttrDict(json_config)
75 |
76 | torch.manual_seed(h.seed)
77 | global device
78 | if torch.cuda.is_available():
79 | torch.cuda.manual_seed(h.seed)
80 | device = torch.device('cuda')
81 | else:
82 | device = torch.device('cpu')
83 |
84 | inference(a)
85 |
86 |
87 | if __name__ == '__main__':
88 | main()
89 |
90 |
--------------------------------------------------------------------------------
/src/neural_formant_synthesis/third_party/hifi_gan/meldataset.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import random
4 | import torch
5 | import torch.utils.data
6 | import numpy as np
7 | import soundfile as sf
8 |
9 | import torchaudio
10 |
11 | MAX_WAV_VALUE = 32768.0
12 |
13 |
14 | def load_wav(full_path):
15 | data, sampling_rate = sf.read(full_path)
16 | # sampling_rate, data = read(full_path)
17 | return data, sampling_rate
18 |
19 |
20 | def dynamic_range_compression(x, C=1, clip_val=1e-5):
21 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
22 |
23 |
24 | def dynamic_range_decompression(x, C=1):
25 | return np.exp(x) / C
26 |
27 |
28 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
29 | return torch.log(torch.clamp(x, min=clip_val) * C)
30 |
31 |
32 | def dynamic_range_decompression_torch(x, C=1):
33 | return torch.exp(x) / C
34 |
35 |
36 | def spectral_normalize_torch(magnitudes):
37 | output = dynamic_range_compression_torch(magnitudes)
38 | return output
39 |
40 |
41 | def spectral_de_normalize_torch(magnitudes):
42 | output = dynamic_range_decompression_torch(magnitudes)
43 | return output
44 |
45 |
46 | mel_basis = {}
47 | hann_window = {}
48 |
49 |
50 | def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
51 | if torch.min(y) < -1.:
52 | print('min value is ', torch.min(y))
53 | if torch.max(y) > 1.:
54 | print('max value is ', torch.max(y))
55 |
56 | global mel_basis, hann_window
57 | if fmax not in mel_basis:
58 | mspec = torchaudio.transforms.MelSpectrogram(
59 | sample_rate=sampling_rate,
60 | n_fft=n_fft,
61 | f_min=fmin,
62 | f_max=fmax,
63 | n_mels=num_mels,
64 | norm='slaney',
65 | mel_scale='slaney',
66 | )
67 | mel = mspec.mel_scale.fb.T.numpy()
68 | mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
69 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
70 |
71 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
72 | y = y.squeeze(1)
73 |
74 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
75 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
76 |
77 | spec = spec.abs() + 1e-6
78 |
79 | spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
80 | spec = spectral_normalize_torch(spec)
81 |
82 | return spec
83 |
84 |
85 | def get_dataset_filelist(a, ext='.wav'):
86 | if hasattr(a, 'wavefile_ext'):
87 | ext = a.wavefile_ext
88 | with open(a.input_training_file, 'r', encoding='utf-8') as fi:
89 | training_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + ext)
90 | for x in fi.read().split('\n') if len(x) > 0]
91 |
92 | with open(a.input_validation_file, 'r', encoding='utf-8') as fi:
93 | validation_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + ext)
94 | for x in fi.read().split('\n') if len(x) > 0]
95 | return training_files, validation_files
96 |
97 |
98 | class MelDataset(torch.utils.data.Dataset):
99 | def __init__(self, training_files, segment_size, n_fft, num_mels,
100 | hop_size, win_size, sampling_rate, fmin, fmax, split=True, shuffle=True, n_cache_reuse=1,
101 | device=None, fmax_loss=None, fine_tuning=False, base_mels_path=None):
102 | self.audio_files = training_files
103 | random.seed(1234)
104 | if shuffle:
105 | random.shuffle(self.audio_files)
106 | self.segment_size = segment_size
107 | self.sampling_rate = sampling_rate
108 | self.split = split
109 | self.n_fft = n_fft
110 | self.num_mels = num_mels
111 | self.hop_size = hop_size
112 | self.win_size = win_size
113 | self.fmin = fmin
114 | self.fmax = fmax
115 | self.fmax_loss = fmax_loss
116 | self.cached_wav = None
117 | self.n_cache_reuse = n_cache_reuse
118 | self._cache_ref_count = 0
119 | self.device = device
120 | self.fine_tuning = fine_tuning
121 | self.base_mels_path = base_mels_path
122 |
123 | def __getitem__(self, index):
124 | filename = self.audio_files[index]
125 | if self._cache_ref_count == 0:
126 | audio, sampling_rate = load_wav(filename)
127 | self.cached_wav = audio
128 | # TODO: resample if needed
129 | if sampling_rate != self.sampling_rate:
130 | raise ValueError("{} SR doesn't match target {} SR".format(
131 | sampling_rate, self.sampling_rate))
132 | self._cache_ref_count = self.n_cache_reuse
133 | else:
134 | audio = self.cached_wav
135 | self._cache_ref_count -= 1
136 |
137 | audio = torch.FloatTensor(audio)
138 | audio = audio.unsqueeze(0)
139 |
140 | if not self.fine_tuning:
141 | if self.split:
142 | if audio.size(1) >= self.segment_size:
143 | max_audio_start = audio.size(1) - self.segment_size
144 | audio_start = random.randint(0, max_audio_start)
145 | audio = audio[:, audio_start:audio_start+self.segment_size]
146 | else:
147 | audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
148 |
149 | mel = mel_spectrogram(audio, self.n_fft, self.num_mels,
150 | self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax,
151 | center=False)
152 | else:
153 | mel = np.load(
154 | os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + '.npy'))
155 | mel = torch.from_numpy(mel)
156 |
157 | if len(mel.shape) < 3:
158 | mel = mel.unsqueeze(0)
159 |
160 | if self.split:
161 | frames_per_seg = math.ceil(self.segment_size / self.hop_size)
162 |
163 | if audio.size(1) >= self.segment_size:
164 | mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
165 | mel = mel[:, :, mel_start:mel_start + frames_per_seg]
166 | audio = audio[:, mel_start * self.hop_size:(mel_start + frames_per_seg) * self.hop_size]
167 | else:
168 | mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), 'constant')
169 | audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
170 |
171 | mel_loss = mel_spectrogram(audio, self.n_fft, self.num_mels,
172 | self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax_loss,
173 | center=False)
174 |
175 | return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
176 |
177 | def __len__(self):
178 | return len(self.audio_files)
179 |
--------------------------------------------------------------------------------
/src/neural_formant_synthesis/third_party/hifi_gan/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
5 | from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
6 | from neural_formant_synthesis.third_party.hifi_gan.utils import init_weights, get_padding
7 |
8 | LRELU_SLOPE = 0.1
9 |
10 |
11 | class ResBlock1(torch.nn.Module):
12 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
13 | super(ResBlock1, self).__init__()
14 | self.h = h
15 | self.convs1 = nn.ModuleList([
16 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
17 | padding=get_padding(kernel_size, dilation[0]))),
18 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
19 | padding=get_padding(kernel_size, dilation[1]))),
20 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
21 | padding=get_padding(kernel_size, dilation[2])))
22 | ])
23 | self.convs1.apply(init_weights)
24 |
25 | self.convs2 = nn.ModuleList([
26 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
27 | padding=get_padding(kernel_size, 1))),
28 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
29 | padding=get_padding(kernel_size, 1))),
30 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
31 | padding=get_padding(kernel_size, 1)))
32 | ])
33 | self.convs2.apply(init_weights)
34 |
35 | def forward(self, x):
36 | for c1, c2 in zip(self.convs1, self.convs2):
37 | xt = F.leaky_relu(x, LRELU_SLOPE)
38 | xt = c1(xt)
39 | xt = F.leaky_relu(xt, LRELU_SLOPE)
40 | xt = c2(xt)
41 | x = xt + x
42 | return x
43 |
44 | def remove_weight_norm(self):
45 | for l in self.convs1:
46 | remove_weight_norm(l)
47 | for l in self.convs2:
48 | remove_weight_norm(l)
49 |
50 |
51 | class ResBlock2(torch.nn.Module):
52 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
53 | super(ResBlock2, self).__init__()
54 | self.h = h
55 | self.convs = nn.ModuleList([
56 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
57 | padding=get_padding(kernel_size, dilation[0]))),
58 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
59 | padding=get_padding(kernel_size, dilation[1])))
60 | ])
61 | self.convs.apply(init_weights)
62 |
63 | def forward(self, x):
64 | for c in self.convs:
65 | xt = F.leaky_relu(x, LRELU_SLOPE)
66 | xt = c(xt)
67 | x = xt + x
68 | return x
69 |
70 | def remove_weight_norm(self):
71 | for l in self.convs:
72 | remove_weight_norm(l)
73 |
74 |
75 | class Generator(torch.nn.Module):
76 | def __init__(self, h, input_channels=80):
77 | super(Generator, self).__init__()
78 | self.h = h
79 | self.num_kernels = len(h.resblock_kernel_sizes)
80 | self.num_upsamples = len(h.upsample_rates)
81 | self.conv_pre = weight_norm(Conv1d(input_channels, h.upsample_initial_channel, 7, 1, padding=3))
82 | resblock = ResBlock1 if h.resblock == '1' else ResBlock2
83 |
84 | self.ups = nn.ModuleList()
85 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
86 | self.ups.append(weight_norm(
87 | ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
88 | k, u, padding=(k-u)//2)))
89 |
90 | self.resblocks = nn.ModuleList()
91 | for i in range(len(self.ups)):
92 | ch = h.upsample_initial_channel//(2**(i+1))
93 | for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
94 | self.resblocks.append(resblock(h, ch, k, d))
95 |
96 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
97 | self.ups.apply(init_weights)
98 | self.conv_post.apply(init_weights)
99 |
100 | def forward(self, x):
101 | x = self.conv_pre(x)
102 | for i in range(self.num_upsamples):
103 | x = F.leaky_relu(x, LRELU_SLOPE)
104 | x = self.ups[i](x)
105 | xs = None
106 | for j in range(self.num_kernels):
107 | if xs is None:
108 | xs = self.resblocks[i*self.num_kernels+j](x)
109 | else:
110 | xs += self.resblocks[i*self.num_kernels+j](x)
111 | x = xs / self.num_kernels
112 | x = F.leaky_relu(x)
113 | x = self.conv_post(x)
114 | x = torch.tanh(x)
115 |
116 | return x
117 |
118 | def remove_weight_norm(self):
119 | print('Removing weight norm...')
120 | for l in self.ups:
121 | remove_weight_norm(l)
122 | for l in self.resblocks:
123 | l.remove_weight_norm()
124 | remove_weight_norm(self.conv_pre)
125 | remove_weight_norm(self.conv_post)
126 |
127 |
128 | class DiscriminatorP(torch.nn.Module):
129 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
130 | super(DiscriminatorP, self).__init__()
131 | self.period = period
132 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm
133 | self.convs = nn.ModuleList([
134 | norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
135 | norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
136 | norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
137 | norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
138 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
139 | ])
140 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
141 |
142 | def forward(self, x):
143 | fmap = []
144 |
145 | # 1d to 2d
146 | b, c, t = x.shape
147 | if t % self.period != 0: # pad first
148 | n_pad = self.period - (t % self.period)
149 | x = F.pad(x, (0, n_pad), "reflect")
150 | t = t + n_pad
151 | x = x.view(b, c, t // self.period, self.period)
152 |
153 | for l in self.convs:
154 | x = l(x)
155 | x = F.leaky_relu(x, LRELU_SLOPE)
156 | fmap.append(x)
157 | x = self.conv_post(x)
158 | fmap.append(x)
159 | x = torch.flatten(x, 1, -1)
160 |
161 | return x, fmap
162 |
163 |
164 | class MultiPeriodDiscriminator(torch.nn.Module):
165 | def __init__(self):
166 | super(MultiPeriodDiscriminator, self).__init__()
167 | self.discriminators = nn.ModuleList([
168 | DiscriminatorP(2),
169 | DiscriminatorP(3),
170 | DiscriminatorP(5),
171 | DiscriminatorP(7),
172 | DiscriminatorP(11),
173 | ])
174 |
175 | def forward(self, y, y_hat):
176 | y_d_rs = []
177 | y_d_gs = []
178 | fmap_rs = []
179 | fmap_gs = []
180 | for i, d in enumerate(self.discriminators):
181 | y_d_r, fmap_r = d(y)
182 | y_d_g, fmap_g = d(y_hat)
183 | y_d_rs.append(y_d_r)
184 | fmap_rs.append(fmap_r)
185 | y_d_gs.append(y_d_g)
186 | fmap_gs.append(fmap_g)
187 |
188 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs
189 |
190 |
191 | class DiscriminatorS(torch.nn.Module):
192 | def __init__(self, use_spectral_norm=False):
193 | super(DiscriminatorS, self).__init__()
194 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm
195 | self.convs = nn.ModuleList([
196 | norm_f(Conv1d(1, 128, 15, 1, padding=7)),
197 | norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
198 | norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
199 | norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
200 | norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
201 | norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
202 | norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
203 | ])
204 | self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
205 |
206 | def forward(self, x):
207 | fmap = []
208 | for l in self.convs:
209 | x = l(x)
210 | x = F.leaky_relu(x, LRELU_SLOPE)
211 | fmap.append(x)
212 | x = self.conv_post(x)
213 | fmap.append(x)
214 | x = torch.flatten(x, 1, -1)
215 |
216 | return x, fmap
217 |
218 |
219 | class MultiScaleDiscriminator(torch.nn.Module):
220 | def __init__(self):
221 | super(MultiScaleDiscriminator, self).__init__()
222 | self.discriminators = nn.ModuleList([
223 | DiscriminatorS(use_spectral_norm=True),
224 | DiscriminatorS(),
225 | DiscriminatorS(),
226 | ])
227 | self.meanpools = nn.ModuleList([
228 | AvgPool1d(4, 2, padding=2),
229 | AvgPool1d(4, 2, padding=2)
230 | ])
231 |
232 | def forward(self, y, y_hat):
233 | y_d_rs = []
234 | y_d_gs = []
235 | fmap_rs = []
236 | fmap_gs = []
237 | for i, d in enumerate(self.discriminators):
238 | if i != 0:
239 | y = self.meanpools[i-1](y)
240 | y_hat = self.meanpools[i-1](y_hat)
241 | y_d_r, fmap_r = d(y)
242 | y_d_g, fmap_g = d(y_hat)
243 | y_d_rs.append(y_d_r)
244 | fmap_rs.append(fmap_r)
245 | y_d_gs.append(y_d_g)
246 | fmap_gs.append(fmap_g)
247 |
248 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs
249 |
250 |
251 | def feature_loss(fmap_r, fmap_g):
252 | loss = 0
253 | for dr, dg in zip(fmap_r, fmap_g):
254 | for rl, gl in zip(dr, dg):
255 | loss += torch.mean(torch.abs(rl - gl))
256 |
257 | return loss*2
258 |
259 |
260 | def discriminator_loss(disc_real_outputs, disc_generated_outputs):
261 | loss = 0
262 | r_losses = []
263 | g_losses = []
264 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
265 | r_loss = torch.mean((1-dr)**2)
266 | g_loss = torch.mean(dg**2)
267 | loss += (r_loss + g_loss)
268 | r_losses.append(r_loss.item())
269 | g_losses.append(g_loss.item())
270 |
271 | return loss, r_losses, g_losses
272 |
273 | class DiscriminatorMetrics():
274 |
275 | def __init__(self):
276 |
277 | self.true_accept_total = 0
278 | self.true_reject_total = 0
279 | self.false_accept_total = 0
280 | self.false_reject_total = 0
281 | self.num_samples_total = 0
282 |
283 | @property
284 | def accuracy(self):
285 | TA = self.true_accept_total
286 | TR = self.true_reject_total
287 | N = self.num_samples_total
288 | return 1.0 * (TA + TR) / N
289 |
290 | @property
291 | def false_accept_rate(self):
292 | return 1.0 * self.false_accept_total / self.num_samples_total
293 |
294 | @property
295 | def false_reject_rate(self):
296 | return 1.0 * self.false_reject_total / self.num_samples_total
297 |
298 | @property
299 | def equal_error_rate(self):
300 | return 0.5 * (self.false_rec)
301 |
302 | def accumulate(self, disc_real_outputs, disc_generated_outputs):
303 | """
304 | Args:
305 | disc_real_outputs:
306 | shape is (batch, channels, timesteps)
307 | disc_generated_outputs
308 | shape is (batch, channels, timesteps)
309 | """
310 | pred_real = []
311 | pred_gen = []
312 | # classifications for each discriminator
313 | for d_real, d_gen in zip(disc_real_outputs, disc_generated_outputs):
314 | # mean prediction over time and channels
315 | pred_real.append(torch.mean(d_real, dim=(-1,)) > 0.5)
316 | pred_gen.append(torch.mean(d_gen, dim=(-1,)) < 0.5)
317 |
318 | # Stack classifications from different discriminators
319 | pred_real = torch.stack(pred_real, dim=0)
320 | pred_gen = torch.stack(pred_gen, dim=0)
321 |
322 | # Majority vote (probabilites not available)
323 | pred_real_voted, _ = torch.median(pred_real * 1.0, dim=-1)
324 | pred_gen_voted, _ = torch.median(pred_gen * 1.0, dim=-1)
325 |
326 | if pred_real_voted.shape != pred_gen_voted.shape:
327 | raise ValueError("Real and generated batch sizes must match")
328 |
329 | N = pred_real_voted.shape[0] + pred_gen_voted.shape[0]
330 |
331 | # True Accept
332 | TA = pred_real_voted.sum()
333 |
334 | # True Reject
335 | TR = pred_gen_voted.sum()
336 |
337 | # False Accept
338 | FA = N - TR
339 |
340 | # False Reject
341 | FR = N - TA
342 |
343 | self.true_accept_total += TA
344 | self.true_reject_total += TR
345 | self.false_accept_total += FA
346 | self.false_reject_total += FR
347 | self.num_samples_total += N
348 |
349 |
350 |
351 |
352 | def generator_adversarial_loss(disc_outputs):
353 | loss = 0
354 | gen_losses = []
355 | for dg in disc_outputs:
356 | l = torch.mean((1-dg)**2)
357 | gen_losses.append(l)
358 | loss += l
359 |
360 | return loss, gen_losses
361 |
362 | def generator_collaborative_loss(disc_outputs):
363 | loss = 0
364 | gen_losses = []
365 | for dg in disc_outputs:
366 | l = torch.mean((0-dg)**2)
367 | gen_losses.append(l)
368 | loss += l
369 |
370 | return loss, gen_losses
371 |
372 |
373 |
--------------------------------------------------------------------------------
/src/neural_formant_synthesis/third_party/hifi_gan/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.4.0
2 | numpy==1.17.4
3 | librosa==0.7.2
4 | scipy==1.4.1
5 | tensorboard==2.0
6 | soundfile==0.10.3.post1
7 | matplotlib==3.1.3
--------------------------------------------------------------------------------
/src/neural_formant_synthesis/third_party/hifi_gan/train.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | warnings.simplefilter(action='ignore', category=FutureWarning)
3 | import itertools
4 | import os
5 | import time
6 | import argparse
7 | import json
8 | import torch
9 | import torch.nn.functional as F
10 | from torch.utils.tensorboard import SummaryWriter
11 | from torch.utils.data import DistributedSampler, DataLoader
12 | import torch.multiprocessing as mp
13 | from torch.distributed import init_process_group
14 | from torch.nn.parallel import DistributedDataParallel
15 | from env import AttrDict, build_env
16 | from meldataset import MelDataset, mel_spectrogram, get_dataset_filelist
17 | from models import Generator, MultiPeriodDiscriminator, MultiScaleDiscriminator, feature_loss, generator_adversarial_loss,\
18 | discriminator_loss
19 | from utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint
20 |
21 | torch.backends.cudnn.benchmark = True
22 |
23 |
24 | def train(rank, a, h):
25 | if h.num_gpus > 1:
26 | init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'],
27 | world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank)
28 |
29 | torch.cuda.manual_seed(h.seed)
30 | if torch.cuda.is_available():
31 | device = torch.device('cuda:{:d}'.format(rank))
32 | else:
33 | device = torch.device('cpu')
34 |
35 | generator = Generator(h).to(device)
36 | mpd = MultiPeriodDiscriminator().to(device)
37 | msd = MultiScaleDiscriminator().to(device)
38 |
39 | if rank == 0:
40 | print(generator)
41 | os.makedirs(a.checkpoint_path, exist_ok=True)
42 | print("checkpoints directory : ", a.checkpoint_path)
43 |
44 | if os.path.isdir(a.checkpoint_path):
45 | cp_g = scan_checkpoint(a.checkpoint_path, 'g_')
46 | cp_do = scan_checkpoint(a.checkpoint_path, 'do_')
47 |
48 | steps = 0
49 | if cp_g is None or cp_do is None:
50 | state_dict_do = None
51 | last_epoch = -1
52 | else:
53 | state_dict_g = load_checkpoint(cp_g, device)
54 | state_dict_do = load_checkpoint(cp_do, device)
55 | generator.load_state_dict(state_dict_g['generator'])
56 | mpd.load_state_dict(state_dict_do['mpd'])
57 | msd.load_state_dict(state_dict_do['msd'])
58 | steps = state_dict_do['steps'] + 1
59 | last_epoch = state_dict_do['epoch']
60 |
61 | if h.num_gpus > 1:
62 | generator = DistributedDataParallel(generator, device_ids=[rank]).to(device)
63 | mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
64 | msd = DistributedDataParallel(msd, device_ids=[rank]).to(device)
65 |
66 | optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
67 | optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()),
68 | h.learning_rate, betas=[h.adam_b1, h.adam_b2])
69 |
70 | if state_dict_do is not None:
71 | optim_g.load_state_dict(state_dict_do['optim_g'])
72 | optim_d.load_state_dict(state_dict_do['optim_d'])
73 |
74 | scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
75 | scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch)
76 |
77 | training_filelist, validation_filelist = get_dataset_filelist(a)
78 |
79 | trainset = MelDataset(training_filelist, h.segment_size, h.n_fft, h.num_mels,
80 | h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, n_cache_reuse=0,
81 | shuffle=False if h.num_gpus > 1 else True, fmax_loss=h.fmax_for_loss, device=device,
82 | fine_tuning=a.fine_tuning, base_mels_path=a.input_mels_dir)
83 |
84 | train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None
85 |
86 | train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False,
87 | sampler=train_sampler,
88 | batch_size=h.batch_size,
89 | pin_memory=True,
90 | drop_last=True)
91 |
92 | if rank == 0:
93 | validset = MelDataset(validation_filelist, h.segment_size, h.n_fft, h.num_mels,
94 | h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, False, False, n_cache_reuse=0,
95 | fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning,
96 | base_mels_path=a.input_mels_dir)
97 | validation_loader = DataLoader(validset, num_workers=1, shuffle=False,
98 | sampler=None,
99 | batch_size=1,
100 | pin_memory=True,
101 | drop_last=True)
102 |
103 | sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs'))
104 |
105 | generator.train()
106 | mpd.train()
107 | msd.train()
108 | for epoch in range(max(0, last_epoch), a.training_epochs):
109 | if rank == 0:
110 | start = time.time()
111 | print("Epoch: {}".format(epoch+1))
112 |
113 | if h.num_gpus > 1:
114 | train_sampler.set_epoch(epoch)
115 |
116 | for i, batch in enumerate(train_loader):
117 | if rank == 0:
118 | start_b = time.time()
119 | x, y, _, y_mel = batch # mel, audio, filename, mel_loss
120 | x = torch.autograd.Variable(x.to(device, non_blocking=True))
121 | y = torch.autograd.Variable(y.to(device, non_blocking=True))
122 | y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
123 | y = y.unsqueeze(1)
124 |
125 | # TODO: add LPC parameter estimation
126 |
127 | y_g_hat = generator(x)
128 | # TODO: add glotnet synthesis filter
129 | y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size,
130 | h.fmin, h.fmax_for_loss)
131 |
132 | optim_d.zero_grad()
133 |
134 | # MPD
135 | y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
136 | loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
137 |
138 | # MSD
139 | y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
140 | loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
141 |
142 | loss_disc_all = loss_disc_s + loss_disc_f
143 |
144 | loss_disc_all.backward()
145 | optim_d.step()
146 |
147 | # Generator
148 | optim_g.zero_grad()
149 |
150 | # L1 Mel-Spectrogram Loss
151 | loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45
152 |
153 | y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
154 | y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
155 | loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
156 | loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
157 | loss_gen_f, losses_gen_f = generator_adversarial_loss(y_df_hat_g)
158 | loss_gen_s, losses_gen_s = generator_adversarial_loss(y_ds_hat_g)
159 | loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
160 |
161 | loss_gen_all.backward()
162 | optim_g.step()
163 |
164 | if rank == 0:
165 | # STDOUT logging
166 | if steps % a.stdout_interval == 0:
167 | with torch.no_grad():
168 | mel_error = F.l1_loss(y_mel, y_g_hat_mel).item()
169 |
170 | print('Steps : {:d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'.
171 | format(steps, loss_gen_all, mel_error, time.time() - start_b))
172 |
173 | # checkpointing
174 | if steps % a.checkpoint_interval == 0 and steps != 0:
175 | checkpoint_path = "{}/g_{:08d}".format(a.checkpoint_path, steps)
176 | save_checkpoint(checkpoint_path,
177 | {'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
178 | checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps)
179 | save_checkpoint(checkpoint_path,
180 | {'mpd': (mpd.module if h.num_gpus > 1
181 | else mpd).state_dict(),
182 | 'msd': (msd.module if h.num_gpus > 1
183 | else msd).state_dict(),
184 | 'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps,
185 | 'epoch': epoch})
186 |
187 | # Tensorboard summary logging
188 | if steps % a.summary_interval == 0:
189 | sw.add_scalar("training/gen_loss_total", loss_gen_all, steps)
190 | sw.add_scalar("training/mel_spec_error", mel_error, steps)
191 |
192 | # Validation
193 | if steps % a.validation_interval == 0: # and steps != 0:
194 | generator.eval()
195 | torch.cuda.empty_cache()
196 | val_err_tot = 0
197 | with torch.no_grad():
198 | for j, batch in enumerate(validation_loader):
199 | x, y, _, y_mel = batch
200 | y_g_hat = generator(x.to(device))
201 | y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
202 | y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate,
203 | h.hop_size, h.win_size,
204 | h.fmin, h.fmax_for_loss)
205 | val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item()
206 |
207 | if j <= 4:
208 | if steps == 0:
209 | sw.add_audio('gt/y_{}'.format(j), y[0], steps, h.sampling_rate)
210 | sw.add_figure('gt/y_spec_{}'.format(j), plot_spectrogram(x[0]), steps)
211 |
212 | sw.add_audio('generated/y_hat_{}'.format(j), y_g_hat[0], steps, h.sampling_rate)
213 | y_hat_spec = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels,
214 | h.sampling_rate, h.hop_size, h.win_size,
215 | h.fmin, h.fmax)
216 | sw.add_figure('generated/y_hat_spec_{}'.format(j),
217 | plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()), steps)
218 |
219 | val_err = val_err_tot / (j+1)
220 | sw.add_scalar("validation/mel_spec_error", val_err, steps)
221 |
222 | generator.train()
223 |
224 | steps += 1
225 |
226 | scheduler_g.step()
227 | scheduler_d.step()
228 |
229 | if rank == 0:
230 | print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start)))
231 |
232 |
233 | def main():
234 | print('Initializing Training Process..')
235 |
236 | parser = argparse.ArgumentParser()
237 |
238 | parser.add_argument('--group_name', default=None)
239 | parser.add_argument('--input_wavs_dir', default='LJSpeech-1.1/wavs')
240 | parser.add_argument('--input_mels_dir', default='ft_dataset')
241 | parser.add_argument('--input_training_file', default='LJSpeech-1.1/training.txt')
242 | parser.add_argument('--input_validation_file', default='LJSpeech-1.1/validation.txt')
243 | parser.add_argument('--checkpoint_path', default='cp_hifigan')
244 | parser.add_argument('--config', default='')
245 | parser.add_argument('--training_epochs', default=3100, type=int)
246 | parser.add_argument('--stdout_interval', default=5, type=int)
247 | parser.add_argument('--checkpoint_interval', default=5000, type=int)
248 | parser.add_argument('--summary_interval', default=100, type=int)
249 | parser.add_argument('--validation_interval', default=1000, type=int)
250 | parser.add_argument('--fine_tuning', default=False, type=bool)
251 |
252 | a = parser.parse_args()
253 |
254 | with open(a.config) as f:
255 | data = f.read()
256 |
257 | json_config = json.loads(data)
258 | h = AttrDict(json_config)
259 | build_env(a.config, 'config.json', a.checkpoint_path)
260 |
261 | torch.manual_seed(h.seed)
262 | if torch.cuda.is_available():
263 | torch.cuda.manual_seed(h.seed)
264 | h.num_gpus = torch.cuda.device_count()
265 | h.batch_size = int(h.batch_size / h.num_gpus)
266 | print('Batch size per GPU :', h.batch_size)
267 | else:
268 | pass
269 |
270 | if h.num_gpus > 1:
271 | mp.spawn(train, nprocs=h.num_gpus, args=(a, h,))
272 | else:
273 | train(0, a, h)
274 |
275 |
276 | if __name__ == '__main__':
277 | main()
278 |
--------------------------------------------------------------------------------
/src/neural_formant_synthesis/third_party/hifi_gan/utils.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import matplotlib
4 | import torch
5 | from torch.nn.utils import weight_norm
6 | matplotlib.use("Agg")
7 | import matplotlib.pylab as plt
8 |
9 |
10 | def plot_spectrogram(spectrogram):
11 | fig, ax = plt.subplots(figsize=(10, 2))
12 | im = ax.imshow(spectrogram, aspect="auto", origin="lower",
13 | interpolation='none')
14 | plt.colorbar(im, ax=ax)
15 |
16 | fig.canvas.draw()
17 | plt.close()
18 |
19 | return fig
20 |
21 |
22 | def init_weights(m, mean=0.0, std=0.01):
23 | classname = m.__class__.__name__
24 | if classname.find("Conv") != -1:
25 | m.weight.data.normal_(mean, std)
26 |
27 |
28 | def apply_weight_norm(m):
29 | classname = m.__class__.__name__
30 | if classname.find("Conv") != -1:
31 | weight_norm(m)
32 |
33 |
34 | def get_padding(kernel_size, dilation=1):
35 | return int((kernel_size*dilation - dilation)/2)
36 |
37 |
38 | def load_checkpoint(filepath, device):
39 | assert os.path.isfile(filepath)
40 | print("Loading '{}'".format(filepath))
41 | checkpoint_dict = torch.load(filepath, map_location=device)
42 | print("Complete.")
43 | return checkpoint_dict
44 |
45 |
46 | def save_checkpoint(filepath, obj):
47 | print("Saving checkpoint to {}".format(filepath))
48 | torch.save(obj, filepath)
49 | print("Complete.")
50 |
51 |
52 | def scan_checkpoint(cp_dir, prefix):
53 | pattern = os.path.join(cp_dir, prefix + '????????')
54 | cp_list = glob.glob(pattern)
55 | if len(cp_list) == 0:
56 | return None
57 | return sorted(cp_list)[-1]
58 |
59 |
--------------------------------------------------------------------------------
/src/neural_formant_synthesis/vctk_preprocessing.py:
--------------------------------------------------------------------------------
1 | import os, glob
2 | import numpy as np
3 | import torch
4 | import torchaudio as ta
5 | from tqdm import tqdm
6 |
7 | import argparse
8 |
9 | from feature_extraction import feature_extractor, MedianPool1d
10 | from neural_formant_synthesis.glotnet.sigproc.emphasis import Emphasis
11 |
12 | ta.set_audio_backend("sox_io")
13 |
14 | def main(vctk_path, target_dir):
15 | # Set random seed to 0
16 | np.random.seed(0)
17 | torch.manual_seed(0)
18 |
19 | # Define initial parameters
20 | window_length = 1024
21 | step_length = 256
22 |
23 | target_sr = 22050
24 |
25 | file_ext = '.flac'
26 |
27 | # Declare feature extractor
28 | feat_extractor = feature_extractor(sr = target_sr,window_samples = window_length, step_samples = step_length, formant_ceiling = 10000, max_formants = 4)
29 | median_filter = MedianPool1d(kernel_size = 3, stride = 1, padding = 0, same = True)
30 | pre_emphasis = Emphasis(alpha=0.97)
31 |
32 | # Divide vctk dataset into train, validation and test sets
33 | divide_vctk(vctk_path, target_dir, file_ext = file_ext)
34 |
35 | # Process train, validation and test sets
36 | print("Processing separated sets")
37 | train_dir = os.path.join(target_dir, 'train')
38 | process_directory(train_dir, target_sr, feat_extractor, median_filter, pre_emphasis = pre_emphasis, file_ext = file_ext)
39 | val_dir = os.path.join(target_dir, 'val')
40 | process_directory(val_dir, target_sr, feat_extractor, median_filter, pre_emphasis = pre_emphasis, file_ext = file_ext)
41 | test_dir = os.path.join(target_dir, 'test')
42 | process_directory(test_dir, target_sr, feat_extractor, median_filter, pre_emphasis = pre_emphasis, file_ext = file_ext)
43 |
44 | def divide_vctk(vctk_path, target_dir, file_ext = '.wav'):
45 | """
46 | Divide original vctk dataset into train, validation and test sets with a ratio of 80:10:10 with different speakers.
47 | """
48 | #Empty target directory if it exists
49 | if os.path.exists(target_dir):
50 | os.system('rm -rf ' + target_dir)
51 | # Create target directory if it doesn't exist
52 | if not os.path.exists(target_dir):
53 | os.mkdir(target_dir)
54 | # Create train, validation and test directories if they don't exist
55 | train_dir = os.path.join(target_dir, 'train')
56 | if not os.path.exists(train_dir):
57 | os.mkdir(train_dir)
58 | val_dir = os.path.join(target_dir, 'val')
59 | if not os.path.exists(val_dir):
60 | os.mkdir(val_dir)
61 | test_dir = os.path.join(target_dir, 'test')
62 | if not os.path.exists(test_dir):
63 | os.mkdir(test_dir)
64 |
65 | # Get speakers as directory names in vctk_path
66 | speakers = os.listdir(vctk_path)
67 |
68 | #Shuffle speakers list
69 | speakers = np.random.permutation(speakers)
70 | # Divide speakers list into train, validation and test sets
71 | train_speakers = speakers[:int(len(speakers) * 0.8)]
72 | val_speakers = speakers[int(len(speakers) * 0.8):int(len(speakers) * 0.9)]
73 | test_speakers = speakers[int(len(speakers) * 0.9):]
74 |
75 | # Save train, validation and test speakers lists in one file in target_dir
76 | speakers_dict = {"train": train_speakers, "val": val_speakers, "test": test_speakers}
77 | torch.save(speakers_dict, os.path.join(target_dir, 'speakers.pt'))
78 |
79 | print("Dividing Speakers")
80 |
81 | # Copy audio files in each speaker directory to train, validation and test directories
82 |
83 | print("Processing Test Set")
84 | testfile = open(os.path.join(target_dir, "test_files.txt"), 'w')
85 | for speaker in tqdm(test_speakers,total = len(test_speakers)):
86 | speaker_dir = os.path.join(vctk_path, speaker)
87 | speaker_files = glob.glob(os.path.join(speaker_dir, '*' + file_ext))
88 | for file in speaker_files:
89 | os.system('cp ' + file + ' ' + test_dir)
90 | testfile.writelines([str(i)+'\n' for i in speaker_files])
91 | testfile.close()
92 |
93 | print("Processing Training Set")
94 | trainfile = open(os.path.join(target_dir, "train_files.txt"), 'w')
95 | for speaker in tqdm(train_speakers, total = len(train_speakers)):
96 | speaker_dir = os.path.join(vctk_path, speaker)
97 | speaker_files = glob.glob(os.path.join(speaker_dir, '*' + file_ext))
98 | for file in speaker_files:
99 | os.system('cp ' + file + ' ' + train_dir)
100 | trainfile.writelines([str(i)+'\n' for i in speaker_files])
101 | trainfile.close()
102 |
103 | print("Processing Validation Set")
104 | valfile = open(os.path.join(target_dir, "val_files.txt"), 'w')
105 | for speaker in tqdm(val_speakers, total = len(val_speakers)):
106 | speaker_dir = os.path.join(vctk_path, speaker)
107 | speaker_files = glob.glob(os.path.join(speaker_dir, '*' + file_ext))
108 | for file in speaker_files:
109 | os.system('cp ' + file + ' ' + val_dir)
110 | valfile.writelines([str(i)+'\n' for i in speaker_files])
111 | valfile.close()
112 |
113 |
114 |
115 |
116 | def process_directory(path, target_sr, feature_extractor, median_filter, pre_emphasis = None, file_ext = '.wav'):
117 | file_list = glob.glob(os.path.join(path, '*' + file_ext))
118 | print(f"Processing directory {path}")
119 | for file in tqdm(file_list, total=len(file_list)):
120 | basename = os.path.basename(file)
121 | no_ext = os.path.splitext(basename)[0]
122 |
123 | formants, energy, centroid, tilt, log_pitch, voicing_flag, r_coeff, ignored = process_file(file, target_sr, feature_extractor, median_filter, pre_emphasis = pre_emphasis)
124 |
125 | if formants.size(0) < log_pitch.size(0):
126 | raise ValueError("Formants size is different than pitch size for file: " + file)
127 |
128 | feature_dict = {"Formants": formants, "Energy": energy, "Centroid": centroid, "Tilt": tilt, "Pitch": log_pitch, "Voicing": voicing_flag, "R_Coeff": r_coeff}
129 | if not ignored:
130 | torch.save(feature_dict, os.path.join(path, no_ext + '.pt'))
131 | else:
132 | print("File: " + basename + " ignored.")
133 |
134 | def process_file(file, target_sr, feature_extractor, median_filter, pre_emphasis = None):
135 | """
136 | Extract features for a single audio file and return feature arrays with the same length.
137 | Params:
138 | file: path to file
139 | target_sr: target sample rate
140 | Returns:
141 | formants: formants
142 | energy: energy in log scale
143 | centroid: spectral centroid
144 | tilt: spectral tilt
145 | tilt_ref: spectral tilt from reference
146 | pitch: pitch in log scale
147 | voicing_flag: voicing flag
148 | r_coeff: reflection coefficients from reference
149 | """
150 | # Read audio file
151 | x, sample_rate = ta.load(file)
152 | x = x[0:1].type(torch.DoubleTensor)
153 | # Resample to target sr
154 | x = ta.functional.resample(x, sample_rate, target_sr)
155 |
156 | if pre_emphasis is not None:
157 | x = pre_emphasis(x.unsqueeze(0))
158 | x = x.squeeze(0).squeeze(0)
159 | formants, energy, centroid, tilt, pitch, voicing_flag,r_coeff, _, ignored = feature_extractor(x)
160 |
161 | formants = median_filter(formants.T.unsqueeze(1)).squeeze(1).T
162 |
163 | pitch = pitch.squeeze(0)
164 | voicing_flag = voicing_flag.squeeze(0)
165 |
166 | # If pitch length is smaller than formants, pad pitch and voicing flag with last value
167 | if pitch.size(0) < formants.size(0):
168 | pitch = torch.nn.functional.pad(pitch, (0, formants.size(0) - pitch.size(0)), mode = 'constant', value = pitch[-1])
169 | voicing_flag = torch.nn.functional.pad(voicing_flag, (0, formants.size(0) - voicing_flag.size(0)), mode = 'constant', value = voicing_flag[-1])
170 | # If pitch length is larger than formants, truncate pitch and voicing flag
171 | elif pitch.size(0) > formants.size(0):
172 | pitch = pitch[:formants.size(0)]
173 | voicing_flag = voicing_flag[:formants.size(0)]
174 |
175 | log_pitch = torch.log(pitch)
176 |
177 | return formants, energy, centroid, tilt, log_pitch, voicing_flag, r_coeff, ignored
178 |
179 | if __name__ == "__main__":
180 |
181 | parser = argparse.ArgumentParser()
182 |
183 | parser.add_argument('--input_dir', default='/workspace/Dataset/wav48_silence_trimmed')
184 | parser.add_argument('--output_dir', default='/workspace/Dataset/vctk_features')
185 |
186 |
187 | a = parser.parse_args()
188 |
189 | vctk_path = os.path.normpath(a.input_dir)
190 | target_dir = os.path.normpath(a.output_dir)
191 | main(vctk_path, target_dir)
--------------------------------------------------------------------------------
/tests/test_dataset.py:
--------------------------------------------------------------------------------
1 | from neural_formant_synthesis.dataset import FeatureDataset
--------------------------------------------------------------------------------
/tests/test_train_imports.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | warnings.simplefilter(action='ignore', category=FutureWarning)
3 | import itertools
4 | import os
5 | import time
6 | import argparse
7 |
8 | import json
9 | import torch
10 | import torch.nn.functional as F
11 | from torch.utils.tensorboard import SummaryWriter
12 | from torch.utils.data import DistributedSampler, DataLoader
13 | import torch.multiprocessing as mp
14 | from torch.distributed import init_process_group
15 | from torch.nn.parallel import DistributedDataParallel
16 | from neural_formant_synthesis.third_party.hifi_gan.env import AttrDict, build_env
17 | from neural_formant_synthesis.third_party.hifi_gan.meldataset import mel_spectrogram
18 | from neural_formant_synthesis.third_party.hifi_gan.models import MultiPeriodDiscriminator, MultiScaleDiscriminator, feature_loss, generator_adversarial_loss,\
19 | discriminator_loss
20 | from neural_formant_synthesis.third_party.hifi_gan.utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint
21 |
22 |
23 | from neural_formant_synthesis.glotnet.sigproc.lpc import LinearPredictor
24 | from neural_formant_synthesis.glotnet.sigproc.emphasis import Emphasis
25 |
26 | from neural_formant_synthesis.dataset import FeatureDataset_List
27 | from neural_formant_synthesis.models import FM_Hifi_Generator, fm_config_obj, Envelope_wavenet, Envelope_conformer
28 | from neural_formant_synthesis.models import SourceFilterFormantSynthesisGenerator
29 |
30 | from neural_formant_synthesis.glotnet.sigproc.levinson import forward_levinson
31 |
32 | import torchaudio as ta
33 |
--------------------------------------------------------------------------------
/tests/test_wavenet.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ljuvela/SourceFilterNeuralFormants/d0894c2aa153510e6967c3b62c47f73ce4cb3879/tests/test_wavenet.py
--------------------------------------------------------------------------------
/train_e2e_hifigan.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | warnings.simplefilter(action='ignore', category=FutureWarning)
3 | import itertools
4 | import os
5 | import time
6 | import argparse
7 |
8 | import json
9 | import torch
10 | import torch.nn.functional as F
11 | from torch.utils.tensorboard import SummaryWriter
12 | from torch.utils.data import DistributedSampler, DataLoader
13 | import torch.multiprocessing as mp
14 | from torch.distributed import init_process_group
15 | from torch.nn.parallel import DistributedDataParallel
16 | from neural_formant_synthesis.third_party.hifi_gan.env import AttrDict, build_env
17 | from neural_formant_synthesis.third_party.hifi_gan.meldataset import mel_spectrogram
18 | from neural_formant_synthesis.third_party.hifi_gan.models import MultiPeriodDiscriminator, MultiScaleDiscriminator, feature_loss, generator_adversarial_loss,\
19 | discriminator_loss
20 | #from neural_formant_synthesis.third_party.hifi_gan.models import discriminator_metrics
21 | from neural_formant_synthesis.third_party.hifi_gan.utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint
22 |
23 |
24 | from neural_formant_synthesis.glotnet.sigproc.lpc import LinearPredictor
25 | from neural_formant_synthesis.glotnet.sigproc.emphasis import Emphasis
26 |
27 | from neural_formant_synthesis.dataset import FeatureDataset_List
28 | from neural_formant_synthesis.models import FM_Hifi_Generator, fm_config_obj, Envelope_wavenet, Envelope_conformer
29 | from neural_formant_synthesis.models import NeuralFormantSynthesisGenerator
30 |
31 |
32 | from neural_formant_synthesis.glotnet.sigproc.levinson import forward_levinson
33 |
34 | import torchaudio as ta
35 |
36 |
37 | torch.backends.cudnn.benchmark = True
38 |
39 |
40 | def train(rank, a, h, fm_h):
41 |
42 | if h.num_gpus > 1:
43 | init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'],
44 | world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank)
45 |
46 | torch.cuda.manual_seed(h.seed)
47 | if torch.cuda.is_available():
48 | device = torch.device('cuda:{:d}'.format(rank))
49 | else:
50 | device = torch.device('cpu')
51 |
52 | # HiFi generator included in feature mapping class.
53 | pretrained_fm = getattr(fm_h, 'model_path', None)
54 | generator = NeuralFormantSynthesisGenerator(fm_config = fm_h, g_config = h,
55 | pretrained_fm = pretrained_fm,
56 | freeze_fm = pretrained_fm is not None,
57 | device = device)
58 |
59 | def count_parameters(model):
60 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
61 |
62 | print(f"Num parameters in HiFi Generator: {count_parameters(generator.hifi_generator)}")
63 | print(f"Num parameters in Feature Mapping model: {count_parameters(generator.feature_mapping)}")
64 |
65 | generator = generator.to(device)
66 | mpd = MultiPeriodDiscriminator().to(device)
67 | msd = MultiScaleDiscriminator().to(device)
68 |
69 | if rank == 0:
70 | print(generator)
71 | os.makedirs(a.checkpoint_path, exist_ok=True)
72 | print("checkpoints directory : ", a.checkpoint_path)
73 |
74 | if os.path.isdir(a.checkpoint_path):
75 | cp_g = scan_checkpoint(a.checkpoint_path, 'g_')
76 | cp_do = scan_checkpoint(a.checkpoint_path, 'do_')
77 |
78 | pretrain_steps = 0 # always reset, never load
79 | steps = 0
80 | if cp_g is None or cp_do is None:
81 | state_dict_do = None
82 | last_epoch = -1
83 | else:
84 | generator.load_generator_e2e_checkpoint(cp_g)
85 | state_dict_do = load_checkpoint(cp_do, device)
86 | mpd.load_state_dict(state_dict_do['mpd'])
87 | msd.load_state_dict(state_dict_do['msd'])
88 | steps = state_dict_do['steps'] + 1
89 | last_epoch = state_dict_do['epoch']
90 |
91 | if h.num_gpus > 1:
92 | generator = DistributedDataParallel(generator, device_ids=[rank]).to(device)
93 | mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
94 | msd = DistributedDataParallel(msd, device_ids=[rank]).to(device)
95 |
96 | optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
97 | optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()),
98 | h.learning_rate, betas=[h.adam_b1, h.adam_b2])
99 |
100 | if state_dict_do is not None:
101 | optim_g.load_state_dict(state_dict_do['optim_g'])
102 | optim_d.load_state_dict(state_dict_do['optim_d'])
103 |
104 | scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
105 | scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch)
106 |
107 | training_path = os.path.join(a.input_wavs_dir, 'train')
108 | trainset = FeatureDataset_List(training_path, h, sampling_rate = h.sampling_rate,
109 | frame_size = h.win_size, hop_size = h.hop_size, shuffle = True, audio_ext = '.flac',
110 | segment_length = 32)
111 |
112 | train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None
113 |
114 | train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False,
115 | sampler=train_sampler,
116 | batch_size=h.batch_size,
117 | pin_memory=True,
118 | drop_last=True)
119 |
120 | if rank == 0:
121 |
122 | valid_path = os.path.join(a.input_wavs_dir, 'val')
123 | validset = FeatureDataset_List(valid_path, h, sampling_rate = h.sampling_rate,
124 | frame_size = h.win_size, hop_size = h.hop_size, audio_ext = '.flac')
125 | validation_loader = DataLoader(validset, num_workers=1, shuffle=False,
126 | sampler=None,
127 | batch_size=1,
128 | pin_memory=True,
129 | drop_last=True)
130 |
131 | sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs'))
132 |
133 |
134 | generator.train()
135 | mpd.train()
136 | msd.train()
137 |
138 | # generator = torch.compile(generator)
139 | # mpd = torch.compile(mpd)
140 | # msd = torch.compile(msd)
141 | for epoch in range(max(0, last_epoch), a.training_epochs):
142 |
143 | if rank == 0:
144 | start = time.time()
145 | print("Epoch: {}".format(epoch+1))
146 |
147 | if h.num_gpus > 1:
148 | train_sampler.set_epoch(epoch)
149 |
150 | for i, batch in enumerate(train_loader):
151 |
152 | if rank == 0:
153 | start_b = time.time()
154 |
155 | #size --> (Batch, features, sequence)
156 | x, _, y, y_mel = batch
157 | x = torch.autograd.Variable(x.to(device, non_blocking=True))
158 | y = torch.autograd.Variable(y.to(device, non_blocking=True))
159 | y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
160 | y = y.unsqueeze(1)
161 |
162 | x_feat = x[:,0:9,:]
163 |
164 | # feature_map_only = pretrain_steps < fm_h.get('pre_train_steps', 0)
165 | feature_map_only = steps < fm_h.get('pre_train_steps', 0)
166 | detach = fm_h.get('detach_feature_map_from_gan', False)
167 | if feature_map_only:
168 | mel_cond = generator.forward(x_feat, feature_map_only, detach)
169 | else:
170 | y_g_hat, mel_cond = generator.forward(x_feat, feature_map_only, detach)
171 |
172 |
173 |
174 | if not feature_map_only:
175 | y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft,
176 | h.num_mels, h.sampling_rate,
177 | h.hop_size, h.win_size,
178 | h.fmin, h.fmax_for_loss)
179 |
180 | optim_d.zero_grad()
181 |
182 | # MPD
183 | y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
184 |
185 | loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
186 |
187 | # MSD
188 | y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
189 |
190 | loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
191 |
192 | loss_disc_all = loss_disc_s + loss_disc_f
193 |
194 | loss_disc_all.backward()
195 | optim_d.step()
196 |
197 | # Generator
198 | optim_g.zero_grad()
199 | loss_gen_all = 0.0
200 |
201 | if not feature_map_only:
202 | # L1 Mel-Spectrogram Loss
203 | loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45
204 |
205 | # Adversarial losses
206 | y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
207 | y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
208 | loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
209 | loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
210 | loss_gen_f, losses_gen_f = generator_adversarial_loss(y_df_hat_g)
211 | loss_gen_s, losses_gen_s = generator_adversarial_loss(y_ds_hat_g)
212 | loss_gen_all = loss_gen_all + loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
213 |
214 | loss_mel_fm = F.l1_loss(y_mel, mel_cond)
215 |
216 | loss_gen_all = loss_gen_all + loss_mel_fm
217 |
218 | loss_gen_all.backward()
219 | optim_g.step()
220 |
221 | if not torch.isfinite(loss_gen_all):
222 | raise ValueError(f"Loss value is not finite, was {loss_gen_all}")
223 |
224 | if rank == 0:
225 | # STDOUT logging
226 | if steps % a.stdout_interval == 0:
227 |
228 | if feature_map_only:
229 | print('Steps : {:d}, Gen Loss Total : {:4.3f}, s/b : {:4.3f}'.
230 | format(steps, loss_gen_all, time.time() - start_b))
231 | else:
232 | with torch.no_grad():
233 | mel_error = F.l1_loss(y_mel, y_g_hat_mel).item()
234 | print('Steps : {:d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'.
235 | format(steps, loss_gen_all, mel_error, time.time() - start_b))
236 |
237 | # checkpointing
238 | if steps % a.checkpoint_interval == 0 and steps != 0:
239 | checkpoint_path = "{}/g_{:08d}".format(a.checkpoint_path, steps)
240 | save_checkpoint(checkpoint_path,
241 | {'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
242 | checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps)
243 | save_checkpoint(checkpoint_path,
244 | {'mpd': (mpd.module if h.num_gpus > 1
245 | else mpd).state_dict(),
246 | 'msd': (msd.module if h.num_gpus > 1
247 | else msd).state_dict(),
248 | 'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps,
249 | 'epoch': epoch})
250 |
251 | # Tensorboard summary logging
252 | if steps % a.summary_interval == 0:
253 | sw.add_scalar("training/gen_loss_total", loss_gen_all, steps)
254 |
255 | sw.add_scalar("training/feature_mapping_mel_loss", loss_mel_fm, steps)
256 |
257 | if not feature_map_only:
258 | sw.add_scalar("training/mel_spec_error", mel_error, steps)
259 |
260 | # Framed Discriminator losses
261 | sw.add_scalar("training_gan/disc_f_r", sum(losses_disc_f_r), steps)
262 | sw.add_scalar("training_gan/disc_f_g", sum(losses_disc_f_g), steps)
263 | # Multiscale Discriminator losses
264 | sw.add_scalar("training_gan/disc_s_r", sum(losses_disc_s_r), steps)
265 | sw.add_scalar("training_gan/disc_s_g", sum(losses_disc_s_g), steps)
266 | # Framed Generator losses
267 | sw.add_scalar("training_gan/gen_f", sum(losses_gen_f), steps)
268 | # Multiscale Generator losses
269 | sw.add_scalar("training_gan/gen_s", sum(losses_gen_s), steps)
270 | # Feature Matching losses
271 | sw.add_scalar("training_gan/loss_fm_f", loss_fm_f, steps)
272 | sw.add_scalar("training_gan/loss_fm_s", loss_fm_s, steps)
273 |
274 |
275 | # Validation
276 | if steps % a.validation_interval == 0: #and steps != 0:
277 |
278 | print(f"Validation at step {steps}")
279 | generator.eval()
280 | torch.cuda.empty_cache()
281 | val_err_tot = 0
282 | max_valid_batches = 100
283 | with torch.no_grad():
284 | for j, batch in enumerate(validation_loader):
285 |
286 | if j > max_valid_batches:
287 | break
288 |
289 | x, _, y, y_mel = batch
290 |
291 | x = x.to(device)
292 |
293 | x_feat = x[:,0:9,:]
294 |
295 | y_g_hat, mel_cond = generator(x_feat)
296 |
297 | y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
298 | y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate,
299 | h.hop_size, h.win_size,
300 | h.fmin, h.fmax_for_loss)
301 |
302 | val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item()
303 |
304 | # TODO: calculate discriminator EER
305 |
306 | if j <= 4:
307 | if steps == 0:
308 | sw.add_audio('gt/y_{}'.format(j), y[0], steps, h.sampling_rate)
309 | sw.add_figure('gt/y_spec_{}'.format(j), plot_spectrogram(y_mel[0].cpu()), steps)
310 |
311 | sw.add_audio('generated/y_hat_{}'.format(j), y_g_hat[0], steps, h.sampling_rate)
312 | y_hat_spec = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels,
313 | h.sampling_rate, h.hop_size, h.win_size,
314 | h.fmin, h.fmax)
315 | sw.add_figure('generated/y_hat_spec_{}'.format(j),
316 | plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()), steps)
317 |
318 | val_err = val_err_tot / (j+1)
319 | sw.add_scalar("validation/mel_spec_error", val_err, steps)
320 |
321 | generator.train()
322 |
323 | steps += 1
324 | pretrain_steps += 1
325 |
326 | scheduler_g.step()
327 | scheduler_d.step()
328 |
329 | if rank == 0:
330 | print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start)))
331 |
332 |
333 | def main():
334 | print('Initializing Training Process..')
335 |
336 | parser = argparse.ArgumentParser()
337 |
338 | parser.add_argument('--group_name', default=None)
339 | parser.add_argument('--input_wavs_dir', default='LJSpeech-1.1/wavs')
340 | parser.add_argument('--input_mels_dir', default='ft_dataset')
341 | parser.add_argument('--input_training_file', default='LJSpeech-1.1/training.txt')
342 | parser.add_argument('--input_validation_file', default='LJSpeech-1.1/validation.txt')
343 | parser.add_argument('--checkpoint_path', default='cp_hifigan')
344 | parser.add_argument('--config', default='')
345 | parser.add_argument('--fm_config', default='')
346 | parser.add_argument('--env_config', default='')
347 | parser.add_argument('--training_epochs', default=3100, type=int)
348 | parser.add_argument('--stdout_interval', default=5, type=int)
349 | parser.add_argument('--checkpoint_interval', default=5000, type=int)
350 | parser.add_argument('--summary_interval', default=100, type=int)
351 | parser.add_argument('--validation_interval', default=1000, type=int)
352 | parser.add_argument('--fine_tuning', default=False, type=bool)
353 | parser.add_argument('--wavefile_ext', default='.wav', type=str)
354 | parser.add_argument('--envelope_model_pretrained', default=False, type=bool)
355 | parser.add_argument('--envelope_model_freeze', default=False, type=bool)
356 |
357 | a = parser.parse_args()
358 |
359 | with open(a.config) as f:
360 | data = f.read()
361 |
362 | json_config = json.loads(data)
363 | h = AttrDict(json_config)
364 |
365 | with open(a.fm_config) as f:
366 | data = f.read()
367 | json_fm_config = json.loads(data)
368 | fm_h = AttrDict(json_fm_config)
369 |
370 |
371 | build_env(a.config, 'config.json', a.checkpoint_path)
372 | # TODO: copy configs for feature mapping and envelope models!
373 |
374 | torch.manual_seed(h.seed)
375 | if torch.cuda.is_available():
376 | torch.cuda.manual_seed(h.seed)
377 | # Skip multi-gpu implementation until supported by all the models.
378 | h.num_gpus = 1 #torch.cuda.device_count()
379 | h.batch_size = int(h.batch_size / h.num_gpus)
380 | print('Batch size per GPU :', h.batch_size)
381 | else:
382 | pass
383 |
384 | if h.num_gpus > 1:
385 | mp.spawn(train, nprocs=h.num_gpus, args=(a, h, fm_h))
386 | else:
387 | train(0, a, h, fm_h)
388 |
389 |
390 | if __name__ == '__main__':
391 | main()
392 |
--------------------------------------------------------------------------------