├── .gitignore ├── README.md ├── generate.py ├── test ├── data │ └── helloworld.wav ├── test_causal_conv.py ├── test_dataloader.py ├── test_encode_sound.py └── test_wavenet_module.py ├── train.py └── wavenet ├── config.py ├── exceptions.py ├── model.py ├── networks.py └── utils └── data.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | __pycache__ 3 | .cache 4 | .idea 5 | output 6 | datasets 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WaveNet 2 | 3 | Yet another WaveNet implementation in PyTorch. 4 | 5 | The purpose of this implementation is Well-structured, reusable and easily understandable. 6 | 7 | - [WaveNet Paper](https://arxiv.org/pdf/1609.03499.pdf) 8 | - [WaveNet: A Generative Model for Raw Audio](https://deepmind.com/blog/wavenet-generative-model-raw-audio/) 9 | 10 | ## Prerequisites 11 | 12 | - System 13 | - Linux or macOS 14 | - CPU or (NVIDIA GPU + CUDA CuDNN) 15 | - It can run on Single CPU/GPU or Multi GPUs. 16 | - Python 3 17 | 18 | - Libraries 19 | - PyTorch >= 0.3.0 20 | - librosa >= 0.5.1 21 | 22 | ## Training 23 | 24 | ```bash 25 | python train.py \ 26 | --data_dir=./test/data \ 27 | --output_dir=./outputs 28 | ``` 29 | 30 | Use `python train.py --help` to see more options. 31 | 32 | ## Generating 33 | 34 | It's just for testing. You need to modify for real world. 35 | 36 | ```bash 37 | python generate.py \ 38 | --model=./outputs/model \ 39 | --seed=./test/data/helloworld.wav \ 40 | --out=./output/helloworld.wav 41 | ``` 42 | 43 | Use `python generate.py --help` to see more options. 44 | 45 | ## File structures 46 | 47 | `networks.py` and `model.py` is main implementations. 48 | 49 | - wavenet 50 | - `config.py` : Training options 51 | - `networks.py` : The neural network architecture of WaveNet 52 | - `model.py` : Calculate loss and optimizing 53 | - utils 54 | - `data.py` : Utilities for loading data 55 | - test 56 | - Some tests for check if it's correct model like casual, dilated.. 57 | - `train.py` : A script for WaveNet training 58 | - `generate.py` : A script for generating with pre-trained model 59 | 60 | # TODO 61 | 62 | - [ ] Add some nice samples 63 | - [ ] Global conditions 64 | - [ ] Local conditions 65 | - [ ] Faster generating 66 | - [ ] Parallel WaveNet 67 | - [ ] General Generator 68 | 69 | ## References 70 | 71 | - https://github.com/ibab/tensorflow-wavenet 72 | - https://qiita.com/MasaEguchi/items/cd5f7e9735a120f27e2a 73 | - https://github.com/musyoku/wavenet/issues/4 74 | - https://github.com/vincentherrmann/pytorch-wavenet 75 | - http://sergeiturukin.com/2017/03/02/wavenet.html 76 | 77 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | """ 2 | A script for WaveNet training 3 | """ 4 | import torch 5 | import librosa 6 | import datetime 7 | import numpy as np 8 | 9 | import wavenet.config as config 10 | from wavenet.model import WaveNet 11 | import wavenet.utils.data as utils 12 | 13 | 14 | class Generator: 15 | def __init__(self, args): 16 | self.args = args 17 | 18 | self.wavenet = WaveNet(args.layer_size, args.stack_size, 19 | args.in_channels, args.res_channels) 20 | 21 | self.wavenet.load(args.model_dir, args.step) 22 | 23 | @staticmethod 24 | def _variable(data): 25 | tensor = torch.from_numpy(data).float() 26 | 27 | if torch.cuda.is_available(): 28 | return torch.autograd.Variable(tensor.cuda()) 29 | else: 30 | return torch.autograd.Variable(tensor) 31 | 32 | def _make_seed(self, audio): 33 | audio = np.pad([audio], [[0, 0], [self.wavenet.receptive_fields, 0], [0, 0]], 'constant') 34 | 35 | if self.args.sample_size: 36 | seed = audio[:, :self.args.sample_size, :] 37 | else: 38 | seed = audio[:, :self.wavenet.receptive_fields*2, :] 39 | 40 | return seed 41 | 42 | def _get_seed_from_audio(self, filepath): 43 | audio = utils.load_audio(filepath, self.args.sample_rate) 44 | audio_length = len(audio) 45 | 46 | audio = utils.mu_law_encode(audio, self.args.in_channels) 47 | audio = utils.one_hot_encode(audio, self.args.in_channels) 48 | 49 | seed = self._make_seed(audio) 50 | 51 | return self._variable(seed), audio_length 52 | 53 | def _save_to_audio_file(self, data): 54 | data = data[0].cpu().data.numpy() 55 | data = utils.one_hot_decode(data, axis=1) 56 | audio = utils.mu_law_decode(data, self.args.in_channels) 57 | 58 | librosa.output.write_wav(self.args.out, audio, self.args.sample_rate) 59 | print('Saved wav file at {}'.format(self.args.out)) 60 | 61 | return librosa.get_duration(y=audio, sr=self.args.sample_rate) 62 | 63 | def generate(self): 64 | outputs = [] 65 | inputs, audio_length = self._get_seed_from_audio(self.args.seed) 66 | 67 | while True: 68 | new = self.wavenet.generate(inputs) 69 | 70 | outputs = torch.cat((outputs, new), dim=1) if len(outputs) else new 71 | 72 | print('{0}/{1} samples are generated.'.format(len(outputs[0]), audio_length)) 73 | 74 | if len(outputs[0]) >= audio_length: 75 | break 76 | 77 | inputs = torch.cat((inputs[:, :-len(new[0]), :], new), dim=1) 78 | 79 | outputs = outputs[:, :audio_length, :] 80 | 81 | return self._save_to_audio_file(outputs) 82 | 83 | 84 | if __name__ == '__main__': 85 | args = config.parse_args(is_training=False) 86 | 87 | generator = Generator(args) 88 | 89 | start_time = datetime.datetime.now() 90 | 91 | duration = generator.generate() 92 | 93 | print('Generate {0} seconds took {1}'.format(duration, datetime.datetime.now() - start_time)) 94 | 95 | -------------------------------------------------------------------------------- /test/data/helloworld.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/golbin/WaveNet/71be43d8cfbe4fdfe1699ada179674e4e485ad5b/test/data/helloworld.wav -------------------------------------------------------------------------------- /test/test_causal_conv.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test Dilated Causal Convolution 3 | """ 4 | 5 | import os 6 | import sys 7 | 8 | import torch 9 | import pytest 10 | import numpy as np 11 | 12 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 13 | from wavenet.networks import CausalConv1d, DilatedCausalConv1d 14 | 15 | 16 | CAUSAL_RESULT = [ 17 | [[[18, 38, 42, 46, 50, 54, 58, 62, 66, 70, 74, 78, 82, 86, 90, 94]]] 18 | ] 19 | 20 | DILATED_CAUSAL_RESULT = [ 21 | [[[56, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184]]], 22 | [[[144, 176, 192, 208, 224, 240, 256, 272, 288, 304, 320, 336, 352]]], 23 | [[[368, 416, 448, 480, 512, 544, 576, 608, 640]]], 24 | [[[1008]]] 25 | ] 26 | 27 | 28 | def causal_conv(data, in_channels, out_channels, print_result=True): 29 | conv = CausalConv1d(in_channels, out_channels) 30 | conv.init_weights_for_test() 31 | 32 | output = conv(data) 33 | 34 | print('Causal convolution ---') 35 | if print_result: 36 | print(' {0}'.format(output.data.numpy().astype(int))) 37 | 38 | return output 39 | 40 | 41 | def dilated_causal_conv(step, data, channels, dilation=1, print_result=True): 42 | conv = DilatedCausalConv1d(channels, dilation=dilation) 43 | conv.init_weights_for_test() 44 | 45 | output = conv(data) 46 | 47 | print('{0} step is OK: dilation={1}, size={2}'.format(step, dilation, output.shape)) 48 | if print_result: 49 | print(' {0}'.format(output.data.numpy().astype(int))) 50 | 51 | return output 52 | 53 | 54 | @pytest.fixture 55 | def generate_x(): 56 | """Test normal convolution 1d""" 57 | x = np.arange(1, 33, dtype=np.float32) 58 | x = np.reshape(x, [1, 2, 16]) # [batch, channel, timestep] 59 | x = torch.autograd.Variable(torch.from_numpy(x)) 60 | 61 | print('Input size={0}'.format(x.shape)) 62 | print(x.data.numpy().astype(int)) 63 | print('-'*80) 64 | 65 | return x 66 | 67 | 68 | @pytest.fixture 69 | def test_causal_conv(generate_x): 70 | """Test normal convolution 1d""" 71 | result = causal_conv(generate_x, 2, 1) 72 | 73 | np.testing.assert_array_equal( 74 | result.data.numpy().astype(int), 75 | CAUSAL_RESULT[0] 76 | ) 77 | 78 | return result 79 | 80 | 81 | def test_dilated_causal_conv(test_causal_conv): 82 | """Test dilated causal convolution : dilation=[1, 2, 4, 8]""" 83 | result = test_causal_conv 84 | 85 | for i in range(0, 4): 86 | result = dilated_causal_conv(i+1, result, 1, dilation=2**i) 87 | 88 | np.testing.assert_array_equal( 89 | result.data.numpy().astype(int), 90 | DILATED_CAUSAL_RESULT[i] 91 | ) 92 | 93 | -------------------------------------------------------------------------------- /test/test_dataloader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test mu-law encoding and decoding 3 | """ 4 | 5 | import os 6 | import sys 7 | 8 | import torch 9 | 10 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 11 | from wavenet.utils.data import DataLoader 12 | 13 | 14 | RECEPTIVE_FIELDS = 1000 15 | SAMPLE_SIZE = 2000 16 | SAMPLE_RATE = 8000 17 | IN_CHANNELS = 256 18 | TEST_AUDIO_DIR = os.path.join(os.path.dirname(__file__), 'data') 19 | 20 | 21 | def test_data_loader(): 22 | data_loader = DataLoader(TEST_AUDIO_DIR, 23 | RECEPTIVE_FIELDS, SAMPLE_SIZE, SAMPLE_RATE, IN_CHANNELS, 24 | shuffle=False) 25 | 26 | dataset_size = [] 27 | 28 | for dataset in data_loader: 29 | input_size = [] 30 | target_size = [] 31 | 32 | for i, t in dataset: 33 | input_size.append(i.shape) 34 | target_size.append(t.shape) 35 | 36 | dataset_size.append([input_size, target_size]) 37 | 38 | assert dataset_size[0][0][0] == torch.Size([1, 2000, 256]) 39 | assert dataset_size[0][1][0] == torch.Size([1, 1000]) 40 | assert dataset_size[0][0][-1] == torch.Size([1, 1839, 256]) 41 | assert dataset_size[0][1][-1] == torch.Size([1, 839]) 42 | 43 | assert dataset_size[1][0][0] == torch.Size([1, 2000, 256]) 44 | assert dataset_size[1][1][0] == torch.Size([1, 1000]) 45 | assert dataset_size[1][0][-1] == torch.Size([1, 1762, 256]) 46 | assert dataset_size[1][1][-1] == torch.Size([1, 762]) 47 | 48 | assert len(dataset_size[0][0]) == 8 49 | assert len(dataset_size[1][0]) == 8 50 | 51 | -------------------------------------------------------------------------------- /test/test_encode_sound.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test mu-law encoding and decoding 3 | """ 4 | 5 | import os 6 | import sys 7 | 8 | import numpy as np 9 | 10 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 11 | from wavenet.utils.data import * 12 | 13 | 14 | SAMPLE_RATE = 8000 15 | QUANTIZATION_CHANNEL = 256 16 | 17 | TEST_AUDIO_FILE = os.path.join(os.path.dirname(__file__), 18 | 'data', 'helloworld.wav') 19 | 20 | 21 | def test_mu_law_encode(): 22 | raw_audio = load_audio(TEST_AUDIO_FILE, SAMPLE_RATE) 23 | raw_audio = raw_audio[2007:2013, :] 24 | 25 | mu_law_encoded = mu_law_encode(raw_audio, QUANTIZATION_CHANNEL) 26 | mu_law_decoded = mu_law_decode(mu_law_encoded, QUANTIZATION_CHANNEL) 27 | one_hot_encoded = one_hot_encode(mu_law_encoded, QUANTIZATION_CHANNEL) 28 | one_hot_decoded = one_hot_decode(one_hot_encoded) 29 | one_hot_decoded.shape = (one_hot_decoded.size, 1) 30 | 31 | print('--- Raw audio ---') 32 | print(raw_audio) 33 | print('--- mu-law encoded ---') 34 | print(mu_law_encoded) 35 | print('--- mu-law decoded ---') 36 | print(mu_law_decoded) 37 | print('--- one-hot encoded ---') 38 | print(one_hot_encoded) 39 | print('--- one-hot decoded ---') 40 | print(one_hot_decoded) 41 | 42 | np.testing.assert_array_equal(mu_law_encoded, one_hot_decoded) 43 | 44 | assert np.min(raw_audio - mu_law_decoded) < 0.0011 45 | assert np.max(raw_audio - mu_law_decoded) < 0.012 46 | assert np.mean(raw_audio - mu_law_decoded) < 0.0036 47 | 48 | -------------------------------------------------------------------------------- /test/test_wavenet_module.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test mu-law encoding and decoding 3 | """ 4 | 5 | import os 6 | import sys 7 | 8 | import torch 9 | import numpy as np 10 | 11 | import pytest 12 | 13 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 14 | from wavenet.networks import WaveNet 15 | from wavenet.exceptions import InputSizeError 16 | 17 | 18 | LAYER_SIZE = 5 # 10 in paper 19 | STACK_SIZE = 2 # 5 in paper 20 | IN_CHANNELS = 2 # 256 in paper. quantized and one-hot input. 21 | RES_CHANNELS = 512 # 512 in paper 22 | 23 | 24 | def generate_dummy(dummy_length): 25 | x = np.arange(0, dummy_length, dtype=np.float32) 26 | x = np.reshape(x, [1, int(dummy_length / 2), 2]) # [batch, timestep, channels] 27 | x = torch.autograd.Variable(torch.from_numpy(x)) 28 | 29 | return x 30 | 31 | 32 | @pytest.fixture 33 | def wavenet(): 34 | net = WaveNet(LAYER_SIZE, STACK_SIZE, IN_CHANNELS, RES_CHANNELS) 35 | 36 | print(net) 37 | 38 | return net 39 | 40 | 41 | def test_wavenet_output_size(wavenet): 42 | x = generate_dummy(wavenet.receptive_fields * 2 + 2) 43 | 44 | output = wavenet(x) 45 | 46 | # input size = receptive field size + 1 (* two channels) 47 | # output size = input size - receptive field size 48 | # = 1 49 | assert output.shape == torch.Size([1, 1, 2]) 50 | 51 | x = generate_dummy(wavenet.receptive_fields * 4) 52 | 53 | output = wavenet(x) 54 | 55 | # input size = receptive field size * 2 (* two channels) 56 | # output size = input size - receptive field size 57 | # = receptive field size 58 | assert output.shape == torch.Size([1, wavenet.receptive_fields, 2]) 59 | 60 | 61 | def test_wavenet_fail_with_short_input(wavenet): 62 | x = generate_dummy(wavenet.receptive_fields * 2) 63 | 64 | try: 65 | wavenet(x) 66 | raise pytest.fail("Should be failed. Input size is too short.") 67 | except InputSizeError: 68 | pass 69 | 70 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | A script for WaveNet training 3 | """ 4 | import os 5 | 6 | import wavenet.config as config 7 | from wavenet.model import WaveNet 8 | from wavenet.utils.data import DataLoader 9 | 10 | 11 | class Trainer: 12 | def __init__(self, args): 13 | self.args = args 14 | 15 | self.wavenet = WaveNet(args.layer_size, args.stack_size, 16 | args.in_channels, args.res_channels, 17 | lr=args.lr) 18 | 19 | self.data_loader = DataLoader(args.data_dir, self.wavenet.receptive_fields, 20 | args.sample_size, args.sample_rate, args.in_channels) 21 | 22 | def infinite_batch(self): 23 | while True: 24 | for dataset in self.data_loader: 25 | for inputs, targets in dataset: 26 | yield inputs, targets 27 | 28 | def run(self): 29 | total_steps = 0 30 | 31 | for inputs, targets in self.infinite_batch(): 32 | loss = self.wavenet.train(inputs, targets) 33 | 34 | total_steps += 1 35 | 36 | print('[{0}/{1}] loss: {2}'.format(total_steps, args.num_steps, loss)) 37 | 38 | if total_steps > self.args.num_steps: 39 | break 40 | 41 | self.wavenet.save(args.model_dir) 42 | 43 | 44 | def prepare_output_dir(args): 45 | args.log_dir = os.path.join(args.output_dir, 'log') 46 | args.model_dir = os.path.join(args.output_dir, 'model') 47 | args.test_output_dir = os.path.join(args.output_dir, 'test') 48 | 49 | os.makedirs(args.log_dir, exist_ok=True) 50 | os.makedirs(args.model_dir, exist_ok=True) 51 | os.makedirs(args.test_output_dir, exist_ok=True) 52 | 53 | 54 | if __name__ == '__main__': 55 | args = config.parse_args() 56 | 57 | prepare_output_dir(args) 58 | 59 | trainer = Trainer(args) 60 | 61 | trainer.run() 62 | -------------------------------------------------------------------------------- /wavenet/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training Options 3 | """ 4 | import argparse 5 | 6 | parser = argparse.ArgumentParser() 7 | 8 | parser.add_argument('--layer_size', type=int, default=10, 9 | help='layer_size: 10 = layer[dilation=1, dilation=2, 4, 8, 16, 32, 64, 128, 256, 512]') 10 | parser.add_argument('--stack_size', type=int, default=5, 11 | help='stack_size: 5 = stack[layer1, layer2, layer3, layer4, layer5]') 12 | parser.add_argument('--in_channels', type=int, default=256, 13 | help='input channel size. mu-law encode factor, one-hot size') 14 | parser.add_argument('--res_channels', type=int, default=512, help='number of channel for residual network') 15 | 16 | parser.add_argument('--sample_rate', type=int, default=16000, help='Sampling rates for input sound') 17 | parser.add_argument('--sample_size', type=int, default=100000, help='Sample size for training input') 18 | 19 | 20 | def parse_args(is_training=True): 21 | if is_training: 22 | parser.add_argument('--data_dir', type=str, default='./test/data', help='Training data dir') 23 | parser.add_argument('--output_dir', type=str, default='./output', help='Output dir for saving model and etc') 24 | parser.add_argument('--num_steps', type=int, default=100000, help='Total training steps') 25 | parser.add_argument('--lr', type=float, default=0.0002, help='learning rate decay') 26 | else: 27 | parser.add_argument('--model_dir', type=str, required=True, help='Pre-trained model dir') 28 | parser.add_argument('--step', type=int, default=0, help='A specific step of pre-trained model to use') 29 | parser.add_argument('--seed', type=str, help='A seed file to generate sound') 30 | parser.add_argument('--out', type=str, help='Output file name which is generated') 31 | 32 | return parser.parse_args() 33 | 34 | 35 | def print_help(): 36 | parser.print_help() 37 | -------------------------------------------------------------------------------- /wavenet/exceptions.py: -------------------------------------------------------------------------------- 1 | class InputSizeError(Exception): 2 | def __init__(self, input_size, receptive_fields, output_size): 3 | 4 | message = 'Input size has to be larger than receptive_fields\n' 5 | message += 'Input size: {0}, Receptive fields size: {1}, Output size: {2}'.format( 6 | input_size, receptive_fields, output_size) 7 | 8 | super(InputSizeError, self).__init__(message) 9 | -------------------------------------------------------------------------------- /wavenet/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Main model of WaveNet 3 | Calculate loss and optimizing 4 | """ 5 | import os 6 | 7 | import torch 8 | import torch.optim 9 | 10 | from wavenet.networks import WaveNet as WaveNetModule 11 | 12 | 13 | class WaveNet: 14 | def __init__(self, layer_size, stack_size, in_channels, res_channels, lr=0.002): 15 | 16 | self.net = WaveNetModule(layer_size, stack_size, in_channels, res_channels) 17 | 18 | self.in_channels = in_channels 19 | self.receptive_fields = self.net.receptive_fields 20 | 21 | self.lr = lr 22 | self.loss = self._loss() 23 | self.optimizer = self._optimizer() 24 | 25 | self._prepare_for_gpu() 26 | 27 | @staticmethod 28 | def _loss(): 29 | loss = torch.nn.CrossEntropyLoss() 30 | 31 | if torch.cuda.is_available(): 32 | loss = loss.cuda() 33 | 34 | return loss 35 | 36 | def _optimizer(self): 37 | return torch.optim.Adam(self.net.parameters(), lr=self.lr) 38 | 39 | def _prepare_for_gpu(self): 40 | if torch.cuda.device_count() > 1: 41 | print("{0} GPUs are detected.".format(torch.cuda.device_count())) 42 | self.net = torch.nn.DataParallel(self.net) 43 | 44 | if torch.cuda.is_available(): 45 | self.net.cuda() 46 | 47 | def train(self, inputs, targets): 48 | """ 49 | Train 1 time 50 | :param inputs: Tensor[batch, timestep, channels] 51 | :param targets: Torch tensor [batch, timestep, channels] 52 | :return: float loss 53 | """ 54 | outputs = self.net(inputs) 55 | 56 | loss = self.loss(outputs.view(-1, self.in_channels), 57 | targets.long().view(-1)) 58 | 59 | self.optimizer.zero_grad() 60 | loss.backward() 61 | self.optimizer.step() 62 | 63 | return loss.data[0] 64 | 65 | def generate(self, inputs): 66 | """ 67 | Generate 1 time 68 | :param inputs: Tensor[batch, timestep, channels] 69 | :return: Tensor[batch, timestep, channels] 70 | """ 71 | outputs = self.net(inputs) 72 | 73 | return outputs 74 | 75 | @staticmethod 76 | def get_model_path(model_dir, step=0): 77 | basename = 'wavenet' 78 | 79 | if step: 80 | return os.path.join(model_dir, '{0}_{1}.pkl'.format(basename, step)) 81 | else: 82 | return os.path.join(model_dir, '{0}.pkl'.format(basename)) 83 | 84 | def load(self, model_dir, step=0): 85 | """ 86 | Load pre-trained model 87 | :param model_dir: 88 | :param step: 89 | :return: 90 | """ 91 | print("Loading model from {0}".format(model_dir)) 92 | 93 | model_path = self.get_model_path(model_dir, step) 94 | 95 | self.net.load_state_dict(torch.load(model_path)) 96 | 97 | def save(self, model_dir, step=0): 98 | print("Saving model into {0}".format(model_dir)) 99 | 100 | model_path = self.get_model_path(model_dir, step) 101 | 102 | torch.save(self.net.state_dict(), model_path) 103 | 104 | -------------------------------------------------------------------------------- /wavenet/networks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Neural network modules for WaveNet 3 | 4 | References : 5 | https://arxiv.org/pdf/1609.03499.pdf 6 | https://github.com/ibab/tensorflow-wavenet 7 | https://qiita.com/MasaEguchi/items/cd5f7e9735a120f27e2a 8 | https://github.com/musyoku/wavenet/issues/4 9 | """ 10 | import torch 11 | import numpy as np 12 | 13 | from wavenet.exceptions import InputSizeError 14 | 15 | 16 | class DilatedCausalConv1d(torch.nn.Module): 17 | """Dilated Causal Convolution for WaveNet""" 18 | def __init__(self, channels, dilation=1): 19 | super(DilatedCausalConv1d, self).__init__() 20 | 21 | self.conv = torch.nn.Conv1d(channels, channels, 22 | kernel_size=2, stride=1, # Fixed for WaveNet 23 | dilation=dilation, 24 | padding=0, # Fixed for WaveNet dilation 25 | bias=False) # Fixed for WaveNet but not sure 26 | 27 | def init_weights_for_test(self): 28 | for m in self.modules(): 29 | if isinstance(m, torch.nn.Conv1d): 30 | m.weight.data.fill_(1) 31 | 32 | def forward(self, x): 33 | output = self.conv(x) 34 | 35 | return output 36 | 37 | 38 | class CausalConv1d(torch.nn.Module): 39 | """Causal Convolution for WaveNet""" 40 | def __init__(self, in_channels, out_channels): 41 | super(CausalConv1d, self).__init__() 42 | 43 | # padding=1 for same size(length) between input and output for causal convolution 44 | self.conv = torch.nn.Conv1d(in_channels, out_channels, 45 | kernel_size=2, stride=1, padding=1, 46 | bias=False) # Fixed for WaveNet but not sure 47 | 48 | def init_weights_for_test(self): 49 | for m in self.modules(): 50 | if isinstance(m, torch.nn.Conv1d): 51 | m.weight.data.fill_(1) 52 | 53 | def forward(self, x): 54 | output = self.conv(x) 55 | 56 | # remove last value for causal convolution 57 | return output[:, :, :-1] 58 | 59 | 60 | class ResidualBlock(torch.nn.Module): 61 | def __init__(self, res_channels, skip_channels, dilation): 62 | """ 63 | Residual block 64 | :param res_channels: number of residual channel for input, output 65 | :param skip_channels: number of skip channel for output 66 | :param dilation: 67 | """ 68 | super(ResidualBlock, self).__init__() 69 | 70 | self.dilated = DilatedCausalConv1d(res_channels, dilation=dilation) 71 | self.conv_res = torch.nn.Conv1d(res_channels, res_channels, 1) 72 | self.conv_skip = torch.nn.Conv1d(res_channels, skip_channels, 1) 73 | 74 | self.gate_tanh = torch.nn.Tanh() 75 | self.gate_sigmoid = torch.nn.Sigmoid() 76 | 77 | def forward(self, x, skip_size): 78 | """ 79 | :param x: 80 | :param skip_size: The last output size for loss and prediction 81 | :return: 82 | """ 83 | output = self.dilated(x) 84 | 85 | # PixelCNN gate 86 | gated_tanh = self.gate_tanh(output) 87 | gated_sigmoid = self.gate_sigmoid(output) 88 | gated = gated_tanh * gated_sigmoid 89 | 90 | # Residual network 91 | output = self.conv_res(gated) 92 | input_cut = x[:, :, -output.size(2):] 93 | output += input_cut 94 | 95 | # Skip connection 96 | skip = self.conv_skip(gated) 97 | skip = skip[:, :, -skip_size:] 98 | 99 | return output, skip 100 | 101 | 102 | class ResidualStack(torch.nn.Module): 103 | def __init__(self, layer_size, stack_size, res_channels, skip_channels): 104 | """ 105 | Stack residual blocks by layer and stack size 106 | :param layer_size: integer, 10 = layer[dilation=1, dilation=2, 4, 8, 16, 32, 64, 128, 256, 512] 107 | :param stack_size: integer, 5 = stack[layer1, layer2, layer3, layer4, layer5] 108 | :param res_channels: number of residual channel for input, output 109 | :param skip_channels: number of skip channel for output 110 | :return: 111 | """ 112 | super(ResidualStack, self).__init__() 113 | 114 | self.layer_size = layer_size 115 | self.stack_size = stack_size 116 | 117 | self.res_blocks = self.stack_res_block(res_channels, skip_channels) 118 | 119 | @staticmethod 120 | def _residual_block(res_channels, skip_channels, dilation): 121 | block = ResidualBlock(res_channels, skip_channels, dilation) 122 | 123 | if torch.cuda.device_count() > 1: 124 | block = torch.nn.DataParallel(block) 125 | 126 | if torch.cuda.is_available(): 127 | block.cuda() 128 | 129 | return block 130 | 131 | def build_dilations(self): 132 | dilations = [] 133 | 134 | # 5 = stack[layer1, layer2, layer3, layer4, layer5] 135 | for s in range(0, self.stack_size): 136 | # 10 = layer[dilation=1, dilation=2, 4, 8, 16, 32, 64, 128, 256, 512] 137 | for l in range(0, self.layer_size): 138 | dilations.append(2 ** l) 139 | 140 | return dilations 141 | 142 | def stack_res_block(self, res_channels, skip_channels): 143 | """ 144 | Prepare dilated convolution blocks by layer and stack size 145 | :return: 146 | """ 147 | res_blocks = [] 148 | dilations = self.build_dilations() 149 | 150 | for dilation in dilations: 151 | block = self._residual_block(res_channels, skip_channels, dilation) 152 | res_blocks.append(block) 153 | 154 | return res_blocks 155 | 156 | def forward(self, x, skip_size): 157 | """ 158 | :param x: 159 | :param skip_size: The last output size for loss and prediction 160 | :return: 161 | """ 162 | output = x 163 | skip_connections = [] 164 | 165 | for res_block in self.res_blocks: 166 | # output is the next input 167 | output, skip = res_block(output, skip_size) 168 | skip_connections.append(skip) 169 | 170 | return torch.stack(skip_connections) 171 | 172 | 173 | class DensNet(torch.nn.Module): 174 | def __init__(self, channels): 175 | """ 176 | The last network of WaveNet 177 | :param channels: number of channels for input and output 178 | :return: 179 | """ 180 | super(DensNet, self).__init__() 181 | 182 | self.conv1 = torch.nn.Conv1d(channels, channels, 1) 183 | self.conv2 = torch.nn.Conv1d(channels, channels, 1) 184 | 185 | self.relu = torch.nn.ReLU() 186 | self.softmax = torch.nn.Softmax(dim=1) 187 | 188 | def forward(self, x): 189 | output = self.relu(x) 190 | output = self.conv1(output) 191 | output = self.relu(output) 192 | output = self.conv2(output) 193 | 194 | output = self.softmax(output) 195 | 196 | return output 197 | 198 | 199 | class WaveNet(torch.nn.Module): 200 | def __init__(self, layer_size, stack_size, in_channels, res_channels): 201 | """ 202 | Stack residual blocks by layer and stack size 203 | :param layer_size: integer, 10 = layer[dilation=1, dilation=2, 4, 8, 16, 32, 64, 128, 256, 512] 204 | :param stack_size: integer, 5 = stack[layer1, layer2, layer3, layer4, layer5] 205 | :param in_channels: number of channels for input data. skip channel is same as input channel 206 | :param res_channels: number of residual channel for input, output 207 | :return: 208 | """ 209 | super(WaveNet, self).__init__() 210 | 211 | self.receptive_fields = self.calc_receptive_fields(layer_size, stack_size) 212 | 213 | self.causal = CausalConv1d(in_channels, res_channels) 214 | 215 | self.res_stack = ResidualStack(layer_size, stack_size, res_channels, in_channels) 216 | 217 | self.densnet = DensNet(in_channels) 218 | 219 | @staticmethod 220 | def calc_receptive_fields(layer_size, stack_size): 221 | layers = [2 ** i for i in range(0, layer_size)] * stack_size 222 | num_receptive_fields = np.sum(layers) 223 | 224 | return int(num_receptive_fields) 225 | 226 | def calc_output_size(self, x): 227 | output_size = int(x.size(2)) - self.receptive_fields 228 | 229 | self.check_input_size(x, output_size) 230 | 231 | return output_size 232 | 233 | def check_input_size(self, x, output_size): 234 | if output_size < 1: 235 | raise InputSizeError(int(x.size(2)), self.receptive_fields, output_size) 236 | 237 | def forward(self, x): 238 | """ 239 | The size of timestep(3rd dimention) has to be bigger than receptive fields 240 | :param x: Tensor[batch, timestep, channels] 241 | :return: Tensor[batch, timestep, channels] 242 | """ 243 | output = x.transpose(1, 2) 244 | 245 | output_size = self.calc_output_size(output) 246 | 247 | output = self.causal(output) 248 | 249 | skip_connections = self.res_stack(output, output_size) 250 | 251 | output = torch.sum(skip_connections, dim=0) 252 | 253 | output = self.densnet(output) 254 | 255 | return output.transpose(1, 2).contiguous() 256 | 257 | -------------------------------------------------------------------------------- /wavenet/utils/data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Show raw audio and mu-law encode samples to make input source 3 | """ 4 | import os 5 | 6 | import librosa 7 | import numpy as np 8 | 9 | import torch 10 | import torch.utils.data as data 11 | 12 | 13 | def load_audio(filename, sample_rate=16000, trim=True, trim_frame_length=2048): 14 | audio, _ = librosa.load(filename, sr=sample_rate, mono=True) 15 | audio = audio.reshape(-1, 1) 16 | 17 | if trim > 0: 18 | audio, _ = librosa.effects.trim(audio, frame_length=trim_frame_length) 19 | 20 | return audio 21 | 22 | 23 | def one_hot_encode(data, channels=256): 24 | one_hot = np.zeros((data.size, channels), dtype=float) 25 | one_hot[np.arange(data.size), data.ravel()] = 1 26 | 27 | return one_hot 28 | 29 | 30 | def one_hot_decode(data, axis=1): 31 | decoded = np.argmax(data, axis=axis) 32 | 33 | return decoded 34 | 35 | 36 | def mu_law_encode(audio, quantization_channels=256): 37 | """ 38 | Quantize waveform amplitudes. 39 | Reference: https://github.com/vincentherrmann/pytorch-wavenet/blob/master/audio_data.py 40 | """ 41 | mu = float(quantization_channels - 1) 42 | quantize_space = np.linspace(-1, 1, quantization_channels) 43 | 44 | quantized = np.sign(audio) * np.log(1 + mu * np.abs(audio)) / np.log(mu + 1) 45 | quantized = np.digitize(quantized, quantize_space) - 1 46 | 47 | return quantized 48 | 49 | 50 | def mu_law_decode(output, quantization_channels=256): 51 | """ 52 | Recovers waveform from quantized values. 53 | Reference: https://github.com/vincentherrmann/pytorch-wavenet/blob/master/audio_data.py 54 | """ 55 | mu = float(quantization_channels - 1) 56 | 57 | expanded = (output / quantization_channels) * 2. - 1 58 | waveform = np.sign(expanded) * ( 59 | np.exp(np.abs(expanded) * np.log(mu + 1)) - 1 60 | ) / mu 61 | 62 | return waveform 63 | 64 | 65 | class Dataset(data.Dataset): 66 | def __init__(self, data_dir, sample_rate=16000, in_channels=256, trim=True): 67 | super(Dataset, self).__init__() 68 | 69 | self.in_channels = in_channels 70 | self.sample_rate = sample_rate 71 | self.trim = trim 72 | 73 | self.root_path = data_dir 74 | self.filenames = [x for x in sorted(os.listdir(data_dir))] 75 | 76 | def __getitem__(self, index): 77 | filepath = os.path.join(self.root_path, self.filenames[index]) 78 | 79 | raw_audio = load_audio(filepath, self.sample_rate, self.trim) 80 | 81 | encoded_audio = mu_law_encode(raw_audio, self.in_channels) 82 | encoded_audio = one_hot_encode(encoded_audio, self.in_channels) 83 | 84 | return encoded_audio 85 | 86 | def __len__(self): 87 | return len(self.filenames) 88 | 89 | 90 | class DataLoader(data.DataLoader): 91 | def __init__(self, data_dir, receptive_fields, 92 | sample_size=0, sample_rate=16000, in_channels=256, 93 | batch_size=1, shuffle=True): 94 | """ 95 | DataLoader for WaveNet 96 | :param data_dir: 97 | :param receptive_fields: integer. size(length) of receptive fields 98 | :param sample_size: integer. number of timesteps to train at once. 99 | sample size has to be bigger than receptive fields. 100 | |-- receptive field --|---------------------| 101 | |------- samples -------------------| 102 | |---------------------|-- outputs --| 103 | :param sample_rate: sound sampling rates 104 | :param in_channels: number of input channels 105 | :param batch_size: 106 | :param shuffle: 107 | """ 108 | dataset = Dataset(data_dir, sample_rate, in_channels) 109 | 110 | super(DataLoader, self).__init__(dataset, batch_size, shuffle) 111 | 112 | if sample_size <= receptive_fields: 113 | raise Exception("sample_size has to be bigger than receptive_fields") 114 | 115 | self.sample_size = sample_size 116 | self.receptive_fields = receptive_fields 117 | 118 | self.collate_fn = self._collate_fn 119 | 120 | def calc_sample_size(self, audio): 121 | return self.sample_size if len(audio[0]) >= self.sample_size\ 122 | else len(audio[0]) 123 | 124 | @staticmethod 125 | def _variable(data): 126 | tensor = torch.from_numpy(data).float() 127 | 128 | if torch.cuda.is_available(): 129 | return torch.autograd.Variable(tensor.cuda()) 130 | else: 131 | return torch.autograd.Variable(tensor) 132 | 133 | def _collate_fn(self, audio): 134 | audio = np.pad(audio, [[0, 0], [self.receptive_fields, 0], [0, 0]], 'constant') 135 | 136 | if self.sample_size: 137 | sample_size = self.calc_sample_size(audio) 138 | 139 | while sample_size > self.receptive_fields: 140 | inputs = audio[:, :sample_size, :] 141 | targets = audio[:, self.receptive_fields:sample_size, :] 142 | 143 | yield self._variable(inputs),\ 144 | self._variable(one_hot_decode(targets, 2)) 145 | 146 | audio = audio[:, sample_size-self.receptive_fields:, :] 147 | sample_size = self.calc_sample_size(audio) 148 | else: 149 | targets = audio[:, self.receptive_fields:, :] 150 | return self._variable(audio),\ 151 | self._variable(one_hot_decode(targets, 2)) 152 | 153 | --------------------------------------------------------------------------------