├── .gitignore ├── LICENSE ├── README.md ├── data └── yesno.py ├── test ├── data │ └── david.wav ├── models.py ├── test_layers.py ├── test_modules.py ├── test_utils.py └── test_wavenet.py └── wavenet ├── __init__.py ├── layers.py ├── utils.py └── wavenet.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache 102 | 103 | # data folders 104 | data/yesno/ 105 | 106 | # specific file types 107 | *.pt 108 | *.wav 109 | *.mp3 110 | *.png 111 | 112 | # keep files 113 | !test/data/david.wav 114 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 David Pollack 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # fast-wavenet.pytorch 2 | A PyTorch implementation of fast-wavenet 3 | 4 | [fast-wavenet paper](https://arxiv.org/abs/1611.09482) 5 | 6 | [tensorflow fast-wavenet implementation](https://github.com/tomlepaine/fast-wavenet) 7 | 8 | [yesno dataset](http://openslr.org/1/) 9 | 10 | ### Notes 11 | 12 | This repo is currently incomplete, although I do hope to get back to working on this. Notably, I don't have an autoregressive fast forward function. 13 | 14 | I created a [similar repo](https://github.com/dhpollack/bytenet.pytorch) for bytenet, which is a predecessor to WaveNet. This repo does have an autoregressive forward function. 15 | 16 | ### Testing 17 | 18 | ```sh 19 | python -m test.layers_test  20 | ``` 21 | -------------------------------------------------------------------------------- /data/yesno.py: -------------------------------------------------------------------------------- 1 | import torchaudio 2 | 3 | # https://github.com/dhpollack/audio, see VCTK branch 4 | torchaudio.datasets.YESNO(".", download=True, dev_mode=True) 5 | -------------------------------------------------------------------------------- /test/data/david.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhpollack/fast-wavenet.pytorch/853f6ecb1e8d23a5c01fc2455640c6637d30f2f9/test/data/david.wav -------------------------------------------------------------------------------- /test/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | from wavenet.layers import * 5 | 6 | '''Simple test model to use Conv1dExt module 7 | 8 | ''' 9 | 10 | class Net(nn.Module): 11 | def __init__(self): 12 | super(Net, self).__init__() 13 | self.conv1 = Conv1dExt(in_channels=1, 14 | out_channels=4, 15 | kernel_size=1, 16 | bias=False) 17 | self.conv2 = Conv1dExt(in_channels=1, 18 | out_channels=4, 19 | kernel_size=1, 20 | bias=False) 21 | self.conv3 = Conv1dExt(in_channels=4, 22 | out_channels=4, 23 | kernel_size=1, 24 | bias=False) 25 | self.conv4 = Conv1dExt(in_channels=4, 26 | out_channels=2, 27 | kernel_size=1, 28 | bias=True) 29 | self.conv1.input_tied_modules = [self.conv3] 30 | self.conv1.output_tied_modules = [self.conv2] 31 | self.conv2.input_tied_modules = [self.conv3] 32 | self.conv2.output_tied_modules = [self.conv1] 33 | self.conv3.input_tied_modules = [self.conv4] 34 | 35 | def forward(self, x): 36 | x1 = self.conv1(x) 37 | x2 = self.conv2(x) 38 | x = nn.functional.relu(x1 + x2) 39 | x = nn.functional.relu(self.conv3(x)) 40 | x = nn.functional.relu(self.conv4(x)) 41 | return x 42 | -------------------------------------------------------------------------------- /test/test_layers.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torchaudio 4 | 5 | from wavenet.layers import * 6 | 7 | sig, sr = torchaudio.load("data/yesno/raw/waves_yesno/0_0_0_0_1_1_1_1.wav") 8 | if len(sig.size()) == 2: 9 | sig.unsqueeze_(0) 10 | print("original size: {}".format(sig.size())) 11 | sig = dilate(sig, 12) 12 | print("dilate1 size: {}".format(sig.size())) 13 | sig = dilate(sig, 8, init_dilation=sig.size(0)) 14 | print("dilate2 size: {}".format(sig.size())) 15 | -------------------------------------------------------------------------------- /test/test_modules.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | from wavenet.layers import * 6 | from test.models import * 7 | import numpy as np 8 | 9 | class Test_dilation(unittest.TestCase): 10 | def test_dilate(self): 11 | input = Variable(torch.arange(0, 13).view(1, 1, 13)) 12 | 13 | dilated, _ = dilate(input, 1) 14 | self.assertEqual(dilated.size(), (1, 1, 13)) 15 | self.assertEqual(dilated[0, 0, 4].data[0], 4) 16 | 17 | dilated, _ = dilate(input, 2) 18 | self.assertEqual(dilated.size(), (2, 1, 7)) 19 | self.assertEqual(dilated[1, 0, 2].data[0], 4) 20 | 21 | dilated, _ = dilate(input, 4) 22 | self.assertEqual(dilated.size(), (4, 1, 4)) 23 | self.assertEqual(dilated[3, 0, 1].data[0], 4) 24 | 25 | dilated, _ = dilate(dilated, 1) 26 | self.assertEqual(dilated.size(), (1, 1, 16)) 27 | self.assertEqual(dilated[0, 0, 7].data[0], 4) 28 | 29 | def test_dilate_multichannel(self): 30 | input = Variable(torch.arange(0, 36).view(2, 3, 6)) 31 | 32 | dilated, _ = dilate(input, 1) 33 | self.assertEqual(dilated.size(), (1, 3, 12)) 34 | dilated, _ = dilate(input, 2) 35 | self.assertEqual(dilated.size(), (2, 3, 6)) 36 | dilated, _ = dilate(input, 4) 37 | self.assertEqual(dilated.size(), (4, 3, 3)) 38 | 39 | def test_dilate_invalid(self): 40 | input = Variable(torch.arange(0, 36).view(2, 3, 6)) 41 | 42 | try: 43 | dilate(input, 5) 44 | except AssertionError: 45 | print("raised AssertionError") 46 | 47 | class Test_padding(unittest.TestCase): 48 | def test_constantpad1d(self): 49 | 50 | # equal padding on all 4 sides 51 | input = torch.rand(3, 2, 5) 52 | padding = 1 53 | m = ConstantPad1d(padding) # m for model 54 | output = m(input).data 55 | self.assertEqual(input[0, 0, 0], output[0, padding, padding]) 56 | self.assertTrue(np.all(output[0, :, 0].numpy()==0)) 57 | self.assertTrue(np.all(output[0, :, -1].numpy()==0)) 58 | self.assertTrue(np.all(output[0, 0, :].numpy()==0)) 59 | self.assertTrue(np.all(output[0, -1, :].numpy()==0)) 60 | 61 | # unequal padding on dimensions, but equal within dimension 62 | input = torch.rand(3, 2, 5) 63 | padding = (1, 2) 64 | m = ConstantPad1d(padding) # m for model 65 | output = m(input).data 66 | self.assertEqual(input[0, 0, 0], output[0, padding[1], padding[0]]) 67 | self.assertTrue(np.all(output[0, :, :padding[0]].numpy()==0)) 68 | self.assertTrue(np.all(output[0, :, -padding[0]:].numpy()==0)) 69 | self.assertTrue(np.all(output[0, :padding[1], :].numpy()==0)) 70 | self.assertTrue(np.all(output[0, -padding[1]:, :].numpy()==0)) 71 | 72 | # padding in one dimension, like we'll use for wavenet 73 | input = torch.rand(3, 2, 5) 74 | padding = (3, 0, 0, 0) 75 | m = ConstantPad1d(padding) # m for model 76 | output = m(input).data 77 | self.assertTrue(np.all(output[:, :, :padding[0]].numpy()==0)) 78 | 79 | # non-zero padding, possibly useful for masking 80 | input = torch.rand(3, 2, 5) 81 | padding = (3, 0, 0, 0) 82 | pad_val = -100 83 | m = ConstantPad1d(padding, pad_val) # m for model 84 | output = m(input).data 85 | self.assertTrue(np.all(output[:, :, :padding[0]].numpy()==pad_val)) 86 | 87 | class Test_conv1dext(unittest.TestCase): 88 | def test_ncc(self): 89 | module = Conv1dExt(in_channels=3, 90 | out_channels=5, 91 | kernel_size=4) 92 | rand = Variable(torch.rand(5, 3, 4)) 93 | module._parameters['weight'] = module.weight * module.weight + rand * 1 94 | ncc = module.normalized_cross_correlation() 95 | print("ncc:\n{}".format(ncc.data)) 96 | 97 | class Test_simple_models(unittest.TestCase): 98 | def test_net_forward(self): 99 | 100 | model = Net() 101 | print(model) 102 | self.assertEqual(model.conv1.out_channels, model.conv2.out_channels) 103 | self.assertEqual(model.conv1.out_channels, model.conv3.in_channels) 104 | self.assertEqual(model.conv2.out_channels, model.conv3.in_channels) 105 | self.assertEqual(model.conv3.out_channels, model.conv4.in_channels) 106 | 107 | # simple forward pass 108 | input = Variable(torch.rand(1, 1, 4) * 2 - 1) 109 | output = model(input) 110 | self.assertEqual(output.size(), (1, 2, 4)) 111 | 112 | # feature split 113 | model.conv1.split_feature(feature_i=1) 114 | model.conv2.split_feature(feature_i=3) 115 | print(model) 116 | self.assertEqual(model.conv1.out_channels, model.conv2.out_channels) 117 | self.assertEqual(model.conv1.out_channels, model.conv3.in_channels) 118 | self.assertEqual(model.conv2.out_channels, model.conv3.in_channels) 119 | self.assertEqual(model.conv3.out_channels, model.conv4.in_channels) 120 | 121 | output2 = model(input) 122 | 123 | diff = output - output2 124 | 125 | dot = torch.dot(diff.view(-1), diff.view(-1)) 126 | # should be close to 0 127 | #self.assertTrue(np.isclose(dot.data[0], 0., atol=1e-2)) 128 | print("mse: ", dot.data[0]) 129 | 130 | class Test_dilated_queue(unittest.TestCase): 131 | def test_enqueue(self): 132 | queue = DilatedQueue(max_length=8, num_channels=3) 133 | e = torch.zeros((3)) 134 | for i in range(11): 135 | e = e + 1 136 | queue.enqueue(e) 137 | 138 | data = queue.data[0, :].data 139 | #print('data: ', data) 140 | self.assertEqual(data[0], 9) 141 | self.assertEqual(data[2], 11) 142 | self.assertEqual(data[7], 8) 143 | 144 | def test_dequeue(self): 145 | queue = DilatedQueue(max_length=8, num_channels=1) 146 | e = torch.zeros((1)) 147 | for i in range(11): 148 | e = e + 1 149 | queue.enqueue(e) 150 | 151 | #print('data: ', queue.data) 152 | 153 | for i in range(9): 154 | d = queue.dequeue(num_deq=3, dilation=2) 155 | d = d.data # only using values for tests 156 | #print("dequeue size: {}".format(d.size())) 157 | 158 | self.assertEqual(d[0][0], 5) 159 | self.assertEqual(d[0][1], 7) 160 | self.assertEqual(d[0][2], 9) 161 | 162 | def test_combined(self): 163 | queue = DilatedQueue(max_length=12, num_channels=1) 164 | e = torch.zeros((1)) 165 | for i in range(30): 166 | e = e + 1 167 | queue.enqueue(e) 168 | d = queue.dequeue(num_deq=3, dilation=4) 169 | d = d.data 170 | self.assertEqual(d[0][0], max(i - 7, 0)) 171 | 172 | ''' 173 | class Test_zero_padding(unittest.TestCase): 174 | def test_end_padding(self): 175 | x = torch.ones((3, 4, 5)) 176 | 177 | p = zero_pad(x, num_pad=5, dimension=0) 178 | assert p.size() == (8, 4, 5) 179 | assert p[-1, 0, 0] == 0 180 | 181 | p = zero_pad(x, num_pad=5, dimension=1) 182 | assert p.size() == (3, 9, 5) 183 | assert p[0, -1, 0] == 0 184 | 185 | p = zero_pad(x, num_pad=5, dimension=2) 186 | assert p.size() == (3, 4, 10) 187 | assert p[0, 0, -1] == 0 188 | 189 | def test_start_padding(self): 190 | x = torch.ones((3, 4, 5)) 191 | 192 | p = zero_pad(x, num_pad=5, dimension=0, pad_start=True) 193 | assert p.size() == (8, 4, 5) 194 | assert p[0, 0, 0] == 0 195 | 196 | p = zero_pad(x, num_pad=5, dimension=1, pad_start=True) 197 | assert p.size() == (3, 9, 5) 198 | assert p[0, 0, 0] == 0 199 | 200 | p = zero_pad(x, num_pad=5, dimension=2, pad_start=True) 201 | assert p.size() == (3, 4, 10) 202 | assert p[0, 0, 0] == 0 203 | 204 | def test_narrowing(self): 205 | x = torch.ones((2, 3, 4)) 206 | x = x.narrow(2, 1, 2) 207 | print(x) 208 | 209 | x = x.narrow(0, -1, 3) 210 | print(x) 211 | 212 | assert False 213 | 214 | 215 | class Test_wav_files(unittest.TestCase): 216 | def test_wav_read(self): 217 | data = wavfile.read('trained_generated.wav')[1] 218 | print(data) 219 | # [0.1, -0.53125... 220 | assert False 221 | 222 | 223 | class Test_padding(unittest.TestCase): 224 | def test_1d(self): 225 | x = Variable(torch.ones((2, 3, 4)), requires_grad=True) 226 | 227 | pad = ConstantPad1d(5, dimension=0, pad_start=False) 228 | 229 | res = pad(x) 230 | assert res.size() == (5, 3, 4) 231 | assert res[-1, 0, 0] == 0 232 | 233 | test = gradcheck(ConstantPad1d, x, eps=1e-6, atol=1e-4) 234 | print('gradcheck', test) 235 | 236 | # torch.autograd.backward(res, ) 237 | res.backward() 238 | back = pad.backward(res) 239 | assert back.size() == (2, 3, 4) 240 | assert back[-1, 0, 0] == 1 241 | 242 | # 243 | # pad = ConstantPad1d(5, dimension=1, pad_start=True) 244 | # 245 | # res = pad(x) 246 | # assert res.size() == (2, 5, 4) 247 | # assert res[0, 4, 0] == 0 248 | # 249 | # back = pad.backward(res) 250 | # assert back.size() == (2, 3, 4) 251 | # assert back[0, 2, 0] == 1 252 | 253 | 254 | def test_2d(self): 255 | pad = ConstantPad2d((5, 0, 0, 0)) 256 | x = Variable(torch.ones((2, 3, 4, 5))) 257 | 258 | res = pad.forward(x) 259 | print(res.size()) 260 | assert False 261 | ''' 262 | 263 | def main(): 264 | unittest.main() 265 | 266 | if __name__ == '__main__': 267 | main() 268 | -------------------------------------------------------------------------------- /test/test_utils.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | import torchaudio 4 | from wavenet.utils import * 5 | import numpy as np 6 | 7 | class Test_mu_law(unittest.TestCase): 8 | sig, sr = torchaudio.load("data/yesno/raw/waves_yesno/0_0_0_0_1_1_1_1.wav") 9 | def test1_mu_law_encoding(self): 10 | quantization_channels = 256 11 | mu = quantization_channels - 1. 12 | sig = self.sig.numpy() 13 | sig /= np.abs(sig).max() 14 | self.assertTrue(sig.min() >= -1. and sig.max() <= 1.) 15 | 16 | sig_mu = mu_law_encoding(sig, mu) 17 | print(sig_mu.ptp(), sig_mu.min(), sig_mu.max()) 18 | self.assertTrue(sig_mu.min() >= 0. and sig.max() <= quantization_channels) 19 | 20 | sig_exp = mu_law_expansion(sig_mu, mu) 21 | print(sig_exp.ptp(),sig_exp.min(), sig_exp.max(), sig_exp.shape) 22 | self.assertTrue(sig_exp.min() >= -1. and sig_exp.max() <= 1.) 23 | 24 | diff = sig - sig_exp 25 | mse = np.linalg.norm(diff) / diff.shape[0] 26 | print(mse, np.isclose(mse, 0., atol=1e-4)) 27 | 28 | def test2_prime_factorization(self): 29 | num = 100 30 | factors_true = [2, 2, 5, 5] 31 | factors_calc = list(prime_factors(num)) 32 | self.assertEqual(factors_true, factors_calc, print(factors_calc)) 33 | 34 | num = 16000 35 | factors_true = [2, 2, 2, 2, 2, 2, 2, 5, 5, 5] 36 | factors_calc = list(prime_factors(num)) 37 | self.assertEqual(factors_true, factors_calc, print(factors_calc)) 38 | 39 | if __name__ == '__main__': 40 | unittest.main() 41 | -------------------------------------------------------------------------------- /test/test_wavenet.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | import torchaudio # one could replace torchaudio.load with scipy.io.wavfile.read 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.optim import lr_scheduler 7 | from torch.autograd import Variable 8 | from wavenet.layers import * 9 | from wavenet.utils import * 10 | from wavenet.wavenet import FastWaveNet 11 | from test.models import * 12 | import numpy as np 13 | 14 | class Test_wavenet(unittest.TestCase): 15 | use_cuda = torch.cuda.is_available() 16 | def test1_wavenet_mono(self): 17 | num_samples = 1<<12 18 | input = torch.rand(1, 1, num_samples) 19 | m = FastWaveNet(layers=8, # less than non-unique prime factors in input size 20 | blocks=2, # number of blocks 21 | residual_channels=16, 22 | dilation_channels=32, 23 | skip_channels=16, 24 | quantization_channels=256, 25 | input_len=num_samples, 26 | kernel_size=2) 27 | if self.use_cuda: 28 | m, input = m.cuda(), input.cuda() 29 | input = Variable(input) 30 | print(input.size()) 31 | # forward pass 32 | output = m(input) 33 | # tests on output 34 | print(output.size()) 35 | self.assertEqual(input.size(-1), output.size(-1)) 36 | print(m.sizes) 37 | 38 | def test2_wavenet_stereo(self): 39 | # TODO fix with stereo signals. For now, just split into multiple monos. 40 | num_samples = 1<<12 41 | input = torch.rand(1, 2, num_samples) 42 | batch_size, audio_channels, _ = input.size() 43 | m = FastWaveNet(layers=8, 44 | blocks=2, # number of blocks 45 | residual_channels=16, 46 | dilation_channels=32, 47 | skip_channels=16, 48 | quantization_channels=256, 49 | input_len=num_samples, 50 | audio_channels=audio_channels, 51 | kernel_size=2) 52 | if self.use_cuda: 53 | m, input = m.cuda(), input.cuda() 54 | input = Variable(input) 55 | # forward pass 56 | output = m(input) 57 | # tests on output 58 | print(input.size(), output.size()) 59 | print(m.sizes) 60 | self.assertEqual(input.size(-1), output.size(-1)) 61 | 62 | def test3_wavenet_dummy(self): 63 | try: 64 | import matplotlib.pyplot as plt 65 | except ImportError: 66 | print("install matplotlib for plot of signals") 67 | plt = None 68 | # setup inputs and labels 69 | num_samples = 1<<10 70 | #for some reason the network has a lot more trouble with the sinewave 71 | input = torch.linspace(0, 20*np.pi, num_samples) 72 | input = torch.sin(input) 73 | #input = torch.rand(1, 1, num_samples) * 2. - 1. 74 | input = input.view(1, 1, -1) 75 | labels = input.numpy() 76 | labels = mu_law_encoding(labels, 256) 77 | labels = torch.from_numpy(labels).squeeze().long() 78 | batch_size, audio_channels, _ = input.size() 79 | print(input.size(), labels.size()) 80 | # build network and optimizer 81 | m = FastWaveNet(layers=10, 82 | blocks=6, # number of blocks 83 | residual_channels=16, 84 | dilation_channels=32, 85 | skip_channels=16, 86 | quantization_channels=256, 87 | input_len=num_samples, 88 | audio_channels=audio_channels, 89 | kernel_size=2) 90 | criterion = torch.nn.CrossEntropyLoss() 91 | optimizer = torch.optim.Adam(m.parameters(), lr=0.01) 92 | if self.use_cuda: 93 | m = m.cuda() 94 | criterion = criterion.cuda() 95 | input, labels = input.cuda(), labels.cuda() 96 | input, labels = Variable(input), Variable(labels) 97 | 98 | epochs = 100 99 | losses = [] 100 | for epoch in range(epochs): 101 | m.zero_grad() 102 | output = m(input) 103 | if epoch == 0: 104 | print(m.sizes) 105 | output.squeeze_() 106 | output = output.t() 107 | loss = criterion(output, labels) 108 | losses.append(loss.data[0]) 109 | if epoch % (epochs // 10) == 0: 110 | print("loss of {} at epoch {}".format(losses[-1], epoch+1)) 111 | loss.backward() 112 | optimizer.step() 113 | print("final loss of {} after {} epochs".format(losses[-1], epoch+1)) 114 | 115 | if plt is not None: 116 | input = input.data.float().numpy().ravel() 117 | 118 | output = F.softmax(output.t()) 119 | output = output.max(0)[1].data.float().numpy() 120 | output = mu_law_expansion(output, 256) 121 | print(input.shape, output.shape) 122 | print(input.min(), input.max(), output.min(), output.max()) 123 | 124 | f, ax = plt.subplots(2, sharex=True) 125 | ax[0].plot(input) 126 | ax[1].plot(output) 127 | f.savefig("test/test_wavenet_dummy.png") 128 | 129 | # j = i + 1 130 | for l_i, l_j, in zip(losses[-1:], losses[1:]): 131 | self.assertTrue(l_j >= l_i) 132 | 133 | def test4_wavenet_audio(self): 134 | try: 135 | import matplotlib.pyplot as plt 136 | except ImportError: 137 | print("install matplotlib for plot of signals") 138 | plt = None 139 | 140 | num_samples = 1 << 15 141 | 142 | sig, sr = torchaudio.load("test/data/david.wav") 143 | sig = sig[:-(sig.size(0)%3):3] 144 | input = sig[16000:(16000+num_samples)].contiguous() 145 | # write sample for qualitative test 146 | torchaudio.save("test/data/david_16000hz_input_sample.wav", input, sr//3) 147 | input /= torch.abs(input).max() 148 | assert input.min() >= -1. and input.max() <= 1. 149 | input = input.view(1, 1, -1) 150 | labels = input.numpy() 151 | labels = mu_law_encoding(labels, 256) 152 | labels = torch.from_numpy(labels).squeeze().long() 153 | 154 | # build network and optimizer 155 | m = FastWaveNet(layers=10, 156 | blocks=4, # number of blocks 157 | residual_channels=16, 158 | dilation_channels=32, 159 | skip_channels=16, 160 | quantization_channels=256, 161 | input_len=num_samples, 162 | audio_channels=1, 163 | kernel_size=2) 164 | 165 | epochs = 250 166 | 167 | criterion = torch.nn.CrossEntropyLoss() 168 | optimizer = torch.optim.Adam(m.parameters(), lr=0.01) 169 | scheduler = lr_scheduler.StepLR(optimizer, step_size=epochs//3) 170 | 171 | if self.use_cuda: 172 | m = m.cuda() 173 | criterion = criterion.cuda() 174 | input, labels = input.cuda(), labels.cuda() 175 | input, labels = Variable(input), Variable(labels) 176 | 177 | losses = [] 178 | for epoch in range(epochs): 179 | scheduler.step() 180 | m.zero_grad() 181 | output = m(input) 182 | output.squeeze_() 183 | output = output.t() 184 | loss = criterion(output, labels) 185 | losses.append(loss.data[0]) 186 | if epoch % (epochs // 10) == 0: 187 | print("loss of {} at epoch {}".format(losses[-1], epoch+1)) 188 | loss.backward() 189 | optimizer.step() 190 | print("final loss of {} after {} epochs".format(losses[-1], epoch+1)) 191 | 192 | if plt is not None: 193 | if self.use_cuda: 194 | input = input.data.cpu() 195 | else: 196 | input = input.data 197 | input = input.float().numpy().ravel() 198 | output = F.softmax(output.t()) 199 | if self.use_cuda: 200 | output = output.data.cpu() 201 | else: 202 | output = output.data 203 | output = output.max(0)[1].float().numpy() 204 | output = mu_law_expansion(output, 256) 205 | print(input.shape, output.shape) 206 | print(input.min(), input.max(), output.min(), output.max()) 207 | 208 | f, ax = plt.subplots(2, sharex=True) 209 | ax[0].plot(input) 210 | ax[1].plot(output) 211 | f.savefig("test/test_wavenet_audio.png") 212 | 213 | plt.figure() 214 | plt.plot(losses) 215 | plt.savefig("test/test_wavenet_audio_loss.png") 216 | 217 | output = torch.from_numpy(output) * (1 << 30) 218 | output = output.unsqueeze(1).long() 219 | #output = output.float() 220 | 221 | torchaudio.save("test/data/david_16000hz_output_sample.wav", output, sr//3) 222 | 223 | if __name__ == '__main__': 224 | unittest.main() 225 | -------------------------------------------------------------------------------- /wavenet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhpollack/fast-wavenet.pytorch/853f6ecb1e8d23a5c01fc2455640c6637d30f2f9/wavenet/__init__.py -------------------------------------------------------------------------------- /wavenet/layers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn import Parameter 6 | from torch.autograd import Variable 7 | 8 | '''based on code from: https://github.com/vincentherrmann/pytorch-wavenet 9 | also see: 10 | https://github.com/musyoku/wavenet 11 | https://github.com/ibab/tensorflow-wavenet 12 | https://github.com/tomlepaine/fast-wavenet 13 | 14 | 15 | ''' 16 | 17 | class Conv1dExt(nn.Conv1d): 18 | def __init__(self, *args, **kwargs): 19 | super(Conv1dExt, self).__init__(*args, **kwargs) 20 | self.init_ncc() 21 | self.input_tied_modules = [] # modules whose inputs sensitive to output size 22 | self.output_tied_modules = [] # modules whose outputs sensitive to input size 23 | 24 | def init_ncc(self): 25 | w = self.weight.view(self.weight.size(0), -1) # (G, F*J) what are these? 26 | #mean = torch.mean(w, dim=1).unsqueeze(1).expand_as(w) 27 | mean = torch.mean(w, dim=1).unsqueeze(1) # 0.2 broadcasting 28 | self.t0_factor = w - mean 29 | self.t0_norm = torch.norm(w, p=2, dim=1) # p=2 is the L2 norm 30 | self.start_ncc = Variable(torch.zeros(self.out_channels)) 31 | self.start_ncc = self.normalized_cross_correlation() 32 | 33 | def normalized_cross_correlation(self): 34 | w = self.weight.view(self.weight.size(0), -1) 35 | t_norm = torch.norm(w, p=2, dim=1) 36 | if self.in_channels == 1 & sum(self.kernel_size) == 1: 37 | ncc = w.squeeze() / torch.norm(self.t0_norm, p=2) 38 | ncc = ncc - self.start_ncc 39 | return ncc 40 | #mean = torch.mean(w, dim=1).unsqueeze(1).expand_as(w) 41 | mean = torch.mean(w, dim=1).unsqueeze(1) # 0.2 broadcasting 42 | t_factor = w - mean 43 | h_product = self.t0_factor * t_factor 44 | cov = torch.sum(h_product, dim=1) # (w.size(1) - 1) 45 | # had normalization code commented out 46 | denom = self.t0_norm * t_norm 47 | 48 | ncc = cov / denom 49 | ncc = ncc - self.start_ncc 50 | return ncc 51 | 52 | def split_output_channel(self, channel_i): 53 | '''Split one output channel (a feature) into two, but retain summed value 54 | 55 | Args: 56 | channel_i: (int) number of channel to be split. the ith channel 57 | ''' 58 | 59 | # weight tensor: (out_channels, in_channels, kernel_size) 60 | self.out_channels += 1 61 | 62 | orig_weight = self.weight.data 63 | split_pos = 2 * torch.rand(self.in_channels, self.kernel_size[0]) 64 | 65 | new_weight = torch.zeros(self.out_channels, self.in_channels, self.kernel_size[0]) 66 | if channel_i > 0: 67 | new_weight[:channel_i, :, :] = orig_weight[:channel_i,:, :] 68 | new_weight[channel_i, :, :] = orig_weight[channel_i, :, :] * split_pos 69 | new_weight[channel_i + 1, :, :] = orig_weight[channel_i, :, :] * (2 - split_pos) 70 | if channel_i + 2 < self.out_channels: 71 | new_weight[channel_i + 2, :, :] = orig_weight[channel_i+1, :, :] 72 | if self.bias is not None: 73 | orig_bias = self.bias.data 74 | new_bias = torch.zeros(self.out_channels) 75 | new_bias[:(channel_i + 1)] = orig_bias[:(channel_i + 1)] 76 | new_bias[(channel_i + 1):] = orig_bias[channel_i:] # why no +1? 77 | self.bias = Parameter(new_bias) 78 | 79 | self.weight = Parameter(new_weight) 80 | self.init_ncc() 81 | 82 | def split_input_channel(self, channel_i): 83 | 84 | if channel_i > self.in_channels: 85 | print("cannot split channel {} of {}".format(channel_i, self.in_channels)) 86 | return 87 | 88 | self.in_channels += 1 89 | orig_weight = self.weight.data 90 | dup_slice = orig_weight[:, channel_i, :] * .5 91 | 92 | new_weight = torch.zeros(self.out_channels, self.in_channels, self.kernel_size[0]) 93 | if channel_i > 0: 94 | new_weight[:, :channel_i, :] = orig_weight[:, :channel_i, :] 95 | new_weight[:, channel_i, :] = dup_slice 96 | new_weight[:, channel_i + 1, :] = dup_slice 97 | if channel_i + 1 < self.in_channels: 98 | new_weight[:, channel_i + 2, :] = orig_weight[:, channel_i + 1, :] 99 | self.weight = Parameter(new_weight) 100 | self.init_ncc() 101 | 102 | def split_feature(self, feature_i): 103 | '''Splits feature in output and input channels 104 | 105 | Args: 106 | feature_i: (int) 107 | ''' 108 | self.split_output_channel(channel_i=feature_i) 109 | for dep in self.input_tied_modules: 110 | dep.split_input_channel(channel_i=feature_i) 111 | for dep in self.output_tied_modules: 112 | dep.split_output_channel(channel_i=feature_i) 113 | 114 | def split_features(self, threshold): 115 | '''Decides which features to split if they are below a specific threshold 116 | 117 | Args: 118 | threshold: (float?) less than 1. 119 | ''' 120 | ncc = self.normalized_cross_correlation() 121 | for i, ncc_val in enumerate(ncc): 122 | if ncc_val < threshold: 123 | print("ncc (feature {}): {}".format(i, ncc_val)) 124 | self.split_feature(i) 125 | 126 | class DilatedQueue: 127 | '''This is the queue to do the fast-wavenet implementation 128 | arXiv 1611.09482 129 | ''' 130 | def __init__(self, 131 | max_length, 132 | data=None, 133 | dilation=1, 134 | num_deq=1, 135 | num_channels=1, 136 | dtype=torch.FloatTensor): 137 | self.in_pos = 0 138 | self.out_pos = 0 139 | self.num_deq = num_deq 140 | self.num_channels = num_channels 141 | self.dilation = dilation 142 | self.max_length = max_length 143 | self.data = data 144 | self.dtype = dtype 145 | if data is None: 146 | self.data = Variable(dtype(num_channels, max_length).zero_()) 147 | 148 | def enqueue(self, input): 149 | self.data[:, self.in_pos] = input 150 | self.in_pos = (self.in_pos + 1) % self.max_length 151 | 152 | def dequeue(self, num_deq=1, dilation=1): 153 | start = self.out_pos - ((num_deq - 1) * dilation) 154 | if start < 0: 155 | t1 = self.data[:, start::dilation] 156 | t2 = self.data[:, self.out_pos % dilation:self.out_pos + 1:dilation] 157 | t = torch.cat((t1, t2), 1) 158 | else: 159 | t = self.data[:, start:self.out_pos + 1:dilation] 160 | self.out_pos = (self.out_pos + 1) % self.max_length 161 | return t 162 | 163 | def reset(self): 164 | self.data = Variable(self.dtype(self.num_channels, self.max_length).zero_()) 165 | 166 | def dilate(sigs, dilation): 167 | """ 168 | 169 | Note this will fail if the dilation doesn't allow a whole number amount of padding 170 | 171 | :param x: Tensor or Variable of size (N, L, C), where N is the input dilation, C is the number of channels, and L is the input length 172 | :param dilation: Target dilation. Will be the size of the first dimension of the output tensor. 173 | :param pad_start: If the input length is not compatible with the specified dilation, zero padding is used. This parameter determines wether the zeros are added at the start or at the end. 174 | :return: The dilated Tensor or Variable of size (dilation, C, L*N / dilation). The output might be zero padded at the start 175 | """ 176 | 177 | n, c, l = sigs.size() 178 | dilation_factor = dilation / n 179 | if dilation_factor == 1: 180 | return sigs, 0. 181 | 182 | # zero padding for reshaping 183 | new_n = int(dilation) 184 | new_l = int(np.ceil(l*n/dilation)) 185 | pad_len = (new_n*new_l-n*l)/n 186 | if pad_len > 0: 187 | print("Padding: {}, {}, {}".format(new_n, new_l, pad_len)) 188 | # TODO pad output tensor unevenly for indivisible dilations 189 | assert pad_len == int(pad_len) 190 | # "squeeze" then "unsqueeze" due to limitation of pad function 191 | # which only works with 4d/5d tensors 192 | padding = (int(pad_len), 0, 0, 0) # (d3_St, d3_End, d2_St, d2_End), d0 and d1 unpadded 193 | sigs = pad1d(sigs, padding) 194 | 195 | # reshape according to dilation 196 | sigs = sigs.permute(1, 2, 0).contiguous() # (n, c, l) -> (c, l, n) 197 | sigs = sigs.view(c, new_l, new_n) 198 | sigs = sigs.permute(2, 0, 1).contiguous() # (c, l, n) -> (n, c, l) 199 | 200 | return sigs, pad_len 201 | 202 | class ConstantPad1d(nn.Module): 203 | r"""Pads the input tensor boundaries with a constant value. 204 | 205 | Accepts 3d, 4d, 5d tensors, which is different than the normal PadXd functions 206 | 207 | Args: 208 | padding (int, tuple): the size of the padding. 209 | If is int, uses the same padding in all boundaries. 210 | if a 2-tuple, uses: (d2_padding, d1_padding), equal on both sides 211 | If a 4-tuple, uses 212 | (d2_paddingFront, d2_paddingBack, 213 | d1_paddingFront, d1_paddingBack) 214 | 215 | 216 | Shape: 217 | - Input: :math:`(d0, d1_{in}, d2_{in})` 218 | - Output: :math:`(d0, d1_{out}, d2_{out})` where 219 | :math:`d2_{out} = d2_{in} + d2_paddingFront + d2_paddingBack` 220 | :math:`d1_{out} = d1_{in} + d1_paddingFront + d1_paddingBack` 221 | 222 | Examples:: 223 | 224 | >>> m = nn.ConstantPad1d(3, 3.5) 225 | >>> input = autograd.Variable(torch.randn(3, 320, 480)) 226 | >>> output = m(input) 227 | >>> # using different paddings 228 | >>> m = nn.ConstantPad1d((3, 3, 6, 6), 3.5) 229 | >>> output = m(input) 230 | 231 | """ 232 | 233 | def __init__(self, padding, value=0): 234 | super(ConstantPad1d, self).__init__() 235 | self.padding = self._quadruple(padding) 236 | self.value = value 237 | 238 | def forward(self, input): 239 | x = input 240 | if len(x.size()) == 3: 241 | x = x.view((1,)+x.size()) 242 | x = F.pad(x, self.padding, 'constant', self.value) 243 | x = x.view(x.size()[1:]) 244 | return x 245 | 246 | def __repr__(self): 247 | return self.__class__.__name__ + ' ' + str(self.padding) 248 | 249 | def _quadruple(self, padding): 250 | if isinstance(padding, int): 251 | padding = tuple([padding]*4) 252 | elif len(padding) == 2: 253 | padding = tuple([padding[0]]*2+[padding[1]]*2) 254 | assert len(padding) == 4 255 | return padding 256 | 257 | def pad1d(input,padding,pad_value=0): 258 | return ConstantPad1d(padding, pad_value)(input) 259 | -------------------------------------------------------------------------------- /wavenet/utils.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | import torch 3 | import numpy as np 4 | 5 | def prime_factors(n): 6 | i = 2 7 | factors = [] 8 | while i * i <= n: 9 | if n % i: 10 | i += 1 11 | else: 12 | n //= i 13 | factors.append(i) 14 | if n > 1: 15 | factors.append(n) 16 | return deque(factors) 17 | 18 | def mu_law_encoding(x, quantization_channels): 19 | mu = quantization_channels - 1. 20 | x_mu = np.sign(x) * np.log1p(mu * np.abs(x)) / np.log1p(mu) 21 | return ((x_mu + 1) / 2 * mu + 0.5).astype(int) 22 | 23 | def mu_law_expansion(x_mu, quantization_channels): 24 | mu = quantization_channels - 1. 25 | x = ((x_mu) / mu) * 2 - 1. 26 | return np.sign(x) * (np.exp(np.abs(x) * np.log1p(mu)) - 1.) / mu 27 | 28 | 29 | 30 | def time_to_batch(sigs, sr): 31 | '''Adds zero padding to inputs and reshapes by sample rate. This essentially 32 | rebatches the input into one second batches. 33 | 34 | Used to perform 1D dilated convolution 35 | 36 | Args: 37 | sig: (tensor) in (Bx)SxC; B = # batches, S = # samples, C = # channels 38 | sr: (int) sample rate of audio signal 39 | Outputs: 40 | sig: (tensor) also in SecBatchesx(B x sr)xC, SecBatches = # of seconds in 41 | padded sample 42 | ''' 43 | 44 | unsqueezed = False 45 | 46 | # check if sig is a batch, if not make a batch of 1 47 | if len(sigs.size()) == 1: 48 | sigs.unsqueeze_(0) 49 | unsqueezed = True 50 | 51 | assert len(sigs.size()) == 3 52 | 53 | # pad to the second (i.e. sample rate) 54 | b_num, s_num, c_num = sigs.size() 55 | width_pad = int(sr * np.ceil(s_num / sr + 1)) 56 | lpad_len = width_pad - s_num 57 | lpad = torch.zeros(b_num, pad_left_len, c_num) 58 | sigs = torch.cat((lpad, sigs), 1) # concat on sample dimension 59 | 60 | # reshape to batches of one second each 61 | secs_num = width_pad // sr 62 | sigs = sigs.view(secs_num, -1, c_num) # seconds x (batches*rate) x channels 63 | 64 | return sigs 65 | 66 | def batch_to_time(sigs, sr, lcrop=0): 67 | ''' Reshape to 1d signal from batches of 1 second. 68 | 69 | I'm using the same variable names as above as opposed to the original 70 | author's variables 71 | 72 | Used to perform dilated conv1d 73 | 74 | Args: 75 | sig: (tensor) second_batches_num x (batch_size x sr) x channels 76 | sr: (int) 77 | lcrop: (int) 78 | Outputs: 79 | sig: (tensor) batch_size x # of samples x channels 80 | ''' 81 | 82 | assert len(sigs.size()) == 3 83 | 84 | secs_num, bxsr, c_num = sigs.size() 85 | b_num = bxsr // sr 86 | width_pad = int(secs_num * sr) 87 | 88 | sigs = sigs.view(-1, width_pad, c_num) # missing dim should be b_num 89 | 90 | assert sigs.size(0) == b_num 91 | 92 | if lcrop > 0: 93 | sigs = sigs[:,lcrop:, :] 94 | 95 | return sigs 96 | -------------------------------------------------------------------------------- /wavenet/wavenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import Parameter 5 | from torch.autograd import Variable 6 | from wavenet.layers import * 7 | from wavenet.utils import * 8 | 9 | import numpy as np 10 | 11 | class FastWaveNet(nn.Module): 12 | """Default values found from various sources 13 | 14 | Original WaveNet Tensorflow (https://github.com/ibab/tensorflow-wavenet) 15 | Usage (with the architecture as in the DeepMind paper): 16 | dilations = [2**i for i in range(N)] * M 17 | filter_width = 2 # Convolutions just use 2 samples. 18 | residual_channels = 16 # Not specified in the paper. 19 | dilation_channels = 32 # Not specified in the paper. 20 | skip_channels = 16 # Not specified in the paper. 21 | net = WaveNetModel(batch_size, dilations, filter_width, 22 | residual_channels, dilation_channels, 23 | skip_channels) 24 | loss = net.loss(input_batch) 25 | 26 | FastWaveNet Tensorflow (https://github.com/tomlepaine/fast-wavenet) 27 | num_blocks=2, 28 | num_layers=14, 29 | 30 | Args: 31 | layers: number of layers per blocks 32 | blocks: number of blocks 33 | residual_channels: How many filters to learn for the residual. 34 | dilation_channels: How many filters to learn for the dilated 35 | convolution. 36 | skip_channels: How many filters to learn that contribute to the 37 | quantized softmax output. 38 | quantization_channels: mu-law quantization channels. Essentially the 39 | number of output classes in the final softmax layer 40 | Potential Args: not used yet, but perhaps I should 41 | initial_filter_width: The width of the initial filter of the 42 | convolution applied to the scalar input. This is only relevant 43 | if scalar_input=True. 44 | 45 | Author Notes: 46 | the original paper dilates by a factor of two always. However, an 47 | invalid dilation will occur if the dilation width is not a factor of 48 | the input length. thus a rotating deque of dilation widths which are 49 | prime factors of the input length are used instead of 2. 50 | 51 | i.e. 16000 = 2 * 2 * 2 * 2 * 2 * 2 * 5 * 5 * 5 52 | 53 | """ 54 | def __init__(self, 55 | blocks=2, 56 | layers=None, # depends on number of prime factors of input_len 57 | residual_channels=16, 58 | dilation_channels=32, 59 | skip_channels=16, 60 | quantization_channels=256, 61 | input_len=16000, 62 | audio_channels=1, 63 | kernel_size=2): 64 | super(FastWaveNet, self).__init__() 65 | # variables 66 | self.scope_mul = prime_factors(input_len) # used to insure valid dilations 67 | if layers is None or layers > len(self.scope_mul): 68 | print("setting # of layers to {}".format(len(self.scope_mul))) 69 | self.layers = len(self.scope_mul) 70 | else: 71 | self.layers = layers 72 | self.blocks = blocks 73 | self.audio_channels = audio_channels 74 | self.residual_channels = residual_channels * audio_channels 75 | self.dilation_channels = dilation_channels * audio_channels 76 | self.skip_channels = skip_channels * audio_channels 77 | self.kernel_size = kernel_size 78 | self.quantization_channels = quantization_channels 79 | 80 | # debugging 81 | self.sizes = [] 82 | 83 | # build model 84 | receptive_field = 1 85 | init_dilation = 1 86 | 87 | self.dilations = [] 88 | self.dilated_queues = [] 89 | self.filter_convs = nn.ModuleList() 90 | self.gate_convs = nn.ModuleList() 91 | self.residual_convs = [] 92 | self.skip_convs = [] 93 | 94 | # filter non-linearity 95 | self.filter_act = F.tanh 96 | 97 | # non-linearity in "pre-softmax" layers 98 | self.nl_out = nn.ReLU() 99 | self.nl_end = nn.ReLU() 100 | #self.nl1 = nn.SELU 101 | #self.nl1 = nn.Hardtanh 102 | #self.nl2 = nn.PReLU(self.quantization_channels) 103 | 104 | # initial convolution 105 | self.conv0 = Conv1dExt(in_channels=audio_channels, 106 | out_channels=self.residual_channels, 107 | kernel_size=1, 108 | bias=False) 109 | # convolution out of blocks 110 | self.conv_out = Conv1dExt(in_channels=self.skip_channels, 111 | out_channels=self.quantization_channels * self.audio_channels, 112 | kernel_size=1, 113 | bias=True) 114 | # final convolution before leaving the network 115 | self.conv_end = Conv1dExt(in_channels=self.quantization_channels * self.audio_channels, 116 | out_channels=self.quantization_channels * self.audio_channels, 117 | kernel_size=1, 118 | bias=False) 119 | for b in range(blocks): 120 | additional_scope = kernel_size - 1 121 | new_dilation = 1 122 | for l in range(self.layers): 123 | self.dilations.append((new_dilation, init_dilation)) 124 | self.dilated_queues.append(DilatedQueue(max_length=int((kernel_size - 1) * new_dilation + 1), 125 | num_channels=self.residual_channels, 126 | dilation=new_dilation)) 127 | self.filter_convs.append(Conv1dExt(in_channels=self.residual_channels, 128 | out_channels=self.dilation_channels, 129 | kernel_size=kernel_size, 130 | bias=False)) 131 | self.gate_convs.append(Conv1dExt(in_channels=self.residual_channels, 132 | out_channels=self.dilation_channels, 133 | kernel_size=kernel_size, 134 | bias=False)) 135 | self.residual_convs.append(Conv1dExt(in_channels=self.dilation_channels, 136 | out_channels=self.residual_channels, 137 | kernel_size=1, 138 | bias=False)) 139 | self.skip_convs.append(Conv1dExt(in_channels=self.dilation_channels, 140 | out_channels=self.skip_channels, 141 | kernel_size=1, 142 | bias=False)) 143 | receptive_field += additional_scope 144 | #print("receptive field (layer: {}, block: {}): {}".format(l+1, b+1, receptive_field)) 145 | additional_scope *= self.scope_mul[0] 146 | init_dilation = new_dilation 147 | new_dilation *= self.scope_mul[0] 148 | self.scope_mul.rotate(-1) 149 | 150 | # define deps 151 | self.conv0.input_tied_modules.append(self.filter_convs[0]) 152 | for i in range(int(blocks*self.layers)): 153 | 154 | self.filter_convs[i].input_tied_modules.append(self.residual_convs[i]) 155 | self.filter_convs[i].input_tied_modules.append(self.skip_convs[i]) 156 | self.filter_convs[i].output_tied_modules.append(self.gate_convs[i]) 157 | 158 | self.gate_convs[i].input_tied_modules.append(self.residual_convs[i]) 159 | self.gate_convs[i].input_tied_modules.append(self.skip_convs[i]) 160 | self.gate_convs[i].output_tied_modules.append(self.filter_convs[i]) 161 | 162 | self.skip_convs[i].input_tied_modules.append(self.conv_out) 163 | self.skip_convs[i].output_tied_modules = [skip for ind, skip in enumerate(self.skip_convs) if ind != i] 164 | if i < blocks*self.layers-1: 165 | # final layer 166 | self.residual_convs[i].input_tied_modules.append(self.filter_convs[i + 1]) 167 | self.residual_convs[i].input_tied_modules.append(self.gate_convs[i + 1]) 168 | if i > 0: 169 | # all but first layer 170 | self.residual_convs[i].output_tied_modules.append(self.residual_convs[i-1]) 171 | self.residual_convs[i].output_tied_modules.append(self.filter_convs[i-1]) 172 | self.residual_convs[i].output_tied_modules.append(self.gate_convs[i-1]) 173 | self.residual_convs[i].input_tied_modules.append(self.skip_convs[i-1]) 174 | self.residual_convs[i].input_tied_modules.append(self.filter_convs[i]) 175 | self.residual_convs[i].input_tied_modules.append(self.gate_convs[i]) 176 | def forward(self, input): 177 | ob, oc, ol = input.size() 178 | self.sizes.append(input.size()) 179 | 180 | x = self.conv0(input) 181 | self.sizes.append(x.size()) 182 | skip = 0 183 | 184 | for i in range(int(self.blocks*self.layers)): 185 | (dil, init_dil) = self.dilations[i] 186 | res, pad_res = dilate(x, dil) 187 | 188 | # dilation Convolutions 189 | if pad_res == 0: 190 | res = pad1d(res,(1, 0, 0, 0),pad_value=0) 191 | else: 192 | print("pad_res: {}".format(pad_res)) 193 | pass 194 | fil = self.filter_convs[i](res) 195 | fil = self.filter_act(fil) 196 | 197 | gate = self.gate_convs[i](res) 198 | gate = F.sigmoid(gate) 199 | self.sizes.append(x.size()) 200 | 201 | x = fil * gate 202 | self.sizes.append(x.size()) 203 | 204 | s = x 205 | if x.size(2) != 1: 206 | s, pad_skip = dilate(x, self.audio_channels) 207 | s = self.skip_convs[i](s) 208 | 209 | try: 210 | # this is designed to remove the front padding 211 | skip = skip[:, :, -s.size(2):] 212 | except: 213 | skip = 0 214 | 215 | skip = s + skip # Note the skip is ultimately part of the output 216 | self.sizes.append(("skip size:",)+skip.size()) 217 | 218 | x = self.residual_convs[i](x) 219 | x = x + res[:, :, (self.kernel_size - 1):] 220 | self.sizes.append(x.size()) 221 | 222 | # the multiple non-linearities modeled after tensorflow version 223 | x = self.nl_out(skip) 224 | x = self.conv_out(x) 225 | self.sizes.append("last conv before resize") 226 | self.sizes.append(x.size()) 227 | x, _ = dilate(x, self.audio_channels) # only works with 1 channel for now 228 | x = self.nl_end(x) 229 | x = self.conv_end(x) 230 | self.sizes.append(x.size()) 231 | 232 | return x 233 | 234 | def parameter_count(self): 235 | par = list(self.parameters()) 236 | s = sum([np.prod(list(d.size())) for d in par]) 237 | return s 238 | --------------------------------------------------------------------------------