├── .gitignore ├── README.md ├── dataset.py ├── hparams.py ├── model ├── __init__.py ├── module.py ├── upsample.py └── wavenet.py ├── preprocess.py ├── synthesize.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | result -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WaveNet Tensorflow v2 2 | 3 | [WaveNet](https://arxiv.org/abs/1609.03499) with TensorFlow 2.0 4 | 5 | ## Train 6 | 7 | ### Library 8 | 9 | ``` 10 | tensorflow 11 | numpy 12 | librosa 13 | scipy 14 | ``` 15 | 16 | 17 | ```bash= 18 | ls wavs/*.wav | tail -n+10 > train_files.txt 19 | ls wavs/*.wav | head -n10 > test_files.txt 20 | python preprocess.py 21 | python train.py 22 | ``` 23 | 24 | ## Eval 25 | 26 | ```bash= 27 | python synthesize.py [input_path] [output_path] [weight_path] 28 | ``` 29 | 30 | ## Result 31 | 32 | The 1-day training model synsthesis result is below: 33 | 0001 34 | 35 | pre trained model link is still a little way off. 36 | 37 | ## Book 38 | 39 | ![4a948f86-b96f-42ea-927f-14232d57589c_base_resized](https://user-images.githubusercontent.com/33972190/76142918-7feff200-60b5-11ea-9569-0423f8bb3fe9.jpg) 40 | 41 | https://techbookfest.org/product/5743005264773120 42 | https://otakuassembly.booth.pm/items/1834753 43 | 44 | ## References 45 | 46 | [WaveNet](https://arxiv.org/abs/1609.03499) 47 | 48 | 論文。 49 | 50 | [r9y9/wavenet_vocoder](https://github.com/r9y9/wavenet_vocoder) 51 | [LESS IS MORE/WaveNet vocoder をやってみましたので、その記録です](https://r9y9.github.io/blog/2018/01/28/wavenet_vocoder/) 52 | 53 | いくつもの論文で使われている実装。PyTorch。 54 | 55 | [Rayhane-mamah/Tacotron-2](https://github.com/Rayhane-mamah/Tacotron-2) 56 | 57 | Tactron2 + WaveNetのDeepMindの人の実装。TensorFlow v1。 58 | 59 | [Monthly Hacker's Blog/VQ-VAEの追試で得たWaveNetのノウハウをまとめてみた。](https://www.monthly-hack.com/entry/2018/02/23/203208) 60 | 61 | WaveNetに関する知見が纏められている。 62 | 63 | [Synthesize Human Speech with WaveNet](https://chainer-colab-notebook.readthedocs.io/ja/latest/notebook/official_example/wavenet.html) 64 | 65 | Colabを用いた解説。Chainer。 66 | 67 | [The LJ Speech Dataset](https://keithito.com/LJ-Speech-Dataset/) 68 | 69 | 英語の単一話者のデータセット。 70 | 71 | [JSUT](https://sites.google.com/site/shinnosuketakamichi/publication/jsut) 72 | 73 | 日本語の単一話者のデータセット。 74 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | import hparams 4 | 5 | 6 | def parse_function(example_proto): 7 | features = { 8 | 'wav': tf.io.FixedLenSequenceFeature([], tf.int64, allow_missing=True), 9 | 'mel_sp': tf.io.FixedLenSequenceFeature([], tf.float32, allow_missing=True), 10 | 'mel_sp_frames': tf.io.FixedLenFeature([], tf.int64), 11 | } 12 | parsed_features = tf.io.parse_single_example(example_proto, features) 13 | 14 | mel_sp = tf.reshape(parsed_features['mel_sp'], 15 | [hparams.num_mels, parsed_features['mel_sp_frames']]) 16 | 17 | return parsed_features['wav'], mel_sp 18 | 19 | 20 | def adjust_time_resolution(wav, mel_sp): 21 | if hparams.seq_len % hparams.hop_size == 0: 22 | max_steps = hparams.seq_len 23 | else: 24 | max_steps = hparams.seq_len - hparams.seq_len % hparams.hop_size 25 | 26 | max_time_frames = max_steps // hparams.hop_size 27 | 28 | mel_offset = tf.random.uniform([1], minval=0, maxval=tf.shape(mel_sp)[1] - max_time_frames, 29 | dtype=tf.int32)[0] 30 | wav_offset = mel_offset * hparams.hop_size 31 | 32 | mel_sp = mel_sp[:, mel_offset:mel_offset + max_time_frames] 33 | x = wav[wav_offset:wav_offset + max_steps] 34 | x = tf.one_hot(x, 256, axis=-1, dtype=tf.float32) 35 | y = wav[wav_offset + 1:wav_offset + max_steps + 1] 36 | 37 | return x, mel_sp, y 38 | 39 | 40 | def get_train_data(): 41 | train_data = tf.data.TFRecordDataset(filenames=hparams.result_dir + "train_data.tfrecord")\ 42 | .shuffle(300)\ 43 | .map(parse_function, num_parallel_calls=tf.data.experimental.AUTOTUNE)\ 44 | .map(adjust_time_resolution, num_parallel_calls=tf.data.experimental.AUTOTUNE)\ 45 | .batch(hparams.batch_size)\ 46 | .prefetch(buffer_size=tf.data.experimental.AUTOTUNE) 47 | 48 | return train_data 49 | 50 | -------------------------------------------------------------------------------- /hparams.py: -------------------------------------------------------------------------------- 1 | 2 | upsample_scales = [4, 8, 8] 3 | 4 | seq_len = 10240 5 | sampling_rate = 24000 6 | num_mels = 80 7 | n_fft = 1024 8 | hop_size = 256 9 | win_size = 1024 10 | 11 | learning_rate = 1e-3 12 | beta_1 = 0.9 13 | exponential_decay_rate = 0.5 14 | exponential_decay_steps = 200000 15 | epoch = 2000 16 | batch_size = 8 17 | 18 | n_test_samples = 1 19 | save_interval = 50 20 | 21 | train_files = "./train_files.txt" 22 | test_files = "./test_files.txt" 23 | result_dir = "./result/" 24 | load_path = None 25 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kokeshing/WaveNet-tf2/36c35ee93cf9c93716b50553f54a01e3cbdeb067/model/__init__.py -------------------------------------------------------------------------------- /model/module.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class ReLU(tf.keras.layers.ReLU): 5 | def __init__(self, *args, **kwargs): 6 | super().__init__(*args, **kwargs) 7 | 8 | def call(self, inputs, is_synthesis=False): 9 | return super().call(inputs) 10 | 11 | 12 | class Conv1D(tf.keras.layers.Conv1D): 13 | def __init__(self, filters, kernel_size, strides=1, padding='causal', 14 | dilation_rate=1, residual_channels=None, *args, **kwargs): 15 | super().__init__(filters, kernel_size, strides=strides, padding=padding, 16 | dilation_rate=dilation_rate) 17 | 18 | self.k = kernel_size 19 | self.d = dilation_rate 20 | 21 | if kernel_size > 1: 22 | self.queue_len = kernel_size + (kernel_size - 1) * (dilation_rate - 1) 23 | self.queue_dim = residual_channels 24 | self.init_queue() 25 | 26 | def build(self, input_shape): 27 | super().build(input_shape) 28 | 29 | self.linearized_weights = tf.cast(tf.reshape(self.kernel, [-1, self.filters]), dtype=tf.float32) 30 | 31 | def call(self, inputs, is_synthesis=False): 32 | if not is_synthesis: 33 | return super().call(inputs) 34 | 35 | if self.k > 1: 36 | self.queue = self.queue[:, 1:, :] 37 | self.queue = tf.concat([self.queue, tf.expand_dims(inputs[:, -1, :], axis=1)], axis=1) 38 | 39 | if self.d > 1: 40 | inputs = self.queue[:, 0::self.d, :] 41 | else: 42 | inputs = self.queue 43 | 44 | outputs = tf.matmul(tf.reshape(inputs, [1, -1]), self.linearized_weights) 45 | outputs = tf.nn.bias_add(outputs, self.bias) 46 | 47 | # [batch_size, 1(time_len), channels] 48 | return tf.reshape(outputs, [-1, 1, self.filters]) 49 | 50 | def init_queue(self): 51 | self.queue = tf.zeros([1, self.queue_len, self.queue_dim], dtype=tf.float32) 52 | 53 | 54 | class ResidualConv1DGLU(tf.keras.Model): 55 | """ 56 | conv1d + GLU => add condition => residual add + skip connection 57 | """ 58 | 59 | def __init__(self, residual_channels, gate_channels, kernel_size, 60 | skip_out_channels=None, dilation_rate=1, **kwargs): 61 | super().__init__() 62 | 63 | self.residual_channels = residual_channels 64 | 65 | if skip_out_channels is None: 66 | skip_out_channels = residual_channels 67 | 68 | self.dilated_conv = Conv1D(gate_channels, 69 | kernel_size=kernel_size, 70 | padding='causal', 71 | dilation_rate=dilation_rate, 72 | residual_channels=residual_channels) 73 | 74 | self.conv_c = Conv1D(gate_channels, 75 | kernel_size=1, 76 | padding='causal') 77 | 78 | self.conv_skip = Conv1D(skip_out_channels, 79 | kernel_size=1, 80 | padding='causal') 81 | self.conv_out = Conv1D(residual_channels, 82 | kernel_size=1, 83 | padding='causal') 84 | 85 | @tf.function 86 | def call(self, inputs, c): 87 | x = self.dilated_conv(inputs) 88 | x_tanh, x_sigmoid = tf.split(x, num_or_size_splits=2, axis=2) 89 | 90 | c = self.conv_c(c) 91 | c_tanh, c_sigmoid = tf.split(c, num_or_size_splits=2, axis=2) 92 | 93 | x_tanh, x_sigmoid = x_tanh + c_tanh, x_sigmoid + c_sigmoid 94 | x = tf.nn.tanh(x_tanh) * tf.nn.sigmoid(x_sigmoid) 95 | 96 | s = self.conv_skip(x) 97 | x = self.conv_out(x) 98 | 99 | x = x + inputs 100 | 101 | return x, s 102 | 103 | def init_queue(self): 104 | self.dilated_conv.init_queue() 105 | 106 | def synthesis_feed(self, inputs, c): 107 | x = self.dilated_conv(inputs, is_synthesis=True) 108 | x_tanh, x_sigmoid = tf.split(x, num_or_size_splits=2, axis=2) 109 | 110 | c = self.conv_c(c, is_synthesis=True) 111 | c_tanh, c_sigmoid = tf.split(c, num_or_size_splits=2, axis=2) 112 | 113 | x_tanh, x_sigmoid = x_tanh + c_tanh, x_sigmoid + c_sigmoid 114 | x = tf.nn.tanh(x_tanh) * tf.nn.sigmoid(x_sigmoid) 115 | 116 | s = self.conv_skip(x, is_synthesis=True) 117 | x = self.conv_out(x, is_synthesis=True) 118 | 119 | x = x + inputs 120 | 121 | return x, s 122 | 123 | 124 | class CrossEntropyLoss(tf.keras.losses.Loss): 125 | def __init__(self, num_classes=256, name=None): 126 | super().__init__() 127 | self.num_classes = num_classes 128 | 129 | def call(self, targets, outputs): 130 | targets_ = tf.one_hot(targets, depth=self.num_classes) 131 | losses = tf.nn.softmax_cross_entropy_with_logits(labels=targets_, logits=outputs) 132 | 133 | return tf.reduce_mean(losses) 134 | -------------------------------------------------------------------------------- /model/upsample.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class UpsampleCond(tf.keras.Model): 5 | def __init__(self, rate, **kwargs): 6 | super().__init__() 7 | 8 | self.upsampling = tf.keras.layers.UpSampling2D((1, rate), interpolation='nearest') 9 | 10 | self.conv = tf.keras.layers.Conv2D(1, kernel_size=(1, rate * 2 + 1), 11 | padding='same', use_bias=False, 12 | kernel_initializer=tf.constant_initializer(1. / (rate * 2 + 1))) 13 | 14 | @tf.function 15 | def call(self, x): 16 | return self.conv(self.upsampling(x)) 17 | 18 | 19 | class UpsampleNetwork(tf.keras.Model): 20 | def __init__(self, upsample_scales, **kwargs): 21 | super().__init__() 22 | 23 | self.upsample_layers = [UpsampleCond(scale) for scale in upsample_scales] 24 | 25 | @tf.function 26 | def call(self, feat): 27 | for layer in self.upsample_layers: 28 | feat = layer(feat) 29 | 30 | return feat 31 | -------------------------------------------------------------------------------- /model/wavenet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | from .module import Conv1D, ReLU, ResidualConv1DGLU 5 | from .upsample import UpsampleNetwork 6 | from utils import mulaw_quantize 7 | 8 | 9 | class WaveNet(tf.keras.Model): 10 | def __init__(self, num_mels, upsample_scales): 11 | super().__init__() 12 | 13 | self.upsample_network = UpsampleNetwork(upsample_scales) 14 | 15 | self.first_layer = Conv1D(128, 16 | kernel_size=1, 17 | padding='causal') 18 | 19 | self.residual_blocks = [] 20 | for _ in range(2): 21 | for i in range(10): 22 | self.residual_blocks.append( 23 | ResidualConv1DGLU(128, 24 | 256, 25 | kernel_size=3, 26 | skip_out_channels=128, 27 | dilation_rate=2 ** i) 28 | ) 29 | 30 | self.final_layers = [ 31 | ReLU(), 32 | Conv1D(128, 33 | kernel_size=1, 34 | padding='causal'), 35 | ReLU(), 36 | Conv1D(256, 37 | kernel_size=1, 38 | padding='causal') 39 | ] 40 | 41 | @tf.function 42 | def call(self, inputs, c): 43 | c = tf.expand_dims(c, axis=-1) 44 | c = self.upsample_network(c) 45 | c = tf.transpose(tf.squeeze(c, axis=-1), perm=[0, 2, 1]) 46 | 47 | x = self.first_layer(inputs) 48 | skips = None 49 | for block in self.residual_blocks: 50 | x, h = block(x, c) 51 | if skips is None: 52 | skips = h 53 | else: 54 | skips = skips + h 55 | 56 | x = skips 57 | for layer in self.final_layers: 58 | x = layer(x) 59 | 60 | return x 61 | 62 | def init_queue(self): 63 | for block in self.residual_blocks: 64 | block.init_queue() 65 | 66 | def synthesis(self, c): 67 | c = tf.expand_dims(c, axis=-1) 68 | c = self.upsample_network(c) 69 | c = tf.transpose(tf.squeeze(c, axis=-1), perm=[0, 2, 1]) 70 | 71 | batch_size, time_len, _ = c.shape 72 | initial_value = mulaw_quantize(0, 256) 73 | inputs = tf.one_hot(indices=initial_value, 74 | depth=256, dtype=tf.float32) 75 | inputs = tf.tile(tf.reshape(inputs, [1, 1, 256]), 76 | [batch_size, 1, 1]) 77 | 78 | outputs = [] 79 | for i in range(time_len): 80 | c_t = tf.expand_dims(c[:, i, :], axis=1) 81 | 82 | x = self.first_layer(inputs, is_synthesis=True) 83 | 84 | skips = None 85 | for block in self.residual_blocks: 86 | x, h = block.synthesis_feed(x, c_t) 87 | 88 | if skips is not None: 89 | skips = skips + h 90 | else: 91 | skips = h 92 | 93 | x = skips 94 | for layer in self.final_layers: 95 | x = layer(x, is_synthesis=True) 96 | 97 | x = tf.argmax(tf.squeeze(x, axis=1), axis=-1) 98 | x = tf.one_hot(x, depth=256) 99 | inputs = x 100 | 101 | outputs.append(tf.argmax(x, axis=1).numpy()) 102 | 103 | outputs = np.array(outputs) 104 | 105 | return np.transpose(outputs, [1, 0]) 106 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import os 4 | 5 | from utils import * 6 | import hparams 7 | 8 | 9 | def _bytes_feature(value): 10 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 11 | 12 | 13 | def _int64_feature(value): 14 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 15 | 16 | 17 | def _int64_array_feature(value): 18 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 19 | 20 | 21 | def _float32_array_feature(value): 22 | return tf.train.Feature(float_list=tf.train.FloatList(value=value)) 23 | 24 | 25 | def audio_preprocess(wav_path): 26 | wav = load_wav(wav_path, hparams.sampling_rate) 27 | wav = trim_silence(wav, top_db=40, fft_size=2048, hop_size=512) 28 | wav = normalize(wav) * 0.95 29 | 30 | mel_sp = melspectrogram(wav, hparams.sampling_rate, hparams.num_mels, 31 | n_fft=hparams.n_fft, hop_size=hparams.hop_size, win_size=hparams.win_size) 32 | 33 | pad = (wav.shape[0] // hparams.hop_size + 1) * hparams.hop_size - len(wav) 34 | wav = np.pad(wav, (0, pad), mode='constant', constant_values=0.0) 35 | assert len(wav) % hparams.hop_size == 0 36 | 37 | wav = mulaw_quantize(wav, 255) 38 | 39 | mel_sp_channels, mel_sp_frames = mel_sp.shape 40 | mel_sp = mel_sp.flatten() 41 | record = tf.train.Example(features=tf.train.Features(feature={ 42 | 'wav': _int64_array_feature(wav), 43 | 'mel_sp': _float32_array_feature(mel_sp), 44 | 'mel_sp_frames': _int64_feature(mel_sp_frames), 45 | })) 46 | 47 | return record 48 | 49 | 50 | def createTFRecord(): 51 | os.makedirs(hparams.result_dir, exist_ok=True) 52 | 53 | train_files = files_to_list(hparams.train_files) 54 | with tf.io.TFRecordWriter(hparams.result_dir + "train_data.tfrecord") as writer: 55 | for wav_path in train_files: 56 | record = audio_preprocess(wav_path) 57 | writer.write(record.SerializeToString()) 58 | 59 | 60 | if __name__ == '__main__': 61 | createTFRecord() 62 | -------------------------------------------------------------------------------- /synthesize.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import argparse 4 | 5 | 6 | from model.wavenet import WaveNet 7 | from utils import load_wav, normalize, melspectrogram, inv_mulaw_quantize, save_wav 8 | import hparams 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('input_path', help="Path of input audio") 12 | parser.add_argument('output_path', help="Path of synthesized audio") 13 | parser.add_argument('weight_path', help="Path of checkpoint (ex:./result/weights/wavenet_0800)") 14 | args = parser.parse_args() 15 | 16 | 17 | def synthesize(mel_sp, save_path, weight_path): 18 | wavenet = WaveNet(hparams.num_mels, hparams.upsample_scales) 19 | wavenet.load_weights(weight_path) 20 | mel_sp = tf.expand_dims(mel_sp, axis=0) 21 | 22 | outputs = wavenet.synthesis(mel_sp) 23 | outputs = np.squeeze(outputs) 24 | outputs = inv_mulaw_quantize(outputs) 25 | 26 | save_wav(outputs, save_path, hparams.sampling_rate) 27 | 28 | 29 | if __name__ == '__main__': 30 | wav = load_wav(args.input_path, hparams.sampling_rate) 31 | wav = normalize(wav) * 0.95 32 | 33 | mel_sp = melspectrogram(wav, hparams.sampling_rate, hparams.num_mels, 34 | n_fft=hparams.n_fft, hop_size=hparams.hop_size, win_size=hparams.win_size) 35 | 36 | synthesize(mel_sp, args.output_path, args.weight_path) 37 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | import numpy as np 4 | import os 5 | 6 | from model.wavenet import WaveNet 7 | from model.module import CrossEntropyLoss 8 | from dataset import get_train_data 9 | import hparams 10 | 11 | 12 | @tf.function 13 | def train_step(model, x, mel_sp, y, loss_fn, optimizer): 14 | with tf.GradientTape() as tape: 15 | y_hat = model(x, mel_sp) 16 | loss = loss_fn(y, y_hat) 17 | 18 | gradients = tape.gradient(loss, model.trainable_variables) 19 | optimizer.apply_gradients(zip(gradients, model.trainable_variables)) 20 | 21 | return loss 22 | 23 | 24 | def train(): 25 | os.makedirs(hparams.result_dir + "weights/", exist_ok=True) 26 | 27 | summary_writer = tf.summary.create_file_writer(hparams.result_dir) 28 | 29 | wavenet = WaveNet(hparams.num_mels, hparams.upsample_scales) 30 | 31 | loss_fn = CrossEntropyLoss(num_classes=256) 32 | 33 | lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(hparams.learning_rate, 34 | decay_steps=hparams.exponential_decay_steps, 35 | decay_rate=hparams.exponential_decay_rate) 36 | optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule, 37 | beta_1=hparams.beta_1) 38 | 39 | if hparams.load_path is not None: 40 | wavenet.load_weights(hparams.load_path) 41 | step = np.load(hparams.result_dir + "weights/step.npy") 42 | step = step 43 | print(f"weights load: {hparams.load_path}") 44 | else: 45 | step = 0 46 | 47 | for epoch in range(hparams.epoch): 48 | train_data = get_train_data() 49 | for x, mel_sp, y in train_data: 50 | loss = train_step(wavenet, x, mel_sp, y, loss_fn, optimizer) 51 | with summary_writer.as_default(): 52 | tf.summary.scalar('train/loss', loss, step=step) 53 | 54 | step += 1 55 | 56 | if epoch % hparams.save_interval == 0: 57 | print(f'Step {step}, Loss: {loss}') 58 | np.save(hparams.result_dir + f"weights/step.npy", np.array(step)) 59 | wavenet.save_weights(hparams.result_dir + f"weights/wavenet_{epoch:04}") 60 | 61 | np.save(hparams.result_dir + f"weights/step.npy", np.array(step)) 62 | wavenet.save_weights(hparams.result_dir + f"weights/wavenet_{epoch:04}") 63 | 64 | if __name__ == '__main__': 65 | train() 66 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import librosa 3 | from scipy.io import wavfile 4 | 5 | 6 | def files_to_list(filename): 7 | with open(filename, encoding="utf-8") as f: 8 | files = f.readlines() 9 | 10 | files = [file.rstrip() for file in files] 11 | 12 | return files 13 | 14 | 15 | def load_wav(path, sampling_rate): 16 | wav = librosa.core.load(path, sr=sampling_rate)[0] 17 | 18 | return wav 19 | 20 | 21 | def trim_silence(wav, top_db=40, fft_size=2048, hop_size=512): 22 | return librosa.effects.trim(wav, top_db=top_db, frame_length=fft_size, hop_length=hop_size)[0] 23 | 24 | 25 | def normalize(wav): 26 | return librosa.util.normalize(wav) 27 | 28 | 29 | def mulaw(x, mu=255): 30 | return np.sign(x) * np.log1p(mu * np.abs(x)) / np.log1p(mu) 31 | 32 | 33 | def mulaw_quantize(x, mu=255): 34 | x = mulaw(x) 35 | x = (x + 1) / 2 * mu 36 | 37 | return x.astype(np.int) 38 | 39 | 40 | def inv_mulaw(x, mu=255): 41 | return np.sign(x) * (1.0 / mu) * ((1.0 + mu) ** np.abs(x) - 1.0) 42 | 43 | 44 | def inv_mulaw_quantize(x, mu=255): 45 | x = 2 * x.astype(np.float32) / mu - 1 46 | 47 | return inv_mulaw(x, mu) 48 | 49 | 50 | def save_wav(wav, path, sr): 51 | wav *= 32767 / max(0.0001, np.max(np.abs(wav))) 52 | wavfile.write(path, sr, wav.astype(np.int16)) 53 | 54 | 55 | def melspectrogram(wav, sampling_rate, num_mels, n_fft, hop_size, win_size): 56 | d = librosa.stft(y=wav, n_fft=n_fft, hop_length=hop_size, 57 | win_length=win_size, pad_mode='constant') 58 | mel_filter = librosa.filters.mel(sampling_rate, n_fft, 59 | n_mels=num_mels) 60 | s = np.dot(mel_filter, np.abs(d)) 61 | 62 | return np.log10(np.maximum(s, 1e-5)) 63 | --------------------------------------------------------------------------------