├── .gitignore ├── README.md ├── cgmm_mask_estimate.m ├── cgmm_mask_visualization.m └── img └── F05_447C0212_CAF.png /.gitignore: -------------------------------------------------------------------------------- 1 | egs/ 2 | *.mat 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Notes 2 | 3 | Implementation of CGMM mask estimation in MATLAB(final version). 4 | 5 | Previous one in [CGMM-MVDR](https://github.com/funcwj/CGMM-MVDR) 6 | 7 | ### Examples 8 | mask for F05_447C0212_CAF.CH{1,2,3,4,5,6} in CHiME4 et05_caf_real 9 | 10 | ![](img/F05_447C0212_CAF.png) 11 | 12 | ### Reference 13 | T. Higuchi, N. Ito, T. Yoshioka, and T. Nakatani, "Robust mvdr beamforming using time-frequency masks for online/offline asr in noise," in ICASSP, 2016. -------------------------------------------------------------------------------- /cgmm_mask_estimate.m: -------------------------------------------------------------------------------- 1 | function cgmm_mask_estimate(pattern, output, iters) 2 | 3 | % apply mvdr based on mask estimated by cgmm 4 | 5 | if nargin <= 1 || nargin > 3 6 | error('format error: cgmm_mask_estimate(pattern, output, [iters = 20])'); 7 | end 8 | 9 | if nargin < 3 10 | iters = 20; 11 | end 12 | 13 | assert(ischar(pattern)); 14 | assert(ischar(output)); 15 | 16 | num_iters = iters; 17 | frame_length = 1024; 18 | fft_length = 1024; 19 | frame_shift = 256; 20 | theta = 10^-6; 21 | hanning_wnd = hanning(frame_length, 'periodic'); 22 | 23 | 24 | multi_channel_wave = dir(pattern); 25 | num_channels = size(multi_channel_wave, 1); 26 | assert(num_channels ~= 0, ['Cound not find wave file match pattern ' pattern]); 27 | 28 | [ff, ~, ~] = fileparts(pattern); 29 | 30 | for c = 1: num_channels 31 | fprintf('--- read audio from %s/%s\n', ff, multi_channel_wave(c, :).name) 32 | samples = audioread([ff '/' multi_channel_wave(c, :).name]); 33 | frames = enframe(samples, hanning_wnd, frame_shift); 34 | frames_size = size(frames); 35 | frames_padding = zeros(frames_size(1), fft_length); 36 | frames_padding(:, 1: frame_length) = frames; 37 | % rfft: T x F 38 | spectrums(:, :, c) = rfft(frames_padding, fft_length, 2); 39 | end 40 | 41 | specs = permute(spectrums, [3, 1, 2]); 42 | [num_channels, num_frames, num_bins] = size(specs); 43 | 44 | % CGMM parameters 45 | lambda_noise = zeros(num_frames, num_bins); 46 | lambda_noisy = zeros(num_frames, num_bins); 47 | phi_noise = ones(num_frames, num_bins); 48 | phi_noisy = ones(num_frames, num_bins); 49 | R_noise = zeros(num_channels, num_channels, num_bins); 50 | R_noisy = zeros(num_channels, num_channels, num_bins); 51 | 52 | % init R_noisy R_noise 53 | for f = 1: num_bins 54 | R_noisy(:, :, f) = specs(:, :, f) * specs(:, :, f)' / num_frames; 55 | R_noise(:, :, f) = eye(num_channels, num_channels); 56 | end 57 | 58 | % precompute y^H * y 59 | yyh = zeros(num_channels, num_channels, num_frames, num_bins); 60 | 61 | for f = 1: num_bins 62 | for t = 1: num_frames 63 | yyh(:, :, t, f) = specs(:, t, f) * specs(:, t, f)'; 64 | end 65 | end 66 | 67 | % init phi 68 | for f = 1: num_bins 69 | 70 | R_noisy_onbin = stab(R_noisy(:, :, f), theta, num_channels); 71 | R_noise_onbin = stab(R_noise(:, :, f), theta, num_channels); 72 | 73 | R_noisy_inv = inv(R_noisy_onbin); 74 | R_noise_inv = inv(R_noise_onbin); 75 | 76 | for t = 1: num_frames 77 | corre = yyh(:, :, t, f); 78 | phi_noise(t, f) = real(trace(corre * R_noise_inv) / num_channels); 79 | phi_noisy(t, f) = real(trace(corre * R_noisy_inv) / num_channels); 80 | end 81 | end 82 | 83 | % start CGMM training 84 | p_noise = ones(num_frames, num_bins); 85 | p_noisy = ones(num_frames, num_bins); 86 | 87 | for iter = 1: num_iters 88 | 89 | for f = 1: num_bins 90 | 91 | R_noisy_onbin = stab(R_noisy(:, :, f), theta, num_channels); 92 | R_noise_onbin = stab(R_noise(:, :, f), theta, num_channels); 93 | 94 | R_noisy_inv = inv(R_noisy_onbin); 95 | R_noise_inv = inv(R_noise_onbin); 96 | R_noisy_accu = zeros(num_channels, num_channels); 97 | R_noise_accu = zeros(num_channels, num_channels); 98 | 99 | for t = 1: num_frames 100 | corre = yyh(:, :, t, f); 101 | obs = specs(:, t, f); 102 | 103 | % update lambda 104 | k_noise = obs' * (R_noise_inv / phi_noise(t, f)) * obs; 105 | det_noise = det(phi_noise(t, f) * R_noise_onbin) * pi; 106 | % +theta: avoid NAN 107 | p_noise(t, f) = real(exp(-k_noise) / det_noise) + theta; 108 | 109 | k_noisy = obs' * (R_noisy_inv / phi_noisy(t, f)) * obs; 110 | det_noisy = det(phi_noisy(t, f) * R_noisy_onbin) * pi; 111 | p_noisy(t, f) = real(exp(-k_noisy) / det_noisy) + theta; 112 | 113 | lambda_noise(t, f) = p_noise(t, f) / (p_noise(t, f) + p_noisy(t, f)); 114 | lambda_noisy(t, f) = p_noisy(t, f) / (p_noise(t, f) + p_noisy(t, f)); 115 | 116 | % update phi 117 | phi_noise(t, f) = real(trace(corre * R_noise_inv) / num_channels); 118 | phi_noisy(t, f) = real(trace(corre * R_noisy_inv) / num_channels); 119 | 120 | % accu R 121 | R_noise_accu = R_noise_accu + lambda_noise(t, f) / phi_noise(t, f) * corre; 122 | R_noisy_accu = R_noisy_accu + lambda_noisy(t, f) / phi_noisy(t, f) * corre; 123 | end 124 | % update R 125 | R_noise(:, :, f) = R_noise_accu / sum(lambda_noise(:, f)); 126 | R_noisy(:, :, f) = R_noisy_accu / sum(lambda_noisy(:, f)); 127 | 128 | end 129 | % Q = sum(sum(lambda_noise .* log(p_noise) + lambda_noisy .* log(p_noisy))) / (num_frames * num_bins); 130 | Qn = sum(sum(lambda_noise .* log(p_noise))) / (num_frames * num_bins); 131 | Qx = sum(sum(lambda_noisy .* log(p_noisy))) / (num_frames * num_bins); 132 | fprintf('--- iter = %2d, Q = %.4f + %.4f = %.4f\n', iter, Qn, Qx, Qn + Qx); 133 | end 134 | 135 | save([output '.mat'], 'lambda_noise'); 136 | 137 | end 138 | 139 | function mat = stab(mat, theta, num_channels) 140 | d = 10 .^ (-6: 1: -1); 141 | for i = 1: 6 142 | if rcond(mat) > theta 143 | break; 144 | end 145 | mat = mat + d(i) * eye(num_channels); 146 | end 147 | end 148 | 149 | -------------------------------------------------------------------------------- /cgmm_mask_visualization.m: -------------------------------------------------------------------------------- 1 | function cgmm_mask_visualization(mask_path) 2 | screensize = get( groot, 'Screensize' ); 3 | set(gcf, 'position', screensize); 4 | mask = load(mask_path); 5 | lambda_noise = transpose(fliplr(abs(mask.lambda_noise))); 6 | lambda_clean = transpose(fliplr(1 - abs(mask.lambda_noise))); 7 | [num_bins, ~] = size(lambda_clean); 8 | freq_ticks = linspace(0, num_bins - 1, 9); 9 | colormap gray 10 | subplot(1, 2, 1), imagesc(lambda_noise); 11 | yticks(freq_ticks); 12 | yticklabels(fliplr(freq_ticks) / (num_bins - 1) * 8); 13 | ylabel('Frequency(kHz)'); 14 | xlabel('Frames'); 15 | title('noise mask'); 16 | colorbar 17 | subplot(1, 2, 2), imagesc(lambda_clean); 18 | yticks(freq_ticks); 19 | yticklabels(fliplr(freq_ticks) / (num_bins - 1) * 8); 20 | ylabel('Frequency(kHz)'); 21 | xlabel('Frames'); 22 | title('clean mask'); 23 | colorbar 24 | saveas(gcf, [mask_path '.jpg']); 25 | end 26 | -------------------------------------------------------------------------------- /img/F05_447C0212_CAF.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/funcwj/cgmm-mask-estimator/36c95059e93a034e685a77b92b8e77d8f68d9e60/img/F05_447C0212_CAF.png --------------------------------------------------------------------------------