├── README.md ├── dataset.py ├── main.py ├── misc ├── audio_lpc.py ├── audio_mfcc.py └── combine.py ├── models.py ├── models_testae.py ├── synthesis.py ├── test.py └── test_model.py /README.md: -------------------------------------------------------------------------------- 1 | 2 | ## Audio to Face Blendshape 3 | Implementation with PyTorch. 4 | 5 | 复现人:刘宇昂 6 | - Base model 7 | - LSTM using MFCC audio features 8 | - CNN([ref](http://research.nvidia.com/publication/2017-07_Audio-Driven-Facial-Animation) simplified version) with LPC features 9 | 10 | 11 | ## Prerequisites 12 | - Python3 13 | - PyTorch v0.3.0 14 | - numpy 15 | - librosa & audiolazy 16 | - scipy 17 | - etc. 18 | 19 | ## Files 20 | - Scripts to run 21 | - `main.py`: change net name and set checkpoints folder to train different models 22 | - `test_model.py`: generate blendshape sequences given extracted audio features (need audio features as input) 23 | - `synthesis.py`: generate blendshape directly from input wav (need arguements of input audio path) 24 | 25 | - Classes 26 | - `models.py`: Classes with LSTM and CNN (simplified NvidiaNet) model. 27 | - `models_testae.py`: Advanced models with audoencoder design. 28 | - `dataset.py`: Class for loading dataset. 29 | 30 | - Input preprocessing 31 | - `misc/audio_mfcc.py`: extract mfcc features from input wav files 32 | - `misc/audio_lpc.py`: extract lpc features 33 | - `misc/combine.py`: combine certain audio feature/blendshape files to obtain a single file for data loading 34 | 35 | ## Usage 36 | ### Input 37 | To build your own dataset, you need to preprocess your wav/blendshape pairs with `misc/audio_mfcc.py` or `misc/audio_lpc.py`. Then combine those feature/blendshape files `misc/combine.py` to a single feature/blendshape file. 38 | 39 | ### Training 40 | Modify `main.py`. Set model to the one you need and also specify checkpoint folder. 41 | 42 | ### Evaluation 43 | - Both `test_model.py` and `synthesis.py` can be used to generate blendshape sequences. 44 | - `test_model.py` accepts extrated audio features (MFCC/LPC). 45 | - `synthesis.py` takes raw wav file as input 46 | - State the arguments and it will produce a blenshape test file. 47 | 48 | 49 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import numpy as np 3 | 4 | class BlendshapeDataset(Dataset): 5 | 6 | def __init__(self, feature_file, target_file): 7 | self.wav_feature = np.load(feature_file) 8 | # reshape to avoid automatic conversion to doubletensor 9 | self.blendshape_target = np.loadtxt(target_file) / 100.0 10 | 11 | self._align() 12 | 13 | def __len__(self): 14 | return len(self.wav_feature) 15 | 16 | def _align(self): 17 | """ 18 | align audio feature with blendshape feature 19 | generally, number of audio feature is less 20 | """ 21 | 22 | n_audioframe, n_videoframe = len(self.wav_feature), len(self.blendshape_target) 23 | print('Current dataset -- n_videoframe: {}, n_audioframe:{}'.format(n_videoframe, n_audioframe)) 24 | assert n_videoframe - n_audioframe <= 40 25 | if n_videoframe != n_audioframe: 26 | start_videoframe = 16 27 | self.blendshape_target = self.blendshape_target[start_videoframe : start_videoframe+n_audioframe] 28 | 29 | def __getitem__(self, index): 30 | return self.wav_feature[index], self.blendshape_target[index] 31 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | ''' 2 | virtualenv: 63server pytorch 3 | author: Yachun Li (liyachun@outlook.com) 4 | ''' 5 | import torch 6 | import torch.autograd as autograd 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | 11 | import os 12 | import shutil 13 | import time 14 | from datetime import datetime 15 | 16 | from dataset import BlendshapeDataset 17 | from models import NvidiaNet, LSTMNvidiaNet, FullyLSTM 18 | from models_testae import * 19 | 20 | # gpu setting 21 | gpu_id = 1 22 | os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id) 23 | 24 | # hyper-parameters 25 | n_blendshape = 51 26 | learning_rate = 0.0001 27 | batch_size = 100 28 | epochs = 500 29 | 30 | print_freq = 20 31 | best_loss = 10000000 32 | 33 | # data path 34 | dataroot = '/home/liyachun/data/audio2bs' 35 | # data_path = os.path.join(dataroot, audio2bs) 36 | data_path = dataroot 37 | checkpoint_path = './checkpoint-lstmae-2distconcat_kl001/' 38 | if not os.path.isdir(checkpoint_path): os.mkdir(checkpoint_path) 39 | 40 | # Reconstruction + KL divergence losses summed over all elements and batch 41 | def loss_function(recon_x, x, mu, logvar): 42 | # BCE = F.binary_cross_entropy(recon_x, x, size_average=False) 43 | MSE = F.mse_loss(recon_x, x) 44 | 45 | # see Appendix B from VAE paper: 46 | # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 47 | # https://arxiv.org/abs/1312.6114 48 | # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) 49 | KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 50 | 51 | # print('loss percent: MSE %.4f(%.6f), KLD %.4f, total %.4f' 52 | # % (MSE.data[0], MSE.data[0]/(MSE.data[0]+KLD.data[0]), KLD.data[0], MSE.data[0]+KLD.data[0])) 53 | 54 | return MSE + 0.01*KLD 55 | 56 | def main(): 57 | global best_loss 58 | model = LSTMAE2dist(is_concat=True) 59 | print(model) 60 | # model = nn.DataParallel(model) 61 | 62 | # get data 63 | train_loader = torch.utils.data.DataLoader( 64 | BlendshapeDataset(feature_file=os.path.join(data_path, 'feature/0201train-mfcc39.npy'), 65 | target_file=os.path.join(data_path, 'blendshape/0201train.txt')), 66 | batch_size=batch_size, shuffle=True, num_workers=2 67 | ) 68 | val_loader = torch.utils.data.DataLoader( 69 | BlendshapeDataset(feature_file=os.path.join(data_path, 'test/feature/0201_06-1min-mfcc39.npy'), 70 | target_file=os.path.join(data_path, 'test/blendshape/0201_06-1min.txt')), 71 | batch_size=batch_size, shuffle=False, num_workers=2 72 | ) 73 | 74 | # define loss and optimiser 75 | criterion = nn.MSELoss() #??.cuda() 76 | optimizer = optim.Adam(model.parameters(), lr=learning_rate) 77 | 78 | if torch.cuda.is_available(): 79 | model = model.cuda() 80 | 81 | # training 82 | print('------------\n Training begin at %s' % datetime.now()) 83 | for epoch in range(epochs): 84 | start_time = time.time() 85 | 86 | model.train() 87 | train_loss = 0. 88 | for i, (input, target) in enumerate(train_loader): 89 | target = target.cuda(async=True) 90 | input_var = autograd.Variable(input.float()).cuda() 91 | target_var = autograd.Variable(target.float()) 92 | 93 | # compute model output 94 | # audio_z, bs_z, output = model(input_var, target_var) 95 | # loss = criterion(output, target_var) 96 | audio_z, bs_z, output, mu, logvar = model(input_var, target_var) # method2: loss change 97 | loss = loss_function(output, target_var, mu, logvar) 98 | 99 | train_loss += loss.data[0] 100 | 101 | # compute gradient and do the backpropagate 102 | optimizer.zero_grad() 103 | loss.backward() 104 | optimizer.step() 105 | # if i % print_freq == 0: 106 | # print('Training -- epoch: {} | iteration: {}/{} | loss: {:.6f} \r' 107 | # .format(epoch+1, i, len(train_loader), loss.data[0])) 108 | 109 | train_loss /= len(train_loader) 110 | print('Glance at training z: max/min of hidden audio(%.4f/%.4f), blendshape(%.4f/%.4f)' 111 | % (max(audio_z.data[0]), min(audio_z.data[0]), max(bs_z.data[0]), min(bs_z.data[0]))) 112 | 113 | model.eval() 114 | eval_loss = 0. 115 | for input, target in val_loader: 116 | target = target.cuda(async=True) 117 | input_var = autograd.Variable(input.float(), volatile=True).cuda() 118 | target_var = autograd.Variable(target.float(), volatile=True) 119 | 120 | # compute output temporal?!! 121 | # audio_z, bs_z, output = model(input_var, target_var) 122 | # loss = criterion(output, target_var) 123 | audio_z, bs_z, output, mu, logvar = model(input_var, target_var) # method2: loss change 124 | loss = loss_function(output, target_var, mu, logvar) 125 | 126 | eval_loss += loss.data[0] 127 | 128 | eval_loss /= len(val_loader) 129 | 130 | # count time of 1 epoch 131 | past_time = time.time() - start_time 132 | 133 | print('Glance at validating z: max/min of hidden audio(%.4f/%.4f), blendshape(%.4f/%.4f)' 134 | % (max(audio_z.data[0]), min(audio_z.data[0]), max(bs_z.data[0]), min(bs_z.data[0]))) 135 | 136 | # print('Evaluating -- epoch: {} | loss: {:.6f} \r'.format(epoch+1, eval_loss/len(val_loader))) 137 | print('epoch: {:03} | train_loss: {:.6f} | eval_loss: {:.6f} | {:.4f} sec/epoch \r' 138 | .format(epoch+1, train_loss, eval_loss, past_time)) 139 | 140 | # save best model on val 141 | is_best = eval_loss < best_loss 142 | best_loss = min(eval_loss, best_loss) 143 | if is_best: 144 | torch.save({ 145 | 'epoch': epoch + 1, 146 | 'state_dict': model.state_dict(), 147 | 'eval_loss': best_loss, 148 | }, checkpoint_path+'model_best.pth.tar') 149 | 150 | # save models every 100 epoch 151 | if (epoch+1) % 100 == 0: 152 | torch.save({ 153 | 'epoch': epoch + 1, 154 | 'state_dict': model.state_dict(), 155 | 'eval_loss': eval_loss, 156 | }, checkpoint_path+'checkpoint-epoch'+str(epoch+1)+'.pth.tar') 157 | 158 | print('Training finished at %s' % datetime.now()) 159 | 160 | # def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 161 | # torch.save(state, checkpoint_path+filename) 162 | # if is_best: 163 | # shutil.copyfile(checkpoint_path+filename, checkpoint_path+'model_best.pth.tar') 164 | 165 | if __name__ == '__main__': 166 | main() 167 | -------------------------------------------------------------------------------- /misc/audio_lpc.py: -------------------------------------------------------------------------------- 1 | from audiolazy import lpc 2 | # from audiolazy import lazy_lpc 3 | import numpy as np 4 | import scipy.signal as signal 5 | import scipy.io.wavfile as wav 6 | import os 7 | from tqdm import tqdm 8 | from multiprocessing import Pool 9 | 10 | _num_worker = 8 11 | 12 | # data path 13 | dataroot = '/home/liyachun/data/audio2bs' 14 | wav_path = os.path.join(dataroot, 'sandbox/wav/') 15 | feature_path = os.path.join(dataroot, 'sandbox/feature-lpc-mp/') 16 | if not os.path.isdir(feature_path): os.mkdir(feature_path) 17 | 18 | def audio_lpc(wav_file, feature_file): 19 | (rate, sig) = wav.read(wav_file) 20 | print(rate) 21 | 22 | videorate = 30 23 | nw = int(rate/videorate) # time 24 | inc = int(nw/2) # 2* overlap 25 | # assert type(inc) == int 26 | winfunc = np.hanning(nw) 27 | 28 | def enframe(signal, nw, inc, winfunc): 29 | """turn audio signal to frame. 30 | parameters: 31 | signal: original audio signal 32 | nw: length of each frame(audio samples = audio sample rate * time interval) 33 | inc: intervals of consecutive frames 34 | """ 35 | signal_length=len(signal) #length of audio signal 36 | if signal_length<=nw: 37 | nf=1 38 | else: #otherwise, compute the length of audio frame 39 | nf=int(np.ceil((1.0*signal_length-nw+inc)/inc)) 40 | pad_length=int((nf-1)*inc+nw) #length of flatten all frames 41 | zeros=np.zeros((pad_length-signal_length,)) # 42 | pad_signal=np.concatenate((signal,zeros)) #after padding 43 | #nf*nw matrix 44 | indices=np.tile(np.arange(0,nw),(nf,1))+np.tile(np.arange(0,nf*inc,inc),(nw,1)).T 45 | indices=np.array(indices,dtype=np.int32) #turn indices to frames 46 | frames=pad_signal[indices] #get frames 47 | win=np.tile(winfunc,(nf,1)) #window function 48 | return frames*win #return frame matrix 49 | 50 | sig = signal.detrend(sig, type= 'constant') 51 | 52 | frame = enframe(sig, nw, inc, winfunc) 53 | assert len(frame) >= 64 54 | 55 | win_size = 64 56 | K = 32 # number of coefficients 57 | win_count = int((len(frame)-win_size)/2)+1 # len(frame) or frame.shape[0] 58 | lpc_feature = np.zeros(shape=(frame.shape[0], K)) 59 | output = np.zeros(shape=(win_count, win_size, K)) 60 | # print(output.shape, mfcc_feature.shape) 61 | 62 | # pbar = tqdm(total=len(frame)) 63 | pool = Pool(_num_worker) 64 | 65 | # for i in range(len(frame)): 66 | # filt = lpc.nautocor(frame[i], order=K) 67 | # lpc_feature[i] = filt.numerator[1:] 68 | # 69 | # if i > 0 and i % 10 == 0: 70 | # pbar.update(10) 71 | filt_coef = pool.map(lpc_K, tqdm(frame)) 72 | lpc_feature[:] = filt_coef 73 | 74 | # pbar.close() 75 | print(type(lpc_feature), lpc_feature.shape) 76 | 77 | for win in range(win_count): 78 | output[win] = lpc_feature[2*win : 2*win+win_size] 79 | 80 | np.save(feature_file, output) 81 | print("LPC extraction finished {}".format(output.shape)) 82 | 83 | def lpc_K(frame, order=32): 84 | filt = lpc.nautocor(frame, order=order) 85 | return filt.numerator[1:] # List of coefficients 86 | 87 | def main(): 88 | wav_files = os.listdir(wav_path) 89 | print(wav_files) 90 | for wav_file in wav_files: 91 | feature_file = wav_file.split('.')[0] + '-lpc' + '.npy' 92 | print('-------------------\n', feature_file) 93 | audio_lpc(wav_path+wav_file, feature_path+feature_file) 94 | 95 | if __name__ == '__main__': 96 | main() 97 | -------------------------------------------------------------------------------- /misc/audio_mfcc.py: -------------------------------------------------------------------------------- 1 | import python_speech_features as psf 2 | import numpy as np 3 | # import scipy.signal as signal 4 | import scipy.io.wavfile as wav 5 | import os 6 | from tqdm import tqdm 7 | 8 | 9 | # data path 10 | dataroot = '/home/liyachun/data/audio2bs' 11 | wav_path = os.path.join(dataroot, 'test/wav/') 12 | feature_path = os.path.join(dataroot, 'test/feature-mfcc39/') 13 | 14 | if not os.path.isdir(feature_path): os.mkdir(feature_path) 15 | 16 | def audio_mfcc(wav_file, feature_file=None): 17 | # load wav 18 | (rate, sig) = wav.read(wav_file) 19 | print(rate) 20 | 21 | # parameterss 22 | videorate = 30 23 | winlen = 1./videorate # time 24 | winstep = 0.5/videorate # 2* overlap 25 | # numcep = 32 # number of cepstrum to return, 0--power 1:31--features, nfilt caide 26 | numcep = 13 # typical value 27 | winfunc = np.hanning 28 | 29 | mfcc = psf.mfcc(sig, rate, winlen=winlen, winstep=winstep, 30 | numcep=numcep, nfilt=numcep*2, nfft=int(rate/videorate), winfunc = winfunc) 31 | # print(mfcc_feature[0:5]) 32 | mfcc_delta = psf.base.delta(mfcc, 2) 33 | mfcc_delta2 = psf.base.delta(mfcc_delta, 2) 34 | 35 | mfcc_all = np.concatenate((mfcc, mfcc_delta, mfcc_delta2), axis=1) 36 | print(type(mfcc_all), mfcc_all.shape) 37 | 38 | win_size = 64 39 | win_count = int((len(mfcc_all)-win_size)/2)+1 40 | output = np.zeros(shape=(win_count, win_size, numcep*3)) 41 | # print(output.shape, mfcc_feature.shape) 42 | for win in tqdm(range(win_count)): 43 | output[win] = mfcc_all[2*win : 2*win+win_size] 44 | 45 | if feature_file: 46 | np.save(feature_file, output) 47 | 48 | print("MPCC extraction finished {}".format(output.shape)) 49 | 50 | return output 51 | 52 | 53 | def main(): 54 | wav_files = os.listdir(wav_path) 55 | print(wav_files) 56 | for wav_file in wav_files: 57 | feature_file = wav_file.split('.')[0] + '-mfcc39.npy' 58 | print('-------------------\n', feature_file) 59 | audio_mfcc(os.path.join(wav_path, wav_file), 60 | os.path.join(feature_path, feature_file)) 61 | 62 | if __name__ == '__main__': 63 | main() 64 | -------------------------------------------------------------------------------- /misc/combine.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | dataroot = '/home/liyachun/data/audio2bs' 5 | feature_path = os.path.join(dataroot, 'sandbox/feature-lpc/') 6 | target_path = os.path.join(dataroot, 'sandbox/blendshape/') 7 | 8 | test_ind = ['06', '17'] 9 | 10 | def combine(feature_path, target_path): 11 | feature_files = sorted(os.listdir(feature_path)) 12 | print('feature: ', feature_files) 13 | 14 | blendshape_files = sorted(os.listdir(target_path)) 15 | print('bs: ', blendshape_files) 16 | 17 | for i in range(len(feature_files)): 18 | # skip test files 19 | if blendshape_files[i].split('.')[0].split('_')[1] in test_ind: 20 | continue 21 | 22 | if i == 0: 23 | feature = np.load(feature_path+feature_files[i]) 24 | feature_combine_file = feature_files[i].split('_')[0] + '.npy' 25 | 26 | # blendshape is shorter, need cut 27 | blendshape = np.loadtxt(target_path+blendshape_files[i]) 28 | blendshape = cut(feature, blendshape) 29 | 30 | blendshape_combine_file = blendshape_files[i].split('_')[0] + '.txt' 31 | 32 | else: 33 | feature_temp = np.load(feature_path+feature_files[i]) 34 | feature = np.concatenate((feature, feature_temp), 0) 35 | 36 | # blendshape is shorter 37 | blendshape_temp = np.loadtxt(target_path+blendshape_files[i]) 38 | blendshape_temp = cut(feature_temp, blendshape_temp) 39 | 40 | blendshape = np.concatenate((blendshape, blendshape_temp), 0) 41 | 42 | print(i, blendshape_files[i], feature.shape, blendshape.shape) 43 | 44 | np.save(os.path.join(feature_path, feature_combine_file), feature) 45 | np.savetxt(os.path.join(target_path, blendshape_combine_file), blendshape, fmt='%.8f') 46 | 47 | def cut(wav_feature, blendshape_target): 48 | n_audioframe, n_videoframe = len(wav_feature), len(blendshape_target) 49 | print('--------\n', 'Current dataset -- n_audioframe: {}, n_videoframe:{}'.format(n_audioframe, n_videoframe)) 50 | assert n_videoframe - n_audioframe == 32 51 | start_videoframe = 16 52 | blendshape_target = blendshape_target[start_videoframe : start_videoframe+n_audioframe] 53 | 54 | return blendshape_target 55 | 56 | def main(): 57 | combine(feature_path, target_path) 58 | 59 | if __name__ == '__main__': 60 | main() 61 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | # n_blendshape = 51 5 | 6 | # audio2blendshape model 7 | class A2BNet(nn.Module): 8 | 9 | def __init__(self, base_model='lstm'): 10 | super(A2BNet, self).__init__() 11 | 12 | self.base_model = base_model 13 | print('Initialising with %s' % (base_model)) 14 | 15 | 16 | def _prepare_model(self): 17 | 18 | if self.base_model == 'lstm': 19 | self.base_model = FullyLSTM() 20 | 21 | elif self.base_model == 'nvidia': 22 | self.base_model = NvidiaNet() 23 | 24 | # lstm concat before nvidia network 25 | elif self.base_model == 'lstm-nvidia': 26 | self.base_model = LSTMNvidiaNet() 27 | 28 | def forward(self, input): 29 | output = self.base_model(input) 30 | 31 | return output 32 | 33 | # LSTM model 34 | class FullyLSTM(nn.Module): 35 | 36 | def __init__(self, num_features=32, num_blendshapes=51): 37 | super(FullyLSTM, self).__init__() 38 | self.rnn = nn.LSTM(input_size=num_features, hidden_size=256, num_layers=2, 39 | batch_first=True, dropout=0.5, bidirectional=True) 40 | self.out = nn.Linear(256*2, num_blendshapes) 41 | 42 | def forward(self, input): 43 | # self.rnn.flatten_parameters() 44 | output, _ = self.rnn(input) 45 | output = self.out(output[:, -1, :]) 46 | return output 47 | 48 | # LSTM-AE model 49 | class LSTMAE(nn.Module): 50 | 51 | def __init__(self, num_timeseries=32, num_audiof=128, num_bsf=2, num_blendshapes=51): 52 | super(LSTMAE, self).__init__() 53 | ## encoder 54 | # audio part with LSTM 55 | self.rnn = nn.LSTM(input_size=num_features, hidden_size=256, num_layers=2, 56 | batch_first=True, dropout=0.5, bidirectional=True) 57 | self.audio_fc = nn.Linear(256*2, num_audiof) 58 | 59 | # blendshape part with fc 60 | self.bs_fc = nn.Sequential( 61 | nn.Linear(num_blendshapes, 24), 62 | nn.ReLU(True), 63 | 64 | nn.Linear(24, num_bsf), 65 | nn.Sigmoid() 66 | ) 67 | 68 | ## decoder? 69 | self.decoder_fc = nn.Sequential( 70 | nn.Linear(num_audiof+num_bsf, 64), 71 | nn.ReLU(True), 72 | nn.Linear(64, num_blendshapes), 73 | nn.Sigmoid() 74 | ) 75 | 76 | def reparameterize(self, mu, logvar): 77 | if self.training: 78 | std = torch.exp(0.5*logvar) 79 | # eps = torch.randn_like(std) 80 | # eps = torch.randn(std.size(), dtype=std.dtype, layout=std.layout, device=std.device) 81 | eps = torch.randn(std.size()) 82 | return eps.mul(std).add_(mu) 83 | else: 84 | return mu 85 | 86 | def decode(self, z): 87 | return self.decoder_fc(z) 88 | 89 | 90 | def forward(self, audio, blendshape): 91 | # encoder 92 | audio_rnn, _ = self.rnn(audio) 93 | audio_fc = self.audio_fc1(audio_rnn[:, -1, :]) 94 | 95 | bs_fc = self.bs_fc(blendshape) 96 | 97 | z = torch.cat((audio_fc, bs_fc), dim=1) 98 | output = self.decode(z) 99 | 100 | return audio_fc, bs_fc, output 101 | 102 | 103 | # nvidia model 104 | class NvidiaNet(nn.Module): 105 | 106 | def __init__(self, num_blendshapes=51): 107 | super(NvidiaNet, self).__init__() 108 | # formant analysis network 109 | self.num_blendshapes = num_blendshapes 110 | self.formant = nn.Sequential( 111 | nn.Conv2d(1, 72, kernel_size=(1,3), stride=(1,2), padding=(0,1)), 112 | nn.ReLU(), 113 | nn.Conv2d(72, 108, kernel_size=(1,3), stride=(1,2), padding=(0,1)), 114 | nn.ReLU(), 115 | nn.Conv2d(108, 162, kernel_size=(1,3), stride=(1,2), padding=(0,1)), 116 | nn.ReLU(), 117 | nn.Conv2d(162, 243, kernel_size=(1,3), stride=(1,2), padding=(0,1)), 118 | nn.ReLU(), 119 | nn.Conv2d(243, 256, kernel_size=(1,3), stride=(1,2), padding=(0,1)), 120 | nn.ReLU() 121 | ) 122 | 123 | # articulation network 124 | self.articulation = nn.Sequential( 125 | nn.Conv2d(256, 256, kernel_size=(3,1), stride=(2,1), padding=(1,0)), 126 | nn.ReLU(), 127 | nn.Conv2d(256, 256, kernel_size=(3,1), stride=(2,1), padding=(1,0)), 128 | nn.ReLU(), 129 | nn.Conv2d(256, 256, kernel_size=(3,1), stride=(2,1), padding=(1,0)), 130 | nn.ReLU(), 131 | nn.Conv2d(256, 256, kernel_size=(3,1), stride=(2,1), padding=(1,0)), 132 | nn.ReLU(), 133 | nn.Conv2d(256, 256, kernel_size=(4,1), stride=(4,1)), 134 | nn.ReLU() 135 | ) 136 | 137 | # output network 138 | self.output = nn.Sequential( 139 | nn.Linear(256, 150), 140 | nn.ReLU(), 141 | nn.Dropout(p=0.5), 142 | nn.Linear(150, self.num_blendshapes) 143 | ) 144 | 145 | def forward(self, x): 146 | x = torch.unsqueeze(x, dim=1) # (-1, channel, height, width) 147 | # or x = x.view(-1, 1, 64, 32) 148 | 149 | # convolution 150 | x = self.formant(x) 151 | x = self.articulation(x) 152 | 153 | # fully connected 154 | x = x.view(-1, num_flat_features(x)) 155 | x = self.output(x) 156 | 157 | return x 158 | 159 | class LSTMNvidiaNet(nn.Module): 160 | 161 | def __init__(self, num_blendshapes=51, num_emotions=16): 162 | super(LSTMNvidiaNet, self).__init__() 163 | 164 | self.num_blendshapes = num_blendshapes 165 | self.num_emotions = num_emotions 166 | 167 | # emotion network with LSTM 168 | self.emotion = nn.LSTM(input_size=32, hidden_size=128, num_layers=1, 169 | batch_first=True, dropout=0.5, bidirectional=True) 170 | self.dense = nn.Sequential( 171 | nn.Linear(128*2, 150), 172 | nn.ReLU(), 173 | nn.Linear(150, self.num_emotions) 174 | ) 175 | 176 | 177 | # formant analysis network 178 | self.formant = nn.Sequential( 179 | nn.Conv2d(1, 72, kernel_size=(1,3), stride=(1,2), padding=(0,1)), 180 | nn.ReLU(), 181 | nn.Conv2d(72, 108, kernel_size=(1,3), stride=(1,2), padding=(0,1)), 182 | nn.ReLU(), 183 | nn.Conv2d(108, 162, kernel_size=(1,3), stride=(1,2), padding=(0,1)), 184 | nn.ReLU(), 185 | nn.Conv2d(162, 243, kernel_size=(1,3), stride=(1,2), padding=(0,1)), 186 | nn.ReLU(), 187 | nn.Conv2d(243, 256, kernel_size=(1,3), stride=(1,2), padding=(0,1)), 188 | nn.ReLU() 189 | ) 190 | 191 | # articulation network 192 | self.conv1 = nn.Conv2d(256, 256, kernel_size=(3,1), stride=(2,1), padding=(1,0)) 193 | self.conv2 = nn.Conv2d(256+self.num_emotions, 256, kernel_size=(3,1), stride=(2,1), padding=(1,0)) 194 | self.conv5 = nn.Conv2d(256+self.num_emotions, 256, kernel_size=(4,1), stride=(4,1)) 195 | self.relu = nn.ReLU() 196 | 197 | # output network 198 | self.output = nn.Sequential( 199 | nn.Linear(256+self.num_emotions, 150), 200 | nn.ReLU(), 201 | nn.Dropout(p=0.5), 202 | nn.Linear(150, self.num_blendshapes) 203 | ) 204 | 205 | def forward(self, x): 206 | # extract emotion state 207 | e_state, _ = self.emotion(x[:, ::2]) # input features are 2* overlapping 208 | e_state = self.dense(e_state[:, -1, :]) # last 209 | e_state = e_state.view(-1, self.num_emotions, 1, 1) 210 | 211 | x = torch.unsqueeze(x, dim=1) 212 | # convolution 213 | x = self.formant(x) 214 | 215 | # conv+concat 216 | x = self.relu(self.conv1(x)) 217 | x = torch.cat((x, e_state.repeat(1, 1, 32, 1)), 1) 218 | 219 | x = self.relu(self.conv2(x)) 220 | x = torch.cat((x, e_state.repeat(1, 1, 16, 1)), 1) 221 | 222 | x = self.relu(self.conv2(x)) 223 | x = torch.cat((x, e_state.repeat(1, 1, 8, 1)), 1) 224 | 225 | x = self.relu(self.conv2(x)) 226 | x = torch.cat((x, e_state.repeat(1, 1, 4, 1)), 1) 227 | 228 | x = self.relu(self.conv5(x)) 229 | x = torch.cat((x, e_state), 1) 230 | 231 | # fully connected 232 | x = x.view(-1, num_flat_features(x)) 233 | x = self.output(x) 234 | 235 | return x 236 | 237 | 238 | 239 | 240 | def num_flat_features(x): 241 | size = x.size()[1:] # all dimensions except the batch dimension 242 | num_features = 1 243 | for s in size: 244 | num_features *= s 245 | return num_features 246 | -------------------------------------------------------------------------------- /models_testae.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from torch.autograd import Variable 4 | 5 | ## LSTM-AE model 6 | # concat 7 | class LSTMAEsigm(nn.Module): 8 | 9 | def __init__(self, num_features=39, num_audiof=128, num_bsf=2, num_blendshapes=51, is_concat=True): 10 | super(LSTMAEsigm, self).__init__() 11 | # assign 12 | self.is_concat = is_concat 13 | 14 | ## encoder 15 | # audio part with LSTM 16 | self.rnn = nn.LSTM(input_size=num_features, hidden_size=256, num_layers=2, 17 | batch_first=True, dropout=0.5, bidirectional=True) 18 | self.audio_fc = nn.Linear(256*2, num_audiof) 19 | 20 | # blendshape part with fc 21 | self.bs_fc = nn.Sequential( 22 | nn.Linear(num_blendshapes, 24), 23 | nn.ReLU(True), 24 | 25 | nn.Linear(24, num_bsf), 26 | nn.Sigmoid() # constrain to 0~1, control variable 27 | ) 28 | 29 | ## decoder? 30 | if self.is_concat: 31 | self.decoder_fc = nn.Sequential( 32 | nn.Linear(num_audiof+num_bsf, 64), 33 | nn.ReLU(True), 34 | nn.Linear(64, num_blendshapes), 35 | nn.Sigmoid() 36 | ) 37 | else: 38 | self.bilinear = nn.Bilinear(num_bsf, num_audiof, num_audiof) 39 | self.decoder_fc = nn.Sequential( 40 | nn.Linear(num_audiof, 64), 41 | nn.ReLU(True), 42 | nn.Linear(64, num_blendshapes), 43 | nn.Sigmoid() 44 | ) 45 | 46 | def fuse(self, audio_z, bs_z): 47 | # concat or bilinear 48 | if self.is_concat: 49 | return torch.cat((audio_z, bs_z), dim=1) 50 | else: 51 | return self.bilinear(bs_z, audio_z) 52 | 53 | def decode(self, z): 54 | return self.decoder_fc(z) 55 | 56 | def decode_audio(self, audio, bs_z): 57 | audio_rnn, _ = self.rnn(audio) 58 | audio_z = self.audio_fc(audio_rnn[:, -1, :]) 59 | bs_z = bs_z.repeat(audio_z.size()[0], 1) # to batch size 60 | 61 | z = self.fuse(audio_z, bs_z) 62 | return self.decode(z) 63 | 64 | def forward(self, audio, blendshape): 65 | # encode 66 | audio_rnn, _ = self.rnn(audio) 67 | audio_z = self.audio_fc(audio_rnn[:, -1, :]) 68 | 69 | bs_z = self.bs_fc(blendshape) 70 | 71 | # decode 72 | z = self.fuse(audio_z, bs_z) 73 | output = self.decode(z) 74 | 75 | return audio_z, bs_z, output 76 | 77 | class LSTMAEdist(nn.Module): 78 | 79 | def __init__(self, num_features=39, num_audiof=128, num_bsf=2, num_blendshapes=51, is_concat=True): 80 | super(LSTMAEdist, self).__init__() 81 | # assign 82 | self.is_concat = is_concat 83 | 84 | ## encoder 85 | # audio part with LSTM 86 | self.rnn = nn.LSTM(input_size=num_features, hidden_size=256, num_layers=2, 87 | batch_first=True, dropout=0.5, bidirectional=True) 88 | self.audio_fc = nn.Linear(256*2, num_audiof) 89 | 90 | # blendshape part with fc 91 | self.bs_fc1 = nn.Sequential( 92 | nn.Linear(num_blendshapes, 24), 93 | nn.ReLU(True), 94 | ) 95 | self.bs_fc21 = nn.Linear(24, num_bsf) 96 | self.bs_fc22 = nn.Linear(24, num_bsf) 97 | 98 | ## decoder? 99 | if self.is_concat: 100 | self.decoder_fc = nn.Sequential( 101 | nn.Linear(num_audiof+num_bsf, 64), 102 | nn.ReLU(True), 103 | nn.Linear(64, num_blendshapes), 104 | nn.Sigmoid() 105 | ) 106 | else: 107 | self.bilinear = nn.Bilinear(num_bsf, num_audiof, num_audiof) 108 | self.decoder_fc = nn.Sequential( 109 | nn.Linear(num_audiof, 64), 110 | nn.ReLU(True), 111 | nn.Linear(64, num_blendshapes), 112 | nn.Sigmoid() 113 | ) 114 | 115 | def encode(self, audio, blendshape): 116 | audio_rnn, _ = self.rnn(audio) 117 | audio_z = self.audio_fc(audio_rnn[:, -1, :]) 118 | 119 | bs_h1 = self.bs_fc1(blendshape) 120 | return audio_z, self.bs_fc21(bs_h1), self.bs_fc22(bs_h1) 121 | 122 | def fuse(self, audio_z, bs_z): 123 | # concat or bilinear 124 | if self.is_concat: 125 | return torch.cat((audio_z, bs_z), dim=1) 126 | else: 127 | return self.bilinear(bs_z, audio_z) 128 | 129 | def reparameterize(self, mu, logvar): 130 | if self.training: 131 | std = torch.exp(0.5*logvar) 132 | # eps = torch.randn_like(std) 133 | # eps = torch.randn(std.size(), dtype=std.dtype, layout=std.layout, device=std.device) 134 | eps = Variable(torch.randn(std.size())).cuda() 135 | return eps.mul(std).add_(mu) 136 | else: 137 | return mu 138 | 139 | def decode(self, z): 140 | return self.decoder_fc(z) 141 | 142 | def decode_audio(self, audio, bs_z): 143 | audio_rnn, _ = self.rnn(audio) 144 | audio_z = self.audio_fc(audio_rnn[:, -1, :]) 145 | bs_z = bs_z.repeat(audio_z.size()[0], 1) 146 | 147 | z = self.fuse(audio_z, bs_z) 148 | return self.decode(z) 149 | 150 | def forward(self, audio, blendshape): 151 | # encode 152 | audio_z, bs_mu, bs_logvar = self.encode(audio, blendshape) 153 | bs_z = self.reparameterize(bs_mu, bs_logvar) 154 | 155 | # decode 156 | z = self.fuse(audio_z, bs_z) 157 | output = self.decode(z) 158 | 159 | return audio_z, bs_z, output, bs_mu, bs_logvar 160 | 161 | class LSTMAE2dist(nn.Module): 162 | 163 | def __init__(self, num_features=39, num_audiof=128, num_bsf=2, num_blendshapes=51, is_concat=True): 164 | super(LSTMAE2dist, self).__init__() 165 | # assign 166 | self.is_concat = is_concat 167 | self.num_audiof = num_audiof 168 | self.num_bsf = num_bsf 169 | 170 | ## encoder 171 | # audio part with LSTM 172 | self.rnn = nn.LSTM(input_size=num_features, hidden_size=256, num_layers=2, 173 | batch_first=True, dropout=0.5, bidirectional=True) 174 | self.audio_fc11 = nn.Linear(256*2, num_audiof) 175 | self.audio_fc12 = nn.Linear(256*2, num_audiof) 176 | 177 | # blendshape part with fc 178 | self.bs_fc1 = nn.Sequential( 179 | nn.Linear(num_blendshapes, 24), 180 | nn.ReLU(True), 181 | ) 182 | self.bs_fc21 = nn.Linear(24, num_bsf) 183 | self.bs_fc22 = nn.Linear(24, num_bsf) 184 | 185 | ## decoder? 186 | if self.is_concat: 187 | self.decoder_fc = nn.Sequential( 188 | nn.Linear(num_audiof+num_bsf, 64), 189 | nn.ReLU(True), 190 | nn.Linear(64, num_blendshapes), 191 | nn.Sigmoid() 192 | ) 193 | else: 194 | # self.bilinear = nn.Bilinear(num_bsf, num_audiof, num_audiof) 195 | # self.decoder_fc = nn.Sequential( 196 | # nn.Linear(num_audiof, 64), 197 | # nn.ReLU(True), 198 | # nn.Linear(64, num_blendshapes), 199 | # nn.Sigmoid() 200 | # ) 201 | print('NO bilinear combination in 2dist model') 202 | 203 | def encode(self, audio, blendshape): 204 | 205 | audio_rnn, _ = self.rnn(audio) 206 | audio_h = audio_rnn[:, -1, :] 207 | 208 | bs_h1 = self.bs_fc1(blendshape) 209 | 210 | return self.audio_fc11(audio_h), self.audio_fc12(audio_h), self.bs_fc21(bs_h1), self.bs_fc22(bs_h1) 211 | 212 | # def fuse(self, audio_z, bs_z): 213 | # # concat or bilinear 214 | # if self.is_concat: 215 | # return torch.cat((audio_z, bs_z), dim=1) 216 | # else: 217 | # return self.bilinear(bs_z, audio_z) 218 | 219 | def reparameterize(self, mu, logvar): 220 | if self.training: 221 | std = torch.exp(0.5*logvar) 222 | # eps = torch.randn_like(std) 223 | # eps = torch.randn(std.size(), dtype=std.dtype, layout=std.layout, device=std.device) 224 | eps = Variable(torch.randn(std.size())).cuda() 225 | return eps.mul(std).add_(mu) 226 | else: 227 | return mu 228 | 229 | def decode(self, z): 230 | return self.decoder_fc(z) 231 | 232 | def decode_audio(self, audio, bs_z): 233 | audio_rnn, _ = self.rnn(audio) 234 | audio_h = audio_rnn[:, -1, :] 235 | 236 | audio_mu = self.audio_fc11(audio_h) 237 | audio_logvar = self.audio_fc12(audio_h) 238 | 239 | audio_z = self.reparameterize(audio_mu, audio_logvar) 240 | 241 | bs_z = bs_z.repeat(audio_z.size()[0], 1) 242 | 243 | z = torch.cat((audio_z, bs_z), dim=1) 244 | return self.decode(z) 245 | 246 | def forward(self, audio, blendshape): 247 | # encode 248 | audio_mu, audio_logvar, bs_mu, bs_logvar = self.encode(audio, blendshape) 249 | mu = torch.cat((audio_mu, bs_mu), dim=1) 250 | logvar = torch.cat((audio_logvar, bs_logvar), dim=1) 251 | 252 | z = self.reparameterize(mu, logvar) 253 | 254 | # decode 255 | # z = self.fuse(audio_z, bs_z) 256 | output = self.decode(z) 257 | 258 | return z[:, :self.num_audiof], z[:, self.num_audiof:], output, mu, logvar 259 | -------------------------------------------------------------------------------- /synthesis.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import argparse 4 | import os 5 | import numpy as np 6 | from torch.autograd import Variable 7 | from torch.utils.data import DataLoader, TensorDataset 8 | from scipy.signal import savgol_filter 9 | from misc.audio_mfcc import audio_mfcc 10 | from misc.audio_lpc import audio_lpc 11 | from models import FullyLSTM, NvidiaNet 12 | from models_testae import * 13 | 14 | # parser arguments 15 | parser = argparse.ArgumentParser(description='Synthesize wave to blendshape') 16 | parser.add_argument('wav', type=str, help='wav to synthesize') 17 | parser.add_argument('--smooth', type=bool, default=True) 18 | parser.add_argument('--pad', type=bool, default=True) 19 | parser.add_argument('--control', type=bool, default=False) 20 | 21 | args = parser.parse_args() 22 | print(args) 23 | 24 | n_blendshape = 51 25 | ckp = './checkpoint-lstmae-2distconcat_kl0/checkpoint-epoch500.pth.tar' 26 | result_path = './synthesis' 27 | result_file = 'test'+args.wav.split('/')[-1].split('.')[0]+'-lstmae-2distconcat_kl0.txt' 28 | 29 | def main(): 30 | global result_file 31 | 32 | start_time = time.time() 33 | ## process audio 34 | feature = torch.from_numpy(audio_mfcc(args.wav)) 35 | # print('Feature extracted ', feature.shape) 36 | target = torch.from_numpy(np.zeros(feature.shape[0])) 37 | 38 | ## load model 39 | model = LSTMAE2dist(is_concat=True) 40 | 41 | # restore checkpoint model 42 | print("=> loading checkpoint '{}'".format(ckp)) 43 | checkpoint = torch.load(ckp) 44 | print("model epoch {} loss: {}".format(checkpoint['epoch'], checkpoint['eval_loss'])) 45 | 46 | model.load_state_dict(checkpoint['state_dict']) 47 | 48 | if torch.cuda.is_available(): 49 | model = model.cuda() 50 | 51 | # evaluation for audio feature 52 | model.eval() 53 | 54 | ## build dataset 55 | test_loader = DataLoader(TensorDataset(feature, target), 56 | batch_size=100, shuffle=False, num_workers=2) 57 | 58 | for i, (input, target) in enumerate(test_loader): 59 | # target = target.cuda(async=True) 60 | input_var = Variable(input.float(), volatile=True).cuda() 61 | # target_var = Variable(target.float(), volatile=True) 62 | 63 | # compute output 64 | if args.control: 65 | bs_z = Variable(torch.Tensor([0.5, 0.5]), volatile=True).cuda() # control variable 66 | # print('Control with', bs_z.data[0]) 67 | output = model.decode_audio(input_var, bs_z) 68 | else: 69 | output = model(input_var) 70 | 71 | if i == 0: 72 | output_cat = output.data 73 | else: 74 | output_cat = torch.cat((output_cat, output.data), 0) 75 | # print(type(output_cat.cpu().numpy()), output_cat.cpu().numpy().shape) 76 | 77 | # convert back *100 78 | output_cat = output_cat.cpu().numpy()*100.0 79 | 80 | if args.smooth: 81 | #smooth3--savgol_filter 82 | win = 9; polyorder = 3 83 | for i in range(n_blendshape): 84 | power = output_cat[:,i] 85 | power_smooth = savgol_filter(power, win, polyorder, mode='nearest') 86 | output_cat[:, i] = power_smooth 87 | result_file = 'smooth-' + result_file 88 | 89 | # pad blendshape 90 | if args.pad: 91 | output_cat = pad_blendshape(output_cat) 92 | result_file = 'pad-' + result_file 93 | 94 | # count time for synthesis 95 | past_time = time.time() - start_time 96 | print("Synthesis finished in {:.4f} sec! Saved in {}".format(past_time, result_file)) 97 | 98 | with open(os.path.join(result_path, result_file), 'wb') as f: 99 | np.savetxt(f, output_cat, fmt='%.6f') 100 | 101 | def pad_blendshape(blendshape): 102 | return np.pad(blendshape, [(16, 16), (0, 0)], mode='constant', constant_values=0.0) 103 | 104 | if __name__ == '__main__': 105 | main() 106 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | def convert(s, numRows): 2 | """ 3 | :type s: str 4 | :type numRows: int 5 | :rtype: str 6 | """ 7 | 8 | 9 | 10 | if __name__ == '__main__': 11 | pass 12 | s = "PAHNAPLSIIGYIR" 13 | convert(s,3) 14 | -------------------------------------------------------------------------------- /test_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd as autograd 3 | 4 | import os 5 | import time 6 | import numpy as np 7 | import argparse 8 | from scipy.signal import savgol_filter 9 | 10 | from dataset import BlendshapeDataset 11 | from models import A2BNet, NvidiaNet, LSTMNvidiaNet, FullyLSTM 12 | 13 | # options 14 | parser = argparse.ArgumentParser(description="PyTorch testing of LSTM") 15 | parser.add_argument('ckp', type=str) 16 | parser.add_argument('--smooth', type=bool, default=False) 17 | parser.add_argument('--pad', type=bool, default=False) 18 | parser.add_argument('--epoch', type=int, default=None) 19 | parser.add_argument('--net', type=str, default='lstm') 20 | 21 | args = parser.parse_args() 22 | 23 | # parameters 24 | n_blendshape = 51 25 | batch_size = 100 26 | 27 | # data path 28 | dataroot = '/home/liyachun/data/audio2bs' 29 | data_path = os.path.join(dataroot, 'test') 30 | # data_path = './data/test/' 31 | checkpoint_path = './checkpoint-'+args.net+'-mfcc39/' 32 | result_path = './results/' 33 | 34 | result_file = 'test0201_06-0201train-'+args.net+'-mfcc39.txt' 35 | 36 | if args.epoch != None: 37 | ckp = 'checkpoint-epoch'+str(args.epoch)+'.pth.tar' 38 | result_file = str(args.epoch)+'-'+result_file 39 | else: 40 | ckp = args.ckp+'.pth.tar' 41 | 42 | def pad_blendshape(blendshape): 43 | return np.pad(blendshape, [(16, 16), (0, 0)], mode='constant', constant_values=0.0) 44 | 45 | 46 | model = FullyLSTM(num_features=39) 47 | 48 | # restore checkpoint model 49 | print("=> loading checkpoint '{}'".format(ckp)) 50 | checkpoint = torch.load(os.path.join(checkpoint_path, ckp)) 51 | print("model epoch {} loss: {}".format(checkpoint['epoch'], checkpoint['eval_loss'])) 52 | 53 | model.load_state_dict(checkpoint['state_dict']) 54 | 55 | # load data 56 | val_loader = torch.utils.data.DataLoader( 57 | BlendshapeDataset(feature_file=os.path.join(data_path, 'feature/0201_06-1min-mfcc39.npy'), 58 | target_file=os.path.join(data_path, 'blendshape/0201_06-1min.txt')), 59 | batch_size=batch_size, shuffle=False, num_workers=2 60 | ) 61 | 62 | if torch.cuda.is_available(): 63 | model = model.cuda() 64 | 65 | # run test features 66 | model.eval() 67 | 68 | start_time = time.time() 69 | for i, (input, target) in enumerate(val_loader): 70 | target = target.cuda(async=True) 71 | input_var = autograd.Variable(input.float(), volatile=True).cuda() 72 | target_var = autograd.Variable(target.float(), volatile=True) 73 | 74 | # compute output 75 | output = model(input_var) 76 | 77 | if i == 0: 78 | output_cat = output.data 79 | else: 80 | output_cat = torch.cat((output_cat, output.data), 0) 81 | # print(type(output_cat.cpu().numpy()), output_cat.cpu().numpy().shape) 82 | 83 | # convert back *100 84 | output_cat = output_cat.cpu().numpy()*100.0 85 | 86 | if args.smooth: 87 | #smooth3--savgol_filter 88 | win = 9; polyorder = 3 89 | for i in range(n_blendshape): 90 | power = output_cat[:,i] 91 | power_smooth = savgol_filter(power, win, polyorder, mode='nearest') 92 | output_cat[:, i] = power_smooth 93 | result_file = 'smooth-' + result_file 94 | 95 | # padding to the same frames as input wav 96 | if args.pad: 97 | output_cat = pad_blendshape(output_cat) 98 | result_file = 'pad-' + result_file 99 | 100 | # count time for testing 101 | past_time = time.time() - start_time 102 | print("Test finished in {:.4f} sec! Saved in {}".format(past_time, result_file)) 103 | 104 | with open(os.path.join(result_path, result_file), 'wb') as f: 105 | np.savetxt(f, output_cat, fmt='%.6f') 106 | --------------------------------------------------------------------------------