├── LICENSE ├── README.md ├── matlab ├── config.m ├── demo_fdndlp.m ├── fdndlp.m └── lib │ ├── +util │ ├── fig.m │ ├── play.m │ └── plot.m │ ├── stftanalysis.m │ └── stftsynthesis.m ├── python ├── stft.py └── wpe.py └── wav_sample ├── drv_sample_4ch.wav └── sample_4ch.wav /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Teng Xiang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Frequency Domain Variance-normalized Delayed Linear Prediction Algorithm 2 | 3 | ## Introduction 4 | This program is an implementation of variance-normalizied delayed linear prediction in time-frequency domain, which is aimed at speech dereverberation, known as weighted prediction error (WPE) method. 5 | 6 | 7 | ## Requirements 8 | - MATLB Code 9 | - signal processing toolbox 10 | - Python Code 11 | - Python 3.x 12 | - Numpy 13 | - soundfile 14 | - matplotlib (Optional) 15 | 16 | ## Run the Demo 17 | - MATLAB code 18 | - Just run the script file `demo_fdndlp.m` in MATLAB and the audio sample in `wav_sample` will be used. 19 | - To use your own data, change the `filepath` and `sample_name` in `demo_fdndlp.m`. 20 | - The configrations are gathered in `config.m`. Be careful to change the settings. 21 | 22 | - Python code 23 | 24 | - Usage: 25 | ```bash 26 | python wpe.py [-h] [-o OUTPUT] [-m MIC_NUM] [-n OUT_NUM] [-p ORDER] filename 27 | ``` 28 | - To use the default configrations and the given audio sample, run: 29 | ```bash 30 | python wpe.py ../wav_sample/sample_4ch.wav 31 | ``` 32 | 33 | ## Layout 34 | ``` 35 | ./ 36 | +-- matlab/ matlab code files 37 | | +-- lib/ 38 | | | +-- +util/ utility functions 39 | | | |-- stftanalysis.m 40 | | | |-- stftsynthesis.m 41 | | |-- demo_fdndlp.m 42 | | |-- fdndlp.m 43 | | |-- config.m 44 | +-- python/ python code files 45 | | |-- wpe.py 46 | | |-- stft.py 47 | +-- wav_sample/ audio samples 48 | | |-- sample_4ch.wav reverberant speech 49 | | |-- drv_sample_4ch.wav dereverberated speech 50 | |-- README.md 51 | ``` 52 | 53 | 54 | ## Reference 55 | 56 | [WPE speech dereverberation](http://www.kecl.ntt.co.jp/icl/signal/wpe/) 57 | 58 | Nakatani T, Yoshioka T, Kinoshita K, et al. Speech Dereverberation Based on Variance-Normalized Delayed Linear Prediction[J]. IEEE Transactions on Audio Speech & Language Processing, 2010, 18(7):1717-1731. 59 | -------------------------------------------------------------------------------- /matlab/config.m: -------------------------------------------------------------------------------- 1 | 2 | num_mic = 3; 3 | num_out = 2; 4 | K = 512; % the number of subbands 5 | F = 2; % over-sampling rate 6 | N = K / F; % decimation factor 7 | D1 = 2; % subband preditction delay 8 | Lc1 = 30; % subband prediction order 9 | eps = 1e-4; % lower bound of rho(Normalizaton factor) 10 | max_iterations = 2; -------------------------------------------------------------------------------- /matlab/demo_fdndlp.m: -------------------------------------------------------------------------------- 1 | clc; 2 | clear; 3 | close all; 4 | 5 | % ***************************************************** 6 | % Set path 7 | % ***************************************************** 8 | 9 | addpath(genpath('lib')); 10 | output_dir = 'wav_out/'; 11 | if ~exist(output_dir, 'dir') 12 | mkdir(output_dir); 13 | disp(['mkdir ', output_dir]) 14 | end 15 | 16 | %****************************************************** 17 | % Input and Output Configurations 18 | %****************************************************** 19 | 20 | filepath = '../wav_sample/'; 21 | sample_name = 'sample_4ch.wav'; 22 | file_name = [filepath, sample_name]; 23 | out_name = [output_dir, ['drv_', sample_name]]; 24 | 25 | %****************************************************** 26 | % Set Parameters 27 | %****************************************************** 28 | % cfgs.num_mic = 3; 29 | % cfgs.num_out = 2; 30 | % cfgs.K = 512; % the number of subbands 31 | % cfgs.F = 2; % over-sampling rate 32 | % cfgs.N = cfg.K / cfg.F; % decimation factor 33 | % cfgs.D1 = 2; % subband preditction delay 34 | % cfgs.Lc1 = 30; % subband prediction order 35 | % cfgs.eps = 1e-4; % lower bound of rho(Normalizaton factor) 36 | % cfgs.iterations = 2; 37 | 38 | cfgs = 'config.m'; 39 | 40 | 41 | %****************************************************** 42 | % Read Audio Files and Processing 43 | %****************************************************** 44 | sig_multi_mode = 1; 45 | % sig_multi_mode = 0; 46 | % sig_num_mic = 3; 47 | disp('Reading Audio Files:') 48 | if sig_multi_mode 49 | disp(file_name) 50 | [x, fs] = audioread(file_name); 51 | else 52 | x = []; 53 | for m = 1 : sig_num_mic 54 | disp(file_name) 55 | filename1 = strrep(file_name, 'ch1', ['ch',num2str(m)]); 56 | [s, fs] = audioread(filename1); 57 | x = [x, s]; 58 | end 59 | end 60 | 61 | y = fdndlp(x, cfgs); 62 | 63 | % *************************************************** 64 | % Output 65 | % *************************************************** 66 | 67 | util.fig(x(:,1), fs); 68 | util.fig(y(:,1), fs); 69 | % udf.play(x(:,1), fs); 70 | % udf.play(y(:,1), fs); 71 | disp(['write to file:',32, out_name]) 72 | audiowrite(out_name, y/max(max(abs(y))), fs); 73 | 74 | rmpath(genpath('lib')); 75 | 76 | -------------------------------------------------------------------------------- /matlab/fdndlp.m: -------------------------------------------------------------------------------- 1 | function y = fdndlp(x, cfgs, varargin) 2 | % 3 | % ============================================================================= 4 | % 5 | % This program is an implementation of Variance-Normalizied Delayed Linear 6 | % Prediction in time-frequency domain, which is aimed at speech 7 | % dereverberation, known as weighted prediction error (WPE) method. 8 | % 9 | % Main parameters: 10 | % mic_num the number of channels 11 | % K the number of subbands 12 | % F over-sampling rate 13 | % N decimation factor 14 | % D1 subband preditction delay 15 | % Lc1 subband prediction order 16 | % eps lower bound of normalizaton factor 17 | % 18 | % Reference: 19 | % [1] Nakatani T, Yoshioka T, Kinoshita K, et al. Speech Dereverberation 20 | % Based on Variance-Normalized Delayed Linear Prediction[J]. IEEE 21 | % Transactions on Audio Speech & Language Processing, 2010, 18(7):1717-1731. 22 | % 23 | % ============================================================================= 24 | % Created by Teng Xiang at 2017-10-14 25 | % Current version: 2018-08-10 26 | % ============================================================================= 27 | 28 | 29 | % ============================================================================= 30 | % Load Parameters 31 | % ============================================================================= 32 | if ischar(cfgs) 33 | run(cfgs); 34 | else 35 | varnames = fieldnames(cfgs); 36 | for ii = 1 : length(varnames) 37 | eval([varnames{ii}, '= getfield(cfgs, varnames{ii});']); 38 | end 39 | end 40 | 41 | if exist('varargin', 'var') 42 | for ii = 1 : 2 : length(varargin) 43 | eval([varargin{ii}, '= varargin{ii+1};']) 44 | end 45 | end 46 | 47 | len = length(x); 48 | 49 | % ============================================================================= 50 | % Frequency-domain variance-normalized delayed linear prediction 51 | % ============================================================================= 52 | sig_channels = size(x, 2); 53 | if sig_channels > num_mic 54 | x = x(:,1:num_mic); 55 | fprintf('Only the first %d channels of input data are used\n\n', num_mic) 56 | elseif sig_channels < num_mic 57 | error('The channels of input does not match the channel setting'); 58 | end 59 | 60 | tic 61 | fprintf('Procssing...') 62 | 63 | xk = stftanalysis(x / max(max(abs(x))), K, N); 64 | LEN = size(xk, 1); 65 | dk = zeros(LEN, K, num_out); 66 | 67 | for k = 1 : K/2 + 1 68 | xk_tmp = zeros(LEN+Lc1, num_mic); 69 | xk_tmp(Lc1+1:end,:) = squeeze(xk(:,k,:)); 70 | xk_tmp = xk_tmp.'; 71 | x_buf = xk_tmp(1:num_out,Lc1+1:end).'; 72 | X_BUF = zeros(num_mic * Lc1, LEN); 73 | for ii = 1 : LEN-D1 74 | xn_D = xk_tmp(:,ii+Lc1:-1:ii+1); 75 | X_BUF(:,ii+D1) = xn_D(:); 76 | end 77 | rho2 = max(mean(abs(x_buf(:,1:num_out)).^2, 2), eps); 78 | c_Err = max_iterations; 79 | 80 | while (c_Err > 1e-2) 81 | Lambda = diag(1 ./ rho2); 82 | Phi = X_BUF*Lambda*X_BUF'; 83 | p = X_BUF*conj(x_buf./rho2(:,ones(1,num_out))); 84 | c = pinv(Phi)*p; 85 | dk(:,k,:) = (x_buf.' - c'*X_BUF).'; 86 | rho2 = max(mean(squeeze(abs(dk(:,k,:)).^2),2), eps); 87 | c_Err = c_Err - 1; 88 | end 89 | end 90 | dk(:,K/2+2:end,:) = conj(dk(:,K/2:-1:2,:)); 91 | y = stftsynthesis(dk, K, N); 92 | y = y(1 : len, :) / max(max(abs(y))); 93 | disp('Done!') 94 | toc 95 | -------------------------------------------------------------------------------- /matlab/lib/+util/fig.m: -------------------------------------------------------------------------------- 1 | function fig(data_in, fs) 2 | 3 | if nargin == 1 4 | [data, fs] = audioread(data_in); 5 | filename = data_in; 6 | else 7 | data = data_in; 8 | end 9 | Fs = 8000; 10 | noverlap = 128 * fs / Fs; 11 | nfft= 256 * fs / Fs; 12 | 13 | figure; 14 | spectrogram(data/max(abs(data)), hamming(nfft),noverlap,nfft,fs,'yaxis') 15 | set(gcf, 'position', [1, 235, 1366, 400]); 16 | set(gca, 'position', [0.05, 0.12, 0.85, 0.8]); 17 | if exist('filename','var') 18 | title(filename, 'interpreter', 'none'); 19 | end 20 | 21 | -------------------------------------------------------------------------------- /matlab/lib/+util/play.m: -------------------------------------------------------------------------------- 1 | function player = play(varargin) 2 | 3 | if ischar(varargin{1}) 4 | [data, fs] = audioread(varargin{1}); 5 | if length(varargin) == 1 6 | normalizemode = 1; 7 | else 8 | normalizemode = varargin{2}; 9 | end 10 | else 11 | 12 | data = varargin{1}; 13 | fs = varargin{2}; 14 | if length(varargin) == 2 15 | normalizemode = 1; 16 | else 17 | normalizemode = varargin{3}; 18 | end 19 | end 20 | if normalizemode 21 | data = data / max(abs(data)); 22 | end 23 | player = audioplayer(data, fs); 24 | play(player) -------------------------------------------------------------------------------- /matlab/lib/+util/plot.m: -------------------------------------------------------------------------------- 1 | function handle = plot(data_in, fs) 2 | % 1) plotwav(data) 3 | % data is the wav filename including path 4 | % 2) plotwav(data, fs) 5 | 6 | if nargin == 1 7 | [data, fs] = audioread(data_in); 8 | else 9 | data = data_in; 10 | end 11 | if nargout == 1 12 | handle = figure; 13 | else 14 | figure; 15 | end 16 | plot((0 : length(data) - 1) / fs, data); 17 | xlim([0 , (length(data) / fs)]); 18 | xlabel('Time(Secs)') -------------------------------------------------------------------------------- /matlab/lib/stftanalysis.m: -------------------------------------------------------------------------------- 1 | 2 | function y = stftanalysis(s, winsize, winshift) 3 | % 4 | %STFTANALYSIS short time Fourier transform analysis 5 | % Decompose the time domain signal into the time-frequency domain signal 6 | % 7 | % y = STFTANALYSIS(s, winsize, winshift) 8 | % 9 | % s is the time domain signal. If s is a matrix, the column of the matrix 10 | % will be treated a vector and the analysis will be performed on the vectors 11 | % separately. 12 | % y is the time-frequency signal. If number of channels equals to 1, the 13 | % the return value y will be a 2-D matrix (frame_number x window_size). 14 | % If the number of channels is more than 1, y will be a 3-D matrix 15 | % (frame_number x window_size x channel_number) 16 | 17 | % ============================================================================= 18 | % Created by Teng Xiang at 2018-01-12 19 | % ============================================================================= 20 | 21 | win = hann(winsize); 22 | channel_num = size(s, 2); 23 | frame_num = ceil((size(s, 1) - winsize)/ winshift) + 1; 24 | s = [s; zeros(winshift - mod(length(s) - winsize, winshift), channel_num)]; 25 | y = zeros(frame_num, winsize, channel_num); 26 | 27 | for l = 1 : frame_num 28 | index = (l-1) * winshift; 29 | y(l,:,:) = reshape(... 30 | fft(bsxfun(@times, s(index + 1:index+winsize, :), win)),... 31 | 1, winsize, channel_num); 32 | end 33 | -------------------------------------------------------------------------------- /matlab/lib/stftsynthesis.m: -------------------------------------------------------------------------------- 1 | 2 | function y = stftsynthesis(s, winsize, winshift) 3 | % 4 | %STFTSYNTHESIS short time Fourier transform synthesis 5 | % Sythesize the time domain signal from the time-frequency domain signal 6 | % 7 | % y1 = STFTSYNTHESIS(s, winsize, winshift) 8 | % 9 | % s is the T-F domain signal which should be arranged in a 3-D matrix, 10 | % whose size is frame_number x subband_number x channel_number. If s 11 | % is a 2-D matrix, it will be treated as an frame_number x subband_bumber x 1 12 | % 3-D matrix. 13 | 14 | % ============================================================================= 15 | % Created by Teng Xiang at 2018-01-12 16 | % ============================================================================= 17 | 18 | 19 | [frame_num, subband_number, channel_number] = size(s); 20 | if subband_number ~= winsize 21 | error('The 2rd dimension of input s must agree with the winsize'); 22 | end 23 | y = zeros((frame_num - 1) * winshift + winsize, channel_number); 24 | 25 | for l = 1 : frame_num 26 | tmp = reshape(ifft(squeeze(s(l,:,:))), winsize, channel_number); 27 | index = (l-1) * winshift; 28 | y(index+1 : index + winsize, :) = y(index+1 : index + winsize, :) + tmp; 29 | end -------------------------------------------------------------------------------- /python/stft.py: -------------------------------------------------------------------------------- 1 | # Created by Teng Xiang at 2018-08-10 2 | # Current version: 2018-08-10 3 | # https://github.com/helianvine/fdndlp 4 | # ============================================================================= 5 | 6 | import numpy as np 7 | from numpy.lib import stride_tricks 8 | 9 | def stft(data, frame_size=512, overlap=0.75, window=None): 10 | """ Multi-channel short time fourier transform 11 | 12 | Args: 13 | data: A 2-dimension numpy array with shape=(channels, samples) 14 | frame_size: An integer number of the length of the frame 15 | overlap: A float nonnegative number less than 1 indicating the overlap 16 | factor between adjacent frames 17 | 18 | Return: 19 | A 3-dimension numpy array with shape=(channels, frames, frequency_bins) 20 | """ 21 | assert(data.ndim == 2) 22 | if window == None: 23 | window = np.hanning(frame_size) 24 | frame_shift = int(frame_size - np.floor(overlap * frame_size)) 25 | cols = int(np.ceil((data.shape[1] - frame_size) / frame_shift)) + 1 26 | data = np.concatenate( 27 | (data, np.zeros((data.shape[0], frame_shift), dtype = np.float32)), 28 | axis = 1) 29 | samples = data.copy() 30 | frames = stride_tricks.as_strided( 31 | samples, 32 | shape=(samples.shape[0], cols, frame_size), 33 | strides=( 34 | samples.strides[-2], 35 | samples.strides[-1] * frame_shift, 36 | samples.strides[-1])).copy() 37 | frames *= window 38 | return np.fft.rfft(frames) 39 | 40 | def istft(data, frame_size=None, overlap=0.75, window=None): 41 | """ Multi-channel inverse short time fourier transform 42 | 43 | Args: 44 | data: A 3-dimension numpy array with shape=(channels, frames, frequency_bins) 45 | frame_size: An integer number of the length of the frame 46 | overlap: A float nonnegative number less than 1 indicating the overlap 47 | factor between adjacent frames 48 | 49 | Return: 50 | A 2-dimension numpy array with shape=(channels, samples) 51 | """ 52 | assert(data.ndim == 3) 53 | real_data = np.fft.irfft(data) 54 | if frame_size == None: 55 | frame_size = real_data.shape[-1] 56 | frame_num = data.shape[-2] 57 | frame_shift = int(frame_size - np.floor(frame_size * overlap)) 58 | length = (frame_num - 1) * frame_shift + frame_size 59 | output = np.zeros((data.shape[0], length)) 60 | for i in range(frame_num): 61 | index = i*frame_shift 62 | output[:, index : index + frame_size] += real_data[:,i] 63 | return output 64 | 65 | def log_spectrum(raw_data, frame_length=512): 66 | """Log magnitude spectrogram""" 67 | if raw_data.ndim == 1: 68 | raw_data = np.reshape(raw_data, (1, -1)) 69 | freq_data = stft(raw_data) 70 | phase = np.angle(freq_data) 71 | freq_data = np.abs(freq_data) 72 | freq_data = np.maximum(freq_data, 1e-8) 73 | log_data = np.log10(freq_data / freq_data.min()) 74 | return log_data, phase 75 | -------------------------------------------------------------------------------- /python/wpe.py: -------------------------------------------------------------------------------- 1 | # Created by Teng Xiang at 2018-08-10 2 | # Current version: 2018-08-10 3 | # https://github.com/helianvine/fdndlp 4 | # ============================================================================= 5 | 6 | """ Weighted prediction error(WPE) method for speech dereverberation.""" 7 | 8 | import stft 9 | import argparse 10 | import time 11 | import os 12 | import numpy as np 13 | import soundfile as sf 14 | from numpy.lib import stride_tricks 15 | # import matplotlib.pyplot as plt 16 | 17 | class Configrations(): 18 | """Argument parser for WPE method configurations.""" 19 | def __init__(self): 20 | self.parser = argparse.ArgumentParser() 21 | 22 | def parse(self): 23 | self.parser.add_argument('filename') 24 | self.parser.add_argument( 25 | '-o', '--output', default='drv.wav', 26 | help='output filename') 27 | self.parser.add_argument( 28 | '-m', '--mic_num', type=int, default=3, 29 | help='number of input channels') 30 | self.parser.add_argument( 31 | '-n','--out_num', type=int, default=2, 32 | help='number of output channels') 33 | self.parser.add_argument( 34 | '-p', '--order', type=int, default=30, 35 | help='predition order') 36 | self.cfgs = self.parser.parse_args() 37 | return self.cfgs 38 | 39 | 40 | class WpeMethod(object): 41 | """WPE metheod for speech dereverberaiton 42 | 43 | Weighted prediction errors (WPE) method is an outstanding speech 44 | derverberation algorithm, which is based on the multi-channel linear 45 | prediction algorithm and produces multi-channel output. 46 | 47 | Attributes: 48 | channels: Number of input channels. 49 | out_num: Number of output channels. 50 | p: An integer number of the prediction order. 51 | d: An integer number of the prediction delay. 52 | frame_size: An integer number of the length of the frame 53 | overlap: A float nonnegative number less than 1 indicating the overlap 54 | factor between adjacent frames 55 | """ 56 | def __init__(self, mic_num, out_num, order=30): 57 | self.channels = mic_num 58 | self.out_num = out_num 59 | self.p = order 60 | self.d = 2 61 | self.frame_size = 512 62 | self.overlap = 0.5 63 | self._iterations = 2 64 | 65 | @property 66 | def iterations(self): 67 | return self._iterations 68 | 69 | @iterations.setter 70 | def iterations(self, value): 71 | assert(int(value) > 0) 72 | self._iterations = int(value) 73 | 74 | def _display_cfgs(self): 75 | print('\nSettings:') 76 | print("Input channel: %d" % self.channels) 77 | print("Output channel: %d" % self.out_num) 78 | print("Prediction order: %d\n" % self.p) 79 | 80 | 81 | def run_offline(self, data): 82 | self._display_cfgs() 83 | time_start = time.time() 84 | print("Processing...") 85 | drv_data = self.__fdndlp(data) 86 | print("Done!\nTotal time: %f\n" % (time.time() - time_start)) 87 | return drv_data 88 | 89 | def __fdndlp(self, data): 90 | """Frequency-domain variance-normalized delayed liner prediction 91 | 92 | This is the core part of the WPE method. The variance-normalized 93 | linear prediciton algorithm is implemented in each frequency bin 94 | separately. Both the input and output signals are in time-domain. 95 | 96 | Args: 97 | data: A 2-dimension numpy array with shape=(chanels, samples) 98 | 99 | Returns: 100 | A 2-dimension numpy array with shape=(output_channels, samples) 101 | """ 102 | 103 | freq_data = stft.stft( 104 | data / np.abs(data).max(), 105 | frame_size=self.frame_size, overlap=self.overlap) 106 | self.freq_num = freq_data.shape[-1] 107 | drv_freq_data = freq_data[0:self.out_num].copy() 108 | for i in range(self.freq_num): 109 | xk = freq_data[:,:,i].T 110 | dk = self.__ndlp(xk) 111 | drv_freq_data[:,:,i] = dk.T 112 | drv_data = stft.istft( 113 | drv_freq_data, 114 | frame_size=self.frame_size, overlap=self.overlap) 115 | return drv_data / np.abs(drv_data).max() 116 | 117 | 118 | def __ndlp(self, xk): 119 | """Variance-normalized delayed liner prediction 120 | 121 | Here is the specific WPE algorithm implementation. The input should be 122 | the reverberant time-frequency signal in a single frequency bin and 123 | the output will be the dereverberated signal in the corresponding 124 | frequency bin. 125 | 126 | Args: 127 | xk: A 2-dimension numpy array with shape=(frames, input_chanels) 128 | 129 | Returns: 130 | A 2-dimension numpy array with shape=(frames, output_channels) 131 | """ 132 | cols = xk.shape[0] - self.d 133 | xk_buf = xk[:,0:self.out_num] 134 | xk = np.concatenate( 135 | (np.zeros((self.p - 1, self.channels)), xk), 136 | axis=0) 137 | xk_tmp = xk[:,::-1].copy() 138 | frames = stride_tricks.as_strided( 139 | xk_tmp, 140 | shape=(self.channels * self.p, cols), 141 | strides=(xk_tmp.strides[-1], xk_tmp.strides[-1]*self.channels)) 142 | frames = frames[::-1] 143 | sigma2 = np.mean(1 / (np.abs(xk_buf[self.d:]) ** 2), axis=1) 144 | for _ in range(self.iterations): 145 | x_cor_m = np.dot( 146 | np.dot(frames, np.diag(sigma2)), 147 | np.conj(frames.T)) 148 | x_cor_v = np.dot( 149 | frames, 150 | np.conj(xk_buf[self.d:] * sigma2.reshape(-1, 1))) 151 | coeffs = np.dot(np.linalg.inv(x_cor_m), x_cor_v) 152 | dk = xk_buf[self.d:] - np.dot(frames.T, np.conj(coeffs)) 153 | sigma2 = np.mean(1 / (np.abs(dk) ** 2), axis=1) 154 | return np.concatenate((xk_buf[0:self.d], dk)) 155 | 156 | def load_audio(self, filename): 157 | data, fs = sf.read(filename, always_2d=True) 158 | data = data.T 159 | assert(data.shape[0] >= self.channels) 160 | if data.shape[0] > self.channels: 161 | print( 162 | "The number of the input channels is %d," % data.shape[0], 163 | "and only the first %d channels are loaded." % self.channels) 164 | data = data[0: self.channels] 165 | return data.copy(), fs 166 | 167 | def write_wav(self, data, fs, filename, path='wav_out'): 168 | if not os.path.exists(path): 169 | os.makedirs(path) 170 | filepath = os.path.join(path, filename) 171 | print('Write to file: %s.' % filepath) 172 | sf.write(filepath, data.T, fs, subtype='PCM_16') 173 | 174 | if __name__ == '__main__': 175 | cfgs = Configrations().parse() 176 | # cfgs.filename = '../wav_sample/sample_4ch.wav' 177 | wpe = WpeMethod(cfgs.mic_num, cfgs.out_num, cfgs.order) 178 | data, fs = wpe.load_audio(cfgs.filename) 179 | drv_data = wpe.run_offline(data) 180 | wpe.write_wav(drv_data, fs, cfgs.output) 181 | 182 | # plt.figure() 183 | # spec, _ = stft.log_spectrum(drv_data[0]) 184 | # plt.pcolor(spec[0].T) 185 | # plt.show() 186 | -------------------------------------------------------------------------------- /wav_sample/drv_sample_4ch.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/helianvine/fdndlp/8c93c16af2006c2065c9dc00f0b9717793fdf2b8/wav_sample/drv_sample_4ch.wav -------------------------------------------------------------------------------- /wav_sample/sample_4ch.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/helianvine/fdndlp/8c93c16af2006c2065c9dc00f0b9717793fdf2b8/wav_sample/sample_4ch.wav --------------------------------------------------------------------------------