├── README.md ├── dump_lpcnet.py ├── lpcnet.py ├── main.py ├── test.py ├── train.py ├── ulaw.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # LPCNet_pytorch 2 | A Pytorch version of LPCNet, including dump weight 3 | 4 | - J.-M. Valin, J. Skoglund, [A Real-Time Wideband Neural Vocoder at 1.6 kb/s Using LPCNet](https://jmvalin.ca/papers/lpcnet_codec.pdf), *Submitted for INTERSPEECH 2019*. 5 | - J.-M. Valin, J. Skoglund, [LPCNet: Improving Neural Speech Synthesis Through Linear Prediction](https://jmvalin.ca/papers/lpcnet_icassp2019.pdf), *Proc. International Conference on Acoustics, Speech and Signal Processing (ICASSP)*, arXiv:1810.11846, 2019. 6 | 7 | # Download Data 8 | Use together with the C code of this [repo](https://github.com/mozilla/LPCNet). 9 | Suitable training material can be obtained from the [McGill University Telecommunications & Signal Processing Laboratory](http://www-mmsp.ece.mcgill.ca/Documents/Data/). Download the ISO and extract the 16k-LP7 directory, the src/concat.sh script can be used to generate a headerless file of training samples. 10 | ``` 11 | cd 16k-LP7 12 | sh /path/to/concat.sh 13 | ``` 14 | # Dump Training Data 15 | Use together with this [repo](https://github.com/mozilla/LPCNet). 16 | ``` 17 | ./dump_data -train input.s16 features.f32 data.u8 18 | ``` 19 | 20 | # Training 21 | ``` 22 | python main.py --feat features.f32 --data data.u8 23 | ``` 24 | 25 | # Testing 26 | ``` 27 | ./dump_data -test test_input.s16 test_features.f32 28 | python test.py --feat test_features.f32 29 | ``` 30 | 31 | # Dump LPCNet weights 32 | ``` 33 | python dump_lpcnet.py --load 34 | ``` 35 | -------------------------------------------------------------------------------- /dump_lpcnet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from lpcnet import * 5 | from torch.nn import Embedding, Linear, Conv1d, GRU, Module 6 | import argparse 7 | 8 | parser = argparse.ArgumentParser(description='Dump LPCNet weights') 9 | parser.add_argument('--load', default='./ckpts/120.pkl', type=str, help='file of state dict') 10 | 11 | max_rnn_neurons = 1 12 | max_conv_inputs = 1 13 | max_mdense_tmp = 1 14 | 15 | 16 | CNAME = { 17 | 'emb_pcm' : 'emb_sig', 18 | 'emb_pitch' : 'emb_pitch', 19 | 'conv1' : 'feature_conv1', 20 | 'conv2' : 'feature_conv2', 21 | 'dense1' : 'feature_dense1', 22 | 'dense2' : 'feature_dense2', 23 | 'gru1' : 'gru_a', 24 | 'gru2' : 'gru_b', 25 | 'md' : 'dual_fc', 26 | } 27 | 28 | # helper 29 | def printVector(f, vector, name, dtype='float'): 30 | v = np.reshape(vector, (-1)); 31 | #print('static const float ', name, '[', len(v), '] = \n', file=f) 32 | f.write('static const {} {}[{}] = {{\n '.format(dtype, name, len(v))) 33 | for i in range(0, len(v)): 34 | f.write('{}'.format(v[i])) 35 | if (i!=len(v)-1): 36 | f.write(',') 37 | else: 38 | break; 39 | if (i%8==7): 40 | f.write("\n ") 41 | else: 42 | f.write(" ") 43 | #print(v, file=f) 44 | f.write('\n};\n\n') 45 | return; 46 | 47 | # helper 48 | def printSparseVector(f, A, name): 49 | N = A.shape[0] 50 | W = np.zeros((0,)) 51 | diag = np.concatenate([np.diag(A[:,:N]), np.diag(A[:,N:2*N]), np.diag(A[:,2*N:])]) 52 | A[:,:N] = A[:,:N] - np.diag(np.diag(A[:,:N])) 53 | A[:,N:2*N] = A[:,N:2*N] - np.diag(np.diag(A[:,N:2*N])) 54 | A[:,2*N:] = A[:,2*N:] - np.diag(np.diag(A[:,2*N:])) 55 | printVector(f, diag, name + '_diag') 56 | idx = np.zeros((0,), dtype='int') 57 | for i in range(3*N//16): 58 | pos = idx.shape[0] 59 | idx = np.append(idx, -1) 60 | nb_nonzero = 0 61 | for j in range(N): 62 | if np.sum(np.abs(A[j, i*16:(i+1)*16])) > 1e-10: 63 | nb_nonzero = nb_nonzero + 1 64 | idx = np.append(idx, j) 65 | W = np.concatenate([W, A[j, i*16:(i+1)*16]]) 66 | idx[pos] = nb_nonzero 67 | printVector(f, W, name) 68 | #idx = np.tile(np.concatenate([np.array([N]), np.arange(N)]), 3*N//16) 69 | printVector(f, idx, name + '_idx', dtype='int') 70 | return; 71 | 72 | # default 73 | def dump_layer_ignore(self, f, hf, key): 74 | print("ignoring layer " + self.name + " of type " + self.__class__.__name__) 75 | return False 76 | Module.dump_layer = dump_layer_ignore 77 | 78 | # Sparse GRU, only use once 79 | def dump_sparse_gru(self, f, hf, key): 80 | global max_rnn_neurons 81 | name = 'sparse_' + CNAME[key] 82 | print("printing layer " + name + " of type sparse " + self.__class__.__name__) 83 | W1 = self.weight_ih_l0.data.transpose(1,0).detach().numpy() 84 | W2 = self.weight_hh_l0.data.transpose(1,0).detach().numpy() 85 | b1 = self.bias_ih_l0.data.unsqueeze(0) 86 | b2 = self.bias_hh_l0.data.unsqueeze(0) 87 | b = torch.cat((b1, b2), dim=0).detach().numpy() 88 | printSparseVector(f, W2, name + '_recurrent_weights') 89 | printVector(f, b, name + '_bias') 90 | activation = 'TANH' 91 | reset_after = 1 92 | neurons = W.shape[1]//3 93 | max_rnn_neurons = max(max_rnn_neurons, neurons) 94 | f.write('const SparseGRULayer {} = {{\n {}_bias,\n {}_recurrent_weights_diag,\n {}_recurrent_weights,\n {}_recurrent_weights_idx,\n {}, ACTIVATION_{}, {}\n}};\n\n' 95 | .format(name, name, name, name, name, W.shape[1]//3, activation, reset_after)) 96 | hf.write('#define {}_OUT_SIZE {}\n'.format(name.upper(), W.shape[1]//3)) 97 | hf.write('#define {}_STATE_SIZE {}\n'.format(name.upper(), W.shape[1]//3)) 98 | hf.write('extern const SparseGRULayer {};\n\n'.format(name)); 99 | return True 100 | 101 | # GRU 102 | def dump_gru_layer(self, f, hf, key): 103 | global max_rnn_neurons 104 | name = CNAME[key] 105 | print("printing layer " + name + " of type " + self.__class__.__name__) 106 | W1 = self.weight_ih_l0.data.transpose(1,0).detach().numpy() 107 | W2 = self.weight_hh_l0.data.transpose(1,0).detach().numpy() 108 | b1 = self.bias_ih_l0.data.unsqueeze(0) 109 | b2 = self.bias_hh_l0.data.unsqueeze(0) 110 | b = torch.cat((b1, b2), dim=0).detach().numpy() 111 | printVector(f, W1, name + '_weights') 112 | printVector(f, W2, name + '_recurrent_weights') 113 | printVector(f, b, name + '_bias') 114 | activation = 'TANH' 115 | reset_after = 1 116 | neurons = W1.shape[1]//3 117 | max_rnn_neurons = max(max_rnn_neurons, neurons) 118 | f.write('const GRULayer {} = {{\n {}_bias,\n {}_weights,\n {}_recurrent_weights,\n {}, {}, ACTIVATION_{}, {}\n}};\n\n' 119 | .format(name, name, name, name, W1.shape[0], W1.shape[1]//3, activation, reset_after)) 120 | hf.write('#define {}_OUT_SIZE {}\n'.format(name.upper(), W1.shape[1]//3)) 121 | hf.write('#define {}_STATE_SIZE {}\n'.format(name.upper(), W1.shape[1]//3)) 122 | hf.write('extern const GRULayer {};\n\n'.format(name)) 123 | return True 124 | GRU.dump_layer = dump_gru_layer 125 | 126 | def dump_dense_layer_impl(name, weights, bias, activation, f, hf): 127 | printVector(f, weights, name + '_weights') 128 | printVector(f, bias, name + '_bias') 129 | f.write('const DenseLayer {} = {{\n {}_bias,\n {}_weights,\n {}, {}, ACTIVATION_{}\n}};\n\n' 130 | .format(name, name, name, weights.shape[0], weights.shape[1], activation)) 131 | hf.write('#define {}_OUT_SIZE {}\n'.format(name.upper(), weights.shape[1])) 132 | hf.write('extern const DenseLayer {};\n\n'.format(name)); 133 | 134 | def dump_dense_layer(self, f, hf, key): 135 | name = CNAME[key] 136 | print("printing layer " + name + " of type " + self.__class__.__name__) 137 | w, b = self.weight.data.transpose(1, 0), self.bias.data 138 | w, b = w.detach().numpy(), b.detach().numpy() 139 | dump_dense_layer_impl(name, w, b, 'TANH', f, hf) 140 | return False 141 | 142 | Linear.dump_layer = dump_dense_layer 143 | 144 | def dump_mdense_layer(self, f, hf, key): 145 | global max_mdense_tmp 146 | name = CNAME[key] 147 | print("printing layer " + name + " of type " + self.__class__.__name__) 148 | w1, w2 = self.fc1.weight.data, self.fc2.weight.data 149 | W = torch.cat((w1.unsqueeze(2), w2.unsqueeze(2)), dim=2).detach().numpy() 150 | b1, b2 = self.fc1.bias.data, self.fc2.bias.data 151 | b = torch.cat((b1.unsqueeze(1), b2.unsqueeze(1)), dim=1).detach().numpy() 152 | gamma1, gamma2 = self.gamma1.data, self.gamma2.data 153 | gamma = torch.cat((gamma1.unsqueeze(1), gamma2.unsqueeze(1)), dim=1).detach().numpy() 154 | printVector(f, np.transpose(W, (1, 2, 0)), name + '_weights') 155 | printVector(f, np.transpose(b, (1, 0)), name + '_bias') 156 | printVector(f, np.transpose(gamma, (1, 0)), name + '_factor') 157 | activation = 'TANH' 158 | max_mdense_tmp = max(max_mdense_tmp, W.shape[0] * W.shape[2]) 159 | f.write('const MDenseLayer {} = {{\n {}_bias,\n {}_weights,\n {}_factor,\n {}, {}, {}, ACTIVATION_{}\n}};\n\n' 160 | .format(name, name, name, name, W.shape[1], W.shape[0], W.shape[2], activation)) 161 | hf.write('#define {}_OUT_SIZE {}\n'.format(name.upper(), W.shape[0])) 162 | hf.write('extern const MDenseLayer {};\n\n'.format(name)); 163 | return False 164 | MDense.dump_layer = dump_mdense_layer 165 | 166 | def dump_conv1d_layer(self, f, hf, key): 167 | global max_conv_inputs 168 | name = CNAME[key] 169 | print("printing layer " + name + " of type " + self.__class__.__name__) 170 | W = self.weight.data.transpose(2, 0) 171 | b = self.bias.data 172 | W, b = W.detach().numpy(), b.detach().numpy() 173 | printVector(f, W, name + '_weights') 174 | printVector(f, b, name + '_bias') 175 | activation = 'TANH' 176 | max_conv_inputs = max(max_conv_inputs, W.shape[1]*W.shape[0]) 177 | f.write('const Conv1DLayer {} = {{\n {}_bias,\n {}_weights,\n {}, {}, {}, ACTIVATION_{}\n}};\n\n' 178 | .format(name, name, name, W.shape[1], W.shape[0], W.shape[2], activation)) 179 | hf.write('#define {}_OUT_SIZE {}\n'.format(name.upper(), W.shape[2])) 180 | hf.write('#define {}_STATE_SIZE ({}*{})\n'.format(name.upper(), W.shape[1], (W.shape[0]-1))) 181 | hf.write('#define {}_DELAY {}\n'.format(name.upper(), (W.shape[0]-1)//2)) 182 | hf.write('extern const Conv1DLayer {};\n\n'.format(name)); 183 | return True 184 | Conv1d.dump_layer = dump_conv1d_layer 185 | 186 | # helper 187 | def dump_embedding_layer_impl(name, weights, f, hf): 188 | printVector(f, weights, name + '_weights') 189 | f.write('const EmbeddingLayer {} = {{\n {}_weights,\n {}, {}\n}};\n\n' 190 | .format(name, name, weights.shape[0], weights.shape[1])) 191 | hf.write('#define {}_OUT_SIZE {}\n'.format(name.upper(), weights.shape[1])) 192 | hf.write('extern const EmbeddingLayer {};\n\n'.format(name)); 193 | 194 | def dump_embedding_layer(self, f, hf, key): 195 | name = CNAME[key] 196 | print("printing layer " + name + " of type " + self.__class__.__name__) 197 | W = self.weight.data.detach().numpy() 198 | dump_embedding_layer_impl(name, W, f, hf) 199 | return False 200 | Embedding.dump_layer = dump_embedding_layer 201 | 202 | 203 | model = torch.nn.DataParallel(LPCNet()) 204 | model.load_state_dict(torch.load(args.load, map_location=torch.device('cpu'))) 205 | model = model.cpu() 206 | net = model.module 207 | 208 | cfile = 'nnet_data.c' 209 | hfile = 'nnet_data.h' 210 | 211 | 212 | f = open(cfile, 'w') 213 | hf = open(hfile, 'w') 214 | 215 | 216 | f.write('/*This file is automatically generated from a Keras model*/\n\n') 217 | f.write('#ifdef HAVE_CONFIG_H\n#include "config.h"\n#endif\n\n#include "nnet.h"\n#include "{}"\n\n'.format(hfile)) 218 | 219 | hf.write('/*This file is automatically generated from a Keras model*/\n\n') 220 | hf.write('#ifndef RNN_DATA_H\n#define RNN_DATA_H\n\n#include "nnet.h"\n\n') 221 | 222 | E = net.emb_pcm.weight.data.detach().numpy() 223 | W = net.gru1.weight_ih_l0.data.transpose(1, 0)[:embed_size, :].detach().numpy() 224 | dump_embedding_layer_impl('gru_a_embed_sig', np.dot(E, W), f, hf) 225 | W = net.gru1.weight_ih_l0.data.transpose(1, 0)[embed_size:2*embed_size,:].detach().numpy() 226 | dump_embedding_layer_impl('gru_a_embed_pred', np.dot(E, W), f, hf) 227 | W = net.gru1.weight_ih_l0.transpose(1, 0)[2*embed_size:3*embed_size,:].detach().numpy() 228 | dump_embedding_layer_impl('gru_a_embed_exc', np.dot(E, W), f, hf) 229 | b1, b2 = net.gru1.bias_ih_l0.data, net.gru1.bias_hh_l0.data 230 | b = torch.cat((b1.unsqueeze(0), b2.unsqueeze(0)), dim=0).detach().numpy() 231 | dump_dense_layer_impl('gru_a_dense_feature', W, b, 'LINEAR', f, hf) 232 | 233 | 234 | layer_list = [] 235 | for key in net._modules: 236 | if net._modules[key].dump_layer(f, hf, key): 237 | layer_list.append(key) 238 | 239 | dump_sparse_gru(net.gru1, f, hf, 'gru1') 240 | 241 | hf.write('#define MAX_RNN_NEURONS {}\n\n'.format(max_rnn_neurons)) 242 | hf.write('#define MAX_CONV_INPUTS {}\n\n'.format(max_conv_inputs)) 243 | hf.write('#define MAX_MDENSE_TMP {}\n\n'.format(max_mdense_tmp)) 244 | 245 | 246 | hf.write('typedef struct {\n') 247 | for i, key in enumerate(layer_list): 248 | name = CNAME[key] 249 | hf.write(' float {}_state[{}_STATE_SIZE];\n'.format(name, name.upper())) 250 | hf.write('} NNetState;\n') 251 | 252 | hf.write('\n\n#endif\n') 253 | 254 | f.close() 255 | hf.close() 256 | -------------------------------------------------------------------------------- /lpcnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | frame_size = 160 6 | pcm_bits = 8 7 | embed_size = 128 8 | pcm_levels = 2**pcm_bits 9 | 10 | 11 | def _init_linear_conv_(layer): 12 | nn.init.xavier_uniform_(layer.weight) 13 | nn.init.zeros_(layer.bias) 14 | 15 | def _init_gru_(layer): 16 | nn.init.xavier_uniform_(layer.weight_ih_l0) 17 | nn.init.orthogonal_(layer.weight_hh_l0) 18 | nn.init.zeros_(layer.bias_ih_l0) 19 | nn.init.zeros_(layer.bias_hh_l0) 20 | 21 | class MDense(nn.Module): 22 | def __init__(self, in_dim=16, out_dim=256): 23 | super(MDense, self).__init__() 24 | self.fc1 = nn.Linear(in_dim, out_dim) 25 | self.fc2 = nn.Linear(in_dim, out_dim) 26 | self.gamma1 = nn.Parameter(torch.ones(out_dim), requires_grad=True) 27 | self.gamma2 = nn.Parameter(torch.ones(out_dim), requires_grad=True) 28 | _init_linear_conv_(self.fc1) 29 | _init_linear_conv_(self.fc2) 30 | def forward(self, x): 31 | y1 = self.fc1(x).tanh() 32 | y2 = self.fc2(x).tanh() 33 | p = self.gamma1 * y1 + self.gamma2 * y2 34 | return p.sigmoid() 35 | 36 | class LPCNet(nn.Module): 37 | def __init__(self, rnn_units1=384, rnn_units2=16, nb_used_features=38, mode='prepad'): 38 | super(LPCNet, self).__init__() 39 | self.rnn_units1 = rnn_units1 40 | self.rnn_units2 = rnn_units2 41 | padding = 0 if mode=='prepad' else 1 42 | self.emb_pcm = nn.Embedding(256, 128) 43 | self.emb_pitch = nn.Embedding(256, 64) 44 | self.conv1 = nn.Conv1d(102, 128, 3, padding=padding) 45 | self.conv2 = nn.Conv1d(128, 128, 3, padding=padding) 46 | self.dense1 = nn.Linear(128, 128) 47 | self.dense2 = nn.Linear(128, 128) 48 | self.gru1 = nn.GRU(input_size=512, hidden_size=rnn_units1, batch_first=True) 49 | self.gru2 = nn.GRU(input_size=512, hidden_size=rnn_units2, batch_first=True) 50 | self.md = MDense(in_dim=16, out_dim=256) 51 | self.pcminit(self.emb_pcm) 52 | nn.init.uniform_(self.emb_pitch.weight, a=-0.05, b=0.05) 53 | _init_linear_conv_(self.conv1) 54 | _init_linear_conv_(self.conv2) 55 | _init_linear_conv_(self.dense1) 56 | _init_linear_conv_(self.dense2) 57 | _init_gru_(self.gru1) 58 | _init_gru_(self.gru2) 59 | 60 | def pcminit(self, layer): 61 | w = layer.state_dict()['weight'] 62 | shape = w.shape 63 | num_rows, num_cols = shape 64 | flat_shape = (num_rows, num_cols) 65 | p = torch.rand(shape).add_(-0.5).mul_(1.7321 * 2) 66 | r = torch.arange(-.5*num_rows+.5,.5*num_rows-.4).mul(math.sqrt(12)/num_rows).reshape(num_rows, 1) 67 | w[:] = p + r 68 | 69 | def encode(self, feat, pitch): 70 | # (bs, 15/19, 1) --> (bs, 15/19, 1, 64) --> (bs, (15/19), 64) 71 | pitch = self.emb_pitch(pitch).squeeze(2) 72 | # (bs, (15/19), 38+64) --> (bs, 102, 15/19) 73 | feat = torch.cat((feat, pitch), dim=2).permute(0, 2, 1) 74 | # (bs, 102, 15/19) --> (bs, 128, 15/17) 75 | feat = self.conv1(feat).tanh() 76 | # (bs, 128, 15/17) --> (bs, 128, 15) --> (bs, 15, 128) 77 | feat = self.conv2(feat).tanh().permute(0, 2, 1) 78 | # (bs, 15, 128) --> (bs, 15, 128) 79 | feat = self.dense2(self.dense1(feat).tanh()).tanh() 80 | return feat 81 | 82 | def decode(self, pcm, feat, hid1, hid2, frame_size=160): 83 | # bs, cs = batch size, chunk size 84 | bs, cs = pcm.shape[0], pcm.shape[1] 85 | # (bs, 2400, 3) --> (bs, 2400, 3, 128) --> (bs, 2400, 384) 86 | pcm = self.emb_pcm(pcm).reshape(bs, cs, -1) 87 | # (bs, 15, 128) --> (bs, 2400, 128) 88 | rfeat = torch.repeat_interleave(feat, frame_size, dim=1) if frame_size > 1 else feat 89 | # (bs, 2400, 512) 90 | rnn_in = torch.cat((pcm, rfeat), dim=-1) 91 | # (bs, 2400, 512) --> (bs, 2400, 384) 92 | self.gru1.flatten_parameters() 93 | hid1 = hid1.to(rnn_in) 94 | out, hid1 = self.gru1(rnn_in, hid1) 95 | # (bs, 2400, 384+128) 96 | rnn_in = torch.cat((out, rfeat), dim=-1) 97 | # (bs, 2400, 16) 98 | self.gru2.flatten_parameters() 99 | hid2 = hid2.to(rnn_in) 100 | out, hid2 = self.gru2(rnn_in, hid2) 101 | prob = self.md(out) 102 | return prob, hid1, hid2 103 | 104 | def forward(self, pcm, feat, pitch): 105 | # bs, cs = batch size, chunk size 106 | bs, cs = pcm.shape[0], pcm.shape[1] 107 | feat = self.encode(feat, pitch) 108 | # (1, bs, 384) 109 | zeros1 = torch.zeros(1, bs, self.rnn_units1) 110 | # (1, bs, 16) 111 | zeros2 = torch.zeros(1, bs, self.rnn_units2) 112 | prob, hid1, hid2 = self.decode(pcm, feat, zeros1, zeros2) 113 | return prob 114 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from utils import get_data 4 | import numpy as np 5 | from lpcnet import LPCNet 6 | from train import train 7 | import sys 8 | import argparse 9 | 10 | 11 | parser = argparse.ArgumentParser(description='Train LPCNet from scratch') 12 | parser.add_argument('--feat', default='../features.f32', type=str, help='input feature') 13 | parser.add_argument('--data', default='../data.u8', type=str, help='output wav file') 14 | 15 | args = parser.parse_args() 16 | 17 | feature_file = args.feat 18 | pcm_file = args.data 19 | 20 | torch.cuda._lazy_init() 21 | torch.backends.cudnn.benchmark = True 22 | 23 | 24 | print('Initialize Model....') 25 | model = nn.DataParallel(LPCNet()).cuda() 26 | print('Read Training Data....') 27 | dataloader = get_data(pcm_file, feature_file) 28 | loss = nn.CrossEntropyLoss().cuda() 29 | if __name__ == '__main__': 30 | print('Start Training!!') 31 | train(model, dataloader, loss) -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from lpcnet import LPCNet 3 | import numpy as np 4 | from ulaw import * 5 | from tqdm import tqdm 6 | import sys 7 | import argparse 8 | from scipy import io 9 | from scipy.io import wavfile 10 | 11 | parser = argparse.ArgumentParser(description='Test LPCNet') 12 | parser.add_argument('--load', default='./ckpts/120.pkl', type=str, help='file of state dict') 13 | parser.add_argument('--feat', default='../test_features.f32', type=str, help='input feature') 14 | parser.add_argument('--file', default='./test.wav', type=str, help='output wav file') 15 | 16 | args = parser.parse_args() 17 | 18 | feature_file = args.feat 19 | out_file = args.file 20 | 21 | model = torch.nn.DataParallel(LPCNet(mode='test')) 22 | model.load_state_dict(torch.load(args.load, map_location=torch.device('cpu'))) 23 | model = model.cpu() 24 | 25 | 26 | frame_size = 160 27 | nb_features = 55 28 | nb_used_features = 38 29 | 30 | features = np.fromfile(feature_file, dtype='float32') 31 | features = np.resize(features, (-1, nb_features)) 32 | nb_frames = 1 33 | feature_chunk_size = features.shape[0] 34 | pcm_chunk_size = frame_size*feature_chunk_size 35 | 36 | features = np.reshape(features, (nb_frames, feature_chunk_size, nb_features)) 37 | features[:,:,18:36] = 0 38 | #periods = (.1 + 50*features[:,:,36:37]+100).astype('int16') 39 | 40 | periods = torch.Tensor((.1 + 50*features[:,:,36:37]+100)).type(torch.LongTensor) 41 | 42 | order = 16 43 | 44 | #pcm = np.zeros((nb_frames*pcm_chunk_size, )) 45 | #fexc = np.zeros((1, 1, 3), dtype='int16')+128 46 | pcm = np.zeros(nb_frames*pcm_chunk_size) 47 | fexc = np.zeros((1, 1, 3), dtype='int16')+128 48 | #state1 = np.zeros((1, model.rnn_units1), dtype='float32') 49 | #state2 = np.zeros((1, model.rnn_units2), dtype='float32') 50 | state1 = torch.zeros((1, 1, model.module.rnn_units1)) 51 | state2 = torch.zeros((1, 1, model.module.rnn_units2)) 52 | 53 | mem = 0 54 | coef = 0.85 55 | 56 | #fout = open(out_file, 'wb') 57 | 58 | skip = order + 1 59 | res = [] 60 | with torch.no_grad(): 61 | for c in range(0, nb_frames): 62 | feat = torch.FloatTensor(features[c:c+1, :, :nb_used_features]) 63 | pitch = periods[c:c+1, :, :] 64 | cfeat = model.module.encode(feat, pitch) 65 | for fr in tqdm(range(0, feature_chunk_size)): 66 | f = c*feature_chunk_size + fr 67 | a = features[c, fr, nb_features-order:] 68 | for i in range(skip, frame_size): 69 | t = f*frame_size + i 70 | # float 71 | pred = -sum(a * pcm[t-1 : t-1-order:-1]) 72 | fexc[0, 0, 1] = lin2ulaw(pred) 73 | p, state1, state2 = model.module.decode(torch.LongTensor(fexc), torch.Tensor(cfeat[:, fr:fr+1, :]), state1, state2, frame_size=1) 74 | p = p.softmax(dim=-1).numpy() 75 | p *= np.power(p, np.maximum(0, 1.5*features[c, fr, 37] - .5)) 76 | p = p/(1e-18 + np.sum(p)) 77 | p = np.maximum(p-0.002, 0).astype('float64') 78 | p = p/(1e-8 + np.sum(p)) 79 | fexc[0, 0, 2] = np.argmax(np.random.multinomial(1, p[0,0,:], 1)) 80 | pcm[t] = pred + ulaw2lin(fexc[0, 0, 2]) 81 | fexc[0, 0, 0] = lin2ulaw(pcm[t]) 82 | mem = coef*mem + pcm[t] 83 | res.append(round(mem)) 84 | skip = 0 85 | 86 | wavfile.write(out_file, rate=16000, data=np.array(res, dtype='int16')) 87 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from tqdm import tqdm 5 | 6 | t_start = 2000 7 | t_end = 40000 8 | interval= 400 9 | final_density = (0.05, 0.05, 0.2) 10 | 11 | def sparsity(net, i): 12 | if i < t_start: 13 | return 14 | if (i - t_start) % interval != 0 and i < t_end: 15 | return 16 | p = net.module.gru1.state_dict()['weight_hh_l0'].transpose(1, 0) 17 | nb = p.shape[1] // p.shape[0] 18 | N = p.shape[0] 19 | for k in range(nb): 20 | density = final_density[k] 21 | if i < t_end: 22 | r = 1.0 - (i - t_start) / (t_end - t_start) 23 | density = 1 - (1 - final_density[k]) * (1-r*r*r) 24 | A = p[:, k*N:(k+1)*N] 25 | A = A - torch.diag(torch.diag(A)) 26 | L = torch.reshape(A, (N, N//16, 16)) 27 | S = torch.sum(L*L, axis=-1) 28 | SS, _ = torch.sort(torch.reshape(S, (-1,))) 29 | thresh = SS[round(N*N//16*(1-density))] 30 | mask = (S>=thresh).float() 31 | mask = torch.repeat_interleave(mask, 16, axis=1) 32 | mask.add_(torch.eye(N).cuda()).clamp_(max=1) 33 | p[:, k*N:(k+1)*N] = p[:, k*N:(k+1)*N]*mask 34 | net.module.gru1.state_dict()['weight_hh_l0'][:] = p.transpose(1, 0) 35 | 36 | def train(net, dataloader, loss, lr=0.001, epochs=120): 37 | opt = torch.optim.Adam(net.parameters(), betas=(0.9, 0.99), eps=1e-7, lr=lr) 38 | scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lambda it: lr/(1. + 5e-5 * it)) 39 | iteration = 0 40 | for epoch in range(1, epochs+1): 41 | print('Epoch:\t'+str(epoch)+'/'+str(epochs)) 42 | total_loss = [] 43 | for i, (pcm, feat, pitch, target) in enumerate(tqdm(dataloader), 1): 44 | iteration += 1 45 | pcm = pcm.type(torch.LongTensor).cuda() 46 | feat = feat.cuda() 47 | pitch = pitch.type(torch.LongTensor).cuda() 48 | target = target.type(torch.LongTensor).reshape(-1).cuda() 49 | prob = net(pcm, feat, pitch).reshape(-1, 256) 50 | L = loss(prob, target) 51 | opt.zero_grad() 52 | L.backward() 53 | opt.step() 54 | scheduler.step() 55 | sparsity(net, iteration) 56 | total_loss.append(L.item()) 57 | avg_loss = sum(total_loss)/len(total_loss) 58 | print('\nEpoch Loss %.4f' % avg_loss) 59 | torch.save(net.state_dict(), './ckpts/%03d.pkl' % epoch) 60 | -------------------------------------------------------------------------------- /ulaw.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import math 4 | 5 | scale = 255.0/32768.0 6 | scale_1 = 32768.0/255.0 7 | def ulaw2lin(u): 8 | u = u - 128 9 | s = np.sign(u) 10 | u = np.abs(u) 11 | return s*scale_1*(np.exp(u/128.*math.log(256))-1) 12 | 13 | 14 | def lin2ulaw(x): 15 | s = np.sign(x) 16 | x = np.abs(x) 17 | u = (s*(128*np.log(1+scale*x)/math.log(256))) 18 | u = np.clip(128 + np.round(u), 0, 255) 19 | return u.astype('int16') 20 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | from lpcnet import frame_size 3 | import numpy as np 4 | 5 | batch_size = 64 6 | nb_epochs = 120 7 | nb_features = 55 8 | nb_used_features = 38 9 | feature_chunk_size = 15 10 | pcm_chunk_size = frame_size*feature_chunk_size 11 | 12 | class LPCData(Dataset): 13 | def __init__(self, in_data, features, periods, out_exc): 14 | self.in_data = in_data 15 | self.features = features 16 | self.periods = periods 17 | self.out_exc = out_exc 18 | def __getitem__(self, i): 19 | return (self.in_data[i], self.features[i], self.periods[i], self.out_exc[i]) 20 | 21 | def __len__(self): 22 | return self.in_data.shape[0] 23 | 24 | def get_data(pcm_file, feature_file): 25 | data = np.fromfile(pcm_file, dtype='uint8') 26 | nb_frames = len(data)//(4*pcm_chunk_size) 27 | features = np.fromfile(feature_file, dtype='float32') 28 | data = data[:nb_frames*4*pcm_chunk_size] 29 | features = features[:nb_frames*feature_chunk_size*nb_features] 30 | features = np.reshape(features, (nb_frames*feature_chunk_size, nb_features)) 31 | sig = np.reshape(data[0::4], (nb_frames, pcm_chunk_size, 1)) 32 | pred = np.reshape(data[1::4], (nb_frames, pcm_chunk_size, 1)) 33 | in_exc = np.reshape(data[2::4], (nb_frames, pcm_chunk_size, 1)) 34 | out_exc = np.reshape(data[3::4], (nb_frames, pcm_chunk_size, 1)) 35 | #print("ulaw std = ", np.std(out_exc)) 36 | features = np.reshape(features, (nb_frames, feature_chunk_size, nb_features)) 37 | features = features[:, :, :nb_used_features] 38 | features[:,:,18:36] = 0 39 | fpad1 = np.concatenate([features[0:1, 0:2, :], features[:-1, -2:, :]], axis=0) 40 | fpad2 = np.concatenate([features[1:, :2, :], features[0:1, -2:, :]], axis=0) 41 | features = np.concatenate([fpad1, features, fpad2], axis=1) 42 | periods = (.1 + 50*features[:,:,36:37]+100).astype('int16') 43 | #periods = np.minimum(periods, 255) 44 | in_data = np.concatenate([sig, pred, in_exc], axis=-1) 45 | dataset = LPCData(in_data, features, periods, out_exc) 46 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True, pin_memory=True) 47 | del data, sig, pred, in_exc 48 | return dataloader 49 | --------------------------------------------------------------------------------