├── .gitignore ├── .pylintrc ├── LICENSE ├── README.md ├── environment.yaml ├── litfass ├── __init__.py ├── data │ └── wada_values.npy ├── dataset │ ├── __init__.py │ ├── audio_utils.py │ ├── cwt.py │ ├── datasets.py │ ├── metrics.py │ └── snr.py ├── fastspeech2 │ ├── __init__.py │ ├── fastdiff_variances.py │ ├── fastspeech2.py │ ├── log_gmm.py │ ├── loss.py │ ├── model.py │ └── noam.py ├── generate.py ├── plot.py ├── synthesis │ ├── __init__.py │ ├── g2p.py │ └── generator.py ├── third_party │ ├── __init__.py │ ├── argutils │ │ └── __init__.py │ ├── dvectors │ │ ├── dvector.pt │ │ └── wav2mel.py │ ├── fastdiff │ │ ├── FastDiff.py │ │ └── module │ │ │ ├── modules.py │ │ │ └── util.py │ ├── hifigan │ │ ├── LICENSE │ │ ├── __init__.py │ │ ├── config.json │ │ ├── generator_LJSpeech.pth.tar │ │ ├── generator_universal.pth.tar │ │ └── models.py │ ├── softdtw │ │ └── __init__.py │ └── stochastic_duration_predictor │ │ ├── normalization.py │ │ ├── sdp.py │ │ └── transforms.py └── train.py ├── pyproject.toml └── scripts ├── generate.sh ├── generate_ab_train_splits.py └── train.sh /.gitignore: -------------------------------------------------------------------------------- 1 | .cache 2 | __pycache__ 3 | variance_encoders 4 | wandb 5 | lightning_logs 6 | LightningFastSpeech 7 | notebooks 8 | logs 9 | None 10 | .vscode 11 | models 12 | __pypackages__ 13 | .pdm.toml 14 | *.log 15 | run_longjob.sh 16 | *.lprof 17 | examples/*.png 18 | !src/data/* 19 | sampling_values 20 | *.pkl 21 | speechbrain 22 | .ipynb_checkpoints 23 | examples 24 | fastdiff_model 25 | *.lock -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [TYPECHECK] 2 | # List of members which are set dynamically and missed by Pylint inference 3 | # system, and so shouldn't trigger E1101 when accessed. 4 | generated-members=numpy.*, torch.* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Christoph Minixhofer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LightningFastSpeech 2 | 3 | **WARNING: This is a work in progress and until version 0.1 (which will be out very soon), it might be hard to get running on your own machine. Thanks for your patience.** 4 | 5 | ## Large Pretrained TTS 6 | 7 | In the NLP community, and more recently in speech recognition, large pre-trained models and how they can be used for down-stream tasks have become an exciting area of research. 8 | 9 | In TTS however, little similar work exists. With this project, I hope to make a first step into bringing pretrained models to TTS. 10 | The original FastSpeech 2 model is 27M parameters large and models a single speaker, while our version would have almost 1B parameters without the improvements from LightSpeech, which bring its size down to a manageable 76M, and models more than 2,000 speakers. 11 | 12 | A big upside of this implementation is that it is based on [Pytorch Lightning](https://www.pytorchlightning.ai/), which makes it easy to do multi-gpu training, load pre-trained models and a lot more. 13 | 14 | LightningFastSpeech couldn't exist without the amazing open source work of many others, for a full list see [Attribution](#attribution). 15 | 16 | ## Current Status 17 | 18 | This library is a work in progress, and until v1.0, updates might break things occasionally. 19 | 20 | # Goals 21 | 22 | ## v0.1 23 | 24 | **0.1** is right around the corner! For this version, the core functionality is already there, and what's missing are mostly quality of life improvements that we should get out of the way now. 25 | 26 | - [x] Replicate original FastSpeech 2 architecture 27 | - [x] Include Depth-wise separable convolutions found in LightSpeech 28 | - [x] Dataloader which computes prosody features online 29 | - [x] Synthesis of both individual utterances and whole datasets 30 | - [x] Configurable training script. 31 | - [ ] Configurable synthesis script. 32 | - [ ] First large pre-trained model (LibriTTS, 2k speakers, 76M parameters). 33 | - [ ] Documentation & tutorials. 34 | - [ ] Allow reporting other than wandb. 35 | - [ ] Configurable metrics. 36 | - [ ] LJSpeech support. 37 | - [ ] PyPi package. 38 | - [ ] Hifi GAN Finetuning (during and after training) 39 | 40 | ## v1.0 41 | 42 | It will take a while to get to 1.0 -- the goal for this to allow everyone to easily fine-tune our models and to easily do controllable synthesis of utterances. 43 | 44 | - [x] Allow models to be loaded from the [Huggingface hub](huggingface.co/models). 45 | - [ ] [Streamlit](https://streamlit.io/) interface for synthesising utterances and generating datasets. 46 | - [ ] [Tract](https://github.com/sonos/tract) and [tractjs](https://bminixhofer.github.io/tractjs/) integration to export models for on-device and web use. 47 | - [ ] Make it easy to add new datasets and to fine-tune models with them. 48 | - [ ] Add HiFi-GAN fine-tuning to the pipeline. 49 | - [ ] A range of pre-trained models with different domains and sizes (e.g. multi-lingual, noisy/clean) 50 | 51 | # Attribution 52 | 53 | This would not be possible without a lot of amazing open source project in the TTS space already present -- please cite their work when appropriate! 54 | 55 | - [Chung-Ming Chien's FastSpeech 2 implementation](https://github.com/ming024/FastSpeech2), which was used during as a reference implementation. 56 | - [yistLin's public d-vector implementation](https://github.com/yistLin/dvector), which is used for multi-speaker training. 57 | - [Aidan Pine's fork of FastSpeech 2](https://github.com/roedoejet/FastSpeech2), which served as the basis for the implementation of the depth-wise convolutions used in LightSpeech. 58 | - [Coqui AI's excellent TTS toolkit](https://github.com/coqui-ai/TTS), which was used for the Stochastic Duration Predictor and inspired the loss weighing we do. 59 | - [Jungil Kong's HiFi-GAN implementation](https://github.com/jik876/hifi-gan), which is used vocoding mel spectrograms produced by our TTS system. 60 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: root 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - asn1crypto=1.0.1=py37_0 9 | - blas=1.0=mkl 10 | - bottleneck=1.3.2=py37heb32a55_1 11 | - brotli=1.0.9=he1b5a44_3 12 | - bzip2=1.0.8=h7b6447c_0 13 | - ca-certificates=2021.10.8=ha878542_0 14 | - certifi=2021.10.8=py37h89c1867_1 15 | - cffi=1.12.3=py37h2e261b9_0 16 | - chardet=3.0.4=py37_1003 17 | - click=7.1.2=pyh9f0ad1d_0 18 | - conda=4.11.0=py37h89c1867_0 19 | - conda-package-handling=1.6.0=py37h7b6447c_0 20 | - configparser=5.2.0=pyhd8ed1ab_0 21 | - cryptography=2.7=py37h1ba5d50_0 22 | - cudatoolkit=10.2.89=hfd86e86_1 23 | - cycler=0.11.0=pyhd8ed1ab_0 24 | - dataclasses=0.8=pyhc8e2a94_3 25 | - dbus=1.13.6=he372182_0 26 | - docker-pycreds=0.4.0=py_0 27 | - expat=2.4.1=h2531618_2 28 | - ffmpeg=4.3=hf484d3e_0 29 | - fontconfig=2.13.1=h86ecdb6_1001 30 | - fonttools=4.25.0=pyhd3eb1b0_0 31 | - freetype=2.11.0=h70c0345_0 32 | - gettext=0.19.8.1=hf34092f_1004 33 | - giflib=5.2.1=h7b6447c_0 34 | - gitdb=4.0.9=pyhd8ed1ab_0 35 | - gitpython=3.1.24=pyhd8ed1ab_0 36 | - glib=2.58.3=py37he00f558_1004 37 | - gmp=6.2.1=h2531618_2 38 | - gnutls=3.6.15=he1e5248_0 39 | - gst-plugins-base=1.14.5=h0935bb2_2 40 | - gstreamer=1.14.5=h36ae1b5_2 41 | - icu=64.2=he1b5a44_1 42 | - idna=2.8=py37_0 43 | - intel-openmp=2021.4.0=h06a4308_3561 44 | - jpeg=9d=h7f8727e_0 45 | - kiwisolver=1.3.1=py37h2531618_0 46 | - lame=3.100=h7b6447c_0 47 | - lcms2=2.12=h3be6417_0 48 | - libedit=3.1.20181209=hc058e9b_0 49 | - libffi=3.2.1=hd88cf55_4 50 | - libgcc-ng=9.1.0=hdf63c60_0 51 | - libgfortran-ng=7.5.0=h14aa051_19 52 | - libgfortran4=7.5.0=h14aa051_19 53 | - libiconv=1.15=h63c8f33_5 54 | - libidn2=2.3.2=h7f8727e_0 55 | - libpng=1.6.37=hbc83047_0 56 | - libprotobuf=3.17.2=h4ff587b_1 57 | - libstdcxx-ng=9.1.0=hdf63c60_0 58 | - libtasn1=4.16.0=h27cfd23_0 59 | - libtiff=4.2.0=h85742a9_0 60 | - libunistring=0.9.10=h27cfd23_0 61 | - libuuid=2.32.1=h14c3975_1000 62 | - libuv=1.40.0=h7b6447c_0 63 | - libwebp=1.2.0=h89dd481_0 64 | - libwebp-base=1.2.0=h27cfd23_0 65 | - libxcb=1.13=h14c3975_1002 66 | - libxml2=2.9.10=hee79883_0 67 | - lz4-c=1.9.3=h295c915_1 68 | - matplotlib=3.5.0=py37h89c1867_0 69 | - matplotlib-base=3.5.0=py37h3ed280b_0 70 | - mkl=2021.4.0=h06a4308_640 71 | - mkl-service=2.4.0=py37h7f8727e_0 72 | - mkl_fft=1.3.1=py37hd3c417c_0 73 | - mkl_random=1.2.2=py37h51133e4_0 74 | - munkres=1.1.4=pyh9f0ad1d_0 75 | - ncurses=6.1=he6710b0_1 76 | - nettle=3.7.3=hbbd107a_1 77 | - numexpr=2.7.3=py37h22e1b3c_1 78 | - numpy=1.21.2=py37h20f2e39_0 79 | - numpy-base=1.21.2=py37h79a1101_0 80 | - olefile=0.46=py37_0 81 | - openh264=2.1.0=hd408876_0 82 | - openssl=1.1.1h=h516909a_0 83 | - packaging=21.3=pyhd8ed1ab_0 84 | - pandas=1.3.4=py37h8c16a72_0 85 | - pathtools=0.1.2=py_1 86 | - patsy=0.5.2=pyhd8ed1ab_0 87 | - pcre=8.45=h295c915_0 88 | - pillow=8.4.0=py37h5aabda8_0 89 | - pip=21.0.1=py37h06a4308_0 90 | - plotly=5.4.0=pyhd8ed1ab_0 91 | - promise=2.3=py37h89c1867_5 92 | - protobuf=3.17.2=py37h295c915_0 93 | - psutil=5.8.0=py37h27cfd23_1 94 | - pthread-stubs=0.4=h36c2ea0_1001 95 | - pycosat=0.6.3=py37h14c3975_0 96 | - pycparser=2.19=py37_0 97 | - pyopenssl=19.0.0=py37_0 98 | - pyparsing=3.0.6=pyhd8ed1ab_0 99 | - pyqt=5.9.2=py37hcca6a23_4 100 | - pysocks=1.7.1=py37_0 101 | - python=3.7.4=h265db76_1 102 | - python-dateutil=2.8.2=pyhd3eb1b0_0 103 | - python_abi=3.7=2_cp37m 104 | - pytorch=1.10.0=py3.7_cuda10.2_cudnn7.6.5_0 105 | - pytorch-mutex=1.0=cuda 106 | - pytz=2021.3=pyhd3eb1b0_0 107 | - pyyaml=5.1.2=py37h516909a_0 108 | - qt=5.9.7=h0c104cb_3 109 | - readline=7.0=h7b6447c_5 110 | - requests=2.22.0=py37_0 111 | - ruamel_yaml=0.15.46=py37h14c3975_0 112 | - scipy=1.7.1=py37h292c36d_2 113 | - seaborn=0.11.2=hd8ed1ab_0 114 | - seaborn-base=0.11.2=pyhd8ed1ab_0 115 | - sentry-sdk=1.5.0=pyhd8ed1ab_0 116 | - setuptools=41.4.0=py37_0 117 | - shortuuid=1.0.8=py37h89c1867_0 118 | - sip=4.19.8=py37hf484d3e_0 119 | - six=1.16.0=pyh6c4a22f_0 120 | - smmap=3.0.5=pyh44b312d_0 121 | - sqlite=3.30.0=h7b6447c_0 122 | - statsmodels=0.12.1=py37ha21ca33_1 123 | - subprocess32=3.5.4=py_1 124 | - tenacity=8.0.1=pyhd8ed1ab_0 125 | - termcolor=1.1.0=py_2 126 | - tk=8.6.8=hbc83047_0 127 | - torchaudio=0.10.0=py37_cu102 128 | - torchvision=0.11.1=py37_cu102 129 | - tornado=6.1=py37h4abf009_0 130 | - tqdm=4.61.2=pyhd8ed1ab_1 131 | - typing_extensions=3.10.0.2=pyh06a4308_0 132 | - urllib3=1.24.2=py37_0 133 | - wandb=0.12.7=pyhd8ed1ab_0 134 | - wheel=0.33.6=py37_0 135 | - xorg-libxau=1.0.9=h14c3975_0 136 | - xorg-libxdmcp=1.1.3=h516909a_0 137 | - xz=5.2.5=h7b6447c_0 138 | - yaml=0.1.7=had09818_2 139 | - yaspin=2.1.0=pyhd8ed1ab_0 140 | - zlib=1.2.11=h7b6447c_3 141 | - zstd=1.4.9=haebb681_0 142 | - pip: 143 | - dill==0.3.4 144 | - pandarallel==1.5.4 145 | prefix: /home/cdminix/.pyenv/versions/miniconda3-4.7.12 146 | -------------------------------------------------------------------------------- /litfass/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiniXC/LightningFastSpeech2/23a83937c09d8cb2590d3b27212c7ac497e2a120/litfass/__init__.py -------------------------------------------------------------------------------- /litfass/data/wada_values.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiniXC/LightningFastSpeech2/23a83937c09d8cb2590d3b27212c7ac497e2a120/litfass/data/wada_values.npy -------------------------------------------------------------------------------- /litfass/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiniXC/LightningFastSpeech2/23a83937c09d8cb2590d3b27212c7ac497e2a120/litfass/dataset/__init__.py -------------------------------------------------------------------------------- /litfass/dataset/audio_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | 5 | 6 | # clip_val 1e-7, log10=False 7 | # todo: make configureable without discarding cache 8 | def dynamic_range_compression(x, C=1, clip_val=1e-6, log10=True): 9 | if log10: 10 | return torch.log10(torch.clamp(x, min=clip_val) * C) 11 | else: 12 | return torch.log(torch.clamp(x, min=clip_val) * C) 13 | 14 | 15 | def dynamic_range_decompression(x, C=1): 16 | return torch.exp(x) / C 17 | 18 | 19 | def smooth(y, box_pts): 20 | box = np.ones(box_pts) / box_pts 21 | y_smooth = np.convolve(y, box, mode="same") 22 | return y_smooth 23 | 24 | 25 | def remove_outliers(values): 26 | values = np.array(values) 27 | p25 = np.percentile(values, 25) 28 | p75 = np.percentile(values, 75) 29 | lower = p25 - 1.5 * (p75 - p25) 30 | upper = p75 + 1.5 * (p75 - p25) 31 | normal_indices = np.logical_and(values > lower, values < upper) 32 | values[~normal_indices] = 0 33 | return values 34 | 35 | 36 | def get_alignment(tier, sampling_rate, hop_length): 37 | sil_phones = ["sil", "sp", "spn", ""] 38 | 39 | phones = [] 40 | durations = [] 41 | start_time = 0 42 | end_time = 0 43 | end_idx = 0 44 | counter = 0 45 | for t in tier._objects: 46 | s, e, p = t.start_time, t.end_time, t.text 47 | 48 | # add silence phone if timestamp gap occurs 49 | if s != end_time and len(phones) > 0: 50 | phones.append("sil") 51 | durations.append( 52 | int( 53 | np.round(s * sampling_rate / hop_length) 54 | - np.round(end_time * sampling_rate / hop_length) 55 | ) 56 | ) 57 | 58 | # Trim leading silences 59 | if phones == []: 60 | if p in sil_phones: 61 | continue 62 | else: 63 | start_time = s 64 | 65 | if p not in sil_phones: 66 | # For ordinary phones 67 | phones.append(p) 68 | end_time = e 69 | end_idx = len(phones) 70 | else: 71 | # For silent phones 72 | phones.append("sil") 73 | end_time = e 74 | 75 | durations.append( 76 | int( 77 | np.round(e * sampling_rate / hop_length) 78 | - np.round(s * sampling_rate / hop_length) 79 | ) 80 | ) 81 | 82 | # Trim tailing silences 83 | phones = phones[:end_idx] 84 | durations = durations[:end_idx] 85 | 86 | true_dur = int(np.ceil(((end_time - start_time) * sampling_rate - 1) / hop_length)) 87 | pred_dur = sum(durations) 88 | if pred_dur != true_dur: 89 | durations[-1] += true_dur - pred_dur 90 | 91 | return phones, durations, start_time, end_time 92 | -------------------------------------------------------------------------------- /litfass/dataset/cwt.py: -------------------------------------------------------------------------------- 1 | from scipy.signal import cwt, ricker 2 | import numpy as np 3 | 4 | # https://www.isca-speech.org/archive_v0/ssw8/papers/ssw8_285.pdf 2.3 5 | # implementation https://github.com/ming024/FastSpeech2/issues/136 6 | 7 | 8 | def wavelet_decomposition(signal, wavelet, n_scales=10, tau=0.2833425): 9 | widths = [2 ** (i + 1) * tau for i in range(1, n_scales + 1)] 10 | cwtmatr = cwt(signal, wavelet, widths) 11 | 12 | constant = [(i + 2.5) ** (-5 / 2) for i in range(1, n_scales + 1)] 13 | constant = np.array(constant)[:, None] 14 | cwtmatr = cwtmatr * constant 15 | return cwtmatr, widths 16 | 17 | 18 | def wavelet_recomposition(wavelet_matrix): 19 | signal = wavelet_matrix.sum(axis=0) 20 | signal = (signal - signal.mean()) / (signal.std() + 1e-7) 21 | return signal 22 | 23 | 24 | class CWT: 25 | def __init__(self, wavelet=ricker, n_scales=10, tau=0.2833425): 26 | self.wavelet = wavelet 27 | self.n_scales = n_scales 28 | self.tau = tau 29 | 30 | def decompose(self, signal): 31 | signal[signal == 0] = 1e-7 32 | original_signal = signal.copy() 33 | signal = np.log(signal) 34 | cwtmatr, widths = wavelet_decomposition( 35 | (signal - signal.mean()) / (signal.std() + 1e-7), 36 | self.wavelet, 37 | self.n_scales, 38 | self.tau, 39 | ) 40 | return { 41 | "signal": signal, 42 | "original_signal": original_signal, 43 | "spectrogram": cwtmatr.T, 44 | "mean": signal.mean(), 45 | "std": signal.std(), 46 | } 47 | 48 | def recompose(self, spectrogram, mean, std): 49 | signal = wavelet_recomposition(spectrogram) 50 | return signal * std + mean 51 | -------------------------------------------------------------------------------- /litfass/dataset/metrics.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import pyworld as pw 5 | from srmrpy import srmr 6 | import torch 7 | from torchaudio import functional as F 8 | 9 | from litfass.dataset.snr import SNR 10 | 11 | class SpeechMetric(ABC): 12 | @abstractmethod 13 | def __init__(self, sample_rate, win_length, hop_length): 14 | self.sample_rate = sample_rate 15 | self.win_length = win_length 16 | self.hop_length = hop_length 17 | 18 | @abstractmethod 19 | def get_metric(self, audio, silence_mask): 20 | pass 21 | 22 | def get_metric_value(self, audio, silence_mask): 23 | return self.get_metric(audio, silence_mask)[~silence_mask].mean() 24 | 25 | @abstractmethod 26 | def __str__(self): 27 | pass 28 | 29 | def _interpolate(self, x): 30 | def nan_helper(y): 31 | return np.isnan(y), lambda z: z.nonzero()[0] 32 | 33 | nans, y = nan_helper(x) 34 | x[nans] = np.interp(y(nans), y(~nans), x[~nans]) 35 | return x 36 | 37 | class WADA(SpeechMetric): 38 | def __init__(self, sample_rate, win_length, hop_length): 39 | super().__init__(sample_rate, win_length, hop_length) 40 | self.name = "WADA" 41 | 42 | def get_metric(self, audio, silence_mask): 43 | audio = audio.astype(np.float32) 44 | snr = SNR(audio, self.sample_rate) 45 | wada = snr.windowed_wada( 46 | window=self.win_length, 47 | stride=self.hop_length / self.win_length, 48 | use_samples=True, 49 | ) 50 | if len(silence_mask) < len(wada): 51 | wada = wada[:len(silence_mask)] 52 | wada[silence_mask] = np.nan 53 | if all(np.isnan(wada)): 54 | wada = np.zeros_like(wada) 55 | else: 56 | wada = self._interpolate(wada) 57 | return wada 58 | 59 | def get_metric_value(self, audio, silence_mask): 60 | audio = audio.astype(np.float32, silence_mask) 61 | audio = audio[~silence_mask] 62 | snr = SNR(audio, self.sample_rate) 63 | snr = SNR(audio.astype(np.float32), self.sample_rate) 64 | return snr.wada() 65 | 66 | def __str__(self): 67 | return self.name 68 | 69 | class Pitch(SpeechMetric): 70 | def __init__(self, sample_rate, win_length, hop_length): 71 | super().__init__(sample_rate, win_length, hop_length) 72 | self.name = "Pitch" 73 | 74 | def get_metric(self, audio, silence_mask): 75 | pitch, t = pw.dio( 76 | audio.astype(np.float64), 77 | self.sampling_rate, 78 | frame_period=self.hop_length / self.sampling_rate * 1000, 79 | speed=self.dio_speed, 80 | ) 81 | pitch = pw.stonemask( 82 | audio.astype(np.float64), pitch, t, self.sampling_rate 83 | ).astype(np.float32) 84 | pitch[pitch == 0] = np.nan 85 | if len(silence_mask) < len(pitch): 86 | pitch = pitch[:len(silence_mask)] 87 | pitch[silence_mask] = np.nan 88 | if np.isnan(pitch).all(): 89 | pitch[:] = 1e-7 90 | pitch = self._interpolate(pitch) 91 | 92 | def __str__(self): 93 | return self.name 94 | 95 | class Energy(SpeechMetric): 96 | def __init__(self, sample_rate, win_length, hop_length): 97 | super().__init__(sample_rate, win_length, hop_length) 98 | self.name = "Energy" 99 | 100 | def get_metric(self, audio, silence_mask): 101 | energy = np.array( 102 | [ 103 | np.sqrt( 104 | np.sum( 105 | ( 106 | audio[ 107 | x * self.hop_length : (x * self.hop_length) 108 | + self.win_length 109 | ] 110 | ** 2 111 | ) 112 | ) 113 | / self.win_length 114 | ) 115 | for x in range(int(np.ceil(len(audio) / self.hop_length))) 116 | ] 117 | ) 118 | if len(silence_mask) < len(energy): 119 | energy = energy[:len(silence_mask)] 120 | 121 | def __str__(self): 122 | return self.name 123 | 124 | class SRMR(SpeechMetric): 125 | def __init__(self, sample_rate, win_length, hop_length): 126 | super().__init__(sample_rate, win_length, hop_length) 127 | self.name = "SRMR" 128 | 129 | def get_metric(self, audio, silence_mask): 130 | if self.sample_rate != 16000: 131 | audio = F.resample( 132 | torch.from_numpy(audio), 133 | self.sample_rate, 134 | 16000, 135 | ).numpy() 136 | 137 | srmr_values = np.array( 138 | [ 139 | srmr( 140 | audio[ 141 | x * self.hop_length : (x * self.hop_length) 142 | + self.win_length 143 | ], 16000 144 | ) 145 | for x in range(int(np.ceil(len(audio) / self.hop_length))) 146 | ] 147 | ) 148 | if len(silence_mask) < len(srmr_values): 149 | srmr_values = srmr_values[:len(silence_mask)] 150 | 151 | def get_metric_value(self, audio, silence_mask): 152 | if self.sample_rate != 16000: 153 | audio = F.resample( 154 | torch.from_numpy(audio), 155 | self.sample_rate, 156 | 16000, 157 | ).numpy() 158 | return srmr(audio, 16000) 159 | 160 | def __str__(self): 161 | return self.name -------------------------------------------------------------------------------- /litfass/dataset/snr.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Iterable, Union 2 | 3 | import numpy as np 4 | from textgrid import TextGrid 5 | from pathlib import Path 6 | 7 | 8 | class SNR: 9 | def __init__( 10 | self, 11 | values: Iterable[float], 12 | rate: int, 13 | rms_window: Optional[int] = None, 14 | rms_stride: Optional[float] = 0.5, 15 | vad: Optional[Union[str, Iterable[Tuple[float, float]]]] = None, 16 | ): 17 | """ 18 | Creates a new ``SNR`` object tied to specific audio array (given as ``values``) and sampling rate (as ``rate``). 19 | RMS window and stride can be set on initialisation using ``rms_window`` and ``rms_stride``, or later using the ``SNR.rms`` endpoint. 20 | When setting RMS on initalisation, this can be reverted to sample-based using the ``SNR.samples`` endpoint. 21 | A vad file in ``.json`` or ``.TextGrid`` format or a tuple array of the format ``[[word_start_in_seconds: float, word_duration_in_seconds: float],...]`` can be given as well to enable windowed measures on voiced parts only or to use ``SNR.vad_ratio``. 22 | """ 23 | self._values = values 24 | self.rate = rate 25 | self._rms = rms_window 26 | self._rms_stride = rms_stride 27 | if isinstance(vad, str): 28 | vad = SNR.load_vad(vad) 29 | self.vad = vad 30 | 31 | def __iter__(self): 32 | return self._values.__iter__() 33 | 34 | def __getitem__(self, key): 35 | return SNR(self._values[key], self.rate, self._rms, self._rms_stride, self.vad) 36 | 37 | def __len__(self): 38 | return len(self._values) 39 | 40 | @property 41 | def duration(self) -> float: 42 | """ 43 | The duration in fractional seconds. 44 | """ 45 | return len(self) / self.rate 46 | 47 | def seconds(self, start: float, end: float) -> "SNR": 48 | """ 49 | Returns a new ``SNR`` object which only spans from second ``start`` to second ``end``. 50 | """ 51 | return self[int(self.rate * start) : int(self.rate * end)] 52 | 53 | def rms(self, rms_window: int = 20, rms_stride: float = 0.5): 54 | """ 55 | Returns a new ``SNR`` object with rms-based audio. 56 | ``rms_window`` and ``rms_stride`` determine the window width and stride as a percentage of the width, respectively. 57 | """ 58 | return SNR(self._values, self.rate, rms_window, rms_stride, self.vad) 59 | 60 | @staticmethod 61 | def load_vad(file_path: str) -> Iterable[Tuple[float, float]]: 62 | """ 63 | Loads voice activity detection intervals for words from either a ``.TextGrid`` or ``.json`` file. 64 | The resulting tuple array can be added to an ``SNR`` object using ``SNR.add_vad`` or be passed on initialisation of a new SNR class. 65 | """ 66 | if ".wav" in file_path: 67 | file_path = file_path.replace(".wav", ".json") 68 | result = [] 69 | if ".textgrid" in file_path.lower(): 70 | tg = TextGrid.fromFile(file_path) 71 | for t in tg: 72 | if t.name == "words": 73 | for w in t: 74 | if len(w.mark) > 0: 75 | result.append([w.minTime, w.maxTime - w.minTime]) 76 | elif ".json" in file_path.lower(): 77 | result = [ 78 | [w["startTime"], w["duration"]] 79 | for w in json.load(open(file_path))["asrResult"]["words"] 80 | ] 81 | return result 82 | 83 | @property 84 | def samples(self) -> "SNR": 85 | """ 86 | Returns a new ``SNR`` object with sample-based audio. 87 | """ 88 | return SNR(self._values, self.rate, None, self._rms_stride, self.vad) 89 | 90 | @property 91 | def notebook(self) -> "SNRNotebook": 92 | """ 93 | The ``SNRNotebook`` endpoint for Jupyter Notebook visualisation. 94 | """ 95 | return SNRNotebook(self) 96 | 97 | @classmethod 98 | def from_file(cls, file_path: str, **kwargs) -> "SNR": 99 | """ 100 | Uses ``librosa.load`` to create a ``SNR`` object from an audio file path (``file_path``). 101 | Keyword arguments for ``SNR`` can be passed as well. 102 | """ 103 | data, rate = librosa.load(file_path) 104 | snr = cls(data, rate, **kwargs) 105 | return snr 106 | 107 | def to_file(self, file_path: str): 108 | """ 109 | Writes the audio to the given file path. 110 | """ 111 | Path("/".join(file_path.split("/")[:-1])).mkdir(parents=True, exist_ok=True) 112 | sf.write(file_path, self.samples.values, self.rate) 113 | 114 | def add_vad(self, vad: Union[str, Iterable[Tuple[float, float]]]): 115 | """ 116 | Adds the given VAD intervals to this ``SNR`` object. 117 | Intervals can be given as ``[[word_start_in_seconds: float, word_duration_in_seconds: float],...]``. 118 | If ``vad`` is a string, this will fall back on ``SNR.load_vad``. 119 | """ 120 | if isinstance(vad, str): 121 | vad = SNR.load_vad(vad) 122 | self.vad = vad 123 | return self 124 | 125 | @property 126 | def values(self): 127 | """ 128 | Returns the audio values. If set to ``SNR.rms`` these will be rms values, if set to ``SNR.samples``, the original audio array will be returned. 129 | """ 130 | if self._rms is None: 131 | return self._values 132 | else: 133 | rms_arr = [] 134 | step = int(self.rate * (self._rms / 1000)) 135 | for index in np.arange(0, len(self._values), int(step * self._rms_stride)): 136 | window_values = self._values[index : index + step] 137 | rms_arr.append(np.sum(window_values**2) / len(window_values)) 138 | return np.array(rms_arr) 139 | 140 | @property 141 | def power(self) -> float: 142 | """ 143 | The RMS power in dB. 144 | """ 145 | return 20 * np.log10(np.sqrt(np.sum(self.values**2) / len(self.values))) 146 | 147 | @staticmethod 148 | def normalize(values: Iterable[float]) -> Iterable[float]: 149 | """ 150 | Returns a normalized version of the given float array. 151 | """ 152 | a = np.sqrt(len(values) / np.sum(values**2)) 153 | return values * a 154 | 155 | def get_augmented(self, noise: "SNR", snr: int = 0): 156 | """ 157 | Combines this ``SNR`` object with the given ``noise`` object at the desired ``snr`` and returns the new noisy ``SNR`` object. 158 | """ 159 | # get the most noisy area of the noise 160 | n, s = noise, self 161 | ns_diff = len(n) - len(s) 162 | # edge cases 163 | if n._rms != s._rms: 164 | raise ValueError("Noise and signal rms window must match.") 165 | if n.rate != s.rate: 166 | raise ValueError("Noise and signal rates must match.") 167 | if ns_diff < 0: 168 | raise ValueError("Noise shorter than signal, use longer noise.") 169 | if ns_diff == 0 and np.allclose(n.values, s.values): 170 | raise ValueError("Noise identical to signal.") 171 | rms_l = [] 172 | # look at 100 different evenly spaced points in the audio 173 | for i in range(100): 174 | start = ns_diff // 100 * i 175 | rms_l.append((start, n[start : start + len(s)].power)) 176 | rms_l = np.array(rms_l) 177 | # take the noise segment with maximum power 178 | start = int(rms_l[rms_l[:, 1].argmax(), 0]) 179 | n = n[start : start + len(s)] 180 | # normalize 181 | std = s.values.std() 182 | n_audio = SNR.normalize(s._values) 183 | n_noise = SNR.normalize(n._values) 184 | # actual SNR computation 185 | factor = 10 ** (-snr / 20) 186 | return SNR( 187 | ((n_audio * std) + (n_noise * std * factor)), 188 | self.rate, 189 | self._rms, 190 | self._rms_stride, 191 | self.vad, 192 | ) 193 | 194 | def _windowed_measure(self, measure, window, stride, use_vad, use_samples): 195 | windows = self.get_windows( 196 | window, stride, return_slices=True, use_samples=use_samples 197 | ) 198 | index_arr = [] 199 | value_arr = [] 200 | for index_slice in windows: 201 | if use_vad: 202 | start_in = any( 203 | [ 204 | (v[0] <= index_slice.start / self.rate <= v[0] + v[1]) 205 | for v in self.vad 206 | ] 207 | ) 208 | stop_in = any( 209 | [ 210 | (v[0] <= index_slice.stop / self.rate <= v[0] + v[1]) 211 | for v in self.vad 212 | ] 213 | ) 214 | if not (start_in or stop_in): 215 | continue 216 | value_arr.append(getattr(self[index_slice], measure)) 217 | index_arr.append(index_slice) 218 | return np.array(index_arr), np.array(value_arr) 219 | 220 | def get_windows( 221 | self, 222 | window: int = 100, 223 | stride: float = 0.5, 224 | return_slices: bool = False, 225 | use_samples: bool = False, 226 | ): 227 | """ 228 | Used to get the windowed values of this ``SNR`` object. 229 | If ``return_slices`` is set to ``True``, slices which can be used to index an SNR object are returned (for example, use ``SNR[SNR.get_windows(return_slices=True)[0]]`` to get the first window). 230 | If ``return_slices`` is set to ``False``, the values present in each window are returned instead. 231 | ``window`` and ``stride`` determine the window width and stride as a percentage of the width, respectively. 232 | """ 233 | index_arr = [] 234 | if use_samples: 235 | step = window 236 | else: 237 | step = int(self.rate * (window / 1000)) 238 | for index in np.arange( 239 | 0, int(np.ceil(len(self._values) / step) * step), int(step * stride) 240 | ): 241 | if index > len(self._values) - 1: 242 | break 243 | index_slice = slice(index, min(index + step, len(self._values))) 244 | if return_slices: 245 | index_arr.append(index_slice) 246 | else: 247 | index_arr.append(self[index_slice]) 248 | if return_slices: 249 | return np.array(index_arr) 250 | else: 251 | return index_arr 252 | 253 | @property 254 | def wada(self): 255 | """ 256 | Over the entire audio: Return the wada measure as defined in http://www.cs.cmu.edu/~robust/Papers/KimSternIS08.pdf using open-source code provided here: https://gist.github.com/johnmeade/d8d2c67b87cda95cd253f55c21387e75#file-snr-py-L7 257 | """ 258 | return _wada(self.values) 259 | 260 | def windowed_wada(self, window, stride=0.5, use_vad=False, use_samples=False): 261 | """ 262 | ``window`` and ``stride`` determine the window width and stride as a percentage of the width, respectively. 263 | """ 264 | value_arr = [] 265 | result = self._windowed_measure("wada", window, stride, use_vad, use_samples) 266 | for i, v in zip(*result): 267 | if v > -20 and v < 100: 268 | value_arr.append(v + 20) 269 | else: 270 | value_arr.append(np.nan) 271 | return np.array(value_arr) 272 | 273 | @property 274 | def r(self): 275 | """ 276 | Over the entire audio: Returns the r measure defined as the log10 of the ratio of the 95th and 5th percentile after taking the absolute value of each sample or RMS and adding a floor at 10e-10. 277 | """ 278 | return _r(self.values) 279 | 280 | def windowed_r(self, window, stride=0.5, use_vad=False): 281 | """ 282 | ``window`` and ``stride`` determine the window width and stride as a percentage of the width, respectively. 283 | """ 284 | index_arr = [] 285 | value_arr = [] 286 | result = self._windowed_measure("r", window, stride, use_vad) 287 | for i, v in zip(*result): 288 | if v > 0: 289 | index_arr.append(i) 290 | value_arr.append(v) 291 | return np.array(index_arr), np.array(value_arr) 292 | 293 | def vad_ratio(self, padding: int = 10): 294 | """ 295 | Over the entire audio: Calculate the ratio of the mean power in voice vs. unvoiced regions. This can be infinity when the power in unvoiced regions is zero. 296 | ``padding`` (given in milliseconds) can make the voice regions smaller when positive, or larger when negative. 297 | """ 298 | v_factors = [] 299 | v_powers = [] 300 | s_factors = [] 301 | s_powers = [] 302 | last_i = 0 303 | for v in self.vad: 304 | v0 = v[0] - padding / 1000 305 | v1 = v[1] - padding / 1000 306 | if v0 - last_i > 0: 307 | selection = self.seconds(last_i, v0) 308 | if len(selection.values) > 0: 309 | s_factors.append(v0 - last_i) 310 | s_powers.append(selection.power) 311 | selection = self.seconds(v0, v0 + v1) 312 | if len(selection.values) > 0: 313 | v_factors.append(v1) 314 | v_powers.append(selection.power) 315 | last_i = v0 + v1 316 | v_factors, s_factors = np.array(v_factors), np.array(s_factors) 317 | v_powers, s_powers = np.array(v_powers), np.array(s_powers) 318 | v_factors /= v_factors.sum() 319 | s_factors /= s_factors.sum() 320 | s_result = (s_powers * s_factors).sum() 321 | v_result = (v_powers * v_factors).sum() 322 | return v_result - s_result 323 | 324 | 325 | g_vals = np.load(Path(__file__).parent.parent / "data" / "wada_values.npy") 326 | 327 | 328 | def _wada(wav): 329 | global g_vals 330 | # Direct blind estimation of the SNR of a speech signal. 331 | # 332 | # Paper on WADA SNR: 333 | # http://www.cs.cmu.edu/~robust/Papers/KimSternIS08.pdf 334 | # 335 | # This function was adapted from this matlab code: 336 | # https://labrosa.ee.columbia.edu/projects/snreval/#9 337 | # 338 | # MIT license, John Meade, 2020 339 | # init 340 | eps = 1e-20 341 | # next 2 lines define a fancy curve derived from a gamma distribution -- see paper 342 | db_vals = np.arange(-20, 101) 343 | # peak normalize, get magnitude, clip lower bound 344 | abs_wav = np.abs(wav) 345 | if np.sum(abs_wav) == 0: 346 | return np.nan 347 | abs_wav[abs_wav < eps] = eps 348 | # calcuate statistics 349 | v1 = max(eps, abs_wav.mean()) 350 | v2 = np.log(abs_wav).mean() 351 | v3 = np.log(v1) - v2 352 | # table interpolation 353 | wav_snr_idx = None 354 | if any(g_vals < v3): 355 | wav_snr_idx = np.where(g_vals < v3)[0].max() 356 | # handle edge cases or interpolate 357 | if wav_snr_idx is None: 358 | wav_snr = db_vals[0] 359 | elif wav_snr_idx == len(db_vals) - 1: 360 | wav_snr = db_vals[-1] 361 | else: 362 | wav_snr = db_vals[wav_snr_idx] + (v3 - g_vals[wav_snr_idx]) / ( 363 | g_vals[wav_snr_idx + 1] - g_vals[wav_snr_idx] 364 | ) * (db_vals[wav_snr_idx + 1] - db_vals[wav_snr_idx]) 365 | # Calculate SNR 366 | dEng = sum(wav**2) 367 | dFactor = 10 ** (wav_snr / 10) 368 | dNoiseEng = dEng / (1 + dFactor) # Noise energy 369 | dSigEng = dEng * dFactor / (1 + dFactor) # Signal energy 370 | snr = 10 * np.log10(dSigEng / dNoiseEng) 371 | return snr 372 | -------------------------------------------------------------------------------- /litfass/fastspeech2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiniXC/LightningFastSpeech2/23a83937c09d8cb2590d3b27212c7ac497e2a120/litfass/fastspeech2/__init__.py -------------------------------------------------------------------------------- /litfass/fastspeech2/fastdiff_variances.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from litfass.fastspeech2.model import VarianceConvolutionLayer, LengthRegulator 4 | from litfass.third_party.fastdiff.module.util import calc_diffusion_step_embedding, std_normal, compute_hyperparams_given_schedule, sampling_given_noise_schedule 5 | from litfass.third_party.fastdiff.FastDiff import swish 6 | 7 | 8 | class FastDiffVarianceAdaptor(nn.Module): 9 | """FastSpeech2 Variance Adaptor with FastDiff diffusion. Limited to 1d, frame-level predictions.""" 10 | def __init__( 11 | self, 12 | stats, 13 | variances, 14 | variance_nlayers, 15 | variance_kernel_size, 16 | variance_dropout, 17 | variance_filter_size, 18 | variance_nbins, 19 | variance_depthwise_conv, 20 | duration_nlayers, 21 | duration_kernel_size, 22 | duration_dropout, 23 | duration_filter_size, 24 | duration_depthwise_conv, 25 | encoder_hidden, 26 | max_length, 27 | diffusion_step_embed_dim_in=128, 28 | diffusion_step_embed_dim_mid=512, 29 | diffusion_step_embed_dim_out=512, 30 | beta_0=1e-6, 31 | beta_T=0.01, 32 | T=1000, 33 | ): 34 | super().__init__() 35 | 36 | self.max_length = max_length 37 | self.diffusion_step_embed_dim_in = diffusion_step_embed_dim_in 38 | 39 | self.noise_schedule = torch.linspace(beta_0, beta_T, T) 40 | self.diffusion_hyperparams = compute_hyperparams_given_schedule(self.noise_schedule) 41 | 42 | self.duration_predictor = FastDiffVariancePredictor( 43 | duration_nlayers, 44 | encoder_hidden, 45 | duration_filter_size, 46 | duration_kernel_size, 47 | duration_dropout, 48 | duration_depthwise_conv, 49 | self.diffusion_hyperparams, 50 | diffusion_step_embed_dim_in, 51 | diffusion_step_embed_dim_mid, 52 | diffusion_step_embed_dim_out, 53 | ) 54 | 55 | self.length_regulator = LengthRegulator(pad_to_multiple_of=64) # TODO: change this to use target length 56 | 57 | self.variances = variances 58 | 59 | self.encoders = {} 60 | for var in self.variances: 61 | self.encoders[var] = FastDiffVarianceEncoder( 62 | variance_nlayers[variances.index(var)], 63 | encoder_hidden, 64 | variance_filter_size, 65 | variance_kernel_size[variances.index(var)], 66 | variance_dropout[variances.index(var)], 67 | variance_depthwise_conv, 68 | stats[var]["min"], 69 | stats[var]["max"], 70 | stats[var]["mean"], 71 | stats[var]["std"], 72 | variance_nbins, 73 | self.diffusion_hyperparams, 74 | diffusion_step_embed_dim_in, 75 | diffusion_step_embed_dim_mid, 76 | diffusion_step_embed_dim_out, 77 | ) 78 | self.encoders = nn.ModuleDict(self.encoders) 79 | 80 | 81 | def forward( 82 | self, 83 | x, 84 | src_mask, 85 | targets, 86 | inference=False, 87 | N=4, 88 | ): 89 | if not inference: 90 | duration = targets["duration"] + 1 + torch.rand(size=targets["duration"].shape, device=targets["duration"].device)*0.49 91 | duration = (torch.log(duration) - 1.08) / 0.96 92 | duration_pred, duration_z = self.duration_predictor( 93 | duration.to(x.dtype), 94 | x.transpose(1, 2), 95 | mask=src_mask 96 | ) 97 | else: 98 | duration_pred = self.duration_predictor.inference(x, N=N) 99 | duration_z = None 100 | 101 | result = {} 102 | 103 | out_val = None 104 | 105 | if not inference: 106 | duration_rounded = targets["duration"] 107 | else: 108 | duration_pred = duration_pred * 0.96 + 1.08 109 | duration_rounded = torch.round((torch.exp(duration_pred) - 1)) 110 | duration_rounded = torch.clamp(duration_rounded, min=0).int() 111 | for i in range(len(duration_rounded)): 112 | if duration_rounded[i][~src_mask[i]].sum() <= (~src_mask[i]).sum() // 2: 113 | duration_rounded[i][~src_mask[i]] = 1 114 | print("Zero duration, setting to 1") 115 | duration_rounded[i][src_mask[i]] = 0 116 | 117 | x, tgt_mask = self.length_regulator(x, duration_rounded, self.max_length) 118 | if out_val is not None: 119 | out_val, _ = self.length_regulator(out_val, duration_rounded, self.max_length) 120 | 121 | for i, var in enumerate(self.variances): 122 | if not inference: 123 | (pred, z), out = self.encoders[var]( 124 | x.transpose(1, 2), targets[f"variances_{var}"], tgt_mask 125 | ) 126 | else: 127 | pred, out = self.encoders[var](x, None, tgt_mask) 128 | z = None 129 | result[f"variances_{var}"] = pred 130 | result[f"variances_{var}_z"] = z 131 | if out_val is None: 132 | out_val = out 133 | else: 134 | out_val = out_val + out 135 | x = x + out 136 | 137 | result["x"] = x 138 | result["duration_prediction"] = duration_pred 139 | result["duration_z"] = duration_z 140 | result["duration_rounded"] = duration_rounded 141 | result["tgt_mask"] = tgt_mask 142 | result["out"] = out_val 143 | 144 | return result 145 | 146 | 147 | class FastDiffVariancePredictor(nn.Module): 148 | def __init__( 149 | self, 150 | nlayers, 151 | in_channels, 152 | filter_size, 153 | kernel_size, 154 | dropout, 155 | depthwise, 156 | diffusion_hyperparams, 157 | diffusion_step_embed_dim_in, 158 | diffusion_step_embed_dim_mid, 159 | diffusion_step_embed_dim_out, 160 | ): 161 | super().__init__() 162 | 163 | self.diffusion_hyperparams = diffusion_hyperparams 164 | 165 | self.linear_in = nn.Linear(1, in_channels) 166 | 167 | self.layers = nn.Sequential( 168 | *[ 169 | VarianceConvolutionLayer( 170 | in_channels, filter_size, kernel_size, dropout, depthwise 171 | ) 172 | for _ in range(nlayers) 173 | ] 174 | ) 175 | 176 | self.diffusion_step_embed_dim_in = diffusion_step_embed_dim_in 177 | 178 | self.fc_t = nn.ModuleList() 179 | self.fc_t1 = nn.Linear(diffusion_step_embed_dim_in, diffusion_step_embed_dim_mid) 180 | self.fc_t2 = nn.Linear(diffusion_step_embed_dim_mid, diffusion_step_embed_dim_out) 181 | 182 | self.linear = nn.Linear(filter_size, 1) 183 | 184 | self.linear_noise = nn.Linear(diffusion_step_embed_dim_out, in_channels) 185 | 186 | def forward(self, x, c, ts=None, mask=None): 187 | # print(x.shape, c.shape, "input shapes") 188 | 189 | if len(x.shape) == 2: 190 | B, L = x.shape # B is batchsize, C=1, L is audio length 191 | x = x.unsqueeze(1) 192 | if len(c.shape) == 2: 193 | c = c.unsqueeze(0) 194 | 195 | B, C, L = c.shape 196 | 197 | if ts is None: 198 | no_ts = True 199 | T, alpha = self.diffusion_hyperparams["T"], self.diffusion_hyperparams["alpha"].to(x.device) 200 | ts = torch.randint(T, size=(B, 1, 1)).to(x.device) # randomly sample steps from 1~T 201 | z = std_normal(x.shape, device=x.device).to(x.dtype) 202 | delta = (1 - alpha[ts] ** 2.).sqrt() 203 | alpha_cur = alpha[ts] 204 | noisy_audio = alpha_cur * x + delta * z # compute x_t from q(x_t|x_0) 205 | x = noisy_audio 206 | ts = ts.view(B, 1) 207 | else: 208 | no_ts = False 209 | 210 | # embed diffusion step t 211 | diffusion_step_embed = calc_diffusion_step_embedding(ts, self.diffusion_step_embed_dim_in) 212 | diffusion_step_embed = swish(self.fc_t1(diffusion_step_embed)) 213 | diffusion_step_embed = swish(self.fc_t2(diffusion_step_embed)) 214 | 215 | x = x.transpose(1,2) 216 | x = self.linear_in(x.to(diffusion_step_embed.dtype)) 217 | x = x.transpose(1,2) 218 | c = c.to(diffusion_step_embed.dtype) 219 | noise_embed = self.linear_noise(diffusion_step_embed).unsqueeze(1).transpose(1, 2) 220 | 221 | # print(x.shape, c.shape, noise_embed.shape, "forward shapes") 222 | # print(x.dtype, c.dtype, diffusion_step_embed.dtype) 223 | 224 | out_conv = self.layers( 225 | (x+c+noise_embed).transpose(1, 2) 226 | ) 227 | out = self.linear(out_conv) 228 | out = out.squeeze(-1) 229 | if mask is not None: 230 | out = out.masked_fill(mask, 0) 231 | 232 | if no_ts: 233 | return out, z 234 | else: 235 | return out 236 | 237 | def inference(self, c, N=4): 238 | """Inference with the given local conditioning auxiliary features. 239 | Args: 240 | c (Tensor): Local conditioning auxiliary features (B, C, T'). 241 | Returns: 242 | Tensor: Output tensor (B, out_channels, T) 243 | """ 244 | c = c.transpose(1, 2) 245 | 246 | reverse_step = N 247 | if reverse_step == 1000: 248 | noise_schedule = torch.linspace(0.000001, 0.01, 1000) 249 | elif reverse_step == 200: 250 | noise_schedule = torch.linspace(0.0001, 0.02, 200) 251 | elif reverse_step == 8: 252 | noise_schedule = [6.689325005027058e-07, 1.0033881153503899e-05, 0.00015496854030061513, 253 | 0.002387222135439515, 0.035597629845142365, 0.3681158423423767, 0.4735414385795593, 0.5] 254 | elif reverse_step == 6: 255 | noise_schedule = [1.7838445955931093e-06, 2.7984189728158526e-05, 0.00043231004383414984, 256 | 0.006634317338466644, 0.09357017278671265, 0.6000000238418579] 257 | elif reverse_step == 4: 258 | noise_schedule = [3.2176e-04, 2.5743e-03, 2.5376e-02, 7.0414e-01] 259 | elif reverse_step == 3: 260 | noise_schedule = [9.0000e-05, 9.0000e-03, 6.0000e-01] 261 | else: 262 | raise ValueError("Reverse step should be 3, 4, 6, 8, 200 or 1000.") 263 | 264 | if not isinstance(noise_schedule, torch.Tensor): 265 | noise_schedule = torch.FloatTensor(noise_schedule).to(c.dtype) 266 | noise_schedule = noise_schedule.to(c.device) 267 | 268 | audio_length = c.shape[-1] 269 | 270 | # print(c.shape, "c shape, inference") 271 | 272 | pred_wav = sampling_given_noise_schedule( 273 | self, 274 | (c.shape[0], audio_length), 275 | self.diffusion_hyperparams, 276 | noise_schedule, 277 | condition=c, 278 | ddim=False, 279 | return_sequence=False, 280 | device=c.device, 281 | ) 282 | 283 | # pred_wav = pred_wav / pred_wav.abs().max(axis=1, keepdim=True)[0] 284 | # pred_wav = pred_wav.view(-1) 285 | return pred_wav 286 | 287 | class FastDiffVarianceEncoder(nn.Module): 288 | def __init__( 289 | self, 290 | nlayers, 291 | in_channels, 292 | filter_size, 293 | kernel_size, 294 | dropout, 295 | depthwise, 296 | min, 297 | max, 298 | mean, 299 | std, 300 | nbins, 301 | diffusion_hyperparams, 302 | diffusion_step_embed_dim_in, 303 | diffusion_step_embed_dim_mid, 304 | diffusion_step_embed_dim_out, 305 | ): 306 | super().__init__() 307 | self.predictor = FastDiffVariancePredictor( 308 | nlayers, 309 | in_channels, 310 | filter_size, 311 | kernel_size, 312 | dropout, 313 | depthwise, 314 | diffusion_hyperparams, 315 | diffusion_step_embed_dim_in, 316 | diffusion_step_embed_dim_mid, 317 | diffusion_step_embed_dim_out, 318 | ) 319 | self.bins = nn.Parameter( 320 | torch.linspace(min, max, nbins - 1), 321 | requires_grad=False, 322 | ) 323 | self.embedding = nn.Embedding(nbins, in_channels) 324 | self.mean = mean 325 | self.std = std 326 | 327 | def forward(self, x, tgt, mask, N=4, control=1.0): 328 | if tgt is not None: 329 | # training 330 | noise_pred, z = self.predictor(tgt, x, mask=mask) 331 | tgt = tgt * self.std + self.mean 332 | embedding = self.embedding(torch.bucketize(tgt, self.bins).to(x.device)) 333 | return (noise_pred, z), embedding 334 | else: 335 | # inference 336 | prediction = self.predictor.inference(x, N=N) 337 | bucket_prediction = prediction * self.std + self.mean 338 | prediction = prediction * control 339 | embedding = self.embedding( 340 | torch.bucketize(bucket_prediction, self.bins).to(x.device) 341 | ) 342 | return prediction, embedding 343 | 344 | class FastDiffSpeakerGenerator(nn.Module): 345 | def __init__( 346 | self, 347 | hidden_dim, 348 | c_dim, 349 | speaker_embed_dim, 350 | diffusion_step_embed_dim_in=128, 351 | diffusion_step_embed_dim_mid=512, 352 | diffusion_step_embed_dim_out=512, 353 | beta_0=1e-6, 354 | beta_T=0.01, 355 | T=1000, 356 | ): 357 | super().__init__() 358 | self.noise_schedule = torch.linspace(beta_0, beta_T, T) 359 | self.diffusion_hyperparams = compute_hyperparams_given_schedule(self.noise_schedule) 360 | 361 | self.predictor = FastDiffSpeakerPredictor( 362 | hidden_dim, 363 | c_dim, 364 | speaker_embed_dim, 365 | diffusion_step_embed_dim_in, 366 | diffusion_step_embed_dim_mid, 367 | diffusion_step_embed_dim_out, 368 | beta_0, 369 | beta_T, 370 | T, 371 | ) 372 | 373 | def forward( 374 | self, 375 | x, 376 | dvec=None, 377 | inference=False, 378 | N=4, 379 | ): 380 | if inference: 381 | # inference 382 | prediction = self.predictor.inference(x, N=N) 383 | return prediction 384 | else: 385 | # training 386 | noise_pred, z = self.predictor(dvec, x) 387 | return noise_pred, z 388 | 389 | 390 | class FastDiffSpeakerPredictor(nn.Module): 391 | def __init__( 392 | self, 393 | hidden_dim, 394 | c_dim, 395 | speaker_embed_dim, 396 | diffusion_step_embed_dim_in=128, 397 | diffusion_step_embed_dim_mid=512, 398 | diffusion_step_embed_dim_out=512, 399 | beta_0=1e-6, 400 | beta_T=0.01, 401 | T=1000, 402 | ): 403 | super().__init__() 404 | 405 | self.noise_schedule = torch.linspace(beta_0, beta_T, T) 406 | diffusion_hyperparams = compute_hyperparams_given_schedule(self.noise_schedule) 407 | 408 | self.diffusion_hyperparams = diffusion_hyperparams 409 | 410 | self.mlp = nn.Sequential( 411 | *[ 412 | nn.Linear(speaker_embed_dim, hidden_dim), 413 | nn.ReLU(), 414 | nn.Linear(hidden_dim, hidden_dim), 415 | nn.ReLU(), 416 | ] 417 | ) 418 | 419 | self.conditional_in = nn.Linear(c_dim, speaker_embed_dim) 420 | self.linear_out = nn.Linear(hidden_dim, speaker_embed_dim) 421 | 422 | self.diffusion_step_embed_dim_in = diffusion_step_embed_dim_in 423 | 424 | self.fc_t = nn.ModuleList() 425 | self.fc_t1 = nn.Linear(diffusion_step_embed_dim_in, diffusion_step_embed_dim_mid) 426 | self.fc_t2 = nn.Linear(diffusion_step_embed_dim_mid, diffusion_step_embed_dim_out) 427 | 428 | self.linear_noise = nn.Linear(diffusion_step_embed_dim_out, speaker_embed_dim) 429 | 430 | def forward(self, x, c, ts=None, mask=None): 431 | # print(x.shape, c.shape, "input shapes") 432 | 433 | B, C = c.shape 434 | 435 | if ts is None: 436 | no_ts = True 437 | T, alpha = self.diffusion_hyperparams["T"], self.diffusion_hyperparams["alpha"].to(x.device) 438 | ts = torch.randint(T, size=(B, 1, 1)).to(x.device) # randomly sample steps from 1~T 439 | x = x.unsqueeze(-1) 440 | z = std_normal(x.shape, device=x.device).to(x.dtype) 441 | delta = (1 - alpha[ts] ** 2.).sqrt() 442 | alpha_cur = alpha[ts] 443 | noisy_audio = alpha_cur * x + delta * z # compute x_t from q(x_t|x_0) 444 | x = noisy_audio 445 | x = x.squeeze(-1) 446 | z = z.squeeze(-1) 447 | ts = ts.view(B, 1) 448 | else: 449 | no_ts = False 450 | 451 | # embed diffusion step t 452 | diffusion_step_embed = calc_diffusion_step_embedding(ts, self.diffusion_step_embed_dim_in) 453 | diffusion_step_embed = swish(self.fc_t1(diffusion_step_embed)) 454 | diffusion_step_embed = swish(self.fc_t2(diffusion_step_embed)) 455 | 456 | # x = x.transpose(1,2) 457 | # x = self.mlp(x.to(diffusion_step_embed.dtype)) 458 | # x = x.transpose(1,2) 459 | c = self.conditional_in(c.to(diffusion_step_embed.dtype)) # TODO: investigate attention here 460 | noise_embed = self.linear_noise(diffusion_step_embed) 461 | 462 | # print(x.shape, c.shape, noise_embed.shape, "forward shapes") 463 | # print(x.dtype, c.dtype, diffusion_step_embed.dtype) 464 | 465 | out_conv = self.mlp( 466 | (x+c+noise_embed) 467 | ) 468 | out = self.linear_out(out_conv) 469 | out = out.squeeze(-1) 470 | if mask is not None: 471 | out = out.masked_fill(mask, 0) 472 | 473 | if no_ts: 474 | return out, z 475 | else: 476 | return out 477 | 478 | def inference(self, c, N=4): 479 | """Inference with the given local conditioning auxiliary features. 480 | Args: 481 | c (Tensor): Local conditioning auxiliary features (B, C, T'). 482 | Returns: 483 | Tensor: Output tensor (B, out_channels, T) 484 | """ 485 | 486 | reverse_step = N 487 | if reverse_step == 1000: 488 | noise_schedule = torch.linspace(0.000001, 0.01, 1000) 489 | elif reverse_step == 200: 490 | noise_schedule = torch.linspace(0.0001, 0.02, 200) 491 | elif reverse_step == 8: 492 | noise_schedule = [6.689325005027058e-07, 1.0033881153503899e-05, 0.00015496854030061513, 493 | 0.002387222135439515, 0.035597629845142365, 0.3681158423423767, 0.4735414385795593, 0.5] 494 | elif reverse_step == 6: 495 | noise_schedule = [1.7838445955931093e-06, 2.7984189728158526e-05, 0.00043231004383414984, 496 | 0.006634317338466644, 0.09357017278671265, 0.6000000238418579] 497 | elif reverse_step == 4: 498 | noise_schedule = [3.2176e-04, 2.5743e-03, 2.5376e-02, 7.0414e-01] 499 | elif reverse_step == 3: 500 | noise_schedule = [9.0000e-05, 9.0000e-03, 6.0000e-01] 501 | else: 502 | raise ValueError("Reverse step should be 3, 4, 6, 8, 200 or 1000.") 503 | 504 | if not isinstance(noise_schedule, torch.Tensor): 505 | noise_schedule = torch.FloatTensor(noise_schedule).to(c.dtype) 506 | noise_schedule = noise_schedule.to(c.device) 507 | 508 | audio_length = c.shape[-1] 509 | 510 | # print(c.shape, "c shape, inference") 511 | 512 | pred_wav = sampling_given_noise_schedule( 513 | self, 514 | (c.shape[0], audio_length), 515 | self.diffusion_hyperparams, 516 | noise_schedule, 517 | condition=c, 518 | ddim=False, 519 | return_sequence=False, 520 | device=c.device 521 | ) 522 | 523 | # pred_wav = pred_wav / pred_wav.abs().max(axis=1, keepdim=True)[0] 524 | # pred_wav = pred_wav.view(-1) 525 | return pred_wav -------------------------------------------------------------------------------- /litfass/fastspeech2/log_gmm.py: -------------------------------------------------------------------------------- 1 | from sklearn.mixture import GaussianMixture 2 | import numpy as np 3 | from copy import copy 4 | from sklearn.preprocessing import StandardScaler 5 | 6 | # TODO: add standard scaler 7 | 8 | class LogGMM(): 9 | def __init__(self, *args, **kwargs): 10 | if "logs" in kwargs: 11 | self.logs = kwargs["logs"] 12 | del kwargs["logs"] 13 | else: 14 | self.logs = [] 15 | if "eps" in kwargs: 16 | self.eps = kwargs["eps"] 17 | del kwargs["eps"] 18 | else: 19 | self.eps = 1e-10 20 | self.max_vals = None 21 | self.gmm = GaussianMixture(*args, **kwargs) 22 | 23 | def _create_log_x(self, X): 24 | X = np.array(copy(X)) 25 | if self.max_vals is None: 26 | self.max_vals = np.max(X, axis=0) 27 | X = X / self.max_vals + self.eps 28 | for i in range(X.shape[1]): 29 | if i in self.logs: 30 | X[:, i] = np.log(X[:, i]) 31 | return X 32 | 33 | def fit(self, X, y=None): 34 | X = self._create_log_x(X) 35 | return self.gmm.fit(X, y) 36 | 37 | def fit_predict(self, X, y=None): 38 | X = self._create_log_x(X) 39 | return self.gmm.fit_predict(X, y) 40 | 41 | def predict(self, X): 42 | X = self._create_log_x(X) 43 | return self.gmm.predict(X) 44 | 45 | def predict_proba(self, X): 46 | X = self._create_log_x(X) 47 | return self.gmm.predict_proba(X) 48 | 49 | def score_samples(self, X): 50 | X = self._create_log_x(X) 51 | return self.gmm.score_samples(X) 52 | 53 | def score(self, X, y=None): 54 | X = self._create_log_x(X) 55 | return self.gmm.score(X, y) 56 | 57 | def bic(self, X): 58 | X = self._create_log_x(X) 59 | return self.gmm.bic(X) 60 | 61 | def aic(self, X): 62 | X = self._create_log_x(X) 63 | return self.gmm.aic(X) 64 | 65 | def sample(self, n_samples=1, random_state=None): 66 | np.random.seed(random_state) 67 | X, comp = self.gmm.sample(n_samples) 68 | for i in range(X.shape[1]): 69 | if i in self.logs: 70 | X[:, i] = (np.exp(X[:, i])-self.eps)*self.max_vals[i] 71 | else: 72 | X[:, i] = (X[:, i]-self.eps)*self.max_vals[i] 73 | return X, comp -------------------------------------------------------------------------------- /litfass/fastspeech2/loss.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | 4 | #from litfass.third_party.softdtw.sdtw_cuda_loss import SoftDTW 5 | from pysdtw import SoftDTW 6 | 7 | class FastSpeech2Loss(nn.Module): 8 | def __init__( 9 | self, 10 | variances=["energy", "pitch", "snr"], 11 | variance_levels=["phone", "phone", "phone"], 12 | variance_transforms=["cwt", "none", "none"], 13 | variance_losses=["mse", "mse", "mse"], 14 | mel_loss="l1", 15 | duration_loss="mse", 16 | duration_stochastic=False, 17 | max_length=4096, 18 | loss_alphas={ 19 | "mel": 1.0, 20 | "pitch": 1e-1, 21 | "energy": 1e-1, 22 | "snr": 1e-1, 23 | "duration": 1e-4, 24 | "fastdiff": 1e-1, 25 | "speakers": 1, 26 | }, 27 | soft_dtw_gamma=0.01, 28 | soft_dtw_chunk_size=256, 29 | fastdiff_loss=None, 30 | fastdiff_variances=False, 31 | ): 32 | super().__init__() 33 | self.losses = { 34 | "mse": nn.MSELoss(), 35 | "l1": nn.L1Loss(), 36 | "soft_dtw": SoftDTW(use_cuda=True, gamma=soft_dtw_gamma), 37 | } 38 | self.variances = variances 39 | self.variance_levels = variance_levels 40 | self.variance_transforms = variance_transforms 41 | self.variance_losses = variance_losses 42 | self.duration_stochastic = duration_stochastic 43 | self.mel_loss = mel_loss 44 | self.bce_loss = nn.BCEWithLogitsLoss() 45 | self.duration_loss = duration_loss 46 | self.max_length = max_length 47 | self.loss_alphas = loss_alphas 48 | self.fastdiff_loss = fastdiff_loss 49 | self.fastdiff_variances = fastdiff_variances 50 | self.soft_dtw_chunk_size = soft_dtw_chunk_size 51 | for i, var in enumerate(self.variances): 52 | if self.variance_transforms[i] == "cwt": 53 | self.loss_alphas[var + "_cwt"] = self.loss_alphas[var] 54 | self.loss_alphas[var + "_mean"] = self.loss_alphas[var] 55 | self.loss_alphas[var + "_std"] = self.loss_alphas[var] 56 | 57 | def get_loss(self, pred, truth, loss, mask, unsqueeze=False): 58 | truth.requires_grad = False 59 | if loss == "soft_dtw" and len(pred.shape) == 2: 60 | pred = pred.unsqueeze(-1) 61 | truth = truth.unsqueeze(-1) 62 | if unsqueeze or loss == "soft_dtw": 63 | mask = mask.unsqueeze(-1) 64 | if loss != "soft_dtw": 65 | pred = pred.masked_select(mask) 66 | truth = truth.masked_select(mask) 67 | loss_func = self.losses[loss] 68 | if loss == "soft_dtw": 69 | pred = pred.masked_fill_(~mask, 0) 70 | truth = truth.masked_fill_(~mask, 0) 71 | pred_chunks = pred.split(self.soft_dtw_chunk_size, dim=1) 72 | truth_chunks = truth.split(self.soft_dtw_chunk_size, dim=1) 73 | for i, (pred_chunk, truth_chunk) in enumerate(zip(pred_chunks, truth_chunks)): 74 | if i == 0: 75 | loss = loss_func(pred_chunk, truth_chunk) 76 | else: 77 | loss += loss_func(pred_chunk, truth_chunk) 78 | loss = loss.sum() 79 | else: 80 | loss = loss_func(pred, truth) 81 | return loss 82 | 83 | def forward(self, result, target, frozen_components=[]): 84 | 85 | losses = {} 86 | 87 | variances_pred = {var: result[f"variances_{var}"] for var in self.variances} 88 | 89 | variances_target = { 90 | var: target[f"variances_{var}"] 91 | for var in self.variances 92 | if f"variances_{var}" in target 93 | } 94 | 95 | src_mask = ~result["src_mask"] 96 | tgt_mask = ~result["tgt_mask"] 97 | 98 | # VARIANCE LOSSES 99 | if self.max_length is not None: 100 | assert target["mel"].shape[1] <= self.max_length 101 | 102 | for variance, level, transform, loss in zip( 103 | self.variances, self.variance_levels, self.variance_transforms, self.variance_losses 104 | ): 105 | if self.fastdiff_variances: 106 | variances_target[variance] = result[f"variances_{variance}_z"].squeeze(1) 107 | variances_pred[variance] = result[f"variances_{variance}"] 108 | # print(variances_pred[variance].shape, variances_target[variance].shape, tgt_mask.shape, "variance loss shapes") 109 | losses[variance] = self.get_loss( 110 | variances_pred[variance], 111 | variances_target[variance].to(dtype=result["mel"].dtype), 112 | "mse", 113 | tgt_mask, 114 | ) 115 | continue 116 | if transform == "cwt": 117 | variances_target[variance] = target[ 118 | f"variances_{variance}_spectrogram" 119 | ] 120 | variances_pred[variance] = result[f"variances_{variance}"][ 121 | "spectrogram" 122 | ] 123 | if level == "frame": 124 | if transform != "cwt": 125 | variances_target[variance] = variances_target[variance][ 126 | :, :int(self.max_length) 127 | ] 128 | variance_mask = tgt_mask 129 | elif level == "phone": 130 | variance_mask = src_mask 131 | else: 132 | raise ValueError("Unknown variance level: {}".format(level)) 133 | if transform == "cwt": 134 | losses[variance + "_cwt"] = self.get_loss( 135 | variances_pred[variance], 136 | variances_target[variance].to(dtype=result["mel"].dtype), 137 | loss, 138 | variance_mask, 139 | unsqueeze=True, 140 | ) 141 | losses[variance + "_mean"] = self.mse_loss( 142 | result[f"variances_{variance}"]["mean"], 143 | torch.tensor(target[f"variances_{variance}_mean"]).to( 144 | result[f"variances_{variance}"]["mean"].device, 145 | dtype=result["mel"].dtype, 146 | ), 147 | ) 148 | losses[variance + "_std"] = self.mse_loss( 149 | result[f"variances_{variance}"]["std"], 150 | torch.tensor(target[f"variances_{variance}_std"]).to( 151 | result[f"variances_{variance}"]["std"].device, 152 | dtype=result["mel"].dtype, 153 | ), 154 | ) 155 | else: 156 | losses[variance] = self.get_loss( 157 | variances_pred[variance], 158 | variances_target[variance].to(dtype=result["mel"].dtype), 159 | loss, 160 | variance_mask, 161 | ) 162 | 163 | # MEL SPECTROGRAM LOSS 164 | losses["mel"] = self.get_loss( 165 | result["mel"], 166 | target["mel"].to(dtype=result["mel"].dtype), 167 | self.mel_loss, 168 | tgt_mask, 169 | unsqueeze=True, 170 | ) 171 | 172 | # DURATION LOSS 173 | if self.fastdiff_variances: 174 | # print(result["duration_prediction"].shape, result["duration_z"].shape, src_mask.shape, "duration loss shapes") 175 | losses["duration"] = self.get_loss( 176 | result["duration_prediction"].to(dtype=result["mel"].dtype), 177 | result["duration_z"].to(dtype=result["mel"].dtype).squeeze(1), 178 | "mse", 179 | src_mask, 180 | ) 181 | elif not self.duration_stochastic: 182 | losses["duration"] = self.get_loss( 183 | result["duration_prediction"], 184 | torch.log(target["duration"] + 1).to(dtype=result["mel"].dtype), 185 | self.duration_loss, 186 | src_mask, 187 | ) 188 | else: 189 | losses["duration"] = torch.sum(result["duration_prediction"]) 190 | 191 | # FASTDIFF LOSS 192 | if self.fastdiff_loss is not None: 193 | losses["fastdiff"] = self.get_loss( 194 | result["fastdiff"][0], 195 | result["fastdiff"][1], 196 | self.fastdiff_loss, 197 | result["wav_mask"], 198 | ) 199 | 200 | if "speaker_z" in result: 201 | losses["speakers"] = self.losses[self.fastdiff_loss](result["speaker_pred"], result["speaker_z"]) 202 | 203 | # TOTAL LOSS 204 | total_loss = sum( 205 | [ 206 | v * self.loss_alphas[k] 207 | for k, v in losses.items() 208 | if not any(f in k for f in frozen_components) 209 | ] 210 | ) 211 | losses["total"] = total_loss 212 | 213 | return losses 214 | 215 | # TODO_NEXT_TIME: just get it to not crash and burn like it does now -------------------------------------------------------------------------------- /litfass/fastspeech2/model.py: -------------------------------------------------------------------------------- 1 | from cmath import inf 2 | import math 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | 8 | from torch.nn import TransformerEncoderLayer 9 | from torch.nn.utils.rnn import pad_sequence 10 | from einops.layers.torch import Rearrange, Reduce 11 | 12 | from litfass.third_party.stochastic_duration_predictor.sdp import StochasticDurationPredictor 13 | from litfass.dataset.cwt import CWT 14 | 15 | 16 | def generate_square_subsequent_mask(sz): 17 | mask = (torch.triu(torch.ones((sz, sz))) == 1).transpose(0, 1) 18 | mask = ( 19 | mask.float() 20 | .masked_fill(mask == 0, float("-inf")) 21 | .masked_fill(mask == 1, float(0.0)) 22 | ) 23 | return mask 24 | 25 | 26 | def create_mask(src, tgt, pad_idx): 27 | src_seq_len = src.shape[1] 28 | tgt_seq_len = tgt.shape[1] 29 | 30 | tgt_mask = generate_square_subsequent_mask(tgt_seq_len) 31 | src_mask = torch.zeros((src_seq_len, src_seq_len)).type(torch.bool) 32 | 33 | src_padding_mask = src == pad_idx 34 | tgt_padding_mask = tgt == pad_idx 35 | return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask 36 | 37 | 38 | class PositionalEncoding(nn.Module): 39 | def __init__(self, d_model, max_len=5000, dropout=0.1): 40 | super().__init__() 41 | self.dropout = nn.Dropout(p=dropout) 42 | 43 | pe = torch.zeros(max_len, d_model) 44 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 45 | div_term = torch.exp( 46 | torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) 47 | ) 48 | pe[:, 0::2] = torch.sin(position * div_term) 49 | pe[:, 1::2] = torch.cos(position * div_term) 50 | pe = pe.unsqueeze(0) 51 | self.register_buffer("pe", pe) 52 | 53 | def forward(self, x): 54 | x = x + self.pe[:, : x.size(1), :].to(x.dtype) 55 | return self.dropout(x) 56 | 57 | 58 | class Transpose(nn.Module): 59 | def __init__(self, module): 60 | super().__init__() 61 | self.module = module 62 | 63 | def forward(self, x): 64 | return self.module(x.transpose(1, 2)).transpose(1, 2) 65 | 66 | 67 | class ConformerEncoderLayer(TransformerEncoderLayer): 68 | def __init__(self, *args, **kwargs): 69 | old_kwargs = {k: v for k, v in kwargs.items() if "conv_" not in k} 70 | super().__init__(*args, **old_kwargs) 71 | del self.linear1 72 | del self.linear2 73 | if "conv_depthwise" in kwargs and kwargs["conv_depthwise"]: 74 | self.conv1 = nn.Sequential( 75 | nn.Conv1d( 76 | kwargs["conv_in"], 77 | kwargs["conv_in"], 78 | kernel_size=kwargs["conv_kernel"][0], 79 | padding="same", 80 | groups=kwargs["conv_in"], 81 | ), 82 | nn.Conv1d(kwargs["conv_in"], kwargs["conv_filter_size"], 1), 83 | ) 84 | self.conv2 = nn.Sequential( 85 | nn.Conv1d( 86 | kwargs["conv_filter_size"], 87 | kwargs["conv_filter_size"], 88 | kernel_size=kwargs["conv_kernel"][1], 89 | padding="same", 90 | groups=kwargs["conv_in"], 91 | ), 92 | nn.Conv1d(kwargs["conv_filter_size"], kwargs["conv_in"], 1), 93 | ) 94 | else: 95 | self.conv1 = nn.Conv1d( 96 | kwargs["conv_in"], 97 | kwargs["conv_filter_size"], 98 | kernel_size=kwargs["conv_kernel"][0], 99 | padding="same", 100 | ) 101 | self.conv2 = nn.Conv1d( 102 | kwargs["conv_filter_size"], 103 | kwargs["conv_in"], 104 | kernel_size=kwargs["conv_kernel"][1], 105 | padding="same", 106 | ) 107 | 108 | def forward(self, src, src_mask=None, src_key_padding_mask=None): 109 | x = src 110 | if self.norm_first: 111 | x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask) 112 | x = x + self._ff_block(self.norm2(x)) 113 | else: 114 | x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask)) 115 | x = self.norm2(x + self._ff_block(x)) 116 | return x 117 | 118 | def _ff_block(self, x): 119 | x = self.conv2( 120 | self.dropout(self.activation(self.conv1(x.transpose(1, 2)))) 121 | ).transpose(1, 2) 122 | return self.dropout2(x) 123 | 124 | 125 | class SpeakerEmbedding(nn.Module): 126 | def __init__(self, embedding_dim, speaker_type, nspeakers=None): 127 | super().__init__() 128 | self.speaker_type = speaker_type 129 | self.embedding_dim = embedding_dim 130 | if "dvector" in speaker_type: 131 | self.projection = nn.Linear(256, embedding_dim) 132 | self.has_projection = True 133 | elif speaker_type == "id": 134 | self.speaker_embedding = nn.Embedding(nspeakers, embedding_dim) 135 | self.relu = nn.ReLU() 136 | 137 | def forward(self, x, input_length, output_shape): 138 | if self.has_projection: 139 | out = self.projection(x) 140 | else: 141 | out = self.speaker_embedding(x) 142 | out = self.relu(out) 143 | return out.reshape(-1, 1, output_shape).repeat_interleave(input_length, dim=1) 144 | 145 | 146 | class PriorEmbedding(nn.Module): 147 | def __init__(self, embedding_dim, nbins, stats): 148 | super().__init__() 149 | self.embedding_dim = embedding_dim 150 | self.bins = nn.Parameter( 151 | torch.linspace(stats["min"], stats["max"], nbins - 1), 152 | requires_grad=False, 153 | ) 154 | self.embedding = nn.Embedding( 155 | nbins, 156 | embedding_dim, 157 | ) 158 | self.relu = nn.ReLU() 159 | 160 | def forward(self, x, input_length): 161 | out = self.relu(self.embedding(torch.bucketize(x, self.bins))) 162 | return out.reshape(-1, 1, self.embedding_dim).repeat_interleave( 163 | input_length, dim=1 164 | ) 165 | 166 | 167 | class VarianceAdaptor(nn.Module): 168 | def __init__( 169 | self, 170 | stats, 171 | variances, 172 | variance_levels, 173 | variance_transforms, 174 | variance_nlayers, 175 | variance_kernel_size, 176 | variance_dropout, 177 | variance_filter_size, 178 | variance_nbins, 179 | variance_depthwise_conv, 180 | duration_nlayers, 181 | duration_stochastic, 182 | duration_kernel_size, 183 | duration_dropout, 184 | duration_filter_size, 185 | duration_depthwise_conv, 186 | encoder_hidden, 187 | max_length, 188 | ): 189 | super().__init__() 190 | self.variances = variances 191 | self.variance_levels = variance_levels 192 | self.variance_transforms = variance_transforms 193 | self.duration_stochastic = duration_stochastic 194 | self.max_length = max_length 195 | 196 | if self.duration_stochastic: 197 | if duration_depthwise_conv: 198 | raise NotImplementedError( 199 | "Depthwise convolution not implemented for Flow-Based duration prediction" 200 | ) 201 | self.duration_predictor = StochasticDurationPredictorWrapper( 202 | duration_nlayers, 203 | encoder_hidden, 204 | duration_filter_size, 205 | duration_kernel_size, 206 | duration_dropout, 207 | ) 208 | else: 209 | self.duration_predictor = VariancePredictor( 210 | duration_nlayers, 211 | encoder_hidden, 212 | duration_filter_size, 213 | duration_kernel_size, 214 | duration_dropout, 215 | duration_depthwise_conv, 216 | ) 217 | 218 | self.length_regulator = LengthRegulator() 219 | 220 | self.encoders = {} 221 | for var in self.variances: 222 | self.encoders[var] = VarianceEncoder( 223 | variance_nlayers[variances.index(var)], 224 | encoder_hidden, 225 | variance_filter_size, 226 | variance_kernel_size[variances.index(var)], 227 | variance_dropout[variances.index(var)], 228 | variance_depthwise_conv, 229 | stats[var]["min"], 230 | stats[var]["max"], 231 | stats[var]["mean"], 232 | stats[var]["std"], 233 | variance_nbins, 234 | cwt=variance_transforms[variances.index(var)] == "cwt", 235 | ) 236 | self.encoders = nn.ModuleDict(self.encoders) 237 | 238 | self.frozen_components = [] 239 | 240 | def freeze(self, component): 241 | if component == "duration": 242 | for param in self.duration_predictor.parameters(): 243 | param.requires_grad = False 244 | else: 245 | for param in self.encoders[component].parameters(): 246 | param.requires_grad = False 247 | self.frozen_components.append(component) 248 | 249 | def forward( 250 | self, 251 | x, 252 | src_mask, 253 | targets, 254 | inference=False, 255 | tf_ratio=1.0, 256 | oracles=[], 257 | ): 258 | if not self.duration_stochastic: 259 | duration_pred = self.duration_predictor(x, src_mask) 260 | else: 261 | if not inference: 262 | duration_pred = self.duration_predictor( 263 | x.detach(), src_mask, targets["duration"] 264 | ) 265 | else: 266 | duration_pred = self.duration_predictor( 267 | x.detach(), src_mask, inference=True 268 | ) 269 | 270 | result = {} 271 | 272 | tf_val = np.random.uniform(0, 1) <= tf_ratio 273 | 274 | out_val = None 275 | 276 | for i, var in enumerate(self.variances): 277 | if self.variance_levels[i] == "phone": 278 | if (not inference and tf_val) or var in oracles: 279 | if self.variance_transforms[i] == "cwt": 280 | pred, out = self.encoders[var]( 281 | x, targets[f"variances_{var}_signal"], src_mask 282 | ) 283 | else: 284 | pred, out = self.encoders[var]( 285 | x, targets[f"variances_{var}"], src_mask 286 | ) 287 | else: 288 | pred, out = self.encoders[var](x, None, src_mask) 289 | result[f"variances_{var}"] = pred 290 | if out_val is None: 291 | out_val = out 292 | else: 293 | out_val = out_val + out 294 | x = x + out 295 | 296 | if not inference: 297 | duration_rounded = targets["duration"] 298 | else: 299 | if not self.duration_stochastic: 300 | duration_rounded = torch.round((torch.exp(duration_pred) - 1)) 301 | else: 302 | duration_rounded = torch.ceil( 303 | (torch.exp(duration_pred + 1e-9)) 304 | ).masked_fill(duration_pred == 0, 0) 305 | duration_rounded = torch.clamp(duration_rounded, min=0).int() 306 | for i in range(len(duration_rounded)): 307 | if duration_rounded[i][~src_mask[i]].sum() <= (~src_mask[i]).sum() // 2: 308 | duration_rounded[i][~src_mask[i]] = 1 309 | print("Zero duration, setting to 1") 310 | 311 | x, tgt_mask = self.length_regulator(x, duration_rounded, self.max_length) 312 | if out_val is not None: 313 | out_val, _ = self.length_regulator(out_val, duration_rounded, self.max_length) 314 | 315 | for i, var in enumerate(self.variances): 316 | if self.variance_levels[i] == "frame": 317 | if (not inference and tf_val) or var in oracles: 318 | if self.variance_transforms[i] == "cwt": 319 | pred, out = self.encoders[var]( 320 | x, targets[f"variances_{var}_signal"], tgt_mask 321 | ) 322 | else: 323 | pred, out = self.encoders[var]( 324 | x, targets[f"variances_{var}"], tgt_mask 325 | ) 326 | else: 327 | pred, out = self.encoders[var](x, None, tgt_mask) 328 | result[f"variances_{var}"] = pred 329 | if out_val is None: 330 | out_val = out 331 | else: 332 | out_val = out_val + out 333 | x = x + out 334 | 335 | result["x"] = x 336 | result["duration_prediction"] = duration_pred 337 | result["duration_rounded"] = duration_rounded 338 | result["tgt_mask"] = tgt_mask 339 | result["out"] = out_val 340 | 341 | return result 342 | 343 | 344 | class LengthRegulator(nn.Module): 345 | def __init__(self, pad_to_multiple_of=None): 346 | super().__init__() 347 | self.pad_to_multiple_of = pad_to_multiple_of 348 | 349 | def forward(self, x, durations, max_length=None): 350 | repeated_list = [ 351 | torch.repeat_interleave(x[i], durations[i], dim=0) 352 | for i in range(x.shape[0]) 353 | ] 354 | lengths = torch.tensor([t.shape[0] for t in repeated_list]).long() 355 | max_length = min(lengths.max(), int(max_length)) 356 | if self.pad_to_multiple_of is not None: 357 | max_length = int((np.ceil(max_length / self.pad_to_multiple_of) * self.pad_to_multiple_of).item()) 358 | mask = ~( 359 | torch.arange(max_length).expand(len(lengths), max_length) 360 | < lengths.unsqueeze(1) 361 | ).to(x.device) 362 | if self.pad_to_multiple_of is not None: 363 | if len(repeated_list[0].shape) == 1: 364 | repeated_list[0] = nn.ConstantPad1d((0, max_length - repeated_list[0].shape[0]), 0)(repeated_list[0]) 365 | elif len(repeated_list[0].shape) == 2: 366 | repeated_list[0] = nn.ConstantPad2d((0, 0, 0, max_length - repeated_list[0].shape[0]), 0)(repeated_list[0]) 367 | out = pad_sequence(repeated_list, batch_first=True, padding_value=0) 368 | if max_length is not None: 369 | out = out[:, :max_length] 370 | return out, mask 371 | 372 | 373 | class VarianceEncoder(nn.Module): 374 | def __init__( 375 | self, 376 | nlayers, 377 | in_channels, 378 | filter_size, 379 | kernel_size, 380 | dropout, 381 | depthwise, 382 | min, 383 | max, 384 | mean, 385 | std, 386 | nbins, 387 | cwt, 388 | ): 389 | super().__init__() 390 | self.cwt = cwt 391 | self.predictor = VariancePredictor( 392 | nlayers, in_channels, filter_size, kernel_size, dropout, depthwise, cwt 393 | ) 394 | if cwt: 395 | min = np.log(min) 396 | max = np.log(max) 397 | self.bins = nn.Parameter( 398 | torch.linspace(min, max, nbins - 1), 399 | requires_grad=False, 400 | ) 401 | self.embedding = nn.Embedding(nbins, in_channels) 402 | if cwt: 403 | self.mean_std_linear = nn.Linear(filter_size, 2) 404 | self.cwt_obj = CWT() 405 | 406 | self.mean = mean 407 | self.std = std 408 | 409 | def forward(self, x, tgt, mask, control=1.0): 410 | if not self.cwt: 411 | prediction = self.predictor(x, mask) 412 | else: 413 | prediction, out_conv = self.predictor(x, mask, return_conv=True) 414 | mean_std = self.mean_std_linear(torch.mean(out_conv, axis=1)) 415 | mean, std = mean_std[:, 0], mean_std[:, 1] 416 | 417 | if tgt is not None: 418 | if self.cwt: 419 | tgt = torch.log(tgt) 420 | else: 421 | tgt = tgt * self.std + self.mean 422 | embedding = self.embedding(torch.bucketize(tgt, self.bins).to(x.device)) 423 | else: 424 | if self.cwt: 425 | tmp_prediction = [] 426 | for i in range(len(prediction)): 427 | tmp_prediction.append( 428 | self.cwt_obj.recompose(prediction[i].T, mean[i], std[i]) 429 | ) 430 | spectrogram = prediction 431 | prediction = torch.stack(tmp_prediction) 432 | bucket_prediction = prediction 433 | else: 434 | bucket_prediction = prediction * self.std + self.mean 435 | prediction = prediction * control 436 | embedding = self.embedding( 437 | torch.bucketize(bucket_prediction, self.bins).to(x.device) 438 | ) 439 | 440 | if not self.cwt: 441 | return prediction, embedding 442 | else: 443 | if tgt is not None: 444 | return ( 445 | { 446 | "spectrogram": prediction, 447 | "mean": mean, 448 | "std": std, 449 | }, 450 | embedding, 451 | ) 452 | else: 453 | return ( 454 | { 455 | "reconstructed_signal": torch.exp(prediction), 456 | "spectrogram": spectrogram, 457 | "mean": mean, 458 | "std": std, 459 | }, 460 | embedding, 461 | ) 462 | 463 | class StochasticDurationPredictorWrapper(nn.Module): 464 | def __init__(self, nlayers, in_channels, filter_size, kernel_size, dropout): 465 | super().__init__() 466 | 467 | self.sdp = StochasticDurationPredictor( 468 | in_channels, 469 | filter_size, 470 | kernel_size, 471 | dropout, 472 | nlayers, 473 | ) 474 | 475 | def forward(self, x, mask, tgt=None, sigma=1.0, inference=False): 476 | out = self.sdp(x, mask, tgt, reverse=inference, noise_scale=sigma) 477 | if mask is not None and inference: 478 | out = out.masked_fill(mask, 0) 479 | return out 480 | 481 | 482 | class VariancePredictor(nn.Module): 483 | def __init__( 484 | self, 485 | nlayers, 486 | in_channels, 487 | filter_size, 488 | kernel_size, 489 | dropout, 490 | depthwise=False, 491 | cwt=False, 492 | ): 493 | super().__init__() 494 | 495 | self.layers = nn.Sequential( 496 | *[ 497 | VarianceConvolutionLayer( 498 | in_channels, filter_size, kernel_size, dropout, depthwise 499 | ) 500 | for _ in range(nlayers) 501 | ] 502 | ) 503 | 504 | self.cwt = cwt 505 | if not self.cwt: 506 | self.linear = nn.Linear(filter_size, 1) 507 | else: 508 | self.linear = nn.Linear(filter_size, 10) 509 | 510 | def forward(self, x, mask=None, return_conv=False): 511 | out_conv = self.layers(x) 512 | out = self.linear(out_conv) 513 | if not self.cwt: 514 | out = out.squeeze(-1) 515 | else: 516 | mask = torch.stack([mask] * 10, dim=-1) 517 | if mask is not None: 518 | out = out.masked_fill(mask, 0) 519 | if return_conv: 520 | return out, out_conv 521 | else: 522 | return out 523 | 524 | class VarianceConvolutionLayer(nn.Module): 525 | def __init__(self, in_channels, filter_size, kernel_size, dropout, depthwise): 526 | super().__init__() 527 | if not depthwise: 528 | self.layers = nn.Sequential( 529 | Transpose( 530 | nn.Conv1d( 531 | in_channels, 532 | filter_size, 533 | kernel_size, 534 | padding=(kernel_size - 1) // 2, 535 | ) 536 | ), 537 | nn.ReLU(), 538 | nn.LayerNorm(filter_size), 539 | nn.Dropout(dropout), 540 | ) 541 | else: 542 | self.layers = nn.Sequential( 543 | Transpose( 544 | nn.Sequential( 545 | nn.Conv1d( 546 | in_channels, 547 | in_channels, 548 | kernel_size, 549 | padding=(kernel_size - 1) // 2, 550 | groups=in_channels, 551 | ), 552 | nn.Conv1d(in_channels, filter_size, 1), 553 | ) 554 | ), 555 | nn.ReLU(), 556 | nn.LayerNorm(filter_size), 557 | nn.Dropout(dropout), 558 | ) 559 | 560 | def forward(self, x): 561 | return self.layers(x) -------------------------------------------------------------------------------- /litfass/fastspeech2/noam.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | 3 | 4 | class NoamLR(_LRScheduler): 5 | """ 6 | Implements the Noam Learning rate schedule. This corresponds to increasing the learning rate 7 | linearly for the first ``warmup_steps`` training steps, and decreasing it thereafter proportionally 8 | to the inverse square root of the step number, scaled by the inverse square root of the 9 | dimensionality of the model. Time will tell if this is just madness or it's actually important. 10 | Parameters 11 | ---------- 12 | warmup_steps: ``int``, required. 13 | The number of steps to linearly increase the learning rate. 14 | """ 15 | 16 | def __init__(self, optimizer, warmup_steps): 17 | self.warmup_steps = warmup_steps 18 | super().__init__(optimizer) 19 | 20 | def get_lr(self): 21 | last_epoch = max(1, self.last_epoch) 22 | scale = self.warmup_steps**0.5 * min( 23 | last_epoch ** (-0.5), last_epoch * self.warmup_steps ** (-1.5) 24 | ) 25 | return [base_lr * scale for base_lr in self.base_lrs] 26 | -------------------------------------------------------------------------------- /litfass/generate.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import pickle 3 | from pathlib import Path 4 | import inspect 5 | 6 | from alignments.datasets.libritts import LibrittsDataset 7 | from audiomentations import Compose, RoomSimulator, AddGaussianSNR, PitchShift 8 | from huggingface_hub import hf_hub_download 9 | from torch.utils.data import DataLoader 10 | import torch 11 | from tqdm.auto import tqdm 12 | import json 13 | import hashlib 14 | import numpy as np 15 | 16 | from litfass.fastspeech2.fastspeech2 import FastSpeech2 17 | from litfass.synthesis.generator import SpeechGenerator 18 | from litfass.synthesis.g2p import EnglishG2P 19 | from litfass.dataset.datasets import TTSDataset 20 | from litfass.third_party.argutils import str2bool 21 | from litfass.third_party.fastdiff.FastDiff import FastDiff 22 | 23 | if __name__ == "__main__": 24 | parser = ArgumentParser() 25 | parser.add_argument("--dataset", type=str, default=None) 26 | parser.add_argument("--sentence", type=str, default=None) 27 | 28 | parser.add_argument("--checkpoint_path", type=str, default=None) 29 | parser.add_argument("--hub", type=str, default=None) 30 | 31 | parser.add_argument("--output_path", type=str) 32 | parser.add_argument("--batch_size", type=int, default=6) 33 | parser.add_argument("--device", type=str, default="cuda:0") 34 | parser.add_argument("--tts_device", type=str, default=None) 35 | parser.add_argument("--hifigan_device", type=str, default=None) 36 | parser.add_argument("--use_voicefixer", type=str2bool, default=True) 37 | parser.add_argument("--use_fastdiff", type=str2bool, default=False) 38 | parser.add_argument("--fastdiff_n", type=int, default=4) 39 | parser.add_argument("--num_workers", type=int, default=-1) 40 | 41 | parser.add_argument("--cache_path", type=str, default=None) 42 | 43 | # override priors 44 | parser.add_argument( 45 | "--prior_values", nargs="+", type=float, default=[-1, -1, -1, -1] 46 | ) 47 | 48 | parser.add_argument("--augment_pitch", type=str2bool, default=False) 49 | for pitch_arg in inspect.signature(PitchShift).parameters: 50 | parser.add_argument(f"--pitch_{pitch_arg}", type=float, default=None) 51 | parser.add_argument("--augment_room", type=str2bool, default=False) 52 | for room_arg in inspect.signature(RoomSimulator).parameters: 53 | if room_arg == "use_ray_tracing": 54 | parser.add_argument(f"--room_{room_arg}", type=str2bool, default=False) 55 | else: 56 | parser.add_argument(f"--room_{room_arg}", type=float, default=None) 57 | parser.add_argument("--augment_noise", type=str2bool, default=False) 58 | for noise_arg in inspect.signature(AddGaussianSNR).parameters: 59 | parser.add_argument(f"--noise_{noise_arg}", type=float, default=None) 60 | 61 | parser.add_argument("--copy", type=str2bool, default=False) 62 | 63 | # speakers 64 | parser.add_argument("--speaker", type=str, default=None) # can be "random", "dataset" or a speaker name 65 | parser.add_argument("--min_samples_per_speaker", type=int, default=0) 66 | 67 | # number of hours to generate 68 | parser.add_argument("--hours", type=float, default=1.0) 69 | 70 | args = parser.parse_args() 71 | 72 | args_dict = vars(args) 73 | 74 | if args.dataset is not None and args.sentence is not None: 75 | raise ValueError("You can only specify one of --dataset and --sentence!") 76 | 77 | if args.tts_device is None: 78 | args.tts_device = args.device 79 | if args.hifigan_device is None: 80 | args.hifigan_device = args.device 81 | 82 | augment_list = [] 83 | if args.augment_pitch: 84 | pitch_args = {} 85 | for pitch_arg in inspect.signature(PitchShift).parameters: 86 | pitch_args[pitch_arg] = args_dict[f"pitch_{pitch_arg}"] 87 | augment_list.append(PitchShift(**pitch_args)) 88 | if args.augment_room: 89 | room_args = {} 90 | for room_arg in inspect.signature(RoomSimulator).parameters: 91 | room_args[room_arg] = args_dict[f"room_{room_arg}"] 92 | augment_list.append(RoomSimulator(**room_args)) 93 | if args.augment_noise: 94 | noise_args = {} 95 | for noise_arg in inspect.signature(AddGaussianSNR).parameters: 96 | noise_args[noise_arg] = args_dict[f"noise_{noise_arg}"] 97 | augment_list.append(AddGaussianSNR(**noise_args)) 98 | 99 | if len(augment_list) > 0: 100 | augmentations = Compose( 101 | augment_list 102 | ) 103 | else: 104 | augmentations = None # pylint: disable=invalid-name 105 | 106 | if args.hub is not None: 107 | args.checkpoint_path = hf_hub_download(args.hub, filename="lit_model.ckpt") 108 | 109 | if args.checkpoint_path is None: 110 | raise ValueError("No checkpoint path or hub identifier specified!") 111 | 112 | model = FastSpeech2.load_from_checkpoint(args.checkpoint_path) 113 | 114 | generator = SpeechGenerator( 115 | model, 116 | EnglishG2P(), 117 | device=args.tts_device, 118 | synth_device=args.hifigan_device, 119 | augmentations=augmentations, 120 | voicefixer=args.use_voicefixer, 121 | fastdiff=args.use_fastdiff, 122 | fastdiff_n=args.fastdiff_n, 123 | ) 124 | 125 | if args.sentence is not None: 126 | if args.speaker is not None: 127 | args.speaker = Path(args.speaker) 128 | audio = generator.generate_from_text(args.sentence, args.speaker, random_seed=0, prior_strategy="gmm", prior_values=args.prior_values) 129 | if args.output_path is None: 130 | raise ValueError("No output path specified!") 131 | Path(args.output_path).parent.mkdir(parents=True, exist_ok=True) 132 | generator.save_audio(audio, Path(args.output_path) / f"{args.sentence.replace(' ', '_').lower()}.wav") 133 | 134 | if args.dataset is not None: 135 | ds = None 136 | if args.cache_path is not None: 137 | cache_path = Path(args.cache_path) 138 | tts_kwargs = { 139 | "speaker_type":model.hparams.speaker_type, 140 | "min_length":model.hparams.min_length, 141 | "max_length":model.hparams.max_length, 142 | "variances":model.hparams.variances, 143 | "variance_transforms":model.hparams.variance_transforms, 144 | "variance_levels":model.hparams.variance_levels, 145 | "priors":model.hparams.priors, 146 | "n_mels":model.hparams.n_mels, 147 | "n_fft":model.hparams.n_fft, 148 | "win_length":model.hparams.win_length, 149 | "hop_length":model.hparams.hop_length, 150 | "min_samples_per_speaker":args.min_samples_per_speaker, 151 | "_stats": model.stats, 152 | } 153 | hash_kwargs = tts_kwargs.copy() 154 | hash_kwargs["dataset"] = args.dataset 155 | ds_hash = hashlib.md5( 156 | json.dumps(hash_kwargs, sort_keys=True).encode("utf-8") 157 | ).hexdigest() 158 | cache_path = cache_path / (ds_hash + ".pt") 159 | if cache_path.exists(): 160 | print("Loading from cache...") 161 | with cache_path.open("rb") as f: 162 | ds = pickle.load(f) 163 | if ds is None: 164 | ds = TTSDataset( 165 | LibrittsDataset(target_directory=args.dataset, chunk_size=10_000), 166 | speaker_type=model.hparams.speaker_type, 167 | min_length=model.hparams.min_length, 168 | max_length=model.hparams.max_length, 169 | variances=model.hparams.variances, 170 | variance_transforms=model.hparams.variance_transforms, 171 | variance_levels=model.hparams.variance_levels, 172 | priors=model.hparams.priors, 173 | n_mels=model.hparams.n_mels, 174 | n_fft=model.hparams.n_fft, 175 | win_length=model.hparams.win_length, 176 | hop_length=model.hparams.hop_length, 177 | min_samples_per_speaker=args.min_samples_per_speaker, 178 | _stats=model.stats, 179 | num_workers=args.num_workers, 180 | ) 181 | if args.cache_path is not None and not cache_path.exists(): 182 | cache_path.parent.mkdir(parents=True, exist_ok=True) 183 | with cache_path.open("wb") as f: 184 | pickle.dump(ds, f) 185 | 186 | dl = DataLoader( 187 | ds, 188 | batch_size=args.batch_size, 189 | num_workers=4, 190 | collate_fn=ds._collate_fn, 191 | shuffle=False, 192 | ) 193 | 194 | with tqdm(total=args.hours) as pbar: 195 | for batch in dl: 196 | skip_speaker = False 197 | for i, speaker in enumerate(batch["speaker_path"]): 198 | speaker_dvec = Path(str(speaker).replace("-b", "-a")) 199 | speaker = speaker.name 200 | if speaker_dvec not in model.speaker2dvector: 201 | #skip_speaker = True 202 | #print(f"The speaker {speaker} is not present in the d-vector collection!") 203 | #break 204 | speaker = list(model.speaker2dvector.keys())[0] 205 | 206 | # if hasattr(model, "speaker_gmms"): 207 | # if speaker not in model.speaker_gmms: 208 | # skip_speaker = True 209 | # print(f"The speaker {speaker} is not present in the GMM collection!") 210 | # break 211 | # else: 212 | # model.speaker_gmms = pickle.load(open("speaker_gmms.pkl", "rb")) 213 | # p_sample = model.speaker_gmms[speaker].sample()[0][0] 214 | # for h, p in enumerate(model.hparams.priors): 215 | # batch[f"priors_{p}"][i] = p_sample[h] 216 | # if hasattr(model, "dvector_gmms"): 217 | # dvec = model.dvector_gmms[speaker_dvec].sample()[0][0] 218 | # batch["speaker"][i] = torch.tensor(dvec) 219 | # else: 220 | # batch["speaker"][i] = torch.tensor(model.speaker2dvector[speaker_dvec]).to(model.device) 221 | if skip_speaker: 222 | continue 223 | results = generator.generate_samples( 224 | batch, 225 | return_original=True, 226 | return_duration=True, 227 | ) 228 | i = 0 229 | stop_loop = False 230 | for audio, speaker, id in zip(results["audios"], batch["speaker_key"], batch["id"]): 231 | if args.output_path is None: 232 | raise ValueError("No output path specified!") 233 | output_path = Path(args.output_path) / speaker 234 | output_path.mkdir(parents=True, exist_ok=True) 235 | generator.save_audio(audio, output_path / id) 236 | id_name = id.replace(".wav", "") 237 | generator.save_audio( 238 | results["original_audios"][i], 239 | output_path / f"{id_name}_original.wav", 240 | fs=model.hparams.sampling_rate, 241 | ) 242 | with open(output_path / f"{id_name}.meta", "wb") as f: 243 | pickle.dump({"phones": batch["phones"], "durations": results["durations"]}, f) 244 | with open(output_path / f"{id_name}.lab", "w", encoding="utf-8") as f: 245 | f.write(batch["text"][i]) 246 | pbar.update(audio.shape[0] / results["fs"] / 3600) 247 | if pbar.n >= args.hours: 248 | stop_loop = True 249 | break 250 | i += 1 251 | if stop_loop: 252 | break 253 | 254 | 255 | -------------------------------------------------------------------------------- /litfass/plot.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import pickle 3 | import matplotlib.pyplot as plt 4 | 5 | from tqdm.auto import tqdm 6 | from alignments.datasets.libritts import LibrittsDataset 7 | from dataset.datasets import TTSDataset 8 | 9 | parser = ArgumentParser() 10 | 11 | if __name__ == "__main__": 12 | # train_ud = UnprocessedDataset( 13 | # "../Data/LibriTTS/train-clean-100-aligned", 14 | # max_entries=10_000, 15 | # pitch_quality=0.25, 16 | # ) 17 | # train_ud.plot(1_000) 18 | # ds = TTSDataset( 19 | # LibrittsDataset(target_directory="../Data/LibriTTS/train-clean-100-aligned", chunk_size=10_000), 20 | # priors=[], 21 | # variances=[], 22 | # variance_transforms=["none", "none", "none"], 23 | # denoise=False, 24 | # ) 25 | # ds = TTSDataset( 26 | # LibrittsDataset(target_directory="../Data/LibriTTS/train-clean-360-aligned", chunk_size=10_000), 27 | # priors=[], 28 | # variances=[], 29 | # variance_transforms=["none", "none", "none"], 30 | # denoise=False, 31 | # ) 32 | ds = TTSDataset( 33 | LibrittsDataset(target_directory="../data/train-clean-a", chunk_size=10_000), 34 | priors=["pitch", "energy", "snr", "duration"], 35 | variances=["pitch", "energy", "snr"], 36 | variance_transforms=["none", "none", "none"], 37 | variance_levels=["phone", "phone", "phone"], 38 | denoise=False, 39 | overwrite_stats=True, 40 | ) 41 | min_len = float("inf") 42 | for i, item in tqdm(enumerate(ds), total=10): 43 | fig = ds.plot(i, show=False) 44 | fig.save(f"test{i}.png") 45 | if i > 10: 46 | break 47 | 48 | # print(train_ud[3398]) 49 | # train_ud.plot(3398) 50 | # train_ud.plot(3692) 51 | -------------------------------------------------------------------------------- /litfass/synthesis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiniXC/LightningFastSpeech2/23a83937c09d8cb2590d3b27212c7ac497e2a120/litfass/synthesis/__init__.py -------------------------------------------------------------------------------- /litfass/synthesis/g2p.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import unicodedata 3 | 4 | from g2p_en import G2p 5 | from phones.convert import Converter 6 | 7 | 8 | class G2P(ABC): 9 | def __init__(self, lexicon_path): 10 | self.lexicon_path = lexicon_path 11 | self.lexicon = self.load_lexicon() 12 | 13 | @abstractmethod 14 | def __call__(self, text): 15 | raise NotImplementedError 16 | 17 | @abstractmethod 18 | def load_lexicon(self): 19 | raise NotImplementedError 20 | 21 | 22 | class EnglishG2P(G2P): 23 | def __init__(self, lexicon_path=None): 24 | super().__init__(lexicon_path) 25 | self.g2p = G2p() 26 | self.converter = Converter() 27 | 28 | def __call__(self, text): 29 | text = unicodedata.normalize("NFKD", text) 30 | text = text.lower() 31 | words = text.split(" ") 32 | phones = [] 33 | for word in words: 34 | if word[-1] in [".", ",", "!", "?"]: 35 | punctuation = word[-1] 36 | word = word[:-1] 37 | else: 38 | punctuation = "" 39 | if word in self.lexicon: 40 | add_phones = self.lexicon[word] 41 | else: 42 | add_phones = self.g2p(word) 43 | for phone in add_phones: 44 | phone.replace("ˌ", "") 45 | phone = phone.replace("0", "").replace("1", "") 46 | phone = self.converter(phone, "arpabet", lang=None) 47 | phones += phone 48 | if punctuation != "": 49 | phones.append("[" + unicodedata.name(punctuation) + "]") 50 | else: 51 | phones.append("[SILENCE]") 52 | return phones 53 | 54 | def load_lexicon(self): 55 | lexicon = {} 56 | if self.lexicon_path is not None: 57 | with open(self.lexicon_path, "r", encoding="utf-8") as file: 58 | for line in file: 59 | line = line.strip() 60 | if len(line) == 0: 61 | continue 62 | word, phones = line.split("\t") 63 | phones = phones.split(" ") 64 | lexicon[word.lower()] = phones 65 | return lexicon 66 | -------------------------------------------------------------------------------- /litfass/synthesis/generator.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import random 3 | import shutil 4 | import multiprocessing 5 | import pickle 6 | from copy import deepcopy 7 | import random 8 | 9 | 10 | import torch 11 | import torchaudio 12 | import torchaudio.functional as F 13 | from torch.utils.data import DataLoader 14 | import numpy as np 15 | from tqdm.auto import tqdm 16 | from voicefixer import VoiceFixer 17 | 18 | 19 | from litfass.third_party.hifigan import Synthesiser 20 | from litfass.fastspeech2.fastspeech2 import FastSpeech2 21 | from litfass.synthesis.g2p import G2P 22 | 23 | 24 | def int16_samples_to_float32(y): 25 | """Convert int16 numpy array of audio samples to float32.""" 26 | if y.dtype != np.int16: 27 | if y.dtype == np.float32: 28 | return y 29 | elif isinstance(y, torch.Tensor): 30 | return y.numpy() 31 | else: 32 | raise ValueError(f"input samples not int16 or float32, but {y.dtype}") 33 | return y.astype(np.float32) / np.iinfo(np.int16).max 34 | 35 | 36 | class SpeechGenerator: 37 | def __init__( 38 | self, 39 | model: FastSpeech2, 40 | g2p_model: G2P, 41 | device: str = "cuda:0", 42 | synth_device: str = None, 43 | overwrite: bool = False, 44 | voicefixer: bool = True, 45 | sampling_path: str = None, 46 | augmentations=None, 47 | speaker_dict=None, 48 | fastdiff=False, 49 | fastdiff_n=4, 50 | ): 51 | if not fastdiff: 52 | if synth_device is None: 53 | self.synth = Synthesiser(device=device) 54 | else: 55 | self.synth = Synthesiser(device=synth_device) 56 | self.fastdiff = fastdiff 57 | self.fastdiff_n = fastdiff_n 58 | self.model = model 59 | self.model.eval() 60 | self.g2p = g2p_model 61 | self.device = device 62 | self.model.to(self.device) 63 | self.overwrite = overwrite 64 | self.sampling_path = sampling_path 65 | self.augmentations = augmentations 66 | self.speaker_dict = speaker_dict 67 | if voicefixer: 68 | self.voicefixer = VoiceFixer() 69 | else: 70 | self.voicefixer = None 71 | 72 | @property 73 | def speakers(self): 74 | if self.model.hparams.speaker_type == "dvector": 75 | return self.model.speaker2dvector.keys() 76 | elif self.model.hparams.speaker_type == "id": 77 | return self.model.speaker2id.keys() 78 | else: 79 | return None 80 | 81 | def save_audio(self, audio, path, fs=None): 82 | if fs is None: 83 | if self.voicefixer: 84 | sampling_rate = 44100 85 | else: 86 | sampling_rate = self.model.hparams.sampling_rate 87 | else: 88 | sampling_rate = fs 89 | # make 2D if mono 90 | if len(audio.shape) == 1: 91 | audio = torch.tensor(audio).unsqueeze(0) 92 | else: 93 | audio = torch.tensor(audio) 94 | torchaudio.save(path, audio, sampling_rate) 95 | 96 | def generate_from_text(self, text, speaker=None, random_seed=None, prior_strategy="sample", prior_values=[-1, -1, -1, -1]): 97 | ids = [ 98 | self.model.phone2id[x] for x in self.g2p(text) if x in self.model.phone2id 99 | ] 100 | batch = {} 101 | speaker_name = None 102 | if self.model.hparams.speaker_type == "dvector": 103 | if speaker is None: 104 | while True: 105 | # pylint: disable=invalid-sequence-index 106 | speaker = list(self.model.speaker2dvector.keys())[ 107 | np.random.randint(len(self.model.speaker2dvector)) 108 | ] 109 | # pylint: enable=invalid-sequence-index 110 | # TODO: remove this when all models are fixed 111 | if len(self.model.hparams.priors) > 0: 112 | if isinstance(speaker, Path): 113 | speaker_name = speaker.name 114 | if speaker_name in self.model.speaker2priors: 115 | break 116 | else: 117 | speaker_name = speaker.name 118 | batch["speaker"] = torch.tensor([self.model.speaker2dvector[speaker]]).to( 119 | self.device 120 | ) 121 | print("Using speaker", speaker) 122 | if self.model.hparams.speaker_type == "id": 123 | batch["speaker"] = torch.tensor([self.model.speaker2id[speaker]]).to( 124 | self.device 125 | ) 126 | print("Using speaker", speaker) 127 | if len(self.model.hparams.priors) > 0: 128 | if speaker_name is None: 129 | speaker_name = speaker 130 | if random_seed is not None: 131 | np.random.seed(random_seed) 132 | if prior_strategy == "sample": 133 | priors = self.model.speaker2priors[speaker_name] 134 | prior_len = len(priors[self.model.hparams.priors[0]]) 135 | random_index = np.random.randint(prior_len) 136 | for prior in self.model.hparams.priors: 137 | batch[f"priors_{prior}"] = torch.tensor([priors[prior][random_index]]).to(self.device) 138 | print(f"Using prior {prior} with value {priors[prior][random_index]:.2f}") 139 | elif prior_strategy == "gmm": 140 | gmm = self.model.speaker_gmms[speaker_name] 141 | values = gmm.sample()[0][0] 142 | for i, prior in enumerate(self.model.hparams.priors): 143 | batch[f"priors_{prior}"] = torch.tensor([values[i]]).to(self.device) 144 | print(f"Using prior {prior} with value {values[i]:.2f}") 145 | batch["phones"] = torch.tensor([ids]).to(self.device) 146 | for i, prior in enumerate(self.model.hparams.priors): 147 | if prior_values[i] != -1: 148 | batch[f"priors_{prior}"] = torch.tensor([prior_values[i]]).to(self.device) 149 | print(f"Overriding prior {prior} with value {prior_values[i]:.2f}") 150 | return self.generate_samples(batch)[1][0] 151 | 152 | def generate_samples( 153 | self, 154 | batch, 155 | return_original=False, 156 | return_duration=False, 157 | ): 158 | result = self.model(batch, inference=True) 159 | fs = self.model.hparams.sampling_rate 160 | 161 | audios = [] 162 | durations = [] 163 | for i in range(len(result["mel"])): 164 | mel = result["mel"][i][~result["tgt_mask"][i]].cpu() 165 | durations.append(result["duration_rounded"][i].cpu()) 166 | if self.fastdiff: 167 | mel = mel + result["fastdiff_var"][i][~result["tgt_mask"][i]].cpu() 168 | pred_audio = self.model.fastdiff_model.inference(mel.to(self.device), N=self.fastdiff_n).cpu() 169 | else: 170 | pred_audio = self.synth(mel)[0] 171 | audios.append(int16_samples_to_float32(pred_audio)) 172 | 173 | if self.voicefixer is not None: 174 | fs_new = None 175 | fixed_audios = [] 176 | for i, audio in enumerate(audios): 177 | tmp_dir = Path("/tmp/voicefixer") 178 | tmp_dir.mkdir(exist_ok=True) 179 | tmp_hash = str(random.getrandbits(128)) 180 | if fs != 22050: 181 | audio = F.resample(torch.tensor(audio), fs, 22050) 182 | fs = 22050 183 | pad_width = int(fs * 0.1) 184 | audio = np.pad(audio, (pad_width, pad_width), constant_values=(0, 0)) 185 | torchaudio.save(tmp_dir / f"{tmp_hash}.wav", torch.tensor([audio]), fs) 186 | self.voicefixer.restore( 187 | input=tmp_dir / f"{tmp_hash}.wav", 188 | output=tmp_dir / f"{tmp_hash}_fixed.wav", 189 | cuda=True, 190 | mode=1, 191 | ) 192 | fixed_audio, fs_new = torchaudio.load(tmp_dir / f"{tmp_hash}_fixed.wav") 193 | # remove padding 194 | fixed_audio = fixed_audio[0].numpy()[pad_width:-pad_width] 195 | fixed_audios.append(fixed_audio) 196 | 197 | if self.augmentations is not None: 198 | audios = [ 199 | self.augmentations(audio, sample_rate=self.model.hparams.sampling_rate) 200 | for audio in audios 201 | ] 202 | 203 | if return_original and self.voicefixer is not None: 204 | result = { 205 | "original_fs": fs, 206 | "original_audios": audios, 207 | "fs": fs_new, 208 | "audios": fixed_audios, 209 | } 210 | elif self.voicefixer is not None: 211 | result = { 212 | "fs": fs_new, 213 | "audios": fixed_audios, 214 | } 215 | else: 216 | result = { 217 | "fs": fs, 218 | "audios": audios, 219 | } 220 | 221 | if return_duration: 222 | result["durations"] = durations 223 | 224 | return result 225 | -------------------------------------------------------------------------------- /litfass/third_party/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiniXC/LightningFastSpeech2/23a83937c09d8cb2590d3b27212c7ac497e2a120/litfass/third_party/__init__.py -------------------------------------------------------------------------------- /litfass/third_party/argutils/__init__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def str2bool(v): 4 | if isinstance(v, bool): 5 | return v 6 | if v.lower() in ("yes", "true", "t", "y", "1"): 7 | return True 8 | elif v.lower() in ("no", "false", "f", "n", "0"): 9 | return False 10 | else: 11 | raise argparse.ArgumentTypeError("Boolean value expected.") 12 | -------------------------------------------------------------------------------- /litfass/third_party/dvectors/dvector.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiniXC/LightningFastSpeech2/23a83937c09d8cb2590d3b27212c7ac497e2a120/litfass/third_party/dvectors/dvector.pt -------------------------------------------------------------------------------- /litfass/third_party/dvectors/wav2mel.py: -------------------------------------------------------------------------------- 1 | """Wav2Mel for processing audio data.""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torchaudio.sox_effects import apply_effects_tensor 6 | from torchaudio.transforms import MelSpectrogram 7 | 8 | 9 | class Wav2Mel(nn.Module): 10 | """Transform audio file into mel spectrogram tensors.""" 11 | 12 | def __init__( 13 | self, 14 | sample_rate: int = 16000, 15 | norm_db: float = -3.0, 16 | sil_threshold: float = 1.0, 17 | sil_duration: float = 0.1, 18 | fft_window_ms: float = 25.0, 19 | fft_hop_ms: float = 10.0, 20 | f_min: float = 50.0, 21 | n_mels: int = 40, 22 | ): 23 | super().__init__() 24 | 25 | self.sample_rate = sample_rate 26 | self.norm_db = norm_db 27 | self.sil_threshold = sil_threshold 28 | self.sil_duration = sil_duration 29 | self.fft_window_ms = fft_window_ms 30 | self.fft_hop_ms = fft_hop_ms 31 | self.f_min = f_min 32 | self.n_mels = n_mels 33 | 34 | self.sox_effects = SoxEffects(sample_rate, norm_db, sil_threshold, sil_duration) 35 | self.log_melspectrogram = LogMelspectrogram( 36 | sample_rate, fft_window_ms, fft_hop_ms, f_min, n_mels 37 | ) 38 | 39 | def forward(self, wav_tensor: torch.Tensor, sample_rate: int) -> torch.Tensor: 40 | wav_tensor = self.sox_effects(wav_tensor, sample_rate) 41 | mel_tensor = self.log_melspectrogram(wav_tensor) 42 | return mel_tensor 43 | 44 | 45 | class SoxEffects(nn.Module): 46 | """Transform waveform tensors.""" 47 | 48 | def __init__( 49 | self, 50 | sample_rate: int, 51 | norm_db: float, 52 | sil_threshold: float, 53 | sil_duration: float, 54 | ): 55 | super().__init__() 56 | self.effects = [ 57 | ["channels", "1"], # convert to mono 58 | ["rate", f"{sample_rate}"], # resample 59 | ["norm", f"{norm_db}"], # normalize to -3 dB 60 | [ 61 | "silence", 62 | "1", 63 | f"{sil_duration}", 64 | f"{sil_threshold}%", 65 | "-1", 66 | f"{sil_duration}", 67 | f"{sil_threshold}%", 68 | ], # remove silence throughout the file 69 | ] 70 | 71 | def forward(self, wav_tensor: torch.Tensor, sample_rate: int) -> torch.Tensor: 72 | wav_tensor, _ = apply_effects_tensor(wav_tensor, sample_rate, self.effects) 73 | return wav_tensor 74 | 75 | 76 | class LogMelspectrogram(nn.Module): 77 | """Transform waveform tensors into log mel spectrogram tensors.""" 78 | 79 | def __init__( 80 | self, 81 | sample_rate: int, 82 | fft_window_ms: float, 83 | fft_hop_ms: float, 84 | f_min: float, 85 | n_mels: int, 86 | ): 87 | super().__init__() 88 | self.melspectrogram = MelSpectrogram( 89 | sample_rate=sample_rate, 90 | hop_length=int(sample_rate * fft_hop_ms / 1000), 91 | n_fft=int(sample_rate * fft_window_ms / 1000), 92 | f_min=f_min, 93 | n_mels=n_mels, 94 | ) 95 | 96 | def forward(self, wav_tensor: torch.Tensor) -> torch.Tensor: 97 | mel_tensor = self.melspectrogram(wav_tensor).squeeze(0).T # (time, n_mels) 98 | return torch.log(torch.clamp(mel_tensor, min=1e-9)) 99 | -------------------------------------------------------------------------------- /litfass/third_party/fastdiff/FastDiff.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import logging 4 | from .module.modules import DiffusionDBlock, TimeAware_LVCBlock 5 | from .module.util import calc_diffusion_step_embedding, std_normal, compute_hyperparams_given_schedule, sampling_given_noise_schedule 6 | from litfass.third_party.argutils import str2bool 7 | 8 | def swish(x): 9 | return x * torch.sigmoid(x) 10 | 11 | class FastDiff(nn.Module): 12 | """FastDiff module.""" 13 | 14 | def __init__( 15 | self, 16 | audio_channels=1, 17 | inner_channels=32, 18 | cond_channels=80, 19 | upsample_ratios=[8, 8, 4], 20 | lvc_layers_each_block=4, 21 | lvc_kernel_size=3, 22 | kpnet_hidden_channels=64, 23 | kpnet_conv_size=3, 24 | dropout=0.0, 25 | diffusion_step_embed_dim_in=128, 26 | diffusion_step_embed_dim_mid=512, 27 | diffusion_step_embed_dim_out=512, 28 | use_weight_norm=True, 29 | beta_0=1e-6, 30 | beta_T=0.01, 31 | T=1000, 32 | ): 33 | super().__init__() 34 | 35 | self.diffusion_step_embed_dim_in = diffusion_step_embed_dim_in 36 | 37 | self.audio_channels = audio_channels 38 | self.cond_channels = cond_channels 39 | self.lvc_block_nums = len(upsample_ratios) 40 | self.first_audio_conv = nn.Conv1d(1, inner_channels, 41 | kernel_size=7, padding=(7 - 1) // 2, 42 | dilation=1, bias=True) 43 | 44 | # define residual blocks 45 | self.lvc_blocks = nn.ModuleList() 46 | self.downsample = nn.ModuleList() 47 | 48 | # the layer-specific fc for noise scale embedding 49 | self.fc_t = nn.ModuleList() 50 | self.fc_t1 = nn.Linear(diffusion_step_embed_dim_in, diffusion_step_embed_dim_mid) 51 | self.fc_t2 = nn.Linear(diffusion_step_embed_dim_mid, diffusion_step_embed_dim_out) 52 | 53 | cond_hop_length = 1 54 | for n in range(self.lvc_block_nums): 55 | cond_hop_length = cond_hop_length * upsample_ratios[n] 56 | lvcb = TimeAware_LVCBlock( 57 | in_channels=inner_channels, 58 | cond_channels=cond_channels, 59 | upsample_ratio=upsample_ratios[n], 60 | conv_layers=lvc_layers_each_block, 61 | conv_kernel_size=lvc_kernel_size, 62 | cond_hop_length=cond_hop_length, 63 | kpnet_hidden_channels=kpnet_hidden_channels, 64 | kpnet_conv_size=kpnet_conv_size, 65 | kpnet_dropout=dropout, 66 | noise_scale_embed_dim_out=diffusion_step_embed_dim_out 67 | ) 68 | self.lvc_blocks += [lvcb] 69 | self.downsample.append(DiffusionDBlock(inner_channels, inner_channels, upsample_ratios[self.lvc_block_nums-n-1])) 70 | 71 | 72 | # define output layers 73 | self.final_conv = nn.Sequential( 74 | nn.Conv1d( 75 | inner_channels, 76 | audio_channels, 77 | kernel_size=7, 78 | padding=(7 - 1) // 2, 79 | dilation=1, 80 | bias=True 81 | ) 82 | ) 83 | 84 | # apply weight norm 85 | if use_weight_norm: 86 | self.apply_weight_norm() 87 | 88 | self.noise_schedule = torch.linspace(beta_0, beta_T, T) 89 | self.diffusion_hyperparams = compute_hyperparams_given_schedule(self.noise_schedule) 90 | 91 | def forward(self, x, c, ts=None, reverse=False, mask=None): 92 | """Calculate forward propagation. 93 | Args: 94 | x (Tensor): Input noise signal (B, 1, T). 95 | c (Tensor): Local conditioning auxiliary features (B, C ,T'). 96 | Returns: 97 | Tensor: Output tensor (B, out_channels, T) 98 | """ 99 | 100 | if len(x.shape) == 2: 101 | B, L = x.shape # B is batchsize, C=1, L is audio length 102 | x = x.unsqueeze(1) 103 | if len(c.shape) == 2: 104 | c = c.unsqueeze(0) 105 | B, C, L = c.shape # B is batchsize, C=80, L is audio length 106 | 107 | if ts is None: 108 | no_ts = True 109 | T, alpha = self.diffusion_hyperparams["T"], self.diffusion_hyperparams["alpha"].to(x.device) 110 | ts = torch.randint(T, size=(B, 1, 1)).to(x.device) # randomly sample steps from 1~T 111 | z = std_normal(x.shape, device=x.device) 112 | delta = (1 - alpha[ts] ** 2.).sqrt() 113 | alpha_cur = alpha[ts] 114 | noisy_audio = alpha_cur * x + delta * z # compute x_t from q(x_t|x_0) 115 | x = noisy_audio 116 | ts = ts.view(B, 1) 117 | else: 118 | no_ts = False 119 | 120 | # embed diffusion step t 121 | diffusion_step_embed = calc_diffusion_step_embedding(ts, self.diffusion_step_embed_dim_in, device=x.device) 122 | diffusion_step_embed = swish(self.fc_t1(diffusion_step_embed)) 123 | diffusion_step_embed = swish(self.fc_t2(diffusion_step_embed)) 124 | 125 | x = self.first_audio_conv(x) 126 | downsample = [] 127 | for down_layer in self.downsample: 128 | downsample.append(x) 129 | x = down_layer(x) 130 | 131 | for n, audio_down in enumerate(reversed(downsample)): 132 | x = self.lvc_blocks[n]((x, audio_down, c, diffusion_step_embed)) 133 | 134 | # apply final layers 135 | x = self.final_conv(x) 136 | 137 | if mask is not None: 138 | x = x.masked_fill(mask, 0) 139 | 140 | if not reverse: 141 | if no_ts: 142 | return x, z 143 | else: 144 | return x 145 | else: 146 | x0 = (noisy_audio - delta * x) / alpha_cur 147 | return (x, z), x0 148 | 149 | def inference(self, c, N=4, hop_size=256): 150 | """Inference with the given local conditioning auxiliary features. 151 | Args: 152 | c (Tensor): Local conditioning auxiliary features (B, C, T'). 153 | Returns: 154 | Tensor: Output tensor (B, out_channels, T) 155 | """ 156 | c = c.transpose(0, 1) 157 | 158 | reverse_step = N 159 | if reverse_step == 1000: 160 | noise_schedule = torch.linspace(0.000001, 0.01, 1000) 161 | elif reverse_step == 200: 162 | noise_schedule = torch.linspace(0.0001, 0.02, 200) 163 | elif reverse_step == 8: 164 | noise_schedule = [6.689325005027058e-07, 1.0033881153503899e-05, 0.00015496854030061513, 165 | 0.002387222135439515, 0.035597629845142365, 0.3681158423423767, 0.4735414385795593, 0.5] 166 | elif reverse_step == 6: 167 | noise_schedule = [1.7838445955931093e-06, 2.7984189728158526e-05, 0.00043231004383414984, 168 | 0.006634317338466644, 0.09357017278671265, 0.6000000238418579] 169 | elif reverse_step == 4: 170 | noise_schedule = [3.2176e-04, 2.5743e-03, 2.5376e-02, 7.0414e-01] 171 | elif reverse_step == 3: 172 | noise_schedule = [9.0000e-05, 9.0000e-03, 6.0000e-01] 173 | else: 174 | raise ValueError("Reverse step should be 3, 4, 6, 8, 200 or 1000.") 175 | 176 | if not isinstance(noise_schedule, torch.Tensor): 177 | noise_schedule = torch.FloatTensor(noise_schedule) 178 | noise_schedule = noise_schedule.to(c.device) 179 | 180 | audio_length = c.shape[-1] * hop_size 181 | 182 | pred_wav = sampling_given_noise_schedule( 183 | self, 184 | (1, 1, audio_length), 185 | self.diffusion_hyperparams, 186 | noise_schedule, 187 | condition=c, 188 | ddim=False, 189 | return_sequence=False, 190 | device=c.device 191 | ) 192 | 193 | pred_wav = pred_wav / pred_wav.abs().max() 194 | pred_wav = pred_wav.view(-1) 195 | return pred_wav 196 | 197 | def remove_weight_norm(self): 198 | """Remove weight normalization module from all of the layers.""" 199 | def _remove_weight_norm(m): 200 | try: 201 | logging.debug(f"Weight norm is removed from {m}.") 202 | torch.nn.utils.remove_weight_norm(m) 203 | except ValueError: # this module didn't have weight norm 204 | return 205 | 206 | self.apply(_remove_weight_norm) 207 | 208 | def apply_weight_norm(self): 209 | """Apply weight normalization module from all of the layers.""" 210 | def _apply_weight_norm(m): 211 | if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d): 212 | torch.nn.utils.weight_norm(m) 213 | logging.debug(f"Weight norm is applied to {m}.") 214 | 215 | self.apply(_apply_weight_norm) 216 | 217 | @staticmethod 218 | def add_model_specific_args(parent_parser): 219 | """Add model specific arguments.""" 220 | parser = parent_parser.add_argument_group("FastDiff model setting") 221 | # network structure related 222 | parser.add_argument("--fastdiff_audio_channels", default=1, type=int, 223 | help="Number of audio channels") 224 | parser.add_argument("--fastdiff_inner_channels", default=32, type=int, 225 | help="Number of inner channels") 226 | parser.add_argument("--fastdiff_cond_channels", default=80, type=int, 227 | help="Number of conditional channels") 228 | parser.add_argument("--fastdiff_upsample_ratios", default=[8, 8, 4], type=int, nargs="+", 229 | help="Upsampling ratios") 230 | parser.add_argument("--fastdiff_lvc_layers_each_block", default=4, type=int, 231 | help="Number of layers in each LVC block") 232 | parser.add_argument("--fastdiff_lvc_kernel_size", default=3, type=int, 233 | help="Kernel size in each LVC block") 234 | parser.add_argument("--fastdiff_kpnet_hidden_channels", default=64, type=int, 235 | help="Number of hidden channels in keypoint network") 236 | parser.add_argument("--fastdiff_kpnet_conv_size", default=3, type=int, 237 | help="Kernel size in keypoint network") 238 | parser.add_argument("--fastdiff_dropout", default=0.0, type=float, 239 | help="Dropout rate") 240 | parser.add_argument("--fastdiff_diffusion_step_embed_dim_in", default=128, type=int, 241 | help="Dimension of diffusion step embedding") 242 | parser.add_argument("--fastdiff_diffusion_step_embed_dim_mid", default=512, type=int, 243 | help="Dimension of diffusion step embedding") 244 | parser.add_argument("--fastdiff_diffusion_step_embed_dim_out", default=512, type=int, 245 | help="Dimension of diffusion step embedding") 246 | parser.add_argument("--fastdiff_use_weight_norm", default=True, type=str2bool, 247 | help="Whether to use weight normalization") 248 | # training related 249 | parser.add_argument("--fastdiff_beta_0", default=1e-6, type=float, 250 | help="Initial noise scale") 251 | parser.add_argument("--fastdiff_beta_T", default=0.01, type=float, 252 | help="Final noise scale") 253 | parser.add_argument("--fastdiff_T", default=1000, type=int, 254 | help="Number of diffusion steps") 255 | return parent_parser -------------------------------------------------------------------------------- /litfass/third_party/fastdiff/module/modules.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from torch.nn import Conv1d 8 | 9 | LRELU_SLOPE = 0.1 10 | 11 | 12 | 13 | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): 14 | ''' Sinusoid position encoding table ''' 15 | 16 | def cal_angle(position, hid_idx): 17 | return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) 18 | 19 | def get_posi_angle_vec(position): 20 | return [cal_angle(position, hid_j) for hid_j in range(d_hid)] 21 | 22 | sinusoid_table = np.array([get_posi_angle_vec(pos_i) 23 | for pos_i in range(n_position)]) 24 | 25 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 26 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 27 | 28 | if padding_idx is not None: 29 | # zero vector for padding dimension 30 | sinusoid_table[padding_idx] = 0. 31 | 32 | return torch.FloatTensor(sinusoid_table) 33 | 34 | 35 | def overlap_and_add(signal, frame_step): 36 | """Reconstructs a signal from a framed representation. 37 | 38 | Adds potentially overlapping frames of a signal with shape 39 | `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`. 40 | The resulting tensor has shape `[..., output_size]` where 41 | 42 | output_size = (frames - 1) * frame_step + frame_length 43 | 44 | Args: 45 | signal: A [..., frames, frame_length] Tensor. All dimensions may be unknown, and rank must be at least 2. 46 | frame_step: An integer denoting overlap offsets. Must be less than or equal to frame_length. 47 | 48 | Returns: 49 | A Tensor with shape [..., output_size] containing the overlap-added frames of signal's inner-most two dimensions. 50 | output_size = (frames - 1) * frame_step + frame_length 51 | 52 | Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py 53 | """ 54 | outer_dimensions = signal.size()[:-2] 55 | frames, frame_length = signal.size()[-2:] 56 | 57 | # gcd=Greatest Common Divisor 58 | subframe_length = math.gcd(frame_length, frame_step) 59 | subframe_step = frame_step // subframe_length 60 | subframes_per_frame = frame_length // subframe_length 61 | output_size = frame_step * (frames - 1) + frame_length 62 | output_subframes = output_size // subframe_length 63 | 64 | subframe_signal = signal.view(*outer_dimensions, -1, subframe_length) 65 | 66 | frame = torch.arange(0, output_subframes).unfold(0, subframes_per_frame, subframe_step) 67 | frame = signal.new_tensor(frame).long() # signal may in GPU or CPU 68 | frame = frame.contiguous().view(-1) 69 | 70 | result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length) 71 | device_of_result = result.device 72 | result.index_add_(-2, frame.to(device_of_result), subframe_signal) 73 | result = result.view(*outer_dimensions, -1) 74 | return result 75 | 76 | 77 | class LastLayer(nn.Module): 78 | def __init__(self, in_channels, out_channels, 79 | nonlinear_activation, nonlinear_activation_params, 80 | pad, kernel_size, pad_params, bias): 81 | super(LastLayer, self).__init__() 82 | self.activation = getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params) 83 | self.pad = getattr(torch.nn, pad)((kernel_size - 1) // 2, **pad_params) 84 | self.conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size, bias=bias) 85 | 86 | def forward(self, x): 87 | x = self.activation(x) 88 | x = self.pad(x) 89 | x = self.conv(x) 90 | return x 91 | 92 | 93 | class WeightConv1d(Conv1d): 94 | """Conv1d module with customized initialization.""" 95 | 96 | def __init__(self, *args, **kwargs): 97 | """Initialize Conv1d module.""" 98 | super(Conv1d, self).__init__(*args, **kwargs) 99 | 100 | def reset_parameters(self): 101 | """Reset parameters.""" 102 | torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu") 103 | if self.bias is not None: 104 | torch.nn.init.constant_(self.bias, 0.0) 105 | 106 | 107 | class Conv1d1x1(Conv1d): 108 | """1x1 Conv1d with customized initialization.""" 109 | 110 | def __init__(self, in_channels, out_channels, bias): 111 | """Initialize 1x1 Conv1d module.""" 112 | super(Conv1d1x1, self).__init__(in_channels, out_channels, 113 | kernel_size=1, padding=0, 114 | dilation=1, bias=bias) 115 | 116 | class DiffusionDBlock(nn.Module): 117 | def __init__(self, input_size, hidden_size, factor): 118 | super().__init__() 119 | self.factor = factor 120 | self.residual_dense = Conv1d(input_size, hidden_size, 1) 121 | self.conv = nn.ModuleList([ 122 | Conv1d(input_size, hidden_size, 3, dilation=1, padding=1), 123 | Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2), 124 | Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4), 125 | ]) 126 | 127 | def forward(self, x): 128 | size = x.shape[-1] // self.factor 129 | 130 | residual = self.residual_dense(x) 131 | residual = F.interpolate(residual, size=size) 132 | 133 | x = F.interpolate(x, size=size) 134 | for layer in self.conv: 135 | x = F.leaky_relu(x, 0.2) 136 | x = layer(x) 137 | 138 | return x + residual 139 | 140 | 141 | class TimeAware_LVCBlock(torch.nn.Module): 142 | ''' time-aware location-variable convolutions 143 | ''' 144 | def __init__(self, 145 | in_channels, 146 | cond_channels, 147 | upsample_ratio, 148 | conv_layers=4, 149 | conv_kernel_size=3, 150 | cond_hop_length=256, 151 | kpnet_hidden_channels=64, 152 | kpnet_conv_size=3, 153 | kpnet_dropout=0.0, 154 | noise_scale_embed_dim_out=512 155 | ): 156 | super().__init__() 157 | 158 | self.cond_hop_length = cond_hop_length 159 | self.conv_layers = conv_layers 160 | self.conv_kernel_size = conv_kernel_size 161 | self.convs = torch.nn.ModuleList() 162 | 163 | self.upsample = torch.nn.ConvTranspose1d(in_channels, in_channels, 164 | kernel_size=upsample_ratio*2, stride=upsample_ratio, 165 | padding=upsample_ratio // 2 + upsample_ratio % 2, 166 | output_padding=upsample_ratio % 2) 167 | 168 | 169 | self.kernel_predictor = KernelPredictor( 170 | cond_channels=cond_channels, 171 | conv_in_channels=in_channels, 172 | conv_out_channels=2 * in_channels, 173 | conv_layers=conv_layers, 174 | conv_kernel_size=conv_kernel_size, 175 | kpnet_hidden_channels=kpnet_hidden_channels, 176 | kpnet_conv_size=kpnet_conv_size, 177 | kpnet_dropout=kpnet_dropout 178 | ) 179 | 180 | # the layer-specific fc for noise scale embedding 181 | self.fc_t = torch.nn.Linear(noise_scale_embed_dim_out, cond_channels) 182 | 183 | for i in range(conv_layers): 184 | padding = (3 ** i) * int((conv_kernel_size - 1) / 2) 185 | conv = torch.nn.Conv1d(in_channels, in_channels, kernel_size=conv_kernel_size, padding=padding, dilation=3 ** i) 186 | 187 | self.convs.append(conv) 188 | 189 | 190 | def forward(self, data): 191 | ''' forward propagation of the time-aware location-variable convolutions. 192 | Args: 193 | x (Tensor): the input sequence (batch, in_channels, in_length) 194 | c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) 195 | 196 | Returns: 197 | Tensor: the output sequence (batch, in_channels, in_length) 198 | ''' 199 | x, audio_down, c, noise_embedding = data 200 | batch, in_channels, in_length = x.shape 201 | 202 | noise = (self.fc_t(noise_embedding)).unsqueeze(-1) # (B, 80) 203 | condition = c + noise # (B, 80, T) 204 | kernels, bias = self.kernel_predictor(condition) 205 | x = F.leaky_relu(x, 0.2) 206 | x = self.upsample(x) 207 | 208 | for i in range(self.conv_layers): 209 | x += audio_down 210 | y = F.leaky_relu(x, 0.2) 211 | y = self.convs[i](y) 212 | y = F.leaky_relu(y, 0.2) 213 | 214 | k = kernels[:, i, :, :, :, :] 215 | b = bias[:, i, :, :] 216 | y = self.location_variable_convolution(y, k, b, 1, self.cond_hop_length) 217 | x = x + torch.sigmoid(y[:, :in_channels, :]) * torch.tanh(y[:, in_channels:, :]) 218 | return x 219 | 220 | def location_variable_convolution(self, x, kernel, bias, dilation, hop_size): 221 | ''' perform location-variable convolution operation on the input sequence (x) using the local convolution kernl. 222 | Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100. 223 | Args: 224 | x (Tensor): the input sequence (batch, in_channels, in_length). 225 | kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length) 226 | bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length) 227 | dilation (int): the dilation of convolution. 228 | hop_size (int): the hop_size of the conditioning sequence. 229 | Returns: 230 | (Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length). 231 | ''' 232 | batch, in_channels, in_length = x.shape 233 | batch, in_channels, out_channels, kernel_size, kernel_length = kernel.shape 234 | 235 | 236 | assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched" 237 | 238 | padding = dilation * int((kernel_size - 1) / 2) 239 | x = F.pad(x, (padding, padding), 'constant', 0) # (batch, in_channels, in_length + 2*padding) 240 | x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding) 241 | 242 | if hop_size < dilation: 243 | x = F.pad(x, (0, dilation), 'constant', 0) 244 | x = x.unfold(3, dilation, 245 | dilation) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation) 246 | x = x[:, :, :, :, :hop_size] 247 | x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation) 248 | x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size) 249 | 250 | o = torch.einsum('bildsk,biokl->bolsd', x, kernel) 251 | o = o + bias.unsqueeze(-1).unsqueeze(-1) 252 | o = o.contiguous().view(batch, out_channels, -1) 253 | return o 254 | 255 | 256 | 257 | class KernelPredictor(torch.nn.Module): 258 | ''' Kernel predictor for the time-aware location-variable convolutions 259 | ''' 260 | 261 | def __init__(self, 262 | cond_channels, 263 | conv_in_channels, 264 | conv_out_channels, 265 | conv_layers, 266 | conv_kernel_size=3, 267 | kpnet_hidden_channels=64, 268 | kpnet_conv_size=3, 269 | kpnet_dropout=0.0, 270 | kpnet_nonlinear_activation="LeakyReLU", 271 | kpnet_nonlinear_activation_params={"negative_slope": 0.1} 272 | ): 273 | ''' 274 | Args: 275 | cond_channels (int): number of channel for the conditioning sequence, 276 | conv_in_channels (int): number of channel for the input sequence, 277 | conv_out_channels (int): number of channel for the output sequence, 278 | conv_layers (int): 279 | kpnet_ 280 | ''' 281 | super().__init__() 282 | 283 | self.conv_in_channels = conv_in_channels 284 | self.conv_out_channels = conv_out_channels 285 | self.conv_kernel_size = conv_kernel_size 286 | self.conv_layers = conv_layers 287 | 288 | l_w = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers 289 | l_b = conv_out_channels * conv_layers 290 | 291 | padding = (kpnet_conv_size - 1) // 2 292 | self.input_conv = torch.nn.Sequential( 293 | torch.nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=(5 - 1) // 2, bias=True), 294 | getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), 295 | ) 296 | 297 | self.residual_conv = torch.nn.Sequential( 298 | torch.nn.Dropout(kpnet_dropout), 299 | torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True), 300 | getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), 301 | torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True), 302 | getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), 303 | torch.nn.Dropout(kpnet_dropout), 304 | torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True), 305 | getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), 306 | torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True), 307 | getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), 308 | torch.nn.Dropout(kpnet_dropout), 309 | torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True), 310 | getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), 311 | torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True), 312 | getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), 313 | ) 314 | 315 | self.kernel_conv = torch.nn.Conv1d(kpnet_hidden_channels, l_w, kpnet_conv_size, 316 | padding=padding, bias=True) 317 | self.bias_conv = torch.nn.Conv1d(kpnet_hidden_channels, l_b, kpnet_conv_size, padding=padding, 318 | bias=True) 319 | 320 | def forward(self, c): 321 | ''' 322 | Args: 323 | c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) 324 | Returns: 325 | ''' 326 | batch, cond_channels, cond_length = c.shape 327 | 328 | c = self.input_conv(c) 329 | c = c + self.residual_conv(c) 330 | k = self.kernel_conv(c) 331 | b = self.bias_conv(c) 332 | 333 | kernels = k.contiguous().view(batch, 334 | self.conv_layers, 335 | self.conv_in_channels, 336 | self.conv_out_channels, 337 | self.conv_kernel_size, 338 | cond_length) 339 | bias = b.contiguous().view(batch, 340 | self.conv_layers, 341 | self.conv_out_channels, 342 | cond_length) 343 | return kernels, bias 344 | -------------------------------------------------------------------------------- /litfass/third_party/fastdiff/module/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import copy 6 | 7 | def flatten(v): 8 | """ 9 | Flatten a list of lists/tuples 10 | """ 11 | 12 | return [x for y in v for x in y] 13 | 14 | 15 | def rescale(x): 16 | """ 17 | Rescale a tensor to 0-1 18 | """ 19 | 20 | return (x - x.min()) / (x.max() - x.min()) 21 | 22 | 23 | def find_max_epoch(path): 24 | """ 25 | Find maximum epoch/iteration in path, formatted ${n_iter}.pkl 26 | E.g. 100000.pkl 27 | 28 | Parameters: 29 | path (str): checkpoint path 30 | 31 | Returns: 32 | maximum iteration, -1 if there is no (valid) checkpoint 33 | """ 34 | 35 | files = os.listdir(path) 36 | epoch = -1 37 | for f in files: 38 | if len(f) <= 4: 39 | continue 40 | if f[-4:] == '.pkl': 41 | try: 42 | epoch = max(epoch, int(f[:-4])) 43 | except: 44 | continue 45 | #print(path, epoch, flush=True) 46 | return epoch 47 | 48 | 49 | def print_size(net): 50 | """ 51 | Print the number of parameters of a network 52 | """ 53 | 54 | if net is not None and isinstance(net, torch.nn.Module): 55 | module_parameters = filter(lambda p: p.requires_grad, net.parameters()) 56 | params = sum([np.prod(p.size()) for p in module_parameters]) 57 | print("{} Parameters: {:.6f}M".format( 58 | net.__class__.__name__, params / 1e6), flush=True) 59 | 60 | 61 | # Utilities for diffusion models 62 | 63 | def std_normal(size, device="cuda:0"): 64 | """ 65 | Generate the standard Gaussian variable of a certain size 66 | """ 67 | 68 | return torch.normal(0, 1, size=size).to(device) 69 | 70 | 71 | def calc_noise_scale_embedding(noise_scales, noise_scale_embed_dim_in, device="cuda:0"): 72 | """ 73 | Embed a noise scale $t$ into a higher dimensional space 74 | E.g. the embedding vector in the 128-dimensional space is 75 | [sin(t * 10^(0*4/63)), ... , sin(t * 10^(63*4/63)), cos(t * 10^(0*4/63)), ... , cos(t * 10^(63*4/63))] 76 | 77 | Parameters: 78 | noise_scales (torch.long tensor, shape=(batchsize, 1)): 79 | noise scales for batch data 80 | noise_scale_embed_dim_in (int, default=128): 81 | dimensionality of the embedding space for discrete noise scales 82 | 83 | Returns: 84 | the embedding vectors (torch.tensor, shape=(batchsize, noise_scale_embed_dim_in)): 85 | """ 86 | 87 | assert noise_scale_embed_dim_in % 2 == 0 88 | 89 | half_dim = noise_scale_embed_dim_in // 2 90 | _embed = np.log(10000) / (half_dim - 1) 91 | _embed = torch.exp(torch.arange(half_dim) * -_embed).to(device) 92 | _embed = noise_scales * _embed 93 | noise_scale_embed = torch.cat((torch.sin(_embed), 94 | torch.cos(_embed)), 1) 95 | 96 | return noise_scale_embed 97 | 98 | 99 | def calc_diffusion_hyperparams_given_beta(beta): 100 | """ 101 | Compute diffusion process hyperparameters 102 | 103 | Parameters: 104 | beta (tensor): beta schedule 105 | 106 | Returns: 107 | a dictionary of diffusion hyperparameters including: 108 | T (int), beta/alpha/sigma (torch.tensor on cpu, shape=(T, )) 109 | These cpu tensors are changed to cuda tensors on each individual gpu 110 | """ 111 | 112 | T = len(beta) 113 | alpha = 1 - beta 114 | sigma = beta + 0 115 | for t in range(1, T): 116 | alpha[t] *= alpha[t-1] # \alpha^2_t = \prod_{s=1}^t (1-\beta_s) 117 | sigma[t] *= (1-alpha[t-1]) / (1-alpha[t]) # \sigma^2_t = \beta_t * (1-\alpha_{t-1}) / (1-\alpha_t) 118 | alpha = torch.sqrt(alpha) 119 | sigma = torch.sqrt(sigma) 120 | 121 | _dh = {} 122 | _dh["T"], _dh["beta"], _dh["alpha"], _dh["sigma"] = T, beta, alpha, sigma 123 | diffusion_hyperparams = _dh 124 | return diffusion_hyperparams 125 | 126 | 127 | def calc_diffusion_hyperparams(T, beta_0, beta_T, tau, N, beta_N, alpha_N, rho): 128 | """ 129 | Compute diffusion process hyperparameters 130 | 131 | Parameters: 132 | T (int): number of noise scales 133 | beta_0 and beta_T (float): beta schedule start/end value, 134 | where any beta_t in the middle is linearly interpolated 135 | 136 | Returns: 137 | a dictionary of diffusion hyperparameters including: 138 | T (int), beta/alpha/sigma (torch.tensor on cpu, shape=(T, )) 139 | These cpu tensors are changed to cuda tensors on each individual gpu 140 | """ 141 | 142 | beta = torch.linspace(beta_0, beta_T, T) 143 | alpha = 1 - beta 144 | sigma = beta + 0 145 | for t in range(1, T): 146 | alpha[t] *= alpha[t-1] # \alpha^2_t = \prod_{s=1}^t (1-\beta_s) 147 | sigma[t] *= (1-alpha[t-1]) / (1-alpha[t]) # \sigma^2_t = \beta_t * (1-\alpha_{t-1}) / (1-\alpha_t) 148 | alpha = torch.sqrt(alpha) 149 | sigma = torch.sqrt(sigma) 150 | 151 | _dh = {} 152 | _dh["T"], _dh["beta"], _dh["alpha"], _dh["sigma"] = T, beta, alpha, sigma 153 | _dh["tau"], _dh["N"], _dh["betaN"], _dh["alphaN"], _dh["rho"] = tau, N, beta_N, alpha_N, rho 154 | diffusion_hyperparams = _dh 155 | return diffusion_hyperparams 156 | 157 | 158 | def sampling_given_noise_schedule( 159 | net, 160 | size, 161 | diffusion_hyperparams, 162 | inference_noise_schedule, 163 | condition=None, 164 | ddim=False, 165 | return_sequence=False, 166 | device="cuda:0" 167 | ): 168 | """ 169 | Perform the complete sampling step according to p(x_0|x_T) = \prod_{t=1}^T p_{\theta}(x_{t-1}|x_t) 170 | Parameters: 171 | net (torch network): the wavenet models 172 | size (tuple): size of tensor to be generated, 173 | usually is (number of audios to generate, channels=1, length of audio) 174 | diffusion_hyperparams (dict): dictionary of diffusion hyperparameters returned by calc_diffusion_hyperparams 175 | note, the tensors need to be cuda tensors 176 | condition (torch.tensor): ground truth mel spectrogram read from disk 177 | None if used for unconditional generation 178 | Returns: 179 | the generated audio(s) in torch.tensor, shape=size 180 | """ 181 | 182 | _dh = diffusion_hyperparams 183 | T, alpha = _dh["T"], _dh["alpha"] 184 | assert len(alpha) == T 185 | assert len(size) == 3 or len(size) == 2 186 | 187 | N = len(inference_noise_schedule) 188 | beta_infer = inference_noise_schedule 189 | alpha_infer = 1 - beta_infer 190 | sigma_infer = beta_infer + 0 191 | for n in range(1, N): 192 | alpha_infer[n] *= alpha_infer[n - 1] 193 | sigma_infer[n] *= (1 - alpha_infer[n - 1]) / (1 - alpha_infer[n]) 194 | alpha_infer = torch.sqrt(alpha_infer) 195 | sigma_infer = torch.sqrt(sigma_infer) 196 | 197 | # Mapping noise scales to time steps 198 | steps_infer = [] 199 | for n in range(N): 200 | step = map_noise_scale_to_time_step(alpha_infer[n], alpha) 201 | if step >= 0: 202 | steps_infer.append(step) 203 | # print(steps_infer, flush=True) 204 | steps_infer = torch.FloatTensor(steps_infer) 205 | 206 | # N may change since alpha_infer can be out of the range of alpha 207 | N = len(steps_infer) 208 | 209 | #print('begin sampling, total number of reverse steps = %s' % N) 210 | 211 | x = std_normal(size, device=device) 212 | if return_sequence: 213 | x_ = copy.deepcopy(x) 214 | xs = [x_] 215 | with torch.no_grad(): 216 | for n in range(N - 1, -1, -1): 217 | diffusion_steps = (steps_infer[n] * torch.ones((size[0], 1))).to(device) 218 | # print(x.shape, condition.shape, diffusion_steps.shape) 219 | # raise 220 | epsilon_theta = net(x, condition, diffusion_steps) #net((x, condition, diffusion_steps,)) 221 | if ddim: 222 | alpha_next = alpha_infer[n] / (1 - beta_infer[n]).sqrt() 223 | c1 = alpha_next / alpha_infer[n] 224 | c2 = -(1 - alpha_infer[n] ** 2.).sqrt() * c1 225 | c3 = (1 - alpha_next ** 2.).sqrt() 226 | x = c1 * x + c2 * epsilon_theta + c3 * epsilon_theta # std_normal(size) 227 | else: 228 | x -= beta_infer[n] / torch.sqrt(1 - alpha_infer[n] ** 2.) * epsilon_theta 229 | x /= torch.sqrt(1 - beta_infer[n]) 230 | if n > 0: 231 | x = x + sigma_infer[n] * std_normal(size, device=device) 232 | if return_sequence: 233 | x_ = copy.deepcopy(x) 234 | xs.append(x_) 235 | if return_sequence: 236 | return xs 237 | return x 238 | 239 | def theta_timestep_loss(net, X, diffusion_hyperparams, reverse=False, device="cuda:0"): 240 | """ 241 | Compute the training loss for learning theta 242 | 243 | Parameters: 244 | net (torch network): the wavenet models 245 | X (tuple, shape=(2,)): training data in tuple form (mel_spectrograms, audios) 246 | mel_spectrograms: torch.tensor, shape is batchsize followed by each mel_spectrogram shape 247 | audios: torch.tensor, shape=(batchsize, 1, length of audio) 248 | diffusion_hyperparams (dict): dictionary of diffusion hyperparameters returned by calc_diffusion_hyperparams 249 | note, the tensors need to be cuda tensors 250 | 251 | Returns: 252 | theta loss 253 | """ 254 | assert type(X) == tuple and len(X) == 2 255 | loss_fn = nn.MSELoss() 256 | 257 | _dh = diffusion_hyperparams 258 | T, alpha = _dh["T"], _dh["alpha"] 259 | 260 | mel_spectrogram, audio = X 261 | B, C, L = audio.shape # B is batchsize, C=1, L is audio length 262 | ts = torch.randint(T, size=(B, 1, 1)).to(device) # randomly sample steps from 1~T 263 | z = std_normal(audio.shape, device=device) 264 | delta = (1 - alpha[ts] ** 2.).sqrt() 265 | alpha_cur = alpha[ts] 266 | noisy_audio = alpha_cur * audio + delta * z # compute x_t from q(x_t|x_0) 267 | epsilon_theta = net((noisy_audio, mel_spectrogram, ts.view(B, 1),)) 268 | 269 | if reverse: 270 | x0 = (noisy_audio - delta * epsilon_theta) / alpha_cur 271 | return loss_fn(epsilon_theta, z), x0 272 | 273 | return loss_fn(epsilon_theta, z) 274 | 275 | 276 | def compute_hyperparams_given_schedule(beta): 277 | """ 278 | Compute diffusion process hyperparameters 279 | 280 | Parameters: 281 | beta (tensor): beta schedule 282 | 283 | Returns: 284 | a dictionary of diffusion hyperparameters including: 285 | T (int), beta/alpha/sigma (torch.tensor on cpu, shape=(T, )) 286 | These cpu tensors are changed to cuda tensors on each individual gpu 287 | """ 288 | 289 | T = len(beta) 290 | alpha = 1 - beta 291 | sigma = beta + 0 292 | for t in range(1, T): 293 | alpha[t] *= alpha[t - 1] # \alpha^2_t = \prod_{s=1}^t (1-\beta_s) 294 | sigma[t] *= (1 - alpha[t - 1]) / (1 - alpha[t]) # \sigma^2_t = \beta_t * (1-\alpha_{t-1}) / (1-\alpha_t) 295 | alpha = torch.sqrt(alpha) 296 | sigma = torch.sqrt(sigma) 297 | 298 | _dh = {} 299 | _dh["T"], _dh["beta"], _dh["alpha"], _dh["sigma"] = T, beta, alpha, sigma 300 | diffusion_hyperparams = _dh 301 | return diffusion_hyperparams 302 | 303 | 304 | 305 | def map_noise_scale_to_time_step(alpha_infer, alpha): 306 | if alpha_infer < alpha[-1]: 307 | return len(alpha) - 1 308 | if alpha_infer > alpha[0]: 309 | return 0 310 | for t in range(len(alpha) - 1): 311 | if alpha[t+1] <= alpha_infer <= alpha[t]: 312 | step_diff = alpha[t] - alpha_infer 313 | step_diff /= alpha[t] - alpha[t+1] 314 | return t + step_diff.item() 315 | return -1 316 | 317 | 318 | def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in, device="cuda:0"): 319 | """ 320 | Embed a diffusion step $t$ into a higher dimensional space 321 | E.g. the embedding vector in the 128-dimensional space is 322 | [sin(t * 10^(0*4/63)), ... , sin(t * 10^(63*4/63)), cos(t * 10^(0*4/63)), ... , cos(t * 10^(63*4/63))] 323 | 324 | Parameters: 325 | diffusion_steps (torch.long tensor, shape=(batchsize, 1)): 326 | diffusion steps for batch data 327 | diffusion_step_embed_dim_in (int, default=128): 328 | dimensionality of the embedding space for discrete diffusion steps 329 | 330 | Returns: 331 | the embedding vectors (torch.tensor, shape=(batchsize, diffusion_step_embed_dim_in)): 332 | """ 333 | 334 | assert diffusion_step_embed_dim_in % 2 == 0 335 | 336 | half_dim = diffusion_step_embed_dim_in // 2 337 | _embed = np.log(10000) / (half_dim - 1) 338 | _embed = torch.exp(torch.arange(half_dim) * -_embed).to(device) 339 | _embed = diffusion_steps * _embed 340 | diffusion_step_embed = torch.cat((torch.sin(_embed), 341 | torch.cos(_embed)), 1) 342 | 343 | return diffusion_step_embed -------------------------------------------------------------------------------- /litfass/third_party/hifigan/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Jungil Kong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /litfass/third_party/hifigan/__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | from pathlib import Path 4 | import os 5 | 6 | from .models import Generator 7 | 8 | 9 | class AttrDict(dict): 10 | def __init__(self, *args, **kwargs): 11 | super(AttrDict, self).__init__(*args, **kwargs) 12 | self.__dict__ = self 13 | 14 | 15 | # TODO: decompress dynamic range 16 | 17 | 18 | class Synthesiser: 19 | def __init__( 20 | self, 21 | device="cuda:0", 22 | model="universal", 23 | ): 24 | with open(Path(__file__).parent / "config.json", "r") as f: 25 | config = json.load(f) 26 | config = AttrDict(config) 27 | vocoder = Generator(config) 28 | ckpt = torch.load(Path(__file__).parent / f"generator_{model}.pth.tar") 29 | vocoder.load_state_dict(ckpt["generator"]) 30 | vocoder.eval() 31 | vocoder.remove_weight_norm() 32 | self.device = device 33 | vocoder.to(self.device) 34 | self.vocoder = vocoder 35 | 36 | def __call__(self, mel): 37 | mel = torch.unsqueeze(mel.T, 0) 38 | result = ( 39 | self.vocoder(mel.to(self.device)).squeeze(1).cpu().detach().numpy() 40 | * 32768.0 41 | ).astype("int16") 42 | return result 43 | -------------------------------------------------------------------------------- /litfass/third_party/hifigan/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "resblock": "1", 3 | "num_gpus": 0, 4 | "batch_size": 16, 5 | "learning_rate": 0.0002, 6 | "adam_b1": 0.8, 7 | "adam_b2": 0.99, 8 | "lr_decay": 0.999, 9 | "seed": 1234, 10 | 11 | "upsample_rates": [8,8,2,2], 12 | "upsample_kernel_sizes": [16,16,4,4], 13 | "upsample_initial_channel": 512, 14 | "resblock_kernel_sizes": [3,7,11], 15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 16 | 17 | "segment_size": 8192, 18 | "num_mels": 80, 19 | "num_freq": 1025, 20 | "n_fft": 1024, 21 | "hop_size": 256, 22 | "win_size": 1024, 23 | "sampling_rate": 22050, 24 | 25 | "fmin": 0, 26 | "fmax": 8000, 27 | "fmax_for_loss": null, 28 | 29 | "num_workers": 4, 30 | 31 | "dist_config": { 32 | "dist_backend": "nccl", 33 | "dist_url": "tcp://localhost:54321", 34 | "world_size": 1 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /litfass/third_party/hifigan/generator_LJSpeech.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiniXC/LightningFastSpeech2/23a83937c09d8cb2590d3b27212c7ac497e2a120/litfass/third_party/hifigan/generator_LJSpeech.pth.tar -------------------------------------------------------------------------------- /litfass/third_party/hifigan/generator_universal.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiniXC/LightningFastSpeech2/23a83937c09d8cb2590d3b27212c7ac497e2a120/litfass/third_party/hifigan/generator_universal.pth.tar -------------------------------------------------------------------------------- /litfass/third_party/hifigan/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import Conv1d, ConvTranspose1d 5 | from torch.nn.utils import weight_norm, remove_weight_norm 6 | 7 | LRELU_SLOPE = 0.1 8 | 9 | 10 | def init_weights(m, mean=0.0, std=0.01): 11 | classname = m.__class__.__name__ 12 | if classname.find("Conv") != -1: 13 | m.weight.data.normal_(mean, std) 14 | 15 | 16 | def get_padding(kernel_size, dilation=1): 17 | return int((kernel_size * dilation - dilation) / 2) 18 | 19 | 20 | class ResBlock(torch.nn.Module): 21 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): 22 | super(ResBlock, self).__init__() 23 | self.h = h 24 | self.convs1 = nn.ModuleList( 25 | [ 26 | weight_norm( 27 | Conv1d( 28 | channels, 29 | channels, 30 | kernel_size, 31 | 1, 32 | dilation=dilation[0], 33 | padding=get_padding(kernel_size, dilation[0]), 34 | ) 35 | ), 36 | weight_norm( 37 | Conv1d( 38 | channels, 39 | channels, 40 | kernel_size, 41 | 1, 42 | dilation=dilation[1], 43 | padding=get_padding(kernel_size, dilation[1]), 44 | ) 45 | ), 46 | weight_norm( 47 | Conv1d( 48 | channels, 49 | channels, 50 | kernel_size, 51 | 1, 52 | dilation=dilation[2], 53 | padding=get_padding(kernel_size, dilation[2]), 54 | ) 55 | ), 56 | ] 57 | ) 58 | self.convs1.apply(init_weights) 59 | 60 | self.convs2 = nn.ModuleList( 61 | [ 62 | weight_norm( 63 | Conv1d( 64 | channels, 65 | channels, 66 | kernel_size, 67 | 1, 68 | dilation=1, 69 | padding=get_padding(kernel_size, 1), 70 | ) 71 | ), 72 | weight_norm( 73 | Conv1d( 74 | channels, 75 | channels, 76 | kernel_size, 77 | 1, 78 | dilation=1, 79 | padding=get_padding(kernel_size, 1), 80 | ) 81 | ), 82 | weight_norm( 83 | Conv1d( 84 | channels, 85 | channels, 86 | kernel_size, 87 | 1, 88 | dilation=1, 89 | padding=get_padding(kernel_size, 1), 90 | ) 91 | ), 92 | ] 93 | ) 94 | self.convs2.apply(init_weights) 95 | 96 | def forward(self, x): 97 | for c1, c2 in zip(self.convs1, self.convs2): 98 | xt = F.leaky_relu(x, LRELU_SLOPE) 99 | xt = c1(xt) 100 | xt = F.leaky_relu(xt, LRELU_SLOPE) 101 | xt = c2(xt) 102 | x = xt + x 103 | return x 104 | 105 | def remove_weight_norm(self): 106 | for l in self.convs1: 107 | remove_weight_norm(l) 108 | for l in self.convs2: 109 | remove_weight_norm(l) 110 | 111 | 112 | class Generator(torch.nn.Module): 113 | def __init__(self, h): 114 | super(Generator, self).__init__() 115 | self.h = h 116 | self.num_kernels = len(h.resblock_kernel_sizes) 117 | self.num_upsamples = len(h.upsample_rates) 118 | self.conv_pre = weight_norm( 119 | Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3) 120 | ) 121 | resblock = ResBlock 122 | 123 | self.ups = nn.ModuleList() 124 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): 125 | self.ups.append( 126 | weight_norm( 127 | ConvTranspose1d( 128 | h.upsample_initial_channel // (2**i), 129 | h.upsample_initial_channel // (2 ** (i + 1)), 130 | k, 131 | u, 132 | padding=(k - u) // 2, 133 | ) 134 | ) 135 | ) 136 | 137 | self.resblocks = nn.ModuleList() 138 | for i in range(len(self.ups)): 139 | ch = h.upsample_initial_channel // (2 ** (i + 1)) 140 | for j, (k, d) in enumerate( 141 | zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) 142 | ): 143 | self.resblocks.append(resblock(h, ch, k, d)) 144 | 145 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) 146 | self.ups.apply(init_weights) 147 | self.conv_post.apply(init_weights) 148 | 149 | def forward(self, x): 150 | x = self.conv_pre(x) 151 | for i in range(self.num_upsamples): 152 | x = F.leaky_relu(x, LRELU_SLOPE) 153 | x = self.ups[i](x) 154 | xs = None 155 | for j in range(self.num_kernels): 156 | if xs is None: 157 | xs = self.resblocks[i * self.num_kernels + j](x) 158 | else: 159 | xs += self.resblocks[i * self.num_kernels + j](x) 160 | x = xs / self.num_kernels 161 | x = F.leaky_relu(x) 162 | x = self.conv_post(x) 163 | x = torch.tanh(x) 164 | 165 | return x 166 | 167 | def remove_weight_norm(self): 168 | print("Removing weight norm...") 169 | for l in self.ups: 170 | remove_weight_norm(l) 171 | for l in self.resblocks: 172 | l.remove_weight_norm() 173 | remove_weight_norm(self.conv_pre) 174 | remove_weight_norm(self.conv_post) 175 | -------------------------------------------------------------------------------- /litfass/third_party/softdtw/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from numba import jit 4 | from torch.autograd import Function 5 | 6 | 7 | @jit(nopython=True) 8 | def compute_softdtw(D, gamma): 9 | B = D.shape[0] 10 | N = D.shape[1] 11 | M = D.shape[2] 12 | R = np.ones((B, N + 2, M + 2)) * np.inf 13 | R[:, 0, 0] = 0 14 | for k in range(B): 15 | for j in range(1, M + 1): 16 | for i in range(1, N + 1): 17 | r0 = -R[k, i - 1, j - 1] / gamma 18 | r1 = -R[k, i - 1, j] / gamma 19 | r2 = -R[k, i, j - 1] / gamma 20 | rmax = max(max(r0, r1), r2) 21 | rsum = np.exp(r0 - rmax) + np.exp(r1 - rmax) + np.exp(r2 - rmax) 22 | softmin = -gamma * (np.log(rsum) + rmax) 23 | R[k, i, j] = D[k, i - 1, j - 1] + softmin 24 | return R 25 | 26 | 27 | @jit(nopython=True) 28 | def compute_softdtw_backward(D_, R, gamma): 29 | B = D_.shape[0] 30 | N = D_.shape[1] 31 | M = D_.shape[2] 32 | D = np.zeros((B, N + 2, M + 2)) 33 | E = np.zeros((B, N + 2, M + 2)) 34 | D[:, 1 : N + 1, 1 : M + 1] = D_ 35 | E[:, -1, -1] = 1 36 | R[:, :, -1] = -np.inf 37 | R[:, -1, :] = -np.inf 38 | R[:, -1, -1] = R[:, -2, -2] 39 | for k in range(B): 40 | for j in range(M, 0, -1): 41 | for i in range(N, 0, -1): 42 | a0 = (R[k, i + 1, j] - R[k, i, j] - D[k, i + 1, j]) / gamma 43 | b0 = (R[k, i, j + 1] - R[k, i, j] - D[k, i, j + 1]) / gamma 44 | c0 = (R[k, i + 1, j + 1] - R[k, i, j] - D[k, i + 1, j + 1]) / gamma 45 | a = np.exp(a0) 46 | b = np.exp(b0) 47 | c = np.exp(c0) 48 | E[k, i, j] = ( 49 | E[k, i + 1, j] * a + E[k, i, j + 1] * b + E[k, i + 1, j + 1] * c 50 | ) 51 | return E[:, 1 : N + 1, 1 : M + 1] 52 | 53 | 54 | class _SoftDTW(Function): 55 | @staticmethod 56 | def forward(ctx, D, gamma): 57 | dev = D.device 58 | dtype = D.dtype 59 | gamma = torch.Tensor([gamma]).to(dev).type(dtype) # dtype fixed 60 | D_ = D.detach().cpu().numpy() 61 | g_ = gamma.item() 62 | R = torch.Tensor(compute_softdtw(D_, g_)).to(dev).type(dtype) 63 | ctx.save_for_backward(D, R, gamma) 64 | return R[:, -2, -2] 65 | 66 | @staticmethod 67 | def backward(ctx, grad_output): 68 | dev = grad_output.device 69 | dtype = grad_output.dtype 70 | D, R, gamma = ctx.saved_tensors 71 | D_ = D.detach().cpu().numpy() 72 | R_ = R.detach().cpu().numpy() 73 | g_ = gamma.item() 74 | E = torch.Tensor(compute_softdtw_backward(D_, R_, g_)).to(dev).type(dtype) 75 | return grad_output.view(-1, 1, 1).expand_as(E) * E, None 76 | 77 | 78 | class SoftDTW(torch.nn.Module): 79 | def __init__(self, gamma=1.0, normalize=False): 80 | super(SoftDTW, self).__init__() 81 | self.normalize = normalize 82 | self.gamma = gamma 83 | self.func_dtw = _SoftDTW.apply 84 | 85 | def calc_distance_matrix(self, x, y): 86 | n = x.size(1) 87 | m = y.size(1) 88 | d = x.size(2) 89 | x = x.unsqueeze(2).expand(-1, n, m, d) 90 | y = y.unsqueeze(1).expand(-1, n, m, d) 91 | dist = torch.pow(x - y, 2).sum(3) 92 | return dist 93 | 94 | def forward(self, x, y): 95 | assert len(x.shape) == len(y.shape) 96 | squeeze = False 97 | if len(x.shape) < 3: 98 | x = x.unsqueeze(0) 99 | y = y.unsqueeze(0) 100 | squeeze = True 101 | if self.normalize: 102 | D_xy = self.calc_distance_matrix(x, y) 103 | out_xy = self.func_dtw(D_xy, self.gamma) 104 | D_xx = self.calc_distance_matrix(x, x) 105 | out_xx = self.func_dtw(D_xx, self.gamma) 106 | D_yy = self.calc_distance_matrix(y, y) 107 | out_yy = self.func_dtw(D_yy, self.gamma) 108 | result = out_xy - 1 / 2 * (out_xx + out_yy) # distance 109 | else: 110 | D_xy = self.calc_distance_matrix(x, y) 111 | out_xy = self.func_dtw(D_xy, self.gamma) 112 | result = out_xy # discrepancy 113 | return result.squeeze(0) if squeeze else result 114 | -------------------------------------------------------------------------------- /litfass/third_party/stochastic_duration_predictor/normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LayerNorm2(nn.Module): 6 | """Layer norm for the 2nd dimension of the input using torch primitive. 7 | Args: 8 | channels (int): number of channels (2nd dimension) of the input. 9 | eps (float): to prevent 0 division 10 | Shapes: 11 | - input: (B, C, T) 12 | - output: (B, C, T) 13 | """ 14 | 15 | def __init__(self, channels, eps=1e-5): 16 | super().__init__() 17 | self.channels = channels 18 | self.eps = eps 19 | 20 | self.gamma = nn.Parameter(torch.ones(channels)) 21 | self.beta = nn.Parameter(torch.zeros(channels)) 22 | 23 | def forward(self, x): 24 | x = x.transpose(1, -1) 25 | x = torch.nn.functional.layer_norm( 26 | x, (self.channels,), self.gamma, self.beta, self.eps 27 | ) 28 | return x.transpose(1, -1) 29 | -------------------------------------------------------------------------------- /litfass/third_party/stochastic_duration_predictor/sdp.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from .normalization import LayerNorm2 8 | from .transforms import piecewise_rational_quadratic_transform 9 | 10 | 11 | class DilatedDepthSeparableConv(nn.Module): 12 | def __init__( 13 | self, channels, kernel_size, num_layers, dropout_p=0.0 14 | ) -> torch.tensor: 15 | """Dilated Depth-wise Separable Convolution module. 16 | 17 | :: 18 | x |-> DDSConv(x) -> LayerNorm(x) -> GeLU(x) -> Conv1x1(x) -> LayerNorm(x) -> GeLU(x) -> + -> o 19 | |-------------------------------------------------------------------------------------^ 20 | 21 | Args: 22 | channels ([type]): [description] 23 | kernel_size ([type]): [description] 24 | num_layers ([type]): [description] 25 | dropout_p (float, optional): [description]. Defaults to 0.0. 26 | 27 | Returns: 28 | torch.tensor: Network output masked by the input sequence mask. 29 | """ 30 | super().__init__() 31 | self.num_layers = num_layers 32 | 33 | self.convs_sep = nn.ModuleList() 34 | self.convs_1x1 = nn.ModuleList() 35 | self.norms_1 = nn.ModuleList() 36 | self.norms_2 = nn.ModuleList() 37 | for i in range(num_layers): 38 | dilation = kernel_size**i 39 | padding = (kernel_size * dilation - dilation) // 2 40 | self.convs_sep.append( 41 | nn.Conv1d( 42 | channels, 43 | channels, 44 | kernel_size, 45 | groups=channels, 46 | dilation=dilation, 47 | padding=padding, 48 | ) 49 | ) 50 | self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) 51 | self.norms_1.append(LayerNorm2(channels)) 52 | self.norms_2.append(LayerNorm2(channels)) 53 | self.dropout = nn.Dropout(dropout_p) 54 | 55 | def forward(self, x, x_mask, g=None): 56 | """ 57 | Shapes: 58 | - x: :math:`[B, C, T]` 59 | - x_mask: :math:`[B, 1, T]` 60 | """ 61 | if g is not None: 62 | x = x + g 63 | for i in range(self.num_layers): 64 | y = self.convs_sep[i](x * x_mask) 65 | y = self.norms_1[i](y) 66 | y = F.gelu(y) 67 | y = self.convs_1x1[i](y) 68 | y = self.norms_2[i](y) 69 | y = F.gelu(y) 70 | y = self.dropout(y) 71 | x = x + y 72 | return x * x_mask 73 | 74 | 75 | class ElementwiseAffine(nn.Module): 76 | """Element-wise affine transform like no-population stats BatchNorm alternative. 77 | 78 | Args: 79 | channels (int): Number of input tensor channels. 80 | """ 81 | 82 | def __init__(self, channels): 83 | super().__init__() 84 | self.translation = nn.Parameter(torch.zeros(channels, 1)) 85 | self.log_scale = nn.Parameter(torch.zeros(channels, 1)) 86 | 87 | def forward( 88 | self, x, x_mask, reverse=False, **kwargs 89 | ): # pylint: disable=unused-argument 90 | if not reverse: 91 | y = (x * torch.exp(self.log_scale) + self.translation) * x_mask 92 | logdet = torch.sum(self.log_scale * x_mask, [1, 2]) 93 | return y, logdet 94 | x = (x - self.translation) * torch.exp(-self.log_scale) * x_mask 95 | return x 96 | 97 | 98 | class ConvFlow(nn.Module): 99 | """Dilated depth separable convolutional based spline flow. 100 | 101 | Args: 102 | in_channels (int): Number of input tensor channels. 103 | hidden_channels (int): Number of in network channels. 104 | kernel_size (int): Convolutional kernel size. 105 | num_layers (int): Number of convolutional layers. 106 | num_bins (int, optional): Number of spline bins. Defaults to 10. 107 | tail_bound (float, optional): Tail bound for PRQT. Defaults to 5.0. 108 | """ 109 | 110 | def __init__( 111 | self, 112 | in_channels: int, 113 | hidden_channels: int, 114 | kernel_size: int, 115 | num_layers: int, 116 | num_bins=10, 117 | tail_bound=5.0, 118 | ): 119 | super().__init__() 120 | self.num_bins = num_bins 121 | self.tail_bound = tail_bound 122 | self.hidden_channels = hidden_channels 123 | self.half_channels = in_channels // 2 124 | 125 | self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) 126 | self.convs = DilatedDepthSeparableConv( 127 | hidden_channels, kernel_size, num_layers, dropout_p=0.0 128 | ) 129 | self.proj = nn.Conv1d( 130 | hidden_channels, self.half_channels * (num_bins * 3 - 1), 1 131 | ) 132 | self.proj.weight.data.zero_() 133 | self.proj.bias.data.zero_() 134 | 135 | def forward(self, x, x_mask, g=None, reverse=False): 136 | x0, x1 = torch.split(x, [self.half_channels] * 2, 1) 137 | h = self.pre(x0) 138 | h = self.convs(h, x_mask, g=g) 139 | h = self.proj(h) * x_mask 140 | 141 | b, c, t = x0.shape 142 | h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] 143 | 144 | unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.hidden_channels) 145 | unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt( 146 | self.hidden_channels 147 | ) 148 | unnormalized_derivatives = h[..., 2 * self.num_bins :] 149 | 150 | x1, logabsdet = piecewise_rational_quadratic_transform( 151 | x1, 152 | unnormalized_widths, 153 | unnormalized_heights, 154 | unnormalized_derivatives, 155 | inverse=reverse, 156 | tails="linear", 157 | tail_bound=self.tail_bound, 158 | ) 159 | 160 | x = torch.cat([x0, x1], 1) * x_mask 161 | logdet = torch.sum(logabsdet * x_mask, [1, 2]) 162 | if not reverse: 163 | return x, logdet 164 | return x 165 | 166 | 167 | class StochasticDurationPredictor(nn.Module): 168 | """Stochastic duration predictor with Spline Flows. 169 | 170 | It applies Variational Dequantization and Variationsl Data Augmentation. 171 | 172 | Paper: 173 | SDP: https://arxiv.org/pdf/2106.06103.pdf 174 | Spline Flow: https://arxiv.org/abs/1906.04032 175 | 176 | :: 177 | ## Inference 178 | 179 | x -> TextCondEncoder() -> Flow() -> dr_hat 180 | noise ----------------------^ 181 | 182 | ## Training 183 | |---------------------| 184 | x -> TextCondEncoder() -> + -> PosteriorEncoder() -> split() -> z_u, z_v -> (d - z_u) -> concat() -> Flow() -> noise 185 | d -> DurCondEncoder() -> ^ | 186 | |------------------------------------------------------------------------------| 187 | 188 | Args: 189 | in_channels (int): Number of input tensor channels. 190 | hidden_channels (int): Number of hidden channels. 191 | kernel_size (int): Kernel size of convolutional layers. 192 | dropout_p (float): Dropout rate. 193 | num_flows (int, optional): Number of flow blocks. Defaults to 4. 194 | cond_channels (int, optional): Number of channels of conditioning tensor. Defaults to 0. 195 | """ 196 | 197 | def __init__( 198 | self, 199 | in_channels: int, 200 | hidden_channels: int, 201 | kernel_size: int, 202 | dropout_p: float, 203 | num_flows=4, 204 | cond_channels=0, 205 | language_emb_dim=0, 206 | ): 207 | super().__init__() 208 | 209 | # add language embedding dim in the input 210 | if language_emb_dim: 211 | in_channels += language_emb_dim 212 | 213 | # condition encoder text 214 | self.pre = nn.Conv1d(in_channels, hidden_channels, 1) 215 | self.convs = DilatedDepthSeparableConv( 216 | hidden_channels, kernel_size, num_layers=3, dropout_p=dropout_p 217 | ) 218 | self.proj = nn.Conv1d(hidden_channels, hidden_channels, 1) 219 | 220 | # posterior encoder 221 | self.flows = nn.ModuleList() 222 | self.flows.append(ElementwiseAffine(2)) 223 | self.flows += [ 224 | ConvFlow(2, hidden_channels, kernel_size, num_layers=3) 225 | for _ in range(num_flows) 226 | ] 227 | 228 | # condition encoder duration 229 | self.post_pre = nn.Conv1d(1, hidden_channels, 1) 230 | self.post_convs = DilatedDepthSeparableConv( 231 | hidden_channels, kernel_size, num_layers=3, dropout_p=dropout_p 232 | ) 233 | self.post_proj = nn.Conv1d(hidden_channels, hidden_channels, 1) 234 | 235 | # flow layers 236 | self.post_flows = nn.ModuleList() 237 | self.post_flows.append(ElementwiseAffine(2)) 238 | self.post_flows += [ 239 | ConvFlow(2, hidden_channels, kernel_size, num_layers=3) 240 | for _ in range(num_flows) 241 | ] 242 | 243 | if cond_channels != 0 and cond_channels is not None: 244 | self.cond = nn.Conv1d(cond_channels, hidden_channels, 1) 245 | 246 | if language_emb_dim != 0 and language_emb_dim is not None: 247 | self.cond_lang = nn.Conv1d(language_emb_dim, hidden_channels, 1) 248 | 249 | def forward( 250 | self, x, x_mask, dr=None, g=None, lang_emb=None, reverse=False, noise_scale=1.0 251 | ): 252 | """ 253 | Shapes: 254 | - x: :math:`[B, C, T]` 255 | - x_mask: :math:`[B, 1, T]` 256 | - dr: :math:`[B, 1, T]` 257 | - g: :math:`[B, C]` 258 | """ 259 | 260 | # condition encoder text 261 | x = x.transpose(1, 2) 262 | x_mask = (~x_mask).unsqueeze(1).int() 263 | if dr is not None: 264 | dr = dr.unsqueeze(1).to(x.dtype) 265 | x = self.pre(x) 266 | if g is not None: 267 | g = g.transpose(1, 2) 268 | x = x + self.cond(g) 269 | 270 | if lang_emb is not None: 271 | x = x + self.cond_lang(lang_emb) 272 | 273 | x = self.convs(x, x_mask) 274 | x = self.proj(x) * x_mask 275 | 276 | if not reverse: 277 | flows = self.flows 278 | assert dr is not None 279 | 280 | # condition encoder duration 281 | h = self.post_pre(dr) 282 | h = self.post_convs(h, x_mask) 283 | h = self.post_proj(h) * x_mask 284 | noise = ( 285 | torch.randn(dr.size(0), 2, dr.size(2)).to( 286 | device=x.device, dtype=x.dtype 287 | ) 288 | * x_mask 289 | ) 290 | z_q = noise 291 | 292 | # posterior encoder 293 | logdet_tot_q = 0.0 294 | for idx, flow in enumerate(self.post_flows): 295 | z_q, logdet_q = flow(z_q, x_mask, g=(x + h)) 296 | logdet_tot_q = logdet_tot_q + logdet_q 297 | if idx > 0: 298 | z_q = torch.flip(z_q, [1]) 299 | 300 | z_u, z_v = torch.split(z_q, [1, 1], 1) 301 | u = torch.sigmoid(z_u) * x_mask 302 | z0 = (dr - u) * x_mask 303 | 304 | # posterior encoder - neg log likelihood 305 | logdet_tot_q += torch.sum( 306 | (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2] 307 | ) 308 | nll_posterior_encoder = ( 309 | torch.sum( 310 | -0.5 * (math.log(2 * math.pi) + (noise**2)) * x_mask, [1, 2] 311 | ) 312 | - logdet_tot_q 313 | ) 314 | 315 | z0 = torch.log(torch.clamp_min(z0, 1e-5)) * x_mask 316 | logdet_tot = torch.sum(-z0, [1, 2]) 317 | z = torch.cat([z0, z_v], 1) 318 | 319 | # flow layers 320 | for idx, flow in enumerate(flows): 321 | z, logdet = flow(z, x_mask, g=x, reverse=reverse) 322 | logdet_tot = logdet_tot + logdet 323 | if idx > 0: 324 | z = torch.flip(z, [1]) 325 | 326 | # flow layers - neg log likelihood 327 | nll_flow_layers = ( 328 | torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) 329 | - logdet_tot 330 | ) 331 | return nll_flow_layers + nll_posterior_encoder 332 | 333 | flows = list(reversed(self.flows)) 334 | flows = flows[:-2] + [flows[-1]] # remove a useless vflow 335 | z = ( 336 | torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) 337 | * noise_scale 338 | ) 339 | for flow in flows: 340 | z = torch.flip(z, [1]) 341 | z = flow(z, x_mask, g=x, reverse=reverse) 342 | 343 | z0, _ = torch.split(z, [1, 1], 1) 344 | 345 | logw = z0 346 | if not reverse: 347 | return logw 348 | else: 349 | return logw.transpose(1, 2).squeeze(2) 350 | -------------------------------------------------------------------------------- /litfass/third_party/stochastic_duration_predictor/transforms.py: -------------------------------------------------------------------------------- 1 | # adopted from https://github.com/bayesiains/nflows 2 | 3 | import numpy as np 4 | import torch 5 | from torch.nn import functional as F 6 | 7 | DEFAULT_MIN_BIN_WIDTH = 1e-3 8 | DEFAULT_MIN_BIN_HEIGHT = 1e-3 9 | DEFAULT_MIN_DERIVATIVE = 1e-3 10 | 11 | 12 | def piecewise_rational_quadratic_transform( 13 | inputs, 14 | unnormalized_widths, 15 | unnormalized_heights, 16 | unnormalized_derivatives, 17 | inverse=False, 18 | tails=None, 19 | tail_bound=1.0, 20 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 21 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 22 | min_derivative=DEFAULT_MIN_DERIVATIVE, 23 | ): 24 | 25 | if tails is None: 26 | spline_fn = rational_quadratic_spline 27 | spline_kwargs = {} 28 | else: 29 | spline_fn = unconstrained_rational_quadratic_spline 30 | spline_kwargs = {"tails": tails, "tail_bound": tail_bound} 31 | 32 | outputs, logabsdet = spline_fn( 33 | inputs=inputs, 34 | unnormalized_widths=unnormalized_widths, 35 | unnormalized_heights=unnormalized_heights, 36 | unnormalized_derivatives=unnormalized_derivatives, 37 | inverse=inverse, 38 | min_bin_width=min_bin_width, 39 | min_bin_height=min_bin_height, 40 | min_derivative=min_derivative, 41 | **spline_kwargs, 42 | ) 43 | return outputs, logabsdet 44 | 45 | 46 | def searchsorted(bin_locations, inputs, eps=1e-6): 47 | bin_locations[..., -1] += eps 48 | return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 49 | 50 | 51 | def unconstrained_rational_quadratic_spline( 52 | inputs, 53 | unnormalized_widths, 54 | unnormalized_heights, 55 | unnormalized_derivatives, 56 | inverse=False, 57 | tails="linear", 58 | tail_bound=1.0, 59 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 60 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 61 | min_derivative=DEFAULT_MIN_DERIVATIVE, 62 | ): 63 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) 64 | outside_interval_mask = ~inside_interval_mask 65 | 66 | outputs = torch.zeros_like(inputs) 67 | logabsdet = torch.zeros_like(inputs) 68 | 69 | if tails == "linear": 70 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) 71 | constant = np.log(np.exp(1 - min_derivative) - 1) 72 | unnormalized_derivatives[..., 0] = constant 73 | unnormalized_derivatives[..., -1] = constant 74 | 75 | outputs[outside_interval_mask] = inputs[outside_interval_mask] 76 | logabsdet[outside_interval_mask] = 0 77 | else: 78 | raise RuntimeError("{} tails are not implemented.".format(tails)) 79 | 80 | out, logabs = rational_quadratic_spline( 81 | inputs=inputs[inside_interval_mask], 82 | unnormalized_widths=unnormalized_widths[inside_interval_mask, :], 83 | unnormalized_heights=unnormalized_heights[inside_interval_mask, :], 84 | unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], 85 | inverse=inverse, 86 | left=-tail_bound, 87 | right=tail_bound, 88 | bottom=-tail_bound, 89 | top=tail_bound, 90 | min_bin_width=min_bin_width, 91 | min_bin_height=min_bin_height, 92 | min_derivative=min_derivative, 93 | ) 94 | 95 | out = out.to(inputs.dtype) 96 | logabs = logabs.to(inputs.dtype) 97 | 98 | outputs[inside_interval_mask], logabsdet[inside_interval_mask] = out, logabs 99 | 100 | return outputs, logabsdet 101 | 102 | 103 | def rational_quadratic_spline( 104 | inputs, 105 | unnormalized_widths, 106 | unnormalized_heights, 107 | unnormalized_derivatives, 108 | inverse=False, 109 | left=0.0, 110 | right=1.0, 111 | bottom=0.0, 112 | top=1.0, 113 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 114 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 115 | min_derivative=DEFAULT_MIN_DERIVATIVE, 116 | ): 117 | if torch.min(inputs) < left or torch.max(inputs) > right: 118 | raise ValueError("Input to a transform is not within its domain") 119 | 120 | num_bins = unnormalized_widths.shape[-1] 121 | 122 | if min_bin_width * num_bins > 1.0: 123 | raise ValueError("Minimal bin width too large for the number of bins") 124 | if min_bin_height * num_bins > 1.0: 125 | raise ValueError("Minimal bin height too large for the number of bins") 126 | 127 | widths = F.softmax(unnormalized_widths, dim=-1) 128 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths 129 | cumwidths = torch.cumsum(widths, dim=-1) 130 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0) 131 | cumwidths = (right - left) * cumwidths + left 132 | cumwidths[..., 0] = left 133 | cumwidths[..., -1] = right 134 | widths = cumwidths[..., 1:] - cumwidths[..., :-1] 135 | 136 | derivatives = min_derivative + F.softplus(unnormalized_derivatives) 137 | 138 | heights = F.softmax(unnormalized_heights, dim=-1) 139 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights 140 | cumheights = torch.cumsum(heights, dim=-1) 141 | cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0) 142 | cumheights = (top - bottom) * cumheights + bottom 143 | cumheights[..., 0] = bottom 144 | cumheights[..., -1] = top 145 | heights = cumheights[..., 1:] - cumheights[..., :-1] 146 | 147 | if inverse: 148 | bin_idx = searchsorted(cumheights, inputs)[..., None] 149 | else: 150 | bin_idx = searchsorted(cumwidths, inputs)[..., None] 151 | 152 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] 153 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0] 154 | 155 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] 156 | delta = heights / widths 157 | input_delta = delta.gather(-1, bin_idx)[..., 0] 158 | 159 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] 160 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] 161 | 162 | input_heights = heights.gather(-1, bin_idx)[..., 0] 163 | 164 | if inverse: 165 | a = (inputs - input_cumheights) * ( 166 | input_derivatives + input_derivatives_plus_one - 2 * input_delta 167 | ) + input_heights * (input_delta - input_derivatives) 168 | b = input_heights * input_derivatives - (inputs - input_cumheights) * ( 169 | input_derivatives + input_derivatives_plus_one - 2 * input_delta 170 | ) 171 | c = -input_delta * (inputs - input_cumheights) 172 | 173 | discriminant = b.pow(2) - 4 * a * c 174 | assert (discriminant >= 0).all() 175 | 176 | root = (2 * c) / (-b - torch.sqrt(discriminant)) 177 | outputs = root * input_bin_widths + input_cumwidths 178 | 179 | theta_one_minus_theta = root * (1 - root) 180 | denominator = input_delta + ( 181 | (input_derivatives + input_derivatives_plus_one - 2 * input_delta) 182 | * theta_one_minus_theta 183 | ) 184 | derivative_numerator = input_delta.pow(2) * ( 185 | input_derivatives_plus_one * root.pow(2) 186 | + 2 * input_delta * theta_one_minus_theta 187 | + input_derivatives * (1 - root).pow(2) 188 | ) 189 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 190 | 191 | return outputs, -logabsdet 192 | else: 193 | theta = (inputs - input_cumwidths) / input_bin_widths 194 | theta_one_minus_theta = theta * (1 - theta) 195 | 196 | numerator = input_heights * ( 197 | input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta 198 | ) 199 | denominator = input_delta + ( 200 | (input_derivatives + input_derivatives_plus_one - 2 * input_delta) 201 | * theta_one_minus_theta 202 | ) 203 | outputs = input_cumheights + numerator / denominator 204 | 205 | derivative_numerator = input_delta.pow(2) * ( 206 | input_derivatives_plus_one * theta.pow(2) 207 | + 2 * input_delta * theta_one_minus_theta 208 | + input_derivatives * (1 - theta).pow(2) 209 | ) 210 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 211 | 212 | return outputs, logabsdet 213 | -------------------------------------------------------------------------------- /litfass/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script used for training the model. 3 | """ 4 | 5 | from argparse import ArgumentParser 6 | import os 7 | import inspect 8 | from pathlib import Path 9 | import json 10 | import hashlib 11 | import pickle 12 | 13 | import torch 14 | import torch.multiprocessing 15 | from pytorch_lightning import Trainer 16 | from pytorch_lightning.loggers import WandbLogger 17 | from pytorch_lightning.callbacks import LearningRateMonitor 18 | from pytorch_lightning.callbacks import ModelCheckpoint 19 | from pytorch_lightning.callbacks import StochasticWeightAveraging 20 | from pytorch_lightning.callbacks.early_stopping import EarlyStopping 21 | from alignments.datasets.libritts import LibrittsDataset 22 | 23 | from litfass.third_party.argutils import str2bool 24 | from litfass.fastspeech2.fastspeech2 import FastSpeech2 25 | from litfass.third_party.fastdiff.FastDiff import FastDiff 26 | 27 | torch.multiprocessing.set_sharing_strategy("file_system") 28 | 29 | if __name__ == "__main__": 30 | parser = ArgumentParser() 31 | 32 | parser = Trainer.add_argparse_args(parser) 33 | 34 | parser.add_argument("--early_stopping", type=str2bool, default=True) 35 | parser.add_argument("--early_stopping_patience", type=int, default=4) 36 | 37 | parser.add_argument("--swa_lr", type=float, default=None) 38 | 39 | lr_monitor = LearningRateMonitor(logging_interval="step") 40 | callbacks = [lr_monitor] 41 | 42 | parser.add_argument("--dataset_cache_path", type=str, default="../dataset_cache") 43 | parser.add_argument("--no_cache", type=str2bool, default=False) 44 | 45 | parser.add_argument( 46 | "--train_target_path", 47 | type=str, 48 | nargs="+", 49 | default=["../data/train-clean-360-aligned"], 50 | ) 51 | parser.add_argument( 52 | "--train_source_path", type=str, nargs="+", default=["../data/train-clean-360"] 53 | ) 54 | parser.add_argument( 55 | "--train_source_url", 56 | type=str, 57 | nargs="+", 58 | default=["https://www.openslr.org/resources/60/train-clean-360.tar.gz"], 59 | ) 60 | parser.add_argument("--train_tmp_path", type=str, default="../tmp") 61 | 62 | parser.add_argument( 63 | "--valid_target_path", type=str, default="../data/dev-clean-aligned" 64 | ) 65 | parser.add_argument("--valid_source_path", type=str, default="../data/dev-clean") 66 | parser.add_argument( 67 | "--valid_source_url", 68 | type=str, 69 | default="https://www.openslr.org/resources/60/dev-clean.tar.gz", 70 | ) 71 | parser.add_argument("--valid_tmp_path", type=str, default="../tmp") 72 | 73 | parser = FastSpeech2.add_model_specific_args(parser) 74 | parser = FastSpeech2.add_dataset_specific_args(parser) 75 | 76 | parser.add_argument("--wandb_project", type=str, default="fastspeech2") 77 | parser.add_argument("--wandb_mode", type=str, default="online") 78 | parser.add_argument("--wandb_name", type=str, default=None) 79 | parser.add_argument("--checkpoint", type=str2bool, default=True) 80 | parser.add_argument("--checkpoint_key", type=str, default="eval/mel_loss") 81 | parser.add_argument("--checkpoint_mode", type=str, default="min") 82 | parser.add_argument("--checkpoint_path", type=str, default="models") 83 | parser.add_argument("--checkpoint_filename", type=str, default=None) 84 | parser.add_argument("--from_checkpoint", type=str, default=None) 85 | 86 | parser.add_argument("--visible_gpus", type=int, default=0) 87 | 88 | # fastdiff vocoder 89 | parser.add_argument("--fastdiff_vocoder", type=str2bool, default=False) 90 | parser.add_argument("--fastdiff_vocoder_checkpoint", type=str, default=None) 91 | parser = FastDiff.add_model_specific_args(parser) 92 | 93 | args = parser.parse_args() 94 | var_args = vars(args) 95 | 96 | os.environ["WANDB_MODE"] = var_args["wandb_mode"] 97 | if var_args["wandb_name"] is None: 98 | wandb_logger = WandbLogger(project=var_args["wandb_project"]) 99 | else: 100 | wandb_logger = WandbLogger( 101 | project=var_args["wandb_project"], name=var_args["wandb_name"] 102 | ) 103 | 104 | train_ds = [] 105 | 106 | train_ds_kwargs = { 107 | k.replace("train_", ""): v 108 | for k, v in var_args.items() 109 | if k.startswith("train_") 110 | } 111 | 112 | valid_ds_kwargs = { 113 | k.replace("valid_", ""): v 114 | for k, v in var_args.items() 115 | if k.startswith("valid_") 116 | } 117 | 118 | if var_args["fastdiff_vocoder"]: 119 | train_ds_kwargs["load_wav"] = True 120 | valid_ds_kwargs["load_wav"] = True 121 | fastdiff_args = { 122 | k.replace("fastdiff_", ""): v 123 | for k, v in var_args.items() 124 | if ( 125 | k.startswith("fastdiff_") and 126 | "schedule" not in k and 127 | "vocoder" not in k and 128 | "variances" not in k and 129 | "speaker" not in k 130 | ) 131 | } 132 | fastdiff_model = FastDiff(**fastdiff_args) 133 | if var_args["fastdiff_vocoder_checkpoint"] is not None: 134 | state_dict = torch.load(var_args["fastdiff_vocoder_checkpoint"])["state_dict"]["model"] 135 | fastdiff_model.load_state_dict(state_dict, strict=True) 136 | else: 137 | fastdiff_model = None 138 | 139 | if not var_args["no_cache"]: 140 | Path(var_args["dataset_cache_path"]).mkdir(parents=True, exist_ok=True) 141 | cache_path = Path(var_args["dataset_cache_path"]) 142 | else: 143 | cache_path = None 144 | 145 | for i in range(len(var_args["train_target_path"])): 146 | if not var_args["no_cache"]: 147 | kwargs = train_ds_kwargs 148 | kwargs.update({"target_directory": var_args["train_target_path"][i]}) 149 | ds_hash = hashlib.md5( 150 | json.dumps(kwargs, sort_keys=True).encode("utf-8") 151 | ).hexdigest() 152 | cache_path_alignments = ( 153 | Path(var_args["dataset_cache_path"]) / f"train-alignments-{ds_hash}.pt" 154 | ) 155 | if ( 156 | var_args["no_cache"] 157 | or next(Path(var_args["train_target_path"][i]).rglob("**/*.TextGrid"), -1) 158 | == -1 159 | or not cache_path_alignments.exists() 160 | ): 161 | if len(var_args["train_source_path"]) > i: 162 | source_path = var_args["train_source_path"][i] 163 | else: 164 | source_path = None 165 | if len(var_args["train_source_url"]) > i: 166 | source_url = var_args["train_source_url"][i] 167 | else: 168 | source_url = None 169 | train_ds += [ 170 | LibrittsDataset( 171 | target_directory=var_args["train_target_path"][i], 172 | source_directory=source_path, 173 | source_url=source_url, 174 | verbose=True, 175 | tmp_directory=var_args["train_tmp_path"], 176 | chunk_size=10_000, 177 | ) 178 | ] 179 | if not var_args["no_cache"]: 180 | train_ds[-1].hash = ds_hash 181 | with open(cache_path_alignments, "wb") as f: 182 | pickle.dump(train_ds[-1], f) 183 | else: 184 | if cache_path_alignments.exists(): 185 | with open(cache_path_alignments, "rb") as f: 186 | train_ds += [pickle.load(f)] 187 | 188 | if not var_args["no_cache"]: 189 | kwargs = valid_ds_kwargs 190 | kwargs.update({"target_directory": var_args["valid_target_path"]}) 191 | ds_hash = hashlib.md5( 192 | json.dumps(kwargs, sort_keys=True).encode("utf-8") 193 | ).hexdigest() 194 | cache_path_alignments = ( 195 | Path(var_args["dataset_cache_path"]) / f"valid-alignments-{ds_hash}.pt" 196 | ) 197 | if ( 198 | var_args["no_cache"] 199 | or next(Path(var_args["valid_target_path"]).rglob("**/*.TextGrid"),-1) == -1 200 | or not cache_path_alignments.exists() 201 | ): 202 | valid_ds = LibrittsDataset( 203 | target_directory=var_args["valid_target_path"], 204 | source_directory=var_args["valid_source_path"], 205 | source_url=var_args["valid_source_url"], 206 | verbose=True, 207 | tmp_directory=var_args["valid_tmp_path"], 208 | chunk_size=10_000, 209 | ) 210 | if not var_args["no_cache"]: 211 | valid_ds.hash = ds_hash 212 | with open(cache_path_alignments, "wb") as f: 213 | pickle.dump(valid_ds, f) 214 | else: 215 | if cache_path_alignments.exists(): 216 | with open(cache_path_alignments, "rb") as f: 217 | valid_ds = pickle.load(f) 218 | 219 | model_args = { 220 | k: v 221 | for k, v in var_args.items() 222 | if k in inspect.signature(FastSpeech2).parameters 223 | } 224 | 225 | del train_ds_kwargs["target_path"] 226 | del train_ds_kwargs["target_directory"] 227 | del train_ds_kwargs["source_path"] 228 | del train_ds_kwargs["source_url"] 229 | del train_ds_kwargs["tmp_path"] 230 | del valid_ds_kwargs["target_path"] 231 | del valid_ds_kwargs["target_directory"] 232 | del valid_ds_kwargs["source_path"] 233 | del valid_ds_kwargs["source_url"] 234 | del valid_ds_kwargs["nexamples"] 235 | del valid_ds_kwargs["example_directory"] 236 | del valid_ds_kwargs["tmp_path"] 237 | if "load_wav" in valid_ds_kwargs: 238 | del valid_ds_kwargs["load_wav"] 239 | 240 | if args.from_checkpoint is not None: 241 | model = FastSpeech2.load_from_checkpoint( 242 | args.from_checkpoint, 243 | train_ds=train_ds, 244 | valid_ds=valid_ds, 245 | train_ds_kwargs=train_ds_kwargs, 246 | valid_ds_kwargs=valid_ds_kwargs, 247 | strict=False, 248 | fastdiff_model=fastdiff_model, 249 | **model_args, 250 | ) 251 | else: 252 | model_args["cache_path"] = cache_path 253 | model = FastSpeech2( 254 | train_ds, 255 | valid_ds, 256 | train_ds_kwargs=train_ds_kwargs, 257 | valid_ds_kwargs=valid_ds_kwargs, 258 | fastdiff_model=fastdiff_model, 259 | **model_args, 260 | ) 261 | 262 | if var_args["checkpoint_filename"] is None and var_args["wandb_name"] is not None: 263 | var_args["checkpoint_filename"] = var_args["wandb_name"] 264 | 265 | if var_args["checkpoint"]: 266 | callbacks.append( 267 | ModelCheckpoint( 268 | monitor=var_args["checkpoint_key"], 269 | mode=var_args["checkpoint_mode"], 270 | filename=var_args["checkpoint_filename"], 271 | dirpath=var_args["checkpoint_path"], 272 | ) 273 | ) 274 | 275 | if var_args["early_stopping"]: 276 | callbacks.append( 277 | EarlyStopping( 278 | monitor="eval/mel_loss", patience=var_args["early_stopping_patience"] 279 | ) 280 | ) 281 | 282 | if var_args["swa_lr"] is not None: 283 | callbacks.append(StochasticWeightAveraging(swa_lrs=var_args["swa_lr"])) 284 | 285 | trainer = Trainer.from_argparse_args( 286 | args, 287 | callbacks=callbacks, 288 | default_root_dir="logs", 289 | logger=wandb_logger, 290 | ) 291 | 292 | trainer.fit(model) 293 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "LightningFastSpeech" 3 | version = "0.1" 4 | description = "" 5 | authors = [ 6 | {name = "Christoph Minixhofer", email = "christoph.minixhofer@gmail.com"}, 7 | ] 8 | dependencies = [ 9 | "pytorch-lightning>=1.5.10", 10 | "wandb>=0.12.11", 11 | "pillow>=9.0.1", 12 | "matplotlib>=3.5.1", 13 | "scikit-learn>=1.0.2", 14 | "seaborn>=0.11.2", 15 | "phones>=0.0.2", 16 | "pandarallel>=1.5.7", 17 | "tgt>=1.4.4", 18 | "pyworld>=0.3.0", 19 | "llvmlite>=0.38.0", 20 | "librosa>=0.9.1", 21 | "click>=8.0.4", 22 | "diskcache>=5.4.0", 23 | "plotly>=5.7.0", 24 | "kaleido==0.2.1", 25 | "snreval>=1.2", 26 | "audiomentations[extras]>=0.25.1", 27 | "crepe>=0.0.12", 28 | "textgrid>=1.5", 29 | "rich>=12.4.4", 30 | "alignments>=0.0.9", 31 | "g2p-en>=2.1.0", 32 | "voicefixer>=0.0.17", 33 | "huggingface-hub>=0.7.0rc0", 34 | "pysdtw>=0.0.5", 35 | "numba>=0.56.2", 36 | "einops>=0.5.0", 37 | "SRMRpy @ git+https://github.com/MiniXC/SRMRpy.git", 38 | "nnAudio>=0.3.2", 39 | ] 40 | license = {text = "MIT"} 41 | requires-python = ">=3.8,<3.10" 42 | [project.urls] 43 | Homepage = "" 44 | 45 | [project.optional-dependencies] 46 | dev = [ 47 | "line-profiler>=3.5.1", 48 | "viztracer>=0.15.3", 49 | "pylint>=2.15.3", 50 | "black>=22.3.0", 51 | ] 52 | extras = [ 53 | "speechbrain>=0.5.10", 54 | "huggingface-hub==0.7.0rc0", 55 | ] 56 | deepspeed = [ 57 | "deepspeed>=0.7.2", 58 | "fairscale>=0.4.9", 59 | "cupy>=11.2.0", 60 | ] 61 | [tool.pdm] 62 | [tool.pdm.dev-dependencies] 63 | 64 | [build-system] 65 | requires = ["pdm-pep517"] 66 | build-backend = "pdm.pep517.api" 67 | -------------------------------------------------------------------------------- /scripts/generate.sh: -------------------------------------------------------------------------------- 1 | pdm run python litfass/generate.py \ 2 | --checkpoint_path "models/fastdiff_nopretrain.ckpt" \ 3 | --dataset "../data/dev-clean-aligned" \ 4 | --output_path "../generated/fastdiff_nopretrain" \ 5 | --hours .5 \ 6 | --batch_size 1 \ 7 | --use_voicefixer True \ 8 | --cache_path "../dataset_cache" \ 9 | --tts_device "cuda:0" \ 10 | --hifigan_device "cuda:1" \ 11 | --use_fastdiff True \ 12 | --fastdiff_n 4 \ 13 | --min_samples_per_speaker 0 \ 14 | --num_workers 16 -------------------------------------------------------------------------------- /scripts/generate_ab_train_splits.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from tqdm.auto import tqdm 3 | import random 4 | 5 | a_speaker_counts = {} 6 | b_speaker_counts = {} 7 | 8 | a_path = Path("../data/train-clean-a") 9 | b_path = Path("../data/train-clean-b") 10 | 11 | extensions = [".lab", ".npy", ".TextGrid"] 12 | 13 | file_list = list(Path("../data/train-clean-360-aligned").rglob("*.wav")) + list(Path("../data/train-clean-100-aligned").rglob("*.wav")) 14 | # sort 15 | file_list = sorted(file_list) 16 | # random shuffle with seed 17 | random.Random(42).shuffle(file_list) 18 | 19 | for wavfile in tqdm(file_list): 20 | speaker = wavfile.parent.name 21 | basename = wavfile.name.replace(".wav", "") 22 | 23 | if speaker not in a_speaker_counts: 24 | a_speaker_counts[speaker] = 0 25 | if speaker not in b_speaker_counts: 26 | b_speaker_counts[speaker] = 0 27 | 28 | if a_speaker_counts[speaker] < b_speaker_counts[speaker]: 29 | a_speaker_counts[speaker] += 1 30 | tgt_path = a_path / speaker 31 | else: 32 | b_speaker_counts[speaker] += 1 33 | tgt_path = b_path / speaker 34 | 35 | tgt_path.mkdir(parents=True, exist_ok=True) 36 | (tgt_path / wavfile.name).symlink_to(wavfile.resolve()) 37 | for ext in extensions: 38 | (tgt_path / (basename + ext)).symlink_to(wavfile.with_suffix(ext).resolve()) -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | 3 | python3 litfass/train.py \ 4 | --accelerator tpu \ 5 | --devices 1 \ 6 | --precision 16 \ 7 | --batch_size 4 \ 8 | --accumulate_grad_batches 12 \ 9 | --val_check_interval 1.0 \ 10 | --log_every_n_steps 10 \ 11 | --layer_dropout 0.00 \ 12 | --duration_dropout 0.1 \ 13 | --variance_dropout 0.1 0.1 0.1 0.1 \ 14 | --soft_dtw_gamma 0.01 \ 15 | --max_epochs 40 \ 16 | --gradient_clip_val 1.0 \ 17 | --encoder_hidden 256 \ 18 | --encoder_conv_filter_size 1024 \ 19 | --variance_filter_size 256 \ 20 | --duration_filter_size 256 \ 21 | --decoder_hidden 256 \ 22 | --decoder_conv_filter_size 1024 \ 23 | --encoder_head 2 \ 24 | --decoder_head 2 \ 25 | --variance_loss_weights 1 1 1 1 \ 26 | --duration_loss_weight 1 \ 27 | --duration_nlayers 5 \ 28 | --variances pitch energy snr srmr \ 29 | --variance_levels frame frame frame frame \ 30 | --variance_transforms none none none none \ 31 | --variance_losses mse mse mse mse \ 32 | --variance_early_stopping none \ 33 | --early_stopping False \ 34 | --decoder_layers 6 \ 35 | --decoder_kernel_sizes 9 9 9 9 9 9 \ 36 | --speaker_embedding_every_layer False \ 37 | --prior_embedding_every_layer False \ 38 | --wandb_name "fastdiff_nopretrain_variances_fixed" \ 39 | --wandb_mode "offline" \ 40 | --speaker_type "dvector" \ 41 | --train_target_path "../data/train-clean-100-aligned" \ 42 | --train_source_path "../data/train-clean-100" \ 43 | --train_source_url "https://www.openslr.org/resources/60/train-clean-100.tar.gz" \ 44 | --train_min_samples_per_speaker 50 \ 45 | --priors_gmm True \ 46 | --priors_gmm_max_components 2 \ 47 | --dvector_gmm False \ 48 | --priors energy duration snr pitch srmr \ 49 | --sort_data_by_length True \ 50 | --train_pad_to_multiple_of 64 \ 51 | --fastdiff_vocoder True \ 52 | --fastdiff_schedule 1 1 \ 53 | --fastdiff_variances True \ 54 | --fastdiff_speakers True \ 55 | --num_sanity_val_steps 0 56 | #--from_checkpoint "models/fastdiff_nopretrain_variances_fixed-v1.ckpt" 57 | 58 | #--fastdiff_vocoder_checkpoint "fastdiff_model/model_ckpt_steps_1000000.ckpt" \ 59 | #--from_checkpoint "models/fastdiff_fixed_inf-v2.ckpt" 60 | 61 | # --priors energy duration snr pitch \ 62 | # --train_target_path "../data/train-clean-100-aligned" "../data/train-clean-360-aligned" "../data/train-other-500-aligned" \ 63 | 64 | # --devices 4 \ 65 | # --strategy "ddp" \ 66 | 67 | # --valid_example_directory "examples" 68 | --------------------------------------------------------------------------------