├── .gitignore ├── LICENSE ├── README.md ├── config.py ├── dataset ├── __init__.py ├── config.py └── ljspeech.py ├── expr.ipynb ├── inference.py ├── model ├── __init__.py ├── config.py └── wavenet.py ├── requirements.txt ├── rsrc └── loss.png ├── train.py └── utils └── noam_schedule.py /.gitignore: -------------------------------------------------------------------------------- 1 | # python 2 | __pycache__ 3 | 4 | # train 5 | ckpt 6 | log 7 | 8 | # test sample 9 | sample 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 YoungJoong Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tf-diffwave 2 | (Unofficial) Tensorflow implementation of DiffWave (Zhifeng Kong et al., 2020) 3 | 4 | - DiffWave: A Versatile Diffusion Model for Audio Synthesis, Zhifeng Kong et al., 2020. [[arXiv:2009.09761](https://arxiv.org/abs/2009.09761)] 5 | 6 | ## Requirements 7 | 8 | Tested in python 3.7.3 conda environment, [requirements.txt](./requirements.txt) 9 | 10 | ## Usage 11 | 12 | To download LJ-Speech dataset, run under script. 13 | 14 | Dataset will be downloaded in '~/tensorflow_datasets' in tfrecord format. If you want to change the download directory, specify `data_dir` parameter of `LJSpeech` initializer. 15 | 16 | ```python 17 | from dataset import LJSpeech 18 | from dataset.config import Config 19 | 20 | config = Config() 21 | # lj = LJSpeech(config, data_dir=path, download=True) 22 | lj = LJSpeech(config, download=True) 23 | ``` 24 | 25 | To train model, run [train.py](./train.py). 26 | 27 | Checkpoint will be written on `TrainConfig.ckpt`, tensorboard summary on `TrainConfig.log`. 28 | 29 | ```bash 30 | python train.py 31 | tensorboard --logdir ./log/ 32 | ``` 33 | 34 | If you want to train model from raw audio, specify audio directory and turn on the flag `--from-raw`. 35 | 36 | ```bash 37 | python .\train.py --data-dir D:\LJSpeech-1.1\wavs --from-raw 38 | ``` 39 | 40 | To start to train from previous checkpoint, `--load-step` is available. 41 | 42 | ```bash 43 | python .\train.py --load-step 416 --config ./ckpt/q1.json 44 | ``` 45 | 46 | For experiments, reference [expr.ipynb](./expr.ipynb). 47 | 48 | To inference test set, run [inference.py](./inference.py). 49 | 50 | ```bash 51 | python .\inference.py 52 | ``` 53 | 54 | Pretrained checkpoints are relased on [releases](https://github.com/revsic/tf-diffwave/releases). 55 | 56 | To use pretrained model, download files and unzip it. Checkout git repository to proper commit tags and followings are sample script. 57 | 58 | ```py 59 | with open('l1.json') as f: 60 | config = Config.load(json.load(f)) 61 | 62 | diffwave = DiffWave(config.model) 63 | diffwave.restore('./l1/l1_1000000.ckpt-1').expect_partial() 64 | ``` 65 | 66 | ## Learning Curve 67 | 68 | res.channels=64, T=20, train 1M steps. 69 | 70 | ![loss](./rsrc/loss.png) 71 | 72 | ## Samples 73 | 74 | Reference [https://revsic.github.io/tf-diffwave](https://revsic.github.io/tf-diffwave). 75 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from dataset.config import Config as DataConfig 2 | from model.config import Config as ModelConfig 3 | from utils.noam_schedule import NoamScheduler 4 | 5 | 6 | class TrainConfig: 7 | """Configuration for training loop. 8 | """ 9 | def __init__(self): 10 | # optimizer 11 | self.lr_policy = 'fixed' 12 | self.learning_rate = 2e-4 13 | # self.lr_policy = 'noam' 14 | # self.learning_rate = 1 15 | # self.lr_params = { 16 | # 'warmup_steps': 4000, 17 | # 'channels': 64 18 | # } 19 | 20 | self.beta1 = 0.9 21 | self.beta2 = 0.98 22 | self.eps = 1e-9 23 | 24 | # 13000:100 25 | self.split = 13000 26 | self.bufsiz = 48 27 | 28 | self.epoch = 10000 29 | 30 | # path config 31 | self.log = './log' 32 | self.ckpt = './ckpt' 33 | 34 | # model name 35 | self.name = 'l1' 36 | 37 | # interval configuration 38 | self.eval_intval = 5000 39 | self.ckpt_intval = 10000 40 | 41 | def lr(self): 42 | """Generate proper learning rate scheduler. 43 | """ 44 | mapper = { 45 | 'noam': NoamScheduler 46 | } 47 | if self.lr_policy == 'fixed': 48 | return self.learning_rate 49 | if self.lr_policy in mapper: 50 | return mapper[self.lr_policy](self.learning_rate, **self.lr_params) 51 | raise ValueError('invalid lr_policy') 52 | 53 | class Config: 54 | """Integrated configuration. 55 | """ 56 | def __init__(self): 57 | self.data = DataConfig() 58 | self.model = ModelConfig() 59 | self.train = TrainConfig() 60 | 61 | def dump(self): 62 | """Dump configurations into serializable dictionary. 63 | """ 64 | return {k: vars(v) for k, v in vars(self).items()} 65 | 66 | @staticmethod 67 | def load(dump_): 68 | """Load dumped configurations into new configuration. 69 | """ 70 | conf = Config() 71 | for k, v in dump_.items(): 72 | if hasattr(conf, k): 73 | obj = getattr(conf, k) 74 | load_state(obj, v) 75 | return conf 76 | 77 | 78 | def load_state(obj, dump_): 79 | """Load dictionary items to attributes. 80 | """ 81 | for k, v in dump_.items(): 82 | if hasattr(obj, k): 83 | setattr(obj, k, v) 84 | return obj 85 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .ljspeech import LJSpeech -------------------------------------------------------------------------------- /dataset/config.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from .ljspeech import LJSpeech 4 | 5 | 6 | class Config: 7 | """Configuration for dataset construction. 8 | """ 9 | def __init__(self): 10 | # audio config 11 | self.sr = LJSpeech.SR 12 | self.maxval = LJSpeech.MAXVAL 13 | 14 | # stft 15 | self.hop = 256 16 | self.win = 1024 17 | self.fft = self.win 18 | self.win_fn = 'hann' 19 | 20 | # mel-scale filter bank 21 | self.mel = 80 22 | self.fmin = 0 23 | self.fmax = 8000 24 | 25 | self.eps = 1e-5 26 | 27 | # sample size 28 | self.frames = 6400 # 16000 29 | self.batch = 8 # 16 30 | 31 | def window_fn(self): 32 | """Return window generator. 33 | Returns: 34 | Callable, window function of tf.signal 35 | , which corresponds to self.win_fn. 36 | """ 37 | mapper = { 38 | 'hann': tf.signal.hann_window, 39 | 'hamming': tf.signal.hamming_window 40 | } 41 | if self.win_fn in mapper: 42 | return mapper[self.win_fn] 43 | 44 | raise ValueError('invalid window function: ' + self.win_fn) 45 | -------------------------------------------------------------------------------- /dataset/ljspeech.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import librosa 4 | import numpy as np 5 | import tensorflow as tf 6 | import tensorflow_datasets as tfds 7 | 8 | 9 | class LJSpeech: 10 | """LJ Speech dataset loader. 11 | Use other opensource vocoder settings, 16bit, sr: 22050. 12 | """ 13 | SR = 22050 14 | MAXVAL = 32767. 15 | 16 | def __init__(self, config, data_dir=None, download=False, from_tfds=True): 17 | """Initializer. 18 | Args: 19 | config: Config, dataset configuration. 20 | data_dir: str, dataset directory 21 | , defaults to '~/tensorflow_datasets'. 22 | download: bool, download dataset or not. 23 | from_tfds: bool, load from tfrecord generated by tfds or read raw audio. 24 | """ 25 | self.config = config 26 | self.rawset, self.info = self.load_data(data_dir, download, from_tfds) 27 | # [fft // 2 + 1, mel] 28 | melfilter = librosa.filters.mel( 29 | config.sr, config.fft, config.mel, config.fmin, config.fmax).T 30 | self.melfilter = tf.convert_to_tensor(melfilter) 31 | 32 | self.normalized = None 33 | 34 | def load_data(self, data_dir=None, download=False, from_tfds=True): 35 | """Load dataset from tfrecord or raw audio files. 36 | Args: 37 | data_dir: str, dataset directory. 38 | For from_tfds, None is acceptable 39 | and set to default value '~/tensorflow_datasets'. 40 | For from raw audio, None is not acceptable. 41 | download: bool, download dataset or not, for from_tfds. 42 | from_tfds: bool, whether use tfds or read raw audio. 43 | Returns: 44 | tf.data.Dataset, data loader. 45 | """ 46 | if from_tfds: 47 | dataset, info = tfds.load( 48 | 'ljspeech', split='train', 49 | data_dir=data_dir, download=download, with_info=True) 50 | # filter only audio 51 | return dataset.map(LJSpeech._preproc_tfds), info 52 | # generate file lists 53 | files = tf.data.Dataset.from_tensor_slices( 54 | [os.path.join(data_dir, n) for n in os.listdir(data_dir)]) 55 | # read audio 56 | return files.map(LJSpeech._load_audio), None 57 | 58 | @staticmethod 59 | def _load_audio(path): 60 | """Load audio with tf apis. 61 | Args: 62 | path: str, wavfile path to read. 63 | Returns: 64 | tf.Tensor, [T], mono audio in range (-1, 1). 65 | """ 66 | raw = tf.io.read_file(path) 67 | audio, _ = tf.audio.decode_wav(raw, desired_channels=1) 68 | return tf.squeeze(audio, axis=-1) 69 | 70 | @staticmethod 71 | def _preproc_tfds(datum): 72 | """Preprocess datum from tfds. 73 | Args: 74 | datum: Dict[str, tf.Tensor], 75 | id: [], string, string id. 76 | speech: [T], int64, audio signal in range (-MAXVALUE - 1, MAXVALUE). 77 | text: [], string, text. 78 | text_normalized: [], string, normalized text. 79 | Returns: 80 | tf.Tensor, [T], mono audio in range (-1, 1). 81 | """ 82 | return tf.cast(datum['speech'], tf.float32) / LJSpeech.MAXVAL 83 | 84 | def normalizer(self, frames=16000): 85 | """Create LJSpeech normalizer, make fixed size segment in range(-1, 1). 86 | Args: 87 | frames: int, segment size, frame unit. 88 | from_tfds: bool, whether use tfds tfrecord or raw audio. 89 | Returns: 90 | Callable, normalizer. 91 | """ 92 | def normalize(speech): 93 | """Normalize datum. 94 | Args: 95 | speech: tf.Tensor, [T], mono audio in range (-1, 1). 96 | Returns: 97 | tf.Tensor, [frames], fixed size speech signal in range (-1, 1). 98 | """ 99 | nonlocal frames 100 | frames = frames // self.config.hop * self.config.hop 101 | start = tf.random.uniform( 102 | (), 0, tf.shape(speech)[0] - frames, dtype=tf.int32) 103 | return speech[start:start + frames] 104 | 105 | return normalize 106 | 107 | def mel_fn(self, signal): 108 | """Generate log mel-spectrogram from input audio segment. 109 | Args: 110 | signal: tf.Tensor, [B, T], audio segment. 111 | Returns: 112 | tuple, 113 | signal: tf.Tensor, [B, T], identity to inputs. 114 | logmel: tf.Tensor, [B, T // hop, mel], log mel-spectrogram. 115 | """ 116 | padlen = self.config.win // 2 117 | # [B, T + win - 1] 118 | center_pad = tf.pad(signal, [[0, 0], [padlen, padlen - 1]], mode='reflect') 119 | # [B, T // hop, fft // 2 + 1] 120 | stft = tf.signal.stft( 121 | center_pad, 122 | frame_length=self.config.win, 123 | frame_step=self.config.hop, 124 | fft_length=self.config.fft, 125 | window_fn=self.config.window_fn()) 126 | # [B, T // hop, mel] 127 | mel = tf.abs(stft) @ self.melfilter 128 | # [B, T // hop, mel] 129 | logmel = tf.math.log(tf.maximum(mel, self.config.eps)) 130 | return signal, logmel 131 | 132 | def dataset(self): 133 | """Generate dataset. 134 | """ 135 | if self.normalized is None: 136 | self.normalized = self.rawset \ 137 | .map(self.normalizer(self.config.frames)) \ 138 | .batch(self.config.batch) \ 139 | .map(self.mel_fn) 140 | return self.normalized 141 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | 5 | import librosa 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | from config import Config 10 | from dataset import LJSpeech 11 | from model import DiffWave 12 | 13 | LJ_DATA_SIZE = 13100 14 | 15 | 16 | def main(args): 17 | # prepare directory for samples 18 | if not os.path.exists(args.sample_dir): 19 | os.makedirs(args.sample_dir) 20 | 21 | # load checkpoint 22 | with open(args.config) as f: 23 | config = Config.load(json.load(f)) 24 | 25 | diffwave = DiffWave(config.model) 26 | diffwave.restore(args.ckpt).expect_partial() 27 | 28 | # open dataset 29 | lj = LJSpeech(config.data) 30 | if args.offset is None: 31 | args.offset = config.train.split + \ 32 | np.random.randint(LJ_DATA_SIZE - config.train.split) 33 | 34 | # sample 35 | print('[*] offset: ', args.offset) 36 | speech = next(iter(lj.rawset.skip(args.offset))) 37 | speech = speech[:speech.shape[0] // config.data.hop * config.data.hop] 38 | 39 | librosa.output.write_wav( 40 | os.path.join(args.sample_dir, str(args.offset) + '_gt.wav'), 41 | speech.numpy(), 42 | config.data.sr) 43 | 44 | # inference 45 | noise = tf.random.normal(tf.shape(speech[None])) 46 | librosa.output.write_wav( 47 | os.path.join(args.sample_dir, str(args.offset) + '_noise.wav'), 48 | noise[0].numpy(), 49 | config.data.sr) 50 | 51 | _, logmel = lj.mel_fn(speech[None]) 52 | _, ir = diffwave(logmel, noise) 53 | for i, sample in enumerate(ir): 54 | librosa.output.write_wav( 55 | os.path.join(args.sample_dir, '{}_{}step.wav'.format(args.offset, i)), 56 | sample[0], 57 | config.data.sr) 58 | 59 | print('[*] done') 60 | 61 | if __name__ == '__main__': 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument('--sample-dir', default='./sample') 64 | parser.add_argument('--config', default='./ckpt/l1.json') 65 | parser.add_argument('--ckpt', default='./ckpt/l1/l1_1000000.ckpt-1') 66 | parser.add_argument('--offset', default=None, type=int) 67 | args = parser.parse_args() 68 | main(args) 69 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from .wavenet import WaveNet 5 | 6 | 7 | class DiffWave(tf.keras.Model): 8 | """DiffWave: A Versatile Diffusion Model for Audio Synthesis. 9 | Zhifeng Kong et al., 2020. 10 | """ 11 | def __init__(self, config): 12 | """Initializer. 13 | Args: 14 | config: Config, model configuration. 15 | """ 16 | super(DiffWave, self).__init__() 17 | self.config = config 18 | self.wavenet = WaveNet(config) 19 | 20 | def call(self, mel, noise=None): 21 | """Generate denoised audio. 22 | Args: 23 | mel: tf.Tensor, [B, T // hop, M], conditonal mel-spectrogram. 24 | noise: Optional[tf.Tensor], [B, T], starting noise. 25 | Returns: 26 | tuple, 27 | signal: tf.Tensor, [B, T], predicted output. 28 | ir: List[np.ndarray: [B, T]], intermediate outputs. 29 | """ 30 | if noise is None: 31 | # [B, T // hop, M] 32 | b, t, _ = tf.shape(mel) 33 | # [B, T] 34 | noise = tf.random.normal([b, t * self.config.hop]) 35 | 36 | # [iter] 37 | alpha = 1 - self.config.beta() 38 | alpha_bar = np.cumprod(alpha) 39 | # [B] 40 | base = tf.ones([tf.shape(noise)[0]], dtype=tf.int32) 41 | 42 | ir, signal = [], noise 43 | for t in range(self.config.iter, 0, -1): 44 | # [B, T] 45 | eps = self.pred_noise(signal, base * t, mel) 46 | # [B, T], [] 47 | mu, sigma = self.pred_signal(signal, eps, alpha[t - 1], alpha_bar[t - 1]) 48 | # [B, T] 49 | signal = mu + tf.random.normal(tf.shape(signal)) * sigma 50 | ir.append(signal.numpy()) 51 | # [B, T], iter x [B, T] 52 | return signal, ir 53 | 54 | def diffusion(self, signal, alpha_bar, eps=None): 55 | """Trans to next state with diffusion process. 56 | Args: 57 | signal: tf.Tensor, [B, T], signal. 58 | alpha_bar: Union[float, tf.Tensor: [B]], cumprod(1 -beta). 59 | eps: Optional[tf.Tensor: [B, T]], noise. 60 | Return: 61 | tuple, 62 | noised: tf.Tensor, [B, T], noised signal. 63 | eps: tf.Tensor, [B, T], noise. 64 | """ 65 | if eps is None: 66 | eps = tf.random.normal(tf.shape(signal)) 67 | if isinstance(alpha_bar, tf.Tensor): 68 | alpha_bar = alpha_bar[:, None] 69 | return tf.sqrt(alpha_bar) * signal + tf.sqrt(1 - alpha_bar) * eps, eps 70 | 71 | def pred_noise(self, signal, timestep, mel): 72 | """Predict noise from signal. 73 | Args: 74 | signal: tf.Tensor, [B, T], noised signal. 75 | timestep: tf.Tensor, [B], timesteps of current markov chain. 76 | mel: tf.Tensor, [B, T // hop, M], conditional mel-spectrogram. 77 | Returns: 78 | tf.Tensor, [B, T], predicted noise. 79 | """ 80 | return self.wavenet(signal, timestep, mel) 81 | 82 | def pred_signal(self, signal, eps, alpha, alpha_bar): 83 | """Compute mean and stddev of denoised signal. 84 | Args: 85 | signal: tf.Tensor, [B, T], noised signal. 86 | eps: tf.Tensor, [B, T], estimated noise. 87 | alpha: float, 1 - beta. 88 | alpha_bar: float, cumprod(1 - beta). 89 | Returns: 90 | tuple, 91 | mean: tf.Tensor, [B, T], estimated mean of denoised signal. 92 | stddev: float, estimated stddev. 93 | """ 94 | # [B, T] 95 | mean = (signal - (1 - alpha) / np.sqrt(1 - alpha_bar) * eps) / np.sqrt(alpha) 96 | # [] 97 | stddev = np.sqrt((1 - alpha_bar / alpha) / (1 - alpha_bar) * (1 - alpha)) 98 | return mean, stddev 99 | 100 | def write(self, path, optim=None): 101 | """Write checkpoint with `tf.train.Checkpoint`. 102 | Args: 103 | path: str, path to write. 104 | optim: Optional[tf.keras.optimizers.Optimizer] 105 | , optional optimizer. 106 | """ 107 | kwargs = {'model': self} 108 | if optim is not None: 109 | kwargs['optim'] = optim 110 | ckpt = tf.train.Checkpoint(**kwargs) 111 | ckpt.save(path) 112 | 113 | def restore(self, path, optim=None): 114 | """Restore checkpoint with `tf.train.Checkpoint`. 115 | Args: 116 | path: str, path to restore. 117 | optim: Optional[tf.keras.optimizers.Optimizer] 118 | , optional optimizer. 119 | """ 120 | kwargs = {'model': self} 121 | if optim is not None: 122 | kwargs['optim'] = optim 123 | ckpt = tf.train.Checkpoint(**kwargs) 124 | return ckpt.restore(path) 125 | -------------------------------------------------------------------------------- /model/config.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Config: 5 | """Configuration for DiffWave implementation. 6 | """ 7 | def __init__(self): 8 | # leaky relu coefficient 9 | self.leak = 0.4 10 | 11 | # embdding config 12 | self.embedding_size = 128 13 | self.embedding_proj = 512 14 | self.embedding_layers = 2 15 | self.embedding_factor = 4 16 | 17 | # upsampler config 18 | self.upsample_stride = [16, 1] 19 | self.upsample_kernel = [32, 3] 20 | self.upsample_layers = 2 21 | # computed hop size 22 | self.hop = self.upsample_stride[0] ** self.upsample_layers 23 | 24 | # block config 25 | self.channels = 64 26 | self.kernel_size = 3 27 | self.dilation_rate = 2 28 | self.num_layers = 30 29 | self.num_cycles = 3 30 | 31 | # noise schedule 32 | self.iter = 20 # 20, 40, 50 33 | self.noise_policy = 'linear' 34 | self.noise_start = 1e-4 35 | self.noise_end = 0.05 # 0.02 for 200 36 | 37 | def beta(self): 38 | """Generate beta-sequence. 39 | Returns: 40 | List[float], [iter], beta values. 41 | """ 42 | mapper = { 43 | 'linear': self._linear_sched, 44 | } 45 | if self.noise_policy not in mapper: 46 | raise ValueError('invalid beta policy') 47 | return mapper[self.noise_policy]() 48 | 49 | def _linear_sched(self): 50 | """Linearly generated noise. 51 | Returns: 52 | List[float], [iter], beta values. 53 | """ 54 | return np.linspace( 55 | self.noise_start, self.noise_end, self.iter, dtype=np.float32) 56 | -------------------------------------------------------------------------------- /model/wavenet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class DilatedConv1d(tf.keras.layers.Layer): 5 | """Custom implementation of dilated convolution 1D 6 | because of the issue https://github.com/tensorflow/tensorflow/issues/26797. 7 | """ 8 | def __init__(self, 9 | in_channels, 10 | out_channels, 11 | kernel_size, 12 | dilation_rate): 13 | """Initializer. 14 | Args: 15 | in_channels: int, input channels. 16 | out_channels: int, output channels. 17 | kernel_size: int, size of the kernel. 18 | dilation_rate: int, dilation rate. 19 | """ 20 | super(DilatedConv1d, self).__init__() 21 | self.dilations = dilation_rate 22 | 23 | init = tf.keras.initializers.GlorotUniform() 24 | self.kernel = tf.Variable( 25 | init([kernel_size, in_channels, out_channels], dtype=tf.float32), 26 | trainable=True) 27 | self.bias = tf.Variable( 28 | tf.zeros([1, 1, out_channels], dtype=tf.float32), 29 | trainable=True) 30 | 31 | def call(self, inputs): 32 | """Pass to dilated convolution 1d. 33 | Args: 34 | inputs: tf.Tensor, [B, T, Cin], input tensor. 35 | Returns: 36 | outputs: tf.Tensor, [B, T', Cout], output tensor. 37 | """ 38 | conv = tf.nn.conv1d( 39 | inputs, self.kernel, 1, padding='SAME', dilations=self.dilations) 40 | return conv + self.bias 41 | 42 | 43 | class Block(tf.keras.Model): 44 | """WaveNet Block. 45 | """ 46 | def __init__(self, channels, kernel_size, dilation, last=False): 47 | """Initializer. 48 | Args: 49 | channels: int, basic channel size. 50 | kernel_size: int, kernel size of the dilated convolution. 51 | dilation: int, dilation rate. 52 | last: bool, last block or not. 53 | """ 54 | super(Block, self).__init__() 55 | self.channels = channels 56 | self.last = last 57 | 58 | self.proj_embed = tf.keras.layers.Dense(channels) 59 | self.conv = DilatedConv1d( 60 | channels, channels * 2, kernel_size, dilation) 61 | self.proj_mel = tf.keras.layers.Conv1D(channels * 2, 1) 62 | 63 | if not last: 64 | self.proj_res = tf.keras.layers.Conv1D(channels, 1) 65 | self.proj_skip = tf.keras.layers.Conv1D(channels, 1) 66 | 67 | def call(self, inputs, embedding, mel): 68 | """Pass wavenet block. 69 | Args: 70 | inputs: tf.Tensor, [B, T, C(=channels)], input tensor. 71 | embedding: tf.Tensor, [B, E], embedding tensor for noise schedules. 72 | mel: tf.Tensor, [B, T // hop, M], mel-spectrogram conditions. 73 | Returns: 74 | residual: tf.Tensor, [B, T, C], output tensor for residual connection. 75 | skip: tf.Tensor, [B, T, C], output tensor for skip connection. 76 | """ 77 | # [B, C] 78 | embedding = self.proj_embed(embedding) 79 | # [B, T, C] 80 | x = inputs + embedding[:, None] 81 | # [B, T, Cx2] 82 | x = self.conv(x) + self.proj_mel(mel) 83 | # [B, T, C] 84 | context = tf.math.tanh(x[..., :self.channels]) 85 | gate = tf.math.sigmoid(x[..., self.channels:]) 86 | x = context * gate 87 | # [B, T, C] 88 | residual = (self.proj_res(x) + inputs) / 2 ** 0.5 if not self.last else None 89 | skip = self.proj_skip(x) 90 | return residual, skip 91 | 92 | 93 | class WaveNet(tf.keras.Model): 94 | """WaveNet structure. 95 | """ 96 | def __init__(self, config): 97 | """Initializer. 98 | Args: 99 | config: Config, model configuration. 100 | """ 101 | super(WaveNet, self).__init__() 102 | self.config = config 103 | # signal proj 104 | self.proj = tf.keras.layers.Conv1D(config.channels, 1) 105 | # embedding 106 | self.embed = self.embedding(config.iter) 107 | self.proj_embed = [ 108 | tf.keras.layers.Dense(config.embedding_proj) 109 | for _ in range(config.embedding_layers)] 110 | # mel-upsampler 111 | self.upsample = [ 112 | tf.keras.layers.Conv2DTranspose( 113 | 1, 114 | config.upsample_kernel, 115 | config.upsample_stride, 116 | padding='same') 117 | for _ in range(config.upsample_layers)] 118 | # wavenet blocks 119 | self.blocks = [] 120 | layers_per_cycle = config.num_layers // config.num_cycles 121 | for i in range(config.num_layers): 122 | dilation = config.dilation_rate ** (i % layers_per_cycle) 123 | self.blocks.append( 124 | Block( 125 | config.channels, 126 | config.kernel_size, 127 | dilation, 128 | last=i == config.num_layers - 1)) 129 | # for output 130 | self.proj_out = [ 131 | tf.keras.layers.Conv1D(config.channels, 1, activation=tf.nn.relu), 132 | tf.keras.layers.Conv1D(1, 1)] 133 | 134 | def call(self, signal, timestep, mel): 135 | """Generate output signal. 136 | Args: 137 | signal: tf.Tensor, [B, T], noised signal. 138 | timestep: tf.Tensor, [B], int, timesteps of current markov chain. 139 | mel: tf.Tensor, [B, T // hop, M], mel-spectrogram. 140 | Returns: 141 | tf.Tensor, [B, T], generated. 142 | """ 143 | # [B, T, C(=channels)] 144 | x = tf.nn.relu(self.proj(signal[..., None])) 145 | # [B, E'] 146 | embed = tf.gather(self.embed, timestep - 1) 147 | # [B, E] 148 | for proj in self.proj_embed: 149 | embed = tf.nn.swish(proj(embed)) 150 | # [B, T, M, 1], treat as 2D tensor. 151 | mel = mel[..., None] 152 | for upsample in self.upsample: 153 | mel = tf.nn.leaky_relu(upsample(mel), self.config.leak) 154 | # [B, T, M] 155 | mel = tf.squeeze(mel, axis=-1) 156 | 157 | context = [] 158 | for block in self.blocks: 159 | # [B, T, C], [B, T, C] 160 | x, skip = block(x, embed, mel) 161 | context.append(skip) 162 | # [B, T, C] 163 | scale = self.config.num_layers ** 0.5 164 | context = tf.reduce_sum(context, axis=0) / scale 165 | # [B, T, 1] 166 | for proj in self.proj_out: 167 | context = proj(context) 168 | # [B, T] 169 | return tf.squeeze(context, axis=-1) 170 | 171 | def embedding(self, iter): 172 | """Generate embedding. 173 | Args: 174 | iter: int, maximum iteration. 175 | Returns: 176 | tf.Tensor, [iter, E(=embedding_size)], embedding vectors. 177 | """ 178 | # [E // 2] 179 | logit = tf.linspace(0., 1., self.config.embedding_size // 2) 180 | exp = tf.pow(10, logit * self.config.embedding_factor) 181 | # [iter] 182 | timestep = tf.range(1, iter + 1) 183 | # [iter, E // 2] 184 | comp = exp[None] * tf.cast(timestep[:, None], tf.float32) 185 | # [iter, E] 186 | return tf.concat([tf.sin(comp), tf.cos(comp)], axis=-1) 187 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | librosa==0.7.0 2 | matplotlib==3.3.1 3 | numba==0.48 # for preventing librosa conflict 4 | numpy==1.19.1 5 | pydub==0.23.1 # for tensorflow_datasets 6 | tensorflow>=2.1.0 7 | tensorflow_datasets==3.1.0 8 | tqdm==4.48.2 -------------------------------------------------------------------------------- /rsrc/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/revsic/tf-diffwave/32b0b403e7ca157f015f9af9f7dcdfa79e312a6a/rsrc/loss.png -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import tensorflow as tf 8 | import tqdm 9 | 10 | from config import Config 11 | from dataset import LJSpeech 12 | from model import DiffWave 13 | 14 | 15 | class Trainer: 16 | """WaveGrad trainer. 17 | """ 18 | def __init__(self, model, lj, config): 19 | """Initializer. 20 | Args: 21 | model: DiffWave, diffwave model. 22 | lj: LJSpeech, LJ-speec dataset 23 | which provides already batched and normalized speech dataset. 24 | config: Config, unified configurations. 25 | """ 26 | self.model = model 27 | self.lj = lj 28 | self.config = config 29 | 30 | self.split = config.train.split // config.data.batch 31 | self.trainset = self.lj.dataset().take(self.split) \ 32 | .shuffle(config.train.bufsiz) \ 33 | .prefetch(tf.data.experimental.AUTOTUNE) 34 | self.testset = self.lj.dataset().skip(self.split) \ 35 | .prefetch(tf.data.experimental.AUTOTUNE) 36 | 37 | self.optim = tf.keras.optimizers.Adam( 38 | config.train.lr(), 39 | config.train.beta1, 40 | config.train.beta2, 41 | config.train.eps) 42 | 43 | self.eval_intval = config.train.eval_intval // config.data.batch 44 | self.ckpt_intval = config.train.ckpt_intval // config.data.batch 45 | 46 | self.train_log = tf.summary.create_file_writer( 47 | os.path.join(config.train.log, config.train.name, 'train')) 48 | self.test_log = tf.summary.create_file_writer( 49 | os.path.join(config.train.log, config.train.name, 'test')) 50 | 51 | self.ckpt_path = os.path.join( 52 | config.train.ckpt, config.train.name, config.train.name) 53 | 54 | self.alpha_bar = np.cumprod(1 - config.model.beta()) 55 | self.cmap = tf.constant(plt.get_cmap('viridis').colors, dtype=tf.float32) 56 | 57 | def compute_loss(self, signal, logmel): 58 | """Compute loss for noise estimation. 59 | Args: 60 | signal: tf.Tensor, [B, T], raw audio signal segment. 61 | logmel: tf.Tensor, [B, T // hop, mel], mel-spectrogram. 62 | Returns: 63 | loss: tf.Tensor, [], L1-loss between noise and estimation. 64 | """ 65 | # [B] 66 | bsize = tf.shape(signal)[0] 67 | # [B] 68 | timesteps = tf.random.uniform( 69 | [bsize], 1, self.config.model.iter + 1, dtype=tf.int32) 70 | # [B] 71 | noise_level = tf.gather(self.alpha_bar, timesteps - 1) 72 | # [B, T], [B, T] 73 | noised, noise = self.model.diffusion(signal, noise_level) 74 | # [B, T] 75 | eps = self.model.pred_noise(noised, timesteps, logmel) 76 | # [] 77 | loss = tf.reduce_mean(tf.abs(eps - noise)) 78 | return loss 79 | 80 | def train(self, step=0, ir_unit=5): 81 | """Train wavegrad. 82 | Args: 83 | step: int, starting step. 84 | ir_unit: int, log ir units. 85 | """ 86 | for _ in tqdm.trange(step // self.split, self.config.train.epoch): 87 | with tqdm.tqdm(total=self.split, leave=False) as pbar: 88 | for signal, logmel in self.trainset: 89 | with tf.GradientTape() as tape: 90 | tape.watch(self.model.trainable_variables) 91 | loss = self.compute_loss(signal, logmel) 92 | 93 | grad = tape.gradient(loss, self.model.trainable_variables) 94 | self.optim.apply_gradients( 95 | zip(grad, self.model.trainable_variables)) 96 | 97 | norm = tf.reduce_mean([tf.norm(g) for g in grad]) 98 | del grad 99 | 100 | step += 1 101 | pbar.update() 102 | pbar.set_postfix( 103 | {'loss': loss.numpy().item(), 104 | 'step': step, 105 | 'grad': norm.numpy().item()}) 106 | 107 | with self.train_log.as_default(): 108 | tf.summary.scalar('loss', loss, step) 109 | tf.summary.scalar('grad norm', norm, step) 110 | if step % self.eval_intval == 0: 111 | pred, _ = self.model(logmel) 112 | tf.summary.audio( 113 | 'train', pred[..., None], self.config.data.sr, step) 114 | tf.summary.image( 115 | 'train mel', self.mel_img(pred), step) 116 | 117 | del pred 118 | 119 | if step % self.ckpt_intval == 0: 120 | self.model.write( 121 | '{}_{}.ckpt'.format(self.ckpt_path, step), 122 | self.optim) 123 | 124 | loss = [ 125 | self.compute_loss(signal, logmel).numpy().item() 126 | for signal, logmel in self.testset 127 | ] 128 | loss = sum(loss) / len(loss) 129 | with self.test_log.as_default(): 130 | tf.summary.scalar('loss', loss, step) 131 | 132 | gt, pred, ir = self.eval() 133 | tf.summary.audio( 134 | 'gt', gt[None, :, None], self.config.data.sr, step) 135 | tf.summary.audio( 136 | 'eval', pred[None, :, None], self.config.data.sr, step) 137 | 138 | tf.summary.image( 139 | 'gt mel', self.mel_img(gt[None]), step) 140 | tf.summary.image( 141 | 'eval mel', self.mel_img(pred[None]), step) 142 | 143 | for i in range(0, len(ir), ir_unit): 144 | tf.summary.audio( 145 | 'ir_{}'.format(i), 146 | np.clip(ir[i][None, :, None], -1., 1.), 147 | self.config.data.sr, step) 148 | 149 | del gt, pred, ir 150 | 151 | def mel_img(self, signal): 152 | """Generate mel-spectrogram images. 153 | Args: 154 | signal: tf.Tensor, [B, T], speech signal. 155 | Returns: 156 | tf.Tensor, [B, mel, T // hop, 3], mel-spectrogram in viridis color map. 157 | """ 158 | # [B, T // hop, mel] 159 | _, mel = self.lj.mel_fn(signal) 160 | # [B, mel, T // hop] 161 | mel = tf.transpose(mel, [0, 2, 1]) 162 | # minmax norm in range(0, 1) 163 | mel = (mel - tf.reduce_min(mel)) / (tf.reduce_max(mel) - tf.reduce_min(mel)) 164 | # in range(0, 255) 165 | mel = tf.cast(mel * 255, tf.int32) 166 | # [B, mel, T // hop, 3] 167 | mel = tf.gather(self.cmap, mel) 168 | # make origin lower 169 | mel = tf.image.flip_up_down(mel) 170 | return mel 171 | 172 | def eval(self): 173 | """Generate evaluation purpose audio. 174 | Returns: 175 | speech: np.ndarray, [T], ground truth. 176 | pred: np.ndarray, [T], predicted. 177 | ir: List[np.ndarray], config.model.iter x [T], 178 | intermediate represnetations. 179 | """ 180 | # [T] 181 | speech = next(iter(lj.rawset)) 182 | # [1, T // hop, mel] 183 | _, logmel = lj.mel_fn(speech[None]) 184 | # [1, T], iter x [1, T] 185 | pred, ir = self.model(logmel) 186 | # [T] 187 | pred = tf.squeeze(pred, axis=0).numpy() 188 | # config.model.iter x [T] 189 | ir = [np.squeeze(i, axis=0) for i in ir] 190 | return speech.numpy(), pred, ir 191 | 192 | if __name__ == '__main__': 193 | parser = argparse.ArgumentParser() 194 | parser.add_argument('--config', default=None) 195 | parser.add_argument('--load-step', default=0, type=int) 196 | parser.add_argument('--ir-unit', default=10, type=int) 197 | parser.add_argument('--data-dir', default=None) 198 | parser.add_argument('--download', default=False, action='store_true') 199 | parser.add_argument('--from-raw', default=False, action='store_true') 200 | args = parser.parse_args() 201 | 202 | config = Config() 203 | if args.config is not None: 204 | print('[*] load config: ' + args.config) 205 | with open(args.config) as f: 206 | config = Config.load(json.load(f)) 207 | 208 | log_path = os.path.join(config.train.log, config.train.name) 209 | if not os.path.exists(log_path): 210 | os.makedirs(log_path) 211 | 212 | ckpt_path = os.path.join(config.train.ckpt, config.train.name) 213 | if not os.path.exists(ckpt_path): 214 | os.makedirs(ckpt_path) 215 | 216 | lj = LJSpeech(config.data, args.data_dir, args.download, not args.from_raw) 217 | diffwave = DiffWave(config.model) 218 | trainer = Trainer(diffwave, lj, config) 219 | 220 | if args.load_step > 0: 221 | super_path = os.path.join(config.train.ckpt, config.train.name) 222 | ckpt_path = '{}_{}.ckpt'.format(config.train.name, args.load_step) 223 | ckpt_path = next( 224 | name for name in os.listdir(super_path) 225 | if name.startswith(ckpt_path) and name.endswith('.index')) 226 | ckpt_path = os.path.join(super_path, ckpt_path[:-6]) 227 | 228 | print('[*] load checkpoint: ' + ckpt_path) 229 | trainer.model.restore(ckpt_path, trainer.optim) 230 | 231 | with open(os.path.join(config.train.ckpt, config.train.name + '.json'), 'w') as f: 232 | json.dump(config.dump(), f) 233 | 234 | trainer.train(args.load_step, args.ir_unit) 235 | -------------------------------------------------------------------------------- /utils/noam_schedule.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class NoamScheduler(tf.keras.optimizers.schedules.LearningRateSchedule): 5 | """Noam learning rate scheduler from Vaswani et al., 2017. 6 | """ 7 | def __init__(self, learning_rate, warmup_steps, channels): 8 | """Initializer. 9 | Args: 10 | learning_rate: float, initial learning rate. 11 | warmup_steps: int, warmup steps. 12 | channels: int, base hidden size of the model. 13 | """ 14 | super(NoamScheduler, self).__init__() 15 | self.learning_rate = learning_rate 16 | self.warmup_steps = warmup_steps 17 | self.channels = channels 18 | 19 | def __call__(self, step): 20 | """Compute learning rate. 21 | """ 22 | return self.learning_rate * self.channels ** -0.5 * \ 23 | tf.minimum(step ** -0.5, step * self.warmup_steps ** -1.5) 24 | 25 | def get_config(self): 26 | """Serialize configurations. 27 | """ 28 | return { 29 | 'learning_rate': self.learning_rate, 30 | 'warmup_steps': self.warmup_steps, 31 | 'channels': self.channels, 32 | } 33 | --------------------------------------------------------------------------------